PneumoniaAPI / src /evaluate.py
GitHub Actions
Auto-deploy from GitHub: 495db78a06be79166200269bb14d9e9b1e8906d6
af59988
"""
Evaluation functions for Pneumonia classification.
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from typing import Dict, Tuple
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
roc_auc_score, confusion_matrix, classification_report
)
from .config import CLASS_NAMES
def predict_proba(
model: nn.Module,
loader: DataLoader,
device: torch.device
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Get predictions, probabilities, and true labels."""
model.eval()
all_probs, all_preds, all_labels = [], [], []
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
outputs = model(images)
probs = torch.sigmoid(outputs).cpu().numpy()
preds = (probs > 0.5).astype(int)
all_probs.extend(probs.flatten())
all_preds.extend(preds.flatten())
all_labels.extend(labels.numpy())
return np.array(all_probs), np.array(all_preds), np.array(all_labels)
def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray) -> Dict:
"""Compute all evaluation metrics."""
return {
'accuracy': accuracy_score(y_true, y_pred),
'precision': precision_score(y_true, y_pred),
'recall': recall_score(y_true, y_pred),
'f1': f1_score(y_true, y_pred),
'roc_auc': roc_auc_score(y_true, y_proba),
'confusion_matrix': confusion_matrix(y_true, y_pred)
}
def evaluate_model(
model: nn.Module,
loader: DataLoader,
device: torch.device
) -> Dict:
"""Full evaluation on a dataset."""
probs, preds, labels = predict_proba(model, loader, device)
metrics = compute_metrics(labels, preds, probs)
print("=" * 50)
print("EVALUATION RESULTS")
print("=" * 50)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1 Score: {metrics['f1']:.4f}")
print(f"ROC-AUC: {metrics['roc_auc']:.4f}")
print("\nConfusion Matrix:")
print(f" {CLASS_NAMES[0]:>10} {CLASS_NAMES[1]:>10}")
for i, row in enumerate(metrics['confusion_matrix']):
print(f" {CLASS_NAMES[i]:>10} {row[0]:>10} {row[1]:>10}")
print("\nClassification Report:")
print(classification_report(labels, preds, target_names=CLASS_NAMES))
return metrics
def get_predictions_with_paths(
model: nn.Module,
dataset,
device: torch.device
) -> list:
"""Get predictions with image paths for error analysis."""
model.eval()
results = []
with torch.no_grad():
for idx in range(len(dataset)):
image, label = dataset[idx]
image = image.unsqueeze(0).to(device)
output = model(image)
prob = torch.sigmoid(output).item()
pred = 1 if prob > 0.5 else 0
results.append({
'path': dataset.image_paths[idx],
'true_label': label,
'pred_label': pred,
'probability': prob,
'correct': pred == label
})
return results