| """desired-state nomination eval (tables 4-7,13,14). |
| |
| a target state is induced by a held-out perturbation p: c* = centroid of p's cells. |
| given c*, a method nominates perturbations; we score exact/functional recovery, |
| ranking quality, endpoint distance, and specificity. |
| |
| ranking uses a numpy reward scorer (works for any predictor). pivot reward guidance |
| (Alg 4) also uses the differentiable torch reward and the model. |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
|
|
| from src.evaluation import inference as inf |
| from src.evaluation import metrics as M |
| from src.evaluation.rewards import Reward |
|
|
|
|
| |
| def numpy_score(pop: np.ndarray, kind: str, c_star=None, target_sample=None, classifier=None, |
| device="cpu", control_ref=None): |
| if kind in ("target", "centroid", "combined"): |
| s = -np.mean(np.sum((pop - c_star) ** 2, axis=1)) |
| if kind != "combined": |
| return float(s) |
| if kind == "cosine": |
| cref = control_ref if control_ref is not None else pop.mean(0) |
| pe = pop - cref |
| te = c_star - cref |
| pe = pe / (np.linalg.norm(pe, axis=1, keepdims=True) + 1e-8) |
| te = te / (np.linalg.norm(te) + 1e-8) |
| return float(np.mean(pe @ te)) |
| if kind == "nn_target": |
| from scipy.spatial.distance import cdist |
| return float(-np.mean(cdist(pop, target_sample).min(axis=1) ** 2)) |
| if kind == "mmd": |
| return float(-M.mmd_rbf(pop, target_sample)) |
| if kind == "wasserstein": |
| return float(-M.sliced_wasserstein(pop, target_sample)) |
| if kind in ("classifier", "combined"): |
| with torch.no_grad(): |
| logit = classifier(torch.as_tensor(pop, dtype=torch.float32, device=device)) |
| rc = torch.nn.functional.logsigmoid(logit).mean().item() |
| if kind == "classifier": |
| return float(rc) |
| return float(s + rc) |
| raise ValueError(kind) |
|
|
|
|
| _PER_ROW_REWARDS = ("centroid", "target", "nn_target", "classifier", "combined", "cosine") |
|
|
|
|
| def rank_candidates(predictor, candidates, c0, score_kwargs): |
| """rank candidates by reward of predicted endpoint. uses the batched gpu path for |
| pivot with per-row rewards; numpy per-candidate otherwise (baselines, mmd/wass).""" |
| from src.experiments.predictors import PivotPredictor |
| from src.evaluation import inference as inf |
|
|
| kind = score_kwargs["kind"] |
| if isinstance(predictor, PivotPredictor) and kind in _PER_ROW_REWARDS: |
| dev = score_kwargs.get("device", "cpu") |
| rew = Reward(kind, target_c=score_kwargs.get("c_star"), |
| target_sample=score_kwargs.get("target_sample"), |
| classifier=score_kwargs.get("classifier"), device=dev, |
| control_ref=score_kwargs.get("control_ref")) |
| c0t = torch.as_tensor(c0, dtype=torch.float32, device=dev) |
| ranked = inf.endpoint_ranking(predictor.model, predictor.data, candidates, c0t, rew, device=dev) |
| return [l for l, _ in ranked], np.array([s for _, s in ranked]) |
| scores = np.array([numpy_score(predictor.population(c, c0), **score_kwargs) for c in candidates]) |
| order = np.argsort(-scores) |
| return [candidates[i] for i in order], scores[order] |
|
|
|
|
| def build_target(data, p, control_mean_emb): |
| return data.emb[data.pert_to_idx[p]].mean(0), data.emb[data.pert_to_idx[p]] |
|
|
|
|
| def evaluate_nomination(predictor, data, targets, candidates, control_pool, |
| reward_kind="centroid", method="ranking", n_ctrl=128, |
| gene_cluster=None, model=None, device="cpu", |
| guidance_steps=25, guidance_step=0.5, k_nearest=10, |
| classifier=None, seed=0, |
| guidance_init="warm", guidance_normalize=True, rerank=True): |
| """evaluate single-perturbation nomination. returns aggregated metric dict. |
| |
| method: 'ranking' (Alg 3) | 'guidance' (Alg 4+5, pivot only). |
| guidance_init: 'warm' (top-ranked) | 'random' | 'mean_top' (mean of top-k cand embeddings). |
| guidance_normalize: normalize the jacobian-pullback step (Alg 4). |
| rerank: if true, rerank k-nn by endpoint reward (Alg 5); else project to single nearest. |
| """ |
| rng = np.random.default_rng(seed) |
| ctrl_all = data.emb[control_pool] |
| ctrl_mean = ctrl_all.mean(0) |
| res = {k: [] for k in ["top1", "top5", "ndcg", "func_top5", "func_ndcg", |
| "rank", "endpoint_dist", "spec_margin", "off_dist", "clf"]} |
| for p in targets: |
| c_star, tgt_sample = build_target(data, p, ctrl_mean) |
| c0 = ctrl_all[rng.choice(len(ctrl_all), min(n_ctrl, len(ctrl_all)), replace=False)] |
| sk = dict(kind=reward_kind, c_star=c_star, target_sample=tgt_sample, |
| classifier=classifier, device=device, control_ref=ctrl_mean) |
|
|
| ranked, _ = rank_candidates(predictor, candidates, c0, sk) |
|
|
| if method == "guidance": |
| assert model is not None |
| c0t = torch.as_tensor(c0, dtype=torch.float32, device=device) |
| rew = Reward(reward_kind if reward_kind != "combined" else "centroid", |
| target_c=c_star, target_sample=tgt_sample, |
| classifier=classifier, device=device, control_ref=ctrl_mean) |
| if guidance_init == "random": |
| e_init = inf.encode_label(model, data, candidates[rng.integers(len(candidates))], device) |
| elif guidance_init == "mean_top": |
| from src.models.pivot import candidate_embeddings |
| E = candidate_embeddings(model, data, ranked[:k_nearest], device) |
| e_init = torch.as_tensor(E.mean(0, keepdims=True), device=device) |
| else: |
| e_init = inf.encode_label(model, data, ranked[0], device) |
| e_star = inf.reward_guidance(model, c0t, rew, e_init, steps=guidance_steps, |
| step_size=guidance_step, normalize=guidance_normalize) |
| if rerank: |
| top = inf.project_and_rerank(model, data, candidates, e_star, c0t, rew, |
| k_nearest=k_nearest, topk=k_nearest, device=device) |
| guided = [l for l, _ in top] |
| else: |
| from src.models.pivot import candidate_embeddings |
| E = torch.as_tensor(candidate_embeddings(model, data, candidates, device), device=device) |
| nn = torch.argsort(torch.cdist(e_star.view(1, -1), E).squeeze(0)).cpu().numpy() |
| guided = [candidates[i] for i in nn[:k_nearest]] |
| ranked = guided + [r for r in ranked if r not in guided] |
|
|
| |
| res["top1"].append(M.top_k_accuracy(ranked, p, 1)) |
| res["top5"].append(M.top_k_accuracy(ranked, p, 5)) |
| res["ndcg"].append(M.ndcg_at_k(ranked, M.exact_relevance(p), 10)) |
| res["rank"].append(M.rank_of_true(ranked, p)) |
| if gene_cluster is not None: |
| frel = M.functional_relevance(p, data.parse, gene_cluster, candidates) |
| res["func_top5"].append(float(any(frel.get(l, 0) >= 0.5 for l in ranked[:5]))) |
| res["func_ndcg"].append(M.ndcg_at_k(ranked, frel, 10)) |
| |
| top1_pop = predictor.population(ranked[0], c0) |
| res["endpoint_dist"].append(float(np.linalg.norm(top1_pop.mean(0) - c_star))) |
| |
| others = rng.choice([q for q in candidates if q != p], size=min(20, len(candidates) - 1), |
| replace=False) |
| off = np.mean([np.linalg.norm(top1_pop.mean(0) - data.emb[data.pert_to_idx[q]].mean(0)) |
| for q in others]) |
| res["off_dist"].append(float(off)) |
| res["spec_margin"].append(float(off - np.linalg.norm(top1_pop.mean(0) - c_star))) |
| if classifier is not None: |
| with torch.no_grad(): |
| res["clf"].append(float(torch.sigmoid( |
| classifier(torch.as_tensor(top1_pop, dtype=torch.float32, device=device)) |
| ).mean().item())) |
|
|
| agg = {k: (float(np.mean(v)) if v else None) for k, v in res.items()} |
| agg["_per"] = res |
| agg["n_targets"] = len(targets) |
| agg["random_top5"] = 5.0 / len(candidates) |
| return agg |
|
|