File size: 2,741 Bytes
3b4941f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | """forward response prediction eval (tables 2,3,10,11).
for each test perturbation: push a matched control sample through the predictor to
get a predicted post-perturbation population, then compare to the true one.
gene-space metrics use the pca decoder (affine, so mean-then-decode =
decode-then-mean). distributional metrics are computed in embedding space.
"""
from __future__ import annotations
import numpy as np
from src.evaluation import metrics as M
def _gene_effect(mean_emb, control_mean_emb, components):
"""affine decode of an embedding-space shift to hvg gene-space effect."""
return (mean_emb - control_mean_emb) @ components
def evaluate_forward(predictor, data, test_perts, control_idx_pool, n_ctrl=200,
de_k=20, seed=0, max_perts=None, dist_cap=300):
"""aggregated forward metrics over test_perts.
dist_cap subsamples populations for the quadratic mmd / wasserstein metrics."""
rng = np.random.default_rng(seed)
comp = data.pca_components # (d, h)
ctrl_emb_all = data.emb[control_idx_pool]
ctrl_mean_emb = ctrl_emb_all.mean(0)
perts = list(test_perts)
if max_perts:
perts = perts[:max_perts]
rows = {k: [] for k in ["mse", "r2", "pearson", "spearman", "de_corr", "mmd", "wasserstein"]}
for p in perts:
true_idx = data.pert_to_idx[p]
true_pop = data.emb[true_idx]
c0 = ctrl_emb_all[rng.choice(len(ctrl_emb_all), min(n_ctrl, len(ctrl_emb_all)), replace=False)]
pred_pop = predictor.population(p, c0)
pred_eff = _gene_effect(pred_pop.mean(0), ctrl_mean_emb, comp)
true_eff = _gene_effect(true_pop.mean(0), ctrl_mean_emb, comp)
# decoded gene-space expression vectors for mse/r2
pred_genes = pred_pop.mean(0) @ comp + data.pca_mean
true_genes = true_pop.mean(0) @ comp + data.pca_mean
rows["mse"].append(M.mse(pred_genes, true_genes))
rows["r2"].append(M.r2(pred_eff, true_eff))
rows["pearson"].append(M.pearson(pred_eff, true_eff))
rows["spearman"].append(M.spearman(pred_eff, true_eff))
rows["de_corr"].append(M.de_gene_correlation(pred_eff, true_eff, k=de_k))
tp = true_pop if len(true_pop) <= dist_cap else true_pop[rng.choice(len(true_pop), dist_cap, replace=False)]
pp = pred_pop if len(pred_pop) <= dist_cap else pred_pop[rng.choice(len(pred_pop), dist_cap, replace=False)]
rows["mmd"].append(M.mmd_rbf(pp, tp, seed=seed))
rows["wasserstein"].append(M.sliced_wasserstein(pp, tp, seed=seed))
out = {k: float(np.mean(v)) for k, v in rows.items()}
out["_per_pert"] = {k: rows[k] for k in rows} # kept for bootstrap cis / paired tests
out["n_perts"] = len(perts)
return out
|