| """pivot training loop, algorithm 1. |
| |
| per minibatch of matched (c0, c1) pairs with perturbation u: |
| embed perturbations -> sample (s,t) -> interpolants -> L_map; |
| sample τ -> L_tan; sample (s,r,t) -> L_semi; add L_reg; backprop on (θ, η). |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import time |
| from dataclasses import dataclass, field, asdict |
|
|
| import numpy as np |
| import torch |
|
|
| from src.data.perturb_data import PerturbData |
| from src.data.splits import load_split |
| from src.models.pivot import PIVOT |
| from src.training.losses import compute_losses |
| from src.utils.common import pick_device, save_json, set_seed |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| dataset: str = "norman" |
| embedding: str = "pca" |
| split: str = "perturbation" |
| rep_mode: str = "gene_op" |
| match: str = "batch" |
| d_pert: int = 64 |
| hidden: int = 512 |
| depth: int = 4 |
| dropout: float = 0.0 |
| lr: float = 1e-3 |
| weight_decay: float = 1e-5 |
| batch_size: int = 1024 |
| epochs: int = 60 |
| lam_tan: float = 1.0 |
| lam_semi: float = 0.5 |
| lam_reg: float = 1e-4 |
| grad_clip: float = 5.0 |
| train_frac: float = 1.0 |
| lam_dist: float = 0.0 |
| n_dist_perts: int = 4 |
| dist_n: int = 64 |
| seed: int = 0 |
| device_index: int | None = None |
| components: list = field(default_factory=lambda: ["map", "tan", "semi"]) |
|
|
|
|
| def _ablate_lambdas(cfg: TrainConfig) -> dict: |
| lam = {"map": 1.0, "tan": cfg.lam_tan, "semi": cfg.lam_semi, "reg": cfg.lam_reg} |
| if "map" not in cfg.components: |
| lam["map"] = 0.0 |
| if "tan" not in cfg.components: |
| lam["tan"] = 0.0 |
| if "semi" not in cfg.components: |
| lam["semi"] = 0.0 |
| return lam |
|
|
|
|
| def make_model(data: PerturbData, cfg: TrainConfig, device) -> PIVOT: |
| gene_pathway, n_path = None, 0 |
| if cfg.rep_mode == "gene_pathway_op": |
| fc = data.functional_clusters(seed=cfg.seed) |
| gp = np.array([fc.get(g, 0) for g in data.genes_vocab], dtype=np.int64) |
| gene_pathway, n_path = gp, int(gp.max()) + 1 |
| return PIVOT( |
| d_state=data.d, n_genes=len(data.genes_vocab), n_ops=len(data.op_vocab), |
| n_perts=len(data.perturbations), d_pert=cfg.d_pert, hidden=cfg.hidden, |
| depth=cfg.depth, rep_mode=cfg.rep_mode, gene_pathway=gene_pathway, |
| n_pathways=n_path, dropout=cfg.dropout, |
| ).to(device) |
|
|
|
|
| def train(cfg: TrainConfig, data: PerturbData | None = None, verbose: bool = True, log_every: int = 10): |
| set_seed(cfg.seed) |
| device = pick_device(cfg.device_index) |
| if data is None: |
| data = PerturbData(os.path.join("data/processed", cfg.dataset), embedding=cfg.embedding) |
| split = load_split(data.dir, cfg.split) |
| rng = np.random.default_rng(cfg.seed) |
|
|
| emb = torch.as_tensor(data.emb, device=device) |
| labels = data.obs["perturbation"].values |
|
|
| |
| train_idx = split["train_idx"] |
| is_ctrl = data.is_control[train_idx] |
| pert_train = train_idx[~is_ctrl] |
| if cfg.train_frac < 1.0: |
| n = max(1, int(cfg.train_frac * len(pert_train))) |
| pert_train = rng.choice(pert_train, size=n, replace=False) |
| if verbose: |
| print(f"train cells: {len(pert_train)} perturbed | device {device} | " |
| f"split {cfg.split} | match {cfg.match} | rep {cfg.rep_mode}") |
|
|
| model = make_model(data, cfg, device) |
| opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs) |
| lam = _ablate_lambdas(cfg) |
|
|
| from src.models.encoders import build_pert_tensors |
|
|
| |
| dist_cells = None |
| if cfg.lam_dist > 0: |
| tr_labels = labels[pert_train] |
| dist_cells = {p: pert_train[tr_labels == p] for p in np.unique(tr_labels)} |
| dist_cells = {p: v for p, v in dist_cells.items() if len(v) >= 8} |
| dist_pert_list = list(dist_cells.keys()) |
|
|
| def _mmd2_torch(x, y): |
| with torch.no_grad(): |
| d = torch.cdist(x[: min(128, len(x))], y[: min(128, len(y))]) |
| gamma = 1.0 / (d.median() ** 2 + 1e-8) |
| k = lambda a, b: torch.exp(-gamma * torch.cdist(a, b) ** 2) |
| return k(x, x).mean() + k(y, y).mean() - 2 * k(x, y).mean() |
|
|
| history = [] |
| t0 = time.time() |
| for epoch in range(cfg.epochs): |
| model.train() |
| perm = rng.permutation(pert_train) |
| ep_losses = [] |
| for b in range(0, len(perm), cfg.batch_size): |
| batch_idx = perm[b: b + cfg.batch_size] |
| ctrl_idx = data.sample_controls(batch_idx, cfg.match, rng) |
| c0 = emb[ctrl_idx] |
| c1 = emb[batch_idx] |
| blabels = labels[batch_idx] |
| g, o, mask, pid = build_pert_tensors(data, list(blabels), device=device) |
| e = model.encode(g, o, mask, pid) |
| total, comp = compute_losses(model.flow, e, c0, c1, lam) |
| if cfg.lam_dist > 0: |
| terms = [] |
| for p in rng.choice(dist_pert_list, size=min(cfg.n_dist_perts, len(dist_pert_list)), |
| replace=False): |
| pc = dist_cells[p] |
| c1d = emb[rng.choice(pc, min(cfg.dist_n, len(pc)), replace=False)] |
| c0d = emb[rng.choice(data.control_idx, c1d.shape[0], replace=True)] |
| gd, od, md, pidd = build_pert_tensors(data, [p], device=device) |
| ed = model.encode(gd, od, md, pidd).expand(c0d.shape[0], -1) |
| chat = model.flow.endpoint(c0d, ed) |
| terms.append(_mmd2_torch(chat, c1d)) |
| Ldist = torch.stack(terms).mean() |
| total = total + cfg.lam_dist * Ldist |
| comp["dist"] = Ldist.item() |
| opt.zero_grad() |
| total.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) |
| opt.step() |
| ep_losses.append(comp) |
| sched.step() |
| m = {k: float(np.mean([d[k] for d in ep_losses])) for k in ep_losses[0]} |
| history.append(m) |
| if verbose and (epoch % log_every == 0 or epoch == cfg.epochs - 1): |
| print(f" ep {epoch:3d} total={m['total']:.4f} map={m['map']:.4f} " |
| f"tan={m['tan']:.4f} semi={m['semi']:.4f}") |
| dur = time.time() - t0 |
| if verbose: |
| print(f"trained in {dur:.1f}s") |
| return model, {"history": history, "duration_s": dur, "device": str(device), |
| "n_train_cells": int(len(pert_train))} |
|
|
|
|
| def save_checkpoint(model, cfg: TrainConfig, info: dict, out_dir: str): |
| os.makedirs(out_dir, exist_ok=True) |
| torch.save(model.state_dict(), os.path.join(out_dir, "model.pt")) |
| save_json(asdict(cfg), os.path.join(out_dir, "config.json")) |
| save_json(info, os.path.join(out_dir, "train_info.json")) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse, dataclasses |
| ap = argparse.ArgumentParser() |
| for f in dataclasses.fields(TrainConfig): |
| if f.name == "components": |
| ap.add_argument("--components", nargs="+", default=None) |
| elif f.type == "int | None" or f.name == "device_index": |
| ap.add_argument(f"--{f.name.replace('_','-')}", type=int, default=None) |
| else: |
| typ = type(f.default) |
| ap.add_argument(f"--{f.name.replace('_','-')}", type=typ, default=f.default) |
| ap.add_argument("--out", default="experiments/exp_001") |
| args = ap.parse_args() |
| kw = {k: v for k, v in vars(args).items() if k != "out" and v is not None} |
| cfg = TrainConfig(**kw) |
| model, info = train(cfg) |
| save_checkpoint(model, cfg, info, args.out) |
| print("saved ->", args.out) |
|
|