Zorrojurro's picture
Upload src/evaluation/metrics.py with huggingface_hub
d9ca01e verified
"""
Evaluation metrics for anomaly detection performance.
Computes accuracy, precision, recall, F1-score, and AUC-ROC.
"""
import numpy as np
from typing import Optional, List
from sklearn.metrics import (
accuracy_score,
precision_score,
recall_score,
f1_score,
roc_auc_score,
confusion_matrix,
classification_report,
)
class MetricsCalculator:
"""Calculates all evaluation metrics for binary classification."""
@staticmethod
def compute_all(
y_true: list | np.ndarray,
y_pred: list | np.ndarray,
y_scores: Optional[list | np.ndarray] = None,
) -> dict:
"""
Compute all metrics.
Args:
y_true: Ground-truth labels (0 = normal, 1 = abnormal).
y_pred: Predicted labels.
y_scores: Predicted probabilities for the positive class
(required for AUC-ROC).
Returns:
Dictionary of metric_name → value.
"""
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
metrics = {
"accuracy": accuracy_score(y_true, y_pred),
"precision": precision_score(y_true, y_pred, zero_division=0),
"recall": recall_score(y_true, y_pred, zero_division=0),
"f1_score": f1_score(y_true, y_pred, zero_division=0),
}
if y_scores is not None and len(np.unique(y_true)) > 1:
metrics["auc_roc"] = roc_auc_score(y_true, y_scores)
return metrics
@staticmethod
def get_confusion_matrix(
y_true: list | np.ndarray,
y_pred: list | np.ndarray,
) -> np.ndarray:
"""Return confusion matrix as a 2×2 numpy array."""
return confusion_matrix(y_true, y_pred)
@staticmethod
def get_classification_report(
y_true: list | np.ndarray,
y_pred: list | np.ndarray,
target_names: list = None,
) -> str:
"""Return a formatted classification report."""
if target_names is None:
target_names = ["Normal", "Abnormal"]
return classification_report(
y_true, y_pred, target_names=target_names, zero_division=0
)
@staticmethod
def format_metrics(metrics: dict) -> str:
"""Pretty-print metrics to a formatted string."""
lines = []
for name, value in metrics.items():
if isinstance(value, float):
lines.append(f" {name:>15s}: {value:.4f}")
else:
lines.append(f" {name:>15s}: {value}")
return "\n".join(lines)