the-well-diffusion / train_diffusion.py
AlexWortega's picture
Upload train_diffusion.py with huggingface_hub
ccdcfe1 verified
#!/usr/bin/env python3
"""
Training script for conditional DDPM on The Well datasets.
Includes periodic evaluation with WandB video logging.
Usage:
python train_diffusion.py --dataset turbulent_radiative_layer_2D --wandb
python train_diffusion.py --dataset active_matter --batch_size 4 --wandb
"""
import argparse
import logging
import math
import os
import time
import torch
import torch.nn as nn
from torch.amp import GradScaler, autocast
from tqdm import tqdm
from data_pipeline import create_dataloader, prepare_batch, get_channel_info
from unet import UNet
from diffusion import GaussianDiffusion
# --- logging setup (suppress noisy library logs) ---
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger("train_diffusion")
logger.setLevel(logging.INFO)
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S"))
logger.addHandler(_h)
logger.propagate = False
# Also let eval_utils log through us
logging.getLogger("eval_utils").setLevel(logging.INFO)
logging.getLogger("eval_utils").addHandler(_h)
logging.getLogger("eval_utils").propagate = False
def cosine_lr(step, warmup, total, base_lr, min_lr=1e-6):
if step < warmup:
return base_lr * step / max(warmup, 1)
progress = (step - warmup) / max(total - warmup, 1)
return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(progress * math.pi))
def train(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Device: {device}")
# ---- WandB ----
wandb_run = None
if args.wandb:
import wandb
wandb_run = wandb.init(
project="the-well-diffusion",
name=f"{args.dataset}_bs{args.batch_size}_lr{args.lr}",
config=vars(args),
)
logger.info(f"WandB run: {wandb_run.url}")
# ---- Data: train ----
logger.info(f"Loading training data: {args.dataset} (streaming={args.streaming})")
train_loader, train_dataset = create_dataloader(
dataset_name=args.dataset,
split="train",
batch_size=args.batch_size,
n_steps_input=args.n_input,
n_steps_output=args.n_output,
num_workers=args.workers,
streaming=args.streaming,
local_path=args.local_path,
)
ch_info = get_channel_info(train_dataset)
logger.info(f"Channel info: {ch_info}")
c_in = ch_info["input_channels"]
c_out = ch_info["output_channels"]
# ---- Data: validation (single-step) ----
logger.info("Loading validation data...")
val_loader, _ = create_dataloader(
dataset_name=args.dataset,
split="valid",
batch_size=args.batch_size,
n_steps_input=args.n_input,
n_steps_output=args.n_output,
num_workers=0,
streaming=args.streaming,
local_path=args.local_path,
)
# ---- Data: rollout validation (multi-step output for GT comparison) ----
logger.info(f"Loading rollout data (n_steps_output={args.n_rollout})...")
rollout_loader, _ = create_dataloader(
dataset_name=args.dataset,
split="valid",
batch_size=1,
n_steps_input=args.n_input,
n_steps_output=args.n_rollout,
num_workers=0,
streaming=args.streaming,
local_path=args.local_path,
)
# ---- Model ----
unet = UNet(
in_channels=c_out + c_in,
out_channels=c_out,
base_ch=args.base_ch,
ch_mults=tuple(args.ch_mults),
n_res=args.n_res,
attn_levels=tuple(args.attn_levels),
dropout=args.dropout,
)
diffusion = GaussianDiffusion(unet, timesteps=args.timesteps).to(device)
n_params = sum(p.numel() for p in diffusion.parameters() if p.requires_grad)
logger.info(f"Model parameters: {n_params:,}")
if wandb_run:
wandb_run.summary["n_params"] = n_params
# ---- Optimizer ----
optimizer = torch.optim.AdamW(diffusion.parameters(), lr=args.lr, weight_decay=args.wd)
scaler = GradScaler("cuda", enabled=args.amp)
# ---- Checkpoint resume ----
start_epoch = 0
global_step = 0
if args.resume and os.path.exists(args.resume):
ckpt = torch.load(args.resume, map_location=device, weights_only=False)
diffusion.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
scaler.load_state_dict(ckpt["scaler"])
start_epoch = ckpt["epoch"] + 1
global_step = ckpt["global_step"]
logger.info(f"Resumed from epoch {start_epoch}, step {global_step}")
# ---- Training loop ----
os.makedirs(args.ckpt_dir, exist_ok=True)
total_steps = args.epochs * len(train_loader)
logger.info(f"Starting training: {args.epochs} epochs, ~{total_steps} steps")
logger.info(f"Eval every {args.eval_every} epochs, rollout {args.n_rollout} steps")
for epoch in range(start_epoch, args.epochs):
diffusion.train()
epoch_loss = 0.0
n_batches = 0
t0 = time.time()
pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
for batch in pbar:
try:
x_cond, x_target = prepare_batch(batch, device)
except Exception as e:
logger.warning(f"Batch error: {e}, skipping")
continue
lr = cosine_lr(global_step, args.warmup, total_steps, args.lr)
for pg in optimizer.param_groups:
pg["lr"] = lr
optimizer.zero_grad(set_to_none=True)
with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=args.amp):
loss = diffusion.training_loss(x_target, x_cond)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(diffusion.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
n_batches += 1
global_step += 1
pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{lr:.2e}")
if wandb_run and global_step % 20 == 0:
wandb_run.log({"train/loss": loss.item(), "train/lr": lr}, step=global_step)
avg_loss = epoch_loss / max(n_batches, 1)
elapsed = time.time() - t0
logger.info(
f"Epoch {epoch}: loss={avg_loss:.4f}, batches={n_batches}, "
f"time={elapsed:.1f}s, lr={lr:.2e}"
)
if wandb_run:
wandb_run.log({"train/epoch_loss": avg_loss, "epoch": epoch}, step=global_step)
# ---- Evaluation with video logging ----
if (epoch + 1) % args.eval_every == 0:
from eval_utils import run_evaluation
logger.info("=" * 40)
logger.info(f"EVALUATION at epoch {epoch}")
logger.info("=" * 40)
eval_metrics = run_evaluation(
model=diffusion,
val_loader=val_loader,
rollout_loader=rollout_loader,
device=device,
global_step=global_step,
wandb_run=wandb_run,
n_val_batches=args.eval_batches,
n_rollout=args.n_rollout,
ddim_steps=args.ddim_steps,
)
logger.info(
f" val/mse={eval_metrics['val/mse']:.6f}, "
f"rollout_mse_mean={eval_metrics['val/rollout_mse_mean']:.6f}"
)
logger.info("=" * 40)
# ---- Checkpoint ----
if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1:
ckpt_path = os.path.join(args.ckpt_dir, f"diffusion_ep{epoch:04d}.pt")
torch.save(
{
"epoch": epoch,
"global_step": global_step,
"model": diffusion.state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict(),
"args": vars(args),
"ch_info": ch_info,
},
ckpt_path,
)
logger.info(f"Saved {ckpt_path}")
if wandb_run:
wandb_run.finish()
logger.info("Training complete.")
def main():
p = argparse.ArgumentParser(description="Train conditional DDPM on The Well")
# Data
p.add_argument("--dataset", default="turbulent_radiative_layer_2D")
p.add_argument("--streaming", action="store_true", default=True)
p.add_argument("--no-streaming", dest="streaming", action="store_false")
p.add_argument("--local_path", default=None)
p.add_argument("--batch_size", type=int, default=8)
p.add_argument("--workers", type=int, default=0)
p.add_argument("--n_input", type=int, default=1)
p.add_argument("--n_output", type=int, default=1)
# Model
p.add_argument("--base_ch", type=int, default=64)
p.add_argument("--ch_mults", type=int, nargs="+", default=[1, 2, 4, 8])
p.add_argument("--n_res", type=int, default=2)
p.add_argument("--attn_levels", type=int, nargs="+", default=[3])
p.add_argument("--dropout", type=float, default=0.1)
p.add_argument("--timesteps", type=int, default=1000)
# Optimization
p.add_argument("--lr", type=float, default=1e-4)
p.add_argument("--wd", type=float, default=0.01)
p.add_argument("--warmup", type=int, default=1000)
p.add_argument("--grad_clip", type=float, default=1.0)
p.add_argument("--amp", action="store_true", default=True)
p.add_argument("--no-amp", dest="amp", action="store_false")
p.add_argument("--epochs", type=int, default=100)
# Evaluation
p.add_argument("--eval_every", type=int, default=5, help="Eval every N epochs")
p.add_argument("--eval_batches", type=int, default=4, help="Val batches for MSE")
p.add_argument("--n_rollout", type=int, default=20, help="Rollout steps for video")
p.add_argument("--ddim_steps", type=int, default=50, help="DDIM steps for eval sampling")
# Checkpointing
p.add_argument("--ckpt_dir", default="checkpoints/diffusion")
p.add_argument("--save_every", type=int, default=5)
p.add_argument("--resume", default=None)
# Logging
p.add_argument("--wandb", action="store_true", default=False)
args = p.parse_args()
train(args)
if __name__ == "__main__":
main()