Spaces:
Sleeping
Sleeping
| # Generated by Claude Code -- 2026-02-08 | |
| """Evaluation metrics for conjunction prediction models.""" | |
| import numpy as np | |
| from sklearn.metrics import ( | |
| average_precision_score, | |
| roc_auc_score, | |
| f1_score, | |
| precision_recall_curve, | |
| mean_absolute_error, | |
| mean_squared_error, | |
| classification_report, | |
| ) | |
| def find_optimal_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> tuple[float, float]: | |
| """Find the threshold that maximizes F1 score on the precision-recall curve.""" | |
| precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob) | |
| # precision_recall_curve returns len(thresholds) = len(precisions) - 1 | |
| # Compute F1 for each threshold | |
| f1_scores = 2 * (precisions[:-1] * recalls[:-1]) / (precisions[:-1] + recalls[:-1] + 1e-8) | |
| best_idx = np.argmax(f1_scores) | |
| return float(thresholds[best_idx]), float(f1_scores[best_idx]) | |
| def evaluate_risk(y_true: np.ndarray, y_prob: np.ndarray, threshold: float = 0.5) -> dict: | |
| """ | |
| Evaluate risk classification predictions. | |
| Args: | |
| y_true: binary ground truth labels | |
| y_prob: predicted probabilities | |
| threshold: classification threshold (used for f1_at_50) | |
| Returns: dict of metrics including optimal threshold F1 | |
| """ | |
| y_pred_fixed = (y_prob >= threshold).astype(int) | |
| results = { | |
| "auc_pr": float(average_precision_score(y_true, y_prob)) if y_true.sum() > 0 else 0.0, | |
| "auc_roc": float(roc_auc_score(y_true, y_prob)) if len(np.unique(y_true)) > 1 else 0.0, | |
| "f1_at_50": float(f1_score(y_true, y_pred_fixed, zero_division=0)), | |
| "n_positive": int(y_true.sum()), | |
| "n_total": int(len(y_true)), | |
| "pos_rate": float(y_true.mean()), | |
| } | |
| # Find optimal threshold that maximizes F1 | |
| if y_true.sum() > 0: | |
| opt_threshold, opt_f1 = find_optimal_threshold(y_true, y_prob) | |
| results["f1"] = opt_f1 | |
| results["optimal_threshold"] = opt_threshold | |
| results["threshold"] = opt_threshold | |
| else: | |
| results["f1"] = results["f1_at_50"] | |
| results["optimal_threshold"] = threshold | |
| results["threshold"] = threshold | |
| # Recall at fixed precision levels | |
| if y_true.sum() > 0: | |
| precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob) | |
| for target_precision in [0.3, 0.5, 0.7]: | |
| mask = precisions >= target_precision | |
| if mask.any(): | |
| best_recall = recalls[mask].max() | |
| results[f"recall_at_prec_{int(target_precision*100)}"] = float(best_recall) | |
| else: | |
| results[f"recall_at_prec_{int(target_precision*100)}"] = 0.0 | |
| return results | |
| def evaluate_miss_distance(y_true_log: np.ndarray, y_pred_log: np.ndarray) -> dict: | |
| """ | |
| Evaluate miss distance regression (log-scale). | |
| Args: | |
| y_true_log: log1p(miss_distance_km) ground truth | |
| y_pred_log: log1p(miss_distance_km) predictions | |
| Returns: dict of metrics | |
| """ | |
| mae_log = float(mean_absolute_error(y_true_log, y_pred_log)) | |
| rmse_log = float(np.sqrt(mean_squared_error(y_true_log, y_pred_log))) | |
| # Convert back to km for interpretable metrics | |
| y_true_km = np.expm1(y_true_log) | |
| y_pred_km = np.expm1(y_pred_log) | |
| mae_km = float(mean_absolute_error(y_true_km, y_pred_km)) | |
| return { | |
| "mae_log": mae_log, | |
| "rmse_log": rmse_log, | |
| "mae_km": mae_km, | |
| "median_abs_error_km": float(np.median(np.abs(y_true_km - y_pred_km))), | |
| } | |
| def full_evaluation( | |
| model_name: str, | |
| y_risk_true: np.ndarray, | |
| y_risk_prob: np.ndarray, | |
| y_miss_true_log: np.ndarray, | |
| y_miss_pred_log: np.ndarray, | |
| ) -> dict: | |
| """Run full evaluation suite for a model.""" | |
| risk_metrics = evaluate_risk(y_risk_true, y_risk_prob) | |
| miss_metrics = evaluate_miss_distance(y_miss_true_log, y_miss_pred_log) | |
| results = {"model": model_name, **risk_metrics, **miss_metrics} | |
| print(f"\n{'='*60}") | |
| print(f" {model_name}") | |
| print(f"{'='*60}") | |
| print(f" Risk Classification:") | |
| print(f" AUC-PR: {risk_metrics['auc_pr']:.4f}") | |
| print(f" AUC-ROC: {risk_metrics['auc_roc']:.4f}") | |
| print(f" F1 (opt): {risk_metrics['f1']:.4f} (threshold={risk_metrics.get('optimal_threshold', 0.5):.3f})") | |
| print(f" F1 (0.50): {risk_metrics['f1_at_50']:.4f}") | |
| print(f" Positives: {risk_metrics['n_positive']}/{risk_metrics['n_total']} " | |
| f"({risk_metrics['pos_rate']:.1%})") | |
| print(f" Miss Distance:") | |
| print(f" MAE (log): {miss_metrics['mae_log']:.4f}") | |
| print(f" MAE (km): {miss_metrics['mae_km']:.2f}") | |
| print(f" Median AE: {miss_metrics['median_abs_error_km']:.2f} km") | |
| print(f"{'='*60}") | |
| return results | |