""" Perturbation Evaluation Suite for MuLGIT-Perturb. Implements the standard perturbation benchmark metrics: 1. DES@K — Differential Expression Score at rank K 2. Pearson-Δ — Correlation between predicted and true expression deltas 3. Direction-match — Fraction of genes with correct sign of change 4. PDS — Perturbation Discrimination Score 5. RMSE/MAE — Raw expression reconstruction error 6. Spearman-sig — Rank correlation of significant DE genes All metrics follow PerturBench conventions (arxiv:2408.10609). """ import torch import torch.nn.functional as F import numpy as np from typing import Optional, Dict, List, Tuple from scipy.stats import pearsonr, spearmanr from sklearn.metrics import roc_auc_score class PerturbationEvaluator: """ Evaluates perturbation response predictions against ground truth. Usage: evaluator = PerturbationEvaluator() metrics = evaluator.evaluate(delta_pred, delta_true) print(metrics) """ def __init__(self, n_top_genes: List[int] = None): self.n_top_genes = n_top_genes or [20, 50, 100, 200] def evaluate( self, delta_pred: torch.Tensor, delta_true: torch.Tensor, sigma2_pred: Optional[torch.Tensor] = None, ) -> Dict[str, float]: """ Compute all evaluation metrics. Args: delta_pred: (B, G) or (G,) predicted expression change delta_true: (B, G) or (G,) true expression change sigma2_pred: (B, G) or (G,) predicted variance (optional, for calibration) Returns: metrics: dict of metric name → value """ metrics = {} # Ensure 2D if delta_pred.dim() == 1: delta_pred = delta_pred.unsqueeze(0) if delta_true.dim() == 1: delta_true = delta_true.unsqueeze(0) B, G = delta_pred.shape # ── Per-sample metrics ───────────────────────────────────── pearson_deltas = [] spearman_sigs = [] direction_matches = [] des_scores = {k: [] for k in self.n_top_genes} rmses = [] maes = [] for b in range(B): dp = delta_pred[b].cpu().numpy() dt = delta_true[b].cpu().numpy() # Pearson-Δ: correlation of predicted vs true deltas pr, _ = pearsonr(dp, dt) pearson_deltas.append(pr) # Spearman-sig: rank correlation on "significant" genes sig_mask = np.abs(dt) > np.percentile(np.abs(dt), 90) # top 10% as "significant" if sig_mask.sum() >= 3: sr, _ = spearmanr(dp[sig_mask], dt[sig_mask]) else: sr = 0.0 spearman_sigs.append(sr) # Direction-match: fraction of genes with correct sign sign_match = (np.sign(dp) == np.sign(dt)).astype(float) # Don't count genes with zero true change nonzero_mask = dt != 0 if nonzero_mask.sum() > 0: dm = sign_match[nonzero_mask].mean() else: dm = 0.5 direction_matches.append(dm) # DES@K: fraction of true top-K DEGs recovered in predicted top-K true_top = set(np.argsort(np.abs(dt))[::-1]) pred_top = np.argsort(np.abs(dp))[::-1] for k in self.n_top_genes: overlap = len(set(pred_top[:k]) & true_top[:k]) des_scores[k].append(overlap / k) # RMSE / MAE rmses.append(np.sqrt(np.mean((dp - dt) ** 2))) maes.append(np.mean(np.abs(dp - dt))) # ── Aggregate ───────────────────────────────────────────── metrics["pearson_delta"] = float(np.mean(pearson_deltas)) metrics["pearson_delta_std"] = float(np.std(pearson_deltas)) metrics["spearman_sig"] = float(np.mean(spearman_sigs)) metrics["direction_match"] = float(np.mean(direction_matches)) for k, scores in des_scores.items(): metrics[f"des@{k}"] = float(np.mean(scores)) metrics[f"des@{k}_std"] = float(np.std(scores)) metrics["rmse"] = float(np.mean(rmses)) metrics["mae"] = float(np.mean(maes)) # ── Uncertainty calibration (if σ² provided) ────────────── if sigma2_pred is not None: metrics.update(self._evaluate_calibration(delta_pred, delta_true, sigma2_pred)) return metrics def _evaluate_calibration( self, delta_pred: torch.Tensor, delta_true: torch.Tensor, sigma2_pred: torch.Tensor, ) -> Dict[str, float]: """ Evaluate uncertainty calibration. For a well-calibrated model, ~68% of true values should fall within ±1σ of the mean, ~95% within ±2σ. """ dp = delta_pred.cpu().numpy() dt = delta_true.cpu().numpy() sp = torch.sqrt(sigma2_pred.clamp(min=1e-6)).cpu().numpy() # Z-scores: (true - pred) / sigma z_scores = (dt - dp) / (sp + 1e-6) # Calibration metrics metrics = {} for n_sigma in [1, 2, 3]: expected_fraction = { 1: 0.6827, 2: 0.9545, 3: 0.9973, }[n_sigma] observed_fraction = float(np.mean(np.abs(z_scores) <= n_sigma)) metrics[f"calibration_{n_sigma}sigma"] = observed_fraction metrics[f"calibration_error_{n_sigma}sigma"] = abs(observed_fraction - expected_fraction) # Average calibration error metrics["avg_calibration_error"] = float(np.mean([ metrics[f"calibration_error_{n}sigma"] for n in [1, 2, 3] ])) # Mean predicted uncertainty metrics["mean_predicted_std"] = float(np.mean(sp)) return metrics def evaluate_per_perturbation( self, delta_pred: torch.Tensor, delta_true: torch.Tensor, perturbation_ids: List[str], ) -> Dict[str, Dict[str, float]]: """ Evaluate metrics separately for each perturbation. Args: delta_pred: (N, G) predicted deltas delta_true: (N, G) true deltas perturbation_ids: list of perturbation identifiers (length N) Returns: per_pert_metrics: {pert_id: {metric: value}} """ unique_perts = list(set(perturbation_ids)) per_pert_metrics = {} for pert_id in unique_perts: mask = [p == pert_id for p in perturbation_ids] mask_tensor = torch.tensor(mask) if mask_tensor.sum() < 2: continue dp = delta_pred[mask_tensor] dt = delta_true[mask_tensor] per_pert_metrics[pert_id] = self.evaluate(dp, dt) return per_pert_metrics def pds( self, delta_pred: torch.Tensor, delta_true: torch.Tensor, perturbation_ids: List[str], ) -> float: """ Perturbation Discrimination Score (PDS). For each pair of perturbations (A, B), the model should predict patterns that are more similar to A's true pattern than B's true pattern. PDS = P(distance(pred_A, true_A) < distance(pred_A, true_B)) 1.0 = perfect discrimination, 0.5 = random. Reference: PerturBench (arxiv:2408.10609) """ unique_perts = list(set(perturbation_ids)) if len(unique_perts) < 2: return 1.0 correct = 0 total = 0 for pert_a in unique_perts: mask_a = torch.tensor([p == pert_a for p in perturbation_ids]) if mask_a.sum() < 2: continue # Average prediction and truth for perturbation A pred_a = delta_pred[mask_a].mean(dim=0) true_a = delta_true[mask_a].mean(dim=0) for pert_b in unique_perts: if pert_a == pert_b: continue mask_b = torch.tensor([p == pert_b for p in perturbation_ids]) if mask_b.sum() < 2: continue true_b = delta_true[mask_b].mean(dim=0) # Compare: is pred_A closer to true_A than to true_B? dist_to_a = F.mse_loss(pred_a, true_a).item() dist_to_b = F.mse_loss(pred_a, true_b).item() if dist_to_a < dist_to_b: correct += 1 total += 1 if total == 0: return 1.0 return correct / total def evaluate_sample( self, delta_pred: torch.Tensor, delta_true: torch.Tensor, ) -> Dict: """ Detailed evaluation of a single sample, including per-gene metrics. Returns: dict with metrics and per-gene rankings for downstream analysis. """ metrics = self.evaluate(delta_pred.unsqueeze(0), delta_true.unsqueeze(0)) # Per-gene errors dp = delta_pred.squeeze(0) dt = delta_true.squeeze(0) abs_errors = (dp - dt).abs() # Top correctly predicted genes (by rank) true_rank = dt.abs().argsort(descending=True) pred_rank = dp.abs().argsort(descending=True) metrics["top10_true_genes"] = true_rank[:10].tolist() metrics["top10_pred_genes"] = pred_rank[:10].tolist() metrics["top10_overlap"] = len(set(true_rank[:10].tolist()) & set(pred_rank[:10].tolist())) return metrics # ─── Convenience Functions ────────────────────────────────────────────── def quick_evaluate( delta_pred: torch.Tensor, delta_true: torch.Tensor, ) -> Dict[str, float]: """Quick evaluation with default settings.""" evaluator = PerturbationEvaluator() return evaluator.evaluate(delta_pred, delta_true) def print_metrics(metrics: Dict[str, float], prefix: str = ""): """Pretty-print evaluation metrics.""" print(f"\n{prefix} Perturbation Evaluation Results") print("=" * 50) for key, value in metrics.items(): if "std" not in key and "error" not in key: print(f" {key:30s}: {value:.4f}") print("=" * 50)