File size: 8,413 Bytes
3b4941f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """desired-state nomination eval (tables 4-7,13,14).
a target state is induced by a held-out perturbation p: c* = centroid of p's cells.
given c*, a method nominates perturbations; we score exact/functional recovery,
ranking quality, endpoint distance, and specificity.
ranking uses a numpy reward scorer (works for any predictor). pivot reward guidance
(Alg 4) also uses the differentiable torch reward and the model.
"""
from __future__ import annotations
import numpy as np
import torch
from src.evaluation import inference as inf
from src.evaluation import metrics as M
from src.evaluation.rewards import Reward
# numpy reward scorers (for ranking; mirror rewards.py)
def numpy_score(pop: np.ndarray, kind: str, c_star=None, target_sample=None, classifier=None,
device="cpu", control_ref=None):
if kind in ("target", "centroid", "combined"):
s = -np.mean(np.sum((pop - c_star) ** 2, axis=1))
if kind != "combined":
return float(s)
if kind == "cosine":
cref = control_ref if control_ref is not None else pop.mean(0)
pe = pop - cref
te = c_star - cref
pe = pe / (np.linalg.norm(pe, axis=1, keepdims=True) + 1e-8)
te = te / (np.linalg.norm(te) + 1e-8)
return float(np.mean(pe @ te))
if kind == "nn_target":
from scipy.spatial.distance import cdist
return float(-np.mean(cdist(pop, target_sample).min(axis=1) ** 2))
if kind == "mmd":
return float(-M.mmd_rbf(pop, target_sample))
if kind == "wasserstein":
return float(-M.sliced_wasserstein(pop, target_sample))
if kind in ("classifier", "combined"):
with torch.no_grad():
logit = classifier(torch.as_tensor(pop, dtype=torch.float32, device=device))
rc = torch.nn.functional.logsigmoid(logit).mean().item()
if kind == "classifier":
return float(rc)
return float(s + rc) # combined
raise ValueError(kind)
_PER_ROW_REWARDS = ("centroid", "target", "nn_target", "classifier", "combined", "cosine")
def rank_candidates(predictor, candidates, c0, score_kwargs):
"""rank candidates by reward of predicted endpoint. uses the batched gpu path for
pivot with per-row rewards; numpy per-candidate otherwise (baselines, mmd/wass)."""
from src.experiments.predictors import PivotPredictor
from src.evaluation import inference as inf
kind = score_kwargs["kind"]
if isinstance(predictor, PivotPredictor) and kind in _PER_ROW_REWARDS:
dev = score_kwargs.get("device", "cpu")
rew = Reward(kind, target_c=score_kwargs.get("c_star"),
target_sample=score_kwargs.get("target_sample"),
classifier=score_kwargs.get("classifier"), device=dev,
control_ref=score_kwargs.get("control_ref"))
c0t = torch.as_tensor(c0, dtype=torch.float32, device=dev)
ranked = inf.endpoint_ranking(predictor.model, predictor.data, candidates, c0t, rew, device=dev)
return [l for l, _ in ranked], np.array([s for _, s in ranked])
scores = np.array([numpy_score(predictor.population(c, c0), **score_kwargs) for c in candidates])
order = np.argsort(-scores)
return [candidates[i] for i in order], scores[order]
def build_target(data, p, control_mean_emb):
return data.emb[data.pert_to_idx[p]].mean(0), data.emb[data.pert_to_idx[p]]
def evaluate_nomination(predictor, data, targets, candidates, control_pool,
reward_kind="centroid", method="ranking", n_ctrl=128,
gene_cluster=None, model=None, device="cpu",
guidance_steps=25, guidance_step=0.5, k_nearest=10,
classifier=None, seed=0,
guidance_init="warm", guidance_normalize=True, rerank=True):
"""evaluate single-perturbation nomination. returns aggregated metric dict.
method: 'ranking' (Alg 3) | 'guidance' (Alg 4+5, pivot only).
guidance_init: 'warm' (top-ranked) | 'random' | 'mean_top' (mean of top-k cand embeddings).
guidance_normalize: normalize the jacobian-pullback step (Alg 4).
rerank: if true, rerank k-nn by endpoint reward (Alg 5); else project to single nearest.
"""
rng = np.random.default_rng(seed)
ctrl_all = data.emb[control_pool]
ctrl_mean = ctrl_all.mean(0)
res = {k: [] for k in ["top1", "top5", "ndcg", "func_top5", "func_ndcg",
"rank", "endpoint_dist", "spec_margin", "off_dist", "clf"]}
for p in targets:
c_star, tgt_sample = build_target(data, p, ctrl_mean)
c0 = ctrl_all[rng.choice(len(ctrl_all), min(n_ctrl, len(ctrl_all)), replace=False)]
sk = dict(kind=reward_kind, c_star=c_star, target_sample=tgt_sample,
classifier=classifier, device=device, control_ref=ctrl_mean)
ranked, _ = rank_candidates(predictor, candidates, c0, sk)
if method == "guidance":
assert model is not None
c0t = torch.as_tensor(c0, dtype=torch.float32, device=device)
rew = Reward(reward_kind if reward_kind != "combined" else "centroid",
target_c=c_star, target_sample=tgt_sample,
classifier=classifier, device=device, control_ref=ctrl_mean)
if guidance_init == "random":
e_init = inf.encode_label(model, data, candidates[rng.integers(len(candidates))], device)
elif guidance_init == "mean_top":
from src.models.pivot import candidate_embeddings
E = candidate_embeddings(model, data, ranked[:k_nearest], device)
e_init = torch.as_tensor(E.mean(0, keepdims=True), device=device)
else: # warm
e_init = inf.encode_label(model, data, ranked[0], device)
e_star = inf.reward_guidance(model, c0t, rew, e_init, steps=guidance_steps,
step_size=guidance_step, normalize=guidance_normalize)
if rerank:
top = inf.project_and_rerank(model, data, candidates, e_star, c0t, rew,
k_nearest=k_nearest, topk=k_nearest, device=device)
guided = [l for l, _ in top]
else:
from src.models.pivot import candidate_embeddings
E = torch.as_tensor(candidate_embeddings(model, data, candidates, device), device=device)
nn = torch.argsort(torch.cdist(e_star.view(1, -1), E).squeeze(0)).cpu().numpy()
guided = [candidates[i] for i in nn[:k_nearest]]
ranked = guided + [r for r in ranked if r not in guided] # full order for rank metrics
# metrics
res["top1"].append(M.top_k_accuracy(ranked, p, 1))
res["top5"].append(M.top_k_accuracy(ranked, p, 5))
res["ndcg"].append(M.ndcg_at_k(ranked, M.exact_relevance(p), 10))
res["rank"].append(M.rank_of_true(ranked, p))
if gene_cluster is not None:
frel = M.functional_relevance(p, data.parse, gene_cluster, candidates)
res["func_top5"].append(float(any(frel.get(l, 0) >= 0.5 for l in ranked[:5])))
res["func_ndcg"].append(M.ndcg_at_k(ranked, frel, 10))
# endpoint distance of nominated top-1 + specificity
top1_pop = predictor.population(ranked[0], c0)
res["endpoint_dist"].append(float(np.linalg.norm(top1_pop.mean(0) - c_star)))
# off-state distance: mean dist of top-1 endpoint to other perturbation centroids
others = rng.choice([q for q in candidates if q != p], size=min(20, len(candidates) - 1),
replace=False)
off = np.mean([np.linalg.norm(top1_pop.mean(0) - data.emb[data.pert_to_idx[q]].mean(0))
for q in others])
res["off_dist"].append(float(off))
res["spec_margin"].append(float(off - np.linalg.norm(top1_pop.mean(0) - c_star)))
if classifier is not None:
with torch.no_grad():
res["clf"].append(float(torch.sigmoid(
classifier(torch.as_tensor(top1_pop, dtype=torch.float32, device=device))
).mean().item()))
agg = {k: (float(np.mean(v)) if v else None) for k, v in res.items()}
agg["_per"] = res
agg["n_targets"] = len(targets)
agg["random_top5"] = 5.0 / len(candidates)
return agg
|