TB-Guard-XAI / evaluate_model.py
Vignesh-19's picture
Upload folder using huggingface_hub
a041069 verified
# Comprehensive Model Evaluation β€” Optimized for GPU
import torch
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader
from sklearn.metrics import (
roc_auc_score, accuracy_score, precision_recall_fscore_support,
confusion_matrix, roc_curve, classification_report
)
import matplotlib.pyplot as plt
import json
from tqdm import tqdm
from ensemble_models import load_ensemble
from preprocessing import PreprocessedDataset, get_val_transforms
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODELS_DIR = Path("models")
PROCESSED_DIR = Path("datasets_processed")
OUTPUTS_DIR = Path("outputs/evaluation")
OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
# Maximize GPU utilization
BATCH_SIZE = 64 # Large batch to fill 16GB GPU
MC_SAMPLES = 20 # MC Dropout iterations
def load_dataset_split(split_dir):
"""Load images and labels"""
image_paths = []
labels = []
for cls, label in [("TB", 1), ("Normal", 0)]:
cls_dir = split_dir / cls
for img_path in cls_dir.glob("*"):
if img_path.suffix.lower() in ['.png', '.jpg', '.jpeg']:
image_paths.append(img_path)
labels.append(label)
return image_paths, labels
def evaluate_with_uncertainty_batched(model, dataloader, n_samples=20):
"""Batched MC Dropout evaluation β€” fast, uses full GPU"""
model.eval()
model.dropout.train() # Enable only dropout
all_means = []
all_stds = []
all_labels = []
with torch.no_grad(), torch.cuda.amp.autocast():
for images, labels in tqdm(dataloader, desc="Evaluating"):
images = images.to(DEVICE, non_blocking=True)
# Run MC Dropout samples in batch
batch_preds = []
for _ in range(n_samples):
pred = model._forward_with_dropout(images)
batch_preds.append(pred)
# Stack: [n_samples, batch_size]
batch_preds = torch.stack(batch_preds)
mean_pred = batch_preds.mean(dim=0).cpu().numpy()
std_pred = batch_preds.std(dim=0).cpu().numpy()
all_means.extend(mean_pred)
all_stds.extend(std_pred)
all_labels.extend(labels.numpy())
return np.array(all_means), np.array(all_stds), np.array(all_labels)
def calculate_calibration(predictions, labels, n_bins=10):
"""Calculate calibration metrics"""
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]
accuracies = []
confidences = []
bin_counts = []
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
in_bin = (predictions >= bin_lower) & (predictions < bin_upper)
prop_in_bin = in_bin.mean()
if prop_in_bin > 0:
accuracy_in_bin = labels[in_bin].mean()
avg_confidence_in_bin = predictions[in_bin].mean()
accuracies.append(accuracy_in_bin)
confidences.append(avg_confidence_in_bin)
bin_counts.append(in_bin.sum())
else:
accuracies.append(0)
confidences.append(0)
bin_counts.append(0)
# Expected Calibration Error
ece = np.sum([
(bin_counts[i] / len(predictions)) * abs(accuracies[i] - confidences[i])
for i in range(n_bins)
])
return {
'ece': ece,
'accuracies': accuracies,
'confidences': confidences,
'bin_counts': bin_counts
}
def plot_calibration(calibration_data, save_path):
"""Plot reliability diagram"""
fig, ax = plt.subplots(figsize=(8, 8))
confidences = calibration_data['confidences']
accuracies = calibration_data['accuracies']
ax.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration')
ax.plot(confidences, accuracies, 'o-', label=f'Model (ECE: {calibration_data["ece"]:.3f})')
ax.set_xlabel('Confidence', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Reliability Diagram', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.close()
def plot_roc_curve(labels, predictions, save_path):
"""Plot ROC curve"""
fpr, tpr, thresholds = roc_curve(labels, predictions)
auc = roc_auc_score(labels, predictions)
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(fpr, tpr, label=f'ROC Curve (AUC: {auc:.3f})')
ax.plot([0, 1], [0, 1], 'k--', label='Random')
ax.set_xlabel('False Positive Rate', fontsize=12)
ax.set_ylabel('True Positive Rate', fontsize=12)
ax.set_title('ROC Curve', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.close()
def plot_uncertainty_distribution(uncertainties, labels, save_path):
"""Plot uncertainty distribution"""
fig, ax = plt.subplots(figsize=(10, 6))
tb_uncertainties = uncertainties[labels == 1]
normal_uncertainties = uncertainties[labels == 0]
ax.hist(tb_uncertainties, bins=30, alpha=0.5, label='TB', color='red')
ax.hist(normal_uncertainties, bins=30, alpha=0.5, label='Normal', color='blue')
ax.set_xlabel('Uncertainty (Std Dev)', fontsize=12)
ax.set_ylabel('Count', fontsize=12)
ax.set_title('Prediction Uncertainty Distribution', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.close()
def analyze_failure_cases(predictions, uncertainties, labels, image_paths, threshold=0.5):
"""Analyze failure cases"""
preds_binary = (predictions > threshold).astype(int)
failures = preds_binary != labels
failure_indices = np.where(failures)[0]
failure_cases = []
for idx in failure_indices:
failure_cases.append({
"image": str(image_paths[idx]),
"true_label": "TB" if labels[idx] == 1 else "Normal",
"predicted_label": "TB" if preds_binary[idx] == 1 else "Normal",
"probability": float(predictions[idx]),
"uncertainty": float(uncertainties[idx])
})
# Sort by uncertainty
failure_cases.sort(key=lambda x: x['uncertainty'], reverse=True)
return failure_cases
def main():
print("="*60)
print("Comprehensive Model Evaluation")
print("="*60)
# Load model
print("\nLoading model...")
model = load_ensemble(MODELS_DIR / "ensemble_best.pth", DEVICE)
# Load training results for threshold
with open(MODELS_DIR / "training_results.json") as f:
results = json.load(f)
threshold = results.get("best_threshold", 0.5)
print(f"Using threshold: {threshold:.3f}")
print(f"Batch size: {BATCH_SIZE}")
print(f"MC Dropout samples: {MC_SAMPLES}")
# Evaluate on test set
print("\nEvaluating on test set...")
test_paths, test_labels = load_dataset_split(PROCESSED_DIR / "test")
test_dataset = PreprocessedDataset(
test_paths, test_labels,
transforms=get_val_transforms(),
use_preprocessing=True
)
# Use DataLoader for batched processing
test_loader = DataLoader(
test_dataset, batch_size=BATCH_SIZE,
num_workers=0, pin_memory=True, shuffle=False
)
predictions, uncertainties, labels = evaluate_with_uncertainty_batched(
model, test_loader, n_samples=MC_SAMPLES
)
# Calculate metrics
print("\nCalculating metrics...")
preds_binary = (predictions > threshold).astype(int)
acc = accuracy_score(labels, preds_binary)
auc = roc_auc_score(labels, predictions)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds_binary, average='binary')
cm = confusion_matrix(labels, preds_binary)
tn, fp, fn, tp = cm.ravel()
specificity = tn / (tn + fp)
sensitivity = tp / (tp + fn)
# Calibration
print("Calculating calibration...")
calibration_data = calculate_calibration(predictions, labels)
# Results
evaluation_results = {
"test_metrics": {
"accuracy": float(acc),
"auc": float(auc),
"precision": float(precision),
"recall": float(recall),
"sensitivity": float(sensitivity),
"specificity": float(specificity),
"f1": float(f1)
},
"confusion_matrix": {
"true_negative": int(tn),
"false_positive": int(fp),
"false_negative": int(fn),
"true_positive": int(tp)
},
"calibration": {
"ece": float(calibration_data['ece'])
},
"uncertainty": {
"mean": float(uncertainties.mean()),
"std": float(uncertainties.std()),
"min": float(uncertainties.min()),
"max": float(uncertainties.max())
},
"threshold": float(threshold)
}
# Print results
print("\n" + "="*60)
print("TEST SET RESULTS")
print("="*60)
print(f"\nAccuracy: {acc:.4f}")
print(f"AUC: {auc:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall/Sensitivity: {recall:.4f}")
print(f"Specificity: {specificity:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"\nExpected Calibration Error: {calibration_data['ece']:.4f}")
print(f"\nConfusion Matrix:")
print(f" TN: {tn}, FP: {fp}")
print(f" FN: {fn}, TP: {tp}")
# Generate plots
print("\nGenerating plots...")
plot_calibration(calibration_data, OUTPUTS_DIR / "calibration.png")
plot_roc_curve(labels, predictions, OUTPUTS_DIR / "roc_curve.png")
plot_uncertainty_distribution(uncertainties, labels, OUTPUTS_DIR / "uncertainty_dist.png")
# Failure analysis
print("\nAnalyzing failure cases...")
failure_cases = analyze_failure_cases(predictions, uncertainties, labels, test_paths, threshold)
print(f"Total failures: {len(failure_cases)}")
if failure_cases:
print(f"Top 5 uncertain failures:")
for i, case in enumerate(failure_cases[:5], 1):
print(f" {i}. {Path(case['image']).name}")
print(f" True: {case['true_label']}, Pred: {case['predicted_label']}")
print(f" Prob: {case['probability']:.3f}, Uncertainty: {case['uncertainty']:.3f}")
evaluation_results['failure_cases'] = failure_cases
# Save results
with open(OUTPUTS_DIR / "evaluation_results.json", 'w') as f:
json.dump(evaluation_results, f, indent=2)
print(f"\nβœ… Evaluation complete!")
print(f"πŸ“ Results saved to: {OUTPUTS_DIR}")
print(f"πŸ“Š Plots: calibration.png, roc_curve.png, uncertainty_dist.png")
print(f"πŸ“„ Full results: evaluation_results.json")
if __name__ == "__main__":
main()