| |
| """ |
| Standalone CFM + UNet training (extracted from close_form_mva_gen_proj.ipynb). |
| Logs loss and sample images to TensorBoard. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| import torchvision |
| from torch.utils.data import DataLoader, Subset |
| from torch.utils.tensorboard import SummaryWriter |
| from torchvision.transforms import v2 |
|
|
| from torchcfm.conditional_flow_matching import ConditionalFlowMatcher |
| from torchcfm.models.unet.unet import UNetModelWrapper |
| from torchdyn.core import NeuralODE |
|
|
| try: |
| import yaml |
| except ImportError as e: |
| raise ImportError("Please `pip install pyyaml` to use --config.") from e |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Train UNet with Conditional Flow Matching (CIFAR-10 or Imagenette)") |
|
|
| |
| p.add_argument( |
| "--dataset", |
| type=str, |
| default="imagenette", |
| choices=["cifar10", "imagenette"], |
| help="Training dataset: CIFAR-10 or Imagenette (both 10 classes, 32x32 after transforms)", |
| ) |
| p.add_argument("--data-root", type=str, default=".", help="Root for dataset download/cache") |
| p.add_argument("--cifar-split", type=str, default="train", choices=["train", "test"]) |
| p.add_argument("--imagenette-split", type=str, default="train", choices=["train", "val"]) |
| p.add_argument("--imagenette-size", type=str, default="160px", choices=["160px", "320px", "full"]) |
| p.add_argument( |
| "--single-class", |
| action="store_true", |
| help="Keep only samples whose label equals --class-id (CIFAR-10 / Imagenette class index 0..9)", |
| ) |
| p.add_argument( |
| "--class-id", |
| type=int, |
| default=0, |
| help="Label to keep when --single-class is set (0..9 for both datasets)", |
| ) |
| p.add_argument("--batch-size", type=int, default=64) |
| p.add_argument("--num-workers", type=int, default=4) |
|
|
| |
| p.add_argument("--epochs", type=int, default=30) |
| p.add_argument("--device", type=str, default=None, help="cuda | cpu (default: auto)") |
| p.add_argument("--log-interval", type=int, default=100, help="Print / log batch loss every N steps") |
| p.add_argument("--seed", type=int, default=0) |
|
|
| |
| p.add_argument("--save-dir", type=str, default="./runs/cfm_unet/checkpoints", help="Directory for .pt files") |
|
|
| |
| p.add_argument( |
| "--log-dir", |
| type=str, |
| default="./runs/cfm_unet/tensorboard", |
| help="TensorBoard log directory (also used if --run-name is set)", |
| ) |
| p.add_argument("--run-name", type=str, default=None, help="Subfolder under log-dir for this run") |
|
|
| |
| p.add_argument( |
| "--config", |
| type=str, |
| default=None, |
| help="YAML with UNet + CFM hyperparameters (default: unet_config.yaml next to this script)", |
| ) |
|
|
| p.add_argument( |
| "--data-percent", |
| type=int, |
| default=100, |
| choices=[10, 20, 30, 60, 80, 100], |
| help="Use only this percentage of the (possibly filtered) training dataset.", |
| ) |
| return p.parse_args() |
|
|
|
|
| def _parse_int_list(s: str) -> list[int]: |
| return [int(x.strip()) for x in s.split(",") if x.strip()] |
|
|
|
|
| def _parse_dim(s: str) -> tuple[int, int, int]: |
| parts = _parse_int_list(s) |
| if len(parts) != 3: |
| raise ValueError("--dim must be three integers C,H,W") |
| return (parts[0], parts[1], parts[2]) |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| sigma: float |
| dim: tuple[int, int, int] |
| lr: float |
| weight_decay: float |
| inference_steps: int |
| vis_batch_size: int |
| num_res_blocks: int |
| num_channels: int |
| channel_mult: list[int] |
| num_heads: int |
| num_head_channels: int |
| attention_resolutions: str |
| dropout: float |
|
|
|
|
| def _dim_from_yaml(value: Any) -> tuple[int, int, int]: |
| if isinstance(value, (list, tuple)) and len(value) == 3: |
| return (int(value[0]), int(value[1]), int(value[2])) |
| if isinstance(value, str): |
| return _parse_dim(value) |
| raise ValueError("YAML 'dim' must be [C,H,W] or a string like '3,32,32'") |
|
|
|
|
| def _channel_mult_from_yaml(value: Any) -> list[int]: |
| if isinstance(value, (list, tuple)): |
| return [int(x) for x in value] |
| if isinstance(value, str): |
| return _parse_int_list(value) |
| raise ValueError("YAML 'channel_mult' must be a list of ints or a comma-separated string") |
|
|
|
|
| REQUIRED_YAML_KEYS = ( |
| "sigma", |
| "dim", |
| "lr", |
| "weight_decay", |
| "inference_steps", |
| "vis_batch_size", |
| "num_res_blocks", |
| "num_channels", |
| "channel_mult", |
| "num_heads", |
| "num_head_channels", |
| "attention_resolutions", |
| "dropout", |
| ) |
|
|
|
|
| def load_unet_config_yaml(path: str | os.PathLike[str]) -> TrainConfig: |
| path = Path(path) |
| if not path.is_file(): |
| raise FileNotFoundError(f"Config file not found: {path.resolve()}") |
|
|
| with open(path, encoding="utf-8") as f: |
| raw = yaml.safe_load(f) |
| if raw is None or not isinstance(raw, dict): |
| raise ValueError(f"Config must be a YAML mapping: {path}") |
|
|
| missing = [k for k in REQUIRED_YAML_KEYS if k not in raw] |
| if missing: |
| raise ValueError(f"Missing keys in {path}: {missing}") |
|
|
| return TrainConfig( |
| sigma=float(raw["sigma"]), |
| dim=_dim_from_yaml(raw["dim"]), |
| lr=float(raw["lr"]), |
| weight_decay=float(raw["weight_decay"]), |
| inference_steps=int(raw["inference_steps"]), |
| vis_batch_size=int(raw["vis_batch_size"]), |
| num_res_blocks=int(raw["num_res_blocks"]), |
| num_channels=int(raw["num_channels"]), |
| channel_mult=_channel_mult_from_yaml(raw["channel_mult"]), |
| num_heads=int(raw["num_heads"]), |
| num_head_channels=int(raw["num_head_channels"]), |
| attention_resolutions=str(raw["attention_resolutions"]), |
| dropout=float(raw["dropout"]), |
| ) |
|
|
|
|
| NUM_CLASSES = {"cifar10": 10, "imagenette": 10} |
|
|
|
|
| def _targets_list(dataset: torch.utils.data.Dataset) -> list[int]: |
| if hasattr(dataset, "targets"): |
| t = dataset.targets |
| return list(t) if not isinstance(t, list) else t |
| return [int(dataset[i][1]) for i in range(len(dataset))] |
|
|
|
|
| def _maybe_single_class( |
| dataset: torch.utils.data.Dataset, |
| *, |
| single_class: bool, |
| class_id: int, |
| dataset_name: str, |
| ) -> torch.utils.data.Dataset: |
| n_cls = NUM_CLASSES[dataset_name] |
| if not single_class: |
| return dataset |
| if class_id < 0 or class_id >= n_cls: |
| raise ValueError(f"--class-id must be in [0, {n_cls - 1}] for {dataset_name}") |
| targets = _targets_list(dataset) |
| indices = [i for i, y in enumerate(targets) if int(y) == class_id] |
| if not indices: |
| raise RuntimeError(f"No samples found for class_id={class_id}") |
| print(f"Single-class filter: dataset={dataset_name}, class_id={class_id}, n_samples={len(indices)}") |
| return Subset(dataset, indices) |
|
|
|
|
| def load_training_dataset(args: argparse.Namespace, transforms: v2.Compose) -> torch.utils.data.Dataset: |
| name = args.dataset |
| if name == "cifar10": |
| ds: torch.utils.data.Dataset = torchvision.datasets.CIFAR10( |
| root=args.data_root, |
| train=(args.cifar_split == "train"), |
| download=True, |
| transform=transforms, |
| ) |
| elif name == "imagenette": |
| ds = torchvision.datasets.Imagenette( |
| args.data_root, |
| split=args.imagenette_split, |
| size=args.imagenette_size, |
| download=True, |
| transform=transforms, |
| ) |
| else: |
| raise ValueError(f"Unknown dataset: {name}") |
|
|
| ds = _maybe_single_class(ds, single_class=args.single_class, class_id=args.class_id, dataset_name=name) |
| return ds |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| default_cfg = Path(__file__).resolve().parent / "unet_config.yaml" |
| config_path = Path(args.config).resolve() if args.config else default_cfg |
| cfg = load_unet_config_yaml(config_path) |
| print(f"Loaded UNet config from: {config_path}") |
| torch.manual_seed(args.seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(args.seed) |
|
|
| device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) |
| print(f"Using device: {device}") |
|
|
| os.makedirs(args.save_dir, exist_ok=True) |
|
|
| tb_dir = os.path.join(args.log_dir, args.run_name) if args.run_name else args.log_dir |
| os.makedirs(tb_dir, exist_ok=True) |
| writer = SummaryWriter(log_dir=tb_dir) |
| writer.add_text("config/args", str(vars(args)), 0) |
| writer.add_text("config/unet_yaml", config_path.read_text(encoding="utf-8"), 0) |
|
|
| transforms = v2.Compose( |
| [ |
| v2.ToTensor(), |
| v2.ToDtype(torch.float32, scale=True), |
| v2.Resize((32,32)), |
| v2.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), |
| ] |
| ) |
| train_dataset = load_training_dataset(args, transforms) |
| print(f"Dataset: {args.dataset}, size={len(train_dataset)}") |
|
|
| orig_len = len(train_dataset) |
| if args.data_percent < 100: |
| new_len = max(1, int(orig_len * args.data_percent / 100.0)) |
|
|
| g = torch.Generator() |
| g.manual_seed(args.seed) |
|
|
| perm = torch.randperm(orig_len, generator=g) |
| indices = perm[:new_len].tolist() |
| torch.save(perm[:new_len], os.path.join(args.save_dir, "indices.pt")) |
| train_dataset = Subset(train_dataset, indices) |
| print(f"Subsampled dataset: {args.data_percent}% -> {len(train_dataset)} samples") |
| else: |
| print(f"Using full dataset: {orig_len} samples") |
| |
| dummy_dataloader = DataLoader( |
| train_dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=device.type == "cuda", |
| ) |
|
|
| total_optimizer_steps = len(dummy_dataloader) * args.epochs |
|
|
| fm = ConditionalFlowMatcher(sigma=cfg.sigma) |
| net_model = UNetModelWrapper( |
| dim=cfg.dim, |
| num_res_blocks=cfg.num_res_blocks, |
| num_channels=cfg.num_channels, |
| channel_mult=cfg.channel_mult, |
| num_heads=cfg.num_heads, |
| num_head_channels=cfg.num_head_channels, |
| attention_resolutions=cfg.attention_resolutions, |
| dropout=cfg.dropout, |
| ).to(device) |
|
|
| optim = torch.optim.AdamW(net_model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
| |
| scheduler = torch.optim.lr_scheduler.LinearLR(optim, total_iters=max(total_optimizer_steps, 1)) |
| t_span = torch.linspace(0, 1, cfg.inference_steps + 1, device=device) |
|
|
| c, h, w = cfg.dim |
| global_step = 0 |
| best_loss = float("inf") |
|
|
| for ep in range(args.epochs): |
| net_model.train() |
| epoch_loss = 0.0 |
| num_batches = 0 |
|
|
| for data in dummy_dataloader: |
| x1 = data[0].to(device, non_blocking=True) |
| x0 = torch.randn_like(x1) |
| t, xt, ut = fm.sample_location_and_conditional_flow(x0, x1) |
| vt = net_model(t, xt) |
| loss = torch.mean((vt - ut) ** 2) |
|
|
| optim.zero_grad(set_to_none=True) |
| loss.backward() |
| optim.step() |
| scheduler.step() |
|
|
| epoch_loss += loss.item() |
| num_batches += 1 |
|
|
| writer.add_scalar("train/loss_step", loss.item(), global_step) |
| writer.add_scalar("train/lr", scheduler.get_last_lr()[0], global_step) |
|
|
| if global_step % args.log_interval == 0: |
| print(f"[step {global_step}] loss = {loss.item():.6f}") |
|
|
| global_step += 1 |
|
|
| avg_epoch_loss = epoch_loss / max(num_batches, 1) |
| writer.add_scalar("train/loss_epoch", avg_epoch_loss, ep) |
| print(f"[epoch {ep}] avg loss = {avg_epoch_loss:.6f}") |
|
|
| |
| net_model.eval() |
| node = NeuralODE(net_model, solver="euler") |
| with torch.no_grad(): |
| x_vis = torch.randn(cfg.vis_batch_size, c, h, w, device=device) |
| traj = node.trajectory(x_vis, t_span=t_span) |
| x_final = traj[-1] |
| x_final = x_final.clamp(0.0, 1.0).cpu() |
| grid = torchvision.utils.make_grid(x_final, nrow=4, padding=2, normalize=False) |
| writer.add_image("samples/neural_ode_final", grid, ep) |
| |
| if ep % 30 == 0: |
| ckpt_path = os.path.join(args.save_dir, f"model_epoch_{ep}.pt") |
| torch.save(net_model.state_dict(), ckpt_path) |
|
|
| if ep == 0 or avg_epoch_loss < best_loss: |
| best_loss = avg_epoch_loss |
| torch.save(net_model.state_dict(), os.path.join(args.save_dir, "model_best.pt")) |
|
|
| writer.close() |
| print(f"Done. Checkpoints: {args.save_dir}") |
| print(f"TensorBoard: tensorboard --logdir {tb_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|