Spaces:
Sleeping
Sleeping
File size: 3,307 Bytes
4e7f56d 96c3348 4e7f56d 96c3348 4e7f56d b43bc74 4e7f56d 96c3348 4e7f56d 96c3348 4e7f56d 96c3348 4e7f56d 96c3348 4e7f56d 96c3348 4e7f56d 96c3348 4e7f56d 96c3348 4e7f56d 96c3348 4e7f56d 96c3348 4e7f56d 96c3348 4e7f56d 96c3348 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import json
from PIL import Image
from transformers import AutoImageProcessor, BeitForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
# -------------------------------
# Config
# -------------------------------
CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"]
MODEL_NAME = "microsoft/beit-base-patch16-224"
# Output directory (from env or default)
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/home/user/outputs/beit-retina")
os.makedirs(OUTPUT_DIR, exist_ok=True)
print("HOME dir:", os.environ.get("HOME"))
print("HF cache:", os.environ.get("HF_HOME"))
# -------------------------------
# Metrics
# -------------------------------
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=1)
return {
"accuracy": accuracy_score(labels, preds),
"f1_weighted": f1_score(labels, preds, average="weighted")
}
# -------------------------------
# Preprocessing function
# -------------------------------
def transform(examples, processor):
"""Converts image paths to pixel_values tensors."""
images = [processor(Image.open(p).convert("RGB"), return_tensors="pt")["pixel_values"][0]
for p in examples["image"]]
return {"pixel_values": images}
# -------------------------------
# Training function
# -------------------------------
def train(train_dir="data/train", val_dir="data/val", epochs=5, batch_size=16):
# Load processor and model
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = BeitForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=len(CLASSES),
id2label={i: c for i, c in enumerate(CLASSES)},
label2id={c: i for i, c in enumerate(CLASSES)}
)
# Load dataset
dataset = load_dataset("imagefolder", data_dir={"train": train_dir, "validation": val_dir})
# Map transform over dataset
dataset = dataset.map(lambda x: transform(x, processor), batched=True)
# Ensure dataset returns PyTorch tensors
dataset.set_format(type="torch", columns=["pixel_values", "label"])
# Training arguments
args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=epochs,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1_weighted",
logging_steps=50,
report_to="none"
)
# Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=processor,
compute_metrics=compute_metrics
)
# Train
trainer.train()
# Save model and processor
model.save_pretrained(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
# Save labels
with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f:
json.dump(CLASSES, f)
print("✅ Training complete. Model saved at:", OUTPUT_DIR)
if __name__ == "__main__":
train()
|