apolinario commited on
Commit
18982d7
·
1 Parent(s): 5655c6f

Preload both PiD checkpoints (2k + 2kto4k); pick at request time based on resolution (>512 → 2kto4k for 4K target)

Browse files
Files changed (1) hide show
  1. app.py +27 -13
app.py CHANGED
@@ -32,6 +32,7 @@ snapshot_download(
32
  local_dir=PID_REPO_DIR,
33
  allow_patterns=[
34
  "checkpoints/PiD_res2k_sr4x_official_flux_distill_4step/*",
 
35
  "checkpoints/ae.safetensors",
36
  ],
37
  )
@@ -48,7 +49,6 @@ from pid._src.utils.model_loader import load_model_from_checkpoint
48
 
49
  DTYPE = torch.bfloat16
50
  BACKBONE = "zimage"
51
- CKPT_TYPE = "2k"
52
  SR_SCALE = 4
53
  PID_INFERENCE_STEPS = 4
54
 
@@ -105,19 +105,32 @@ from diffusers import AutoencoderTiny
105
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=DTYPE).to("cuda")
106
  taef1.eval()
107
 
108
- print("[pid] loading PiD decoder...", flush=True)
109
- pid_meta = get_pid_checkpoint(BACKBONE, CKPT_TYPE)
110
- pid_model, _pid_cfg = load_model_from_checkpoint(
111
- experiment_name=pid_meta.experiment,
112
- checkpoint_path=pid_meta.checkpoint_path,
113
- config_file="pid/_src/configs/pid/config.py",
114
- enable_fsdp=False,
115
- strict=False,
116
- )
117
- pid_model.eval()
 
 
 
 
 
 
 
 
118
  print("[pid] ready", flush=True)
119
 
120
 
 
 
 
 
 
121
  def _latent_to_pil(tensor: torch.Tensor) -> Image.Image:
122
  """PiD output is (C, T, H, W) with T=1 for image -> PIL.Image."""
123
  if tensor.dim() == 4:
@@ -145,7 +158,7 @@ def _pid_pixel_to_pil(x: torch.Tensor) -> Image.Image:
145
  return Image.fromarray(arr)
146
 
147
 
148
- def _pid_stream(latent: torch.Tensor, baseline_01: torch.Tensor, sigma: float, caption: str, num_steps: int = PID_INFERENCE_STEPS):
149
  """Reimplementation of PiDDistillModel.generate_samples_from_batch that yields
150
  the current pixel-space tensor after each of the `num_steps` student-sampler
151
  iterations. Final yield is the clean output."""
@@ -310,7 +323,8 @@ def generate(
310
  # ---- PiD upscaling on the final latent, streaming the 4 internal steps ----
311
  final_sigma = float(pipeline.scheduler.sigmas[-1].item())
312
  pid_img = None
313
- for k, total, x in _pid_stream(final_latent, baseline_01, final_sigma, prompt):
 
314
  pid_img = _pid_pixel_to_pil(x)
315
  yield (
316
  gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD — step {k}/{total}"),
 
32
  local_dir=PID_REPO_DIR,
33
  allow_patterns=[
34
  "checkpoints/PiD_res2k_sr4x_official_flux_distill_4step/*",
35
+ "checkpoints/PiD_res2kto4k_sr4x_official_flux_distill_4step/*",
36
  "checkpoints/ae.safetensors",
37
  ],
38
  )
 
49
 
50
  DTYPE = torch.bfloat16
51
  BACKBONE = "zimage"
 
52
  SR_SCALE = 4
53
  PID_INFERENCE_STEPS = 4
54
 
 
105
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=DTYPE).to("cuda")
106
  taef1.eval()
107
 
108
+ def _load_pid(ckpt_type: str):
109
+ meta = get_pid_checkpoint(BACKBONE, ckpt_type)
110
+ print(f"[pid] loading PiD decoder ({ckpt_type})...", flush=True)
111
+ model, _ = load_model_from_checkpoint(
112
+ experiment_name=meta.experiment,
113
+ checkpoint_path=meta.checkpoint_path,
114
+ config_file="pid/_src/configs/pid/config.py",
115
+ enable_fsdp=False,
116
+ strict=False,
117
+ )
118
+ model.eval()
119
+ return model
120
+
121
+
122
+ pid_models = {
123
+ "2k": _load_pid("2k"),
124
+ "2kto4k": _load_pid("2kto4k"),
125
+ }
126
  print("[pid] ready", flush=True)
127
 
128
 
129
+ def _pick_pid_model(resolution: int):
130
+ """2k decoder is trained at 2048px (sweet spot 512 → 2048); 2kto4k handles 1024 → 4K."""
131
+ return pid_models["2kto4k"] if resolution > 512 else pid_models["2k"]
132
+
133
+
134
  def _latent_to_pil(tensor: torch.Tensor) -> Image.Image:
135
  """PiD output is (C, T, H, W) with T=1 for image -> PIL.Image."""
136
  if tensor.dim() == 4:
 
158
  return Image.fromarray(arr)
159
 
160
 
161
+ def _pid_stream(pid_model, latent: torch.Tensor, baseline_01: torch.Tensor, sigma: float, caption: str, num_steps: int = PID_INFERENCE_STEPS):
162
  """Reimplementation of PiDDistillModel.generate_samples_from_batch that yields
163
  the current pixel-space tensor after each of the `num_steps` student-sampler
164
  iterations. Final yield is the clean output."""
 
323
  # ---- PiD upscaling on the final latent, streaming the 4 internal steps ----
324
  final_sigma = float(pipeline.scheduler.sigmas[-1].item())
325
  pid_img = None
326
+ pid_model = _pick_pid_model(H)
327
+ for k, total, x in _pid_stream(pid_model, final_latent, baseline_01, final_sigma, prompt):
328
  pid_img = _pid_pixel_to_pil(x)
329
  yield (
330
  gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD — step {k}/{total}"),