File size: 15,805 Bytes
0972cc0
 
 
 
 
835d190
 
 
0972cc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18982d7
0972cc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0fc9e3
 
 
 
 
 
bcabac3
b0fc9e3
 
 
 
 
 
 
 
 
 
 
 
 
 
9014add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0fc9e3
0972cc0
 
afb0b5a
 
835d190
 
 
afb0b5a
 
18982d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0972cc0
 
 
18982d7
 
 
 
 
0972cc0
6cd8e25
0972cc0
6cd8e25
0972cc0
 
 
 
afb0b5a
 
 
 
 
 
 
 
 
 
 
 
 
622f4d0
 
 
 
 
 
18982d7
622f4d0
 
 
 
 
 
0972cc0
622f4d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0972cc0
622f4d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0972cc0
 
 
 
 
 
 
 
 
 
 
1e34c9b
afb0b5a
 
 
 
11d62c5
0972cc0
 
 
 
 
1e34c9b
0972cc0
 
 
 
1e34c9b
 
 
0972cc0
afb0b5a
0972cc0
d330ff4
e2f50b1
1e34c9b
0972cc0
afb0b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2f50b1
1e34c9b
afb0b5a
 
0972cc0
 
e2f50b1
1e34c9b
e2f50b1
 
 
 
 
0972cc0
d330ff4
 
 
622f4d0
e2f50b1
622f4d0
18982d7
 
622f4d0
 
 
 
1e34c9b
622f4d0
afb0b5a
e2f50b1
 
 
 
1e34c9b
e2f50b1
0972cc0
 
11d62c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0972cc0
 
 
1e34c9b
 
 
0972cc0
 
d29a4c4
 
 
 
bcabac3
 
0972cc0
 
d1e77f7
d29a4c4
 
171fc6b
d1e77f7
 
d29a4c4
d1e77f7
 
 
12192e5
d1e77f7
 
 
 
 
 
 
 
 
 
8c73a94
d1e77f7
 
 
 
1e34c9b
0972cc0
 
 
1e34c9b
 
0972cc0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
import os
import sys
import subprocess
import tempfile

# Help the allocator survive the large-activation spikes during PiD pixel-space ops
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

import spaces


PID_REPO_URL = "https://github.com/nv-tlabs/PiD.git"
PID_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "PiD")

if not os.path.exists(PID_REPO_DIR):
    print(f"[pid] cloning {PID_REPO_URL} -> {PID_REPO_DIR}", flush=True)
    subprocess.check_call(["git", "clone", "--depth", "1", PID_REPO_URL, PID_REPO_DIR])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", PID_REPO_DIR])

# PiD's loader resolves paths relative to CWD, so chdir into the repo root.
os.chdir(PID_REPO_DIR)
sys.path.insert(0, PID_REPO_DIR)

import torch
import numpy as np
import gradio as gr
from PIL import Image
from types import SimpleNamespace
from huggingface_hub import snapshot_download

# Pull just the Flux-1 / Z-Image-compatible checkpoints from nvidia/PiD into the
# repo's expected checkpoints/ tree.
snapshot_download(
    repo_id="nvidia/PiD",
    local_dir=PID_REPO_DIR,
    allow_patterns=[
        "checkpoints/PiD_res2k_sr4x_official_flux_distill_4step/*",
        "checkpoints/PiD_res2kto4k_sr4x_official_flux_distill_4step/*",
        "checkpoints/ae.safetensors",
    ],
)

from pid._src.inference.checkpoint_registry import get_pid_checkpoint
from pid._src.inference.create_dataset import XtCaptureCallback
from pid._src.inference.pipeline_registry import (
    decode_with_pipeline_vae,
    extract_latent,
    load_pipeline,
)
from pid._src.utils.model_loader import load_model_from_checkpoint


DTYPE = torch.bfloat16
BACKBONE = "zimage"
SR_SCALE = 4
PID_INFERENCE_STEPS = 4

print("[pid] loading Z-Image pipeline...", flush=True)
# transformers 4.57's SDPA / eager mask builders both broadcast the mask
# function over (b, h, q, k) via torch.vmap, which trips ZeroGPU's
# __torch_function__ hijack when it tries to fake-allocate the indexed
# tensors. Replace vmap with explicit broadcasting — same result, same speed,
# no functorch transform context.
from transformers import masking_utils as _mu

def _broadcasting_vmap_for_bhqkv(mask_function, bh_indices: bool = True):
    def wrapped(b, h, q, k):
        if bh_indices:
            return mask_function(
                b[:, None, None, None],
                h[None, :, None, None],
                q[None, None, :, None],
                k[None, None, None, :],
            )
        return mask_function(b, h, q[:, None], k[None, :])
    return wrapped

_mu._vmap_for_bhqkv = _broadcasting_vmap_for_bhqkv

# Gemma2's forward does `normalizer = torch.tensor(hidden_size**0.5, dtype=...)`
# without a device kwarg, so it lands on CPU while hidden_states is on cuda.
# Vanilla CUDA tolerates the cross-device scalar op; ZeroGPU's __torch_function__
# hijack rejects it. Force torch.tensor calls inside Gemma2.forward onto the
# embedding's device.
import transformers.models.gemma2.modeling_gemma2 as _gm

_orig_gemma2_forward = _gm.Gemma2Model.forward

def _patched_gemma2_forward(self, *args, **kwargs):
    _orig_tt = torch.tensor
    dev = self.embed_tokens.weight.device
    def _tt(data, *a, **kw):
        kw.setdefault("device", dev)
        return _orig_tt(data, *a, **kw)
    torch.tensor = _tt
    try:
        return _orig_gemma2_forward(self, *args, **kwargs)
    finally:
        torch.tensor = _orig_tt

_gm.Gemma2Model.forward = _patched_gemma2_forward

pipeline, pipe_cfg = load_pipeline(BACKBONE, dtype=DTYPE)
pipeline.to("cuda")

print("[pid] loading TAEF1 (fast preview decoder)...", flush=True)
from diffusers import AutoencoderTiny
taef1 = AutoencoderTiny.from_pretrained(
    "madebyollin/taef1", torch_dtype=DTYPE, low_cpu_mem_usage=False
).to("cuda")
taef1.eval()

def _load_pid(ckpt_type: str):
    meta = get_pid_checkpoint(BACKBONE, ckpt_type)
    print(f"[pid] loading PiD decoder ({ckpt_type})...", flush=True)
    model, _ = load_model_from_checkpoint(
        experiment_name=meta.experiment,
        checkpoint_path=meta.checkpoint_path,
        config_file="pid/_src/configs/pid/config.py",
        enable_fsdp=False,
        strict=False,
    )
    model.eval()
    return model


pid_models = {
    "2k": _load_pid("2k"),
    "2kto4k": _load_pid("2kto4k"),
}
print("[pid] ready", flush=True)


def _pick_pid_model(resolution: int):
    """2k decoder is trained at 2048px (sweet spot 512 → 2048); 2kto4k handles 1024 → 4K."""
    return pid_models["2kto4k"] if resolution > 512 else pid_models["2k"]


def _latent_to_pil(tensor: torch.Tensor) -> Image.Image:
    """PiD output is (C, T, H, W) with T=1 for image -> PIL.Image."""
    if tensor.dim() == 4:
        tensor = tensor.squeeze(1)
    arr = ((tensor.float().clamp(-1, 1) + 1) * 127.5).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    return Image.fromarray(arr)


def _taef1_preview(packed_latent: torch.Tensor, H: int, W: int) -> Image.Image:
    """Fast low-res decode of a Z-Image latent using TAEF1 (FLUX-1 compatible)."""
    with torch.no_grad():
        unpacked = extract_latent(pipeline, SimpleNamespace(images=packed_latent), pipe_cfg, H, W)
        scale = pipeline.vae.config.scaling_factor
        shift = getattr(pipeline.vae.config, "shift_factor", None) or 0.0
        denorm = unpacked.to(dtype=DTYPE) / scale + shift
        img = taef1.decode(denorm).sample
        img = (img.float().clamp(-1, 1) + 1) / 2
        arr = (img[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
    return Image.fromarray(arr)


def _pid_pixel_to_pil(x: torch.Tensor) -> Image.Image:
    """PiD pixel-space tensor (B, 3, H, W) in [-1, 1] -> PIL.Image."""
    arr = ((x[0].float().clamp(-1, 1) + 1) * 127.5).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    return Image.fromarray(arr)


def _pid_stream(pid_model, latent: torch.Tensor, baseline_01: torch.Tensor, sigma: float, caption: str, num_steps: int = PID_INFERENCE_STEPS):
    """Reimplementation of PiDDistillModel.generate_samples_from_batch that yields
    the current pixel-space tensor after each of the `num_steps` student-sampler
    iterations. Final yield is the clean output."""
    from contextlib import nullcontext

    B = 1
    lq_h, lq_w = baseline_01.shape[-2], baseline_01.shape[-1]
    img_h, img_w = lq_h * SR_SCALE, lq_w * SR_SCALE

    caption_embs, _ = pid_model._encode_text_raw([caption])
    caption_embs = caption_embs.to(**pid_model.tensor_kwargs)

    lq_video_or_image = (baseline_01 * 2.0 - 1.0).to(dtype=DTYPE, device="cuda")
    lq_latent = latent.to(dtype=DTYPE, device="cuda")
    degrade_sigma_tensor = torch.tensor([sigma], device="cuda", dtype=torch.float32)

    gen = torch.Generator(device="cuda").manual_seed(0)
    noise = torch.randn(B, 3, img_h, img_w, device="cuda", generator=gen)

    t_list = pid_model._get_t_list(device=torch.device("cuda"), num_steps=num_steps)
    autocast_ctx = (
        torch.autocast("cuda", dtype=pid_model.autocast_dtype)
        if pid_model.autocast_dtype
        else nullcontext()
    )
    net = pid_model.net
    net.eval()
    timescale = pid_model.fm_trainer.timescale
    student_sample_type = pid_model.config.student_sample_type
    prediction_type = pid_model.config.prediction_type

    x = noise
    with torch.no_grad(), autocast_ctx:
        steps_total = len(t_list) - 1
        for step_idx, (t_cur, t_next) in enumerate(zip(t_list[:-1], t_list[1:])):
            t_cur_batch = t_cur.expand(B)
            t_cur_scaled = t_cur_batch * timescale
            v_pred = net(
                x,
                t_cur_scaled,
                caption_embs,
                lq_video_or_image=lq_video_or_image,
                lq_latent=lq_latent,
                degrade_sigma=degrade_sigma_tensor,
            )
            if t_next.item() > 0:
                if student_sample_type == "ode":
                    v_for_step = pid_model._net_output_to_velocity(x, v_pred, t_cur_batch, prediction_type)
                    dt = t_next - t_cur
                    x = x + dt * v_for_step
                else:
                    x0_pred = pid_model._velocity_to_x0(x, v_pred, t_cur_batch)
                    eps_infer = torch.randn(
                        x0_pred.shape, device=x0_pred.device, dtype=x0_pred.dtype, generator=gen
                    )
                    s = [B] + [1] * (x.ndim - 1)
                    t_next_bcast = t_next.reshape(1).expand(s)
                    x = (1.0 - t_next_bcast) * x0_pred + t_next_bcast * eps_infer
            else:
                x = pid_model._velocity_to_x0(x, v_pred, t_cur_batch)
            yield step_idx + 1, steps_total, x.clone()


def _evenly_spaced_capture_steps(total_steps: int, num_captures: int) -> list[int]:
    """Pick N capture indices spread across [1, total_steps-1]. The final x0 is always added separately."""
    if num_captures <= 0:
        return []
    # avoid 0 (no forward pass yet) and total_steps (== final clean, captured separately)
    raw = np.linspace(1, max(2, total_steps - 1), num_captures + 1)[1:]
    return sorted({int(round(x)) for x in raw})


import random
import threading
import queue as _queue


def _generate_core(
    prompt: str,
    num_inference_steps: int = 28,
    guidance_scale: float = 5.0,
    seed: int = 0,
    resolution: int = 512,
    randomize_seed: bool = False,
):
    if not prompt or not prompt.strip():
        raise gr.Error("Please enter a prompt.")

    if randomize_seed:
        seed = random.randint(0, 2**31 - 1)
    seed = int(seed)
    num_inference_steps = int(num_inference_steps)
    H = W = int(resolution)


    # initial: show the live preview, hide the final slider
    yield gr.update(visible=True, value=None, label="Generating Z-Image…"), gr.update(visible=False, value=None), gr.update(value=seed)

    # ---- Run Z-Image in a thread; stream taef1 previews via a queue ----
    preview_q: "_queue.Queue" = _queue.Queue()
    _DONE = object()

    def streaming_cb(pipe, step_index, timestep, callback_kwargs):
        try:
            preview = _taef1_preview(callback_kwargs["latents"], H, W)
            preview_q.put((step_index, preview))
        except Exception as e:
            print(f"[pid] taef1 preview failed at step {step_index}: {e}", flush=True)
        return callback_kwargs

    def run_pipeline():
        gen_torch = torch.Generator(device="cuda").manual_seed(int(seed))
        gen_kwargs = dict(
            prompt=prompt,
            height=H,
            width=W,
            num_inference_steps=num_inference_steps,
            guidance_scale=float(guidance_scale),
            num_images_per_prompt=1,
            output_type="latent",
            generator=gen_torch,
            callback_on_step_end=streaming_cb,
            callback_on_step_end_tensor_inputs=["latents"],
        )
        gen_kwargs.update(pipe_cfg.extra_generate_kwargs)
        try:
            with torch.no_grad():
                out = pipeline(**gen_kwargs)
            preview_q.put((_DONE, out))
        except Exception as e:
            preview_q.put((_DONE, e))

    thread = threading.Thread(target=run_pipeline, daemon=True)
    thread.start()

    raw_output = None
    while True:
        step_index, payload = preview_q.get()
        if step_index is _DONE:
            if isinstance(payload, Exception):
                raise payload
            raw_output = payload
            break
        label = f"Generating Z-Image — step {step_index + 1}/{num_inference_steps}"
        yield gr.update(visible=True, value=payload, label=label), gr.update(visible=False), gr.update()

    thread.join()
    final_latent = extract_latent(pipeline, raw_output, pipe_cfg, H, W)

    # ---- VAE decode of the final clean latent (Z-Image baseline) ----
    yield gr.update(visible=True, label="Decoding final Z-Image…"), gr.update(visible=False), gr.update()
    with torch.no_grad():
        baseline_01 = decode_with_pipeline_vae(pipeline, final_latent, pipe_cfg)
    zimage_img = Image.fromarray(
        (baseline_01[0].clamp(0, 1).permute(1, 2, 0).float().cpu().numpy() * 255).astype(np.uint8)
    )

    # Free Z-Image VAE intermediates before PiD takes over the GPU
    torch.cuda.empty_cache()

    # ---- PiD upscaling on the final latent, streaming the 4 internal steps ----
    final_sigma = float(pipeline.scheduler.sigmas[-1].item())
    pid_img = None
    pid_model = _pick_pid_model(H)
    for k, total, x in _pid_stream(pid_model, final_latent, baseline_01, final_sigma, prompt):
        pid_img = _pid_pixel_to_pil(x)
        yield (
            gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD — step {k}/{total}"),
            gr.update(visible=False),
            gr.update(),
        )

    # ---- Done: hide live preview, show the A/B slider ----
    yield (
        gr.update(visible=False, value=None),
        gr.update(visible=True, value=(zimage_img, pid_img)),
        gr.update(),
    )


# Two decorated entrypoints: 1024 needs the full 96GB Blackwell, 512 fits the MIG.
# `size` isn't dynamic via the decorator, so we route per request from a plain dispatcher.
@spaces.GPU(duration=60)
def generate_large(*args, **kwargs):
    yield from _generate_core(*args, **kwargs)


@spaces.GPU(duration=90, size="xlarge")
def generate_xlarge(*args, **kwargs):
    yield from _generate_core(*args, **kwargs)


def generate(prompt, num_inference_steps, guidance_scale, seed, resolution, randomize_seed):
    fn = generate_xlarge if int(resolution) >= 1024 else generate_large
    yield from fn(prompt, num_inference_steps, guidance_scale, seed, resolution, randomize_seed)


DESCRIPTION = """
# 🪄 PiD — Pixel Diffusion Decoder for Z-Image

Runs [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image) (live previews via TAEF1) then
[PiD](https://github.com/nv-tlabs/PiD)'s 4-step pixel-diffusion decoder for a 4×
super-resolved result. The slider compares Z-Image's native VAE output to the PiD upscale.
"""

CSS = """
.gradio-container { max-width: 1200px !important; margin: auto !important; }
.dark .gradio-container { color: var(--body-text-color); }
"""

with gr.Blocks(theme=gr.themes.Citrus(), css=CSS) as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        prompt = gr.Textbox(
            show_label=False,
            placeholder="Describe what you want to generate…",
            value="A photorealistic close-up of a brown tabby cat wearing a woolen hat sitting on a rustic wooden table, morning light, detailed fur",
            lines=1,
            scale=4,
            container=False,
        )
        run = gr.Button("Run", variant="primary", scale=1)

    live_preview = gr.Image(label="Z-Image with PiD", visible=True, show_label=True, type="pil", height=720)
    slider = gr.ImageSlider(
        label="Z-Image (left)  ↔  PiD 4× upscale (right)",
        visible=False,
        type="pil",
        height=720,
        max_height=720,
    )

    with gr.Accordion("Advanced settings", open=False):
        with gr.Row():
            resolution = gr.Radio(label="Z-Image resolution", choices=[512, 1024], value=512, info="512 → 2048² (PiD 2k); 1024 → 4096² (PiD 2kto4k)")
            num_inference_steps = gr.Slider(label="Z-Image steps", minimum=8, maximum=50, step=1, value=28)
        with gr.Row():
            guidance_scale = gr.Slider(label="Guidance", minimum=1.0, maximum=10.0, step=0.5, value=5.0)
            seed = gr.Number(label="Seed", value=0, precision=0)
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

    run.click(
        fn=generate,
        inputs=[prompt, num_inference_steps, guidance_scale, seed, resolution, randomize_seed],
        outputs=[live_preview, slider, seed],
    )

if __name__ == "__main__":
    demo.queue().launch()