File size: 6,444 Bytes
0ca4c93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Inference: load a checkpoint and generate samples / trajectory / interp grid.

Usage:
    # 16 random faces with DDIM 50 steps
    python3 sample.py --ckpt checkpoints/stage-256_best.pt --num 16 --steps 50

    # save denoising trajectory as a GIF
    python3 sample.py --ckpt checkpoints/stage-256_best.pt --trajectory \
        --num 4 --steps 50 --out samples/traj.gif

    # interpolate between two random latents (8 frames, slerp)
    python3 sample.py --ckpt checkpoints/stage-256_best.pt --interpolate 8 \
        --out samples/interp.png

    # DDPM-1000 vs DDIM-50 side-by-side
    python3 sample.py --ckpt ... --compare-ddpm --num 4
"""
from __future__ import annotations

import argparse
import os
from typing import Optional

os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")

import torch

from config import Config
from models.unet import UNet
from models.diffusion import GaussianDiffusion, EMA
from utils.visualize import (save_image_grid, trajectory_to_gif,
                             interpolate_latents, make_grid)
from PIL import Image


# ---------------------------------------------------------------------------
def load_run(ckpt_path: str, device: torch.device, prefer_ema: bool = True):
    payload = torch.load(ckpt_path, map_location=device)
    cfg_dict = payload["config"]
    cfg = Config(**cfg_dict)
    model = UNet(
        image_size=cfg.image_size,
        in_channels=cfg.in_channels,
        base_channels=cfg.base_channels,
        channel_mults=cfg.channel_mults,
        num_res_blocks=cfg.num_res_blocks,
        attn_resolutions=cfg.attn_resolutions,
        time_embed_dim=cfg.time_embed_dim,
        dropout=cfg.dropout,
    ).to(device)
    if prefer_ema and payload.get("ema") is not None:
        model.load_state_dict(payload["ema"], strict=True)
        print("[sample] loaded EMA weights")
    else:
        model.load_state_dict(payload["model"], strict=True)
        print("[sample] loaded raw weights")
    model.eval()
    diffusion = GaussianDiffusion(
        timesteps=cfg.timesteps, beta_start=cfg.beta_start,
        beta_end=cfg.beta_end, schedule=cfg.beta_schedule,
    ).to(device)
    return cfg, model, diffusion


# ---------------------------------------------------------------------------
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpt", required=True)
    p.add_argument("--num", type=int, default=16)
    p.add_argument("--steps", type=int, default=50)
    p.add_argument("--eta", type=float, default=0.0)
    p.add_argument("--seed", type=int, default=None)
    p.add_argument("--out", type=str, default=None)
    p.add_argument("--no-ema", action="store_true")
    p.add_argument("--device", type=str, default=None)
    # mode flags
    p.add_argument("--trajectory", action="store_true",
                   help="save denoising trajectory as a GIF")
    p.add_argument("--interpolate", type=int, default=0,
                   help="number of interpolation frames between two latents")
    p.add_argument("--compare-ddpm", action="store_true",
                   help="generate DDIM-N vs DDPM-T side-by-side comparison")
    return p.parse_args()


# ---------------------------------------------------------------------------
def main():
    args = parse_args()
    device = torch.device(args.device or ("mps" if torch.backends.mps.is_available() else "cpu"))
    cfg, model, diffusion = load_run(args.ckpt, device, prefer_ema=not args.no_ema)
    print(f"[sample] image_size={cfg.image_size} run={cfg.run_name} device={device}")

    if args.seed is not None:
        torch.manual_seed(args.seed)

    shape = (args.num, cfg.in_channels, cfg.image_size, cfg.image_size)
    out_dir = cfg.sample_dir
    os.makedirs(out_dir, exist_ok=True)

    # ---- interpolation -------------------------------------------------
    if args.interpolate > 0:
        n = args.interpolate
        z1 = torch.randn(1, *shape[1:], device=device)
        z2 = torch.randn(1, *shape[1:], device=device)
        latents = interpolate_latents(z1.cpu(), z2.cpu(), num_steps=n).squeeze(1).to(device)
        # latents shape: (n, C, H, W). One sampling pass per frame.
        with torch.no_grad():
            samples = diffusion.ddim_sample(
                model, (n, *shape[1:]), num_steps=args.steps, eta=args.eta,
                x_T=latents, device=device,
            )
        out = args.out or os.path.join(out_dir, f"interp_{n}.png")
        save_image_grid(samples.cpu(), out, nrow=n)
        print(f"[sample] interpolation saved -> {out}")
        return

    # ---- trajectory GIF ------------------------------------------------
    if args.trajectory:
        x_T = torch.randn(shape, device=device)
        with torch.no_grad():
            _, traj = diffusion.ddim_sample(
                model, shape, num_steps=args.steps, eta=args.eta,
                x_T=x_T, device=device,
                return_trajectory=True, trajectory_stride=1,
            )
        out = args.out or os.path.join(out_dir, f"traj_{args.steps}.gif")
        trajectory_to_gif(traj, out, fps=10)
        print(f"[sample] trajectory saved -> {out}")
        return

    # ---- DDIM vs DDPM comparison --------------------------------------
    if args.compare_ddpm:
        x_T = torch.randn(shape, device=device)
        with torch.no_grad():
            ddim = diffusion.ddim_sample(model, shape, num_steps=args.steps,
                                         eta=args.eta, x_T=x_T.clone(), device=device)
            ddpm = diffusion.ddim_sample(model, shape, num_steps=cfg.timesteps,
                                         eta=1.0, x_T=x_T.clone(), device=device)
        # stack as 2 rows
        side = torch.cat([ddim.cpu(), ddpm.cpu()], dim=0)
        out = args.out or os.path.join(out_dir, f"compare_ddim{args.steps}_vs_ddpm.png")
        save_image_grid(side, out, nrow=args.num)
        print(f"[sample] comparison saved -> {out}  (top: DDIM-{args.steps}, bottom: DDPM-{cfg.timesteps})")
        return

    # ---- default: simple grid -----------------------------------------
    with torch.no_grad():
        samples = diffusion.ddim_sample(
            model, shape, num_steps=args.steps, eta=args.eta, device=device,
        )
    out = args.out or os.path.join(out_dir, f"samples_n{args.num}_s{args.steps}.png")
    save_image_grid(samples.cpu(), out)
    print(f"[sample] grid saved -> {out}")


if __name__ == "__main__":
    main()