jdavis commited on
Commit
8f4c53d
·
verified ·
1 Parent(s): 9b81d07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -371
app.py CHANGED
@@ -1,94 +1,6 @@
1
- import os
2
- import sys
3
-
4
- # Set critical environment variables first
5
- os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
6
- os.environ["WATCHDOG_OPTIONAL"] = "1"
7
- os.environ["PYTORCH_JIT"] = "0"
8
-
9
- # Import third party modules
10
- import streamlit as st
11
- import numpy as np
12
- import random
13
- from PIL import Image
14
- import io
15
- import time
16
-
17
- # Set up imports for huggingface_hub
18
- # Import what we can, but handle potential import errors
19
- try:
20
- from huggingface_hub import HfApi, HfFolder, login
21
- except ImportError as e:
22
- st.error(f"Error importing from huggingface_hub: {e}")
23
-
24
- # Configure Hugging Face cache and environment
25
- os.environ["HF_HOME"] = os.path.join(os.getcwd(), ".cache/huggingface")
26
-
27
- # Import PyTorch after environment setup
28
- import torch
29
- from diffusers import FluxFillPipeline
30
-
31
- # Constants
32
- MAX_SEED = np.iinfo(np.int32).max
33
- MAX_IMAGE_SIZE = 2048
34
-
35
- # Setting page config
36
- st.set_page_config(
37
- page_title="FLUX.1 Fill [dev]",
38
- layout="wide"
39
- )
40
-
41
- # Title and description
42
- st.markdown("""
43
- # FLUX.1 Fill [dev]
44
- 12B param rectified flow transformer structural conditioning tuned, guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
45
- [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
46
- """)
47
-
48
- # Add simple instructions
49
- st.sidebar.markdown("""
50
- ## Important Setup Information
51
-
52
- This app uses the FLUX.1-Fill-dev model which requires special access:
53
-
54
- 1. Sign up/login at [Hugging Face](https://huggingface.co/)
55
- 2. Request access to [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) by clicking 'Access repository'
56
- 3. Wait for approval from model owners
57
-
58
- ### For Hugging Face Spaces Setup:
59
- 1. Go to your Space settings > Secrets
60
- 2. Add a new secret with the name `HF_TOKEN`
61
- 3. Set its value to your Hugging Face API token (found in your account settings)
62
- """)
63
-
64
- # Try to get a Hugging Face token from environment variables
65
- def get_hf_token():
66
- # Check common environment variable names for HF tokens
67
- token_env_vars = [
68
- 'HF_TOKEN',
69
- 'HUGGINGFACE_TOKEN',
70
- 'HUGGING_FACE_HUB_TOKEN',
71
- 'HF_API_TOKEN',
72
- 'HUGGINGFACE_API_TOKEN',
73
- 'HUGGINGFACE_HUB_TOKEN'
74
- ]
75
-
76
- for env_var in token_env_vars:
77
- if env_var in os.environ and os.environ[env_var].strip():
78
- st.sidebar.success(f"Found token in {env_var}")
79
- return os.environ[env_var].strip()
80
-
81
- # If we're here, no token was found
82
- st.sidebar.warning("No Hugging Face token found in environment variables")
83
-
84
- return None
85
-
86
  @st.cache_resource(show_spinner=False)
87
  def load_model():
88
- """Load the model with the simplest approach possible"""
89
- # Set up basic logging
90
- st.info("Preparing to load FLUX.1-Fill-dev model...")
91
-
92
  # Get device
93
  device = "cuda" if torch.cuda.is_available() else "cpu"
94
  st.info(f"Using device: {device}")
@@ -97,62 +9,17 @@ def load_model():
97
  token = get_hf_token()
98
  st.info(f"Token available: {'Yes' if token else 'No'}")
99
 
100
- # Set up progress indicator
101
- progress = st.empty()
102
-
103
- # Ignore transformers warnings
104
- import transformers
105
- transformers.logging.set_verbosity_error()
106
-
107
- # Create 4 attempts with different approaches
108
  try:
109
- # Attempt 1: Just the basics
110
- progress.info("Loading model (attempt 1/4): Basic parameters")
111
- try:
112
- model = FluxFillPipeline.from_pretrained(
113
- "black-forest-labs/FLUX.1-Fill-dev",
114
- token=token
115
- )
116
- st.success("Model loaded successfully with basic parameters!")
117
- return model.to(device)
118
- except Exception as e1:
119
- progress.warning(f"Basic loading failed: {e1}")
120
-
121
- # Attempt 2: With use_auth_token
122
- progress.info("Loading model (attempt 2/4): Using use_auth_token")
123
- try:
124
- model = FluxFillPipeline.from_pretrained(
125
- "black-forest-labs/FLUX.1-Fill-dev",
126
- use_auth_token=token
127
- )
128
- st.success("Model loaded successfully with use_auth_token!")
129
- return model.to(device)
130
- except Exception as e2:
131
- progress.warning(f"Loading with use_auth_token failed: {e2}")
132
-
133
- # Attempt 3: With float32 (more compatible)
134
- progress.info("Loading model (attempt 3/4): Using float32 dtype")
135
- try:
136
- model = FluxFillPipeline.from_pretrained(
137
- "black-forest-labs/FLUX.1-Fill-dev",
138
- token=token,
139
- torch_dtype=torch.float32
140
- )
141
- st.success("Model loaded successfully with float32!")
142
- return model.to(device)
143
- except Exception as e3:
144
- progress.warning(f"Loading with float32 failed: {e3}")
145
-
146
- # Attempt 4: Minimal parameters
147
- progress.info("Loading model (attempt 4/4): Minimal approach")
148
  model = FluxFillPipeline.from_pretrained(
149
- "black-forest-labs/FLUX.1-Fill-dev"
 
 
150
  )
151
- st.success("Model loaded successfully with minimal parameters!")
152
  return model.to(device)
153
-
154
  except Exception as e:
155
- st.error(f"Failed to load model after all attempts: {e}")
156
 
157
  if "401" in str(e) or "access" in str(e).lower() or "denied" in str(e).lower():
158
  st.error("""
@@ -166,234 +33,4 @@ def load_model():
166
 
167
  Note: You can find your token at https://huggingface.co/settings/tokens
168
  """)
169
- elif "Tried to instantiate class" in str(e):
170
- st.error("""
171
- PyTorch class initialization error. Try restarting the app.
172
- If the error persists, try accessing the app from a different browser.
173
- """)
174
- st.stop()
175
-
176
- # Initialize model section
177
- with st.spinner("Loading model..."):
178
- try:
179
- pipe = load_model()
180
- st.success("Model loaded successfully!")
181
- except Exception as e:
182
- st.error(f"Failed to load model: {str(e)}")
183
- st.stop()
184
-
185
- def calculate_optimal_dimensions(image: Image.Image):
186
- # Extract the original dimensions
187
- original_width, original_height = image.size
188
-
189
- # Set constants
190
- MIN_ASPECT_RATIO = 9 / 16
191
- MAX_ASPECT_RATIO = 16 / 9
192
- FIXED_DIMENSION = 1024
193
-
194
- # Calculate the aspect ratio of the original image
195
- original_aspect_ratio = original_width / original_height
196
-
197
- # Determine which dimension to fix
198
- if original_aspect_ratio > 1: # Wider than tall
199
- width = FIXED_DIMENSION
200
- height = round(FIXED_DIMENSION / original_aspect_ratio)
201
- else: # Taller than wide
202
- height = FIXED_DIMENSION
203
- width = round(FIXED_DIMENSION * original_aspect_ratio)
204
-
205
- # Ensure dimensions are multiples of 8
206
- width = (width // 8) * 8
207
- height = (height // 8) * 8
208
-
209
- # Enforce aspect ratio limits
210
- calculated_aspect_ratio = width / height
211
- if calculated_aspect_ratio > MAX_ASPECT_RATIO:
212
- width = (height * MAX_ASPECT_RATIO // 8) * 8
213
- elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
214
- height = (width / MIN_ASPECT_RATIO // 8) * 8
215
-
216
- # Ensure width and height remain above the minimum dimensions
217
- width = max(width, 576) if width == FIXED_DIMENSION else width
218
- height = max(height, 576) if height == FIXED_DIMENSION else height
219
-
220
- return width, height
221
-
222
- # Create two columns for layout
223
- col1, col2 = st.columns([1, 1])
224
-
225
- with col1:
226
- # Upload image
227
- uploaded_file = st.file_uploader("Upload an image for inpainting", type=["jpg", "jpeg", "png"])
228
-
229
- if uploaded_file is not None:
230
- # Display the uploaded image
231
- image = Image.open(uploaded_file).convert("RGB")
232
- st.image(image, caption="Uploaded Image", use_container_width=True)
233
-
234
- # Simple approach to create a mask - select a square area
235
- st.write("Select an area to inpaint:")
236
-
237
- # Get image dimensions
238
- img_width, img_height = image.size
239
-
240
- # Scale for display while maintaining aspect ratio
241
- display_height = 600
242
- display_width = int(img_width * (display_height / img_height))
243
-
244
- # Create sliders for selecting the area
245
- col_sliders1, col_sliders2 = st.columns(2)
246
-
247
- with col_sliders1:
248
- x1 = st.slider("Left edge (X1)", 0, img_width, img_width // 4)
249
- y1 = st.slider("Top edge (Y1)", 0, img_height, img_height // 4)
250
-
251
- with col_sliders2:
252
- x2 = st.slider("Right edge (X2)", x1, img_width, min(x1 + img_width // 2, img_width))
253
- y2 = st.slider("Bottom edge (Y2)", y1, img_height, min(y1 + img_height // 2, img_height))
254
-
255
- # Create a copy of the image to show the mask
256
- preview_img = image.copy()
257
- preview_mask = Image.new("L", image.size, 0)
258
-
259
- # Draw a white rectangle on the mask
260
- from PIL import ImageDraw
261
- draw = ImageDraw.Draw(preview_mask)
262
- draw.rectangle([(x1, y1), (x2, y2)], fill=255)
263
-
264
- # Show the mask on the image
265
- masked_preview = image.copy()
266
- # Add semi-transparent white overlay
267
- overlay = Image.new("RGBA", image.size, (255, 255, 255, 128))
268
- masked_preview.paste(overlay, (0, 0), preview_mask)
269
-
270
- st.image(masked_preview, caption="Area to inpaint (white overlay)", use_container_width=True)
271
-
272
- # Prompt input
273
- prompt = st.text_input("Enter your prompt")
274
-
275
- # Example prompts
276
- examples = [
277
- "a tiny astronaut hatching from an egg on the moon",
278
- "a cat holding a sign that says hello world",
279
- "an anime illustration of a wiener schnitzel",
280
- ]
281
-
282
- example_prompt = st.selectbox("Or select an example prompt", [""] + examples)
283
- if example_prompt and not prompt:
284
- prompt = example_prompt
285
-
286
- # Advanced settings with expander
287
- with st.expander("Advanced Settings"):
288
- randomize_seed = st.checkbox("Randomize seed", value=True)
289
-
290
- if not randomize_seed:
291
- seed = st.slider("Seed", 0, MAX_SEED, 0)
292
- else:
293
- seed = random.randint(0, MAX_SEED)
294
-
295
- guidance_scale = st.slider("Guidance Scale", 1.0, 30.0, 3.5, 0.5)
296
- num_inference_steps = st.slider("Number of inference steps", 1, 50, 28)
297
-
298
- # Run button
299
- run_button = st.button("Generate")
300
-
301
- with col2:
302
- if uploaded_file is not None:
303
- st.write("Result will appear here")
304
-
305
- if run_button and prompt:
306
- with st.spinner("Generating image..."):
307
- # Create mask from rectangle coordinates
308
- mask = Image.new("L", image.size, 0)
309
- draw = ImageDraw.Draw(mask)
310
- draw.rectangle([(x1, y1), (x2, y2)], fill=255)
311
-
312
- # Calculate dimensions for generation
313
- width, height = calculate_optimal_dimensions(image)
314
-
315
- # Progress bar
316
- progress_bar = st.progress(0)
317
-
318
- # Generate the image
319
- try:
320
- # Set up progress bar updates
321
- progress_text = st.empty()
322
- debug_info = st.empty()
323
-
324
- # Show parameters for debugging
325
- debug_info.info(f"Model type: {pipe.__class__.__name__}")
326
-
327
- # Update progress
328
- progress_bar.progress(0.1)
329
- progress_text.text("Preparing image and mask...")
330
-
331
- # Make sure mask is in the right format
332
- # Some models require masks where white (255) is the area to inpaint
333
- mask_img = mask.convert("L")
334
-
335
- # Prepare arguments - different models may have different parameter names
336
- model_class_name = pipe.__class__.__name__
337
-
338
- # Common parameters for all models
339
- common_params = {
340
- "prompt": prompt,
341
- "image": image,
342
- "mask_image": mask_img,
343
- "num_inference_steps": num_inference_steps,
344
- "generator": torch.Generator("cpu").manual_seed(seed)
345
- }
346
-
347
- # Add parameters for Flux model
348
- common_params["guidance_scale"] = guidance_scale
349
-
350
- # Try running generation with dimensions
351
- try:
352
- progress_text.text("Running generation...")
353
- progress_bar.progress(0.2)
354
-
355
- # First try with dimensions
356
- common_params["height"] = int(height)
357
- common_params["width"] = int(width)
358
- result = pipe(**common_params)
359
- except Exception as e:
360
- debug_info.warning(f"First attempt failed: {str(e)}")
361
- progress_text.text("Retrying with adjusted parameters...")
362
-
363
- # Remove dimensions and try again
364
- del common_params["height"]
365
- del common_params["width"]
366
- result = pipe(**common_params)
367
-
368
- # Get the result image
369
- result_image = result.images[0]
370
-
371
- # Update final progress
372
- progress_bar.progress(1.0)
373
- progress_text.text("Complete!")
374
- debug_info.empty() # Clear debug info
375
-
376
- # Display the result
377
- st.image(result_image, caption="Generated Result", use_column_width=True)
378
-
379
- # Add download button
380
- buf = io.BytesIO()
381
- result_image.save(buf, format="PNG")
382
- st.download_button(
383
- label="Download result",
384
- data=buf.getvalue(),
385
- file_name="inpaint_result.png",
386
- mime="image/png",
387
- )
388
-
389
- # Display used seed
390
- st.write(f"Seed used: {seed}")
391
-
392
- except Exception as e:
393
- st.error(f"An error occurred during generation: {str(e)}")
394
- st.error("Try adjusting the parameters or using a different image.")
395
-
396
- # If no image is uploaded
397
- if uploaded_file is None:
398
- with col2:
399
- st.write("Please upload an image first")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  @st.cache_resource(show_spinner=False)
2
  def load_model():
3
+ """Load the model with a simplified approach using the required token"""
 
 
 
4
  # Get device
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
6
  st.info(f"Using device: {device}")
 
9
  token = get_hf_token()
10
  st.info(f"Token available: {'Yes' if token else 'No'}")
11
 
 
 
 
 
 
 
 
 
12
  try:
13
+ # Use the same parameters as the Gradio version, just with token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  model = FluxFillPipeline.from_pretrained(
15
+ "black-forest-labs/FLUX.1-Fill-dev",
16
+ token=token,
17
+ torch_dtype=torch.bfloat16
18
  )
19
+ st.success("Model loaded successfully!")
20
  return model.to(device)
 
21
  except Exception as e:
22
+ st.error(f"Failed to load model: {e}")
23
 
24
  if "401" in str(e) or "access" in str(e).lower() or "denied" in str(e).lower():
25
  st.error("""
 
33
 
34
  Note: You can find your token at https://huggingface.co/settings/tokens
35
  """)
36
+ st.stop()