github-actions[bot]
deploy: sync snapshot from github
4937cba
"""Model evaluation utilities."""
from __future__ import annotations
from typing import Any
import numpy as np
from sklearn.metrics import (
average_precision_score,
confusion_matrix,
f1_score,
precision_score,
recall_score,
roc_auc_score,
)
def _safe_roc_auc(y_true, y_pred_proba) -> float:
try:
return float(roc_auc_score(y_true, y_pred_proba))
except ValueError:
return float("nan")
def _safe_pr_auc(y_true, y_pred_proba) -> float:
try:
return float(average_precision_score(y_true, y_pred_proba))
except ValueError:
return float("nan")
def calculate_metrics(y_true, y_pred, y_pred_proba) -> dict[str, Any]:
"""Calculate classification metrics used for model comparison."""
cm = confusion_matrix(y_true, y_pred)
return {
"precision": float(precision_score(y_true, y_pred, zero_division=0)),
"recall": float(recall_score(y_true, y_pred, zero_division=0)),
"f1": float(f1_score(y_true, y_pred, zero_division=0)),
"roc_auc": _safe_roc_auc(y_true, y_pred_proba),
"pr_auc": _safe_pr_auc(y_true, y_pred_proba),
"confusion_matrix": cm.tolist(),
}
def rank_models(results: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Sort candidate model results by recall, then precision, then roc_auc."""
return sorted(
results,
key=lambda r: (r["metrics"]["recall"], r["metrics"]["precision"], r["metrics"]["roc_auc"]),
reverse=True,
)
def calculate_metrics_at_threshold(
y_true,
y_pred_proba,
*,
threshold: float,
) -> dict[str, Any]:
"""Compute metrics using a probability threshold."""
y_pred = (np.asarray(y_pred_proba) >= threshold).astype(int)
metrics = calculate_metrics(y_true, y_pred, y_pred_proba)
metrics["threshold"] = float(threshold)
return metrics
def evaluate_thresholds(
y_true,
y_pred_proba,
*,
thresholds: list[float] | None = None,
min_threshold: float = 0.01,
max_threshold: float = 0.99,
grid_size: int = 99,
) -> list[dict[str, Any]]:
"""Evaluate model metrics across threshold grid."""
if thresholds is None:
thresholds = np.linspace(min_threshold, max_threshold, grid_size).tolist()
return [
calculate_metrics_at_threshold(y_true, y_pred_proba, threshold=t)
for t in thresholds
]
def select_best_threshold(
y_true,
y_pred_proba,
*,
min_recall: float = 0.90,
min_threshold: float = 0.01,
max_threshold: float = 0.99,
grid_size: int = 99,
) -> dict[str, Any]:
"""Select threshold by maximizing precision while meeting recall target."""
evaluations = evaluate_thresholds(
y_true,
y_pred_proba,
min_threshold=min_threshold,
max_threshold=max_threshold,
grid_size=grid_size,
)
feasible = [m for m in evaluations if m["recall"] >= min_recall]
search_space = feasible if feasible else evaluations
selection_reason = "meets_min_recall" if feasible else "fallback_max_recall"
best = sorted(
search_space,
key=lambda m: (m["precision"], m["f1"], m["recall"]),
reverse=True,
)[0]
return {
"selection_reason": selection_reason,
"min_recall_target": float(min_recall),
"selected_threshold": float(best["threshold"]),
"selected_metrics": best,
"threshold_grid_size": int(grid_size),
"thresholds_evaluated": evaluations,
}