File size: 2,741 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""forward response prediction eval (tables 2,3,10,11).

for each test perturbation: push a matched control sample through the predictor to
get a predicted post-perturbation population, then compare to the true one.
gene-space metrics use the pca decoder (affine, so mean-then-decode =
decode-then-mean). distributional metrics are computed in embedding space.
"""
from __future__ import annotations

import numpy as np

from src.evaluation import metrics as M


def _gene_effect(mean_emb, control_mean_emb, components):
    """affine decode of an embedding-space shift to hvg gene-space effect."""
    return (mean_emb - control_mean_emb) @ components


def evaluate_forward(predictor, data, test_perts, control_idx_pool, n_ctrl=200,
                     de_k=20, seed=0, max_perts=None, dist_cap=300):
    """aggregated forward metrics over test_perts.

    dist_cap subsamples populations for the quadratic mmd / wasserstein metrics."""
    rng = np.random.default_rng(seed)
    comp = data.pca_components  # (d, h)
    ctrl_emb_all = data.emb[control_idx_pool]
    ctrl_mean_emb = ctrl_emb_all.mean(0)
    perts = list(test_perts)
    if max_perts:
        perts = perts[:max_perts]

    rows = {k: [] for k in ["mse", "r2", "pearson", "spearman", "de_corr", "mmd", "wasserstein"]}
    for p in perts:
        true_idx = data.pert_to_idx[p]
        true_pop = data.emb[true_idx]
        c0 = ctrl_emb_all[rng.choice(len(ctrl_emb_all), min(n_ctrl, len(ctrl_emb_all)), replace=False)]
        pred_pop = predictor.population(p, c0)

        pred_eff = _gene_effect(pred_pop.mean(0), ctrl_mean_emb, comp)
        true_eff = _gene_effect(true_pop.mean(0), ctrl_mean_emb, comp)
        # decoded gene-space expression vectors for mse/r2
        pred_genes = pred_pop.mean(0) @ comp + data.pca_mean
        true_genes = true_pop.mean(0) @ comp + data.pca_mean

        rows["mse"].append(M.mse(pred_genes, true_genes))
        rows["r2"].append(M.r2(pred_eff, true_eff))
        rows["pearson"].append(M.pearson(pred_eff, true_eff))
        rows["spearman"].append(M.spearman(pred_eff, true_eff))
        rows["de_corr"].append(M.de_gene_correlation(pred_eff, true_eff, k=de_k))
        tp = true_pop if len(true_pop) <= dist_cap else true_pop[rng.choice(len(true_pop), dist_cap, replace=False)]
        pp = pred_pop if len(pred_pop) <= dist_cap else pred_pop[rng.choice(len(pred_pop), dist_cap, replace=False)]
        rows["mmd"].append(M.mmd_rbf(pp, tp, seed=seed))
        rows["wasserstein"].append(M.sliced_wasserstein(pp, tp, seed=seed))

    out = {k: float(np.mean(v)) for k, v in rows.items()}
    out["_per_pert"] = {k: rows[k] for k in rows}  # kept for bootstrap cis / paired tests
    out["n_perts"] = len(perts)
    return out