Bert-model-test / test.py
ganeshkonapalli's picture
Upload 8 files
8505a58 verified
raw
history blame contribute delete
716 Bytes
from app.model_utils import load_model, load_data
from sklearn.metrics import classification_report
import torch
def test_model():
model, tokenizer = load_model()
_, _, test_dataset = load_data(tokenizer)
preds, labels = [], []
for item in test_dataset:
input_ids = item["input_ids"].unsqueeze(0)
attention_mask = item["attention_mask"].unsqueeze(0)
label = item["label"].item()
with torch.no_grad():
output = model(input_ids=input_ids, attention_mask=attention_mask)
pred = torch.argmax(output.logits, dim=1).item()
preds.append(pred)
labels.append(label)
return classification_report(labels, preds, output_dict=True)