File size: 5,282 Bytes
8ee4a1b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | """
Thyroid Grad-CAM Visualization (Fixed for SwinV2)
"""
import os, sys, math, json, random, warnings, traceback
warnings.filterwarnings("ignore")
import numpy as np
from PIL import Image
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
HF_USERNAME = "Johnyquest7"
DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset"
MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid"
OUTPUT_DIR = "./gradcam_outputs"
SEED = 42
BATCH_SIZE = 16
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
def main():
print("=" * 60)
print("Thyroid Grad-CAM Visualization")
print("=" * 60)
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nDevice: {device}")
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval()
id2label = model.config.id2label
ds = load_dataset(DATASET_NAME, split="train")
ds = ds.shuffle(seed=SEED)
train_test = ds.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED)
test_ds = train_test["test"]
print(f"Test samples: {len(test_ds)}")
# Get predictions first
all_logits, all_labels = [], []
for i in range(0, len(test_ds), BATCH_SIZE):
batch_items = [test_ds[j] for j in range(i, min(i+BATCH_SIZE, len(test_ds)))]
images = [item["image"].convert("RGB") for item in batch_items]
inputs = processor(images, return_tensors="pt")
with torch.no_grad():
outputs = model(pixel_values=inputs["pixel_values"].to(device))
all_logits.extend(outputs.logits.cpu().numpy())
all_labels.extend([item["label"] for item in batch_items])
y_true = np.array(all_labels)
y_pred = np.argmax(np.array(all_logits), axis=1)
correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]]
incorrect_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]]
random.shuffle(correct_idx)
random.shuffle(incorrect_idx)
selected = correct_idx[:5] + incorrect_idx[:5]
print(f"\nGenerating Grad-CAM for {len(selected)} samples ({len(correct_idx[:5])} correct, {len(incorrect_idx[:5])} incorrect)...")
# Register hooks on last stage
gradcam_data = {}
def fwd_hook(module, input, output):
gradcam_data["feat"] = output.detach()
def bwd_hook(module, grad_input, grad_output):
gradcam_data["grad"] = grad_output[0].detach()
target_layer = model.swinv2.encoder.layers[-1].blocks[-1].layernorm_after
fwd_handle = target_layer.register_forward_hook(fwd_hook)
bwd_handle = target_layer.register_full_backward_hook(bwd_hook)
for idx in selected:
try:
item = test_ds[idx]
img = item["image"].convert("RGB")
label = item["label"]
inputs = processor(img, return_tensors="pt")
img_tensor = inputs["pixel_values"].to(device).requires_grad_(True)
model.zero_grad()
outputs = model(pixel_values=img_tensor)
target_class = int(y_pred[idx])
score = outputs.logits[0, target_class]
score.backward()
feat = gradcam_data["feat"][0] # shape: [H*W, C]
grads = gradcam_data["grad"][0] # shape: [H*W, C]
# Compute Grad-CAM: weighted sum of channels
weights = grads.mean(dim=0, keepdim=True) # [1, C]
cam = torch.matmul(feat, weights.t()).squeeze() # [H*W]
# Reshape to spatial
H = W = int(math.sqrt(cam.shape[0]))
cam = cam.reshape(H, W) # [H, W]
# Normalize
cam = F.relu(cam)
cam = cam - cam.min()
cam = cam / (cam.max() + 1e-8)
# Upsample to 256x256
cam = cam.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
cam = F.interpolate(cam, size=(256, 256), mode="bilinear", align_corners=False)
cam = cam.squeeze().cpu().numpy()
# Get image for overlay
img_np = img_tensor.squeeze().detach().cpu().permute(1,2,0).numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
plt.figure(figsize=(6,6))
plt.imshow(img_np)
plt.imshow(cam, cmap="jet", alpha=0.5)
pred_name = id2label.get(target_class, str(target_class))
true_name = id2label.get(label, str(label))
status = "CORRECT" if y_true[idx] == y_pred[idx] else "WRONG"
plt.title(f"{status}: Pred={pred_name} | True={true_name}")
plt.axis("off")
fname = f"{OUTPUT_DIR}/gradcam_{status}_sample{idx}_{pred_name}_vs_{true_name}.png"
plt.savefig(fname, bbox_inches="tight", dpi=150)
plt.close()
print(f" Saved {fname}")
except Exception as e:
print(f" Skipped sample {idx}: {e}")
traceback.print_exc()
fwd_handle.remove()
bwd_handle.remove()
print("\nGrad-CAM complete.")
if __name__ == "__main__":
main()
|