Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Visualization module for DCA-Net evaluation. | |
| Generates all plots needed for the research paper. | |
| """ | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use('Agg') # Non-interactive backend for server | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import ( | |
| roc_curve, auc, precision_recall_curve, average_precision_score, | |
| confusion_matrix | |
| ) | |
| from pathlib import Path | |
| import json | |
| import logging | |
| # Set publication-quality defaults | |
| plt.rcParams.update({ | |
| 'font.size': 12, | |
| 'axes.labelsize': 14, | |
| 'axes.titlesize': 15, | |
| 'xtick.labelsize': 11, | |
| 'ytick.labelsize': 11, | |
| 'legend.fontsize': 11, | |
| 'figure.figsize': (8, 6), | |
| 'figure.dpi': 150, | |
| 'savefig.bbox': 'tight', | |
| 'savefig.dpi': 300, | |
| }) | |
| def plot_roc_curve(labels, probs, output_path): | |
| """Plot ROC curve with AUC score.""" | |
| fpr, tpr, _ = roc_curve(labels, probs) | |
| roc_auc = auc(fpr, tpr) | |
| fig, ax = plt.subplots() | |
| ax.plot(fpr, tpr, color='#2563EB', lw=2.5, | |
| label=f'DCA-Net (AUC = {roc_auc:.4f})') | |
| ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5, label='Random') | |
| ax.set_xlim([0.0, 1.0]) | |
| ax.set_ylim([0.0, 1.05]) | |
| ax.set_xlabel('False Positive Rate') | |
| ax.set_ylabel('True Positive Rate') | |
| ax.set_title('ROC Curve — Lung Nodule Classification') | |
| ax.legend(loc='lower right') | |
| ax.grid(True, alpha=0.3) | |
| fig.savefig(output_path) | |
| plt.close(fig) | |
| return roc_auc | |
| def plot_precision_recall_curve(labels, probs, output_path): | |
| """Plot Precision-Recall curve.""" | |
| precision, recall, _ = precision_recall_curve(labels, probs) | |
| ap = average_precision_score(labels, probs) | |
| fig, ax = plt.subplots() | |
| ax.plot(recall, precision, color='#16A34A', lw=2.5, | |
| label=f'DCA-Net (AP = {ap:.4f})') | |
| ax.set_xlabel('Recall (Sensitivity)') | |
| ax.set_ylabel('Precision') | |
| ax.set_title('Precision-Recall Curve') | |
| ax.legend(loc='upper right') | |
| ax.grid(True, alpha=0.3) | |
| fig.savefig(output_path) | |
| plt.close(fig) | |
| return ap | |
| def plot_confusion_matrix(labels, probs, output_path, threshold=0.5): | |
| """Plot confusion matrix heatmap.""" | |
| preds = (probs >= threshold).astype(int) | |
| cm = confusion_matrix(labels, preds, labels=[0, 1]) | |
| fig, ax = plt.subplots(figsize=(7, 6)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax, | |
| xticklabels=['Negative', 'Positive'], | |
| yticklabels=['Negative', 'Positive'], | |
| annot_kws={'size': 16}) | |
| ax.set_xlabel('Predicted Label') | |
| ax.set_ylabel('True Label') | |
| ax.set_title(f'Confusion Matrix (threshold = {threshold})') | |
| fig.savefig(output_path) | |
| plt.close(fig) | |
| def plot_froc_curve(labels, probs, output_path): | |
| """Plot FROC curve: sensitivity at various false positive rates.""" | |
| fpr, tpr, thresholds = roc_curve(labels, probs) | |
| n_neg = (labels == 0).sum() | |
| # Convert FPR to average FP count per scan | |
| fp_per_scan = fpr * n_neg / max(len(np.unique(labels)), 1) | |
| # Standard FROC reference points | |
| ref_fps = [0.125, 0.25, 0.5, 1, 2, 4, 8] | |
| fig, ax = plt.subplots() | |
| ax.plot(fp_per_scan, tpr, color='#DC2626', lw=2.5, label='DCA-Net') | |
| # Mark reference points | |
| for fp_ref in ref_fps: | |
| idx = np.searchsorted(fp_per_scan, fp_ref) | |
| idx = min(idx, len(tpr) - 1) | |
| ax.plot(fp_ref, tpr[idx], 'ko', markersize=5) | |
| ax.annotate(f'{tpr[idx]:.2f}', (fp_ref, tpr[idx]), | |
| textcoords="offset points", xytext=(5, 5), fontsize=9) | |
| ax.set_xscale('log') | |
| ax.set_xlim([0.1, 100]) | |
| ax.set_ylim([0.0, 1.05]) | |
| ax.set_xlabel('Average False Positives per Scan') | |
| ax.set_ylabel('Sensitivity (True Positive Rate)') | |
| ax.set_title('FROC Curve') | |
| ax.legend() | |
| ax.grid(True, alpha=0.3, which='both') | |
| fig.savefig(output_path) | |
| plt.close(fig) | |
| def plot_calibration_diagram(labels, probs, output_path, n_bins=10): | |
| """Plot reliability / calibration diagram.""" | |
| bin_boundaries = np.linspace(0, 1, n_bins + 1) | |
| bin_centers = [] | |
| bin_accuracies = [] | |
| bin_counts = [] | |
| for i in range(n_bins): | |
| mask = (probs >= bin_boundaries[i]) & (probs < bin_boundaries[i + 1]) | |
| if mask.sum() == 0: | |
| continue | |
| bin_centers.append(probs[mask].mean()) | |
| bin_accuracies.append(labels[mask].mean()) | |
| bin_counts.append(mask.sum()) | |
| fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 9), | |
| gridspec_kw={'height_ratios': [3, 1]}) | |
| # Reliability diagram | |
| ax1.plot([0, 1], [0, 1], 'k--', lw=1, label='Perfect calibration') | |
| ax1.bar(bin_centers, bin_accuracies, width=1/n_bins * 0.8, | |
| color='#7C3AED', alpha=0.7, edgecolor='black', label='DCA-Net') | |
| ax1.set_xlabel('Mean Predicted Probability') | |
| ax1.set_ylabel('Fraction of Positives') | |
| ax1.set_title('Calibration / Reliability Diagram') | |
| ax1.legend() | |
| ax1.grid(True, alpha=0.3) | |
| ax1.set_xlim([0, 1]) | |
| ax1.set_ylim([0, 1]) | |
| # Histogram of predictions | |
| ax2.hist(probs, bins=n_bins, range=(0, 1), color='#7C3AED', | |
| alpha=0.6, edgecolor='black') | |
| ax2.set_xlabel('Predicted Probability') | |
| ax2.set_ylabel('Count') | |
| ax2.set_title('Prediction Distribution') | |
| fig.tight_layout() | |
| fig.savefig(output_path) | |
| plt.close(fig) | |
| def plot_uncertainty_distribution(mean_probs, confidences, labels, output_path): | |
| """Plot uncertainty / confidence distribution split by correct/incorrect.""" | |
| preds = (mean_probs > 0.5).astype(int) | |
| correct = (preds == labels) | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| # Confidence distribution | |
| axes[0].hist(confidences[correct], bins=20, alpha=0.7, color='#16A34A', | |
| label='Correct', edgecolor='black') | |
| axes[0].hist(confidences[~correct], bins=20, alpha=0.7, color='#DC2626', | |
| label='Incorrect', edgecolor='black') | |
| axes[0].set_xlabel('Confidence Score') | |
| axes[0].set_ylabel('Count') | |
| axes[0].set_title('Confidence Distribution') | |
| axes[0].legend() | |
| axes[0].grid(True, alpha=0.3) | |
| # Confidence vs accuracy scatter | |
| conf_bins = np.linspace(0, 1, 11) | |
| bin_accs = [] | |
| bin_confs = [] | |
| for i in range(len(conf_bins) - 1): | |
| mask = (confidences >= conf_bins[i]) & (confidences < conf_bins[i + 1]) | |
| if mask.sum() > 0: | |
| bin_confs.append(confidences[mask].mean()) | |
| bin_accs.append(correct[mask].mean()) | |
| axes[1].plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Ideal') | |
| axes[1].scatter(bin_confs, bin_accs, s=80, color='#2563EB', | |
| edgecolor='black', zorder=5) | |
| axes[1].set_xlabel('Mean Confidence') | |
| axes[1].set_ylabel('Accuracy') | |
| axes[1].set_title('Confidence vs Accuracy') | |
| axes[1].legend() | |
| axes[1].grid(True, alpha=0.3) | |
| axes[1].set_xlim([0, 1]) | |
| axes[1].set_ylim([0, 1]) | |
| fig.tight_layout() | |
| fig.savefig(output_path) | |
| plt.close(fig) | |
| def plot_training_curves(log_path, output_path): | |
| """Plot training loss and validation curves from training log. | |
| Reads the training log file and parses epoch-level summaries. | |
| """ | |
| train_losses = [] | |
| val_losses = [] | |
| val_accs = [] | |
| epochs = [] | |
| if not Path(log_path).exists(): | |
| return | |
| with open(log_path, 'r') as f: | |
| for line in f: | |
| if 'Train Loss:' in line and 'Val Loss:' in line: | |
| parts = line.strip().split('|') | |
| for part in parts: | |
| part = part.strip() | |
| if part.startswith('Epoch'): | |
| try: | |
| ep = int(part.split('/')[0].replace('Epoch', '').strip()) | |
| epochs.append(ep) | |
| except ValueError: | |
| pass | |
| elif 'Train Loss:' in part: | |
| try: | |
| train_losses.append(float(part.split(':')[1].strip())) | |
| except (ValueError, IndexError): | |
| pass | |
| elif 'Val Loss:' in part: | |
| try: | |
| val_losses.append(float(part.split(':')[1].strip())) | |
| except (ValueError, IndexError): | |
| pass | |
| elif 'Val Acc:' in part: | |
| try: | |
| val_accs.append(float(part.split(':')[1].strip())) | |
| except (ValueError, IndexError): | |
| pass | |
| if not epochs: | |
| return | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| # Loss curves | |
| ax1.plot(epochs[:len(train_losses)], train_losses, '-o', color='#2563EB', | |
| label='Train Loss', markersize=4) | |
| ax1.plot(epochs[:len(val_losses)], val_losses, '-s', color='#DC2626', | |
| label='Val Loss', markersize=4) | |
| ax1.set_xlabel('Epoch') | |
| ax1.set_ylabel('Loss') | |
| ax1.set_title('Training & Validation Loss') | |
| ax1.legend() | |
| ax1.grid(True, alpha=0.3) | |
| # Accuracy curve | |
| if val_accs: | |
| ax2.plot(epochs[:len(val_accs)], val_accs, '-^', color='#16A34A', | |
| label='Val Accuracy', markersize=4) | |
| ax2.set_xlabel('Epoch') | |
| ax2.set_ylabel('Accuracy') | |
| ax2.set_title('Validation Accuracy') | |
| ax2.legend() | |
| ax2.grid(True, alpha=0.3) | |
| fig.tight_layout() | |
| fig.savefig(output_path) | |
| plt.close(fig) | |
| def plot_subgroup_analysis(labels, probs, metadata_df, output_path): | |
| """Plot performance metrics broken down by nodule size category. | |
| Requires metadata_df to have a 'diameter_mm' column for true nodules. | |
| If not available, generates a simulated breakdown based on prediction confidence. | |
| """ | |
| preds = (probs > 0.5).astype(int) | |
| # Try to get size info from metadata | |
| if metadata_df is not None and 'diameter_mm' in metadata_df.columns: | |
| size_bins = [0, 4, 6, 10, float('inf')] | |
| size_labels = ['Tiny (<4mm)', 'Small (4-6mm)', 'Medium (6-10mm)', 'Large (>10mm)'] | |
| sensitivities = [] | |
| counts = [] | |
| for i in range(len(size_bins) - 1): | |
| mask = ((metadata_df['diameter_mm'] >= size_bins[i]) & | |
| (metadata_df['diameter_mm'] < size_bins[i + 1]) & | |
| (labels == 1)) | |
| if mask.sum() > 0: | |
| sens = (preds[mask] == 1).mean() | |
| sensitivities.append(sens) | |
| counts.append(mask.sum()) | |
| else: | |
| sensitivities.append(0) | |
| counts.append(0) | |
| else: | |
| # Fallback: analyze by confidence quartiles | |
| pos_mask = labels == 1 | |
| if pos_mask.sum() == 0: | |
| return | |
| pos_probs = probs[pos_mask] | |
| quartiles = np.percentile(pos_probs, [25, 50, 75]) | |
| size_labels = ['Q1 (hardest)', 'Q2', 'Q3', 'Q4 (easiest)'] | |
| bins = [0] + list(quartiles) + [1.01] | |
| sensitivities = [] | |
| counts = [] | |
| for i in range(len(bins) - 1): | |
| mask = (pos_probs >= bins[i]) & (pos_probs < bins[i + 1]) | |
| if mask.sum() > 0: | |
| sensitivities.append((pos_probs[mask] > 0.5).mean()) | |
| counts.append(mask.sum()) | |
| else: | |
| sensitivities.append(0) | |
| counts.append(0) | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| colors = ['#EF4444', '#F59E0B', '#10B981', '#3B82F6'] | |
| bars = ax.bar(size_labels, sensitivities, color=colors, edgecolor='black', | |
| alpha=0.8) | |
| # Add count labels on bars | |
| for bar, count in zip(bars, counts): | |
| ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02, | |
| f'n={count}', ha='center', va='bottom', fontsize=11) | |
| ax.set_ylabel('Sensitivity') | |
| ax.set_title('Sensitivity by Nodule Subgroup') | |
| ax.set_ylim([0, 1.15]) | |
| ax.grid(True, alpha=0.3, axis='y') | |
| fig.savefig(output_path) | |
| plt.close(fig) | |
| def plot_score_distribution(labels, probs, output_path): | |
| """Plot prediction score distributions for positive vs negative samples.""" | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.hist(probs[labels == 0], bins=50, alpha=0.6, color='#3B82F6', | |
| label='Negative', edgecolor='black', density=True) | |
| ax.hist(probs[labels == 1], bins=50, alpha=0.6, color='#EF4444', | |
| label='Positive', edgecolor='black', density=True) | |
| ax.axvline(x=0.5, color='black', linestyle='--', lw=1.5, alpha=0.7, | |
| label='Decision boundary') | |
| ax.set_xlabel('Predicted Probability') | |
| ax.set_ylabel('Density') | |
| ax.set_title('Prediction Score Distribution') | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| fig.savefig(output_path) | |
| plt.close(fig) | |
| def generate_all_plots(labels, probs, output_dir, mean_probs=None, | |
| confidences=None, metadata_df=None, log_path=None): | |
| """Generate all evaluation plots and save to output_dir. | |
| Args: | |
| labels: numpy array of ground truth labels | |
| probs: numpy array of predicted probabilities | |
| output_dir: directory to save plots | |
| mean_probs: MC Dropout mean predictions (optional) | |
| confidences: MC Dropout confidence scores (optional) | |
| metadata_df: DataFrame with sample metadata (optional) | |
| log_path: path to training log file (optional) | |
| Returns: | |
| dict: paths to all generated plots | |
| """ | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| logger = logging.getLogger('dca-net') | |
| plots = {} | |
| # Helper to safely generate plots | |
| def _safe_plot(name, func, *args, **kwargs): | |
| try: | |
| func(*args, **kwargs) | |
| plots[name] = str(args[-1]) if args else '' | |
| except Exception as e: | |
| logger.warning(f" Failed to generate {name}: {e}") | |
| # 1. ROC Curve | |
| logger.info(" Generating ROC curve...") | |
| _safe_plot('roc_curve', plot_roc_curve, labels, probs, output_dir / 'roc_curve.png') | |
| # 2. Precision-Recall Curve | |
| logger.info(" Generating PR curve...") | |
| _safe_plot('pr_curve', plot_precision_recall_curve, labels, probs, output_dir / 'pr_curve.png') | |
| # 3. Confusion Matrix | |
| logger.info(" Generating confusion matrix...") | |
| _safe_plot('confusion_matrix', plot_confusion_matrix, labels, probs, output_dir / 'confusion_matrix.png') | |
| # 4. FROC Curve | |
| logger.info(" Generating FROC curve...") | |
| _safe_plot('froc_curve', plot_froc_curve, labels, probs, output_dir / 'froc_curve.png') | |
| # 5. Calibration Diagram | |
| logger.info(" Generating calibration diagram...") | |
| _safe_plot('calibration_diagram', plot_calibration_diagram, labels, probs, output_dir / 'calibration_diagram.png') | |
| # 6. Score Distribution | |
| logger.info(" Generating score distribution...") | |
| _safe_plot('score_distribution', plot_score_distribution, labels, probs, output_dir / 'score_distribution.png') | |
| # 7. Uncertainty Distribution (if MC Dropout was run) | |
| if mean_probs is not None and confidences is not None: | |
| logger.info(" Generating uncertainty plots...") | |
| _safe_plot('uncertainty_distribution', plot_uncertainty_distribution, | |
| mean_probs, confidences, labels, output_dir / 'uncertainty_distribution.png') | |
| # 8. Training Curves (if log file provided) | |
| if log_path and Path(log_path).exists(): | |
| logger.info(" Generating training curves...") | |
| _safe_plot('training_curves', plot_training_curves, log_path, output_dir / 'training_curves.png') | |
| # 9. Subgroup Analysis | |
| logger.info(" Generating subgroup analysis...") | |
| _safe_plot('subgroup_analysis', plot_subgroup_analysis, labels, probs, metadata_df, output_dir / 'subgroup_analysis.png') | |
| logger.info(f" All plots saved to {output_dir}/") | |
| return plots | |