| """ |
| Thyroid Ultrasound Evaluation + Grad-CAM Visualization |
| Evaluates model on test set and generates attention visualizations. |
| """ |
| import os, sys, io, math, json, random, warnings, base64, traceback |
| warnings.filterwarnings("ignore") |
|
|
| import numpy as np |
| from PIL import Image |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| import torch |
| import torch.nn.functional as F |
| from datasets import load_dataset |
| from transformers import ( |
| AutoImageProcessor, AutoModelForImageClassification, |
| Trainer, TrainingArguments, DefaultDataCollator |
| ) |
| from sklearn.metrics import ( |
| accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix |
| ) |
|
|
| os.environ["TRACKIO_SPACE_ID"] = "" |
| os.environ["TRACKIO_PROJECT"] = "" |
|
|
| HF_USERNAME = "Johnyquest7" |
| DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset" |
| MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid" |
| OUTPUT_DIR = "./eval_outputs" |
| SEED = 42 |
| MAX_SAMPLES_GRADCAM = 20 |
|
|
| random.seed(SEED) |
| np.random.seed(SEED) |
| torch.manual_seed(SEED) |
|
|
| def main(): |
| print("=" * 60) |
| print("Thyroid Ultrasound Model Evaluation + Grad-CAM") |
| print("=" * 60) |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\nDevice: {device}") |
| print(f"Loading model: {MODEL_NAME}") |
|
|
| try: |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval() |
| except Exception as e: |
| print(f"Model loading failed: {e}") |
| sys.exit(1) |
|
|
| print(f"\nLoading dataset: {DATASET_NAME}") |
| ds = load_dataset(DATASET_NAME, split="train") |
| ds = ds.shuffle(seed=SEED) |
| train_test = ds.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED) |
| test_ds = train_test["test"] |
| print(f"Test samples: {len(test_ds)}") |
|
|
| id2label = model.config.id2label |
| label2id = model.config.label2id |
|
|
| def transform(examples): |
| images = [img.convert("RGB") if img.mode != "RGB" else img for img in examples["image"]] |
| return processor(images, return_tensors="pt") |
|
|
| test_ds.set_transform(transform) |
|
|
| |
| print("\nRunning evaluation...") |
| args = TrainingArguments( |
| output_dir="/tmp/eval", per_device_eval_batch_size=16, |
| remove_unused_columns=False, disable_tqdm=True, |
| logging_strategy="steps", logging_first_step=True, |
| report_to=[] |
| ) |
| trainer = Trainer(model=model, args=args, data_collator=DefaultDataCollator(), |
| eval_dataset=test_ds) |
| metrics = trainer.evaluate() |
| print(f"\nRaw metrics: {metrics}") |
|
|
| |
| all_logits, all_labels = [], [] |
| for i in range(0, len(test_ds), 16): |
| batch = test_ds[i:i+16] |
| inputs = {k: torch.stack([v for v in batch[k]]).to(device) if isinstance(batch[k][0], torch.Tensor) else None |
| for k in batch if k in processor.model_input_names or k == "pixel_values"} |
| if "pixel_values" in inputs and inputs["pixel_values"] is not None: |
| with torch.no_grad(): |
| outputs = model(pixel_values=inputs["pixel_values"]) |
| all_logits.extend(outputs.logits.cpu().numpy()) |
| all_labels.extend(batch["label"]) |
|
|
| y_true = np.array(all_labels) |
| y_logits = np.array(all_logits) |
| y_pred = np.argmax(y_logits, axis=1) |
| probs = F.softmax(torch.from_numpy(y_logits), dim=1).numpy() |
|
|
| acc = accuracy_score(y_true, y_pred) |
| prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted") |
| try: |
| auc = roc_auc_score(y_true, probs[:, 1]) |
| except: |
| auc = roc_auc_score(y_true, probs[:, 0]) |
| cm = confusion_matrix(y_true, y_pred) |
|
|
| final = { |
| "test_accuracy": float(acc), |
| "test_weighted_f1": float(f1), |
| "test_weighted_precision": float(prec), |
| "test_weighted_recall": float(rec), |
| "test_roc_auc": float(auc), |
| "test_confusion_matrix": cm.tolist(), |
| "eval_loss": float(metrics.get("eval_loss", 0)), |
| } |
| print(f"\n{'='*60}") |
| print("FINAL TEST METRICS") |
| print(f"{'='*60}") |
| for k, v in final.items(): |
| print(f" {k}: {v}") |
| json.dump(final, open(f"{OUTPUT_DIR}/test_metrics.json", "w"), indent=2) |
| print(f"\nSaved to {OUTPUT_DIR}/test_metrics.json") |
|
|
| |
| correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]] |
| incorrect_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]] |
| random.shuffle(correct_idx) |
| random.shuffle(incorrect_idx) |
| selected = correct_idx[:min(5, len(correct_idx))] + incorrect_idx[:min(5, len(incorrect_idx))] |
| print(f"\nGenerating Grad-CAM for {len(selected)} samples ({len(correct_idx[:5])} correct, {len(incorrect_idx[:5])} incorrect)...") |
|
|
| |
| gradcam_data = {} |
| def fwd_hook(module, input, output): |
| gradcam_data["feat"] = output.detach() |
| def bwd_hook(module, grad_input, grad_output): |
| gradcam_data["grad"] = grad_output[0].detach() |
|
|
| target_layer = model.swinv2.encoder.layers[-1].blocks[-1].layernorm_after |
| fwd_handle = target_layer.register_forward_hook(fwd_hook) |
| bwd_handle = target_layer.register_full_backward_hook(bwd_hook) |
|
|
| for idx in selected[:MAX_SAMPLES_GRADCAM]: |
| try: |
| sample = test_ds[idx] |
| label = sample["label"] |
| img_tensor = sample["pixel_values"].unsqueeze(0).to(device).requires_grad_(True) |
| model.zero_grad() |
| outputs = model(pixel_values=img_tensor) |
| target_class = int(y_pred[idx]) |
| score = outputs.logits[0, target_class] |
| score.backward() |
|
|
| feat = gradcam_data["feat"][0] |
| grads = gradcam_data["grad"][0] |
| if feat.dim() == 3: |
| weights = grads.mean(dim=0, keepdim=True) |
| cam = torch.matmul(feat, weights.t()).squeeze() |
| H = W = int(math.sqrt(cam.shape[0])) |
| cam = cam.reshape(H, W) |
| else: |
| weights = grads.mean(dim=(0,1), keepdim=True) |
| cam = (feat * weights).sum(dim=-1).squeeze() |
|
|
| cam = F.relu(cam) |
| cam = cam - cam.min() |
| cam = cam / (cam.max() + 1e-8) |
| cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(256,256), mode="bilinear", align_corners=False) |
| cam = cam.squeeze().cpu().numpy() |
|
|
| |
| img_np = img_tensor.squeeze().detach().cpu().permute(1,2,0).numpy() |
| img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8) |
| plt.figure(figsize=(6,6)) |
| plt.imshow(img_np) |
| plt.imshow(cam, cmap="jet", alpha=0.5) |
| plt.title(f"Pred: {id2label[target_class]} | True: {id2label[label]}") |
| plt.axis("off") |
| fname = f"{OUTPUT_DIR}/gradcam_sample_{idx}_pred{id2label[target_class]}_true{id2label[label]}.png" |
| plt.savefig(fname, bbox_inches="tight", dpi=150) |
| plt.close() |
| print(f" Saved {fname}") |
| except Exception as e: |
| print(f" Skipped sample {idx}: {e}") |
| traceback.print_exc() |
|
|
| fwd_handle.remove() |
| bwd_handle.remove() |
|
|
| |
| print("\nEvaluation complete.") |
| print(f"Results saved to {OUTPUT_DIR}/") |
|
|
| if __name__ == "__main__": |
| main() |
|
|