OncoVision-X / src /evaluation /visualizations.py
adityasync's picture
Clean OncoVision-X deployment with LFS
8960670
#!/usr/bin/env python3
"""
Visualization module for DCA-Net evaluation.
Generates all plots needed for the research paper.
"""
import numpy as np
import matplotlib
matplotlib.use('Agg') # Non-interactive backend for server
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
roc_curve, auc, precision_recall_curve, average_precision_score,
confusion_matrix
)
from pathlib import Path
import json
import logging
# Set publication-quality defaults
plt.rcParams.update({
'font.size': 12,
'axes.labelsize': 14,
'axes.titlesize': 15,
'xtick.labelsize': 11,
'ytick.labelsize': 11,
'legend.fontsize': 11,
'figure.figsize': (8, 6),
'figure.dpi': 150,
'savefig.bbox': 'tight',
'savefig.dpi': 300,
})
def plot_roc_curve(labels, probs, output_path):
"""Plot ROC curve with AUC score."""
fpr, tpr, _ = roc_curve(labels, probs)
roc_auc = auc(fpr, tpr)
fig, ax = plt.subplots()
ax.plot(fpr, tpr, color='#2563EB', lw=2.5,
label=f'DCA-Net (AUC = {roc_auc:.4f})')
ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5, label='Random')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curve — Lung Nodule Classification')
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3)
fig.savefig(output_path)
plt.close(fig)
return roc_auc
def plot_precision_recall_curve(labels, probs, output_path):
"""Plot Precision-Recall curve."""
precision, recall, _ = precision_recall_curve(labels, probs)
ap = average_precision_score(labels, probs)
fig, ax = plt.subplots()
ax.plot(recall, precision, color='#16A34A', lw=2.5,
label=f'DCA-Net (AP = {ap:.4f})')
ax.set_xlabel('Recall (Sensitivity)')
ax.set_ylabel('Precision')
ax.set_title('Precision-Recall Curve')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
fig.savefig(output_path)
plt.close(fig)
return ap
def plot_confusion_matrix(labels, probs, output_path, threshold=0.5):
"""Plot confusion matrix heatmap."""
preds = (probs >= threshold).astype(int)
cm = confusion_matrix(labels, preds, labels=[0, 1])
fig, ax = plt.subplots(figsize=(7, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
xticklabels=['Negative', 'Positive'],
yticklabels=['Negative', 'Positive'],
annot_kws={'size': 16})
ax.set_xlabel('Predicted Label')
ax.set_ylabel('True Label')
ax.set_title(f'Confusion Matrix (threshold = {threshold})')
fig.savefig(output_path)
plt.close(fig)
def plot_froc_curve(labels, probs, output_path):
"""Plot FROC curve: sensitivity at various false positive rates."""
fpr, tpr, thresholds = roc_curve(labels, probs)
n_neg = (labels == 0).sum()
# Convert FPR to average FP count per scan
fp_per_scan = fpr * n_neg / max(len(np.unique(labels)), 1)
# Standard FROC reference points
ref_fps = [0.125, 0.25, 0.5, 1, 2, 4, 8]
fig, ax = plt.subplots()
ax.plot(fp_per_scan, tpr, color='#DC2626', lw=2.5, label='DCA-Net')
# Mark reference points
for fp_ref in ref_fps:
idx = np.searchsorted(fp_per_scan, fp_ref)
idx = min(idx, len(tpr) - 1)
ax.plot(fp_ref, tpr[idx], 'ko', markersize=5)
ax.annotate(f'{tpr[idx]:.2f}', (fp_ref, tpr[idx]),
textcoords="offset points", xytext=(5, 5), fontsize=9)
ax.set_xscale('log')
ax.set_xlim([0.1, 100])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('Average False Positives per Scan')
ax.set_ylabel('Sensitivity (True Positive Rate)')
ax.set_title('FROC Curve')
ax.legend()
ax.grid(True, alpha=0.3, which='both')
fig.savefig(output_path)
plt.close(fig)
def plot_calibration_diagram(labels, probs, output_path, n_bins=10):
"""Plot reliability / calibration diagram."""
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_centers = []
bin_accuracies = []
bin_counts = []
for i in range(n_bins):
mask = (probs >= bin_boundaries[i]) & (probs < bin_boundaries[i + 1])
if mask.sum() == 0:
continue
bin_centers.append(probs[mask].mean())
bin_accuracies.append(labels[mask].mean())
bin_counts.append(mask.sum())
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 9),
gridspec_kw={'height_ratios': [3, 1]})
# Reliability diagram
ax1.plot([0, 1], [0, 1], 'k--', lw=1, label='Perfect calibration')
ax1.bar(bin_centers, bin_accuracies, width=1/n_bins * 0.8,
color='#7C3AED', alpha=0.7, edgecolor='black', label='DCA-Net')
ax1.set_xlabel('Mean Predicted Probability')
ax1.set_ylabel('Fraction of Positives')
ax1.set_title('Calibration / Reliability Diagram')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_xlim([0, 1])
ax1.set_ylim([0, 1])
# Histogram of predictions
ax2.hist(probs, bins=n_bins, range=(0, 1), color='#7C3AED',
alpha=0.6, edgecolor='black')
ax2.set_xlabel('Predicted Probability')
ax2.set_ylabel('Count')
ax2.set_title('Prediction Distribution')
fig.tight_layout()
fig.savefig(output_path)
plt.close(fig)
def plot_uncertainty_distribution(mean_probs, confidences, labels, output_path):
"""Plot uncertainty / confidence distribution split by correct/incorrect."""
preds = (mean_probs > 0.5).astype(int)
correct = (preds == labels)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Confidence distribution
axes[0].hist(confidences[correct], bins=20, alpha=0.7, color='#16A34A',
label='Correct', edgecolor='black')
axes[0].hist(confidences[~correct], bins=20, alpha=0.7, color='#DC2626',
label='Incorrect', edgecolor='black')
axes[0].set_xlabel('Confidence Score')
axes[0].set_ylabel('Count')
axes[0].set_title('Confidence Distribution')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Confidence vs accuracy scatter
conf_bins = np.linspace(0, 1, 11)
bin_accs = []
bin_confs = []
for i in range(len(conf_bins) - 1):
mask = (confidences >= conf_bins[i]) & (confidences < conf_bins[i + 1])
if mask.sum() > 0:
bin_confs.append(confidences[mask].mean())
bin_accs.append(correct[mask].mean())
axes[1].plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Ideal')
axes[1].scatter(bin_confs, bin_accs, s=80, color='#2563EB',
edgecolor='black', zorder=5)
axes[1].set_xlabel('Mean Confidence')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Confidence vs Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim([0, 1])
axes[1].set_ylim([0, 1])
fig.tight_layout()
fig.savefig(output_path)
plt.close(fig)
def plot_training_curves(log_path, output_path):
"""Plot training loss and validation curves from training log.
Reads the training log file and parses epoch-level summaries.
"""
train_losses = []
val_losses = []
val_accs = []
epochs = []
if not Path(log_path).exists():
return
with open(log_path, 'r') as f:
for line in f:
if 'Train Loss:' in line and 'Val Loss:' in line:
parts = line.strip().split('|')
for part in parts:
part = part.strip()
if part.startswith('Epoch'):
try:
ep = int(part.split('/')[0].replace('Epoch', '').strip())
epochs.append(ep)
except ValueError:
pass
elif 'Train Loss:' in part:
try:
train_losses.append(float(part.split(':')[1].strip()))
except (ValueError, IndexError):
pass
elif 'Val Loss:' in part:
try:
val_losses.append(float(part.split(':')[1].strip()))
except (ValueError, IndexError):
pass
elif 'Val Acc:' in part:
try:
val_accs.append(float(part.split(':')[1].strip()))
except (ValueError, IndexError):
pass
if not epochs:
return
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Loss curves
ax1.plot(epochs[:len(train_losses)], train_losses, '-o', color='#2563EB',
label='Train Loss', markersize=4)
ax1.plot(epochs[:len(val_losses)], val_losses, '-s', color='#DC2626',
label='Val Loss', markersize=4)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Accuracy curve
if val_accs:
ax2.plot(epochs[:len(val_accs)], val_accs, '-^', color='#16A34A',
label='Val Accuracy', markersize=4)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(output_path)
plt.close(fig)
def plot_subgroup_analysis(labels, probs, metadata_df, output_path):
"""Plot performance metrics broken down by nodule size category.
Requires metadata_df to have a 'diameter_mm' column for true nodules.
If not available, generates a simulated breakdown based on prediction confidence.
"""
preds = (probs > 0.5).astype(int)
# Try to get size info from metadata
if metadata_df is not None and 'diameter_mm' in metadata_df.columns:
size_bins = [0, 4, 6, 10, float('inf')]
size_labels = ['Tiny (<4mm)', 'Small (4-6mm)', 'Medium (6-10mm)', 'Large (>10mm)']
sensitivities = []
counts = []
for i in range(len(size_bins) - 1):
mask = ((metadata_df['diameter_mm'] >= size_bins[i]) &
(metadata_df['diameter_mm'] < size_bins[i + 1]) &
(labels == 1))
if mask.sum() > 0:
sens = (preds[mask] == 1).mean()
sensitivities.append(sens)
counts.append(mask.sum())
else:
sensitivities.append(0)
counts.append(0)
else:
# Fallback: analyze by confidence quartiles
pos_mask = labels == 1
if pos_mask.sum() == 0:
return
pos_probs = probs[pos_mask]
quartiles = np.percentile(pos_probs, [25, 50, 75])
size_labels = ['Q1 (hardest)', 'Q2', 'Q3', 'Q4 (easiest)']
bins = [0] + list(quartiles) + [1.01]
sensitivities = []
counts = []
for i in range(len(bins) - 1):
mask = (pos_probs >= bins[i]) & (pos_probs < bins[i + 1])
if mask.sum() > 0:
sensitivities.append((pos_probs[mask] > 0.5).mean())
counts.append(mask.sum())
else:
sensitivities.append(0)
counts.append(0)
fig, ax = plt.subplots(figsize=(10, 6))
colors = ['#EF4444', '#F59E0B', '#10B981', '#3B82F6']
bars = ax.bar(size_labels, sensitivities, color=colors, edgecolor='black',
alpha=0.8)
# Add count labels on bars
for bar, count in zip(bars, counts):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02,
f'n={count}', ha='center', va='bottom', fontsize=11)
ax.set_ylabel('Sensitivity')
ax.set_title('Sensitivity by Nodule Subgroup')
ax.set_ylim([0, 1.15])
ax.grid(True, alpha=0.3, axis='y')
fig.savefig(output_path)
plt.close(fig)
def plot_score_distribution(labels, probs, output_path):
"""Plot prediction score distributions for positive vs negative samples."""
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(probs[labels == 0], bins=50, alpha=0.6, color='#3B82F6',
label='Negative', edgecolor='black', density=True)
ax.hist(probs[labels == 1], bins=50, alpha=0.6, color='#EF4444',
label='Positive', edgecolor='black', density=True)
ax.axvline(x=0.5, color='black', linestyle='--', lw=1.5, alpha=0.7,
label='Decision boundary')
ax.set_xlabel('Predicted Probability')
ax.set_ylabel('Density')
ax.set_title('Prediction Score Distribution')
ax.legend()
ax.grid(True, alpha=0.3)
fig.savefig(output_path)
plt.close(fig)
def generate_all_plots(labels, probs, output_dir, mean_probs=None,
confidences=None, metadata_df=None, log_path=None):
"""Generate all evaluation plots and save to output_dir.
Args:
labels: numpy array of ground truth labels
probs: numpy array of predicted probabilities
output_dir: directory to save plots
mean_probs: MC Dropout mean predictions (optional)
confidences: MC Dropout confidence scores (optional)
metadata_df: DataFrame with sample metadata (optional)
log_path: path to training log file (optional)
Returns:
dict: paths to all generated plots
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger = logging.getLogger('dca-net')
plots = {}
# Helper to safely generate plots
def _safe_plot(name, func, *args, **kwargs):
try:
func(*args, **kwargs)
plots[name] = str(args[-1]) if args else ''
except Exception as e:
logger.warning(f" Failed to generate {name}: {e}")
# 1. ROC Curve
logger.info(" Generating ROC curve...")
_safe_plot('roc_curve', plot_roc_curve, labels, probs, output_dir / 'roc_curve.png')
# 2. Precision-Recall Curve
logger.info(" Generating PR curve...")
_safe_plot('pr_curve', plot_precision_recall_curve, labels, probs, output_dir / 'pr_curve.png')
# 3. Confusion Matrix
logger.info(" Generating confusion matrix...")
_safe_plot('confusion_matrix', plot_confusion_matrix, labels, probs, output_dir / 'confusion_matrix.png')
# 4. FROC Curve
logger.info(" Generating FROC curve...")
_safe_plot('froc_curve', plot_froc_curve, labels, probs, output_dir / 'froc_curve.png')
# 5. Calibration Diagram
logger.info(" Generating calibration diagram...")
_safe_plot('calibration_diagram', plot_calibration_diagram, labels, probs, output_dir / 'calibration_diagram.png')
# 6. Score Distribution
logger.info(" Generating score distribution...")
_safe_plot('score_distribution', plot_score_distribution, labels, probs, output_dir / 'score_distribution.png')
# 7. Uncertainty Distribution (if MC Dropout was run)
if mean_probs is not None and confidences is not None:
logger.info(" Generating uncertainty plots...")
_safe_plot('uncertainty_distribution', plot_uncertainty_distribution,
mean_probs, confidences, labels, output_dir / 'uncertainty_distribution.png')
# 8. Training Curves (if log file provided)
if log_path and Path(log_path).exists():
logger.info(" Generating training curves...")
_safe_plot('training_curves', plot_training_curves, log_path, output_dir / 'training_curves.png')
# 9. Subgroup Analysis
logger.info(" Generating subgroup analysis...")
_safe_plot('subgroup_analysis', plot_subgroup_analysis, labels, probs, metadata_df, output_dir / 'subgroup_analysis.png')
logger.info(f" All plots saved to {output_dir}/")
return plots