"""shared metric library. built once, used by every table. pointwise/gene-level: mse, r2, pearson, spearman, de-gene correlation. distributional: mmd (rbf), sliced-wasserstein. nomination: top-k, ndcg@k (graded), true/median rank, functional (pathway) recovery, partial overlap (combos), synergy correlation. """ from __future__ import annotations import numpy as np from scipy.stats import pearsonr, spearmanr # ---------- pointwise / gene-level ---------- def mse(pred, true): return float(np.mean((np.asarray(pred) - np.asarray(true)) ** 2)) def r2(pred, true): true = np.asarray(true); pred = np.asarray(pred) ss_res = np.sum((true - pred) ** 2) ss_tot = np.sum((true - true.mean()) ** 2) + 1e-12 return float(1 - ss_res / ss_tot) def pearson(pred, true): pred, true = np.asarray(pred).ravel(), np.asarray(true).ravel() if pred.std() < 1e-12 or true.std() < 1e-12: return 0.0 return float(pearsonr(pred, true)[0]) def spearman(pred, true): pred, true = np.asarray(pred).ravel(), np.asarray(true).ravel() if pred.std() < 1e-12 or true.std() < 1e-12: return 0.0 return float(spearmanr(pred, true)[0]) def de_gene_correlation(pred_effect, true_effect, k: int = 20): """pearson r over the top-k differentially-expressed genes (by |true effect|).""" true_effect = np.asarray(true_effect).ravel() pred_effect = np.asarray(pred_effect).ravel() de = np.argsort(-np.abs(true_effect))[:k] return pearson(pred_effect[de], true_effect[de]) # ---------- distributional ---------- def mmd_rbf(X, Y, gamma: float | None = None, max_n: int = 500, seed: int = 0): rng = np.random.default_rng(seed) X = np.asarray(X); Y = np.asarray(Y) if len(X) > max_n: X = X[rng.choice(len(X), max_n, replace=False)] if len(Y) > max_n: Y = Y[rng.choice(len(Y), max_n, replace=False)] if gamma is None: from scipy.spatial.distance import cdist med = np.median(cdist(X[:200], Y[:200]) ** 2) + 1e-9 gamma = 1.0 / med from scipy.spatial.distance import cdist Kxx = np.exp(-gamma * cdist(X, X) ** 2) Kyy = np.exp(-gamma * cdist(Y, Y) ** 2) Kxy = np.exp(-gamma * cdist(X, Y) ** 2) return float(Kxx.mean() + Kyy.mean() - 2 * Kxy.mean()) def sliced_wasserstein(X, Y, n_proj: int = 50, seed: int = 0): """sliced-wasserstein-1 distance (multivariate, standard approximation).""" rng = np.random.default_rng(seed) X = np.asarray(X); Y = np.asarray(Y) d = X.shape[1] dirs = rng.normal(size=(n_proj, d)) dirs /= np.linalg.norm(dirs, axis=1, keepdims=True) + 1e-12 tot = 0.0 for w in dirs: xp = np.sort(X @ w); yp = np.sort(Y @ w) m = min(len(xp), len(yp)) xq = np.quantile(xp, np.linspace(0, 1, m)) yq = np.quantile(yp, np.linspace(0, 1, m)) tot += np.mean(np.abs(xq - yq)) return float(tot / n_proj) # ---------- nomination ---------- def top_k_accuracy(ranked_labels, true_label, k: int) -> float: return float(true_label in list(ranked_labels)[:k]) def rank_of_true(ranked_labels, true_label) -> int: ranked = list(ranked_labels) return ranked.index(true_label) + 1 if true_label in ranked else len(ranked) + 1 def ndcg_at_k(ranked_labels, relevance: dict, k: int = 10) -> float: """graded ndcg@k. relevance maps label -> gain (e.g. 1 exact, 0.5 same-pathway).""" rl = list(ranked_labels)[:k] dcg = sum(relevance.get(l, 0.0) / np.log2(i + 2) for i, l in enumerate(rl)) ideal = sorted(relevance.values(), reverse=True)[:k] idcg = sum(g / np.log2(i + 2) for i, g in enumerate(ideal)) + 1e-12 return float(dcg / idcg) def exact_relevance(true_label) -> dict: return {true_label: 1.0} def functional_relevance(true_label, parse_fn, gene_cluster: dict, candidate_labels) -> dict: """graded relevance: 1.0 exact match, 0.5 if shares a functional cluster with the target.""" tgt_genes = parse_fn(true_label) tgt_clusters = {gene_cluster.get(g) for g in tgt_genes} rel = {} for l in candidate_labels: if l == true_label: rel[l] = 1.0 else: lg = parse_fn(l) if any(gene_cluster.get(g) in tgt_clusters for g in lg): rel[l] = max(rel.get(l, 0.0), 0.5) return rel # ---------- combinatorial ---------- def partial_overlap(pred_genes, true_genes) -> float: a, b = set(pred_genes), set(true_genes) return len(a & b) / max(len(b), 1) def synergy_correlation(pred_effects, additive_effects, true_effects): """corr between predicted and true non-additive (synergy) components.""" pred_syn = np.asarray(pred_effects) - np.asarray(additive_effects) true_syn = np.asarray(true_effects) - np.asarray(additive_effects) return pearson(pred_syn.ravel(), true_syn.ravel()) # ---------- bootstrap CI / paired test ---------- def bootstrap_ci(values, n_boot: int = 2000, alpha: float = 0.05, seed: int = 0): rng = np.random.default_rng(seed) values = np.asarray(values, dtype=float) means = [rng.choice(values, len(values), replace=True).mean() for _ in range(n_boot)] lo, hi = np.quantile(means, [alpha / 2, 1 - alpha / 2]) return float(values.mean()), float(lo), float(hi) def paired_bootstrap_pvalue(a, b, n_boot: int = 2000, seed: int = 0): """two-sided p-value that mean(a) != mean(b) for paired samples a,b.""" rng = np.random.default_rng(seed) a = np.asarray(a, float); b = np.asarray(b, float) diff = a - b obs = diff.mean() centered = diff - obs count = 0 for _ in range(n_boot): s = rng.choice(centered, len(centered), replace=True).mean() if abs(s) >= abs(obs): count += 1 return float((count + 1) / (n_boot + 1)), float(obs)