jdavis commited on
Commit
f07e450
·
verified ·
1 Parent(s): 657d0cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -131
app.py CHANGED
@@ -4,51 +4,27 @@ import torch
4
  import random
5
  from diffusers import FluxFillPipeline
6
  from PIL import Image
7
- import io
 
 
 
 
8
 
9
  # Constants
10
  MAX_SEED = np.iinfo(np.int32).max
11
  MAX_IMAGE_SIZE = 2048
12
 
13
- # Setting page config
14
- st.set_page_config(
15
- page_title="FLUX.1 Fill [dev]",
16
- layout="wide"
17
- )
18
-
19
- # Title and description
20
- st.markdown("""
21
- # FLUX.1 Fill [dev]
22
- 12B param rectified flow transformer structural conditioning tuned, guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
23
- [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
24
- """)
25
-
26
- # Get Hugging Face token
27
- hf_token = st.text_input("Enter your Hugging Face token (needed to access FLUX.1-Fill-dev)", type="password")
28
- if not hf_token:
29
- st.warning("You need to provide your Hugging Face token to access this model")
30
- st.markdown("1. Sign up/login at [Hugging Face](https://huggingface.co/)")
31
- st.markdown("2. Generate a token at https://huggingface.co/settings/tokens")
32
- st.markdown("3. Request access to [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev)")
33
- st.stop()
34
-
35
- # Load the model
36
  @st.cache_resource
37
- def load_model(token):
38
- device = "cuda" if torch.cuda.is_available() else "cpu"
39
- try:
40
- return FluxFillPipeline.from_pretrained(
41
- "black-forest-labs/FLUX.1-Fill-dev",
42
- torch_dtype=torch.bfloat16,
43
- use_auth_token=token
44
- ).to(device)
45
- except Exception as e:
46
- st.error(f"Error loading model: {str(e)}")
47
- st.stop()
48
-
49
- pipe = load_model(hf_token)
50
-
51
- def calculate_optimal_dimensions(image: Image.Image):
52
  # Extract the original dimensions
53
  original_width, original_height = image.size
54
 
@@ -83,120 +59,154 @@ def calculate_optimal_dimensions(image: Image.Image):
83
  width = max(width, 576) if width == FIXED_DIMENSION else width
84
  height = max(height, 576) if height == FIXED_DIMENSION else height
85
 
86
- return width, height
87
-
88
- # Create two columns for layout
89
- col1, col2 = st.columns([1, 1])
90
 
91
- with col1:
92
- # Upload image
93
- uploaded_file = st.file_uploader("Upload an image for inpainting", type=["jpg", "jpeg", "png"])
94
 
95
  if uploaded_file is not None:
96
- # Display the uploaded image
97
  image = Image.open(uploaded_file).convert("RGB")
98
- st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
 
 
 
 
 
99
 
100
- # Canvas for creating mask
101
- st.write("Draw on the image to create a mask for inpainting")
 
102
  from streamlit_drawable_canvas import st_canvas
 
103
  canvas_result = st_canvas(
104
- fill_color="white",
105
  stroke_width=10,
106
- stroke_color="white",
107
- background_color="transparent",
108
  background_image=image,
109
- update_streamlit=True,
110
- height=600,
111
  drawing_mode="freedraw",
112
  key="canvas",
113
  )
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # Prompt input
116
  prompt = st.text_input("Enter your prompt")
117
 
118
  # Example prompts
119
- examples = [
120
- "a tiny astronaut hatching from an egg on the moon",
121
- "a cat holding a sign that says hello world",
122
- "an anime illustration of a wiener schnitzel",
123
- ]
124
-
125
- example_prompt = st.selectbox("Or select an example prompt", [""] + examples)
126
- if example_prompt and not prompt:
127
- prompt = example_prompt
128
-
129
- # Advanced settings with expander
130
  with st.expander("Advanced Settings"):
 
131
  randomize_seed = st.checkbox("Randomize seed", value=True)
132
 
133
- if not randomize_seed:
134
- seed = st.slider("Seed", 0, MAX_SEED, 0)
135
- else:
136
- seed = random.randint(0, MAX_SEED)
137
-
138
  guidance_scale = st.slider("Guidance Scale", 1.0, 30.0, 3.5, 0.5)
139
- num_inference_steps = st.slider("Number of inference steps", 1, 50, 28)
140
-
141
  # Run button
142
- run_button = st.button("Generate")
143
-
144
- with col2:
145
- if uploaded_file is not None:
146
- st.write("Result will appear here")
147
-
148
- if run_button and prompt and canvas_result.image_data is not None:
149
- with st.spinner("Generating image..."):
150
- # Create mask from canvas
151
- mask_data = canvas_result.image_data
152
- mask = Image.fromarray(mask_data.astype(np.uint8)).convert("L")
153
-
154
- # Calculate dimensions
155
- width, height = calculate_optimal_dimensions(image)
 
 
 
156
 
157
- # Progress bar
158
- progress_bar = st.progress(0)
 
159
 
160
- # Generate the image
161
- def update_progress(step, total_steps):
162
- progress_bar.progress(step / total_steps)
163
 
164
- try:
165
- result_image = pipe(
166
- prompt=prompt,
167
- image=image,
168
- mask_image=mask,
169
- height=int(height),
170
- width=int(width),
171
- guidance_scale=guidance_scale,
172
- num_inference_steps=num_inference_steps,
173
- generator=torch.Generator("cpu").manual_seed(seed),
174
- callback=update_progress
175
- ).images[0]
176
-
177
- # Update final progress
178
- progress_bar.progress(1.0)
179
-
180
- # Display the result
181
- st.image(result_image, caption="Generated Result", use_column_width=True)
182
-
183
- # Add download button
184
- buf = io.BytesIO()
185
- result_image.save(buf, format="PNG")
186
- st.download_button(
187
- label="Download result",
188
- data=buf.getvalue(),
189
- file_name="flux_fill_result.png",
190
- mime="image/png",
191
- )
192
-
193
- # Display used seed
194
- st.write(f"Seed used: {seed}")
195
-
196
- except Exception as e:
197
- st.error(f"An error occurred: {str(e)}")
198
-
199
- # If no image is uploaded
200
- if uploaded_file is None:
201
- with col2:
202
- st.write("Please upload an image first")
 
4
  import random
5
  from diffusers import FluxFillPipeline
6
  from PIL import Image
7
+ from io import BytesIO
8
+ import base64
9
+
10
+ # Set page configuration
11
+ st.set_page_config(page_title="FLUX.1 Fill [dev]", layout="wide")
12
 
13
  # Constants
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
16
 
17
+ # Initialize the model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @st.cache_resource
19
+ def load_model():
20
+ pipe = FluxFillPipeline.from_pretrained(
21
+ "black-forest-labs/FLUX.1-Fill-dev",
22
+ torch_dtype=torch.bfloat16
23
+ ).to("cuda")
24
+ return pipe
25
+
26
+ # Function to calculate optimal dimensions
27
+ def calculate_optimal_dimensions(image):
 
 
 
 
 
 
28
  # Extract the original dimensions
29
  original_width, original_height = image.size
30
 
 
59
  width = max(width, 576) if width == FIXED_DIMENSION else width
60
  height = max(height, 576) if height == FIXED_DIMENSION else height
61
 
62
+ return int(width), int(height)
 
 
 
63
 
64
+ # Custom component for image editing and masking
65
+ def image_editor():
66
+ uploaded_file = st.file_uploader("Upload an image for inpainting", type=["png", "jpg", "jpeg"])
67
 
68
  if uploaded_file is not None:
 
69
  image = Image.open(uploaded_file).convert("RGB")
70
+ width, height = image.size
71
+
72
+ # Create a placeholder for the canvas
73
+ canvas_placeholder = st.empty()
74
+
75
+ # Display instructions
76
+ st.markdown("### Draw a mask on the image")
77
+ st.markdown("Draw on the areas you want to edit (white areas will be inpainted)")
78
 
79
+ # Create canvas for mask drawing
80
+ # Note: We need to import the components within function to avoid import errors
81
+ import streamlit_drawable_canvas
82
  from streamlit_drawable_canvas import st_canvas
83
+
84
  canvas_result = st_canvas(
85
+ fill_color="rgba(255, 255, 255, 0.3)",
86
  stroke_width=10,
87
+ stroke_color="#FFFFFF",
 
88
  background_image=image,
89
+ height=min(600, height),
90
+ width=min(1000, width),
91
  drawing_mode="freedraw",
92
  key="canvas",
93
  )
94
 
95
+ if canvas_result.image_data is not None:
96
+ # Extract mask from canvas
97
+ mask_data = canvas_result.image_data
98
+ mask = Image.fromarray((mask_data[:, :, 3] > 0).astype(np.uint8) * 255)
99
+
100
+ return {
101
+ "background": image,
102
+ "mask": mask,
103
+ "has_mask": (mask.getextrema()[1] > 0)
104
+ }
105
+
106
+ return None
107
+
108
+ # Function to run inference
109
+ def run_inference(image_data, prompt, seed, randomize_seed, guidance_scale, num_inference_steps):
110
+ if randomize_seed:
111
+ seed = random.randint(0, MAX_SEED)
112
+
113
+ pipe = load_model()
114
+ width, height = calculate_optimal_dimensions(image_data["background"])
115
+
116
+ with st.spinner("Generating image..."):
117
+ generated_image = pipe(
118
+ prompt=prompt,
119
+ image=image_data["background"],
120
+ mask_image=image_data["mask"],
121
+ height=height,
122
+ width=width,
123
+ guidance_scale=guidance_scale,
124
+ num_inference_steps=num_inference_steps,
125
+ generator=torch.Generator("cpu").manual_seed(seed)
126
+ ).images[0]
127
+
128
+ return generated_image, seed
129
+
130
+ # Example prompts
131
+ examples = [
132
+ "a tiny astronaut hatching from an egg on the moon",
133
+ "a cat holding a sign that says hello world",
134
+ "an anime illustration of a wiener schnitzel",
135
+ ]
136
+
137
+ # Main UI
138
+ def main():
139
+ # Header
140
+ st.title("FLUX.1 Fill [dev]")
141
+ st.markdown("""
142
+ 12B param rectified flow transformer structural conditioning tuned, guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
143
+ [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
144
+ """)
145
+
146
+ # Create columns for layout
147
+ col1, col2 = st.columns([1, 1])
148
+
149
+ with col1:
150
+ # Image editor
151
+ image_data = image_editor()
152
+
153
  # Prompt input
154
  prompt = st.text_input("Enter your prompt")
155
 
156
  # Example prompts
157
+ st.subheader("Example Prompts")
158
+ for i, example in enumerate(examples):
159
+ if st.button(f"Use Example: {example}", key=f"example_{i}"):
160
+ st.session_state.prompt = example
161
+ st.experimental_rerun()
162
+
163
+ # Advanced settings
 
 
 
 
164
  with st.expander("Advanced Settings"):
165
+ seed = st.slider("Seed", 0, MAX_SEED, 0)
166
  randomize_seed = st.checkbox("Randomize seed", value=True)
167
 
 
 
 
 
 
168
  guidance_scale = st.slider("Guidance Scale", 1.0, 30.0, 3.5, 0.5)
169
+ num_inference_steps = st.slider("Number of inference steps", 1, 50, 28, 1)
170
+
171
  # Run button
172
+ run_button = st.button("Run")
173
+
174
+ # Result display in second column
175
+ with col2:
176
+ # Check if we need to run inference
177
+ if image_data is not None and prompt and run_button:
178
+ if not image_data.get("has_mask", False):
179
+ st.error("Please draw a mask on the image")
180
+ else:
181
+ result_image, result_seed = run_inference(
182
+ image_data,
183
+ prompt,
184
+ seed,
185
+ randomize_seed,
186
+ guidance_scale,
187
+ num_inference_steps
188
+ )
189
 
190
+ # Update seed if it was randomized
191
+ if randomize_seed:
192
+ st.session_state.seed = result_seed
193
 
194
+ # Display result
195
+ st.subheader("Generated Image")
196
+ st.image(result_image, use_column_width=True)
197
 
198
+ # Download button
199
+ buf = BytesIO()
200
+ result_image.save(buf, format="PNG")
201
+ byte_im = buf.getvalue()
202
+ download_button_str = f"""
203
+ <a href="data:image/png;base64,{base64.b64encode(byte_im).decode()}" download="generated_image.png">
204
+ <div style="display: inline-flex; align-items: center; background-color: #4CAF50; color: white; padding: 0.5em 1em; border-radius: 4px; text-decoration: none; font-weight: bold; margin-top: 10px;">
205
+ ⬇️ Download Image
206
+ </div>
207
+ </a>
208
+ """
209
+ st.markdown(download_button_str, unsafe_allow_html=True)
210
+
211
+ if __name__ == "__main__":
212
+ main()