jdavis commited on
Commit
beec38c
·
verified ·
1 Parent(s): 8f4c53d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +333 -1
app.py CHANGED
@@ -1,3 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  @st.cache_resource(show_spinner=False)
2
  def load_model():
3
  """Load the model with a simplified approach using the required token"""
@@ -33,4 +118,251 @@ def load_model():
33
 
34
  Note: You can find your token at https://huggingface.co/settings/tokens
35
  """)
36
- st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Set critical environment variables first
5
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
6
+ os.environ["WATCHDOG_OPTIONAL"] = "1"
7
+ os.environ["PYTORCH_JIT"] = "0"
8
+
9
+ # Import third party modules
10
+ import streamlit as st
11
+ import numpy as np
12
+ import random
13
+ from PIL import Image
14
+ import io
15
+ import time
16
+
17
+ # Set up imports for huggingface_hub
18
+ # Import what we can, but handle potential import errors
19
+ try:
20
+ from huggingface_hub import HfApi, HfFolder, login
21
+ except ImportError as e:
22
+ st.error(f"Error importing from huggingface_hub: {e}")
23
+
24
+ # Configure Hugging Face cache and environment
25
+ os.environ["HF_HOME"] = os.path.join(os.getcwd(), ".cache/huggingface")
26
+
27
+ # Import PyTorch after environment setup
28
+ import torch
29
+ from diffusers import FluxFillPipeline
30
+
31
+ # Constants
32
+ MAX_SEED = np.iinfo(np.int32).max
33
+ MAX_IMAGE_SIZE = 2048
34
+
35
+ # Setting page config
36
+ st.set_page_config(
37
+ page_title="FLUX.1 Fill [dev]",
38
+ layout="wide"
39
+ )
40
+
41
+ # Title and description
42
+ st.markdown("""
43
+ # FLUX.1 Fill [dev]
44
+ 12B param rectified flow transformer structural conditioning tuned, guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
45
+ [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
46
+ """)
47
+
48
+ # Add simple instructions
49
+ st.sidebar.markdown("""
50
+ ## Important Setup Information
51
+
52
+ This app uses the FLUX.1-Fill-dev model which requires special access:
53
+
54
+ 1. Sign up/login at [Hugging Face](https://huggingface.co/)
55
+ 2. Request access to [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) by clicking 'Access repository'
56
+ 3. Wait for approval from model owners
57
+
58
+ ### For Hugging Face Spaces Setup:
59
+ 1. Go to your Space settings > Secrets
60
+ 2. Add a new secret with the name `HF_TOKEN`
61
+ 3. Set its value to your Hugging Face API token (found in your account settings)
62
+ """)
63
+
64
+ # Try to get a Hugging Face token from environment variables
65
+ def get_hf_token():
66
+ # Check common environment variable names for HF tokens
67
+ token_env_vars = [
68
+ 'HF_TOKEN',
69
+ 'HUGGINGFACE_TOKEN',
70
+ 'HUGGING_FACE_HUB_TOKEN',
71
+ 'HF_API_TOKEN',
72
+ 'HUGGINGFACE_API_TOKEN',
73
+ 'HUGGINGFACE_HUB_TOKEN'
74
+ ]
75
+
76
+ for env_var in token_env_vars:
77
+ if env_var in os.environ and os.environ[env_var].strip():
78
+ st.sidebar.success(f"Found token in {env_var}")
79
+ return os.environ[env_var].strip()
80
+
81
+ # If we're here, no token was found
82
+ st.sidebar.warning("No Hugging Face token found in environment variables")
83
+
84
+ return None
85
+
86
  @st.cache_resource(show_spinner=False)
87
  def load_model():
88
  """Load the model with a simplified approach using the required token"""
 
118
 
119
  Note: You can find your token at https://huggingface.co/settings/tokens
120
  """)
121
+ st.stop()
122
+
123
+ except Exception as e:
124
+ st.error(f"Failed to load model after all attempts: {e}")
125
+
126
+ if "401" in str(e) or "access" in str(e).lower() or "denied" in str(e).lower():
127
+ st.error("""
128
+ Access Denied: You need to:
129
+ 1. Request access to the model at https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev
130
+ 2. Set up your Hugging Face token in Spaces:
131
+ - Go to your Space settings > Secrets
132
+ - Add a new secret with name 'HF_TOKEN'
133
+ - Set its value to your Hugging Face API token
134
+ 3. Wait for approval from model owners
135
+
136
+ Note: You can find your token at https://huggingface.co/settings/tokens
137
+ """)
138
+ elif "Tried to instantiate class" in str(e):
139
+ st.error("""
140
+ PyTorch class initialization error. Try restarting the app.
141
+ If the error persists, try accessing the app from a different browser.
142
+ """)
143
+ st.stop()
144
+
145
+ # Initialize model section
146
+ with st.spinner("Loading model..."):
147
+ try:
148
+ pipe = load_model()
149
+ st.success("Model loaded successfully!")
150
+ except Exception as e:
151
+ st.error(f"Failed to load model: {str(e)}")
152
+ st.stop()
153
+
154
+ def calculate_optimal_dimensions(image: Image.Image):
155
+ # Extract the original dimensions
156
+ original_width, original_height = image.size
157
+
158
+ # Set constants
159
+ MIN_ASPECT_RATIO = 9 / 16
160
+ MAX_ASPECT_RATIO = 16 / 9
161
+ FIXED_DIMENSION = 1024
162
+
163
+ # Calculate the aspect ratio of the original image
164
+ original_aspect_ratio = original_width / original_height
165
+
166
+ # Determine which dimension to fix
167
+ if original_aspect_ratio > 1: # Wider than tall
168
+ width = FIXED_DIMENSION
169
+ height = round(FIXED_DIMENSION / original_aspect_ratio)
170
+ else: # Taller than wide
171
+ height = FIXED_DIMENSION
172
+ width = round(FIXED_DIMENSION * original_aspect_ratio)
173
+
174
+ # Ensure dimensions are multiples of 8
175
+ width = (width // 8) * 8
176
+ height = (height // 8) * 8
177
+
178
+ # Enforce aspect ratio limits
179
+ calculated_aspect_ratio = width / height
180
+ if calculated_aspect_ratio > MAX_ASPECT_RATIO:
181
+ width = (height * MAX_ASPECT_RATIO // 8) * 8
182
+ elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
183
+ height = (width / MIN_ASPECT_RATIO // 8) * 8
184
+
185
+ # Ensure width and height remain above the minimum dimensions
186
+ width = max(width, 576) if width == FIXED_DIMENSION else width
187
+ height = max(height, 576) if height == FIXED_DIMENSION else height
188
+
189
+ return width, height
190
+
191
+ # Create two columns for layout
192
+ col1, col2 = st.columns([1, 1])
193
+
194
+ with col1:
195
+ # Upload image
196
+ uploaded_file = st.file_uploader("Upload an image for inpainting", type=["jpg", "jpeg", "png"])
197
+
198
+ if uploaded_file is not None:
199
+ # Display the uploaded image
200
+ image = Image.open(uploaded_file).convert("RGB")
201
+ st.image(image, caption="Uploaded Image", use_container_width=True)
202
+
203
+ # Simple approach to create a mask - select a square area
204
+ st.write("Select an area to inpaint:")
205
+
206
+ # Get image dimensions
207
+ img_width, img_height = image.size
208
+
209
+ # Scale for display while maintaining aspect ratio
210
+ display_height = 600
211
+ display_width = int(img_width * (display_height / img_height))
212
+
213
+ # Create sliders for selecting the area
214
+ col_sliders1, col_sliders2 = st.columns(2)
215
+
216
+ with col_sliders1:
217
+ x1 = st.slider("Left edge (X1)", 0, img_width, img_width // 4)
218
+ y1 = st.slider("Top edge (Y1)", 0, img_height, img_height // 4)
219
+
220
+ with col_sliders2:
221
+ x2 = st.slider("Right edge (X2)", x1, img_width, min(x1 + img_width // 2, img_width))
222
+ y2 = st.slider("Bottom edge (Y2)", y1, img_height, min(y1 + img_height // 2, img_height))
223
+
224
+ # Create a copy of the image to show the mask
225
+ preview_img = image.copy()
226
+ preview_mask = Image.new("L", image.size, 0)
227
+
228
+ # Draw a white rectangle on the mask
229
+ from PIL import ImageDraw
230
+ draw = ImageDraw.Draw(preview_mask)
231
+ draw.rectangle([(x1, y1), (x2, y2)], fill=255)
232
+
233
+ # Show the mask on the image
234
+ masked_preview = image.copy()
235
+ # Add semi-transparent white overlay
236
+ overlay = Image.new("RGBA", image.size, (255, 255, 255, 128))
237
+ masked_preview.paste(overlay, (0, 0), preview_mask)
238
+
239
+ st.image(masked_preview, caption="Area to inpaint (white overlay)", use_container_width=True)
240
+
241
+ # Prompt input
242
+ prompt = st.text_input("Enter your prompt")
243
+
244
+ # Example prompts
245
+ examples = [
246
+ "a tiny astronaut hatching from an egg on the moon",
247
+ "a cat holding a sign that says hello world",
248
+ "an anime illustration of a wiener schnitzel",
249
+ ]
250
+
251
+ example_prompt = st.selectbox("Or select an example prompt", [""] + examples)
252
+ if example_prompt and not prompt:
253
+ prompt = example_prompt
254
+
255
+ # Advanced settings with expander
256
+ with st.expander("Advanced Settings"):
257
+ randomize_seed = st.checkbox("Randomize seed", value=True)
258
+
259
+ if not randomize_seed:
260
+ seed = st.slider("Seed", 0, MAX_SEED, 0)
261
+ else:
262
+ seed = random.randint(0, MAX_SEED)
263
+
264
+ guidance_scale = st.slider("Guidance Scale", 1.0, 30.0, 3.5, 0.5)
265
+ num_inference_steps = st.slider("Number of inference steps", 1, 50, 28)
266
+
267
+ # Run button
268
+ run_button = st.button("Generate")
269
+
270
+ with col2:
271
+ if uploaded_file is not None:
272
+ st.write("Result will appear here")
273
+
274
+ if run_button and prompt:
275
+ with st.spinner("Generating image..."):
276
+ # Create mask from rectangle coordinates
277
+ mask = Image.new("L", image.size, 0)
278
+ draw = ImageDraw.Draw(mask)
279
+ draw.rectangle([(x1, y1), (x2, y2)], fill=255)
280
+
281
+ # Calculate dimensions for generation
282
+ width, height = calculate_optimal_dimensions(image)
283
+
284
+ # Progress bar
285
+ progress_bar = st.progress(0)
286
+
287
+ # Generate the image
288
+ try:
289
+ # Set up progress bar updates
290
+ progress_text = st.empty()
291
+ debug_info = st.empty()
292
+
293
+ # Show parameters for debugging
294
+ debug_info.info(f"Model type: {pipe.__class__.__name__}")
295
+
296
+ # Update progress
297
+ progress_bar.progress(0.1)
298
+ progress_text.text("Preparing image and mask...")
299
+
300
+ # Make sure mask is in the right format
301
+ # Some models require masks where white (255) is the area to inpaint
302
+ mask_img = mask.convert("L")
303
+
304
+ # Prepare arguments - different models may have different parameter names
305
+ model_class_name = pipe.__class__.__name__
306
+
307
+ # Common parameters for all models
308
+ common_params = {
309
+ "prompt": prompt,
310
+ "image": image,
311
+ "mask_image": mask_img,
312
+ "num_inference_steps": num_inference_steps,
313
+ "generator": torch.Generator("cpu").manual_seed(seed)
314
+ }
315
+
316
+ # Add parameters for Flux model
317
+ common_params["guidance_scale"] = guidance_scale
318
+
319
+ # Try running generation with dimensions
320
+ try:
321
+ progress_text.text("Running generation...")
322
+ progress_bar.progress(0.2)
323
+
324
+ # First try with dimensions
325
+ common_params["height"] = int(height)
326
+ common_params["width"] = int(width)
327
+ result = pipe(**common_params)
328
+ except Exception as e:
329
+ debug_info.warning(f"First attempt failed: {str(e)}")
330
+ progress_text.text("Retrying with adjusted parameters...")
331
+
332
+ # Remove dimensions and try again
333
+ del common_params["height"]
334
+ del common_params["width"]
335
+ result = pipe(**common_params)
336
+
337
+ # Get the result image
338
+ result_image = result.images[0]
339
+
340
+ # Update final progress
341
+ progress_bar.progress(1.0)
342
+ progress_text.text("Complete!")
343
+ debug_info.empty() # Clear debug info
344
+
345
+ # Display the result
346
+ st.image(result_image, caption="Generated Result", use_column_width=True)
347
+
348
+ # Add download button
349
+ buf = io.BytesIO()
350
+ result_image.save(buf, format="PNG")
351
+ st.download_button(
352
+ label="Download result",
353
+ data=buf.getvalue(),
354
+ file_name="inpaint_result.png",
355
+ mime="image/png",
356
+ )
357
+
358
+ # Display used seed
359
+ st.write(f"Seed used: {seed}")
360
+
361
+ except Exception as e:
362
+ st.error(f"An error occurred during generation: {str(e)}")
363
+ st.error("Try adjusting the parameters or using a different image.")
364
+
365
+ # If no image is uploaded
366
+ if uploaded_file is None:
367
+ with col2:
368
+ st.write("Please upload an image first")