| """forward response prediction eval (tables 2,3,10,11). |
| |
| for each test perturbation: push a matched control sample through the predictor to |
| get a predicted post-perturbation population, then compare to the true one. |
| gene-space metrics use the pca decoder (affine, so mean-then-decode = |
| decode-then-mean). distributional metrics are computed in embedding space. |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
|
|
| from src.evaluation import metrics as M |
|
|
|
|
| def _gene_effect(mean_emb, control_mean_emb, components): |
| """affine decode of an embedding-space shift to hvg gene-space effect.""" |
| return (mean_emb - control_mean_emb) @ components |
|
|
|
|
| def evaluate_forward(predictor, data, test_perts, control_idx_pool, n_ctrl=200, |
| de_k=20, seed=0, max_perts=None, dist_cap=300): |
| """aggregated forward metrics over test_perts. |
| |
| dist_cap subsamples populations for the quadratic mmd / wasserstein metrics.""" |
| rng = np.random.default_rng(seed) |
| comp = data.pca_components |
| ctrl_emb_all = data.emb[control_idx_pool] |
| ctrl_mean_emb = ctrl_emb_all.mean(0) |
| perts = list(test_perts) |
| if max_perts: |
| perts = perts[:max_perts] |
|
|
| rows = {k: [] for k in ["mse", "r2", "pearson", "spearman", "de_corr", "mmd", "wasserstein"]} |
| for p in perts: |
| true_idx = data.pert_to_idx[p] |
| true_pop = data.emb[true_idx] |
| c0 = ctrl_emb_all[rng.choice(len(ctrl_emb_all), min(n_ctrl, len(ctrl_emb_all)), replace=False)] |
| pred_pop = predictor.population(p, c0) |
|
|
| pred_eff = _gene_effect(pred_pop.mean(0), ctrl_mean_emb, comp) |
| true_eff = _gene_effect(true_pop.mean(0), ctrl_mean_emb, comp) |
| |
| pred_genes = pred_pop.mean(0) @ comp + data.pca_mean |
| true_genes = true_pop.mean(0) @ comp + data.pca_mean |
|
|
| rows["mse"].append(M.mse(pred_genes, true_genes)) |
| rows["r2"].append(M.r2(pred_eff, true_eff)) |
| rows["pearson"].append(M.pearson(pred_eff, true_eff)) |
| rows["spearman"].append(M.spearman(pred_eff, true_eff)) |
| rows["de_corr"].append(M.de_gene_correlation(pred_eff, true_eff, k=de_k)) |
| tp = true_pop if len(true_pop) <= dist_cap else true_pop[rng.choice(len(true_pop), dist_cap, replace=False)] |
| pp = pred_pop if len(pred_pop) <= dist_cap else pred_pop[rng.choice(len(pred_pop), dist_cap, replace=False)] |
| rows["mmd"].append(M.mmd_rbf(pp, tp, seed=seed)) |
| rows["wasserstein"].append(M.sliced_wasserstein(pp, tp, seed=seed)) |
|
|
| out = {k: float(np.mean(v)) for k, v in rows.items()} |
| out["_per_pert"] = {k: rows[k] for k in rows} |
| out["n_perts"] = len(perts) |
| return out |
|
|