""" Thyroid Ultrasound Evaluation + Grad-CAM (Pure PyTorch, no Trainer) """ import os, sys, math, json, random, warnings, traceback warnings.filterwarnings("ignore") import numpy as np from PIL import Image import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import torch import torch.nn.functional as F from datasets import load_dataset from transformers import AutoImageProcessor, AutoModelForImageClassification from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix HF_USERNAME = "Johnyquest7" DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset" MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid" OUTPUT_DIR = "./eval_outputs" SEED = 42 BATCH_SIZE = 16 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) def main(): print("=" * 60) print("Thyroid Ultrasound Evaluation + Grad-CAM") print("=" * 60) os.makedirs(OUTPUT_DIR, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\nDevice: {device}") print(f"Loading model: {MODEL_NAME}") processor = AutoImageProcessor.from_pretrained(MODEL_NAME) model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval() print(f"Model loaded: {sum(p.numel() for p in model.parameters())/1e6:.1f}M params") print(f"\nLoading dataset: {DATASET_NAME}") ds = load_dataset(DATASET_NAME, split="train") ds = ds.shuffle(seed=SEED) train_test = ds.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED) test_ds = train_test["test"] print(f"Test samples: {len(test_ds)} (Benign: {sum(1 for x in test_ds if x['label']==0)}, Malignant: {sum(1 for x in test_ds if x['label']==1)})") id2label = model.config.id2label # Simple inference loop all_logits, all_labels = [], [] print("\nRunning inference...") for i in range(0, len(test_ds), BATCH_SIZE): batch_items = [test_ds[j] for j in range(i, min(i+BATCH_SIZE, len(test_ds)))] images = [item["image"].convert("RGB") if item["image"].mode != "RGB" else item["image"] for item in batch_items] inputs = processor(images, return_tensors="pt") pixel_values = inputs["pixel_values"].to(device) with torch.no_grad(): outputs = model(pixel_values=pixel_values) all_logits.extend(outputs.logits.cpu().numpy()) all_labels.extend([item["label"] for item in batch_items]) if (i // BATCH_SIZE) % 5 == 0: print(f" Batch {i//BATCH_SIZE + 1}/{(len(test_ds)+BATCH_SIZE-1)//BATCH_SIZE}") y_true = np.array(all_labels) y_logits = np.array(all_logits) y_pred = np.argmax(y_logits, axis=1) probs = F.softmax(torch.from_numpy(y_logits), dim=1).numpy() acc = accuracy_score(y_true, y_pred) prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0) try: auc = roc_auc_score(y_true, probs[:, 1]) except Exception: auc = roc_auc_score(y_true, probs[:, 0]) cm = confusion_matrix(y_true, y_pred) sens = cm[1,1] / (cm[1,1] + cm[1,0]) if (cm[1,1] + cm[1,0]) > 0 else 0 spec = cm[0,0] / (cm[0,0] + cm[0,1]) if (cm[0,0] + cm[0,1]) > 0 else 0 final = { "test_accuracy": float(acc), "test_weighted_f1": float(f1), "test_weighted_precision": float(prec), "test_weighted_recall": float(rec), "test_roc_auc": float(auc), "test_sensitivity": float(sens), "test_specificity": float(spec), "test_confusion_matrix": cm.tolist(), } print(f"\n{'='*60}") print("FINAL TEST METRICS") print(f"{'='*60}") for k, v in final.items(): print(f" {k}: {v}") with open(f"{OUTPUT_DIR}/test_metrics.json", "w") as f: json.dump(final, f, indent=2) print(f"\nSaved to {OUTPUT_DIR}/test_metrics.json") # Grad-CAM correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]] incorrect_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]] random.shuffle(correct_idx) random.shuffle(incorrect_idx) selected = correct_idx[:5] + incorrect_idx[:5] print(f"\nGenerating Grad-CAM for {len(selected)} samples ({len(correct_idx[:5])} correct, {len(incorrect_idx[:5])} incorrect)...") gradcam_data = {} def fwd_hook(module, input, output): gradcam_data["feat"] = output.detach() def bwd_hook(module, grad_input, grad_output): gradcam_data["grad"] = grad_output[0].detach() target_layer = model.swinv2.encoder.layers[-1].blocks[-1].layernorm_after fwd_handle = target_layer.register_forward_hook(fwd_hook) bwd_handle = target_layer.register_full_backward_hook(bwd_hook) for idx in selected: try: item = test_ds[idx] img = item["image"].convert("RGB") label = item["label"] inputs = processor(img, return_tensors="pt") img_tensor = inputs["pixel_values"].to(device).requires_grad_(True) model.zero_grad() outputs = model(pixel_values=img_tensor) target_class = int(y_pred[idx]) score = outputs.logits[0, target_class] score.backward() feat = gradcam_data["feat"][0] grads = gradcam_data["grad"][0] if feat.dim() == 3: weights = grads.mean(dim=0, keepdim=True) cam = torch.matmul(feat, weights.t()).squeeze() H = W = int(math.sqrt(cam.shape[0])) cam = cam.reshape(H, W) else: weights = grads.mean(dim=(0,1), keepdim=True) cam = (feat * weights).sum(dim=-1).squeeze() cam = F.relu(cam) cam = cam - cam.min() cam = cam / (cam.max() + 1e-8) cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(256,256), mode="bilinear", align_corners=False) cam = cam.squeeze().cpu().numpy() img_np = img_tensor.squeeze().detach().cpu().permute(1,2,0).numpy() img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8) plt.figure(figsize=(6,6)) plt.imshow(img_np) plt.imshow(cam, cmap="jet", alpha=0.5) pred_name = id2label.get(target_class, str(target_class)) true_name = id2label.get(label, str(label)) plt.title(f"Pred: {pred_name} | True: {true_name}") plt.axis("off") fname = f"{OUTPUT_DIR}/gradcam_sample_{idx}_pred{pred_name}_true{true_name}.png" plt.savefig(fname, bbox_inches="tight", dpi=150) plt.close() print(f" Saved {fname}") except Exception as e: print(f" Skipped sample {idx}: {e}") traceback.print_exc() fwd_handle.remove() bwd_handle.remove() print("\nEvaluation complete.") if __name__ == "__main__": main()