File size: 959 Bytes
162ca90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from typing import List, Dict
import numpy as np
import evaluate
def compute_metrics_sentiment(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
acc = (preds == labels).mean().item()
return {"accuracy": acc}
def compute_metrics_ner(eval_pred, label_list: List[str]):
seqeval = evaluate.load("seqeval")
logits, labels = eval_pred
preds = logits.argmax(-1)
true_preds = [
[label_list[p] for (p, l) in zip(pred, lab) if l != -100]
for pred, lab in zip(preds, labels)
]
true_labels = [
[label_list[l] for (p, l) in zip(pred, lab) if l != -100]
for pred, lab in zip(preds, labels)
]
results = seqeval.compute(predictions=true_preds, references=true_labels)
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
"accuracy": results["overall_accuracy"],
}
|