OncoVision-X / src /evaluation /evaluator.py
adityasync's picture
Clean OncoVision-X deployment with LFS
8960670
#!/usr/bin/env python3
"""
Evaluation module for DCA-Net.
Computes classification metrics, FROC, calibration, uncertainty, and subgroup analysis.
Generates all publication-quality plots.
"""
import numpy as np
import pandas as pd
from sklearn.metrics import (
roc_auc_score, average_precision_score, roc_curve,
precision_recall_curve, f1_score, accuracy_score,
confusion_matrix, classification_report
)
import torch
import torch.nn as nn
import logging
from tqdm import tqdm
from pathlib import Path
import json
from src.evaluation.visualizations import generate_all_plots
class Evaluator:
"""Comprehensive evaluation for lung nodule classification.
Computes:
- AUC-ROC, AUC-PR
- Sensitivity, Specificity, F1
- FROC (sensitivity at different FP rates)
- Expected Calibration Error (ECE)
- MC Dropout uncertainty metrics
- Subgroup analysis
Generates all plots via visualizations module.
Args:
model: DCANet model (or DataParallel wrapped)
test_loader: DataLoader for test set
device: torch device
logger: Logger instance
"""
def __init__(self, model, test_loader, device=None, logger=None):
self.model = model
self.test_loader = test_loader
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.logger = logger or logging.getLogger('dca-net')
@torch.no_grad()
def collect_predictions(self):
"""Run model on test set, collect predictions and labels."""
self.model.eval()
all_probs = []
all_labels = []
for nodule, context, labels in tqdm(self.test_loader, desc="Evaluating"):
nodule = nodule.to(self.device)
context = context.to(self.device)
logits = self.model(nodule, context)
probs = torch.sigmoid(logits.squeeze(-1))
all_probs.extend(probs.cpu().numpy())
all_labels.extend(labels.numpy())
return np.array(all_probs), np.array(all_labels)
def collect_uncertainty(self, mc_passes=5):
"""Run MC Dropout uncertainty estimation on test set.
Args:
mc_passes: Number of stochastic forward passes
Returns:
mean_probs, confidences, labels: numpy arrays
"""
# Get the raw model (unwrap DataParallel)
raw_model = self.model
if isinstance(raw_model, nn.DataParallel):
raw_model = raw_model.module
self.logger.info(f" Running MC Dropout ({mc_passes} passes)...")
all_mean_probs = []
all_confidences = []
all_labels = []
for nodule, context, labels in tqdm(self.test_loader, desc="MC Dropout"):
nodule = nodule.to(self.device)
context = context.to(self.device)
mean_prob, confidence = raw_model.predict_with_uncertainty(
nodule, context
)
all_mean_probs.extend(mean_prob.cpu().numpy())
all_confidences.extend(confidence.cpu().numpy())
all_labels.extend(labels.numpy())
return (np.array(all_mean_probs),
np.array(all_confidences),
np.array(all_labels))
def compute_metrics(self, probs, labels, threshold=0.5):
"""Compute all classification metrics."""
preds = (probs >= threshold).astype(int)
# Handle edge case: single class
has_both = len(np.unique(labels)) > 1
auc_roc = roc_auc_score(labels, probs) if has_both else 0.0
auc_pr = average_precision_score(labels, probs) if has_both else 0.0
f1 = f1_score(labels, preds, zero_division=0)
acc = accuracy_score(labels, preds)
tn, fp, fn, tp = confusion_matrix(labels, preds, labels=[0, 1]).ravel()
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
# FP per scan (assume ~30 candidates/scan, per spec Section 2.3)
total_scans = max(len(labels) / 30.0, 1.0)
fp_per_scan = float(fp) / total_scans
metrics = {
'auc_roc': float(auc_roc),
'auc_pr': float(auc_pr),
'f1_score': float(f1),
'accuracy': float(acc),
'sensitivity': float(sensitivity),
'specificity': float(specificity),
'precision': float(precision),
'fp_per_scan': float(fp_per_scan),
'true_positives': int(tp),
'false_positives': int(fp),
'true_negatives': int(tn),
'false_negatives': int(fn),
'threshold': float(threshold),
}
return metrics
def compute_froc(self, probs, labels, fp_rates=[0.5, 1, 2, 4, 8]):
"""Compute FROC — sensitivity at specific false positive rates."""
if len(np.unique(labels)) < 2:
return {f'sensitivity_at_{fp}fp': 0.0 for fp in fp_rates}
fpr, tpr, _ = roc_curve(labels, probs)
n_neg = (labels == 0).sum()
froc = {}
for fp_rate in fp_rates:
target_fpr = min(fp_rate / max(n_neg, 1), 1.0)
idx = np.searchsorted(fpr, target_fpr)
idx = min(idx, len(tpr) - 1)
froc[f'sensitivity_at_{fp_rate}fp'] = float(tpr[idx])
return froc
def compute_ece(self, probs, labels, n_bins=10):
"""Compute Expected Calibration Error."""
bin_boundaries = np.linspace(0, 1, n_bins + 1)
ece = 0.0
for i in range(n_bins):
mask = (probs >= bin_boundaries[i]) & (probs < bin_boundaries[i + 1])
if mask.sum() == 0:
continue
bin_conf = probs[mask].mean()
bin_acc = labels[mask].mean()
ece += mask.sum() * abs(bin_conf - bin_acc)
ece /= len(probs)
return float(ece)
def compute_uncertainty_metrics(self, mean_probs, confidences, labels):
"""Compute uncertainty-specific metrics."""
preds = (mean_probs > 0.5).astype(int)
correct = (preds == labels)
metrics = {
'mean_confidence': float(confidences.mean()),
'mean_confidence_correct': float(
confidences[correct].mean() if correct.sum() > 0 else 0
),
'mean_confidence_incorrect': float(
confidences[~correct].mean() if (~correct).sum() > 0 else 0
),
'uncertain_cases_ratio': float(
(confidences < 0.7).sum() / len(confidences)
),
'uncertain_cases_count': int((confidences < 0.7).sum()),
'misclassified_flagged_by_uncertainty': 0.0,
}
# Key metric: what fraction of misclassified cases had low confidence?
if (~correct).sum() > 0:
flagged = ((~correct) & (confidences < 0.7)).sum()
metrics['misclassified_flagged_by_uncertainty'] = float(
flagged / (~correct).sum()
)
return metrics
def evaluate(self, output_dir=None, run_uncertainty=True,
metadata_csv=None, training_log=None):
"""Run full evaluation pipeline with all metrics and plots.
Args:
output_dir: Directory to save results and plots
run_uncertainty: Whether to run MC Dropout uncertainty estimation
metadata_csv: Path to metadata CSV with diameter info
training_log: Path to training log for training curves
Returns:
dict: All computed metrics
"""
self.logger.info("\n" + "=" * 60)
self.logger.info("COMPREHENSIVE EVALUATION")
self.logger.info("=" * 60)
# ── 1. Collect predictions ──
self.logger.info("\n1. Collecting predictions...")
probs, labels = self.collect_predictions()
self.logger.info(f" Samples: {len(probs)}")
self.logger.info(f" Positives: {(labels == 1).sum()}")
self.logger.info(f" Negatives: {(labels == 0).sum()}")
# ── 2. Classification metrics ──
self.logger.info("\n2. Computing classification metrics...")
metrics = self.compute_metrics(probs, labels)
# ── 3. FROC ──
self.logger.info("3. Computing FROC...")
froc = self.compute_froc(probs, labels)
metrics.update(froc)
# ── 4. Calibration ──
self.logger.info("4. Computing calibration (ECE)...")
metrics['ece'] = self.compute_ece(probs, labels)
# ── 5. Uncertainty ──
mean_probs = None
confidences = None
raw_model = self.model
if isinstance(raw_model, nn.DataParallel):
raw_model = raw_model.module
if run_uncertainty and not hasattr(raw_model, 'predict_with_uncertainty'):
self.logger.warning("Model lacks 'predict_with_uncertainty' method. Skipping uncertainty estimation.")
run_uncertainty = False
if run_uncertainty:
self.logger.info("5. Running uncertainty estimation...")
mean_probs, confidences, unc_labels = self.collect_uncertainty()
uncertainty_metrics = self.compute_uncertainty_metrics(
mean_probs, confidences, unc_labels
)
metrics['uncertainty'] = uncertainty_metrics
else:
self.logger.info("5. Skipping uncertainty estimation")
# ── 6. Log all results ──
self.logger.info("\n" + "=" * 60)
self.logger.info("RESULTS SUMMARY")
self.logger.info("=" * 60)
result_lines = [
f" AUC-ROC: {metrics['auc_roc']:.4f}",
f" AUC-PR: {metrics['auc_pr']:.4f}",
f" Sensitivity: {metrics['sensitivity']:.4f}",
f" Specificity: {metrics['specificity']:.4f}",
f" Precision: {metrics['precision']:.4f}",
f" F1-Score: {metrics['f1_score']:.4f}",
f" Accuracy: {metrics['accuracy']:.4f}",
f" ECE: {metrics['ece']:.4f}",
]
if 'sensitivity_at_1fp' in metrics:
result_lines.append(f" Sens@1FP: {metrics['sensitivity_at_1fp']:.4f}")
result_lines.append(f" Sens@4FP: {metrics['sensitivity_at_4fp']:.4f}")
if 'uncertainty' in metrics:
um = metrics['uncertainty']
result_lines.extend([
f" Mean Conf (correct): {um['mean_confidence_correct']:.4f}",
f" Mean Conf (incorrect): {um['mean_confidence_incorrect']:.4f}",
f" Misclassified flagged: {um['misclassified_flagged_by_uncertainty']:.2%}",
])
for line in result_lines:
self.logger.info(line)
# ── 7. Save results and generate plots ──
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
figures_dir = output_dir / 'figures'
figures_dir.mkdir(exist_ok=True)
# Save metrics JSON
with open(output_dir / 'evaluation_results.json', 'w') as f:
json.dump(metrics, f, indent=2)
# Save raw predictions
np.savez(
output_dir / 'predictions.npz',
probs=probs, labels=labels,
mean_probs=mean_probs if mean_probs is not None else [],
confidences=confidences if confidences is not None else []
)
# Load metadata for subgroup analysis
metadata_df = None
if metadata_csv and Path(metadata_csv).exists():
metadata_df = pd.read_csv(metadata_csv)
# Generate all plots
self.logger.info("\n6. Generating plots...")
plot_paths = generate_all_plots(
labels=labels,
probs=probs,
output_dir=figures_dir,
mean_probs=mean_probs,
confidences=confidences,
metadata_df=metadata_df,
log_path=training_log,
)
# Save plot index
with open(output_dir / 'plot_index.json', 'w') as f:
json.dump(plot_paths, f, indent=2)
# Generate text report
self._generate_text_report(metrics, output_dir / 'evaluation_report.txt')
self.logger.info(f"\n✅ All results saved to {output_dir}/")
self.logger.info(f" - evaluation_results.json")
self.logger.info(f" - predictions.npz")
self.logger.info(f" - evaluation_report.txt")
self.logger.info(f" - figures/ ({len(plot_paths)} plots)")
return metrics
def _generate_text_report(self, metrics, output_path):
"""Generate a human-readable text report."""
lines = [
"=" * 60,
"DCA-NET EVALUATION REPORT",
"Lung Nodule Classification — LUNA16 Dataset",
"=" * 60,
"",
"CLASSIFICATION METRICS",
"-" * 40,
f" AUC-ROC: {metrics['auc_roc']:.4f}",
f" AUC-PR: {metrics['auc_pr']:.4f}",
f" Sensitivity: {metrics['sensitivity']:.4f}",
f" Specificity: {metrics['specificity']:.4f}",
f" Precision: {metrics['precision']:.4f}",
f" F1-Score: {metrics['f1_score']:.4f}",
f" Accuracy: {metrics['accuracy']:.4f}",
"",
"CONFUSION MATRIX",
"-" * 40,
f" True Positives: {metrics['true_positives']}",
f" False Positives: {metrics['false_positives']}",
f" True Negatives: {metrics['true_negatives']}",
f" False Negatives: {metrics['false_negatives']}",
"",
"FROC ANALYSIS",
"-" * 40,
]
for fp in [0.5, 1, 2, 4, 8]:
key = f'sensitivity_at_{fp}fp'
if key in metrics:
lines.append(f" Sensitivity @ {fp} FP/scan: {metrics[key]:.4f}")
lines.extend([
"",
"CALIBRATION",
"-" * 40,
f" ECE: {metrics['ece']:.4f}",
])
if 'uncertainty' in metrics:
um = metrics['uncertainty']
lines.extend([
"",
"UNCERTAINTY ANALYSIS",
"-" * 40,
f" Mean Confidence (all): {um['mean_confidence']:.4f}",
f" Mean Confidence (correct): {um['mean_confidence_correct']:.4f}",
f" Mean Confidence (incorrect): {um['mean_confidence_incorrect']:.4f}",
f" Uncertain cases (<0.7): {um['uncertain_cases_count']} "
f"({um['uncertain_cases_ratio']:.1%})",
f" Misclassified flagged: "
f"{um['misclassified_flagged_by_uncertainty']:.1%}",
])
lines.extend(["", "=" * 60])
with open(output_path, 'w') as f:
f.write('\n'.join(lines))