from app.model_utils import load_model, load_data from sklearn.metrics import accuracy_score import torch def validate_model(): model, tokenizer = load_model() _, val_dataset, _ = load_data(tokenizer) correct = 0 total = 0 for item in val_dataset: input_ids = item["input_ids"].unsqueeze(0) attention_mask = item["attention_mask"].unsqueeze(0) label = item["label"].item() with torch.no_grad(): output = model(input_ids=input_ids, attention_mask=attention_mask) pred = torch.argmax(output.logits, dim=1).item() correct += int(pred == label) total += 1 return {"accuracy": correct / total}