ALYYAN commited on
Commit
bdb70cc
·
1 Parent(s): 30672d3

Complete baseline model with EfficientFormer

Browse files
.gitignore CHANGED
@@ -130,7 +130,7 @@ __pypackages__/
130
  # Celery stuff
131
  celerybeat-schedule
132
  celerybeat.pid
133
-
134
  # SageMath parsed files
135
  *.sage.py
136
 
 
130
  # Celery stuff
131
  celerybeat-schedule
132
  celerybeat.pid
133
+ artifacts/data_ingestion/*
134
  # SageMath parsed files
135
  *.sage.py
136
 
app.py CHANGED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ from transformers import pipeline
6
+ from mtcnn import MTCNN
7
+ from collections import defaultdict
8
+
9
+ st.set_page_config(layout="wide", page_title="Facial Age Detection")
10
+
11
+ st.title("Facial Age Detection")
12
+ st.write("Detect age groups from images, videos, or a live webcam feed.")
13
+ st.write("This application uses an EfficientFormer-L1 model fine-tuned on the Facial Age dataset.")
14
+
15
+ # --- Helper Functions and Classes ---
16
+
17
+ @st.cache_resource
18
+ def load_model():
19
+ """Load the age detection model pipeline."""
20
+ model_path = "artifacts/model_trainer/facial_age_detector_model"
21
+ pipe = pipeline('image-classification', model=model_path, device=0) # Use 0 for GPU
22
+ return pipe
23
+
24
+ @st.cache_resource
25
+ def load_face_detector():
26
+ """Load the MTCNN face detector."""
27
+ return MTCNN()
28
+
29
+ def iou(boxA, boxB):
30
+ """Calculate Intersection over Union."""
31
+ xA = max(boxA[0], boxB[0])
32
+ yA = max(boxA[1], boxB[1])
33
+ xB = min(boxA[2], boxB[2])
34
+ yB = min(boxA[3], boxB[3])
35
+ interArea = max(0, xB - xA) * max(0, yB - yA)
36
+ boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
37
+ boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
38
+ return interArea / float(boxAArea + boxBArea - interArea)
39
+
40
+ class EMATracker:
41
+ """Exponential Moving Average Tracker for smoothing predictions."""
42
+ def __init__(self, alpha=0.3):
43
+ self.alpha = alpha
44
+ self.tracked_objects = {} # {track_id: {box: [], ema_preds: {}}}
45
+
46
+ def update(self, detections, id_counter):
47
+ # Detections are a list of face boxes
48
+ # Simple tracking by IOU
49
+
50
+ # Match detections to existing tracks
51
+ matches = {} # {track_id: det_idx}
52
+ used_det_indices = set()
53
+
54
+ # This is a simple greedy matching. For more robust tracking, consider Hungarian algorithm.
55
+ for track_id, data in self.tracked_objects.items():
56
+ best_iou = 0
57
+ best_det_idx = -1
58
+ for i, det_box in enumerate(detections):
59
+ if i in used_det_indices: continue
60
+ current_iou = iou(data['box'], det_box)
61
+ if current_iou > best_iou and current_iou > 0.3: # IOU threshold
62
+ best_iou = current_iou
63
+ best_det_idx = i
64
+ if best_det_idx != -1:
65
+ matches[track_id] = best_det_idx
66
+ used_det_indices.add(best_det_idx)
67
+
68
+ # Update matched tracks
69
+ for track_id, det_idx in matches.items():
70
+ self.tracked_objects[track_id]['box'] = detections[det_idx]
71
+
72
+ # Add new tracks
73
+ for i, det_box in enumerate(detections):
74
+ if i not in used_det_indices:
75
+ self.tracked_objects[id_counter] = {'box': det_box, 'ema_preds': defaultdict(float)}
76
+ id_counter += 1
77
+
78
+ # Remove old tracks (optional, for long videos)
79
+
80
+ return id_counter
81
+
82
+ def apply_ema(self, track_id, new_preds):
83
+ """Applies EMA to the predictions for a given track."""
84
+ if track_id not in self.tracked_objects:
85
+ return {}
86
+
87
+ current_ema = self.tracked_objects[track_id]['ema_preds']
88
+
89
+ # Initialize if new
90
+ if not current_ema:
91
+ for pred in new_preds:
92
+ current_ema[pred['label']] = pred['score']
93
+ else:
94
+ # Update existing values
95
+ for pred in new_preds:
96
+ label = pred['label']
97
+ current_ema[label] = (self.alpha * pred['score']) + ((1 - self.alpha) * current_ema[label])
98
+
99
+ self.tracked_objects[track_id]['ema_preds'] = current_ema
100
+
101
+ # Return the top prediction from EMA
102
+ if not current_ema: return None
103
+ top_label = max(current_ema, key=current_ema.get)
104
+ return f"{top_label} ({current_ema[top_label]:.2f})"
105
+
106
+
107
+ # --- Load Models ---
108
+ try:
109
+ age_pipe = load_model()
110
+ face_detector = load_face_detector()
111
+ except Exception as e:
112
+ st.error(f"Error loading models: {e}. Please ensure the model is trained and located at 'artifacts/model_trainer/facial_age_detector_model'.")
113
+ st.stop()
114
+
115
+
116
+ # --- UI Sidebar ---
117
+ st.sidebar.header("Input Options")
118
+ app_mode = st.sidebar.selectbox("Choose the app mode", ["Image", "Video", "Live Webcam"])
119
+
120
+ # --- Main App Logic ---
121
+
122
+ if app_mode == "Image":
123
+ uploaded_file = st.sidebar.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
124
+ if uploaded_file is not None:
125
+ image = Image.open(uploaded_file).convert("RGB")
126
+ img_array = np.array(image)
127
+
128
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
129
+ st.write("")
130
+ st.write("Detecting faces and predicting age...")
131
+
132
+ faces = face_detector.detect_faces(img_array)
133
+
134
+ if not faces:
135
+ st.warning("No faces detected in the image.")
136
+ else:
137
+ for face in faces:
138
+ x, y, w, h = face['box']
139
+ face_img = img_array[y:y+h, x:x+w]
140
+ pil_face = Image.fromarray(face_img)
141
+
142
+ # Predict age
143
+ age_preds = age_pipe(pil_face)
144
+ top_pred = age_preds[0]
145
+
146
+ # Draw on image
147
+ cv2.rectangle(img_array, (x, y), (x+w, y+h), (0, 255, 0), 2)
148
+ label = f"Age: {top_pred['label']} ({top_pred['score']:.2f})"
149
+ cv2.putText(img_array, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0,255,0), 2)
150
+
151
+ st.image(img_array, caption='Processed Image.', use_column_width=True)
152
+
153
+ elif app_mode == "Live Webcam":
154
+ st.sidebar.info("Webcam feed will start automatically. Press 'Stop' to end.")
155
+ run = st.sidebar.button('Start Webcam')
156
+ stop = st.sidebar.button('Stop Webcam')
157
+ FRAME_WINDOW = st.image([])
158
+
159
+ cap = cv2.VideoCapture(0)
160
+ tracker = EMATracker()
161
+ track_id_counter = 0
162
+
163
+ while run and not stop:
164
+ ret, frame = cap.read()
165
+ if not ret:
166
+ st.error("Failed to capture image from webcam.")
167
+ break
168
+
169
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
170
+ faces = face_detector.detect_faces(frame_rgb)
171
+
172
+ detection_boxes = [f['box'] for f in faces]
173
+ track_id_counter = tracker.update(detection_boxes, track_id_counter)
174
+
175
+ for track_id, data in tracker.tracked_objects.items():
176
+ x, y, w, h = data['box']
177
+ if w > 20 and h > 20: # Filter small detections
178
+ face_img = frame_rgb[y:y+h, x:x+w]
179
+ pil_face = Image.fromarray(face_img)
180
+
181
+ age_preds = age_pipe(pil_face)
182
+ smoothed_label = tracker.apply_ema(track_id, age_preds)
183
+
184
+ cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
185
+ if smoothed_label:
186
+ cv2.putText(frame, smoothed_label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2)
187
+
188
+ FRAME_WINDOW.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
189
+
190
+ cap.release()
191
+ st.sidebar.info("Webcam stopped.")
192
+
193
+ # Add a placeholder for Video processing, which would be similar to Webcam but with a file uploader.
194
+ elif app_mode == "Video":
195
+ st.sidebar.warning("Video processing is similar to the webcam feed but processes a file. This feature is not fully implemented in this demo but follows the same logic.")
196
+ # You would use cv2.VideoCapture(video_path) and loop through frames.
config/config.yaml CHANGED
@@ -15,4 +15,4 @@ model_trainer:
15
  root_dir: artifacts/model_trainer
16
  trained_model_path: artifacts/model_trainer/facial_age_detector_model
17
  # Using EfficientFormer-L1, a much lighter model than ViT
18
- model_name: "snap-research/efficientformer-l1"
 
15
  root_dir: artifacts/model_trainer
16
  trained_model_path: artifacts/model_trainer/facial_age_detector_model
17
  # Using EfficientFormer-L1, a much lighter model than ViT
18
+ model_name: "snap-research/efficientformer-l1-300"
dvc.lock ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ schema: '2.0'
2
+ stages:
3
+ data_ingestion:
4
+ cmd: python src/cnnClassifier/pipeline/stage_01_data_ingestion.py
5
+ deps:
6
+ - path: config/config.yaml
7
+ hash: md5
8
+ md5: 3cea2dfb36f0a5e40dd599dad9458ca4
9
+ size: 609
10
+ - path: src/cnnClassifier/components/data_ingestion.py
11
+ hash: md5
12
+ md5: 80b591ef3eedaf256ef85f4d196a0d43
13
+ size: 1591
14
+ - path: src/cnnClassifier/pipeline/stage_01_data_ingestion.py
15
+ hash: md5
16
+ md5: 2e1c2ad52ddc9763ff2a241576a7477c
17
+ size: 904
18
+ outs:
19
+ - path: artifacts/data_ingestion
20
+ hash: md5
21
+ md5: 35941f86a72fc72e64cb3195753ae21d.dir
22
+ size: 1758455894
23
+ nfiles: 19557
24
+ model_training:
25
+ cmd: python src/cnnClassifier/pipeline/stage_02_model_training.py
26
+ deps:
27
+ - path: artifacts/data_ingestion
28
+ hash: md5
29
+ md5: 35941f86a72fc72e64cb3195753ae21d.dir
30
+ size: 1758455894
31
+ nfiles: 19557
32
+ - path: config/config.yaml
33
+ hash: md5
34
+ md5: 3cea2dfb36f0a5e40dd599dad9458ca4
35
+ size: 609
36
+ - path: params.yaml
37
+ hash: md5
38
+ md5: ce8c137aa11f22d0901fb41485e9bfde
39
+ size: 239
40
+ - path: src/cnnClassifier/components/model_trainer.py
41
+ hash: md5
42
+ md5: bc58a9fdc35492409863b38424773ef6
43
+ size: 8585
44
+ - path: src/cnnClassifier/pipeline/stage_02_model_training.py
45
+ hash: md5
46
+ md5: 374003acf88403924718ed5982007523
47
+ size: 829
48
+ outs:
49
+ - path: artifacts/model_trainer
50
+ hash: md5
51
+ md5: 621f61ba7beea89c3bef7a921afdcc9d.dir
52
+ size: 183039001
53
+ nfiles: 12
requirements.txt CHANGED
@@ -5,13 +5,14 @@ torchvision==0.16.0+cu118
5
  torchaudio==2.1.0
6
 
7
  # Pin NumPy to a version compatible with Torch 2.1.0
8
- numpy<2.0
9
 
10
  # Hugging Face
11
- transformers
 
12
  datasets>=2.14.5
13
  evaluate
14
- accelerate>=0.27
15
 
16
  # MLOps and Utilities
17
  mlflow
@@ -29,6 +30,7 @@ imblearn
29
  streamlit
30
  opencv-python
31
  mtcnn
 
32
 
33
  # AWS Deployment
34
  boto3
 
5
  torchaudio==2.1.0
6
 
7
  # Pin NumPy to a version compatible with Torch 2.1.0
8
+ numpy>=1.23,<2.0
9
 
10
  # Hugging Face
11
+ transformers==4.36.2
12
+ tokenizers==0.15.0
13
  datasets>=2.14.5
14
  evaluate
15
+ accelerate>=0.25
16
 
17
  # MLOps and Utilities
18
  mlflow
 
30
  streamlit
31
  opencv-python
32
  mtcnn
33
+ tensorflow==2.15.0
34
 
35
  # AWS Deployment
36
  boto3
src/cnnClassifier/__init__.py CHANGED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+
5
+ # Define the logging format
6
+ logging_str = "[%(asctime)s: %(levelname)s: %(module)s: %(message)s]"
7
+
8
+ # Define the directory for log files
9
+ log_dir = "logs"
10
+ log_filepath = os.path.join(log_dir,"running_logs.log")
11
+ os.makedirs(log_dir, exist_ok=True)
12
+
13
+
14
+ # Configure the logging
15
+ logging.basicConfig(
16
+ level= logging.INFO,
17
+ format= logging_str,
18
+
19
+ handlers=[
20
+ logging.FileHandler(log_filepath), # Log to a file
21
+ logging.StreamHandler(sys.stdout) # Also print to the console
22
+ ]
23
+ )
24
+
25
+ # Create the logger object
26
+ logger = logging.getLogger("cnnClassifierLogger")
src/cnnClassifier/components/model_trainer.py CHANGED
@@ -2,32 +2,62 @@ import torch
2
  import pandas as pd
3
  from pathlib import Path
4
  from tqdm import tqdm
 
5
  from datasets import Dataset, Image, ClassLabel
6
  from imblearn.over_sampling import RandomOverSampler
7
  from transformers import (
8
  EfficientFormerImageProcessor,
9
  EfficientFormerForImageClassification,
10
  TrainingArguments,
11
- Trainer,
12
- DefaultDataCollator
13
  )
14
  from torchvision.transforms import (
15
  Compose,
16
  Normalize,
17
  RandomRotation,
18
- RandomResizedCrop,
19
  RandomHorizontalFlip,
20
  Resize,
21
  ToTensor
22
  )
23
  import evaluate
24
  from cnnClassifier.entity.config_entity import ModelTrainerConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  class ModelTrainer:
27
  def __init__(self, config: ModelTrainerConfig):
28
  self.config = config
29
- self.label2id = None
30
- self.id2label = None
31
 
32
  def _prepare_data(self):
33
  logger.info("Preparing data...")
@@ -60,13 +90,11 @@ class ModelTrainer:
60
  file_names, labels = [], []
61
  data_path = Path(self.config.data_path)
62
  for file in tqdm(sorted(data_path.glob('*/*.*'))):
63
- label = str(file).split('/')[-2]
64
  labels.append(label_dict[label])
65
  file_names.append(str(file))
66
 
67
  df = pd.DataFrame.from_dict({"image": file_names, "label": labels})
68
-
69
- # Random oversampling
70
  ros = RandomOverSampler(random_state=self.config.random_state)
71
  df_resampled, y_resampled = ros.fit_resample(df[['image']], df['label'])
72
  df = pd.concat([df_resampled, y_resampled], axis=1)
@@ -74,72 +102,54 @@ class ModelTrainer:
74
  dataset = Dataset.from_pandas(df).cast_column("image", Image())
75
 
76
  labels_list = sorted(list(set(labels)))
77
- self.label2id = {label: i for i, label in enumerate(labels_list)}
78
- self.id2label = {i: label for i, label in enumerate(labels_list)}
79
 
80
  ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)
81
  dataset = dataset.map(lambda x: {'label': ClassLabels.str2int(x['label'])}, batched=True)
82
  dataset = dataset.cast_column('label', ClassLabels)
83
 
84
- return dataset.train_test_split(test_size=self.config.test_split_size, shuffle=True, stratify_by_column="label")
 
85
 
86
  def train(self):
87
  device = "cuda" if torch.cuda.is_available() else "cpu"
88
  logger.info(f"Using device: {device}")
89
 
90
- split_dataset = self._prepare_data()
91
  train_data = split_dataset['train']
92
  test_data = split_dataset['test']
93
 
94
  processor = EfficientFormerImageProcessor.from_pretrained(self.config.model_name)
95
 
96
- image_mean, image_std = processor.image_mean, processor.image_std
97
- size = self.config.image_size
98
-
99
- normalize = Normalize(mean=image_mean, std=image_std)
100
  _train_transforms = Compose([
101
- Resize((size, size)),
102
  RandomRotation(15),
103
  RandomHorizontalFlip(0.5),
104
  ToTensor(),
105
- normalize
106
  ])
107
  _val_transforms = Compose([
108
- Resize((size, size)),
109
  ToTensor(),
110
- normalize
111
  ])
112
 
113
- def train_transforms(examples):
114
- examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
115
- return examples
116
-
117
- def val_transforms(examples):
118
- examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
119
- return examples
120
-
121
- train_data.set_transform(train_transforms)
122
- test_data.set_transform(val_transforms)
123
-
124
- def collate_fn(examples):
125
- pixel_values = torch.stack([example["pixel_values"] for example in examples])
126
- labels = torch.tensor([example['label'] for example in examples])
127
- return {"pixel_values": pixel_values, "labels": labels}
128
 
 
 
 
129
  model = EfficientFormerForImageClassification.from_pretrained(
130
  self.config.model_name,
131
- num_labels=len(self.id2label),
132
- id2label=self.id2label,
133
- label2id=self.label2id,
134
- ignore_mismatched_sizes=True # Important for transfer learning
135
  ).to(device)
136
 
137
- accuracy = evaluate.load("accuracy")
138
- def compute_metrics(eval_pred):
139
- predictions, label_ids = eval_pred
140
- predicted_labels = predictions.argmax(axis=1)
141
- return accuracy.compute(predictions=predicted_labels, references=label_ids)
142
-
143
  args = TrainingArguments(
144
  output_dir=self.config.root_dir,
145
  logging_dir=f'{self.config.root_dir}/logs',
@@ -154,6 +164,8 @@ class ModelTrainer:
154
  load_best_model_at_end=True,
155
  metric_for_best_model="accuracy",
156
  save_total_limit=1,
 
 
157
  report_to="none"
158
  )
159
 
 
2
  import pandas as pd
3
  from pathlib import Path
4
  from tqdm import tqdm
5
+ from functools import partial
6
  from datasets import Dataset, Image, ClassLabel
7
  from imblearn.over_sampling import RandomOverSampler
8
  from transformers import (
9
  EfficientFormerImageProcessor,
10
  EfficientFormerForImageClassification,
11
  TrainingArguments,
12
+ Trainer
 
13
  )
14
  from torchvision.transforms import (
15
  Compose,
16
  Normalize,
17
  RandomRotation,
 
18
  RandomHorizontalFlip,
19
  Resize,
20
  ToTensor
21
  )
22
  import evaluate
23
  from cnnClassifier.entity.config_entity import ModelTrainerConfig
24
+ from cnnClassifier import logger
25
+
26
+ # ==============================================================================
27
+ # TOP-LEVEL FUNCTION DEFINITIONS (FOR PICKLING)
28
+ # ==============================================================================
29
+
30
+ def apply_transforms(batch, processor, transform_pipeline):
31
+ """Applies a given transformation pipeline to a batch of images."""
32
+ # Create the normalization transform with stats from the processor
33
+ normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
34
+
35
+ # Combine the base transforms with normalization
36
+ full_transforms = Compose([*transform_pipeline.transforms, normalize])
37
+
38
+ # Apply to each image in the batch
39
+ batch["pixel_values"] = [full_transforms(img.convert("RGB")) for img in batch["image"]]
40
+ return batch
41
+
42
+ def collate_fn(batch):
43
+ """A custom collate function for image classification."""
44
+ return {
45
+ 'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
46
+ 'labels': torch.tensor([x['label'] for x in batch])
47
+ }
48
+
49
+ def compute_metrics(eval_pred):
50
+ """Computes accuracy metric for evaluation."""
51
+ accuracy = evaluate.load("accuracy")
52
+ predictions, label_ids = eval_pred
53
+ predicted_labels = predictions.argmax(axis=1)
54
+ return accuracy.compute(predictions=predicted_labels, references=label_ids)
55
+
56
+ # ==============================================================================
57
 
58
  class ModelTrainer:
59
  def __init__(self, config: ModelTrainerConfig):
60
  self.config = config
 
 
61
 
62
  def _prepare_data(self):
63
  logger.info("Preparing data...")
 
90
  file_names, labels = [], []
91
  data_path = Path(self.config.data_path)
92
  for file in tqdm(sorted(data_path.glob('*/*.*'))):
93
+ label = file.parent.name
94
  labels.append(label_dict[label])
95
  file_names.append(str(file))
96
 
97
  df = pd.DataFrame.from_dict({"image": file_names, "label": labels})
 
 
98
  ros = RandomOverSampler(random_state=self.config.random_state)
99
  df_resampled, y_resampled = ros.fit_resample(df[['image']], df['label'])
100
  df = pd.concat([df_resampled, y_resampled], axis=1)
 
102
  dataset = Dataset.from_pandas(df).cast_column("image", Image())
103
 
104
  labels_list = sorted(list(set(labels)))
105
+ label2id = {label: i for i, label in enumerate(labels_list)}
106
+ id2label = {i: label for i, label in enumerate(labels_list)}
107
 
108
  ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)
109
  dataset = dataset.map(lambda x: {'label': ClassLabels.str2int(x['label'])}, batched=True)
110
  dataset = dataset.cast_column('label', ClassLabels)
111
 
112
+ split_dataset = dataset.train_test_split(test_size=self.config.test_split_size, shuffle=True, stratify_by_column="label")
113
+ return split_dataset, id2label, label2id
114
 
115
  def train(self):
116
  device = "cuda" if torch.cuda.is_available() else "cpu"
117
  logger.info(f"Using device: {device}")
118
 
119
+ split_dataset, id2label, label2id = self._prepare_data()
120
  train_data = split_dataset['train']
121
  test_data = split_dataset['test']
122
 
123
  processor = EfficientFormerImageProcessor.from_pretrained(self.config.model_name)
124
 
125
+ # Define base transforms (without normalization)
 
 
 
126
  _train_transforms = Compose([
127
+ Resize((self.config.image_size, self.config.image_size)),
128
  RandomRotation(15),
129
  RandomHorizontalFlip(0.5),
130
  ToTensor(),
 
131
  ])
132
  _val_transforms = Compose([
133
+ Resize((self.config.image_size, self.config.image_size)),
134
  ToTensor(),
 
135
  ])
136
 
137
+ # Use functools.partial to create specialized versions of our top-level function
138
+ # This is a pickle-safe way to pass extra arguments (processor, transforms)
139
+ train_transform_func = partial(apply_transforms, processor=processor, transform_pipeline=_train_transforms)
140
+ val_transform_func = partial(apply_transforms, processor=processor, transform_pipeline=_val_transforms)
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ train_data.set_transform(train_transform_func)
143
+ test_data.set_transform(val_transform_func)
144
+
145
  model = EfficientFormerForImageClassification.from_pretrained(
146
  self.config.model_name,
147
+ num_labels=len(id2label),
148
+ id2label=id2label,
149
+ label2id=label2id,
150
+ ignore_mismatched_sizes=True
151
  ).to(device)
152
 
 
 
 
 
 
 
153
  args = TrainingArguments(
154
  output_dir=self.config.root_dir,
155
  logging_dir=f'{self.config.root_dir}/logs',
 
164
  load_best_model_at_end=True,
165
  metric_for_best_model="accuracy",
166
  save_total_limit=1,
167
+ remove_unused_columns=False,
168
+ dataloader_num_workers=4,
169
  report_to="none"
170
  )
171
 
src/cnnClassifier/config/configuration.py CHANGED
@@ -39,20 +39,22 @@ class ConfigurationManager:
39
 
40
  def get_model_trainer_config(self) -> ModelTrainerConfig:
41
  config = self.config.model_trainer
 
42
  params = self.params
43
  create_directories([config.root_dir])
44
 
45
  model_trainer_config = ModelTrainerConfig(
46
- root_dir=Path(config.root_dir),
47
- trained_model_path=Path(config.trained_model_path),
48
- model_name=config.model_name,
49
- image_size=params.IMAGE_SIZE,
50
- learning_rate=params.LEARNING_RATE,
51
- batch_size=params.BATCH_SIZE,
52
- num_train_epochs=params.NUM_TRAIN_EPOCHS,
53
- weight_decay=params.WEIGHT_DECAY,
54
- warmup_steps=params.WARMUP_STEPS,
55
- test_split_size=params.TEST_SPLIT_SIZE,
56
- random_state=params.RANDOM_STATE
 
57
  )
58
  return model_trainer_config
 
39
 
40
  def get_model_trainer_config(self) -> ModelTrainerConfig:
41
  config = self.config.model_trainer
42
+ data_prep_config = self.config.data_preparation
43
  params = self.params
44
  create_directories([config.root_dir])
45
 
46
  model_trainer_config = ModelTrainerConfig(
47
+ root_dir=Path(config.root_dir),
48
+ data_path=Path(data_prep_config.data_path),
49
+ trained_model_path=Path(config.trained_model_path),
50
+ model_name=config.model_name,
51
+ image_size=int(params.IMAGE_SIZE),
52
+ learning_rate=float(params.LEARNING_RATE), # <<< CORRECTED
53
+ batch_size=int(params.BATCH_SIZE),
54
+ num_train_epochs=int(params.NUM_TRAIN_EPOCHS),
55
+ weight_decay=float(params.WEIGHT_DECAY), # <<< CORRECTED
56
+ warmup_steps=int(params.WARMUP_STEPS),
57
+ test_split_size=float(params.TEST_SPLIT_SIZE), # <<< CORRECTED
58
+ random_state=int(params.RANDOM_STATE)
59
  )
60
  return model_trainer_config
src/cnnClassifier/entity/config_entity.py CHANGED
@@ -17,6 +17,7 @@ class DataPreparationConfig:
17
  @dataclass(frozen=True)
18
  class ModelTrainerConfig:
19
  root_dir: Path
 
20
  trained_model_path: Path
21
  model_name: str
22
  image_size: int
 
17
  @dataclass(frozen=True)
18
  class ModelTrainerConfig:
19
  root_dir: Path
20
+ data_path: Path
21
  trained_model_path: Path
22
  model_name: str
23
  image_size: int