""" Thyroid Grad-CAM Visualization (Fixed for SwinV2) """ import os, sys, math, json, random, warnings, traceback warnings.filterwarnings("ignore") import numpy as np from PIL import Image import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import torch import torch.nn.functional as F from datasets import load_dataset from transformers import AutoImageProcessor, AutoModelForImageClassification HF_USERNAME = "Johnyquest7" DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset" MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid" OUTPUT_DIR = "./gradcam_outputs" SEED = 42 BATCH_SIZE = 16 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) def main(): print("=" * 60) print("Thyroid Grad-CAM Visualization") print("=" * 60) os.makedirs(OUTPUT_DIR, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\nDevice: {device}") processor = AutoImageProcessor.from_pretrained(MODEL_NAME) model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval() id2label = model.config.id2label ds = load_dataset(DATASET_NAME, split="train") ds = ds.shuffle(seed=SEED) train_test = ds.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED) test_ds = train_test["test"] print(f"Test samples: {len(test_ds)}") # Get predictions first all_logits, all_labels = [], [] for i in range(0, len(test_ds), BATCH_SIZE): batch_items = [test_ds[j] for j in range(i, min(i+BATCH_SIZE, len(test_ds)))] images = [item["image"].convert("RGB") for item in batch_items] inputs = processor(images, return_tensors="pt") with torch.no_grad(): outputs = model(pixel_values=inputs["pixel_values"].to(device)) all_logits.extend(outputs.logits.cpu().numpy()) all_labels.extend([item["label"] for item in batch_items]) y_true = np.array(all_labels) y_pred = np.argmax(np.array(all_logits), axis=1) correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]] incorrect_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]] random.shuffle(correct_idx) random.shuffle(incorrect_idx) selected = correct_idx[:5] + incorrect_idx[:5] print(f"\nGenerating Grad-CAM for {len(selected)} samples ({len(correct_idx[:5])} correct, {len(incorrect_idx[:5])} incorrect)...") # Register hooks on last stage gradcam_data = {} def fwd_hook(module, input, output): gradcam_data["feat"] = output.detach() def bwd_hook(module, grad_input, grad_output): gradcam_data["grad"] = grad_output[0].detach() target_layer = model.swinv2.encoder.layers[-1].blocks[-1].layernorm_after fwd_handle = target_layer.register_forward_hook(fwd_hook) bwd_handle = target_layer.register_full_backward_hook(bwd_hook) for idx in selected: try: item = test_ds[idx] img = item["image"].convert("RGB") label = item["label"] inputs = processor(img, return_tensors="pt") img_tensor = inputs["pixel_values"].to(device).requires_grad_(True) model.zero_grad() outputs = model(pixel_values=img_tensor) target_class = int(y_pred[idx]) score = outputs.logits[0, target_class] score.backward() feat = gradcam_data["feat"][0] # shape: [H*W, C] grads = gradcam_data["grad"][0] # shape: [H*W, C] # Compute Grad-CAM: weighted sum of channels weights = grads.mean(dim=0, keepdim=True) # [1, C] cam = torch.matmul(feat, weights.t()).squeeze() # [H*W] # Reshape to spatial H = W = int(math.sqrt(cam.shape[0])) cam = cam.reshape(H, W) # [H, W] # Normalize cam = F.relu(cam) cam = cam - cam.min() cam = cam / (cam.max() + 1e-8) # Upsample to 256x256 cam = cam.unsqueeze(0).unsqueeze(0) # [1, 1, H, W] cam = F.interpolate(cam, size=(256, 256), mode="bilinear", align_corners=False) cam = cam.squeeze().cpu().numpy() # Get image for overlay img_np = img_tensor.squeeze().detach().cpu().permute(1,2,0).numpy() img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8) plt.figure(figsize=(6,6)) plt.imshow(img_np) plt.imshow(cam, cmap="jet", alpha=0.5) pred_name = id2label.get(target_class, str(target_class)) true_name = id2label.get(label, str(label)) status = "CORRECT" if y_true[idx] == y_pred[idx] else "WRONG" plt.title(f"{status}: Pred={pred_name} | True={true_name}") plt.axis("off") fname = f"{OUTPUT_DIR}/gradcam_{status}_sample{idx}_{pred_name}_vs_{true_name}.png" plt.savefig(fname, bbox_inches="tight", dpi=150) plt.close() print(f" Saved {fname}") except Exception as e: print(f" Skipped sample {idx}: {e}") traceback.print_exc() fwd_handle.remove() bwd_handle.remove() print("\nGrad-CAM complete.") if __name__ == "__main__": main()