| """ |
| Perturbation Evaluation Suite for MuLGIT-Perturb. |
| |
| Implements the standard perturbation benchmark metrics: |
| 1. DES@K β Differential Expression Score at rank K |
| 2. Pearson-Ξ β Correlation between predicted and true expression deltas |
| 3. Direction-match β Fraction of genes with correct sign of change |
| 4. PDS β Perturbation Discrimination Score |
| 5. RMSE/MAE β Raw expression reconstruction error |
| 6. Spearman-sig β Rank correlation of significant DE genes |
| |
| All metrics follow PerturBench conventions (arxiv:2408.10609). |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import Optional, Dict, List, Tuple |
| from scipy.stats import pearsonr, spearmanr |
| from sklearn.metrics import roc_auc_score |
|
|
|
|
| class PerturbationEvaluator: |
| """ |
| Evaluates perturbation response predictions against ground truth. |
| |
| Usage: |
| evaluator = PerturbationEvaluator() |
| metrics = evaluator.evaluate(delta_pred, delta_true) |
| print(metrics) |
| """ |
|
|
| def __init__(self, n_top_genes: List[int] = None): |
| self.n_top_genes = n_top_genes or [20, 50, 100, 200] |
|
|
| def evaluate( |
| self, |
| delta_pred: torch.Tensor, |
| delta_true: torch.Tensor, |
| sigma2_pred: Optional[torch.Tensor] = None, |
| ) -> Dict[str, float]: |
| """ |
| Compute all evaluation metrics. |
| |
| Args: |
| delta_pred: (B, G) or (G,) predicted expression change |
| delta_true: (B, G) or (G,) true expression change |
| sigma2_pred: (B, G) or (G,) predicted variance (optional, for calibration) |
| |
| Returns: |
| metrics: dict of metric name β value |
| """ |
| metrics = {} |
|
|
| |
| if delta_pred.dim() == 1: |
| delta_pred = delta_pred.unsqueeze(0) |
| if delta_true.dim() == 1: |
| delta_true = delta_true.unsqueeze(0) |
|
|
| B, G = delta_pred.shape |
|
|
| |
| pearson_deltas = [] |
| spearman_sigs = [] |
| direction_matches = [] |
| des_scores = {k: [] for k in self.n_top_genes} |
| rmses = [] |
| maes = [] |
|
|
| for b in range(B): |
| dp = delta_pred[b].cpu().numpy() |
| dt = delta_true[b].cpu().numpy() |
|
|
| |
| pr, _ = pearsonr(dp, dt) |
| pearson_deltas.append(pr) |
|
|
| |
| sig_mask = np.abs(dt) > np.percentile(np.abs(dt), 90) |
| if sig_mask.sum() >= 3: |
| sr, _ = spearmanr(dp[sig_mask], dt[sig_mask]) |
| else: |
| sr = 0.0 |
| spearman_sigs.append(sr) |
|
|
| |
| sign_match = (np.sign(dp) == np.sign(dt)).astype(float) |
| |
| nonzero_mask = dt != 0 |
| if nonzero_mask.sum() > 0: |
| dm = sign_match[nonzero_mask].mean() |
| else: |
| dm = 0.5 |
| direction_matches.append(dm) |
|
|
| |
| true_top = set(np.argsort(np.abs(dt))[::-1]) |
| pred_top = np.argsort(np.abs(dp))[::-1] |
| for k in self.n_top_genes: |
| overlap = len(set(pred_top[:k]) & true_top[:k]) |
| des_scores[k].append(overlap / k) |
|
|
| |
| rmses.append(np.sqrt(np.mean((dp - dt) ** 2))) |
| maes.append(np.mean(np.abs(dp - dt))) |
|
|
| |
| metrics["pearson_delta"] = float(np.mean(pearson_deltas)) |
| metrics["pearson_delta_std"] = float(np.std(pearson_deltas)) |
| metrics["spearman_sig"] = float(np.mean(spearman_sigs)) |
| metrics["direction_match"] = float(np.mean(direction_matches)) |
|
|
| for k, scores in des_scores.items(): |
| metrics[f"des@{k}"] = float(np.mean(scores)) |
| metrics[f"des@{k}_std"] = float(np.std(scores)) |
|
|
| metrics["rmse"] = float(np.mean(rmses)) |
| metrics["mae"] = float(np.mean(maes)) |
|
|
| |
| if sigma2_pred is not None: |
| metrics.update(self._evaluate_calibration(delta_pred, delta_true, sigma2_pred)) |
|
|
| return metrics |
|
|
| def _evaluate_calibration( |
| self, |
| delta_pred: torch.Tensor, |
| delta_true: torch.Tensor, |
| sigma2_pred: torch.Tensor, |
| ) -> Dict[str, float]: |
| """ |
| Evaluate uncertainty calibration. |
| |
| For a well-calibrated model, ~68% of true values should fall within |
| Β±1Ο of the mean, ~95% within Β±2Ο. |
| """ |
| dp = delta_pred.cpu().numpy() |
| dt = delta_true.cpu().numpy() |
| sp = torch.sqrt(sigma2_pred.clamp(min=1e-6)).cpu().numpy() |
|
|
| |
| z_scores = (dt - dp) / (sp + 1e-6) |
|
|
| |
| metrics = {} |
| for n_sigma in [1, 2, 3]: |
| expected_fraction = { |
| 1: 0.6827, |
| 2: 0.9545, |
| 3: 0.9973, |
| }[n_sigma] |
| observed_fraction = float(np.mean(np.abs(z_scores) <= n_sigma)) |
| metrics[f"calibration_{n_sigma}sigma"] = observed_fraction |
| metrics[f"calibration_error_{n_sigma}sigma"] = abs(observed_fraction - expected_fraction) |
|
|
| |
| metrics["avg_calibration_error"] = float(np.mean([ |
| metrics[f"calibration_error_{n}sigma"] for n in [1, 2, 3] |
| ])) |
|
|
| |
| metrics["mean_predicted_std"] = float(np.mean(sp)) |
|
|
| return metrics |
|
|
| def evaluate_per_perturbation( |
| self, |
| delta_pred: torch.Tensor, |
| delta_true: torch.Tensor, |
| perturbation_ids: List[str], |
| ) -> Dict[str, Dict[str, float]]: |
| """ |
| Evaluate metrics separately for each perturbation. |
| |
| Args: |
| delta_pred: (N, G) predicted deltas |
| delta_true: (N, G) true deltas |
| perturbation_ids: list of perturbation identifiers (length N) |
| |
| Returns: |
| per_pert_metrics: {pert_id: {metric: value}} |
| """ |
| unique_perts = list(set(perturbation_ids)) |
| per_pert_metrics = {} |
|
|
| for pert_id in unique_perts: |
| mask = [p == pert_id for p in perturbation_ids] |
| mask_tensor = torch.tensor(mask) |
|
|
| if mask_tensor.sum() < 2: |
| continue |
|
|
| dp = delta_pred[mask_tensor] |
| dt = delta_true[mask_tensor] |
| per_pert_metrics[pert_id] = self.evaluate(dp, dt) |
|
|
| return per_pert_metrics |
|
|
| def pds( |
| self, |
| delta_pred: torch.Tensor, |
| delta_true: torch.Tensor, |
| perturbation_ids: List[str], |
| ) -> float: |
| """ |
| Perturbation Discrimination Score (PDS). |
| |
| For each pair of perturbations (A, B), the model should predict |
| patterns that are more similar to A's true pattern than B's true pattern. |
| |
| PDS = P(distance(pred_A, true_A) < distance(pred_A, true_B)) |
| |
| 1.0 = perfect discrimination, 0.5 = random. |
| |
| Reference: PerturBench (arxiv:2408.10609) |
| """ |
| unique_perts = list(set(perturbation_ids)) |
| if len(unique_perts) < 2: |
| return 1.0 |
|
|
| correct = 0 |
| total = 0 |
|
|
| for pert_a in unique_perts: |
| mask_a = torch.tensor([p == pert_a for p in perturbation_ids]) |
| if mask_a.sum() < 2: |
| continue |
|
|
| |
| pred_a = delta_pred[mask_a].mean(dim=0) |
| true_a = delta_true[mask_a].mean(dim=0) |
|
|
| for pert_b in unique_perts: |
| if pert_a == pert_b: |
| continue |
| mask_b = torch.tensor([p == pert_b for p in perturbation_ids]) |
| if mask_b.sum() < 2: |
| continue |
|
|
| true_b = delta_true[mask_b].mean(dim=0) |
|
|
| |
| dist_to_a = F.mse_loss(pred_a, true_a).item() |
| dist_to_b = F.mse_loss(pred_a, true_b).item() |
|
|
| if dist_to_a < dist_to_b: |
| correct += 1 |
| total += 1 |
|
|
| if total == 0: |
| return 1.0 |
|
|
| return correct / total |
|
|
| def evaluate_sample( |
| self, |
| delta_pred: torch.Tensor, |
| delta_true: torch.Tensor, |
| ) -> Dict: |
| """ |
| Detailed evaluation of a single sample, including per-gene metrics. |
| |
| Returns: |
| dict with metrics and per-gene rankings for downstream analysis. |
| """ |
| metrics = self.evaluate(delta_pred.unsqueeze(0), delta_true.unsqueeze(0)) |
|
|
| |
| dp = delta_pred.squeeze(0) |
| dt = delta_true.squeeze(0) |
| abs_errors = (dp - dt).abs() |
|
|
| |
| true_rank = dt.abs().argsort(descending=True) |
| pred_rank = dp.abs().argsort(descending=True) |
|
|
| metrics["top10_true_genes"] = true_rank[:10].tolist() |
| metrics["top10_pred_genes"] = pred_rank[:10].tolist() |
| metrics["top10_overlap"] = len(set(true_rank[:10].tolist()) & set(pred_rank[:10].tolist())) |
|
|
| return metrics |
|
|
|
|
| |
|
|
|
|
| def quick_evaluate( |
| delta_pred: torch.Tensor, |
| delta_true: torch.Tensor, |
| ) -> Dict[str, float]: |
| """Quick evaluation with default settings.""" |
| evaluator = PerturbationEvaluator() |
| return evaluator.evaluate(delta_pred, delta_true) |
|
|
|
|
| def print_metrics(metrics: Dict[str, float], prefix: str = ""): |
| """Pretty-print evaluation metrics.""" |
| print(f"\n{prefix} Perturbation Evaluation Results") |
| print("=" * 50) |
| for key, value in metrics.items(): |
| if "std" not in key and "error" not in key: |
| print(f" {key:30s}: {value:.4f}") |
| print("=" * 50) |
|
|