jdavis commited on
Commit
e24f1da
·
verified ·
1 Parent(s): 1941ddf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import torch
4
+ import random
5
+ from diffusers import FluxFillPipeline
6
+ from PIL import Image
7
+ import io
8
+
9
+ # Constants
10
+ MAX_SEED = np.iinfo(np.int32).max
11
+ MAX_IMAGE_SIZE = 2048
12
+
13
+ # Setting page config
14
+ st.set_page_config(
15
+ page_title="FLUX.1 Fill [dev]",
16
+ layout="wide"
17
+ )
18
+
19
+ # Title and description
20
+ st.markdown("""
21
+ # FLUX.1 Fill [dev]
22
+ 12B param rectified flow transformer structural conditioning tuned, guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
23
+ [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
24
+ """)
25
+
26
+ # Get Hugging Face token
27
+ hf_token = st.text_input("Enter your Hugging Face token (needed to access FLUX.1-Fill-dev)", type="password")
28
+ if not hf_token:
29
+ st.warning("You need to provide your Hugging Face token to access this model")
30
+ st.markdown("1. Sign up/login at [Hugging Face](https://huggingface.co/)")
31
+ st.markdown("2. Generate a token at https://huggingface.co/settings/tokens")
32
+ st.markdown("3. Request access to [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev)")
33
+ st.stop()
34
+
35
+ # Load the model
36
+ @st.cache_resource
37
+ def load_model(token):
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ try:
40
+ return FluxFillPipeline.from_pretrained(
41
+ "black-forest-labs/FLUX.1-Fill-dev",
42
+ torch_dtype=torch.bfloat16,
43
+ use_auth_token=token
44
+ ).to(device)
45
+ except Exception as e:
46
+ st.error(f"Error loading model: {str(e)}")
47
+ st.stop()
48
+
49
+ pipe = load_model(hf_token)
50
+
51
+ def calculate_optimal_dimensions(image: Image.Image):
52
+ # Extract the original dimensions
53
+ original_width, original_height = image.size
54
+
55
+ # Set constants
56
+ MIN_ASPECT_RATIO = 9 / 16
57
+ MAX_ASPECT_RATIO = 16 / 9
58
+ FIXED_DIMENSION = 1024
59
+
60
+ # Calculate the aspect ratio of the original image
61
+ original_aspect_ratio = original_width / original_height
62
+
63
+ # Determine which dimension to fix
64
+ if original_aspect_ratio > 1: # Wider than tall
65
+ width = FIXED_DIMENSION
66
+ height = round(FIXED_DIMENSION / original_aspect_ratio)
67
+ else: # Taller than wide
68
+ height = FIXED_DIMENSION
69
+ width = round(FIXED_DIMENSION * original_aspect_ratio)
70
+
71
+ # Ensure dimensions are multiples of 8
72
+ width = (width // 8) * 8
73
+ height = (height // 8) * 8
74
+
75
+ # Enforce aspect ratio limits
76
+ calculated_aspect_ratio = width / height
77
+ if calculated_aspect_ratio > MAX_ASPECT_RATIO:
78
+ width = (height * MAX_ASPECT_RATIO // 8) * 8
79
+ elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
80
+ height = (width / MIN_ASPECT_RATIO // 8) * 8
81
+
82
+ # Ensure width and height remain above the minimum dimensions
83
+ width = max(width, 576) if width == FIXED_DIMENSION else width
84
+ height = max(height, 576) if height == FIXED_DIMENSION else height
85
+
86
+ return width, height
87
+
88
+ # Create two columns for layout
89
+ col1, col2 = st.columns([1, 1])
90
+
91
+ with col1:
92
+ # Upload image
93
+ uploaded_file = st.file_uploader("Upload an image for inpainting", type=["jpg", "jpeg", "png"])
94
+
95
+ if uploaded_file is not None:
96
+ # Display the uploaded image
97
+ image = Image.open(uploaded_file).convert("RGB")
98
+ st.image(image, caption="Uploaded Image", use_column_width=True)
99
+
100
+ # Canvas for creating mask
101
+ st.write("Draw on the image to create a mask for inpainting")
102
+ from streamlit_drawable_canvas import st_canvas
103
+ canvas_result = st_canvas(
104
+ fill_color="white",
105
+ stroke_width=10,
106
+ stroke_color="white",
107
+ background_color="transparent",
108
+ background_image=image,
109
+ update_streamlit=True,
110
+ height=600,
111
+ drawing_mode="freedraw",
112
+ key="canvas",
113
+ )
114
+
115
+ # Prompt input
116
+ prompt = st.text_input("Enter your prompt")
117
+
118
+ # Example prompts
119
+ examples = [
120
+ "a tiny astronaut hatching from an egg on the moon",
121
+ "a cat holding a sign that says hello world",
122
+ "an anime illustration of a wiener schnitzel",
123
+ ]
124
+
125
+ example_prompt = st.selectbox("Or select an example prompt", [""] + examples)
126
+ if example_prompt and not prompt:
127
+ prompt = example_prompt
128
+
129
+ # Advanced settings with expander
130
+ with st.expander("Advanced Settings"):
131
+ randomize_seed = st.checkbox("Randomize seed", value=True)
132
+
133
+ if not randomize_seed:
134
+ seed = st.slider("Seed", 0, MAX_SEED, 0)
135
+ else:
136
+ seed = random.randint(0, MAX_SEED)
137
+
138
+ guidance_scale = st.slider("Guidance Scale", 1.0, 30.0, 3.5, 0.5)
139
+ num_inference_steps = st.slider("Number of inference steps", 1, 50, 28)
140
+
141
+ # Run button
142
+ run_button = st.button("Generate")
143
+
144
+ with col2:
145
+ if uploaded_file is not None:
146
+ st.write("Result will appear here")
147
+
148
+ if run_button and prompt and canvas_result.image_data is not None:
149
+ with st.spinner("Generating image..."):
150
+ # Create mask from canvas
151
+ mask_data = canvas_result.image_data
152
+ mask = Image.fromarray(mask_data.astype(np.uint8)).convert("L")
153
+
154
+ # Calculate dimensions
155
+ width, height = calculate_optimal_dimensions(image)
156
+
157
+ # Progress bar
158
+ progress_bar = st.progress(0)
159
+
160
+ # Generate the image
161
+ def update_progress(step, total_steps):
162
+ progress_bar.progress(step / total_steps)
163
+
164
+ try:
165
+ result_image = pipe(
166
+ prompt=prompt,
167
+ image=image,
168
+ mask_image=mask,
169
+ height=int(height),
170
+ width=int(width),
171
+ guidance_scale=guidance_scale,
172
+ num_inference_steps=num_inference_steps,
173
+ generator=torch.Generator("cpu").manual_seed(seed),
174
+ callback=update_progress
175
+ ).images[0]
176
+
177
+ # Update final progress
178
+ progress_bar.progress(1.0)
179
+
180
+ # Display the result
181
+ st.image(result_image, caption="Generated Result", use_column_width=True)
182
+
183
+ # Add download button
184
+ buf = io.BytesIO()
185
+ result_image.save(buf, format="PNG")
186
+ st.download_button(
187
+ label="Download result",
188
+ data=buf.getvalue(),
189
+ file_name="flux_fill_result.png",
190
+ mime="image/png",
191
+ )
192
+
193
+ # Display used seed
194
+ st.write(f"Seed used: {seed}")
195
+
196
+ except Exception as e:
197
+ st.error(f"An error occurred: {str(e)}")
198
+
199
+ # If no image is uploaded
200
+ if uploaded_file is None:
201
+ with col2:
202
+ st.write("Please upload an image first")