NoobNovel's picture
DDIM face generation — full project
0ca4c93
"""DDPM training loop with W&B logging, EMA, auto-resume from checkpoint.
Usage:
python3 train.py --image-size 64 --epochs 50
python3 train.py --image-size 256 --epochs 200 --resume
The checkpoint policy: write `latest.pt` every epoch and `best.pt` when
the running epoch loss improves. Auto-resume looks for `latest.pt` under
the configured ckpt_dir and loads model + optimizer + EMA + epoch + step.
"""
from __future__ import annotations
import argparse
import math
import os
import random
import signal
import sys
import time
from typing import Optional
# Apple Silicon: a few ops still fall back to CPU. Enable that by default.
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from config import Config, get_default_config
from models.unet import UNet
from models.diffusion import GaussianDiffusion, EMA, AdamW
from utils.dataset import make_dataloader, denormalize
from utils.visualize import make_grid, save_image_grid
# ---------------------------------------------------------------------------
# Util
# ---------------------------------------------------------------------------
def seed_everything(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def build_model(cfg: Config) -> UNet:
return UNet(
image_size=cfg.image_size,
in_channels=cfg.in_channels,
base_channels=cfg.base_channels,
channel_mults=cfg.channel_mults,
num_res_blocks=cfg.num_res_blocks,
attn_resolutions=cfg.attn_resolutions,
time_embed_dim=cfg.time_embed_dim,
dropout=cfg.dropout,
)
def build_diffusion(cfg: Config) -> GaussianDiffusion:
return GaussianDiffusion(
timesteps=cfg.timesteps,
beta_start=cfg.beta_start,
beta_end=cfg.beta_end,
schedule=cfg.beta_schedule,
)
def latest_ckpt_path(cfg: Config) -> str:
return os.path.join(cfg.ckpt_dir, f"{cfg.run_name}_latest.pt")
def best_ckpt_path(cfg: Config) -> str:
return os.path.join(cfg.ckpt_dir, f"{cfg.run_name}_best.pt")
def save_checkpoint(path: str, *, model, optimizer, ema, epoch, step, best_loss, cfg: Config):
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
payload = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"ema": ema.state_dict() if ema is not None else None,
"epoch": epoch,
"step": step,
"best_loss": best_loss,
"config": cfg.to_dict(),
}
tmp = path + ".tmp"
torch.save(payload, tmp)
os.replace(tmp, path)
def load_checkpoint(path: str, *, model, optimizer, ema, device):
payload = torch.load(path, map_location=device)
model.load_state_dict(payload["model"])
if optimizer is not None and "optimizer" in payload:
optimizer.load_state_dict(payload["optimizer"])
if ema is not None and payload.get("ema") is not None:
ema.load_state_dict(payload["ema"])
return payload
# ---------------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------------
def train(cfg: Config, args):
seed_everything(cfg.seed)
device = torch.device(cfg.device)
print(f"[train] device={device} run={cfg.run_name} image={cfg.image_size}")
# ---- data --------------------------------------------------------
loader = make_dataloader(
root=cfg.data_dir,
image_size=cfg.image_size,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
augment=True,
limit=args.limit,
)
print(f"[train] dataset images={len(loader.dataset)} batches/epoch={len(loader)}")
# ---- model + diffusion ------------------------------------------
model = build_model(cfg).to(device)
diffusion = build_diffusion(cfg).to(device)
optimizer = AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
ema = EMA(model, decay=cfg.ema_decay)
n_params = sum(p.numel() for p in model.parameters())
print(f"[train] model params={n_params/1e6:.1f}M")
# ---- resume ------------------------------------------------------
start_epoch = 0
global_step = 0
best_loss = math.inf
ckpt = latest_ckpt_path(cfg)
if args.resume and os.path.isfile(ckpt):
payload = load_checkpoint(ckpt, model=model, optimizer=optimizer, ema=ema, device=device)
start_epoch = payload.get("epoch", 0) + 1
global_step = payload.get("step", 0)
best_loss = payload.get("best_loss", math.inf)
print(f"[train] resumed from {ckpt} at epoch={start_epoch} step={global_step}")
# ---- W&B ---------------------------------------------------------
use_wandb = cfg.use_wandb and not args.no_wandb
wandb_run = None
if use_wandb:
try:
import wandb
wandb_run = wandb.init(
project=cfg.wandb_project,
name=cfg.run_name,
config=cfg.to_dict(),
resume="allow",
)
except Exception as e: # noqa: BLE001
print(f"[train] wandb disabled ({e})")
use_wandb = False
# ---- graceful shutdown so we always save latest ------------------
interrupted = {"flag": False}
def _on_sig(signum, frame):
interrupted["flag"] = True
print("\n[train] caught signal, finishing batch then saving checkpoint...")
signal.signal(signal.SIGINT, _on_sig)
signal.signal(signal.SIGTERM, _on_sig)
# ---- training ----------------------------------------------------
sample_shape = (min(16, cfg.batch_size), cfg.in_channels, cfg.image_size, cfg.image_size)
fixed_noise = torch.randn(sample_shape, device=device)
for epoch in range(start_epoch, cfg.epochs):
model.train()
epoch_loss = 0.0
epoch_count = 0
pbar = tqdm(loader, desc=f"epoch {epoch}", dynamic_ncols=True)
for batch in pbar:
batch = batch.to(device, non_blocking=True)
loss = diffusion.training_loss(model, batch)
optimizer.zero_grad(set_to_none=True)
loss.backward()
if cfg.grad_clip:
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
optimizer.step()
ema.update(model)
global_step += 1
loss_v = loss.item()
epoch_loss += loss_v
epoch_count += 1
pbar.set_postfix(loss=f"{loss_v:.4f}", step=global_step)
if use_wandb and global_step % cfg.log_every == 0:
import wandb
wandb.log({"loss": loss_v, "epoch": epoch}, step=global_step)
if interrupted["flag"]:
break
avg_loss = epoch_loss / max(epoch_count, 1)
print(f"[train] epoch {epoch} avg_loss={avg_loss:.4f}")
if use_wandb:
import wandb
wandb.log({"epoch_loss": avg_loss, "epoch": epoch}, step=global_step)
# ---- sample grid at every Nth epoch -------------------------
if (epoch + 1) % cfg.sample_every_epochs == 0 or epoch == 0:
model.eval()
ema_model = build_model(cfg).to(device)
ema.copy_to(ema_model)
ema_model.eval()
with torch.no_grad():
samples = diffusion.ddim_sample(
ema_model, sample_shape, num_steps=cfg.ddim_steps, eta=cfg.ddim_eta,
x_T=fixed_noise.clone(), device=device,
)
sample_path = os.path.join(cfg.sample_dir,
f"{cfg.run_name}_epoch{epoch:04d}.png")
save_image_grid(samples.cpu(), sample_path, nrow=4)
if use_wandb:
import wandb
wandb.log({"samples": wandb.Image(sample_path)}, step=global_step)
# ---- checkpoint --------------------------------------------
if (epoch + 1) % cfg.ckpt_every_epochs == 0 or interrupted["flag"]:
save_checkpoint(
latest_ckpt_path(cfg),
model=model, optimizer=optimizer, ema=ema,
epoch=epoch, step=global_step,
best_loss=best_loss, cfg=cfg,
)
if avg_loss < best_loss:
best_loss = avg_loss
save_checkpoint(
best_ckpt_path(cfg),
model=model, optimizer=optimizer, ema=ema,
epoch=epoch, step=global_step,
best_loss=best_loss, cfg=cfg,
)
if interrupted["flag"]:
print("[train] saved and exiting")
break
if wandb_run is not None:
wandb_run.finish()
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--image-size", type=int, default=64, choices=[64, 128, 256])
p.add_argument("--epochs", type=int, default=None)
p.add_argument("--batch-size", type=int, default=None)
p.add_argument("--lr", type=float, default=None)
p.add_argument("--num-workers", type=int, default=None)
p.add_argument("--limit", type=int, default=None,
help="cap dataset size (smoke tests)")
p.add_argument("--resume", action="store_true",
help="auto-load <run>_latest.pt if present")
p.add_argument("--no-wandb", action="store_true")
p.add_argument("--run-name", type=str, default=None)
return p.parse_args()
def main():
args = parse_args()
overrides = {}
if args.epochs is not None: overrides["epochs"] = args.epochs
if args.batch_size is not None: overrides["batch_size"] = args.batch_size
if args.lr is not None: overrides["lr"] = args.lr
if args.num_workers is not None: overrides["num_workers"] = args.num_workers
if args.run_name is not None: overrides["run_name"] = args.run_name
cfg = Config.for_stage(args.image_size, **overrides)
os.makedirs(cfg.ckpt_dir, exist_ok=True)
os.makedirs(cfg.sample_dir, exist_ok=True)
train(cfg, args)
if __name__ == "__main__":
main()