PIVOT / src /training /train.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
7.94 kB
"""pivot training loop, algorithm 1.
per minibatch of matched (c0, c1) pairs with perturbation u:
embed perturbations -> sample (s,t) -> interpolants -> L_map;
sample τ -> L_tan; sample (s,r,t) -> L_semi; add L_reg; backprop on (θ, η).
"""
from __future__ import annotations
import os
import time
from dataclasses import dataclass, field, asdict
import numpy as np
import torch
from src.data.perturb_data import PerturbData
from src.data.splits import load_split
from src.models.pivot import PIVOT
from src.training.losses import compute_losses
from src.utils.common import pick_device, save_json, set_seed
@dataclass
class TrainConfig:
dataset: str = "norman"
embedding: str = "pca"
split: str = "perturbation"
rep_mode: str = "gene_op"
match: str = "batch" # control-matching strategy
d_pert: int = 64
hidden: int = 512
depth: int = 4
dropout: float = 0.0
lr: float = 1e-3
weight_decay: float = 1e-5
batch_size: int = 1024
epochs: int = 60
lam_tan: float = 1.0
lam_semi: float = 0.5
lam_reg: float = 1e-4
grad_clip: float = 5.0
train_frac: float = 1.0 # data-scaling ablation
lam_dist: float = 0.0 # distributional flow loss (population mmd) weight
n_dist_perts: int = 4 # perturbations sampled per step for the dist loss
dist_n: int = 64 # cells per population in the dist loss
seed: int = 0
device_index: int | None = None
components: list = field(default_factory=lambda: ["map", "tan", "semi"]) # ablate losses
def _ablate_lambdas(cfg: TrainConfig) -> dict:
lam = {"map": 1.0, "tan": cfg.lam_tan, "semi": cfg.lam_semi, "reg": cfg.lam_reg}
if "map" not in cfg.components:
lam["map"] = 0.0
if "tan" not in cfg.components:
lam["tan"] = 0.0
if "semi" not in cfg.components:
lam["semi"] = 0.0
return lam
def make_model(data: PerturbData, cfg: TrainConfig, device) -> PIVOT:
gene_pathway, n_path = None, 0
if cfg.rep_mode == "gene_pathway_op":
fc = data.functional_clusters(seed=cfg.seed)
gp = np.array([fc.get(g, 0) for g in data.genes_vocab], dtype=np.int64)
gene_pathway, n_path = gp, int(gp.max()) + 1
return PIVOT(
d_state=data.d, n_genes=len(data.genes_vocab), n_ops=len(data.op_vocab),
n_perts=len(data.perturbations), d_pert=cfg.d_pert, hidden=cfg.hidden,
depth=cfg.depth, rep_mode=cfg.rep_mode, gene_pathway=gene_pathway,
n_pathways=n_path, dropout=cfg.dropout,
).to(device)
def train(cfg: TrainConfig, data: PerturbData | None = None, verbose: bool = True, log_every: int = 10):
set_seed(cfg.seed)
device = pick_device(cfg.device_index)
if data is None:
data = PerturbData(os.path.join("data/processed", cfg.dataset), embedding=cfg.embedding)
split = load_split(data.dir, cfg.split)
rng = np.random.default_rng(cfg.seed)
emb = torch.as_tensor(data.emb, device=device)
labels = data.obs["perturbation"].values
# perturbed training cells (controls used only as matching pool)
train_idx = split["train_idx"]
is_ctrl = data.is_control[train_idx]
pert_train = train_idx[~is_ctrl]
if cfg.train_frac < 1.0:
n = max(1, int(cfg.train_frac * len(pert_train)))
pert_train = rng.choice(pert_train, size=n, replace=False)
if verbose:
print(f"train cells: {len(pert_train)} perturbed | device {device} | "
f"split {cfg.split} | match {cfg.match} | rep {cfg.rep_mode}")
model = make_model(data, cfg, device)
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)
lam = _ablate_lambdas(cfg)
from src.models.encoders import build_pert_tensors
# precompute per-perturbation train cell lists for the distributional loss
dist_cells = None
if cfg.lam_dist > 0:
tr_labels = labels[pert_train]
dist_cells = {p: pert_train[tr_labels == p] for p in np.unique(tr_labels)}
dist_cells = {p: v for p, v in dist_cells.items() if len(v) >= 8}
dist_pert_list = list(dist_cells.keys())
def _mmd2_torch(x, y):
with torch.no_grad():
d = torch.cdist(x[: min(128, len(x))], y[: min(128, len(y))])
gamma = 1.0 / (d.median() ** 2 + 1e-8)
k = lambda a, b: torch.exp(-gamma * torch.cdist(a, b) ** 2)
return k(x, x).mean() + k(y, y).mean() - 2 * k(x, y).mean()
history = []
t0 = time.time()
for epoch in range(cfg.epochs):
model.train()
perm = rng.permutation(pert_train)
ep_losses = []
for b in range(0, len(perm), cfg.batch_size):
batch_idx = perm[b: b + cfg.batch_size]
ctrl_idx = data.sample_controls(batch_idx, cfg.match, rng)
c0 = emb[ctrl_idx]
c1 = emb[batch_idx]
blabels = labels[batch_idx]
g, o, mask, pid = build_pert_tensors(data, list(blabels), device=device)
e = model.encode(g, o, mask, pid)
total, comp = compute_losses(model.flow, e, c0, c1, lam)
if cfg.lam_dist > 0:
terms = []
for p in rng.choice(dist_pert_list, size=min(cfg.n_dist_perts, len(dist_pert_list)),
replace=False):
pc = dist_cells[p]
c1d = emb[rng.choice(pc, min(cfg.dist_n, len(pc)), replace=False)]
c0d = emb[rng.choice(data.control_idx, c1d.shape[0], replace=True)]
gd, od, md, pidd = build_pert_tensors(data, [p], device=device)
ed = model.encode(gd, od, md, pidd).expand(c0d.shape[0], -1)
chat = model.flow.endpoint(c0d, ed)
terms.append(_mmd2_torch(chat, c1d))
Ldist = torch.stack(terms).mean()
total = total + cfg.lam_dist * Ldist
comp["dist"] = Ldist.item()
opt.zero_grad()
total.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
opt.step()
ep_losses.append(comp)
sched.step()
m = {k: float(np.mean([d[k] for d in ep_losses])) for k in ep_losses[0]}
history.append(m)
if verbose and (epoch % log_every == 0 or epoch == cfg.epochs - 1):
print(f" ep {epoch:3d} total={m['total']:.4f} map={m['map']:.4f} "
f"tan={m['tan']:.4f} semi={m['semi']:.4f}")
dur = time.time() - t0
if verbose:
print(f"trained in {dur:.1f}s")
return model, {"history": history, "duration_s": dur, "device": str(device),
"n_train_cells": int(len(pert_train))}
def save_checkpoint(model, cfg: TrainConfig, info: dict, out_dir: str):
os.makedirs(out_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
save_json(asdict(cfg), os.path.join(out_dir, "config.json"))
save_json(info, os.path.join(out_dir, "train_info.json"))
if __name__ == "__main__":
import argparse, dataclasses
ap = argparse.ArgumentParser()
for f in dataclasses.fields(TrainConfig):
if f.name == "components":
ap.add_argument("--components", nargs="+", default=None)
elif f.type == "int | None" or f.name == "device_index":
ap.add_argument(f"--{f.name.replace('_','-')}", type=int, default=None)
else:
typ = type(f.default)
ap.add_argument(f"--{f.name.replace('_','-')}", type=typ, default=f.default)
ap.add_argument("--out", default="experiments/exp_001")
args = ap.parse_args()
kw = {k: v for k, v in vars(args).items() if k != "out" and v is not None}
cfg = TrainConfig(**kw)
model, info = train(cfg)
save_checkpoint(model, cfg, info, args.out)
print("saved ->", args.out)