| """shared metric library. built once, used by every table. |
| |
| pointwise/gene-level: mse, r2, pearson, spearman, de-gene correlation. |
| distributional: mmd (rbf), sliced-wasserstein. |
| nomination: top-k, ndcg@k (graded), true/median rank, functional (pathway) recovery, |
| partial overlap (combos), synergy correlation. |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
| from scipy.stats import pearsonr, spearmanr |
|
|
|
|
| |
| def mse(pred, true): |
| return float(np.mean((np.asarray(pred) - np.asarray(true)) ** 2)) |
|
|
|
|
| def r2(pred, true): |
| true = np.asarray(true); pred = np.asarray(pred) |
| ss_res = np.sum((true - pred) ** 2) |
| ss_tot = np.sum((true - true.mean()) ** 2) + 1e-12 |
| return float(1 - ss_res / ss_tot) |
|
|
|
|
| def pearson(pred, true): |
| pred, true = np.asarray(pred).ravel(), np.asarray(true).ravel() |
| if pred.std() < 1e-12 or true.std() < 1e-12: |
| return 0.0 |
| return float(pearsonr(pred, true)[0]) |
|
|
|
|
| def spearman(pred, true): |
| pred, true = np.asarray(pred).ravel(), np.asarray(true).ravel() |
| if pred.std() < 1e-12 or true.std() < 1e-12: |
| return 0.0 |
| return float(spearmanr(pred, true)[0]) |
|
|
|
|
| def de_gene_correlation(pred_effect, true_effect, k: int = 20): |
| """pearson r over the top-k differentially-expressed genes (by |true effect|).""" |
| true_effect = np.asarray(true_effect).ravel() |
| pred_effect = np.asarray(pred_effect).ravel() |
| de = np.argsort(-np.abs(true_effect))[:k] |
| return pearson(pred_effect[de], true_effect[de]) |
|
|
|
|
| |
| def mmd_rbf(X, Y, gamma: float | None = None, max_n: int = 500, seed: int = 0): |
| rng = np.random.default_rng(seed) |
| X = np.asarray(X); Y = np.asarray(Y) |
| if len(X) > max_n: |
| X = X[rng.choice(len(X), max_n, replace=False)] |
| if len(Y) > max_n: |
| Y = Y[rng.choice(len(Y), max_n, replace=False)] |
| if gamma is None: |
| from scipy.spatial.distance import cdist |
| med = np.median(cdist(X[:200], Y[:200]) ** 2) + 1e-9 |
| gamma = 1.0 / med |
| from scipy.spatial.distance import cdist |
| Kxx = np.exp(-gamma * cdist(X, X) ** 2) |
| Kyy = np.exp(-gamma * cdist(Y, Y) ** 2) |
| Kxy = np.exp(-gamma * cdist(X, Y) ** 2) |
| return float(Kxx.mean() + Kyy.mean() - 2 * Kxy.mean()) |
|
|
|
|
| def sliced_wasserstein(X, Y, n_proj: int = 50, seed: int = 0): |
| """sliced-wasserstein-1 distance (multivariate, standard approximation).""" |
| rng = np.random.default_rng(seed) |
| X = np.asarray(X); Y = np.asarray(Y) |
| d = X.shape[1] |
| dirs = rng.normal(size=(n_proj, d)) |
| dirs /= np.linalg.norm(dirs, axis=1, keepdims=True) + 1e-12 |
| tot = 0.0 |
| for w in dirs: |
| xp = np.sort(X @ w); yp = np.sort(Y @ w) |
| m = min(len(xp), len(yp)) |
| xq = np.quantile(xp, np.linspace(0, 1, m)) |
| yq = np.quantile(yp, np.linspace(0, 1, m)) |
| tot += np.mean(np.abs(xq - yq)) |
| return float(tot / n_proj) |
|
|
|
|
| |
| def top_k_accuracy(ranked_labels, true_label, k: int) -> float: |
| return float(true_label in list(ranked_labels)[:k]) |
|
|
|
|
| def rank_of_true(ranked_labels, true_label) -> int: |
| ranked = list(ranked_labels) |
| return ranked.index(true_label) + 1 if true_label in ranked else len(ranked) + 1 |
|
|
|
|
| def ndcg_at_k(ranked_labels, relevance: dict, k: int = 10) -> float: |
| """graded ndcg@k. relevance maps label -> gain (e.g. 1 exact, 0.5 same-pathway).""" |
| rl = list(ranked_labels)[:k] |
| dcg = sum(relevance.get(l, 0.0) / np.log2(i + 2) for i, l in enumerate(rl)) |
| ideal = sorted(relevance.values(), reverse=True)[:k] |
| idcg = sum(g / np.log2(i + 2) for i, g in enumerate(ideal)) + 1e-12 |
| return float(dcg / idcg) |
|
|
|
|
| def exact_relevance(true_label) -> dict: |
| return {true_label: 1.0} |
|
|
|
|
| def functional_relevance(true_label, parse_fn, gene_cluster: dict, candidate_labels) -> dict: |
| """graded relevance: 1.0 exact match, 0.5 if shares a functional cluster with the target.""" |
| tgt_genes = parse_fn(true_label) |
| tgt_clusters = {gene_cluster.get(g) for g in tgt_genes} |
| rel = {} |
| for l in candidate_labels: |
| if l == true_label: |
| rel[l] = 1.0 |
| else: |
| lg = parse_fn(l) |
| if any(gene_cluster.get(g) in tgt_clusters for g in lg): |
| rel[l] = max(rel.get(l, 0.0), 0.5) |
| return rel |
|
|
|
|
| |
| def partial_overlap(pred_genes, true_genes) -> float: |
| a, b = set(pred_genes), set(true_genes) |
| return len(a & b) / max(len(b), 1) |
|
|
|
|
| def synergy_correlation(pred_effects, additive_effects, true_effects): |
| """corr between predicted and true non-additive (synergy) components.""" |
| pred_syn = np.asarray(pred_effects) - np.asarray(additive_effects) |
| true_syn = np.asarray(true_effects) - np.asarray(additive_effects) |
| return pearson(pred_syn.ravel(), true_syn.ravel()) |
|
|
|
|
| |
| def bootstrap_ci(values, n_boot: int = 2000, alpha: float = 0.05, seed: int = 0): |
| rng = np.random.default_rng(seed) |
| values = np.asarray(values, dtype=float) |
| means = [rng.choice(values, len(values), replace=True).mean() for _ in range(n_boot)] |
| lo, hi = np.quantile(means, [alpha / 2, 1 - alpha / 2]) |
| return float(values.mean()), float(lo), float(hi) |
|
|
|
|
| def paired_bootstrap_pvalue(a, b, n_boot: int = 2000, seed: int = 0): |
| """two-sided p-value that mean(a) != mean(b) for paired samples a,b.""" |
| rng = np.random.default_rng(seed) |
| a = np.asarray(a, float); b = np.asarray(b, float) |
| diff = a - b |
| obs = diff.mean() |
| centered = diff - obs |
| count = 0 |
| for _ in range(n_boot): |
| s = rng.choice(centered, len(centered), replace=True).mean() |
| if abs(s) >= abs(obs): |
| count += 1 |
| return float((count + 1) / (n_boot + 1)), float(obs) |
|
|