File size: 14,355 Bytes
beec38c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
011b800
 
 
beec38c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6152816
f07e450
8f4c53d
9b81d07
e97b433
9b81d07
6152816
9b81d07
 
 
 
e97b433
011b800
9b81d07
8f4c53d
 
011b800
0e1643e
9b81d07
8f4c53d
9b81d07
e97b433
8f4c53d
9b81d07
 
e97b433
 
 
831ef02
 
 
 
e97b433
831ef02
 
e97b433
beec38c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import os
import sys

# Set critical environment variables first
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
os.environ["WATCHDOG_OPTIONAL"] = "1"
os.environ["PYTORCH_JIT"] = "0"

# Import third party modules
import streamlit as st
import numpy as np
import random
from PIL import Image
import io
import time

# Set up imports for huggingface_hub
# Import what we can, but handle potential import errors
try:
    from huggingface_hub import HfApi, HfFolder, login
except ImportError as e:
    st.error(f"Error importing from huggingface_hub: {e}")

# Configure Hugging Face cache and environment
os.environ["HF_HOME"] = os.path.join(os.getcwd(), ".cache/huggingface")

# Import PyTorch after environment setup
import torch
from diffusers import FluxFillPipeline

import warnings
warnings.filterwarnings("ignore", message=".*add_prefix_space.*")

# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

# Setting page config
st.set_page_config(
    page_title="FLUX.1 Fill [dev]",
    layout="wide"
)

# Title and description
st.markdown("""
# FLUX.1 Fill [dev]
12B param rectified flow transformer structural conditioning tuned, guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)  
[[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
""")

# Add simple instructions
st.sidebar.markdown("""
## Important Setup Information

This app uses the FLUX.1-Fill-dev model which requires special access:

1. Sign up/login at [Hugging Face](https://huggingface.co/)
2. Request access to [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) by clicking 'Access repository'
3. Wait for approval from model owners

### For Hugging Face Spaces Setup:
1. Go to your Space settings > Secrets
2. Add a new secret with the name `HF_TOKEN` 
3. Set its value to your Hugging Face API token (found in your account settings)
""")

# Try to get a Hugging Face token from environment variables
def get_hf_token():
    # Check common environment variable names for HF tokens
    token_env_vars = [
        'HF_TOKEN', 
        'HUGGINGFACE_TOKEN', 
        'HUGGING_FACE_HUB_TOKEN',
        'HF_API_TOKEN',
        'HUGGINGFACE_API_TOKEN',
        'HUGGINGFACE_HUB_TOKEN'
    ]
    
    for env_var in token_env_vars:
        if env_var in os.environ and os.environ[env_var].strip():
            st.sidebar.success(f"Found token in {env_var}")
            return os.environ[env_var].strip()
    
    # If we're here, no token was found
    st.sidebar.warning("No Hugging Face token found in environment variables")
    
    return None

@st.cache_resource(show_spinner=False)
def load_model():
    """Load the model with a simplified approach using the required token"""
    # Get device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    st.info(f"Using device: {device}")
    
    # Get token
    token = get_hf_token()
    st.info(f"Token available: {'Yes' if token else 'No'}")
    
    try:
        # Add use_fast_tokenizer=True to address the tokenizer warning
        model = FluxFillPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-Fill-dev",
            token=token,
            torch_dtype=torch.bfloat16,
            revision="main"
        )
        st.success("Model loaded successfully!")
        return model.to(device)
    except Exception as e:
        st.error(f"Failed to load model: {e}")
        
        if "401" in str(e) or "access" in str(e).lower() or "denied" in str(e).lower():
            st.error("""
            Access Denied: You need to:
            1. Request access to the model at https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev
            2. Set up your Hugging Face token in Spaces:
               - Go to your Space settings > Secrets
               - Add a new secret with name 'HF_TOKEN'
               - Set its value to your Hugging Face API token
            3. Wait for approval from model owners
            
            Note: You can find your token at https://huggingface.co/settings/tokens
            """)
        st.stop()
        
    except Exception as e:
        st.error(f"Failed to load model after all attempts: {e}")
        
        if "401" in str(e) or "access" in str(e).lower() or "denied" in str(e).lower():
            st.error("""
            Access Denied: You need to:
            1. Request access to the model at https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev
            2. Set up your Hugging Face token in Spaces:
               - Go to your Space settings > Secrets
               - Add a new secret with name 'HF_TOKEN'
               - Set its value to your Hugging Face API token
            3. Wait for approval from model owners
            
            Note: You can find your token at https://huggingface.co/settings/tokens
            """)
        elif "Tried to instantiate class" in str(e):
            st.error("""
            PyTorch class initialization error. Try restarting the app.
            If the error persists, try accessing the app from a different browser.
            """)
        st.stop()

# Initialize model section
with st.spinner("Loading model..."):
    try:
        pipe = load_model()
        st.success("Model loaded successfully!")
    except Exception as e:
        st.error(f"Failed to load model: {str(e)}")
        st.stop()

def calculate_optimal_dimensions(image: Image.Image):
    # Extract the original dimensions
    original_width, original_height = image.size
    
    # Set constants
    MIN_ASPECT_RATIO = 9 / 16
    MAX_ASPECT_RATIO = 16 / 9
    FIXED_DIMENSION = 1024

    # Calculate the aspect ratio of the original image
    original_aspect_ratio = original_width / original_height

    # Determine which dimension to fix
    if original_aspect_ratio > 1:  # Wider than tall
        width = FIXED_DIMENSION
        height = round(FIXED_DIMENSION / original_aspect_ratio)
    else:  # Taller than wide
        height = FIXED_DIMENSION
        width = round(FIXED_DIMENSION * original_aspect_ratio)

    # Ensure dimensions are multiples of 8
    width = (width // 8) * 8
    height = (height // 8) * 8

    # Enforce aspect ratio limits
    calculated_aspect_ratio = width / height
    if calculated_aspect_ratio > MAX_ASPECT_RATIO:
        width = (height * MAX_ASPECT_RATIO // 8) * 8
    elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
        height = (width / MIN_ASPECT_RATIO // 8) * 8

    # Ensure width and height remain above the minimum dimensions
    width = max(width, 576) if width == FIXED_DIMENSION else width
    height = max(height, 576) if height == FIXED_DIMENSION else height

    return width, height

# Create two columns for layout
col1, col2 = st.columns([1, 1])

with col1:
    # Upload image
    uploaded_file = st.file_uploader("Upload an image for inpainting", type=["jpg", "jpeg", "png"])
    
    if uploaded_file is not None:
        # Display the uploaded image
        image = Image.open(uploaded_file).convert("RGB")
        st.image(image, caption="Uploaded Image", use_container_width=True)
        
        # Simple approach to create a mask - select a square area
        st.write("Select an area to inpaint:")
        
        # Get image dimensions
        img_width, img_height = image.size
        
        # Scale for display while maintaining aspect ratio
        display_height = 600
        display_width = int(img_width * (display_height / img_height))
        
        # Create sliders for selecting the area
        col_sliders1, col_sliders2 = st.columns(2)
        
        with col_sliders1:
            x1 = st.slider("Left edge (X1)", 0, img_width, img_width // 4)
            y1 = st.slider("Top edge (Y1)", 0, img_height, img_height // 4)
        
        with col_sliders2:
            x2 = st.slider("Right edge (X2)", x1, img_width, min(x1 + img_width // 2, img_width))
            y2 = st.slider("Bottom edge (Y2)", y1, img_height, min(y1 + img_height // 2, img_height))
        
        # Create a copy of the image to show the mask
        preview_img = image.copy()
        preview_mask = Image.new("L", image.size, 0)
        
        # Draw a white rectangle on the mask
        from PIL import ImageDraw
        draw = ImageDraw.Draw(preview_mask)
        draw.rectangle([(x1, y1), (x2, y2)], fill=255)
        
        # Show the mask on the image
        masked_preview = image.copy()
        # Add semi-transparent white overlay 
        overlay = Image.new("RGBA", image.size, (255, 255, 255, 128))
        masked_preview.paste(overlay, (0, 0), preview_mask)
        
        st.image(masked_preview, caption="Area to inpaint (white overlay)", use_container_width=True)
        
        # Prompt input
        prompt = st.text_input("Enter your prompt")
        
        # Example prompts
        examples = [
            "a tiny astronaut hatching from an egg on the moon",
            "a cat holding a sign that says hello world",
            "an anime illustration of a wiener schnitzel",
        ]
        
        example_prompt = st.selectbox("Or select an example prompt", [""] + examples)
        if example_prompt and not prompt:
            prompt = example_prompt
        
        # Advanced settings with expander
        with st.expander("Advanced Settings"):
            randomize_seed = st.checkbox("Randomize seed", value=True)
            
            if not randomize_seed:
                seed = st.slider("Seed", 0, MAX_SEED, 0)
            else:
                seed = random.randint(0, MAX_SEED)
            
            guidance_scale = st.slider("Guidance Scale", 1.0, 30.0, 3.5, 0.5)
            num_inference_steps = st.slider("Number of inference steps", 1, 50, 28)
        
        # Run button
        run_button = st.button("Generate")

with col2:
    if uploaded_file is not None:
        st.write("Result will appear here")
        
        if run_button and prompt:
            with st.spinner("Generating image..."):
                # Create mask from rectangle coordinates
                mask = Image.new("L", image.size, 0)
                draw = ImageDraw.Draw(mask)
                draw.rectangle([(x1, y1), (x2, y2)], fill=255)
                
                # Calculate dimensions for generation
                width, height = calculate_optimal_dimensions(image)
                
                # Progress bar
                progress_bar = st.progress(0)
                
                # Generate the image
                try:
                    # Set up progress bar updates
                    progress_text = st.empty()
                    debug_info = st.empty()
                    
                    # Show parameters for debugging
                    debug_info.info(f"Model type: {pipe.__class__.__name__}")
                    
                    # Update progress
                    progress_bar.progress(0.1)
                    progress_text.text("Preparing image and mask...")
                    
                    # Make sure mask is in the right format
                    # Some models require masks where white (255) is the area to inpaint
                    mask_img = mask.convert("L")
                    
                    # Prepare arguments - different models may have different parameter names
                    model_class_name = pipe.__class__.__name__
                    
                    # Common parameters for all models
                    common_params = {
                        "prompt": prompt,
                        "image": image,
                        "mask_image": mask_img,
                        "num_inference_steps": num_inference_steps,
                        "generator": torch.Generator("cpu").manual_seed(seed)
                    }
                    
                    # Add parameters for Flux model
                    common_params["guidance_scale"] = guidance_scale
                    
                    # Try running generation with dimensions
                    try:
                        progress_text.text("Running generation...")
                        progress_bar.progress(0.2)
                        
                        # First try with dimensions
                        common_params["height"] = int(height)
                        common_params["width"] = int(width)
                        result = pipe(**common_params)
                    except Exception as e:
                        debug_info.warning(f"First attempt failed: {str(e)}")
                        progress_text.text("Retrying with adjusted parameters...")
                        
                        # Remove dimensions and try again
                        del common_params["height"]
                        del common_params["width"]
                        result = pipe(**common_params)
                    
                    # Get the result image
                    result_image = result.images[0]
                    
                    # Update final progress
                    progress_bar.progress(1.0)
                    progress_text.text("Complete!")
                    debug_info.empty()  # Clear debug info
                    
                    # Display the result
                    st.image(result_image, caption="Generated Result", use_column_width=True)
                    
                    # Add download button
                    buf = io.BytesIO()
                    result_image.save(buf, format="PNG")
                    st.download_button(
                        label="Download result",
                        data=buf.getvalue(),
                        file_name="inpaint_result.png",
                        mime="image/png",
                    )
                    
                    # Display used seed
                    st.write(f"Seed used: {seed}")
                    
                except Exception as e:
                    st.error(f"An error occurred during generation: {str(e)}")
                    st.error("Try adjusting the parameters or using a different image.")

# If no image is uploaded
if uploaded_file is None:
    with col2:
        st.write("Please upload an image first")