"""pivot training loop, algorithm 1. per minibatch of matched (c0, c1) pairs with perturbation u: embed perturbations -> sample (s,t) -> interpolants -> L_map; sample τ -> L_tan; sample (s,r,t) -> L_semi; add L_reg; backprop on (θ, η). """ from __future__ import annotations import os import time from dataclasses import dataclass, field, asdict import numpy as np import torch from src.data.perturb_data import PerturbData from src.data.splits import load_split from src.models.pivot import PIVOT from src.training.losses import compute_losses from src.utils.common import pick_device, save_json, set_seed @dataclass class TrainConfig: dataset: str = "norman" embedding: str = "pca" split: str = "perturbation" rep_mode: str = "gene_op" match: str = "batch" # control-matching strategy d_pert: int = 64 hidden: int = 512 depth: int = 4 dropout: float = 0.0 lr: float = 1e-3 weight_decay: float = 1e-5 batch_size: int = 1024 epochs: int = 60 lam_tan: float = 1.0 lam_semi: float = 0.5 lam_reg: float = 1e-4 grad_clip: float = 5.0 train_frac: float = 1.0 # data-scaling ablation lam_dist: float = 0.0 # distributional flow loss (population mmd) weight n_dist_perts: int = 4 # perturbations sampled per step for the dist loss dist_n: int = 64 # cells per population in the dist loss seed: int = 0 device_index: int | None = None components: list = field(default_factory=lambda: ["map", "tan", "semi"]) # ablate losses def _ablate_lambdas(cfg: TrainConfig) -> dict: lam = {"map": 1.0, "tan": cfg.lam_tan, "semi": cfg.lam_semi, "reg": cfg.lam_reg} if "map" not in cfg.components: lam["map"] = 0.0 if "tan" not in cfg.components: lam["tan"] = 0.0 if "semi" not in cfg.components: lam["semi"] = 0.0 return lam def make_model(data: PerturbData, cfg: TrainConfig, device) -> PIVOT: gene_pathway, n_path = None, 0 if cfg.rep_mode == "gene_pathway_op": fc = data.functional_clusters(seed=cfg.seed) gp = np.array([fc.get(g, 0) for g in data.genes_vocab], dtype=np.int64) gene_pathway, n_path = gp, int(gp.max()) + 1 return PIVOT( d_state=data.d, n_genes=len(data.genes_vocab), n_ops=len(data.op_vocab), n_perts=len(data.perturbations), d_pert=cfg.d_pert, hidden=cfg.hidden, depth=cfg.depth, rep_mode=cfg.rep_mode, gene_pathway=gene_pathway, n_pathways=n_path, dropout=cfg.dropout, ).to(device) def train(cfg: TrainConfig, data: PerturbData | None = None, verbose: bool = True, log_every: int = 10): set_seed(cfg.seed) device = pick_device(cfg.device_index) if data is None: data = PerturbData(os.path.join("data/processed", cfg.dataset), embedding=cfg.embedding) split = load_split(data.dir, cfg.split) rng = np.random.default_rng(cfg.seed) emb = torch.as_tensor(data.emb, device=device) labels = data.obs["perturbation"].values # perturbed training cells (controls used only as matching pool) train_idx = split["train_idx"] is_ctrl = data.is_control[train_idx] pert_train = train_idx[~is_ctrl] if cfg.train_frac < 1.0: n = max(1, int(cfg.train_frac * len(pert_train))) pert_train = rng.choice(pert_train, size=n, replace=False) if verbose: print(f"train cells: {len(pert_train)} perturbed | device {device} | " f"split {cfg.split} | match {cfg.match} | rep {cfg.rep_mode}") model = make_model(data, cfg, device) opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs) lam = _ablate_lambdas(cfg) from src.models.encoders import build_pert_tensors # precompute per-perturbation train cell lists for the distributional loss dist_cells = None if cfg.lam_dist > 0: tr_labels = labels[pert_train] dist_cells = {p: pert_train[tr_labels == p] for p in np.unique(tr_labels)} dist_cells = {p: v for p, v in dist_cells.items() if len(v) >= 8} dist_pert_list = list(dist_cells.keys()) def _mmd2_torch(x, y): with torch.no_grad(): d = torch.cdist(x[: min(128, len(x))], y[: min(128, len(y))]) gamma = 1.0 / (d.median() ** 2 + 1e-8) k = lambda a, b: torch.exp(-gamma * torch.cdist(a, b) ** 2) return k(x, x).mean() + k(y, y).mean() - 2 * k(x, y).mean() history = [] t0 = time.time() for epoch in range(cfg.epochs): model.train() perm = rng.permutation(pert_train) ep_losses = [] for b in range(0, len(perm), cfg.batch_size): batch_idx = perm[b: b + cfg.batch_size] ctrl_idx = data.sample_controls(batch_idx, cfg.match, rng) c0 = emb[ctrl_idx] c1 = emb[batch_idx] blabels = labels[batch_idx] g, o, mask, pid = build_pert_tensors(data, list(blabels), device=device) e = model.encode(g, o, mask, pid) total, comp = compute_losses(model.flow, e, c0, c1, lam) if cfg.lam_dist > 0: terms = [] for p in rng.choice(dist_pert_list, size=min(cfg.n_dist_perts, len(dist_pert_list)), replace=False): pc = dist_cells[p] c1d = emb[rng.choice(pc, min(cfg.dist_n, len(pc)), replace=False)] c0d = emb[rng.choice(data.control_idx, c1d.shape[0], replace=True)] gd, od, md, pidd = build_pert_tensors(data, [p], device=device) ed = model.encode(gd, od, md, pidd).expand(c0d.shape[0], -1) chat = model.flow.endpoint(c0d, ed) terms.append(_mmd2_torch(chat, c1d)) Ldist = torch.stack(terms).mean() total = total + cfg.lam_dist * Ldist comp["dist"] = Ldist.item() opt.zero_grad() total.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) opt.step() ep_losses.append(comp) sched.step() m = {k: float(np.mean([d[k] for d in ep_losses])) for k in ep_losses[0]} history.append(m) if verbose and (epoch % log_every == 0 or epoch == cfg.epochs - 1): print(f" ep {epoch:3d} total={m['total']:.4f} map={m['map']:.4f} " f"tan={m['tan']:.4f} semi={m['semi']:.4f}") dur = time.time() - t0 if verbose: print(f"trained in {dur:.1f}s") return model, {"history": history, "duration_s": dur, "device": str(device), "n_train_cells": int(len(pert_train))} def save_checkpoint(model, cfg: TrainConfig, info: dict, out_dir: str): os.makedirs(out_dir, exist_ok=True) torch.save(model.state_dict(), os.path.join(out_dir, "model.pt")) save_json(asdict(cfg), os.path.join(out_dir, "config.json")) save_json(info, os.path.join(out_dir, "train_info.json")) if __name__ == "__main__": import argparse, dataclasses ap = argparse.ArgumentParser() for f in dataclasses.fields(TrainConfig): if f.name == "components": ap.add_argument("--components", nargs="+", default=None) elif f.type == "int | None" or f.name == "device_index": ap.add_argument(f"--{f.name.replace('_','-')}", type=int, default=None) else: typ = type(f.default) ap.add_argument(f"--{f.name.replace('_','-')}", type=typ, default=f.default) ap.add_argument("--out", default="experiments/exp_001") args = ap.parse_args() kw = {k: v for k, v in vars(args).items() if k != "out" and v is not None} cfg = TrainConfig(**kw) model, info = train(cfg) save_checkpoint(model, cfg, info, args.out) print("saved ->", args.out)