#!/usr/bin/env python3 """ CFM training with unconditional JiT (jit_model_unconditional.JiT). Mirrors train_cfm_unet.py (data, TensorBoard, checkpoints); model + YAML differ. JiT expects forward(x, t); torchdyn NeuralODE calls f(t, x) — use CFMFlowWrapper. """ from __future__ import annotations import argparse import os from dataclasses import dataclass from pathlib import Path from typing import Any import torch import torch.nn as nn import torchvision from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import v2 from torchdyn.core import NeuralODE from torchcfm.conditional_flow_matching import ConditionalFlowMatcher from jit import JiT try: import yaml # type: ignore[import-untyped] except ImportError as e: # pragma: no cover raise ImportError("Please `pip install pyyaml` to use --config.") from e # Reuse dataset helpers from UNet trainer (same CLI for data) from train_unet import load_training_dataset def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Train unconditional JiT with Conditional Flow Matching") p.add_argument( "--dataset", type=str, default="imagenette", choices=["cifar10", "imagenette"], help="Training dataset", ) p.add_argument("--data-root", type=str, default=".", help="Root for dataset download/cache") p.add_argument("--cifar-split", type=str, default="train", choices=["train", "test"]) p.add_argument("--imagenette-split", type=str, default="train", choices=["train", "val"]) p.add_argument("--imagenette-size", type=str, default="160px", choices=["160px", "320px", "full"]) p.add_argument("--single-class", action="store_true") p.add_argument("--class-id", type=int, default=0) p.add_argument("--batch-size", type=int, default=64) p.add_argument("--num-workers", type=int, default=4) p.add_argument("--epochs", type=int, default=30) p.add_argument("--device", type=str, default=None, help="cuda | cpu (default: auto)") p.add_argument("--log-interval", type=int, default=100) p.add_argument("--seed", type=int, default=0) p.add_argument("--save-dir", type=str, default="./runs/cfm_jit/checkpoints") p.add_argument( "--log-dir", type=str, default="./runs/cfm_jit/tensorboard", ) p.add_argument("--run-name", type=str, default=None) p.add_argument( "--config", type=str, default=None, help="YAML with JiT + CFM hyperparameters (default: jit_config.yaml next to this script)", ) return p.parse_args() def _dim_from_yaml(value: Any) -> tuple[int, int, int]: if isinstance(value, (list, tuple)) and len(value) == 3: return (int(value[0]), int(value[1]), int(value[2])) raise ValueError("YAML 'dim' must be [C, H, W]") @dataclass class JiTTrainConfig: sigma: float dim: tuple[int, int, int] lr: float weight_decay: float inference_steps: int vis_batch_size: int input_size: int patch_size: int hidden_size: int depth: int num_heads: int mlp_ratio: float attn_drop: float proj_drop: float bottleneck_dim: int in_context_len: int in_context_start: int REQUIRED_JIT_YAML_KEYS = ( "sigma", "dim", "lr", "weight_decay", "inference_steps", "vis_batch_size", "input_size", "patch_size", "hidden_size", "depth", "num_heads", "mlp_ratio", "attn_drop", "proj_drop", "bottleneck_dim", "in_context_len", "in_context_start", ) def load_jit_config_yaml(path: str | os.PathLike[str]) -> JiTTrainConfig: path = Path(path) if not path.is_file(): raise FileNotFoundError(f"Config file not found: {path.resolve()}") with open(path, encoding="utf-8") as f: raw = yaml.safe_load(f) if raw is None or not isinstance(raw, dict): raise ValueError(f"Config must be a YAML mapping: {path}") missing = [k for k in REQUIRED_JIT_YAML_KEYS if k not in raw] if missing: raise ValueError(f"Missing keys in {path}: {missing}") dim = _dim_from_yaml(raw["dim"]) input_size = int(raw["input_size"]) if dim[1] != input_size or dim[2] != input_size: raise ValueError(f"dim {dim} must match input_size×input_size ({input_size})") return JiTTrainConfig( sigma=float(raw["sigma"]), dim=dim, lr=float(raw["lr"]), weight_decay=float(raw["weight_decay"]), inference_steps=int(raw["inference_steps"]), vis_batch_size=int(raw["vis_batch_size"]), input_size=input_size, patch_size=int(raw["patch_size"]), hidden_size=int(raw["hidden_size"]), depth=int(raw["depth"]), num_heads=int(raw["num_heads"]), mlp_ratio=float(raw["mlp_ratio"]), attn_drop=float(raw["attn_drop"]), proj_drop=float(raw["proj_drop"]), bottleneck_dim=int(raw["bottleneck_dim"]), in_context_len=int(raw["in_context_len"]), in_context_start=int(raw["in_context_start"]), ) def build_jit(cfg: JiTTrainConfig) -> JiT: c = cfg.dim[0] return JiT( input_size=cfg.input_size, patch_size=cfg.patch_size, in_channels=c, hidden_size=cfg.hidden_size, depth=cfg.depth, num_heads=cfg.num_heads, mlp_ratio=cfg.mlp_ratio, attn_drop=cfg.attn_drop, proj_drop=cfg.proj_drop, bottleneck_dim=cfg.bottleneck_dim, in_context_len=cfg.in_context_len, in_context_start=cfg.in_context_start, ) class CFMFlowWrapper(nn.Module): """ torchdyn NeuralODE expects f(t, x) with same batch as x. JiT is forward(x, t) with t shape (N,). """ def __init__(self, model: JiT): super().__init__() self.model = model def forward(self, t: torch.Tensor, x: torch.Tensor, y=None, *args, **kwargs) -> torch.Tensor: batch = x.shape[0] t_flat = torch.as_tensor(t, device=x.device, dtype=torch.float32).reshape(-1) if t_flat.numel() == 1: t_flat = t_flat.expand(batch) elif t_flat.shape[0] != batch: t_flat = t_flat[:batch] return self.model(x, t_flat) def main() -> None: args = parse_args() default_cfg = Path(__file__).resolve().parent / "jit_config.yaml" config_path = Path(args.config).resolve() if args.config else default_cfg cfg = load_jit_config_yaml(config_path) print(f"Loaded JiT config from: {config_path}") torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) print(f"Using device: {device}") os.makedirs(args.save_dir, exist_ok=True) tb_dir = os.path.join(args.log_dir, args.run_name) if args.run_name else args.log_dir os.makedirs(tb_dir, exist_ok=True) writer = SummaryWriter(log_dir=tb_dir) writer.add_text("config/args", str(vars(args)), 0) writer.add_text("config/jit_yaml", config_path.read_text(encoding="utf-8"), 0) transforms = v2.Compose( [ v2.ToTensor(), v2.ToDtype(torch.float32, scale=True), v2.Resize((cfg.input_size, cfg.input_size)), v2.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), ] ) train_dataset = load_training_dataset(args, transforms) print(f"Dataset: {args.dataset}, size={len(train_dataset)}") dummy_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=device.type == "cuda", ) total_optimizer_steps = len(dummy_dataloader) * args.epochs fm = ConditionalFlowMatcher(sigma=cfg.sigma) net_model = build_jit(cfg).to(device) ode_net = CFMFlowWrapper(net_model) optim = torch.optim.AdamW(net_model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) scheduler = torch.optim.lr_scheduler.LinearLR(optim, total_iters=max(total_optimizer_steps, 1)) t_span = torch.linspace(0, 1, cfg.inference_steps + 1, device=device) c, h, w = cfg.dim global_step = 0 best_loss = float("inf") for ep in range(args.epochs): net_model.train() epoch_loss = 0.0 num_batches = 0 for data in dummy_dataloader: x1 = data[0].to(device, non_blocking=True) x0 = torch.randn_like(x1) t, xt, ut = fm.sample_location_and_conditional_flow(x0, x1) t_b = t.reshape(-1).float() vt = net_model(xt, t_b) loss = torch.mean((vt - ut) ** 2) optim.zero_grad(set_to_none=True) loss.backward() optim.step() scheduler.step() epoch_loss += loss.item() num_batches += 1 writer.add_scalar("train/loss_step", loss.item(), global_step) writer.add_scalar("train/lr", scheduler.get_last_lr()[0], global_step) if global_step % args.log_interval == 0: print(f"[step {global_step}] loss = {loss.item():.6f}") global_step += 1 avg_epoch_loss = epoch_loss / max(num_batches, 1) writer.add_scalar("train/loss_epoch", avg_epoch_loss, ep) print(f"[epoch {ep}] avg loss = {avg_epoch_loss:.6f}") net_model.eval() node = NeuralODE(ode_net, solver="euler") with torch.no_grad(): x_vis = torch.randn(cfg.vis_batch_size, c, h, w, device=device) traj = node.trajectory(x_vis, t_span=t_span) x_final = traj[-1] x_final = x_final.clamp(0.0, 1.0).cpu() grid = torchvision.utils.make_grid(x_final, nrow=4, padding=2, normalize=False) writer.add_image("samples/neural_ode_final", grid, ep) if ep % 30 == 0: ckpt_path = os.path.join(args.save_dir, f"model_epoch_{ep}.pt") torch.save(net_model.state_dict(), ckpt_path) if ep == 0 or avg_epoch_loss < best_loss: best_loss = avg_epoch_loss torch.save(net_model.state_dict(), os.path.join(args.save_dir, "model_best.pt")) writer.close() print(f"Done. Checkpoints: {args.save_dir}") print(f"TensorBoard: tensorboard --logdir {tb_dir}") if __name__ == "__main__": main()