prithivMLmods commited on
Commit
02534d9
·
verified ·
1 Parent(s): ee702f3

update app

Browse files
Files changed (1) hide show
  1. app.py +485 -0
app.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import gradio as gr
4
+ import numpy as np
5
+ import random
6
+ import spaces
7
+ import torch
8
+ from diffusers import Flux2KleinPipeline, AutoencoderKLFlux2
9
+ from PIL import Image
10
+ from pathlib import Path
11
+ import concurrent.futures
12
+ import threading
13
+ from typing import Iterable
14
+
15
+ from gradio.themes import Soft
16
+ from gradio.themes.utils import colors, fonts, sizes
17
+
18
+ colors.orange_red = colors.Color(
19
+ name="orange_red",
20
+ c50="#FFF0E5",
21
+ c100="#FFE0CC",
22
+ c200="#FFC299",
23
+ c300="#FFA366",
24
+ c400="#FF8533",
25
+ c500="#FF4500",
26
+ c600="#E63E00",
27
+ c700="#CC3700",
28
+ c800="#B33000",
29
+ c900="#992900",
30
+ c950="#802200",
31
+ )
32
+
33
+ class OrangeRedTheme(Soft):
34
+ def __init__(
35
+ self,
36
+ *,
37
+ primary_hue: colors.Color | str = colors.gray,
38
+ secondary_hue: colors.Color | str = colors.orange_red,
39
+ neutral_hue: colors.Color | str = colors.slate,
40
+ text_size: sizes.Size | str = sizes.text_lg,
41
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
42
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
43
+ ),
44
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
45
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
46
+ ),
47
+ ):
48
+ super().__init__(
49
+ primary_hue=primary_hue,
50
+ secondary_hue=secondary_hue,
51
+ neutral_hue=neutral_hue,
52
+ text_size=text_size,
53
+ font=font,
54
+ font_mono=font_mono,
55
+ )
56
+ super().set(
57
+ background_fill_primary="*primary_50",
58
+ background_fill_primary_dark="*primary_900",
59
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
60
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
61
+ button_primary_text_color="white",
62
+ button_primary_text_color_hover="white",
63
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
64
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
65
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
66
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
67
+ button_secondary_text_color="black",
68
+ button_secondary_text_color_hover="white",
69
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
70
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
71
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
72
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
73
+ slider_color="*secondary_500",
74
+ slider_color_dark="*secondary_600",
75
+ block_title_text_weight="600",
76
+ block_border_width="3px",
77
+ block_shadow="*shadow_drop_lg",
78
+ button_primary_shadow="*shadow_drop_lg",
79
+ button_large_padding="11px",
80
+ color_accent_soft="*primary_100",
81
+ block_label_background_fill="*primary_200",
82
+ )
83
+
84
+ orange_red_theme = OrangeRedTheme()
85
+
86
+ dtype = torch.bfloat16
87
+ device = "cuda" if torch.cuda.is_available() else "cpu"
88
+
89
+ MAX_SEED = np.iinfo(np.int32).max
90
+ MAX_IMAGE_SIZE = 1024
91
+ EXAMPLES_DIR = Path("examples")
92
+
93
+ print("Loading 4B Distilled model (Standard VAE)...")
94
+ pipe_standard = Flux2KleinPipeline.from_pretrained(
95
+ "black-forest-labs/FLUX.2-klein-4B",
96
+ torch_dtype=dtype,
97
+ )
98
+ pipe_standard.enable_model_cpu_offload()
99
+
100
+ print("Loading Small Decoder VAE...")
101
+ vae_small = AutoencoderKLFlux2.from_pretrained(
102
+ "black-forest-labs/FLUX.2-small-decoder",
103
+ torch_dtype=dtype,
104
+ )
105
+
106
+ print("Loading 4B Distilled model (Small Decoder VAE)...")
107
+ pipe_small_decoder = Flux2KleinPipeline.from_pretrained(
108
+ "black-forest-labs/FLUX.2-klein-4B",
109
+ vae=vae_small,
110
+ torch_dtype=dtype,
111
+ )
112
+ pipe_small_decoder.enable_model_cpu_offload()
113
+
114
+ pipe_lock_standard = threading.Lock()
115
+ pipe_lock_small = threading.Lock()
116
+
117
+ def calc_dimensions(pil_img: Image.Image):
118
+ """
119
+ Given a PIL image return (width, height) snapped to multiples of 8,
120
+ fitting within 1024 px on the long side, min 256 px on each side.
121
+ Uses round() so we match the reference app exactly.
122
+ """
123
+ iw, ih = pil_img.size
124
+ aspect = iw / ih
125
+
126
+ if aspect >= 1: # landscape / square
127
+ new_width = 1024
128
+ new_height = int(round(1024 / aspect))
129
+ else: # portrait
130
+ new_height = 1024
131
+ new_width = int(round(1024 * aspect))
132
+
133
+ # snap to 8-pixel grid with round(), clamp to [256, 1024]
134
+ new_width = max(256, min(1024, round(new_width / 8) * 8))
135
+ new_height = max(256, min(1024, round(new_height / 8) * 8))
136
+ return new_width, new_height
137
+
138
+
139
+ def update_dimensions_from_image(image_list):
140
+ """
141
+ Called by the gallery .upload() event.
142
+ Returns updated slider values for width and height.
143
+ """
144
+ if not image_list:
145
+ return 1024, 1024
146
+
147
+ # gallery items arrive as PIL images when type="pil"
148
+ item = image_list[0]
149
+ img = item[0] if isinstance(item, tuple) else item
150
+
151
+ if isinstance(img, str):
152
+ img = Image.open(img).convert("RGB")
153
+ elif not isinstance(img, Image.Image):
154
+ return 1024, 1024
155
+
156
+ return calc_dimensions(img)
157
+
158
+ def parse_and_resize_images(input_images, width: int, height: int):
159
+ """
160
+ Parse the gallery input and resize every frame to (width, height).
161
+ Returns a list[PIL.Image] or None.
162
+ """
163
+ if input_images is None:
164
+ return None
165
+
166
+ raw_list = []
167
+
168
+ if isinstance(input_images, str):
169
+ if os.path.exists(input_images):
170
+ raw_list = [Image.open(input_images).convert("RGB")]
171
+ elif isinstance(input_images, Image.Image):
172
+ raw_list = [input_images.convert("RGB")]
173
+ elif isinstance(input_images, list):
174
+ for item in input_images:
175
+ try:
176
+ src = item[0] if isinstance(item, tuple) else item
177
+ if isinstance(src, str):
178
+ raw_list.append(Image.open(src).convert("RGB"))
179
+ elif isinstance(src, Image.Image):
180
+ raw_list.append(src.convert("RGB"))
181
+ elif hasattr(src, "name"):
182
+ raw_list.append(Image.open(src.name).convert("RGB"))
183
+ except Exception as e:
184
+ print(f"Skipping invalid image: {e}")
185
+
186
+ if not raw_list:
187
+ return None
188
+
189
+ resized = [
190
+ img.resize((width, height), Image.LANCZOS)
191
+ for img in raw_list
192
+ ]
193
+ return resized
194
+
195
+ def run_pipeline(pipe, lock, kwargs, seed):
196
+ with lock:
197
+ gen = torch.Generator(device="cpu").manual_seed(seed)
198
+ result = pipe(**kwargs, generator=gen).images[0]
199
+ return result
200
+
201
+ @spaces.GPU(duration=120)
202
+ def infer(
203
+ prompt,
204
+ input_images=None,
205
+ seed=42,
206
+ randomize_seed=False,
207
+ width=1024,
208
+ height=1024,
209
+ num_inference_steps=4,
210
+ guidance_scale=1.0,
211
+ progress=gr.Progress(track_tqdm=True),
212
+ ):
213
+ gc.collect()
214
+ torch.cuda.empty_cache()
215
+
216
+ if not prompt or not prompt.strip():
217
+ raise gr.Error("Please enter a prompt.")
218
+
219
+ if randomize_seed:
220
+ seed = random.randint(0, MAX_SEED)
221
+
222
+ # ── width / height: derive from the first uploaded image if present ──
223
+ image_list = None
224
+ if input_images:
225
+ # Re-derive dimensions from the actual first image so they are
226
+ # always consistent with what the pipeline will receive.
227
+ item = (
228
+ input_images[0][0]
229
+ if isinstance(input_images[0], tuple)
230
+ else input_images[0]
231
+ )
232
+ if isinstance(item, str):
233
+ first_pil = Image.open(item).convert("RGB")
234
+ elif isinstance(item, Image.Image):
235
+ first_pil = item.convert("RGB")
236
+ else:
237
+ first_pil = None
238
+
239
+ if first_pil is not None:
240
+ width, height = calc_dimensions(first_pil)
241
+
242
+ # parse + resize all images to the final (width, height)
243
+ image_list = parse_and_resize_images(input_images, width, height)
244
+
245
+ # ensure dims are multiples of 8 even for text-only runs
246
+ width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8))
247
+ height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8))
248
+
249
+ shared_kwargs = dict(
250
+ prompt=prompt,
251
+ height=height,
252
+ width=width,
253
+ num_inference_steps=num_inference_steps,
254
+ guidance_scale=guidance_scale,
255
+ )
256
+ if image_list is not None:
257
+ shared_kwargs["image"] = image_list
258
+
259
+ progress(0.30, desc="Launching both pipelines simultaneously...")
260
+
261
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
262
+ future_std = executor.submit(
263
+ run_pipeline, pipe_standard, pipe_lock_standard, shared_kwargs, seed
264
+ )
265
+ future_small = executor.submit(
266
+ run_pipeline, pipe_small_decoder, pipe_lock_small, shared_kwargs, seed
267
+ )
268
+ concurrent.futures.wait(
269
+ [future_std, future_small],
270
+ return_when=concurrent.futures.ALL_COMPLETED,
271
+ )
272
+
273
+ progress(0.80, desc="✅ Both pipelines done!")
274
+
275
+ out_standard = future_std.result()
276
+ out_small = future_small.result()
277
+
278
+ gc.collect()
279
+ torch.cuda.empty_cache()
280
+
281
+ return out_standard, out_small, seed
282
+
283
+
284
+ @spaces.GPU(duration=120)
285
+ def infer_example(prompt):
286
+ out_std, out_small, seed_used = infer(
287
+ prompt=prompt,
288
+ input_images=None,
289
+ seed=0,
290
+ randomize_seed=True,
291
+ width=1024,
292
+ height=1024,
293
+ num_inference_steps=4,
294
+ guidance_scale=1.0,
295
+ )
296
+ return out_std, out_small, seed_used
297
+
298
+
299
+ def get_example_items():
300
+ example_prompts = {
301
+ "1.jpg": "Change the weather to stormy.",
302
+ "2.jpg": "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition.",
303
+ "3.jpg": "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent.",
304
+ "4.jpg": "Make the texture high-resolution.",
305
+ }
306
+ items = []
307
+ if EXAMPLES_DIR.exists():
308
+ for name in sorted(os.listdir(EXAMPLES_DIR)):
309
+ if name.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
310
+ items.append({
311
+ "file": name,
312
+ "path": str(EXAMPLES_DIR / name),
313
+ "prompt": example_prompts.get(
314
+ name, "Edit this image while preserving composition."
315
+ ),
316
+ })
317
+ return items
318
+
319
+ EXAMPLE_ITEMS = get_example_items()
320
+
321
+ css = """
322
+ #col-container {
323
+ margin: 0 auto;
324
+ max-width: 1100px;
325
+ }
326
+ #main-title h1 {
327
+ font-size: 2.4em !important;
328
+ }
329
+ .vae-badge {
330
+ font-weight: 700;
331
+ font-size: 0.95em;
332
+ text-align: center;
333
+ padding: 4px 16px;
334
+ border-radius: 20px;
335
+ display: block;
336
+ margin-bottom: 6px;
337
+ }
338
+ """
339
+
340
+ with gr.Blocks() as demo:
341
+
342
+ with gr.Column(elem_id="col-container"):
343
+
344
+ gr.Markdown(
345
+ "# **Flux.2-4B-Decoder-Comparator**",
346
+ elem_id="main-title",
347
+ )
348
+ gr.Markdown(
349
+ "Compare **FLUX.2-klein-4B** side-by-side with "
350
+ "[small decoder](https://huggingface.co/black-forest-labs/FLUX.2-small-decoder)."
351
+ )
352
+
353
+ with gr.Row(equal_height=True):
354
+
355
+ with gr.Column():
356
+ input_images = gr.Gallery(
357
+ label="Input Images",
358
+ type="pil",
359
+ columns=2,
360
+ rows=1,
361
+ height=300,
362
+ allow_preview=True,
363
+ )
364
+
365
+ prompt = gr.Text(
366
+ label="Prompt",
367
+ max_lines=1,
368
+ show_label=True,
369
+ placeholder="e.g., A black cat holding a sign that says hello world...",
370
+ )
371
+
372
+ run_button = gr.Button("Run Comparison", variant="primary")
373
+
374
+ with gr.Column():
375
+ with gr.Row():
376
+ with gr.Column():
377
+ result_standard = gr.Image(
378
+ label="Standard Decoder",
379
+ show_label=True,
380
+ interactive=False,
381
+ format="png",
382
+ height=250,
383
+ )
384
+ with gr.Column():
385
+ result_small = gr.Image(
386
+ label="Small Decoder",
387
+ show_label=True,
388
+ interactive=False,
389
+ format="png",
390
+ height=250,
391
+ )
392
+
393
+ seed_output = gr.Number(label="Seed Used", precision=0, visible=False)
394
+
395
+ with gr.Accordion("Advanced Settings", open=False, visible=False):
396
+ seed = gr.Slider(
397
+ label="Seed",
398
+ minimum=0,
399
+ maximum=MAX_SEED,
400
+ step=1,
401
+ value=0,
402
+ )
403
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
404
+
405
+ with gr.Row():
406
+ width = gr.Slider(
407
+ label="Width",
408
+ minimum=256,
409
+ maximum=MAX_IMAGE_SIZE,
410
+ step=8,
411
+ value=1024,
412
+ )
413
+ height_slider = gr.Slider(
414
+ label="Height",
415
+ minimum=256,
416
+ maximum=MAX_IMAGE_SIZE,
417
+ step=8,
418
+ value=1024,
419
+ )
420
+
421
+ with gr.Row():
422
+ num_inference_steps = gr.Slider(
423
+ label="Inference Steps",
424
+ minimum=1,
425
+ maximum=20,
426
+ step=1,
427
+ value=4,
428
+ )
429
+ guidance_scale = gr.Slider(
430
+ label="Guidance Scale",
431
+ minimum=0.0,
432
+ maximum=10.0,
433
+ step=0.1,
434
+ value=1.0,
435
+ )
436
+
437
+ gr.Examples(
438
+ examples=[
439
+ [["examples/I1.jpg", "examples/I2.jpg"], "Make her wear these glasses in Image 2."],
440
+ [["examples/1.jpg"], "Change the weather to stormy."],
441
+ [["examples/2.jpg"], "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition."],
442
+ [["examples/3.jpg"], "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent."],
443
+ [["examples/4.jpg"], "Make the texture high-resolution."],
444
+ ],
445
+ inputs=[input_images, prompt],
446
+ outputs=[result_standard, result_small, seed_output],
447
+ fn=infer_example,
448
+ cache_examples=False,
449
+ label="Examples",
450
+ )
451
+
452
+ gr.Markdown(
453
+ "[*](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B) "
454
+ "Experimental Space — FLUX.2 [klein] 4B VAE Decoder Comparison."
455
+ )
456
+
457
+ input_images.upload(
458
+ fn=update_dimensions_from_image,
459
+ inputs=[input_images],
460
+ outputs=[width, height_slider],
461
+ )
462
+
463
+ gr.on(
464
+ triggers=[run_button.click, prompt.submit],
465
+ fn=infer,
466
+ inputs=[
467
+ prompt,
468
+ input_images,
469
+ seed,
470
+ randomize_seed,
471
+ width,
472
+ height_slider,
473
+ num_inference_steps,
474
+ guidance_scale,
475
+ ],
476
+ outputs=[result_standard, result_small, seed_output],
477
+ )
478
+
479
+ if __name__ == "__main__":
480
+ demo.queue(max_size=20).launch(
481
+ theme=orange_red_theme, css=css,
482
+ mcp_server=True,
483
+ ssr_mode=False,
484
+ show_error=True,
485
+ )