| """Train TactileVAE on the fota_unlabeled parquet dataset. |
| |
| Run: |
| python tactile_vae/script/train_vae.py --config tactile_vae/config/train_vae.yaml |
| |
| Each run lives in `<runs_root>/<run_id>/`. Re-launching with the same |
| `run_id` auto-resumes from `ckpt_last.pt` in that directory (override with |
| `--no-resume`, or `--resume-from <path>` to resume from a specific file). |
| |
| Writes into `<runs_root>/<run_id>/`: |
| - metrics.csv per-step training + per-eval validation metrics |
| - samples/step_*.png input vs. reconstruction grid (val-set images) |
| - ckpt_last.pt most recent checkpoint |
| - ckpt_step_*.pt periodic checkpoints (rotated; keep_last_ckpts) |
| - ckpt_best.pt lowest monitored validation metric (default: val/total) |
| - run.log stdout mirror |
| - config.snapshot.yaml the resolved config (first launch — preserved on resume) |
| |
| Checkpoints are saved as: |
| {"state_dict": ..., "optimizer": ..., "scaler": ..., "scheduler": ..., |
| "step": int, "epoch": int, "config": dict, "best_val_recon": float} |
| which `tactile_vae.model.load_pretrained` can re-open via its `state_dict` key. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import datetime as dt |
| import math |
| import os |
| import random |
| import signal |
| import sys |
| import time |
| from collections import deque |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import yaml |
| from PIL import Image |
| from torch.utils.data import DataLoader |
| import torch.nn.functional as F |
|
|
| try: |
| import wandb |
| except ImportError: |
| wandb = None |
|
|
| _REPO_ROOT = Path(__file__).resolve().parents[2] |
| if str(_REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(_REPO_ROOT)) |
|
|
| from tactile_vae.dataset import ( |
| ColorJitterConfig, |
| ParquetFileShuffleSampler, |
| TactileParquetDataset, |
| ) |
| from tactile_vae.model import TactileVAE, VAELoss |
|
|
|
|
| |
| |
| |
|
|
| def _resolve_path(p: str | None) -> Path | None: |
| if p is None: |
| return None |
| path = Path(p) |
| return path if path.is_absolute() else (_REPO_ROOT / path).resolve() |
|
|
|
|
| def _autogenerate_run_id() -> str: |
| return "run_" + dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
|
|
| def load_config(path: Path) -> dict: |
| with path.open() as f: |
| cfg = yaml.safe_load(f) |
|
|
| |
| if not cfg.get("run_id"): |
| cfg["run_id"] = _autogenerate_run_id() |
|
|
| |
| |
| if cfg.get("output_dir"): |
| cfg["output_dir"] = str(_resolve_path(cfg["output_dir"])) |
| else: |
| runs_root = _resolve_path(cfg.get("runs_root", "runs")) |
| cfg["output_dir"] = str(runs_root / cfg["run_id"]) |
|
|
| cfg["data"]["root"] = str(_resolve_path(cfg["data"]["root"])) |
| if cfg["data"].get("splits_path"): |
| cfg["data"]["splits_path"] = str(_resolve_path(cfg["data"]["splits_path"])) |
| if cfg["train"].get("resume_from"): |
| cfg["train"]["resume_from"] = str(_resolve_path(cfg["train"]["resume_from"])) |
| return cfg |
|
|
|
|
| def _maybe_autoresume(cfg: dict, *, allow_autoresume: bool) -> None: |
| """If the run dir already has ckpt_last.pt and the user didn't pin |
| `resume_from`, auto-resume from it. Mutates `cfg["train"]` in place.""" |
| if cfg["train"].get("resume_from") or not allow_autoresume: |
| return |
| last = Path(cfg["output_dir"]) / "ckpt_last.pt" |
| if last.exists(): |
| cfg["train"]["resume_from"] = str(last) |
|
|
|
|
| def set_seed(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def pick_device(spec: str) -> torch.device: |
| if spec == "auto": |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| return torch.device(spec) |
|
|
|
|
| def init_wandb(config: dict, output_dir: Path) -> Any: |
| """Initialize a wandb run if `WANDB_PROJECT` is set in the environment. |
| |
| Run id / name default to the training-script `run_id` so that re-launching |
| with the same run_id continues the same wandb run (`resume="allow"`). |
| Returns the wandb run handle, or None when wandb is unavailable / disabled. |
| """ |
| if wandb is None or not os.environ.get("WANDB_PROJECT"): |
| return None |
| run = wandb.init( |
| project=os.environ["WANDB_PROJECT"], |
| entity=os.environ.get("WANDB_ENTITY"), |
| id=os.environ.get("WANDB_RUN_ID") or config["run_id"], |
| name=os.environ.get("WANDB_NAME") or config["run_id"], |
| resume="allow", |
| config=config, |
| mode=os.environ.get("WANDB_MODE", "online"), |
| dir=str(output_dir), |
| ) |
| return run |
|
|
|
|
| class TeeLogger: |
| """stdout that also appends to a file.""" |
|
|
| def __init__(self, path: Path): |
| path.parent.mkdir(parents=True, exist_ok=True) |
| self._fh = path.open("a", buffering=1) |
| self._stdout = sys.stdout |
|
|
| def write(self, msg: str) -> None: |
| self._stdout.write(msg) |
| self._fh.write(msg) |
|
|
| def flush(self) -> None: |
| self._stdout.flush() |
| self._fh.flush() |
|
|
|
|
| |
| |
| |
|
|
| def build_datasets(data_cfg: dict) -> tuple[TactileParquetDataset, TactileParquetDataset]: |
| common = dict( |
| root=data_cfg["root"], |
| image_size=data_cfg["image_size"], |
| cache_files=data_cfg.get("cache_files", 1), |
| splits_path=data_cfg.get("splits_path"), |
| return_meta=data_cfg.get("return_meta", False), |
| ) |
| if data_cfg.get("meta_columns"): |
| common["meta_columns"] = data_cfg["meta_columns"] |
|
|
| jitter_cfg = data_cfg.get("color_jitter") |
| color_jitter = ColorJitterConfig(**jitter_cfg) if jitter_cfg else None |
|
|
| train_ds = TactileParquetDataset(split="train", color_jitter=color_jitter, **common) |
| val_ds = TactileParquetDataset(split="val", color_jitter=None, **common) |
| return train_ds, val_ds |
|
|
|
|
| def build_model(model_cfg: dict) -> TactileVAE: |
| return TactileVAE(**model_cfg) |
|
|
|
|
| class ConfigurablePerceptualVAELoss(nn.Module): |
| """VAE loss with configurable perceptual term: SSIM or LPIPS.""" |
|
|
| def __init__(self, loss_cfg: dict): |
| super().__init__() |
| self.perceptual_type = str(loss_cfg.get("perceptual_type", "ssim")).lower() |
| if self.perceptual_type not in {"ssim", "lpips"}: |
| raise ValueError( |
| f"loss.perceptual_type must be one of [ssim, lpips], got: {self.perceptual_type!r}" |
| ) |
| self.aux_key = self.perceptual_type |
| self.ssim_impl: VAELoss | None = None |
| self.lpips_impl: nn.Module | None = None |
|
|
| if self.perceptual_type == "ssim": |
| self.ssim_impl = VAELoss(**loss_cfg) |
| else: |
| self.beta = float(loss_cfg.get("beta", 1e-3)) |
| self.recon_type = str(loss_cfg.get("recon_type", "l1")).lower() |
| self.lpips_weight = float(loss_cfg.get("lpips_weight", loss_cfg.get("ssim_weight", 0.1))) |
| try: |
| import lpips |
| except ImportError as exc: |
| raise ImportError( |
| "LPIPS loss requested but `lpips` is not installed. " |
| "Install with: pip install lpips" |
| ) from exc |
| self.lpips_impl = lpips.LPIPS(net="alex") |
| self.lpips_impl.eval() |
| for p in self.lpips_impl.parameters(): |
| p.requires_grad = False |
|
|
| def forward(self, x_hat: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> dict[str, torch.Tensor]: |
| if self.perceptual_type == "ssim": |
| assert self.ssim_impl is not None |
| return self.ssim_impl(x_hat, x, mu, logvar) |
|
|
| if self.recon_type == "l1": |
| recon = F.l1_loss(x_hat, x) |
| elif self.recon_type == "mse": |
| recon = F.mse_loss(x_hat, x) |
| else: |
| raise ValueError(f"loss.recon_type must be one of [l1, mse], got: {self.recon_type!r}") |
|
|
| |
| with torch.amp.autocast(device_type=x_hat.device.type, enabled=False): |
| x_hat_lp = (2.0 * x_hat.float()) - 1.0 |
| x_lp = (2.0 * x.float()) - 1.0 |
| assert self.lpips_impl is not None |
| lpips_val = self.lpips_impl(x_hat_lp, x_lp).mean() |
| recon_total = recon + self.lpips_weight * lpips_val |
| kl = (-0.5 * (1 + logvar - mu.pow(2) - logvar.exp())).mean() |
| total = recon_total + self.beta * kl |
| return { |
| "total": total, |
| "recon": recon, |
| "recon_total": recon_total, |
| "lpips": lpips_val, |
| "kl": kl, |
| } |
|
|
|
|
| def build_loss(loss_cfg: dict) -> nn.Module: |
| return ConfigurablePerceptualVAELoss(loss_cfg) |
|
|
|
|
| def build_optimizer(params, optim_cfg: dict) -> torch.optim.Optimizer: |
| return torch.optim.AdamW( |
| params, |
| lr=optim_cfg["lr"], |
| weight_decay=optim_cfg.get("weight_decay", 0.0), |
| betas=tuple(optim_cfg.get("betas", (0.9, 0.95))), |
| eps=optim_cfg.get("eps", 1e-8), |
| ) |
|
|
|
|
| def lr_at_step(step: int, base_lr: float, total_steps: int, sched_cfg: dict) -> float: |
| warmup = int(sched_cfg.get("warmup_steps", 0)) |
| sched = sched_cfg.get("type", "constant") |
| if step < warmup: |
| return base_lr * (step + 1) / max(1, warmup) |
| if sched == "constant": |
| return base_lr |
| if sched == "cosine": |
| min_ratio = float(sched_cfg.get("min_lr_ratio", 0.1)) |
| |
| progress = (step - warmup) / max(1, total_steps - warmup) |
| progress = min(max(progress, 0.0), 1.0) |
| cos = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| return base_lr * (min_ratio + (1 - min_ratio) * cos) |
| raise ValueError(f"unknown scheduler type: {sched}") |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class MetricAccum: |
| sum: float = 0.0 |
| n: int = 0 |
|
|
| def add(self, v: float, count: int = 1) -> None: |
| self.sum += v * count |
| self.n += count |
|
|
| def mean(self) -> float: |
| return self.sum / self.n if self.n else float("nan") |
|
|
|
|
| @torch.no_grad() |
| def run_validation( |
| model: TactileVAE, |
| criterion: nn.Module, |
| loader: DataLoader, |
| device: torch.device, |
| max_batches: int, |
| ) -> dict[str, float]: |
| model.eval() |
| accs: dict[str, MetricAccum] = {} |
| for i, batch in enumerate(loader): |
| if i >= max_batches: |
| break |
| x = batch.to(device, non_blocking=True) |
| out = model(x, sample=False) |
| losses = criterion(out["x_hat"], x, out["mu"], out["logvar"]) |
| bs = x.shape[0] |
| for k, v in losses.items(): |
| if k not in accs: |
| accs[k] = MetricAccum() |
| accs[k].add(v.item(), bs) |
| model.train() |
| return {f"val/{k}": a.mean() for k, a in accs.items()} |
|
|
|
|
| def _to_uint8_hwc(t: torch.Tensor) -> np.ndarray: |
| arr = t.detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy() |
| return (arr * 255).astype(np.uint8) |
|
|
|
|
| @torch.no_grad() |
| def save_sample_grid( |
| model: TactileVAE, |
| val_ds: TactileParquetDataset, |
| device: torch.device, |
| out_path: Path, |
| n: int, |
| rng_state: np.random.Generator, |
| ) -> tuple[list[np.ndarray], list[np.ndarray]]: |
| """Sample `n` images from val, run reconstruction, save a top=target/bottom=recon grid. |
| |
| Returns (targets, reconstructions) as lists of HWC uint8 arrays for wandb logging. |
| """ |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| indices = rng_state.choice(len(val_ds), size=n, replace=False).tolist() |
| imgs = torch.stack([val_ds[i] for i in indices]).to(device, non_blocking=True) |
| model.eval() |
| recon = model(imgs, sample=False)["x_hat"] |
| model.train() |
|
|
| targets = [_to_uint8_hwc(imgs[i]) for i in range(n)] |
| recons = [_to_uint8_hwc(recon[i]) for i in range(n)] |
|
|
| |
| h = w = val_ds.image_size |
| canvas = np.zeros((2 * h, n * w, 3), dtype=np.uint8) |
| for i in range(n): |
| canvas[:h, i * w : (i + 1) * w] = targets[i] |
| canvas[h:, i * w : (i + 1) * w] = recons[i] |
| Image.fromarray(canvas).save(out_path) |
|
|
| return targets, recons |
|
|
|
|
| def save_checkpoint( |
| path: Path, |
| *, |
| model: nn.Module, |
| optimizer: torch.optim.Optimizer, |
| scaler: torch.amp.GradScaler | None, |
| step: int, |
| epoch: int, |
| config: dict, |
| best_val_metric: float, |
| best_metric_name: str, |
| ) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| payload: dict[str, Any] = { |
| "state_dict": model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "step": step, |
| "epoch": epoch, |
| "config": config, |
| "best_val_metric": best_val_metric, |
| "best_metric_name": best_metric_name, |
| |
| "best_val_recon": best_val_metric, |
| } |
| if scaler is not None: |
| payload["scaler"] = scaler.state_dict() |
| tmp = path.with_suffix(path.suffix + ".tmp") |
| torch.save(payload, tmp) |
| os.replace(tmp, path) |
|
|
|
|
| def rotate_periodic_ckpts(out_dir: Path, keep: int) -> None: |
| ckpts = sorted(out_dir.glob("ckpt_step_*.pt")) |
| while len(ckpts) > keep: |
| ckpts.pop(0).unlink(missing_ok=True) |
|
|
|
|
| |
| |
| |
|
|
| def train(config: dict) -> None: |
| set_seed(config["seed"]) |
| device = pick_device(config["device"]) |
| out_dir = Path(config["output_dir"]) |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| sys.stdout = TeeLogger(out_dir / "run.log") |
| print(f"== Tactile VAE training ==") |
| print(f"run_id: {config['run_id']} device: {device}") |
| print(f"output_dir: {out_dir}") |
| if config["train"].get("resume_from"): |
| print(f"resume_from: {config['train']['resume_from']}") |
|
|
| wandb_run = init_wandb(config, out_dir) |
| if wandb_run is not None: |
| print(f"wandb: project={os.environ.get('WANDB_PROJECT')} " |
| f"run_id={wandb_run.id} url={wandb_run.url}") |
| else: |
| print("wandb: disabled (set WANDB_PROJECT to enable)") |
|
|
| |
| |
| snap = out_dir / "config.snapshot.yaml" |
| if not snap.exists(): |
| with snap.open("w") as f: |
| yaml.safe_dump(config, f, sort_keys=False) |
|
|
| train_ds, val_ds = build_datasets(config["data"]) |
| print(f"datasets: train={len(train_ds):,} val={len(val_ds):,}") |
|
|
| tcfg = config["train"] |
| train_sampler = ParquetFileShuffleSampler(train_ds, seed=config["seed"]) |
| train_loader = DataLoader( |
| train_ds, |
| batch_size=tcfg["batch_size"], |
| sampler=train_sampler, |
| num_workers=tcfg["num_workers"], |
| pin_memory=device.type == "cuda", |
| drop_last=True, |
| persistent_workers=tcfg["num_workers"] > 0, |
| prefetch_factor=2 if tcfg["num_workers"] > 0 else None, |
| ) |
| val_loader = DataLoader( |
| val_ds, |
| batch_size=tcfg["batch_size"], |
| shuffle=False, |
| num_workers=max(2, tcfg["num_workers"] // 2), |
| pin_memory=device.type == "cuda", |
| drop_last=False, |
| ) |
| steps_per_epoch = len(train_loader) |
| total_steps = ( |
| tcfg["max_steps"] |
| if tcfg.get("max_steps") |
| else steps_per_epoch * tcfg["epochs"] |
| ) |
| print(f"steps/epoch={steps_per_epoch:,} total_steps={total_steps:,}") |
|
|
| model = build_model(config["model"]).to(device) |
| criterion = build_loss(config["loss"]).to(device) |
| optimizer = build_optimizer(model.parameters(), config["optim"]) |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"model: {model.__class__.__name__} params={n_params:,}") |
|
|
| use_amp = bool(tcfg.get("amp", False)) and device.type == "cuda" |
| amp_dtype_cfg = str(tcfg.get("amp_dtype", "bf16")).lower() |
| if amp_dtype_cfg not in {"bf16", "bfloat16"}: |
| print(f"[info] overriding train.amp_dtype={amp_dtype_cfg!r} to 'bf16' (enforced)") |
| amp_dtype = torch.bfloat16 |
| if not use_amp: |
| amp_dtype = torch.float32 |
|
|
| |
| scaler = None |
|
|
| |
| stop_requested = False |
|
|
| def _request_stop(signum: int, _frame) -> None: |
| nonlocal stop_requested |
| stop_requested = True |
| try: |
| sig_name = signal.Signals(signum).name |
| except ValueError: |
| sig_name = str(signum) |
| print(f"[signal] received {sig_name}; stopping after current step and saving ckpt_last.pt") |
|
|
| prev_sigterm = signal.getsignal(signal.SIGTERM) |
| prev_sigint = signal.getsignal(signal.SIGINT) |
| signal.signal(signal.SIGTERM, _request_stop) |
| signal.signal(signal.SIGINT, _request_stop) |
|
|
| |
| step = 0 |
| epoch_start = 0 |
| best_metric_name = str(tcfg.get("best_metric", "val/total")) |
| best_val_metric = float("inf") |
| if tcfg.get("resume_from"): |
| ckpt = torch.load(tcfg["resume_from"], map_location=device) |
| model.load_state_dict(ckpt["state_dict"]) |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| if scaler is not None and "scaler" in ckpt: |
| scaler.load_state_dict(ckpt["scaler"]) |
| step = int(ckpt.get("step", 0)) |
| epoch_start = int(ckpt.get("epoch", 0)) |
| best_val_metric = float( |
| ckpt.get("best_val_metric", ckpt.get("best_val_recon", float("inf"))) |
| ) |
| best_metric_name = str(ckpt.get("best_metric_name", best_metric_name)) |
| print(f"resumed from {tcfg['resume_from']} @ step={step} epoch={epoch_start}") |
|
|
| |
| metrics_csv = out_dir / "metrics.csv" |
| new_csv = not metrics_csv.exists() |
| csv_fh = metrics_csv.open("a", newline="", buffering=1) |
| csv_writer = csv.writer(csv_fh) |
| if new_csv: |
| aux_metric_name = str(getattr(criterion, "aux_key", "ssim")) |
| csv_writer.writerow( |
| ["step", "epoch", "lr", "split", |
| "loss_total", "recon", "recon_total", aux_metric_name, "kl", "throughput"] |
| ) |
|
|
| aux_metric_name = str(getattr(criterion, "aux_key", "ssim")) |
| metric_keys = ("total", "recon", "recon_total", aux_metric_name, "kl") |
| running = {k: deque(maxlen=tcfg["log_every"]) for k in metric_keys} |
| grad_norm_running: deque[float] = deque(maxlen=tcfg["log_every"]) |
| sample_rng = np.random.default_rng(config["seed"] + 1) |
| t_window = time.time() |
| samples_in_window = 0 |
|
|
| base_lr = config["optim"]["lr"] |
| model.train() |
|
|
| |
| done = False |
| try: |
| for epoch in range(epoch_start, tcfg["epochs"]): |
| train_sampler.set_epoch(epoch) |
| epoch_accs = {k: MetricAccum() for k in metric_keys} |
| epoch_samples = 0 |
| for batch in train_loader: |
| if step >= total_steps or stop_requested: |
| done = True |
| break |
|
|
| lr = lr_at_step(step, base_lr, total_steps, config["scheduler"]) |
| for g in optimizer.param_groups: |
| g["lr"] = lr |
|
|
| x = batch.to(device, non_blocking=True) |
| optimizer.zero_grad(set_to_none=True) |
|
|
| with torch.amp.autocast(device.type, dtype=amp_dtype, enabled=use_amp): |
| out = model(x) |
| losses = criterion(out["x_hat"], x, out["mu"], out["logvar"]) |
| loss = losses["total"] |
|
|
| if not torch.isfinite(loss).item(): |
| print( |
| f"[warn] non-finite loss at step={step + 1}, epoch={epoch}; " |
| "skipping optimizer step" |
| ) |
| optimizer.zero_grad(set_to_none=True) |
| continue |
|
|
| grad_norm_value: float | None = None |
| if scaler is not None: |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| if tcfg.get("gradient_clip_norm"): |
| gn = torch.nn.utils.clip_grad_norm_( |
| model.parameters(), |
| tcfg["gradient_clip_norm"], |
| ) |
| grad_norm_value = float(gn.item()) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| if tcfg.get("gradient_clip_norm"): |
| gn = torch.nn.utils.clip_grad_norm_( |
| model.parameters(), |
| tcfg["gradient_clip_norm"], |
| ) |
| grad_norm_value = float(gn.item()) |
| optimizer.step() |
|
|
| bs = x.shape[0] |
| for k, dq in running.items(): |
| dq.append(losses[k].item()) |
| for k, acc in epoch_accs.items(): |
| acc.add(losses[k].item(), bs) |
| if grad_norm_value is not None and math.isfinite(grad_norm_value): |
| grad_norm_running.append(grad_norm_value) |
| samples_in_window += bs |
| epoch_samples += bs |
| step += 1 |
|
|
| |
| if step % tcfg["log_every"] == 0 or step == 1: |
| now = time.time() |
| throughput = samples_in_window / max(1e-6, now - t_window) |
| means = {k: sum(dq) / len(dq) for k, dq in running.items()} |
| print( |
| f"step {step:>7} | ep {epoch:>3} | lr {lr:.2e} | " |
| f"total {means['total']:.4f} recon {means['recon']:.4f} " |
| f"{aux_metric_name} {means[aux_metric_name]:.4f} kl {means['kl']:.4f} " |
| f"gclip {((sum(grad_norm_running) / len(grad_norm_running)) if grad_norm_running else float('nan')):.3f} | " |
| f"{throughput:.0f} img/s" |
| ) |
| csv_writer.writerow([ |
| step, epoch, f"{lr:.6g}", "train", |
| f"{means['total']:.6g}", f"{means['recon']:.6g}", |
| f"{means['recon_total']:.6g}", f"{means[aux_metric_name]:.6g}", |
| f"{means['kl']:.6g}", f"{throughput:.1f}", |
| ]) |
| if wandb_run is not None: |
| wandb_run.log({ |
| "train/total": means["total"], |
| "train/recon": means["recon"], |
| "train/recon_total": means["recon_total"], |
| f"train/{aux_metric_name}": means[aux_metric_name], |
| "train/kl": means["kl"], |
| "train/throughput_img_per_s": throughput, |
| "train/lr": lr, |
| "epoch": epoch, |
| }, step=step) |
| t_window = now |
| samples_in_window = 0 |
|
|
| |
| if step % tcfg["val_every_steps"] == 0: |
| vmetrics = run_validation( |
| model, criterion, val_loader, device, |
| max_batches=tcfg["num_val_batches"], |
| ) |
| print( |
| f" [val @ step {step}] " |
| + " ".join(f"{k.split('/')[-1]}={v:.4f}" for k, v in vmetrics.items()) |
| ) |
| csv_writer.writerow([ |
| step, epoch, f"{lr:.6g}", "val", |
| f"{vmetrics['val/total']:.6g}", f"{vmetrics['val/recon']:.6g}", |
| f"{vmetrics['val/recon_total']:.6g}", f"{vmetrics[f'val/{aux_metric_name}']:.6g}", |
| f"{vmetrics['val/kl']:.6g}", "", |
| ]) |
| if wandb_run is not None: |
| wandb_run.log(vmetrics, step=step) |
| if best_metric_name not in vmetrics: |
| raise KeyError( |
| f"train.best_metric={best_metric_name!r} not found in validation metrics " |
| f"{sorted(vmetrics.keys())}" |
| ) |
| if vmetrics[best_metric_name] < best_val_metric: |
| best_val_metric = vmetrics[best_metric_name] |
| save_checkpoint( |
| out_dir / "ckpt_best.pt", |
| model=model, optimizer=optimizer, scaler=scaler, |
| step=step, epoch=epoch, config=config, |
| best_val_metric=best_val_metric, |
| best_metric_name=best_metric_name, |
| ) |
| print( |
| f" -> new best {best_metric_name}={best_val_metric:.4f}, " |
| "saved ckpt_best.pt" |
| ) |
|
|
| |
| if step % tcfg["sample_every_steps"] == 0: |
| sample_path = out_dir / "samples" / f"step_{step:07d}.png" |
| targets, recons = save_sample_grid( |
| model, val_ds, device, |
| out_path=sample_path, |
| n=tcfg["num_sample_images"], |
| rng_state=sample_rng, |
| ) |
| if wandb_run is not None: |
| wandb_run.log({ |
| "samples/target": [ |
| wandb.Image(tgt, caption=f"sample {i}") for i, tgt in enumerate(targets) |
| ], |
| "samples/reconstruction": [ |
| wandb.Image(rec, caption=f"sample {i}") for i, rec in enumerate(recons) |
| ], |
| }, step=step) |
|
|
| |
| if step % tcfg["ckpt_every_steps"] == 0: |
| save_checkpoint( |
| out_dir / f"ckpt_step_{step:07d}.pt", |
| model=model, optimizer=optimizer, scaler=scaler, |
| step=step, epoch=epoch, config=config, |
| best_val_metric=best_val_metric, |
| best_metric_name=best_metric_name, |
| ) |
| save_checkpoint( |
| out_dir / "ckpt_last.pt", |
| model=model, optimizer=optimizer, scaler=scaler, |
| step=step, epoch=epoch, config=config, |
| best_val_metric=best_val_metric, |
| best_metric_name=best_metric_name, |
| ) |
| rotate_periodic_ckpts(out_dir, tcfg["keep_last_ckpts"]) |
| print(f" saved ckpt_step_{step:07d}.pt") |
|
|
| |
| if epoch_samples > 0: |
| epoch_means = {k: acc.mean() for k, acc in epoch_accs.items()} |
| print( |
| f"[epoch {epoch} end @ step {step}] " |
| + " ".join(f"{k}={v:.4f}" for k, v in epoch_means.items()) |
| ) |
| csv_writer.writerow([ |
| step, epoch, f"{lr:.6g}", "train_epoch", |
| f"{epoch_means['total']:.6g}", f"{epoch_means['recon']:.6g}", |
| f"{epoch_means['recon_total']:.6g}", f"{epoch_means[aux_metric_name]:.6g}", |
| f"{epoch_means['kl']:.6g}", "", |
| ]) |
| if wandb_run is not None: |
| wandb_run.log( |
| {f"epoch_train/{k}": v for k, v in epoch_means.items()} | {"epoch": epoch}, |
| step=step, |
| ) |
|
|
| if done: |
| break |
| finally: |
| signal.signal(signal.SIGTERM, prev_sigterm) |
| signal.signal(signal.SIGINT, prev_sigint) |
|
|
| save_checkpoint( |
| out_dir / "ckpt_last.pt", |
| model=model, optimizer=optimizer, scaler=scaler, |
| step=step, epoch=epoch, config=config, |
| best_val_metric=best_val_metric, |
| best_metric_name=best_metric_name, |
| ) |
| csv_fh.close() |
| if wandb_run is not None: |
| wandb_run.summary["best_val_metric"] = best_val_metric |
| wandb_run.summary["best_metric_name"] = best_metric_name |
| wandb_run.summary["final_step"] = step |
| wandb_run.finish() |
| print(f"done. step={step} best_{best_metric_name}={best_val_metric:.4f}") |
|
|
|
|
| |
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser() |
| p.add_argument( |
| "--config", |
| type=Path, |
| default=Path(__file__).resolve().parents[1] / "config" / "train_vae.yaml", |
| ) |
| p.add_argument("--run-id", type=str, default=None, |
| help="override run_id from the config (output_dir becomes runs_root/<run-id>)") |
| p.add_argument("--output-dir", type=str, default=None, |
| help="override output_dir directly (bypasses runs_root/run_id derivation)") |
| p.add_argument("--resume-from", type=str, default=None, |
| help="resume from a specific checkpoint path (overrides auto-resume)") |
| p.add_argument("--no-resume", action="store_true", |
| help="start fresh even if <output_dir>/ckpt_last.pt exists") |
| return p.parse_args() |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
|
|
| |
| |
| with args.config.open() as f: |
| raw_cfg = yaml.safe_load(f) |
| if args.run_id: |
| raw_cfg["run_id"] = args.run_id |
| if args.output_dir: |
| raw_cfg["output_dir"] = args.output_dir |
|
|
| |
| tmp = args.config.parent / f".__cli_override_{os.getpid()}.yaml" |
| try: |
| with tmp.open("w") as f: |
| yaml.safe_dump(raw_cfg, f, sort_keys=False) |
| cfg = load_config(tmp) |
| finally: |
| tmp.unlink(missing_ok=True) |
|
|
| if args.resume_from: |
| cfg["train"]["resume_from"] = str(_resolve_path(args.resume_from)) |
| _maybe_autoresume(cfg, allow_autoresume=not args.no_resume) |
|
|
| train(cfg) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|