Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| import pandas as pd | |
| # ========================================== | |
| # CONFIGURATION | |
| # ========================================== | |
| PROJECT_ROOT = Path(r"d:\Downloads\BrainGrowthProject") | |
| MODEL_PATH = PROJECT_ROOT / "outputs" / "cnn_production" / "checkpoints" / "best_custom_cnn.pth" | |
| MAPPING_CSV = PROJECT_ROOT / "master_outputs" / "mappings" / "cnn_slice_mapping.csv" | |
| GRAD_OUT = PROJECT_ROOT / "outputs" / "cnn_production" / "gradcam" | |
| GRAD_OUT.mkdir(parents=True, exist_ok=True) | |
| # 1. Architecture (Must match training) | |
| class CustomLightCNN(nn.Module): | |
| def __init__(self): | |
| super(CustomLightCNN, self).__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(1, 8, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(8), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(8, 16, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(16, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2) | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(32 * 28 * 28, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(64, 2) | |
| ) | |
| self.gradients = None | |
| def activations_hook(self, grad): | |
| self.gradients = grad | |
| def forward(self, x): | |
| x = self.features[0:10](x) # Up to last ReLU before MaxPool | |
| h = x.register_hook(self.activations_hook) | |
| x = self.features[10:12](x) # Finish features | |
| x = self.classifier(x) | |
| return x | |
| def get_activations_gradient(self): | |
| return self.gradients | |
| def get_activations(self, x): | |
| return self.features[0:10](x) | |
| def generate_gradcam(): | |
| print("--- Phase 3: Generating Neural Explainability Maps (Grad-CAM) ---") | |
| device = torch.device("cpu") | |
| # Load Model | |
| model = CustomLightCNN() | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) | |
| model.eval() | |
| # Load mapping | |
| df = pd.read_csv(MAPPING_CSV) | |
| # Pick one Impaired and one Healthy sample | |
| samples = { | |
| "Healthy": df[df['dementia_label'] == 0].iloc[10]['slice_path'], | |
| "Impaired": df[df['dementia_label'] == 1].iloc[10]['slice_path'] | |
| } | |
| for label, path in samples.items(): | |
| img_orig = cv2.imread(path, cv2.IMREAD_GRAYSCALE) | |
| img_tensor = torch.tensor(img_orig).float().unsqueeze(0).unsqueeze(0) / 255.0 | |
| # Forward pass | |
| pred = model(img_tensor) | |
| pred[:, 1].backward() # Gradients for 'Impaired' class | |
| # Pull gradients and activations | |
| gradients = model.get_activations_gradient() | |
| pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) | |
| activations = model.get_activations(img_tensor).detach() | |
| for i in range(32): | |
| activations[:, i, :, :] *= pooled_gradients[i] | |
| heatmap = torch.mean(activations, dim=1).squeeze() | |
| heatmap = np.maximum(heatmap, 0) | |
| heatmap /= torch.max(heatmap) + 1e-8 | |
| # Upscale and blend | |
| heatmap = cv2.resize(heatmap.numpy(), (224, 224)) | |
| heatmap = np.uint8(255 * heatmap) | |
| heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
| orig_color = cv2.cvtColor(img_orig, cv2.COLOR_GRAY2BGR) | |
| superimposed_img = cv2.addWeighted(orig_color, 0.6, heatmap, 0.4, 0) | |
| # Save | |
| subj_id = Path(path).name.split("_")[1] | |
| out_path = GRAD_OUT / f"gradcam_{subj_id}_{label.lower()}.png" | |
| cv2.imwrite(str(out_path), superimposed_img) | |
| print(f"Saved: {out_path}") | |
| if __name__ == "__main__": | |
| generate_gradcam() | |