prithivMLmods commited on
Commit
62b9762
·
verified ·
1 Parent(s): cef7b4f

update app

Browse files
Files changed (1) hide show
  1. app.py +130 -57
app.py CHANGED
@@ -83,8 +83,8 @@ class OrangeRedTheme(Soft):
83
 
84
  orange_red_theme = OrangeRedTheme()
85
 
86
- dtype = torch.bfloat16
87
- device = "cuda" if torch.cuda.is_available() else "cpu"
88
 
89
  MAX_SEED = np.iinfo(np.int32).max
90
  MAX_IMAGE_SIZE = 1024
@@ -114,78 +114,99 @@ pipe_small_decoder.enable_model_cpu_offload()
114
  pipe_lock_standard = threading.Lock()
115
  pipe_lock_small = threading.Lock()
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def update_dimensions_from_image(image_list):
118
- if image_list is None or len(image_list) == 0:
 
 
 
 
119
  return 1024, 1024
120
 
 
121
  item = image_list[0]
122
  img = item[0] if isinstance(item, tuple) else item
123
 
124
  if isinstance(img, str):
125
  img = Image.open(img).convert("RGB")
 
 
126
 
127
- iw, ih = img.size
128
- aspect_ratio = iw / ih
129
 
130
- if aspect_ratio >= 1:
131
- new_width = 1024
132
- new_height = int(1024 / aspect_ratio)
133
- else:
134
- new_height = 1024
135
- new_width = int(1024 * aspect_ratio)
136
-
137
- new_width = max(256, min(1024, round(new_width / 8) * 8))
138
- new_height = max(256, min(1024, round(new_height / 8) * 8))
139
- return new_width, new_height
140
-
141
- def get_example_items():
142
- example_prompts = {
143
- "1.jpg": "Change the weather to stormy.",
144
- "2.jpg": "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition.",
145
- "3.jpg": "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent.",
146
- "4.jpg": "Make the texture high-resolution.",
147
- }
148
- items = []
149
- if EXAMPLES_DIR.exists():
150
- for name in sorted(os.listdir(EXAMPLES_DIR)):
151
- if name.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
152
- items.append({
153
- "file": name,
154
- "path": str(EXAMPLES_DIR / name),
155
- "prompt": example_prompts.get(
156
- name, "Edit this image while preserving composition."
157
- ),
158
- })
159
- return items
160
 
161
- def parse_input_images(input_images):
162
- """Safely parse gallery / filepath / PIL inputs → list[PIL.Image] or None."""
 
 
 
 
163
  if input_images is None:
164
  return None
 
 
 
165
  if isinstance(input_images, str):
166
- return [Image.open(input_images).convert("RGB")] if os.path.exists(input_images) else None
167
- if isinstance(input_images, list) and len(input_images) > 0:
168
- parsed = []
 
 
169
  for item in input_images:
170
  try:
171
  src = item[0] if isinstance(item, tuple) else item
172
  if isinstance(src, str):
173
- parsed.append(Image.open(src).convert("RGB"))
174
  elif isinstance(src, Image.Image):
175
- parsed.append(src.convert("RGB"))
176
  elif hasattr(src, "name"):
177
- parsed.append(Image.open(src.name).convert("RGB"))
178
  except Exception as e:
179
  print(f"Skipping invalid image: {e}")
180
- return parsed or None
181
- return None
182
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def run_pipeline(pipe, lock, kwargs, seed):
184
  with lock:
185
  gen = torch.Generator(device="cpu").manual_seed(seed)
186
  result = pipe(**kwargs, generator=gen).images[0]
187
  return result
188
 
 
 
189
  @spaces.GPU(duration=120)
190
  def infer(
191
  prompt,
@@ -201,13 +222,38 @@ def infer(
201
  gc.collect()
202
  torch.cuda.empty_cache()
203
 
204
- if not prompt or prompt.strip() == "":
205
  raise gr.Error("Please enter a prompt.")
206
 
207
  if randomize_seed:
208
  seed = random.randint(0, MAX_SEED)
209
 
210
- image_list = parse_input_images(input_images)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  shared_kwargs = dict(
213
  prompt=prompt,
@@ -222,8 +268,12 @@ def infer(
222
  progress(0.05, desc="⚡ Launching both pipelines simultaneously...")
223
 
224
  with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
225
- future_std = executor.submit(run_pipeline, pipe_standard, pipe_lock_standard, shared_kwargs, seed)
226
- future_small = executor.submit(run_pipeline, pipe_small_decoder, pipe_lock_small, shared_kwargs, seed)
 
 
 
 
227
  concurrent.futures.wait(
228
  [future_std, future_small],
229
  return_when=concurrent.futures.ALL_COMPLETED,
@@ -254,6 +304,27 @@ def infer_example(prompt):
254
  )
255
  return out_std, out_small, seed_used
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  EXAMPLE_ITEMS = get_example_items()
258
 
259
  css = """
@@ -284,11 +355,13 @@ with gr.Blocks() as demo:
284
  elem_id="main-title",
285
  )
286
  gr.Markdown(
287
- "Compare **FLUX.2-klein-4B** side-by-side with [samll decoder](https://huggingface.co/black-forest-labs/FLUX.2-small-decoder)."
 
288
  )
289
 
290
  with gr.Row(equal_height=True):
291
 
 
292
  with gr.Column():
293
  input_images = gr.Gallery(
294
  label="Input Images",
@@ -304,9 +377,10 @@ with gr.Blocks() as demo:
304
  show_label=True,
305
  placeholder="e.g., A black cat holding a sign that says hello world...",
306
  )
307
-
308
  run_button = gr.Button("Run Comparison", variant="primary")
309
 
 
310
  with gr.Column():
311
  with gr.Row():
312
  with gr.Column():
@@ -317,7 +391,6 @@ with gr.Blocks() as demo:
317
  format="png",
318
  height=250,
319
  )
320
-
321
  with gr.Column():
322
  result_small = gr.Image(
323
  label="Small Decoder",
@@ -329,7 +402,7 @@ with gr.Blocks() as demo:
329
 
330
  seed_output = gr.Number(label="Seed Used", precision=0, visible=False)
331
 
332
- with gr.Accordion("Advanced Settings", open=False, visible=False):
333
  seed = gr.Slider(
334
  label="Seed",
335
  minimum=0,
@@ -390,7 +463,8 @@ with gr.Blocks() as demo:
390
  "[*](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B) "
391
  "Experimental Space — FLUX.2 [klein] 4B VAE Decoder Comparison."
392
  )
393
-
 
394
  input_images.upload(
395
  fn=update_dimensions_from_image,
396
  inputs=[input_images],
@@ -415,9 +489,8 @@ with gr.Blocks() as demo:
415
 
416
  if __name__ == "__main__":
417
  demo.queue(max_size=20).launch(
418
- theme=orange_red_theme,
419
- mcp_server=True,
420
- css=css,
421
  ssr_mode=False,
422
  show_error=True,
423
  )
 
83
 
84
  orange_red_theme = OrangeRedTheme()
85
 
86
+ dtype = torch.bfloat16
87
+ device = "cuda" if torch.cuda.is_available() else "cpu"
88
 
89
  MAX_SEED = np.iinfo(np.int32).max
90
  MAX_IMAGE_SIZE = 1024
 
114
  pipe_lock_standard = threading.Lock()
115
  pipe_lock_small = threading.Lock()
116
 
117
+
118
+ # ── dimension helper ────────────────────────────────────────────────────────
119
+ def calc_dimensions(pil_img: Image.Image):
120
+ """
121
+ Given a PIL image return (width, height) snapped to multiples of 8,
122
+ fitting within 1024 px on the long side, min 256 px on each side.
123
+ Uses round() so we match the reference app exactly.
124
+ """
125
+ iw, ih = pil_img.size
126
+ aspect = iw / ih
127
+
128
+ if aspect >= 1: # landscape / square
129
+ new_width = 1024
130
+ new_height = int(round(1024 / aspect))
131
+ else: # portrait
132
+ new_height = 1024
133
+ new_width = int(round(1024 * aspect))
134
+
135
+ # snap to 8-pixel grid with round(), clamp to [256, 1024]
136
+ new_width = max(256, min(1024, round(new_width / 8) * 8))
137
+ new_height = max(256, min(1024, round(new_height / 8) * 8))
138
+ return new_width, new_height
139
+
140
+
141
  def update_dimensions_from_image(image_list):
142
+ """
143
+ Called by the gallery .upload() event.
144
+ Returns updated slider values for width and height.
145
+ """
146
+ if not image_list:
147
  return 1024, 1024
148
 
149
+ # gallery items arrive as PIL images when type="pil"
150
  item = image_list[0]
151
  img = item[0] if isinstance(item, tuple) else item
152
 
153
  if isinstance(img, str):
154
  img = Image.open(img).convert("RGB")
155
+ elif not isinstance(img, Image.Image):
156
+ return 1024, 1024
157
 
158
+ return calc_dimensions(img)
 
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ # ── image parser ─────────────────────────────────────────────────────────────
162
+ def parse_and_resize_images(input_images, width: int, height: int):
163
+ """
164
+ Parse the gallery input and resize every frame to (width, height).
165
+ Returns a list[PIL.Image] or None.
166
+ """
167
  if input_images is None:
168
  return None
169
+
170
+ raw_list = []
171
+
172
  if isinstance(input_images, str):
173
+ if os.path.exists(input_images):
174
+ raw_list = [Image.open(input_images).convert("RGB")]
175
+ elif isinstance(input_images, Image.Image):
176
+ raw_list = [input_images.convert("RGB")]
177
+ elif isinstance(input_images, list):
178
  for item in input_images:
179
  try:
180
  src = item[0] if isinstance(item, tuple) else item
181
  if isinstance(src, str):
182
+ raw_list.append(Image.open(src).convert("RGB"))
183
  elif isinstance(src, Image.Image):
184
+ raw_list.append(src.convert("RGB"))
185
  elif hasattr(src, "name"):
186
+ raw_list.append(Image.open(src.name).convert("RGB"))
187
  except Exception as e:
188
  print(f"Skipping invalid image: {e}")
 
 
189
 
190
+ if not raw_list:
191
+ return None
192
+
193
+ # ── KEY FIX: resize every image to the exact pipeline dimensions ──
194
+ resized = [
195
+ img.resize((width, height), Image.LANCZOS)
196
+ for img in raw_list
197
+ ]
198
+ return resized
199
+
200
+
201
+ # ── pipeline runner ───────────────────────────────────────────────────────────
202
  def run_pipeline(pipe, lock, kwargs, seed):
203
  with lock:
204
  gen = torch.Generator(device="cpu").manual_seed(seed)
205
  result = pipe(**kwargs, generator=gen).images[0]
206
  return result
207
 
208
+
209
+ # ── main inference ────────────────────────────────────────────────────────────
210
  @spaces.GPU(duration=120)
211
  def infer(
212
  prompt,
 
222
  gc.collect()
223
  torch.cuda.empty_cache()
224
 
225
+ if not prompt or not prompt.strip():
226
  raise gr.Error("Please enter a prompt.")
227
 
228
  if randomize_seed:
229
  seed = random.randint(0, MAX_SEED)
230
 
231
+ # ── width / height: derive from the first uploaded image if present ──
232
+ image_list = None
233
+ if input_images:
234
+ # Re-derive dimensions from the actual first image so they are
235
+ # always consistent with what the pipeline will receive.
236
+ item = (
237
+ input_images[0][0]
238
+ if isinstance(input_images[0], tuple)
239
+ else input_images[0]
240
+ )
241
+ if isinstance(item, str):
242
+ first_pil = Image.open(item).convert("RGB")
243
+ elif isinstance(item, Image.Image):
244
+ first_pil = item.convert("RGB")
245
+ else:
246
+ first_pil = None
247
+
248
+ if first_pil is not None:
249
+ width, height = calc_dimensions(first_pil)
250
+
251
+ # parse + resize all images to the final (width, height)
252
+ image_list = parse_and_resize_images(input_images, width, height)
253
+
254
+ # ensure dims are multiples of 8 even for text-only runs
255
+ width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8))
256
+ height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8))
257
 
258
  shared_kwargs = dict(
259
  prompt=prompt,
 
268
  progress(0.05, desc="⚡ Launching both pipelines simultaneously...")
269
 
270
  with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
271
+ future_std = executor.submit(
272
+ run_pipeline, pipe_standard, pipe_lock_standard, shared_kwargs, seed
273
+ )
274
+ future_small = executor.submit(
275
+ run_pipeline, pipe_small_decoder, pipe_lock_small, shared_kwargs, seed
276
+ )
277
  concurrent.futures.wait(
278
  [future_std, future_small],
279
  return_when=concurrent.futures.ALL_COMPLETED,
 
304
  )
305
  return out_std, out_small, seed_used
306
 
307
+
308
+ def get_example_items():
309
+ example_prompts = {
310
+ "1.jpg": "Change the weather to stormy.",
311
+ "2.jpg": "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition.",
312
+ "3.jpg": "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent.",
313
+ "4.jpg": "Make the texture high-resolution.",
314
+ }
315
+ items = []
316
+ if EXAMPLES_DIR.exists():
317
+ for name in sorted(os.listdir(EXAMPLES_DIR)):
318
+ if name.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
319
+ items.append({
320
+ "file": name,
321
+ "path": str(EXAMPLES_DIR / name),
322
+ "prompt": example_prompts.get(
323
+ name, "Edit this image while preserving composition."
324
+ ),
325
+ })
326
+ return items
327
+
328
  EXAMPLE_ITEMS = get_example_items()
329
 
330
  css = """
 
355
  elem_id="main-title",
356
  )
357
  gr.Markdown(
358
+ "Compare **FLUX.2-klein-4B** side-by-side with "
359
+ "[small decoder](https://huggingface.co/black-forest-labs/FLUX.2-small-decoder)."
360
  )
361
 
362
  with gr.Row(equal_height=True):
363
 
364
+ # ── LEFT COLUMN: inputs ─────────────────────────────────────────
365
  with gr.Column():
366
  input_images = gr.Gallery(
367
  label="Input Images",
 
377
  show_label=True,
378
  placeholder="e.g., A black cat holding a sign that says hello world...",
379
  )
380
+
381
  run_button = gr.Button("Run Comparison", variant="primary")
382
 
383
+ # ── RIGHT COLUMN: outputs ───────────────────────────────────────
384
  with gr.Column():
385
  with gr.Row():
386
  with gr.Column():
 
391
  format="png",
392
  height=250,
393
  )
 
394
  with gr.Column():
395
  result_small = gr.Image(
396
  label="Small Decoder",
 
402
 
403
  seed_output = gr.Number(label="Seed Used", precision=0, visible=False)
404
 
405
+ with gr.Accordion("Advanced Settings", open=False):
406
  seed = gr.Slider(
407
  label="Seed",
408
  minimum=0,
 
463
  "[*](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B) "
464
  "Experimental Space — FLUX.2 [klein] 4B VAE Decoder Comparison."
465
  )
466
+
467
+ # ── events ────────────────────────────────────────────────────────────────
468
  input_images.upload(
469
  fn=update_dimensions_from_image,
470
  inputs=[input_images],
 
489
 
490
  if __name__ == "__main__":
491
  demo.queue(max_size=20).launch(
492
+ theme=orange_red_theme, css=css,
493
+ mcp_server=True,
 
494
  ssr_mode=False,
495
  show_error=True,
496
  )