import os import json from PIL import Image from transformers import AutoImageProcessor, BeitForImageClassification, TrainingArguments, Trainer from datasets import load_dataset from sklearn.metrics import accuracy_score, f1_score import numpy as np # ------------------------------- # Config # ------------------------------- CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"] MODEL_NAME = "microsoft/beit-base-patch16-224" # Output directory (from env or default) OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/home/user/outputs/beit-retina") os.makedirs(OUTPUT_DIR, exist_ok=True) print("HOME dir:", os.environ.get("HOME")) print("HF cache:", os.environ.get("HF_HOME")) # ------------------------------- # Metrics # ------------------------------- def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=1) return { "accuracy": accuracy_score(labels, preds), "f1_weighted": f1_score(labels, preds, average="weighted") } # ------------------------------- # Preprocessing function # ------------------------------- def transform(examples, processor): """Converts image paths to pixel_values tensors.""" images = [processor(Image.open(p).convert("RGB"), return_tensors="pt")["pixel_values"][0] for p in examples["image"]] return {"pixel_values": images} # ------------------------------- # Training function # ------------------------------- def train(train_dir="data/train", val_dir="data/val", epochs=5, batch_size=16): # Load processor and model processor = AutoImageProcessor.from_pretrained(MODEL_NAME) model = BeitForImageClassification.from_pretrained( MODEL_NAME, num_labels=len(CLASSES), id2label={i: c for i, c in enumerate(CLASSES)}, label2id={c: i for i, c in enumerate(CLASSES)} ) # Load dataset dataset = load_dataset("imagefolder", data_dir={"train": train_dir, "validation": val_dir}) # Map transform over dataset dataset = dataset.map(lambda x: transform(x, processor), batched=True) # Ensure dataset returns PyTorch tensors dataset.set_format(type="torch", columns=["pixel_values", "label"]) # Training arguments args = TrainingArguments( output_dir=OUTPUT_DIR, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, num_train_epochs=epochs, evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="f1_weighted", logging_steps=50, report_to="none" ) # Trainer trainer = Trainer( model=model, args=args, train_dataset=dataset["train"], eval_dataset=dataset["validation"], tokenizer=processor, compute_metrics=compute_metrics ) # Train trainer.train() # Save model and processor model.save_pretrained(OUTPUT_DIR) processor.save_pretrained(OUTPUT_DIR) # Save labels with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f: json.dump(CLASSES, f) print("✅ Training complete. Model saved at:", OUTPUT_DIR) if __name__ == "__main__": train()