Zorrojurro commited on
Commit
d9ca01e
·
verified ·
1 Parent(s): 098c64c

Upload src/evaluation/metrics.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/evaluation/metrics.py +86 -0
src/evaluation/metrics.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation metrics for anomaly detection performance.
3
+
4
+ Computes accuracy, precision, recall, F1-score, and AUC-ROC.
5
+ """
6
+
7
+ import numpy as np
8
+ from typing import Optional, List
9
+ from sklearn.metrics import (
10
+ accuracy_score,
11
+ precision_score,
12
+ recall_score,
13
+ f1_score,
14
+ roc_auc_score,
15
+ confusion_matrix,
16
+ classification_report,
17
+ )
18
+
19
+
20
+ class MetricsCalculator:
21
+ """Calculates all evaluation metrics for binary classification."""
22
+
23
+ @staticmethod
24
+ def compute_all(
25
+ y_true: list | np.ndarray,
26
+ y_pred: list | np.ndarray,
27
+ y_scores: Optional[list | np.ndarray] = None,
28
+ ) -> dict:
29
+ """
30
+ Compute all metrics.
31
+
32
+ Args:
33
+ y_true: Ground-truth labels (0 = normal, 1 = abnormal).
34
+ y_pred: Predicted labels.
35
+ y_scores: Predicted probabilities for the positive class
36
+ (required for AUC-ROC).
37
+
38
+ Returns:
39
+ Dictionary of metric_name → value.
40
+ """
41
+ y_true = np.asarray(y_true)
42
+ y_pred = np.asarray(y_pred)
43
+
44
+ metrics = {
45
+ "accuracy": accuracy_score(y_true, y_pred),
46
+ "precision": precision_score(y_true, y_pred, zero_division=0),
47
+ "recall": recall_score(y_true, y_pred, zero_division=0),
48
+ "f1_score": f1_score(y_true, y_pred, zero_division=0),
49
+ }
50
+
51
+ if y_scores is not None and len(np.unique(y_true)) > 1:
52
+ metrics["auc_roc"] = roc_auc_score(y_true, y_scores)
53
+
54
+ return metrics
55
+
56
+ @staticmethod
57
+ def get_confusion_matrix(
58
+ y_true: list | np.ndarray,
59
+ y_pred: list | np.ndarray,
60
+ ) -> np.ndarray:
61
+ """Return confusion matrix as a 2×2 numpy array."""
62
+ return confusion_matrix(y_true, y_pred)
63
+
64
+ @staticmethod
65
+ def get_classification_report(
66
+ y_true: list | np.ndarray,
67
+ y_pred: list | np.ndarray,
68
+ target_names: list = None,
69
+ ) -> str:
70
+ """Return a formatted classification report."""
71
+ if target_names is None:
72
+ target_names = ["Normal", "Abnormal"]
73
+ return classification_report(
74
+ y_true, y_pred, target_names=target_names, zero_division=0
75
+ )
76
+
77
+ @staticmethod
78
+ def format_metrics(metrics: dict) -> str:
79
+ """Pretty-print metrics to a formatted string."""
80
+ lines = []
81
+ for name, value in metrics.items():
82
+ if isinstance(value, float):
83
+ lines.append(f" {name:>15s}: {value:.4f}")
84
+ else:
85
+ lines.append(f" {name:>15s}: {value}")
86
+ return "\n".join(lines)