Spaces:
Runtime error
Runtime error
File size: 7,772 Bytes
bf07f10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import os
import sys
import argparse
import time
from pathlib import Path
from typing import List, Dict
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as tvmodels
import timm
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import cv2
import csv
import matplotlib.pyplot as plt
# Import necessary modules for Grad-CAM
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
# Add parent directory to path for imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from src.utils import get_device, get_model, get_transforms
DEVICE = get_device()
print(f"Using device: {DEVICE}")
# ----------------------------- Dataset (Reusing logic from pipeline.py) -----------------------------
class FractureDataset(Dataset):
def __init__(self, df, img_root: str = '.', transform=None):
self.entries = df
self.img_root = img_root
self.transform = transform
# CRITICAL PATH FIX: Define the redundant prefix
self.redundant_prefix = 'balanced_augmented_dataset/'
self.redundant_prefix_len = len(self.redundant_prefix)
def __len__(self):
return len(len(self.entries))
def __getitem__(self, idx):
row = self.entries[idx]
img_path = row['image_path']
# PATH CLEANING FIX: Strip the redundant prefix
if img_path.startswith(self.redundant_prefix):
img_path = img_path[self.redundant_prefix_len:]
if not os.path.isabs(img_path):
img_path = os.path.join(self.img_root, img_path)
img = Image.open(img_path).convert('RGB')
# NOTE: We return the raw image here for visualization purposes
raw_img = np.array(img).astype(np.float32) / 255.0
label = int(row['label'])
if self.transform:
img = self.transform(img)
return img, label, img_path, raw_img
# ----------------------------- Model selection with Grad-CAM target layers -----------------------------
def get_model_with_target_layer(name: str, num_classes: int, pretrained: bool=True):
"""Get model and its target layer for Grad-CAM visualization."""
model = get_model(name, num_classes, pretrained=pretrained)
name = name.lower()
if name.startswith('swin'):
# Target layer for Swin: the last layer of the last stage (blocks[-1][-1])
target_layer = model.layers[-1].blocks[-1].norm2
return model, target_layer
if name.startswith('convnext'):
# Target layer for ConvNext: the last block of the feature extractor
target_layer = model.stages[-1]
return model, target_layer
if name.startswith('densenet'):
# Target layer for DenseNet: features.norm5
target_layer = model.features.norm5
return model, target_layer
raise ValueError(f'Unknown target layer for model: {name}')
# ----------------------------- Helpers: CSV loader -----------------------------
def load_csv_like(path: str) -> List[Dict]:
rows = []
with open(path, 'r', encoding='utf8') as f:
reader = csv.DictReader(f)
for r in reader:
rows.append(r)
return rows
# ----------------------------- Grad-CAM Analysis -----------------------------
def analyze(args):
device = DEVICE
# Load CSVs
test_rows = load_csv_like(args.test_csv)
# Get model and the target layer for Grad-CAM
model, target_layer = get_model_with_target_layer(args.model, args.num_classes, pretrained=False)
model.to(device)
# Load checkpoint weights
ck = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(ck['model_state_dict'])
model.eval()
print(f'Loaded model from {args.checkpoint} onto {device}.')
# Data setup
test_tf = get_transforms('val', args.img_size)
test_ds = FractureDataset(test_rows, img_root=args.img_root, transform=test_tf)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False) # Use batch size 1 for accurate CAM per image
# Initialize Grad-CAM
cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=(device.type == 'cuda'))
# Setup output directory
os.makedirs(args.out_dir, exist_ok=True)
class_names = args.class_names.split(',')
print(f"Starting Grad-CAM analysis on {len(test_ds)} images...")
for i, (imgs, labels, img_paths, raw_imgs) in enumerate(test_loader):
imgs = imgs.to(device)
true_label = labels.item()
# 1. Prediction and Target Setup
with torch.no_grad():
outputs = model(imgs)
predicted_label = outputs.softmax(dim=1).argmax(dim=1).item()
# Set the target to the PREDICTED class for visualization
targets = [ClassifierOutputTarget(predicted_label)]
# 2. Generate CAM
grayscale_cam = cam(input_tensor=imgs, targets=targets)
grayscale_cam = grayscale_cam[0, :]
# 3. Visualization
# raw_img is the unnormalized image [0, 1]
raw_img_for_viz = raw_imgs.squeeze(0).numpy()
visualization = show_cam_on_image(raw_img_for_viz, grayscale_cam, use_rgb=True)
# Convert to PIL Image for saving
visualization_pil = Image.fromarray(cv2.cvtColor((visualization * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
# 4. Save
path_obj = Path(img_paths[0])
class_name = class_names[true_label]
# Define saving path
save_dir = os.path.join(args.out_dir, class_name)
os.makedirs(save_dir, exist_ok=True)
# Determine the name with prediction/truth info
pred_class_name = class_names[predicted_label]
file_name = f'CAM_T{class_name}_P{pred_class_name}_{path_obj.name}'
save_path = os.path.join(save_dir, file_name)
visualization_pil.save(save_path)
if i % 10 == 0:
print(f"Processed {i+1}/{len(test_ds)}. Saved to: {save_path}")
print("Grad-CAM analysis complete. Results saved to:", args.out_dir)
# ----------------------------- Main -----------------------------
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run Grad-CAM analysis on test data.')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to the model checkpoint (e.g., outputs/swin_mps/best.pth)')
parser.add_argument('--test-csv', type=str, required=True, help='Path to the test CSV file.')
parser.add_argument('--img-root', type=str, default='.', help='Root directory for images.')
parser.add_argument('--model', type=str, default='swin', choices=['swin','convnext'])
parser.add_argument('--num-classes', type=int, default=8)
parser.add_argument('--img-size', type=int, default=224)
parser.add_argument('--out-dir', type=str, default='outputs/analysis', help='Directory to save CAM visualizations.')
parser.add_argument('--class-names', type=str, required=True,
help='Comma-separated list of class names (e.g., "A,B,C")')
args = parser.parse_args()
# Check for required library dependencies
try:
import pytorch_grad_cam
except ImportError:
print("ERROR: pytorch-grad-cam library not found. Please install it:")
print("pip install pytorch-grad-cam")
exit(1)
analyze(args) |