| |
| """ |
| Thyroid Ultrasound Nodule Malignancy Classification |
| Dataset: BTX24/thyroid-cancer-classification-ultrasound-dataset |
| Binary classification: benign (0) vs malignant (1) |
| """ |
|
|
| import os |
| import sys |
| import numpy as np |
| from collections import Counter |
|
|
| from datasets import load_dataset |
| from transformers import ( |
| AutoImageProcessor, |
| AutoModelForImageClassification, |
| TrainingArguments, |
| Trainer, |
| DefaultDataCollator, |
| EarlyStoppingCallback, |
| ) |
| import evaluate |
| import torch |
| from torchvision.transforms import ( |
| Compose, Resize, RandomRotation, RandomHorizontalFlip, |
| RandomVerticalFlip, ColorJitter, ToTensor, Normalize |
| ) |
|
|
| |
| |
| |
| DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset" |
| MODEL_NAME = "microsoft/swinv2-base-patch4-window8-256" |
| OUTPUT_DIR = "./thyroid-swinv2-model" |
| HUB_MODEL_ID = "Johnyquest7/ML-Inter_thyroid" |
|
|
| NUM_LABELS = 2 |
| ID2LABEL = {0: "benign", 1: "malignant"} |
| LABEL2ID = {"benign": 0, "malignant": 1} |
|
|
| |
| |
| |
| accuracy = evaluate.load("accuracy") |
| f1 = evaluate.load("f1") |
| precision = evaluate.load("precision") |
| recall = evaluate.load("recall") |
| roc_auc = evaluate.load("roc_auc") |
|
|
| def compute_metrics(eval_pred): |
| logits, labels = eval_pred |
| preds = np.argmax(logits, axis=1) |
| probs = torch.softmax(torch.tensor(logits), dim=1)[:, 1].numpy() |
|
|
| result = {} |
| result.update(accuracy.compute(predictions=preds, references=labels)) |
| result.update(f1.compute(predictions=preds, references=labels, average="binary")) |
| result.update(precision.compute(predictions=preds, references=labels, average="binary")) |
| result.update(recall.compute(predictions=preds, references=labels, average="binary")) |
| try: |
| result.update(roc_auc.compute(prediction_scores=probs, references=labels)) |
| except Exception: |
| result["roc_auc"] = 0.0 |
| return result |
|
|
| |
| |
| |
| print("Loading dataset...") |
| train_ds = load_dataset(DATASET_NAME, split="train") |
| test_ds = load_dataset(DATASET_NAME, split="test") |
|
|
| |
| train_val = train_ds.train_test_split(test_size=0.15, stratify_by_column="label", seed=42) |
| train_ds = train_val["train"] |
| val_ds = train_val["test"] |
|
|
| print(f"Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}") |
| print(f"Train labels: {Counter(train_ds['label'])}") |
| print(f"Val labels: {Counter(val_ds['label'])}") |
| print(f"Test labels: {Counter(test_ds['label'])}") |
|
|
| |
| |
| |
| print(f"Loading image processor from {MODEL_NAME}...") |
| image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
|
|
| |
| image_mean = image_processor.image_mean |
| image_std = image_processor.image_std |
| size = ( |
| image_processor.size["shortest_edge"] |
| if "shortest_edge" in image_processor.size |
| else (image_processor.size["height"], image_processor.size["width"]) |
| ) |
|
|
| train_transforms = Compose([ |
| Resize(size), |
| RandomRotation(degrees=10), |
| RandomHorizontalFlip(p=0.5), |
| RandomVerticalFlip(p=0.3), |
| ColorJitter(brightness=0.2, contrast=0.2), |
| ToTensor(), |
| Normalize(mean=image_mean, std=image_std), |
| ]) |
|
|
| val_transforms = Compose([ |
| Resize(size), |
| ToTensor(), |
| Normalize(mean=image_mean, std=image_std), |
| ]) |
|
|
| def preprocess_train(examples): |
| |
| examples["pixel_values"] = [ |
| train_transforms(img.convert("RGB")) for img in examples["image"] |
| ] |
| del examples["image"] |
| return examples |
|
|
| def preprocess_val(examples): |
| examples["pixel_values"] = [ |
| val_transforms(img.convert("RGB")) for img in examples["image"] |
| ] |
| del examples["image"] |
| return examples |
|
|
| print("Applying transforms...") |
| train_ds = train_ds.with_transform(preprocess_train) |
| val_ds = val_ds.with_transform(preprocess_val) |
| test_ds = test_ds.with_transform(preprocess_val) |
|
|
| |
| |
| |
| print(f"Loading model {MODEL_NAME}...") |
| model = AutoModelForImageClassification.from_pretrained( |
| MODEL_NAME, |
| num_labels=NUM_LABELS, |
| id2label=ID2LABEL, |
| label2id=LABEL2ID, |
| ignore_mismatched_sizes=True, |
| ) |
|
|
| |
| |
| |
| training_args = TrainingArguments( |
| output_dir=OUTPUT_DIR, |
| remove_unused_columns=False, |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| learning_rate=2e-5, |
| per_device_train_batch_size=16, |
| per_device_eval_batch_size=16, |
| gradient_accumulation_steps=2, |
| num_train_epochs=30, |
| warmup_steps=100, |
| weight_decay=0.01, |
| logging_strategy="steps", |
| logging_steps=10, |
| logging_first_step=True, |
| disable_tqdm=True, |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_roc_auc", |
| greater_is_better=True, |
| push_to_hub=True, |
| hub_model_id=HUB_MODEL_ID, |
| report_to="trackio", |
| run_name="thyroid-swinv2-binary", |
| project="thyroid-malignancy", |
| seed=42, |
| bf16=True, |
| dataloader_num_workers=4, |
| ) |
|
|
| |
| |
| |
| data_collator = DefaultDataCollator() |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| data_collator=data_collator, |
| train_dataset=train_ds, |
| eval_dataset=val_ds, |
| processing_class=image_processor, |
| compute_metrics=compute_metrics, |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=5)], |
| ) |
|
|
| print("Starting training...") |
| trainer.train() |
|
|
| print("Evaluating on test set...") |
| test_results = trainer.evaluate(test_ds, metric_key_prefix="test") |
| print("Test results:", test_results) |
|
|
| print("Pushing to Hub...") |
| trainer.push_to_hub() |
|
|
| print("Done!") |
|
|