Bert-model-test / validate.py
ganeshkonapalli's picture
Upload 8 files
8505a58 verified
raw
history blame contribute delete
690 Bytes
from app.model_utils import load_model, load_data
from sklearn.metrics import accuracy_score
import torch
def validate_model():
model, tokenizer = load_model()
_, val_dataset, _ = load_data(tokenizer)
correct = 0
total = 0
for item in val_dataset:
input_ids = item["input_ids"].unsqueeze(0)
attention_mask = item["attention_mask"].unsqueeze(0)
label = item["label"].item()
with torch.no_grad():
output = model(input_ids=input_ids, attention_mask=attention_mask)
pred = torch.argmax(output.logits, dim=1).item()
correct += int(pred == label)
total += 1
return {"accuracy": correct / total}