File size: 1,141 Bytes
137457a
 
db2b094
 
137457a
 
 
 
 
 
 
 
 
 
 
 
db2b094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dea7034
db2b094
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from transformers import pipeline, AutoTokenizer
from split_data import make_test_data
import torch
from torchmetrics.classification import MulticlassConfusionMatrix
label2id = {
    "POSITIVE": 1,
    "NEGATIVE": 0,
}

# Load tokenizer and classifier with truncation
tokenizer = AutoTokenizer.from_pretrained("./finetuned")
classifier = pipeline("sentiment-analysis", model="./finetuned", tokenizer=tokenizer, max_length=512, truncation=True)

test_data = make_test_data()
results = classifier(test_data['text'])

true_labels = test_data["label"]
texts = test_data["text"]

# Convert predicted labels and true labels to numerical format
label_to_id = {"NEGATIVE": 0, "POSITIVE": 1}

predicted_labels = []
for text in texts:
    prediction = classifier(text)
    predicted_label = label_to_id[prediction[0]["label"]]
    predicted_labels.append(predicted_label)

predicted_tensor = torch.tensor(predicted_labels)
true_tensor = torch.tensor(true_labels)

#Get the confusion matrix using Torch Metrics
confusion_matrix = MulticlassConfusionMatrix(num_classes=2)(predicted_tensor, true_tensor)

print("Confusion Matrix")
print(confusion_matrix)