Spaces:
Running on Zero
Running on Zero
| import os | |
| import sys | |
| import subprocess | |
| import tempfile | |
| from typing import Iterable | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from types import SimpleNamespace | |
| from huggingface_hub import snapshot_download | |
| import spaces | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| colors.orange_red = colors.Color( | |
| name="orange_red", c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366", | |
| c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700", c800="#B33000", | |
| c900="#992900", c950="#802200", | |
| ) | |
| class OrangeRedTheme(Soft): | |
| def __init__( | |
| self, *, primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.orange_red, | |
| neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Outfit"), "Arial", "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, | |
| text_size=text_size, font=font, font_mono=font_mono, | |
| ) | |
| super().set( | |
| background_fill_primary="*primary_50", | |
| background_fill_primary_dark="*primary_900", | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| button_primary_text_color="white", | |
| button_primary_text_color_hover="white", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| slider_color="*secondary_500", | |
| slider_color_dark="*secondary_600", | |
| block_title_text_weight="600", block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg", | |
| button_large_padding="11px", color_accent_soft="*primary_100", | |
| block_label_background_fill="*primary_200", | |
| ) | |
| orange_red_theme = OrangeRedTheme() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES")) | |
| print("torch.__version__ =", torch.__version__) | |
| print("torch.version.cuda =", torch.version.cuda) | |
| print("cuda available:", torch.cuda.is_available()) | |
| print("cuda device count:", torch.cuda.device_count()) | |
| if torch.cuda.is_available(): | |
| print("current device:", torch.cuda.current_device()) | |
| print("device name:", torch.cuda.get_device_name(torch.cuda.current_device())) | |
| print("Using device:", device) | |
| # Help the allocator survive the large-activation spikes during PiD pixel-space ops | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| PID_REPO_URL = "https://github.com/nv-tlabs/PiD.git" | |
| PID_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "PiD") | |
| if not os.path.exists(PID_REPO_DIR): | |
| print(f"[pid] cloning {PID_REPO_URL} -> {PID_REPO_DIR}", flush=True) | |
| subprocess.check_call(["git", "clone", "--depth", "1", PID_REPO_URL, PID_REPO_DIR]) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", PID_REPO_DIR]) | |
| # PiD's loader resolves paths relative to CWD, so chdir into the repo root. | |
| os.chdir(PID_REPO_DIR) | |
| sys.path.insert(0, PID_REPO_DIR) | |
| # Pull just the Flux-1 / Z-Image-compatible checkpoints from nvidia/PiD into the | |
| # repo's expected checkpoints/ tree. | |
| snapshot_download( | |
| repo_id="nvidia/PiD", | |
| local_dir=PID_REPO_DIR, | |
| allow_patterns=[ | |
| "checkpoints/PiD_res2k_sr4x_official_flux_distill_4step/*", | |
| "checkpoints/PiD_res2kto4k_sr4x_official_flux_distill_4step/*", | |
| "checkpoints/ae.safetensors", | |
| ], | |
| ) | |
| from pid._src.inference.checkpoint_registry import get_pid_checkpoint | |
| #from pid._src.inference.create_dataset import XtCaptureCallback | |
| from pid._src.inference.pipeline_registry import ( | |
| decode_with_pipeline_vae, | |
| extract_latent, | |
| load_pipeline, | |
| ) | |
| from pid._src.utils.model_loader import load_model_from_checkpoint | |
| DTYPE = torch.bfloat16 | |
| BACKBONE = "zimage" | |
| SR_SCALE = 4 | |
| PID_INFERENCE_STEPS = 4 | |
| MAX_SEED = 2**31 - 1 | |
| print("[pid] loading Z-Image pipeline...", flush=True) | |
| # transformers 4.57's SDPA / eager mask builders both broadcast the mask | |
| # function over (b, h, q, k) via torch.vmap, which trips ZeroGPU's | |
| # __torch_function__ hijack when it tries to fake-allocate the indexed | |
| # tensors. Replace vmap with explicit broadcasting — same result, same speed, | |
| # no functorch transform context. | |
| from transformers import masking_utils as _mu | |
| def _broadcasting_vmap_for_bhqkv(mask_function, bh_indices: bool = True): | |
| def wrapped(b, h, q, k): | |
| if bh_indices: | |
| return mask_function( | |
| b[:, None, None, None], | |
| h[None, :, None, None], | |
| q[None, None, :, None], | |
| k[None, None, None, :], | |
| ) | |
| return mask_function(b, h, q[:, None], k[None, :]) | |
| return wrapped | |
| _mu._vmap_for_bhqkv = _broadcasting_vmap_for_bhqkv | |
| # Gemma2's forward does `normalizer = torch.tensor(hidden_size**0.5, dtype=...)` | |
| # without a device kwarg, so it lands on CPU while hidden_states is on cuda. | |
| # Vanilla CUDA tolerates the cross-device scalar op; ZeroGPU's __torch_function__ | |
| # hijack rejects it. Force torch.tensor calls inside Gemma2.forward onto the | |
| # embedding's device. | |
| import transformers.models.gemma2.modeling_gemma2 as _gm | |
| _orig_gemma2_forward = _gm.Gemma2Model.forward | |
| def _patched_gemma2_forward(self, *args, **kwargs): | |
| _orig_tt = torch.tensor | |
| dev = self.embed_tokens.weight.device | |
| def _tt(data, *a, **kw): | |
| kw.setdefault("device", dev) | |
| return _orig_tt(data, *a, **kw) | |
| torch.tensor = _tt | |
| try: | |
| return _orig_gemma2_forward(self, *args, **kwargs) | |
| finally: | |
| torch.tensor = _orig_tt | |
| _gm.Gemma2Model.forward = _patched_gemma2_forward | |
| pipeline, pipe_cfg = load_pipeline(BACKBONE, dtype=DTYPE) | |
| pipeline.to("cuda") | |
| print("[pid] loading TAEF1 (fast preview decoder)...", flush=True) | |
| from diffusers import AutoencoderTiny | |
| taef1 = AutoencoderTiny.from_pretrained( | |
| "madebyollin/taef1", torch_dtype=DTYPE, low_cpu_mem_usage=False | |
| ).to("cuda") | |
| taef1.eval() | |
| def _load_pid(ckpt_type: str): | |
| meta = get_pid_checkpoint(BACKBONE, ckpt_type) | |
| print(f"[pid] loading PiD decoder ({ckpt_type})...", flush=True) | |
| model, _ = load_model_from_checkpoint( | |
| experiment_name=meta.experiment, | |
| checkpoint_path=meta.checkpoint_path, | |
| config_file="pid/_src/configs/pid/config.py", | |
| enable_fsdp=False, | |
| strict=False, | |
| ) | |
| model.eval() | |
| return model | |
| pid_models = { | |
| "2k": _load_pid("2k"), | |
| "2kto4k": _load_pid("2kto4k"), | |
| } | |
| print("[pid] loading FLUX.2-Klein pipeline...", flush=True) | |
| from diffusers import Flux2KleinPipeline | |
| klein_pipe = Flux2KleinPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.2-klein-4B", | |
| torch_dtype=DTYPE, | |
| ).to("cuda") | |
| print("[pid] FLUX.2-Klein loaded.", flush=True) | |
| print("[pid] ready", flush=True) | |
| def _pick_pid_model(resolution: int): | |
| """2k decoder is trained at 2048px (sweet spot 512 → 2048); 2kto4k handles 1024 → 4K.""" | |
| return pid_models["2kto4k"] if resolution > 512 else pid_models["2k"] | |
| def _latent_to_pil(tensor: torch.Tensor) -> Image.Image: | |
| """PiD output is (C, T, H, W) with T=1 for image -> PIL.Image.""" | |
| if tensor.dim() == 4: | |
| tensor = tensor.squeeze(1) | |
| arr = ((tensor.float().clamp(-1, 1) + 1) * 127.5).permute(1, 2, 0).cpu().numpy().astype(np.uint8) | |
| return Image.fromarray(arr) | |
| def _taef1_preview(packed_latent: torch.Tensor, H: int, W: int) -> Image.Image: | |
| """Fast low-res decode of a Z-Image latent using TAEF1 (FLUX-1 compatible).""" | |
| with torch.no_grad(): | |
| unpacked = extract_latent(pipeline, SimpleNamespace(images=packed_latent), pipe_cfg, H, W) | |
| scale = pipeline.vae.config.scaling_factor | |
| shift = getattr(pipeline.vae.config, "shift_factor", None) or 0.0 | |
| denorm = unpacked.to(dtype=DTYPE) / scale + shift | |
| img = taef1.decode(denorm).sample | |
| img = (img.float().clamp(-1, 1) + 1) / 2 | |
| arr = (img[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) | |
| return Image.fromarray(arr) | |
| def _pid_pixel_to_pil(x: torch.Tensor) -> Image.Image: | |
| """PiD pixel-space tensor (B, 3, H, W) in [-1, 1] -> PIL.Image.""" | |
| arr = ((x[0].float().clamp(-1, 1) + 1) * 127.5).permute(1, 2, 0).cpu().numpy().astype(np.uint8) | |
| return Image.fromarray(arr) | |
| def _pid_stream( | |
| pid_model, | |
| latent: torch.Tensor, | |
| baseline_01: torch.Tensor, | |
| sigma: float, | |
| caption: str, | |
| num_steps: int = PID_INFERENCE_STEPS, | |
| ): | |
| """Reimplementation of PiDDistillModel.generate_samples_from_batch that yields | |
| the current pixel-space tensor after each of the `num_steps` student-sampler | |
| iterations. Final yield is the clean output.""" | |
| from contextlib import nullcontext | |
| B = 1 | |
| lq_h, lq_w = baseline_01.shape[-2], baseline_01.shape[-1] | |
| img_h, img_w = lq_h * SR_SCALE, lq_w * SR_SCALE | |
| caption_embs, _ = pid_model._encode_text_raw([caption]) | |
| caption_embs = caption_embs.to(**pid_model.tensor_kwargs) | |
| lq_video_or_image = (baseline_01 * 2.0 - 1.0).to(dtype=DTYPE, device="cuda") | |
| lq_latent = latent.to(dtype=DTYPE, device="cuda") | |
| degrade_sigma_tensor = torch.tensor([sigma], device="cuda", dtype=torch.float32) | |
| gen = torch.Generator(device="cuda").manual_seed(0) | |
| noise = torch.randn(B, 3, img_h, img_w, device="cuda", generator=gen) | |
| t_list = pid_model._get_t_list(device=torch.device("cuda"), num_steps=num_steps) | |
| autocast_ctx = ( | |
| torch.autocast("cuda", dtype=pid_model.autocast_dtype) | |
| if pid_model.autocast_dtype | |
| else nullcontext() | |
| ) | |
| net = pid_model.net | |
| net.eval() | |
| timescale = pid_model.fm_trainer.timescale | |
| student_sample_type = pid_model.config.student_sample_type | |
| prediction_type = pid_model.config.prediction_type | |
| x = noise | |
| with torch.no_grad(), autocast_ctx: | |
| steps_total = len(t_list) - 1 | |
| for step_idx, (t_cur, t_next) in enumerate(zip(t_list[:-1], t_list[1:])): | |
| t_cur_batch = t_cur.expand(B) | |
| t_cur_scaled = t_cur_batch * timescale | |
| v_pred = net( | |
| x, | |
| t_cur_scaled, | |
| caption_embs, | |
| lq_video_or_image=lq_video_or_image, | |
| lq_latent=lq_latent, | |
| degrade_sigma=degrade_sigma_tensor, | |
| ) | |
| if t_next.item() > 0: | |
| if student_sample_type == "ode": | |
| v_for_step = pid_model._net_output_to_velocity(x, v_pred, t_cur_batch, prediction_type) | |
| dt = t_next - t_cur | |
| x = x + dt * v_for_step | |
| else: | |
| x0_pred = pid_model._velocity_to_x0(x, v_pred, t_cur_batch) | |
| eps_infer = torch.randn( | |
| x0_pred.shape, device=x0_pred.device, dtype=x0_pred.dtype, generator=gen | |
| ) | |
| s = [B] + [1] * (x.ndim - 1) | |
| t_next_bcast = t_next.reshape(1).expand(s) | |
| x = (1.0 - t_next_bcast) * x0_pred + t_next_bcast * eps_infer | |
| else: | |
| x = pid_model._velocity_to_x0(x, v_pred, t_cur_batch) | |
| yield step_idx + 1, steps_total, x.clone() | |
| def _evenly_spaced_capture_steps(total_steps: int, num_captures: int) -> list[int]: | |
| """Pick N capture indices spread across [1, total_steps-1].""" | |
| if num_captures <= 0: | |
| return [] | |
| raw = np.linspace(1, max(2, total_steps - 1), num_captures + 1)[1:] | |
| return sorted({int(round(x)) for x in raw}) | |
| def _resize_to_divisible(image: Image.Image, max_side: int = 1024, div: int = 16) -> Image.Image: | |
| """Resize so the longer side ≤ max_side and both dims divisible by `div`. | |
| Never upscales the input image.""" | |
| w, h = image.size | |
| scale = min(max_side / w, max_side / h, 1.0) | |
| nw = max(div, (int(w * scale) // div) * div) | |
| nh = max(div, (int(h * scale) // div) * div) | |
| return image.resize((nw, nh), Image.LANCZOS) | |
| def _encode_image_to_latent(image_01: torch.Tensor) -> torch.Tensor: | |
| """Encode a (1, 3, H, W) [0,1] float tensor to a VAE latent via the Z-Image VAE.""" | |
| vae = pipeline.vae | |
| image_norm = image_01 * 2.0 - 1.0 # [0,1] → [-1,1] | |
| with torch.no_grad(): | |
| latent = vae.encode(image_norm.to(dtype=DTYPE, device="cuda")).latent_dist.sample() | |
| scale = vae.config.scaling_factor | |
| shift = getattr(vae.config, "shift_factor", None) or 0.0 | |
| latent = (latent - shift) * scale | |
| return latent | |
| import random | |
| import threading | |
| import queue as _queue | |
| def _generate_core( | |
| prompt: str, | |
| num_inference_steps: int = 28, | |
| guidance_scale: float = 5.0, | |
| seed: int = 0, | |
| resolution: int = 512, | |
| randomize_seed: bool = False, | |
| ): | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Please enter a prompt.") | |
| if randomize_seed: | |
| seed = random.randint(0, 2**31 - 1) | |
| seed = int(seed) | |
| num_inference_steps = int(num_inference_steps) | |
| H = W = int(resolution) | |
| # initial: show the live preview, hide the final slider | |
| yield gr.update(visible=True, value=None, label="Generating Z-Image…"), gr.update(visible=False, value=None), gr.update(value=seed) | |
| # ---- Run Z-Image in a thread; stream taef1 previews via a queue ---- | |
| preview_q: "_queue.Queue" = _queue.Queue() | |
| _DONE = object() | |
| def streaming_cb(pipe, step_index, timestep, callback_kwargs): | |
| try: | |
| preview = _taef1_preview(callback_kwargs["latents"], H, W) | |
| preview_q.put((step_index, preview)) | |
| except Exception as e: | |
| print(f"[pid] taef1 preview failed at step {step_index}: {e}", flush=True) | |
| return callback_kwargs | |
| def run_pipeline(): | |
| gen_torch = torch.Generator(device="cuda").manual_seed(int(seed)) | |
| gen_kwargs = dict( | |
| prompt=prompt, | |
| height=H, | |
| width=W, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=float(guidance_scale), | |
| num_images_per_prompt=1, | |
| output_type="latent", | |
| generator=gen_torch, | |
| callback_on_step_end=streaming_cb, | |
| callback_on_step_end_tensor_inputs=["latents"], | |
| ) | |
| gen_kwargs.update(pipe_cfg.extra_generate_kwargs) | |
| try: | |
| with torch.no_grad(): | |
| out = pipeline(**gen_kwargs) | |
| preview_q.put((_DONE, out)) | |
| except Exception as e: | |
| preview_q.put((_DONE, e)) | |
| thread = threading.Thread(target=run_pipeline, daemon=True) | |
| thread.start() | |
| raw_output = None | |
| while True: | |
| step_index, payload = preview_q.get() | |
| if step_index is _DONE: | |
| if isinstance(payload, Exception): | |
| raise payload | |
| raw_output = payload | |
| break | |
| label = f"Generating Z-Image — step {step_index + 1}/{num_inference_steps}" | |
| yield gr.update(visible=True, value=payload, label=label), gr.update(visible=False), gr.update() | |
| thread.join() | |
| final_latent = extract_latent(pipeline, raw_output, pipe_cfg, H, W) | |
| yield gr.update(visible=True, label="Decoding final Z-Image…"), gr.update(visible=False), gr.update() | |
| with torch.no_grad(): | |
| baseline_01 = decode_with_pipeline_vae(pipeline, final_latent, pipe_cfg) | |
| zimage_img = Image.fromarray( | |
| (baseline_01[0].clamp(0, 1).permute(1, 2, 0).float().cpu().numpy() * 255).astype(np.uint8) | |
| ) | |
| torch.cuda.empty_cache() | |
| final_sigma = float(pipeline.scheduler.sigmas[-1].item()) | |
| pid_img = None | |
| pid_model = _pick_pid_model(H) | |
| for k, total, x in _pid_stream(pid_model, final_latent, baseline_01, final_sigma, prompt): | |
| pid_img = _pid_pixel_to_pil(x) | |
| yield ( | |
| gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD — step {k}/{total}"), | |
| gr.update(visible=False), | |
| gr.update(), | |
| ) | |
| yield ( | |
| gr.update(visible=False, value=None), | |
| gr.update(visible=True, value=(zimage_img, pid_img)), | |
| gr.update(), | |
| ) | |
| def generate_large(*args, **kwargs): | |
| yield from _generate_core(*args, **kwargs) | |
| def generate_xlarge(*args, **kwargs): | |
| yield from _generate_core(*args, **kwargs) | |
| def generate(prompt, num_inference_steps, guidance_scale, seed, resolution, randomize_seed): | |
| fn = generate_xlarge if int(resolution) >= 1024 else generate_large | |
| yield from fn(prompt, num_inference_steps, guidance_scale, seed, resolution, randomize_seed) | |
| def update_dimensions_on_upload(image: Image.Image): | |
| """Return markdown info string after safe resize.""" | |
| if image is None: | |
| return "_Upload an image to see its processed dimensions._" | |
| resized = _resize_to_divisible(image) | |
| ow, oh = image.size | |
| nw, nh = resized.size | |
| return ( | |
| f"**Input:** {ow} × {oh} px → " | |
| f"**Processed:** {nw} × {nh} px → " | |
| f"**PiD output:** {nw * SR_SCALE} × {nh * SR_SCALE} px" | |
| ) | |
| def _i2i_generate_core( | |
| input_image: Image.Image, | |
| prompt: str, | |
| seed: int = 0, | |
| randomize_seed: bool = True, | |
| guidance_scale: float = 1.0, | |
| steps: int = 4, | |
| ): | |
| if input_image is None: | |
| raise gr.Error("Please upload an input image.") | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Please enter a prompt / description.") | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| seed = int(seed) | |
| input_image = _resize_to_divisible(input_image.convert("RGB")) | |
| W, H = input_image.size | |
| yield ( | |
| gr.update(visible=True, value=None, label="Running FLUX.2-Klein…"), | |
| gr.update(visible=False, value=None), | |
| gr.update(value=seed), | |
| ) | |
| gen_torch = torch.Generator(device="cuda").manual_seed(seed) | |
| with torch.no_grad(): | |
| klein_out = klein_pipe( | |
| prompt=prompt, | |
| image=input_image, | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance_scale), | |
| generator=gen_torch, | |
| output_type="pil", | |
| ) | |
| klein_img: Image.Image = klein_out.images[0] | |
| if klein_img.size != (W, H): | |
| klein_img = klein_img.resize((W, H), Image.LANCZOS) | |
| yield ( | |
| gr.update(visible=True, value=klein_img, label="FLUX.2-Klein done — encoding for PiD…"), | |
| gr.update(visible=False), | |
| gr.update(), | |
| ) | |
| torch.cuda.empty_cache() | |
| klein_arr = np.array(klein_img).astype(np.float32) / 255.0 | |
| klein_tensor_01 = torch.from_numpy(klein_arr).permute(2, 0, 1).unsqueeze(0) | |
| final_latent = _encode_image_to_latent(klein_tensor_01) | |
| baseline_01 = klein_tensor_01.to(dtype=DTYPE, device="cuda") | |
| final_sigma = float(pipeline.scheduler.sigmas[-1].item()) | |
| pid_model = _pick_pid_model(max(H, W)) | |
| pid_img = None | |
| for k, total, x in _pid_stream( | |
| pid_model, final_latent, baseline_01, final_sigma, prompt, num_steps=PID_INFERENCE_STEPS | |
| ): | |
| pid_img = _pid_pixel_to_pil(x) | |
| yield ( | |
| gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD — step {k}/{total}"), | |
| gr.update(visible=False), | |
| gr.update(), | |
| ) | |
| yield ( | |
| gr.update(visible=False, value=None), | |
| gr.update(visible=True, value=(klein_img, pid_img)), | |
| gr.update(), | |
| ) | |
| def i2i_generate(*args, **kwargs): | |
| yield from _i2i_generate_core(*args, **kwargs) | |
| # PiD upscaler supports up to 1024px input (→ 4096px output with 2kto4k model). | |
| # We clamp at 1024 to stay within VRAM budget. | |
| UPSCALER_MAX_SIDE = 1024 | |
| def _upscaler_dim_info(image: Image.Image): | |
| """Dimension markdown shown when the user uploads an image.""" | |
| if image is None: | |
| return "_Upload an image to see its upscale dimensions._" | |
| w, h = image.size | |
| scale = min(UPSCALER_MAX_SIDE / w, UPSCALER_MAX_SIDE / h, 1.0) | |
| nw = max(16, (int(w * scale) // 16) * 16) | |
| nh = max(16, (int(h * scale) // 16) * 16) | |
| out_w, out_h = nw * SR_SCALE, nh * SR_SCALE | |
| return ( | |
| f"**Input:** {w} × {h} px → " | |
| f"**Processed:** {nw} × {nh} px → " | |
| f"**Upscaled output:** {out_w} × {out_h} px " | |
| f"*({SR_SCALE}× via PiD)*" | |
| ) | |
| def _upscaler_core( | |
| input_image: Image.Image, | |
| prompt: str, | |
| ): | |
| """ | |
| Pure PiD upscaler: | |
| 1. Resize input so longer side ≤ 1024 and dims are divisible by 16. | |
| 2. Encode to VAE latent (Z-Image VAE). | |
| 3. Run PiD 4-step student sampler → 4× pixel-space output. | |
| 4. Yield live step previews, then the final A/B slider. | |
| """ | |
| if input_image is None: | |
| raise gr.Error("Please upload an image to upscale.") | |
| # caption is optional — use a generic fallback if blank | |
| caption = prompt.strip() if prompt and prompt.strip() else "high quality, detailed, sharp" | |
| img_rgb = input_image.convert("RGB") | |
| w, h = img_rgb.size | |
| scale = min(UPSCALER_MAX_SIDE / w, UPSCALER_MAX_SIDE / h, 1.0) | |
| nw = max(16, (int(w * scale) // 16) * 16) | |
| nh = max(16, (int(h * scale) // 16) * 16) | |
| if (nw, nh) != (w, h): | |
| img_rgb = img_rgb.resize((nw, nh), Image.LANCZOS) | |
| input_pil = img_rgb # clean resized input shown on the left of the slider | |
| yield ( | |
| gr.update(visible=True, value=input_pil, label="Encoding image…"), | |
| gr.update(visible=False, value=None), | |
| ) | |
| # ── Encode to VAE latent ─────────────────────────────────────────────── | |
| arr_01 = np.array(img_rgb).astype(np.float32) / 255.0 | |
| tensor_01 = torch.from_numpy(arr_01).permute(2, 0, 1).unsqueeze(0) # 1 3 H W [0,1] | |
| latent = _encode_image_to_latent(tensor_01) | |
| baseline_01 = tensor_01.to(dtype=DTYPE, device="cuda") | |
| sigma = float(pipeline.scheduler.sigmas[-1].item()) | |
| torch.cuda.empty_cache() | |
| # ── PiD 4-step upscaling ─────────────────────────────────────────────── | |
| pid_model = _pick_pid_model(max(nw, nh)) | |
| pid_img = None | |
| for k, total, x in _pid_stream( | |
| pid_model, latent, baseline_01, sigma, caption, num_steps=PID_INFERENCE_STEPS | |
| ): | |
| pid_img = _pid_pixel_to_pil(x) | |
| yield ( | |
| gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD — step {k}/{total}"), | |
| gr.update(visible=False), | |
| ) | |
| # ── Done: show A/B slider ────────────────────────────────────────────── | |
| yield ( | |
| gr.update(visible=False, value=None), | |
| gr.update(visible=True, value=(input_pil, pid_img)), | |
| ) | |
| def upscaler_run(*args, **kwargs): | |
| yield from _upscaler_core(*args, **kwargs) | |
| DESCRIPTION = """ | |
| # PiD — Pixel Diffusion Decoder | |
| **Text2Image** uses [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image) (live TAEF1 previews) then [PiD](https://huggingface.co/nvidia/PiD)'s 4-step pixel-diffusion decoder for 4× super-resolution. **Image2Image** uses FLUX.2-Klein for fast image-to-image then [PiD](https://huggingface.co/nvidia/PiD) for 4× upscaling. The slider on each tab compares the base model output vs the PiD upscale. — [visit github](https://github.com/PRITHIVSAKTHIUR/PiD-Image-Upscaler). | |
| """ | |
| css = """ | |
| .gradio-container { max-width: 1200px !important; margin: auto !important; } | |
| .dark .gradio-container { color: var(--body-text-color); } | |
| """ | |
| with gr.Blocks(theme=orange_red_theme, css=css) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Tabs(): | |
| with gr.Tab("Image2ImagePiD"): | |
| gr.Markdown( | |
| "Upload any image — **[FLUX.2-Klein](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B)** refines it then " | |
| "**PiD** super-resolves the result 4×. \n" | |
| "The slider compares the Klein output **(left)** to the PiD upscale **(right)**." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| i2i_input = gr.Image(label="Input image", type="pil", height=380) | |
| i2i_dim_info = gr.Markdown("_Upload an image to see its processed dimensions._") | |
| i2i_prompt = gr.Textbox( | |
| label="Prompt / description", | |
| placeholder="Describe the image content or the desired style…", | |
| lines=3, | |
| ) | |
| i2i_run = gr.Button("Run", variant="primary") | |
| with gr.Accordion("Advanced Settings", open=False, visible=True): | |
| i2i_seed = gr.Slider( | |
| label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0 | |
| ) | |
| i2i_rand = gr.Checkbox(label="Randomize seed", value=True) | |
| i2i_guidance = gr.Slider( | |
| label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=1.0 | |
| ) | |
| i2i_steps = gr.Slider( | |
| label="Steps", minimum=1, maximum=50, value=4, step=1 | |
| ) | |
| with gr.Column(scale=2): | |
| i2i_live = gr.Image( | |
| label="Output", visible=True, show_label=True, type="pil", height=400 | |
| ) | |
| i2i_slider = gr.ImageSlider( | |
| label="FLUX.2-Klein (left) ↔ PiD 4× upscale (right)", | |
| visible=False, | |
| type="pil", | |
| height=720, | |
| max_height=720, | |
| ) | |
| i2i_input.upload( | |
| fn=update_dimensions_on_upload, | |
| inputs=i2i_input, | |
| outputs=i2i_dim_info, | |
| ) | |
| i2i_run.click( | |
| fn=i2i_generate, | |
| inputs=[i2i_input, i2i_prompt, i2i_seed, i2i_rand, i2i_guidance, i2i_steps], | |
| outputs=[i2i_live, i2i_slider, i2i_seed], | |
| ) | |
| with gr.Tab("Text2ImagePiD"): | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| show_label=False, | |
| placeholder="Describe what you want to generate…", | |
| value="A photorealistic Labrador retriever resting beside a campfire at night, glowing warm firelight reflecting on detailed fur, cinematic outdoor atmosphere.", | |
| max_lines=1, | |
| scale=4, | |
| container=False, | |
| ) | |
| run = gr.Button("Run", variant="primary", scale=1) | |
| live_preview = gr.Image(label="Z-Image with PiD", visible=True, show_label=True, type="pil", height=720) | |
| slider = gr.ImageSlider( | |
| label="Z-Image (left) ↔ PiD 4× upscale (right)", | |
| visible=False, | |
| type="pil", | |
| height=720, | |
| max_height=720, | |
| ) | |
| with gr.Accordion("Advanced settings", open=False): | |
| with gr.Row(): | |
| resolution = gr.Radio( | |
| label="Z-Image resolution", | |
| choices=[512, 1024], | |
| value=512, | |
| info="512 → 2048² (PiD 2k); 1024 → 4096² (PiD 2kto4k)", | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Z-Image steps", minimum=8, maximum=50, step=1, value=28 | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance", minimum=1.0, maximum=10.0, step=0.5, value=5.0 | |
| ) | |
| seed = gr.Number(label="Seed", value=0, precision=0) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| run.click( | |
| fn=generate, | |
| inputs=[prompt, num_inference_steps, guidance_scale, seed, resolution, randomize_seed], | |
| outputs=[live_preview, slider, seed], | |
| ) | |
| with gr.Tab("Image-Upscaler-(preview)"): | |
| gr.Markdown( | |
| "Upload any image and **PiD** will upscale it **4×** directly — " | |
| "no text generation step needed. \n" | |
| "An optional prompt / description helps PiD produce sharper, " | |
| "more faithful detail. \n" | |
| "The slider compares the **original** (left) to the **PiD 4× upscale** (right)." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| up_input = gr.Image( | |
| label="Image to upscale", | |
| type="pil", | |
| height=400, | |
| ) | |
| up_dim_info = gr.Markdown( | |
| "_Upload an image to see its upscale dimensions._" | |
| ) | |
| up_prompt = gr.Textbox( | |
| label="Optional prompt / description", | |
| placeholder="Describe the image for better detail (leave blank for auto)…", | |
| lines=3, | |
| visible=False, | |
| ) | |
| up_run = gr.Button("Upscale 4×", variant="primary") | |
| with gr.Column(scale=2): | |
| up_live = gr.Image( | |
| label="Output", | |
| visible=True, | |
| show_label=True, | |
| type="pil", | |
| height=400, | |
| ) | |
| up_slider = gr.ImageSlider( | |
| label="Original (left) ↔ PiD 4× upscale (right)", | |
| visible=False, | |
| type="pil", | |
| height=720, | |
| max_height=720, | |
| ) | |
| # live dimension info on upload | |
| up_input.upload( | |
| fn=_upscaler_dim_info, | |
| inputs=up_input, | |
| outputs=up_dim_info, | |
| ) | |
| up_run.click( | |
| fn=upscaler_run, | |
| inputs=[up_input, up_prompt], | |
| outputs=[up_live, up_slider], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(mcp_server=True, ssr_mode=False, show_error=True) |