thyroid-training-scripts / train_thyroid.py
Johnyquest7's picture
Upload train_thyroid.py
b66d95f verified
#!/usr/bin/env python3
"""
Thyroid Ultrasound Nodule Malignancy Classification
Dataset: BTX24/thyroid-cancer-classification-ultrasound-dataset
Binary classification: benign (0) vs malignant (1)
"""
import os
import sys
import numpy as np
from collections import Counter
from datasets import load_dataset
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
TrainingArguments,
Trainer,
DefaultDataCollator,
EarlyStoppingCallback,
)
import evaluate
import torch
from torchvision.transforms import (
Compose, Resize, RandomRotation, RandomHorizontalFlip,
RandomVerticalFlip, ColorJitter, ToTensor, Normalize
)
# ------------------------------------------------------------------
# Config
# ------------------------------------------------------------------
DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset"
MODEL_NAME = "microsoft/swinv2-base-patch4-window8-256"
OUTPUT_DIR = "./thyroid-swinv2-model"
HUB_MODEL_ID = "Johnyquest7/ML-Inter_thyroid"
NUM_LABELS = 2
ID2LABEL = {0: "benign", 1: "malignant"}
LABEL2ID = {"benign": 0, "malignant": 1}
# ------------------------------------------------------------------
# Metrics
# ------------------------------------------------------------------
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
precision = evaluate.load("precision")
recall = evaluate.load("recall")
roc_auc = evaluate.load("roc_auc")
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=1)
probs = torch.softmax(torch.tensor(logits), dim=1)[:, 1].numpy()
result = {}
result.update(accuracy.compute(predictions=preds, references=labels))
result.update(f1.compute(predictions=preds, references=labels, average="binary"))
result.update(precision.compute(predictions=preds, references=labels, average="binary"))
result.update(recall.compute(predictions=preds, references=labels, average="binary"))
try:
result.update(roc_auc.compute(prediction_scores=probs, references=labels))
except Exception:
result["roc_auc"] = 0.0
return result
# ------------------------------------------------------------------
# Load dataset
# ------------------------------------------------------------------
print("Loading dataset...")
train_ds = load_dataset(DATASET_NAME, split="train")
test_ds = load_dataset(DATASET_NAME, split="test")
# Create validation split from train
train_val = train_ds.train_test_split(test_size=0.15, stratify_by_column="label", seed=42)
train_ds = train_val["train"]
val_ds = train_val["test"]
print(f"Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}")
print(f"Train labels: {Counter(train_ds['label'])}")
print(f"Val labels: {Counter(val_ds['label'])}")
print(f"Test labels: {Counter(test_ds['label'])}")
# ------------------------------------------------------------------
# Image processor & transforms
# ------------------------------------------------------------------
print(f"Loading image processor from {MODEL_NAME}...")
image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
# Ultrasound images are grayscale (mode 'L') — convert to RGB for SwinV2
image_mean = image_processor.image_mean
image_std = image_processor.image_std
size = (
image_processor.size["shortest_edge"]
if "shortest_edge" in image_processor.size
else (image_processor.size["height"], image_processor.size["width"])
)
train_transforms = Compose([
Resize(size),
RandomRotation(degrees=10),
RandomHorizontalFlip(p=0.5),
RandomVerticalFlip(p=0.3),
ColorJitter(brightness=0.2, contrast=0.2),
ToTensor(),
Normalize(mean=image_mean, std=image_std),
])
val_transforms = Compose([
Resize(size),
ToTensor(),
Normalize(mean=image_mean, std=image_std),
])
def preprocess_train(examples):
# Convert grayscale to RGB
examples["pixel_values"] = [
train_transforms(img.convert("RGB")) for img in examples["image"]
]
del examples["image"]
return examples
def preprocess_val(examples):
examples["pixel_values"] = [
val_transforms(img.convert("RGB")) for img in examples["image"]
]
del examples["image"]
return examples
print("Applying transforms...")
train_ds = train_ds.with_transform(preprocess_train)
val_ds = val_ds.with_transform(preprocess_val)
test_ds = test_ds.with_transform(preprocess_val)
# ------------------------------------------------------------------
# Model
# ------------------------------------------------------------------
print(f"Loading model {MODEL_NAME}...")
model = AutoModelForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=NUM_LABELS,
id2label=ID2LABEL,
label2id=LABEL2ID,
ignore_mismatched_sizes=True,
)
# ------------------------------------------------------------------
# Training arguments
# ------------------------------------------------------------------
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
remove_unused_columns=False,
eval_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
gradient_accumulation_steps=2,
num_train_epochs=30,
warmup_steps=100,
weight_decay=0.01,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
load_best_model_at_end=True,
metric_for_best_model="eval_roc_auc",
greater_is_better=True,
push_to_hub=True,
hub_model_id=HUB_MODEL_ID,
report_to="trackio",
run_name="thyroid-swinv2-binary",
project="thyroid-malignancy",
seed=42,
bf16=True,
dataloader_num_workers=4,
)
# ------------------------------------------------------------------
# Trainer
# ------------------------------------------------------------------
data_collator = DefaultDataCollator()
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_ds,
eval_dataset=val_ds,
processing_class=image_processor,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)
print("Starting training...")
trainer.train()
print("Evaluating on test set...")
test_results = trainer.evaluate(test_ds, metric_key_prefix="test")
print("Test results:", test_results)
print("Pushing to Hub...")
trainer.push_to_hub()
print("Done!")