File size: 4,852 Bytes
a4b5ecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Generated by Claude Code -- 2026-02-08
"""Evaluation metrics for conjunction prediction models."""

import numpy as np
from sklearn.metrics import (
    average_precision_score,
    roc_auc_score,
    f1_score,
    precision_recall_curve,
    mean_absolute_error,
    mean_squared_error,
    classification_report,
)


def find_optimal_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> tuple[float, float]:
    """Find the threshold that maximizes F1 score on the precision-recall curve."""
    precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob)
    # precision_recall_curve returns len(thresholds) = len(precisions) - 1
    # Compute F1 for each threshold
    f1_scores = 2 * (precisions[:-1] * recalls[:-1]) / (precisions[:-1] + recalls[:-1] + 1e-8)
    best_idx = np.argmax(f1_scores)
    return float(thresholds[best_idx]), float(f1_scores[best_idx])


def evaluate_risk(y_true: np.ndarray, y_prob: np.ndarray, threshold: float = 0.5) -> dict:
    """

    Evaluate risk classification predictions.



    Args:

        y_true: binary ground truth labels

        y_prob: predicted probabilities

        threshold: classification threshold (used for f1_at_50)



    Returns: dict of metrics including optimal threshold F1

    """
    y_pred_fixed = (y_prob >= threshold).astype(int)

    results = {
        "auc_pr": float(average_precision_score(y_true, y_prob)) if y_true.sum() > 0 else 0.0,
        "auc_roc": float(roc_auc_score(y_true, y_prob)) if len(np.unique(y_true)) > 1 else 0.0,
        "f1_at_50": float(f1_score(y_true, y_pred_fixed, zero_division=0)),
        "n_positive": int(y_true.sum()),
        "n_total": int(len(y_true)),
        "pos_rate": float(y_true.mean()),
    }

    # Find optimal threshold that maximizes F1
    if y_true.sum() > 0:
        opt_threshold, opt_f1 = find_optimal_threshold(y_true, y_prob)
        results["f1"] = opt_f1
        results["optimal_threshold"] = opt_threshold
        results["threshold"] = opt_threshold
    else:
        results["f1"] = results["f1_at_50"]
        results["optimal_threshold"] = threshold
        results["threshold"] = threshold

    # Recall at fixed precision levels
    if y_true.sum() > 0:
        precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob)
        for target_precision in [0.3, 0.5, 0.7]:
            mask = precisions >= target_precision
            if mask.any():
                best_recall = recalls[mask].max()
                results[f"recall_at_prec_{int(target_precision*100)}"] = float(best_recall)
            else:
                results[f"recall_at_prec_{int(target_precision*100)}"] = 0.0

    return results


def evaluate_miss_distance(y_true_log: np.ndarray, y_pred_log: np.ndarray) -> dict:
    """

    Evaluate miss distance regression (log-scale).



    Args:

        y_true_log: log1p(miss_distance_km) ground truth

        y_pred_log: log1p(miss_distance_km) predictions



    Returns: dict of metrics

    """
    mae_log = float(mean_absolute_error(y_true_log, y_pred_log))
    rmse_log = float(np.sqrt(mean_squared_error(y_true_log, y_pred_log)))

    # Convert back to km for interpretable metrics
    y_true_km = np.expm1(y_true_log)
    y_pred_km = np.expm1(y_pred_log)
    mae_km = float(mean_absolute_error(y_true_km, y_pred_km))

    return {
        "mae_log": mae_log,
        "rmse_log": rmse_log,
        "mae_km": mae_km,
        "median_abs_error_km": float(np.median(np.abs(y_true_km - y_pred_km))),
    }


def full_evaluation(

    model_name: str,

    y_risk_true: np.ndarray,

    y_risk_prob: np.ndarray,

    y_miss_true_log: np.ndarray,

    y_miss_pred_log: np.ndarray,

) -> dict:
    """Run full evaluation suite for a model."""
    risk_metrics = evaluate_risk(y_risk_true, y_risk_prob)
    miss_metrics = evaluate_miss_distance(y_miss_true_log, y_miss_pred_log)

    results = {"model": model_name, **risk_metrics, **miss_metrics}

    print(f"\n{'='*60}")
    print(f"  {model_name}")
    print(f"{'='*60}")
    print(f"  Risk Classification:")
    print(f"    AUC-PR:     {risk_metrics['auc_pr']:.4f}")
    print(f"    AUC-ROC:    {risk_metrics['auc_roc']:.4f}")
    print(f"    F1 (opt):   {risk_metrics['f1']:.4f}  (threshold={risk_metrics.get('optimal_threshold', 0.5):.3f})")
    print(f"    F1 (0.50):  {risk_metrics['f1_at_50']:.4f}")
    print(f"    Positives:  {risk_metrics['n_positive']}/{risk_metrics['n_total']} "
          f"({risk_metrics['pos_rate']:.1%})")
    print(f"  Miss Distance:")
    print(f"    MAE (log): {miss_metrics['mae_log']:.4f}")
    print(f"    MAE (km):  {miss_metrics['mae_km']:.2f}")
    print(f"    Median AE: {miss_metrics['median_abs_error_km']:.2f} km")
    print(f"{'='*60}")

    return results