""" Evaluation functions for Pneumonia classification. """ import torch import torch.nn as nn from torch.utils.data import DataLoader import numpy as np from typing import Dict, Tuple from sklearn.metrics import ( accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, classification_report ) from .config import CLASS_NAMES def predict_proba( model: nn.Module, loader: DataLoader, device: torch.device ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Get predictions, probabilities, and true labels.""" model.eval() all_probs, all_preds, all_labels = [], [], [] with torch.no_grad(): for images, labels in loader: images = images.to(device) outputs = model(images) probs = torch.sigmoid(outputs).cpu().numpy() preds = (probs > 0.5).astype(int) all_probs.extend(probs.flatten()) all_preds.extend(preds.flatten()) all_labels.extend(labels.numpy()) return np.array(all_probs), np.array(all_preds), np.array(all_labels) def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray) -> Dict: """Compute all evaluation metrics.""" return { 'accuracy': accuracy_score(y_true, y_pred), 'precision': precision_score(y_true, y_pred), 'recall': recall_score(y_true, y_pred), 'f1': f1_score(y_true, y_pred), 'roc_auc': roc_auc_score(y_true, y_proba), 'confusion_matrix': confusion_matrix(y_true, y_pred) } def evaluate_model( model: nn.Module, loader: DataLoader, device: torch.device ) -> Dict: """Full evaluation on a dataset.""" probs, preds, labels = predict_proba(model, loader, device) metrics = compute_metrics(labels, preds, probs) print("=" * 50) print("EVALUATION RESULTS") print("=" * 50) print(f"Accuracy: {metrics['accuracy']:.4f}") print(f"Precision: {metrics['precision']:.4f}") print(f"Recall: {metrics['recall']:.4f}") print(f"F1 Score: {metrics['f1']:.4f}") print(f"ROC-AUC: {metrics['roc_auc']:.4f}") print("\nConfusion Matrix:") print(f" {CLASS_NAMES[0]:>10} {CLASS_NAMES[1]:>10}") for i, row in enumerate(metrics['confusion_matrix']): print(f" {CLASS_NAMES[i]:>10} {row[0]:>10} {row[1]:>10}") print("\nClassification Report:") print(classification_report(labels, preds, target_names=CLASS_NAMES)) return metrics def get_predictions_with_paths( model: nn.Module, dataset, device: torch.device ) -> list: """Get predictions with image paths for error analysis.""" model.eval() results = [] with torch.no_grad(): for idx in range(len(dataset)): image, label = dataset[idx] image = image.unsqueeze(0).to(device) output = model(image) prob = torch.sigmoid(output).item() pred = 1 if prob > 0.5 else 0 results.append({ 'path': dataset.image_paths[idx], 'true_label': label, 'pred_label': pred, 'probability': prob, 'correct': pred == label }) return results