"""forward response prediction eval (tables 2,3,10,11). for each test perturbation: push a matched control sample through the predictor to get a predicted post-perturbation population, then compare to the true one. gene-space metrics use the pca decoder (affine, so mean-then-decode = decode-then-mean). distributional metrics are computed in embedding space. """ from __future__ import annotations import numpy as np from src.evaluation import metrics as M def _gene_effect(mean_emb, control_mean_emb, components): """affine decode of an embedding-space shift to hvg gene-space effect.""" return (mean_emb - control_mean_emb) @ components def evaluate_forward(predictor, data, test_perts, control_idx_pool, n_ctrl=200, de_k=20, seed=0, max_perts=None, dist_cap=300): """aggregated forward metrics over test_perts. dist_cap subsamples populations for the quadratic mmd / wasserstein metrics.""" rng = np.random.default_rng(seed) comp = data.pca_components # (d, h) ctrl_emb_all = data.emb[control_idx_pool] ctrl_mean_emb = ctrl_emb_all.mean(0) perts = list(test_perts) if max_perts: perts = perts[:max_perts] rows = {k: [] for k in ["mse", "r2", "pearson", "spearman", "de_corr", "mmd", "wasserstein"]} for p in perts: true_idx = data.pert_to_idx[p] true_pop = data.emb[true_idx] c0 = ctrl_emb_all[rng.choice(len(ctrl_emb_all), min(n_ctrl, len(ctrl_emb_all)), replace=False)] pred_pop = predictor.population(p, c0) pred_eff = _gene_effect(pred_pop.mean(0), ctrl_mean_emb, comp) true_eff = _gene_effect(true_pop.mean(0), ctrl_mean_emb, comp) # decoded gene-space expression vectors for mse/r2 pred_genes = pred_pop.mean(0) @ comp + data.pca_mean true_genes = true_pop.mean(0) @ comp + data.pca_mean rows["mse"].append(M.mse(pred_genes, true_genes)) rows["r2"].append(M.r2(pred_eff, true_eff)) rows["pearson"].append(M.pearson(pred_eff, true_eff)) rows["spearman"].append(M.spearman(pred_eff, true_eff)) rows["de_corr"].append(M.de_gene_correlation(pred_eff, true_eff, k=de_k)) tp = true_pop if len(true_pop) <= dist_cap else true_pop[rng.choice(len(true_pop), dist_cap, replace=False)] pp = pred_pop if len(pred_pop) <= dist_cap else pred_pop[rng.choice(len(pred_pop), dist_cap, replace=False)] rows["mmd"].append(M.mmd_rbf(pp, tp, seed=seed)) rows["wasserstein"].append(M.sliced_wasserstein(pp, tp, seed=seed)) out = {k: float(np.mean(v)) for k, v in rows.items()} out["_per_pert"] = {k: rows[k] for k in rows} # kept for bootstrap cis / paired tests out["n_perts"] = len(perts) return out