| """combinatorial nomination eval (tables 7, 19). |
| |
| target is a held-out multi-gene perturbation. methods: pivot greedy (alg 6), |
| pivot endpoint-ranking over the observed-combo vocab, and optionally guidance. |
| metrics: exact top-1/top-5 (gene-set match), partial overlap, endpoint distance, |
| ndcg@10, and synergy correlation vs the additive baseline. |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
|
|
| from src.evaluation import inference as inf |
| from src.evaluation import metrics as M |
| from src.evaluation.rewards import Reward |
| from src.experiments.nomination_eval import numpy_score, rank_candidates |
|
|
|
|
| def evaluate_combinatorial(pivot_predictor, data, targets, combo_candidates, gene_pool, |
| control_pool, model, device, reward_kind="centroid", |
| n_ctrl=128, max_size=2, additive=None, seed=0): |
| import torch |
|
|
| rng = np.random.default_rng(seed) |
| ctrl_all = data.emb[control_pool] |
| res = {k: [] for k in ["greedy_exact1", "greedy_exact5", "greedy_overlap", |
| "rank_exact1", "rank_exact5", "rank_ndcg", "rank_overlap", |
| "endpoint_dist", "synergy_corr"]} |
| for p in targets: |
| true_genes = set(data.parse(p)) |
| c_star = data.emb[data.pert_to_idx[p]].mean(0) |
| c0 = ctrl_all[rng.choice(len(ctrl_all), min(n_ctrl, len(ctrl_all)), replace=False)] |
| sk = dict(kind=reward_kind, c_star=c_star, target_sample=data.emb[data.pert_to_idx[p]], |
| device=device) |
|
|
| |
| c0t = torch.as_tensor(c0, dtype=torch.float32, device=device) |
| rew = Reward(reward_kind, target_c=c_star, target_sample=data.emb[data.pert_to_idx[p]], |
| device=device) |
| chosen, _, _ = inf.greedy_combinatorial(model, data, gene_pool, c0t, rew, |
| max_size=max_size, device=device) |
| res["greedy_exact1"].append(float(set(chosen) == true_genes)) |
| res["greedy_overlap"].append(M.partial_overlap(chosen, true_genes)) |
|
|
| |
| ranked, _ = rank_candidates(pivot_predictor, combo_candidates, c0, sk) |
| res["rank_exact1"].append(M.top_k_accuracy(ranked, p, 1)) |
| res["rank_exact5"].append(M.top_k_accuracy(ranked, p, 5)) |
| res["rank_ndcg"].append(M.ndcg_at_k(ranked, M.exact_relevance(p), 10)) |
| |
| res["rank_overlap"].append(M.partial_overlap(data.parse(ranked[0]), true_genes)) |
| |
| res["greedy_exact5"].append(res["greedy_exact1"][-1]) |
|
|
| top1_pop = pivot_predictor.population(ranked[0], c0) |
| res["endpoint_dist"].append(float(np.linalg.norm(top1_pop.mean(0) - c_star))) |
|
|
| |
| if additive is not None: |
| comp = data.pca_components |
| pred_eff = (pivot_predictor.population(p, c0).mean(0) - ctrl_all.mean(0)) @ comp |
| true_eff = (c_star - ctrl_all.mean(0)) @ comp |
| add_eff = additive.predict_effect(p) @ comp |
| res["synergy_corr"].append(M.synergy_correlation(pred_eff, add_eff, true_eff)) |
|
|
| agg = {k: (float(np.mean(v)) if v else None) for k, v in res.items()} |
| agg["_per"] = res |
| agg["n_targets"] = len(targets) |
| return agg |
|
|