"""combinatorial nomination eval (tables 7, 19). target is a held-out multi-gene perturbation. methods: pivot greedy (alg 6), pivot endpoint-ranking over the observed-combo vocab, and optionally guidance. metrics: exact top-1/top-5 (gene-set match), partial overlap, endpoint distance, ndcg@10, and synergy correlation vs the additive baseline. """ from __future__ import annotations import numpy as np from src.evaluation import inference as inf from src.evaluation import metrics as M from src.evaluation.rewards import Reward from src.experiments.nomination_eval import numpy_score, rank_candidates def evaluate_combinatorial(pivot_predictor, data, targets, combo_candidates, gene_pool, control_pool, model, device, reward_kind="centroid", n_ctrl=128, max_size=2, additive=None, seed=0): import torch rng = np.random.default_rng(seed) ctrl_all = data.emb[control_pool] res = {k: [] for k in ["greedy_exact1", "greedy_exact5", "greedy_overlap", "rank_exact1", "rank_exact5", "rank_ndcg", "rank_overlap", "endpoint_dist", "synergy_corr"]} for p in targets: true_genes = set(data.parse(p)) c_star = data.emb[data.pert_to_idx[p]].mean(0) c0 = ctrl_all[rng.choice(len(ctrl_all), min(n_ctrl, len(ctrl_all)), replace=False)] sk = dict(kind=reward_kind, c_star=c_star, target_sample=data.emb[data.pert_to_idx[p]], device=device) # greedy (alg 6) c0t = torch.as_tensor(c0, dtype=torch.float32, device=device) rew = Reward(reward_kind, target_c=c_star, target_sample=data.emb[data.pert_to_idx[p]], device=device) chosen, _, _ = inf.greedy_combinatorial(model, data, gene_pool, c0t, rew, max_size=max_size, device=device) res["greedy_exact1"].append(float(set(chosen) == true_genes)) res["greedy_overlap"].append(M.partial_overlap(chosen, true_genes)) # ranking over observed combos ranked, _ = rank_candidates(pivot_predictor, combo_candidates, c0, sk) res["rank_exact1"].append(M.top_k_accuracy(ranked, p, 1)) res["rank_exact5"].append(M.top_k_accuracy(ranked, p, 5)) res["rank_ndcg"].append(M.ndcg_at_k(ranked, M.exact_relevance(p), 10)) # partial overlap of best-ranked combo's genes res["rank_overlap"].append(M.partial_overlap(data.parse(ranked[0]), true_genes)) # greedy only yields one set, so reuse exact1 as the top-5 proxy res["greedy_exact5"].append(res["greedy_exact1"][-1]) top1_pop = pivot_predictor.population(ranked[0], c0) res["endpoint_dist"].append(float(np.linalg.norm(top1_pop.mean(0) - c_star))) # synergy: predicted vs true non-additive component (gene space), if additive given if additive is not None: comp = data.pca_components pred_eff = (pivot_predictor.population(p, c0).mean(0) - ctrl_all.mean(0)) @ comp true_eff = (c_star - ctrl_all.mean(0)) @ comp add_eff = additive.predict_effect(p) @ comp res["synergy_corr"].append(M.synergy_correlation(pred_eff, add_eff, true_eff)) agg = {k: (float(np.mean(v)) if v else None) for k, v in res.items()} agg["_per"] = res agg["n_targets"] = len(targets) return agg