jdavis commited on
Commit
e97b433
·
verified ·
1 Parent(s): af8041d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -123
app.py CHANGED
@@ -2,46 +2,66 @@ import streamlit as st
2
  import numpy as np
3
  import torch
4
  import random
5
- import base64
6
- from io import BytesIO
7
  from diffusers import FluxFillPipeline
8
  from PIL import Image
9
- from streamlit_drawable_canvas import st_canvas
10
 
11
  # Constants
12
- MAX_SEED = np.int32().max
13
  MAX_IMAGE_SIZE = 2048
14
 
15
- # Function to convert PIL Image to data URL (to replace image_to_url)
16
- def image_to_base64_url(img):
17
- buffered = BytesIO()
18
- img.save(buffered, format="PNG")
19
- img_str = base64.b64encode(buffered.getvalue()).decode()
20
- return f"data:image/png;base64,{img_str}"
21
-
22
-
23
- # Authenticate with Hugging Face
24
- def authenticate():
25
- token = st.text_input("Enter your Hugging Face token:", type="password",
26
- help="Find your token at https://huggingface.co/settings/tokens")
27
- if token:
28
- try:
29
- login(token)
30
- st.success("Successfully authenticated with Hugging Face!")
31
- return True
32
- except Exception as e:
33
- st.error(f"Authentication failed: {e}")
34
- return False
35
- return False
36
-
37
- # Load model (wrapped in a function to load only when needed)
 
 
38
  @st.cache_resource
39
  def load_model():
40
- pipe = FluxFillPipeline.from_pretrained(
41
- "black-forest-labs/FLUX.1-Fill-dev",
42
- torch_dtype=torch.bfloat16
43
- ).to("cuda")
44
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def calculate_optimal_dimensions(image: Image.Image):
47
  # Extract the original dimensions
@@ -80,119 +100,118 @@ def calculate_optimal_dimensions(image: Image.Image):
80
 
81
  return width, height
82
 
83
- def infer(pipe, image, mask, prompt, seed, randomize_seed, guidance_scale, num_inference_steps):
84
- width, height = calculate_optimal_dimensions(image)
85
-
86
- if randomize_seed:
87
- seed = random.randint(0, MAX_SEED)
88
-
89
- with st.status("Generating image..."):
90
- result_image = pipe(
91
- prompt=prompt,
92
- image=image,
93
- mask_image=mask,
94
- height=height,
95
- width=width,
96
- guidance_scale=guidance_scale,
97
- num_inference_steps=num_inference_steps,
98
- generator=torch.Generator("cpu").manual_seed(seed)
99
- ).images[0]
100
-
101
- return result_image, seed
102
 
103
- def image_editor():
104
- uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
 
105
 
106
- if uploaded_image is not None:
107
- image = Image.open(uploaded_image).convert("RGB")
108
-
109
- # Display image for reference
110
  st.image(image, caption="Uploaded Image", use_column_width=True)
111
 
112
- # Create a base64 URL for the background image
113
- background_image_url = image_to_base64_url(image)
114
-
115
- # Create the canvas for drawing masks
116
  canvas_result = st_canvas(
117
- fill_color="rgba(255, 255, 255, 0.3)",
118
  stroke_width=10,
119
- stroke_color="#FFFFFF",
120
- background_image=background_image_url,
 
 
121
  height=600,
122
  drawing_mode="freedraw",
123
  key="canvas",
124
  )
125
 
126
- return {
127
- "background": image,
128
- "mask": Image.fromarray(canvas_result.image_data[:, :, -1], mode="L") if canvas_result.image_data is not None else None
129
- }
130
-
131
- return None
132
-
133
- def main():
134
- st.set_page_config(page_title="FLUX.1 Fill", layout="wide")
135
-
136
- st.markdown("""
137
- # FLUX.1 Fill [dev]
138
- 12B param rectified flow transformer structural conditioning tuned, guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
139
- [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]
140
- [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)]
141
- [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
142
- """)
143
-
144
- # Initialize model
145
- pipe = load_model()
146
-
147
- col1, col2 = st.columns([6, 6])
148
-
149
- with col1:
150
- st.subheader("Input")
151
- image_data = image_editor()
152
-
153
- prompt = st.text_input("Prompt", placeholder="Enter your prompt")
154
 
 
155
  examples = [
156
  "a tiny astronaut hatching from an egg on the moon",
157
  "a cat holding a sign that says hello world",
158
  "an anime illustration of a wiener schnitzel",
159
  ]
160
 
161
- selected_example = st.selectbox("Or try an example prompt:", [""] + examples)
162
- if selected_example:
163
- prompt = selected_example
164
-
 
165
  with st.expander("Advanced Settings"):
166
- seed = st.slider("Seed", 0, MAX_SEED, 0)
167
  randomize_seed = st.checkbox("Randomize seed", value=True)
168
 
 
 
 
 
 
169
  guidance_scale = st.slider("Guidance Scale", 1.0, 30.0, 3.5, 0.5)
170
  num_inference_steps = st.slider("Number of inference steps", 1, 50, 28)
171
 
172
- run_button = st.button("Run")
173
-
174
- with col2:
175
- st.subheader("Result")
176
- result_container = st.empty()
177
- seed_text = st.empty()
178
-
179
- if run_button and image_data and prompt:
180
- if image_data["mask"] is not None:
181
- result_image, used_seed = infer(
182
- pipe,
183
- image_data["background"],
184
- image_data["mask"],
185
- prompt,
186
- seed,
187
- randomize_seed,
188
- guidance_scale,
189
- num_inference_steps
190
- )
191
-
192
- result_container.image(result_image, caption="Generated Image", use_column_width=True)
193
- seed_text.text(f"Seed used: {used_seed}")
194
- else:
195
- st.error("Please draw a mask on the image")
196
 
197
- if __name__ == "__main__":
198
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import torch
4
  import random
 
 
5
  from diffusers import FluxFillPipeline
6
  from PIL import Image
7
+ import io
8
 
9
  # Constants
10
+ MAX_SEED = np.iinfo(np.int32).max
11
  MAX_IMAGE_SIZE = 2048
12
 
13
+ # Setting page config
14
+ st.set_page_config(
15
+ page_title="FLUX.1 Fill [dev]",
16
+ layout="wide"
17
+ )
18
+
19
+ # Title and description
20
+ st.markdown("""
21
+ # FLUX.1 Fill [dev]
22
+ 12B param rectified flow transformer structural conditioning tuned, guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
23
+ [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
24
+ """)
25
+
26
+ # Add simple instructions
27
+ st.sidebar.markdown("""
28
+ ## Important Setup Information
29
+
30
+ This app uses the FLUX.1-Fill-dev model which requires special access:
31
+
32
+ 1. Sign up/login at [Hugging Face](https://huggingface.co/)
33
+ 2. Run `huggingface-cli login` in your terminal (or add your token to Hugging Face Spaces secrets)
34
+ 3. Request access to [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) by clicking 'Access repository'
35
+ 4. Wait for approval from model owners
36
+ """)
37
+
38
  @st.cache_resource
39
  def load_model():
40
+ """Load the model using the Hugging Face CLI login approach"""
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ try:
43
+ # This should work if the user has done huggingface-cli login
44
+ # or if token is in HF_HOME/.huggingface/token
45
+ return FluxFillPipeline.from_pretrained(
46
+ "black-forest-labs/FLUX.1-Fill-dev",
47
+ torch_dtype=torch.bfloat16
48
+ ).to(device)
49
+ except Exception as e:
50
+ st.error(f"Error loading model: {str(e)}")
51
+ if "401 Client Error" in str(e):
52
+ st.error("""
53
+ Access Denied: You need to:
54
+ 1. Request access to the model at https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev
55
+ 2. Make sure you've run 'huggingface-cli login' or set up the token in Spaces secrets
56
+ 3. Wait for approval from model owners
57
+ """)
58
+ st.stop()
59
+
60
+ try:
61
+ pipe = load_model()
62
+ except Exception as e:
63
+ st.error(f"Failed to load model: {str(e)}")
64
+ st.stop()
65
 
66
  def calculate_optimal_dimensions(image: Image.Image):
67
  # Extract the original dimensions
 
100
 
101
  return width, height
102
 
103
+ # Create two columns for layout
104
+ col1, col2 = st.columns([1, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ with col1:
107
+ # Upload image
108
+ uploaded_file = st.file_uploader("Upload an image for inpainting", type=["jpg", "jpeg", "png"])
109
 
110
+ if uploaded_file is not None:
111
+ # Display the uploaded image
112
+ image = Image.open(uploaded_file).convert("RGB")
 
113
  st.image(image, caption="Uploaded Image", use_column_width=True)
114
 
115
+ # Canvas for creating mask
116
+ st.write("Draw on the image to create a mask for inpainting")
117
+ from streamlit_drawable_canvas import st_canvas
 
118
  canvas_result = st_canvas(
119
+ fill_color="white",
120
  stroke_width=10,
121
+ stroke_color="white",
122
+ background_color="transparent",
123
+ background_image=image,
124
+ update_streamlit=True,
125
  height=600,
126
  drawing_mode="freedraw",
127
  key="canvas",
128
  )
129
 
130
+ # Prompt input
131
+ prompt = st.text_input("Enter your prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # Example prompts
134
  examples = [
135
  "a tiny astronaut hatching from an egg on the moon",
136
  "a cat holding a sign that says hello world",
137
  "an anime illustration of a wiener schnitzel",
138
  ]
139
 
140
+ example_prompt = st.selectbox("Or select an example prompt", [""] + examples)
141
+ if example_prompt and not prompt:
142
+ prompt = example_prompt
143
+
144
+ # Advanced settings with expander
145
  with st.expander("Advanced Settings"):
 
146
  randomize_seed = st.checkbox("Randomize seed", value=True)
147
 
148
+ if not randomize_seed:
149
+ seed = st.slider("Seed", 0, MAX_SEED, 0)
150
+ else:
151
+ seed = random.randint(0, MAX_SEED)
152
+
153
  guidance_scale = st.slider("Guidance Scale", 1.0, 30.0, 3.5, 0.5)
154
  num_inference_steps = st.slider("Number of inference steps", 1, 50, 28)
155
 
156
+ # Run button
157
+ run_button = st.button("Generate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ with col2:
160
+ if uploaded_file is not None:
161
+ st.write("Result will appear here")
162
+
163
+ if run_button and prompt and canvas_result.image_data is not None:
164
+ with st.spinner("Generating image..."):
165
+ # Create mask from canvas
166
+ mask_data = canvas_result.image_data
167
+ mask = Image.fromarray(mask_data.astype(np.uint8)).convert("L")
168
+
169
+ # Calculate dimensions
170
+ width, height = calculate_optimal_dimensions(image)
171
+
172
+ # Progress bar
173
+ progress_bar = st.progress(0)
174
+
175
+ # Generate the image
176
+ def update_progress(step, total_steps):
177
+ progress_bar.progress(step / total_steps)
178
+
179
+ try:
180
+ result_image = pipe(
181
+ prompt=prompt,
182
+ image=image,
183
+ mask_image=mask,
184
+ height=int(height),
185
+ width=int(width),
186
+ guidance_scale=guidance_scale,
187
+ num_inference_steps=num_inference_steps,
188
+ generator=torch.Generator("cpu").manual_seed(seed),
189
+ callback=update_progress
190
+ ).images[0]
191
+
192
+ # Update final progress
193
+ progress_bar.progress(1.0)
194
+
195
+ # Display the result
196
+ st.image(result_image, caption="Generated Result", use_column_width=True)
197
+
198
+ # Add download button
199
+ buf = io.BytesIO()
200
+ result_image.save(buf, format="PNG")
201
+ st.download_button(
202
+ label="Download result",
203
+ data=buf.getvalue(),
204
+ file_name="flux_fill_result.png",
205
+ mime="image/png",
206
+ )
207
+
208
+ # Display used seed
209
+ st.write(f"Seed used: {seed}")
210
+
211
+ except Exception as e:
212
+ st.error(f"An error occurred: {str(e)}")
213
+
214
+ # If no image is uploaded
215
+ if uploaded_file is None:
216
+ with col2:
217
+ st.write("Please upload an image first")