File size: 7,935 Bytes
3b4941f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | """pivot training loop, algorithm 1.
per minibatch of matched (c0, c1) pairs with perturbation u:
embed perturbations -> sample (s,t) -> interpolants -> L_map;
sample τ -> L_tan; sample (s,r,t) -> L_semi; add L_reg; backprop on (θ, η).
"""
from __future__ import annotations
import os
import time
from dataclasses import dataclass, field, asdict
import numpy as np
import torch
from src.data.perturb_data import PerturbData
from src.data.splits import load_split
from src.models.pivot import PIVOT
from src.training.losses import compute_losses
from src.utils.common import pick_device, save_json, set_seed
@dataclass
class TrainConfig:
dataset: str = "norman"
embedding: str = "pca"
split: str = "perturbation"
rep_mode: str = "gene_op"
match: str = "batch" # control-matching strategy
d_pert: int = 64
hidden: int = 512
depth: int = 4
dropout: float = 0.0
lr: float = 1e-3
weight_decay: float = 1e-5
batch_size: int = 1024
epochs: int = 60
lam_tan: float = 1.0
lam_semi: float = 0.5
lam_reg: float = 1e-4
grad_clip: float = 5.0
train_frac: float = 1.0 # data-scaling ablation
lam_dist: float = 0.0 # distributional flow loss (population mmd) weight
n_dist_perts: int = 4 # perturbations sampled per step for the dist loss
dist_n: int = 64 # cells per population in the dist loss
seed: int = 0
device_index: int | None = None
components: list = field(default_factory=lambda: ["map", "tan", "semi"]) # ablate losses
def _ablate_lambdas(cfg: TrainConfig) -> dict:
lam = {"map": 1.0, "tan": cfg.lam_tan, "semi": cfg.lam_semi, "reg": cfg.lam_reg}
if "map" not in cfg.components:
lam["map"] = 0.0
if "tan" not in cfg.components:
lam["tan"] = 0.0
if "semi" not in cfg.components:
lam["semi"] = 0.0
return lam
def make_model(data: PerturbData, cfg: TrainConfig, device) -> PIVOT:
gene_pathway, n_path = None, 0
if cfg.rep_mode == "gene_pathway_op":
fc = data.functional_clusters(seed=cfg.seed)
gp = np.array([fc.get(g, 0) for g in data.genes_vocab], dtype=np.int64)
gene_pathway, n_path = gp, int(gp.max()) + 1
return PIVOT(
d_state=data.d, n_genes=len(data.genes_vocab), n_ops=len(data.op_vocab),
n_perts=len(data.perturbations), d_pert=cfg.d_pert, hidden=cfg.hidden,
depth=cfg.depth, rep_mode=cfg.rep_mode, gene_pathway=gene_pathway,
n_pathways=n_path, dropout=cfg.dropout,
).to(device)
def train(cfg: TrainConfig, data: PerturbData | None = None, verbose: bool = True, log_every: int = 10):
set_seed(cfg.seed)
device = pick_device(cfg.device_index)
if data is None:
data = PerturbData(os.path.join("data/processed", cfg.dataset), embedding=cfg.embedding)
split = load_split(data.dir, cfg.split)
rng = np.random.default_rng(cfg.seed)
emb = torch.as_tensor(data.emb, device=device)
labels = data.obs["perturbation"].values
# perturbed training cells (controls used only as matching pool)
train_idx = split["train_idx"]
is_ctrl = data.is_control[train_idx]
pert_train = train_idx[~is_ctrl]
if cfg.train_frac < 1.0:
n = max(1, int(cfg.train_frac * len(pert_train)))
pert_train = rng.choice(pert_train, size=n, replace=False)
if verbose:
print(f"train cells: {len(pert_train)} perturbed | device {device} | "
f"split {cfg.split} | match {cfg.match} | rep {cfg.rep_mode}")
model = make_model(data, cfg, device)
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)
lam = _ablate_lambdas(cfg)
from src.models.encoders import build_pert_tensors
# precompute per-perturbation train cell lists for the distributional loss
dist_cells = None
if cfg.lam_dist > 0:
tr_labels = labels[pert_train]
dist_cells = {p: pert_train[tr_labels == p] for p in np.unique(tr_labels)}
dist_cells = {p: v for p, v in dist_cells.items() if len(v) >= 8}
dist_pert_list = list(dist_cells.keys())
def _mmd2_torch(x, y):
with torch.no_grad():
d = torch.cdist(x[: min(128, len(x))], y[: min(128, len(y))])
gamma = 1.0 / (d.median() ** 2 + 1e-8)
k = lambda a, b: torch.exp(-gamma * torch.cdist(a, b) ** 2)
return k(x, x).mean() + k(y, y).mean() - 2 * k(x, y).mean()
history = []
t0 = time.time()
for epoch in range(cfg.epochs):
model.train()
perm = rng.permutation(pert_train)
ep_losses = []
for b in range(0, len(perm), cfg.batch_size):
batch_idx = perm[b: b + cfg.batch_size]
ctrl_idx = data.sample_controls(batch_idx, cfg.match, rng)
c0 = emb[ctrl_idx]
c1 = emb[batch_idx]
blabels = labels[batch_idx]
g, o, mask, pid = build_pert_tensors(data, list(blabels), device=device)
e = model.encode(g, o, mask, pid)
total, comp = compute_losses(model.flow, e, c0, c1, lam)
if cfg.lam_dist > 0:
terms = []
for p in rng.choice(dist_pert_list, size=min(cfg.n_dist_perts, len(dist_pert_list)),
replace=False):
pc = dist_cells[p]
c1d = emb[rng.choice(pc, min(cfg.dist_n, len(pc)), replace=False)]
c0d = emb[rng.choice(data.control_idx, c1d.shape[0], replace=True)]
gd, od, md, pidd = build_pert_tensors(data, [p], device=device)
ed = model.encode(gd, od, md, pidd).expand(c0d.shape[0], -1)
chat = model.flow.endpoint(c0d, ed)
terms.append(_mmd2_torch(chat, c1d))
Ldist = torch.stack(terms).mean()
total = total + cfg.lam_dist * Ldist
comp["dist"] = Ldist.item()
opt.zero_grad()
total.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
opt.step()
ep_losses.append(comp)
sched.step()
m = {k: float(np.mean([d[k] for d in ep_losses])) for k in ep_losses[0]}
history.append(m)
if verbose and (epoch % log_every == 0 or epoch == cfg.epochs - 1):
print(f" ep {epoch:3d} total={m['total']:.4f} map={m['map']:.4f} "
f"tan={m['tan']:.4f} semi={m['semi']:.4f}")
dur = time.time() - t0
if verbose:
print(f"trained in {dur:.1f}s")
return model, {"history": history, "duration_s": dur, "device": str(device),
"n_train_cells": int(len(pert_train))}
def save_checkpoint(model, cfg: TrainConfig, info: dict, out_dir: str):
os.makedirs(out_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
save_json(asdict(cfg), os.path.join(out_dir, "config.json"))
save_json(info, os.path.join(out_dir, "train_info.json"))
if __name__ == "__main__":
import argparse, dataclasses
ap = argparse.ArgumentParser()
for f in dataclasses.fields(TrainConfig):
if f.name == "components":
ap.add_argument("--components", nargs="+", default=None)
elif f.type == "int | None" or f.name == "device_index":
ap.add_argument(f"--{f.name.replace('_','-')}", type=int, default=None)
else:
typ = type(f.default)
ap.add_argument(f"--{f.name.replace('_','-')}", type=typ, default=f.default)
ap.add_argument("--out", default="experiments/exp_001")
args = ap.parse_args()
kw = {k: v for k, v in vars(args).items() if k != "out" and v is not None}
cfg = TrainConfig(**kw)
model, info = train(cfg)
save_checkpoint(model, cfg, info, args.out)
print("saved ->", args.out)
|