"""desired-state nomination eval (tables 4-7,13,14). a target state is induced by a held-out perturbation p: c* = centroid of p's cells. given c*, a method nominates perturbations; we score exact/functional recovery, ranking quality, endpoint distance, and specificity. ranking uses a numpy reward scorer (works for any predictor). pivot reward guidance (Alg 4) also uses the differentiable torch reward and the model. """ from __future__ import annotations import numpy as np import torch from src.evaluation import inference as inf from src.evaluation import metrics as M from src.evaluation.rewards import Reward # numpy reward scorers (for ranking; mirror rewards.py) def numpy_score(pop: np.ndarray, kind: str, c_star=None, target_sample=None, classifier=None, device="cpu", control_ref=None): if kind in ("target", "centroid", "combined"): s = -np.mean(np.sum((pop - c_star) ** 2, axis=1)) if kind != "combined": return float(s) if kind == "cosine": cref = control_ref if control_ref is not None else pop.mean(0) pe = pop - cref te = c_star - cref pe = pe / (np.linalg.norm(pe, axis=1, keepdims=True) + 1e-8) te = te / (np.linalg.norm(te) + 1e-8) return float(np.mean(pe @ te)) if kind == "nn_target": from scipy.spatial.distance import cdist return float(-np.mean(cdist(pop, target_sample).min(axis=1) ** 2)) if kind == "mmd": return float(-M.mmd_rbf(pop, target_sample)) if kind == "wasserstein": return float(-M.sliced_wasserstein(pop, target_sample)) if kind in ("classifier", "combined"): with torch.no_grad(): logit = classifier(torch.as_tensor(pop, dtype=torch.float32, device=device)) rc = torch.nn.functional.logsigmoid(logit).mean().item() if kind == "classifier": return float(rc) return float(s + rc) # combined raise ValueError(kind) _PER_ROW_REWARDS = ("centroid", "target", "nn_target", "classifier", "combined", "cosine") def rank_candidates(predictor, candidates, c0, score_kwargs): """rank candidates by reward of predicted endpoint. uses the batched gpu path for pivot with per-row rewards; numpy per-candidate otherwise (baselines, mmd/wass).""" from src.experiments.predictors import PivotPredictor from src.evaluation import inference as inf kind = score_kwargs["kind"] if isinstance(predictor, PivotPredictor) and kind in _PER_ROW_REWARDS: dev = score_kwargs.get("device", "cpu") rew = Reward(kind, target_c=score_kwargs.get("c_star"), target_sample=score_kwargs.get("target_sample"), classifier=score_kwargs.get("classifier"), device=dev, control_ref=score_kwargs.get("control_ref")) c0t = torch.as_tensor(c0, dtype=torch.float32, device=dev) ranked = inf.endpoint_ranking(predictor.model, predictor.data, candidates, c0t, rew, device=dev) return [l for l, _ in ranked], np.array([s for _, s in ranked]) scores = np.array([numpy_score(predictor.population(c, c0), **score_kwargs) for c in candidates]) order = np.argsort(-scores) return [candidates[i] for i in order], scores[order] def build_target(data, p, control_mean_emb): return data.emb[data.pert_to_idx[p]].mean(0), data.emb[data.pert_to_idx[p]] def evaluate_nomination(predictor, data, targets, candidates, control_pool, reward_kind="centroid", method="ranking", n_ctrl=128, gene_cluster=None, model=None, device="cpu", guidance_steps=25, guidance_step=0.5, k_nearest=10, classifier=None, seed=0, guidance_init="warm", guidance_normalize=True, rerank=True): """evaluate single-perturbation nomination. returns aggregated metric dict. method: 'ranking' (Alg 3) | 'guidance' (Alg 4+5, pivot only). guidance_init: 'warm' (top-ranked) | 'random' | 'mean_top' (mean of top-k cand embeddings). guidance_normalize: normalize the jacobian-pullback step (Alg 4). rerank: if true, rerank k-nn by endpoint reward (Alg 5); else project to single nearest. """ rng = np.random.default_rng(seed) ctrl_all = data.emb[control_pool] ctrl_mean = ctrl_all.mean(0) res = {k: [] for k in ["top1", "top5", "ndcg", "func_top5", "func_ndcg", "rank", "endpoint_dist", "spec_margin", "off_dist", "clf"]} for p in targets: c_star, tgt_sample = build_target(data, p, ctrl_mean) c0 = ctrl_all[rng.choice(len(ctrl_all), min(n_ctrl, len(ctrl_all)), replace=False)] sk = dict(kind=reward_kind, c_star=c_star, target_sample=tgt_sample, classifier=classifier, device=device, control_ref=ctrl_mean) ranked, _ = rank_candidates(predictor, candidates, c0, sk) if method == "guidance": assert model is not None c0t = torch.as_tensor(c0, dtype=torch.float32, device=device) rew = Reward(reward_kind if reward_kind != "combined" else "centroid", target_c=c_star, target_sample=tgt_sample, classifier=classifier, device=device, control_ref=ctrl_mean) if guidance_init == "random": e_init = inf.encode_label(model, data, candidates[rng.integers(len(candidates))], device) elif guidance_init == "mean_top": from src.models.pivot import candidate_embeddings E = candidate_embeddings(model, data, ranked[:k_nearest], device) e_init = torch.as_tensor(E.mean(0, keepdims=True), device=device) else: # warm e_init = inf.encode_label(model, data, ranked[0], device) e_star = inf.reward_guidance(model, c0t, rew, e_init, steps=guidance_steps, step_size=guidance_step, normalize=guidance_normalize) if rerank: top = inf.project_and_rerank(model, data, candidates, e_star, c0t, rew, k_nearest=k_nearest, topk=k_nearest, device=device) guided = [l for l, _ in top] else: from src.models.pivot import candidate_embeddings E = torch.as_tensor(candidate_embeddings(model, data, candidates, device), device=device) nn = torch.argsort(torch.cdist(e_star.view(1, -1), E).squeeze(0)).cpu().numpy() guided = [candidates[i] for i in nn[:k_nearest]] ranked = guided + [r for r in ranked if r not in guided] # full order for rank metrics # metrics res["top1"].append(M.top_k_accuracy(ranked, p, 1)) res["top5"].append(M.top_k_accuracy(ranked, p, 5)) res["ndcg"].append(M.ndcg_at_k(ranked, M.exact_relevance(p), 10)) res["rank"].append(M.rank_of_true(ranked, p)) if gene_cluster is not None: frel = M.functional_relevance(p, data.parse, gene_cluster, candidates) res["func_top5"].append(float(any(frel.get(l, 0) >= 0.5 for l in ranked[:5]))) res["func_ndcg"].append(M.ndcg_at_k(ranked, frel, 10)) # endpoint distance of nominated top-1 + specificity top1_pop = predictor.population(ranked[0], c0) res["endpoint_dist"].append(float(np.linalg.norm(top1_pop.mean(0) - c_star))) # off-state distance: mean dist of top-1 endpoint to other perturbation centroids others = rng.choice([q for q in candidates if q != p], size=min(20, len(candidates) - 1), replace=False) off = np.mean([np.linalg.norm(top1_pop.mean(0) - data.emb[data.pert_to_idx[q]].mean(0)) for q in others]) res["off_dist"].append(float(off)) res["spec_margin"].append(float(off - np.linalg.norm(top1_pop.mean(0) - c_star))) if classifier is not None: with torch.no_grad(): res["clf"].append(float(torch.sigmoid( classifier(torch.as_tensor(top1_pop, dtype=torch.float32, device=device)) ).mean().item())) agg = {k: (float(np.mean(v)) if v else None) for k, v in res.items()} agg["_per"] = res agg["n_targets"] = len(targets) agg["random_top5"] = 5.0 / len(candidates) return agg