Spaces:
Build error
Build error
File size: 716 Bytes
8505a58 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | from app.model_utils import load_model, load_data
from sklearn.metrics import classification_report
import torch
def test_model():
model, tokenizer = load_model()
_, _, test_dataset = load_data(tokenizer)
preds, labels = [], []
for item in test_dataset:
input_ids = item["input_ids"].unsqueeze(0)
attention_mask = item["attention_mask"].unsqueeze(0)
label = item["label"].item()
with torch.no_grad():
output = model(input_ids=input_ids, attention_mask=attention_mask)
pred = torch.argmax(output.logits, dim=1).item()
preds.append(pred)
labels.append(label)
return classification_report(labels, preds, output_dict=True)
|