NoobNovel's picture
DDIM face generation — full project
0ca4c93
"""Inference: load a checkpoint and generate samples / trajectory / interp grid.
Usage:
# 16 random faces with DDIM 50 steps
python3 sample.py --ckpt checkpoints/stage-256_best.pt --num 16 --steps 50
# save denoising trajectory as a GIF
python3 sample.py --ckpt checkpoints/stage-256_best.pt --trajectory \
--num 4 --steps 50 --out samples/traj.gif
# interpolate between two random latents (8 frames, slerp)
python3 sample.py --ckpt checkpoints/stage-256_best.pt --interpolate 8 \
--out samples/interp.png
# DDPM-1000 vs DDIM-50 side-by-side
python3 sample.py --ckpt ... --compare-ddpm --num 4
"""
from __future__ import annotations
import argparse
import os
from typing import Optional
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
import torch
from config import Config
from models.unet import UNet
from models.diffusion import GaussianDiffusion, EMA
from utils.visualize import (save_image_grid, trajectory_to_gif,
interpolate_latents, make_grid)
from PIL import Image
# ---------------------------------------------------------------------------
def load_run(ckpt_path: str, device: torch.device, prefer_ema: bool = True):
payload = torch.load(ckpt_path, map_location=device)
cfg_dict = payload["config"]
cfg = Config(**cfg_dict)
model = UNet(
image_size=cfg.image_size,
in_channels=cfg.in_channels,
base_channels=cfg.base_channels,
channel_mults=cfg.channel_mults,
num_res_blocks=cfg.num_res_blocks,
attn_resolutions=cfg.attn_resolutions,
time_embed_dim=cfg.time_embed_dim,
dropout=cfg.dropout,
).to(device)
if prefer_ema and payload.get("ema") is not None:
model.load_state_dict(payload["ema"], strict=True)
print("[sample] loaded EMA weights")
else:
model.load_state_dict(payload["model"], strict=True)
print("[sample] loaded raw weights")
model.eval()
diffusion = GaussianDiffusion(
timesteps=cfg.timesteps, beta_start=cfg.beta_start,
beta_end=cfg.beta_end, schedule=cfg.beta_schedule,
).to(device)
return cfg, model, diffusion
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True)
p.add_argument("--num", type=int, default=16)
p.add_argument("--steps", type=int, default=50)
p.add_argument("--eta", type=float, default=0.0)
p.add_argument("--seed", type=int, default=None)
p.add_argument("--out", type=str, default=None)
p.add_argument("--no-ema", action="store_true")
p.add_argument("--device", type=str, default=None)
# mode flags
p.add_argument("--trajectory", action="store_true",
help="save denoising trajectory as a GIF")
p.add_argument("--interpolate", type=int, default=0,
help="number of interpolation frames between two latents")
p.add_argument("--compare-ddpm", action="store_true",
help="generate DDIM-N vs DDPM-T side-by-side comparison")
return p.parse_args()
# ---------------------------------------------------------------------------
def main():
args = parse_args()
device = torch.device(args.device or ("mps" if torch.backends.mps.is_available() else "cpu"))
cfg, model, diffusion = load_run(args.ckpt, device, prefer_ema=not args.no_ema)
print(f"[sample] image_size={cfg.image_size} run={cfg.run_name} device={device}")
if args.seed is not None:
torch.manual_seed(args.seed)
shape = (args.num, cfg.in_channels, cfg.image_size, cfg.image_size)
out_dir = cfg.sample_dir
os.makedirs(out_dir, exist_ok=True)
# ---- interpolation -------------------------------------------------
if args.interpolate > 0:
n = args.interpolate
z1 = torch.randn(1, *shape[1:], device=device)
z2 = torch.randn(1, *shape[1:], device=device)
latents = interpolate_latents(z1.cpu(), z2.cpu(), num_steps=n).squeeze(1).to(device)
# latents shape: (n, C, H, W). One sampling pass per frame.
with torch.no_grad():
samples = diffusion.ddim_sample(
model, (n, *shape[1:]), num_steps=args.steps, eta=args.eta,
x_T=latents, device=device,
)
out = args.out or os.path.join(out_dir, f"interp_{n}.png")
save_image_grid(samples.cpu(), out, nrow=n)
print(f"[sample] interpolation saved -> {out}")
return
# ---- trajectory GIF ------------------------------------------------
if args.trajectory:
x_T = torch.randn(shape, device=device)
with torch.no_grad():
_, traj = diffusion.ddim_sample(
model, shape, num_steps=args.steps, eta=args.eta,
x_T=x_T, device=device,
return_trajectory=True, trajectory_stride=1,
)
out = args.out or os.path.join(out_dir, f"traj_{args.steps}.gif")
trajectory_to_gif(traj, out, fps=10)
print(f"[sample] trajectory saved -> {out}")
return
# ---- DDIM vs DDPM comparison --------------------------------------
if args.compare_ddpm:
x_T = torch.randn(shape, device=device)
with torch.no_grad():
ddim = diffusion.ddim_sample(model, shape, num_steps=args.steps,
eta=args.eta, x_T=x_T.clone(), device=device)
ddpm = diffusion.ddim_sample(model, shape, num_steps=cfg.timesteps,
eta=1.0, x_T=x_T.clone(), device=device)
# stack as 2 rows
side = torch.cat([ddim.cpu(), ddpm.cpu()], dim=0)
out = args.out or os.path.join(out_dir, f"compare_ddim{args.steps}_vs_ddpm.png")
save_image_grid(side, out, nrow=args.num)
print(f"[sample] comparison saved -> {out} (top: DDIM-{args.steps}, bottom: DDPM-{cfg.timesteps})")
return
# ---- default: simple grid -----------------------------------------
with torch.no_grad():
samples = diffusion.ddim_sample(
model, shape, num_steps=args.steps, eta=args.eta, device=device,
)
out = args.out or os.path.join(out_dir, f"samples_n{args.num}_s{args.steps}.png")
save_image_grid(samples.cpu(), out)
print(f"[sample] grid saved -> {out}")
if __name__ == "__main__":
main()