File size: 5,282 Bytes
8ee4a1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Thyroid Grad-CAM Visualization (Fixed for SwinV2)
"""
import os, sys, math, json, random, warnings, traceback
warnings.filterwarnings("ignore")

import numpy as np
from PIL import Image
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification

HF_USERNAME = "Johnyquest7"
DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset"
MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid"
OUTPUT_DIR = "./gradcam_outputs"
SEED = 42
BATCH_SIZE = 16

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

def main():
    print("=" * 60)
    print("Thyroid Grad-CAM Visualization")
    print("=" * 60)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nDevice: {device}")
    processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
    model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval()
    id2label = model.config.id2label

    ds = load_dataset(DATASET_NAME, split="train")
    ds = ds.shuffle(seed=SEED)
    train_test = ds.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED)
    test_ds = train_test["test"]
    print(f"Test samples: {len(test_ds)}")

    # Get predictions first
    all_logits, all_labels = [], []
    for i in range(0, len(test_ds), BATCH_SIZE):
        batch_items = [test_ds[j] for j in range(i, min(i+BATCH_SIZE, len(test_ds)))]
        images = [item["image"].convert("RGB") for item in batch_items]
        inputs = processor(images, return_tensors="pt")
        with torch.no_grad():
            outputs = model(pixel_values=inputs["pixel_values"].to(device))
        all_logits.extend(outputs.logits.cpu().numpy())
        all_labels.extend([item["label"] for item in batch_items])

    y_true = np.array(all_labels)
    y_pred = np.argmax(np.array(all_logits), axis=1)

    correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]]
    incorrect_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]]
    random.shuffle(correct_idx)
    random.shuffle(incorrect_idx)
    selected = correct_idx[:5] + incorrect_idx[:5]
    print(f"\nGenerating Grad-CAM for {len(selected)} samples ({len(correct_idx[:5])} correct, {len(incorrect_idx[:5])} incorrect)...")

    # Register hooks on last stage
    gradcam_data = {}
    def fwd_hook(module, input, output):
        gradcam_data["feat"] = output.detach()
    def bwd_hook(module, grad_input, grad_output):
        gradcam_data["grad"] = grad_output[0].detach()

    target_layer = model.swinv2.encoder.layers[-1].blocks[-1].layernorm_after
    fwd_handle = target_layer.register_forward_hook(fwd_hook)
    bwd_handle = target_layer.register_full_backward_hook(bwd_hook)

    for idx in selected:
        try:
            item = test_ds[idx]
            img = item["image"].convert("RGB")
            label = item["label"]
            inputs = processor(img, return_tensors="pt")
            img_tensor = inputs["pixel_values"].to(device).requires_grad_(True)
            model.zero_grad()
            outputs = model(pixel_values=img_tensor)
            target_class = int(y_pred[idx])
            score = outputs.logits[0, target_class]
            score.backward()

            feat = gradcam_data["feat"][0]  # shape: [H*W, C]
            grads = gradcam_data["grad"][0]  # shape: [H*W, C]

            # Compute Grad-CAM: weighted sum of channels
            weights = grads.mean(dim=0, keepdim=True)  # [1, C]
            cam = torch.matmul(feat, weights.t()).squeeze()  # [H*W]

            # Reshape to spatial
            H = W = int(math.sqrt(cam.shape[0]))
            cam = cam.reshape(H, W)  # [H, W]

            # Normalize
            cam = F.relu(cam)
            cam = cam - cam.min()
            cam = cam / (cam.max() + 1e-8)

            # Upsample to 256x256
            cam = cam.unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
            cam = F.interpolate(cam, size=(256, 256), mode="bilinear", align_corners=False)
            cam = cam.squeeze().cpu().numpy()

            # Get image for overlay
            img_np = img_tensor.squeeze().detach().cpu().permute(1,2,0).numpy()
            img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)

            plt.figure(figsize=(6,6))
            plt.imshow(img_np)
            plt.imshow(cam, cmap="jet", alpha=0.5)
            pred_name = id2label.get(target_class, str(target_class))
            true_name = id2label.get(label, str(label))
            status = "CORRECT" if y_true[idx] == y_pred[idx] else "WRONG"
            plt.title(f"{status}: Pred={pred_name} | True={true_name}")
            plt.axis("off")
            fname = f"{OUTPUT_DIR}/gradcam_{status}_sample{idx}_{pred_name}_vs_{true_name}.png"
            plt.savefig(fname, bbox_inches="tight", dpi=150)
            plt.close()
            print(f"  Saved {fname}")
        except Exception as e:
            print(f"  Skipped sample {idx}: {e}")
            traceback.print_exc()

    fwd_handle.remove()
    bwd_handle.remove()
    print("\nGrad-CAM complete.")

if __name__ == "__main__":
    main()