import os import json import torch from datasets import load_dataset from transformers import ( AutoImageProcessor, BeitForImageClassification, TrainingArguments, Trainer ) from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score from PIL import Image # ---------------------------- # CONFIG # ---------------------------- MODEL_NAME = "microsoft/beit-base-patch16-224" OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.expanduser("~/outputs/beit-retina")) DATA_DIR = os.environ.get("DATA_DIR", "data2") # dynamic dataset path print(f"🔹 OUTPUT_DIR set to: {OUTPUT_DIR}") print(f"🔹 DATA_DIR set to: {DATA_DIR}") os.makedirs(OUTPUT_DIR, exist_ok=True) # ---------------------------- # LOAD DATASET # ---------------------------- print(f"🔹 Loading dataset from '{DATA_DIR}' folder...") dataset = load_dataset("imagefolder", data_dir=DATA_DIR) print(f"🔹 Dataset loaded. Columns: {dataset['train'].column_names}") print(f"🔹 Dataset splits: {list(dataset.keys())}") print(f"🔹 Number of training samples: {len(dataset['train'])}") print(f"🔹 Number of validation samples: {len(dataset['validation'])}") # ---------------------------- # PREPROCESSOR # ---------------------------- print(f"🔹 Loading processor from {MODEL_NAME}...") processor = AutoImageProcessor.from_pretrained(MODEL_NAME) def transform(example): # Detect image column image_column = "image" if "image" in example else [c for c in example.keys() if c != "label"][0] images = example[image_column] if not isinstance(images, list): images = [images] processed_images = [] for img in images: if isinstance(img, str): print(f" ↪ Opening image from path: {img}") img = Image.open(img).convert("RGB") elif isinstance(img, Image.Image): print(" ↪ Using PIL.Image directly") img = img.convert("RGB") else: raise ValueError(f"Unknown type for image: {type(img)}") processed_images.append(img) inputs = processor(images=processed_images, return_tensors="pt") labels = example["label"] if not isinstance(labels, list): labels = [labels] inputs["labels"] = torch.tensor(labels) return inputs print("🔹 Applying transform to dataset...") dataset = dataset.with_transform(transform) print("🔹 Transform applied successfully.") # ---------------------------- # MODEL # ---------------------------- print(f"🔹 Loading BEiT model ({MODEL_NAME}) with {len(dataset['train'].features['label'].names)} classes...") model = BeitForImageClassification.from_pretrained( MODEL_NAME, num_labels=len(dataset["train"].features["label"].names), ignore_mismatched_sizes=True ) print("🔹 Model loaded successfully.") # ---------------------------- # METRICS # ---------------------------- def compute_metrics(eval_pred): logits, labels = eval_pred preds = logits.argmax(axis=-1) metrics = { "accuracy": accuracy_score(labels, preds), "precision": precision_score(labels, preds, average="macro"), "recall": recall_score(labels, preds, average="macro"), "f1": f1_score(labels, preds, average="macro"), } print(f"🔹 Metrics computed: {metrics}") return metrics # ---------------------------- # TRAINING ARGUMENTS # ---------------------------- args = TrainingArguments( output_dir=OUTPUT_DIR, evaluation_strategy="epoch", save_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=5, weight_decay=0.01, logging_dir=os.path.join(OUTPUT_DIR, "logs"), logging_steps=10, push_to_hub=False ) print("🔹 TrainingArguments configured.") # ---------------------------- # TRAINER # ---------------------------- trainer = Trainer( model=model, args=args, train_dataset=dataset["train"], eval_dataset=dataset["validation"], tokenizer=processor, compute_metrics=compute_metrics ) print("🔹 Trainer created. Ready to train.") # ---------------------------- # TRAIN # ---------------------------- print("🔹 Starting training...") trainer.train() print("🔹 Training complete.") # ---------------------------- # SAVE MODEL + PROCESSOR + LABELS # ---------------------------- print("🔹 Saving final model and processor...") trainer.save_model(OUTPUT_DIR) processor.save_pretrained(OUTPUT_DIR) labels = dataset["train"].features["label"].names with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f: json.dump(labels, f) print(f"✅ Model and processor saved to {OUTPUT_DIR}")