import os import torch import torch.nn as nn from torchvision import transforms, models import medmnist from medmnist import INFO import matplotlib.pyplot as plt import numpy as np from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image def main(): # 1. Hardware Setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Generating Grad-CAM on: {device}") dataset_root = r"C:\Users\USER\Downloads\MedMNIST_Data" info = INFO['pneumoniamnist'] DataClass = getattr(medmnist, info['python_class']) # 2. Rebuild and Load the Frozen Model print("Loading Baseline ResNet50...") model = models.resnet50() model.fc = nn.Linear(model.fc.in_features, 2) weights_path = os.path.join(dataset_root, 'baseline_resnet50.pth') model.load_state_dict(torch.load(weights_path, map_location=device, weights_only=True)) model = model.to(device) model.eval() # Lock the model # 3. Hook into the final layer # Target layer4[-1], which is the final convolutional block before the classification head target_layers = [model.layer4[-1]] cam = GradCAM(model=model, target_layers=target_layers) # 4. Grab a sample image (without mathematical augmentations) val_dataset_raw = DataClass(split='val', download=False, size=224, root=dataset_root) # Mathematical preprocessing just for the model's brain transform = transforms.Compose([ transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) sample_idx = 0 for i in range(len(val_dataset_raw)): img, label = val_dataset_raw[i] if label[0] == 1: sample_idx = i break raw_img, _ = val_dataset_raw[sample_idx] # Convert the raw grayscale image to RGB and scale to [0, 1] for the visual heatmap overlay rgb_img = np.array(raw_img.convert('RGB'), dtype=np.float32) / 255.0 # Push the math-ready tensor to the GPU input_tensor = transform(raw_img).unsqueeze(0).to(device) # 5. Generate the Heatmap # Show pneumonia targets = [ClassifierOutputTarget(1)] grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :] # Overlay the red/yellow heatmap on the black and white X-ray visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) # 6. Plot and Save for the Report print("Generating Figure...") fig, axes = plt.subplots(1, 2, figsize=(10, 5)) axes[0].imshow(rgb_img) axes[0].set_title("Original X-Ray (Pneumonia)", fontweight='bold') axes[0].axis('off') axes[1].imshow(visualization) axes[1].set_title("Grad-CAM Heatmap", fontweight='bold') axes[1].axis('off') save_path = os.path.join(dataset_root, 'gradcam_pneumonia.png') plt.tight_layout() plt.savefig(save_path, dpi=300) print(f"Success! Image saved to: {save_path}") plt.show() if __name__ == '__main__': main()