MVA_GenAI / train_unet.py
haiphamcse's picture
Upload folder using huggingface_hub
f729117 verified
#!/usr/bin/env python3
"""
Standalone CFM + UNet training (extracted from close_form_mva_gen_proj.ipynb).
Logs loss and sample images to TensorBoard.
"""
from __future__ import annotations
import argparse
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import torch
import torchvision
from torch.utils.data import DataLoader, Subset
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import v2
from torchcfm.conditional_flow_matching import ConditionalFlowMatcher
from torchcfm.models.unet.unet import UNetModelWrapper
from torchdyn.core import NeuralODE
try:
import yaml # type: ignore[import-untyped]
except ImportError as e: # pragma: no cover
raise ImportError("Please `pip install pyyaml` to use --config.") from e
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Train UNet with Conditional Flow Matching (CIFAR-10 or Imagenette)")
# Data
p.add_argument(
"--dataset",
type=str,
default="imagenette",
choices=["cifar10", "imagenette"],
help="Training dataset: CIFAR-10 or Imagenette (both 10 classes, 32x32 after transforms)",
)
p.add_argument("--data-root", type=str, default=".", help="Root for dataset download/cache")
p.add_argument("--cifar-split", type=str, default="train", choices=["train", "test"])
p.add_argument("--imagenette-split", type=str, default="train", choices=["train", "val"])
p.add_argument("--imagenette-size", type=str, default="160px", choices=["160px", "320px", "full"])
p.add_argument(
"--single-class",
action="store_true",
help="Keep only samples whose label equals --class-id (CIFAR-10 / Imagenette class index 0..9)",
)
p.add_argument(
"--class-id",
type=int,
default=0,
help="Label to keep when --single-class is set (0..9 for both datasets)",
)
p.add_argument("--batch-size", type=int, default=64)
p.add_argument("--num-workers", type=int, default=4)
# Training
p.add_argument("--epochs", type=int, default=30)
p.add_argument("--device", type=str, default=None, help="cuda | cpu (default: auto)")
p.add_argument("--log-interval", type=int, default=100, help="Print / log batch loss every N steps")
p.add_argument("--seed", type=int, default=0)
# Checkpoints
p.add_argument("--save-dir", type=str, default="./runs/cfm_unet/checkpoints", help="Directory for .pt files")
# TensorBoard
p.add_argument(
"--log-dir",
type=str,
default="./runs/cfm_unet/tensorboard",
help="TensorBoard log directory (also used if --run-name is set)",
)
p.add_argument("--run-name", type=str, default=None, help="Subfolder under log-dir for this run")
# UNet / CFM (YAML)
p.add_argument(
"--config",
type=str,
default=None,
help="YAML with UNet + CFM hyperparameters (default: unet_config.yaml next to this script)",
)
p.add_argument(
"--data-percent",
type=int,
default=100,
choices=[10, 20, 30, 60, 80, 100],
help="Use only this percentage of the (possibly filtered) training dataset.",
)
return p.parse_args()
def _parse_int_list(s: str) -> list[int]:
return [int(x.strip()) for x in s.split(",") if x.strip()]
def _parse_dim(s: str) -> tuple[int, int, int]:
parts = _parse_int_list(s)
if len(parts) != 3:
raise ValueError("--dim must be three integers C,H,W")
return (parts[0], parts[1], parts[2])
@dataclass
class TrainConfig:
sigma: float
dim: tuple[int, int, int]
lr: float
weight_decay: float
inference_steps: int
vis_batch_size: int
num_res_blocks: int
num_channels: int
channel_mult: list[int]
num_heads: int
num_head_channels: int
attention_resolutions: str
dropout: float
def _dim_from_yaml(value: Any) -> tuple[int, int, int]:
if isinstance(value, (list, tuple)) and len(value) == 3:
return (int(value[0]), int(value[1]), int(value[2]))
if isinstance(value, str):
return _parse_dim(value)
raise ValueError("YAML 'dim' must be [C,H,W] or a string like '3,32,32'")
def _channel_mult_from_yaml(value: Any) -> list[int]:
if isinstance(value, (list, tuple)):
return [int(x) for x in value]
if isinstance(value, str):
return _parse_int_list(value)
raise ValueError("YAML 'channel_mult' must be a list of ints or a comma-separated string")
REQUIRED_YAML_KEYS = (
"sigma",
"dim",
"lr",
"weight_decay",
"inference_steps",
"vis_batch_size",
"num_res_blocks",
"num_channels",
"channel_mult",
"num_heads",
"num_head_channels",
"attention_resolutions",
"dropout",
)
def load_unet_config_yaml(path: str | os.PathLike[str]) -> TrainConfig:
path = Path(path)
if not path.is_file():
raise FileNotFoundError(f"Config file not found: {path.resolve()}")
with open(path, encoding="utf-8") as f:
raw = yaml.safe_load(f)
if raw is None or not isinstance(raw, dict):
raise ValueError(f"Config must be a YAML mapping: {path}")
missing = [k for k in REQUIRED_YAML_KEYS if k not in raw]
if missing:
raise ValueError(f"Missing keys in {path}: {missing}")
return TrainConfig(
sigma=float(raw["sigma"]),
dim=_dim_from_yaml(raw["dim"]),
lr=float(raw["lr"]),
weight_decay=float(raw["weight_decay"]),
inference_steps=int(raw["inference_steps"]),
vis_batch_size=int(raw["vis_batch_size"]),
num_res_blocks=int(raw["num_res_blocks"]),
num_channels=int(raw["num_channels"]),
channel_mult=_channel_mult_from_yaml(raw["channel_mult"]),
num_heads=int(raw["num_heads"]),
num_head_channels=int(raw["num_head_channels"]),
attention_resolutions=str(raw["attention_resolutions"]),
dropout=float(raw["dropout"]),
)
NUM_CLASSES = {"cifar10": 10, "imagenette": 10}
def _targets_list(dataset: torch.utils.data.Dataset) -> list[int]:
if hasattr(dataset, "targets"):
t = dataset.targets
return list(t) if not isinstance(t, list) else t
return [int(dataset[i][1]) for i in range(len(dataset))]
def _maybe_single_class(
dataset: torch.utils.data.Dataset,
*,
single_class: bool,
class_id: int,
dataset_name: str,
) -> torch.utils.data.Dataset:
n_cls = NUM_CLASSES[dataset_name]
if not single_class:
return dataset
if class_id < 0 or class_id >= n_cls:
raise ValueError(f"--class-id must be in [0, {n_cls - 1}] for {dataset_name}")
targets = _targets_list(dataset)
indices = [i for i, y in enumerate(targets) if int(y) == class_id]
if not indices:
raise RuntimeError(f"No samples found for class_id={class_id}")
print(f"Single-class filter: dataset={dataset_name}, class_id={class_id}, n_samples={len(indices)}")
return Subset(dataset, indices)
def load_training_dataset(args: argparse.Namespace, transforms: v2.Compose) -> torch.utils.data.Dataset:
name = args.dataset
if name == "cifar10":
ds: torch.utils.data.Dataset = torchvision.datasets.CIFAR10(
root=args.data_root,
train=(args.cifar_split == "train"),
download=True,
transform=transforms,
)
elif name == "imagenette":
ds = torchvision.datasets.Imagenette(
args.data_root,
split=args.imagenette_split,
size=args.imagenette_size,
download=True,
transform=transforms,
)
else:
raise ValueError(f"Unknown dataset: {name}")
ds = _maybe_single_class(ds, single_class=args.single_class, class_id=args.class_id, dataset_name=name)
return ds
def main() -> None:
args = parse_args()
default_cfg = Path(__file__).resolve().parent / "unet_config.yaml"
config_path = Path(args.config).resolve() if args.config else default_cfg
cfg = load_unet_config_yaml(config_path)
print(f"Loaded UNet config from: {config_path}")
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
print(f"Using device: {device}")
os.makedirs(args.save_dir, exist_ok=True)
tb_dir = os.path.join(args.log_dir, args.run_name) if args.run_name else args.log_dir
os.makedirs(tb_dir, exist_ok=True)
writer = SummaryWriter(log_dir=tb_dir)
writer.add_text("config/args", str(vars(args)), 0)
writer.add_text("config/unet_yaml", config_path.read_text(encoding="utf-8"), 0)
transforms = v2.Compose(
[
v2.ToTensor(),
v2.ToDtype(torch.float32, scale=True),
v2.Resize((32,32)),
v2.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
]
)
train_dataset = load_training_dataset(args, transforms)
print(f"Dataset: {args.dataset}, size={len(train_dataset)}")
orig_len = len(train_dataset)
if args.data_percent < 100:
new_len = max(1, int(orig_len * args.data_percent / 100.0))
g = torch.Generator()
g.manual_seed(args.seed)
perm = torch.randperm(orig_len, generator=g)
indices = perm[:new_len].tolist()
torch.save(perm[:new_len], os.path.join(args.save_dir, "indices.pt"))
train_dataset = Subset(train_dataset, indices)
print(f"Subsampled dataset: {args.data_percent}% -> {len(train_dataset)} samples")
else:
print(f"Using full dataset: {orig_len} samples")
dummy_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=device.type == "cuda",
)
total_optimizer_steps = len(dummy_dataloader) * args.epochs
fm = ConditionalFlowMatcher(sigma=cfg.sigma)
net_model = UNetModelWrapper(
dim=cfg.dim,
num_res_blocks=cfg.num_res_blocks,
num_channels=cfg.num_channels,
channel_mult=cfg.channel_mult,
num_heads=cfg.num_heads,
num_head_channels=cfg.num_head_channels,
attention_resolutions=cfg.attention_resolutions,
dropout=cfg.dropout,
).to(device)
optim = torch.optim.AdamW(net_model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
# Match notebook: one scheduler.step() per batch; span full training (not just epochs count).
scheduler = torch.optim.lr_scheduler.LinearLR(optim, total_iters=max(total_optimizer_steps, 1))
t_span = torch.linspace(0, 1, cfg.inference_steps + 1, device=device)
c, h, w = cfg.dim
global_step = 0
best_loss = float("inf")
for ep in range(args.epochs):
net_model.train()
epoch_loss = 0.0
num_batches = 0
for data in dummy_dataloader:
x1 = data[0].to(device, non_blocking=True)
x0 = torch.randn_like(x1)
t, xt, ut = fm.sample_location_and_conditional_flow(x0, x1)
vt = net_model(t, xt)
loss = torch.mean((vt - ut) ** 2)
optim.zero_grad(set_to_none=True)
loss.backward()
optim.step()
scheduler.step()
epoch_loss += loss.item()
num_batches += 1
writer.add_scalar("train/loss_step", loss.item(), global_step)
writer.add_scalar("train/lr", scheduler.get_last_lr()[0], global_step)
if global_step % args.log_interval == 0:
print(f"[step {global_step}] loss = {loss.item():.6f}")
global_step += 1
avg_epoch_loss = epoch_loss / max(num_batches, 1)
writer.add_scalar("train/loss_epoch", avg_epoch_loss, ep)
print(f"[epoch {ep}] avg loss = {avg_epoch_loss:.6f}")
# Sample trajectories (NeuralODE) — log image grid to TensorBoard
net_model.eval()
node = NeuralODE(net_model, solver="euler")
with torch.no_grad():
x_vis = torch.randn(cfg.vis_batch_size, c, h, w, device=device)
traj = node.trajectory(x_vis, t_span=t_span)
x_final = traj[-1]
x_final = x_final.clamp(0.0, 1.0).cpu()
grid = torchvision.utils.make_grid(x_final, nrow=4, padding=2, normalize=False)
writer.add_image("samples/neural_ode_final", grid, ep)
if ep % 30 == 0:
ckpt_path = os.path.join(args.save_dir, f"model_epoch_{ep}.pt")
torch.save(net_model.state_dict(), ckpt_path)
if ep == 0 or avg_epoch_loss < best_loss:
best_loss = avg_epoch_loss
torch.save(net_model.state_dict(), os.path.join(args.save_dir, "model_best.pt"))
writer.close()
print(f"Done. Checkpoints: {args.save_dir}")
print(f"TensorBoard: tensorboard --logdir {tb_dir}")
if __name__ == "__main__":
main()