File size: 1,141 Bytes
137457a db2b094 137457a db2b094 dea7034 db2b094 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | from transformers import pipeline, AutoTokenizer
from split_data import make_test_data
import torch
from torchmetrics.classification import MulticlassConfusionMatrix
label2id = {
"POSITIVE": 1,
"NEGATIVE": 0,
}
# Load tokenizer and classifier with truncation
tokenizer = AutoTokenizer.from_pretrained("./finetuned")
classifier = pipeline("sentiment-analysis", model="./finetuned", tokenizer=tokenizer, max_length=512, truncation=True)
test_data = make_test_data()
results = classifier(test_data['text'])
true_labels = test_data["label"]
texts = test_data["text"]
# Convert predicted labels and true labels to numerical format
label_to_id = {"NEGATIVE": 0, "POSITIVE": 1}
predicted_labels = []
for text in texts:
prediction = classifier(text)
predicted_label = label_to_id[prediction[0]["label"]]
predicted_labels.append(predicted_label)
predicted_tensor = torch.tensor(predicted_labels)
true_tensor = torch.tensor(true_labels)
#Get the confusion matrix using Torch Metrics
confusion_matrix = MulticlassConfusionMatrix(num_classes=2)(predicted_tensor, true_tensor)
print("Confusion Matrix")
print(confusion_matrix) |