prithivMLmods commited on
Commit
ee702f3
·
verified ·
1 Parent(s): ccaf792

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -466
app.py DELETED
@@ -1,466 +0,0 @@
1
- import os
2
- import gc
3
- import gradio as gr
4
- import numpy as np
5
- import random
6
- import spaces
7
- import torch
8
- from diffusers import Flux2KleinPipeline, AutoencoderKLFlux2
9
- from PIL import Image
10
- from pathlib import Path
11
- from typing import Iterable
12
-
13
- from gradio.themes import Soft
14
- from gradio.themes.utils import colors, fonts, sizes
15
-
16
- colors.orange_red = colors.Color(
17
- name="orange_red",
18
- c50="#FFF0E5",
19
- c100="#FFE0CC",
20
- c200="#FFC299",
21
- c300="#FFA366",
22
- c400="#FF8533",
23
- c500="#FF4500",
24
- c600="#E63E00",
25
- c700="#CC3700",
26
- c800="#B33000",
27
- c900="#992900",
28
- c950="#802200",
29
- )
30
-
31
- class OrangeRedTheme(Soft):
32
- def __init__(
33
- self,
34
- *,
35
- primary_hue: colors.Color | str = colors.gray,
36
- secondary_hue: colors.Color | str = colors.orange_red,
37
- neutral_hue: colors.Color | str = colors.slate,
38
- text_size: sizes.Size | str = sizes.text_lg,
39
- font: fonts.Font | str | Iterable[fonts.Font | str] = (
40
- fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
41
- ),
42
- font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
43
- fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
44
- ),
45
- ):
46
- super().__init__(
47
- primary_hue=primary_hue,
48
- secondary_hue=secondary_hue,
49
- neutral_hue=neutral_hue,
50
- text_size=text_size,
51
- font=font,
52
- font_mono=font_mono,
53
- )
54
- super().set(
55
- background_fill_primary="*primary_50",
56
- background_fill_primary_dark="*primary_900",
57
- body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
58
- body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
59
- button_primary_text_color="white",
60
- button_primary_text_color_hover="white",
61
- button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
62
- button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
63
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
64
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
65
- button_secondary_text_color="black",
66
- button_secondary_text_color_hover="white",
67
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
68
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
69
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
70
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
71
- slider_color="*secondary_500",
72
- slider_color_dark="*secondary_600",
73
- block_title_text_weight="600",
74
- block_border_width="3px",
75
- block_shadow="*shadow_drop_lg",
76
- button_primary_shadow="*shadow_drop_lg",
77
- button_large_padding="11px",
78
- color_accent_soft="*primary_100",
79
- block_label_background_fill="*primary_200",
80
- )
81
-
82
- orange_red_theme = OrangeRedTheme()
83
-
84
- dtype = torch.bfloat16
85
- device = "cuda" if torch.cuda.is_available() else "cpu"
86
-
87
- MAX_SEED = np.iinfo(np.int32).max
88
- MAX_IMAGE_SIZE = 1024
89
- EXAMPLES_DIR = Path("examples")
90
-
91
- # ── Load standard pipeline ──────────────────────────────────────────────────
92
- print("Loading 4B Distilled model (Standard VAE)...")
93
- pipe_standard = Flux2KleinPipeline.from_pretrained(
94
- "black-forest-labs/FLUX.2-klein-4B",
95
- torch_dtype=dtype,
96
- )
97
- pipe_standard.enable_model_cpu_offload()
98
-
99
- # ── Load small decoder VAE ───────────────────────────────────────────────────
100
- print("Loading Small Decoder VAE...")
101
- vae_small = AutoencoderKLFlux2.from_pretrained(
102
- "black-forest-labs/FLUX.2-small-decoder",
103
- torch_dtype=dtype,
104
- )
105
-
106
- # ── Load small-decoder pipeline ──────────────────────────────────────────────
107
- print("Loading 4B Distilled model (Small Decoder VAE)...")
108
- pipe_small_decoder = Flux2KleinPipeline.from_pretrained(
109
- "black-forest-labs/FLUX.2-klein-4B",
110
- vae=vae_small,
111
- torch_dtype=dtype,
112
- )
113
- pipe_small_decoder.enable_model_cpu_offload()
114
-
115
- # ────────────────────────────────────────────────────────────────────────────
116
- def calc_dimensions(pil_img: Image.Image):
117
- iw, ih = pil_img.size
118
- aspect = iw / ih
119
-
120
- if aspect >= 1:
121
- new_width = 1024
122
- new_height = int(round(1024 / aspect))
123
- else:
124
- new_height = 1024
125
- new_width = int(round(1024 * aspect))
126
-
127
- new_width = max(256, min(1024, round(new_width / 8) * 8))
128
- new_height = max(256, min(1024, round(new_height / 8) * 8))
129
- return new_width, new_height
130
-
131
-
132
- def update_dimensions_from_image(image_list):
133
- if not image_list:
134
- return 1024, 1024
135
-
136
- item = image_list[0]
137
- img = item[0] if isinstance(item, tuple) else item
138
-
139
- if isinstance(img, str):
140
- img = Image.open(img).convert("RGB")
141
- elif not isinstance(img, Image.Image):
142
- return 1024, 1024
143
-
144
- return calc_dimensions(img)
145
-
146
-
147
- def parse_and_resize_images(input_images, width: int, height: int):
148
- if input_images is None:
149
- return None
150
-
151
- raw_list = []
152
-
153
- if isinstance(input_images, str):
154
- if os.path.exists(input_images):
155
- raw_list = [Image.open(input_images).convert("RGB")]
156
- elif isinstance(input_images, Image.Image):
157
- raw_list = [input_images.convert("RGB")]
158
- elif isinstance(input_images, list):
159
- for item in input_images:
160
- try:
161
- src = item[0] if isinstance(item, tuple) else item
162
- if isinstance(src, str):
163
- raw_list.append(Image.open(src).convert("RGB"))
164
- elif isinstance(src, Image.Image):
165
- raw_list.append(src.convert("RGB"))
166
- elif hasattr(src, "name"):
167
- raw_list.append(Image.open(src.name).convert("RGB"))
168
- except Exception as e:
169
- print(f"Skipping invalid image: {e}")
170
-
171
- if not raw_list:
172
- return None
173
-
174
- resized = [
175
- img.resize((width, height), Image.LANCZOS)
176
- for img in raw_list
177
- ]
178
- return resized
179
-
180
-
181
- def run_pipeline(pipe, kwargs, seed):
182
- """Run a single pipeline — no locks needed, purely sequential."""
183
- gen = torch.Generator(device="cpu").manual_seed(seed)
184
- result = pipe(**kwargs, generator=gen).images[0]
185
- return result
186
-
187
-
188
- @spaces.GPU(duration=120)
189
- def infer(
190
- prompt,
191
- input_images=None,
192
- seed=42,
193
- randomize_seed=False,
194
- width=1024,
195
- height=1024,
196
- num_inference_steps=4,
197
- guidance_scale=1.0,
198
- progress=gr.Progress(track_tqdm=True),
199
- ):
200
- gc.collect()
201
- torch.cuda.empty_cache()
202
-
203
- if not prompt or not prompt.strip():
204
- raise gr.Error("Please enter a prompt.")
205
-
206
- if randomize_seed:
207
- seed = random.randint(0, MAX_SEED)
208
-
209
- # ── Derive dimensions from the first uploaded image if present ───────────
210
- image_list = None
211
- if input_images:
212
- item = (
213
- input_images[0][0]
214
- if isinstance(input_images[0], tuple)
215
- else input_images[0]
216
- )
217
- if isinstance(item, str):
218
- first_pil = Image.open(item).convert("RGB")
219
- elif isinstance(item, Image.Image):
220
- first_pil = item.convert("RGB")
221
- else:
222
- first_pil = None
223
-
224
- if first_pil is not None:
225
- width, height = calc_dimensions(first_pil)
226
-
227
- image_list = parse_and_resize_images(input_images, width, height)
228
-
229
- # ensure dims are multiples of 8
230
- width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8))
231
- height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8))
232
-
233
- shared_kwargs = dict(
234
- prompt=prompt,
235
- height=height,
236
- width=width,
237
- num_inference_steps=num_inference_steps,
238
- guidance_scale=guidance_scale,
239
- )
240
- if image_list is not None:
241
- shared_kwargs["image"] = image_list
242
-
243
- # ── Pipeline 1: Standard Decoder ─────────────────────────────────────────
244
- progress(0.10, desc="Running Pipeline 1 / 2 — Standard Decoder...")
245
- out_standard = run_pipeline(pipe_standard, shared_kwargs, seed)
246
-
247
- gc.collect()
248
- torch.cuda.empty_cache()
249
-
250
- # ── Pipeline 2: Small Decoder ─────────────────────────────────────────────
251
- progress(0.55, desc="Running Pipeline 2 / 2 — Small Decoder...")
252
- out_small = run_pipeline(pipe_small_decoder, shared_kwargs, seed)
253
-
254
- gc.collect()
255
- torch.cuda.empty_cache()
256
-
257
- progress(1.00, desc="✅ Both pipelines complete!")
258
-
259
- return out_standard, out_small, seed
260
-
261
-
262
- @spaces.GPU(duration=120)
263
- def infer_example(prompt):
264
- out_std, out_small, seed_used = infer(
265
- prompt=prompt,
266
- input_images=None,
267
- seed=0,
268
- randomize_seed=True,
269
- width=1024,
270
- height=1024,
271
- num_inference_steps=4,
272
- guidance_scale=1.0,
273
- )
274
- return out_std, out_small, seed_used
275
-
276
-
277
- def get_example_items():
278
- example_prompts = {
279
- "1.jpg": "Change the weather to stormy.",
280
- "2.jpg": "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition.",
281
- "3.jpg": "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent.",
282
- "4.jpg": "Make the texture high-resolution.",
283
- }
284
- items = []
285
- if EXAMPLES_DIR.exists():
286
- for name in sorted(os.listdir(EXAMPLES_DIR)):
287
- if name.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
288
- items.append({
289
- "file": name,
290
- "path": str(EXAMPLES_DIR / name),
291
- "prompt": example_prompts.get(
292
- name, "Edit this image while preserving composition."
293
- ),
294
- })
295
- return items
296
-
297
- EXAMPLE_ITEMS = get_example_items()
298
-
299
- css = """
300
- #col-container {
301
- margin: 0 auto;
302
- max-width: 1100px;
303
- }
304
- #main-title h1 {
305
- font-size: 2.4em !important;
306
- }
307
- .vae-badge {
308
- font-weight: 700;
309
- font-size: 0.95em;
310
- text-align: center;
311
- padding: 4px 16px;
312
- border-radius: 20px;
313
- display: block;
314
- margin-bottom: 6px;
315
- }
316
- """
317
-
318
- with gr.Blocks() as demo:
319
-
320
- with gr.Column(elem_id="col-container"):
321
-
322
- gr.Markdown(
323
- "# **Flux.2-4B-Decoder-Comparator**",
324
- elem_id="main-title",
325
- )
326
- gr.Markdown(
327
- "Compare **FLUX.2-klein-4B** side-by-side with "
328
- "[small decoder](https://huggingface.co/black-forest-labs/FLUX.2-small-decoder). "
329
- "Both pipelines run **one after the other** using the **same seed and latents** — "
330
- "only the VAE decoder differs."
331
- )
332
-
333
- with gr.Row(equal_height=True):
334
-
335
- with gr.Column():
336
- input_images = gr.Gallery(
337
- label="Input Images",
338
- type="pil",
339
- columns=2,
340
- rows=1,
341
- height=300,
342
- allow_preview=True,
343
- )
344
-
345
- prompt = gr.Text(
346
- label="Prompt",
347
- max_lines=1,
348
- show_label=True,
349
- placeholder="e.g., A black cat holding a sign that says hello world...",
350
- )
351
-
352
- run_button = gr.Button("Run Comparison", variant="primary")
353
-
354
- with gr.Column():
355
- with gr.Row():
356
- with gr.Column():
357
- result_standard = gr.Image(
358
- label="① Standard Decoder (runs first)",
359
- show_label=True,
360
- interactive=False,
361
- format="png",
362
- height=250,
363
- )
364
- with gr.Column():
365
- result_small = gr.Image(
366
- label="② Small Decoder (runs second)",
367
- show_label=True,
368
- interactive=False,
369
- format="png",
370
- height=250,
371
- )
372
-
373
- seed_output = gr.Number(label="Seed Used", precision=0, visible=False)
374
-
375
- with gr.Accordion("Advanced Settings", open=False):
376
- seed = gr.Slider(
377
- label="Seed",
378
- minimum=0,
379
- maximum=MAX_SEED,
380
- step=1,
381
- value=0,
382
- )
383
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
384
-
385
- with gr.Row():
386
- width = gr.Slider(
387
- label="Width",
388
- minimum=256,
389
- maximum=MAX_IMAGE_SIZE,
390
- step=8,
391
- value=1024,
392
- )
393
- height_slider = gr.Slider(
394
- label="Height",
395
- minimum=256,
396
- maximum=MAX_IMAGE_SIZE,
397
- step=8,
398
- value=1024,
399
- )
400
-
401
- with gr.Row():
402
- num_inference_steps = gr.Slider(
403
- label="Inference Steps",
404
- minimum=1,
405
- maximum=20,
406
- step=1,
407
- value=4,
408
- )
409
- guidance_scale = gr.Slider(
410
- label="Guidance Scale",
411
- minimum=0.0,
412
- maximum=10.0,
413
- step=0.1,
414
- value=1.0,
415
- )
416
-
417
- gr.Examples(
418
- examples=[
419
- [["examples/I1.jpg", "examples/I2.jpg"], "Make her wear these glasses in Image 2."],
420
- [["examples/1.jpg"], "Change the weather to stormy."],
421
- [["examples/2.jpg"], "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition."],
422
- [["examples/3.jpg"], "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent."],
423
- [["examples/4.jpg"], "Make the texture high-resolution."],
424
- ],
425
- inputs=[input_images, prompt],
426
- outputs=[result_standard, result_small, seed_output],
427
- fn=infer_example,
428
- cache_examples=False,
429
- label="Examples",
430
- )
431
-
432
- gr.Markdown(
433
- "[*](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B) "
434
- "Experimental Space — FLUX.2 [klein] 4B VAE Decoder Comparison."
435
- )
436
-
437
- input_images.upload(
438
- fn=update_dimensions_from_image,
439
- inputs=[input_images],
440
- outputs=[width, height_slider],
441
- )
442
-
443
- gr.on(
444
- triggers=[run_button.click, prompt.submit],
445
- fn=infer,
446
- inputs=[
447
- prompt,
448
- input_images,
449
- seed,
450
- randomize_seed,
451
- width,
452
- height_slider,
453
- num_inference_steps,
454
- guidance_scale,
455
- ],
456
- outputs=[result_standard, result_small, seed_output],
457
- )
458
-
459
- if __name__ == "__main__":
460
- demo.queue(max_size=20).launch(
461
- theme=orange_red_theme,
462
- css=css,
463
- mcp_server=True,
464
- ssr_mode=False,
465
- show_error=True,
466
- )