panacea-api / src /evaluation /metrics.py
DTanzillo's picture
Upload folder using huggingface_hub
a4b5ecb verified
# Generated by Claude Code -- 2026-02-08
"""Evaluation metrics for conjunction prediction models."""
import numpy as np
from sklearn.metrics import (
average_precision_score,
roc_auc_score,
f1_score,
precision_recall_curve,
mean_absolute_error,
mean_squared_error,
classification_report,
)
def find_optimal_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> tuple[float, float]:
"""Find the threshold that maximizes F1 score on the precision-recall curve."""
precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob)
# precision_recall_curve returns len(thresholds) = len(precisions) - 1
# Compute F1 for each threshold
f1_scores = 2 * (precisions[:-1] * recalls[:-1]) / (precisions[:-1] + recalls[:-1] + 1e-8)
best_idx = np.argmax(f1_scores)
return float(thresholds[best_idx]), float(f1_scores[best_idx])
def evaluate_risk(y_true: np.ndarray, y_prob: np.ndarray, threshold: float = 0.5) -> dict:
"""
Evaluate risk classification predictions.
Args:
y_true: binary ground truth labels
y_prob: predicted probabilities
threshold: classification threshold (used for f1_at_50)
Returns: dict of metrics including optimal threshold F1
"""
y_pred_fixed = (y_prob >= threshold).astype(int)
results = {
"auc_pr": float(average_precision_score(y_true, y_prob)) if y_true.sum() > 0 else 0.0,
"auc_roc": float(roc_auc_score(y_true, y_prob)) if len(np.unique(y_true)) > 1 else 0.0,
"f1_at_50": float(f1_score(y_true, y_pred_fixed, zero_division=0)),
"n_positive": int(y_true.sum()),
"n_total": int(len(y_true)),
"pos_rate": float(y_true.mean()),
}
# Find optimal threshold that maximizes F1
if y_true.sum() > 0:
opt_threshold, opt_f1 = find_optimal_threshold(y_true, y_prob)
results["f1"] = opt_f1
results["optimal_threshold"] = opt_threshold
results["threshold"] = opt_threshold
else:
results["f1"] = results["f1_at_50"]
results["optimal_threshold"] = threshold
results["threshold"] = threshold
# Recall at fixed precision levels
if y_true.sum() > 0:
precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob)
for target_precision in [0.3, 0.5, 0.7]:
mask = precisions >= target_precision
if mask.any():
best_recall = recalls[mask].max()
results[f"recall_at_prec_{int(target_precision*100)}"] = float(best_recall)
else:
results[f"recall_at_prec_{int(target_precision*100)}"] = 0.0
return results
def evaluate_miss_distance(y_true_log: np.ndarray, y_pred_log: np.ndarray) -> dict:
"""
Evaluate miss distance regression (log-scale).
Args:
y_true_log: log1p(miss_distance_km) ground truth
y_pred_log: log1p(miss_distance_km) predictions
Returns: dict of metrics
"""
mae_log = float(mean_absolute_error(y_true_log, y_pred_log))
rmse_log = float(np.sqrt(mean_squared_error(y_true_log, y_pred_log)))
# Convert back to km for interpretable metrics
y_true_km = np.expm1(y_true_log)
y_pred_km = np.expm1(y_pred_log)
mae_km = float(mean_absolute_error(y_true_km, y_pred_km))
return {
"mae_log": mae_log,
"rmse_log": rmse_log,
"mae_km": mae_km,
"median_abs_error_km": float(np.median(np.abs(y_true_km - y_pred_km))),
}
def full_evaluation(
model_name: str,
y_risk_true: np.ndarray,
y_risk_prob: np.ndarray,
y_miss_true_log: np.ndarray,
y_miss_pred_log: np.ndarray,
) -> dict:
"""Run full evaluation suite for a model."""
risk_metrics = evaluate_risk(y_risk_true, y_risk_prob)
miss_metrics = evaluate_miss_distance(y_miss_true_log, y_miss_pred_log)
results = {"model": model_name, **risk_metrics, **miss_metrics}
print(f"\n{'='*60}")
print(f" {model_name}")
print(f"{'='*60}")
print(f" Risk Classification:")
print(f" AUC-PR: {risk_metrics['auc_pr']:.4f}")
print(f" AUC-ROC: {risk_metrics['auc_roc']:.4f}")
print(f" F1 (opt): {risk_metrics['f1']:.4f} (threshold={risk_metrics.get('optimal_threshold', 0.5):.3f})")
print(f" F1 (0.50): {risk_metrics['f1_at_50']:.4f}")
print(f" Positives: {risk_metrics['n_positive']}/{risk_metrics['n_total']} "
f"({risk_metrics['pos_rate']:.1%})")
print(f" Miss Distance:")
print(f" MAE (log): {miss_metrics['mae_log']:.4f}")
print(f" MAE (km): {miss_metrics['mae_km']:.2f}")
print(f" Median AE: {miss_metrics['median_abs_error_km']:.2f} km")
print(f"{'='*60}")
return results