Spaces:
Running
Running
| import os | |
| import sys | |
| # Set critical environment variables first | |
| os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |
| os.environ["WATCHDOG_OPTIONAL"] = "1" | |
| os.environ["PYTORCH_JIT"] = "0" | |
| # Import third party modules | |
| import streamlit as st | |
| import numpy as np | |
| import random | |
| from PIL import Image | |
| import io | |
| import time | |
| # Set up imports for huggingface_hub | |
| # Import what we can, but handle potential import errors | |
| try: | |
| from huggingface_hub import HfApi, HfFolder, login | |
| except ImportError as e: | |
| st.error(f"Error importing from huggingface_hub: {e}") | |
| # Configure Hugging Face cache and environment | |
| os.environ["HF_HOME"] = os.path.join(os.getcwd(), ".cache/huggingface") | |
| # Import PyTorch after environment setup | |
| import torch | |
| from diffusers import FluxFillPipeline | |
| import warnings | |
| warnings.filterwarnings("ignore", message=".*add_prefix_space.*") | |
| # Constants | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 2048 | |
| # Setting page config | |
| st.set_page_config( | |
| page_title="FLUX.1 Fill [dev]", | |
| layout="wide" | |
| ) | |
| # Title and description | |
| st.markdown(""" | |
| # FLUX.1 Fill [dev] | |
| 12B param rectified flow transformer structural conditioning tuned, guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) | |
| [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)] | |
| """) | |
| # Add simple instructions | |
| st.sidebar.markdown(""" | |
| ## Important Setup Information | |
| This app uses the FLUX.1-Fill-dev model which requires special access: | |
| 1. Sign up/login at [Hugging Face](https://huggingface.co/) | |
| 2. Request access to [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) by clicking 'Access repository' | |
| 3. Wait for approval from model owners | |
| ### For Hugging Face Spaces Setup: | |
| 1. Go to your Space settings > Secrets | |
| 2. Add a new secret with the name `HF_TOKEN` | |
| 3. Set its value to your Hugging Face API token (found in your account settings) | |
| """) | |
| # Try to get a Hugging Face token from environment variables | |
| def get_hf_token(): | |
| # Check common environment variable names for HF tokens | |
| token_env_vars = [ | |
| 'HF_TOKEN', | |
| 'HUGGINGFACE_TOKEN', | |
| 'HUGGING_FACE_HUB_TOKEN', | |
| 'HF_API_TOKEN', | |
| 'HUGGINGFACE_API_TOKEN', | |
| 'HUGGINGFACE_HUB_TOKEN' | |
| ] | |
| for env_var in token_env_vars: | |
| if env_var in os.environ and os.environ[env_var].strip(): | |
| st.sidebar.success(f"Found token in {env_var}") | |
| return os.environ[env_var].strip() | |
| # If we're here, no token was found | |
| st.sidebar.warning("No Hugging Face token found in environment variables") | |
| return None | |
| def load_model(): | |
| """Load the model with a simplified approach using the required token""" | |
| # Get device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| st.info(f"Using device: {device}") | |
| # Get token | |
| token = get_hf_token() | |
| st.info(f"Token available: {'Yes' if token else 'No'}") | |
| try: | |
| # Add use_fast_tokenizer=True to address the tokenizer warning | |
| model = FluxFillPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-Fill-dev", | |
| token=token, | |
| torch_dtype=torch.bfloat16, | |
| revision="main" | |
| ) | |
| st.success("Model loaded successfully!") | |
| return model.to(device) | |
| except Exception as e: | |
| st.error(f"Failed to load model: {e}") | |
| if "401" in str(e) or "access" in str(e).lower() or "denied" in str(e).lower(): | |
| st.error(""" | |
| Access Denied: You need to: | |
| 1. Request access to the model at https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | |
| 2. Set up your Hugging Face token in Spaces: | |
| - Go to your Space settings > Secrets | |
| - Add a new secret with name 'HF_TOKEN' | |
| - Set its value to your Hugging Face API token | |
| 3. Wait for approval from model owners | |
| Note: You can find your token at https://huggingface.co/settings/tokens | |
| """) | |
| st.stop() | |
| except Exception as e: | |
| st.error(f"Failed to load model after all attempts: {e}") | |
| if "401" in str(e) or "access" in str(e).lower() or "denied" in str(e).lower(): | |
| st.error(""" | |
| Access Denied: You need to: | |
| 1. Request access to the model at https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | |
| 2. Set up your Hugging Face token in Spaces: | |
| - Go to your Space settings > Secrets | |
| - Add a new secret with name 'HF_TOKEN' | |
| - Set its value to your Hugging Face API token | |
| 3. Wait for approval from model owners | |
| Note: You can find your token at https://huggingface.co/settings/tokens | |
| """) | |
| elif "Tried to instantiate class" in str(e): | |
| st.error(""" | |
| PyTorch class initialization error. Try restarting the app. | |
| If the error persists, try accessing the app from a different browser. | |
| """) | |
| st.stop() | |
| # Initialize model section | |
| with st.spinner("Loading model..."): | |
| try: | |
| pipe = load_model() | |
| st.success("Model loaded successfully!") | |
| except Exception as e: | |
| st.error(f"Failed to load model: {str(e)}") | |
| st.stop() | |
| def calculate_optimal_dimensions(image: Image.Image): | |
| # Extract the original dimensions | |
| original_width, original_height = image.size | |
| # Set constants | |
| MIN_ASPECT_RATIO = 9 / 16 | |
| MAX_ASPECT_RATIO = 16 / 9 | |
| FIXED_DIMENSION = 1024 | |
| # Calculate the aspect ratio of the original image | |
| original_aspect_ratio = original_width / original_height | |
| # Determine which dimension to fix | |
| if original_aspect_ratio > 1: # Wider than tall | |
| width = FIXED_DIMENSION | |
| height = round(FIXED_DIMENSION / original_aspect_ratio) | |
| else: # Taller than wide | |
| height = FIXED_DIMENSION | |
| width = round(FIXED_DIMENSION * original_aspect_ratio) | |
| # Ensure dimensions are multiples of 8 | |
| width = (width // 8) * 8 | |
| height = (height // 8) * 8 | |
| # Enforce aspect ratio limits | |
| calculated_aspect_ratio = width / height | |
| if calculated_aspect_ratio > MAX_ASPECT_RATIO: | |
| width = (height * MAX_ASPECT_RATIO // 8) * 8 | |
| elif calculated_aspect_ratio < MIN_ASPECT_RATIO: | |
| height = (width / MIN_ASPECT_RATIO // 8) * 8 | |
| # Ensure width and height remain above the minimum dimensions | |
| width = max(width, 576) if width == FIXED_DIMENSION else width | |
| height = max(height, 576) if height == FIXED_DIMENSION else height | |
| return width, height | |
| # Create two columns for layout | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| # Upload image | |
| uploaded_file = st.file_uploader("Upload an image for inpainting", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Display the uploaded image | |
| image = Image.open(uploaded_file).convert("RGB") | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| # Simple approach to create a mask - select a square area | |
| st.write("Select an area to inpaint:") | |
| # Get image dimensions | |
| img_width, img_height = image.size | |
| # Scale for display while maintaining aspect ratio | |
| display_height = 600 | |
| display_width = int(img_width * (display_height / img_height)) | |
| # Create sliders for selecting the area | |
| col_sliders1, col_sliders2 = st.columns(2) | |
| with col_sliders1: | |
| x1 = st.slider("Left edge (X1)", 0, img_width, img_width // 4) | |
| y1 = st.slider("Top edge (Y1)", 0, img_height, img_height // 4) | |
| with col_sliders2: | |
| x2 = st.slider("Right edge (X2)", x1, img_width, min(x1 + img_width // 2, img_width)) | |
| y2 = st.slider("Bottom edge (Y2)", y1, img_height, min(y1 + img_height // 2, img_height)) | |
| # Create a copy of the image to show the mask | |
| preview_img = image.copy() | |
| preview_mask = Image.new("L", image.size, 0) | |
| # Draw a white rectangle on the mask | |
| from PIL import ImageDraw | |
| draw = ImageDraw.Draw(preview_mask) | |
| draw.rectangle([(x1, y1), (x2, y2)], fill=255) | |
| # Show the mask on the image | |
| masked_preview = image.copy() | |
| # Add semi-transparent white overlay | |
| overlay = Image.new("RGBA", image.size, (255, 255, 255, 128)) | |
| masked_preview.paste(overlay, (0, 0), preview_mask) | |
| st.image(masked_preview, caption="Area to inpaint (white overlay)", use_container_width=True) | |
| # Prompt input | |
| prompt = st.text_input("Enter your prompt") | |
| # Example prompts | |
| examples = [ | |
| "a tiny astronaut hatching from an egg on the moon", | |
| "a cat holding a sign that says hello world", | |
| "an anime illustration of a wiener schnitzel", | |
| ] | |
| example_prompt = st.selectbox("Or select an example prompt", [""] + examples) | |
| if example_prompt and not prompt: | |
| prompt = example_prompt | |
| # Advanced settings with expander | |
| with st.expander("Advanced Settings"): | |
| randomize_seed = st.checkbox("Randomize seed", value=True) | |
| if not randomize_seed: | |
| seed = st.slider("Seed", 0, MAX_SEED, 0) | |
| else: | |
| seed = random.randint(0, MAX_SEED) | |
| guidance_scale = st.slider("Guidance Scale", 1.0, 30.0, 3.5, 0.5) | |
| num_inference_steps = st.slider("Number of inference steps", 1, 50, 28) | |
| # Run button | |
| run_button = st.button("Generate") | |
| with col2: | |
| if uploaded_file is not None: | |
| st.write("Result will appear here") | |
| if run_button and prompt: | |
| with st.spinner("Generating image..."): | |
| # Create mask from rectangle coordinates | |
| mask = Image.new("L", image.size, 0) | |
| draw = ImageDraw.Draw(mask) | |
| draw.rectangle([(x1, y1), (x2, y2)], fill=255) | |
| # Calculate dimensions for generation | |
| width, height = calculate_optimal_dimensions(image) | |
| # Progress bar | |
| progress_bar = st.progress(0) | |
| # Generate the image | |
| try: | |
| # Set up progress bar updates | |
| progress_text = st.empty() | |
| debug_info = st.empty() | |
| # Show parameters for debugging | |
| debug_info.info(f"Model type: {pipe.__class__.__name__}") | |
| # Update progress | |
| progress_bar.progress(0.1) | |
| progress_text.text("Preparing image and mask...") | |
| # Make sure mask is in the right format | |
| # Some models require masks where white (255) is the area to inpaint | |
| mask_img = mask.convert("L") | |
| # Prepare arguments - different models may have different parameter names | |
| model_class_name = pipe.__class__.__name__ | |
| # Common parameters for all models | |
| common_params = { | |
| "prompt": prompt, | |
| "image": image, | |
| "mask_image": mask_img, | |
| "num_inference_steps": num_inference_steps, | |
| "generator": torch.Generator("cpu").manual_seed(seed) | |
| } | |
| # Add parameters for Flux model | |
| common_params["guidance_scale"] = guidance_scale | |
| # Try running generation with dimensions | |
| try: | |
| progress_text.text("Running generation...") | |
| progress_bar.progress(0.2) | |
| # First try with dimensions | |
| common_params["height"] = int(height) | |
| common_params["width"] = int(width) | |
| result = pipe(**common_params) | |
| except Exception as e: | |
| debug_info.warning(f"First attempt failed: {str(e)}") | |
| progress_text.text("Retrying with adjusted parameters...") | |
| # Remove dimensions and try again | |
| del common_params["height"] | |
| del common_params["width"] | |
| result = pipe(**common_params) | |
| # Get the result image | |
| result_image = result.images[0] | |
| # Update final progress | |
| progress_bar.progress(1.0) | |
| progress_text.text("Complete!") | |
| debug_info.empty() # Clear debug info | |
| # Display the result | |
| st.image(result_image, caption="Generated Result", use_column_width=True) | |
| # Add download button | |
| buf = io.BytesIO() | |
| result_image.save(buf, format="PNG") | |
| st.download_button( | |
| label="Download result", | |
| data=buf.getvalue(), | |
| file_name="inpaint_result.png", | |
| mime="image/png", | |
| ) | |
| # Display used seed | |
| st.write(f"Seed used: {seed}") | |
| except Exception as e: | |
| st.error(f"An error occurred during generation: {str(e)}") | |
| st.error("Try adjusting the parameters or using a different image.") | |
| # If no image is uploaded | |
| if uploaded_file is None: | |
| with col2: | |
| st.write("Please upload an image first") |