MuLGIT / mulgit /perturb /evaluate.py
vedatonuryilmaz's picture
Upload mulgit/perturb/evaluate.py
eeab3e6 verified
"""
Perturbation Evaluation Suite for MuLGIT-Perturb.
Implements the standard perturbation benchmark metrics:
1. DES@K β€” Differential Expression Score at rank K
2. Pearson-Ξ” β€” Correlation between predicted and true expression deltas
3. Direction-match β€” Fraction of genes with correct sign of change
4. PDS β€” Perturbation Discrimination Score
5. RMSE/MAE β€” Raw expression reconstruction error
6. Spearman-sig β€” Rank correlation of significant DE genes
All metrics follow PerturBench conventions (arxiv:2408.10609).
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import Optional, Dict, List, Tuple
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import roc_auc_score
class PerturbationEvaluator:
"""
Evaluates perturbation response predictions against ground truth.
Usage:
evaluator = PerturbationEvaluator()
metrics = evaluator.evaluate(delta_pred, delta_true)
print(metrics)
"""
def __init__(self, n_top_genes: List[int] = None):
self.n_top_genes = n_top_genes or [20, 50, 100, 200]
def evaluate(
self,
delta_pred: torch.Tensor,
delta_true: torch.Tensor,
sigma2_pred: Optional[torch.Tensor] = None,
) -> Dict[str, float]:
"""
Compute all evaluation metrics.
Args:
delta_pred: (B, G) or (G,) predicted expression change
delta_true: (B, G) or (G,) true expression change
sigma2_pred: (B, G) or (G,) predicted variance (optional, for calibration)
Returns:
metrics: dict of metric name β†’ value
"""
metrics = {}
# Ensure 2D
if delta_pred.dim() == 1:
delta_pred = delta_pred.unsqueeze(0)
if delta_true.dim() == 1:
delta_true = delta_true.unsqueeze(0)
B, G = delta_pred.shape
# ── Per-sample metrics ─────────────────────────────────────
pearson_deltas = []
spearman_sigs = []
direction_matches = []
des_scores = {k: [] for k in self.n_top_genes}
rmses = []
maes = []
for b in range(B):
dp = delta_pred[b].cpu().numpy()
dt = delta_true[b].cpu().numpy()
# Pearson-Ξ”: correlation of predicted vs true deltas
pr, _ = pearsonr(dp, dt)
pearson_deltas.append(pr)
# Spearman-sig: rank correlation on "significant" genes
sig_mask = np.abs(dt) > np.percentile(np.abs(dt), 90) # top 10% as "significant"
if sig_mask.sum() >= 3:
sr, _ = spearmanr(dp[sig_mask], dt[sig_mask])
else:
sr = 0.0
spearman_sigs.append(sr)
# Direction-match: fraction of genes with correct sign
sign_match = (np.sign(dp) == np.sign(dt)).astype(float)
# Don't count genes with zero true change
nonzero_mask = dt != 0
if nonzero_mask.sum() > 0:
dm = sign_match[nonzero_mask].mean()
else:
dm = 0.5
direction_matches.append(dm)
# DES@K: fraction of true top-K DEGs recovered in predicted top-K
true_top = set(np.argsort(np.abs(dt))[::-1])
pred_top = np.argsort(np.abs(dp))[::-1]
for k in self.n_top_genes:
overlap = len(set(pred_top[:k]) & true_top[:k])
des_scores[k].append(overlap / k)
# RMSE / MAE
rmses.append(np.sqrt(np.mean((dp - dt) ** 2)))
maes.append(np.mean(np.abs(dp - dt)))
# ── Aggregate ─────────────────────────────────────────────
metrics["pearson_delta"] = float(np.mean(pearson_deltas))
metrics["pearson_delta_std"] = float(np.std(pearson_deltas))
metrics["spearman_sig"] = float(np.mean(spearman_sigs))
metrics["direction_match"] = float(np.mean(direction_matches))
for k, scores in des_scores.items():
metrics[f"des@{k}"] = float(np.mean(scores))
metrics[f"des@{k}_std"] = float(np.std(scores))
metrics["rmse"] = float(np.mean(rmses))
metrics["mae"] = float(np.mean(maes))
# ── Uncertainty calibration (if σ² provided) ──────────────
if sigma2_pred is not None:
metrics.update(self._evaluate_calibration(delta_pred, delta_true, sigma2_pred))
return metrics
def _evaluate_calibration(
self,
delta_pred: torch.Tensor,
delta_true: torch.Tensor,
sigma2_pred: torch.Tensor,
) -> Dict[str, float]:
"""
Evaluate uncertainty calibration.
For a well-calibrated model, ~68% of true values should fall within
Β±1Οƒ of the mean, ~95% within Β±2Οƒ.
"""
dp = delta_pred.cpu().numpy()
dt = delta_true.cpu().numpy()
sp = torch.sqrt(sigma2_pred.clamp(min=1e-6)).cpu().numpy()
# Z-scores: (true - pred) / sigma
z_scores = (dt - dp) / (sp + 1e-6)
# Calibration metrics
metrics = {}
for n_sigma in [1, 2, 3]:
expected_fraction = {
1: 0.6827,
2: 0.9545,
3: 0.9973,
}[n_sigma]
observed_fraction = float(np.mean(np.abs(z_scores) <= n_sigma))
metrics[f"calibration_{n_sigma}sigma"] = observed_fraction
metrics[f"calibration_error_{n_sigma}sigma"] = abs(observed_fraction - expected_fraction)
# Average calibration error
metrics["avg_calibration_error"] = float(np.mean([
metrics[f"calibration_error_{n}sigma"] for n in [1, 2, 3]
]))
# Mean predicted uncertainty
metrics["mean_predicted_std"] = float(np.mean(sp))
return metrics
def evaluate_per_perturbation(
self,
delta_pred: torch.Tensor,
delta_true: torch.Tensor,
perturbation_ids: List[str],
) -> Dict[str, Dict[str, float]]:
"""
Evaluate metrics separately for each perturbation.
Args:
delta_pred: (N, G) predicted deltas
delta_true: (N, G) true deltas
perturbation_ids: list of perturbation identifiers (length N)
Returns:
per_pert_metrics: {pert_id: {metric: value}}
"""
unique_perts = list(set(perturbation_ids))
per_pert_metrics = {}
for pert_id in unique_perts:
mask = [p == pert_id for p in perturbation_ids]
mask_tensor = torch.tensor(mask)
if mask_tensor.sum() < 2:
continue
dp = delta_pred[mask_tensor]
dt = delta_true[mask_tensor]
per_pert_metrics[pert_id] = self.evaluate(dp, dt)
return per_pert_metrics
def pds(
self,
delta_pred: torch.Tensor,
delta_true: torch.Tensor,
perturbation_ids: List[str],
) -> float:
"""
Perturbation Discrimination Score (PDS).
For each pair of perturbations (A, B), the model should predict
patterns that are more similar to A's true pattern than B's true pattern.
PDS = P(distance(pred_A, true_A) < distance(pred_A, true_B))
1.0 = perfect discrimination, 0.5 = random.
Reference: PerturBench (arxiv:2408.10609)
"""
unique_perts = list(set(perturbation_ids))
if len(unique_perts) < 2:
return 1.0
correct = 0
total = 0
for pert_a in unique_perts:
mask_a = torch.tensor([p == pert_a for p in perturbation_ids])
if mask_a.sum() < 2:
continue
# Average prediction and truth for perturbation A
pred_a = delta_pred[mask_a].mean(dim=0)
true_a = delta_true[mask_a].mean(dim=0)
for pert_b in unique_perts:
if pert_a == pert_b:
continue
mask_b = torch.tensor([p == pert_b for p in perturbation_ids])
if mask_b.sum() < 2:
continue
true_b = delta_true[mask_b].mean(dim=0)
# Compare: is pred_A closer to true_A than to true_B?
dist_to_a = F.mse_loss(pred_a, true_a).item()
dist_to_b = F.mse_loss(pred_a, true_b).item()
if dist_to_a < dist_to_b:
correct += 1
total += 1
if total == 0:
return 1.0
return correct / total
def evaluate_sample(
self,
delta_pred: torch.Tensor,
delta_true: torch.Tensor,
) -> Dict:
"""
Detailed evaluation of a single sample, including per-gene metrics.
Returns:
dict with metrics and per-gene rankings for downstream analysis.
"""
metrics = self.evaluate(delta_pred.unsqueeze(0), delta_true.unsqueeze(0))
# Per-gene errors
dp = delta_pred.squeeze(0)
dt = delta_true.squeeze(0)
abs_errors = (dp - dt).abs()
# Top correctly predicted genes (by rank)
true_rank = dt.abs().argsort(descending=True)
pred_rank = dp.abs().argsort(descending=True)
metrics["top10_true_genes"] = true_rank[:10].tolist()
metrics["top10_pred_genes"] = pred_rank[:10].tolist()
metrics["top10_overlap"] = len(set(true_rank[:10].tolist()) & set(pred_rank[:10].tolist()))
return metrics
# ─── Convenience Functions ──────────────────────────────────────────────
def quick_evaluate(
delta_pred: torch.Tensor,
delta_true: torch.Tensor,
) -> Dict[str, float]:
"""Quick evaluation with default settings."""
evaluator = PerturbationEvaluator()
return evaluator.evaluate(delta_pred, delta_true)
def print_metrics(metrics: Dict[str, float], prefix: str = ""):
"""Pretty-print evaluation metrics."""
print(f"\n{prefix} Perturbation Evaluation Results")
print("=" * 50)
for key, value in metrics.items():
if "std" not in key and "error" not in key:
print(f" {key:30s}: {value:.4f}")
print("=" * 50)