File size: 3,495 Bytes
4937cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Model evaluation utilities."""

from __future__ import annotations

from typing import Any

import numpy as np
from sklearn.metrics import (
    average_precision_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)


def _safe_roc_auc(y_true, y_pred_proba) -> float:
    try:
        return float(roc_auc_score(y_true, y_pred_proba))
    except ValueError:
        return float("nan")


def _safe_pr_auc(y_true, y_pred_proba) -> float:
    try:
        return float(average_precision_score(y_true, y_pred_proba))
    except ValueError:
        return float("nan")


def calculate_metrics(y_true, y_pred, y_pred_proba) -> dict[str, Any]:
    """Calculate classification metrics used for model comparison."""
    cm = confusion_matrix(y_true, y_pred)
    return {
        "precision": float(precision_score(y_true, y_pred, zero_division=0)),
        "recall": float(recall_score(y_true, y_pred, zero_division=0)),
        "f1": float(f1_score(y_true, y_pred, zero_division=0)),
        "roc_auc": _safe_roc_auc(y_true, y_pred_proba),
        "pr_auc": _safe_pr_auc(y_true, y_pred_proba),
        "confusion_matrix": cm.tolist(),
    }


def rank_models(results: list[dict[str, Any]]) -> list[dict[str, Any]]:
    """Sort candidate model results by recall, then precision, then roc_auc."""
    return sorted(
        results,
        key=lambda r: (r["metrics"]["recall"], r["metrics"]["precision"], r["metrics"]["roc_auc"]),
        reverse=True,
    )


def calculate_metrics_at_threshold(
    y_true,
    y_pred_proba,
    *,
    threshold: float,
) -> dict[str, Any]:
    """Compute metrics using a probability threshold."""
    y_pred = (np.asarray(y_pred_proba) >= threshold).astype(int)
    metrics = calculate_metrics(y_true, y_pred, y_pred_proba)
    metrics["threshold"] = float(threshold)
    return metrics


def evaluate_thresholds(
    y_true,
    y_pred_proba,
    *,
    thresholds: list[float] | None = None,
    min_threshold: float = 0.01,
    max_threshold: float = 0.99,
    grid_size: int = 99,
) -> list[dict[str, Any]]:
    """Evaluate model metrics across threshold grid."""
    if thresholds is None:
        thresholds = np.linspace(min_threshold, max_threshold, grid_size).tolist()
    return [
        calculate_metrics_at_threshold(y_true, y_pred_proba, threshold=t)
        for t in thresholds
    ]


def select_best_threshold(
    y_true,
    y_pred_proba,
    *,
    min_recall: float = 0.90,
    min_threshold: float = 0.01,
    max_threshold: float = 0.99,
    grid_size: int = 99,
) -> dict[str, Any]:
    """Select threshold by maximizing precision while meeting recall target."""
    evaluations = evaluate_thresholds(
        y_true,
        y_pred_proba,
        min_threshold=min_threshold,
        max_threshold=max_threshold,
        grid_size=grid_size,
    )

    feasible = [m for m in evaluations if m["recall"] >= min_recall]
    search_space = feasible if feasible else evaluations
    selection_reason = "meets_min_recall" if feasible else "fallback_max_recall"

    best = sorted(
        search_space,
        key=lambda m: (m["precision"], m["f1"], m["recall"]),
        reverse=True,
    )[0]

    return {
        "selection_reason": selection_reason,
        "min_recall_target": float(min_recall),
        "selected_threshold": float(best["threshold"]),
        "selected_metrics": best,
        "threshold_grid_size": int(grid_size),
        "thresholds_evaluated": evaluations,
    }