Spaces:
Sleeping
Sleeping
File size: 4,660 Bytes
0e0e505 238cd9e 0e0e505 05c5199 0e0e505 238cd9e 05c5199 0e0e505 05c5199 238cd9e 785b8f1 0e0e505 238cd9e 0e0e505 05c5199 f119f72 785b8f1 05c5199 785b8f1 05c5199 785b8f1 0e0e505 238cd9e 0e0e505 238cd9e 0e0e505 05c5199 0e0e505 05c5199 0e0e505 238cd9e 0e0e505 238cd9e 0e0e505 238cd9e 0e0e505 785b8f1 0e0e505 05c5199 0e0e505 238cd9e 0e0e505 238cd9e 0e0e505 238cd9e 0e0e505 238cd9e 0e0e505 785b8f1 0e0e505 238cd9e 0e0e505 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import json
import torch
from datasets import load_dataset
from transformers import (
AutoImageProcessor,
BeitForImageClassification,
TrainingArguments,
Trainer
)
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from PIL import Image
# ----------------------------
# CONFIG
# ----------------------------
MODEL_NAME = "microsoft/beit-base-patch16-224"
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.expanduser("~/outputs/beit-retina"))
DATA_DIR = os.environ.get("DATA_DIR", "data2") # dynamic dataset path
print(f"๐น OUTPUT_DIR set to: {OUTPUT_DIR}")
print(f"๐น DATA_DIR set to: {DATA_DIR}")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ----------------------------
# LOAD DATASET
# ----------------------------
print(f"๐น Loading dataset from '{DATA_DIR}' folder...")
dataset = load_dataset("imagefolder", data_dir=DATA_DIR)
print(f"๐น Dataset loaded. Columns: {dataset['train'].column_names}")
print(f"๐น Dataset splits: {list(dataset.keys())}")
print(f"๐น Number of training samples: {len(dataset['train'])}")
print(f"๐น Number of validation samples: {len(dataset['validation'])}")
# ----------------------------
# PREPROCESSOR
# ----------------------------
print(f"๐น Loading processor from {MODEL_NAME}...")
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
def transform(example):
# Detect image column
image_column = "image" if "image" in example else [c for c in example.keys() if c != "label"][0]
images = example[image_column]
if not isinstance(images, list):
images = [images]
processed_images = []
for img in images:
if isinstance(img, str):
print(f" โช Opening image from path: {img}")
img = Image.open(img).convert("RGB")
elif isinstance(img, Image.Image):
print(" โช Using PIL.Image directly")
img = img.convert("RGB")
else:
raise ValueError(f"Unknown type for image: {type(img)}")
processed_images.append(img)
inputs = processor(images=processed_images, return_tensors="pt")
labels = example["label"]
if not isinstance(labels, list):
labels = [labels]
inputs["labels"] = torch.tensor(labels)
return inputs
print("๐น Applying transform to dataset...")
dataset = dataset.with_transform(transform)
print("๐น Transform applied successfully.")
# ----------------------------
# MODEL
# ----------------------------
print(f"๐น Loading BEiT model ({MODEL_NAME}) with {len(dataset['train'].features['label'].names)} classes...")
model = BeitForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=len(dataset["train"].features["label"].names),
ignore_mismatched_sizes=True
)
print("๐น Model loaded successfully.")
# ----------------------------
# METRICS
# ----------------------------
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = logits.argmax(axis=-1)
metrics = {
"accuracy": accuracy_score(labels, preds),
"precision": precision_score(labels, preds, average="macro"),
"recall": recall_score(labels, preds, average="macro"),
"f1": f1_score(labels, preds, average="macro"),
}
print(f"๐น Metrics computed: {metrics}")
return metrics
# ----------------------------
# TRAINING ARGUMENTS
# ----------------------------
args = TrainingArguments(
output_dir=OUTPUT_DIR,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=5,
weight_decay=0.01,
logging_dir=os.path.join(OUTPUT_DIR, "logs"),
logging_steps=10,
push_to_hub=False
)
print("๐น TrainingArguments configured.")
# ----------------------------
# TRAINER
# ----------------------------
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=processor,
compute_metrics=compute_metrics
)
print("๐น Trainer created. Ready to train.")
# ----------------------------
# TRAIN
# ----------------------------
print("๐น Starting training...")
trainer.train()
print("๐น Training complete.")
# ----------------------------
# SAVE MODEL + PROCESSOR + LABELS
# ----------------------------
print("๐น Saving final model and processor...")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
labels = dataset["train"].features["label"].names
with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f:
json.dump(labels, f)
print(f"โ
Model and processor saved to {OUTPUT_DIR}")
|