Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import argparse | |
| import time | |
| from pathlib import Path | |
| from typing import List, Dict | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| import torchvision.transforms as T | |
| import torchvision.models as tvmodels | |
| import timm | |
| from sklearn.metrics import precision_recall_fscore_support, confusion_matrix | |
| import cv2 | |
| import csv | |
| import matplotlib.pyplot as plt | |
| # Import necessary modules for Grad-CAM | |
| from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| # Add parent directory to path for imports | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) | |
| from src.utils import get_device, get_model, get_transforms | |
| DEVICE = get_device() | |
| print(f"Using device: {DEVICE}") | |
| # ----------------------------- Dataset (Reusing logic from pipeline.py) ----------------------------- | |
| class FractureDataset(Dataset): | |
| def __init__(self, df, img_root: str = '.', transform=None): | |
| self.entries = df | |
| self.img_root = img_root | |
| self.transform = transform | |
| # CRITICAL PATH FIX: Define the redundant prefix | |
| self.redundant_prefix = 'balanced_augmented_dataset/' | |
| self.redundant_prefix_len = len(self.redundant_prefix) | |
| def __len__(self): | |
| return len(len(self.entries)) | |
| def __getitem__(self, idx): | |
| row = self.entries[idx] | |
| img_path = row['image_path'] | |
| # PATH CLEANING FIX: Strip the redundant prefix | |
| if img_path.startswith(self.redundant_prefix): | |
| img_path = img_path[self.redundant_prefix_len:] | |
| if not os.path.isabs(img_path): | |
| img_path = os.path.join(self.img_root, img_path) | |
| img = Image.open(img_path).convert('RGB') | |
| # NOTE: We return the raw image here for visualization purposes | |
| raw_img = np.array(img).astype(np.float32) / 255.0 | |
| label = int(row['label']) | |
| if self.transform: | |
| img = self.transform(img) | |
| return img, label, img_path, raw_img | |
| # ----------------------------- Model selection with Grad-CAM target layers ----------------------------- | |
| def get_model_with_target_layer(name: str, num_classes: int, pretrained: bool=True): | |
| """Get model and its target layer for Grad-CAM visualization.""" | |
| model = get_model(name, num_classes, pretrained=pretrained) | |
| name = name.lower() | |
| if name.startswith('swin'): | |
| # Target layer for Swin: the last layer of the last stage (blocks[-1][-1]) | |
| target_layer = model.layers[-1].blocks[-1].norm2 | |
| return model, target_layer | |
| if name.startswith('convnext'): | |
| # Target layer for ConvNext: the last block of the feature extractor | |
| target_layer = model.stages[-1] | |
| return model, target_layer | |
| if name.startswith('densenet'): | |
| # Target layer for DenseNet: features.norm5 | |
| target_layer = model.features.norm5 | |
| return model, target_layer | |
| raise ValueError(f'Unknown target layer for model: {name}') | |
| # ----------------------------- Helpers: CSV loader ----------------------------- | |
| def load_csv_like(path: str) -> List[Dict]: | |
| rows = [] | |
| with open(path, 'r', encoding='utf8') as f: | |
| reader = csv.DictReader(f) | |
| for r in reader: | |
| rows.append(r) | |
| return rows | |
| # ----------------------------- Grad-CAM Analysis ----------------------------- | |
| def analyze(args): | |
| device = DEVICE | |
| # Load CSVs | |
| test_rows = load_csv_like(args.test_csv) | |
| # Get model and the target layer for Grad-CAM | |
| model, target_layer = get_model_with_target_layer(args.model, args.num_classes, pretrained=False) | |
| model.to(device) | |
| # Load checkpoint weights | |
| ck = torch.load(args.checkpoint, map_location=device) | |
| model.load_state_dict(ck['model_state_dict']) | |
| model.eval() | |
| print(f'Loaded model from {args.checkpoint} onto {device}.') | |
| # Data setup | |
| test_tf = get_transforms('val', args.img_size) | |
| test_ds = FractureDataset(test_rows, img_root=args.img_root, transform=test_tf) | |
| test_loader = DataLoader(test_ds, batch_size=1, shuffle=False) # Use batch size 1 for accurate CAM per image | |
| # Initialize Grad-CAM | |
| cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=(device.type == 'cuda')) | |
| # Setup output directory | |
| os.makedirs(args.out_dir, exist_ok=True) | |
| class_names = args.class_names.split(',') | |
| print(f"Starting Grad-CAM analysis on {len(test_ds)} images...") | |
| for i, (imgs, labels, img_paths, raw_imgs) in enumerate(test_loader): | |
| imgs = imgs.to(device) | |
| true_label = labels.item() | |
| # 1. Prediction and Target Setup | |
| with torch.no_grad(): | |
| outputs = model(imgs) | |
| predicted_label = outputs.softmax(dim=1).argmax(dim=1).item() | |
| # Set the target to the PREDICTED class for visualization | |
| targets = [ClassifierOutputTarget(predicted_label)] | |
| # 2. Generate CAM | |
| grayscale_cam = cam(input_tensor=imgs, targets=targets) | |
| grayscale_cam = grayscale_cam[0, :] | |
| # 3. Visualization | |
| # raw_img is the unnormalized image [0, 1] | |
| raw_img_for_viz = raw_imgs.squeeze(0).numpy() | |
| visualization = show_cam_on_image(raw_img_for_viz, grayscale_cam, use_rgb=True) | |
| # Convert to PIL Image for saving | |
| visualization_pil = Image.fromarray(cv2.cvtColor((visualization * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) | |
| # 4. Save | |
| path_obj = Path(img_paths[0]) | |
| class_name = class_names[true_label] | |
| # Define saving path | |
| save_dir = os.path.join(args.out_dir, class_name) | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Determine the name with prediction/truth info | |
| pred_class_name = class_names[predicted_label] | |
| file_name = f'CAM_T{class_name}_P{pred_class_name}_{path_obj.name}' | |
| save_path = os.path.join(save_dir, file_name) | |
| visualization_pil.save(save_path) | |
| if i % 10 == 0: | |
| print(f"Processed {i+1}/{len(test_ds)}. Saved to: {save_path}") | |
| print("Grad-CAM analysis complete. Results saved to:", args.out_dir) | |
| # ----------------------------- Main ----------------------------- | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Run Grad-CAM analysis on test data.') | |
| parser.add_argument('--checkpoint', type=str, required=True, help='Path to the model checkpoint (e.g., outputs/swin_mps/best.pth)') | |
| parser.add_argument('--test-csv', type=str, required=True, help='Path to the test CSV file.') | |
| parser.add_argument('--img-root', type=str, default='.', help='Root directory for images.') | |
| parser.add_argument('--model', type=str, default='swin', choices=['swin','convnext']) | |
| parser.add_argument('--num-classes', type=int, default=8) | |
| parser.add_argument('--img-size', type=int, default=224) | |
| parser.add_argument('--out-dir', type=str, default='outputs/analysis', help='Directory to save CAM visualizations.') | |
| parser.add_argument('--class-names', type=str, required=True, | |
| help='Comma-separated list of class names (e.g., "A,B,C")') | |
| args = parser.parse_args() | |
| # Check for required library dependencies | |
| try: | |
| import pytorch_grad_cam | |
| except ImportError: | |
| print("ERROR: pytorch-grad-cam library not found. Please install it:") | |
| print("pip install pytorch-grad-cam") | |
| exit(1) | |
| analyze(args) |