MedAI-ACM / src /analysis /analyze_2.py
Tirath5504's picture
deploy
bf07f10
import os
import sys
import argparse
import time
from pathlib import Path
from typing import List, Dict
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as tvmodels
import timm
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import cv2
import csv
import matplotlib.pyplot as plt
# Import necessary modules for Grad-CAM
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
# Add parent directory to path for imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from src.utils import get_device, get_model, get_transforms
DEVICE = get_device()
print(f"Using device: {DEVICE}")
# ----------------------------- Dataset (Reusing logic from pipeline.py) -----------------------------
class FractureDataset(Dataset):
def __init__(self, df, img_root: str = '.', transform=None):
self.entries = df
self.img_root = img_root
self.transform = transform
# CRITICAL PATH FIX: Define the redundant prefix
self.redundant_prefix = 'balanced_augmented_dataset/'
self.redundant_prefix_len = len(self.redundant_prefix)
def __len__(self):
return len(len(self.entries))
def __getitem__(self, idx):
row = self.entries[idx]
img_path = row['image_path']
# PATH CLEANING FIX: Strip the redundant prefix
if img_path.startswith(self.redundant_prefix):
img_path = img_path[self.redundant_prefix_len:]
if not os.path.isabs(img_path):
img_path = os.path.join(self.img_root, img_path)
img = Image.open(img_path).convert('RGB')
# NOTE: We return the raw image here for visualization purposes
raw_img = np.array(img).astype(np.float32) / 255.0
label = int(row['label'])
if self.transform:
img = self.transform(img)
return img, label, img_path, raw_img
# ----------------------------- Model selection with Grad-CAM target layers -----------------------------
def get_model_with_target_layer(name: str, num_classes: int, pretrained: bool=True):
"""Get model and its target layer for Grad-CAM visualization."""
model = get_model(name, num_classes, pretrained=pretrained)
name = name.lower()
if name.startswith('swin'):
# Target layer for Swin: the last layer of the last stage (blocks[-1][-1])
target_layer = model.layers[-1].blocks[-1].norm2
return model, target_layer
if name.startswith('convnext'):
# Target layer for ConvNext: the last block of the feature extractor
target_layer = model.stages[-1]
return model, target_layer
if name.startswith('densenet'):
# Target layer for DenseNet: features.norm5
target_layer = model.features.norm5
return model, target_layer
raise ValueError(f'Unknown target layer for model: {name}')
# ----------------------------- Helpers: CSV loader -----------------------------
def load_csv_like(path: str) -> List[Dict]:
rows = []
with open(path, 'r', encoding='utf8') as f:
reader = csv.DictReader(f)
for r in reader:
rows.append(r)
return rows
# ----------------------------- Grad-CAM Analysis -----------------------------
def analyze(args):
device = DEVICE
# Load CSVs
test_rows = load_csv_like(args.test_csv)
# Get model and the target layer for Grad-CAM
model, target_layer = get_model_with_target_layer(args.model, args.num_classes, pretrained=False)
model.to(device)
# Load checkpoint weights
ck = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(ck['model_state_dict'])
model.eval()
print(f'Loaded model from {args.checkpoint} onto {device}.')
# Data setup
test_tf = get_transforms('val', args.img_size)
test_ds = FractureDataset(test_rows, img_root=args.img_root, transform=test_tf)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False) # Use batch size 1 for accurate CAM per image
# Initialize Grad-CAM
cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=(device.type == 'cuda'))
# Setup output directory
os.makedirs(args.out_dir, exist_ok=True)
class_names = args.class_names.split(',')
print(f"Starting Grad-CAM analysis on {len(test_ds)} images...")
for i, (imgs, labels, img_paths, raw_imgs) in enumerate(test_loader):
imgs = imgs.to(device)
true_label = labels.item()
# 1. Prediction and Target Setup
with torch.no_grad():
outputs = model(imgs)
predicted_label = outputs.softmax(dim=1).argmax(dim=1).item()
# Set the target to the PREDICTED class for visualization
targets = [ClassifierOutputTarget(predicted_label)]
# 2. Generate CAM
grayscale_cam = cam(input_tensor=imgs, targets=targets)
grayscale_cam = grayscale_cam[0, :]
# 3. Visualization
# raw_img is the unnormalized image [0, 1]
raw_img_for_viz = raw_imgs.squeeze(0).numpy()
visualization = show_cam_on_image(raw_img_for_viz, grayscale_cam, use_rgb=True)
# Convert to PIL Image for saving
visualization_pil = Image.fromarray(cv2.cvtColor((visualization * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
# 4. Save
path_obj = Path(img_paths[0])
class_name = class_names[true_label]
# Define saving path
save_dir = os.path.join(args.out_dir, class_name)
os.makedirs(save_dir, exist_ok=True)
# Determine the name with prediction/truth info
pred_class_name = class_names[predicted_label]
file_name = f'CAM_T{class_name}_P{pred_class_name}_{path_obj.name}'
save_path = os.path.join(save_dir, file_name)
visualization_pil.save(save_path)
if i % 10 == 0:
print(f"Processed {i+1}/{len(test_ds)}. Saved to: {save_path}")
print("Grad-CAM analysis complete. Results saved to:", args.out_dir)
# ----------------------------- Main -----------------------------
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run Grad-CAM analysis on test data.')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to the model checkpoint (e.g., outputs/swin_mps/best.pth)')
parser.add_argument('--test-csv', type=str, required=True, help='Path to the test CSV file.')
parser.add_argument('--img-root', type=str, default='.', help='Root directory for images.')
parser.add_argument('--model', type=str, default='swin', choices=['swin','convnext'])
parser.add_argument('--num-classes', type=int, default=8)
parser.add_argument('--img-size', type=int, default=224)
parser.add_argument('--out-dir', type=str, default='outputs/analysis', help='Directory to save CAM visualizations.')
parser.add_argument('--class-names', type=str, required=True,
help='Comma-separated list of class names (e.g., "A,B,C")')
args = parser.parse_args()
# Check for required library dependencies
try:
import pytorch_grad_cam
except ImportError:
print("ERROR: pytorch-grad-cam library not found. Please install it:")
print("pip install pytorch-grad-cam")
exit(1)
analyze(args)