Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, BeitForImageClassification, TrainingArguments, Trainer | |
| from datasets import load_dataset | |
| from sklearn.metrics import accuracy_score, f1_score | |
| import numpy as np | |
| # ------------------------------- | |
| # Config | |
| # ------------------------------- | |
| CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"] | |
| MODEL_NAME = "microsoft/beit-base-patch16-224" | |
| # Output directory (from env or default) | |
| OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/home/user/outputs/beit-retina") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| print("HOME dir:", os.environ.get("HOME")) | |
| print("HF cache:", os.environ.get("HF_HOME")) | |
| # ------------------------------- | |
| # Metrics | |
| # ------------------------------- | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=1) | |
| return { | |
| "accuracy": accuracy_score(labels, preds), | |
| "f1_weighted": f1_score(labels, preds, average="weighted") | |
| } | |
| # ------------------------------- | |
| # Preprocessing function | |
| # ------------------------------- | |
| def transform(examples, processor): | |
| """Converts image paths to pixel_values tensors.""" | |
| images = [processor(Image.open(p).convert("RGB"), return_tensors="pt")["pixel_values"][0] | |
| for p in examples["image"]] | |
| return {"pixel_values": images} | |
| # ------------------------------- | |
| # Training function | |
| # ------------------------------- | |
| def train(train_dir="data/train", val_dir="data/val", epochs=5, batch_size=16): | |
| # Load processor and model | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| model = BeitForImageClassification.from_pretrained( | |
| MODEL_NAME, | |
| num_labels=len(CLASSES), | |
| id2label={i: c for i, c in enumerate(CLASSES)}, | |
| label2id={c: i for i, c in enumerate(CLASSES)} | |
| ) | |
| # Load dataset | |
| dataset = load_dataset("imagefolder", data_dir={"train": train_dir, "validation": val_dir}) | |
| # Map transform over dataset | |
| dataset = dataset.map(lambda x: transform(x, processor), batched=True) | |
| # Ensure dataset returns PyTorch tensors | |
| dataset.set_format(type="torch", columns=["pixel_values", "label"]) | |
| # Training arguments | |
| args = TrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=batch_size, | |
| num_train_epochs=epochs, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="f1_weighted", | |
| logging_steps=50, | |
| report_to="none" | |
| ) | |
| # Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["validation"], | |
| tokenizer=processor, | |
| compute_metrics=compute_metrics | |
| ) | |
| # Train | |
| trainer.train() | |
| # Save model and processor | |
| model.save_pretrained(OUTPUT_DIR) | |
| processor.save_pretrained(OUTPUT_DIR) | |
| # Save labels | |
| with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f: | |
| json.dump(CLASSES, f) | |
| print("β Training complete. Model saved at:", OUTPUT_DIR) | |
| if __name__ == "__main__": | |
| train() | |