PIVOT / src /evaluation /baselines.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
9.92 kB
"""baseline forward models.
all baselines are effect predictors in cell-state embedding space:
predict_effect(label) -> delta, and predict_endpoint(c0) = c0 + delta.
this unifies forward eval and the inverse "+ranking" wrappers (rank
candidates by the reward of their predicted endpoint).
implemented: Random, MeanControl, GlobalAverageEffect, Additive,
LinearRidge, NearestCentroid, EndpointMLP. heavy/foundation baselines
are out of scope and handled as n/r rows by the runner, never faked.
"""
from __future__ import annotations
import numpy as np
from src.data.perturb_data import PerturbData
def training_effects(data: PerturbData, train_perts, train_idx) -> dict[str, np.ndarray]:
"""embedding-space effect (pert mean - control mean) per training perturbation."""
train_set = set(train_idx.tolist())
ctrl = np.array([i for i in data.control_idx if i in train_set])
cmean = data.emb[ctrl].mean(0) if len(ctrl) else data.emb[data.control_idx].mean(0)
eff = {}
for p in train_perts:
idx = np.array([i for i in data.pert_to_idx[p] if i in train_set])
if len(idx) == 0:
idx = data.pert_to_idx[p]
eff[p] = data.emb[idx].mean(0) - cmean
return eff, cmean
class EffectPredictor:
name = "base"
def fit(self, data, train_perts, train_idx):
self.data = data
self.eff, self.cmean = training_effects(data, train_perts, train_idx)
self.train_perts = list(train_perts)
self.d = data.d
self._fit_extra()
return self
def _fit_extra(self):
pass
def predict_effect(self, label) -> np.ndarray:
raise NotImplementedError
def predict_endpoint(self, label, c0: np.ndarray) -> np.ndarray:
return c0 + self.predict_effect(label)[None, :]
class Random(EffectPredictor):
name = "Random"
def _fit_extra(self):
self._rng = np.random.default_rng(0)
self._effs = np.stack(list(self.eff.values())) if self.eff else np.zeros((1, self.d))
def predict_effect(self, label):
return self._effs[self._rng.integers(len(self._effs))]
class MeanControl(EffectPredictor):
name = "MeanControl"
def predict_effect(self, label):
return np.zeros(self.d, dtype=np.float32)
class GlobalAverageEffect(EffectPredictor):
name = "AvgPerturbationEffect"
def _fit_extra(self):
self._mean = np.mean(list(self.eff.values()), axis=0) if self.eff else np.zeros(self.d)
def predict_effect(self, label):
return self._mean
class Additive(EffectPredictor):
name = "Additive"
def _fit_extra(self):
# single-gene effects from training singles
self._single = {}
for p, e in self.eff.items():
g = self.data.parse(p)
if len(g) == 1:
self._single[g[0]] = e
self._fallback = np.mean(list(self.eff.values()), axis=0) if self.eff else np.zeros(self.d)
def predict_effect(self, label):
genes = self.data.parse(label)
parts = [self._single[g] for g in genes if g in self._single]
return np.sum(parts, axis=0) if parts else self._fallback
class LinearRidge(EffectPredictor):
name = "LinearResponse"
def _fit_extra(self):
from sklearn.linear_model import Ridge
genes = self.data.genes_vocab
gid = {g: i for i, g in enumerate(genes)}
X = np.zeros((len(self.train_perts), len(genes)), dtype=np.float32)
Y = np.zeros((len(self.train_perts), self.d), dtype=np.float32)
for r, p in enumerate(self.train_perts):
for g in self.data.parse(p):
if g in gid:
X[r, gid[g]] = 1.0
Y[r] = self.eff[p]
self._gid = gid
self._model = Ridge(alpha=1.0).fit(X, Y)
def predict_effect(self, label):
x = np.zeros((1, len(self._gid)), dtype=np.float32)
for g in self.data.parse(label):
if g in self._gid:
x[0, self._gid[g]] = 1.0
return self._model.predict(x)[0]
class NearestCentroid(EffectPredictor):
name = "NearestPerturbationCentroid"
def _fit_extra(self):
self._sets = {p: set(self.data.parse(p)) for p in self.train_perts}
def predict_effect(self, label):
gq = set(self.data.parse(label))
best, best_j = None, -1.0
for p, gs in self._sets.items():
j = len(gq & gs) / max(len(gq | gs), 1)
if j > best_j:
best_j, best = j, p
return self.eff[best] if best is not None else np.zeros(self.d)
class EndpointMLP(EffectPredictor):
name = "EndpointMLP"
def _fit_extra(self):
import torch
import torch.nn as nn
genes = self.data.genes_vocab
gid = {g: i for i, g in enumerate(genes)}
X = np.zeros((len(self.train_perts), len(genes)), dtype=np.float32)
Y = np.zeros((len(self.train_perts), self.d), dtype=np.float32)
for r, p in enumerate(self.train_perts):
for g in self.data.parse(p):
if g in gid:
X[r, gid[g]] = 1.0
Y[r] = self.eff[p]
self._gid = gid
dev = "cuda" if torch.cuda.is_available() else "cpu"
self._dev = dev
net = nn.Sequential(nn.Linear(len(genes), 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, self.d)).to(dev)
Xt = torch.as_tensor(X, device=dev); Yt = torch.as_tensor(Y, device=dev)
opt = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-5)
for _ in range(800):
opt.zero_grad()
loss = ((net(Xt) - Yt) ** 2).sum(-1).mean()
loss.backward(); opt.step()
net.eval()
self._net = net
def predict_effect(self, label):
import torch
x = np.zeros((1, len(self._gid)), dtype=np.float32)
for g in self.data.parse(label):
if g in self._gid:
x[0, self._gid[g]] = 1.0
with torch.no_grad():
return self._net(torch.as_tensor(x, device=self._dev)).cpu().numpy()[0]
class KNNLatent(EffectPredictor):
name = "kNN-latent"
K = 5
def _fit_extra(self):
self._sets = {p: set(self.data.parse(p)) for p in self.train_perts}
def predict_effect(self, label):
gq = set(self.data.parse(label))
sims = sorted(((len(gq & gs) / max(len(gq | gs), 1), p) for p, gs in self._sets.items()),
reverse=True)[: self.K]
if not sims or sims[0][0] == 0:
return np.mean(list(self.eff.values()), axis=0) if self.eff else np.zeros(self.d)
return np.mean([self.eff[p] for _, p in sims], axis=0)
class ConditionalMLP(EffectPredictor):
"""c0-conditional endpoint predictor: mlp([c0, gene multi-hot]) -> c1 in embedding space.
a fair neural competitor to pivot (cell-state dependent, unlike the constant-effect mlp)."""
name = "ConditionalMLP"
def _fit_extra(self):
import torch
import torch.nn as nn
genes = self.data.genes_vocab
gid = {g: i for i, g in enumerate(genes)}
self._gid = gid
train_set = set(self._train_idx.tolist()) if hasattr(self, "_train_idx") else None
# build matched (c0, c1, multihot) training pairs
rng = np.random.default_rng(0)
rows_c0, rows_c1, rows_x = [], [], []
for p in self.train_perts:
idx = self.data.pert_to_idx[p]
c1 = self.data.emb[idx]
ctrl = self.data.sample_controls(idx, "batch", rng)
c0 = self.data.emb[ctrl]
x = np.zeros((len(idx), len(genes)), dtype=np.float32)
for g in self.data.parse(p):
if g in gid:
x[:, gid[g]] = 1.0
rows_c0.append(c0); rows_c1.append(c1); rows_x.append(x)
C0 = np.concatenate(rows_c0); C1 = np.concatenate(rows_c1); X = np.concatenate(rows_x)
dev = "cuda" if torch.cuda.is_available() else "cpu"
self._dev = dev
net = nn.Sequential(nn.Linear(self.d + len(genes), 512), nn.SiLU(),
nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, self.d)).to(dev)
inp = torch.as_tensor(np.concatenate([C0, X], 1), device=dev)
tgt = torch.as_tensor(C1, device=dev)
opt = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-5)
n = inp.shape[0]
for _ in range(40):
perm = torch.randperm(n, device=dev)
for b in range(0, n, 2048):
bi = perm[b:b + 2048]
opt.zero_grad()
loss = ((net(inp[bi]) - tgt[bi]) ** 2).sum(-1).mean()
loss.backward(); opt.step()
net.eval()
self._net = net
def fit(self, data, train_perts, train_idx):
self._train_idx = train_idx
return super().fit(data, train_perts, train_idx)
def predict_endpoint(self, label, c0):
import torch
x = np.zeros((1, len(self._gid)), dtype=np.float32)
for g in self.data.parse(label):
if g in self._gid:
x[0, self._gid[g]] = 1.0
X = np.repeat(x, len(c0), axis=0)
inp = torch.as_tensor(np.concatenate([c0.astype(np.float32), X], 1), device=self._dev)
with torch.no_grad():
return self._net(inp).cpu().numpy()
def predict_effect(self, label):
# mean effect over a control sample (for forward effect-vector metrics)
c0 = self.data.emb[self.data.control_idx[:256]]
return self.predict_endpoint(label, c0).mean(0) - c0.mean(0)
BASELINES = {
b.name: b for b in [
Random(), MeanControl(), GlobalAverageEffect(), Additive(),
LinearRidge(), NearestCentroid(), EndpointMLP(), KNNLatent(), ConditionalMLP(),
]
}
def build_baseline(name: str) -> EffectPredictor:
cls = {b.name: type(b) for b in BASELINES.values()}[name]
return cls()