#!/usr/bin/env python3 """ Thyroid Ultrasound Nodule Malignancy Classification Dataset: BTX24/thyroid-cancer-classification-ultrasound-dataset Binary classification: benign (0) vs malignant (1) """ import os import sys import numpy as np from collections import Counter from datasets import load_dataset from transformers import ( AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer, DefaultDataCollator, EarlyStoppingCallback, ) import evaluate import torch from torchvision.transforms import ( Compose, Resize, RandomRotation, RandomHorizontalFlip, RandomVerticalFlip, ColorJitter, ToTensor, Normalize ) # ------------------------------------------------------------------ # Config # ------------------------------------------------------------------ DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset" MODEL_NAME = "microsoft/swinv2-base-patch4-window8-256" OUTPUT_DIR = "./thyroid-swinv2-model" HUB_MODEL_ID = "Johnyquest7/ML-Inter_thyroid" NUM_LABELS = 2 ID2LABEL = {0: "benign", 1: "malignant"} LABEL2ID = {"benign": 0, "malignant": 1} # ------------------------------------------------------------------ # Metrics # ------------------------------------------------------------------ accuracy = evaluate.load("accuracy") f1 = evaluate.load("f1") precision = evaluate.load("precision") recall = evaluate.load("recall") roc_auc = evaluate.load("roc_auc") def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=1) probs = torch.softmax(torch.tensor(logits), dim=1)[:, 1].numpy() result = {} result.update(accuracy.compute(predictions=preds, references=labels)) result.update(f1.compute(predictions=preds, references=labels, average="binary")) result.update(precision.compute(predictions=preds, references=labels, average="binary")) result.update(recall.compute(predictions=preds, references=labels, average="binary")) try: result.update(roc_auc.compute(prediction_scores=probs, references=labels)) except Exception: result["roc_auc"] = 0.0 return result # ------------------------------------------------------------------ # Load dataset # ------------------------------------------------------------------ print("Loading dataset...") train_ds = load_dataset(DATASET_NAME, split="train") test_ds = load_dataset(DATASET_NAME, split="test") # Create validation split from train train_val = train_ds.train_test_split(test_size=0.15, stratify_by_column="label", seed=42) train_ds = train_val["train"] val_ds = train_val["test"] print(f"Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}") print(f"Train labels: {Counter(train_ds['label'])}") print(f"Val labels: {Counter(val_ds['label'])}") print(f"Test labels: {Counter(test_ds['label'])}") # ------------------------------------------------------------------ # Image processor & transforms # ------------------------------------------------------------------ print(f"Loading image processor from {MODEL_NAME}...") image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME) # Ultrasound images are grayscale (mode 'L') — convert to RGB for SwinV2 image_mean = image_processor.image_mean image_std = image_processor.image_std size = ( image_processor.size["shortest_edge"] if "shortest_edge" in image_processor.size else (image_processor.size["height"], image_processor.size["width"]) ) train_transforms = Compose([ Resize(size), RandomRotation(degrees=10), RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.3), ColorJitter(brightness=0.2, contrast=0.2), ToTensor(), Normalize(mean=image_mean, std=image_std), ]) val_transforms = Compose([ Resize(size), ToTensor(), Normalize(mean=image_mean, std=image_std), ]) def preprocess_train(examples): # Convert grayscale to RGB examples["pixel_values"] = [ train_transforms(img.convert("RGB")) for img in examples["image"] ] del examples["image"] return examples def preprocess_val(examples): examples["pixel_values"] = [ val_transforms(img.convert("RGB")) for img in examples["image"] ] del examples["image"] return examples print("Applying transforms...") train_ds = train_ds.with_transform(preprocess_train) val_ds = val_ds.with_transform(preprocess_val) test_ds = test_ds.with_transform(preprocess_val) # ------------------------------------------------------------------ # Model # ------------------------------------------------------------------ print(f"Loading model {MODEL_NAME}...") model = AutoModelForImageClassification.from_pretrained( MODEL_NAME, num_labels=NUM_LABELS, id2label=ID2LABEL, label2id=LABEL2ID, ignore_mismatched_sizes=True, ) # ------------------------------------------------------------------ # Training arguments # ------------------------------------------------------------------ training_args = TrainingArguments( output_dir=OUTPUT_DIR, remove_unused_columns=False, eval_strategy="epoch", save_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=16, per_device_eval_batch_size=16, gradient_accumulation_steps=2, num_train_epochs=30, warmup_steps=100, weight_decay=0.01, logging_strategy="steps", logging_steps=10, logging_first_step=True, disable_tqdm=True, load_best_model_at_end=True, metric_for_best_model="eval_roc_auc", greater_is_better=True, push_to_hub=True, hub_model_id=HUB_MODEL_ID, report_to="trackio", run_name="thyroid-swinv2-binary", project="thyroid-malignancy", seed=42, bf16=True, dataloader_num_workers=4, ) # ------------------------------------------------------------------ # Trainer # ------------------------------------------------------------------ data_collator = DefaultDataCollator() trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=train_ds, eval_dataset=val_ds, processing_class=image_processor, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=5)], ) print("Starting training...") trainer.train() print("Evaluating on test set...") test_results = trainer.evaluate(test_ds, metric_key_prefix="test") print("Test results:", test_results) print("Pushing to Hub...") trainer.push_to_hub() print("Done!")