Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.nn import CrossEntropyLoss | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| """ | |
| Evaluates a trained model on a dataloader that returns batches like: | |
| batch["image"] -> Tensor [B, 3, 256, 256] | |
| batch["label"] -> Tensor [B] | |
| """ | |
| def make_predictions(model, dataloader, device): | |
| model.eval() | |
| criterion = CrossEntropyLoss() | |
| total_loss = 0 | |
| total_correct = 0 | |
| total_samples = 0 | |
| all_preds = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| for batch in dataloader: | |
| # Move tensors to device | |
| images = batch["image"].to(device) | |
| labels = batch["label"].to(device).long() | |
| # Forward pass | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| preds = outputs.argmax(dim=1) | |
| total_loss += loss.item() * images.size(0) | |
| total_correct += (preds == labels).sum().item() | |
| total_samples += labels.size(0) | |
| # Accumulate all predictions and labels | |
| all_preds.extend(preds.tolist()) | |
| all_labels.extend(labels.tolist()) | |
| accuracy = total_correct / total_samples | |
| avg_loss = total_loss / total_samples | |
| return { | |
| "accuracy": accuracy, | |
| "loss": avg_loss, | |
| "predictions": np.array(all_preds), | |
| "labels": np.array(all_labels), | |
| } | |
| # Computes per-class accuracies | |
| def class_accuracies(labels, preds, num_classes): | |
| correct = np.zeros(num_classes, dtype=int) | |
| counts = np.zeros(num_classes, dtype=int) | |
| accuracies = np.zeros(num_classes, dtype=float) | |
| for true, pred in zip(labels, preds): | |
| counts[true] += 1 | |
| if true == pred: | |
| correct[true] += 1 | |
| # Calculate accuracies | |
| for i in range(num_classes): | |
| if counts[i] > 0: | |
| accuracies[i] = round(correct[i] / counts[i], 4) | |
| else: | |
| accuracies[i] = 0.0 | |
| return accuracies | |
| def plot_class_accuracies(accuracies, class_names): | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| ax.set_title("Per-Class Accuracy") | |
| ax.set_xlabel("Class") | |
| ax.set_ylabel("Accuracy") | |
| ax.set_ylim(0, 1.0) | |
| ax.bar(class_names, accuracies) | |
| plt.xticks(rotation=90) | |
| plt.tight_layout() | |
| return fig | |