Spaces:
Running on Zero
Running on Zero
Commit ·
18982d7
1
Parent(s): 5655c6f
Preload both PiD checkpoints (2k + 2kto4k); pick at request time based on resolution (>512 → 2kto4k for 4K target)
Browse files
app.py
CHANGED
|
@@ -32,6 +32,7 @@ snapshot_download(
|
|
| 32 |
local_dir=PID_REPO_DIR,
|
| 33 |
allow_patterns=[
|
| 34 |
"checkpoints/PiD_res2k_sr4x_official_flux_distill_4step/*",
|
|
|
|
| 35 |
"checkpoints/ae.safetensors",
|
| 36 |
],
|
| 37 |
)
|
|
@@ -48,7 +49,6 @@ from pid._src.utils.model_loader import load_model_from_checkpoint
|
|
| 48 |
|
| 49 |
DTYPE = torch.bfloat16
|
| 50 |
BACKBONE = "zimage"
|
| 51 |
-
CKPT_TYPE = "2k"
|
| 52 |
SR_SCALE = 4
|
| 53 |
PID_INFERENCE_STEPS = 4
|
| 54 |
|
|
@@ -105,19 +105,32 @@ from diffusers import AutoencoderTiny
|
|
| 105 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=DTYPE).to("cuda")
|
| 106 |
taef1.eval()
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
print("[pid] ready", flush=True)
|
| 119 |
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
def _latent_to_pil(tensor: torch.Tensor) -> Image.Image:
|
| 122 |
"""PiD output is (C, T, H, W) with T=1 for image -> PIL.Image."""
|
| 123 |
if tensor.dim() == 4:
|
|
@@ -145,7 +158,7 @@ def _pid_pixel_to_pil(x: torch.Tensor) -> Image.Image:
|
|
| 145 |
return Image.fromarray(arr)
|
| 146 |
|
| 147 |
|
| 148 |
-
def _pid_stream(latent: torch.Tensor, baseline_01: torch.Tensor, sigma: float, caption: str, num_steps: int = PID_INFERENCE_STEPS):
|
| 149 |
"""Reimplementation of PiDDistillModel.generate_samples_from_batch that yields
|
| 150 |
the current pixel-space tensor after each of the `num_steps` student-sampler
|
| 151 |
iterations. Final yield is the clean output."""
|
|
@@ -310,7 +323,8 @@ def generate(
|
|
| 310 |
# ---- PiD upscaling on the final latent, streaming the 4 internal steps ----
|
| 311 |
final_sigma = float(pipeline.scheduler.sigmas[-1].item())
|
| 312 |
pid_img = None
|
| 313 |
-
|
|
|
|
| 314 |
pid_img = _pid_pixel_to_pil(x)
|
| 315 |
yield (
|
| 316 |
gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD — step {k}/{total}"),
|
|
|
|
| 32 |
local_dir=PID_REPO_DIR,
|
| 33 |
allow_patterns=[
|
| 34 |
"checkpoints/PiD_res2k_sr4x_official_flux_distill_4step/*",
|
| 35 |
+
"checkpoints/PiD_res2kto4k_sr4x_official_flux_distill_4step/*",
|
| 36 |
"checkpoints/ae.safetensors",
|
| 37 |
],
|
| 38 |
)
|
|
|
|
| 49 |
|
| 50 |
DTYPE = torch.bfloat16
|
| 51 |
BACKBONE = "zimage"
|
|
|
|
| 52 |
SR_SCALE = 4
|
| 53 |
PID_INFERENCE_STEPS = 4
|
| 54 |
|
|
|
|
| 105 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=DTYPE).to("cuda")
|
| 106 |
taef1.eval()
|
| 107 |
|
| 108 |
+
def _load_pid(ckpt_type: str):
|
| 109 |
+
meta = get_pid_checkpoint(BACKBONE, ckpt_type)
|
| 110 |
+
print(f"[pid] loading PiD decoder ({ckpt_type})...", flush=True)
|
| 111 |
+
model, _ = load_model_from_checkpoint(
|
| 112 |
+
experiment_name=meta.experiment,
|
| 113 |
+
checkpoint_path=meta.checkpoint_path,
|
| 114 |
+
config_file="pid/_src/configs/pid/config.py",
|
| 115 |
+
enable_fsdp=False,
|
| 116 |
+
strict=False,
|
| 117 |
+
)
|
| 118 |
+
model.eval()
|
| 119 |
+
return model
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
pid_models = {
|
| 123 |
+
"2k": _load_pid("2k"),
|
| 124 |
+
"2kto4k": _load_pid("2kto4k"),
|
| 125 |
+
}
|
| 126 |
print("[pid] ready", flush=True)
|
| 127 |
|
| 128 |
|
| 129 |
+
def _pick_pid_model(resolution: int):
|
| 130 |
+
"""2k decoder is trained at 2048px (sweet spot 512 → 2048); 2kto4k handles 1024 → 4K."""
|
| 131 |
+
return pid_models["2kto4k"] if resolution > 512 else pid_models["2k"]
|
| 132 |
+
|
| 133 |
+
|
| 134 |
def _latent_to_pil(tensor: torch.Tensor) -> Image.Image:
|
| 135 |
"""PiD output is (C, T, H, W) with T=1 for image -> PIL.Image."""
|
| 136 |
if tensor.dim() == 4:
|
|
|
|
| 158 |
return Image.fromarray(arr)
|
| 159 |
|
| 160 |
|
| 161 |
+
def _pid_stream(pid_model, latent: torch.Tensor, baseline_01: torch.Tensor, sigma: float, caption: str, num_steps: int = PID_INFERENCE_STEPS):
|
| 162 |
"""Reimplementation of PiDDistillModel.generate_samples_from_batch that yields
|
| 163 |
the current pixel-space tensor after each of the `num_steps` student-sampler
|
| 164 |
iterations. Final yield is the clean output."""
|
|
|
|
| 323 |
# ---- PiD upscaling on the final latent, streaming the 4 internal steps ----
|
| 324 |
final_sigma = float(pipeline.scheduler.sigmas[-1].item())
|
| 325 |
pid_img = None
|
| 326 |
+
pid_model = _pick_pid_model(H)
|
| 327 |
+
for k, total, x in _pid_stream(pid_model, final_latent, baseline_01, final_sigma, prompt):
|
| 328 |
pid_img = _pid_pixel_to_pil(x)
|
| 329 |
yield (
|
| 330 |
gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD — step {k}/{total}"),
|