| |
| |
| |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| |
| |
|
|
| import os |
| for dirname, _, filenames in os.walk('/kaggle/input'): |
| for filename in filenames: |
| print(os.path.join(dirname, filename)) |
|
|
| |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from PIL import Image |
| import cv2 |
| import os |
| import pandas as pd |
| from scipy.ndimage import gaussian_filter |
|
|
| |
| |
| |
| OUTPUT_DIR = "/kaggle/working" |
| GRADCAM_OUTPUT_DIR = os.path.join(OUTPUT_DIR, 'gradcam_plus_plus_results') |
| os.makedirs(GRADCAM_OUTPUT_DIR, exist_ok=True) |
|
|
| print("="*80) |
| print("HIERARCHICAL TRI-HEAD GRAD-CAM++ CONFIGURATION") |
| print("="*80) |
| print(f"Output directory: {GRADCAM_OUTPUT_DIR}") |
| print("="*80) |
|
|
| |
| |
| |
| DISEASE_CLASS_MAPPING = { |
| 0: "Breast_cancer", |
| 1: "annrbc-anemia_processed", |
| 2: "colon_processed", |
| 3: "leukemia_processed", |
| 4: "lung_processed", |
| 5: "oral-cancer_processed", |
| 6: "ovarian-cancer_processed", |
| 7: "sickle-cell-new_processed", |
| 8: "thalassemia_processed", |
| } |
|
|
| SEVERITY_CLASS_MAPPING = { |
| 0: "Normal", |
| 1: "Abnormal", |
| } |
|
|
| |
| |
| |
| |
| |
| |
| class GradCAM: |
| """ |
| Standard Grad-CAM implementation for Vision Transformers |
| Simplified - no second-order gradients, just straightforward CAM |
| """ |
| def __init__(self, model, target_layer): |
| """ |
| Args: |
| model: Your Phase3 hierarchical model |
| target_layer: The layer to hook (typically last transformer block) |
| """ |
| self.model = model |
| self.target_layer = target_layer |
| self.gradients = None |
| self.activations = None |
| |
| |
| self.handlers = [] |
| self._register_hooks() |
| |
| def _register_hooks(self): |
| """Register forward and backward hooks on target layer""" |
| def forward_hook(module, input, output): |
| |
| if isinstance(output, tuple): |
| self.activations = output[0].detach() |
| print(f" 🪝 Forward hook: Captured from tuple, shape {output[0].shape}") |
| else: |
| self.activations = output.detach() |
| print(f" 🪝 Forward hook: Captured tensor, shape {output.shape}") |
| |
| def backward_hook(module, grad_input, grad_output): |
| |
| if isinstance(grad_output, tuple): |
| grad = grad_output[0] |
| if grad is not None: |
| self.gradients = grad.detach() |
| print(f" 🪝 Backward hook: Captured from tuple, shape {grad.shape}") |
| else: |
| if grad_output is not None: |
| self.gradients = grad_output.detach() |
| print(f" 🪝 Backward hook: Captured tensor, shape {grad_output.shape}") |
| |
| |
| self.handlers.append( |
| self.target_layer.register_forward_hook(forward_hook) |
| ) |
| self.handlers.append( |
| self.target_layer.register_full_backward_hook(backward_hook) |
| ) |
| |
| def remove_hooks(self): |
| """Remove all hooks""" |
| for handle in self.handlers: |
| handle.remove() |
| |
| def generate_cam(self, class_idx, logits): |
| """ |
| Generate standard Grad-CAM heatmap |
| |
| Args: |
| class_idx: Target class index |
| logits: Model output logits |
| |
| Returns: |
| cam: Grad-CAM heatmap (H, W) |
| """ |
| |
| self.model.zero_grad() |
| |
| |
| one_hot = torch.zeros_like(logits) |
| one_hot[0, class_idx] = 1 |
| logits.backward(gradient=one_hot, retain_graph=True) |
| |
| |
| |
| |
| print(f"\n 🔍 GRADIENT FLOW CHECK:") |
| print(f" {'='*60}") |
| |
| |
| if self.gradients is None: |
| print(f" ❌ CRITICAL: No gradients captured!") |
| return np.zeros((14, 14)) |
| else: |
| print(f" ✅ Gradients captured: {self.gradients.shape}") |
| print(f" Min: {self.gradients.min().item():.6f}, Max: {self.gradients.max().item():.6f}") |
| print(f" Mean: {self.gradients.mean().item():.6f}, Std: {self.gradients.std().item():.6f}") |
| |
| |
| if self.activations is None: |
| print(f" ❌ CRITICAL: No activations captured!") |
| return np.zeros((14, 14)) |
| else: |
| print(f" ✅ Activations captured: {self.activations.shape}") |
| print(f" Min: {self.activations.min().item():.6f}, Max: {self.activations.max().item():.6f}") |
| |
| print(f" {'='*60}\n") |
| |
| |
| |
| |
| |
| |
| |
| |
| weights = self.gradients.mean(dim=1, keepdim=True) |
| |
| print(f" 📊 Weights (channel importance):") |
| print(f" Shape: {weights.shape}") |
| print(f" Min: {weights.min().item():.6f}, Max: {weights.max().item():.6f}") |
| |
| |
| |
| |
| |
| cam = (weights * self.activations).sum(dim=2) |
| |
| print(f" 📊 CAM before ReLU:") |
| print(f" Shape: {cam.shape}") |
| print(f" Min: {cam.min().item():.6f}, Max: {cam.max().item():.6f}") |
| |
| |
| |
| |
| print(f" 📊 CAM after ReLU:") |
| print(f" Min: {cam.min().item():.6f}, Max: {cam.max().item():.6f}") |
| |
| |
| cam = cam[0] |
| |
| |
| if cam.shape[0] > 1: |
| cam = cam[1:] |
| print(f" 🎯 Removed CLS token, remaining tokens: {cam.shape[0]}") |
| |
| |
| grid_size = int(np.sqrt(cam.shape[0])) |
| print(f" 📐 Grid size: {grid_size}x{grid_size}") |
| |
| cam = cam.reshape(grid_size, grid_size) |
| |
| |
| cam_min = cam.min() |
| cam_max = cam.max() |
| |
| if cam_max > cam_min: |
| cam = (cam - cam_min) / (cam_max - cam_min) |
| print(f" ✅ Normalized CAM: min={cam.min().item():.4f}, max={cam.max().item():.4f}") |
| else: |
| print(f" ⚠️ WARNING: No variation in CAM (all same value)") |
| cam = torch.zeros_like(cam) |
| |
| print(f" 📊 Final CAM mean: {cam.mean().item():.6f}\n") |
| |
| return cam.cpu().numpy() |
|
|
|
|
| class GradCAMVisionTransformer: |
| """ |
| Wrapper to apply standard Grad-CAM to Vision Transformer models |
| """ |
| def __init__(self, model, device): |
| self.model = model |
| self.device = device |
| self.model.eval() |
| |
| |
| self.target_layer = self._find_target_layer() |
| print(f" 🎯 Target layer for Grad-CAM: {self.target_layer}") |
| |
| def _find_target_layer(self): |
| """ |
| Find the last transformer block in DINOv2/Phikon model |
| """ |
| print("\n🔍 Searching for target layer in DINOv2 architecture...") |
| |
| try: |
| if hasattr(self.model, 'backbone'): |
| vit_model = self.model.backbone.vit |
| print(f"✅ Found backbone.vit: {type(vit_model).__name__}") |
| else: |
| raise AttributeError("No backbone found") |
| |
| if hasattr(vit_model, 'encoder') and hasattr(vit_model.encoder, 'layer'): |
| num_layers = len(vit_model.encoder.layer) |
| last_layer = vit_model.encoder.layer[-1] |
| print(f"✅ Found encoder with {num_layers} layers") |
| print(f"✅ Target layer: encoder.layer[-1] (layer {num_layers-1})") |
| return last_layer |
| else: |
| raise AttributeError("No encoder.layer found") |
| |
| except AttributeError as e: |
| print(f"❌ Error: {e}") |
| raise ValueError("Could not find DINOv2 encoder layers") |
| |
| def generate_heatmap(self, image_tensor, target_class_idx, head_type='disease'): |
| """ |
| Generate Grad-CAM heatmap for a specific head |
| |
| Args: |
| image_tensor: Input image tensor [1, 3, H, W] |
| target_class_idx: Target class index |
| head_type: 'disease', 'severity', or 'stage' |
| |
| Returns: |
| heatmap: Grad-CAM heatmap resized to input image size |
| """ |
| |
| gradcam = GradCAM(self.model, self.target_layer) |
| |
| |
| image_tensor = image_tensor.to(self.device) |
| image_tensor.requires_grad = True |
| |
| |
| disease_logits, severity_logits, stage_logits, _ = self.model([image_tensor], enable_gradients=True) |
| |
| if head_type == 'disease': |
| logits = disease_logits |
| elif head_type == 'severity': |
| disease_pred_idx = disease_logits.argmax(dim=1).item() |
| disease_name = DISEASE_CLASS_MAPPING.get(disease_pred_idx, f"Unknown_{disease_pred_idx}") |
| logits = severity_logits[disease_name] |
| elif head_type == 'stage': |
| logits = stage_logits |
| else: |
| raise ValueError(f"Unknown head_type: {head_type}") |
| |
| |
| cam = gradcam.generate_cam(target_class_idx, logits) |
| |
| |
| gradcam.remove_hooks() |
| |
| |
| H, W = image_tensor.shape[2], image_tensor.shape[3] |
| cam_resized = cv2.resize(cam, (W, H), interpolation=cv2.INTER_CUBIC) |
| |
| |
| cam_smooth = gaussian_filter(cam_resized, sigma=2) |
| |
| return cam_smooth |
| |
| |
| |
| def create_gradcam_overlay(image_array, heatmap, alpha=0.5, colormap='jet'): |
| """ |
| Create a visual overlay of Grad-CAM++ heatmap on original image |
| Uses red-yellow colormap like traditional Grad-CAM |
| |
| Args: |
| image_array: Original image as numpy array (H, W, 3) |
| heatmap: Grad-CAM++ heatmap (H, W), values in [0, 1] |
| alpha: Transparency of heatmap overlay |
| colormap: Matplotlib colormap name |
| |
| Returns: |
| Overlayed image as numpy array (H, W, 3) in range [0, 1] |
| """ |
| |
| img_normalized = image_array.astype(np.float32) / 255.0 |
| |
| |
| target_h, target_w = img_normalized.shape[:2] |
| if heatmap.shape != (target_h, target_w): |
| print(f" 📐 Resizing heatmap from {heatmap.shape} to ({target_h}, {target_w})") |
| heatmap = cv2.resize(heatmap, (target_w, target_h), interpolation=cv2.INTER_CUBIC) |
| |
| |
| heatmap = np.clip(heatmap, 0, 1) |
| |
| |
| import matplotlib |
| cmap = matplotlib.colormaps.get_cmap(colormap) |
| heatmap_colored = cmap(heatmap)[:, :, :3] |
| |
| |
| overlay = img_normalized * (1 - alpha) + heatmap_colored * alpha |
| overlay = np.clip(overlay, 0, 1) |
| |
| return overlay |
|
|
| def load_and_preprocess_image(img_path, target_size=224): |
| """Load image and preprocess for both visualization and model input""" |
| try: |
| if img_path.lower().endswith(('.svs', '.ndpi')): |
| slide = openslide.OpenSlide(img_path) |
| img_pil = slide.get_thumbnail((target_size, target_size)) |
| slide.close() |
| elif img_path.lower().endswith('.tif'): |
| try: |
| slide = openslide.OpenSlide(img_path) |
| img_pil = slide.get_thumbnail((target_size, target_size)) |
| slide.close() |
| except: |
| img_pil = Image.open(img_path).convert('RGB') |
| img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR) |
| else: |
| img_pil = Image.open(img_path).convert('RGB') |
| img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR) |
| |
| img_array = np.array(img_pil) |
| img_tensor = test_transform(img_pil).unsqueeze(0) |
| |
| return img_pil, img_array, img_tensor |
| |
| except Exception as e: |
| print(f"Error loading image {img_path}: {e}") |
| return None, None, None |
|
|
|
|
| def extract_stage_label(img_path): |
| """ |
| Extract stage label using strict hierarchy rules. |
| |
| Valid structures: |
| - .../abnormal/test/image.png -> no stage -> return -1 |
| - .../abnormal/<stage_name>/test/image.png -> stage exists -> return <stage_name> |
| """ |
| path_parts = img_path.split(os.sep) |
| |
| try: |
| |
| test_idx = path_parts.index('test') |
| |
| |
| candidate = path_parts[test_idx - 1] |
| |
| |
| if candidate.lower() == 'abnormal': |
| return -1 |
| |
| |
| return candidate |
| |
| except (ValueError, IndexError): |
| |
| return -1 |
|
|
|
|
| |
| |
| |
| def run_tri_head_gradcam_plus_plus_analysis(model, device, collected_images): |
| """ |
| Run Grad-CAM++ analysis on disease head (Level 1), severity head (Level 2), and stage head (Level 3) |
| Stage head is only analyzed when ground truth is abnormal and stage label is not -1 |
| Returns comprehensive dictionaries with all analysis results |
| """ |
| print("\n" + "="*80) |
| print("🔍 STARTING TRI-HEAD GRAD-CAM++ ANALYSIS") |
| print("="*80) |
| |
| if not collected_images: |
| print("❌ No images provided") |
| return [] |
| |
| print(f"✅ Processing {len(collected_images)} images\n") |
| |
| |
| gradcam_wrapper = GradCAMVisionTransformer(model, device) |
|
|
| |
| all_results = [] |
| |
| for idx, img_path in enumerate(collected_images): |
| print(f"\n[{idx+1}/{len(collected_images)}] Processing: {os.path.basename(img_path)}") |
| |
| try: |
| |
| img_pil, img_array, img_tensor = load_and_preprocess_image(img_path, target_size=224) |
| |
| if img_tensor is None: |
| print(f" ❌ Failed to load image") |
| continue |
| |
| |
| path_parts = img_path.split(os.sep) |
| dataset_name = "Unknown" |
| true_label = "unknown" |
| stage_label = extract_stage_label(img_path) |
| |
| for part in path_parts: |
| if part in ["ovarian-cancer_processed", "oral-cancer_processed", |
| "Breast_cancer", "colon_processed", "lung_processed", |
| "annrbc-anemia_processed", "leukemia_processed", |
| "sickle-cell-new_processed", "thalassemia_processed"]: |
| dataset_name = part |
| if part in ["normal", "abnormal"]: |
| true_label = part |
| break |
| |
| |
| with torch.no_grad(): |
| img_tensor_device = img_tensor.to(device) |
| disease_logits, severity_logits_dict, stage_logits, _ = model([img_tensor_device]) |
| |
| |
| disease_pred_idx = disease_logits.argmax(dim=1).item() |
| disease_probs = F.softmax(disease_logits, dim=1) |
| disease_confidence = disease_probs[0, disease_pred_idx].item() |
| disease_name = DISEASE_CLASS_MAPPING.get(disease_pred_idx, f"Unknown_{disease_pred_idx}") |
| disease_all_probs = disease_probs[0].cpu().numpy() |
| |
| print(f" 📊 Level 1 (Disease): {disease_name}") |
| print(f" Index: {disease_pred_idx}, Confidence: {disease_confidence:.4f}") |
| |
| |
| severity_logits = severity_logits_dict[disease_name] |
| severity_pred_idx = severity_logits.argmax(dim=1).item() |
| severity_probs = F.softmax(severity_logits, dim=1) |
| severity_confidence = severity_probs[0, severity_pred_idx].item() |
| severity_label_text = SEVERITY_CLASS_MAPPING.get(severity_pred_idx, f"Unknown_{severity_pred_idx}") |
| severity_all_probs = severity_probs[0].cpu().numpy() |
| |
| print(f" 📊 Level 2 (Severity): {severity_label_text}") |
| print(f" Index: {severity_pred_idx}, Confidence: {severity_confidence:.4f}") |
| |
| |
| stage_pred_idx = None |
| stage_confidence = None |
| stage_all_probs = None |
| |
| if stage_logits is not None: |
| stage_pred_idx = stage_logits.argmax(dim=1).item() |
| stage_probs = F.softmax(stage_logits, dim=1) |
| stage_confidence = stage_probs[0, stage_pred_idx].item() |
| stage_all_probs = stage_probs[0].cpu().numpy() |
| |
| print(f" 📊 Level 3 (Stage): Stage {stage_pred_idx}") |
| print(f" Confidence: {stage_confidence:.4f}") |
| print(f" Ground Truth Stage: {stage_label}") |
| |
| |
| print(f"\n 🔥 Generating Grad-CAM++ for Disease Head...") |
| disease_heatmap = gradcam_wrapper.generate_heatmap( |
| img_tensor.clone(), |
| disease_pred_idx, |
| head_type='disease' |
| ) |
| |
| |
| disease_overlay = create_gradcam_overlay(img_array, disease_heatmap, alpha=0.5) |
| |
| print(f" ✅ Disease heatmap generated") |
| print(f" Min: {disease_heatmap.min():.4f}, Max: {disease_heatmap.max():.4f}") |
| |
| |
| print(f" 🔥 Generating Grad-CAM++ for Severity Head...") |
| severity_heatmap = gradcam_wrapper.generate_heatmap( |
| img_tensor.clone(), |
| severity_pred_idx, |
| head_type='severity' |
| ) |
| |
| |
| severity_overlay = create_gradcam_overlay(img_array, severity_heatmap, alpha=0.5) |
| |
| print(f" ✅ Severity heatmap generated") |
| print(f" Min: {severity_heatmap.min():.4f}, Max: {severity_heatmap.max():.4f}") |
| |
| |
| stage_heatmap = None |
| stage_overlay = None |
| include_stage_analysis = False |
| |
| |
| if true_label == "abnormal" and stage_label != -1 and stage_logits is not None: |
| include_stage_analysis = True |
| print(f" 🔥 Generating Grad-CAM++ for Stage Head (GT: abnormal, Stage: {stage_label})...") |
| |
| stage_heatmap = gradcam_wrapper.generate_heatmap( |
| img_tensor.clone(), |
| stage_pred_idx, |
| head_type='stage' |
| ) |
| |
| |
| stage_overlay = create_gradcam_overlay(img_array, stage_heatmap, alpha=0.5) |
| |
| print(f" ✅ Stage heatmap generated") |
| print(f" Min: {stage_heatmap.min():.4f}, Max: {stage_heatmap.max():.4f}") |
| else: |
| reason = [] |
| if true_label != "abnormal": |
| reason.append(f"true_label='{true_label}'") |
| if stage_label == -1: |
| reason.append("stage_label=-1") |
| if stage_logits is None: |
| reason.append("stage_logits=None") |
| print(f" ⏭️ Skipping Stage Head Analysis ({', '.join(reason)})") |
| |
| |
| if include_stage_analysis: |
| |
| union_heatmap = (disease_heatmap + severity_heatmap + stage_heatmap) / 3.0 |
| print(f" 📊 Union Heatmap: Average of 3 heads (Disease + Severity + Stage)") |
| else: |
| |
| union_heatmap = (disease_heatmap + severity_heatmap) / 2.0 |
| print(f" 📊 Union Heatmap: Average of 2 heads (Disease + Severity)") |
| |
| union_overlay = create_gradcam_overlay(img_array, union_heatmap, alpha=0.5) |
| |
| |
| disease_mean_activation = float(disease_heatmap.mean()) |
| disease_max_activation = float(disease_heatmap.max()) |
| |
| severity_mean_activation = float(severity_heatmap.mean()) |
| severity_max_activation = float(severity_heatmap.max()) |
| |
| stage_mean_activation = None |
| stage_max_activation = None |
| if stage_heatmap is not None: |
| stage_mean_activation = float(stage_heatmap.mean()) |
| stage_max_activation = float(stage_heatmap.max()) |
| |
| union_mean_activation = float(union_heatmap.mean()) |
| union_max_activation = float(union_heatmap.max()) |
| |
| |
| result_dict = { |
| |
| 'filename': os.path.basename(img_path), |
| 'full_path': img_path, |
| 'dataset_name': dataset_name, |
| 'true_label': true_label, |
| 'stage_label': stage_label, |
| 'include_stage_analysis': include_stage_analysis, |
| |
| |
| 'image': img_array, |
| |
| |
| 'level1_disease': { |
| 'predicted_class': disease_name, |
| 'predicted_idx': disease_pred_idx, |
| 'confidence': disease_confidence, |
| 'all_probabilities': disease_all_probs, |
| 'heatmap_raw': disease_heatmap, |
| 'heatmap_overlay': disease_overlay, |
| 'activation_stats': { |
| 'mean': disease_mean_activation, |
| 'max': disease_max_activation, |
| } |
| }, |
| |
| |
| 'level2_severity': { |
| 'predicted_class': severity_label_text, |
| 'predicted_idx': severity_pred_idx, |
| 'confidence': severity_confidence, |
| 'all_probabilities': severity_all_probs, |
| 'heatmap_raw': severity_heatmap, |
| 'heatmap_overlay': severity_overlay, |
| 'activation_stats': { |
| 'mean': severity_mean_activation, |
| 'max': severity_max_activation, |
| } |
| }, |
| |
| |
| 'level3_stage': { |
| 'predicted_idx': stage_pred_idx, |
| 'confidence': stage_confidence, |
| 'all_probabilities': stage_all_probs, |
| 'heatmap_raw': stage_heatmap, |
| 'heatmap_overlay': stage_overlay, |
| 'activation_stats': { |
| 'mean': stage_mean_activation, |
| 'max': stage_max_activation, |
| } if stage_heatmap is not None else None |
| }, |
| |
| |
| 'union': { |
| 'heatmap_raw': union_heatmap, |
| 'heatmap_overlay': union_overlay, |
| 'num_heads_averaged': 3 if include_stage_analysis else 2, |
| 'activation_stats': { |
| 'mean': union_mean_activation, |
| 'max': union_max_activation, |
| } |
| }, |
| |
| |
| 'disease_heatmap': disease_heatmap, |
| 'severity_heatmap': severity_heatmap, |
| 'stage_heatmap': stage_heatmap, |
| 'disease_pred': disease_name, |
| 'disease_idx': disease_pred_idx, |
| 'disease_conf': disease_confidence, |
| 'severity_pred': severity_label_text, |
| 'severity_idx': severity_pred_idx, |
| 'severity_conf': severity_confidence, |
| 'stage_pred_idx': stage_pred_idx, |
| 'stage_conf': stage_confidence, |
| } |
| |
| all_results.append(result_dict) |
| |
| print(f" ✅ Completed tri-head Grad-CAM++ analysis") |
| |
| except Exception as e: |
| print(f" ❌ Error: {e}") |
| import traceback |
| traceback.print_exc() |
| continue |
| |
| print("\n" + "="*80) |
| print("✅ TRI-HEAD GRAD-CAM++ ANALYSIS COMPLETE") |
| print(f"📦 Generated {len(all_results)} comprehensive result dictionaries") |
| print("="*80) |
| |
| return all_results |
|
|
|
|
| |
| |
| |
| def display_tri_head_gradcam_grid(results): |
| """ |
| Display grid: each row = one image with 5 columns (or 4 if no stage analysis) |
| [Original | Disease Grad-CAM++ | Severity Grad-CAM++ | Stage Grad-CAM++ (if available) | Union] |
| """ |
| if not results: |
| print("No results to display") |
| return |
| |
| num_images = len(results) |
| max_cols = 5 |
| |
| |
| fig, axes = plt.subplots(num_images, max_cols, figsize=(35, 7 * num_images)) |
| |
| |
| if num_images == 1: |
| axes = axes.reshape(1, -1) |
| |
| cmap = plt.cm.jet |
| |
| for i, result in enumerate(results): |
| has_stage = result['include_stage_analysis'] |
| |
| |
| axes[i, 0].imshow(result['image']) |
| title_text = ( |
| f"Original Image {i+1}\n" |
| f"Dataset: {result['dataset_name']}\n" |
| f"True Label: {result['true_label']}\n" |
| f"Stage GT: {result['stage_label']}\n" |
| f"File: {result['filename']}" |
| ) |
| axes[i, 0].set_title(title_text, fontsize=10, fontweight='bold', pad=10) |
| axes[i, 0].axis('off') |
| |
| |
| axes[i, 1].imshow(result['image']) |
| |
| disease_heatmap = result['level1_disease']['heatmap_raw'] |
| |
| im1 = axes[i, 1].imshow( |
| disease_heatmap, |
| cmap=cmap, |
| alpha=0.5, |
| vmin=0, |
| vmax=1 |
| ) |
| |
| cbar1 = plt.colorbar(im1, ax=axes[i, 1], fraction=0.046, pad=0.04) |
| cbar1.set_label('Activation', rotation=270, labelpad=15) |
| |
| disease_title = ( |
| f"Level 1: Disease Head\n" |
| f"Predicted: {result['level1_disease']['predicted_class']}\n" |
| f"Confidence: {result['level1_disease']['confidence']:.4f}\n" |
| f"Mean Act: {result['level1_disease']['activation_stats']['mean']:.4f}" |
| ) |
| axes[i, 1].set_title(disease_title, fontsize=10, fontweight='bold', pad=10) |
| axes[i, 1].axis('off') |
| |
| |
| axes[i, 2].imshow(result['image']) |
| |
| severity_heatmap = result['level2_severity']['heatmap_raw'] |
| |
| im2 = axes[i, 2].imshow( |
| severity_heatmap, |
| cmap=cmap, |
| alpha=0.5, |
| vmin=0, |
| vmax=1 |
| ) |
| |
| cbar2 = plt.colorbar(im2, ax=axes[i, 2], fraction=0.046, pad=0.04) |
| cbar2.set_label('Activation', rotation=270, labelpad=15) |
| |
| severity_title = ( |
| f"Level 2: Severity Head\n" |
| f"Predicted: {result['level2_severity']['predicted_class']}\n" |
| f"Confidence: {result['level2_severity']['confidence']:.4f}\n" |
| f"Mean Act: {result['level2_severity']['activation_stats']['mean']:.4f}" |
| ) |
| axes[i, 2].set_title(severity_title, fontsize=10, fontweight='bold', pad=10) |
| axes[i, 2].axis('off') |
| |
| |
| if has_stage: |
| axes[i, 3].imshow(result['image']) |
| |
| stage_heatmap = result['level3_stage']['heatmap_raw'] |
| |
| im3 = axes[i, 3].imshow( |
| stage_heatmap, |
| cmap=cmap, |
| alpha=0.5, |
| vmin=0, |
| vmax=1 |
| ) |
| |
| cbar3 = plt.colorbar(im3, ax=axes[i, 3], fraction=0.046, pad=0.04) |
| cbar3.set_label('Activation', rotation=270, labelpad=15) |
| |
| stage_title = ( |
| f"Level 3: Stage Head\n" |
| f"Predicted: Stage {result['level3_stage']['predicted_idx']}\n" |
| f"Confidence: {result['level3_stage']['confidence']:.4f}\n" |
| f"Mean Act: {result['level3_stage']['activation_stats']['mean']:.4f}" |
| ) |
| axes[i, 3].set_title(stage_title, fontsize=10, fontweight='bold', pad=10) |
| axes[i, 3].axis('off') |
| else: |
| |
| axes[i, 3].text( |
| 0.5, 0.5, |
| "Stage Analysis\nNot Applicable\n\n" + |
| (f"Reason: GT={result['true_label']}\n" if result['true_label'] != 'abnormal' else "") + |
| (f"Stage={result['stage_label']}" if result['stage_label'] == -1 else ""), |
| ha='center', va='center', |
| fontsize=12, color='gray', |
| transform=axes[i, 3].transAxes |
| ) |
| axes[i, 3].axis('off') |
| |
| |
| axes[i, 4].imshow(result['image']) |
| |
| union_heatmap = result['union']['heatmap_raw'] |
| |
| im4 = axes[i, 4].imshow( |
| union_heatmap, |
| cmap=cmap, |
| alpha=0.5, |
| vmin=0, |
| vmax=1 |
| ) |
| |
| cbar4 = plt.colorbar(im4, ax=axes[i, 4], fraction=0.046, pad=0.04) |
| cbar4.set_label('Activation', rotation=270, labelpad=15) |
| |
| union_title = ( |
| f"Union: Combined Grad-CAM++\n" |
| f"Averaged {result['union']['num_heads_averaged']} Heads\n" |
| f"Disease: {result['level1_disease']['predicted_class']}\n" |
| f"Severity: {result['level2_severity']['predicted_class']}" |
| ) |
| if has_stage: |
| union_title += f"\nStage: {result['level3_stage']['predicted_idx']}" |
| union_title += f"\nMean Act: {result['union']['activation_stats']['mean']:.4f}" |
| |
| axes[i, 4].set_title(union_title, fontsize=10, fontweight='bold', pad=10) |
| axes[i, 4].axis('off') |
| |
| |
| print(f"\n📊 Image {i+1} ({result['filename']}) Statistics:") |
| print(f" Disease Head: {result['level1_disease']['predicted_class']} " |
| f"({result['level1_disease']['confidence']:.4f})") |
| print(f" Mean Activation: {result['level1_disease']['activation_stats']['mean']:.4f}, " |
| f"Max: {result['level1_disease']['activation_stats']['max']:.4f}") |
| |
| print(f" Severity Head: {result['level2_severity']['predicted_class']} " |
| f"({result['level2_severity']['confidence']:.4f})") |
| print(f" Mean Activation: {result['level2_severity']['activation_stats']['mean']:.4f}, " |
| f"Max: {result['level2_severity']['activation_stats']['max']:.4f}") |
| |
| if has_stage: |
| print(f" Stage Head: Stage {result['level3_stage']['predicted_idx']} " |
| f"({result['level3_stage']['confidence']:.4f})") |
| print(f" Mean Activation: {result['level3_stage']['activation_stats']['mean']:.4f}, " |
| f"Max: {result['level3_stage']['activation_stats']['max']:.4f}") |
| else: |
| print(f" Stage Head: Not analyzed (GT: {result['true_label']}, Stage: {result['stage_label']})") |
| |
| print(f" Union Heatmap ({result['union']['num_heads_averaged']} heads):") |
| print(f" Mean Activation: {result['union']['activation_stats']['mean']:.4f}, " |
| f"Max: {result['union']['activation_stats']['max']:.4f}") |
| |
| plt.suptitle( |
| 'Hierarchical Model - Tri-Head Grad-CAM++ Analysis with Union\n' |
| 'Level 1: Disease | Level 2: Severity | Level 3: Stage (Conditional) | Union: Combined Analysis\n' |
| 'Red = High Activation | Blue = Low Activation', |
| fontsize=16, |
| fontweight='bold', |
| y=0.998 |
| ) |
| |
| plt.tight_layout() |
| |
| |
| grid_save_path = os.path.join(GRADCAM_OUTPUT_DIR, 'tri_head_union_gradcam_plus_plus_analysis.png') |
| plt.savefig(grid_save_path, dpi=150, bbox_inches='tight') |
| print(f"\n✅ Grid saved to: {grid_save_path}") |
| |
| plt.show() |
|
|
|
|
| |
| |
| |
| print("\n" + "="*80) |
| print("CHECKING FOR COLLECTED IMAGES") |
| print("="*80) |
|
|
| try: |
| if 'collected_images' in locals() or 'collected_images' in globals(): |
| print(f"✅ Found collected_images with {len(collected_images)} images\n") |
| |
| |
| gradcam_results = run_tri_head_gradcam_plus_plus_analysis( |
| model, |
| device, |
| collected_images |
| ) |
| |
| |
| if gradcam_results: |
| display_tri_head_gradcam_grid(gradcam_results) |
| |
| |
| results_summary = [] |
| for r in gradcam_results: |
| summary_row = { |
| 'filename': r['filename'], |
| 'dataset': r['dataset_name'], |
| 'true_label': r['true_label'], |
| 'stage_gt': r['stage_label'], |
| 'disease_predicted': r['level1_disease']['predicted_class'], |
| 'disease_confidence': r['level1_disease']['confidence'], |
| 'disease_mean_activation': r['level1_disease']['activation_stats']['mean'], |
| 'severity_predicted': r['level2_severity']['predicted_class'], |
| 'severity_confidence': r['level2_severity']['confidence'], |
| 'severity_mean_activation': r['level2_severity']['activation_stats']['mean'], |
| } |
| |
| if r['include_stage_analysis']: |
| summary_row.update({ |
| 'stage_predicted': r['level3_stage']['predicted_idx'], |
| 'stage_confidence': r['level3_stage']['confidence'], |
| 'stage_mean_activation': r['level3_stage']['activation_stats']['mean'], |
| }) |
| else: |
| summary_row.update({ |
| 'stage_predicted': 'N/A', |
| 'stage_confidence': 'N/A', |
| 'stage_mean_activation': 'N/A', |
| }) |
| |
| summary_row['union_heads_averaged'] = r['union']['num_heads_averaged'] |
| summary_row['union_mean_activation'] = r['union']['activation_stats']['mean'] |
| |
| results_summary.append(summary_row) |
| |
| summary_df = pd.DataFrame(results_summary) |
| summary_path = os.path.join(GRADCAM_OUTPUT_DIR, 'tri_head_union_gradcam_plus_plus_summary.csv') |
| summary_df.to_csv(summary_path, index=False) |
| print(f"\n✅ Summary saved to: {summary_path}") |
| |
| print("\n" + "="*80) |
| print("TRI-HEAD GRAD-CAM++ ANALYSIS SUMMARY") |
| print("="*80) |
| print(summary_df.to_string(index=False)) |
| print("="*80) |
| |
| |
| print("\n" + "="*80) |
| print("📦 RESULTS STRUCTURE") |
| print("="*80) |
| print("Each result dictionary contains:") |
| print(" - filename, full_path, dataset_name, true_label, stage_label") |
| print(" - include_stage_analysis: boolean flag") |
| print(" - image: original image array") |
| print(" - level1_disease: {") |
| print(" predicted_class, predicted_idx, confidence, all_probabilities") |
| print(" heatmap_raw, heatmap_overlay, activation_stats") |
| print(" }") |
| print(" - level2_severity: {") |
| print(" predicted_class, predicted_idx, confidence, all_probabilities") |
| print(" heatmap_raw, heatmap_overlay, activation_stats") |
| print(" }") |
| print(" - level3_stage: {") |
| print(" predicted_idx, confidence, all_probabilities") |
| print(" heatmap_raw (None if not analyzed), heatmap_overlay (None if not analyzed)") |
| print(" activation_stats (None if not analyzed)") |
| print(" }") |
| print(" - union: {") |
| print(" heatmap_raw (average of 2 or 3 heads)") |
| print(" heatmap_overlay, num_heads_averaged, activation_stats") |
| print(" }") |
| print("="*80) |
| |
| print(f"\n✅ gradcam_results variable contains {len(gradcam_results)} dictionaries") |
| print(" Use gradcam_results in the next cell for further analysis!") |
| |
| else: |
| print("\n❌ No results generated") |
| else: |
| print("❌ collected_images not found!") |
| |
| except Exception as e: |
| print(f"❌ Error: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
|
|
|
|
|
|
| import torch |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from PIL import Image |
| import cv2 |
| from scipy.ndimage import zoom, gaussian_filter |
| import os |
|
|
| |
| |
| |
| OUTPUT_DIR="/kaggle/working" |
| ATTENTION_OUTPUT_DIR = os.path.join(OUTPUT_DIR, 'attention_results') |
| os.makedirs(ATTENTION_OUTPUT_DIR, exist_ok=True) |
|
|
| print("="*80) |
| print("ATTENTION VISUALIZATION CONFIGURATION") |
| print("="*80) |
| print(f"Output directory: {ATTENTION_OUTPUT_DIR}") |
| print("="*80) |
|
|
| |
| |
| |
| class AttentionExtractor(nn.Module): |
| """ |
| Wrapper to extract attention weights from the hierarchical model |
| """ |
| def __init__(self, phase3_model): |
| super().__init__() |
| self.phase3_model = phase3_model |
| self.attention_weights = None |
| self.tile_features = None |
| |
| def forward(self, tiles): |
| """ |
| Extract attention weights and tile features from the model |
| """ |
| |
| disease_logits, severity_logits, stage_logits, attention_weights = self.phase3_model(tiles) |
| |
| |
| self.attention_weights = attention_weights |
| |
| return disease_logits, severity_logits, stage_logits, attention_weights |
|
|
| |
| attention_extractor = AttentionExtractor(model).to(device) |
| attention_extractor.eval() |
|
|
| print("\n✅ Attention extractor created successfully\n") |
|
|
| |
| |
| |
| def extract_attention_map(model, preprocessed_image, device): |
| """ |
| Extract attention weights from the model for a single image |
| Since each image is a single tile, we get attention for that single representation |
| |
| Args: |
| model: AttentionExtractor model |
| preprocessed_image: Preprocessed tensor (single image, not tiled) |
| device: torch device |
| |
| Returns: |
| attention_weights: numpy array of attention weight (single value for single tile) |
| disease_logits, severity_logits, stage_logits: model outputs |
| """ |
| model.eval() |
| |
| try: |
| |
| |
| if preprocessed_image.dim() == 3: |
| |
| image_batch = preprocessed_image.unsqueeze(0) |
| else: |
| |
| image_batch = preprocessed_image |
| |
| |
| |
| tiles_list = [image_batch.to(device)] |
| |
| with torch.no_grad(): |
| disease_logits, severity_logits, stage_logits, attention_weights = model(tiles_list) |
| |
| |
| |
| attention_np = attention_weights.squeeze().cpu().numpy() |
| |
| |
| if attention_np.ndim == 0: |
| attention_np = np.array([attention_np.item()]) |
| |
| print(f" Extracted attention weights: shape={attention_np.shape}, value={attention_np}") |
| |
| return attention_np, disease_logits, severity_logits, stage_logits |
| |
| except Exception as e: |
| print(f" Error extracting attention: {e}") |
| import traceback |
| traceback.print_exc() |
| return None, None, None, None |
|
|
| def create_uniform_attention_heatmap(attention_weight, image_shape): |
| """ |
| Create a uniform attention heatmap for a single tile (entire image) |
| Since the whole image is one tile, the attention is uniform across it |
| |
| Args: |
| attention_weight: single attention weight value |
| image_shape: tuple (height, width) of image |
| |
| Returns: |
| heatmap: 2D array with uniform attention value |
| """ |
| |
| |
| heatmap = np.full(image_shape, attention_weight, dtype=np.float32) |
| |
| |
| if heatmap.max() > 0: |
| heatmap = heatmap / heatmap.max() |
| |
| return heatmap |
|
|
| def extract_patch_level_attention(model_backbone, preprocessed_image, device, patch_size=16): |
| """ |
| Extract patch-level attention from ViT backbone |
| ViT processes image as patches, we can visualize their importance |
| |
| Args: |
| model_backbone: ViT backbone model |
| preprocessed_image: Preprocessed tensor |
| device: torch device |
| patch_size: ViT patch size (default 16 for most ViTs) |
| |
| Returns: |
| patch_attention_map: 2D heatmap showing patch-level importance |
| """ |
| try: |
| if preprocessed_image.dim() == 3: |
| image_batch = preprocessed_image.unsqueeze(0).to(device) |
| else: |
| image_batch = preprocessed_image.to(device) |
| |
| with torch.no_grad(): |
| |
| outputs = model_backbone.vit(pixel_values=image_batch) |
| |
| |
| |
| hidden_states = outputs.last_hidden_state |
| |
| |
| patch_tokens = hidden_states[:, 1:, :] |
| |
| |
| patch_importance = torch.norm(patch_tokens, p=2, dim=2).squeeze().cpu().numpy() |
| |
| |
| |
| num_patches = len(patch_importance) |
| grid_size = int(np.sqrt(num_patches)) |
| |
| |
| attention_grid = patch_importance.reshape(grid_size, grid_size) |
| |
| print(f" Extracted patch-level attention: {grid_size}x{grid_size} patches") |
| |
| return attention_grid |
| |
| except Exception as e: |
| print(f" Error extracting patch attention: {e}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
| def create_patch_attention_heatmap(patch_attention_grid, target_shape): |
| """ |
| Upsample patch-level attention to image dimensions |
| |
| Args: |
| patch_attention_grid: 2D grid of patch attention values |
| target_shape: tuple (height, width) for output |
| |
| Returns: |
| heatmap: upsampled attention heatmap |
| """ |
| |
| zoom_factors = (target_shape[0] / patch_attention_grid.shape[0], |
| target_shape[1] / patch_attention_grid.shape[1]) |
| |
| |
| heatmap = zoom(patch_attention_grid, zoom_factors, order=1) |
| |
| |
| heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) |
| |
| |
| heatmap = gaussian_filter(heatmap, sigma=5) |
| |
| return heatmap |
|
|
| def load_image_for_attention(img_path, target_size=768): |
| """Load and resize image for attention visualization""" |
| try: |
| if img_path.lower().endswith(('.svs', '.ndpi')): |
| slide = openslide.OpenSlide(img_path) |
| img_pil = slide.get_thumbnail((target_size, target_size)) |
| slide.close() |
| elif img_path.lower().endswith('.tif'): |
| try: |
| slide = openslide.OpenSlide(img_path) |
| img_pil = slide.get_thumbnail((target_size, target_size)) |
| slide.close() |
| except: |
| img_pil = Image.open(img_path).convert('RGB') |
| img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR) |
| else: |
| img_pil = Image.open(img_path).convert('RGB') |
| img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR) |
| |
| img_array = np.array(img_pil) |
| return img_pil, img_array |
| |
| except Exception as e: |
| print(f"Error loading image {img_path}: {e}") |
| return None, None |
|
|
| |
| |
| |
| def run_attention_analysis(attention_model, device, collected_images, processed_images, |
| main_class_mapping, stage_class_mapping): |
| """ |
| Extract and visualize attention weights from the hierarchical model |
| Uses patch-level attention from ViT backbone since images are single tiles |
| |
| Args: |
| attention_model: AttentionExtractor model |
| device: torch device |
| collected_images: list of image paths |
| processed_images: list of preprocessed tensors (single images) |
| main_class_mapping: dictionary mapping class indices to names |
| stage_class_mapping: dictionary mapping stage indices to names |
| |
| Returns: |
| list of results dictionaries |
| """ |
| print("\n" + "="*80) |
| print("🎯 STARTING ATTENTION WEIGHT EXTRACTION AND VISUALIZATION") |
| print("="*80) |
| print("ℹ️ Note: Each image is treated as a single tile") |
| print("ℹ️ Using patch-level attention from ViT backbone for visualization") |
| |
| if not collected_images or not processed_images: |
| print("❌ No images or preprocessed data provided") |
| return [] |
| |
| print(f"✅ Processing {len(collected_images)} images\n") |
| |
| all_results = [] |
| |
| for idx, (img_path, preprocessed_image) in enumerate(zip(collected_images, processed_images)): |
| print(f"\n[{idx+1}/{len(collected_images)}] Processing: {os.path.basename(img_path)}") |
| print(f" Path: {img_path}") |
| print(f" Image shape: {preprocessed_image.shape}") |
| |
| try: |
| |
| img_pil, img_array = load_image_for_attention(img_path, target_size=768) |
| |
| if img_array is None: |
| print(f" ❌ Failed to load image") |
| continue |
| |
| |
| attention_weights, disease_logits, severity_logits, stage_logits = extract_attention_map( |
| attention_model, |
| preprocessed_image, |
| device |
| ) |
| |
| if attention_weights is None: |
| print(f" ❌ Failed to extract attention") |
| continue |
| |
| |
| patch_attention = extract_patch_level_attention( |
| attention_model.phase3_model.backbone, |
| preprocessed_image, |
| device |
| ) |
| |
| if patch_attention is not None: |
| |
| attention_heatmap = create_patch_attention_heatmap( |
| patch_attention, |
| img_array.shape[:2] |
| ) |
| print(f" ✅ Created patch-level attention heatmap") |
| else: |
| |
| attention_heatmap = create_uniform_attention_heatmap( |
| attention_weights[0], |
| img_array.shape[:2] |
| ) |
| print(f" ℹ️ Using uniform attention heatmap") |
| |
| |
| with torch.no_grad(): |
| disease_probs = F.softmax(disease_logits, dim=1) |
| disease_pred_idx = torch.argmax(disease_probs, dim=1).item() |
| disease_confidence = disease_probs[0, disease_pred_idx].item() |
| |
| predicted_class_name = main_class_mapping.get( |
| disease_pred_idx, |
| f"Unknown_Class_{disease_pred_idx}" |
| ) |
| |
| |
| if "_normal" in predicted_class_name: |
| predicted_disease = predicted_class_name.replace("_normal", "") |
| severity_label = "Normal" |
| elif "_abnormal" in predicted_class_name: |
| predicted_disease = predicted_class_name.replace("_abnormal", "") |
| severity_label = "Abnormal" |
| else: |
| predicted_disease = predicted_class_name |
| severity_label = "Unknown" |
| |
| print(f" 📊 Prediction: {predicted_class_name}") |
| print(f" Confidence: {disease_confidence:.4f}") |
| print(f" MIL Attention Weight: {attention_weights[0]:.4f}") |
| |
| |
| path_parts = img_path.split(os.sep) |
| dataset_name = "Unknown" |
| true_label = "unknown" |
| |
| for part in path_parts: |
| if "processed" in part or "cancer" in part.lower(): |
| dataset_name = part |
| if part in ["normal", "abnormal"]: |
| true_label = part |
| break |
| |
| |
| result = { |
| 'image': img_array, |
| 'attention_heatmap': attention_heatmap, |
| 'mil_attention_weight': attention_weights[0], |
| 'true_label': true_label, |
| 'dataset_name': dataset_name, |
| 'predicted_class': predicted_class_name, |
| 'predicted_disease': predicted_disease, |
| 'severity': severity_label, |
| 'class_idx': disease_pred_idx, |
| 'confidence': disease_confidence, |
| 'filename': os.path.basename(img_path), |
| 'full_path': img_path |
| } |
| |
| all_results.append(result) |
| |
| |
| heatmap_filename = f"attention_heatmap_{idx+1}_{os.path.splitext(os.path.basename(img_path))[0]}.npy" |
| heatmap_path = os.path.join(ATTENTION_OUTPUT_DIR, heatmap_filename) |
| np.save(heatmap_path, attention_heatmap) |
| |
| print(f" ✅ Attention heatmap saved to: {heatmap_filename}") |
| print(f" ✅ Completed analysis") |
| |
| except Exception as e: |
| print(f" ❌ Error: {e}") |
| import traceback |
| traceback.print_exc() |
| continue |
| |
| print("\n" + "="*80) |
| print("✅ ATTENTION EXTRACTION COMPLETE") |
| print(f"📁 Results saved to: {ATTENTION_OUTPUT_DIR}") |
| print("="*80) |
| |
| return all_results |
|
|
| |
| |
| |
| def display_attention_grid(results): |
| """ |
| Display grid with original images and attention heatmap overlays |
| """ |
| if not results: |
| print("No results to display") |
| return |
| |
| num_images = len(results) |
| |
| |
| fig, axes = plt.subplots(num_images, 3, figsize=(18, 6 * num_images)) |
| |
| |
| if num_images == 1: |
| axes = axes.reshape(1, -1) |
| |
| |
| cmap = plt.cm.jet |
| |
| for i, result in enumerate(results): |
| |
| axes[i, 0].imshow(result['image']) |
| title_text = ( |
| f"Original Image {i+1}\n" |
| f"Dataset: {result['dataset_name']}\n" |
| f"True Label: {result['true_label']}\n" |
| f"File: {result['filename'][:30]}..." |
| ) |
| axes[i, 0].set_title(title_text, fontsize=10, fontweight='bold', pad=10) |
| axes[i, 0].axis('off') |
| |
| |
| im = axes[i, 1].imshow(result['attention_heatmap'], cmap=cmap) |
| cbar = plt.colorbar(im, ax=axes[i, 1], fraction=0.046, pad=0.04) |
| cbar.set_label('Attention Weight', rotation=270, labelpad=15) |
| |
| heatmap_title = ( |
| f"Attention Heatmap {i+1}\n" |
| f"Patch-Level Importance\n" |
| f"MIL Weight: {result['mil_attention_weight']:.4f}" |
| ) |
| axes[i, 1].set_title(heatmap_title, fontsize=10, fontweight='bold', pad=10) |
| axes[i, 1].axis('off') |
| |
| |
| axes[i, 2].imshow(result['image']) |
| axes[i, 2].imshow(result['attention_heatmap'], cmap=cmap, alpha=0.5) |
| |
| overlay_title = ( |
| f"Overlay {i+1}\n" |
| f"Predicted: {result['predicted_class']}\n" |
| f"Confidence: {result['confidence']:.4f}" |
| ) |
| axes[i, 2].set_title(overlay_title, fontsize=10, fontweight='bold', pad=10) |
| axes[i, 2].axis('off') |
| |
| |
| high_attention = np.sum(result['attention_heatmap'] > 0.7) / result['attention_heatmap'].size * 100 |
| medium_attention = np.sum((result['attention_heatmap'] > 0.4) & |
| (result['attention_heatmap'] <= 0.7)) / result['attention_heatmap'].size * 100 |
| low_attention = np.sum(result['attention_heatmap'] <= 0.4) / result['attention_heatmap'].size * 100 |
| |
| print(f"\n📊 Image {i+1} ({result['filename']}) Attention Statistics:") |
| print(f" Predicted: {result['predicted_class']}") |
| print(f" Confidence: {result['confidence']:.4f}") |
| print(f" MIL Attention Weight: {result['mil_attention_weight']:.4f}") |
| print(f" High attention regions (>0.7): {high_attention:.1f}%") |
| print(f" Medium attention regions (0.4-0.7): {medium_attention:.1f}%") |
| print(f" Low attention regions (<0.4): {low_attention:.1f}%") |
| |
| plt.suptitle( |
| 'Hierarchical Model - Patch-Level Attention Visualization\n' |
| 'Warmer colors (red/yellow) indicate higher attention | Cooler colors (blue) indicate lower attention', |
| fontsize=16, |
| fontweight='bold', |
| y=0.998 |
| ) |
| |
| plt.tight_layout() |
| |
| |
| grid_save_path = os.path.join(ATTENTION_OUTPUT_DIR, 'attention_visualization_grid.png') |
| plt.savefig(grid_save_path, dpi=150, bbox_inches='tight') |
| print(f"\n✅ Grid visualization saved to: {grid_save_path}") |
| |
| plt.show() |
|
|
| |
| |
| |
| print("\n" + "="*80) |
| print("CHECKING FOR COLLECTED AND PROCESSED IMAGES") |
| print("="*80) |
|
|
| try: |
| |
| if 'collected_images' in locals() or 'collected_images' in globals(): |
| if 'processed_images' in locals() or 'processed_images' in globals(): |
| print(f"✅ Found collected_images: {len(collected_images)} images") |
| print(f"✅ Found processed_images: {len(processed_images)} tensors\n") |
| |
| |
| attention_results = run_attention_analysis( |
| attention_extractor, |
| device, |
| collected_images, |
| processed_images, |
| DISEASE_CLASS_MAPPING, |
| STAGE_CLASS_MAPPING |
| ) |
| |
| |
| if attention_results: |
| display_attention_grid(attention_results) |
| |
| |
| results_summary = [] |
| for r in attention_results: |
| high_attn = np.sum(r['attention_heatmap'] > 0.7) / r['attention_heatmap'].size * 100 |
| |
| results_summary.append({ |
| 'filename': r['filename'], |
| 'dataset': r['dataset_name'], |
| 'true_label': r['true_label'], |
| 'predicted_class': r['predicted_class'], |
| 'confidence': r['confidence'], |
| 'mil_attention_weight': r['mil_attention_weight'], |
| 'high_attention_area_%': high_attn, |
| 'max_attention': r['attention_heatmap'].max(), |
| 'mean_attention': r['attention_heatmap'].mean() |
| }) |
| |
| summary_df = pd.DataFrame(results_summary) |
| summary_path = os.path.join(ATTENTION_OUTPUT_DIR, 'attention_summary.csv') |
| summary_df.to_csv(summary_path, index=False) |
| print(f"\n✅ Summary saved to: {summary_path}") |
| |
| print("\n" + "="*80) |
| print("ATTENTION ANALYSIS SUMMARY") |
| print("="*80) |
| print(summary_df.to_string(index=False)) |
| print("="*80) |
| |
| |
| print("\n" + "="*80) |
| print("ATTENTION RESULTS READY FOR OPENCV PROCESSING") |
| print("="*80) |
| print(f"✅ Variable 'attention_results' contains {len(attention_results)} results") |
| print("✅ Each result includes:") |
| print(" - Original image") |
| print(" - Patch-level attention heatmap (smoothed)") |
| print(" - MIL attention weight") |
| print(" - Predictions and metadata") |
| print("\n💡 Use 'attention_results' in the next cell for OpenCV feature extraction") |
| print("="*80) |
| else: |
| print("\n❌ No results generated") |
| else: |
| print("❌ processed_images not found!") |
| print("Please run the preprocessing cell first") |
| else: |
| print("❌ collected_images not found!") |
| print("Please run the image collection cell first") |
| |
| except Exception as e: |
| print(f"❌ Error: {e}") |
| import traceback |
| traceback.print_exc() |
| import numpy as np |
| import cv2 |
| from sklearn.cluster import DBSCAN |
| import matplotlib.pyplot as plt |
| import os |
| from scipy.ndimage import maximum_filter |
| from skimage.feature import graycomatrix, graycoprops |
| from openai import OpenAI |
| import time |
| from scipy.stats import spearmanr |
|
|
| |
| |
| |
| |
| class HeatmapFeatureExtractor: |
| |
| def __init__(self, attention_result): |
| """ |
| attention_result is one entry from attention_results |
| """ |
| self.heatmap = attention_result['attention_heatmap'] |
| self.original_image = attention_result['image'] |
| self.prediction_info = { |
| 'predicted_class': attention_result['predicted_class'], |
| 'predicted_disease': attention_result['predicted_disease'], |
| 'severity': attention_result['severity'], |
| 'confidence': attention_result['confidence'], |
| 'class_idx': attention_result['class_idx'], |
| 'mil_attention': attention_result['mil_attention_weight'] |
| } |
| self.true_label = attention_result['true_label'] |
| self.dataset_name = attention_result['dataset_name'] |
| self.filename = attention_result['filename'] |
| |
| |
| |
| |
| def get_brightest_region(self): |
| """ |
| IMPROVED: Comprehensive analysis of high-attention regions |
| """ |
| heatmap = self.heatmap.astype(float) |
| H, W = heatmap.shape |
| |
| heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) |
| |
| |
| brightest_idx = np.unravel_index(np.argmax(heatmap_norm), heatmap_norm.shape) |
| y_bright, x_bright = brightest_idx |
| intensity_bright = heatmap_norm[y_bright, x_bright] |
| |
| position_bright = self._get_anatomical_position(y_bright, x_bright, H, W) |
| |
| primary_hotspot = { |
| "pixel": (int(y_bright), int(x_bright)), |
| "position": position_bright, |
| "intensity": float(intensity_bright) |
| } |
| |
| |
| secondary_hotspots = self._find_secondary_hotspots(heatmap_norm, H, W, threshold=0.6) |
| |
| |
| attention_pattern = self._determine_attention_pattern(heatmap_norm, H, W) |
| |
| |
| spatial_coverage = self._calculate_spatial_coverage(heatmap_norm, H, W) |
| |
| |
| hotspot_count = 1 + len(secondary_hotspots) |
| |
| return { |
| "primary_hotspot": primary_hotspot, |
| "secondary_hotspots": secondary_hotspots, |
| "attention_pattern": attention_pattern, |
| "spatial_coverage": spatial_coverage, |
| "hotspot_count": hotspot_count |
| } |
|
|
| def _get_anatomical_position(self, y, x, H, W): |
| """Convert pixel coordinates to descriptive position""" |
| y_rel = y / H |
| x_rel = x / W |
| |
| center_y, center_x = H / 2, W / 2 |
| dist_from_center = np.sqrt((y - center_y)**2 + (x - center_x)**2) |
| max_dist = np.sqrt((H/2)**2 + (W/2)**2) |
| dist_ratio = dist_from_center / max_dist |
| |
| center_threshold_inner = 0.35 |
| center_threshold_outer = 0.65 |
| periphery_threshold = 0.75 |
| |
| if y_rel < center_threshold_inner: |
| vert = "upper" |
| elif y_rel > center_threshold_outer: |
| vert = "lower" |
| else: |
| vert = "mid" |
| |
| if x_rel < center_threshold_inner: |
| horiz = "left" |
| elif x_rel > center_threshold_outer: |
| horiz = "right" |
| else: |
| horiz = "center" |
| |
| if horiz == "center" and vert == "mid": |
| position = "center" |
| elif horiz == "center": |
| position = f"{vert}-center" |
| elif vert == "mid": |
| position = f"{horiz}-center" |
| else: |
| position = f"{vert}-{horiz}" |
| |
| if dist_ratio > periphery_threshold: |
| position = f"{position} (periphery)" |
| elif dist_ratio < 0.3: |
| position = f"{position} (core)" |
| |
| return position |
| |
| def _find_secondary_hotspots(self, heatmap_norm, H, W, threshold=0.6, min_distance=20): |
| """Find additional significant attention regions""" |
| secondary = [] |
| |
| neighborhood_size = max(10, min(H, W) // 20) |
| local_max = maximum_filter(heatmap_norm, size=neighborhood_size) |
| |
| peaks = (heatmap_norm == local_max) & (heatmap_norm > threshold * heatmap_norm.max()) |
| peak_coords = np.argwhere(peaks) |
| peak_intensities = heatmap_norm[peaks] |
| sorted_indices = np.argsort(peak_intensities)[::-1] |
| |
| primary_y, primary_x = np.unravel_index(np.argmax(heatmap_norm), heatmap_norm.shape) |
| |
| for idx in sorted_indices[:5]: |
| y, x = peak_coords[idx] |
| |
| if np.sqrt((y - primary_y)**2 + (x - primary_x)**2) < min_distance: |
| continue |
| |
| too_close = False |
| for existing in secondary: |
| ey, ex = existing["pixel"] |
| if np.sqrt((y - ey)**2 + (x - ex)**2) < min_distance: |
| too_close = True |
| break |
| |
| if too_close: |
| continue |
| |
| position = self._get_anatomical_position(y, x, H, W) |
| intensity = float(heatmap_norm[y, x]) |
| |
| secondary.append({ |
| "pixel": (int(y), int(x)), |
| "position": position, |
| "intensity": intensity |
| }) |
| |
| return secondary |
|
|
| def _determine_attention_pattern(self, heatmap_norm, H, W): |
| """Determine overall attention distribution pattern""" |
| center_y, center_x = H // 2, W // 2 |
| |
| Y, X = np.ogrid[:H, :W] |
| dist_from_center = np.sqrt((Y - center_y)**2 + (X - center_x)**2) |
| max_dist = np.sqrt((H/2)**2 + (W/2)**2) |
| |
| core_mask = dist_from_center < (max_dist * 0.3) |
| mid_mask = (dist_from_center >= max_dist * 0.3) & (dist_from_center < max_dist * 0.7) |
| periphery_mask = dist_from_center >= (max_dist * 0.7) |
| |
| core_attention = np.mean(heatmap_norm[core_mask]) |
| mid_attention = np.mean(heatmap_norm[mid_mask]) |
| periphery_attention = np.mean(heatmap_norm[periphery_mask]) |
| |
| high_attention_pixels = np.sum(heatmap_norm > 0.7) / heatmap_norm.size |
| |
| if core_attention > 0.7 and core_attention > mid_attention * 1.5: |
| return "centralized (focused on center)" |
| elif periphery_attention > 0.7 and periphery_attention > core_attention * 1.5: |
| return "peripheral (focused on edges)" |
| elif mid_attention > core_attention and mid_attention > periphery_attention: |
| return "ring-like (donut pattern)" |
| elif high_attention_pixels > 0.5: |
| return "diffuse (spread across image)" |
| elif high_attention_pixels > 0.1 and high_attention_pixels < 0.3: |
| return "focal (single concentrated region)" |
| else: |
| return "scattered (multiple regions)" |
|
|
| def _calculate_spatial_coverage(self, heatmap_norm, H, W): |
| """Calculate percentage of attention in each spatial region""" |
| center_y, center_x = H // 2, W // 2 |
| |
| Y, X = np.ogrid[:H, :W] |
| dist_from_center = np.sqrt((Y - center_y)**2 + (X - center_x)**2) |
| max_dist = np.sqrt((H/2)**2 + (W/2)**2) |
| |
| core_mask = dist_from_center < (max_dist * 0.3) |
| mid_mask = (dist_from_center >= max_dist * 0.3) & (dist_from_center < max_dist * 0.7) |
| periphery_mask = dist_from_center >= (max_dist * 0.7) |
| |
| total_attention = np.sum(heatmap_norm) |
| core_sum = np.sum(heatmap_norm[core_mask]) |
| mid_sum = np.sum(heatmap_norm[mid_mask]) |
| periphery_sum = np.sum(heatmap_norm[periphery_mask]) |
| |
| return { |
| "center_attention": float(core_sum / total_attention * 100) if total_attention > 0 else 0, |
| "mid_region_attention": float(mid_sum / total_attention * 100) if total_attention > 0 else 0, |
| "periphery_attention": float(periphery_sum / total_attention * 100) if total_attention > 0 else 0 |
| } |
|
|
| |
| |
| |
| def get_activation_scatter(self, threshold_ratio=0.6): |
| """Determine if heatmap is focused or scattered""" |
| heatmap = self.heatmap.astype(float) |
| H, W = heatmap.shape |
|
|
| heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6) |
| thresh = threshold_ratio * heatmap_norm.max() |
| high_pixels = np.argwhere(heatmap_norm >= thresh) |
|
|
| if len(high_pixels) == 0: |
| return { |
| "scatter_level": "low", |
| "num_clusters": 0, |
| "clusters_sizes": [] |
| } |
|
|
| clustering = DBSCAN(eps=8, min_samples=20).fit(high_pixels) |
| labels = clustering.labels_ |
| unique_labels = [lb for lb in np.unique(labels) if lb != -1] |
|
|
| cluster_sizes = [] |
| for lb in unique_labels: |
| cluster_sizes.append(int(np.sum(labels == lb))) |
|
|
| num_clusters = len(unique_labels) |
|
|
| if num_clusters == 1: |
| scatter = "low" |
| elif 2 <= num_clusters <= 3: |
| scatter = "medium" |
| else: |
| scatter = "high" |
|
|
| return { |
| "scatter_level": scatter, |
| "num_clusters": num_clusters, |
| "clusters_sizes": cluster_sizes |
| } |
| |
| |
| |
| |
| def get_dominant_focus_color(self, threshold_ratio=0.6, k_clusters=5): |
| """Detect dominant color in attention-focused regions""" |
| heatmap = self.heatmap.astype(float) |
| orig = self.original_image.copy() |
| |
| H, W = heatmap.shape |
| heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6) |
| mask = (heatmap_norm >= threshold_ratio).astype(np.uint8) * 255 |
| |
| if np.sum(mask) == 0: |
| return { |
| "dominant_color_rgb": None, |
| "dominant_color_hsv": None, |
| "dominant_color_name": "none", |
| "color_confidence": 0.0 |
| } |
| |
| kernel = np.ones((3, 3), np.uint8) |
| mask_clean = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2) |
| contours, _ = cv2.findContours(mask_clean, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
| if len(contours) == 0: |
| return { |
| "dominant_color_rgb": None, |
| "dominant_color_hsv": None, |
| "dominant_color_name": "none", |
| "color_confidence": 0.0 |
| } |
| |
| activation_mask = np.zeros_like(mask_clean) |
| cv2.drawContours(activation_mask, contours, -1, 255, -1) |
| focus_pixels = orig[activation_mask == 255] |
| |
| if len(focus_pixels) < 10: |
| return { |
| "dominant_color_rgb": None, |
| "dominant_color_hsv": None, |
| "dominant_color_name": "none", |
| "color_confidence": 0.0 |
| } |
|
|
| Z = np.float32(focus_pixels) |
| criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 20, 1.0) |
| K = k_clusters |
| |
| _, labels, centers = cv2.kmeans( |
| Z, K, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS |
| ) |
| |
| counts = np.bincount(labels.flatten()) |
| sorted_indices = np.argsort(counts)[::-1] |
| |
| dominant_index = None |
| dominant_color = None |
| |
| for idx in sorted_indices: |
| candidate_color = centers[idx].astype(int) |
| hsv_color = cv2.cvtColor( |
| np.uint8([[candidate_color]]), |
| cv2.COLOR_RGB2HSV |
| )[0][0] |
| |
| if hsv_color[1] > 30: |
| dominant_index = idx |
| dominant_color = candidate_color |
| break |
| |
| if dominant_index is None: |
| dominant_index = sorted_indices[0] |
| dominant_color = centers[dominant_index].astype(int) |
| |
| hsv_color = cv2.cvtColor( |
| np.uint8([[dominant_color]]), |
| cv2.COLOR_RGB2HSV |
| )[0][0] |
| |
| color_confidence = counts[dominant_index] / len(labels) * 100 |
| dominant_name = self._map_color_to_name( |
| dominant_color.tolist(), |
| hsv_color.tolist() |
| ) |
| |
| return { |
| "dominant_color_rgb": dominant_color.tolist(), |
| "dominant_color_hsv": hsv_color.tolist(), |
| "dominant_color_name": dominant_name, |
| "color_confidence": float(color_confidence) |
| } |
|
|
| def _map_color_to_name(self, rgb, hsv=None): |
| """Enhanced color naming using HSV color space""" |
| r, g, b = rgb |
| |
| if hsv is None: |
| hsv_array = cv2.cvtColor(np.uint8([[rgb]]), cv2.COLOR_RGB2HSV)[0][0] |
| h, s, v = hsv_array.tolist() |
| else: |
| h, s, v = hsv |
| |
| if s < 30: |
| if v > 200: |
| return "white / very light" |
| elif v > 150: |
| return "light gray / pale" |
| elif v > 80: |
| return "gray" |
| else: |
| return "dark gray / black" |
| |
| if v < 60: |
| return "very dark / black" |
| |
| if 130 <= h <= 160: |
| if s > 100: |
| return "purple / violet" |
| else: |
| return "light purple / lavender" |
| |
| if 160 <= h <= 180 or h <= 10: |
| if v > 180 and s < 100: |
| return "pink / light red" |
| elif s > 150: |
| return "magenta / bright pink" |
| else: |
| return "pink / rose" |
| |
| if h <= 10: |
| if v < 150: |
| return "dark red / maroon" |
| else: |
| return "red / crimson" |
| |
| if 10 <= h < 25: |
| if v < 130: |
| return "brown / dark tan" |
| else: |
| return "orange / tan" |
| |
| if 25 <= h < 40: |
| if s < 80: |
| return "beige / cream" |
| else: |
| return "yellow / golden" |
| |
| if 40 <= h < 80: |
| if v > 180: |
| return "light green / pale green" |
| elif v > 120: |
| return "green" |
| else: |
| return "dark green" |
| |
| if 80 <= h < 100: |
| return "cyan / turquoise" |
| |
| if 100 <= h < 130: |
| if s > 150: |
| return "blue / deep blue" |
| elif v > 180: |
| return "light blue / sky blue" |
| else: |
| return "blue" |
| |
| max_channel = max(r, g, b) |
| if max_channel == r: |
| return "reddish tones" |
| elif max_channel == g: |
| return "greenish tones" |
| elif max_channel == b: |
| return "bluish tones" |
| else: |
| return "mixed color region" |
|
|
| |
| |
| |
| def get_texture_analysis(self, threshold_ratio=0.6): |
| """ |
| Analyze texture patterns in high-attention regions using GLCM |
| |
| Returns generic image-based descriptions without medical assumptions |
| """ |
| heatmap = self.heatmap.astype(float) |
| orig = self.original_image.copy() |
| |
| H, W = heatmap.shape |
| heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6) |
| mask = (heatmap_norm >= threshold_ratio).astype(np.uint8) * 255 |
| |
| if np.sum(mask) < 100: |
| return { |
| "texture_classification": "insufficient data", |
| "texture_description": "Not enough attention data to analyze texture", |
| "texture_scores": { |
| "uniformity": 0, |
| "organization": 0, |
| "complexity": 0, |
| "smoothness": 0 |
| }, |
| "glcm_features": { |
| "contrast": 0.0, |
| "correlation": 0.0, |
| "energy": 0.0, |
| "homogeneity": 0.0 |
| } |
| } |
| |
| |
| kernel = np.ones((3, 3), np.uint8) |
| mask_clean = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2) |
| |
| |
| contours, _ = cv2.findContours(mask_clean, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| if len(contours) == 0: |
| return { |
| "texture_classification": "no region detected", |
| "texture_description": "No focused region detected", |
| "texture_scores": { |
| "uniformity": 0, |
| "organization": 0, |
| "complexity": 0, |
| "smoothness": 0 |
| }, |
| "glcm_features": { |
| "contrast": 0.0, |
| "correlation": 0.0, |
| "energy": 0.0, |
| "homogeneity": 0.0 |
| } |
| } |
| |
| |
| activation_mask = np.zeros_like(mask_clean) |
| cv2.drawContours(activation_mask, contours, -1, 255, -1) |
| |
| |
| y_coords, x_coords = np.where(activation_mask == 255) |
| if len(y_coords) == 0: |
| return { |
| "texture_classification": "invalid region", |
| "texture_description": "Invalid region for texture analysis", |
| "texture_scores": { |
| "uniformity": 0, |
| "organization": 0, |
| "complexity": 0, |
| "smoothness": 0 |
| }, |
| "glcm_features": { |
| "contrast": 0.0, |
| "correlation": 0.0, |
| "energy": 0.0, |
| "homogeneity": 0.0 |
| } |
| } |
| |
| y_min, y_max = y_coords.min(), y_coords.max() |
| x_min, x_max = x_coords.min(), x_coords.max() |
| |
| |
| region_rgb = orig[y_min:y_max+1, x_min:x_max+1] |
| region_mask = activation_mask[y_min:y_max+1, x_min:x_max+1] |
| |
| |
| region_gray = cv2.cvtColor(region_rgb, cv2.COLOR_RGB2GRAY) |
| |
| |
| region_gray_masked = region_gray.copy() |
| region_gray_masked[region_mask == 0] = 0 |
| |
| |
| region_quantized = (region_gray_masked / 4).astype(np.uint8) |
| |
| |
| |
| |
| distances = [1] |
| angles = [0, np.pi/4, np.pi/2, 3*np.pi/4] |
| |
| try: |
| glcm = graycomatrix( |
| region_quantized, |
| distances=distances, |
| angles=angles, |
| levels=64, |
| symmetric=True, |
| normed=True |
| ) |
| |
| |
| contrast = float(graycoprops(glcm, 'contrast')[0].mean()) |
| correlation = float(graycoprops(glcm, 'correlation')[0].mean()) |
| energy = float(graycoprops(glcm, 'energy')[0].mean()) |
| homogeneity = float(graycoprops(glcm, 'homogeneity')[0].mean()) |
| |
| except Exception as e: |
| print(f" Warning: GLCM computation failed: {e}") |
| return { |
| "texture_classification": "computation error", |
| "texture_description": "Error computing texture features", |
| "texture_scores": { |
| "uniformity": 0, |
| "organization": 0, |
| "complexity": 0, |
| "smoothness": 0 |
| }, |
| "glcm_features": { |
| "contrast": 0.0, |
| "correlation": 0.0, |
| "energy": 0.0, |
| "homogeneity": 0.0 |
| } |
| } |
| |
| |
| uniformity_score = int(energy * 100) |
| organization_score = int(max(0, min(100, (correlation + 1) * 50))) |
| complexity_score = int((1 - energy) * 100) |
| smoothness_score = int(homogeneity * 100) |
| |
| |
| classification, description = self._classify_texture( |
| contrast, correlation, energy, homogeneity |
| ) |
| |
| return { |
| "texture_classification": classification, |
| "texture_description": description, |
| "texture_scores": { |
| "uniformity": uniformity_score, |
| "organization": organization_score, |
| "complexity": complexity_score, |
| "smoothness": smoothness_score |
| }, |
| "glcm_features": { |
| "contrast": round(contrast, 2), |
| "correlation": round(correlation, 3), |
| "energy": round(energy, 3), |
| "homogeneity": round(homogeneity, 3) |
| } |
| } |
| |
| def _classify_texture(self, contrast, correlation, energy, homogeneity): |
| """ |
| Classify texture based on GLCM features |
| Returns (classification, description) tuple |
| """ |
| |
| if contrast < 100 and homogeneity > 0.8: |
| classification = "uniform and smooth" |
| description = "Model focused on a region with smooth, uniform texture showing consistent patterns with minimal variation" |
| |
| elif correlation > 0.7 and energy > 0.3: |
| classification = "structured and regular" |
| description = "Model focused on a region with organized, structured patterns exhibiting regular, repeating elements" |
| |
| elif contrast > 400 and correlation < 0.4: |
| classification = "irregular and chaotic" |
| description = "Model focused on a region with irregular, chaotic texture displaying highly variable patterns with no clear organization" |
| |
| elif homogeneity < 0.5: |
| classification = "rough and coarse" |
| description = "Model focused on a region with rough, coarse texture showing sharp intensity changes and abrupt transitions" |
| |
| elif energy < 0.2: |
| classification = "complex and varied" |
| description = "Model focused on a region with complex, varied texture containing multiple different patterns and high visual diversity" |
| |
| else: |
| classification = "moderate texture" |
| description = "Model focused on a region with moderate texture complexity showing intermediate characteristics" |
| |
| return classification, description |
| import warnings |
| warnings.filterwarnings("ignore") |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| import torchvision.transforms as transforms |
| from PIL import Image |
| import os |
| import numpy as np |
| import pandas as pd |
| from transformers import Dinov2Model |
| import openslide |
| from tqdm import tqdm |
| import logging |
| import matplotlib |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
| from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec |
| from openai import OpenAI |
| import time |
|
|
| logging.basicConfig(level=logging.INFO) |
|
|
| MODEL_PATH = "/kaggle/input/models/ulimaank/updated-diagnostic-model-jan-18/other/default/1/phase3_mil_best.pth" |
| OUTPUT_DIR = "/kaggle/working" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}, GPUs: {torch.cuda.device_count()}") |
|
|
| DISEASE_NAMES = [ |
| 'Breast_cancer', 'annrbc-anemia_processed', 'colon_processed', |
| 'leukemia_processed', 'lung_processed', 'oral-cancer_processed', |
| 'ovarian-cancer_processed', 'sickle-cell-new_processed', 'thalassemia_processed' |
| ] |
|
|
| STAGE_NAMES = { |
| 0: 'Breast_cancer - ductal_carcinoma', |
| 1: 'Breast_cancer - lobular_carcinoma', |
| 2: 'Breast_cancer - mucinous_carcinoma', |
| 3: 'Breast_cancer - papillary_carcinoma', |
| 4: 'leukemia_processed - Early', |
| 5: 'leukemia_processed - Pre', |
| 6: 'leukemia_processed - Pro', |
| 7: 'lung_processed - lung_aca', |
| 8: 'lung_processed - lung_scc', |
| 9: 'ovarian-cancer_processed - CC', |
| 10: 'ovarian-cancer_processed - EC', |
| 11: 'ovarian-cancer_processed - HGSC', |
| 12: 'ovarian-cancer_processed - LGSC', |
| 13: 'ovarian-cancer_processed - MC' |
| } |
|
|
| DISEASE_CLASS_MAPPING = { |
| 0: "Breast_cancer", |
| 1: "annrbc-anemia_processed", |
| 2: "colon_processed", |
| 3: "leukemia_processed", |
| 4: "lung_processed", |
| 5: "oral-cancer_processed", |
| 6: "ovarian-cancer_processed", |
| 7: "sickle-cell-new_processed", |
| 8: "thalassemia_processed", |
| } |
|
|
| STAGE_CLASS_MAPPING = STAGE_NAMES |
|
|
| TARGET_SIZE = 256 |
| standardize_transform = transforms.Resize((TARGET_SIZE, TARGET_SIZE)) |
|
|
|
|
| |
| |
| |
|
|
| def collect_images_from_folder(folder_path): |
| images = [] |
| valid_extensions = ('.svs', '.tif', '.ndpi', '.png', '.jpg', '.jpeg', '.tiff') |
| for root, dirs, files in os.walk(folder_path): |
| for f in files: |
| if f.lower().endswith(valid_extensions): |
| images.append(os.path.join(root, f)) |
| return images |
|
|
|
|
| |
| |
| |
|
|
| class SimpleSlideDataset(Dataset): |
| def __init__(self, image_paths, tile_size=224, max_tiles=1000): |
| self.image_paths = image_paths |
| self.tile_size = tile_size |
| self.max_tiles = max_tiles |
|
|
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def __getitem__(self, idx): |
| slide_path = self.image_paths[idx] |
| tiles = [] |
| try: |
| if slide_path.lower().endswith(('.svs', '.ndpi')): |
| slide = openslide.OpenSlide(slide_path) |
| width, height = slide.dimensions |
| for y in range(0, height, self.tile_size): |
| for x in range(0, width, self.tile_size): |
| if len(tiles) >= self.max_tiles: |
| break |
| tile = slide.read_region((x, y), 0, (self.tile_size, self.tile_size)).convert('RGB') |
| tiles.append(standardize_transform(tile)) |
| if len(tiles) >= self.max_tiles: |
| break |
| slide.close() |
| elif slide_path.lower().endswith('.tif'): |
| try: |
| slide = openslide.OpenSlide(slide_path) |
| width, height = slide.dimensions |
| for y in range(0, height, self.tile_size): |
| for x in range(0, width, self.tile_size): |
| if len(tiles) >= self.max_tiles: |
| break |
| tile = slide.read_region((x, y), 0, (self.tile_size, self.tile_size)).convert('RGB') |
| tiles.append(standardize_transform(tile)) |
| if len(tiles) >= self.max_tiles: |
| break |
| slide.close() |
| except openslide.OpenSlideError: |
| tiles = [standardize_transform(Image.open(slide_path).convert('RGB'))] |
| except Exception: |
| tiles = [standardize_transform(Image.open(slide_path).convert('RGB'))] |
| else: |
| tiles = [standardize_transform(Image.open(slide_path).convert('RGB'))] |
|
|
| if not tiles: |
| raise ValueError("No tiles extracted") |
| return tiles, slide_path |
| except Exception as e: |
| logging.error(f"Error processing slide {slide_path}: {e}") |
| return [], slide_path |
|
|
|
|
| test_transform = transforms.Compose([ |
| transforms.Resize(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
|
|
| def simple_collate(batch): |
| valid_batch = [item for item in batch if item[0]] |
| if not valid_batch: |
| return [], [] |
| tiles_list, paths = zip(*valid_batch) |
| processed_tiles = [torch.stack([test_transform(tile) for tile in tiles]) for tiles in tiles_list] |
| return processed_tiles, list(paths) |
|
|
|
|
| |
| |
| |
|
|
| class ViTBackbone(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.vit = Dinov2Model.from_pretrained("owkin/phikon-v2") |
|
|
| def forward(self, x): |
| return self.vit(pixel_values=x).last_hidden_state[:, 0] |
|
|
|
|
| class ClassificationHead(nn.Module): |
| def __init__(self, in_dim=1024, num_classes=2, hidden_dim=512): |
| super().__init__() |
| self.classifier = nn.Sequential( |
| nn.Linear(in_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(hidden_dim, num_classes) |
| ) |
|
|
| def forward(self, x): |
| return self.classifier(x) |
|
|
|
|
| class HierarchicalMILAggregator(nn.Module): |
| def __init__(self, embed_dim=1024, num_heads=8, num_layers=2, |
| num_diseases=6, num_stage_classes=0, disease_names=None): |
| super().__init__() |
| self.pre_norm = nn.LayerNorm(embed_dim) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=embed_dim, nhead=num_heads, batch_first=True, dropout=0.1) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
| self.attention = nn.Sequential(nn.Linear(embed_dim, 256), nn.Tanh(), nn.Linear(256, 1)) |
| self.disease_head = ClassificationHead(embed_dim, num_diseases) |
| self.severity_heads = nn.ModuleDict() |
| for name in disease_names: |
| self.severity_heads[name] = ClassificationHead(embed_dim, 2) |
| self.stage_head = ClassificationHead(embed_dim, num_stage_classes) if num_stage_classes > 0 else None |
| self.disease_name_to_idx = {n: i for i, n in enumerate(disease_names)} |
| self.idx_to_disease_name = {i: n for n, i in self.disease_name_to_idx.items()} |
| self.disease_names = disease_names |
|
|
| def forward(self, tile_features): |
| normalized = self.pre_norm(tile_features) |
| aggregated = self.transformer(normalized) |
| attn_scores = self.attention(aggregated) |
| attn_weights = torch.softmax(attn_scores.squeeze(-1), dim=1) |
| weighted = torch.sum(aggregated * attn_weights.unsqueeze(-1), dim=1) |
| disease_logits = self.disease_head(weighted) |
| severity_logits = {n: self.severity_heads[n](weighted) for n in self.disease_names} |
| stage_logits = self.stage_head(weighted) if self.stage_head is not None else None |
| return disease_logits, severity_logits, stage_logits, attn_weights |
|
|
|
|
| class Phase3Model(nn.Module): |
| def __init__(self, backbone, num_diseases=6, num_stage_classes=0, disease_names=None): |
| super().__init__() |
| self.backbone = backbone |
| for param in self.backbone.parameters(): |
| param.requires_grad = False |
| self.aggregator = HierarchicalMILAggregator( |
| num_diseases=num_diseases, |
| num_stage_classes=num_stage_classes, |
| disease_names=disease_names |
| ) |
|
|
| def forward(self, tiles, enable_gradients=False): |
| all_features = [] |
| for batch_tiles in tiles: |
| if batch_tiles.numel() == 0: |
| continue |
| batch_tiles = batch_tiles.to(next(self.backbone.parameters()).device) |
| if enable_gradients: |
| batch_features = self.backbone(batch_tiles) |
| else: |
| with torch.no_grad(): |
| batch_features = self.backbone(batch_tiles) |
| all_features.append(batch_features) |
| if not all_features: |
| raise ValueError("No valid tile features could be extracted.") |
| all_features = torch.stack(all_features) |
| return self.aggregator(all_features) |
|
|
|
|
| |
| |
| |
|
|
| def predict_image(model, tiles, disease_names, stage_names): |
| model.eval() |
| try: |
| with torch.no_grad(): |
| disease_logits, severity_logits, stage_logits, _ = model(tiles) |
| disease_probs = F.softmax(disease_logits, dim=1) |
| disease_pred_idx = torch.argmax(disease_probs, dim=1).item() |
| disease_confidence = disease_probs[0, disease_pred_idx].item() |
| predicted_disease = disease_names[disease_pred_idx] |
|
|
| severity_probs = F.softmax(severity_logits[predicted_disease], dim=1) |
| severity_pred = torch.argmax(severity_probs, dim=1).item() |
| severity_confidence = severity_probs[0, severity_pred].item() |
| severity_label = "Normal" if severity_pred == 0 else "Abnormal" |
|
|
| stage_label = "N/A" |
| stage_confidence = 0.0 |
| if severity_pred == 1 and stage_logits is not None: |
| stage_probs = F.softmax(stage_logits, dim=1) |
| stage_pred_idx = torch.argmax(stage_probs, dim=1).item() |
| stage_confidence = stage_probs[0, stage_pred_idx].item() |
| stage_label = stage_names.get(stage_pred_idx, f"Stage_{stage_pred_idx}") |
|
|
| return { |
| 'disease': predicted_disease, |
| 'disease_confidence': disease_confidence, |
| 'severity': severity_label, |
| 'severity_confidence': severity_confidence, |
| 'stage': stage_label, |
| 'stage_confidence': stage_confidence |
| } |
| except Exception as e: |
| logging.error(f"Error during prediction: {e}") |
| return None |
|
|
|
|
| |
| |
| |
|
|
| def generate_comprehensive_explanation(comprehensive_data): |
| """ |
| Calls GPT-4o-mini to convert technical XAI metrics into a human-friendly |
| explanation. Falls back to a template string if the API call fails. |
| Only references Attention and GradCAM heatmaps. |
| """ |
| try: |
| client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) |
|
|
| prompt = f"""You are an AI explainability assistant helping users understand how a hierarchical medical image classification model made its decision. Convert the following technical analysis into a clear, accessible explanation. |
| |
| HIERARCHICAL MODEL PREDICTION: |
| - Region: {comprehensive_data['predicted_disease']} ({comprehensive_data['gradcam_disease_conf']:.1%} confidence) |
| - Status Level: {comprehensive_data['gradcam_severity']} ({comprehensive_data['gradcam_severity_conf']:.1%} confidence) |
| - Stage Level: {comprehensive_data['predicted_stage']} ({comprehensive_data['stage_confidence']:.1%} confidence) |
| |
| GRADCAM ANALYSIS (Gradient-weighted Class Activation Mapping): |
| - Note: Bright/warm regions in GradCAM indicate areas that most strongly influenced the model's prediction |
| |
| SPATIAL ATTENTION PATTERN AND VISUAL CHARACTERISTICS (from Attention Heatmap): |
| - Primary Focus: {comprehensive_data['primary_position']} (intensity: {comprehensive_data['primary_intensity']:.2f}) |
| - Attention Hotspots: {comprehensive_data['hotspot_count']} |
| - Spatial Distribution: Center {comprehensive_data['center_attention']:.1f}%, Mid-region {comprehensive_data['mid_attention']:.1f}%, Periphery {comprehensive_data['periphery_attention']:.1f}% |
| - Clustering: {comprehensive_data['scatter_level']} scatter level with {comprehensive_data['num_clusters']} clusters |
| - Dominant Color: {comprehensive_data['dominant_color']} ({comprehensive_data['color_confidence']:.1f}% confidence) |
| - Texture Pattern: {comprehensive_data['texture_classification']} |
| - Texture Scores: Uniformity {comprehensive_data['uniformity']}/100, Organization {comprehensive_data['organization']}/100, Complexity {comprehensive_data['complexity']}/100, Smoothness {comprehensive_data['smoothness']}/100 |
| |
| CRITICAL INSTRUCTIONS: |
| 1. Write in clear, accessible language for someone without medical or technical expertise |
| 2. Ground ALL statements in the provided data - do NOT add medical interpretations or diagnoses |
| 3. Explain how the two explainability methods (Attention Heatmap and GradCAM) show WHERE the model focused |
| 4. Describe WHAT visual patterns were detected, not WHY medically |
| 5. Keep it concise but informative (under 100 words) |
| 6. Structure with clear sections |
| 7. Make it conversational but professional |
| 8. Visual Characteristics and Spatial Attention Pattern were taken from Attention Heatmap |
| |
| Generate a comprehensive explanation covering: what the model decided, where it looked, what the Attention and GradCAM methods revealed, what visual characteristics were important, and how confident we can be in the decision. |
| |
| Format as natural paragraphs, not bullet points.""" |
|
|
| response = client.chat.completions.create( |
| model="gpt-4o-mini", |
| messages=[ |
| {"role": "system", "content": "You are an expert at explaining complex AI model decisions in simple, clear language. You help users understand model behavior without making medical claims."}, |
| {"role": "user", "content": prompt} |
| ], |
| temperature=0.7, |
| max_tokens=700 |
| ) |
| return response.choices[0].message.content.strip() |
|
|
| except Exception as e: |
| logging.warning(f"OpenAI API call failed: {e}. Using fallback template.") |
| return ( |
| f"MODEL DECISION SUMMARY\n\n" |
| f"The model classified this as '{comprehensive_data['predicted_disease']}' " |
| f"with severity '{comprehensive_data['gradcam_severity']}' " |
| f"and stage '{comprehensive_data['predicted_stage']}' " |
| f"({comprehensive_data['stage_confidence']:.1%} confidence).\n\n" |
| f"ATTENTION ANALYSIS\n\n" |
| f"Primary focus: {comprehensive_data['primary_position']} region. " |
| f"Attention shows {comprehensive_data['scatter_level']} scatter across " |
| f"{comprehensive_data['num_clusters']} clusters. " |
| f"Distribution - Center: {comprehensive_data['center_attention']:.1f}%, " |
| f"Mid: {comprehensive_data['mid_attention']:.1f}%, " |
| f"Periphery: {comprehensive_data['periphery_attention']:.1f}%.\n\n" |
| f"GRADCAM ANALYSIS\n\n" |
| f"GradCAM confidence: {comprehensive_data['gradcam_disease_conf']:.1%} (disease), " |
| f"{comprehensive_data['gradcam_severity_conf']:.1%} (severity).\n\n" |
| f"VISUAL PATTERNS\n\n" |
| f"Dominant color: {comprehensive_data['dominant_color']}. " |
| f"Texture: {comprehensive_data['texture_classification']} " |
| f"(uniformity {comprehensive_data['uniformity']}/100, " |
| f"smoothness {comprehensive_data['smoothness']}/100)." |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _render_explanation_to_image(explanation_text, figsize=(16, 4)): |
| """ |
| Renders a plain-text explanation string into an (H, W, 3) uint8 numpy |
| array that fills the entire Row 3 panel in display_prediction(). |
| """ |
| fig, ax = plt.subplots(figsize=figsize, facecolor='#0F0F2A') |
| fig.subplots_adjust(left=0, right=1, top=1, bottom=0) |
| ax.set_facecolor('#0F0F2A') |
| ax.set_xlim(0, 1) |
| ax.set_ylim(0, 1) |
| ax.axis('off') |
|
|
| ax.add_patch(mpatches.FancyBboxPatch( |
| (0.0, 0.0), 1.0, 1.0, |
| boxstyle="round,pad=0.01", |
| linewidth=2, |
| edgecolor='#3498DB', |
| facecolor='#16213E', |
| transform=ax.transAxes, |
| clip_on=False |
| )) |
|
|
| ax.text( |
| 0.5, 0.93, |
| 'Textual Explanation', |
| ha='center', va='top', |
| fontsize=11, |
| fontweight='bold', |
| color='#3498DB', |
| transform=ax.transAxes |
| ) |
|
|
| ax.add_line(plt.Line2D( |
| [0.02, 0.98], [0.855, 0.855], |
| transform=ax.transAxes, |
| color='#3498DB', |
| linewidth=1.0 |
| )) |
|
|
| ax.text( |
| 0.02, 0.83, |
| explanation_text, |
| va='top', ha='left', |
| fontsize=10, |
| color='#E0E0F0', |
| family='monospace', |
| wrap=True, |
| transform=ax.transAxes |
| ) |
|
|
| fig.canvas.draw() |
| buf = fig.canvas.buffer_rgba() |
| img_array = np.frombuffer(buf, dtype=np.uint8).reshape( |
| fig.canvas.get_width_height()[::-1] + (4,) |
| ) |
| plt.close(fig) |
| return img_array[:, :, :3] |
|
|
|
|
| |
| |
| |
|
|
| def _build_attention_overlay(img_array, heatmap_raw): |
| """ |
| Takes img_array (H, W, 3) uint8 and heatmap_raw (H, W) float, |
| returns an RGB overlay as np.ndarray in [0, 1]. |
| """ |
| import cv2 |
| img_norm = img_array.astype(np.float32) / 255.0 |
| hm = heatmap_raw.astype(np.float32) |
|
|
| h, w = img_norm.shape[:2] |
| if hm.shape != (h, w): |
| hm = cv2.resize(hm, (w, h), interpolation=cv2.INTER_CUBIC) |
|
|
| hm_min, hm_max = hm.min(), hm.max() |
| if hm_max > hm_min: |
| hm = (hm - hm_min) / (hm_max - hm_min) |
|
|
| cmap = matplotlib.colormaps.get_cmap('jet') |
| hm_colored = cmap(hm)[:, :, :3] |
| overlay = img_norm * 0.5 + hm_colored * 0.5 |
| return np.clip(overlay, 0, 1) |
|
|
|
|
| def _preprocess_images_for_attention(image_paths): |
| processed = [] |
| for p in image_paths: |
| try: |
| img_pil = Image.open(p).convert('RGB').resize((224, 224), Image.BILINEAR) |
| tensor = test_transform(img_pil) |
| processed.append(tensor) |
| except Exception as e: |
| logging.warning(f"Could not preprocess {p} for attention: {e}") |
| processed.append(torch.zeros(3, 224, 224)) |
| return processed |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| def display_prediction(image_path, prediction, |
| heatmap_images=None, |
| heatmap_titles=None, |
| explanation_image=None): |
| severity = prediction['severity'] |
| accent = '#E74C3C' if severity == 'Abnormal' else '#2ECC71' |
| bg_color = '#1A1A2E' |
| panel_color = '#16213E' |
| border_dim = '#2A2A4A' |
|
|
| |
| if heatmap_titles is None: |
| heatmap_titles = ['Attention Heatmap', 'GradCAM Heatmap'] |
|
|
| fig = plt.figure(figsize=(16, 14), facecolor=bg_color) |
| outer_gs = GridSpec(3, 1, figure=fig, |
| height_ratios=[5, 4, 3], |
| hspace=0.08) |
|
|
| |
| row1_gs = GridSpecFromSubplotSpec(1, 2, |
| subplot_spec=outer_gs[0], |
| width_ratios=[1, 1.2], |
| wspace=0.05) |
|
|
| ax_img = fig.add_subplot(row1_gs[0]) |
| ax_img.set_facecolor(bg_color) |
| try: |
| img = Image.open(image_path).convert('RGB') |
| ax_img.imshow(img) |
| except Exception: |
| ax_img.text(0.5, 0.5, 'WSI / Slide\n(preview unavailable)', |
| ha='center', va='center', color='white', fontsize=13, |
| transform=ax_img.transAxes) |
| for spine in ax_img.spines.values(): |
| spine.set_edgecolor(accent) |
| spine.set_linewidth(3) |
| ax_img.set_xticks([]) |
| ax_img.set_yticks([]) |
| ax_img.set_title(os.path.basename(image_path), color='white', |
| fontsize=11, pad=8, fontweight='bold') |
|
|
| ax_info = fig.add_subplot(row1_gs[1]) |
| ax_info.set_facecolor(bg_color) |
| ax_info.set_xlim(0, 1) |
| ax_info.set_ylim(0, 1) |
| ax_info.axis('off') |
| ax_info.text(0.5, 0.96, 'Diagnostic Report', |
| ha='center', va='top', fontsize=15, fontweight='bold', |
| color='white', transform=ax_info.transAxes) |
| divider = plt.Line2D([0.05, 0.95], [0.89, 0.89], |
| transform=ax_info.transAxes, |
| color=accent, linewidth=1.5) |
| ax_info.add_line(divider) |
|
|
| def draw_card(ax, y, label, value, confidence, color): |
| ax.add_patch(mpatches.FancyBboxPatch( |
| (0.04, y - 0.11), 0.92, 0.14, |
| boxstyle="round,pad=0.01", |
| linewidth=1.5, edgecolor=color, |
| facecolor=panel_color, transform=ax.transAxes, clip_on=False |
| )) |
| ax.text(0.10, y - 0.01, label.upper(), |
| ha='left', va='center', fontsize=8, color='#A0A0C0', |
| fontweight='bold', transform=ax.transAxes) |
| ax.text(0.10, y - 0.05, value, |
| ha='left', va='center', fontsize=13, color='white', |
| fontweight='bold', transform=ax.transAxes) |
| if confidence > 0: |
| bar_y = y - 0.09 |
| ax.add_patch(mpatches.FancyBboxPatch( |
| (0.08, bar_y), 0.60, 0.015, |
| boxstyle="round,pad=0.001", linewidth=0, |
| facecolor='#0F3460', transform=ax.transAxes, clip_on=False |
| )) |
| ax.add_patch(mpatches.FancyBboxPatch( |
| (0.08, bar_y), 0.60 * confidence, 0.015, |
| boxstyle="round,pad=0.001", linewidth=0, |
| facecolor=color, transform=ax.transAxes, clip_on=False |
| )) |
| ax.text(0.72, bar_y + 0.007, f'{confidence:.1%}', |
| ha='left', va='center', fontsize=9, color=color, |
| fontweight='bold', transform=ax.transAxes) |
|
|
| draw_card(ax_info, 0.78, |
| 'Region', |
| prediction['disease'].replace('_processed', '').replace('_', ' ').title(), |
| prediction['disease_confidence'], '#3498DB') |
| draw_card(ax_info, 0.57, 'Status', severity, |
| prediction['severity_confidence'], accent) |
|
|
| stage_val = prediction['stage'] |
| stage_conf = prediction['stage_confidence'] |
| if stage_val == 'N/A': |
| stage_display = 'N/A (Normal)' |
| stage_conf = 0 |
| else: |
| stage_display = (stage_val.split(' - ')[-1].replace('_', ' ').title() |
| if ' - ' in stage_val else stage_val) |
| draw_card(ax_info, 0.36, 'Stage / Subtype', stage_display, stage_conf, '#F39C12') |
|
|
| |
| row2_gs = GridSpecFromSubplotSpec(1, 2, |
| subplot_spec=outer_gs[1], |
| wspace=0.06) |
| for col_idx in range(2): |
| ax_hm = fig.add_subplot(row2_gs[col_idx]) |
| ax_hm.set_facecolor(panel_color) |
|
|
| if heatmap_images and col_idx < len(heatmap_images) and heatmap_images[col_idx] is not None: |
| hm = heatmap_images[col_idx] |
| ax_hm.imshow(hm if isinstance(hm, np.ndarray) else np.array(hm)) |
| else: |
| ax_hm.set_xlim(0, 1) |
| ax_hm.set_ylim(0, 1) |
| ax_hm.add_patch(mpatches.FancyBboxPatch( |
| (0.05, 0.05), 0.90, 0.90, |
| boxstyle="round,pad=0.02", |
| linewidth=1.5, linestyle='--', |
| edgecolor='#4A4A6A', facecolor='#0F0F2A', |
| transform=ax_hm.transAxes, clip_on=False |
| )) |
| ax_hm.text(0.5, 0.5, '[ Heatmap\nPlaceholder ]', |
| ha='center', va='center', |
| color='#4A4A6A', fontsize=9, fontstyle='italic', |
| transform=ax_hm.transAxes) |
|
|
| title = heatmap_titles[col_idx] if col_idx < len(heatmap_titles) else f'Heatmap {col_idx+1}' |
| ax_hm.set_title(title, color='#A0A0C0', fontsize=9, fontweight='bold', pad=5) |
| for spine in ax_hm.spines.values(): |
| spine.set_edgecolor(border_dim) |
| spine.set_linewidth(1.2) |
| ax_hm.set_xticks([]) |
| ax_hm.set_yticks([]) |
|
|
| |
| ax_text = fig.add_subplot(outer_gs[2]) |
| ax_text.set_facecolor(panel_color) |
|
|
| if explanation_image is not None: |
| exp_img = explanation_image if isinstance(explanation_image, np.ndarray) \ |
| else np.array(explanation_image) |
| ax_text.imshow(exp_img, aspect='auto') |
| ax_text.set_xticks([]) |
| ax_text.set_yticks([]) |
| else: |
| ax_text.set_xlim(0, 1) |
| ax_text.set_ylim(0, 1) |
| ax_text.axis('off') |
| ax_text.add_patch(mpatches.FancyBboxPatch( |
| (0.01, 0.05), 0.98, 0.90, |
| boxstyle="round,pad=0.02", |
| linewidth=1.5, linestyle='--', |
| edgecolor='#4A4A6A', facecolor='#0F0F2A', |
| transform=ax_text.transAxes, clip_on=False |
| )) |
| ax_text.text(0.5, 0.80, |
| 'Human-Friendly Text Explanation', |
| ha='center', va='center', |
| color='#A0A0C0', fontsize=11, fontweight='bold', |
| transform=ax_text.transAxes) |
| ax_text.text(0.5, 0.42, |
| '[ Textual Explanation Placeholder ]\n\n' |
| 'The model focused on [extracted features]\n' |
| 'because [rule-based reasoning] ...', |
| ha='center', va='center', |
| color='#3A3A5A', fontsize=10, fontstyle='italic', |
| transform=ax_text.transAxes) |
|
|
| for spine in ax_text.spines.values(): |
| spine.set_edgecolor(border_dim) |
| spine.set_linewidth(1.2) |
|
|
| plt.tight_layout(pad=1.2) |
| plt.show() |
| print() |
|
|
|
|
| |
| |
| |
|
|
| def load_model(): |
| if not os.path.exists(MODEL_PATH): |
| print(f"ERROR: Model not found at {MODEL_PATH}") |
| return None, None |
|
|
| checkpoint = torch.load(MODEL_PATH, map_location=device) |
| state_dict = checkpoint['model_state_dict'] |
|
|
| if list(state_dict.keys())[0].startswith('module.'): |
| state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} |
|
|
| severity_head_names = set() |
| for k in state_dict.keys(): |
| if k.startswith('aggregator.severity_heads.'): |
| parts = k.split('.') |
| if len(parts) > 2: |
| severity_head_names.add(parts[2]) |
| disease_names = sorted(list(severity_head_names)) |
|
|
| num_diseases = state_dict['aggregator.disease_head.classifier.3.weight'].shape[0] |
| num_stage_classes = ( |
| state_dict['aggregator.stage_head.classifier.3.weight'].shape[0] |
| if 'aggregator.stage_head.classifier.3.weight' in state_dict else 0 |
| ) |
|
|
| print(f"\nModel config -> diseases: {num_diseases} | stages: {num_stage_classes}") |
| print(f"Classes: {', '.join(disease_names)}\n") |
|
|
| backbone = ViTBackbone() |
| model = Phase3Model(backbone, num_diseases=num_diseases, |
| num_stage_classes=num_stage_classes, |
| disease_names=disease_names).to(device) |
| model.load_state_dict(state_dict, strict=True) |
| model.eval() |
| print("Model loaded successfully!\n") |
| return model, disease_names |
|
|
|
|
| |
| |
| |
|
|
| def run_inference(): |
| print("\n" + "=" * 70) |
| print(" HIERARCHICAL MIL MODEL - PATHOLOGY INFERENCE") |
| print("=" * 70) |
| print("\nOptions:") |
| print(" 1. Single image (provide full path to one image file)") |
| print(" 2. Folder (provide path to a folder; all images processed)") |
|
|
| choice = input("\nSelect option (1 / 2): ").strip() |
|
|
| if choice == '1': |
| image_path = input("Enter image path: ").strip() |
| if not os.path.isfile(image_path): |
| print(f"ERROR: File not found -> {image_path}") |
| return |
| all_images = [image_path] |
| elif choice == '2': |
| folder_path = input("Enter folder path: ").strip() |
| if not os.path.isdir(folder_path): |
| print(f"ERROR: Folder not found -> {folder_path}") |
| return |
| all_images = collect_images_from_folder(folder_path) |
| if not all_images: |
| print("No valid images found in the folder.") |
| return |
| print(f"Found {len(all_images)} image(s).") |
| else: |
| print("Invalid option.") |
| return |
|
|
| |
| |
| |
| print("\nLoading model ...") |
| model, disease_names = load_model() |
| if model is None: |
| return |
|
|
| |
| |
| |
| print("\n" + "=" * 70) |
| print("Running Attention Analysis ...") |
| print("=" * 70) |
| processed_images = _preprocess_images_for_attention(all_images) |
| attn_results = run_attention_analysis( |
| attention_extractor, |
| device, |
| all_images, |
| processed_images, |
| DISEASE_CLASS_MAPPING, |
| STAGE_CLASS_MAPPING |
| ) |
|
|
| print("\n" + "=" * 70) |
| print("Running GradCAM Analysis ...") |
| print("=" * 70) |
| gradcam_results = run_tri_head_gradcam_plus_plus_analysis(model, device, all_images) |
|
|
| |
| |
| |
| print("\n" + "=" * 70) |
| print("Running Feature Extraction ...") |
| print("=" * 70) |
|
|
| explanations_list = [] |
|
|
| for i, attention_result in enumerate(attn_results): |
| print(f" [{i+1}/{len(attn_results)}] Extracting features: {attention_result['filename']}") |
|
|
| extractor = HeatmapFeatureExtractor(attention_result) |
|
|
| bright = extractor.get_brightest_region() |
| scatter = extractor.get_activation_scatter() |
| dom_color = extractor.get_dominant_focus_color() |
| texture = extractor.get_texture_analysis() |
|
|
| explanations_list.append({ |
| "brightest": bright, |
| "scatter": scatter, |
| "dominant_color": dom_color, |
| "texture": texture, |
| }) |
|
|
| print(f" ✅ Position={bright['primary_hotspot']['position']}, " |
| f"Scatter={scatter['scatter_level']}, " |
| f"Color={dom_color['dominant_color_name']}, " |
| f"Texture={texture['texture_classification']}") |
|
|
| |
| |
| |
| dataset = SimpleSlideDataset(all_images) |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=False, |
| collate_fn=simple_collate, num_workers=2, pin_memory=True) |
|
|
| results = [] |
|
|
| for batch_idx, batch in enumerate(tqdm(dataloader, desc="Running inference")): |
| tiles, paths = batch |
| if not tiles or not paths: |
| continue |
|
|
| slide_path = paths[0] |
|
|
| try: |
| img_idx = all_images.index(slide_path) |
| except ValueError: |
| img_idx = batch_idx |
|
|
| prediction = predict_image(model, tiles, disease_names, STAGE_NAMES) |
| if prediction is None: |
| print(f"Failed to process: {slide_path}") |
| continue |
|
|
| |
| attn_overlay = None |
| if attn_results and img_idx < len(attn_results): |
| ar = attn_results[img_idx] |
| attn_overlay = _build_attention_overlay(ar['image'], ar['attention_heatmap']) |
|
|
| gradcam_overlay = None |
| if gradcam_results and img_idx < len(gradcam_results): |
| gradcam_overlay = gradcam_results[img_idx]['union']['heatmap_overlay'] |
|
|
| |
| exp = explanations_list[img_idx] |
| bright = exp['brightest'] |
| scatter_res = exp['scatter'] |
| dom = exp['dominant_color'] |
| texture = exp['texture'] |
|
|
| |
| gradcam_disease_conf = gradcam_results[img_idx]['level1_disease']['confidence'] \ |
| if gradcam_results else 0.0 |
| gradcam_severity_conf = gradcam_results[img_idx]['level2_severity']['confidence'] \ |
| if gradcam_results else 0.0 |
| gradcam_severity = gradcam_results[img_idx]['level2_severity']['predicted_class'] \ |
| if gradcam_results else 'N/A' |
|
|
| comprehensive_data = { |
| 'predicted_disease': prediction['disease'], |
| 'gradcam_severity': gradcam_severity, |
| 'predicted_stage': prediction['stage'], |
| 'stage_confidence': prediction['stage_confidence'], |
| 'gradcam_disease_conf': gradcam_disease_conf, |
| 'gradcam_severity_conf': gradcam_severity_conf, |
| 'primary_position': bright['primary_hotspot']['position'], |
| 'primary_intensity': bright['primary_hotspot']['intensity'], |
| 'hotspot_count': bright['hotspot_count'], |
| 'center_attention': bright['spatial_coverage']['center_attention'], |
| 'mid_attention': bright['spatial_coverage']['mid_region_attention'], |
| 'periphery_attention': bright['spatial_coverage']['periphery_attention'], |
| 'scatter_level': scatter_res['scatter_level'], |
| 'num_clusters': scatter_res['num_clusters'], |
| 'dominant_color': dom['dominant_color_name'], |
| 'color_confidence': dom['color_confidence'], |
| 'texture_classification': texture['texture_classification'], |
| 'uniformity': texture['texture_scores']['uniformity'], |
| 'organization': texture['texture_scores']['organization'], |
| 'complexity': texture['texture_scores']['complexity'], |
| 'smoothness': texture['texture_scores']['smoothness'], |
| } |
|
|
| |
| print(f"\n Generating GPT explanation for image {img_idx + 1} ...") |
| explanation_text = generate_comprehensive_explanation(comprehensive_data) |
| explanation_image = _render_explanation_to_image(explanation_text) |
| time.sleep(0.5) |
|
|
| |
| display_prediction( |
| image_path = slide_path, |
| prediction = prediction, |
| heatmap_images = [attn_overlay, gradcam_overlay], |
| heatmap_titles = ['Attention Heatmap', 'GradCAM Heatmap'], |
| explanation_image = explanation_image |
| ) |
|
|
| results.append({ |
| 'image_path': slide_path, |
| 'image_name': os.path.basename(slide_path), |
| 'predicted_disease': prediction['disease'], |
| 'disease_confidence': prediction['disease_confidence'], |
| 'predicted_severity': prediction['severity'], |
| 'severity_confidence': prediction['severity_confidence'], |
| 'predicted_stage': prediction['stage'], |
| 'stage_confidence': prediction['stage_confidence'] |
| }) |
|
|
| if results: |
| df = pd.DataFrame(results) |
| out_path = os.path.join(OUTPUT_DIR, "inference_results.csv") |
| df.to_csv(out_path, index=False) |
| print(f"\nResults saved -> {out_path}") |
| print(f"Total processed: {len(results)}") |
|
|
| print("\n" + "=" * 70) |
| print("DONE") |
| print("=" * 70) |
|
|
|
|
| if __name__ == "__main__": |
| run_inference() |