Beit-Retinal / train.py
Habeeb Okunade
Update Training script
b43bc74
import os
import json
from PIL import Image
from transformers import AutoImageProcessor, BeitForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
# -------------------------------
# Config
# -------------------------------
CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"]
MODEL_NAME = "microsoft/beit-base-patch16-224"
# Output directory (from env or default)
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/home/user/outputs/beit-retina")
os.makedirs(OUTPUT_DIR, exist_ok=True)
print("HOME dir:", os.environ.get("HOME"))
print("HF cache:", os.environ.get("HF_HOME"))
# -------------------------------
# Metrics
# -------------------------------
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=1)
return {
"accuracy": accuracy_score(labels, preds),
"f1_weighted": f1_score(labels, preds, average="weighted")
}
# -------------------------------
# Preprocessing function
# -------------------------------
def transform(examples, processor):
"""Converts image paths to pixel_values tensors."""
images = [processor(Image.open(p).convert("RGB"), return_tensors="pt")["pixel_values"][0]
for p in examples["image"]]
return {"pixel_values": images}
# -------------------------------
# Training function
# -------------------------------
def train(train_dir="data/train", val_dir="data/val", epochs=5, batch_size=16):
# Load processor and model
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = BeitForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=len(CLASSES),
id2label={i: c for i, c in enumerate(CLASSES)},
label2id={c: i for i, c in enumerate(CLASSES)}
)
# Load dataset
dataset = load_dataset("imagefolder", data_dir={"train": train_dir, "validation": val_dir})
# Map transform over dataset
dataset = dataset.map(lambda x: transform(x, processor), batched=True)
# Ensure dataset returns PyTorch tensors
dataset.set_format(type="torch", columns=["pixel_values", "label"])
# Training arguments
args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=epochs,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1_weighted",
logging_steps=50,
report_to="none"
)
# Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=processor,
compute_metrics=compute_metrics
)
# Train
trainer.train()
# Save model and processor
model.save_pretrained(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
# Save labels
with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f:
json.dump(CLASSES, f)
print("βœ… Training complete. Model saved at:", OUTPUT_DIR)
if __name__ == "__main__":
train()