# Generated by Claude Code -- 2026-02-08 """Evaluation metrics for conjunction prediction models.""" import numpy as np from sklearn.metrics import ( average_precision_score, roc_auc_score, f1_score, precision_recall_curve, mean_absolute_error, mean_squared_error, classification_report, ) def find_optimal_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> tuple[float, float]: """Find the threshold that maximizes F1 score on the precision-recall curve.""" precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob) # precision_recall_curve returns len(thresholds) = len(precisions) - 1 # Compute F1 for each threshold f1_scores = 2 * (precisions[:-1] * recalls[:-1]) / (precisions[:-1] + recalls[:-1] + 1e-8) best_idx = np.argmax(f1_scores) return float(thresholds[best_idx]), float(f1_scores[best_idx]) def evaluate_risk(y_true: np.ndarray, y_prob: np.ndarray, threshold: float = 0.5) -> dict: """ Evaluate risk classification predictions. Args: y_true: binary ground truth labels y_prob: predicted probabilities threshold: classification threshold (used for f1_at_50) Returns: dict of metrics including optimal threshold F1 """ y_pred_fixed = (y_prob >= threshold).astype(int) results = { "auc_pr": float(average_precision_score(y_true, y_prob)) if y_true.sum() > 0 else 0.0, "auc_roc": float(roc_auc_score(y_true, y_prob)) if len(np.unique(y_true)) > 1 else 0.0, "f1_at_50": float(f1_score(y_true, y_pred_fixed, zero_division=0)), "n_positive": int(y_true.sum()), "n_total": int(len(y_true)), "pos_rate": float(y_true.mean()), } # Find optimal threshold that maximizes F1 if y_true.sum() > 0: opt_threshold, opt_f1 = find_optimal_threshold(y_true, y_prob) results["f1"] = opt_f1 results["optimal_threshold"] = opt_threshold results["threshold"] = opt_threshold else: results["f1"] = results["f1_at_50"] results["optimal_threshold"] = threshold results["threshold"] = threshold # Recall at fixed precision levels if y_true.sum() > 0: precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob) for target_precision in [0.3, 0.5, 0.7]: mask = precisions >= target_precision if mask.any(): best_recall = recalls[mask].max() results[f"recall_at_prec_{int(target_precision*100)}"] = float(best_recall) else: results[f"recall_at_prec_{int(target_precision*100)}"] = 0.0 return results def evaluate_miss_distance(y_true_log: np.ndarray, y_pred_log: np.ndarray) -> dict: """ Evaluate miss distance regression (log-scale). Args: y_true_log: log1p(miss_distance_km) ground truth y_pred_log: log1p(miss_distance_km) predictions Returns: dict of metrics """ mae_log = float(mean_absolute_error(y_true_log, y_pred_log)) rmse_log = float(np.sqrt(mean_squared_error(y_true_log, y_pred_log))) # Convert back to km for interpretable metrics y_true_km = np.expm1(y_true_log) y_pred_km = np.expm1(y_pred_log) mae_km = float(mean_absolute_error(y_true_km, y_pred_km)) return { "mae_log": mae_log, "rmse_log": rmse_log, "mae_km": mae_km, "median_abs_error_km": float(np.median(np.abs(y_true_km - y_pred_km))), } def full_evaluation( model_name: str, y_risk_true: np.ndarray, y_risk_prob: np.ndarray, y_miss_true_log: np.ndarray, y_miss_pred_log: np.ndarray, ) -> dict: """Run full evaluation suite for a model.""" risk_metrics = evaluate_risk(y_risk_true, y_risk_prob) miss_metrics = evaluate_miss_distance(y_miss_true_log, y_miss_pred_log) results = {"model": model_name, **risk_metrics, **miss_metrics} print(f"\n{'='*60}") print(f" {model_name}") print(f"{'='*60}") print(f" Risk Classification:") print(f" AUC-PR: {risk_metrics['auc_pr']:.4f}") print(f" AUC-ROC: {risk_metrics['auc_roc']:.4f}") print(f" F1 (opt): {risk_metrics['f1']:.4f} (threshold={risk_metrics.get('optimal_threshold', 0.5):.3f})") print(f" F1 (0.50): {risk_metrics['f1_at_50']:.4f}") print(f" Positives: {risk_metrics['n_positive']}/{risk_metrics['n_total']} " f"({risk_metrics['pos_rate']:.1%})") print(f" Miss Distance:") print(f" MAE (log): {miss_metrics['mae_log']:.4f}") print(f" MAE (km): {miss_metrics['mae_km']:.2f}") print(f" Median AE: {miss_metrics['median_abs_error_km']:.2f} km") print(f"{'='*60}") return results