my-mnist-hf / examples /metrics.py
dsaint31's picture
Release custom MNIST model
fab639f verified
# hf_custom_proj/examples/metrics.py
import evaluate
import numpy as np
def compute_metrics(eval_pred):
"""
Hugging Face Trainer용 metrics 함수.
eval_pred:
- logits : np.ndarray, shape (N, num_classes)
- labels : np.ndarray, shape (N,)
"""
metric = evaluate.load("f1")
logits, labels = eval_pred
# Trainer는 기본적으로 logits를 NumPy 배열로 전달함
predictions = np.argmax(logits, axis=-1)
return metric.compute(
predictions=predictions,
references=labels,
average="macro",
)