prithivMLmods commited on
Commit
ccaf792
·
verified ·
1 Parent(s): 68dd889

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -52
app.py CHANGED
@@ -8,8 +8,6 @@ import torch
8
  from diffusers import Flux2KleinPipeline, AutoencoderKLFlux2
9
  from PIL import Image
10
  from pathlib import Path
11
- import concurrent.futures
12
- import threading
13
  from typing import Iterable
14
 
15
  from gradio.themes import Soft
@@ -90,6 +88,7 @@ MAX_SEED = np.iinfo(np.int32).max
90
  MAX_IMAGE_SIZE = 1024
91
  EXAMPLES_DIR = Path("examples")
92
 
 
93
  print("Loading 4B Distilled model (Standard VAE)...")
94
  pipe_standard = Flux2KleinPipeline.from_pretrained(
95
  "black-forest-labs/FLUX.2-klein-4B",
@@ -97,12 +96,14 @@ pipe_standard = Flux2KleinPipeline.from_pretrained(
97
  )
98
  pipe_standard.enable_model_cpu_offload()
99
 
 
100
  print("Loading Small Decoder VAE...")
101
  vae_small = AutoencoderKLFlux2.from_pretrained(
102
  "black-forest-labs/FLUX.2-small-decoder",
103
  torch_dtype=dtype,
104
  )
105
 
 
106
  print("Loading 4B Distilled model (Small Decoder VAE)...")
107
  pipe_small_decoder = Flux2KleinPipeline.from_pretrained(
108
  "black-forest-labs/FLUX.2-klein-4B",
@@ -111,40 +112,27 @@ pipe_small_decoder = Flux2KleinPipeline.from_pretrained(
111
  )
112
  pipe_small_decoder.enable_model_cpu_offload()
113
 
114
- pipe_lock_standard = threading.Lock()
115
- pipe_lock_small = threading.Lock()
116
-
117
  def calc_dimensions(pil_img: Image.Image):
118
- """
119
- Given a PIL image return (width, height) snapped to multiples of 8,
120
- fitting within 1024 px on the long side, min 256 px on each side.
121
- Uses round() so we match the reference app exactly.
122
- """
123
  iw, ih = pil_img.size
124
  aspect = iw / ih
125
 
126
- if aspect >= 1: # landscape / square
127
  new_width = 1024
128
  new_height = int(round(1024 / aspect))
129
- else: # portrait
130
  new_height = 1024
131
  new_width = int(round(1024 * aspect))
132
 
133
- # snap to 8-pixel grid with round(), clamp to [256, 1024]
134
  new_width = max(256, min(1024, round(new_width / 8) * 8))
135
  new_height = max(256, min(1024, round(new_height / 8) * 8))
136
  return new_width, new_height
137
 
138
 
139
  def update_dimensions_from_image(image_list):
140
- """
141
- Called by the gallery .upload() event.
142
- Returns updated slider values for width and height.
143
- """
144
  if not image_list:
145
  return 1024, 1024
146
 
147
- # gallery items arrive as PIL images when type="pil"
148
  item = image_list[0]
149
  img = item[0] if isinstance(item, tuple) else item
150
 
@@ -155,11 +143,8 @@ def update_dimensions_from_image(image_list):
155
 
156
  return calc_dimensions(img)
157
 
 
158
  def parse_and_resize_images(input_images, width: int, height: int):
159
- """
160
- Parse the gallery input and resize every frame to (width, height).
161
- Returns a list[PIL.Image] or None.
162
- """
163
  if input_images is None:
164
  return None
165
 
@@ -192,12 +177,14 @@ def parse_and_resize_images(input_images, width: int, height: int):
192
  ]
193
  return resized
194
 
195
- def run_pipeline(pipe, lock, kwargs, seed):
196
- with lock:
197
- gen = torch.Generator(device="cpu").manual_seed(seed)
198
- result = pipe(**kwargs, generator=gen).images[0]
 
199
  return result
200
 
 
201
  @spaces.GPU(duration=120)
202
  def infer(
203
  prompt,
@@ -219,11 +206,9 @@ def infer(
219
  if randomize_seed:
220
  seed = random.randint(0, MAX_SEED)
221
 
222
- # ── width / height: derive from the first uploaded image if present ──
223
  image_list = None
224
  if input_images:
225
- # Re-derive dimensions from the actual first image so they are
226
- # always consistent with what the pipeline will receive.
227
  item = (
228
  input_images[0][0]
229
  if isinstance(input_images[0], tuple)
@@ -239,10 +224,9 @@ def infer(
239
  if first_pil is not None:
240
  width, height = calc_dimensions(first_pil)
241
 
242
- # parse + resize all images to the final (width, height)
243
  image_list = parse_and_resize_images(input_images, width, height)
244
 
245
- # ensure dims are multiples of 8 even for text-only runs
246
  width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8))
247
  height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8))
248
 
@@ -256,28 +240,22 @@ def infer(
256
  if image_list is not None:
257
  shared_kwargs["image"] = image_list
258
 
259
- progress(0.30, desc="Launching both pipelines simultaneously...")
 
 
260
 
261
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
262
- future_std = executor.submit(
263
- run_pipeline, pipe_standard, pipe_lock_standard, shared_kwargs, seed
264
- )
265
- future_small = executor.submit(
266
- run_pipeline, pipe_small_decoder, pipe_lock_small, shared_kwargs, seed
267
- )
268
- concurrent.futures.wait(
269
- [future_std, future_small],
270
- return_when=concurrent.futures.ALL_COMPLETED,
271
- )
272
-
273
- progress(0.80, desc="✅ Both pipelines done!")
274
 
275
- out_standard = future_std.result()
276
- out_small = future_small.result()
 
277
 
278
  gc.collect()
279
  torch.cuda.empty_cache()
280
 
 
 
281
  return out_standard, out_small, seed
282
 
283
 
@@ -347,7 +325,9 @@ with gr.Blocks() as demo:
347
  )
348
  gr.Markdown(
349
  "Compare **FLUX.2-klein-4B** side-by-side with "
350
- "[small decoder](https://huggingface.co/black-forest-labs/FLUX.2-small-decoder)."
 
 
351
  )
352
 
353
  with gr.Row(equal_height=True):
@@ -375,7 +355,7 @@ with gr.Blocks() as demo:
375
  with gr.Row():
376
  with gr.Column():
377
  result_standard = gr.Image(
378
- label="Standard Decoder",
379
  show_label=True,
380
  interactive=False,
381
  format="png",
@@ -383,7 +363,7 @@ with gr.Blocks() as demo:
383
  )
384
  with gr.Column():
385
  result_small = gr.Image(
386
- label="Small Decoder",
387
  show_label=True,
388
  interactive=False,
389
  format="png",
@@ -392,7 +372,7 @@ with gr.Blocks() as demo:
392
 
393
  seed_output = gr.Number(label="Seed Used", precision=0, visible=False)
394
 
395
- with gr.Accordion("Advanced Settings", open=False, visible=False):
396
  seed = gr.Slider(
397
  label="Seed",
398
  minimum=0,
@@ -478,7 +458,8 @@ with gr.Blocks() as demo:
478
 
479
  if __name__ == "__main__":
480
  demo.queue(max_size=20).launch(
481
- theme=orange_red_theme, css=css,
 
482
  mcp_server=True,
483
  ssr_mode=False,
484
  show_error=True,
 
8
  from diffusers import Flux2KleinPipeline, AutoencoderKLFlux2
9
  from PIL import Image
10
  from pathlib import Path
 
 
11
  from typing import Iterable
12
 
13
  from gradio.themes import Soft
 
88
  MAX_IMAGE_SIZE = 1024
89
  EXAMPLES_DIR = Path("examples")
90
 
91
+ # ── Load standard pipeline ──────────────────────────────────────────────────
92
  print("Loading 4B Distilled model (Standard VAE)...")
93
  pipe_standard = Flux2KleinPipeline.from_pretrained(
94
  "black-forest-labs/FLUX.2-klein-4B",
 
96
  )
97
  pipe_standard.enable_model_cpu_offload()
98
 
99
+ # ── Load small decoder VAE ───────────────────────────────────────────────────
100
  print("Loading Small Decoder VAE...")
101
  vae_small = AutoencoderKLFlux2.from_pretrained(
102
  "black-forest-labs/FLUX.2-small-decoder",
103
  torch_dtype=dtype,
104
  )
105
 
106
+ # ── Load small-decoder pipeline ──────────────────────────────────────────────
107
  print("Loading 4B Distilled model (Small Decoder VAE)...")
108
  pipe_small_decoder = Flux2KleinPipeline.from_pretrained(
109
  "black-forest-labs/FLUX.2-klein-4B",
 
112
  )
113
  pipe_small_decoder.enable_model_cpu_offload()
114
 
115
+ # ────────────────────────────────────────────────────────────────────────────
 
 
116
  def calc_dimensions(pil_img: Image.Image):
 
 
 
 
 
117
  iw, ih = pil_img.size
118
  aspect = iw / ih
119
 
120
+ if aspect >= 1:
121
  new_width = 1024
122
  new_height = int(round(1024 / aspect))
123
+ else:
124
  new_height = 1024
125
  new_width = int(round(1024 * aspect))
126
 
 
127
  new_width = max(256, min(1024, round(new_width / 8) * 8))
128
  new_height = max(256, min(1024, round(new_height / 8) * 8))
129
  return new_width, new_height
130
 
131
 
132
  def update_dimensions_from_image(image_list):
 
 
 
 
133
  if not image_list:
134
  return 1024, 1024
135
 
 
136
  item = image_list[0]
137
  img = item[0] if isinstance(item, tuple) else item
138
 
 
143
 
144
  return calc_dimensions(img)
145
 
146
+
147
  def parse_and_resize_images(input_images, width: int, height: int):
 
 
 
 
148
  if input_images is None:
149
  return None
150
 
 
177
  ]
178
  return resized
179
 
180
+
181
+ def run_pipeline(pipe, kwargs, seed):
182
+ """Run a single pipeline — no locks needed, purely sequential."""
183
+ gen = torch.Generator(device="cpu").manual_seed(seed)
184
+ result = pipe(**kwargs, generator=gen).images[0]
185
  return result
186
 
187
+
188
  @spaces.GPU(duration=120)
189
  def infer(
190
  prompt,
 
206
  if randomize_seed:
207
  seed = random.randint(0, MAX_SEED)
208
 
209
+ # ── Derive dimensions from the first uploaded image if present ───────────
210
  image_list = None
211
  if input_images:
 
 
212
  item = (
213
  input_images[0][0]
214
  if isinstance(input_images[0], tuple)
 
224
  if first_pil is not None:
225
  width, height = calc_dimensions(first_pil)
226
 
 
227
  image_list = parse_and_resize_images(input_images, width, height)
228
 
229
+ # ensure dims are multiples of 8
230
  width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8))
231
  height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8))
232
 
 
240
  if image_list is not None:
241
  shared_kwargs["image"] = image_list
242
 
243
+ # ── Pipeline 1: Standard Decoder ─────────────────────────────────────────
244
+ progress(0.10, desc="Running Pipeline 1 / 2 — Standard Decoder...")
245
+ out_standard = run_pipeline(pipe_standard, shared_kwargs, seed)
246
 
247
+ gc.collect()
248
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
249
 
250
+ # ── Pipeline 2: Small Decoder ─────────────────────────────────────────────
251
+ progress(0.55, desc="Running Pipeline 2 / 2 — Small Decoder...")
252
+ out_small = run_pipeline(pipe_small_decoder, shared_kwargs, seed)
253
 
254
  gc.collect()
255
  torch.cuda.empty_cache()
256
 
257
+ progress(1.00, desc="✅ Both pipelines complete!")
258
+
259
  return out_standard, out_small, seed
260
 
261
 
 
325
  )
326
  gr.Markdown(
327
  "Compare **FLUX.2-klein-4B** side-by-side with "
328
+ "[small decoder](https://huggingface.co/black-forest-labs/FLUX.2-small-decoder). "
329
+ "Both pipelines run **one after the other** using the **same seed and latents** — "
330
+ "only the VAE decoder differs."
331
  )
332
 
333
  with gr.Row(equal_height=True):
 
355
  with gr.Row():
356
  with gr.Column():
357
  result_standard = gr.Image(
358
+ label="Standard Decoder (runs first)",
359
  show_label=True,
360
  interactive=False,
361
  format="png",
 
363
  )
364
  with gr.Column():
365
  result_small = gr.Image(
366
+ label="Small Decoder (runs second)",
367
  show_label=True,
368
  interactive=False,
369
  format="png",
 
372
 
373
  seed_output = gr.Number(label="Seed Used", precision=0, visible=False)
374
 
375
+ with gr.Accordion("Advanced Settings", open=False):
376
  seed = gr.Slider(
377
  label="Seed",
378
  minimum=0,
 
458
 
459
  if __name__ == "__main__":
460
  demo.queue(max_size=20).launch(
461
+ theme=orange_red_theme,
462
+ css=css,
463
  mcp_server=True,
464
  ssr_mode=False,
465
  show_error=True,