File size: 3,415 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""combinatorial nomination eval (tables 7, 19).

target is a held-out multi-gene perturbation. methods: pivot greedy (alg 6),
pivot endpoint-ranking over the observed-combo vocab, and optionally guidance.
metrics: exact top-1/top-5 (gene-set match), partial overlap, endpoint distance,
ndcg@10, and synergy correlation vs the additive baseline.
"""
from __future__ import annotations

import numpy as np

from src.evaluation import inference as inf
from src.evaluation import metrics as M
from src.evaluation.rewards import Reward
from src.experiments.nomination_eval import numpy_score, rank_candidates


def evaluate_combinatorial(pivot_predictor, data, targets, combo_candidates, gene_pool,
                           control_pool, model, device, reward_kind="centroid",
                           n_ctrl=128, max_size=2, additive=None, seed=0):
    import torch

    rng = np.random.default_rng(seed)
    ctrl_all = data.emb[control_pool]
    res = {k: [] for k in ["greedy_exact1", "greedy_exact5", "greedy_overlap",
                           "rank_exact1", "rank_exact5", "rank_ndcg", "rank_overlap",
                           "endpoint_dist", "synergy_corr"]}
    for p in targets:
        true_genes = set(data.parse(p))
        c_star = data.emb[data.pert_to_idx[p]].mean(0)
        c0 = ctrl_all[rng.choice(len(ctrl_all), min(n_ctrl, len(ctrl_all)), replace=False)]
        sk = dict(kind=reward_kind, c_star=c_star, target_sample=data.emb[data.pert_to_idx[p]],
                  device=device)

        # greedy (alg 6)
        c0t = torch.as_tensor(c0, dtype=torch.float32, device=device)
        rew = Reward(reward_kind, target_c=c_star, target_sample=data.emb[data.pert_to_idx[p]],
                     device=device)
        chosen, _, _ = inf.greedy_combinatorial(model, data, gene_pool, c0t, rew,
                                                 max_size=max_size, device=device)
        res["greedy_exact1"].append(float(set(chosen) == true_genes))
        res["greedy_overlap"].append(M.partial_overlap(chosen, true_genes))

        # ranking over observed combos
        ranked, _ = rank_candidates(pivot_predictor, combo_candidates, c0, sk)
        res["rank_exact1"].append(M.top_k_accuracy(ranked, p, 1))
        res["rank_exact5"].append(M.top_k_accuracy(ranked, p, 5))
        res["rank_ndcg"].append(M.ndcg_at_k(ranked, M.exact_relevance(p), 10))
        # partial overlap of best-ranked combo's genes
        res["rank_overlap"].append(M.partial_overlap(data.parse(ranked[0]), true_genes))
        # greedy only yields one set, so reuse exact1 as the top-5 proxy
        res["greedy_exact5"].append(res["greedy_exact1"][-1])

        top1_pop = pivot_predictor.population(ranked[0], c0)
        res["endpoint_dist"].append(float(np.linalg.norm(top1_pop.mean(0) - c_star)))

        # synergy: predicted vs true non-additive component (gene space), if additive given
        if additive is not None:
            comp = data.pca_components
            pred_eff = (pivot_predictor.population(p, c0).mean(0) - ctrl_all.mean(0)) @ comp
            true_eff = (c_star - ctrl_all.mean(0)) @ comp
            add_eff = additive.predict_effect(p) @ comp
            res["synergy_corr"].append(M.synergy_correlation(pred_eff, add_eff, true_eff))

    agg = {k: (float(np.mean(v)) if v else None) for k, v in res.items()}
    agg["_per"] = res
    agg["n_targets"] = len(targets)
    return agg