NeuroVision-App / models /cnn /explainability.py
Prasannata's picture
Upload folder using huggingface_hub
45ed07c verified
import torch
import torch.nn as nn
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
# ==========================================
# CONFIGURATION
# ==========================================
PROJECT_ROOT = Path(r"d:\Downloads\BrainGrowthProject")
MODEL_PATH = PROJECT_ROOT / "outputs" / "cnn_production" / "checkpoints" / "best_custom_cnn.pth"
MAPPING_CSV = PROJECT_ROOT / "master_outputs" / "mappings" / "cnn_slice_mapping.csv"
GRAD_OUT = PROJECT_ROOT / "outputs" / "cnn_production" / "gradcam"
GRAD_OUT.mkdir(parents=True, exist_ok=True)
# 1. Architecture (Must match training)
class CustomLightCNN(nn.Module):
def __init__(self):
super(CustomLightCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3, padding=1),
nn.BatchNorm2d(8), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(8, 16, kernel_size=3, padding=1),
nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(32 * 28 * 28, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 2)
)
self.gradients = None
def activations_hook(self, grad):
self.gradients = grad
def forward(self, x):
x = self.features[0:10](x) # Up to last ReLU before MaxPool
h = x.register_hook(self.activations_hook)
x = self.features[10:12](x) # Finish features
x = self.classifier(x)
return x
def get_activations_gradient(self):
return self.gradients
def get_activations(self, x):
return self.features[0:10](x)
def generate_gradcam():
print("--- Phase 3: Generating Neural Explainability Maps (Grad-CAM) ---")
device = torch.device("cpu")
# Load Model
model = CustomLightCNN()
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
# Load mapping
df = pd.read_csv(MAPPING_CSV)
# Pick one Impaired and one Healthy sample
samples = {
"Healthy": df[df['dementia_label'] == 0].iloc[10]['slice_path'],
"Impaired": df[df['dementia_label'] == 1].iloc[10]['slice_path']
}
for label, path in samples.items():
img_orig = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
img_tensor = torch.tensor(img_orig).float().unsqueeze(0).unsqueeze(0) / 255.0
# Forward pass
pred = model(img_tensor)
pred[:, 1].backward() # Gradients for 'Impaired' class
# Pull gradients and activations
gradients = model.get_activations_gradient()
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
activations = model.get_activations(img_tensor).detach()
for i in range(32):
activations[:, i, :, :] *= pooled_gradients[i]
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = np.maximum(heatmap, 0)
heatmap /= torch.max(heatmap) + 1e-8
# Upscale and blend
heatmap = cv2.resize(heatmap.numpy(), (224, 224))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
orig_color = cv2.cvtColor(img_orig, cv2.COLOR_GRAY2BGR)
superimposed_img = cv2.addWeighted(orig_color, 0.6, heatmap, 0.4, 0)
# Save
subj_id = Path(path).name.split("_")[1]
out_path = GRAD_OUT / f"gradcam_{subj_id}_{label.lower()}.png"
cv2.imwrite(str(out_path), superimposed_img)
print(f"Saved: {out_path}")
if __name__ == "__main__":
generate_gradcam()