File size: 1,841 Bytes
52cebf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import yaml
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load test data from YAML file
with open("test_data.yaml", "r") as file:
    test_data = yaml.safe_load(file)["test_data"]

# Load pre-trained model and tokenizer
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Load your fine-tuned model weights
model.load_state_dict(torch.load("path/to/your/fine-tuned/model.pth"))
model.eval()

# Evaluate on test data
correct_predictions = 0
total_samples = 0

for sample in test_data:
    text = sample["text"]
    expected_label = sample["label"]

    # Tokenize and encode input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

    # Get model predictions
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_label = "Positive" if logits.argmax().item() else "Negative"

    # Check if prediction matches expected label
    if predicted_label == expected_label:
        correct_predictions += 1
    total_samples += 1

# Calculate accuracy
accuracy = correct_predictions / total_samples
print(f"Accuracy on test data: {accuracy * 100:.2f}%")

# Demonstrate model predictions
print("\nModel Predictions:")
for sample in test_data:
    text = sample["text"]
    expected_label = sample["label"]

    # Tokenize and encode input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

    # Get model predictions
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_label = "Positive" if logits.argmax().item() else "Negative"

    print(f"Text: {text}")
    print(f"Expected Label: {expected_label}")
    print(f"Predicted Label: {predicted_label}")
    print()