Habeeb Okunade commited on
Commit
238cd9e
ยท
1 Parent(s): 0e0e505

Update Training script

Browse files
Files changed (1) hide show
  1. train2.py +28 -8
train2.py CHANGED
@@ -9,6 +9,7 @@ from transformers import (
9
  Trainer
10
  )
11
  from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
 
12
 
13
  # ----------------------------
14
  # CONFIG
@@ -17,38 +18,51 @@ MODEL_NAME = "microsoft/beit-base-patch16-224"
17
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.expanduser("~/outputs/beit-retina"))
18
  NUM_CLASSES = 6 # retina disease classes
19
 
20
- # Make sure output directory exists
21
  os.makedirs(OUTPUT_DIR, exist_ok=True)
22
 
23
  # ----------------------------
24
  # LOAD DATASET
25
  # ----------------------------
26
- # Example: Replace this with your retina dataset
27
- # You can load a Hugging Face dataset or your own image folder dataset
28
- # Dataset format: train/valid/test folders each containing subfolders by class name
29
  dataset = load_dataset("imagefolder", data_dir="data")
 
30
 
31
  # ----------------------------
32
  # PREPROCESSOR
33
  # ----------------------------
 
34
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
35
 
36
  def transform(example):
37
- inputs = processor(example["image"], return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
38
  inputs["label"] = example["label"]
39
  return inputs
40
 
41
- # Map preprocessing
42
  dataset = dataset.with_transform(transform)
 
43
 
44
  # ----------------------------
45
  # MODEL
46
  # ----------------------------
 
47
  model = BeitForImageClassification.from_pretrained(
48
  MODEL_NAME,
49
  num_labels=NUM_CLASSES,
50
  ignore_mismatched_sizes=True
51
  )
 
52
 
53
  # ----------------------------
54
  # METRICS
@@ -56,12 +70,14 @@ model = BeitForImageClassification.from_pretrained(
56
  def compute_metrics(eval_pred):
57
  logits, labels = eval_pred
58
  preds = logits.argmax(axis=-1)
59
- return {
60
  "accuracy": accuracy_score(labels, preds),
61
  "precision": precision_score(labels, preds, average="macro"),
62
  "recall": recall_score(labels, preds, average="macro"),
63
  "f1": f1_score(labels, preds, average="macro"),
64
  }
 
 
65
 
66
  # ----------------------------
67
  # TRAINING ARGS
@@ -78,6 +94,7 @@ args = TrainingArguments(
78
  logging_dir=os.path.join(OUTPUT_DIR, "logs"),
79
  push_to_hub=False
80
  )
 
81
 
82
  # ----------------------------
83
  # TRAINER
@@ -90,19 +107,22 @@ trainer = Trainer(
90
  tokenizer=processor,
91
  compute_metrics=compute_metrics
92
  )
 
93
 
94
  # ----------------------------
95
  # TRAIN
96
  # ----------------------------
 
97
  trainer.train()
 
98
 
99
  # ----------------------------
100
  # SAVE FINAL MODEL + LABELS
101
  # ----------------------------
 
102
  trainer.save_model(OUTPUT_DIR)
103
  processor.save_pretrained(OUTPUT_DIR)
104
 
105
- # Save class labels mapping
106
  labels = dataset["train"].features["label"].names
107
  with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f:
108
  json.dump(labels, f)
 
9
  Trainer
10
  )
11
  from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
12
+ from PIL import Image
13
 
14
  # ----------------------------
15
  # CONFIG
 
18
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.expanduser("~/outputs/beit-retina"))
19
  NUM_CLASSES = 6 # retina disease classes
20
 
21
+ print(f"๐Ÿ”น OUTPUT_DIR set to: {OUTPUT_DIR}")
22
  os.makedirs(OUTPUT_DIR, exist_ok=True)
23
 
24
  # ----------------------------
25
  # LOAD DATASET
26
  # ----------------------------
27
+ print("๐Ÿ”น Loading dataset from 'data/' folder...")
 
 
28
  dataset = load_dataset("imagefolder", data_dir="data")
29
+ print(f"๐Ÿ”น Dataset loaded. Columns: {dataset['train'].column_names}")
30
 
31
  # ----------------------------
32
  # PREPROCESSOR
33
  # ----------------------------
34
+ print(f"๐Ÿ”น Loading processor from {MODEL_NAME}...")
35
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
36
 
37
  def transform(example):
38
+ # Determine correct image column
39
+ image_column = "image" if "image" in example else list(example.keys())[0]
40
+ img = example[image_column]
41
+ if isinstance(img, str): # if path, open it
42
+ img = Image.open(img).convert("RGB")
43
+ elif isinstance(img, Image.Image):
44
+ img = img.convert("RGB")
45
+ else:
46
+ raise ValueError(f"Unknown type for image: {type(img)}")
47
+
48
+ inputs = processor(img, return_tensors="pt")
49
  inputs["label"] = example["label"]
50
  return inputs
51
 
52
+ print("๐Ÿ”น Applying transform to dataset...")
53
  dataset = dataset.with_transform(transform)
54
+ print("๐Ÿ”น Transform applied successfully.")
55
 
56
  # ----------------------------
57
  # MODEL
58
  # ----------------------------
59
+ print(f"๐Ÿ”น Loading BEiT model ({MODEL_NAME}) with {NUM_CLASSES} classes...")
60
  model = BeitForImageClassification.from_pretrained(
61
  MODEL_NAME,
62
  num_labels=NUM_CLASSES,
63
  ignore_mismatched_sizes=True
64
  )
65
+ print("๐Ÿ”น Model loaded successfully.")
66
 
67
  # ----------------------------
68
  # METRICS
 
70
  def compute_metrics(eval_pred):
71
  logits, labels = eval_pred
72
  preds = logits.argmax(axis=-1)
73
+ metrics = {
74
  "accuracy": accuracy_score(labels, preds),
75
  "precision": precision_score(labels, preds, average="macro"),
76
  "recall": recall_score(labels, preds, average="macro"),
77
  "f1": f1_score(labels, preds, average="macro"),
78
  }
79
+ print(f"๐Ÿ”น Metrics computed: {metrics}")
80
+ return metrics
81
 
82
  # ----------------------------
83
  # TRAINING ARGS
 
94
  logging_dir=os.path.join(OUTPUT_DIR, "logs"),
95
  push_to_hub=False
96
  )
97
+ print("๐Ÿ”น TrainingArguments configured.")
98
 
99
  # ----------------------------
100
  # TRAINER
 
107
  tokenizer=processor,
108
  compute_metrics=compute_metrics
109
  )
110
+ print("๐Ÿ”น Trainer created. Ready to train.")
111
 
112
  # ----------------------------
113
  # TRAIN
114
  # ----------------------------
115
+ print("๐Ÿ”น Starting training...")
116
  trainer.train()
117
+ print("๐Ÿ”น Training complete.")
118
 
119
  # ----------------------------
120
  # SAVE FINAL MODEL + LABELS
121
  # ----------------------------
122
+ print("๐Ÿ”น Saving final model and processor...")
123
  trainer.save_model(OUTPUT_DIR)
124
  processor.save_pretrained(OUTPUT_DIR)
125
 
 
126
  labels = dataset["train"].features["label"].names
127
  with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f:
128
  json.dump(labels, f)