supcon / code /src /metrics.py
IGandarillas1's picture
add code
0797029
raw
history blame contribute delete
670 Bytes
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
def compute_metrics_bce(eval_pred):
logits, labels = eval_pred
logits[logits>=0.5] = 1
logits[logits<0.5] = 0
predictions = logits.reshape(-1)
labels = labels.reshape(-1)
accuracy = accuracy_score(labels, predictions)
f1 = f1_score(labels, predictions, pos_label=1, average='binary')
precision = precision_score(labels, predictions, pos_label=1, average='binary')
recall = recall_score(labels, predictions, pos_label=1, average='binary')
return {"accuracy": accuracy, "f1": f1, "precision": precision, "recall": recall}