#!/usr/bin/env python3 """ Standalone CFM + UNet training (extracted from close_form_mva_gen_proj.ipynb). Logs loss and sample images to TensorBoard. """ from __future__ import annotations import argparse import os from dataclasses import dataclass from pathlib import Path from typing import Any import torch import torchvision from torch.utils.data import DataLoader, Subset from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import v2 from torchcfm.conditional_flow_matching import ConditionalFlowMatcher from torchcfm.models.unet.unet import UNetModelWrapper from torchdyn.core import NeuralODE try: import yaml # type: ignore[import-untyped] except ImportError as e: # pragma: no cover raise ImportError("Please `pip install pyyaml` to use --config.") from e def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Train UNet with Conditional Flow Matching (CIFAR-10 or Imagenette)") # Data p.add_argument( "--dataset", type=str, default="imagenette", choices=["cifar10", "imagenette"], help="Training dataset: CIFAR-10 or Imagenette (both 10 classes, 32x32 after transforms)", ) p.add_argument("--data-root", type=str, default=".", help="Root for dataset download/cache") p.add_argument("--cifar-split", type=str, default="train", choices=["train", "test"]) p.add_argument("--imagenette-split", type=str, default="train", choices=["train", "val"]) p.add_argument("--imagenette-size", type=str, default="160px", choices=["160px", "320px", "full"]) p.add_argument( "--single-class", action="store_true", help="Keep only samples whose label equals --class-id (CIFAR-10 / Imagenette class index 0..9)", ) p.add_argument( "--class-id", type=int, default=0, help="Label to keep when --single-class is set (0..9 for both datasets)", ) p.add_argument("--batch-size", type=int, default=64) p.add_argument("--num-workers", type=int, default=4) # Training p.add_argument("--epochs", type=int, default=30) p.add_argument("--device", type=str, default=None, help="cuda | cpu (default: auto)") p.add_argument("--log-interval", type=int, default=100, help="Print / log batch loss every N steps") p.add_argument("--seed", type=int, default=0) # Checkpoints p.add_argument("--save-dir", type=str, default="./runs/cfm_unet/checkpoints", help="Directory for .pt files") # TensorBoard p.add_argument( "--log-dir", type=str, default="./runs/cfm_unet/tensorboard", help="TensorBoard log directory (also used if --run-name is set)", ) p.add_argument("--run-name", type=str, default=None, help="Subfolder under log-dir for this run") # UNet / CFM (YAML) p.add_argument( "--config", type=str, default=None, help="YAML with UNet + CFM hyperparameters (default: unet_config.yaml next to this script)", ) p.add_argument( "--data-percent", type=int, default=100, choices=[10, 20, 30, 60, 80, 100], help="Use only this percentage of the (possibly filtered) training dataset.", ) return p.parse_args() def _parse_int_list(s: str) -> list[int]: return [int(x.strip()) for x in s.split(",") if x.strip()] def _parse_dim(s: str) -> tuple[int, int, int]: parts = _parse_int_list(s) if len(parts) != 3: raise ValueError("--dim must be three integers C,H,W") return (parts[0], parts[1], parts[2]) @dataclass class TrainConfig: sigma: float dim: tuple[int, int, int] lr: float weight_decay: float inference_steps: int vis_batch_size: int num_res_blocks: int num_channels: int channel_mult: list[int] num_heads: int num_head_channels: int attention_resolutions: str dropout: float def _dim_from_yaml(value: Any) -> tuple[int, int, int]: if isinstance(value, (list, tuple)) and len(value) == 3: return (int(value[0]), int(value[1]), int(value[2])) if isinstance(value, str): return _parse_dim(value) raise ValueError("YAML 'dim' must be [C,H,W] or a string like '3,32,32'") def _channel_mult_from_yaml(value: Any) -> list[int]: if isinstance(value, (list, tuple)): return [int(x) for x in value] if isinstance(value, str): return _parse_int_list(value) raise ValueError("YAML 'channel_mult' must be a list of ints or a comma-separated string") REQUIRED_YAML_KEYS = ( "sigma", "dim", "lr", "weight_decay", "inference_steps", "vis_batch_size", "num_res_blocks", "num_channels", "channel_mult", "num_heads", "num_head_channels", "attention_resolutions", "dropout", ) def load_unet_config_yaml(path: str | os.PathLike[str]) -> TrainConfig: path = Path(path) if not path.is_file(): raise FileNotFoundError(f"Config file not found: {path.resolve()}") with open(path, encoding="utf-8") as f: raw = yaml.safe_load(f) if raw is None or not isinstance(raw, dict): raise ValueError(f"Config must be a YAML mapping: {path}") missing = [k for k in REQUIRED_YAML_KEYS if k not in raw] if missing: raise ValueError(f"Missing keys in {path}: {missing}") return TrainConfig( sigma=float(raw["sigma"]), dim=_dim_from_yaml(raw["dim"]), lr=float(raw["lr"]), weight_decay=float(raw["weight_decay"]), inference_steps=int(raw["inference_steps"]), vis_batch_size=int(raw["vis_batch_size"]), num_res_blocks=int(raw["num_res_blocks"]), num_channels=int(raw["num_channels"]), channel_mult=_channel_mult_from_yaml(raw["channel_mult"]), num_heads=int(raw["num_heads"]), num_head_channels=int(raw["num_head_channels"]), attention_resolutions=str(raw["attention_resolutions"]), dropout=float(raw["dropout"]), ) NUM_CLASSES = {"cifar10": 10, "imagenette": 10} def _targets_list(dataset: torch.utils.data.Dataset) -> list[int]: if hasattr(dataset, "targets"): t = dataset.targets return list(t) if not isinstance(t, list) else t return [int(dataset[i][1]) for i in range(len(dataset))] def _maybe_single_class( dataset: torch.utils.data.Dataset, *, single_class: bool, class_id: int, dataset_name: str, ) -> torch.utils.data.Dataset: n_cls = NUM_CLASSES[dataset_name] if not single_class: return dataset if class_id < 0 or class_id >= n_cls: raise ValueError(f"--class-id must be in [0, {n_cls - 1}] for {dataset_name}") targets = _targets_list(dataset) indices = [i for i, y in enumerate(targets) if int(y) == class_id] if not indices: raise RuntimeError(f"No samples found for class_id={class_id}") print(f"Single-class filter: dataset={dataset_name}, class_id={class_id}, n_samples={len(indices)}") return Subset(dataset, indices) def load_training_dataset(args: argparse.Namespace, transforms: v2.Compose) -> torch.utils.data.Dataset: name = args.dataset if name == "cifar10": ds: torch.utils.data.Dataset = torchvision.datasets.CIFAR10( root=args.data_root, train=(args.cifar_split == "train"), download=True, transform=transforms, ) elif name == "imagenette": ds = torchvision.datasets.Imagenette( args.data_root, split=args.imagenette_split, size=args.imagenette_size, download=True, transform=transforms, ) else: raise ValueError(f"Unknown dataset: {name}") ds = _maybe_single_class(ds, single_class=args.single_class, class_id=args.class_id, dataset_name=name) return ds def main() -> None: args = parse_args() default_cfg = Path(__file__).resolve().parent / "unet_config.yaml" config_path = Path(args.config).resolve() if args.config else default_cfg cfg = load_unet_config_yaml(config_path) print(f"Loaded UNet config from: {config_path}") torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) print(f"Using device: {device}") os.makedirs(args.save_dir, exist_ok=True) tb_dir = os.path.join(args.log_dir, args.run_name) if args.run_name else args.log_dir os.makedirs(tb_dir, exist_ok=True) writer = SummaryWriter(log_dir=tb_dir) writer.add_text("config/args", str(vars(args)), 0) writer.add_text("config/unet_yaml", config_path.read_text(encoding="utf-8"), 0) transforms = v2.Compose( [ v2.ToTensor(), v2.ToDtype(torch.float32, scale=True), v2.Resize((32,32)), v2.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), ] ) train_dataset = load_training_dataset(args, transforms) print(f"Dataset: {args.dataset}, size={len(train_dataset)}") orig_len = len(train_dataset) if args.data_percent < 100: new_len = max(1, int(orig_len * args.data_percent / 100.0)) g = torch.Generator() g.manual_seed(args.seed) perm = torch.randperm(orig_len, generator=g) indices = perm[:new_len].tolist() torch.save(perm[:new_len], os.path.join(args.save_dir, "indices.pt")) train_dataset = Subset(train_dataset, indices) print(f"Subsampled dataset: {args.data_percent}% -> {len(train_dataset)} samples") else: print(f"Using full dataset: {orig_len} samples") dummy_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=device.type == "cuda", ) total_optimizer_steps = len(dummy_dataloader) * args.epochs fm = ConditionalFlowMatcher(sigma=cfg.sigma) net_model = UNetModelWrapper( dim=cfg.dim, num_res_blocks=cfg.num_res_blocks, num_channels=cfg.num_channels, channel_mult=cfg.channel_mult, num_heads=cfg.num_heads, num_head_channels=cfg.num_head_channels, attention_resolutions=cfg.attention_resolutions, dropout=cfg.dropout, ).to(device) optim = torch.optim.AdamW(net_model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) # Match notebook: one scheduler.step() per batch; span full training (not just epochs count). scheduler = torch.optim.lr_scheduler.LinearLR(optim, total_iters=max(total_optimizer_steps, 1)) t_span = torch.linspace(0, 1, cfg.inference_steps + 1, device=device) c, h, w = cfg.dim global_step = 0 best_loss = float("inf") for ep in range(args.epochs): net_model.train() epoch_loss = 0.0 num_batches = 0 for data in dummy_dataloader: x1 = data[0].to(device, non_blocking=True) x0 = torch.randn_like(x1) t, xt, ut = fm.sample_location_and_conditional_flow(x0, x1) vt = net_model(t, xt) loss = torch.mean((vt - ut) ** 2) optim.zero_grad(set_to_none=True) loss.backward() optim.step() scheduler.step() epoch_loss += loss.item() num_batches += 1 writer.add_scalar("train/loss_step", loss.item(), global_step) writer.add_scalar("train/lr", scheduler.get_last_lr()[0], global_step) if global_step % args.log_interval == 0: print(f"[step {global_step}] loss = {loss.item():.6f}") global_step += 1 avg_epoch_loss = epoch_loss / max(num_batches, 1) writer.add_scalar("train/loss_epoch", avg_epoch_loss, ep) print(f"[epoch {ep}] avg loss = {avg_epoch_loss:.6f}") # Sample trajectories (NeuralODE) — log image grid to TensorBoard net_model.eval() node = NeuralODE(net_model, solver="euler") with torch.no_grad(): x_vis = torch.randn(cfg.vis_batch_size, c, h, w, device=device) traj = node.trajectory(x_vis, t_span=t_span) x_final = traj[-1] x_final = x_final.clamp(0.0, 1.0).cpu() grid = torchvision.utils.make_grid(x_final, nrow=4, padding=2, normalize=False) writer.add_image("samples/neural_ode_final", grid, ep) if ep % 30 == 0: ckpt_path = os.path.join(args.save_dir, f"model_epoch_{ep}.pt") torch.save(net_model.state_dict(), ckpt_path) if ep == 0 or avg_epoch_loss < best_loss: best_loss = avg_epoch_loss torch.save(net_model.state_dict(), os.path.join(args.save_dir, "model_best.pt")) writer.close() print(f"Done. Checkpoints: {args.save_dir}") print(f"TensorBoard: tensorboard --logdir {tb_dir}") if __name__ == "__main__": main()