thyroid-training-scripts / evaluate_simple.py
Johnyquest7's picture
Upload evaluate_simple.py
f069771 verified
"""
Thyroid Ultrasound Evaluation + Grad-CAM (Pure PyTorch, no Trainer)
"""
import os, sys, math, json, random, warnings, traceback
warnings.filterwarnings("ignore")
import numpy as np
from PIL import Image
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
HF_USERNAME = "Johnyquest7"
DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset"
MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid"
OUTPUT_DIR = "./eval_outputs"
SEED = 42
BATCH_SIZE = 16
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
def main():
print("=" * 60)
print("Thyroid Ultrasound Evaluation + Grad-CAM")
print("=" * 60)
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nDevice: {device}")
print(f"Loading model: {MODEL_NAME}")
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval()
print(f"Model loaded: {sum(p.numel() for p in model.parameters())/1e6:.1f}M params")
print(f"\nLoading dataset: {DATASET_NAME}")
ds = load_dataset(DATASET_NAME, split="train")
ds = ds.shuffle(seed=SEED)
train_test = ds.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED)
test_ds = train_test["test"]
print(f"Test samples: {len(test_ds)} (Benign: {sum(1 for x in test_ds if x['label']==0)}, Malignant: {sum(1 for x in test_ds if x['label']==1)})")
id2label = model.config.id2label
# Simple inference loop
all_logits, all_labels = [], []
print("\nRunning inference...")
for i in range(0, len(test_ds), BATCH_SIZE):
batch_items = [test_ds[j] for j in range(i, min(i+BATCH_SIZE, len(test_ds)))]
images = [item["image"].convert("RGB") if item["image"].mode != "RGB" else item["image"] for item in batch_items]
inputs = processor(images, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device)
with torch.no_grad():
outputs = model(pixel_values=pixel_values)
all_logits.extend(outputs.logits.cpu().numpy())
all_labels.extend([item["label"] for item in batch_items])
if (i // BATCH_SIZE) % 5 == 0:
print(f" Batch {i//BATCH_SIZE + 1}/{(len(test_ds)+BATCH_SIZE-1)//BATCH_SIZE}")
y_true = np.array(all_labels)
y_logits = np.array(all_logits)
y_pred = np.argmax(y_logits, axis=1)
probs = F.softmax(torch.from_numpy(y_logits), dim=1).numpy()
acc = accuracy_score(y_true, y_pred)
prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0)
try:
auc = roc_auc_score(y_true, probs[:, 1])
except Exception:
auc = roc_auc_score(y_true, probs[:, 0])
cm = confusion_matrix(y_true, y_pred)
sens = cm[1,1] / (cm[1,1] + cm[1,0]) if (cm[1,1] + cm[1,0]) > 0 else 0
spec = cm[0,0] / (cm[0,0] + cm[0,1]) if (cm[0,0] + cm[0,1]) > 0 else 0
final = {
"test_accuracy": float(acc),
"test_weighted_f1": float(f1),
"test_weighted_precision": float(prec),
"test_weighted_recall": float(rec),
"test_roc_auc": float(auc),
"test_sensitivity": float(sens),
"test_specificity": float(spec),
"test_confusion_matrix": cm.tolist(),
}
print(f"\n{'='*60}")
print("FINAL TEST METRICS")
print(f"{'='*60}")
for k, v in final.items():
print(f" {k}: {v}")
with open(f"{OUTPUT_DIR}/test_metrics.json", "w") as f:
json.dump(final, f, indent=2)
print(f"\nSaved to {OUTPUT_DIR}/test_metrics.json")
# Grad-CAM
correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]]
incorrect_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]]
random.shuffle(correct_idx)
random.shuffle(incorrect_idx)
selected = correct_idx[:5] + incorrect_idx[:5]
print(f"\nGenerating Grad-CAM for {len(selected)} samples ({len(correct_idx[:5])} correct, {len(incorrect_idx[:5])} incorrect)...")
gradcam_data = {}
def fwd_hook(module, input, output):
gradcam_data["feat"] = output.detach()
def bwd_hook(module, grad_input, grad_output):
gradcam_data["grad"] = grad_output[0].detach()
target_layer = model.swinv2.encoder.layers[-1].blocks[-1].layernorm_after
fwd_handle = target_layer.register_forward_hook(fwd_hook)
bwd_handle = target_layer.register_full_backward_hook(bwd_hook)
for idx in selected:
try:
item = test_ds[idx]
img = item["image"].convert("RGB")
label = item["label"]
inputs = processor(img, return_tensors="pt")
img_tensor = inputs["pixel_values"].to(device).requires_grad_(True)
model.zero_grad()
outputs = model(pixel_values=img_tensor)
target_class = int(y_pred[idx])
score = outputs.logits[0, target_class]
score.backward()
feat = gradcam_data["feat"][0]
grads = gradcam_data["grad"][0]
if feat.dim() == 3:
weights = grads.mean(dim=0, keepdim=True)
cam = torch.matmul(feat, weights.t()).squeeze()
H = W = int(math.sqrt(cam.shape[0]))
cam = cam.reshape(H, W)
else:
weights = grads.mean(dim=(0,1), keepdim=True)
cam = (feat * weights).sum(dim=-1).squeeze()
cam = F.relu(cam)
cam = cam - cam.min()
cam = cam / (cam.max() + 1e-8)
cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(256,256), mode="bilinear", align_corners=False)
cam = cam.squeeze().cpu().numpy()
img_np = img_tensor.squeeze().detach().cpu().permute(1,2,0).numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
plt.figure(figsize=(6,6))
plt.imshow(img_np)
plt.imshow(cam, cmap="jet", alpha=0.5)
pred_name = id2label.get(target_class, str(target_class))
true_name = id2label.get(label, str(label))
plt.title(f"Pred: {pred_name} | True: {true_name}")
plt.axis("off")
fname = f"{OUTPUT_DIR}/gradcam_sample_{idx}_pred{pred_name}_true{true_name}.png"
plt.savefig(fname, bbox_inches="tight", dpi=150)
plt.close()
print(f" Saved {fname}")
except Exception as e:
print(f" Skipped sample {idx}: {e}")
traceback.print_exc()
fwd_handle.remove()
bwd_handle.remove()
print("\nEvaluation complete.")
if __name__ == "__main__":
main()