"""Training loop shared across all four models. Differences across runs are entirely in the model registered under `config.model`. """ from __future__ import annotations import json import math import os import time from dataclasses import dataclass, field from pathlib import Path import numpy as np import torch import yaml from torch.utils.data import DataLoader from .data import MIMICAlignedDataset, collate_with_dt, split_by_subject from .ema import ema_tau from .models import MODEL_REGISTRY, ModelConfig from .monitor import CollapseMonitor, cross_modal_cosine, effective_rank @dataclass class TrainConfig: run_name: str = "debug" model: str = "F" # one of A, B, C, F epochs: int = 100 batch_size: int = 64 lr: float = 1e-4 weight_decay: float = 0.04 warmup_epochs: int = 10 ema_start: float = 0.996 ema_end: float = 0.9999 ema_warmup_frac: float = 0.30 grad_clip: float = 1.0 log_every: int = 100 ckpt_every_epochs: int = 5 seed: int = 0 wandb_project: str = "physiojepa" wandb_mode: str = "online" wandb_entity: str | None = None output_dir: str = "runs" index_path: str = "cache/mimic_index.json" shard_roots: list[str] = field(default_factory=list) num_workers: int = 4 amp: bool = True # controls for Δt sampling inside collate_fn log_uniform_frac: float = 0.6 # window-level subsetting (for fast iteration / K2 gate runs) subset_frac: float = 1.0 # ablation knobs forwarded to ModelConfig pred_depth: int = 4 query_mode: str = "learned" mask_ratio: float = 0.50 # precomputed mmap dataset (overrides shard_roots + index_path if set) fast_cache_dir: str = "" def load_yaml_config(path: str) -> TrainConfig: with open(path, "r") as f: d = yaml.safe_load(f) return TrainConfig(**d) class _Collator: """Top-level callable so DataLoader workers can serialize it across fork.""" def __init__(self, log_uniform_frac: float, seed: int): self.log_uniform_frac = log_uniform_frac self.seed = seed self._rng = None def __call__(self, items): if self._rng is None: self._rng = np.random.default_rng(self.seed + os.getpid()) return collate_with_dt(items, log_uniform_frac=self.log_uniform_frac, rng=self._rng) def _build_dataloaders(cfg: TrainConfig) -> tuple[DataLoader, DataLoader, list[str]]: if cfg.fast_cache_dir: from .data_fast import MIMICFastDataset cache_dir = Path(cfg.fast_cache_dir) import json meta = json.loads((cache_dir / "windows_meta.json").read_text()) subjects = sorted(set(meta["subjects"])) train_subj, val_subj = split_by_subject(subjects, frac=0.9, seed=cfg.seed) train_ds = MIMICFastDataset(cache_dir, subjects_allow=train_subj) val_ds = MIMICFastDataset(cache_dir, subjects_allow=val_subj) else: shard_roots = [Path(p) for p in cfg.shard_roots] ds_full = MIMICAlignedDataset( shard_roots=shard_roots, index_path=Path(cfg.index_path), build_index=not Path(cfg.index_path).exists(), ) subjects = sorted({r["subject_id"] for r in ds_full.index}) train_subj, val_subj = split_by_subject(subjects, frac=0.9, seed=cfg.seed) train_ds = MIMICAlignedDataset( shard_roots, Path(cfg.index_path), build_index=False, subjects_allow=train_subj, subset_frac=cfg.subset_frac, subset_seed=cfg.seed, ) val_ds = MIMICAlignedDataset( shard_roots, Path(cfg.index_path), build_index=False, subjects_allow=val_subj, ) collate = _Collator(cfg.log_uniform_frac, cfg.seed) train_loader = DataLoader( train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, collate_fn=collate, drop_last=True, persistent_workers=cfg.num_workers > 0, ) val_loader = DataLoader( val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=max(cfg.num_workers, 1), collate_fn=collate, drop_last=False, ) return train_loader, val_loader, subjects def _cosine_lr(step: int, total_steps: int, base: float, warmup_steps: int) -> float: if step < warmup_steps: return base * (step + 1) / max(1, warmup_steps) progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) return 0.5 * base * (1 + math.cos(math.pi * progress)) def train(cfg: TrainConfig) -> dict: import wandb torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")) train_loader, val_loader, subjects = _build_dataloaders(cfg) print(f"[trainer] device={device} n_train_windows={len(train_loader.dataset)} " f"n_val_windows={len(val_loader.dataset)} subjects={len(subjects)}") mcfg = ModelConfig( pred_depth=cfg.pred_depth, query_mode=cfg.query_mode, mask_ratio=cfg.mask_ratio, ) model = MODEL_REGISTRY[cfg.model](mcfg).to(device) opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) scaler = torch.amp.GradScaler(device.type) if cfg.amp and device.type == "cuda" else None total_steps = cfg.epochs * len(train_loader) warmup_steps = cfg.warmup_epochs * len(train_loader) wandb.init(project=cfg.wandb_project, name=cfg.run_name, config=cfg.__dict__, mode=cfg.wandb_mode, entity=cfg.wandb_entity) monitor = CollapseMonitor() step = 0 out_root = Path(cfg.output_dir) / cfg.run_name out_root.mkdir(parents=True, exist_ok=True) aborted = False for epoch in range(cfg.epochs): model.train(True) for batch in train_loader: # move to device for k in ("ecg", "ppg", "dt_seconds", "ptt_ms"): if k in batch and isinstance(batch[k], torch.Tensor): batch[k] = batch[k].to(device) # lr schedule lr_now = _cosine_lr(step, total_steps, cfg.lr, warmup_steps) for g in opt.param_groups: g["lr"] = lr_now opt.zero_grad(set_to_none=True) if scaler is not None: with torch.amp.autocast("cuda"): out = model.step(batch) scaler.scale(out["loss"]).backward() scaler.unscale_(opt) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) scaler.step(opt) scaler.update() else: out = model.step(batch) out["loss"].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) opt.step() # EMA update tau = ema_tau(step, total_steps, cfg.ema_start, cfg.ema_end, cfg.ema_warmup_frac) for online, tgt in model.targets(): tgt.update(online, tau) if step % cfg.log_every == 0: metrics = { "step": step, "epoch": epoch, "lr": lr_now, "tau": tau, "loss": float(out["loss"].detach().item()), "L_cross": float(out.get("L_cross", torch.tensor(0.0)).item()), "L_self": float(out.get("L_self", torch.tensor(0.0)).item()), } z_e = out.get("z_ecg") if z_e is not None and z_e.shape[0] > 1: metrics["ecg_latent_var"] = float(z_e.var(dim=0).mean().item()) metrics["ecg_eff_rank"] = effective_rank(z_e) z_p_pred = out.get("z_pred") z_p_tgt = out.get("z_ppg") if z_p_pred is not None and z_p_tgt is not None and z_p_pred.shape[0] > 1: cosine = cross_modal_cosine(z_p_pred, z_p_tgt) metrics["cross_modal_cosine"] = cosine if monitor.update(cosine): print(f"[trainer] COLLAPSE DETECTED at step={step} cosine={cosine:.4f}") aborted = True wandb.log(metrics, step=step) print(f"[step {step}] loss={metrics['loss']:.4f} " f"L_cross={metrics['L_cross']:.4f} L_self={metrics['L_self']:.4f} " f"tau={tau:.4f}") step += 1 if aborted: break if aborted: break if (epoch + 1) % cfg.ckpt_every_epochs == 0 or epoch == cfg.epochs - 1: ckpt = out_root / f"ckpt_epoch{epoch + 1:03d}.pt" torch.save({"model": model.state_dict(), "cfg": cfg.__dict__, "epoch": epoch + 1, "step": step}, ckpt) print(f"[trainer] saved {ckpt}") final_ckpt = out_root / "ckpt_final.pt" torch.save({"model": model.state_dict(), "cfg": cfg.__dict__, "aborted": aborted, "step": step}, final_ckpt) wandb.finish() return {"aborted": aborted, "final_step": step, "ckpt": str(final_ckpt)}