# This Python 3 environment comes with many helpful analytics libraries installed # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python # For example, here's several helpful packages to load import numpy as np # linear algebra import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) # Input data files are available in the read-only "../input/" directory # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory import os for dirname, _, filenames in os.walk('/kaggle/input'): for filename in filenames: print(os.path.join(dirname, filename)) # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt import numpy as np from PIL import Image import cv2 import os import pandas as pd from scipy.ndimage import gaussian_filter # ------------------------------------------- # Configuration # ------------------------------------------- OUTPUT_DIR = "/kaggle/working" GRADCAM_OUTPUT_DIR = os.path.join(OUTPUT_DIR, 'gradcam_plus_plus_results') os.makedirs(GRADCAM_OUTPUT_DIR, exist_ok=True) print("="*80) print("HIERARCHICAL TRI-HEAD GRAD-CAM++ CONFIGURATION") print("="*80) print(f"Output directory: {GRADCAM_OUTPUT_DIR}") print("="*80) # ------------------------------------------- # Class Mappings # ------------------------------------------- DISEASE_CLASS_MAPPING = { 0: "Breast_cancer", 1: "annrbc-anemia_processed", 2: "colon_processed", 3: "leukemia_processed", 4: "lung_processed", 5: "oral-cancer_processed", 6: "ovarian-cancer_processed", 7: "sickle-cell-new_processed", 8: "thalassemia_processed", } SEVERITY_CLASS_MAPPING = { 0: "Normal", 1: "Abnormal", } # ------------------------------------------- # Grad-CAM++ Implementation for ViT # ------------------------------------------- # ------------------------------------------- # Grad-CAM++ Implementation for ViT (FIXED for tuple outputs) # ------------------------------------------- class GradCAM: """ Standard Grad-CAM implementation for Vision Transformers Simplified - no second-order gradients, just straightforward CAM """ def __init__(self, model, target_layer): """ Args: model: Your Phase3 hierarchical model target_layer: The layer to hook (typically last transformer block) """ self.model = model self.target_layer = target_layer self.gradients = None self.activations = None # Register hooks self.handlers = [] self._register_hooks() def _register_hooks(self): """Register forward and backward hooks on target layer""" def forward_hook(module, input, output): # Handle tuple output (DINOv2 returns tuple) if isinstance(output, tuple): self.activations = output[0].detach() print(f" šŸŖ Forward hook: Captured from tuple, shape {output[0].shape}") else: self.activations = output.detach() print(f" šŸŖ Forward hook: Captured tensor, shape {output.shape}") def backward_hook(module, grad_input, grad_output): # Handle tuple output in gradients if isinstance(grad_output, tuple): grad = grad_output[0] if grad is not None: self.gradients = grad.detach() print(f" šŸŖ Backward hook: Captured from tuple, shape {grad.shape}") else: if grad_output is not None: self.gradients = grad_output.detach() print(f" šŸŖ Backward hook: Captured tensor, shape {grad_output.shape}") # Register hooks self.handlers.append( self.target_layer.register_forward_hook(forward_hook) ) self.handlers.append( self.target_layer.register_full_backward_hook(backward_hook) ) def remove_hooks(self): """Remove all hooks""" for handle in self.handlers: handle.remove() def generate_cam(self, class_idx, logits): """ Generate standard Grad-CAM heatmap Args: class_idx: Target class index logits: Model output logits Returns: cam: Grad-CAM heatmap (H, W) """ # Zero gradients self.model.zero_grad() # Backward pass one_hot = torch.zeros_like(logits) one_hot[0, class_idx] = 1 logits.backward(gradient=one_hot, retain_graph=True) # ======================================== # šŸ” GRADIENT FLOW DEBUGGING # ======================================== print(f"\n šŸ” GRADIENT FLOW CHECK:") print(f" {'='*60}") # Check if gradients were captured if self.gradients is None: print(f" āŒ CRITICAL: No gradients captured!") return np.zeros((14, 14)) else: print(f" āœ… Gradients captured: {self.gradients.shape}") print(f" Min: {self.gradients.min().item():.6f}, Max: {self.gradients.max().item():.6f}") print(f" Mean: {self.gradients.mean().item():.6f}, Std: {self.gradients.std().item():.6f}") # Check if activations were captured if self.activations is None: print(f" āŒ CRITICAL: No activations captured!") return np.zeros((14, 14)) else: print(f" āœ… Activations captured: {self.activations.shape}") print(f" Min: {self.activations.min().item():.6f}, Max: {self.activations.max().item():.6f}") print(f" {'='*60}\n") # ======================================== # Standard Grad-CAM computation # gradients: [B, N, D] # activations: [B, N, D] # Step 1: Global average pooling on gradients to get weights # Take mean across spatial dimension (tokens) for each channel weights = self.gradients.mean(dim=1, keepdim=True) # [B, 1, D] print(f" šŸ“Š Weights (channel importance):") print(f" Shape: {weights.shape}") print(f" Min: {weights.min().item():.6f}, Max: {weights.max().item():.6f}") # Step 2: Weighted combination of activation maps # weights: [B, 1, D] # activations: [B, N, D] # Result: [B, N] - one value per token cam = (weights * self.activations).sum(dim=2) # Sum across channels (D) print(f" šŸ“Š CAM before ReLU:") print(f" Shape: {cam.shape}") print(f" Min: {cam.min().item():.6f}, Max: {cam.max().item():.6f}") # Step 3: Apply ReLU (only keep positive contributions) #cam = F.relu(cam) print(f" šŸ“Š CAM after ReLU:") print(f" Min: {cam.min().item():.6f}, Max: {cam.max().item():.6f}") # Step 4: Remove batch dimension cam = cam[0] # [N] # Step 5: Remove CLS token (first token in ViT) if cam.shape[0] > 1: cam = cam[1:] print(f" šŸŽÆ Removed CLS token, remaining tokens: {cam.shape[0]}") # Step 6: Reshape to spatial grid grid_size = int(np.sqrt(cam.shape[0])) print(f" šŸ“ Grid size: {grid_size}x{grid_size}") cam = cam.reshape(grid_size, grid_size) # Step 7: Normalize to [0, 1] cam_min = cam.min() cam_max = cam.max() if cam_max > cam_min: cam = (cam - cam_min) / (cam_max - cam_min) print(f" āœ… Normalized CAM: min={cam.min().item():.4f}, max={cam.max().item():.4f}") else: print(f" āš ļø WARNING: No variation in CAM (all same value)") cam = torch.zeros_like(cam) print(f" šŸ“Š Final CAM mean: {cam.mean().item():.6f}\n") return cam.cpu().numpy() class GradCAMVisionTransformer: """ Wrapper to apply standard Grad-CAM to Vision Transformer models """ def __init__(self, model, device): self.model = model self.device = device self.model.eval() # Find the last transformer block self.target_layer = self._find_target_layer() print(f" šŸŽÆ Target layer for Grad-CAM: {self.target_layer}") def _find_target_layer(self): """ Find the last transformer block in DINOv2/Phikon model """ print("\nšŸ” Searching for target layer in DINOv2 architecture...") try: if hasattr(self.model, 'backbone'): vit_model = self.model.backbone.vit print(f"āœ… Found backbone.vit: {type(vit_model).__name__}") else: raise AttributeError("No backbone found") if hasattr(vit_model, 'encoder') and hasattr(vit_model.encoder, 'layer'): num_layers = len(vit_model.encoder.layer) last_layer = vit_model.encoder.layer[-1] print(f"āœ… Found encoder with {num_layers} layers") print(f"āœ… Target layer: encoder.layer[-1] (layer {num_layers-1})") return last_layer else: raise AttributeError("No encoder.layer found") except AttributeError as e: print(f"āŒ Error: {e}") raise ValueError("Could not find DINOv2 encoder layers") def generate_heatmap(self, image_tensor, target_class_idx, head_type='disease'): """ Generate Grad-CAM heatmap for a specific head Args: image_tensor: Input image tensor [1, 3, H, W] target_class_idx: Target class index head_type: 'disease', 'severity', or 'stage' Returns: heatmap: Grad-CAM heatmap resized to input image size """ # Create Grad-CAM instance gradcam = GradCAM(self.model, self.target_layer) # Forward pass with gradients enabled image_tensor = image_tensor.to(self.device) image_tensor.requires_grad = True # Get logits based on head type disease_logits, severity_logits, stage_logits, _ = self.model([image_tensor], enable_gradients=True) if head_type == 'disease': logits = disease_logits elif head_type == 'severity': disease_pred_idx = disease_logits.argmax(dim=1).item() disease_name = DISEASE_CLASS_MAPPING.get(disease_pred_idx, f"Unknown_{disease_pred_idx}") logits = severity_logits[disease_name] elif head_type == 'stage': logits = stage_logits else: raise ValueError(f"Unknown head_type: {head_type}") # Generate CAM cam = gradcam.generate_cam(target_class_idx, logits) # Clean up hooks gradcam.remove_hooks() # Resize to match input image size H, W = image_tensor.shape[2], image_tensor.shape[3] cam_resized = cv2.resize(cam, (W, H), interpolation=cv2.INTER_CUBIC) # Smooth the heatmap cam_smooth = gaussian_filter(cam_resized, sigma=2) return cam_smooth # ------------------------------------------- # Helper Functions # ------------------------------------------- def create_gradcam_overlay(image_array, heatmap, alpha=0.5, colormap='jet'): """ Create a visual overlay of Grad-CAM++ heatmap on original image Uses red-yellow colormap like traditional Grad-CAM Args: image_array: Original image as numpy array (H, W, 3) heatmap: Grad-CAM++ heatmap (H, W), values in [0, 1] alpha: Transparency of heatmap overlay colormap: Matplotlib colormap name Returns: Overlayed image as numpy array (H, W, 3) in range [0, 1] """ # Normalize image to [0, 1] img_normalized = image_array.astype(np.float32) / 255.0 # Resize heatmap to match image size if needed target_h, target_w = img_normalized.shape[:2] if heatmap.shape != (target_h, target_w): print(f" šŸ“ Resizing heatmap from {heatmap.shape} to ({target_h}, {target_w})") heatmap = cv2.resize(heatmap, (target_w, target_h), interpolation=cv2.INTER_CUBIC) # Ensure heatmap is in [0, 1] heatmap = np.clip(heatmap, 0, 1) # Apply colormap - Fixed for newer matplotlib import matplotlib cmap = matplotlib.colormaps.get_cmap(colormap) heatmap_colored = cmap(heatmap)[:, :, :3] # Remove alpha channel # Blend with original image overlay = img_normalized * (1 - alpha) + heatmap_colored * alpha overlay = np.clip(overlay, 0, 1) return overlay def load_and_preprocess_image(img_path, target_size=224): """Load image and preprocess for both visualization and model input""" try: if img_path.lower().endswith(('.svs', '.ndpi')): slide = openslide.OpenSlide(img_path) img_pil = slide.get_thumbnail((target_size, target_size)) slide.close() elif img_path.lower().endswith('.tif'): try: slide = openslide.OpenSlide(img_path) img_pil = slide.get_thumbnail((target_size, target_size)) slide.close() except: img_pil = Image.open(img_path).convert('RGB') img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR) else: img_pil = Image.open(img_path).convert('RGB') img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR) img_array = np.array(img_pil) img_tensor = test_transform(img_pil).unsqueeze(0) return img_pil, img_array, img_tensor except Exception as e: print(f"Error loading image {img_path}: {e}") return None, None, None def extract_stage_label(img_path): """ Extract stage label using strict hierarchy rules. Valid structures: - .../abnormal/test/image.png -> no stage -> return -1 - .../abnormal//test/image.png -> stage exists -> return """ path_parts = img_path.split(os.sep) try: # Find 'test' folder test_idx = path_parts.index('test') # Folder immediately above 'test' candidate = path_parts[test_idx - 1] # If abnormal is directly above test → no stage if candidate.lower() == 'abnormal': return -1 # Otherwise, this folder is the stage return candidate except (ValueError, IndexError): # 'test' not found or malformed path return -1 # ------------------------------------------- # Main Tri-Head Grad-CAM++ Analysis # ------------------------------------------- def run_tri_head_gradcam_plus_plus_analysis(model, device, collected_images): """ Run Grad-CAM++ analysis on disease head (Level 1), severity head (Level 2), and stage head (Level 3) Stage head is only analyzed when ground truth is abnormal and stage label is not -1 Returns comprehensive dictionaries with all analysis results """ print("\n" + "="*80) print("šŸ” STARTING TRI-HEAD GRAD-CAM++ ANALYSIS") print("="*80) if not collected_images: print("āŒ No images provided") return [] print(f"āœ… Processing {len(collected_images)} images\n") # Initialize Grad-CAM++ wrapper gradcam_wrapper = GradCAMVisionTransformer(model, device) all_results = [] for idx, img_path in enumerate(collected_images): print(f"\n[{idx+1}/{len(collected_images)}] Processing: {os.path.basename(img_path)}") try: # Load image img_pil, img_array, img_tensor = load_and_preprocess_image(img_path, target_size=224) if img_tensor is None: print(f" āŒ Failed to load image") continue # Extract metadata path_parts = img_path.split(os.sep) dataset_name = "Unknown" true_label = "unknown" stage_label = extract_stage_label(img_path) for part in path_parts: if part in ["ovarian-cancer_processed", "oral-cancer_processed", "Breast_cancer", "colon_processed", "lung_processed", "annrbc-anemia_processed", "leukemia_processed", "sickle-cell-new_processed", "thalassemia_processed"]: dataset_name = part if part in ["normal", "abnormal"]: true_label = part break # ===== STEP 1: Get predictions from all heads ===== with torch.no_grad(): img_tensor_device = img_tensor.to(device) disease_logits, severity_logits_dict, stage_logits, _ = model([img_tensor_device]) # Disease prediction (Level 1) disease_pred_idx = disease_logits.argmax(dim=1).item() disease_probs = F.softmax(disease_logits, dim=1) disease_confidence = disease_probs[0, disease_pred_idx].item() disease_name = DISEASE_CLASS_MAPPING.get(disease_pred_idx, f"Unknown_{disease_pred_idx}") disease_all_probs = disease_probs[0].cpu().numpy() print(f" šŸ“Š Level 1 (Disease): {disease_name}") print(f" Index: {disease_pred_idx}, Confidence: {disease_confidence:.4f}") # Severity prediction (Level 2) severity_logits = severity_logits_dict[disease_name] severity_pred_idx = severity_logits.argmax(dim=1).item() severity_probs = F.softmax(severity_logits, dim=1) severity_confidence = severity_probs[0, severity_pred_idx].item() severity_label_text = SEVERITY_CLASS_MAPPING.get(severity_pred_idx, f"Unknown_{severity_pred_idx}") severity_all_probs = severity_probs[0].cpu().numpy() print(f" šŸ“Š Level 2 (Severity): {severity_label_text}") print(f" Index: {severity_pred_idx}, Confidence: {severity_confidence:.4f}") # Stage prediction (Level 3) - if available stage_pred_idx = None stage_confidence = None stage_all_probs = None if stage_logits is not None: stage_pred_idx = stage_logits.argmax(dim=1).item() stage_probs = F.softmax(stage_logits, dim=1) stage_confidence = stage_probs[0, stage_pred_idx].item() stage_all_probs = stage_probs[0].cpu().numpy() print(f" šŸ“Š Level 3 (Stage): Stage {stage_pred_idx}") print(f" Confidence: {stage_confidence:.4f}") print(f" Ground Truth Stage: {stage_label}") # ===== STEP 2: Generate Grad-CAM++ for disease head ===== print(f"\n šŸ”„ Generating Grad-CAM++ for Disease Head...") disease_heatmap = gradcam_wrapper.generate_heatmap( img_tensor.clone(), disease_pred_idx, head_type='disease' ) # Create overlay disease_overlay = create_gradcam_overlay(img_array, disease_heatmap, alpha=0.5) print(f" āœ… Disease heatmap generated") print(f" Min: {disease_heatmap.min():.4f}, Max: {disease_heatmap.max():.4f}") # ===== STEP 3: Generate Grad-CAM++ for severity head ===== print(f" šŸ”„ Generating Grad-CAM++ for Severity Head...") severity_heatmap = gradcam_wrapper.generate_heatmap( img_tensor.clone(), severity_pred_idx, head_type='severity' ) # Create overlay severity_overlay = create_gradcam_overlay(img_array, severity_heatmap, alpha=0.5) print(f" āœ… Severity heatmap generated") print(f" Min: {severity_heatmap.min():.4f}, Max: {severity_heatmap.max():.4f}") # ===== STEP 4: Generate Grad-CAM++ for stage head (conditional) ===== stage_heatmap = None stage_overlay = None include_stage_analysis = False # Check conditions: abnormal ground truth AND stage label != -1 if true_label == "abnormal" and stage_label != -1 and stage_logits is not None: include_stage_analysis = True print(f" šŸ”„ Generating Grad-CAM++ for Stage Head (GT: abnormal, Stage: {stage_label})...") stage_heatmap = gradcam_wrapper.generate_heatmap( img_tensor.clone(), stage_pred_idx, head_type='stage' ) # Create overlay stage_overlay = create_gradcam_overlay(img_array, stage_heatmap, alpha=0.5) print(f" āœ… Stage heatmap generated") print(f" Min: {stage_heatmap.min():.4f}, Max: {stage_heatmap.max():.4f}") else: reason = [] if true_label != "abnormal": reason.append(f"true_label='{true_label}'") if stage_label == -1: reason.append("stage_label=-1") if stage_logits is None: reason.append("stage_logits=None") print(f" ā­ļø Skipping Stage Head Analysis ({', '.join(reason)})") # ===== STEP 5: Create Union Heatmap ===== if include_stage_analysis: # Average of all three heatmaps union_heatmap = (disease_heatmap + severity_heatmap + stage_heatmap) / 3.0 print(f" šŸ“Š Union Heatmap: Average of 3 heads (Disease + Severity + Stage)") else: # Average of two heatmaps union_heatmap = (disease_heatmap + severity_heatmap) / 2.0 print(f" šŸ“Š Union Heatmap: Average of 2 heads (Disease + Severity)") union_overlay = create_gradcam_overlay(img_array, union_heatmap, alpha=0.5) # ===== STEP 6: Calculate statistics ===== disease_mean_activation = float(disease_heatmap.mean()) disease_max_activation = float(disease_heatmap.max()) severity_mean_activation = float(severity_heatmap.mean()) severity_max_activation = float(severity_heatmap.max()) stage_mean_activation = None stage_max_activation = None if stage_heatmap is not None: stage_mean_activation = float(stage_heatmap.mean()) stage_max_activation = float(stage_heatmap.max()) union_mean_activation = float(union_heatmap.mean()) union_max_activation = float(union_heatmap.max()) # ===== STEP 7: Compile comprehensive results dictionary ===== result_dict = { # ===== Image Information ===== 'filename': os.path.basename(img_path), 'full_path': img_path, 'dataset_name': dataset_name, 'true_label': true_label, 'stage_label': stage_label, 'include_stage_analysis': include_stage_analysis, # ===== Original Image ===== 'image': img_array, # ===== Level 1: Disease Head Results ===== 'level1_disease': { 'predicted_class': disease_name, 'predicted_idx': disease_pred_idx, 'confidence': disease_confidence, 'all_probabilities': disease_all_probs, 'heatmap_raw': disease_heatmap, 'heatmap_overlay': disease_overlay, 'activation_stats': { 'mean': disease_mean_activation, 'max': disease_max_activation, } }, # ===== Level 2: Severity Head Results ===== 'level2_severity': { 'predicted_class': severity_label_text, 'predicted_idx': severity_pred_idx, 'confidence': severity_confidence, 'all_probabilities': severity_all_probs, 'heatmap_raw': severity_heatmap, 'heatmap_overlay': severity_overlay, 'activation_stats': { 'mean': severity_mean_activation, 'max': severity_max_activation, } }, # ===== Level 3: Stage Head Results (conditional) ===== 'level3_stage': { 'predicted_idx': stage_pred_idx, 'confidence': stage_confidence, 'all_probabilities': stage_all_probs, 'heatmap_raw': stage_heatmap, 'heatmap_overlay': stage_overlay, 'activation_stats': { 'mean': stage_mean_activation, 'max': stage_max_activation, } if stage_heatmap is not None else None }, # ===== Union Results ===== 'union': { 'heatmap_raw': union_heatmap, 'heatmap_overlay': union_overlay, 'num_heads_averaged': 3 if include_stage_analysis else 2, 'activation_stats': { 'mean': union_mean_activation, 'max': union_max_activation, } }, # ===== Legacy Fields ===== 'disease_heatmap': disease_heatmap, 'severity_heatmap': severity_heatmap, 'stage_heatmap': stage_heatmap, 'disease_pred': disease_name, 'disease_idx': disease_pred_idx, 'disease_conf': disease_confidence, 'severity_pred': severity_label_text, 'severity_idx': severity_pred_idx, 'severity_conf': severity_confidence, 'stage_pred_idx': stage_pred_idx, 'stage_conf': stage_confidence, } all_results.append(result_dict) print(f" āœ… Completed tri-head Grad-CAM++ analysis") except Exception as e: print(f" āŒ Error: {e}") import traceback traceback.print_exc() continue print("\n" + "="*80) print("āœ… TRI-HEAD GRAD-CAM++ ANALYSIS COMPLETE") print(f"šŸ“¦ Generated {len(all_results)} comprehensive result dictionaries") print("="*80) return all_results # ------------------------------------------- # Visualization Function # ------------------------------------------- def display_tri_head_gradcam_grid(results): """ Display grid: each row = one image with 5 columns (or 4 if no stage analysis) [Original | Disease Grad-CAM++ | Severity Grad-CAM++ | Stage Grad-CAM++ (if available) | Union] """ if not results: print("No results to display") return num_images = len(results) max_cols = 5 # Create figure fig, axes = plt.subplots(num_images, max_cols, figsize=(35, 7 * num_images)) # Handle single image case if num_images == 1: axes = axes.reshape(1, -1) cmap = plt.cm.jet for i, result in enumerate(results): has_stage = result['include_stage_analysis'] # Column 1: Original Image axes[i, 0].imshow(result['image']) title_text = ( f"Original Image {i+1}\n" f"Dataset: {result['dataset_name']}\n" f"True Label: {result['true_label']}\n" f"Stage GT: {result['stage_label']}\n" f"File: {result['filename']}" ) axes[i, 0].set_title(title_text, fontsize=10, fontweight='bold', pad=10) axes[i, 0].axis('off') # Column 2: Disease Head Grad-CAM++ (Level 1) axes[i, 1].imshow(result['image']) disease_heatmap = result['level1_disease']['heatmap_raw'] im1 = axes[i, 1].imshow( disease_heatmap, cmap=cmap, alpha=0.5, vmin=0, vmax=1 ) cbar1 = plt.colorbar(im1, ax=axes[i, 1], fraction=0.046, pad=0.04) cbar1.set_label('Activation', rotation=270, labelpad=15) disease_title = ( f"Level 1: Disease Head\n" f"Predicted: {result['level1_disease']['predicted_class']}\n" f"Confidence: {result['level1_disease']['confidence']:.4f}\n" f"Mean Act: {result['level1_disease']['activation_stats']['mean']:.4f}" ) axes[i, 1].set_title(disease_title, fontsize=10, fontweight='bold', pad=10) axes[i, 1].axis('off') # Column 3: Severity Head Grad-CAM++ (Level 2) axes[i, 2].imshow(result['image']) severity_heatmap = result['level2_severity']['heatmap_raw'] im2 = axes[i, 2].imshow( severity_heatmap, cmap=cmap, alpha=0.5, vmin=0, vmax=1 ) cbar2 = plt.colorbar(im2, ax=axes[i, 2], fraction=0.046, pad=0.04) cbar2.set_label('Activation', rotation=270, labelpad=15) severity_title = ( f"Level 2: Severity Head\n" f"Predicted: {result['level2_severity']['predicted_class']}\n" f"Confidence: {result['level2_severity']['confidence']:.4f}\n" f"Mean Act: {result['level2_severity']['activation_stats']['mean']:.4f}" ) axes[i, 2].set_title(severity_title, fontsize=10, fontweight='bold', pad=10) axes[i, 2].axis('off') # Column 4: Stage Head Grad-CAM++ (Level 3) - Conditional if has_stage: axes[i, 3].imshow(result['image']) stage_heatmap = result['level3_stage']['heatmap_raw'] im3 = axes[i, 3].imshow( stage_heatmap, cmap=cmap, alpha=0.5, vmin=0, vmax=1 ) cbar3 = plt.colorbar(im3, ax=axes[i, 3], fraction=0.046, pad=0.04) cbar3.set_label('Activation', rotation=270, labelpad=15) stage_title = ( f"Level 3: Stage Head\n" f"Predicted: Stage {result['level3_stage']['predicted_idx']}\n" f"Confidence: {result['level3_stage']['confidence']:.4f}\n" f"Mean Act: {result['level3_stage']['activation_stats']['mean']:.4f}" ) axes[i, 3].set_title(stage_title, fontsize=10, fontweight='bold', pad=10) axes[i, 3].axis('off') else: # Display placeholder text axes[i, 3].text( 0.5, 0.5, "Stage Analysis\nNot Applicable\n\n" + (f"Reason: GT={result['true_label']}\n" if result['true_label'] != 'abnormal' else "") + (f"Stage={result['stage_label']}" if result['stage_label'] == -1 else ""), ha='center', va='center', fontsize=12, color='gray', transform=axes[i, 3].transAxes ) axes[i, 3].axis('off') # Column 5: Union Grad-CAM++ axes[i, 4].imshow(result['image']) union_heatmap = result['union']['heatmap_raw'] im4 = axes[i, 4].imshow( union_heatmap, cmap=cmap, alpha=0.5, vmin=0, vmax=1 ) cbar4 = plt.colorbar(im4, ax=axes[i, 4], fraction=0.046, pad=0.04) cbar4.set_label('Activation', rotation=270, labelpad=15) union_title = ( f"Union: Combined Grad-CAM++\n" f"Averaged {result['union']['num_heads_averaged']} Heads\n" f"Disease: {result['level1_disease']['predicted_class']}\n" f"Severity: {result['level2_severity']['predicted_class']}" ) if has_stage: union_title += f"\nStage: {result['level3_stage']['predicted_idx']}" union_title += f"\nMean Act: {result['union']['activation_stats']['mean']:.4f}" axes[i, 4].set_title(union_title, fontsize=10, fontweight='bold', pad=10) axes[i, 4].axis('off') # Print statistics print(f"\nšŸ“Š Image {i+1} ({result['filename']}) Statistics:") print(f" Disease Head: {result['level1_disease']['predicted_class']} " f"({result['level1_disease']['confidence']:.4f})") print(f" Mean Activation: {result['level1_disease']['activation_stats']['mean']:.4f}, " f"Max: {result['level1_disease']['activation_stats']['max']:.4f}") print(f" Severity Head: {result['level2_severity']['predicted_class']} " f"({result['level2_severity']['confidence']:.4f})") print(f" Mean Activation: {result['level2_severity']['activation_stats']['mean']:.4f}, " f"Max: {result['level2_severity']['activation_stats']['max']:.4f}") if has_stage: print(f" Stage Head: Stage {result['level3_stage']['predicted_idx']} " f"({result['level3_stage']['confidence']:.4f})") print(f" Mean Activation: {result['level3_stage']['activation_stats']['mean']:.4f}, " f"Max: {result['level3_stage']['activation_stats']['max']:.4f}") else: print(f" Stage Head: Not analyzed (GT: {result['true_label']}, Stage: {result['stage_label']})") print(f" Union Heatmap ({result['union']['num_heads_averaged']} heads):") print(f" Mean Activation: {result['union']['activation_stats']['mean']:.4f}, " f"Max: {result['union']['activation_stats']['max']:.4f}") plt.suptitle( 'Hierarchical Model - Tri-Head Grad-CAM++ Analysis with Union\n' 'Level 1: Disease | Level 2: Severity | Level 3: Stage (Conditional) | Union: Combined Analysis\n' 'Red = High Activation | Blue = Low Activation', fontsize=16, fontweight='bold', y=0.998 ) plt.tight_layout() # Save grid_save_path = os.path.join(GRADCAM_OUTPUT_DIR, 'tri_head_union_gradcam_plus_plus_analysis.png') plt.savefig(grid_save_path, dpi=150, bbox_inches='tight') print(f"\nāœ… Grid saved to: {grid_save_path}") plt.show() # ------------------------------------------- # Execute Analysis # ------------------------------------------- print("\n" + "="*80) print("CHECKING FOR COLLECTED IMAGES") print("="*80) try: if 'collected_images' in locals() or 'collected_images' in globals(): print(f"āœ… Found collected_images with {len(collected_images)} images\n") # Run tri-head Grad-CAM++ analysis gradcam_results = run_tri_head_gradcam_plus_plus_analysis( model, device, collected_images ) # Display results if gradcam_results: display_tri_head_gradcam_grid(gradcam_results) # Save summary results_summary = [] for r in gradcam_results: summary_row = { 'filename': r['filename'], 'dataset': r['dataset_name'], 'true_label': r['true_label'], 'stage_gt': r['stage_label'], 'disease_predicted': r['level1_disease']['predicted_class'], 'disease_confidence': r['level1_disease']['confidence'], 'disease_mean_activation': r['level1_disease']['activation_stats']['mean'], 'severity_predicted': r['level2_severity']['predicted_class'], 'severity_confidence': r['level2_severity']['confidence'], 'severity_mean_activation': r['level2_severity']['activation_stats']['mean'], } if r['include_stage_analysis']: summary_row.update({ 'stage_predicted': r['level3_stage']['predicted_idx'], 'stage_confidence': r['level3_stage']['confidence'], 'stage_mean_activation': r['level3_stage']['activation_stats']['mean'], }) else: summary_row.update({ 'stage_predicted': 'N/A', 'stage_confidence': 'N/A', 'stage_mean_activation': 'N/A', }) summary_row['union_heads_averaged'] = r['union']['num_heads_averaged'] summary_row['union_mean_activation'] = r['union']['activation_stats']['mean'] results_summary.append(summary_row) summary_df = pd.DataFrame(results_summary) summary_path = os.path.join(GRADCAM_OUTPUT_DIR, 'tri_head_union_gradcam_plus_plus_summary.csv') summary_df.to_csv(summary_path, index=False) print(f"\nāœ… Summary saved to: {summary_path}") print("\n" + "="*80) print("TRI-HEAD GRAD-CAM++ ANALYSIS SUMMARY") print("="*80) print(summary_df.to_string(index=False)) print("="*80) # Print structure of results for reference print("\n" + "="*80) print("šŸ“¦ RESULTS STRUCTURE") print("="*80) print("Each result dictionary contains:") print(" - filename, full_path, dataset_name, true_label, stage_label") print(" - include_stage_analysis: boolean flag") print(" - image: original image array") print(" - level1_disease: {") print(" predicted_class, predicted_idx, confidence, all_probabilities") print(" heatmap_raw, heatmap_overlay, activation_stats") print(" }") print(" - level2_severity: {") print(" predicted_class, predicted_idx, confidence, all_probabilities") print(" heatmap_raw, heatmap_overlay, activation_stats") print(" }") print(" - level3_stage: {") print(" predicted_idx, confidence, all_probabilities") print(" heatmap_raw (None if not analyzed), heatmap_overlay (None if not analyzed)") print(" activation_stats (None if not analyzed)") print(" }") print(" - union: {") print(" heatmap_raw (average of 2 or 3 heads)") print(" heatmap_overlay, num_heads_averaged, activation_stats") print(" }") print("="*80) print(f"\nāœ… gradcam_results variable contains {len(gradcam_results)} dictionaries") print(" Use gradcam_results in the next cell for further analysis!") else: print("\nāŒ No results generated") else: print("āŒ collected_images not found!") except Exception as e: print(f"āŒ Error: {e}") import traceback traceback.print_exc() import torch import torch.nn.functional as F import matplotlib.pyplot as plt import numpy as np from PIL import Image import cv2 from scipy.ndimage import zoom, gaussian_filter import os # ------------------------------------------- # Configuration # ------------------------------------------- OUTPUT_DIR="/kaggle/working" ATTENTION_OUTPUT_DIR = os.path.join(OUTPUT_DIR, 'attention_results') os.makedirs(ATTENTION_OUTPUT_DIR, exist_ok=True) print("="*80) print("ATTENTION VISUALIZATION CONFIGURATION") print("="*80) print(f"Output directory: {ATTENTION_OUTPUT_DIR}") print("="*80) # ------------------------------------------- # Attention Extraction Wrapper # ------------------------------------------- class AttentionExtractor(nn.Module): """ Wrapper to extract attention weights from the hierarchical model """ def __init__(self, phase3_model): super().__init__() self.phase3_model = phase3_model self.attention_weights = None self.tile_features = None def forward(self, tiles): """ Extract attention weights and tile features from the model """ # Get model outputs including attention weights disease_logits, severity_logits, stage_logits, attention_weights = self.phase3_model(tiles) # Store attention weights for visualization self.attention_weights = attention_weights return disease_logits, severity_logits, stage_logits, attention_weights # Create attention extractor attention_extractor = AttentionExtractor(model).to(device) attention_extractor.eval() print("\nāœ… Attention extractor created successfully\n") # ------------------------------------------- # Helper Functions # ------------------------------------------- def extract_attention_map(model, preprocessed_image, device): """ Extract attention weights from the model for a single image Since each image is a single tile, we get attention for that single representation Args: model: AttentionExtractor model preprocessed_image: Preprocessed tensor (single image, not tiled) device: torch device Returns: attention_weights: numpy array of attention weight (single value for single tile) disease_logits, severity_logits, stage_logits: model outputs """ model.eval() try: # Each preprocessed image is already a single tensor of shape (C, H, W) # We need to add batch dimension and wrap in list if preprocessed_image.dim() == 3: # Single image: (C, H, W) -> (1, C, H, W) image_batch = preprocessed_image.unsqueeze(0) else: # Already has batch dimension image_batch = preprocessed_image # Wrap in list as model expects list of tile batches # Since we have single image as single tile, this is [1 tile batch] tiles_list = [image_batch.to(device)] with torch.no_grad(): disease_logits, severity_logits, stage_logits, attention_weights = model(tiles_list) # Convert attention weights to numpy # For single tile, this will be shape (1,) or (1, 1) attention_np = attention_weights.squeeze().cpu().numpy() # Ensure it's at least 1D if attention_np.ndim == 0: attention_np = np.array([attention_np.item()]) print(f" Extracted attention weights: shape={attention_np.shape}, value={attention_np}") return attention_np, disease_logits, severity_logits, stage_logits except Exception as e: print(f" Error extracting attention: {e}") import traceback traceback.print_exc() return None, None, None, None def create_uniform_attention_heatmap(attention_weight, image_shape): """ Create a uniform attention heatmap for a single tile (entire image) Since the whole image is one tile, the attention is uniform across it Args: attention_weight: single attention weight value image_shape: tuple (height, width) of image Returns: heatmap: 2D array with uniform attention value """ # Since we have single tile = whole image, create uniform heatmap # with the attention weight value heatmap = np.full(image_shape, attention_weight, dtype=np.float32) # Normalize to [0, 1] for visualization if heatmap.max() > 0: heatmap = heatmap / heatmap.max() return heatmap def extract_patch_level_attention(model_backbone, preprocessed_image, device, patch_size=16): """ Extract patch-level attention from ViT backbone ViT processes image as patches, we can visualize their importance Args: model_backbone: ViT backbone model preprocessed_image: Preprocessed tensor device: torch device patch_size: ViT patch size (default 16 for most ViTs) Returns: patch_attention_map: 2D heatmap showing patch-level importance """ try: if preprocessed_image.dim() == 3: image_batch = preprocessed_image.unsqueeze(0).to(device) else: image_batch = preprocessed_image.to(device) with torch.no_grad(): # Get ViT outputs - last_hidden_state contains all patch embeddings outputs = model_backbone.vit(pixel_values=image_batch) # Shape: (batch, num_patches + 1, embed_dim) # First token is CLS token, rest are patch tokens hidden_states = outputs.last_hidden_state # Get patch tokens (exclude CLS token at index 0) patch_tokens = hidden_states[:, 1:, :] # (1, num_patches, embed_dim) # Compute importance as L2 norm of each patch embedding patch_importance = torch.norm(patch_tokens, p=2, dim=2).squeeze().cpu().numpy() # Calculate grid dimensions # For 224x224 image with patch_size=16: 14x14 = 196 patches num_patches = len(patch_importance) grid_size = int(np.sqrt(num_patches)) # Reshape to 2D grid attention_grid = patch_importance.reshape(grid_size, grid_size) print(f" Extracted patch-level attention: {grid_size}x{grid_size} patches") return attention_grid except Exception as e: print(f" Error extracting patch attention: {e}") import traceback traceback.print_exc() return None def create_patch_attention_heatmap(patch_attention_grid, target_shape): """ Upsample patch-level attention to image dimensions Args: patch_attention_grid: 2D grid of patch attention values target_shape: tuple (height, width) for output Returns: heatmap: upsampled attention heatmap """ # Calculate zoom factors zoom_factors = (target_shape[0] / patch_attention_grid.shape[0], target_shape[1] / patch_attention_grid.shape[1]) # Upsample using bilinear interpolation heatmap = zoom(patch_attention_grid, zoom_factors, order=1) # Normalize to [0, 1] heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) # Apply smoothing for better visualization heatmap = gaussian_filter(heatmap, sigma=5) return heatmap def load_image_for_attention(img_path, target_size=768): """Load and resize image for attention visualization""" try: if img_path.lower().endswith(('.svs', '.ndpi')): slide = openslide.OpenSlide(img_path) img_pil = slide.get_thumbnail((target_size, target_size)) slide.close() elif img_path.lower().endswith('.tif'): try: slide = openslide.OpenSlide(img_path) img_pil = slide.get_thumbnail((target_size, target_size)) slide.close() except: img_pil = Image.open(img_path).convert('RGB') img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR) else: img_pil = Image.open(img_path).convert('RGB') img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR) img_array = np.array(img_pil) return img_pil, img_array except Exception as e: print(f"Error loading image {img_path}: {e}") return None, None # ------------------------------------------- # Main Attention Extraction Function # ------------------------------------------- def run_attention_analysis(attention_model, device, collected_images, processed_images, main_class_mapping, stage_class_mapping): """ Extract and visualize attention weights from the hierarchical model Uses patch-level attention from ViT backbone since images are single tiles Args: attention_model: AttentionExtractor model device: torch device collected_images: list of image paths processed_images: list of preprocessed tensors (single images) main_class_mapping: dictionary mapping class indices to names stage_class_mapping: dictionary mapping stage indices to names Returns: list of results dictionaries """ print("\n" + "="*80) print("šŸŽÆ STARTING ATTENTION WEIGHT EXTRACTION AND VISUALIZATION") print("="*80) print("ā„¹ļø Note: Each image is treated as a single tile") print("ā„¹ļø Using patch-level attention from ViT backbone for visualization") if not collected_images or not processed_images: print("āŒ No images or preprocessed data provided") return [] print(f"āœ… Processing {len(collected_images)} images\n") all_results = [] for idx, (img_path, preprocessed_image) in enumerate(zip(collected_images, processed_images)): print(f"\n[{idx+1}/{len(collected_images)}] Processing: {os.path.basename(img_path)}") print(f" Path: {img_path}") print(f" Image shape: {preprocessed_image.shape}") try: # Load original image for visualization img_pil, img_array = load_image_for_attention(img_path, target_size=768) if img_array is None: print(f" āŒ Failed to load image") continue # Extract MIL-level attention weights (single value for single tile) attention_weights, disease_logits, severity_logits, stage_logits = extract_attention_map( attention_model, preprocessed_image, device ) if attention_weights is None: print(f" āŒ Failed to extract attention") continue # Extract patch-level attention from ViT backbone patch_attention = extract_patch_level_attention( attention_model.phase3_model.backbone, preprocessed_image, device ) if patch_attention is not None: # Create heatmap from patch attention attention_heatmap = create_patch_attention_heatmap( patch_attention, img_array.shape[:2] ) print(f" āœ… Created patch-level attention heatmap") else: # Fallback: uniform heatmap with MIL attention weight attention_heatmap = create_uniform_attention_heatmap( attention_weights[0], img_array.shape[:2] ) print(f" ā„¹ļø Using uniform attention heatmap") # Get predictions with torch.no_grad(): disease_probs = F.softmax(disease_logits, dim=1) disease_pred_idx = torch.argmax(disease_probs, dim=1).item() disease_confidence = disease_probs[0, disease_pred_idx].item() predicted_class_name = main_class_mapping.get( disease_pred_idx, f"Unknown_Class_{disease_pred_idx}" ) # Get severity prediction if "_normal" in predicted_class_name: predicted_disease = predicted_class_name.replace("_normal", "") severity_label = "Normal" elif "_abnormal" in predicted_class_name: predicted_disease = predicted_class_name.replace("_abnormal", "") severity_label = "Abnormal" else: predicted_disease = predicted_class_name severity_label = "Unknown" print(f" šŸ“Š Prediction: {predicted_class_name}") print(f" Confidence: {disease_confidence:.4f}") print(f" MIL Attention Weight: {attention_weights[0]:.4f}") # Extract dataset and true label from path path_parts = img_path.split(os.sep) dataset_name = "Unknown" true_label = "unknown" for part in path_parts: if "processed" in part or "cancer" in part.lower(): dataset_name = part if part in ["normal", "abnormal"]: true_label = part break # Store results result = { 'image': img_array, 'attention_heatmap': attention_heatmap, 'mil_attention_weight': attention_weights[0], 'true_label': true_label, 'dataset_name': dataset_name, 'predicted_class': predicted_class_name, 'predicted_disease': predicted_disease, 'severity': severity_label, 'class_idx': disease_pred_idx, 'confidence': disease_confidence, 'filename': os.path.basename(img_path), 'full_path': img_path } all_results.append(result) # Save individual attention heatmap (raw, for OpenCV processing) heatmap_filename = f"attention_heatmap_{idx+1}_{os.path.splitext(os.path.basename(img_path))[0]}.npy" heatmap_path = os.path.join(ATTENTION_OUTPUT_DIR, heatmap_filename) np.save(heatmap_path, attention_heatmap) print(f" āœ… Attention heatmap saved to: {heatmap_filename}") print(f" āœ… Completed analysis") except Exception as e: print(f" āŒ Error: {e}") import traceback traceback.print_exc() continue print("\n" + "="*80) print("āœ… ATTENTION EXTRACTION COMPLETE") print(f"šŸ“ Results saved to: {ATTENTION_OUTPUT_DIR}") print("="*80) return all_results # ------------------------------------------- # Visualization Function # ------------------------------------------- def display_attention_grid(results): """ Display grid with original images and attention heatmap overlays """ if not results: print("No results to display") return num_images = len(results) # Create figure: 3 columns (original, heatmap, overlay) fig, axes = plt.subplots(num_images, 3, figsize=(18, 6 * num_images)) # Handle single image case if num_images == 1: axes = axes.reshape(1, -1) # Use 'jet' colormap for attention (blue to red) cmap = plt.cm.jet for i, result in enumerate(results): # Column 1: Original Image axes[i, 0].imshow(result['image']) title_text = ( f"Original Image {i+1}\n" f"Dataset: {result['dataset_name']}\n" f"True Label: {result['true_label']}\n" f"File: {result['filename'][:30]}..." ) axes[i, 0].set_title(title_text, fontsize=10, fontweight='bold', pad=10) axes[i, 0].axis('off') # Column 2: Attention Heatmap im = axes[i, 1].imshow(result['attention_heatmap'], cmap=cmap) cbar = plt.colorbar(im, ax=axes[i, 1], fraction=0.046, pad=0.04) cbar.set_label('Attention Weight', rotation=270, labelpad=15) heatmap_title = ( f"Attention Heatmap {i+1}\n" f"Patch-Level Importance\n" f"MIL Weight: {result['mil_attention_weight']:.4f}" ) axes[i, 1].set_title(heatmap_title, fontsize=10, fontweight='bold', pad=10) axes[i, 1].axis('off') # Column 3: Overlay axes[i, 2].imshow(result['image']) axes[i, 2].imshow(result['attention_heatmap'], cmap=cmap, alpha=0.5) overlay_title = ( f"Overlay {i+1}\n" f"Predicted: {result['predicted_class']}\n" f"Confidence: {result['confidence']:.4f}" ) axes[i, 2].set_title(overlay_title, fontsize=10, fontweight='bold', pad=10) axes[i, 2].axis('off') # Print statistics high_attention = np.sum(result['attention_heatmap'] > 0.7) / result['attention_heatmap'].size * 100 medium_attention = np.sum((result['attention_heatmap'] > 0.4) & (result['attention_heatmap'] <= 0.7)) / result['attention_heatmap'].size * 100 low_attention = np.sum(result['attention_heatmap'] <= 0.4) / result['attention_heatmap'].size * 100 print(f"\nšŸ“Š Image {i+1} ({result['filename']}) Attention Statistics:") print(f" Predicted: {result['predicted_class']}") print(f" Confidence: {result['confidence']:.4f}") print(f" MIL Attention Weight: {result['mil_attention_weight']:.4f}") print(f" High attention regions (>0.7): {high_attention:.1f}%") print(f" Medium attention regions (0.4-0.7): {medium_attention:.1f}%") print(f" Low attention regions (<0.4): {low_attention:.1f}%") plt.suptitle( 'Hierarchical Model - Patch-Level Attention Visualization\n' 'Warmer colors (red/yellow) indicate higher attention | Cooler colors (blue) indicate lower attention', fontsize=16, fontweight='bold', y=0.998 ) plt.tight_layout() # Save grid grid_save_path = os.path.join(ATTENTION_OUTPUT_DIR, 'attention_visualization_grid.png') plt.savefig(grid_save_path, dpi=150, bbox_inches='tight') print(f"\nāœ… Grid visualization saved to: {grid_save_path}") plt.show() # ------------------------------------------- # Execute Attention Analysis # ------------------------------------------- print("\n" + "="*80) print("CHECKING FOR COLLECTED AND PROCESSED IMAGES") print("="*80) try: # Check if required variables exist if 'collected_images' in locals() or 'collected_images' in globals(): if 'processed_images' in locals() or 'processed_images' in globals(): print(f"āœ… Found collected_images: {len(collected_images)} images") print(f"āœ… Found processed_images: {len(processed_images)} tensors\n") # Run attention analysis attention_results = run_attention_analysis( attention_extractor, device, collected_images, processed_images, DISEASE_CLASS_MAPPING, STAGE_CLASS_MAPPING ) # Display results if attention_results: display_attention_grid(attention_results) # Save results summary results_summary = [] for r in attention_results: high_attn = np.sum(r['attention_heatmap'] > 0.7) / r['attention_heatmap'].size * 100 results_summary.append({ 'filename': r['filename'], 'dataset': r['dataset_name'], 'true_label': r['true_label'], 'predicted_class': r['predicted_class'], 'confidence': r['confidence'], 'mil_attention_weight': r['mil_attention_weight'], 'high_attention_area_%': high_attn, 'max_attention': r['attention_heatmap'].max(), 'mean_attention': r['attention_heatmap'].mean() }) summary_df = pd.DataFrame(results_summary) summary_path = os.path.join(ATTENTION_OUTPUT_DIR, 'attention_summary.csv') summary_df.to_csv(summary_path, index=False) print(f"\nāœ… Summary saved to: {summary_path}") print("\n" + "="*80) print("ATTENTION ANALYSIS SUMMARY") print("="*80) print(summary_df.to_string(index=False)) print("="*80) # Save attention results for next cell (OpenCV feature extraction) print("\n" + "="*80) print("ATTENTION RESULTS READY FOR OPENCV PROCESSING") print("="*80) print(f"āœ… Variable 'attention_results' contains {len(attention_results)} results") print("āœ… Each result includes:") print(" - Original image") print(" - Patch-level attention heatmap (smoothed)") print(" - MIL attention weight") print(" - Predictions and metadata") print("\nšŸ’” Use 'attention_results' in the next cell for OpenCV feature extraction") print("="*80) else: print("\nāŒ No results generated") else: print("āŒ processed_images not found!") print("Please run the preprocessing cell first") else: print("āŒ collected_images not found!") print("Please run the image collection cell first") except Exception as e: print(f"āŒ Error: {e}") import traceback traceback.print_exc() import numpy as np import cv2 from sklearn.cluster import DBSCAN import matplotlib.pyplot as plt import os from scipy.ndimage import maximum_filter from skimage.feature import graycomatrix, graycoprops from openai import OpenAI import time from scipy.stats import spearmanr # ================================================================ # HEATMAP FEATURE EXTRACTOR CLASS # Adapted for Hierarchical Model Pipeline # ================================================================ class HeatmapFeatureExtractor: def __init__(self, attention_result): """ attention_result is one entry from attention_results """ self.heatmap = attention_result['attention_heatmap'] self.original_image = attention_result['image'] self.prediction_info = { 'predicted_class': attention_result['predicted_class'], 'predicted_disease': attention_result['predicted_disease'], 'severity': attention_result['severity'], 'confidence': attention_result['confidence'], 'class_idx': attention_result['class_idx'], 'mil_attention': attention_result['mil_attention_weight'] } self.true_label = attention_result['true_label'] self.dataset_name = attention_result['dataset_name'] self.filename = attention_result['filename'] # --------------------------------------------------------- # METHOD 1: Brightest Region Analysis # --------------------------------------------------------- def get_brightest_region(self): """ IMPROVED: Comprehensive analysis of high-attention regions """ heatmap = self.heatmap.astype(float) H, W = heatmap.shape heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) # 1. PRIMARY HOTSPOT brightest_idx = np.unravel_index(np.argmax(heatmap_norm), heatmap_norm.shape) y_bright, x_bright = brightest_idx intensity_bright = heatmap_norm[y_bright, x_bright] position_bright = self._get_anatomical_position(y_bright, x_bright, H, W) primary_hotspot = { "pixel": (int(y_bright), int(x_bright)), "position": position_bright, "intensity": float(intensity_bright) } # 2. SECONDARY HOTSPOTS secondary_hotspots = self._find_secondary_hotspots(heatmap_norm, H, W, threshold=0.6) # 3. ATTENTION PATTERN attention_pattern = self._determine_attention_pattern(heatmap_norm, H, W) # 4. SPATIAL COVERAGE spatial_coverage = self._calculate_spatial_coverage(heatmap_norm, H, W) # 5. HOTSPOT COUNT hotspot_count = 1 + len(secondary_hotspots) return { "primary_hotspot": primary_hotspot, "secondary_hotspots": secondary_hotspots, "attention_pattern": attention_pattern, "spatial_coverage": spatial_coverage, "hotspot_count": hotspot_count } def _get_anatomical_position(self, y, x, H, W): """Convert pixel coordinates to descriptive position""" y_rel = y / H x_rel = x / W center_y, center_x = H / 2, W / 2 dist_from_center = np.sqrt((y - center_y)**2 + (x - center_x)**2) max_dist = np.sqrt((H/2)**2 + (W/2)**2) dist_ratio = dist_from_center / max_dist center_threshold_inner = 0.35 center_threshold_outer = 0.65 periphery_threshold = 0.75 if y_rel < center_threshold_inner: vert = "upper" elif y_rel > center_threshold_outer: vert = "lower" else: vert = "mid" if x_rel < center_threshold_inner: horiz = "left" elif x_rel > center_threshold_outer: horiz = "right" else: horiz = "center" if horiz == "center" and vert == "mid": position = "center" elif horiz == "center": position = f"{vert}-center" elif vert == "mid": position = f"{horiz}-center" else: position = f"{vert}-{horiz}" if dist_ratio > periphery_threshold: position = f"{position} (periphery)" elif dist_ratio < 0.3: position = f"{position} (core)" return position def _find_secondary_hotspots(self, heatmap_norm, H, W, threshold=0.6, min_distance=20): """Find additional significant attention regions""" secondary = [] neighborhood_size = max(10, min(H, W) // 20) local_max = maximum_filter(heatmap_norm, size=neighborhood_size) peaks = (heatmap_norm == local_max) & (heatmap_norm > threshold * heatmap_norm.max()) peak_coords = np.argwhere(peaks) peak_intensities = heatmap_norm[peaks] sorted_indices = np.argsort(peak_intensities)[::-1] primary_y, primary_x = np.unravel_index(np.argmax(heatmap_norm), heatmap_norm.shape) for idx in sorted_indices[:5]: y, x = peak_coords[idx] if np.sqrt((y - primary_y)**2 + (x - primary_x)**2) < min_distance: continue too_close = False for existing in secondary: ey, ex = existing["pixel"] if np.sqrt((y - ey)**2 + (x - ex)**2) < min_distance: too_close = True break if too_close: continue position = self._get_anatomical_position(y, x, H, W) intensity = float(heatmap_norm[y, x]) secondary.append({ "pixel": (int(y), int(x)), "position": position, "intensity": intensity }) return secondary def _determine_attention_pattern(self, heatmap_norm, H, W): """Determine overall attention distribution pattern""" center_y, center_x = H // 2, W // 2 Y, X = np.ogrid[:H, :W] dist_from_center = np.sqrt((Y - center_y)**2 + (X - center_x)**2) max_dist = np.sqrt((H/2)**2 + (W/2)**2) core_mask = dist_from_center < (max_dist * 0.3) mid_mask = (dist_from_center >= max_dist * 0.3) & (dist_from_center < max_dist * 0.7) periphery_mask = dist_from_center >= (max_dist * 0.7) core_attention = np.mean(heatmap_norm[core_mask]) mid_attention = np.mean(heatmap_norm[mid_mask]) periphery_attention = np.mean(heatmap_norm[periphery_mask]) high_attention_pixels = np.sum(heatmap_norm > 0.7) / heatmap_norm.size if core_attention > 0.7 and core_attention > mid_attention * 1.5: return "centralized (focused on center)" elif periphery_attention > 0.7 and periphery_attention > core_attention * 1.5: return "peripheral (focused on edges)" elif mid_attention > core_attention and mid_attention > periphery_attention: return "ring-like (donut pattern)" elif high_attention_pixels > 0.5: return "diffuse (spread across image)" elif high_attention_pixels > 0.1 and high_attention_pixels < 0.3: return "focal (single concentrated region)" else: return "scattered (multiple regions)" def _calculate_spatial_coverage(self, heatmap_norm, H, W): """Calculate percentage of attention in each spatial region""" center_y, center_x = H // 2, W // 2 Y, X = np.ogrid[:H, :W] dist_from_center = np.sqrt((Y - center_y)**2 + (X - center_x)**2) max_dist = np.sqrt((H/2)**2 + (W/2)**2) core_mask = dist_from_center < (max_dist * 0.3) mid_mask = (dist_from_center >= max_dist * 0.3) & (dist_from_center < max_dist * 0.7) periphery_mask = dist_from_center >= (max_dist * 0.7) total_attention = np.sum(heatmap_norm) core_sum = np.sum(heatmap_norm[core_mask]) mid_sum = np.sum(heatmap_norm[mid_mask]) periphery_sum = np.sum(heatmap_norm[periphery_mask]) return { "center_attention": float(core_sum / total_attention * 100) if total_attention > 0 else 0, "mid_region_attention": float(mid_sum / total_attention * 100) if total_attention > 0 else 0, "periphery_attention": float(periphery_sum / total_attention * 100) if total_attention > 0 else 0 } # --------------------------------------------------------- # METHOD 2: Scatter Analysis # --------------------------------------------------------- def get_activation_scatter(self, threshold_ratio=0.6): """Determine if heatmap is focused or scattered""" heatmap = self.heatmap.astype(float) H, W = heatmap.shape heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6) thresh = threshold_ratio * heatmap_norm.max() high_pixels = np.argwhere(heatmap_norm >= thresh) if len(high_pixels) == 0: return { "scatter_level": "low", "num_clusters": 0, "clusters_sizes": [] } clustering = DBSCAN(eps=8, min_samples=20).fit(high_pixels) labels = clustering.labels_ unique_labels = [lb for lb in np.unique(labels) if lb != -1] cluster_sizes = [] for lb in unique_labels: cluster_sizes.append(int(np.sum(labels == lb))) num_clusters = len(unique_labels) if num_clusters == 1: scatter = "low" elif 2 <= num_clusters <= 3: scatter = "medium" else: scatter = "high" return { "scatter_level": scatter, "num_clusters": num_clusters, "clusters_sizes": cluster_sizes } # --------------------------------------------------------- # METHOD 3: Dominant Color Analysis # --------------------------------------------------------- def get_dominant_focus_color(self, threshold_ratio=0.6, k_clusters=5): """Detect dominant color in attention-focused regions""" heatmap = self.heatmap.astype(float) orig = self.original_image.copy() H, W = heatmap.shape heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6) mask = (heatmap_norm >= threshold_ratio).astype(np.uint8) * 255 if np.sum(mask) == 0: return { "dominant_color_rgb": None, "dominant_color_hsv": None, "dominant_color_name": "none", "color_confidence": 0.0 } kernel = np.ones((3, 3), np.uint8) mask_clean = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2) contours, _ = cv2.findContours(mask_clean, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if len(contours) == 0: return { "dominant_color_rgb": None, "dominant_color_hsv": None, "dominant_color_name": "none", "color_confidence": 0.0 } activation_mask = np.zeros_like(mask_clean) cv2.drawContours(activation_mask, contours, -1, 255, -1) focus_pixels = orig[activation_mask == 255] if len(focus_pixels) < 10: return { "dominant_color_rgb": None, "dominant_color_hsv": None, "dominant_color_name": "none", "color_confidence": 0.0 } Z = np.float32(focus_pixels) criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 20, 1.0) K = k_clusters _, labels, centers = cv2.kmeans( Z, K, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS ) counts = np.bincount(labels.flatten()) sorted_indices = np.argsort(counts)[::-1] dominant_index = None dominant_color = None for idx in sorted_indices: candidate_color = centers[idx].astype(int) hsv_color = cv2.cvtColor( np.uint8([[candidate_color]]), cv2.COLOR_RGB2HSV )[0][0] if hsv_color[1] > 30: dominant_index = idx dominant_color = candidate_color break if dominant_index is None: dominant_index = sorted_indices[0] dominant_color = centers[dominant_index].astype(int) hsv_color = cv2.cvtColor( np.uint8([[dominant_color]]), cv2.COLOR_RGB2HSV )[0][0] color_confidence = counts[dominant_index] / len(labels) * 100 dominant_name = self._map_color_to_name( dominant_color.tolist(), hsv_color.tolist() ) return { "dominant_color_rgb": dominant_color.tolist(), "dominant_color_hsv": hsv_color.tolist(), "dominant_color_name": dominant_name, "color_confidence": float(color_confidence) } def _map_color_to_name(self, rgb, hsv=None): """Enhanced color naming using HSV color space""" r, g, b = rgb if hsv is None: hsv_array = cv2.cvtColor(np.uint8([[rgb]]), cv2.COLOR_RGB2HSV)[0][0] h, s, v = hsv_array.tolist() else: h, s, v = hsv if s < 30: if v > 200: return "white / very light" elif v > 150: return "light gray / pale" elif v > 80: return "gray" else: return "dark gray / black" if v < 60: return "very dark / black" if 130 <= h <= 160: if s > 100: return "purple / violet" else: return "light purple / lavender" if 160 <= h <= 180 or h <= 10: if v > 180 and s < 100: return "pink / light red" elif s > 150: return "magenta / bright pink" else: return "pink / rose" if h <= 10: if v < 150: return "dark red / maroon" else: return "red / crimson" if 10 <= h < 25: if v < 130: return "brown / dark tan" else: return "orange / tan" if 25 <= h < 40: if s < 80: return "beige / cream" else: return "yellow / golden" if 40 <= h < 80: if v > 180: return "light green / pale green" elif v > 120: return "green" else: return "dark green" if 80 <= h < 100: return "cyan / turquoise" if 100 <= h < 130: if s > 150: return "blue / deep blue" elif v > 180: return "light blue / sky blue" else: return "blue" max_channel = max(r, g, b) if max_channel == r: return "reddish tones" elif max_channel == g: return "greenish tones" elif max_channel == b: return "bluish tones" else: return "mixed color region" # --------------------------------------------------------- # METHOD 4: Texture Analysis # --------------------------------------------------------- def get_texture_analysis(self, threshold_ratio=0.6): """ Analyze texture patterns in high-attention regions using GLCM Returns generic image-based descriptions without medical assumptions """ heatmap = self.heatmap.astype(float) orig = self.original_image.copy() H, W = heatmap.shape heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6) mask = (heatmap_norm >= threshold_ratio).astype(np.uint8) * 255 if np.sum(mask) < 100: return { "texture_classification": "insufficient data", "texture_description": "Not enough attention data to analyze texture", "texture_scores": { "uniformity": 0, "organization": 0, "complexity": 0, "smoothness": 0 }, "glcm_features": { "contrast": 0.0, "correlation": 0.0, "energy": 0.0, "homogeneity": 0.0 } } # Clean mask kernel = np.ones((3, 3), np.uint8) mask_clean = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2) # Extract focused region from original image contours, _ = cv2.findContours(mask_clean, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if len(contours) == 0: return { "texture_classification": "no region detected", "texture_description": "No focused region detected", "texture_scores": { "uniformity": 0, "organization": 0, "complexity": 0, "smoothness": 0 }, "glcm_features": { "contrast": 0.0, "correlation": 0.0, "energy": 0.0, "homogeneity": 0.0 } } # Create activation mask activation_mask = np.zeros_like(mask_clean) cv2.drawContours(activation_mask, contours, -1, 255, -1) # Get bounding box of the region y_coords, x_coords = np.where(activation_mask == 255) if len(y_coords) == 0: return { "texture_classification": "invalid region", "texture_description": "Invalid region for texture analysis", "texture_scores": { "uniformity": 0, "organization": 0, "complexity": 0, "smoothness": 0 }, "glcm_features": { "contrast": 0.0, "correlation": 0.0, "energy": 0.0, "homogeneity": 0.0 } } y_min, y_max = y_coords.min(), y_coords.max() x_min, x_max = x_coords.min(), x_coords.max() # Extract region region_rgb = orig[y_min:y_max+1, x_min:x_max+1] region_mask = activation_mask[y_min:y_max+1, x_min:x_max+1] # Convert to grayscale for texture analysis region_gray = cv2.cvtColor(region_rgb, cv2.COLOR_RGB2GRAY) # Apply mask to focus only on high-attention pixels region_gray_masked = region_gray.copy() region_gray_masked[region_mask == 0] = 0 # Quantize to reduce GLCM computation (64 levels) region_quantized = (region_gray_masked / 4).astype(np.uint8) # Compute GLCM # distances: [1] means immediate neighbors # angles: [0, Ļ€/4, Ļ€/2, 3Ļ€/4] for rotation invariance distances = [1] angles = [0, np.pi/4, np.pi/2, 3*np.pi/4] try: glcm = graycomatrix( region_quantized, distances=distances, angles=angles, levels=64, symmetric=True, normed=True ) # Extract features (averaged across all angles) contrast = float(graycoprops(glcm, 'contrast')[0].mean()) correlation = float(graycoprops(glcm, 'correlation')[0].mean()) energy = float(graycoprops(glcm, 'energy')[0].mean()) homogeneity = float(graycoprops(glcm, 'homogeneity')[0].mean()) except Exception as e: print(f" Warning: GLCM computation failed: {e}") return { "texture_classification": "computation error", "texture_description": "Error computing texture features", "texture_scores": { "uniformity": 0, "organization": 0, "complexity": 0, "smoothness": 0 }, "glcm_features": { "contrast": 0.0, "correlation": 0.0, "energy": 0.0, "homogeneity": 0.0 } } # Convert to 0-100 scores uniformity_score = int(energy * 100) organization_score = int(max(0, min(100, (correlation + 1) * 50))) # Scale -1,1 to 0,100 complexity_score = int((1 - energy) * 100) smoothness_score = int(homogeneity * 100) # Classify texture based on GLCM features classification, description = self._classify_texture( contrast, correlation, energy, homogeneity ) return { "texture_classification": classification, "texture_description": description, "texture_scores": { "uniformity": uniformity_score, "organization": organization_score, "complexity": complexity_score, "smoothness": smoothness_score }, "glcm_features": { "contrast": round(contrast, 2), "correlation": round(correlation, 3), "energy": round(energy, 3), "homogeneity": round(homogeneity, 3) } } def _classify_texture(self, contrast, correlation, energy, homogeneity): """ Classify texture based on GLCM features Returns (classification, description) tuple """ # Rule-based classification if contrast < 100 and homogeneity > 0.8: classification = "uniform and smooth" description = "Model focused on a region with smooth, uniform texture showing consistent patterns with minimal variation" elif correlation > 0.7 and energy > 0.3: classification = "structured and regular" description = "Model focused on a region with organized, structured patterns exhibiting regular, repeating elements" elif contrast > 400 and correlation < 0.4: classification = "irregular and chaotic" description = "Model focused on a region with irregular, chaotic texture displaying highly variable patterns with no clear organization" elif homogeneity < 0.5: classification = "rough and coarse" description = "Model focused on a region with rough, coarse texture showing sharp intensity changes and abrupt transitions" elif energy < 0.2: classification = "complex and varied" description = "Model focused on a region with complex, varied texture containing multiple different patterns and high visual diversity" else: classification = "moderate texture" description = "Model focused on a region with moderate texture complexity showing intermediate characteristics" return classification, description import warnings warnings.filterwarnings("ignore") import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms from PIL import Image import os import numpy as np import pandas as pd from transformers import Dinov2Model import openslide from tqdm import tqdm import logging import matplotlib import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec from openai import OpenAI import time logging.basicConfig(level=logging.INFO) MODEL_PATH = "/kaggle/input/models/ulimaank/updated-diagnostic-model-jan-18/other/default/1/phase3_mil_best.pth" OUTPUT_DIR = "/kaggle/working" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}, GPUs: {torch.cuda.device_count()}") DISEASE_NAMES = [ 'Breast_cancer', 'annrbc-anemia_processed', 'colon_processed', 'leukemia_processed', 'lung_processed', 'oral-cancer_processed', 'ovarian-cancer_processed', 'sickle-cell-new_processed', 'thalassemia_processed' ] STAGE_NAMES = { 0: 'Breast_cancer - ductal_carcinoma', 1: 'Breast_cancer - lobular_carcinoma', 2: 'Breast_cancer - mucinous_carcinoma', 3: 'Breast_cancer - papillary_carcinoma', 4: 'leukemia_processed - Early', 5: 'leukemia_processed - Pre', 6: 'leukemia_processed - Pro', 7: 'lung_processed - lung_aca', 8: 'lung_processed - lung_scc', 9: 'ovarian-cancer_processed - CC', 10: 'ovarian-cancer_processed - EC', 11: 'ovarian-cancer_processed - HGSC', 12: 'ovarian-cancer_processed - LGSC', 13: 'ovarian-cancer_processed - MC' } DISEASE_CLASS_MAPPING = { 0: "Breast_cancer", 1: "annrbc-anemia_processed", 2: "colon_processed", 3: "leukemia_processed", 4: "lung_processed", 5: "oral-cancer_processed", 6: "ovarian-cancer_processed", 7: "sickle-cell-new_processed", 8: "thalassemia_processed", } STAGE_CLASS_MAPPING = STAGE_NAMES TARGET_SIZE = 256 standardize_transform = transforms.Resize((TARGET_SIZE, TARGET_SIZE)) # ================================================================ # DATA COLLECTION # ================================================================ def collect_images_from_folder(folder_path): images = [] valid_extensions = ('.svs', '.tif', '.ndpi', '.png', '.jpg', '.jpeg', '.tiff') for root, dirs, files in os.walk(folder_path): for f in files: if f.lower().endswith(valid_extensions): images.append(os.path.join(root, f)) return images # ================================================================ # DATASET # ================================================================ class SimpleSlideDataset(Dataset): def __init__(self, image_paths, tile_size=224, max_tiles=1000): self.image_paths = image_paths self.tile_size = tile_size self.max_tiles = max_tiles def __len__(self): return len(self.image_paths) def __getitem__(self, idx): slide_path = self.image_paths[idx] tiles = [] try: if slide_path.lower().endswith(('.svs', '.ndpi')): slide = openslide.OpenSlide(slide_path) width, height = slide.dimensions for y in range(0, height, self.tile_size): for x in range(0, width, self.tile_size): if len(tiles) >= self.max_tiles: break tile = slide.read_region((x, y), 0, (self.tile_size, self.tile_size)).convert('RGB') tiles.append(standardize_transform(tile)) if len(tiles) >= self.max_tiles: break slide.close() elif slide_path.lower().endswith('.tif'): try: slide = openslide.OpenSlide(slide_path) width, height = slide.dimensions for y in range(0, height, self.tile_size): for x in range(0, width, self.tile_size): if len(tiles) >= self.max_tiles: break tile = slide.read_region((x, y), 0, (self.tile_size, self.tile_size)).convert('RGB') tiles.append(standardize_transform(tile)) if len(tiles) >= self.max_tiles: break slide.close() except openslide.OpenSlideError: tiles = [standardize_transform(Image.open(slide_path).convert('RGB'))] except Exception: tiles = [standardize_transform(Image.open(slide_path).convert('RGB'))] else: tiles = [standardize_transform(Image.open(slide_path).convert('RGB'))] if not tiles: raise ValueError("No tiles extracted") return tiles, slide_path except Exception as e: logging.error(f"Error processing slide {slide_path}: {e}") return [], slide_path test_transform = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def simple_collate(batch): valid_batch = [item for item in batch if item[0]] if not valid_batch: return [], [] tiles_list, paths = zip(*valid_batch) processed_tiles = [torch.stack([test_transform(tile) for tile in tiles]) for tiles in tiles_list] return processed_tiles, list(paths) # ================================================================ # MODEL ARCHITECTURE # ================================================================ class ViTBackbone(nn.Module): def __init__(self): super().__init__() self.vit = Dinov2Model.from_pretrained("owkin/phikon-v2") def forward(self, x): return self.vit(pixel_values=x).last_hidden_state[:, 0] class ClassificationHead(nn.Module): def __init__(self, in_dim=1024, num_classes=2, hidden_dim=512): super().__init__() self.classifier = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.3), nn.Linear(hidden_dim, num_classes) ) def forward(self, x): return self.classifier(x) class HierarchicalMILAggregator(nn.Module): def __init__(self, embed_dim=1024, num_heads=8, num_layers=2, num_diseases=6, num_stage_classes=0, disease_names=None): super().__init__() self.pre_norm = nn.LayerNorm(embed_dim) encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, batch_first=True, dropout=0.1) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.attention = nn.Sequential(nn.Linear(embed_dim, 256), nn.Tanh(), nn.Linear(256, 1)) self.disease_head = ClassificationHead(embed_dim, num_diseases) self.severity_heads = nn.ModuleDict() for name in disease_names: self.severity_heads[name] = ClassificationHead(embed_dim, 2) self.stage_head = ClassificationHead(embed_dim, num_stage_classes) if num_stage_classes > 0 else None self.disease_name_to_idx = {n: i for i, n in enumerate(disease_names)} self.idx_to_disease_name = {i: n for n, i in self.disease_name_to_idx.items()} self.disease_names = disease_names def forward(self, tile_features): normalized = self.pre_norm(tile_features) aggregated = self.transformer(normalized) attn_scores = self.attention(aggregated) attn_weights = torch.softmax(attn_scores.squeeze(-1), dim=1) weighted = torch.sum(aggregated * attn_weights.unsqueeze(-1), dim=1) disease_logits = self.disease_head(weighted) severity_logits = {n: self.severity_heads[n](weighted) for n in self.disease_names} stage_logits = self.stage_head(weighted) if self.stage_head is not None else None return disease_logits, severity_logits, stage_logits, attn_weights class Phase3Model(nn.Module): def __init__(self, backbone, num_diseases=6, num_stage_classes=0, disease_names=None): super().__init__() self.backbone = backbone for param in self.backbone.parameters(): param.requires_grad = False self.aggregator = HierarchicalMILAggregator( num_diseases=num_diseases, num_stage_classes=num_stage_classes, disease_names=disease_names ) def forward(self, tiles, enable_gradients=False): all_features = [] for batch_tiles in tiles: if batch_tiles.numel() == 0: continue batch_tiles = batch_tiles.to(next(self.backbone.parameters()).device) if enable_gradients: batch_features = self.backbone(batch_tiles) else: with torch.no_grad(): batch_features = self.backbone(batch_tiles) all_features.append(batch_features) if not all_features: raise ValueError("No valid tile features could be extracted.") all_features = torch.stack(all_features) return self.aggregator(all_features) # ================================================================ # PREDICTION # ================================================================ def predict_image(model, tiles, disease_names, stage_names): model.eval() try: with torch.no_grad(): disease_logits, severity_logits, stage_logits, _ = model(tiles) disease_probs = F.softmax(disease_logits, dim=1) disease_pred_idx = torch.argmax(disease_probs, dim=1).item() disease_confidence = disease_probs[0, disease_pred_idx].item() predicted_disease = disease_names[disease_pred_idx] severity_probs = F.softmax(severity_logits[predicted_disease], dim=1) severity_pred = torch.argmax(severity_probs, dim=1).item() severity_confidence = severity_probs[0, severity_pred].item() severity_label = "Normal" if severity_pred == 0 else "Abnormal" stage_label = "N/A" stage_confidence = 0.0 if severity_pred == 1 and stage_logits is not None: stage_probs = F.softmax(stage_logits, dim=1) stage_pred_idx = torch.argmax(stage_probs, dim=1).item() stage_confidence = stage_probs[0, stage_pred_idx].item() stage_label = stage_names.get(stage_pred_idx, f"Stage_{stage_pred_idx}") return { 'disease': predicted_disease, 'disease_confidence': disease_confidence, 'severity': severity_label, 'severity_confidence': severity_confidence, 'stage': stage_label, 'stage_confidence': stage_confidence } except Exception as e: logging.error(f"Error during prediction: {e}") return None # ================================================================ # GPT EXPLANATION GENERATOR # ================================================================ def generate_comprehensive_explanation(comprehensive_data): """ Calls GPT-4o-mini to convert technical XAI metrics into a human-friendly explanation. Falls back to a template string if the API call fails. Only references Attention and GradCAM heatmaps. """ try: client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) prompt = f"""You are an AI explainability assistant helping users understand how a hierarchical medical image classification model made its decision. Convert the following technical analysis into a clear, accessible explanation. HIERARCHICAL MODEL PREDICTION: - Region: {comprehensive_data['predicted_disease']} ({comprehensive_data['gradcam_disease_conf']:.1%} confidence) - Status Level: {comprehensive_data['gradcam_severity']} ({comprehensive_data['gradcam_severity_conf']:.1%} confidence) - Stage Level: {comprehensive_data['predicted_stage']} ({comprehensive_data['stage_confidence']:.1%} confidence) GRADCAM ANALYSIS (Gradient-weighted Class Activation Mapping): - Note: Bright/warm regions in GradCAM indicate areas that most strongly influenced the model's prediction SPATIAL ATTENTION PATTERN AND VISUAL CHARACTERISTICS (from Attention Heatmap): - Primary Focus: {comprehensive_data['primary_position']} (intensity: {comprehensive_data['primary_intensity']:.2f}) - Attention Hotspots: {comprehensive_data['hotspot_count']} - Spatial Distribution: Center {comprehensive_data['center_attention']:.1f}%, Mid-region {comprehensive_data['mid_attention']:.1f}%, Periphery {comprehensive_data['periphery_attention']:.1f}% - Clustering: {comprehensive_data['scatter_level']} scatter level with {comprehensive_data['num_clusters']} clusters - Dominant Color: {comprehensive_data['dominant_color']} ({comprehensive_data['color_confidence']:.1f}% confidence) - Texture Pattern: {comprehensive_data['texture_classification']} - Texture Scores: Uniformity {comprehensive_data['uniformity']}/100, Organization {comprehensive_data['organization']}/100, Complexity {comprehensive_data['complexity']}/100, Smoothness {comprehensive_data['smoothness']}/100 CRITICAL INSTRUCTIONS: 1. Write in clear, accessible language for someone without medical or technical expertise 2. Ground ALL statements in the provided data - do NOT add medical interpretations or diagnoses 3. Explain how the two explainability methods (Attention Heatmap and GradCAM) show WHERE the model focused 4. Describe WHAT visual patterns were detected, not WHY medically 5. Keep it concise but informative (under 100 words) 6. Structure with clear sections 7. Make it conversational but professional 8. Visual Characteristics and Spatial Attention Pattern were taken from Attention Heatmap Generate a comprehensive explanation covering: what the model decided, where it looked, what the Attention and GradCAM methods revealed, what visual characteristics were important, and how confident we can be in the decision. Format as natural paragraphs, not bullet points.""" response = client.chat.completions.create( model="gpt-4o-mini", messages=[ {"role": "system", "content": "You are an expert at explaining complex AI model decisions in simple, clear language. You help users understand model behavior without making medical claims."}, {"role": "user", "content": prompt} ], temperature=0.7, max_tokens=700 ) return response.choices[0].message.content.strip() except Exception as e: logging.warning(f"OpenAI API call failed: {e}. Using fallback template.") return ( f"MODEL DECISION SUMMARY\n\n" f"The model classified this as '{comprehensive_data['predicted_disease']}' " f"with severity '{comprehensive_data['gradcam_severity']}' " f"and stage '{comprehensive_data['predicted_stage']}' " f"({comprehensive_data['stage_confidence']:.1%} confidence).\n\n" f"ATTENTION ANALYSIS\n\n" f"Primary focus: {comprehensive_data['primary_position']} region. " f"Attention shows {comprehensive_data['scatter_level']} scatter across " f"{comprehensive_data['num_clusters']} clusters. " f"Distribution - Center: {comprehensive_data['center_attention']:.1f}%, " f"Mid: {comprehensive_data['mid_attention']:.1f}%, " f"Periphery: {comprehensive_data['periphery_attention']:.1f}%.\n\n" f"GRADCAM ANALYSIS\n\n" f"GradCAM confidence: {comprehensive_data['gradcam_disease_conf']:.1%} (disease), " f"{comprehensive_data['gradcam_severity_conf']:.1%} (severity).\n\n" f"VISUAL PATTERNS\n\n" f"Dominant color: {comprehensive_data['dominant_color']}. " f"Texture: {comprehensive_data['texture_classification']} " f"(uniformity {comprehensive_data['uniformity']}/100, " f"smoothness {comprehensive_data['smoothness']}/100)." ) # ================================================================ # RENDER EXPLANATION TEXT -> RGB NUMPY ARRAY FOR ROW 3 # ================================================================ def _render_explanation_to_image(explanation_text, figsize=(16, 4)): """ Renders a plain-text explanation string into an (H, W, 3) uint8 numpy array that fills the entire Row 3 panel in display_prediction(). """ fig, ax = plt.subplots(figsize=figsize, facecolor='#0F0F2A') fig.subplots_adjust(left=0, right=1, top=1, bottom=0) ax.set_facecolor('#0F0F2A') ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.axis('off') ax.add_patch(mpatches.FancyBboxPatch( (0.0, 0.0), 1.0, 1.0, boxstyle="round,pad=0.01", linewidth=2, edgecolor='#3498DB', facecolor='#16213E', transform=ax.transAxes, clip_on=False )) ax.text( 0.5, 0.93, 'Textual Explanation', ha='center', va='top', fontsize=11, fontweight='bold', color='#3498DB', transform=ax.transAxes ) ax.add_line(plt.Line2D( [0.02, 0.98], [0.855, 0.855], transform=ax.transAxes, color='#3498DB', linewidth=1.0 )) ax.text( 0.02, 0.83, explanation_text, va='top', ha='left', fontsize=10, color='#E0E0F0', family='monospace', wrap=True, transform=ax.transAxes ) fig.canvas.draw() buf = fig.canvas.buffer_rgba() img_array = np.frombuffer(buf, dtype=np.uint8).reshape( fig.canvas.get_width_height()[::-1] + (4,) ) plt.close(fig) return img_array[:, :, :3] # ================================================================ # ATTENTION OVERLAY HELPERS # ================================================================ def _build_attention_overlay(img_array, heatmap_raw): """ Takes img_array (H, W, 3) uint8 and heatmap_raw (H, W) float, returns an RGB overlay as np.ndarray in [0, 1]. """ import cv2 img_norm = img_array.astype(np.float32) / 255.0 hm = heatmap_raw.astype(np.float32) h, w = img_norm.shape[:2] if hm.shape != (h, w): hm = cv2.resize(hm, (w, h), interpolation=cv2.INTER_CUBIC) hm_min, hm_max = hm.min(), hm.max() if hm_max > hm_min: hm = (hm - hm_min) / (hm_max - hm_min) cmap = matplotlib.colormaps.get_cmap('jet') hm_colored = cmap(hm)[:, :, :3] overlay = img_norm * 0.5 + hm_colored * 0.5 return np.clip(overlay, 0, 1) def _preprocess_images_for_attention(image_paths): processed = [] for p in image_paths: try: img_pil = Image.open(p).convert('RGB').resize((224, 224), Image.BILINEAR) tensor = test_transform(img_pil) processed.append(tensor) except Exception as e: logging.warning(f"Could not preprocess {p} for attention: {e}") processed.append(torch.zeros(3, 224, 224)) return processed # ================================================================ # DISPLAY FUNCTION - 3-ROW LAYOUT # Row 1 : Original Image | Diagnostic Report # Row 2 : Attention Heatmap | GradCAM Heatmap # Row 3 : GPT-Generated Human-Friendly Text Explanation # ================================================================ def display_prediction(image_path, prediction, heatmap_images=None, heatmap_titles=None, explanation_image=None): severity = prediction['severity'] accent = '#E74C3C' if severity == 'Abnormal' else '#2ECC71' bg_color = '#1A1A2E' panel_color = '#16213E' border_dim = '#2A2A4A' # Exactly 2 heatmaps: Attention + GradCAM if heatmap_titles is None: heatmap_titles = ['Attention Heatmap', 'GradCAM Heatmap'] fig = plt.figure(figsize=(16, 14), facecolor=bg_color) outer_gs = GridSpec(3, 1, figure=fig, height_ratios=[5, 4, 3], hspace=0.08) # ===== ROW 1: Original Image | Diagnostic Report ===== row1_gs = GridSpecFromSubplotSpec(1, 2, subplot_spec=outer_gs[0], width_ratios=[1, 1.2], wspace=0.05) ax_img = fig.add_subplot(row1_gs[0]) ax_img.set_facecolor(bg_color) try: img = Image.open(image_path).convert('RGB') ax_img.imshow(img) except Exception: ax_img.text(0.5, 0.5, 'WSI / Slide\n(preview unavailable)', ha='center', va='center', color='white', fontsize=13, transform=ax_img.transAxes) for spine in ax_img.spines.values(): spine.set_edgecolor(accent) spine.set_linewidth(3) ax_img.set_xticks([]) ax_img.set_yticks([]) ax_img.set_title(os.path.basename(image_path), color='white', fontsize=11, pad=8, fontweight='bold') ax_info = fig.add_subplot(row1_gs[1]) ax_info.set_facecolor(bg_color) ax_info.set_xlim(0, 1) ax_info.set_ylim(0, 1) ax_info.axis('off') ax_info.text(0.5, 0.96, 'Diagnostic Report', ha='center', va='top', fontsize=15, fontweight='bold', color='white', transform=ax_info.transAxes) divider = plt.Line2D([0.05, 0.95], [0.89, 0.89], transform=ax_info.transAxes, color=accent, linewidth=1.5) ax_info.add_line(divider) def draw_card(ax, y, label, value, confidence, color): ax.add_patch(mpatches.FancyBboxPatch( (0.04, y - 0.11), 0.92, 0.14, boxstyle="round,pad=0.01", linewidth=1.5, edgecolor=color, facecolor=panel_color, transform=ax.transAxes, clip_on=False )) ax.text(0.10, y - 0.01, label.upper(), ha='left', va='center', fontsize=8, color='#A0A0C0', fontweight='bold', transform=ax.transAxes) ax.text(0.10, y - 0.05, value, ha='left', va='center', fontsize=13, color='white', fontweight='bold', transform=ax.transAxes) if confidence > 0: bar_y = y - 0.09 ax.add_patch(mpatches.FancyBboxPatch( (0.08, bar_y), 0.60, 0.015, boxstyle="round,pad=0.001", linewidth=0, facecolor='#0F3460', transform=ax.transAxes, clip_on=False )) ax.add_patch(mpatches.FancyBboxPatch( (0.08, bar_y), 0.60 * confidence, 0.015, boxstyle="round,pad=0.001", linewidth=0, facecolor=color, transform=ax.transAxes, clip_on=False )) ax.text(0.72, bar_y + 0.007, f'{confidence:.1%}', ha='left', va='center', fontsize=9, color=color, fontweight='bold', transform=ax.transAxes) draw_card(ax_info, 0.78, 'Region', prediction['disease'].replace('_processed', '').replace('_', ' ').title(), prediction['disease_confidence'], '#3498DB') draw_card(ax_info, 0.57, 'Status', severity, prediction['severity_confidence'], accent) stage_val = prediction['stage'] stage_conf = prediction['stage_confidence'] if stage_val == 'N/A': stage_display = 'N/A (Normal)' stage_conf = 0 else: stage_display = (stage_val.split(' - ')[-1].replace('_', ' ').title() if ' - ' in stage_val else stage_val) draw_card(ax_info, 0.36, 'Stage / Subtype', stage_display, stage_conf, '#F39C12') # ===== ROW 2: Attention Heatmap | GradCAM Heatmap (2 columns) ===== row2_gs = GridSpecFromSubplotSpec(1, 2, subplot_spec=outer_gs[1], wspace=0.06) for col_idx in range(2): ax_hm = fig.add_subplot(row2_gs[col_idx]) ax_hm.set_facecolor(panel_color) if heatmap_images and col_idx < len(heatmap_images) and heatmap_images[col_idx] is not None: hm = heatmap_images[col_idx] ax_hm.imshow(hm if isinstance(hm, np.ndarray) else np.array(hm)) else: ax_hm.set_xlim(0, 1) ax_hm.set_ylim(0, 1) ax_hm.add_patch(mpatches.FancyBboxPatch( (0.05, 0.05), 0.90, 0.90, boxstyle="round,pad=0.02", linewidth=1.5, linestyle='--', edgecolor='#4A4A6A', facecolor='#0F0F2A', transform=ax_hm.transAxes, clip_on=False )) ax_hm.text(0.5, 0.5, '[ Heatmap\nPlaceholder ]', ha='center', va='center', color='#4A4A6A', fontsize=9, fontstyle='italic', transform=ax_hm.transAxes) title = heatmap_titles[col_idx] if col_idx < len(heatmap_titles) else f'Heatmap {col_idx+1}' ax_hm.set_title(title, color='#A0A0C0', fontsize=9, fontweight='bold', pad=5) for spine in ax_hm.spines.values(): spine.set_edgecolor(border_dim) spine.set_linewidth(1.2) ax_hm.set_xticks([]) ax_hm.set_yticks([]) # ===== ROW 3: GPT Explanation ===== ax_text = fig.add_subplot(outer_gs[2]) ax_text.set_facecolor(panel_color) if explanation_image is not None: exp_img = explanation_image if isinstance(explanation_image, np.ndarray) \ else np.array(explanation_image) ax_text.imshow(exp_img, aspect='auto') ax_text.set_xticks([]) ax_text.set_yticks([]) else: ax_text.set_xlim(0, 1) ax_text.set_ylim(0, 1) ax_text.axis('off') ax_text.add_patch(mpatches.FancyBboxPatch( (0.01, 0.05), 0.98, 0.90, boxstyle="round,pad=0.02", linewidth=1.5, linestyle='--', edgecolor='#4A4A6A', facecolor='#0F0F2A', transform=ax_text.transAxes, clip_on=False )) ax_text.text(0.5, 0.80, 'Human-Friendly Text Explanation', ha='center', va='center', color='#A0A0C0', fontsize=11, fontweight='bold', transform=ax_text.transAxes) ax_text.text(0.5, 0.42, '[ Textual Explanation Placeholder ]\n\n' 'The model focused on [extracted features]\n' 'because [rule-based reasoning] ...', ha='center', va='center', color='#3A3A5A', fontsize=10, fontstyle='italic', transform=ax_text.transAxes) for spine in ax_text.spines.values(): spine.set_edgecolor(border_dim) spine.set_linewidth(1.2) plt.tight_layout(pad=1.2) plt.show() print() # ================================================================ # MODEL LOADER # ================================================================ def load_model(): if not os.path.exists(MODEL_PATH): print(f"ERROR: Model not found at {MODEL_PATH}") return None, None checkpoint = torch.load(MODEL_PATH, map_location=device) state_dict = checkpoint['model_state_dict'] if list(state_dict.keys())[0].startswith('module.'): state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} severity_head_names = set() for k in state_dict.keys(): if k.startswith('aggregator.severity_heads.'): parts = k.split('.') if len(parts) > 2: severity_head_names.add(parts[2]) disease_names = sorted(list(severity_head_names)) num_diseases = state_dict['aggregator.disease_head.classifier.3.weight'].shape[0] num_stage_classes = ( state_dict['aggregator.stage_head.classifier.3.weight'].shape[0] if 'aggregator.stage_head.classifier.3.weight' in state_dict else 0 ) print(f"\nModel config -> diseases: {num_diseases} | stages: {num_stage_classes}") print(f"Classes: {', '.join(disease_names)}\n") backbone = ViTBackbone() model = Phase3Model(backbone, num_diseases=num_diseases, num_stage_classes=num_stage_classes, disease_names=disease_names).to(device) model.load_state_dict(state_dict, strict=True) model.eval() print("Model loaded successfully!\n") return model, disease_names # ================================================================ # MAIN INFERENCE PIPELINE # ================================================================ def run_inference(): print("\n" + "=" * 70) print(" HIERARCHICAL MIL MODEL - PATHOLOGY INFERENCE") print("=" * 70) print("\nOptions:") print(" 1. Single image (provide full path to one image file)") print(" 2. Folder (provide path to a folder; all images processed)") choice = input("\nSelect option (1 / 2): ").strip() if choice == '1': image_path = input("Enter image path: ").strip() if not os.path.isfile(image_path): print(f"ERROR: File not found -> {image_path}") return all_images = [image_path] elif choice == '2': folder_path = input("Enter folder path: ").strip() if not os.path.isdir(folder_path): print(f"ERROR: Folder not found -> {folder_path}") return all_images = collect_images_from_folder(folder_path) if not all_images: print("No valid images found in the folder.") return print(f"Found {len(all_images)} image(s).") else: print("Invalid option.") return # ---------------------------------------------------------- # Load model # ---------------------------------------------------------- print("\nLoading model ...") model, disease_names = load_model() if model is None: return # ---------------------------------------------------------- # STEP 1: Run Attention and GradCAM analyses only # ---------------------------------------------------------- print("\n" + "=" * 70) print("Running Attention Analysis ...") print("=" * 70) processed_images = _preprocess_images_for_attention(all_images) attn_results = run_attention_analysis( attention_extractor, device, all_images, processed_images, DISEASE_CLASS_MAPPING, STAGE_CLASS_MAPPING ) print("\n" + "=" * 70) print("Running GradCAM Analysis ...") print("=" * 70) gradcam_results = run_tri_head_gradcam_plus_plus_analysis(model, device, all_images) # ---------------------------------------------------------- # STEP 2: Feature extraction from attention heatmaps # ---------------------------------------------------------- print("\n" + "=" * 70) print("Running Feature Extraction ...") print("=" * 70) explanations_list = [] for i, attention_result in enumerate(attn_results): print(f" [{i+1}/{len(attn_results)}] Extracting features: {attention_result['filename']}") extractor = HeatmapFeatureExtractor(attention_result) bright = extractor.get_brightest_region() scatter = extractor.get_activation_scatter() dom_color = extractor.get_dominant_focus_color() texture = extractor.get_texture_analysis() explanations_list.append({ "brightest": bright, "scatter": scatter, "dominant_color": dom_color, "texture": texture, }) print(f" āœ… Position={bright['primary_hotspot']['position']}, " f"Scatter={scatter['scatter_level']}, " f"Color={dom_color['dominant_color_name']}, " f"Texture={texture['texture_classification']}") # ---------------------------------------------------------- # STEP 3: Standard inference loop + GPT explanation + display # ---------------------------------------------------------- dataset = SimpleSlideDataset(all_images) dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=simple_collate, num_workers=2, pin_memory=True) results = [] for batch_idx, batch in enumerate(tqdm(dataloader, desc="Running inference")): tiles, paths = batch if not tiles or not paths: continue slide_path = paths[0] try: img_idx = all_images.index(slide_path) except ValueError: img_idx = batch_idx prediction = predict_image(model, tiles, disease_names, STAGE_NAMES) if prediction is None: print(f"Failed to process: {slide_path}") continue # ---- Build the 2 overlay images: Attention + GradCAM ---- attn_overlay = None if attn_results and img_idx < len(attn_results): ar = attn_results[img_idx] attn_overlay = _build_attention_overlay(ar['image'], ar['attention_heatmap']) gradcam_overlay = None if gradcam_results and img_idx < len(gradcam_results): gradcam_overlay = gradcam_results[img_idx]['union']['heatmap_overlay'] # ---- Build comprehensive_data dict for GPT ---- exp = explanations_list[img_idx] bright = exp['brightest'] scatter_res = exp['scatter'] dom = exp['dominant_color'] texture = exp['texture'] # GradCAM-sourced confidence and severity values gradcam_disease_conf = gradcam_results[img_idx]['level1_disease']['confidence'] \ if gradcam_results else 0.0 gradcam_severity_conf = gradcam_results[img_idx]['level2_severity']['confidence'] \ if gradcam_results else 0.0 gradcam_severity = gradcam_results[img_idx]['level2_severity']['predicted_class'] \ if gradcam_results else 'N/A' comprehensive_data = { 'predicted_disease': prediction['disease'], 'gradcam_severity': gradcam_severity, 'predicted_stage': prediction['stage'], 'stage_confidence': prediction['stage_confidence'], 'gradcam_disease_conf': gradcam_disease_conf, 'gradcam_severity_conf': gradcam_severity_conf, 'primary_position': bright['primary_hotspot']['position'], 'primary_intensity': bright['primary_hotspot']['intensity'], 'hotspot_count': bright['hotspot_count'], 'center_attention': bright['spatial_coverage']['center_attention'], 'mid_attention': bright['spatial_coverage']['mid_region_attention'], 'periphery_attention': bright['spatial_coverage']['periphery_attention'], 'scatter_level': scatter_res['scatter_level'], 'num_clusters': scatter_res['num_clusters'], 'dominant_color': dom['dominant_color_name'], 'color_confidence': dom['color_confidence'], 'texture_classification': texture['texture_classification'], 'uniformity': texture['texture_scores']['uniformity'], 'organization': texture['texture_scores']['organization'], 'complexity': texture['texture_scores']['complexity'], 'smoothness': texture['texture_scores']['smoothness'], } # ---- Generate GPT explanation -> render to image for Row 3 ---- print(f"\n Generating GPT explanation for image {img_idx + 1} ...") explanation_text = generate_comprehensive_explanation(comprehensive_data) explanation_image = _render_explanation_to_image(explanation_text) time.sleep(0.5) # ---- Display: Original | Report | Attention | GradCAM | Explanation ---- display_prediction( image_path = slide_path, prediction = prediction, heatmap_images = [attn_overlay, gradcam_overlay], heatmap_titles = ['Attention Heatmap', 'GradCAM Heatmap'], explanation_image = explanation_image ) results.append({ 'image_path': slide_path, 'image_name': os.path.basename(slide_path), 'predicted_disease': prediction['disease'], 'disease_confidence': prediction['disease_confidence'], 'predicted_severity': prediction['severity'], 'severity_confidence': prediction['severity_confidence'], 'predicted_stage': prediction['stage'], 'stage_confidence': prediction['stage_confidence'] }) if results: df = pd.DataFrame(results) out_path = os.path.join(OUTPUT_DIR, "inference_results.csv") df.to_csv(out_path, index=False) print(f"\nResults saved -> {out_path}") print(f"Total processed: {len(results)}") print("\n" + "=" * 70) print("DONE") print("=" * 70) if __name__ == "__main__": run_inference()