prithivMLmods commited on
Commit
527aafd
·
verified ·
1 Parent(s): 401fd11

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +501 -0
app.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import gradio as gr
4
+ import numpy as np
5
+ import random
6
+ import spaces
7
+ import torch
8
+ from diffusers import Flux2KleinPipeline, AutoencoderKLFlux2
9
+ from PIL import Image
10
+ from pathlib import Path
11
+ import concurrent.futures
12
+ import threading
13
+ from typing import Iterable
14
+ from gradio.themes import Soft
15
+ from gradio.themes.utils import colors, fonts, sizes
16
+
17
+ # ── Theme ─────────────────────────────────────────────────────────────────────
18
+
19
+ colors.flux_teal = colors.Color(
20
+ name="flux_teal",
21
+ c50="#E6FAF8",
22
+ c100="#CCF5F1",
23
+ c200="#99EBE3",
24
+ c300="#66E1D5",
25
+ c400="#33D7C7",
26
+ c500="#00CDB9",
27
+ c600="#00A494",
28
+ c700="#007B6F",
29
+ c800="#00524A",
30
+ c900="#002925",
31
+ c950="#001412",
32
+ )
33
+
34
+
35
+ class FluxTheme(Soft):
36
+ def __init__(
37
+ self,
38
+ *,
39
+ primary_hue: colors.Color | str = colors.slate,
40
+ secondary_hue: colors.Color | str = colors.flux_teal,
41
+ neutral_hue: colors.Color | str = colors.slate,
42
+ text_size: sizes.Size | str = sizes.text_lg,
43
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
44
+ fonts.GoogleFont("Inter"),
45
+ "Arial",
46
+ "sans-serif",
47
+ ),
48
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
49
+ fonts.GoogleFont("JetBrains Mono"),
50
+ "ui-monospace",
51
+ "monospace",
52
+ ),
53
+ ):
54
+ super().__init__(
55
+ primary_hue=primary_hue,
56
+ secondary_hue=secondary_hue,
57
+ neutral_hue=neutral_hue,
58
+ text_size=text_size,
59
+ font=font,
60
+ font_mono=font_mono,
61
+ )
62
+ super().set(
63
+ background_fill_primary="*primary_50",
64
+ background_fill_primary_dark="*primary_900",
65
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
66
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
67
+ button_primary_text_color="white",
68
+ button_primary_text_color_hover="white",
69
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
70
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
71
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
72
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
73
+ button_secondary_text_color="black",
74
+ button_secondary_text_color_hover="white",
75
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
76
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
77
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
78
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
79
+ slider_color="*secondary_500",
80
+ slider_color_dark="*secondary_600",
81
+ block_title_text_weight="600",
82
+ block_border_width="3px",
83
+ block_shadow="*shadow_drop_lg",
84
+ button_primary_shadow="*shadow_drop_lg",
85
+ button_large_padding="11px",
86
+ color_accent_soft="*primary_100",
87
+ block_label_background_fill="*primary_200",
88
+ )
89
+
90
+
91
+ flux_theme = FluxTheme()
92
+
93
+ # ── Config ────────────────────────────────────────────────────────────────────
94
+
95
+ dtype = torch.bfloat16
96
+ device = "cuda" if torch.cuda.is_available() else "cpu"
97
+
98
+ MAX_SEED = np.iinfo(np.int32).max
99
+ MAX_IMAGE_SIZE = 1024
100
+ EXAMPLES_DIR = Path("examples")
101
+
102
+ # ── Models ────────────────────────────────────────────────────────────────────
103
+
104
+ print("Loading 4B Distilled model (Standard VAE)...")
105
+ pipe_standard = Flux2KleinPipeline.from_pretrained(
106
+ "black-forest-labs/FLUX.2-klein-4B",
107
+ torch_dtype=dtype,
108
+ )
109
+ pipe_standard.enable_model_cpu_offload()
110
+
111
+ print("Loading Small Decoder VAE...")
112
+ vae_small = AutoencoderKLFlux2.from_pretrained(
113
+ "black-forest-labs/FLUX.2-small-decoder",
114
+ torch_dtype=dtype,
115
+ )
116
+
117
+ print("Loading 4B Distilled model (Small Decoder VAE)...")
118
+ pipe_small_decoder = Flux2KleinPipeline.from_pretrained(
119
+ "black-forest-labs/FLUX.2-klein-4B",
120
+ vae=vae_small,
121
+ torch_dtype=dtype,
122
+ )
123
+ pipe_small_decoder.enable_model_cpu_offload()
124
+
125
+ pipe_lock_standard = threading.Lock()
126
+ pipe_lock_small = threading.Lock()
127
+
128
+ # ── Helpers ──────────────────────────��────────────────────────────────────────
129
+
130
+ def update_dimensions_from_image(image_list):
131
+ if image_list is None or len(image_list) == 0:
132
+ return 1024, 1024
133
+
134
+ item = image_list[0]
135
+ img = item[0] if isinstance(item, tuple) else item
136
+
137
+ if isinstance(img, str):
138
+ img = Image.open(img).convert("RGB")
139
+
140
+ iw, ih = img.size
141
+ aspect_ratio = iw / ih
142
+
143
+ if aspect_ratio >= 1:
144
+ new_width = 1024
145
+ new_height = int(1024 / aspect_ratio)
146
+ else:
147
+ new_height = 1024
148
+ new_width = int(1024 * aspect_ratio)
149
+
150
+ new_width = max(256, min(1024, round(new_width / 8) * 8))
151
+ new_height = max(256, min(1024, round(new_height / 8) * 8))
152
+ return new_width, new_height
153
+
154
+
155
+ def get_example_items():
156
+ example_prompts = {
157
+ "1.jpg": "Change the weather to stormy.",
158
+ "2.jpg": "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition.",
159
+ "3.jpg": "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent.",
160
+ "4.jpg": "Make the texture high-resolution.",
161
+ }
162
+ items = []
163
+ if EXAMPLES_DIR.exists():
164
+ for name in sorted(os.listdir(EXAMPLES_DIR)):
165
+ if name.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
166
+ items.append({
167
+ "file": name,
168
+ "path": str(EXAMPLES_DIR / name),
169
+ "prompt": example_prompts.get(
170
+ name, "Edit this image while preserving composition."
171
+ ),
172
+ })
173
+ return items
174
+
175
+
176
+ def parse_input_images(input_images):
177
+ """Safely parse gallery / filepath / PIL inputs → list[PIL.Image] or None."""
178
+ if input_images is None:
179
+ return None
180
+ if isinstance(input_images, str):
181
+ return [Image.open(input_images).convert("RGB")] if os.path.exists(input_images) else None
182
+ if isinstance(input_images, list) and len(input_images) > 0:
183
+ parsed = []
184
+ for item in input_images:
185
+ try:
186
+ src = item[0] if isinstance(item, tuple) else item
187
+ if isinstance(src, str):
188
+ parsed.append(Image.open(src).convert("RGB"))
189
+ elif isinstance(src, Image.Image):
190
+ parsed.append(src.convert("RGB"))
191
+ elif hasattr(src, "name"):
192
+ parsed.append(Image.open(src.name).convert("RGB"))
193
+ except Exception as e:
194
+ print(f"Skipping invalid image: {e}")
195
+ return parsed or None
196
+ return None
197
+
198
+
199
+ # ── Per-pipeline worker ───────────────────────────────────────────────────────
200
+
201
+ def run_pipeline(pipe, lock, kwargs, seed):
202
+ with lock:
203
+ gen = torch.Generator(device="cpu").manual_seed(seed)
204
+ result = pipe(**kwargs, generator=gen).images[0]
205
+ return result
206
+
207
+
208
+ # ── Inference ─────────────────────────────────────────────────────────────────
209
+
210
+ @spaces.GPU(duration=120)
211
+ def infer(
212
+ prompt,
213
+ input_images=None,
214
+ seed=42,
215
+ randomize_seed=False,
216
+ width=1024,
217
+ height=1024,
218
+ num_inference_steps=4,
219
+ guidance_scale=1.0,
220
+ progress=gr.Progress(track_tqdm=True),
221
+ ):
222
+ gc.collect()
223
+ torch.cuda.empty_cache()
224
+
225
+ if not prompt or prompt.strip() == "":
226
+ raise gr.Error("Please enter a prompt.")
227
+
228
+ if randomize_seed:
229
+ seed = random.randint(0, MAX_SEED)
230
+
231
+ image_list = parse_input_images(input_images)
232
+
233
+ shared_kwargs = dict(
234
+ prompt=prompt,
235
+ height=height,
236
+ width=width,
237
+ num_inference_steps=num_inference_steps,
238
+ guidance_scale=guidance_scale,
239
+ )
240
+ if image_list is not None:
241
+ shared_kwargs["image"] = image_list
242
+
243
+ progress(0.05, desc="⚡ Launching both pipelines simultaneously...")
244
+
245
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
246
+ future_std = executor.submit(run_pipeline, pipe_standard, pipe_lock_standard, shared_kwargs, seed)
247
+ future_small = executor.submit(run_pipeline, pipe_small_decoder, pipe_lock_small, shared_kwargs, seed)
248
+ concurrent.futures.wait(
249
+ [future_std, future_small],
250
+ return_when=concurrent.futures.ALL_COMPLETED,
251
+ )
252
+
253
+ progress(0.95, desc="✅ Both pipelines done!")
254
+
255
+ out_standard = future_std.result()
256
+ out_small = future_small.result()
257
+
258
+ gc.collect()
259
+ torch.cuda.empty_cache()
260
+
261
+ return out_standard, out_small, seed
262
+
263
+
264
+ # Text-only wrapper used by gr.Examples (avoids Gallery type issues entirely)
265
+ @spaces.GPU(duration=120)
266
+ def infer_example(prompt):
267
+ out_std, out_small, seed_used = infer(
268
+ prompt=prompt,
269
+ input_images=None,
270
+ seed=0,
271
+ randomize_seed=True,
272
+ width=1024,
273
+ height=1024,
274
+ num_inference_steps=4,
275
+ guidance_scale=1.0,
276
+ )
277
+ return out_std, out_small, seed_used
278
+
279
+
280
+ # ── Load examples ─────────────────────────────────────────────────────────────
281
+
282
+ EXAMPLE_ITEMS = get_example_items()
283
+
284
+ # ── CSS ───────────────────────────────────────────────────────────────────────
285
+
286
+ css = """
287
+ #col-container {
288
+ margin: 0 auto;
289
+ max-width: 1100px;
290
+ }
291
+ #main-title h1 {
292
+ font-size: 2.4em !important;
293
+ }
294
+ .vae-badge {
295
+ font-weight: 700;
296
+ font-size: 0.95em;
297
+ text-align: center;
298
+ padding: 4px 16px;
299
+ border-radius: 20px;
300
+ display: block;
301
+ margin-bottom: 6px;
302
+ }
303
+ """
304
+
305
+ # ── UI ────────────────────────────────────────────────────────────────────────
306
+
307
+ with gr.Blocks() as demo:
308
+
309
+ with gr.Column(elem_id="col-container"):
310
+
311
+ gr.Markdown(
312
+ "# ⚡ **Flux.2-4B-Encoder-Comparator**",
313
+ elem_id="main-title",
314
+ )
315
+ gr.Markdown(
316
+ "Compare **FLUX.2-klein-4B** side-by-side with two VAE decoders — "
317
+ "generated **simultaneously** from the **same seed**. "
318
+ "🟦 **Standard VAE** vs 🟩 **Small Decoder VAE** "
319
+ "([FLUX.2-small-decoder](https://huggingface.co/black-forest-labs/FLUX.2-small-decoder)) · "
320
+ "[[model](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B)] · "
321
+ "[[blog](https://bfl.ai/blog/flux-2)]"
322
+ )
323
+
324
+ # ── Main two-column row ───────────────────────────────────────────────
325
+ with gr.Row(equal_height=True):
326
+
327
+ # ── Left: inputs ─────────────────────────────────────────────────
328
+ with gr.Column():
329
+ input_images = gr.Gallery(
330
+ label="Input Image(s) for Editing (optional)",
331
+ type="pil",
332
+ columns=2,
333
+ rows=1,
334
+ height=280,
335
+ allow_preview=True,
336
+ )
337
+
338
+ prompt = gr.Text(
339
+ label="Prompt",
340
+ show_label=True,
341
+ placeholder="e.g., A black cat holding a sign that says hello world...",
342
+ )
343
+
344
+ run_button = gr.Button("⚡ Run Comparison", variant="primary")
345
+
346
+ # ── Right: outputs ────────────────────────────────────────────────
347
+ with gr.Column():
348
+ with gr.Row():
349
+ with gr.Column():
350
+ gr.HTML(
351
+ '<span class="vae-badge" '
352
+ 'style="background:#dbeafe;color:#1d4ed8;">'
353
+ '🟦 Standard VAE</span>'
354
+ )
355
+ result_standard = gr.Image(
356
+ label="Standard VAE",
357
+ show_label=False,
358
+ interactive=False,
359
+ format="png",
360
+ height=280,
361
+ )
362
+
363
+ with gr.Column():
364
+ gr.HTML(
365
+ '<span class="vae-badge" '
366
+ 'style="background:#d1fae5;color:#065f46;">'
367
+ '🟩 Small Decoder VAE</span>'
368
+ )
369
+ result_small = gr.Image(
370
+ label="Small Decoder VAE",
371
+ show_label=False,
372
+ interactive=False,
373
+ format="png",
374
+ height=280,
375
+ )
376
+
377
+ seed_output = gr.Number(label="Seed Used", precision=0)
378
+
379
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
380
+ seed = gr.Slider(
381
+ label="Seed",
382
+ minimum=0,
383
+ maximum=MAX_SEED,
384
+ step=1,
385
+ value=0,
386
+ )
387
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
388
+
389
+ with gr.Row():
390
+ width = gr.Slider(
391
+ label="Width",
392
+ minimum=256,
393
+ maximum=MAX_IMAGE_SIZE,
394
+ step=8,
395
+ value=1024,
396
+ )
397
+ height_slider = gr.Slider(
398
+ label="Height",
399
+ minimum=256,
400
+ maximum=MAX_IMAGE_SIZE,
401
+ step=8,
402
+ value=1024,
403
+ )
404
+
405
+ with gr.Row():
406
+ num_inference_steps = gr.Slider(
407
+ label="Inference Steps",
408
+ minimum=1,
409
+ maximum=20,
410
+ step=1,
411
+ value=4,
412
+ )
413
+ guidance_scale = gr.Slider(
414
+ label="Guidance Scale",
415
+ minimum=0.0,
416
+ maximum=10.0,
417
+ step=0.1,
418
+ value=1.0,
419
+ )
420
+
421
+ # ── Examples — prompt-only, no Gallery input ──────────────────────────
422
+ # Build rows: [prompt_str] (no image column → no Gallery postprocess)
423
+ if EXAMPLE_ITEMS:
424
+ prompt_only_rows = [[item["prompt"]] for item in EXAMPLE_ITEMS]
425
+
426
+ gr.Examples(
427
+ examples=prompt_only_rows,
428
+ inputs=[prompt], # ← only Text, never Gallery
429
+ outputs=[result_standard, result_small, seed_output],
430
+ fn=infer_example, # ← wrapper with no image arg
431
+ cache_examples=False,
432
+ label="Examples",
433
+ )
434
+
435
+ # ── Visual image cards (click to load image + prompt) ─────────────────
436
+ if EXAMPLE_ITEMS:
437
+ gr.Markdown("#### 🖼️ Image Editing Examples — click to load")
438
+ with gr.Row():
439
+ for item in EXAMPLE_ITEMS:
440
+ with gr.Column(scale=1, min_width=180):
441
+ gr.Image(
442
+ value=item["path"],
443
+ show_label=False,
444
+ interactive=False,
445
+ height=150,
446
+ )
447
+ card_btn = gr.Button(
448
+ (item["prompt"][:48] + "…")
449
+ if len(item["prompt"]) > 48
450
+ else item["prompt"],
451
+ size="sm",
452
+ )
453
+
454
+ def _make_loader(p, path):
455
+ def _load():
456
+ pil = Image.open(path).convert("RGB")
457
+ return p, [(pil, None)]
458
+ return _load
459
+
460
+ card_btn.click(
461
+ fn=_make_loader(item["prompt"], item["path"]),
462
+ inputs=[],
463
+ outputs=[prompt, input_images],
464
+ )
465
+
466
+ gr.Markdown(
467
+ "[*](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B) "
468
+ "Experimental Space — FLUX.2 [klein] 4B VAE Decoder Comparison."
469
+ )
470
+
471
+ # ── Events ────────────────────────────────────────────────────────────────
472
+
473
+ input_images.upload(
474
+ fn=update_dimensions_from_image,
475
+ inputs=[input_images],
476
+ outputs=[width, height_slider],
477
+ )
478
+
479
+ gr.on(
480
+ triggers=[run_button.click, prompt.submit],
481
+ fn=infer,
482
+ inputs=[
483
+ prompt,
484
+ input_images,
485
+ seed,
486
+ randomize_seed,
487
+ width,
488
+ height_slider,
489
+ num_inference_steps,
490
+ guidance_scale,
491
+ ],
492
+ outputs=[result_standard, result_small, seed_output],
493
+ )
494
+
495
+ if __name__ == "__main__":
496
+ demo.queue(max_size=20).launch(
497
+ css=css,
498
+ theme=flux_theme,
499
+ ssr_mode=False,
500
+ show_error=True,
501
+ )