PIVOT / src /evaluation /metrics.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
5.78 kB
"""shared metric library. built once, used by every table.
pointwise/gene-level: mse, r2, pearson, spearman, de-gene correlation.
distributional: mmd (rbf), sliced-wasserstein.
nomination: top-k, ndcg@k (graded), true/median rank, functional (pathway) recovery,
partial overlap (combos), synergy correlation.
"""
from __future__ import annotations
import numpy as np
from scipy.stats import pearsonr, spearmanr
# ---------- pointwise / gene-level ----------
def mse(pred, true):
return float(np.mean((np.asarray(pred) - np.asarray(true)) ** 2))
def r2(pred, true):
true = np.asarray(true); pred = np.asarray(pred)
ss_res = np.sum((true - pred) ** 2)
ss_tot = np.sum((true - true.mean()) ** 2) + 1e-12
return float(1 - ss_res / ss_tot)
def pearson(pred, true):
pred, true = np.asarray(pred).ravel(), np.asarray(true).ravel()
if pred.std() < 1e-12 or true.std() < 1e-12:
return 0.0
return float(pearsonr(pred, true)[0])
def spearman(pred, true):
pred, true = np.asarray(pred).ravel(), np.asarray(true).ravel()
if pred.std() < 1e-12 or true.std() < 1e-12:
return 0.0
return float(spearmanr(pred, true)[0])
def de_gene_correlation(pred_effect, true_effect, k: int = 20):
"""pearson r over the top-k differentially-expressed genes (by |true effect|)."""
true_effect = np.asarray(true_effect).ravel()
pred_effect = np.asarray(pred_effect).ravel()
de = np.argsort(-np.abs(true_effect))[:k]
return pearson(pred_effect[de], true_effect[de])
# ---------- distributional ----------
def mmd_rbf(X, Y, gamma: float | None = None, max_n: int = 500, seed: int = 0):
rng = np.random.default_rng(seed)
X = np.asarray(X); Y = np.asarray(Y)
if len(X) > max_n:
X = X[rng.choice(len(X), max_n, replace=False)]
if len(Y) > max_n:
Y = Y[rng.choice(len(Y), max_n, replace=False)]
if gamma is None:
from scipy.spatial.distance import cdist
med = np.median(cdist(X[:200], Y[:200]) ** 2) + 1e-9
gamma = 1.0 / med
from scipy.spatial.distance import cdist
Kxx = np.exp(-gamma * cdist(X, X) ** 2)
Kyy = np.exp(-gamma * cdist(Y, Y) ** 2)
Kxy = np.exp(-gamma * cdist(X, Y) ** 2)
return float(Kxx.mean() + Kyy.mean() - 2 * Kxy.mean())
def sliced_wasserstein(X, Y, n_proj: int = 50, seed: int = 0):
"""sliced-wasserstein-1 distance (multivariate, standard approximation)."""
rng = np.random.default_rng(seed)
X = np.asarray(X); Y = np.asarray(Y)
d = X.shape[1]
dirs = rng.normal(size=(n_proj, d))
dirs /= np.linalg.norm(dirs, axis=1, keepdims=True) + 1e-12
tot = 0.0
for w in dirs:
xp = np.sort(X @ w); yp = np.sort(Y @ w)
m = min(len(xp), len(yp))
xq = np.quantile(xp, np.linspace(0, 1, m))
yq = np.quantile(yp, np.linspace(0, 1, m))
tot += np.mean(np.abs(xq - yq))
return float(tot / n_proj)
# ---------- nomination ----------
def top_k_accuracy(ranked_labels, true_label, k: int) -> float:
return float(true_label in list(ranked_labels)[:k])
def rank_of_true(ranked_labels, true_label) -> int:
ranked = list(ranked_labels)
return ranked.index(true_label) + 1 if true_label in ranked else len(ranked) + 1
def ndcg_at_k(ranked_labels, relevance: dict, k: int = 10) -> float:
"""graded ndcg@k. relevance maps label -> gain (e.g. 1 exact, 0.5 same-pathway)."""
rl = list(ranked_labels)[:k]
dcg = sum(relevance.get(l, 0.0) / np.log2(i + 2) for i, l in enumerate(rl))
ideal = sorted(relevance.values(), reverse=True)[:k]
idcg = sum(g / np.log2(i + 2) for i, g in enumerate(ideal)) + 1e-12
return float(dcg / idcg)
def exact_relevance(true_label) -> dict:
return {true_label: 1.0}
def functional_relevance(true_label, parse_fn, gene_cluster: dict, candidate_labels) -> dict:
"""graded relevance: 1.0 exact match, 0.5 if shares a functional cluster with the target."""
tgt_genes = parse_fn(true_label)
tgt_clusters = {gene_cluster.get(g) for g in tgt_genes}
rel = {}
for l in candidate_labels:
if l == true_label:
rel[l] = 1.0
else:
lg = parse_fn(l)
if any(gene_cluster.get(g) in tgt_clusters for g in lg):
rel[l] = max(rel.get(l, 0.0), 0.5)
return rel
# ---------- combinatorial ----------
def partial_overlap(pred_genes, true_genes) -> float:
a, b = set(pred_genes), set(true_genes)
return len(a & b) / max(len(b), 1)
def synergy_correlation(pred_effects, additive_effects, true_effects):
"""corr between predicted and true non-additive (synergy) components."""
pred_syn = np.asarray(pred_effects) - np.asarray(additive_effects)
true_syn = np.asarray(true_effects) - np.asarray(additive_effects)
return pearson(pred_syn.ravel(), true_syn.ravel())
# ---------- bootstrap CI / paired test ----------
def bootstrap_ci(values, n_boot: int = 2000, alpha: float = 0.05, seed: int = 0):
rng = np.random.default_rng(seed)
values = np.asarray(values, dtype=float)
means = [rng.choice(values, len(values), replace=True).mean() for _ in range(n_boot)]
lo, hi = np.quantile(means, [alpha / 2, 1 - alpha / 2])
return float(values.mean()), float(lo), float(hi)
def paired_bootstrap_pvalue(a, b, n_boot: int = 2000, seed: int = 0):
"""two-sided p-value that mean(a) != mean(b) for paired samples a,b."""
rng = np.random.default_rng(seed)
a = np.asarray(a, float); b = np.asarray(b, float)
diff = a - b
obs = diff.mean()
centered = diff - obs
count = 0
for _ in range(n_boot):
s = rng.choice(centered, len(centered), replace=True).mean()
if abs(s) >= abs(obs):
count += 1
return float((count + 1) / (n_boot + 1)), float(obs)