Yusuf
per class accuracy
ed657fc
import torch
from torch.nn import CrossEntropyLoss
import numpy as np
import matplotlib.pyplot as plt
"""
Evaluates a trained model on a dataloader that returns batches like:
batch["image"] -> Tensor [B, 3, 256, 256]
batch["label"] -> Tensor [B]
"""
def make_predictions(model, dataloader, device):
model.eval()
criterion = CrossEntropyLoss()
total_loss = 0
total_correct = 0
total_samples = 0
all_preds = []
all_labels = []
with torch.no_grad():
for batch in dataloader:
# Move tensors to device
images = batch["image"].to(device)
labels = batch["label"].to(device).long()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
preds = outputs.argmax(dim=1)
total_loss += loss.item() * images.size(0)
total_correct += (preds == labels).sum().item()
total_samples += labels.size(0)
# Accumulate all predictions and labels
all_preds.extend(preds.tolist())
all_labels.extend(labels.tolist())
accuracy = total_correct / total_samples
avg_loss = total_loss / total_samples
return {
"accuracy": accuracy,
"loss": avg_loss,
"predictions": np.array(all_preds),
"labels": np.array(all_labels),
}
# Computes per-class accuracies
def class_accuracies(labels, preds, num_classes):
correct = np.zeros(num_classes, dtype=int)
counts = np.zeros(num_classes, dtype=int)
accuracies = np.zeros(num_classes, dtype=float)
for true, pred in zip(labels, preds):
counts[true] += 1
if true == pred:
correct[true] += 1
# Calculate accuracies
for i in range(num_classes):
if counts[i] > 0:
accuracies[i] = round(correct[i] / counts[i], 4)
else:
accuracies[i] = 0.0
return accuracies
def plot_class_accuracies(accuracies, class_names):
fig, ax = plt.subplots(figsize=(12, 6))
ax.set_title("Per-Class Accuracy")
ax.set_xlabel("Class")
ax.set_ylabel("Accuracy")
ax.set_ylim(0, 1.0)
ax.bar(class_names, accuracies)
plt.xticks(rotation=90)
plt.tight_layout()
return fig