SkinProAI / models /explainability.py
cgoodmaker's picture
Initial commit — SkinProAI dermoscopic analysis platform
86f402d
raw
history blame
5.58 kB
# models/explainability.py
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from typing import Tuple
from PIL import Image
class GradCAM:
"""
Gradient-weighted Class Activation Mapping
Shows which regions of image are important for prediction
"""
def __init__(self, model: torch.nn.Module, target_layer: str = None):
"""
Args:
model: The neural network
target_layer: Layer name to compute CAM on (usually last conv layer)
"""
self.model = model
self.gradients = None
self.activations = None
# Auto-detect target layer if not specified
if target_layer is None:
# Use last ConvNeXt stage
self.target_layer = model.convnext.stages[-1]
else:
self.target_layer = dict(model.named_modules())[target_layer]
# Register hooks
self.target_layer.register_forward_hook(self._save_activation)
self.target_layer.register_full_backward_hook(self._save_gradient)
def _save_activation(self, module, input, output):
"""Save forward activations"""
self.activations = output.detach()
def _save_gradient(self, module, grad_input, grad_output):
"""Save backward gradients"""
self.gradients = grad_output[0].detach()
def generate_cam(
self,
image: torch.Tensor,
target_class: int = None
) -> np.ndarray:
"""
Generate Class Activation Map
Args:
image: Input image [1, 3, H, W]
target_class: Class to generate CAM for (None = predicted class)
Returns:
cam: Activation map [H, W] normalized to 0-1
"""
self.model.eval()
# Forward pass
output = self.model(image)
# Use predicted class if not specified
if target_class is None:
target_class = output.argmax(dim=1).item()
# Zero gradients
self.model.zero_grad()
# Backward pass for target class
output[0, target_class].backward()
# Get gradients and activations
gradients = self.gradients[0] # [C, H, W]
activations = self.activations[0] # [C, H, W]
# Global average pooling of gradients
weights = gradients.mean(dim=(1, 2)) # [C]
# Weighted sum of activations
cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
for i, w in enumerate(weights):
cam += w * activations[i]
# ReLU
cam = F.relu(cam)
# Normalize to 0-1
cam = cam.cpu().numpy()
cam = cam - cam.min()
if cam.max() > 0:
cam = cam / cam.max()
return cam
def overlay_cam_on_image(
self,
image: np.ndarray, # [H, W, 3] RGB
cam: np.ndarray, # [h, w]
alpha: float = 0.5,
colormap: int = cv2.COLORMAP_JET
) -> np.ndarray:
"""
Overlay CAM heatmap on original image
Returns:
overlay: [H, W, 3] RGB image with heatmap
"""
H, W = image.shape[:2]
# Resize CAM to image size
cam_resized = cv2.resize(cam, (W, H))
# Convert to heatmap
heatmap = cv2.applyColorMap(
np.uint8(255 * cam_resized),
colormap
)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
# Blend with original image
overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
return overlay
class AttentionVisualizer:
"""Visualize MedSigLIP attention maps"""
def __init__(self, model):
self.model = model
def get_attention_maps(self, image: torch.Tensor) -> np.ndarray:
"""
Extract attention maps from MedSigLIP
Returns:
attention: [num_heads, H, W] attention weights
"""
# Forward pass
with torch.no_grad():
_ = self.model(image)
# Get last layer attention from MedSigLIP
# Shape: [batch, num_heads, seq_len, seq_len]
attention = self.model.medsiglip_features
# Average across heads and extract spatial attention
# This is model-dependent - adjust based on MedSigLIP architecture
# Placeholder implementation
# You'll need to adapt this to your specific MedSigLIP implementation
return np.random.rand(14, 14) # Placeholder
def overlay_attention(
self,
image: np.ndarray,
attention: np.ndarray,
alpha: float = 0.6
) -> np.ndarray:
"""Overlay attention map on image"""
H, W = image.shape[:2]
# Resize attention to image size
attention_resized = cv2.resize(attention, (W, H))
# Normalize
attention_resized = (attention_resized - attention_resized.min())
if attention_resized.max() > 0:
attention_resized = attention_resized / attention_resized.max()
# Create colored overlay
heatmap = cv2.applyColorMap(
np.uint8(255 * attention_resized),
cv2.COLORMAP_VIRIDIS
)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
# Blend
overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
return overlay