| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import time |
| import math |
| import yaml |
| import torch |
| from torch.optim import Adam |
| from tqdm.auto import tqdm |
| from torchvision import datasets, transforms |
| from torchvision.utils import save_image |
| from ema_pytorch import EMA |
| import wandb |
| import numpy as np |
| from torchvision.utils import make_grid |
|
|
|
|
| |
| try: |
| from cleanfid import fid as clean_fid |
| HAS_CLEANFID = True |
| except Exception: |
| HAS_CLEANFID = False |
|
|
| try: |
| from torch_fidelity import calculate_metrics as tf_calculate_metrics |
| HAS_TORCH_FIDELITY = True |
| except Exception: |
| HAS_TORCH_FIDELITY = False |
|
|
| from unet import UNet |
| from diffusion import GaussianDiffusion |
|
|
| |
| |
| |
|
|
|
|
| def frames_to_wandb_video(frames, nrow=8, fps=6): |
| """ |
| Convert a list of [B,C,H,W] tensors (values in [0,1]) into a W&B Video. |
| - For each time step: make a grid of the batch (nrow), convert to HxWxC uint8. |
| - Stack along time to build a (T,H,W,C) numpy array. |
| """ |
| np_frames = [] |
| for f in frames: |
| |
| f = f.clamp(0, 1) |
| grid = make_grid(f, nrow=nrow) |
| grid = (grid * 255.0).byte().cpu().numpy() |
| grid = np.transpose(grid, (1, 2, 0)) |
| np_frames.append(grid) |
| video = np.stack(np_frames, axis=0) |
| return wandb.Video(video, fps=fps, format="mp4") |
| |
| |
| |
|
|
|
|
| def maybe_enable_cuda_speedups(cfg): |
| if torch.cuda.is_available(): |
| if cfg.get("compute", {}).get("enable_tf32", True): |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cudnn.benchmark = True |
|
|
| |
| |
| |
|
|
|
|
| def get_loader_mnist(bs, nw, img_size): |
| tfm = transforms.Compose([ |
| transforms.Resize(img_size), |
| transforms.ToTensor(), |
| transforms.ConvertImageDtype(torch.float32), |
| ]) |
| ds = datasets.MNIST(root="./data", train=True, |
| download=True, transform=tfm) |
| return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, num_workers=nw, pin_memory=True) |
|
|
| |
| |
| |
|
|
|
|
| def log_global_grad_norm_sparsely(model, step, every=1000): |
| """ |
| Logs a single scalar 'train/global_grad_norm' every `every` steps. |
| """ |
| if (step % every) != 0: |
| return |
| with torch.no_grad(): |
| norms = [p.grad.norm().item() |
| for p in model.parameters() if p.grad is not None] |
| if len(norms) == 0: |
| return |
| global_norm = float(torch.tensor(norms).norm().item()) |
| wandb.log({"train/global_grad_norm": global_norm, "step": step}, step=step) |
|
|
|
|
| |
| |
| |
| def ensure_real_ref_folder(dl, out_dir, max_images=50000, img_size=32, force_rgb=False): |
| """ |
| Exports up to `max_images` real images from the dataloader to `out_dir` |
| in PNG format for FID reference. |
| |
| - MNIST is 1-channel; some FID/IS tools expect 3-channel -> set force_rgb=True to replicate channels. |
| - Images are already [0,1] tensors from dataloader. |
| """ |
| os.makedirs(out_dir, exist_ok=True) |
| |
| existing = [f for f in os.listdir(out_dir) if f.lower().endswith(".png")] |
| if len(existing) >= max_images // 10: |
| return |
|
|
| saved = 0 |
| idx = 0 |
| for x, _ in dl: |
| |
| if force_rgb and x.shape[1] == 1: |
| x = x.repeat(1, 3, 1, 1) |
| for i in range(x.size(0)): |
| save_image(x[i], os.path.join(out_dir, f"{idx:06d}.png")) |
| idx += 1 |
| saved += 1 |
| if saved >= max_images: |
| return |
|
|
| |
| |
| |
|
|
|
|
| @torch.inference_mode() |
| def generate_images_to_folder(model, n_images=5000, batch_size=64, out_dir="./gen_eval", force_rgb=True): |
| """ |
| Uses the (EMA) diffusion sampler to generate `n_images` and save as PNGs. |
| Optionally tile grayscale to RGB to satisfy metric toolchains. |
| """ |
| os.makedirs(out_dir, exist_ok=True) |
| saved = 0 |
| idx = 0 |
| while saved < n_images: |
| cur = min(batch_size, n_images - saved) |
| imgs = model.sample(cur) |
| if force_rgb and imgs.shape[1] == 1: |
| imgs = imgs.repeat(1, 3, 1, 1) |
| for i in range(cur): |
| save_image(imgs[i], os.path.join(out_dir, f"{idx:06d}.png")) |
| idx += 1 |
| saved += cur |
|
|
| |
| |
| |
|
|
|
|
| def compute_fid_cleanfid(gen_dir, real_dir): |
| if not HAS_CLEANFID: |
| print("[metrics] clean-fid not installed; skip FID.") |
| return None |
| try: |
| score = clean_fid.compute_fid(gen_dir, real_dir) |
| return float(score) |
| except Exception as e: |
| print("[metrics] clean-fid error:", e) |
| return None |
|
|
|
|
| def compute_inception_score_torchfidelity(gen_dir, cuda=True): |
| if not HAS_TORCH_FIDELITY: |
| print("[metrics] torch-fidelity not installed; skip IS.") |
| return None, None |
| try: |
| metrics = tf_calculate_metrics( |
| input1=gen_dir, |
| cuda=cuda and torch.cuda.is_available(), |
| isc=True, fid=False, kid=False, prc=False |
| ) |
| |
| return float(metrics.get("inception_score_mean", float("nan"))), float(metrics.get("inception_score_std", float("nan"))) |
| except Exception as e: |
| print("[metrics] torch-fidelity error:", e) |
| return None, None |
|
|
| |
| |
| |
|
|
|
|
| def main(cfg_path="config_mnist_small.yaml", seed=42): |
| torch.manual_seed(seed) |
|
|
| |
| cfg = yaml.safe_load(open(cfg_path)) |
| os.makedirs(cfg["train"]["ckpt_dir"], exist_ok=True) |
| os.makedirs("./samples", exist_ok=True) |
| maybe_enable_cuda_speedups(cfg) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| run = None |
| if cfg["wandb"]["enabled"]: |
| if cfg["wandb"].get("mode", "online") == "offline": |
| os.environ["WANDB_MODE"] = "offline" |
| wandb.login() |
| run = wandb.init( |
| project=cfg["project"], |
| name=cfg["run_name"], |
| config=cfg, |
| tags=cfg["wandb"].get("tags", []) |
| ) |
| |
| wandb.config.update({ |
| "hparams/T": cfg["diffusion"]["T"], |
| "hparams/beta_schedule": cfg["diffusion"]["beta_schedule"], |
| "hparams/sampling_steps": cfg["diffusion"]["sampling_steps"], |
| "hparams/eta": cfg["diffusion"]["eta"], |
| }, allow_val_change=True) |
|
|
| |
| dl = get_loader_mnist(cfg["data"]["batch_size"], |
| cfg["data"]["num_workers"], cfg["data"]["image_size"]) |
|
|
| |
| unet = UNet( |
| dim=cfg["model"]["dim"], |
| dim_mults=tuple(cfg["model"]["dim_mults"]), |
| channels=cfg["model"]["channels"], |
| attn_heads=cfg["model"]["attn_heads"], |
| attn_dim_head=cfg["model"]["attn_dim_head"], |
| dropout=cfg["model"]["dropout"], |
| self_condition=cfg["model"]["self_condition"], |
| learned_variance=cfg["model"]["learned_variance"], |
| outer_attn=cfg["model"]["outer_attn"], |
| ).to(device) |
|
|
| diffusion = GaussianDiffusion( |
| unet, |
| image_size=(cfg["data"]["image_size"], cfg["data"]["image_size"]), |
| timesteps=cfg["diffusion"]["T"], |
| beta_schedule=cfg["diffusion"]["beta_schedule"], |
| objective=cfg["diffusion"]["objective"], |
| sampling_steps=cfg["diffusion"]["sampling_steps"], |
| eta=cfg["diffusion"]["eta"], |
| self_condition=cfg["diffusion"]["self_condition"], |
| auto_normalize=True, |
| clamp_x0=cfg["diffusion"]["clamp_x0"] |
| ).to(device) |
|
|
| |
| opt = Adam(diffusion.parameters(), |
| lr=cfg["opt"]["lr"], betas=tuple(cfg["opt"]["betas"])) |
|
|
| |
| ema = None |
| if cfg.get("ema", {}).get("enabled", True): |
| ema = EMA(diffusion, beta=cfg["ema"]["decay"], |
| update_every=cfg["ema"]["update_every"]) |
| ema.to(device) |
|
|
| |
| max_steps = int(cfg["train"]["max_steps"]) |
| log_every = int(cfg["train"]["log_every"]) |
| grad_accum = int(cfg["train"].get("grad_accum", 1)) |
|
|
| |
| global_norm_every = int( |
| cfg.get("metrics", {}).get("global_norm_every", 1000)) |
|
|
| |
| enable_fid = bool(cfg.get("metrics", {}).get("enable_fid", False)) |
| enable_is = bool(cfg.get("metrics", {}).get("enable_is", False)) |
| fid_every = int(cfg.get("metrics", {}).get("fid_every", 4000)) |
| is_every = int(cfg.get("metrics", {}).get("is_every", 4000)) |
| metric_n_gen = int(cfg.get("metrics", {}).get("metric_num_gen", 5000)) |
| metric_bs = int(cfg.get("metrics", {}).get("metric_batch_size", 64)) |
|
|
| |
| step = 0 |
| pbar = tqdm(total=max_steps, desc="training") |
| opt.zero_grad(set_to_none=True) |
|
|
| |
| last_log_time = time.perf_counter() |
| last_log_step = 0 |
|
|
| |
| while step < max_steps: |
| for x, _ in dl: |
| |
| x = x.to(device, non_blocking=True).float() |
|
|
| |
| loss = diffusion(x) / grad_accum |
| loss.backward() |
|
|
| if ((step + 1) % grad_accum) == 0: |
| |
| torch.nn.utils.clip_grad_norm_( |
| diffusion.parameters(), cfg["opt"]["grad_clip"]) |
| |
| opt.step() |
| opt.zero_grad(set_to_none=True) |
| |
| if ema is not None: |
| ema.update() |
|
|
| step += 1 |
| pbar.update(1) |
|
|
| |
| if run and step % log_every == 0: |
| |
| now = time.perf_counter() |
| delta_t = max(now - last_log_time, 1e-6) |
| delta_s = step - last_log_step |
| ips = delta_s / delta_t |
| ms_per_iter = 1000.0 / max(ips, 1e-9) |
|
|
| wandb.log({ |
| "train/loss": float(loss.item() * grad_accum), |
| "speed/iter_per_sec": ips, |
| "speed/ms_per_iter": ms_per_iter, |
| "step": step |
| }, step=step) |
|
|
| |
| last_log_time = now |
| last_log_step = step |
|
|
| |
| if run: |
| log_global_grad_norm_sparsely( |
| diffusion, step, every=global_norm_every) |
|
|
| |
| if step % int(cfg["diffusion"]["sample_every"]) == 0: |
| diffusion.eval() |
| with torch.inference_mode(): |
| sampler = ema.ema_model if ema is not None else diffusion |
| t0 = time.perf_counter() |
| samples = sampler.sample(cfg["diffusion"]["sample_n"]) |
| t1 = time.perf_counter() |
| path = f"./samples/mnist_step_{step}.png" |
| save_image(samples, path, nrow=8) |
|
|
| |
| dt = max(t1 - t0, 1e-6) |
| imgs_per_sec = cfg["diffusion"]["sample_n"] / dt |
|
|
| if run: |
| wandb.log({ |
| "samples_grid": wandb.Image(path), |
| "speed/sampling_imgs_per_sec": imgs_per_sec, |
| "speed/sampling_sec": dt, |
| "step": step |
| }, step=step) |
| |
| t0 = time.perf_counter() |
| samples = sampler.sample(cfg["diffusion"]["sample_n"]) |
| t1 = time.perf_counter() |
| path = f"./samples/mnist_step_{step}.png" |
| save_image(samples, path, nrow=8) |
| dt = max(t1 - t0, 1e-6) |
| imgs_per_sec = cfg["diffusion"]["sample_n"] / dt |
|
|
| if run: |
| wandb.log({ |
| "samples_grid": wandb.Image(path), |
| "speed/sampling_imgs_per_sec": imgs_per_sec, |
| "speed/sampling_sec": dt, |
| "step": step |
| }, step=step) |
|
|
| |
| if cfg.get("viz", {}).get("enable_reverse_traj", False) \ |
| and step % int(cfg["viz"]["reverse_every_steps"]) == 0: |
| B = int(cfg["viz"]["reverse_batch_n"]) |
| C = diffusion.channels |
| H = W = cfg["data"]["image_size"] |
| |
| _, frames_xt, _ = sampler.ddpm_sample_trajectory( |
| shape=(B, C, H, W), |
| record_every=int( |
| cfg["viz"]["reverse_record_every"]), |
| return_x0=False |
| ) |
| video = frames_to_wandb_video( |
| frames_xt, nrow=min(8, B), fps=int(cfg["viz"]["video_fps"])) |
| if run: |
| wandb.log({"viz/reverse_xt": video, |
| "step": step}, step=step) |
|
|
| |
| if cfg.get("viz", {}).get("enable_forward_traj", False) \ |
| and step % int(cfg["viz"]["forward_every_steps"]) == 0: |
| |
| Bf = int(cfg["viz"]["forward_batch_n"]) |
| |
| x0_vis = x[:Bf].detach().cpu() |
| t_vals = cfg["viz"]["forward_t_values"] |
| frames_fwd = diffusion.forward_noising_trajectory( |
| x0=x0_vis.to(device), t_values=t_vals |
| ) |
| video_fwd = frames_to_wandb_video( |
| frames_fwd, nrow=min(8, Bf), fps=int(cfg["viz"]["video_fps"])) |
| if run: |
| wandb.log({"viz/forward_xt": video_fwd, |
| "step": step}, step=step) |
|
|
| diffusion.train() |
|
|
| |
| if step % (5 * int(cfg["diffusion"]["sample_every"])) == 0: |
| save_obj = { |
| "step": step, "model": diffusion.state_dict(), "opt": opt.state_dict()} |
| if ema is not None: |
| save_obj["ema"] = ema.state_dict() |
| torch.save(save_obj, os.path.join( |
| cfg["train"]["ckpt_dir"], f"mnist_step_{step}.pt")) |
|
|
| |
| |
| if (enable_fid or enable_is) and (step % min(fid_every if enable_fid else is_every, |
| is_every if enable_is else fid_every) == 0): |
| |
| real_ref_dir = "./metrics_ref/mnist_train_32_rgb" |
| ensure_real_ref_folder(dl, real_ref_dir, max_images=50000, |
| img_size=cfg["data"]["image_size"], force_rgb=True) |
|
|
| |
| gen_dir = f"./metrics_gen/step_{step}" |
| sampler = ema.ema_model if ema is not None else diffusion |
| t0 = time.perf_counter() |
| with torch.inference_mode(): |
| generate_images_to_folder(sampler, n_images=metric_n_gen, |
| batch_size=metric_bs, out_dir=gen_dir, force_rgb=True) |
| t1 = time.perf_counter() |
| gen_fps = metric_n_gen / max(t1 - t0, 1e-6) |
|
|
| log_payload = {"step": step, |
| "metrics/gen_imgs_per_sec": gen_fps} |
|
|
| if enable_fid and HAS_CLEANFID and (step % fid_every == 0): |
| fid_score = compute_fid_cleanfid(gen_dir, real_ref_dir) |
| if fid_score is not None: |
| log_payload["metrics/FID_clean"] = fid_score |
|
|
| if enable_is and HAS_TORCH_FIDELITY and (step % is_every == 0): |
| is_mean, is_std = compute_inception_score_torchfidelity( |
| gen_dir, cuda=True) |
| if is_mean is not None: |
| log_payload["metrics/IS_mean"] = is_mean |
| if is_std is not None: |
| log_payload["metrics/IS_std"] = is_std |
|
|
| if run and len(log_payload) > 1: |
| wandb.log(log_payload, step=step) |
|
|
| if step >= max_steps: |
| break |
|
|
| pbar.close() |
| if run: |
| run.finish() |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--config", type=str, |
| default="config_mnist_small.yaml", help="Path to YAML config") |
| args = ap.parse_args() |
| main(args.config) |
|
|