thyroid-training-scripts / gradcam_fixed.py
Johnyquest7's picture
Upload gradcam_fixed.py
8ee4a1b verified
"""
Thyroid Grad-CAM Visualization (Fixed for SwinV2)
"""
import os, sys, math, json, random, warnings, traceback
warnings.filterwarnings("ignore")
import numpy as np
from PIL import Image
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
HF_USERNAME = "Johnyquest7"
DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset"
MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid"
OUTPUT_DIR = "./gradcam_outputs"
SEED = 42
BATCH_SIZE = 16
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
def main():
print("=" * 60)
print("Thyroid Grad-CAM Visualization")
print("=" * 60)
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nDevice: {device}")
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval()
id2label = model.config.id2label
ds = load_dataset(DATASET_NAME, split="train")
ds = ds.shuffle(seed=SEED)
train_test = ds.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED)
test_ds = train_test["test"]
print(f"Test samples: {len(test_ds)}")
# Get predictions first
all_logits, all_labels = [], []
for i in range(0, len(test_ds), BATCH_SIZE):
batch_items = [test_ds[j] for j in range(i, min(i+BATCH_SIZE, len(test_ds)))]
images = [item["image"].convert("RGB") for item in batch_items]
inputs = processor(images, return_tensors="pt")
with torch.no_grad():
outputs = model(pixel_values=inputs["pixel_values"].to(device))
all_logits.extend(outputs.logits.cpu().numpy())
all_labels.extend([item["label"] for item in batch_items])
y_true = np.array(all_labels)
y_pred = np.argmax(np.array(all_logits), axis=1)
correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]]
incorrect_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]]
random.shuffle(correct_idx)
random.shuffle(incorrect_idx)
selected = correct_idx[:5] + incorrect_idx[:5]
print(f"\nGenerating Grad-CAM for {len(selected)} samples ({len(correct_idx[:5])} correct, {len(incorrect_idx[:5])} incorrect)...")
# Register hooks on last stage
gradcam_data = {}
def fwd_hook(module, input, output):
gradcam_data["feat"] = output.detach()
def bwd_hook(module, grad_input, grad_output):
gradcam_data["grad"] = grad_output[0].detach()
target_layer = model.swinv2.encoder.layers[-1].blocks[-1].layernorm_after
fwd_handle = target_layer.register_forward_hook(fwd_hook)
bwd_handle = target_layer.register_full_backward_hook(bwd_hook)
for idx in selected:
try:
item = test_ds[idx]
img = item["image"].convert("RGB")
label = item["label"]
inputs = processor(img, return_tensors="pt")
img_tensor = inputs["pixel_values"].to(device).requires_grad_(True)
model.zero_grad()
outputs = model(pixel_values=img_tensor)
target_class = int(y_pred[idx])
score = outputs.logits[0, target_class]
score.backward()
feat = gradcam_data["feat"][0] # shape: [H*W, C]
grads = gradcam_data["grad"][0] # shape: [H*W, C]
# Compute Grad-CAM: weighted sum of channels
weights = grads.mean(dim=0, keepdim=True) # [1, C]
cam = torch.matmul(feat, weights.t()).squeeze() # [H*W]
# Reshape to spatial
H = W = int(math.sqrt(cam.shape[0]))
cam = cam.reshape(H, W) # [H, W]
# Normalize
cam = F.relu(cam)
cam = cam - cam.min()
cam = cam / (cam.max() + 1e-8)
# Upsample to 256x256
cam = cam.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
cam = F.interpolate(cam, size=(256, 256), mode="bilinear", align_corners=False)
cam = cam.squeeze().cpu().numpy()
# Get image for overlay
img_np = img_tensor.squeeze().detach().cpu().permute(1,2,0).numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
plt.figure(figsize=(6,6))
plt.imshow(img_np)
plt.imshow(cam, cmap="jet", alpha=0.5)
pred_name = id2label.get(target_class, str(target_class))
true_name = id2label.get(label, str(label))
status = "CORRECT" if y_true[idx] == y_pred[idx] else "WRONG"
plt.title(f"{status}: Pred={pred_name} | True={true_name}")
plt.axis("off")
fname = f"{OUTPUT_DIR}/gradcam_{status}_sample{idx}_{pred_name}_vs_{true_name}.png"
plt.savefig(fname, bbox_inches="tight", dpi=150)
plt.close()
print(f" Saved {fname}")
except Exception as e:
print(f" Skipped sample {idx}: {e}")
traceback.print_exc()
fwd_handle.remove()
bwd_handle.remove()
print("\nGrad-CAM complete.")
if __name__ == "__main__":
main()