File size: 3,307 Bytes
4e7f56d
 
 
96c3348
 
 
 
 
4e7f56d
 
 
96c3348
 
 
4e7f56d
b43bc74
4e7f56d
96c3348
4e7f56d
 
96c3348
4e7f56d
 
 
96c3348
 
 
 
 
 
 
 
4e7f56d
 
 
 
 
 
 
 
96c3348
4e7f56d
 
 
 
 
 
96c3348
 
 
 
 
 
 
4e7f56d
 
 
 
 
 
 
 
 
 
96c3348
4e7f56d
96c3348
 
 
 
 
 
 
 
 
 
 
4e7f56d
96c3348
 
 
 
 
 
 
 
 
4e7f56d
96c3348
 
4e7f56d
 
 
 
 
 
96c3348
4e7f56d
 
 
96c3348
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import json
from PIL import Image
from transformers import AutoImageProcessor, BeitForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score
import numpy as np

# -------------------------------
# Config
# -------------------------------
CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"]
MODEL_NAME = "microsoft/beit-base-patch16-224"

# Output directory (from env or default)
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/home/user/outputs/beit-retina")
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("HOME dir:", os.environ.get("HOME"))
print("HF cache:", os.environ.get("HF_HOME"))

# -------------------------------
# Metrics
# -------------------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1_weighted": f1_score(labels, preds, average="weighted")
    }

# -------------------------------
# Preprocessing function
# -------------------------------
def transform(examples, processor):
    """Converts image paths to pixel_values tensors."""
    images = [processor(Image.open(p).convert("RGB"), return_tensors="pt")["pixel_values"][0] 
              for p in examples["image"]]
    return {"pixel_values": images}

# -------------------------------
# Training function
# -------------------------------
def train(train_dir="data/train", val_dir="data/val", epochs=5, batch_size=16):
    # Load processor and model
    processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
    model = BeitForImageClassification.from_pretrained(
        MODEL_NAME,
        num_labels=len(CLASSES),
        id2label={i: c for i, c in enumerate(CLASSES)},
        label2id={c: i for i, c in enumerate(CLASSES)}
    )

    # Load dataset
    dataset = load_dataset("imagefolder", data_dir={"train": train_dir, "validation": val_dir})

    # Map transform over dataset
    dataset = dataset.map(lambda x: transform(x, processor), batched=True)

    # Ensure dataset returns PyTorch tensors
    dataset.set_format(type="torch", columns=["pixel_values", "label"])

    # Training arguments
    args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=epochs,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1_weighted",
        logging_steps=50,
        report_to="none"
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        tokenizer=processor,
        compute_metrics=compute_metrics
    )

    # Train
    trainer.train()

    # Save model and processor
    model.save_pretrained(OUTPUT_DIR)
    processor.save_pretrained(OUTPUT_DIR)

    # Save labels
    with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f:
        json.dump(CLASSES, f)

    print("✅ Training complete. Model saved at:", OUTPUT_DIR)


if __name__ == "__main__":
    train()