prithivMLmods commited on
Commit
fa1b44c
Β·
verified Β·
1 Parent(s): 166b9fb

update app

Browse files
Files changed (1) hide show
  1. app.py +814 -0
app.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import tempfile
5
+ from typing import Iterable
6
+
7
+ import torch
8
+ import numpy as np
9
+ import gradio as gr
10
+ from PIL import Image
11
+ from types import SimpleNamespace
12
+ from huggingface_hub import snapshot_download
13
+
14
+ import spaces
15
+
16
+ from gradio.themes import Soft
17
+ from gradio.themes.utils import colors, fonts, sizes
18
+
19
+ colors.orange_red = colors.Color(
20
+ name="orange_red", c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366",
21
+ c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700", c800="#B33000",
22
+ c900="#992900", c950="#802200",
23
+ )
24
+
25
+ class OrangeRedTheme(Soft):
26
+ def __init__(
27
+ self, *, primary_hue: colors.Color | str = colors.gray,
28
+ secondary_hue: colors.Color | str = colors.orange_red,
29
+ neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg,
30
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
31
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
32
+ ),
33
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
34
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
35
+ ),
36
+ ):
37
+ super().__init__(
38
+ primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue,
39
+ text_size=text_size, font=font, font_mono=font_mono,
40
+ )
41
+ super().set(
42
+ background_fill_primary="*primary_50",
43
+ background_fill_primary_dark="*primary_900",
44
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
45
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
46
+ button_primary_text_color="white",
47
+ button_primary_text_color_hover="white",
48
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
49
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
50
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
51
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
52
+ slider_color="*secondary_500",
53
+ slider_color_dark="*secondary_600",
54
+ block_title_text_weight="600", block_border_width="3px",
55
+ block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg",
56
+ button_large_padding="11px", color_accent_soft="*primary_100",
57
+ block_label_background_fill="*primary_200",
58
+ )
59
+
60
+ orange_red_theme = OrangeRedTheme()
61
+
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
65
+ print("torch.__version__ =", torch.__version__)
66
+ print("torch.version.cuda =", torch.version.cuda)
67
+ print("cuda available:", torch.cuda.is_available())
68
+ print("cuda device count:", torch.cuda.device_count())
69
+ if torch.cuda.is_available():
70
+ print("current device:", torch.cuda.current_device())
71
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
72
+
73
+ print("Using device:", device)
74
+
75
+ # Help the allocator survive the large-activation spikes during PiD pixel-space ops
76
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
77
+
78
+ PID_REPO_URL = "https://github.com/nv-tlabs/PiD.git"
79
+ PID_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "PiD")
80
+
81
+ if not os.path.exists(PID_REPO_DIR):
82
+ print(f"[pid] cloning {PID_REPO_URL} -> {PID_REPO_DIR}", flush=True)
83
+ subprocess.check_call(["git", "clone", "--depth", "1", PID_REPO_URL, PID_REPO_DIR])
84
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", PID_REPO_DIR])
85
+
86
+ # PiD's loader resolves paths relative to CWD, so chdir into the repo root.
87
+ os.chdir(PID_REPO_DIR)
88
+ sys.path.insert(0, PID_REPO_DIR)
89
+
90
+ # Pull just the Flux-1 / Z-Image-compatible checkpoints from nvidia/PiD into the
91
+ # repo's expected checkpoints/ tree.
92
+ snapshot_download(
93
+ repo_id="nvidia/PiD",
94
+ local_dir=PID_REPO_DIR,
95
+ allow_patterns=[
96
+ "checkpoints/PiD_res2k_sr4x_official_flux_distill_4step/*",
97
+ "checkpoints/PiD_res2kto4k_sr4x_official_flux_distill_4step/*",
98
+ "checkpoints/ae.safetensors",
99
+ ],
100
+ )
101
+
102
+ from pid._src.inference.checkpoint_registry import get_pid_checkpoint
103
+ from pid._src.inference.create_dataset import XtCaptureCallback
104
+ from pid._src.inference.pipeline_registry import (
105
+ decode_with_pipeline_vae,
106
+ extract_latent,
107
+ load_pipeline,
108
+ )
109
+ from pid._src.utils.model_loader import load_model_from_checkpoint
110
+
111
+
112
+ DTYPE = torch.bfloat16
113
+ BACKBONE = "zimage"
114
+ SR_SCALE = 4
115
+ PID_INFERENCE_STEPS = 4
116
+ MAX_SEED = 2**31 - 1
117
+
118
+ print("[pid] loading Z-Image pipeline...", flush=True)
119
+
120
+ # transformers 4.57's SDPA / eager mask builders both broadcast the mask
121
+ # function over (b, h, q, k) via torch.vmap, which trips ZeroGPU's
122
+ # __torch_function__ hijack when it tries to fake-allocate the indexed
123
+ # tensors. Replace vmap with explicit broadcasting β€” same result, same speed,
124
+ # no functorch transform context.
125
+ from transformers import masking_utils as _mu
126
+
127
+ def _broadcasting_vmap_for_bhqkv(mask_function, bh_indices: bool = True):
128
+ def wrapped(b, h, q, k):
129
+ if bh_indices:
130
+ return mask_function(
131
+ b[:, None, None, None],
132
+ h[None, :, None, None],
133
+ q[None, None, :, None],
134
+ k[None, None, None, :],
135
+ )
136
+ return mask_function(b, h, q[:, None], k[None, :])
137
+ return wrapped
138
+
139
+ _mu._vmap_for_bhqkv = _broadcasting_vmap_for_bhqkv
140
+
141
+ # Gemma2's forward does `normalizer = torch.tensor(hidden_size**0.5, dtype=...)`
142
+ # without a device kwarg, so it lands on CPU while hidden_states is on cuda.
143
+ # Vanilla CUDA tolerates the cross-device scalar op; ZeroGPU's __torch_function__
144
+ # hijack rejects it. Force torch.tensor calls inside Gemma2.forward onto the
145
+ # embedding's device.
146
+ import transformers.models.gemma2.modeling_gemma2 as _gm
147
+
148
+ _orig_gemma2_forward = _gm.Gemma2Model.forward
149
+
150
+ def _patched_gemma2_forward(self, *args, **kwargs):
151
+ _orig_tt = torch.tensor
152
+ dev = self.embed_tokens.weight.device
153
+ def _tt(data, *a, **kw):
154
+ kw.setdefault("device", dev)
155
+ return _orig_tt(data, *a, **kw)
156
+ torch.tensor = _tt
157
+ try:
158
+ return _orig_gemma2_forward(self, *args, **kwargs)
159
+ finally:
160
+ torch.tensor = _orig_tt
161
+
162
+ _gm.Gemma2Model.forward = _patched_gemma2_forward
163
+
164
+ pipeline, pipe_cfg = load_pipeline(BACKBONE, dtype=DTYPE)
165
+ pipeline.to("cuda")
166
+
167
+ print("[pid] loading TAEF1 (fast preview decoder)...", flush=True)
168
+ from diffusers import AutoencoderTiny
169
+ taef1 = AutoencoderTiny.from_pretrained(
170
+ "madebyollin/taef1", torch_dtype=DTYPE, low_cpu_mem_usage=False
171
+ ).to("cuda")
172
+ taef1.eval()
173
+
174
+ def _load_pid(ckpt_type: str):
175
+ meta = get_pid_checkpoint(BACKBONE, ckpt_type)
176
+ print(f"[pid] loading PiD decoder ({ckpt_type})...", flush=True)
177
+ model, _ = load_model_from_checkpoint(
178
+ experiment_name=meta.experiment,
179
+ checkpoint_path=meta.checkpoint_path,
180
+ config_file="pid/_src/configs/pid/config.py",
181
+ enable_fsdp=False,
182
+ strict=False,
183
+ )
184
+ model.eval()
185
+ return model
186
+
187
+
188
+ pid_models = {
189
+ "2k": _load_pid("2k"),
190
+ "2kto4k": _load_pid("2kto4k"),
191
+ }
192
+
193
+
194
+ print("[pid] loading FLUX.2-Klein pipeline...", flush=True)
195
+ from diffusers import Flux2KleinPipeline
196
+
197
+ klein_pipe = Flux2KleinPipeline.from_pretrained(
198
+ "black-forest-labs/FLUX.2-klein-4B",
199
+ torch_dtype=DTYPE,
200
+ ).to("cuda")
201
+ print("[pid] FLUX.2-Klein loaded.", flush=True)
202
+
203
+ print("[pid] ready", flush=True)
204
+
205
+
206
+ def _pick_pid_model(resolution: int):
207
+ """2k decoder is trained at 2048px (sweet spot 512 β†’ 2048); 2kto4k handles 1024 β†’ 4K."""
208
+ return pid_models["2kto4k"] if resolution > 512 else pid_models["2k"]
209
+
210
+
211
+ def _latent_to_pil(tensor: torch.Tensor) -> Image.Image:
212
+ """PiD output is (C, T, H, W) with T=1 for image -> PIL.Image."""
213
+ if tensor.dim() == 4:
214
+ tensor = tensor.squeeze(1)
215
+ arr = ((tensor.float().clamp(-1, 1) + 1) * 127.5).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
216
+ return Image.fromarray(arr)
217
+
218
+
219
+ def _taef1_preview(packed_latent: torch.Tensor, H: int, W: int) -> Image.Image:
220
+ """Fast low-res decode of a Z-Image latent using TAEF1 (FLUX-1 compatible)."""
221
+ with torch.no_grad():
222
+ unpacked = extract_latent(pipeline, SimpleNamespace(images=packed_latent), pipe_cfg, H, W)
223
+ scale = pipeline.vae.config.scaling_factor
224
+ shift = getattr(pipeline.vae.config, "shift_factor", None) or 0.0
225
+ denorm = unpacked.to(dtype=DTYPE) / scale + shift
226
+ img = taef1.decode(denorm).sample
227
+ img = (img.float().clamp(-1, 1) + 1) / 2
228
+ arr = (img[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
229
+ return Image.fromarray(arr)
230
+
231
+
232
+ def _pid_pixel_to_pil(x: torch.Tensor) -> Image.Image:
233
+ """PiD pixel-space tensor (B, 3, H, W) in [-1, 1] -> PIL.Image."""
234
+ arr = ((x[0].float().clamp(-1, 1) + 1) * 127.5).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
235
+ return Image.fromarray(arr)
236
+
237
+
238
+ def _pid_stream(
239
+ pid_model,
240
+ latent: torch.Tensor,
241
+ baseline_01: torch.Tensor,
242
+ sigma: float,
243
+ caption: str,
244
+ num_steps: int = PID_INFERENCE_STEPS,
245
+ ):
246
+ """Reimplementation of PiDDistillModel.generate_samples_from_batch that yields
247
+ the current pixel-space tensor after each of the `num_steps` student-sampler
248
+ iterations. Final yield is the clean output."""
249
+ from contextlib import nullcontext
250
+
251
+ B = 1
252
+ lq_h, lq_w = baseline_01.shape[-2], baseline_01.shape[-1]
253
+ img_h, img_w = lq_h * SR_SCALE, lq_w * SR_SCALE
254
+
255
+ caption_embs, _ = pid_model._encode_text_raw([caption])
256
+ caption_embs = caption_embs.to(**pid_model.tensor_kwargs)
257
+
258
+ lq_video_or_image = (baseline_01 * 2.0 - 1.0).to(dtype=DTYPE, device="cuda")
259
+ lq_latent = latent.to(dtype=DTYPE, device="cuda")
260
+ degrade_sigma_tensor = torch.tensor([sigma], device="cuda", dtype=torch.float32)
261
+
262
+ gen = torch.Generator(device="cuda").manual_seed(0)
263
+ noise = torch.randn(B, 3, img_h, img_w, device="cuda", generator=gen)
264
+
265
+ t_list = pid_model._get_t_list(device=torch.device("cuda"), num_steps=num_steps)
266
+ autocast_ctx = (
267
+ torch.autocast("cuda", dtype=pid_model.autocast_dtype)
268
+ if pid_model.autocast_dtype
269
+ else nullcontext()
270
+ )
271
+ net = pid_model.net
272
+ net.eval()
273
+ timescale = pid_model.fm_trainer.timescale
274
+ student_sample_type = pid_model.config.student_sample_type
275
+ prediction_type = pid_model.config.prediction_type
276
+
277
+ x = noise
278
+ with torch.no_grad(), autocast_ctx:
279
+ steps_total = len(t_list) - 1
280
+ for step_idx, (t_cur, t_next) in enumerate(zip(t_list[:-1], t_list[1:])):
281
+ t_cur_batch = t_cur.expand(B)
282
+ t_cur_scaled = t_cur_batch * timescale
283
+ v_pred = net(
284
+ x,
285
+ t_cur_scaled,
286
+ caption_embs,
287
+ lq_video_or_image=lq_video_or_image,
288
+ lq_latent=lq_latent,
289
+ degrade_sigma=degrade_sigma_tensor,
290
+ )
291
+ if t_next.item() > 0:
292
+ if student_sample_type == "ode":
293
+ v_for_step = pid_model._net_output_to_velocity(x, v_pred, t_cur_batch, prediction_type)
294
+ dt = t_next - t_cur
295
+ x = x + dt * v_for_step
296
+ else:
297
+ x0_pred = pid_model._velocity_to_x0(x, v_pred, t_cur_batch)
298
+ eps_infer = torch.randn(
299
+ x0_pred.shape, device=x0_pred.device, dtype=x0_pred.dtype, generator=gen
300
+ )
301
+ s = [B] + [1] * (x.ndim - 1)
302
+ t_next_bcast = t_next.reshape(1).expand(s)
303
+ x = (1.0 - t_next_bcast) * x0_pred + t_next_bcast * eps_infer
304
+ else:
305
+ x = pid_model._velocity_to_x0(x, v_pred, t_cur_batch)
306
+ yield step_idx + 1, steps_total, x.clone()
307
+
308
+
309
+ def _evenly_spaced_capture_steps(total_steps: int, num_captures: int) -> list[int]:
310
+ """Pick N capture indices spread across [1, total_steps-1]."""
311
+ if num_captures <= 0:
312
+ return []
313
+ raw = np.linspace(1, max(2, total_steps - 1), num_captures + 1)[1:]
314
+ return sorted({int(round(x)) for x in raw})
315
+
316
+
317
+ def _resize_to_divisible(image: Image.Image, max_side: int = 1024, div: int = 16) -> Image.Image:
318
+ """Resize so the longer side ≀ max_side and both dims divisible by `div`.
319
+ Never upscales the input image."""
320
+ w, h = image.size
321
+ scale = min(max_side / w, max_side / h, 1.0)
322
+ nw = max(div, (int(w * scale) // div) * div)
323
+ nh = max(div, (int(h * scale) // div) * div)
324
+ return image.resize((nw, nh), Image.LANCZOS)
325
+
326
+
327
+ def _encode_image_to_latent(image_01: torch.Tensor) -> torch.Tensor:
328
+ """Encode a (1, 3, H, W) [0,1] float tensor to a VAE latent via the Z-Image VAE."""
329
+ vae = pipeline.vae
330
+ image_norm = image_01 * 2.0 - 1.0 # [0,1] β†’ [-1,1]
331
+ with torch.no_grad():
332
+ latent = vae.encode(image_norm.to(dtype=DTYPE, device="cuda")).latent_dist.sample()
333
+ scale = vae.config.scaling_factor
334
+ shift = getattr(vae.config, "shift_factor", None) or 0.0
335
+ latent = (latent - shift) * scale
336
+ return latent
337
+
338
+
339
+ import random
340
+ import threading
341
+ import queue as _queue
342
+
343
+ def _generate_core(
344
+ prompt: str,
345
+ num_inference_steps: int = 28,
346
+ guidance_scale: float = 5.0,
347
+ seed: int = 0,
348
+ resolution: int = 512,
349
+ randomize_seed: bool = False,
350
+ ):
351
+ if not prompt or not prompt.strip():
352
+ raise gr.Error("Please enter a prompt.")
353
+
354
+ if randomize_seed:
355
+ seed = random.randint(0, 2**31 - 1)
356
+ seed = int(seed)
357
+ num_inference_steps = int(num_inference_steps)
358
+ H = W = int(resolution)
359
+
360
+ # initial: show the live preview, hide the final slider
361
+ yield gr.update(visible=True, value=None, label="Generating Z-Image…"), gr.update(visible=False, value=None), gr.update(value=seed)
362
+
363
+ # ---- Run Z-Image in a thread; stream taef1 previews via a queue ----
364
+ preview_q: "_queue.Queue" = _queue.Queue()
365
+ _DONE = object()
366
+
367
+ def streaming_cb(pipe, step_index, timestep, callback_kwargs):
368
+ try:
369
+ preview = _taef1_preview(callback_kwargs["latents"], H, W)
370
+ preview_q.put((step_index, preview))
371
+ except Exception as e:
372
+ print(f"[pid] taef1 preview failed at step {step_index}: {e}", flush=True)
373
+ return callback_kwargs
374
+
375
+ def run_pipeline():
376
+ gen_torch = torch.Generator(device="cuda").manual_seed(int(seed))
377
+ gen_kwargs = dict(
378
+ prompt=prompt,
379
+ height=H,
380
+ width=W,
381
+ num_inference_steps=num_inference_steps,
382
+ guidance_scale=float(guidance_scale),
383
+ num_images_per_prompt=1,
384
+ output_type="latent",
385
+ generator=gen_torch,
386
+ callback_on_step_end=streaming_cb,
387
+ callback_on_step_end_tensor_inputs=["latents"],
388
+ )
389
+ gen_kwargs.update(pipe_cfg.extra_generate_kwargs)
390
+ try:
391
+ with torch.no_grad():
392
+ out = pipeline(**gen_kwargs)
393
+ preview_q.put((_DONE, out))
394
+ except Exception as e:
395
+ preview_q.put((_DONE, e))
396
+
397
+ thread = threading.Thread(target=run_pipeline, daemon=True)
398
+ thread.start()
399
+
400
+ raw_output = None
401
+ while True:
402
+ step_index, payload = preview_q.get()
403
+ if step_index is _DONE:
404
+ if isinstance(payload, Exception):
405
+ raise payload
406
+ raw_output = payload
407
+ break
408
+ label = f"Generating Z-Image β€” step {step_index + 1}/{num_inference_steps}"
409
+ yield gr.update(visible=True, value=payload, label=label), gr.update(visible=False), gr.update()
410
+
411
+ thread.join()
412
+ final_latent = extract_latent(pipeline, raw_output, pipe_cfg, H, W)
413
+
414
+ yield gr.update(visible=True, label="Decoding final Z-Image…"), gr.update(visible=False), gr.update()
415
+ with torch.no_grad():
416
+ baseline_01 = decode_with_pipeline_vae(pipeline, final_latent, pipe_cfg)
417
+ zimage_img = Image.fromarray(
418
+ (baseline_01[0].clamp(0, 1).permute(1, 2, 0).float().cpu().numpy() * 255).astype(np.uint8)
419
+ )
420
+
421
+ torch.cuda.empty_cache()
422
+
423
+ final_sigma = float(pipeline.scheduler.sigmas[-1].item())
424
+ pid_img = None
425
+ pid_model = _pick_pid_model(H)
426
+ for k, total, x in _pid_stream(pid_model, final_latent, baseline_01, final_sigma, prompt):
427
+ pid_img = _pid_pixel_to_pil(x)
428
+ yield (
429
+ gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD β€” step {k}/{total}"),
430
+ gr.update(visible=False),
431
+ gr.update(),
432
+ )
433
+
434
+ yield (
435
+ gr.update(visible=False, value=None),
436
+ gr.update(visible=True, value=(zimage_img, pid_img)),
437
+ gr.update(),
438
+ )
439
+
440
+
441
+ @spaces.GPU(duration=60)
442
+ def generate_large(*args, **kwargs):
443
+ yield from _generate_core(*args, **kwargs)
444
+
445
+
446
+ @spaces.GPU(duration=90, size="xlarge")
447
+ def generate_xlarge(*args, **kwargs):
448
+ yield from _generate_core(*args, **kwargs)
449
+
450
+
451
+ def generate(prompt, num_inference_steps, guidance_scale, seed, resolution, randomize_seed):
452
+ fn = generate_xlarge if int(resolution) >= 1024 else generate_large
453
+ yield from fn(prompt, num_inference_steps, guidance_scale, seed, resolution, randomize_seed)
454
+
455
+
456
+ def update_dimensions_on_upload(image: Image.Image):
457
+ """Return markdown info string after safe resize."""
458
+ if image is None:
459
+ return "_Upload an image to see its processed dimensions._"
460
+ resized = _resize_to_divisible(image)
461
+ ow, oh = image.size
462
+ nw, nh = resized.size
463
+ return (
464
+ f"**Input:** {ow} Γ— {oh} px β†’ "
465
+ f"**Processed:** {nw} Γ— {nh} px β†’ "
466
+ f"**PiD output:** {nw * SR_SCALE} Γ— {nh * SR_SCALE} px"
467
+ )
468
+
469
+
470
+ def _i2i_generate_core(
471
+ input_image: Image.Image,
472
+ prompt: str,
473
+ seed: int = 0,
474
+ randomize_seed: bool = True,
475
+ guidance_scale: float = 1.0,
476
+ steps: int = 4,
477
+ ):
478
+ if input_image is None:
479
+ raise gr.Error("Please upload an input image.")
480
+ if not prompt or not prompt.strip():
481
+ raise gr.Error("Please enter a prompt / description.")
482
+
483
+ if randomize_seed:
484
+ seed = random.randint(0, MAX_SEED)
485
+ seed = int(seed)
486
+
487
+ input_image = _resize_to_divisible(input_image.convert("RGB"))
488
+ W, H = input_image.size
489
+
490
+ yield (
491
+ gr.update(visible=True, value=None, label="Running FLUX.2-Klein…"),
492
+ gr.update(visible=False, value=None),
493
+ gr.update(value=seed),
494
+ )
495
+
496
+ gen_torch = torch.Generator(device="cuda").manual_seed(seed)
497
+ with torch.no_grad():
498
+ klein_out = klein_pipe(
499
+ prompt=prompt,
500
+ image=input_image,
501
+ num_inference_steps=int(steps),
502
+ guidance_scale=float(guidance_scale),
503
+ generator=gen_torch,
504
+ output_type="pil",
505
+ )
506
+ klein_img: Image.Image = klein_out.images[0]
507
+
508
+ if klein_img.size != (W, H):
509
+ klein_img = klein_img.resize((W, H), Image.LANCZOS)
510
+
511
+ yield (
512
+ gr.update(visible=True, value=klein_img, label="FLUX.2-Klein done β€” encoding for PiD…"),
513
+ gr.update(visible=False),
514
+ gr.update(),
515
+ )
516
+
517
+ torch.cuda.empty_cache()
518
+
519
+ klein_arr = np.array(klein_img).astype(np.float32) / 255.0
520
+ klein_tensor_01 = torch.from_numpy(klein_arr).permute(2, 0, 1).unsqueeze(0)
521
+
522
+ final_latent = _encode_image_to_latent(klein_tensor_01)
523
+ baseline_01 = klein_tensor_01.to(dtype=DTYPE, device="cuda")
524
+ final_sigma = float(pipeline.scheduler.sigmas[-1].item())
525
+
526
+ pid_model = _pick_pid_model(max(H, W))
527
+ pid_img = None
528
+ for k, total, x in _pid_stream(
529
+ pid_model, final_latent, baseline_01, final_sigma, prompt, num_steps=PID_INFERENCE_STEPS
530
+ ):
531
+ pid_img = _pid_pixel_to_pil(x)
532
+ yield (
533
+ gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD β€” step {k}/{total}"),
534
+ gr.update(visible=False),
535
+ gr.update(),
536
+ )
537
+
538
+ yield (
539
+ gr.update(visible=False, value=None),
540
+ gr.update(visible=True, value=(klein_img, pid_img)),
541
+ gr.update(),
542
+ )
543
+
544
+
545
+ @spaces.GPU(duration=90, size="xlarge")
546
+ def i2i_generate(*args, **kwargs):
547
+ yield from _i2i_generate_core(*args, **kwargs)
548
+
549
+ # PiD upscaler supports up to 1024px input (β†’ 4096px output with 2kto4k model).
550
+ # We clamp at 1024 to stay within VRAM budget.
551
+ UPSCALER_MAX_SIDE = 1024
552
+
553
+
554
+ def _upscaler_dim_info(image: Image.Image):
555
+ """Dimension markdown shown when the user uploads an image."""
556
+ if image is None:
557
+ return "_Upload an image to see its upscale dimensions._"
558
+ w, h = image.size
559
+ scale = min(UPSCALER_MAX_SIDE / w, UPSCALER_MAX_SIDE / h, 1.0)
560
+ nw = max(16, (int(w * scale) // 16) * 16)
561
+ nh = max(16, (int(h * scale) // 16) * 16)
562
+ out_w, out_h = nw * SR_SCALE, nh * SR_SCALE
563
+ return (
564
+ f"**Input:** {w} Γ— {h} px β†’ "
565
+ f"**Processed:** {nw} Γ— {nh} px β†’ "
566
+ f"**Upscaled output:** {out_w} Γ— {out_h} px "
567
+ f"*({SR_SCALE}Γ— via PiD)*"
568
+ )
569
+
570
+
571
+ def _upscaler_core(
572
+ input_image: Image.Image,
573
+ prompt: str,
574
+ ):
575
+ """
576
+ Pure PiD upscaler:
577
+ 1. Resize input so longer side ≀ 1024 and dims are divisible by 16.
578
+ 2. Encode to VAE latent (Z-Image VAE).
579
+ 3. Run PiD 4-step student sampler β†’ 4Γ— pixel-space output.
580
+ 4. Yield live step previews, then the final A/B slider.
581
+ """
582
+ if input_image is None:
583
+ raise gr.Error("Please upload an image to upscale.")
584
+
585
+ # caption is optional β€” use a generic fallback if blank
586
+ caption = prompt.strip() if prompt and prompt.strip() else "high quality, detailed, sharp"
587
+
588
+ img_rgb = input_image.convert("RGB")
589
+ w, h = img_rgb.size
590
+ scale = min(UPSCALER_MAX_SIDE / w, UPSCALER_MAX_SIDE / h, 1.0)
591
+ nw = max(16, (int(w * scale) // 16) * 16)
592
+ nh = max(16, (int(h * scale) // 16) * 16)
593
+ if (nw, nh) != (w, h):
594
+ img_rgb = img_rgb.resize((nw, nh), Image.LANCZOS)
595
+
596
+ input_pil = img_rgb # clean resized input shown on the left of the slider
597
+
598
+ yield (
599
+ gr.update(visible=True, value=input_pil, label="Encoding image…"),
600
+ gr.update(visible=False, value=None),
601
+ )
602
+
603
+ # ── Encode to VAE latent ───────────────────────────────────────────────
604
+ arr_01 = np.array(img_rgb).astype(np.float32) / 255.0
605
+ tensor_01 = torch.from_numpy(arr_01).permute(2, 0, 1).unsqueeze(0) # 1 3 H W [0,1]
606
+
607
+ latent = _encode_image_to_latent(tensor_01)
608
+ baseline_01 = tensor_01.to(dtype=DTYPE, device="cuda")
609
+ sigma = float(pipeline.scheduler.sigmas[-1].item())
610
+
611
+ torch.cuda.empty_cache()
612
+
613
+ # ── PiD 4-step upscaling ───────────────────────────────────────────────
614
+ pid_model = _pick_pid_model(max(nw, nh))
615
+ pid_img = None
616
+
617
+ for k, total, x in _pid_stream(
618
+ pid_model, latent, baseline_01, sigma, caption, num_steps=PID_INFERENCE_STEPS
619
+ ):
620
+ pid_img = _pid_pixel_to_pil(x)
621
+ yield (
622
+ gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD β€” step {k}/{total}"),
623
+ gr.update(visible=False),
624
+ )
625
+
626
+ # ── Done: show A/B slider ──────────────────────────────────────────────
627
+ yield (
628
+ gr.update(visible=False, value=None),
629
+ gr.update(visible=True, value=(input_pil, pid_img)),
630
+ )
631
+
632
+
633
+ @spaces.GPU(duration=90, size="xlarge")
634
+ def upscaler_run(*args, **kwargs):
635
+ yield from _upscaler_core(*args, **kwargs)
636
+
637
+
638
+ DESCRIPTION = """
639
+ # PiD β€” Pixel Diffusion Decoder
640
+
641
+ **Text2Image** uses [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image) (live TAEF1 previews) then [PiD](https://huggingface.co/nvidia/PiD)'s 4-step pixel-diffusion decoder for 4Γ— super-resolution. **Image2Image** uses FLUX.2-Klein for fast image-to-image then [PiD](https://huggingface.co/nvidia/PiD) for 4Γ— upscaling. The slider on each tab compares the base model output vs the PiD upscale. [@github](https://github.com/PRITHIVSAKTHIUR/PiD-Image-Upscaler).
642
+ """
643
+
644
+ css = """
645
+ .gradio-container { max-width: 1200px !important; margin: auto !important; }
646
+ .dark .gradio-container { color: var(--body-text-color); }
647
+ """
648
+
649
+ with gr.Blocks(theme=orange_red_theme, css=css) as demo:
650
+
651
+ gr.Markdown(DESCRIPTION)
652
+
653
+ with gr.Tabs():
654
+
655
+ with gr.Tab("Image2ImagePiD"):
656
+
657
+ gr.Markdown(
658
+ "Upload any image β€” **[FLUX.2-Klein](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B)** refines it then "
659
+ "**PiD** super-resolves the result 4Γ—. \n"
660
+ "The slider compares the Klein output **(left)** to the PiD upscale **(right)**."
661
+ )
662
+
663
+ with gr.Row():
664
+ with gr.Column(scale=1):
665
+ i2i_input = gr.Image(label="Input image", type="pil", height=380)
666
+ i2i_dim_info = gr.Markdown("_Upload an image to see its processed dimensions._")
667
+ i2i_prompt = gr.Textbox(
668
+ label="Prompt / description",
669
+ placeholder="Describe the image content or the desired style…",
670
+ lines=3,
671
+ )
672
+ i2i_run = gr.Button("Run", variant="primary")
673
+
674
+ with gr.Accordion("Advanced Settings", open=False, visible=True):
675
+ i2i_seed = gr.Slider(
676
+ label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0
677
+ )
678
+ i2i_rand = gr.Checkbox(label="Randomize seed", value=True)
679
+ i2i_guidance = gr.Slider(
680
+ label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=1.0
681
+ )
682
+ i2i_steps = gr.Slider(
683
+ label="Steps", minimum=1, maximum=50, value=4, step=1
684
+ )
685
+
686
+ with gr.Column(scale=2):
687
+ i2i_live = gr.Image(
688
+ label="Output", visible=True, show_label=True, type="pil", height=400
689
+ )
690
+ i2i_slider = gr.ImageSlider(
691
+ label="FLUX.2-Klein (left) ↔ PiD 4Γ— upscale (right)",
692
+ visible=False,
693
+ type="pil",
694
+ height=720,
695
+ max_height=720,
696
+ )
697
+
698
+ i2i_input.upload(
699
+ fn=update_dimensions_on_upload,
700
+ inputs=i2i_input,
701
+ outputs=i2i_dim_info,
702
+ )
703
+ i2i_run.click(
704
+ fn=i2i_generate,
705
+ inputs=[i2i_input, i2i_prompt, i2i_seed, i2i_rand, i2i_guidance, i2i_steps],
706
+ outputs=[i2i_live, i2i_slider, i2i_seed],
707
+ )
708
+
709
+ with gr.Tab("Text2ImagePiD"):
710
+
711
+ with gr.Row():
712
+ prompt = gr.Textbox(
713
+ show_label=False,
714
+ placeholder="Describe what you want to generate…",
715
+ value="A photorealistic Labrador retriever resting beside a campfire at night, glowing warm firelight reflecting on detailed fur, cinematic outdoor atmosphere.",
716
+ max_lines=1,
717
+ scale=4,
718
+ container=False,
719
+ )
720
+ run = gr.Button("Run", variant="primary", scale=1)
721
+
722
+ live_preview = gr.Image(label="Z-Image with PiD", visible=True, show_label=True, type="pil", height=720)
723
+ slider = gr.ImageSlider(
724
+ label="Z-Image (left) ↔ PiD 4Γ— upscale (right)",
725
+ visible=False,
726
+ type="pil",
727
+ height=720,
728
+ max_height=720,
729
+ )
730
+
731
+ with gr.Accordion("Advanced settings", open=False):
732
+ with gr.Row():
733
+ resolution = gr.Radio(
734
+ label="Z-Image resolution",
735
+ choices=[512, 1024],
736
+ value=512,
737
+ info="512 β†’ 2048Β² (PiD 2k); 1024 β†’ 4096Β² (PiD 2kto4k)",
738
+ )
739
+ num_inference_steps = gr.Slider(
740
+ label="Z-Image steps", minimum=8, maximum=50, step=1, value=28
741
+ )
742
+ with gr.Row():
743
+ guidance_scale = gr.Slider(
744
+ label="Guidance", minimum=1.0, maximum=10.0, step=0.5, value=5.0
745
+ )
746
+ seed = gr.Number(label="Seed", value=0, precision=0)
747
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
748
+
749
+ run.click(
750
+ fn=generate,
751
+ inputs=[prompt, num_inference_steps, guidance_scale, seed, resolution, randomize_seed],
752
+ outputs=[live_preview, slider, seed],
753
+ )
754
+
755
+ with gr.Tab("Image-Upscaler-(preview)"):
756
+
757
+ gr.Markdown(
758
+ "Upload any image and **PiD** will upscale it **4Γ—** directly β€” "
759
+ "no text generation step needed. \n"
760
+ "An optional prompt / description helps PiD produce sharper, "
761
+ "more faithful detail. \n"
762
+ "The slider compares the **original** (left) to the **PiD 4Γ— upscale** (right)."
763
+ )
764
+
765
+ with gr.Row():
766
+
767
+ with gr.Column(scale=1):
768
+ up_input = gr.Image(
769
+ label="Image to upscale",
770
+ type="pil",
771
+ height=400,
772
+ )
773
+ up_dim_info = gr.Markdown(
774
+ "_Upload an image to see its upscale dimensions._"
775
+ )
776
+ up_prompt = gr.Textbox(
777
+ label="Optional prompt / description",
778
+ placeholder="Describe the image for better detail (leave blank for auto)…",
779
+ lines=3,
780
+ visible=False,
781
+ )
782
+ up_run = gr.Button("Upscale 4Γ—", variant="primary")
783
+
784
+ with gr.Column(scale=2):
785
+ up_live = gr.Image(
786
+ label="Output",
787
+ visible=True,
788
+ show_label=True,
789
+ type="pil",
790
+ height=400,
791
+ )
792
+ up_slider = gr.ImageSlider(
793
+ label="Original (left) ↔ PiD 4Γ— upscale (right)",
794
+ visible=False,
795
+ type="pil",
796
+ height=720,
797
+ max_height=720,
798
+ )
799
+
800
+ # live dimension info on upload
801
+ up_input.upload(
802
+ fn=_upscaler_dim_info,
803
+ inputs=up_input,
804
+ outputs=up_dim_info,
805
+ )
806
+
807
+ up_run.click(
808
+ fn=upscaler_run,
809
+ inputs=[up_input, up_prompt],
810
+ outputs=[up_live, up_slider],
811
+ )
812
+
813
+ if __name__ == "__main__":
814
+ demo.queue().launch(mcp_server=True, ssr_mode=False, show_error=True)