import torch import torch.nn as nn import numpy as np import cv2 import matplotlib.pyplot as plt from pathlib import Path import pandas as pd # ========================================== # CONFIGURATION # ========================================== PROJECT_ROOT = Path(r"d:\Downloads\BrainGrowthProject") MODEL_PATH = PROJECT_ROOT / "outputs" / "cnn_production" / "checkpoints" / "best_custom_cnn.pth" MAPPING_CSV = PROJECT_ROOT / "master_outputs" / "mappings" / "cnn_slice_mapping.csv" GRAD_OUT = PROJECT_ROOT / "outputs" / "cnn_production" / "gradcam" GRAD_OUT.mkdir(parents=True, exist_ok=True) # 1. Architecture (Must match training) class CustomLightCNN(nn.Module): def __init__(self): super(CustomLightCNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 8, kernel_size=3, padding=1), nn.BatchNorm2d(8), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(8, 16, kernel_size=3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2) ) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(32 * 28 * 28, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, 2) ) self.gradients = None def activations_hook(self, grad): self.gradients = grad def forward(self, x): x = self.features[0:10](x) # Up to last ReLU before MaxPool h = x.register_hook(self.activations_hook) x = self.features[10:12](x) # Finish features x = self.classifier(x) return x def get_activations_gradient(self): return self.gradients def get_activations(self, x): return self.features[0:10](x) def generate_gradcam(): print("--- Phase 3: Generating Neural Explainability Maps (Grad-CAM) ---") device = torch.device("cpu") # Load Model model = CustomLightCNN() model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.eval() # Load mapping df = pd.read_csv(MAPPING_CSV) # Pick one Impaired and one Healthy sample samples = { "Healthy": df[df['dementia_label'] == 0].iloc[10]['slice_path'], "Impaired": df[df['dementia_label'] == 1].iloc[10]['slice_path'] } for label, path in samples.items(): img_orig = cv2.imread(path, cv2.IMREAD_GRAYSCALE) img_tensor = torch.tensor(img_orig).float().unsqueeze(0).unsqueeze(0) / 255.0 # Forward pass pred = model(img_tensor) pred[:, 1].backward() # Gradients for 'Impaired' class # Pull gradients and activations gradients = model.get_activations_gradient() pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) activations = model.get_activations(img_tensor).detach() for i in range(32): activations[:, i, :, :] *= pooled_gradients[i] heatmap = torch.mean(activations, dim=1).squeeze() heatmap = np.maximum(heatmap, 0) heatmap /= torch.max(heatmap) + 1e-8 # Upscale and blend heatmap = cv2.resize(heatmap.numpy(), (224, 224)) heatmap = np.uint8(255 * heatmap) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) orig_color = cv2.cvtColor(img_orig, cv2.COLOR_GRAY2BGR) superimposed_img = cv2.addWeighted(orig_color, 0.6, heatmap, 0.4, 0) # Save subj_id = Path(path).name.split("_")[1] out_path = GRAD_OUT / f"gradcam_{subj_id}_{label.lower()}.png" cv2.imwrite(str(out_path), superimposed_img) print(f"Saved: {out_path}") if __name__ == "__main__": generate_gradcam()