#!/usr/bin/env python3 """ Evaluation module for DCA-Net. Computes classification metrics, FROC, calibration, uncertainty, and subgroup analysis. Generates all publication-quality plots. """ import numpy as np import pandas as pd from sklearn.metrics import ( roc_auc_score, average_precision_score, roc_curve, precision_recall_curve, f1_score, accuracy_score, confusion_matrix, classification_report ) import torch import torch.nn as nn import logging from tqdm import tqdm from pathlib import Path import json from src.evaluation.visualizations import generate_all_plots class Evaluator: """Comprehensive evaluation for lung nodule classification. Computes: - AUC-ROC, AUC-PR - Sensitivity, Specificity, F1 - FROC (sensitivity at different FP rates) - Expected Calibration Error (ECE) - MC Dropout uncertainty metrics - Subgroup analysis Generates all plots via visualizations module. Args: model: DCANet model (or DataParallel wrapped) test_loader: DataLoader for test set device: torch device logger: Logger instance """ def __init__(self, model, test_loader, device=None, logger=None): self.model = model self.test_loader = test_loader self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self.logger = logger or logging.getLogger('dca-net') @torch.no_grad() def collect_predictions(self): """Run model on test set, collect predictions and labels.""" self.model.eval() all_probs = [] all_labels = [] for nodule, context, labels in tqdm(self.test_loader, desc="Evaluating"): nodule = nodule.to(self.device) context = context.to(self.device) logits = self.model(nodule, context) probs = torch.sigmoid(logits.squeeze(-1)) all_probs.extend(probs.cpu().numpy()) all_labels.extend(labels.numpy()) return np.array(all_probs), np.array(all_labels) def collect_uncertainty(self, mc_passes=5): """Run MC Dropout uncertainty estimation on test set. Args: mc_passes: Number of stochastic forward passes Returns: mean_probs, confidences, labels: numpy arrays """ # Get the raw model (unwrap DataParallel) raw_model = self.model if isinstance(raw_model, nn.DataParallel): raw_model = raw_model.module self.logger.info(f" Running MC Dropout ({mc_passes} passes)...") all_mean_probs = [] all_confidences = [] all_labels = [] for nodule, context, labels in tqdm(self.test_loader, desc="MC Dropout"): nodule = nodule.to(self.device) context = context.to(self.device) mean_prob, confidence = raw_model.predict_with_uncertainty( nodule, context ) all_mean_probs.extend(mean_prob.cpu().numpy()) all_confidences.extend(confidence.cpu().numpy()) all_labels.extend(labels.numpy()) return (np.array(all_mean_probs), np.array(all_confidences), np.array(all_labels)) def compute_metrics(self, probs, labels, threshold=0.5): """Compute all classification metrics.""" preds = (probs >= threshold).astype(int) # Handle edge case: single class has_both = len(np.unique(labels)) > 1 auc_roc = roc_auc_score(labels, probs) if has_both else 0.0 auc_pr = average_precision_score(labels, probs) if has_both else 0.0 f1 = f1_score(labels, preds, zero_division=0) acc = accuracy_score(labels, preds) tn, fp, fn, tp = confusion_matrix(labels, preds, labels=[0, 1]).ravel() sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0 specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0 precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 # FP per scan (assume ~30 candidates/scan, per spec Section 2.3) total_scans = max(len(labels) / 30.0, 1.0) fp_per_scan = float(fp) / total_scans metrics = { 'auc_roc': float(auc_roc), 'auc_pr': float(auc_pr), 'f1_score': float(f1), 'accuracy': float(acc), 'sensitivity': float(sensitivity), 'specificity': float(specificity), 'precision': float(precision), 'fp_per_scan': float(fp_per_scan), 'true_positives': int(tp), 'false_positives': int(fp), 'true_negatives': int(tn), 'false_negatives': int(fn), 'threshold': float(threshold), } return metrics def compute_froc(self, probs, labels, fp_rates=[0.5, 1, 2, 4, 8]): """Compute FROC — sensitivity at specific false positive rates.""" if len(np.unique(labels)) < 2: return {f'sensitivity_at_{fp}fp': 0.0 for fp in fp_rates} fpr, tpr, _ = roc_curve(labels, probs) n_neg = (labels == 0).sum() froc = {} for fp_rate in fp_rates: target_fpr = min(fp_rate / max(n_neg, 1), 1.0) idx = np.searchsorted(fpr, target_fpr) idx = min(idx, len(tpr) - 1) froc[f'sensitivity_at_{fp_rate}fp'] = float(tpr[idx]) return froc def compute_ece(self, probs, labels, n_bins=10): """Compute Expected Calibration Error.""" bin_boundaries = np.linspace(0, 1, n_bins + 1) ece = 0.0 for i in range(n_bins): mask = (probs >= bin_boundaries[i]) & (probs < bin_boundaries[i + 1]) if mask.sum() == 0: continue bin_conf = probs[mask].mean() bin_acc = labels[mask].mean() ece += mask.sum() * abs(bin_conf - bin_acc) ece /= len(probs) return float(ece) def compute_uncertainty_metrics(self, mean_probs, confidences, labels): """Compute uncertainty-specific metrics.""" preds = (mean_probs > 0.5).astype(int) correct = (preds == labels) metrics = { 'mean_confidence': float(confidences.mean()), 'mean_confidence_correct': float( confidences[correct].mean() if correct.sum() > 0 else 0 ), 'mean_confidence_incorrect': float( confidences[~correct].mean() if (~correct).sum() > 0 else 0 ), 'uncertain_cases_ratio': float( (confidences < 0.7).sum() / len(confidences) ), 'uncertain_cases_count': int((confidences < 0.7).sum()), 'misclassified_flagged_by_uncertainty': 0.0, } # Key metric: what fraction of misclassified cases had low confidence? if (~correct).sum() > 0: flagged = ((~correct) & (confidences < 0.7)).sum() metrics['misclassified_flagged_by_uncertainty'] = float( flagged / (~correct).sum() ) return metrics def evaluate(self, output_dir=None, run_uncertainty=True, metadata_csv=None, training_log=None): """Run full evaluation pipeline with all metrics and plots. Args: output_dir: Directory to save results and plots run_uncertainty: Whether to run MC Dropout uncertainty estimation metadata_csv: Path to metadata CSV with diameter info training_log: Path to training log for training curves Returns: dict: All computed metrics """ self.logger.info("\n" + "=" * 60) self.logger.info("COMPREHENSIVE EVALUATION") self.logger.info("=" * 60) # ── 1. Collect predictions ── self.logger.info("\n1. Collecting predictions...") probs, labels = self.collect_predictions() self.logger.info(f" Samples: {len(probs)}") self.logger.info(f" Positives: {(labels == 1).sum()}") self.logger.info(f" Negatives: {(labels == 0).sum()}") # ── 2. Classification metrics ── self.logger.info("\n2. Computing classification metrics...") metrics = self.compute_metrics(probs, labels) # ── 3. FROC ── self.logger.info("3. Computing FROC...") froc = self.compute_froc(probs, labels) metrics.update(froc) # ── 4. Calibration ── self.logger.info("4. Computing calibration (ECE)...") metrics['ece'] = self.compute_ece(probs, labels) # ── 5. Uncertainty ── mean_probs = None confidences = None raw_model = self.model if isinstance(raw_model, nn.DataParallel): raw_model = raw_model.module if run_uncertainty and not hasattr(raw_model, 'predict_with_uncertainty'): self.logger.warning("Model lacks 'predict_with_uncertainty' method. Skipping uncertainty estimation.") run_uncertainty = False if run_uncertainty: self.logger.info("5. Running uncertainty estimation...") mean_probs, confidences, unc_labels = self.collect_uncertainty() uncertainty_metrics = self.compute_uncertainty_metrics( mean_probs, confidences, unc_labels ) metrics['uncertainty'] = uncertainty_metrics else: self.logger.info("5. Skipping uncertainty estimation") # ── 6. Log all results ── self.logger.info("\n" + "=" * 60) self.logger.info("RESULTS SUMMARY") self.logger.info("=" * 60) result_lines = [ f" AUC-ROC: {metrics['auc_roc']:.4f}", f" AUC-PR: {metrics['auc_pr']:.4f}", f" Sensitivity: {metrics['sensitivity']:.4f}", f" Specificity: {metrics['specificity']:.4f}", f" Precision: {metrics['precision']:.4f}", f" F1-Score: {metrics['f1_score']:.4f}", f" Accuracy: {metrics['accuracy']:.4f}", f" ECE: {metrics['ece']:.4f}", ] if 'sensitivity_at_1fp' in metrics: result_lines.append(f" Sens@1FP: {metrics['sensitivity_at_1fp']:.4f}") result_lines.append(f" Sens@4FP: {metrics['sensitivity_at_4fp']:.4f}") if 'uncertainty' in metrics: um = metrics['uncertainty'] result_lines.extend([ f" Mean Conf (correct): {um['mean_confidence_correct']:.4f}", f" Mean Conf (incorrect): {um['mean_confidence_incorrect']:.4f}", f" Misclassified flagged: {um['misclassified_flagged_by_uncertainty']:.2%}", ]) for line in result_lines: self.logger.info(line) # ── 7. Save results and generate plots ── if output_dir: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) figures_dir = output_dir / 'figures' figures_dir.mkdir(exist_ok=True) # Save metrics JSON with open(output_dir / 'evaluation_results.json', 'w') as f: json.dump(metrics, f, indent=2) # Save raw predictions np.savez( output_dir / 'predictions.npz', probs=probs, labels=labels, mean_probs=mean_probs if mean_probs is not None else [], confidences=confidences if confidences is not None else [] ) # Load metadata for subgroup analysis metadata_df = None if metadata_csv and Path(metadata_csv).exists(): metadata_df = pd.read_csv(metadata_csv) # Generate all plots self.logger.info("\n6. Generating plots...") plot_paths = generate_all_plots( labels=labels, probs=probs, output_dir=figures_dir, mean_probs=mean_probs, confidences=confidences, metadata_df=metadata_df, log_path=training_log, ) # Save plot index with open(output_dir / 'plot_index.json', 'w') as f: json.dump(plot_paths, f, indent=2) # Generate text report self._generate_text_report(metrics, output_dir / 'evaluation_report.txt') self.logger.info(f"\n✅ All results saved to {output_dir}/") self.logger.info(f" - evaluation_results.json") self.logger.info(f" - predictions.npz") self.logger.info(f" - evaluation_report.txt") self.logger.info(f" - figures/ ({len(plot_paths)} plots)") return metrics def _generate_text_report(self, metrics, output_path): """Generate a human-readable text report.""" lines = [ "=" * 60, "DCA-NET EVALUATION REPORT", "Lung Nodule Classification — LUNA16 Dataset", "=" * 60, "", "CLASSIFICATION METRICS", "-" * 40, f" AUC-ROC: {metrics['auc_roc']:.4f}", f" AUC-PR: {metrics['auc_pr']:.4f}", f" Sensitivity: {metrics['sensitivity']:.4f}", f" Specificity: {metrics['specificity']:.4f}", f" Precision: {metrics['precision']:.4f}", f" F1-Score: {metrics['f1_score']:.4f}", f" Accuracy: {metrics['accuracy']:.4f}", "", "CONFUSION MATRIX", "-" * 40, f" True Positives: {metrics['true_positives']}", f" False Positives: {metrics['false_positives']}", f" True Negatives: {metrics['true_negatives']}", f" False Negatives: {metrics['false_negatives']}", "", "FROC ANALYSIS", "-" * 40, ] for fp in [0.5, 1, 2, 4, 8]: key = f'sensitivity_at_{fp}fp' if key in metrics: lines.append(f" Sensitivity @ {fp} FP/scan: {metrics[key]:.4f}") lines.extend([ "", "CALIBRATION", "-" * 40, f" ECE: {metrics['ece']:.4f}", ]) if 'uncertainty' in metrics: um = metrics['uncertainty'] lines.extend([ "", "UNCERTAINTY ANALYSIS", "-" * 40, f" Mean Confidence (all): {um['mean_confidence']:.4f}", f" Mean Confidence (correct): {um['mean_confidence_correct']:.4f}", f" Mean Confidence (incorrect): {um['mean_confidence_incorrect']:.4f}", f" Uncertain cases (<0.7): {um['uncertain_cases_count']} " f"({um['uncertain_cases_ratio']:.1%})", f" Misclassified flagged: " f"{um['misclassified_flagged_by_uncertainty']:.1%}", ]) lines.extend(["", "=" * 60]) with open(output_path, 'w') as f: f.write('\n'.join(lines))