multimodalart HF Staff commited on
Commit
8d760b2
·
verified ·
1 Parent(s): 850d0d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -11
app.py CHANGED
@@ -23,6 +23,9 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 1024
25
 
 
 
 
26
  hf_client = InferenceClient(
27
  api_key=os.environ.get("HF_TOKEN"),
28
  )
@@ -78,6 +81,13 @@ pipe = Flux2Pipeline.from_pretrained(
78
  transformer=dit,
79
  torch_dtype=torch.bfloat16
80
  )
 
 
 
 
 
 
 
81
  pipe.to(device)
82
 
83
  # Pull pre-compiled Flux2 Transformer blocks from HF hub
@@ -157,14 +167,17 @@ def update_dimensions_from_image(image_list):
157
 
158
  return new_width, new_height
159
 
160
- # Updated duration function to match generate_image arguments (including progress)
161
- def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
162
  num_images = 0 if image_list is None else len(image_list)
163
  step_duration = 1 + 0.8 * num_images
 
 
 
164
  return max(65, num_inference_steps * step_duration + 10)
165
 
166
  @spaces.GPU(duration=get_duration)
167
- def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
168
  # Move embeddings to GPU only when inside the GPU decorated function
169
  prompt_embeds = prompt_embeds.to(device)
170
 
@@ -173,13 +186,19 @@ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps
173
  pipe_kwargs = {
174
  "prompt_embeds": prompt_embeds,
175
  "image": image_list,
176
- "num_inference_steps": num_inference_steps,
177
  "guidance_scale": guidance_scale,
178
  "generator": generator,
179
  "width": width,
180
  "height": height,
181
  }
182
 
 
 
 
 
 
 
 
183
  # Progress bar for the actual generation steps
184
  if progress:
185
  progress(0, desc="Starting generation...")
@@ -187,7 +206,7 @@ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps
187
  image = pipe(**pipe_kwargs).images[0]
188
  return image
189
 
190
- def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
191
 
192
  if randomize_seed:
193
  seed = random.randint(0, MAX_SEED)
@@ -221,7 +240,8 @@ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024,
221
  height,
222
  num_inference_steps,
223
  guidance_scale,
224
- seed,
 
225
  progress
226
  )
227
 
@@ -252,8 +272,8 @@ css="""
252
  with gr.Blocks() as demo:
253
 
254
  with gr.Column(elem_id="col-container"):
255
- gr.Markdown(f"""# FLUX.2 [dev]
256
- FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and combining images based on text instructions model [[model](https://huggingface.co/black-forest-labs/FLUX.2-dev)], [[blog](https://bfl.ai/blog/flux-2)]
257
  """)
258
  with gr.Row():
259
  with gr.Column():
@@ -278,6 +298,12 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
278
  )
279
 
280
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
281
  prompt_upsampling = gr.Checkbox(
282
  label="Prompt Upsampling",
283
  value=True,
@@ -315,7 +341,7 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
315
  with gr.Row():
316
 
317
  num_inference_steps = gr.Slider(
318
- label="Number of inference steps",
319
  minimum=1,
320
  maximum=100,
321
  step=1,
@@ -327,7 +353,7 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
327
  minimum=0.0,
328
  maximum=10.0,
329
  step=0.1,
330
- value=4,
331
  )
332
 
333
 
@@ -363,7 +389,7 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
363
  gr.on(
364
  triggers=[run_button.click, prompt.submit],
365
  fn=infer,
366
- inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling],
367
  outputs=[result, seed]
368
  )
369
 
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 1024
25
 
26
+ # Pre-shifted custom sigmas for 8-step turbo inference
27
+ TURBO_SIGMAS = [1.0, 0.6509, 0.4374, 0.2932, 0.1893, 0.1108, 0.0495, 0.00031]
28
+
29
  hf_client = InferenceClient(
30
  api_key=os.environ.get("HF_TOKEN"),
31
  )
 
81
  transformer=dit,
82
  torch_dtype=torch.bfloat16
83
  )
84
+
85
+ # Load the Turbo LoRA
86
+ pipe.load_lora_weights(
87
+ "fal/FLUX.2-Turbo",
88
+ weight_name="flux.2-turbo-lora.safetensors"
89
+ )
90
+
91
  pipe.to(device)
92
 
93
  # Pull pre-compiled Flux2 Transformer blocks from HF hub
 
167
 
168
  return new_width, new_height
169
 
170
+ # Updated duration function for Turbo (much faster with fewer steps)
171
+ def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, use_turbo, progress=gr.Progress(track_tqdm=True)):
172
  num_images = 0 if image_list is None else len(image_list)
173
  step_duration = 1 + 0.8 * num_images
174
+ # Turbo mode uses fewer steps, so shorter duration
175
+ if use_turbo:
176
+ return max(30, 8 * step_duration + 10) # Fixed 8 steps for turbo
177
  return max(65, num_inference_steps * step_duration + 10)
178
 
179
  @spaces.GPU(duration=get_duration)
180
+ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, use_turbo, progress=gr.Progress(track_tqdm=True)):
181
  # Move embeddings to GPU only when inside the GPU decorated function
182
  prompt_embeds = prompt_embeds.to(device)
183
 
 
186
  pipe_kwargs = {
187
  "prompt_embeds": prompt_embeds,
188
  "image": image_list,
 
189
  "guidance_scale": guidance_scale,
190
  "generator": generator,
191
  "width": width,
192
  "height": height,
193
  }
194
 
195
+ # Use Turbo sigmas or regular inference steps
196
+ if use_turbo:
197
+ pipe_kwargs["sigmas"] = TURBO_SIGMAS
198
+ pipe_kwargs["num_inference_steps"] = 8 # Turbo always uses 8 steps
199
+ else:
200
+ pipe_kwargs["num_inference_steps"] = num_inference_steps
201
+
202
  # Progress bar for the actual generation steps
203
  if progress:
204
  progress(0, desc="Starting generation...")
 
206
  image = pipe(**pipe_kwargs).images[0]
207
  return image
208
 
209
+ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=8, guidance_scale=2.5, prompt_upsampling=False, use_turbo=True, progress=gr.Progress(track_tqdm=True)):
210
 
211
  if randomize_seed:
212
  seed = random.randint(0, MAX_SEED)
 
240
  height,
241
  num_inference_steps,
242
  guidance_scale,
243
+ seed,
244
+ use_turbo,
245
  progress
246
  )
247
 
 
272
  with gr.Blocks() as demo:
273
 
274
  with gr.Column(elem_id="col-container"):
275
+ gr.Markdown(f"""# FLUX.2 [dev] Turbo
276
+ FLUX.2 [dev] with [Turbo LoRA by fal](https://huggingface.co/fal/FLUX.2-Turbo) - a 32B rectified flow model capable of generating, editing and combining images based on text instructions in just 8 steps [[model](https://huggingface.co/black-forest-labs/FLUX.2-dev)], [[blog](https://bfl.ai/blog/flux-2)]
277
  """)
278
  with gr.Row():
279
  with gr.Column():
 
298
  )
299
 
300
  with gr.Accordion("Advanced Settings", open=False):
301
+ use_turbo = gr.Checkbox(
302
+ label="Use Turbo Mode (8 steps)",
303
+ value=True,
304
+ info="Enable Turbo LoRA for fast 8-step generation"
305
+ )
306
+
307
  prompt_upsampling = gr.Checkbox(
308
  label="Prompt Upsampling",
309
  value=True,
 
341
  with gr.Row():
342
 
343
  num_inference_steps = gr.Slider(
344
+ label="Number of inference steps (ignored in Turbo mode)",
345
  minimum=1,
346
  maximum=100,
347
  step=1,
 
353
  minimum=0.0,
354
  maximum=10.0,
355
  step=0.1,
356
+ value=2.5,
357
  )
358
 
359
 
 
389
  gr.on(
390
  triggers=[run_button.click, prompt.submit],
391
  fn=infer,
392
+ inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling, use_turbo],
393
  outputs=[result, seed]
394
  )
395