"""baseline forward models. all baselines are effect predictors in cell-state embedding space: predict_effect(label) -> delta, and predict_endpoint(c0) = c0 + delta. this unifies forward eval and the inverse "+ranking" wrappers (rank candidates by the reward of their predicted endpoint). implemented: Random, MeanControl, GlobalAverageEffect, Additive, LinearRidge, NearestCentroid, EndpointMLP. heavy/foundation baselines are out of scope and handled as n/r rows by the runner, never faked. """ from __future__ import annotations import numpy as np from src.data.perturb_data import PerturbData def training_effects(data: PerturbData, train_perts, train_idx) -> dict[str, np.ndarray]: """embedding-space effect (pert mean - control mean) per training perturbation.""" train_set = set(train_idx.tolist()) ctrl = np.array([i for i in data.control_idx if i in train_set]) cmean = data.emb[ctrl].mean(0) if len(ctrl) else data.emb[data.control_idx].mean(0) eff = {} for p in train_perts: idx = np.array([i for i in data.pert_to_idx[p] if i in train_set]) if len(idx) == 0: idx = data.pert_to_idx[p] eff[p] = data.emb[idx].mean(0) - cmean return eff, cmean class EffectPredictor: name = "base" def fit(self, data, train_perts, train_idx): self.data = data self.eff, self.cmean = training_effects(data, train_perts, train_idx) self.train_perts = list(train_perts) self.d = data.d self._fit_extra() return self def _fit_extra(self): pass def predict_effect(self, label) -> np.ndarray: raise NotImplementedError def predict_endpoint(self, label, c0: np.ndarray) -> np.ndarray: return c0 + self.predict_effect(label)[None, :] class Random(EffectPredictor): name = "Random" def _fit_extra(self): self._rng = np.random.default_rng(0) self._effs = np.stack(list(self.eff.values())) if self.eff else np.zeros((1, self.d)) def predict_effect(self, label): return self._effs[self._rng.integers(len(self._effs))] class MeanControl(EffectPredictor): name = "MeanControl" def predict_effect(self, label): return np.zeros(self.d, dtype=np.float32) class GlobalAverageEffect(EffectPredictor): name = "AvgPerturbationEffect" def _fit_extra(self): self._mean = np.mean(list(self.eff.values()), axis=0) if self.eff else np.zeros(self.d) def predict_effect(self, label): return self._mean class Additive(EffectPredictor): name = "Additive" def _fit_extra(self): # single-gene effects from training singles self._single = {} for p, e in self.eff.items(): g = self.data.parse(p) if len(g) == 1: self._single[g[0]] = e self._fallback = np.mean(list(self.eff.values()), axis=0) if self.eff else np.zeros(self.d) def predict_effect(self, label): genes = self.data.parse(label) parts = [self._single[g] for g in genes if g in self._single] return np.sum(parts, axis=0) if parts else self._fallback class LinearRidge(EffectPredictor): name = "LinearResponse" def _fit_extra(self): from sklearn.linear_model import Ridge genes = self.data.genes_vocab gid = {g: i for i, g in enumerate(genes)} X = np.zeros((len(self.train_perts), len(genes)), dtype=np.float32) Y = np.zeros((len(self.train_perts), self.d), dtype=np.float32) for r, p in enumerate(self.train_perts): for g in self.data.parse(p): if g in gid: X[r, gid[g]] = 1.0 Y[r] = self.eff[p] self._gid = gid self._model = Ridge(alpha=1.0).fit(X, Y) def predict_effect(self, label): x = np.zeros((1, len(self._gid)), dtype=np.float32) for g in self.data.parse(label): if g in self._gid: x[0, self._gid[g]] = 1.0 return self._model.predict(x)[0] class NearestCentroid(EffectPredictor): name = "NearestPerturbationCentroid" def _fit_extra(self): self._sets = {p: set(self.data.parse(p)) for p in self.train_perts} def predict_effect(self, label): gq = set(self.data.parse(label)) best, best_j = None, -1.0 for p, gs in self._sets.items(): j = len(gq & gs) / max(len(gq | gs), 1) if j > best_j: best_j, best = j, p return self.eff[best] if best is not None else np.zeros(self.d) class EndpointMLP(EffectPredictor): name = "EndpointMLP" def _fit_extra(self): import torch import torch.nn as nn genes = self.data.genes_vocab gid = {g: i for i, g in enumerate(genes)} X = np.zeros((len(self.train_perts), len(genes)), dtype=np.float32) Y = np.zeros((len(self.train_perts), self.d), dtype=np.float32) for r, p in enumerate(self.train_perts): for g in self.data.parse(p): if g in gid: X[r, gid[g]] = 1.0 Y[r] = self.eff[p] self._gid = gid dev = "cuda" if torch.cuda.is_available() else "cpu" self._dev = dev net = nn.Sequential(nn.Linear(len(genes), 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, self.d)).to(dev) Xt = torch.as_tensor(X, device=dev); Yt = torch.as_tensor(Y, device=dev) opt = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-5) for _ in range(800): opt.zero_grad() loss = ((net(Xt) - Yt) ** 2).sum(-1).mean() loss.backward(); opt.step() net.eval() self._net = net def predict_effect(self, label): import torch x = np.zeros((1, len(self._gid)), dtype=np.float32) for g in self.data.parse(label): if g in self._gid: x[0, self._gid[g]] = 1.0 with torch.no_grad(): return self._net(torch.as_tensor(x, device=self._dev)).cpu().numpy()[0] class KNNLatent(EffectPredictor): name = "kNN-latent" K = 5 def _fit_extra(self): self._sets = {p: set(self.data.parse(p)) for p in self.train_perts} def predict_effect(self, label): gq = set(self.data.parse(label)) sims = sorted(((len(gq & gs) / max(len(gq | gs), 1), p) for p, gs in self._sets.items()), reverse=True)[: self.K] if not sims or sims[0][0] == 0: return np.mean(list(self.eff.values()), axis=0) if self.eff else np.zeros(self.d) return np.mean([self.eff[p] for _, p in sims], axis=0) class ConditionalMLP(EffectPredictor): """c0-conditional endpoint predictor: mlp([c0, gene multi-hot]) -> c1 in embedding space. a fair neural competitor to pivot (cell-state dependent, unlike the constant-effect mlp).""" name = "ConditionalMLP" def _fit_extra(self): import torch import torch.nn as nn genes = self.data.genes_vocab gid = {g: i for i, g in enumerate(genes)} self._gid = gid train_set = set(self._train_idx.tolist()) if hasattr(self, "_train_idx") else None # build matched (c0, c1, multihot) training pairs rng = np.random.default_rng(0) rows_c0, rows_c1, rows_x = [], [], [] for p in self.train_perts: idx = self.data.pert_to_idx[p] c1 = self.data.emb[idx] ctrl = self.data.sample_controls(idx, "batch", rng) c0 = self.data.emb[ctrl] x = np.zeros((len(idx), len(genes)), dtype=np.float32) for g in self.data.parse(p): if g in gid: x[:, gid[g]] = 1.0 rows_c0.append(c0); rows_c1.append(c1); rows_x.append(x) C0 = np.concatenate(rows_c0); C1 = np.concatenate(rows_c1); X = np.concatenate(rows_x) dev = "cuda" if torch.cuda.is_available() else "cpu" self._dev = dev net = nn.Sequential(nn.Linear(self.d + len(genes), 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, self.d)).to(dev) inp = torch.as_tensor(np.concatenate([C0, X], 1), device=dev) tgt = torch.as_tensor(C1, device=dev) opt = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-5) n = inp.shape[0] for _ in range(40): perm = torch.randperm(n, device=dev) for b in range(0, n, 2048): bi = perm[b:b + 2048] opt.zero_grad() loss = ((net(inp[bi]) - tgt[bi]) ** 2).sum(-1).mean() loss.backward(); opt.step() net.eval() self._net = net def fit(self, data, train_perts, train_idx): self._train_idx = train_idx return super().fit(data, train_perts, train_idx) def predict_endpoint(self, label, c0): import torch x = np.zeros((1, len(self._gid)), dtype=np.float32) for g in self.data.parse(label): if g in self._gid: x[0, self._gid[g]] = 1.0 X = np.repeat(x, len(c0), axis=0) inp = torch.as_tensor(np.concatenate([c0.astype(np.float32), X], 1), device=self._dev) with torch.no_grad(): return self._net(inp).cpu().numpy() def predict_effect(self, label): # mean effect over a control sample (for forward effect-vector metrics) c0 = self.data.emb[self.data.control_idx[:256]] return self.predict_endpoint(label, c0).mean(0) - c0.mean(0) BASELINES = { b.name: b for b in [ Random(), MeanControl(), GlobalAverageEffect(), Additive(), LinearRidge(), NearestCentroid(), EndpointMLP(), KNNLatent(), ConditionalMLP(), ] } def build_baseline(name: str) -> EffectPredictor: cls = {b.name: type(b) for b in BASELINES.values()}[name] return cls()