PIVOT / src /experiments /nomination_eval.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
8.41 kB
"""desired-state nomination eval (tables 4-7,13,14).
a target state is induced by a held-out perturbation p: c* = centroid of p's cells.
given c*, a method nominates perturbations; we score exact/functional recovery,
ranking quality, endpoint distance, and specificity.
ranking uses a numpy reward scorer (works for any predictor). pivot reward guidance
(Alg 4) also uses the differentiable torch reward and the model.
"""
from __future__ import annotations
import numpy as np
import torch
from src.evaluation import inference as inf
from src.evaluation import metrics as M
from src.evaluation.rewards import Reward
# numpy reward scorers (for ranking; mirror rewards.py)
def numpy_score(pop: np.ndarray, kind: str, c_star=None, target_sample=None, classifier=None,
device="cpu", control_ref=None):
if kind in ("target", "centroid", "combined"):
s = -np.mean(np.sum((pop - c_star) ** 2, axis=1))
if kind != "combined":
return float(s)
if kind == "cosine":
cref = control_ref if control_ref is not None else pop.mean(0)
pe = pop - cref
te = c_star - cref
pe = pe / (np.linalg.norm(pe, axis=1, keepdims=True) + 1e-8)
te = te / (np.linalg.norm(te) + 1e-8)
return float(np.mean(pe @ te))
if kind == "nn_target":
from scipy.spatial.distance import cdist
return float(-np.mean(cdist(pop, target_sample).min(axis=1) ** 2))
if kind == "mmd":
return float(-M.mmd_rbf(pop, target_sample))
if kind == "wasserstein":
return float(-M.sliced_wasserstein(pop, target_sample))
if kind in ("classifier", "combined"):
with torch.no_grad():
logit = classifier(torch.as_tensor(pop, dtype=torch.float32, device=device))
rc = torch.nn.functional.logsigmoid(logit).mean().item()
if kind == "classifier":
return float(rc)
return float(s + rc) # combined
raise ValueError(kind)
_PER_ROW_REWARDS = ("centroid", "target", "nn_target", "classifier", "combined", "cosine")
def rank_candidates(predictor, candidates, c0, score_kwargs):
"""rank candidates by reward of predicted endpoint. uses the batched gpu path for
pivot with per-row rewards; numpy per-candidate otherwise (baselines, mmd/wass)."""
from src.experiments.predictors import PivotPredictor
from src.evaluation import inference as inf
kind = score_kwargs["kind"]
if isinstance(predictor, PivotPredictor) and kind in _PER_ROW_REWARDS:
dev = score_kwargs.get("device", "cpu")
rew = Reward(kind, target_c=score_kwargs.get("c_star"),
target_sample=score_kwargs.get("target_sample"),
classifier=score_kwargs.get("classifier"), device=dev,
control_ref=score_kwargs.get("control_ref"))
c0t = torch.as_tensor(c0, dtype=torch.float32, device=dev)
ranked = inf.endpoint_ranking(predictor.model, predictor.data, candidates, c0t, rew, device=dev)
return [l for l, _ in ranked], np.array([s for _, s in ranked])
scores = np.array([numpy_score(predictor.population(c, c0), **score_kwargs) for c in candidates])
order = np.argsort(-scores)
return [candidates[i] for i in order], scores[order]
def build_target(data, p, control_mean_emb):
return data.emb[data.pert_to_idx[p]].mean(0), data.emb[data.pert_to_idx[p]]
def evaluate_nomination(predictor, data, targets, candidates, control_pool,
reward_kind="centroid", method="ranking", n_ctrl=128,
gene_cluster=None, model=None, device="cpu",
guidance_steps=25, guidance_step=0.5, k_nearest=10,
classifier=None, seed=0,
guidance_init="warm", guidance_normalize=True, rerank=True):
"""evaluate single-perturbation nomination. returns aggregated metric dict.
method: 'ranking' (Alg 3) | 'guidance' (Alg 4+5, pivot only).
guidance_init: 'warm' (top-ranked) | 'random' | 'mean_top' (mean of top-k cand embeddings).
guidance_normalize: normalize the jacobian-pullback step (Alg 4).
rerank: if true, rerank k-nn by endpoint reward (Alg 5); else project to single nearest.
"""
rng = np.random.default_rng(seed)
ctrl_all = data.emb[control_pool]
ctrl_mean = ctrl_all.mean(0)
res = {k: [] for k in ["top1", "top5", "ndcg", "func_top5", "func_ndcg",
"rank", "endpoint_dist", "spec_margin", "off_dist", "clf"]}
for p in targets:
c_star, tgt_sample = build_target(data, p, ctrl_mean)
c0 = ctrl_all[rng.choice(len(ctrl_all), min(n_ctrl, len(ctrl_all)), replace=False)]
sk = dict(kind=reward_kind, c_star=c_star, target_sample=tgt_sample,
classifier=classifier, device=device, control_ref=ctrl_mean)
ranked, _ = rank_candidates(predictor, candidates, c0, sk)
if method == "guidance":
assert model is not None
c0t = torch.as_tensor(c0, dtype=torch.float32, device=device)
rew = Reward(reward_kind if reward_kind != "combined" else "centroid",
target_c=c_star, target_sample=tgt_sample,
classifier=classifier, device=device, control_ref=ctrl_mean)
if guidance_init == "random":
e_init = inf.encode_label(model, data, candidates[rng.integers(len(candidates))], device)
elif guidance_init == "mean_top":
from src.models.pivot import candidate_embeddings
E = candidate_embeddings(model, data, ranked[:k_nearest], device)
e_init = torch.as_tensor(E.mean(0, keepdims=True), device=device)
else: # warm
e_init = inf.encode_label(model, data, ranked[0], device)
e_star = inf.reward_guidance(model, c0t, rew, e_init, steps=guidance_steps,
step_size=guidance_step, normalize=guidance_normalize)
if rerank:
top = inf.project_and_rerank(model, data, candidates, e_star, c0t, rew,
k_nearest=k_nearest, topk=k_nearest, device=device)
guided = [l for l, _ in top]
else:
from src.models.pivot import candidate_embeddings
E = torch.as_tensor(candidate_embeddings(model, data, candidates, device), device=device)
nn = torch.argsort(torch.cdist(e_star.view(1, -1), E).squeeze(0)).cpu().numpy()
guided = [candidates[i] for i in nn[:k_nearest]]
ranked = guided + [r for r in ranked if r not in guided] # full order for rank metrics
# metrics
res["top1"].append(M.top_k_accuracy(ranked, p, 1))
res["top5"].append(M.top_k_accuracy(ranked, p, 5))
res["ndcg"].append(M.ndcg_at_k(ranked, M.exact_relevance(p), 10))
res["rank"].append(M.rank_of_true(ranked, p))
if gene_cluster is not None:
frel = M.functional_relevance(p, data.parse, gene_cluster, candidates)
res["func_top5"].append(float(any(frel.get(l, 0) >= 0.5 for l in ranked[:5])))
res["func_ndcg"].append(M.ndcg_at_k(ranked, frel, 10))
# endpoint distance of nominated top-1 + specificity
top1_pop = predictor.population(ranked[0], c0)
res["endpoint_dist"].append(float(np.linalg.norm(top1_pop.mean(0) - c_star)))
# off-state distance: mean dist of top-1 endpoint to other perturbation centroids
others = rng.choice([q for q in candidates if q != p], size=min(20, len(candidates) - 1),
replace=False)
off = np.mean([np.linalg.norm(top1_pop.mean(0) - data.emb[data.pert_to_idx[q]].mean(0))
for q in others])
res["off_dist"].append(float(off))
res["spec_margin"].append(float(off - np.linalg.norm(top1_pop.mean(0) - c_star)))
if classifier is not None:
with torch.no_grad():
res["clf"].append(float(torch.sigmoid(
classifier(torch.as_tensor(top1_pop, dtype=torch.float32, device=device))
).mean().item()))
agg = {k: (float(np.mean(v)) if v else None) for k, v in res.items()}
agg["_per"] = res
agg["n_targets"] = len(targets)
agg["random_top5"] = 5.0 / len(candidates)
return agg