Spaces:
Sleeping
Sleeping
File size: 3,231 Bytes
af59988 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
"""
Evaluation functions for Pneumonia classification.
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from typing import Dict, Tuple
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
roc_auc_score, confusion_matrix, classification_report
)
from .config import CLASS_NAMES
def predict_proba(
model: nn.Module,
loader: DataLoader,
device: torch.device
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Get predictions, probabilities, and true labels."""
model.eval()
all_probs, all_preds, all_labels = [], [], []
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
outputs = model(images)
probs = torch.sigmoid(outputs).cpu().numpy()
preds = (probs > 0.5).astype(int)
all_probs.extend(probs.flatten())
all_preds.extend(preds.flatten())
all_labels.extend(labels.numpy())
return np.array(all_probs), np.array(all_preds), np.array(all_labels)
def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray) -> Dict:
"""Compute all evaluation metrics."""
return {
'accuracy': accuracy_score(y_true, y_pred),
'precision': precision_score(y_true, y_pred),
'recall': recall_score(y_true, y_pred),
'f1': f1_score(y_true, y_pred),
'roc_auc': roc_auc_score(y_true, y_proba),
'confusion_matrix': confusion_matrix(y_true, y_pred)
}
def evaluate_model(
model: nn.Module,
loader: DataLoader,
device: torch.device
) -> Dict:
"""Full evaluation on a dataset."""
probs, preds, labels = predict_proba(model, loader, device)
metrics = compute_metrics(labels, preds, probs)
print("=" * 50)
print("EVALUATION RESULTS")
print("=" * 50)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1 Score: {metrics['f1']:.4f}")
print(f"ROC-AUC: {metrics['roc_auc']:.4f}")
print("\nConfusion Matrix:")
print(f" {CLASS_NAMES[0]:>10} {CLASS_NAMES[1]:>10}")
for i, row in enumerate(metrics['confusion_matrix']):
print(f" {CLASS_NAMES[i]:>10} {row[0]:>10} {row[1]:>10}")
print("\nClassification Report:")
print(classification_report(labels, preds, target_names=CLASS_NAMES))
return metrics
def get_predictions_with_paths(
model: nn.Module,
dataset,
device: torch.device
) -> list:
"""Get predictions with image paths for error analysis."""
model.eval()
results = []
with torch.no_grad():
for idx in range(len(dataset)):
image, label = dataset[idx]
image = image.unsqueeze(0).to(device)
output = model(image)
prob = torch.sigmoid(output).item()
pred = 1 if prob > 0.5 else 0
results.append({
'path': dataset.image_paths[idx],
'true_label': label,
'pred_label': pred,
'probability': prob,
'correct': pred == label
})
return results
|