Beit-Retinal / train2.py
Habeeb Okunade
Update the training script
05c5199
import os
import json
import torch
from datasets import load_dataset
from transformers import (
AutoImageProcessor,
BeitForImageClassification,
TrainingArguments,
Trainer
)
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from PIL import Image
# ----------------------------
# CONFIG
# ----------------------------
MODEL_NAME = "microsoft/beit-base-patch16-224"
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.expanduser("~/outputs/beit-retina"))
DATA_DIR = os.environ.get("DATA_DIR", "data2") # dynamic dataset path
print(f"๐Ÿ”น OUTPUT_DIR set to: {OUTPUT_DIR}")
print(f"๐Ÿ”น DATA_DIR set to: {DATA_DIR}")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ----------------------------
# LOAD DATASET
# ----------------------------
print(f"๐Ÿ”น Loading dataset from '{DATA_DIR}' folder...")
dataset = load_dataset("imagefolder", data_dir=DATA_DIR)
print(f"๐Ÿ”น Dataset loaded. Columns: {dataset['train'].column_names}")
print(f"๐Ÿ”น Dataset splits: {list(dataset.keys())}")
print(f"๐Ÿ”น Number of training samples: {len(dataset['train'])}")
print(f"๐Ÿ”น Number of validation samples: {len(dataset['validation'])}")
# ----------------------------
# PREPROCESSOR
# ----------------------------
print(f"๐Ÿ”น Loading processor from {MODEL_NAME}...")
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
def transform(example):
# Detect image column
image_column = "image" if "image" in example else [c for c in example.keys() if c != "label"][0]
images = example[image_column]
if not isinstance(images, list):
images = [images]
processed_images = []
for img in images:
if isinstance(img, str):
print(f" โ†ช Opening image from path: {img}")
img = Image.open(img).convert("RGB")
elif isinstance(img, Image.Image):
print(" โ†ช Using PIL.Image directly")
img = img.convert("RGB")
else:
raise ValueError(f"Unknown type for image: {type(img)}")
processed_images.append(img)
inputs = processor(images=processed_images, return_tensors="pt")
labels = example["label"]
if not isinstance(labels, list):
labels = [labels]
inputs["labels"] = torch.tensor(labels)
return inputs
print("๐Ÿ”น Applying transform to dataset...")
dataset = dataset.with_transform(transform)
print("๐Ÿ”น Transform applied successfully.")
# ----------------------------
# MODEL
# ----------------------------
print(f"๐Ÿ”น Loading BEiT model ({MODEL_NAME}) with {len(dataset['train'].features['label'].names)} classes...")
model = BeitForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=len(dataset["train"].features["label"].names),
ignore_mismatched_sizes=True
)
print("๐Ÿ”น Model loaded successfully.")
# ----------------------------
# METRICS
# ----------------------------
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = logits.argmax(axis=-1)
metrics = {
"accuracy": accuracy_score(labels, preds),
"precision": precision_score(labels, preds, average="macro"),
"recall": recall_score(labels, preds, average="macro"),
"f1": f1_score(labels, preds, average="macro"),
}
print(f"๐Ÿ”น Metrics computed: {metrics}")
return metrics
# ----------------------------
# TRAINING ARGUMENTS
# ----------------------------
args = TrainingArguments(
output_dir=OUTPUT_DIR,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=5,
weight_decay=0.01,
logging_dir=os.path.join(OUTPUT_DIR, "logs"),
logging_steps=10,
push_to_hub=False
)
print("๐Ÿ”น TrainingArguments configured.")
# ----------------------------
# TRAINER
# ----------------------------
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=processor,
compute_metrics=compute_metrics
)
print("๐Ÿ”น Trainer created. Ready to train.")
# ----------------------------
# TRAIN
# ----------------------------
print("๐Ÿ”น Starting training...")
trainer.train()
print("๐Ÿ”น Training complete.")
# ----------------------------
# SAVE MODEL + PROCESSOR + LABELS
# ----------------------------
print("๐Ÿ”น Saving final model and processor...")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
labels = dataset["train"].features["label"].names
with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f:
json.dump(labels, f)
print(f"โœ… Model and processor saved to {OUTPUT_DIR}")