Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoImageProcessor, | |
| BeitForImageClassification, | |
| TrainingArguments, | |
| Trainer | |
| ) | |
| from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score | |
| from PIL import Image | |
| # ---------------------------- | |
| # CONFIG | |
| # ---------------------------- | |
| MODEL_NAME = "microsoft/beit-base-patch16-224" | |
| OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.expanduser("~/outputs/beit-retina")) | |
| DATA_DIR = os.environ.get("DATA_DIR", "data2") # dynamic dataset path | |
| print(f"๐น OUTPUT_DIR set to: {OUTPUT_DIR}") | |
| print(f"๐น DATA_DIR set to: {DATA_DIR}") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # ---------------------------- | |
| # LOAD DATASET | |
| # ---------------------------- | |
| print(f"๐น Loading dataset from '{DATA_DIR}' folder...") | |
| dataset = load_dataset("imagefolder", data_dir=DATA_DIR) | |
| print(f"๐น Dataset loaded. Columns: {dataset['train'].column_names}") | |
| print(f"๐น Dataset splits: {list(dataset.keys())}") | |
| print(f"๐น Number of training samples: {len(dataset['train'])}") | |
| print(f"๐น Number of validation samples: {len(dataset['validation'])}") | |
| # ---------------------------- | |
| # PREPROCESSOR | |
| # ---------------------------- | |
| print(f"๐น Loading processor from {MODEL_NAME}...") | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| def transform(example): | |
| # Detect image column | |
| image_column = "image" if "image" in example else [c for c in example.keys() if c != "label"][0] | |
| images = example[image_column] | |
| if not isinstance(images, list): | |
| images = [images] | |
| processed_images = [] | |
| for img in images: | |
| if isinstance(img, str): | |
| print(f" โช Opening image from path: {img}") | |
| img = Image.open(img).convert("RGB") | |
| elif isinstance(img, Image.Image): | |
| print(" โช Using PIL.Image directly") | |
| img = img.convert("RGB") | |
| else: | |
| raise ValueError(f"Unknown type for image: {type(img)}") | |
| processed_images.append(img) | |
| inputs = processor(images=processed_images, return_tensors="pt") | |
| labels = example["label"] | |
| if not isinstance(labels, list): | |
| labels = [labels] | |
| inputs["labels"] = torch.tensor(labels) | |
| return inputs | |
| print("๐น Applying transform to dataset...") | |
| dataset = dataset.with_transform(transform) | |
| print("๐น Transform applied successfully.") | |
| # ---------------------------- | |
| # MODEL | |
| # ---------------------------- | |
| print(f"๐น Loading BEiT model ({MODEL_NAME}) with {len(dataset['train'].features['label'].names)} classes...") | |
| model = BeitForImageClassification.from_pretrained( | |
| MODEL_NAME, | |
| num_labels=len(dataset["train"].features["label"].names), | |
| ignore_mismatched_sizes=True | |
| ) | |
| print("๐น Model loaded successfully.") | |
| # ---------------------------- | |
| # METRICS | |
| # ---------------------------- | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = logits.argmax(axis=-1) | |
| metrics = { | |
| "accuracy": accuracy_score(labels, preds), | |
| "precision": precision_score(labels, preds, average="macro"), | |
| "recall": recall_score(labels, preds, average="macro"), | |
| "f1": f1_score(labels, preds, average="macro"), | |
| } | |
| print(f"๐น Metrics computed: {metrics}") | |
| return metrics | |
| # ---------------------------- | |
| # TRAINING ARGUMENTS | |
| # ---------------------------- | |
| args = TrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| learning_rate=5e-5, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=16, | |
| num_train_epochs=5, | |
| weight_decay=0.01, | |
| logging_dir=os.path.join(OUTPUT_DIR, "logs"), | |
| logging_steps=10, | |
| push_to_hub=False | |
| ) | |
| print("๐น TrainingArguments configured.") | |
| # ---------------------------- | |
| # TRAINER | |
| # ---------------------------- | |
| trainer = Trainer( | |
| model=model, | |
| args=args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["validation"], | |
| tokenizer=processor, | |
| compute_metrics=compute_metrics | |
| ) | |
| print("๐น Trainer created. Ready to train.") | |
| # ---------------------------- | |
| # TRAIN | |
| # ---------------------------- | |
| print("๐น Starting training...") | |
| trainer.train() | |
| print("๐น Training complete.") | |
| # ---------------------------- | |
| # SAVE MODEL + PROCESSOR + LABELS | |
| # ---------------------------- | |
| print("๐น Saving final model and processor...") | |
| trainer.save_model(OUTPUT_DIR) | |
| processor.save_pretrained(OUTPUT_DIR) | |
| labels = dataset["train"].features["label"].names | |
| with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f: | |
| json.dump(labels, f) | |
| print(f"โ Model and processor saved to {OUTPUT_DIR}") | |