CCS229_ALA / src /evaluate.py
Gillie2004's picture
Upload 4 files
2222d7e verified
from sklearn.metrics import classification_report, confusion_matrix
import torch
def evaluate_model(model, dataloader, device):
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(device), y.to(device)
preds = model(x).argmax(dim=1)
y_true.extend(y.cpu().numpy())
y_pred.extend(preds.cpu().numpy())
return classification_report(y_true, y_pred), confusion_matrix(y_true, y_pred)