File size: 3,415 Bytes
3b4941f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | """combinatorial nomination eval (tables 7, 19).
target is a held-out multi-gene perturbation. methods: pivot greedy (alg 6),
pivot endpoint-ranking over the observed-combo vocab, and optionally guidance.
metrics: exact top-1/top-5 (gene-set match), partial overlap, endpoint distance,
ndcg@10, and synergy correlation vs the additive baseline.
"""
from __future__ import annotations
import numpy as np
from src.evaluation import inference as inf
from src.evaluation import metrics as M
from src.evaluation.rewards import Reward
from src.experiments.nomination_eval import numpy_score, rank_candidates
def evaluate_combinatorial(pivot_predictor, data, targets, combo_candidates, gene_pool,
control_pool, model, device, reward_kind="centroid",
n_ctrl=128, max_size=2, additive=None, seed=0):
import torch
rng = np.random.default_rng(seed)
ctrl_all = data.emb[control_pool]
res = {k: [] for k in ["greedy_exact1", "greedy_exact5", "greedy_overlap",
"rank_exact1", "rank_exact5", "rank_ndcg", "rank_overlap",
"endpoint_dist", "synergy_corr"]}
for p in targets:
true_genes = set(data.parse(p))
c_star = data.emb[data.pert_to_idx[p]].mean(0)
c0 = ctrl_all[rng.choice(len(ctrl_all), min(n_ctrl, len(ctrl_all)), replace=False)]
sk = dict(kind=reward_kind, c_star=c_star, target_sample=data.emb[data.pert_to_idx[p]],
device=device)
# greedy (alg 6)
c0t = torch.as_tensor(c0, dtype=torch.float32, device=device)
rew = Reward(reward_kind, target_c=c_star, target_sample=data.emb[data.pert_to_idx[p]],
device=device)
chosen, _, _ = inf.greedy_combinatorial(model, data, gene_pool, c0t, rew,
max_size=max_size, device=device)
res["greedy_exact1"].append(float(set(chosen) == true_genes))
res["greedy_overlap"].append(M.partial_overlap(chosen, true_genes))
# ranking over observed combos
ranked, _ = rank_candidates(pivot_predictor, combo_candidates, c0, sk)
res["rank_exact1"].append(M.top_k_accuracy(ranked, p, 1))
res["rank_exact5"].append(M.top_k_accuracy(ranked, p, 5))
res["rank_ndcg"].append(M.ndcg_at_k(ranked, M.exact_relevance(p), 10))
# partial overlap of best-ranked combo's genes
res["rank_overlap"].append(M.partial_overlap(data.parse(ranked[0]), true_genes))
# greedy only yields one set, so reuse exact1 as the top-5 proxy
res["greedy_exact5"].append(res["greedy_exact1"][-1])
top1_pop = pivot_predictor.population(ranked[0], c0)
res["endpoint_dist"].append(float(np.linalg.norm(top1_pop.mean(0) - c_star)))
# synergy: predicted vs true non-additive component (gene space), if additive given
if additive is not None:
comp = data.pca_components
pred_eff = (pivot_predictor.population(p, c0).mean(0) - ctrl_all.mean(0)) @ comp
true_eff = (c_star - ctrl_all.mean(0)) @ comp
add_eff = additive.predict_effect(p) @ comp
res["synergy_corr"].append(M.synergy_correlation(pred_eff, add_eff, true_eff))
agg = {k: (float(np.mean(v)) if v else None) for k, v in res.items()}
agg["_per"] = res
agg["n_targets"] = len(targets)
return agg
|