File size: 8,413 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""desired-state nomination eval (tables 4-7,13,14).

a target state is induced by a held-out perturbation p: c* = centroid of p's cells.
given c*, a method nominates perturbations; we score exact/functional recovery,
ranking quality, endpoint distance, and specificity.

ranking uses a numpy reward scorer (works for any predictor). pivot reward guidance
(Alg 4) also uses the differentiable torch reward and the model.
"""
from __future__ import annotations

import numpy as np
import torch

from src.evaluation import inference as inf
from src.evaluation import metrics as M
from src.evaluation.rewards import Reward


# numpy reward scorers (for ranking; mirror rewards.py)
def numpy_score(pop: np.ndarray, kind: str, c_star=None, target_sample=None, classifier=None,
                device="cpu", control_ref=None):
    if kind in ("target", "centroid", "combined"):
        s = -np.mean(np.sum((pop - c_star) ** 2, axis=1))
        if kind != "combined":
            return float(s)
    if kind == "cosine":
        cref = control_ref if control_ref is not None else pop.mean(0)
        pe = pop - cref
        te = c_star - cref
        pe = pe / (np.linalg.norm(pe, axis=1, keepdims=True) + 1e-8)
        te = te / (np.linalg.norm(te) + 1e-8)
        return float(np.mean(pe @ te))
    if kind == "nn_target":
        from scipy.spatial.distance import cdist
        return float(-np.mean(cdist(pop, target_sample).min(axis=1) ** 2))
    if kind == "mmd":
        return float(-M.mmd_rbf(pop, target_sample))
    if kind == "wasserstein":
        return float(-M.sliced_wasserstein(pop, target_sample))
    if kind in ("classifier", "combined"):
        with torch.no_grad():
            logit = classifier(torch.as_tensor(pop, dtype=torch.float32, device=device))
            rc = torch.nn.functional.logsigmoid(logit).mean().item()
        if kind == "classifier":
            return float(rc)
        return float(s + rc)  # combined
    raise ValueError(kind)


_PER_ROW_REWARDS = ("centroid", "target", "nn_target", "classifier", "combined", "cosine")


def rank_candidates(predictor, candidates, c0, score_kwargs):
    """rank candidates by reward of predicted endpoint. uses the batched gpu path for
    pivot with per-row rewards; numpy per-candidate otherwise (baselines, mmd/wass)."""
    from src.experiments.predictors import PivotPredictor
    from src.evaluation import inference as inf

    kind = score_kwargs["kind"]
    if isinstance(predictor, PivotPredictor) and kind in _PER_ROW_REWARDS:
        dev = score_kwargs.get("device", "cpu")
        rew = Reward(kind, target_c=score_kwargs.get("c_star"),
                     target_sample=score_kwargs.get("target_sample"),
                     classifier=score_kwargs.get("classifier"), device=dev,
                     control_ref=score_kwargs.get("control_ref"))
        c0t = torch.as_tensor(c0, dtype=torch.float32, device=dev)
        ranked = inf.endpoint_ranking(predictor.model, predictor.data, candidates, c0t, rew, device=dev)
        return [l for l, _ in ranked], np.array([s for _, s in ranked])
    scores = np.array([numpy_score(predictor.population(c, c0), **score_kwargs) for c in candidates])
    order = np.argsort(-scores)
    return [candidates[i] for i in order], scores[order]


def build_target(data, p, control_mean_emb):
    return data.emb[data.pert_to_idx[p]].mean(0), data.emb[data.pert_to_idx[p]]


def evaluate_nomination(predictor, data, targets, candidates, control_pool,
                        reward_kind="centroid", method="ranking", n_ctrl=128,
                        gene_cluster=None, model=None, device="cpu",
                        guidance_steps=25, guidance_step=0.5, k_nearest=10,
                        classifier=None, seed=0,
                        guidance_init="warm", guidance_normalize=True, rerank=True):
    """evaluate single-perturbation nomination. returns aggregated metric dict.

    method: 'ranking' (Alg 3) | 'guidance' (Alg 4+5, pivot only).
    guidance_init: 'warm' (top-ranked) | 'random' | 'mean_top' (mean of top-k cand embeddings).
    guidance_normalize: normalize the jacobian-pullback step (Alg 4).
    rerank: if true, rerank k-nn by endpoint reward (Alg 5); else project to single nearest.
    """
    rng = np.random.default_rng(seed)
    ctrl_all = data.emb[control_pool]
    ctrl_mean = ctrl_all.mean(0)
    res = {k: [] for k in ["top1", "top5", "ndcg", "func_top5", "func_ndcg",
                            "rank", "endpoint_dist", "spec_margin", "off_dist", "clf"]}
    for p in targets:
        c_star, tgt_sample = build_target(data, p, ctrl_mean)
        c0 = ctrl_all[rng.choice(len(ctrl_all), min(n_ctrl, len(ctrl_all)), replace=False)]
        sk = dict(kind=reward_kind, c_star=c_star, target_sample=tgt_sample,
                  classifier=classifier, device=device, control_ref=ctrl_mean)

        ranked, _ = rank_candidates(predictor, candidates, c0, sk)

        if method == "guidance":
            assert model is not None
            c0t = torch.as_tensor(c0, dtype=torch.float32, device=device)
            rew = Reward(reward_kind if reward_kind != "combined" else "centroid",
                         target_c=c_star, target_sample=tgt_sample,
                         classifier=classifier, device=device, control_ref=ctrl_mean)
            if guidance_init == "random":
                e_init = inf.encode_label(model, data, candidates[rng.integers(len(candidates))], device)
            elif guidance_init == "mean_top":
                from src.models.pivot import candidate_embeddings
                E = candidate_embeddings(model, data, ranked[:k_nearest], device)
                e_init = torch.as_tensor(E.mean(0, keepdims=True), device=device)
            else:  # warm
                e_init = inf.encode_label(model, data, ranked[0], device)
            e_star = inf.reward_guidance(model, c0t, rew, e_init, steps=guidance_steps,
                                         step_size=guidance_step, normalize=guidance_normalize)
            if rerank:
                top = inf.project_and_rerank(model, data, candidates, e_star, c0t, rew,
                                             k_nearest=k_nearest, topk=k_nearest, device=device)
                guided = [l for l, _ in top]
            else:
                from src.models.pivot import candidate_embeddings
                E = torch.as_tensor(candidate_embeddings(model, data, candidates, device), device=device)
                nn = torch.argsort(torch.cdist(e_star.view(1, -1), E).squeeze(0)).cpu().numpy()
                guided = [candidates[i] for i in nn[:k_nearest]]
            ranked = guided + [r for r in ranked if r not in guided]  # full order for rank metrics

        # metrics
        res["top1"].append(M.top_k_accuracy(ranked, p, 1))
        res["top5"].append(M.top_k_accuracy(ranked, p, 5))
        res["ndcg"].append(M.ndcg_at_k(ranked, M.exact_relevance(p), 10))
        res["rank"].append(M.rank_of_true(ranked, p))
        if gene_cluster is not None:
            frel = M.functional_relevance(p, data.parse, gene_cluster, candidates)
            res["func_top5"].append(float(any(frel.get(l, 0) >= 0.5 for l in ranked[:5])))
            res["func_ndcg"].append(M.ndcg_at_k(ranked, frel, 10))
        # endpoint distance of nominated top-1 + specificity
        top1_pop = predictor.population(ranked[0], c0)
        res["endpoint_dist"].append(float(np.linalg.norm(top1_pop.mean(0) - c_star)))
        # off-state distance: mean dist of top-1 endpoint to other perturbation centroids
        others = rng.choice([q for q in candidates if q != p], size=min(20, len(candidates) - 1),
                            replace=False)
        off = np.mean([np.linalg.norm(top1_pop.mean(0) - data.emb[data.pert_to_idx[q]].mean(0))
                       for q in others])
        res["off_dist"].append(float(off))
        res["spec_margin"].append(float(off - np.linalg.norm(top1_pop.mean(0) - c_star)))
        if classifier is not None:
            with torch.no_grad():
                res["clf"].append(float(torch.sigmoid(
                    classifier(torch.as_tensor(top1_pop, dtype=torch.float32, device=device))
                ).mean().item()))

    agg = {k: (float(np.mean(v)) if v else None) for k, v in res.items()}
    agg["_per"] = res
    agg["n_targets"] = len(targets)
    agg["random_top5"] = 5.0 / len(candidates)
    return agg