Spaces:
Runtime error
Runtime error
Complete baseline model with EfficientFormer
Browse files- .gitignore +1 -1
- app.py +196 -0
- config/config.yaml +1 -1
- dvc.lock +53 -0
- requirements.txt +5 -3
- src/cnnClassifier/__init__.py +26 -0
- src/cnnClassifier/components/model_trainer.py +57 -45
- src/cnnClassifier/config/configuration.py +13 -11
- src/cnnClassifier/entity/config_entity.py +1 -0
.gitignore
CHANGED
|
@@ -130,7 +130,7 @@ __pypackages__/
|
|
| 130 |
# Celery stuff
|
| 131 |
celerybeat-schedule
|
| 132 |
celerybeat.pid
|
| 133 |
-
|
| 134 |
# SageMath parsed files
|
| 135 |
*.sage.py
|
| 136 |
|
|
|
|
| 130 |
# Celery stuff
|
| 131 |
celerybeat-schedule
|
| 132 |
celerybeat.pid
|
| 133 |
+
artifacts/data_ingestion/*
|
| 134 |
# SageMath parsed files
|
| 135 |
*.sage.py
|
| 136 |
|
app.py
CHANGED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from transformers import pipeline
|
| 6 |
+
from mtcnn import MTCNN
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
st.set_page_config(layout="wide", page_title="Facial Age Detection")
|
| 10 |
+
|
| 11 |
+
st.title("Facial Age Detection")
|
| 12 |
+
st.write("Detect age groups from images, videos, or a live webcam feed.")
|
| 13 |
+
st.write("This application uses an EfficientFormer-L1 model fine-tuned on the Facial Age dataset.")
|
| 14 |
+
|
| 15 |
+
# --- Helper Functions and Classes ---
|
| 16 |
+
|
| 17 |
+
@st.cache_resource
|
| 18 |
+
def load_model():
|
| 19 |
+
"""Load the age detection model pipeline."""
|
| 20 |
+
model_path = "artifacts/model_trainer/facial_age_detector_model"
|
| 21 |
+
pipe = pipeline('image-classification', model=model_path, device=0) # Use 0 for GPU
|
| 22 |
+
return pipe
|
| 23 |
+
|
| 24 |
+
@st.cache_resource
|
| 25 |
+
def load_face_detector():
|
| 26 |
+
"""Load the MTCNN face detector."""
|
| 27 |
+
return MTCNN()
|
| 28 |
+
|
| 29 |
+
def iou(boxA, boxB):
|
| 30 |
+
"""Calculate Intersection over Union."""
|
| 31 |
+
xA = max(boxA[0], boxB[0])
|
| 32 |
+
yA = max(boxA[1], boxB[1])
|
| 33 |
+
xB = min(boxA[2], boxB[2])
|
| 34 |
+
yB = min(boxA[3], boxB[3])
|
| 35 |
+
interArea = max(0, xB - xA) * max(0, yB - yA)
|
| 36 |
+
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
| 37 |
+
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
| 38 |
+
return interArea / float(boxAArea + boxBArea - interArea)
|
| 39 |
+
|
| 40 |
+
class EMATracker:
|
| 41 |
+
"""Exponential Moving Average Tracker for smoothing predictions."""
|
| 42 |
+
def __init__(self, alpha=0.3):
|
| 43 |
+
self.alpha = alpha
|
| 44 |
+
self.tracked_objects = {} # {track_id: {box: [], ema_preds: {}}}
|
| 45 |
+
|
| 46 |
+
def update(self, detections, id_counter):
|
| 47 |
+
# Detections are a list of face boxes
|
| 48 |
+
# Simple tracking by IOU
|
| 49 |
+
|
| 50 |
+
# Match detections to existing tracks
|
| 51 |
+
matches = {} # {track_id: det_idx}
|
| 52 |
+
used_det_indices = set()
|
| 53 |
+
|
| 54 |
+
# This is a simple greedy matching. For more robust tracking, consider Hungarian algorithm.
|
| 55 |
+
for track_id, data in self.tracked_objects.items():
|
| 56 |
+
best_iou = 0
|
| 57 |
+
best_det_idx = -1
|
| 58 |
+
for i, det_box in enumerate(detections):
|
| 59 |
+
if i in used_det_indices: continue
|
| 60 |
+
current_iou = iou(data['box'], det_box)
|
| 61 |
+
if current_iou > best_iou and current_iou > 0.3: # IOU threshold
|
| 62 |
+
best_iou = current_iou
|
| 63 |
+
best_det_idx = i
|
| 64 |
+
if best_det_idx != -1:
|
| 65 |
+
matches[track_id] = best_det_idx
|
| 66 |
+
used_det_indices.add(best_det_idx)
|
| 67 |
+
|
| 68 |
+
# Update matched tracks
|
| 69 |
+
for track_id, det_idx in matches.items():
|
| 70 |
+
self.tracked_objects[track_id]['box'] = detections[det_idx]
|
| 71 |
+
|
| 72 |
+
# Add new tracks
|
| 73 |
+
for i, det_box in enumerate(detections):
|
| 74 |
+
if i not in used_det_indices:
|
| 75 |
+
self.tracked_objects[id_counter] = {'box': det_box, 'ema_preds': defaultdict(float)}
|
| 76 |
+
id_counter += 1
|
| 77 |
+
|
| 78 |
+
# Remove old tracks (optional, for long videos)
|
| 79 |
+
|
| 80 |
+
return id_counter
|
| 81 |
+
|
| 82 |
+
def apply_ema(self, track_id, new_preds):
|
| 83 |
+
"""Applies EMA to the predictions for a given track."""
|
| 84 |
+
if track_id not in self.tracked_objects:
|
| 85 |
+
return {}
|
| 86 |
+
|
| 87 |
+
current_ema = self.tracked_objects[track_id]['ema_preds']
|
| 88 |
+
|
| 89 |
+
# Initialize if new
|
| 90 |
+
if not current_ema:
|
| 91 |
+
for pred in new_preds:
|
| 92 |
+
current_ema[pred['label']] = pred['score']
|
| 93 |
+
else:
|
| 94 |
+
# Update existing values
|
| 95 |
+
for pred in new_preds:
|
| 96 |
+
label = pred['label']
|
| 97 |
+
current_ema[label] = (self.alpha * pred['score']) + ((1 - self.alpha) * current_ema[label])
|
| 98 |
+
|
| 99 |
+
self.tracked_objects[track_id]['ema_preds'] = current_ema
|
| 100 |
+
|
| 101 |
+
# Return the top prediction from EMA
|
| 102 |
+
if not current_ema: return None
|
| 103 |
+
top_label = max(current_ema, key=current_ema.get)
|
| 104 |
+
return f"{top_label} ({current_ema[top_label]:.2f})"
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# --- Load Models ---
|
| 108 |
+
try:
|
| 109 |
+
age_pipe = load_model()
|
| 110 |
+
face_detector = load_face_detector()
|
| 111 |
+
except Exception as e:
|
| 112 |
+
st.error(f"Error loading models: {e}. Please ensure the model is trained and located at 'artifacts/model_trainer/facial_age_detector_model'.")
|
| 113 |
+
st.stop()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# --- UI Sidebar ---
|
| 117 |
+
st.sidebar.header("Input Options")
|
| 118 |
+
app_mode = st.sidebar.selectbox("Choose the app mode", ["Image", "Video", "Live Webcam"])
|
| 119 |
+
|
| 120 |
+
# --- Main App Logic ---
|
| 121 |
+
|
| 122 |
+
if app_mode == "Image":
|
| 123 |
+
uploaded_file = st.sidebar.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
|
| 124 |
+
if uploaded_file is not None:
|
| 125 |
+
image = Image.open(uploaded_file).convert("RGB")
|
| 126 |
+
img_array = np.array(image)
|
| 127 |
+
|
| 128 |
+
st.image(image, caption='Uploaded Image.', use_column_width=True)
|
| 129 |
+
st.write("")
|
| 130 |
+
st.write("Detecting faces and predicting age...")
|
| 131 |
+
|
| 132 |
+
faces = face_detector.detect_faces(img_array)
|
| 133 |
+
|
| 134 |
+
if not faces:
|
| 135 |
+
st.warning("No faces detected in the image.")
|
| 136 |
+
else:
|
| 137 |
+
for face in faces:
|
| 138 |
+
x, y, w, h = face['box']
|
| 139 |
+
face_img = img_array[y:y+h, x:x+w]
|
| 140 |
+
pil_face = Image.fromarray(face_img)
|
| 141 |
+
|
| 142 |
+
# Predict age
|
| 143 |
+
age_preds = age_pipe(pil_face)
|
| 144 |
+
top_pred = age_preds[0]
|
| 145 |
+
|
| 146 |
+
# Draw on image
|
| 147 |
+
cv2.rectangle(img_array, (x, y), (x+w, y+h), (0, 255, 0), 2)
|
| 148 |
+
label = f"Age: {top_pred['label']} ({top_pred['score']:.2f})"
|
| 149 |
+
cv2.putText(img_array, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0,255,0), 2)
|
| 150 |
+
|
| 151 |
+
st.image(img_array, caption='Processed Image.', use_column_width=True)
|
| 152 |
+
|
| 153 |
+
elif app_mode == "Live Webcam":
|
| 154 |
+
st.sidebar.info("Webcam feed will start automatically. Press 'Stop' to end.")
|
| 155 |
+
run = st.sidebar.button('Start Webcam')
|
| 156 |
+
stop = st.sidebar.button('Stop Webcam')
|
| 157 |
+
FRAME_WINDOW = st.image([])
|
| 158 |
+
|
| 159 |
+
cap = cv2.VideoCapture(0)
|
| 160 |
+
tracker = EMATracker()
|
| 161 |
+
track_id_counter = 0
|
| 162 |
+
|
| 163 |
+
while run and not stop:
|
| 164 |
+
ret, frame = cap.read()
|
| 165 |
+
if not ret:
|
| 166 |
+
st.error("Failed to capture image from webcam.")
|
| 167 |
+
break
|
| 168 |
+
|
| 169 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 170 |
+
faces = face_detector.detect_faces(frame_rgb)
|
| 171 |
+
|
| 172 |
+
detection_boxes = [f['box'] for f in faces]
|
| 173 |
+
track_id_counter = tracker.update(detection_boxes, track_id_counter)
|
| 174 |
+
|
| 175 |
+
for track_id, data in tracker.tracked_objects.items():
|
| 176 |
+
x, y, w, h = data['box']
|
| 177 |
+
if w > 20 and h > 20: # Filter small detections
|
| 178 |
+
face_img = frame_rgb[y:y+h, x:x+w]
|
| 179 |
+
pil_face = Image.fromarray(face_img)
|
| 180 |
+
|
| 181 |
+
age_preds = age_pipe(pil_face)
|
| 182 |
+
smoothed_label = tracker.apply_ema(track_id, age_preds)
|
| 183 |
+
|
| 184 |
+
cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
|
| 185 |
+
if smoothed_label:
|
| 186 |
+
cv2.putText(frame, smoothed_label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2)
|
| 187 |
+
|
| 188 |
+
FRAME_WINDOW.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 189 |
+
|
| 190 |
+
cap.release()
|
| 191 |
+
st.sidebar.info("Webcam stopped.")
|
| 192 |
+
|
| 193 |
+
# Add a placeholder for Video processing, which would be similar to Webcam but with a file uploader.
|
| 194 |
+
elif app_mode == "Video":
|
| 195 |
+
st.sidebar.warning("Video processing is similar to the webcam feed but processes a file. This feature is not fully implemented in this demo but follows the same logic.")
|
| 196 |
+
# You would use cv2.VideoCapture(video_path) and loop through frames.
|
config/config.yaml
CHANGED
|
@@ -15,4 +15,4 @@ model_trainer:
|
|
| 15 |
root_dir: artifacts/model_trainer
|
| 16 |
trained_model_path: artifacts/model_trainer/facial_age_detector_model
|
| 17 |
# Using EfficientFormer-L1, a much lighter model than ViT
|
| 18 |
-
model_name: "snap-research/efficientformer-l1"
|
|
|
|
| 15 |
root_dir: artifacts/model_trainer
|
| 16 |
trained_model_path: artifacts/model_trainer/facial_age_detector_model
|
| 17 |
# Using EfficientFormer-L1, a much lighter model than ViT
|
| 18 |
+
model_name: "snap-research/efficientformer-l1-300"
|
dvc.lock
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
schema: '2.0'
|
| 2 |
+
stages:
|
| 3 |
+
data_ingestion:
|
| 4 |
+
cmd: python src/cnnClassifier/pipeline/stage_01_data_ingestion.py
|
| 5 |
+
deps:
|
| 6 |
+
- path: config/config.yaml
|
| 7 |
+
hash: md5
|
| 8 |
+
md5: 3cea2dfb36f0a5e40dd599dad9458ca4
|
| 9 |
+
size: 609
|
| 10 |
+
- path: src/cnnClassifier/components/data_ingestion.py
|
| 11 |
+
hash: md5
|
| 12 |
+
md5: 80b591ef3eedaf256ef85f4d196a0d43
|
| 13 |
+
size: 1591
|
| 14 |
+
- path: src/cnnClassifier/pipeline/stage_01_data_ingestion.py
|
| 15 |
+
hash: md5
|
| 16 |
+
md5: 2e1c2ad52ddc9763ff2a241576a7477c
|
| 17 |
+
size: 904
|
| 18 |
+
outs:
|
| 19 |
+
- path: artifacts/data_ingestion
|
| 20 |
+
hash: md5
|
| 21 |
+
md5: 35941f86a72fc72e64cb3195753ae21d.dir
|
| 22 |
+
size: 1758455894
|
| 23 |
+
nfiles: 19557
|
| 24 |
+
model_training:
|
| 25 |
+
cmd: python src/cnnClassifier/pipeline/stage_02_model_training.py
|
| 26 |
+
deps:
|
| 27 |
+
- path: artifacts/data_ingestion
|
| 28 |
+
hash: md5
|
| 29 |
+
md5: 35941f86a72fc72e64cb3195753ae21d.dir
|
| 30 |
+
size: 1758455894
|
| 31 |
+
nfiles: 19557
|
| 32 |
+
- path: config/config.yaml
|
| 33 |
+
hash: md5
|
| 34 |
+
md5: 3cea2dfb36f0a5e40dd599dad9458ca4
|
| 35 |
+
size: 609
|
| 36 |
+
- path: params.yaml
|
| 37 |
+
hash: md5
|
| 38 |
+
md5: ce8c137aa11f22d0901fb41485e9bfde
|
| 39 |
+
size: 239
|
| 40 |
+
- path: src/cnnClassifier/components/model_trainer.py
|
| 41 |
+
hash: md5
|
| 42 |
+
md5: bc58a9fdc35492409863b38424773ef6
|
| 43 |
+
size: 8585
|
| 44 |
+
- path: src/cnnClassifier/pipeline/stage_02_model_training.py
|
| 45 |
+
hash: md5
|
| 46 |
+
md5: 374003acf88403924718ed5982007523
|
| 47 |
+
size: 829
|
| 48 |
+
outs:
|
| 49 |
+
- path: artifacts/model_trainer
|
| 50 |
+
hash: md5
|
| 51 |
+
md5: 621f61ba7beea89c3bef7a921afdcc9d.dir
|
| 52 |
+
size: 183039001
|
| 53 |
+
nfiles: 12
|
requirements.txt
CHANGED
|
@@ -5,13 +5,14 @@ torchvision==0.16.0+cu118
|
|
| 5 |
torchaudio==2.1.0
|
| 6 |
|
| 7 |
# Pin NumPy to a version compatible with Torch 2.1.0
|
| 8 |
-
numpy
|
| 9 |
|
| 10 |
# Hugging Face
|
| 11 |
-
transformers
|
|
|
|
| 12 |
datasets>=2.14.5
|
| 13 |
evaluate
|
| 14 |
-
accelerate>=0.
|
| 15 |
|
| 16 |
# MLOps and Utilities
|
| 17 |
mlflow
|
|
@@ -29,6 +30,7 @@ imblearn
|
|
| 29 |
streamlit
|
| 30 |
opencv-python
|
| 31 |
mtcnn
|
|
|
|
| 32 |
|
| 33 |
# AWS Deployment
|
| 34 |
boto3
|
|
|
|
| 5 |
torchaudio==2.1.0
|
| 6 |
|
| 7 |
# Pin NumPy to a version compatible with Torch 2.1.0
|
| 8 |
+
numpy>=1.23,<2.0
|
| 9 |
|
| 10 |
# Hugging Face
|
| 11 |
+
transformers==4.36.2
|
| 12 |
+
tokenizers==0.15.0
|
| 13 |
datasets>=2.14.5
|
| 14 |
evaluate
|
| 15 |
+
accelerate>=0.25
|
| 16 |
|
| 17 |
# MLOps and Utilities
|
| 18 |
mlflow
|
|
|
|
| 30 |
streamlit
|
| 31 |
opencv-python
|
| 32 |
mtcnn
|
| 33 |
+
tensorflow==2.15.0
|
| 34 |
|
| 35 |
# AWS Deployment
|
| 36 |
boto3
|
src/cnnClassifier/__init__.py
CHANGED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
# Define the logging format
|
| 6 |
+
logging_str = "[%(asctime)s: %(levelname)s: %(module)s: %(message)s]"
|
| 7 |
+
|
| 8 |
+
# Define the directory for log files
|
| 9 |
+
log_dir = "logs"
|
| 10 |
+
log_filepath = os.path.join(log_dir,"running_logs.log")
|
| 11 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Configure the logging
|
| 15 |
+
logging.basicConfig(
|
| 16 |
+
level= logging.INFO,
|
| 17 |
+
format= logging_str,
|
| 18 |
+
|
| 19 |
+
handlers=[
|
| 20 |
+
logging.FileHandler(log_filepath), # Log to a file
|
| 21 |
+
logging.StreamHandler(sys.stdout) # Also print to the console
|
| 22 |
+
]
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Create the logger object
|
| 26 |
+
logger = logging.getLogger("cnnClassifierLogger")
|
src/cnnClassifier/components/model_trainer.py
CHANGED
|
@@ -2,32 +2,62 @@ import torch
|
|
| 2 |
import pandas as pd
|
| 3 |
from pathlib import Path
|
| 4 |
from tqdm import tqdm
|
|
|
|
| 5 |
from datasets import Dataset, Image, ClassLabel
|
| 6 |
from imblearn.over_sampling import RandomOverSampler
|
| 7 |
from transformers import (
|
| 8 |
EfficientFormerImageProcessor,
|
| 9 |
EfficientFormerForImageClassification,
|
| 10 |
TrainingArguments,
|
| 11 |
-
Trainer
|
| 12 |
-
DefaultDataCollator
|
| 13 |
)
|
| 14 |
from torchvision.transforms import (
|
| 15 |
Compose,
|
| 16 |
Normalize,
|
| 17 |
RandomRotation,
|
| 18 |
-
RandomResizedCrop,
|
| 19 |
RandomHorizontalFlip,
|
| 20 |
Resize,
|
| 21 |
ToTensor
|
| 22 |
)
|
| 23 |
import evaluate
|
| 24 |
from cnnClassifier.entity.config_entity import ModelTrainerConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
class ModelTrainer:
|
| 27 |
def __init__(self, config: ModelTrainerConfig):
|
| 28 |
self.config = config
|
| 29 |
-
self.label2id = None
|
| 30 |
-
self.id2label = None
|
| 31 |
|
| 32 |
def _prepare_data(self):
|
| 33 |
logger.info("Preparing data...")
|
|
@@ -60,13 +90,11 @@ class ModelTrainer:
|
|
| 60 |
file_names, labels = [], []
|
| 61 |
data_path = Path(self.config.data_path)
|
| 62 |
for file in tqdm(sorted(data_path.glob('*/*.*'))):
|
| 63 |
-
label =
|
| 64 |
labels.append(label_dict[label])
|
| 65 |
file_names.append(str(file))
|
| 66 |
|
| 67 |
df = pd.DataFrame.from_dict({"image": file_names, "label": labels})
|
| 68 |
-
|
| 69 |
-
# Random oversampling
|
| 70 |
ros = RandomOverSampler(random_state=self.config.random_state)
|
| 71 |
df_resampled, y_resampled = ros.fit_resample(df[['image']], df['label'])
|
| 72 |
df = pd.concat([df_resampled, y_resampled], axis=1)
|
|
@@ -74,72 +102,54 @@ class ModelTrainer:
|
|
| 74 |
dataset = Dataset.from_pandas(df).cast_column("image", Image())
|
| 75 |
|
| 76 |
labels_list = sorted(list(set(labels)))
|
| 77 |
-
|
| 78 |
-
|
| 79 |
|
| 80 |
ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)
|
| 81 |
dataset = dataset.map(lambda x: {'label': ClassLabels.str2int(x['label'])}, batched=True)
|
| 82 |
dataset = dataset.cast_column('label', ClassLabels)
|
| 83 |
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
def train(self):
|
| 87 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 88 |
logger.info(f"Using device: {device}")
|
| 89 |
|
| 90 |
-
split_dataset = self._prepare_data()
|
| 91 |
train_data = split_dataset['train']
|
| 92 |
test_data = split_dataset['test']
|
| 93 |
|
| 94 |
processor = EfficientFormerImageProcessor.from_pretrained(self.config.model_name)
|
| 95 |
|
| 96 |
-
|
| 97 |
-
size = self.config.image_size
|
| 98 |
-
|
| 99 |
-
normalize = Normalize(mean=image_mean, std=image_std)
|
| 100 |
_train_transforms = Compose([
|
| 101 |
-
Resize((
|
| 102 |
RandomRotation(15),
|
| 103 |
RandomHorizontalFlip(0.5),
|
| 104 |
ToTensor(),
|
| 105 |
-
normalize
|
| 106 |
])
|
| 107 |
_val_transforms = Compose([
|
| 108 |
-
Resize((
|
| 109 |
ToTensor(),
|
| 110 |
-
normalize
|
| 111 |
])
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def val_transforms(examples):
|
| 118 |
-
examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
|
| 119 |
-
return examples
|
| 120 |
-
|
| 121 |
-
train_data.set_transform(train_transforms)
|
| 122 |
-
test_data.set_transform(val_transforms)
|
| 123 |
-
|
| 124 |
-
def collate_fn(examples):
|
| 125 |
-
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 126 |
-
labels = torch.tensor([example['label'] for example in examples])
|
| 127 |
-
return {"pixel_values": pixel_values, "labels": labels}
|
| 128 |
|
|
|
|
|
|
|
|
|
|
| 129 |
model = EfficientFormerForImageClassification.from_pretrained(
|
| 130 |
self.config.model_name,
|
| 131 |
-
num_labels=len(
|
| 132 |
-
id2label=
|
| 133 |
-
label2id=
|
| 134 |
-
ignore_mismatched_sizes=True
|
| 135 |
).to(device)
|
| 136 |
|
| 137 |
-
accuracy = evaluate.load("accuracy")
|
| 138 |
-
def compute_metrics(eval_pred):
|
| 139 |
-
predictions, label_ids = eval_pred
|
| 140 |
-
predicted_labels = predictions.argmax(axis=1)
|
| 141 |
-
return accuracy.compute(predictions=predicted_labels, references=label_ids)
|
| 142 |
-
|
| 143 |
args = TrainingArguments(
|
| 144 |
output_dir=self.config.root_dir,
|
| 145 |
logging_dir=f'{self.config.root_dir}/logs',
|
|
@@ -154,6 +164,8 @@ class ModelTrainer:
|
|
| 154 |
load_best_model_at_end=True,
|
| 155 |
metric_for_best_model="accuracy",
|
| 156 |
save_total_limit=1,
|
|
|
|
|
|
|
| 157 |
report_to="none"
|
| 158 |
)
|
| 159 |
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
from pathlib import Path
|
| 4 |
from tqdm import tqdm
|
| 5 |
+
from functools import partial
|
| 6 |
from datasets import Dataset, Image, ClassLabel
|
| 7 |
from imblearn.over_sampling import RandomOverSampler
|
| 8 |
from transformers import (
|
| 9 |
EfficientFormerImageProcessor,
|
| 10 |
EfficientFormerForImageClassification,
|
| 11 |
TrainingArguments,
|
| 12 |
+
Trainer
|
|
|
|
| 13 |
)
|
| 14 |
from torchvision.transforms import (
|
| 15 |
Compose,
|
| 16 |
Normalize,
|
| 17 |
RandomRotation,
|
|
|
|
| 18 |
RandomHorizontalFlip,
|
| 19 |
Resize,
|
| 20 |
ToTensor
|
| 21 |
)
|
| 22 |
import evaluate
|
| 23 |
from cnnClassifier.entity.config_entity import ModelTrainerConfig
|
| 24 |
+
from cnnClassifier import logger
|
| 25 |
+
|
| 26 |
+
# ==============================================================================
|
| 27 |
+
# TOP-LEVEL FUNCTION DEFINITIONS (FOR PICKLING)
|
| 28 |
+
# ==============================================================================
|
| 29 |
+
|
| 30 |
+
def apply_transforms(batch, processor, transform_pipeline):
|
| 31 |
+
"""Applies a given transformation pipeline to a batch of images."""
|
| 32 |
+
# Create the normalization transform with stats from the processor
|
| 33 |
+
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
|
| 34 |
+
|
| 35 |
+
# Combine the base transforms with normalization
|
| 36 |
+
full_transforms = Compose([*transform_pipeline.transforms, normalize])
|
| 37 |
+
|
| 38 |
+
# Apply to each image in the batch
|
| 39 |
+
batch["pixel_values"] = [full_transforms(img.convert("RGB")) for img in batch["image"]]
|
| 40 |
+
return batch
|
| 41 |
+
|
| 42 |
+
def collate_fn(batch):
|
| 43 |
+
"""A custom collate function for image classification."""
|
| 44 |
+
return {
|
| 45 |
+
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
|
| 46 |
+
'labels': torch.tensor([x['label'] for x in batch])
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
def compute_metrics(eval_pred):
|
| 50 |
+
"""Computes accuracy metric for evaluation."""
|
| 51 |
+
accuracy = evaluate.load("accuracy")
|
| 52 |
+
predictions, label_ids = eval_pred
|
| 53 |
+
predicted_labels = predictions.argmax(axis=1)
|
| 54 |
+
return accuracy.compute(predictions=predicted_labels, references=label_ids)
|
| 55 |
+
|
| 56 |
+
# ==============================================================================
|
| 57 |
|
| 58 |
class ModelTrainer:
|
| 59 |
def __init__(self, config: ModelTrainerConfig):
|
| 60 |
self.config = config
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def _prepare_data(self):
|
| 63 |
logger.info("Preparing data...")
|
|
|
|
| 90 |
file_names, labels = [], []
|
| 91 |
data_path = Path(self.config.data_path)
|
| 92 |
for file in tqdm(sorted(data_path.glob('*/*.*'))):
|
| 93 |
+
label = file.parent.name
|
| 94 |
labels.append(label_dict[label])
|
| 95 |
file_names.append(str(file))
|
| 96 |
|
| 97 |
df = pd.DataFrame.from_dict({"image": file_names, "label": labels})
|
|
|
|
|
|
|
| 98 |
ros = RandomOverSampler(random_state=self.config.random_state)
|
| 99 |
df_resampled, y_resampled = ros.fit_resample(df[['image']], df['label'])
|
| 100 |
df = pd.concat([df_resampled, y_resampled], axis=1)
|
|
|
|
| 102 |
dataset = Dataset.from_pandas(df).cast_column("image", Image())
|
| 103 |
|
| 104 |
labels_list = sorted(list(set(labels)))
|
| 105 |
+
label2id = {label: i for i, label in enumerate(labels_list)}
|
| 106 |
+
id2label = {i: label for i, label in enumerate(labels_list)}
|
| 107 |
|
| 108 |
ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)
|
| 109 |
dataset = dataset.map(lambda x: {'label': ClassLabels.str2int(x['label'])}, batched=True)
|
| 110 |
dataset = dataset.cast_column('label', ClassLabels)
|
| 111 |
|
| 112 |
+
split_dataset = dataset.train_test_split(test_size=self.config.test_split_size, shuffle=True, stratify_by_column="label")
|
| 113 |
+
return split_dataset, id2label, label2id
|
| 114 |
|
| 115 |
def train(self):
|
| 116 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 117 |
logger.info(f"Using device: {device}")
|
| 118 |
|
| 119 |
+
split_dataset, id2label, label2id = self._prepare_data()
|
| 120 |
train_data = split_dataset['train']
|
| 121 |
test_data = split_dataset['test']
|
| 122 |
|
| 123 |
processor = EfficientFormerImageProcessor.from_pretrained(self.config.model_name)
|
| 124 |
|
| 125 |
+
# Define base transforms (without normalization)
|
|
|
|
|
|
|
|
|
|
| 126 |
_train_transforms = Compose([
|
| 127 |
+
Resize((self.config.image_size, self.config.image_size)),
|
| 128 |
RandomRotation(15),
|
| 129 |
RandomHorizontalFlip(0.5),
|
| 130 |
ToTensor(),
|
|
|
|
| 131 |
])
|
| 132 |
_val_transforms = Compose([
|
| 133 |
+
Resize((self.config.image_size, self.config.image_size)),
|
| 134 |
ToTensor(),
|
|
|
|
| 135 |
])
|
| 136 |
|
| 137 |
+
# Use functools.partial to create specialized versions of our top-level function
|
| 138 |
+
# This is a pickle-safe way to pass extra arguments (processor, transforms)
|
| 139 |
+
train_transform_func = partial(apply_transforms, processor=processor, transform_pipeline=_train_transforms)
|
| 140 |
+
val_transform_func = partial(apply_transforms, processor=processor, transform_pipeline=_val_transforms)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
train_data.set_transform(train_transform_func)
|
| 143 |
+
test_data.set_transform(val_transform_func)
|
| 144 |
+
|
| 145 |
model = EfficientFormerForImageClassification.from_pretrained(
|
| 146 |
self.config.model_name,
|
| 147 |
+
num_labels=len(id2label),
|
| 148 |
+
id2label=id2label,
|
| 149 |
+
label2id=label2id,
|
| 150 |
+
ignore_mismatched_sizes=True
|
| 151 |
).to(device)
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
args = TrainingArguments(
|
| 154 |
output_dir=self.config.root_dir,
|
| 155 |
logging_dir=f'{self.config.root_dir}/logs',
|
|
|
|
| 164 |
load_best_model_at_end=True,
|
| 165 |
metric_for_best_model="accuracy",
|
| 166 |
save_total_limit=1,
|
| 167 |
+
remove_unused_columns=False,
|
| 168 |
+
dataloader_num_workers=4,
|
| 169 |
report_to="none"
|
| 170 |
)
|
| 171 |
|
src/cnnClassifier/config/configuration.py
CHANGED
|
@@ -39,20 +39,22 @@ class ConfigurationManager:
|
|
| 39 |
|
| 40 |
def get_model_trainer_config(self) -> ModelTrainerConfig:
|
| 41 |
config = self.config.model_trainer
|
|
|
|
| 42 |
params = self.params
|
| 43 |
create_directories([config.root_dir])
|
| 44 |
|
| 45 |
model_trainer_config = ModelTrainerConfig(
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
| 57 |
)
|
| 58 |
return model_trainer_config
|
|
|
|
| 39 |
|
| 40 |
def get_model_trainer_config(self) -> ModelTrainerConfig:
|
| 41 |
config = self.config.model_trainer
|
| 42 |
+
data_prep_config = self.config.data_preparation
|
| 43 |
params = self.params
|
| 44 |
create_directories([config.root_dir])
|
| 45 |
|
| 46 |
model_trainer_config = ModelTrainerConfig(
|
| 47 |
+
root_dir=Path(config.root_dir),
|
| 48 |
+
data_path=Path(data_prep_config.data_path),
|
| 49 |
+
trained_model_path=Path(config.trained_model_path),
|
| 50 |
+
model_name=config.model_name,
|
| 51 |
+
image_size=int(params.IMAGE_SIZE),
|
| 52 |
+
learning_rate=float(params.LEARNING_RATE), # <<< CORRECTED
|
| 53 |
+
batch_size=int(params.BATCH_SIZE),
|
| 54 |
+
num_train_epochs=int(params.NUM_TRAIN_EPOCHS),
|
| 55 |
+
weight_decay=float(params.WEIGHT_DECAY), # <<< CORRECTED
|
| 56 |
+
warmup_steps=int(params.WARMUP_STEPS),
|
| 57 |
+
test_split_size=float(params.TEST_SPLIT_SIZE), # <<< CORRECTED
|
| 58 |
+
random_state=int(params.RANDOM_STATE)
|
| 59 |
)
|
| 60 |
return model_trainer_config
|
src/cnnClassifier/entity/config_entity.py
CHANGED
|
@@ -17,6 +17,7 @@ class DataPreparationConfig:
|
|
| 17 |
@dataclass(frozen=True)
|
| 18 |
class ModelTrainerConfig:
|
| 19 |
root_dir: Path
|
|
|
|
| 20 |
trained_model_path: Path
|
| 21 |
model_name: str
|
| 22 |
image_size: int
|
|
|
|
| 17 |
@dataclass(frozen=True)
|
| 18 |
class ModelTrainerConfig:
|
| 19 |
root_dir: Path
|
| 20 |
+
data_path: Path
|
| 21 |
trained_model_path: Path
|
| 22 |
model_name: str
|
| 23 |
image_size: int
|