"""Model evaluation utilities.""" from __future__ import annotations from typing import Any import numpy as np from sklearn.metrics import ( average_precision_score, confusion_matrix, f1_score, precision_score, recall_score, roc_auc_score, ) def _safe_roc_auc(y_true, y_pred_proba) -> float: try: return float(roc_auc_score(y_true, y_pred_proba)) except ValueError: return float("nan") def _safe_pr_auc(y_true, y_pred_proba) -> float: try: return float(average_precision_score(y_true, y_pred_proba)) except ValueError: return float("nan") def calculate_metrics(y_true, y_pred, y_pred_proba) -> dict[str, Any]: """Calculate classification metrics used for model comparison.""" cm = confusion_matrix(y_true, y_pred) return { "precision": float(precision_score(y_true, y_pred, zero_division=0)), "recall": float(recall_score(y_true, y_pred, zero_division=0)), "f1": float(f1_score(y_true, y_pred, zero_division=0)), "roc_auc": _safe_roc_auc(y_true, y_pred_proba), "pr_auc": _safe_pr_auc(y_true, y_pred_proba), "confusion_matrix": cm.tolist(), } def rank_models(results: list[dict[str, Any]]) -> list[dict[str, Any]]: """Sort candidate model results by recall, then precision, then roc_auc.""" return sorted( results, key=lambda r: (r["metrics"]["recall"], r["metrics"]["precision"], r["metrics"]["roc_auc"]), reverse=True, ) def calculate_metrics_at_threshold( y_true, y_pred_proba, *, threshold: float, ) -> dict[str, Any]: """Compute metrics using a probability threshold.""" y_pred = (np.asarray(y_pred_proba) >= threshold).astype(int) metrics = calculate_metrics(y_true, y_pred, y_pred_proba) metrics["threshold"] = float(threshold) return metrics def evaluate_thresholds( y_true, y_pred_proba, *, thresholds: list[float] | None = None, min_threshold: float = 0.01, max_threshold: float = 0.99, grid_size: int = 99, ) -> list[dict[str, Any]]: """Evaluate model metrics across threshold grid.""" if thresholds is None: thresholds = np.linspace(min_threshold, max_threshold, grid_size).tolist() return [ calculate_metrics_at_threshold(y_true, y_pred_proba, threshold=t) for t in thresholds ] def select_best_threshold( y_true, y_pred_proba, *, min_recall: float = 0.90, min_threshold: float = 0.01, max_threshold: float = 0.99, grid_size: int = 99, ) -> dict[str, Any]: """Select threshold by maximizing precision while meeting recall target.""" evaluations = evaluate_thresholds( y_true, y_pred_proba, min_threshold=min_threshold, max_threshold=max_threshold, grid_size=grid_size, ) feasible = [m for m in evaluations if m["recall"] >= min_recall] search_space = feasible if feasible else evaluations selection_reason = "meets_min_recall" if feasible else "fallback_max_recall" best = sorted( search_space, key=lambda m: (m["precision"], m["f1"], m["recall"]), reverse=True, )[0] return { "selection_reason": selection_reason, "min_recall_target": float(min_recall), "selected_threshold": float(best["threshold"]), "selected_metrics": best, "threshold_grid_size": int(grid_size), "thresholds_evaluated": evaluations, }