sql-error-classifier / src /hf_metrics.py
nishu08's picture
Deploy CodeBERT inference Space
7aae828 verified
Raw
History Blame Contribute Delete
1.53 kB
"""Evaluation metrics for multi-label SQL error classification."""
from __future__ import annotations
from typing import Dict
import numpy as np
from sklearn.metrics import (
accuracy_score,
f1_score,
hamming_loss,
precision_score,
recall_score,
)
def sigmoid(x: np.ndarray) -> np.ndarray:
return 1.0 / (1.0 + np.exp(-x))
def compute_multilabel_metrics(
logits: np.ndarray,
labels: np.ndarray,
threshold: float = 0.5,
) -> Dict[str, float]:
probs = sigmoid(logits)
preds = (probs >= threshold).astype(int)
labels = labels.astype(int)
return {
"accuracy": float(accuracy_score(labels, preds)),
"f1_macro": float(f1_score(labels, preds, average="macro", zero_division=0)),
"f1_micro": float(f1_score(labels, preds, average="micro", zero_division=0)),
"precision_macro": float(
precision_score(labels, preds, average="macro", zero_division=0)
),
"recall_macro": float(
recall_score(labels, preds, average="macro", zero_division=0)
),
"hamming_loss": float(hamming_loss(labels, preds)),
"subset_accuracy": float((preds == labels).all(axis=1).mean()),
}
def build_compute_metrics(threshold: float = 0.5):
"""Factory for Hugging Face Trainer compute_metrics callback."""
def compute_metrics(eval_pred) -> Dict[str, float]:
logits, labels = eval_pred
return compute_multilabel_metrics(logits, labels, threshold=threshold)
return compute_metrics