File size: 7,772 Bytes
bf07f10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os
import sys
import argparse
import time
from pathlib import Path
from typing import List, Dict

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as tvmodels
import timm

from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import cv2
import csv
import matplotlib.pyplot as plt

# Import necessary modules for Grad-CAM
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# Add parent directory to path for imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

from src.utils import get_device, get_model, get_transforms

DEVICE = get_device()
print(f"Using device: {DEVICE}")

# ----------------------------- Dataset (Reusing logic from pipeline.py) -----------------------------

class FractureDataset(Dataset):
    def __init__(self, df, img_root: str = '.', transform=None):
        self.entries = df
        self.img_root = img_root
        self.transform = transform
        # CRITICAL PATH FIX: Define the redundant prefix
        self.redundant_prefix = 'balanced_augmented_dataset/' 
        self.redundant_prefix_len = len(self.redundant_prefix)

    def __len__(self):
        return len(len(self.entries))

    def __getitem__(self, idx):
        row = self.entries[idx]
        img_path = row['image_path']
        
        # PATH CLEANING FIX: Strip the redundant prefix
        if img_path.startswith(self.redundant_prefix):
            img_path = img_path[self.redundant_prefix_len:]

        if not os.path.isabs(img_path):
            img_path = os.path.join(self.img_root, img_path)
            
        img = Image.open(img_path).convert('RGB')
        
        # NOTE: We return the raw image here for visualization purposes
        raw_img = np.array(img).astype(np.float32) / 255.0

        label = int(row['label']) 
        if self.transform:
            img = self.transform(img)
            
        return img, label, img_path, raw_img


# ----------------------------- Model selection with Grad-CAM target layers -----------------------------

def get_model_with_target_layer(name: str, num_classes: int, pretrained: bool=True):
    """Get model and its target layer for Grad-CAM visualization."""
    model = get_model(name, num_classes, pretrained=pretrained)
    name = name.lower()
    
    if name.startswith('swin'):
        # Target layer for Swin: the last layer of the last stage (blocks[-1][-1])
        target_layer = model.layers[-1].blocks[-1].norm2
        return model, target_layer
    
    if name.startswith('convnext'):
        # Target layer for ConvNext: the last block of the feature extractor
        target_layer = model.stages[-1]
        return model, target_layer
    
    if name.startswith('densenet'):
        # Target layer for DenseNet: features.norm5
        target_layer = model.features.norm5
        return model, target_layer

    raise ValueError(f'Unknown target layer for model: {name}')


# ----------------------------- Helpers: CSV loader -----------------------------

def load_csv_like(path: str) -> List[Dict]:
    rows = []
    with open(path, 'r', encoding='utf8') as f: 
        reader = csv.DictReader(f)
        for r in reader:
            rows.append(r)
    return rows

# ----------------------------- Grad-CAM Analysis -----------------------------

def analyze(args):
    device = DEVICE

    # Load CSVs
    test_rows = load_csv_like(args.test_csv)
    
    # Get model and the target layer for Grad-CAM
    model, target_layer = get_model_with_target_layer(args.model, args.num_classes, pretrained=False)
    model.to(device)

    # Load checkpoint weights
    ck = torch.load(args.checkpoint, map_location=device)
    model.load_state_dict(ck['model_state_dict'])
    model.eval()
    print(f'Loaded model from {args.checkpoint} onto {device}.')

    # Data setup
    test_tf = get_transforms('val', args.img_size)
    test_ds = FractureDataset(test_rows, img_root=args.img_root, transform=test_tf)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False) # Use batch size 1 for accurate CAM per image

    # Initialize Grad-CAM
    cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=(device.type == 'cuda'))

    # Setup output directory
    os.makedirs(args.out_dir, exist_ok=True)
    
    class_names = args.class_names.split(',')
    
    print(f"Starting Grad-CAM analysis on {len(test_ds)} images...")

    for i, (imgs, labels, img_paths, raw_imgs) in enumerate(test_loader):
        imgs = imgs.to(device)
        true_label = labels.item()
        
        # 1. Prediction and Target Setup
        with torch.no_grad():
            outputs = model(imgs)
            predicted_label = outputs.softmax(dim=1).argmax(dim=1).item()

        # Set the target to the PREDICTED class for visualization
        targets = [ClassifierOutputTarget(predicted_label)]
        
        # 2. Generate CAM
        grayscale_cam = cam(input_tensor=imgs, targets=targets)
        grayscale_cam = grayscale_cam[0, :]
        
        # 3. Visualization
        # raw_img is the unnormalized image [0, 1]
        raw_img_for_viz = raw_imgs.squeeze(0).numpy()
        visualization = show_cam_on_image(raw_img_for_viz, grayscale_cam, use_rgb=True)
        
        # Convert to PIL Image for saving
        visualization_pil = Image.fromarray(cv2.cvtColor((visualization * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
        
        # 4. Save
        path_obj = Path(img_paths[0])
        class_name = class_names[true_label]
        
        # Define saving path
        save_dir = os.path.join(args.out_dir, class_name)
        os.makedirs(save_dir, exist_ok=True)
        
        # Determine the name with prediction/truth info
        pred_class_name = class_names[predicted_label]
        file_name = f'CAM_T{class_name}_P{pred_class_name}_{path_obj.name}'
        save_path = os.path.join(save_dir, file_name)
        
        visualization_pil.save(save_path)
        
        if i % 10 == 0:
            print(f"Processed {i+1}/{len(test_ds)}. Saved to: {save_path}")

    print("Grad-CAM analysis complete. Results saved to:", args.out_dir)


# ----------------------------- Main -----------------------------

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run Grad-CAM analysis on test data.')
    parser.add_argument('--checkpoint', type=str, required=True, help='Path to the model checkpoint (e.g., outputs/swin_mps/best.pth)')
    parser.add_argument('--test-csv', type=str, required=True, help='Path to the test CSV file.')
    parser.add_argument('--img-root', type=str, default='.', help='Root directory for images.')
    parser.add_argument('--model', type=str, default='swin', choices=['swin','convnext'])
    parser.add_argument('--num-classes', type=int, default=8)
    parser.add_argument('--img-size', type=int, default=224)
    parser.add_argument('--out-dir', type=str, default='outputs/analysis', help='Directory to save CAM visualizations.')
    parser.add_argument('--class-names', type=str, required=True, 
                        help='Comma-separated list of class names (e.g., "A,B,C")')
    
    args = parser.parse_args()
    
    # Check for required library dependencies
    try:
        import pytorch_grad_cam
    except ImportError:
        print("ERROR: pytorch-grad-cam library not found. Please install it:")
        print("pip install pytorch-grad-cam")
        exit(1)

    analyze(args)