| """baseline forward models. |
| |
| all baselines are effect predictors in cell-state embedding space: |
| predict_effect(label) -> delta, and predict_endpoint(c0) = c0 + delta. |
| this unifies forward eval and the inverse "+ranking" wrappers (rank |
| candidates by the reward of their predicted endpoint). |
| |
| implemented: Random, MeanControl, GlobalAverageEffect, Additive, |
| LinearRidge, NearestCentroid, EndpointMLP. heavy/foundation baselines |
| are out of scope and handled as n/r rows by the runner, never faked. |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
|
|
| from src.data.perturb_data import PerturbData |
|
|
|
|
| def training_effects(data: PerturbData, train_perts, train_idx) -> dict[str, np.ndarray]: |
| """embedding-space effect (pert mean - control mean) per training perturbation.""" |
| train_set = set(train_idx.tolist()) |
| ctrl = np.array([i for i in data.control_idx if i in train_set]) |
| cmean = data.emb[ctrl].mean(0) if len(ctrl) else data.emb[data.control_idx].mean(0) |
| eff = {} |
| for p in train_perts: |
| idx = np.array([i for i in data.pert_to_idx[p] if i in train_set]) |
| if len(idx) == 0: |
| idx = data.pert_to_idx[p] |
| eff[p] = data.emb[idx].mean(0) - cmean |
| return eff, cmean |
|
|
|
|
| class EffectPredictor: |
| name = "base" |
|
|
| def fit(self, data, train_perts, train_idx): |
| self.data = data |
| self.eff, self.cmean = training_effects(data, train_perts, train_idx) |
| self.train_perts = list(train_perts) |
| self.d = data.d |
| self._fit_extra() |
| return self |
|
|
| def _fit_extra(self): |
| pass |
|
|
| def predict_effect(self, label) -> np.ndarray: |
| raise NotImplementedError |
|
|
| def predict_endpoint(self, label, c0: np.ndarray) -> np.ndarray: |
| return c0 + self.predict_effect(label)[None, :] |
|
|
|
|
| class Random(EffectPredictor): |
| name = "Random" |
|
|
| def _fit_extra(self): |
| self._rng = np.random.default_rng(0) |
| self._effs = np.stack(list(self.eff.values())) if self.eff else np.zeros((1, self.d)) |
|
|
| def predict_effect(self, label): |
| return self._effs[self._rng.integers(len(self._effs))] |
|
|
|
|
| class MeanControl(EffectPredictor): |
| name = "MeanControl" |
|
|
| def predict_effect(self, label): |
| return np.zeros(self.d, dtype=np.float32) |
|
|
|
|
| class GlobalAverageEffect(EffectPredictor): |
| name = "AvgPerturbationEffect" |
|
|
| def _fit_extra(self): |
| self._mean = np.mean(list(self.eff.values()), axis=0) if self.eff else np.zeros(self.d) |
|
|
| def predict_effect(self, label): |
| return self._mean |
|
|
|
|
| class Additive(EffectPredictor): |
| name = "Additive" |
|
|
| def _fit_extra(self): |
| |
| self._single = {} |
| for p, e in self.eff.items(): |
| g = self.data.parse(p) |
| if len(g) == 1: |
| self._single[g[0]] = e |
| self._fallback = np.mean(list(self.eff.values()), axis=0) if self.eff else np.zeros(self.d) |
|
|
| def predict_effect(self, label): |
| genes = self.data.parse(label) |
| parts = [self._single[g] for g in genes if g in self._single] |
| return np.sum(parts, axis=0) if parts else self._fallback |
|
|
|
|
| class LinearRidge(EffectPredictor): |
| name = "LinearResponse" |
|
|
| def _fit_extra(self): |
| from sklearn.linear_model import Ridge |
|
|
| genes = self.data.genes_vocab |
| gid = {g: i for i, g in enumerate(genes)} |
| X = np.zeros((len(self.train_perts), len(genes)), dtype=np.float32) |
| Y = np.zeros((len(self.train_perts), self.d), dtype=np.float32) |
| for r, p in enumerate(self.train_perts): |
| for g in self.data.parse(p): |
| if g in gid: |
| X[r, gid[g]] = 1.0 |
| Y[r] = self.eff[p] |
| self._gid = gid |
| self._model = Ridge(alpha=1.0).fit(X, Y) |
|
|
| def predict_effect(self, label): |
| x = np.zeros((1, len(self._gid)), dtype=np.float32) |
| for g in self.data.parse(label): |
| if g in self._gid: |
| x[0, self._gid[g]] = 1.0 |
| return self._model.predict(x)[0] |
|
|
|
|
| class NearestCentroid(EffectPredictor): |
| name = "NearestPerturbationCentroid" |
|
|
| def _fit_extra(self): |
| self._sets = {p: set(self.data.parse(p)) for p in self.train_perts} |
|
|
| def predict_effect(self, label): |
| gq = set(self.data.parse(label)) |
| best, best_j = None, -1.0 |
| for p, gs in self._sets.items(): |
| j = len(gq & gs) / max(len(gq | gs), 1) |
| if j > best_j: |
| best_j, best = j, p |
| return self.eff[best] if best is not None else np.zeros(self.d) |
|
|
|
|
| class EndpointMLP(EffectPredictor): |
| name = "EndpointMLP" |
|
|
| def _fit_extra(self): |
| import torch |
| import torch.nn as nn |
|
|
| genes = self.data.genes_vocab |
| gid = {g: i for i, g in enumerate(genes)} |
| X = np.zeros((len(self.train_perts), len(genes)), dtype=np.float32) |
| Y = np.zeros((len(self.train_perts), self.d), dtype=np.float32) |
| for r, p in enumerate(self.train_perts): |
| for g in self.data.parse(p): |
| if g in gid: |
| X[r, gid[g]] = 1.0 |
| Y[r] = self.eff[p] |
| self._gid = gid |
| dev = "cuda" if torch.cuda.is_available() else "cpu" |
| self._dev = dev |
| net = nn.Sequential(nn.Linear(len(genes), 256), nn.ReLU(), |
| nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, self.d)).to(dev) |
| Xt = torch.as_tensor(X, device=dev); Yt = torch.as_tensor(Y, device=dev) |
| opt = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-5) |
| for _ in range(800): |
| opt.zero_grad() |
| loss = ((net(Xt) - Yt) ** 2).sum(-1).mean() |
| loss.backward(); opt.step() |
| net.eval() |
| self._net = net |
|
|
| def predict_effect(self, label): |
| import torch |
|
|
| x = np.zeros((1, len(self._gid)), dtype=np.float32) |
| for g in self.data.parse(label): |
| if g in self._gid: |
| x[0, self._gid[g]] = 1.0 |
| with torch.no_grad(): |
| return self._net(torch.as_tensor(x, device=self._dev)).cpu().numpy()[0] |
|
|
|
|
| class KNNLatent(EffectPredictor): |
| name = "kNN-latent" |
| K = 5 |
|
|
| def _fit_extra(self): |
| self._sets = {p: set(self.data.parse(p)) for p in self.train_perts} |
|
|
| def predict_effect(self, label): |
| gq = set(self.data.parse(label)) |
| sims = sorted(((len(gq & gs) / max(len(gq | gs), 1), p) for p, gs in self._sets.items()), |
| reverse=True)[: self.K] |
| if not sims or sims[0][0] == 0: |
| return np.mean(list(self.eff.values()), axis=0) if self.eff else np.zeros(self.d) |
| return np.mean([self.eff[p] for _, p in sims], axis=0) |
|
|
|
|
| class ConditionalMLP(EffectPredictor): |
| """c0-conditional endpoint predictor: mlp([c0, gene multi-hot]) -> c1 in embedding space. |
| a fair neural competitor to pivot (cell-state dependent, unlike the constant-effect mlp).""" |
| name = "ConditionalMLP" |
|
|
| def _fit_extra(self): |
| import torch |
| import torch.nn as nn |
|
|
| genes = self.data.genes_vocab |
| gid = {g: i for i, g in enumerate(genes)} |
| self._gid = gid |
| train_set = set(self._train_idx.tolist()) if hasattr(self, "_train_idx") else None |
| |
| rng = np.random.default_rng(0) |
| rows_c0, rows_c1, rows_x = [], [], [] |
| for p in self.train_perts: |
| idx = self.data.pert_to_idx[p] |
| c1 = self.data.emb[idx] |
| ctrl = self.data.sample_controls(idx, "batch", rng) |
| c0 = self.data.emb[ctrl] |
| x = np.zeros((len(idx), len(genes)), dtype=np.float32) |
| for g in self.data.parse(p): |
| if g in gid: |
| x[:, gid[g]] = 1.0 |
| rows_c0.append(c0); rows_c1.append(c1); rows_x.append(x) |
| C0 = np.concatenate(rows_c0); C1 = np.concatenate(rows_c1); X = np.concatenate(rows_x) |
| dev = "cuda" if torch.cuda.is_available() else "cpu" |
| self._dev = dev |
| net = nn.Sequential(nn.Linear(self.d + len(genes), 512), nn.SiLU(), |
| nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, self.d)).to(dev) |
| inp = torch.as_tensor(np.concatenate([C0, X], 1), device=dev) |
| tgt = torch.as_tensor(C1, device=dev) |
| opt = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-5) |
| n = inp.shape[0] |
| for _ in range(40): |
| perm = torch.randperm(n, device=dev) |
| for b in range(0, n, 2048): |
| bi = perm[b:b + 2048] |
| opt.zero_grad() |
| loss = ((net(inp[bi]) - tgt[bi]) ** 2).sum(-1).mean() |
| loss.backward(); opt.step() |
| net.eval() |
| self._net = net |
|
|
| def fit(self, data, train_perts, train_idx): |
| self._train_idx = train_idx |
| return super().fit(data, train_perts, train_idx) |
|
|
| def predict_endpoint(self, label, c0): |
| import torch |
| x = np.zeros((1, len(self._gid)), dtype=np.float32) |
| for g in self.data.parse(label): |
| if g in self._gid: |
| x[0, self._gid[g]] = 1.0 |
| X = np.repeat(x, len(c0), axis=0) |
| inp = torch.as_tensor(np.concatenate([c0.astype(np.float32), X], 1), device=self._dev) |
| with torch.no_grad(): |
| return self._net(inp).cpu().numpy() |
|
|
| def predict_effect(self, label): |
| |
| c0 = self.data.emb[self.data.control_idx[:256]] |
| return self.predict_endpoint(label, c0).mean(0) - c0.mean(0) |
|
|
|
|
| BASELINES = { |
| b.name: b for b in [ |
| Random(), MeanControl(), GlobalAverageEffect(), Additive(), |
| LinearRidge(), NearestCentroid(), EndpointMLP(), KNNLatent(), ConditionalMLP(), |
| ] |
| } |
|
|
|
|
| def build_baseline(name: str) -> EffectPredictor: |
| cls = {b.name: type(b) for b in BASELINES.values()}[name] |
| return cls() |
|
|