import torch from torch.nn import CrossEntropyLoss import numpy as np import matplotlib.pyplot as plt """ Evaluates a trained model on a dataloader that returns batches like: batch["image"] -> Tensor [B, 3, 256, 256] batch["label"] -> Tensor [B] """ def make_predictions(model, dataloader, device): model.eval() criterion = CrossEntropyLoss() total_loss = 0 total_correct = 0 total_samples = 0 all_preds = [] all_labels = [] with torch.no_grad(): for batch in dataloader: # Move tensors to device images = batch["image"].to(device) labels = batch["label"].to(device).long() # Forward pass outputs = model(images) loss = criterion(outputs, labels) preds = outputs.argmax(dim=1) total_loss += loss.item() * images.size(0) total_correct += (preds == labels).sum().item() total_samples += labels.size(0) # Accumulate all predictions and labels all_preds.extend(preds.tolist()) all_labels.extend(labels.tolist()) accuracy = total_correct / total_samples avg_loss = total_loss / total_samples return { "accuracy": accuracy, "loss": avg_loss, "predictions": np.array(all_preds), "labels": np.array(all_labels), } # Computes per-class accuracies def class_accuracies(labels, preds, num_classes): correct = np.zeros(num_classes, dtype=int) counts = np.zeros(num_classes, dtype=int) accuracies = np.zeros(num_classes, dtype=float) for true, pred in zip(labels, preds): counts[true] += 1 if true == pred: correct[true] += 1 # Calculate accuracies for i in range(num_classes): if counts[i] > 0: accuracies[i] = round(correct[i] / counts[i], 4) else: accuracies[i] = 0.0 return accuracies def plot_class_accuracies(accuracies, class_names): fig, ax = plt.subplots(figsize=(12, 6)) ax.set_title("Per-Class Accuracy") ax.set_xlabel("Class") ax.set_ylabel("Accuracy") ax.set_ylim(0, 1.0) ax.bar(class_names, accuracies) plt.xticks(rotation=90) plt.tight_layout() return fig