File size: 716 Bytes
8505a58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from app.model_utils import load_model, load_data
from sklearn.metrics import classification_report
import torch

def test_model():
    model, tokenizer = load_model()
    _, _, test_dataset = load_data(tokenizer)

    preds, labels = [], []
    for item in test_dataset:
        input_ids = item["input_ids"].unsqueeze(0)
        attention_mask = item["attention_mask"].unsqueeze(0)
        label = item["label"].item()

        with torch.no_grad():
            output = model(input_ids=input_ids, attention_mask=attention_mask)
            pred = torch.argmax(output.logits, dim=1).item()

        preds.append(pred)
        labels.append(label)

    return classification_report(labels, preds, output_dict=True)