import os import sys from pathlib import Path # Add project root to sys.path sys.path.append(str(Path(__file__).parent.parent)) import matplotlib # noqa: E402 matplotlib.use("Agg") # Use headless backend import matplotlib.pyplot as plt # noqa: E402 import torch # noqa: E402 import torch.nn as nn # noqa: E402 import yaml # noqa: E402 from sklearn.metrics import ( # noqa: E402 ConfusionMatrixDisplay, classification_report, confusion_matrix, ) def load_config(config_path="config.yaml"): with open(config_path, "r") as f: return yaml.safe_load(f) config = load_config() CLASSES = config["classes"] def get_device(config_device): if config_device == "auto": return "cuda" if torch.cuda.is_available() else "cpu" return config_device DEVICE = get_device(config["device"]) def evaluate(model, data_loader, device=DEVICE, save_dir="models/plots"): """ Evaluates a PyTorch model on a given DataLoader. Args: model: The PyTorch model to evaluate. data_loader: The DataLoader providing the evaluation data. device: The device to run evaluation on (e.g., 'cuda', 'cpu'). save_dir: Directory to save plots. Returns: avg_loss (float): The average loss over the dataset. accuracy (float): The classification accuracy (0.0 to 1.0). """ model.to(device) model.eval() criterion = nn.CrossEntropyLoss() total_loss = 0.0 correct = 0 total = 0 all_preds = [] all_labels = [] with torch.no_grad(): for images, labels in data_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) total_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) avg_loss = total_loss / len(data_loader) accuracy = correct / total print("\nEvaluation Results:") print(f"Average Loss: {avg_loss:.4f}") print(f"Accuracy: {accuracy:.4f}") # Classification Report print("\nClassification Report:") report = classification_report( all_labels, all_preds, target_names=CLASSES, labels=range(len(CLASSES)), zero_division=0 ) print(report) # Confusion Matrix cm = confusion_matrix(all_labels, all_preds, labels=range(len(CLASSES))) os.makedirs(save_dir, exist_ok=True) fig, ax = plt.subplots(figsize=(10, 8)) disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=CLASSES) disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation=45) plt.title("Confusion Matrix") plt.tight_layout() plt.savefig(f"{save_dir}/confusion_matrix.png") print(f"\nConfusion matrix saved to {save_dir}/confusion_matrix.png") return avg_loss, accuracy