Spaces:
Sleeping
Sleeping
File size: 2,357 Bytes
ed657fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
from torch.nn import CrossEntropyLoss
import numpy as np
import matplotlib.pyplot as plt
"""
Evaluates a trained model on a dataloader that returns batches like:
batch["image"] -> Tensor [B, 3, 256, 256]
batch["label"] -> Tensor [B]
"""
def make_predictions(model, dataloader, device):
model.eval()
criterion = CrossEntropyLoss()
total_loss = 0
total_correct = 0
total_samples = 0
all_preds = []
all_labels = []
with torch.no_grad():
for batch in dataloader:
# Move tensors to device
images = batch["image"].to(device)
labels = batch["label"].to(device).long()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
preds = outputs.argmax(dim=1)
total_loss += loss.item() * images.size(0)
total_correct += (preds == labels).sum().item()
total_samples += labels.size(0)
# Accumulate all predictions and labels
all_preds.extend(preds.tolist())
all_labels.extend(labels.tolist())
accuracy = total_correct / total_samples
avg_loss = total_loss / total_samples
return {
"accuracy": accuracy,
"loss": avg_loss,
"predictions": np.array(all_preds),
"labels": np.array(all_labels),
}
# Computes per-class accuracies
def class_accuracies(labels, preds, num_classes):
correct = np.zeros(num_classes, dtype=int)
counts = np.zeros(num_classes, dtype=int)
accuracies = np.zeros(num_classes, dtype=float)
for true, pred in zip(labels, preds):
counts[true] += 1
if true == pred:
correct[true] += 1
# Calculate accuracies
for i in range(num_classes):
if counts[i] > 0:
accuracies[i] = round(correct[i] / counts[i], 4)
else:
accuracies[i] = 0.0
return accuracies
def plot_class_accuracies(accuracies, class_names):
fig, ax = plt.subplots(figsize=(12, 6))
ax.set_title("Per-Class Accuracy")
ax.set_xlabel("Class")
ax.set_ylabel("Accuracy")
ax.set_ylim(0, 1.0)
ax.bar(class_names, accuracies)
plt.xticks(rotation=90)
plt.tight_layout()
return fig
|