| |
| """ |
| CFM training with unconditional JiT (jit_model_unconditional.JiT). |
| Mirrors train_cfm_unet.py (data, TensorBoard, checkpoints); model + YAML differ. |
| |
| JiT expects forward(x, t); torchdyn NeuralODE calls f(t, x) — use CFMFlowWrapper. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
| import torchvision |
| from torch.utils.data import DataLoader |
| from torch.utils.tensorboard import SummaryWriter |
| from torchvision.transforms import v2 |
| from torchdyn.core import NeuralODE |
|
|
| from torchcfm.conditional_flow_matching import ConditionalFlowMatcher |
|
|
| from jit import JiT |
|
|
| try: |
| import yaml |
| except ImportError as e: |
| raise ImportError("Please `pip install pyyaml` to use --config.") from e |
|
|
| |
| from train_unet import load_training_dataset |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Train unconditional JiT with Conditional Flow Matching") |
|
|
| p.add_argument( |
| "--dataset", |
| type=str, |
| default="imagenette", |
| choices=["cifar10", "imagenette"], |
| help="Training dataset", |
| ) |
| p.add_argument("--data-root", type=str, default=".", help="Root for dataset download/cache") |
| p.add_argument("--cifar-split", type=str, default="train", choices=["train", "test"]) |
| p.add_argument("--imagenette-split", type=str, default="train", choices=["train", "val"]) |
| p.add_argument("--imagenette-size", type=str, default="160px", choices=["160px", "320px", "full"]) |
| p.add_argument("--single-class", action="store_true") |
| p.add_argument("--class-id", type=int, default=0) |
| p.add_argument("--batch-size", type=int, default=64) |
| p.add_argument("--num-workers", type=int, default=4) |
|
|
| p.add_argument("--epochs", type=int, default=30) |
| p.add_argument("--device", type=str, default=None, help="cuda | cpu (default: auto)") |
| p.add_argument("--log-interval", type=int, default=100) |
| p.add_argument("--seed", type=int, default=0) |
|
|
| p.add_argument("--save-dir", type=str, default="./runs/cfm_jit/checkpoints") |
| p.add_argument( |
| "--log-dir", |
| type=str, |
| default="./runs/cfm_jit/tensorboard", |
| ) |
| p.add_argument("--run-name", type=str, default=None) |
|
|
| p.add_argument( |
| "--config", |
| type=str, |
| default=None, |
| help="YAML with JiT + CFM hyperparameters (default: jit_config.yaml next to this script)", |
| ) |
|
|
| return p.parse_args() |
|
|
|
|
| def _dim_from_yaml(value: Any) -> tuple[int, int, int]: |
| if isinstance(value, (list, tuple)) and len(value) == 3: |
| return (int(value[0]), int(value[1]), int(value[2])) |
| raise ValueError("YAML 'dim' must be [C, H, W]") |
|
|
|
|
| @dataclass |
| class JiTTrainConfig: |
| sigma: float |
| dim: tuple[int, int, int] |
| lr: float |
| weight_decay: float |
| inference_steps: int |
| vis_batch_size: int |
| input_size: int |
| patch_size: int |
| hidden_size: int |
| depth: int |
| num_heads: int |
| mlp_ratio: float |
| attn_drop: float |
| proj_drop: float |
| bottleneck_dim: int |
| in_context_len: int |
| in_context_start: int |
|
|
|
|
| REQUIRED_JIT_YAML_KEYS = ( |
| "sigma", |
| "dim", |
| "lr", |
| "weight_decay", |
| "inference_steps", |
| "vis_batch_size", |
| "input_size", |
| "patch_size", |
| "hidden_size", |
| "depth", |
| "num_heads", |
| "mlp_ratio", |
| "attn_drop", |
| "proj_drop", |
| "bottleneck_dim", |
| "in_context_len", |
| "in_context_start", |
| ) |
|
|
|
|
| def load_jit_config_yaml(path: str | os.PathLike[str]) -> JiTTrainConfig: |
| path = Path(path) |
| if not path.is_file(): |
| raise FileNotFoundError(f"Config file not found: {path.resolve()}") |
|
|
| with open(path, encoding="utf-8") as f: |
| raw = yaml.safe_load(f) |
| if raw is None or not isinstance(raw, dict): |
| raise ValueError(f"Config must be a YAML mapping: {path}") |
|
|
| missing = [k for k in REQUIRED_JIT_YAML_KEYS if k not in raw] |
| if missing: |
| raise ValueError(f"Missing keys in {path}: {missing}") |
|
|
| dim = _dim_from_yaml(raw["dim"]) |
| input_size = int(raw["input_size"]) |
| if dim[1] != input_size or dim[2] != input_size: |
| raise ValueError(f"dim {dim} must match input_size×input_size ({input_size})") |
|
|
| return JiTTrainConfig( |
| sigma=float(raw["sigma"]), |
| dim=dim, |
| lr=float(raw["lr"]), |
| weight_decay=float(raw["weight_decay"]), |
| inference_steps=int(raw["inference_steps"]), |
| vis_batch_size=int(raw["vis_batch_size"]), |
| input_size=input_size, |
| patch_size=int(raw["patch_size"]), |
| hidden_size=int(raw["hidden_size"]), |
| depth=int(raw["depth"]), |
| num_heads=int(raw["num_heads"]), |
| mlp_ratio=float(raw["mlp_ratio"]), |
| attn_drop=float(raw["attn_drop"]), |
| proj_drop=float(raw["proj_drop"]), |
| bottleneck_dim=int(raw["bottleneck_dim"]), |
| in_context_len=int(raw["in_context_len"]), |
| in_context_start=int(raw["in_context_start"]), |
| ) |
|
|
|
|
| def build_jit(cfg: JiTTrainConfig) -> JiT: |
| c = cfg.dim[0] |
| return JiT( |
| input_size=cfg.input_size, |
| patch_size=cfg.patch_size, |
| in_channels=c, |
| hidden_size=cfg.hidden_size, |
| depth=cfg.depth, |
| num_heads=cfg.num_heads, |
| mlp_ratio=cfg.mlp_ratio, |
| attn_drop=cfg.attn_drop, |
| proj_drop=cfg.proj_drop, |
| bottleneck_dim=cfg.bottleneck_dim, |
| in_context_len=cfg.in_context_len, |
| in_context_start=cfg.in_context_start, |
| ) |
|
|
|
|
| class CFMFlowWrapper(nn.Module): |
| """ |
| torchdyn NeuralODE expects f(t, x) with same batch as x. |
| JiT is forward(x, t) with t shape (N,). |
| """ |
|
|
| def __init__(self, model: JiT): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, t: torch.Tensor, x: torch.Tensor, y=None, *args, **kwargs) -> torch.Tensor: |
| batch = x.shape[0] |
| t_flat = torch.as_tensor(t, device=x.device, dtype=torch.float32).reshape(-1) |
| if t_flat.numel() == 1: |
| t_flat = t_flat.expand(batch) |
| elif t_flat.shape[0] != batch: |
| t_flat = t_flat[:batch] |
| return self.model(x, t_flat) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| default_cfg = Path(__file__).resolve().parent / "jit_config.yaml" |
| config_path = Path(args.config).resolve() if args.config else default_cfg |
| cfg = load_jit_config_yaml(config_path) |
| print(f"Loaded JiT config from: {config_path}") |
|
|
| torch.manual_seed(args.seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(args.seed) |
|
|
| device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) |
| print(f"Using device: {device}") |
|
|
| os.makedirs(args.save_dir, exist_ok=True) |
|
|
| tb_dir = os.path.join(args.log_dir, args.run_name) if args.run_name else args.log_dir |
| os.makedirs(tb_dir, exist_ok=True) |
| writer = SummaryWriter(log_dir=tb_dir) |
| writer.add_text("config/args", str(vars(args)), 0) |
| writer.add_text("config/jit_yaml", config_path.read_text(encoding="utf-8"), 0) |
|
|
| transforms = v2.Compose( |
| [ |
| v2.ToTensor(), |
| v2.ToDtype(torch.float32, scale=True), |
| v2.Resize((cfg.input_size, cfg.input_size)), |
| v2.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), |
| ] |
| ) |
| train_dataset = load_training_dataset(args, transforms) |
| print(f"Dataset: {args.dataset}, size={len(train_dataset)}") |
|
|
| dummy_dataloader = DataLoader( |
| train_dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=device.type == "cuda", |
| ) |
|
|
| total_optimizer_steps = len(dummy_dataloader) * args.epochs |
|
|
| fm = ConditionalFlowMatcher(sigma=cfg.sigma) |
| net_model = build_jit(cfg).to(device) |
| ode_net = CFMFlowWrapper(net_model) |
|
|
| optim = torch.optim.AdamW(net_model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
| scheduler = torch.optim.lr_scheduler.LinearLR(optim, total_iters=max(total_optimizer_steps, 1)) |
| t_span = torch.linspace(0, 1, cfg.inference_steps + 1, device=device) |
|
|
| c, h, w = cfg.dim |
| global_step = 0 |
| best_loss = float("inf") |
|
|
| for ep in range(args.epochs): |
| net_model.train() |
| epoch_loss = 0.0 |
| num_batches = 0 |
|
|
| for data in dummy_dataloader: |
| x1 = data[0].to(device, non_blocking=True) |
| x0 = torch.randn_like(x1) |
| t, xt, ut = fm.sample_location_and_conditional_flow(x0, x1) |
| t_b = t.reshape(-1).float() |
| vt = net_model(xt, t_b) |
| loss = torch.mean((vt - ut) ** 2) |
|
|
| optim.zero_grad(set_to_none=True) |
| loss.backward() |
| optim.step() |
| scheduler.step() |
|
|
| epoch_loss += loss.item() |
| num_batches += 1 |
|
|
| writer.add_scalar("train/loss_step", loss.item(), global_step) |
| writer.add_scalar("train/lr", scheduler.get_last_lr()[0], global_step) |
|
|
| if global_step % args.log_interval == 0: |
| print(f"[step {global_step}] loss = {loss.item():.6f}") |
|
|
| global_step += 1 |
|
|
| avg_epoch_loss = epoch_loss / max(num_batches, 1) |
| writer.add_scalar("train/loss_epoch", avg_epoch_loss, ep) |
| print(f"[epoch {ep}] avg loss = {avg_epoch_loss:.6f}") |
|
|
| net_model.eval() |
| node = NeuralODE(ode_net, solver="euler") |
| with torch.no_grad(): |
| x_vis = torch.randn(cfg.vis_batch_size, c, h, w, device=device) |
| traj = node.trajectory(x_vis, t_span=t_span) |
| x_final = traj[-1] |
| x_final = x_final.clamp(0.0, 1.0).cpu() |
| grid = torchvision.utils.make_grid(x_final, nrow=4, padding=2, normalize=False) |
| writer.add_image("samples/neural_ode_final", grid, ep) |
| |
| |
| if ep % 30 == 0: |
| ckpt_path = os.path.join(args.save_dir, f"model_epoch_{ep}.pt") |
| torch.save(net_model.state_dict(), ckpt_path) |
|
|
| if ep == 0 or avg_epoch_loss < best_loss: |
| best_loss = avg_epoch_loss |
| torch.save(net_model.state_dict(), os.path.join(args.save_dir, "model_best.pt")) |
|
|
| writer.close() |
| print(f"Done. Checkpoints: {args.save_dir}") |
| print(f"TensorBoard: tensorboard --logdir {tb_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|