Spaces:
Sleeping
Sleeping
File size: 6,444 Bytes
0ca4c93 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | """Inference: load a checkpoint and generate samples / trajectory / interp grid.
Usage:
# 16 random faces with DDIM 50 steps
python3 sample.py --ckpt checkpoints/stage-256_best.pt --num 16 --steps 50
# save denoising trajectory as a GIF
python3 sample.py --ckpt checkpoints/stage-256_best.pt --trajectory \
--num 4 --steps 50 --out samples/traj.gif
# interpolate between two random latents (8 frames, slerp)
python3 sample.py --ckpt checkpoints/stage-256_best.pt --interpolate 8 \
--out samples/interp.png
# DDPM-1000 vs DDIM-50 side-by-side
python3 sample.py --ckpt ... --compare-ddpm --num 4
"""
from __future__ import annotations
import argparse
import os
from typing import Optional
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
import torch
from config import Config
from models.unet import UNet
from models.diffusion import GaussianDiffusion, EMA
from utils.visualize import (save_image_grid, trajectory_to_gif,
interpolate_latents, make_grid)
from PIL import Image
# ---------------------------------------------------------------------------
def load_run(ckpt_path: str, device: torch.device, prefer_ema: bool = True):
payload = torch.load(ckpt_path, map_location=device)
cfg_dict = payload["config"]
cfg = Config(**cfg_dict)
model = UNet(
image_size=cfg.image_size,
in_channels=cfg.in_channels,
base_channels=cfg.base_channels,
channel_mults=cfg.channel_mults,
num_res_blocks=cfg.num_res_blocks,
attn_resolutions=cfg.attn_resolutions,
time_embed_dim=cfg.time_embed_dim,
dropout=cfg.dropout,
).to(device)
if prefer_ema and payload.get("ema") is not None:
model.load_state_dict(payload["ema"], strict=True)
print("[sample] loaded EMA weights")
else:
model.load_state_dict(payload["model"], strict=True)
print("[sample] loaded raw weights")
model.eval()
diffusion = GaussianDiffusion(
timesteps=cfg.timesteps, beta_start=cfg.beta_start,
beta_end=cfg.beta_end, schedule=cfg.beta_schedule,
).to(device)
return cfg, model, diffusion
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True)
p.add_argument("--num", type=int, default=16)
p.add_argument("--steps", type=int, default=50)
p.add_argument("--eta", type=float, default=0.0)
p.add_argument("--seed", type=int, default=None)
p.add_argument("--out", type=str, default=None)
p.add_argument("--no-ema", action="store_true")
p.add_argument("--device", type=str, default=None)
# mode flags
p.add_argument("--trajectory", action="store_true",
help="save denoising trajectory as a GIF")
p.add_argument("--interpolate", type=int, default=0,
help="number of interpolation frames between two latents")
p.add_argument("--compare-ddpm", action="store_true",
help="generate DDIM-N vs DDPM-T side-by-side comparison")
return p.parse_args()
# ---------------------------------------------------------------------------
def main():
args = parse_args()
device = torch.device(args.device or ("mps" if torch.backends.mps.is_available() else "cpu"))
cfg, model, diffusion = load_run(args.ckpt, device, prefer_ema=not args.no_ema)
print(f"[sample] image_size={cfg.image_size} run={cfg.run_name} device={device}")
if args.seed is not None:
torch.manual_seed(args.seed)
shape = (args.num, cfg.in_channels, cfg.image_size, cfg.image_size)
out_dir = cfg.sample_dir
os.makedirs(out_dir, exist_ok=True)
# ---- interpolation -------------------------------------------------
if args.interpolate > 0:
n = args.interpolate
z1 = torch.randn(1, *shape[1:], device=device)
z2 = torch.randn(1, *shape[1:], device=device)
latents = interpolate_latents(z1.cpu(), z2.cpu(), num_steps=n).squeeze(1).to(device)
# latents shape: (n, C, H, W). One sampling pass per frame.
with torch.no_grad():
samples = diffusion.ddim_sample(
model, (n, *shape[1:]), num_steps=args.steps, eta=args.eta,
x_T=latents, device=device,
)
out = args.out or os.path.join(out_dir, f"interp_{n}.png")
save_image_grid(samples.cpu(), out, nrow=n)
print(f"[sample] interpolation saved -> {out}")
return
# ---- trajectory GIF ------------------------------------------------
if args.trajectory:
x_T = torch.randn(shape, device=device)
with torch.no_grad():
_, traj = diffusion.ddim_sample(
model, shape, num_steps=args.steps, eta=args.eta,
x_T=x_T, device=device,
return_trajectory=True, trajectory_stride=1,
)
out = args.out or os.path.join(out_dir, f"traj_{args.steps}.gif")
trajectory_to_gif(traj, out, fps=10)
print(f"[sample] trajectory saved -> {out}")
return
# ---- DDIM vs DDPM comparison --------------------------------------
if args.compare_ddpm:
x_T = torch.randn(shape, device=device)
with torch.no_grad():
ddim = diffusion.ddim_sample(model, shape, num_steps=args.steps,
eta=args.eta, x_T=x_T.clone(), device=device)
ddpm = diffusion.ddim_sample(model, shape, num_steps=cfg.timesteps,
eta=1.0, x_T=x_T.clone(), device=device)
# stack as 2 rows
side = torch.cat([ddim.cpu(), ddpm.cpu()], dim=0)
out = args.out or os.path.join(out_dir, f"compare_ddim{args.steps}_vs_ddpm.png")
save_image_grid(side, out, nrow=args.num)
print(f"[sample] comparison saved -> {out} (top: DDIM-{args.steps}, bottom: DDPM-{cfg.timesteps})")
return
# ---- default: simple grid -----------------------------------------
with torch.no_grad():
samples = diffusion.ddim_sample(
model, shape, num_steps=args.steps, eta=args.eta, device=device,
)
out = args.out or os.path.join(out_dir, f"samples_n{args.num}_s{args.steps}.png")
save_image_grid(samples.cpu(), out)
print(f"[sample] grid saved -> {out}")
if __name__ == "__main__":
main()
|