Habeeb Okunade commited on
Commit
785b8f1
ยท
1 Parent(s): 238cd9e

Update Training script

Browse files
Files changed (1) hide show
  1. train2.py +31 -13
train2.py CHANGED
@@ -27,6 +27,9 @@ os.makedirs(OUTPUT_DIR, exist_ok=True)
27
  print("๐Ÿ”น Loading dataset from 'data/' folder...")
28
  dataset = load_dataset("imagefolder", data_dir="data")
29
  print(f"๐Ÿ”น Dataset loaded. Columns: {dataset['train'].column_names}")
 
 
 
30
 
31
  # ----------------------------
32
  # PREPROCESSOR
@@ -35,18 +38,33 @@ print(f"๐Ÿ”น Loading processor from {MODEL_NAME}...")
35
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
36
 
37
  def transform(example):
38
- # Determine correct image column
39
  image_column = "image" if "image" in example else list(example.keys())[0]
40
- img = example[image_column]
41
- if isinstance(img, str): # if path, open it
42
- img = Image.open(img).convert("RGB")
43
- elif isinstance(img, Image.Image):
44
- img = img.convert("RGB")
45
- else:
46
- raise ValueError(f"Unknown type for image: {type(img)}")
47
-
48
- inputs = processor(img, return_tensors="pt")
49
- inputs["label"] = example["label"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  return inputs
51
 
52
  print("๐Ÿ”น Applying transform to dataset...")
@@ -80,7 +98,7 @@ def compute_metrics(eval_pred):
80
  return metrics
81
 
82
  # ----------------------------
83
- # TRAINING ARGS
84
  # ----------------------------
85
  args = TrainingArguments(
86
  output_dir=OUTPUT_DIR,
@@ -117,7 +135,7 @@ trainer.train()
117
  print("๐Ÿ”น Training complete.")
118
 
119
  # ----------------------------
120
- # SAVE FINAL MODEL + LABELS
121
  # ----------------------------
122
  print("๐Ÿ”น Saving final model and processor...")
123
  trainer.save_model(OUTPUT_DIR)
 
27
  print("๐Ÿ”น Loading dataset from 'data/' folder...")
28
  dataset = load_dataset("imagefolder", data_dir="data")
29
  print(f"๐Ÿ”น Dataset loaded. Columns: {dataset['train'].column_names}")
30
+ print(f"๐Ÿ”น Dataset splits: {list(dataset.keys())}")
31
+ print(f"๐Ÿ”น Number of training samples: {len(dataset['train'])}")
32
+ print(f"๐Ÿ”น Number of validation samples: {len(dataset['validation'])}")
33
 
34
  # ----------------------------
35
  # PREPROCESSOR
 
38
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
39
 
40
  def transform(example):
41
+ # Determine image column
42
  image_column = "image" if "image" in example else list(example.keys())[0]
43
+ images = example[image_column]
44
+
45
+ # Ensure list
46
+ if not isinstance(images, list):
47
+ images = [images]
48
+
49
+ processed_images = []
50
+ for img in images:
51
+ if isinstance(img, str):
52
+ img = Image.open(img).convert("RGB")
53
+ elif isinstance(img, Image.Image):
54
+ img = img.convert("RGB")
55
+ else:
56
+ raise ValueError(f"Unknown type for image: {type(img)}")
57
+ processed_images.append(img)
58
+
59
+ # Convert images to tensors
60
+ inputs = processor(images=processed_images, return_tensors="pt")
61
+
62
+ # Handle labels
63
+ labels = example["label"]
64
+ if not isinstance(labels, list):
65
+ labels = [labels]
66
+ inputs["labels"] = torch.tensor(labels)
67
+
68
  return inputs
69
 
70
  print("๐Ÿ”น Applying transform to dataset...")
 
98
  return metrics
99
 
100
  # ----------------------------
101
+ # TRAINING ARGUMENTS
102
  # ----------------------------
103
  args = TrainingArguments(
104
  output_dir=OUTPUT_DIR,
 
135
  print("๐Ÿ”น Training complete.")
136
 
137
  # ----------------------------
138
+ # SAVE MODEL + PROCESSOR + LABELS
139
  # ----------------------------
140
  print("๐Ÿ”น Saving final model and processor...")
141
  trainer.save_model(OUTPUT_DIR)