ALYYAN commited on
Commit
1f4e421
·
1 Parent(s): eacd6a2

FEAT: Finalize code for Hugging Face deployment

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +29 -33
  3. packages.txt +2 -0
  4. requirements.txt +7 -26
  5. src/cnnClassifier/pipeline/prediction.py +67 -137
.gitignore CHANGED
@@ -205,3 +205,4 @@ cython_debug/
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
 
 
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
208
+ aws-key.pem
app.py CHANGED
@@ -9,53 +9,49 @@ import tempfile
9
  import time
10
  from streamlit_option_menu import option_menu
11
 
12
- # --- Page Config (Set once at the top) ---
13
  st.set_page_config(page_title="Facial Analysis", page_icon="👤", layout="wide", initial_sidebar_state="expanded")
14
 
15
- # --- Backend Loading (Robust and Unchanged) ---
16
  try:
 
17
  src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
18
  if src_path not in sys.path: sys.path.append(src_path)
19
  from cnnClassifier.pipeline.prediction import PredictionPipeline
20
  except ImportError:
21
- st.error("FATAL: Prediction pipeline not found. Check project structure.")
22
- st.stop()
 
 
23
  try:
24
  gpus = tf.config.list_physical_devices('GPU')
25
  if gpus:
26
  for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True)
27
  except Exception: pass
 
28
  @st.cache_resource
29
  def load_pipeline():
30
  return PredictionPipeline()
 
31
  pipeline = load_pipeline()
32
 
33
- # --- Session State for Webcam Control ---
34
  if 'webcam_running' not in st.session_state: st.session_state.webcam_running = False
35
  def start_webcam(): st.session_state.webcam_running = True
36
  def stop_webcam(): st.session_state.webcam_running = False
37
 
38
- # --- Sidebar UI (Clean and Themed) ---
39
  with st.sidebar:
40
  st.markdown("## ⚙️ Controls")
41
- app_mode = option_menu(
42
- menu_title=None,
43
- options=["Image", "Video", "Live Feed"],
44
- icons=["image", "film", "camera-video"],
45
- menu_icon="cast",
46
- default_index=0,
47
- )
48
- st.divider()
49
- st.info("This app uses a multi-task EfficientNet model to predict age and gender.")
50
-
51
- # --- Main Page Content ---
52
- st.title(f"👤 Facial Demographics Analysis")
53
- st.markdown(f"### Mode: {app_mode}")
54
- st.divider()
55
 
56
  if not pipeline:
57
  st.error("AI Pipeline failed to load. Please check the terminal for errors.")
58
  else:
 
 
 
 
59
  if app_mode == "Image":
60
  uploaded_file = st.file_uploader("Upload an image for analysis", type=["jpg", "jpeg", "png"])
61
  if uploaded_file:
@@ -63,8 +59,9 @@ else:
63
  col1, col2 = st.columns(2)
64
  with col1: st.image(image, caption='Original Image', use_column_width=True)
65
  with col2:
66
- with st.spinner('🔬 Analyzing...'):
67
- annotated_image, predictions = pipeline.predict_image(np.array(image))
 
68
  st.image(annotated_image, caption='Processed Image', use_column_width=True)
69
  if predictions:
70
  with st.expander("View Details", expanded=True):
@@ -79,18 +76,17 @@ else:
79
  tfile.write(uploaded_file.read())
80
  cap = cv2.VideoCapture(tfile.name)
81
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
82
- st.info(f"Video has {frame_count} frames.")
83
- if st.button("Start Video Processing", type="primary", use_container_width=True):
84
  progress_bar = st.progress(0, text="Initializing...")
85
  out_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
86
  h, w = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
87
  out = cv2.VideoWriter(out_tfile.name, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), (w, h))
88
- def frame_generator():
89
- for _ in range(frame_count):
90
- ret, frame = cap.read()
91
- if not ret: break
92
- yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
93
- for i, annotated_frame_rgb in enumerate(pipeline.process_video_stream(frame_generator())):
94
  out.write(cv2.cvtColor(annotated_frame_rgb, cv2.COLOR_RGB2BGR))
95
  progress_bar.progress((i + 1) / frame_count, text=f"Processing Frame {i+1}/{frame_count}")
96
  cap.release(), out.release()
@@ -100,15 +96,14 @@ else:
100
  st.download_button("Download Processed Video", f, "output.mp4", "video/mp4", use_container_width=True)
101
 
102
  elif app_mode == "Live Feed":
 
103
  col1, col2 = st.columns(2)
104
  with col1: st.button("Start Feed", on_click=start_webcam, use_container_width=True, type="primary")
105
  with col2: st.button("Stop Feed", on_click=stop_webcam, use_container_width=True)
106
-
107
  _, center_col, _ = st.columns([1, 2, 1])
108
  with center_col:
109
  FRAME_WINDOW = st.image([])
110
  fps_display = st.empty()
111
-
112
  if st.session_state.webcam_running:
113
  cap = cv2.VideoCapture(0)
114
  while st.session_state.webcam_running:
@@ -116,7 +111,8 @@ else:
116
  ret, frame = cap.read()
117
  if not ret: break
118
  frame = cv2.flip(frame, 1)
119
- annotated_frame = pipeline.process_live_frame(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
 
120
  FRAME_WINDOW.image(annotated_frame, channels="RGB")
121
  fps = 1.0 / (time.time() - start_time) if (time.time() - start_time) > 0 else 0
122
  fps_display.markdown(f"<p style='text-align: center;'><b>FPS: {fps:.2f}</b></p>", unsafe_allow_html=True)
 
9
  import time
10
  from streamlit_option_menu import option_menu
11
 
12
+ # --- Page Config ---
13
  st.set_page_config(page_title="Facial Analysis", page_icon="👤", layout="wide", initial_sidebar_state="expanded")
14
 
15
+ # --- Path Setup & Model Loading ---
16
  try:
17
+ # This works for local development
18
  src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
19
  if src_path not in sys.path: sys.path.append(src_path)
20
  from cnnClassifier.pipeline.prediction import PredictionPipeline
21
  except ImportError:
22
+ # This is a fallback for Hugging Face Spaces
23
+ from src.cnnClassifier.pipeline.prediction import PredictionPipeline
24
+
25
+ # --- TF Config (for MTCNN in Image/Video modes) ---
26
  try:
27
  gpus = tf.config.list_physical_devices('GPU')
28
  if gpus:
29
  for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True)
30
  except Exception: pass
31
+
32
  @st.cache_resource
33
  def load_pipeline():
34
  return PredictionPipeline()
35
+
36
  pipeline = load_pipeline()
37
 
 
38
  if 'webcam_running' not in st.session_state: st.session_state.webcam_running = False
39
  def start_webcam(): st.session_state.webcam_running = True
40
  def stop_webcam(): st.session_state.webcam_running = False
41
 
42
+ # --- UI ---
43
  with st.sidebar:
44
  st.markdown("## ⚙️ Controls")
45
+ app_mode = option_menu(None, ["Image", "Video", "Live Feed"],
46
+ icons=['image', 'film', 'camera-video'], menu_icon="cast", default_index=0)
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  if not pipeline:
49
  st.error("AI Pipeline failed to load. Please check the terminal for errors.")
50
  else:
51
+ st.title("👤 Facial Demographics Analysis")
52
+ st.header(f"Mode: {app_mode}")
53
+ st.divider()
54
+
55
  if app_mode == "Image":
56
  uploaded_file = st.file_uploader("Upload an image for analysis", type=["jpg", "jpeg", "png"])
57
  if uploaded_file:
 
59
  col1, col2 = st.columns(2)
60
  with col1: st.image(image, caption='Original Image', use_column_width=True)
61
  with col2:
62
+ with st.spinner('🔬 Analyzing with high-quality detector...'):
63
+ # --- THE FIX: Call the HQ method ---
64
+ annotated_image, predictions = pipeline.predict_hq(np.array(image))
65
  st.image(annotated_image, caption='Processed Image', use_column_width=True)
66
  if predictions:
67
  with st.expander("View Details", expanded=True):
 
76
  tfile.write(uploaded_file.read())
77
  cap = cv2.VideoCapture(tfile.name)
78
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
79
+ st.info(f"Video has {frame_count} frames. This will be slow but high-quality.")
80
+ if st.button("Process Video", type="primary", use_container_width=True):
81
  progress_bar = st.progress(0, text="Initializing...")
82
  out_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
83
  h, w = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
84
  out = cv2.VideoWriter(out_tfile.name, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), (w, h))
85
+ for i in range(frame_count):
86
+ ret, frame = cap.read()
87
+ if not ret: break
88
+ # --- THE FIX: Call the HQ method ---
89
+ annotated_frame_rgb, _ = pipeline.predict_hq(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
 
90
  out.write(cv2.cvtColor(annotated_frame_rgb, cv2.COLOR_RGB2BGR))
91
  progress_bar.progress((i + 1) / frame_count, text=f"Processing Frame {i+1}/{frame_count}")
92
  cap.release(), out.release()
 
96
  st.download_button("Download Processed Video", f, "output.mp4", "video/mp4", use_container_width=True)
97
 
98
  elif app_mode == "Live Feed":
99
+ st.info("Live feed uses a lightweight face detector for higher FPS.")
100
  col1, col2 = st.columns(2)
101
  with col1: st.button("Start Feed", on_click=start_webcam, use_container_width=True, type="primary")
102
  with col2: st.button("Stop Feed", on_click=stop_webcam, use_container_width=True)
 
103
  _, center_col, _ = st.columns([1, 2, 1])
104
  with center_col:
105
  FRAME_WINDOW = st.image([])
106
  fps_display = st.empty()
 
107
  if st.session_state.webcam_running:
108
  cap = cv2.VideoCapture(0)
109
  while st.session_state.webcam_running:
 
111
  ret, frame = cap.read()
112
  if not ret: break
113
  frame = cv2.flip(frame, 1)
114
+ # --- THE FIX: Call the LQ method ---
115
+ annotated_frame, _ = pipeline.predict_lq(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
116
  FRAME_WINDOW.image(annotated_frame, channels="RGB")
117
  fps = 1.0 / (time.time() - start_time) if (time.time() - start_time) > 0 else 0
118
  fps_display.markdown(f"<p style='text-align: center;'><b>FPS: {fps:.2f}</b></p>", unsafe_allow_html=True)
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libgl1-mesa-glx
2
+ libglib2.0-0
requirements.txt CHANGED
@@ -1,36 +1,17 @@
1
- # For PyTorch with CUDA 11.8 - MUST be installed with the extra index URL
2
- --extra-index-url https://download.pytorch.org/whl/cu118
3
- torch==2.1.0+cu118
4
- torchvision==0.16.0+cu118
5
  torchaudio==2.1.0
6
-
7
- # Pin NumPy to a version compatible with Torch 2.1.0
8
- numpy>=1.23,<2.0
9
-
10
- # Hugging Face
11
  transformers==4.36.2
12
  tokenizers==0.15.0
13
- datasets>=2.14.5
14
- evaluate
15
- accelerate>=0.25
16
-
17
- # MLOps and Utilities
18
- mlflow
19
- dvc[s3] # Assuming you might use S3 with DVC for AWS
20
  python-box
21
  PyYAML
22
- ensure
23
  pandas
24
  scikit-learn
25
  Pillow
26
- tqdm
27
- imblearn
28
- seaborn
29
- # Frontend and Real-time Processing
30
  streamlit
31
- opencv-python
32
- mtcnn
33
- tensorflow==2.15.0
34
  streamlit-option-menu
35
- # AWS Deployment
36
- boto3
 
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
 
 
3
  torchaudio==2.1.0
4
+ numpy<2.0
 
 
 
 
5
  transformers==4.36.2
6
  tokenizers==0.15.0
7
+ safetensors
 
 
 
 
 
 
8
  python-box
9
  PyYAML
 
10
  pandas
11
  scikit-learn
12
  Pillow
 
 
 
 
13
  streamlit
 
 
 
14
  streamlit-option-menu
15
+ opencv-python-headless
16
+ mtcnn
17
+ tensorflow==2.15.0
src/cnnClassifier/pipeline/prediction.py CHANGED
@@ -1,110 +1,58 @@
1
  import torch
2
- import pandas as pd
3
  import numpy as np
4
  from PIL import Image
5
  from transformers import AutoImageProcessor
6
  import cv2
7
- from mtcnn import MTCNN
8
  from pathlib import Path
9
  import sys
10
  import os
11
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
12
  from safetensors.torch import load_file as load_safetensors
13
- from collections import OrderedDict
14
- from scipy.spatial import distance as dist
15
 
16
  try:
17
  src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
18
  if src_path not in sys.path: sys.path.append(src_path)
19
  from components.multi_task_model_trainer import MultiTaskEfficientNet
20
  from utils.common import read_yaml
21
- except ImportError as e:
22
- print(f"Could not import custom modules: {e}.")
23
- sys.exit(1)
24
-
25
- class CentroidTracker:
26
- def __init__(self, max_disappeared=20):
27
- self.next_object_id = 0
28
- self.objects = OrderedDict()
29
- self.disappeared = OrderedDict()
30
- self.max_disappeared = max_disappeared
31
-
32
- def register(self, centroid, box):
33
- self.objects[self.next_object_id] = {'centroid': centroid, 'box': box, 'labels': {}, 'ema_preds': {}}
34
- self.disappeared[self.next_object_id] = 0
35
- self.next_object_id += 1
36
-
37
- def deregister(self, object_id):
38
- del self.objects[object_id]
39
- del self.disappeared[object_id]
40
-
41
- def update(self, boxes):
42
- if len(boxes) == 0:
43
- for object_id in list(self.disappeared.keys()):
44
- self.disappeared[object_id] += 1
45
- if self.disappeared[object_id] > self.max_disappeared:
46
- self.deregister(object_id)
47
- return self.objects
48
-
49
- input_centroids = np.array([(x + w // 2, y + h // 2) for (x, y, w, h) in boxes])
50
-
51
- if len(self.objects) == 0:
52
- for i in range(len(input_centroids)):
53
- self.register(input_centroids[i], boxes[i])
54
- else:
55
- object_ids = list(self.objects.keys())
56
- object_centroids = np.array([v['centroid'] for v in self.objects.values()])
57
- D = dist.cdist(object_centroids, input_centroids)
58
- rows = D.min(axis=1).argsort()
59
- cols = D.argmin(axis=1)[rows]
60
- used_rows, used_cols = set(), set()
61
- for row, col in zip(rows, cols):
62
- if row in used_rows or col in used_cols: continue
63
- object_id = object_ids[row]
64
- self.objects[object_id]['centroid'] = input_centroids[col]
65
- self.objects[object_id]['box'] = boxes[col]
66
- self.disappeared[object_id] = 0
67
- used_rows.add(row)
68
- used_cols.add(col)
69
-
70
- unused_rows = set(range(D.shape[0])).difference(used_rows)
71
- unused_cols = set(range(D.shape[1])).difference(used_cols)
72
-
73
- if D.shape[0] >= D.shape[1]:
74
- for row in unused_rows:
75
- object_id = object_ids[row]
76
- self.disappeared[object_id] += 1
77
- if self.disappeared[object_id] > self.max_disappeared:
78
- self.deregister(object_id)
79
- else:
80
- for col in unused_cols:
81
- self.register(input_centroids[col], boxes[col])
82
- return self.objects
83
 
84
  class PredictionPipeline:
85
- def __init__(self, model_path: str = "artifacts/multi_task_model_trainer/checkpoint-26873"):
86
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
87
  self.model_path = Path(model_path)
88
  self.base_model_name = "google/efficientnet-b2"
89
- params = read_yaml(Path("params.yaml"))
 
 
 
 
 
 
 
90
  self.processor = AutoImageProcessor.from_pretrained(self.base_model_name)
91
- self.transforms = Compose([Resize((params.IMAGE_SIZE, params.IMAGE_SIZE)), ToTensor(), Normalize(mean=self.processor.image_mean, std=self.processor.image_std)])
92
- self.label_maps = self._load_label_maps()
93
  self.model = self._load_model()
94
- self.face_detector = MTCNN()
95
- self.tracker = CentroidTracker()
96
- print(f"--- Pipeline Initialized on device: {self.device} ---")
97
-
98
- def _load_label_maps(self):
99
- maps = {'age_id2label': {'0': '0-2', '1': '3-9', '2': '10-19', '3': '20-29', '4': '30-39', '5': '40-49', '6': '50-59', '7': '60-69', '8': 'more than 70'},
100
- 'gender_id2label': {'0': 'Male', '1': 'Female'}}
101
- return maps
 
 
102
 
103
  def _load_model(self):
104
  num_age, num_gender, num_race = len(self.label_maps['age_id2label']), len(self.label_maps['gender_id2label']), 7
105
  model = MultiTaskEfficientNet(self.base_model_name, num_age, num_gender, num_race)
106
  weight_file = self.model_path / 'model.safetensors'
107
  if not weight_file.exists(): weight_file = self.model_path / 'pytorch_model.bin'
 
108
  state_dict = load_safetensors(weight_file, device="cpu") if weight_file.suffix == ".safetensors" else torch.load(weight_file, map_location="cpu")
109
  model.load_state_dict(state_dict)
110
  model.to(self.device)
@@ -126,66 +74,48 @@ class PredictionPipeline:
126
  for i, line in enumerate(text_lines):
127
  y_text = y - total_height + (i * line_height) + 18
128
  cv2.putText(image, line, (x + 5, y_text), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
129
-
130
- def _predict_for_box(self, frame, box):
131
- x, y, w, h = [int(c) for c in box]
132
- face_img = frame[max(0,y):min(frame.shape[0],y+h), max(0,x):min(frame.shape[1],x+w)]
133
- if face_img.size == 0: return None
134
- pixel_values = self.transforms(Image.fromarray(face_img)).unsqueeze(0).to(self.device)
135
- with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
136
- return outputs
137
-
138
- def predict_image(self, image_array):
139
  annotated_image, predictions = image_array.copy(), []
140
- face_results = self.face_detector.detect_faces(image_array)
141
  if not face_results: return annotated_image, predictions
 
142
  for face in face_results:
143
- if face['confidence'] < 0.9: continue
144
- box = face['box']
145
- outputs = self._predict_for_box(annotated_image, box)
146
- if outputs:
147
- age_label = self.label_maps['age_id2label'][str(outputs['age_logits'].argmax(1).item())]
148
- gender_label = self.label_maps['gender_id2label'][str(outputs['gender_logits'].argmax(1).item())]
149
- prediction_labels = {"age": age_label, "gender": gender_label}
150
- predictions.append({**prediction_labels, 'box': box})
151
- self._draw_predictions(annotated_image, box, prediction_labels)
 
 
 
 
 
152
  return annotated_image, predictions
153
 
154
- def process_video_stream(self, frame_generator):
155
- self.tracker = CentroidTracker()
156
- for frame in frame_generator:
157
- face_results = self.face_detector.detect_faces(frame)
158
- boxes = [tuple(face['box']) for face in face_results if face['confidence'] > 0.9]
159
- tracked_objects = self.tracker.update(boxes)
160
-
161
- for obj_id, data in tracked_objects.items():
162
- # Predict only for new tracks or tracks that have just been re-found
163
- if 'labels' not in data or self.tracker.disappeared[obj_id] == 0:
164
- outputs = self._predict_for_box(frame, data['box'])
165
- if outputs:
166
- alpha = 0.3
167
- current_probs = {
168
- 'age': outputs['age_logits'].softmax(1).cpu().numpy()[0],
169
- 'gender': outputs['gender_logits'].softmax(1).cpu().numpy()[0]
170
- }
171
- # Apply EMA smoothing
172
- if not data.get('ema_preds'): data['ema_preds'] = current_probs
173
- else:
174
- for task in ['age', 'gender']:
175
- data['ema_preds'][task] = alpha * current_probs[task] + (1 - alpha) * data['ema_preds'][task]
176
-
177
- # Always update the label from the latest smoothed probabilities
178
- if data.get('ema_preds'):
179
- age_label = self.label_maps['age_id2label'][str(np.argmax(data['ema_preds']['age']))]
180
- gender_label = self.label_maps['gender_id2label'][str(np.argmax(data['ema_preds']['gender']))]
181
- data['labels'] = {"age": age_label, "gender": gender_label}
182
-
183
- annotated_frame = frame.copy()
184
- for obj_id, data in tracked_objects.items():
185
- if 'labels' in data:
186
- self._draw_predictions(annotated_frame, data['box'], data['labels'])
187
- yield annotated_frame
188
-
189
- def process_live_frame(self, frame):
190
- annotated_frame, _ = self.predict_image(frame)
191
- return annotated_frame
 
1
  import torch
 
2
  import numpy as np
3
  from PIL import Image
4
  from transformers import AutoImageProcessor
5
  import cv2
6
+ from mtcnn import MTCNN # For high-quality
7
  from pathlib import Path
8
  import sys
9
  import os
10
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
11
  from safetensors.torch import load_file as load_safetensors
 
 
12
 
13
  try:
14
  src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
15
  if src_path not in sys.path: sys.path.append(src_path)
16
  from components.multi_task_model_trainer import MultiTaskEfficientNet
17
  from utils.common import read_yaml
18
+ except ImportError:
19
+ # Fallback for Hugging Face Spaces
20
+ from src.cnnClassifier.components.multi_task_model_trainer import MultiTaskEfficientNet
21
+ from src.cnnClassifier.utils.common import read_yaml
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  class PredictionPipeline:
24
+ def __init__(self, model_path: str = "model/checkpoint-26873"):
25
+ self.device = "cpu" # Force CPU for deployment
26
  self.model_path = Path(model_path)
27
  self.base_model_name = "google/efficientnet-b2"
28
+ self.params = read_yaml(Path("model/params.yaml"))
29
+
30
+ self.label_maps = {
31
+ 'age_id2label': {'0': '0-2', '1': '3-9', '2': '10-19', '3': '20-29', '4': '30-39', '5': '40-49', '6': '50-59', '7': '60-69', '8': 'more than 70'},
32
+ 'gender_id2label': {'0': 'Male', '1': 'Female'}
33
+ }
34
+
35
+ print("--- Initializing Prediction Pipeline ---")
36
  self.processor = AutoImageProcessor.from_pretrained(self.base_model_name)
37
+ self.transforms = Compose([Resize((self.params.IMAGE_SIZE, self.params.IMAGE_SIZE)), ToTensor(), Normalize(mean=self.processor.image_mean, std=self.processor.image_std)])
 
38
  self.model = self._load_model()
39
+
40
+ # --- THE FIX: LOAD BOTH DETECTORS ---
41
+ # High-quality detector for offline tasks
42
+ self.hq_face_detector = MTCNN()
43
+ # Lightweight detector for live feed
44
+ haar_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
45
+ self.lq_face_detector = cv2.CascadeClassifier(haar_cascade_path)
46
+ # --- END FIX ---
47
+
48
+ print(f"--- Pipeline Initialized Successfully on device: {self.device} ---")
49
 
50
  def _load_model(self):
51
  num_age, num_gender, num_race = len(self.label_maps['age_id2label']), len(self.label_maps['gender_id2label']), 7
52
  model = MultiTaskEfficientNet(self.base_model_name, num_age, num_gender, num_race)
53
  weight_file = self.model_path / 'model.safetensors'
54
  if not weight_file.exists(): weight_file = self.model_path / 'pytorch_model.bin'
55
+ if not weight_file.exists(): raise FileNotFoundError(f"Weights not found in {self.model_path}")
56
  state_dict = load_safetensors(weight_file, device="cpu") if weight_file.suffix == ".safetensors" else torch.load(weight_file, map_location="cpu")
57
  model.load_state_dict(state_dict)
58
  model.to(self.device)
 
74
  for i, line in enumerate(text_lines):
75
  y_text = y - total_height + (i * line_height) + 18
76
  cv2.putText(image, line, (x + 5, y_text), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
77
+
78
+ def predict_hq(self, image_array: np.ndarray) -> (np.ndarray, list):
79
+ """High-quality prediction using MTCNN for images and videos."""
 
 
 
 
 
 
 
80
  annotated_image, predictions = image_array.copy(), []
81
+ face_results = self.hq_face_detector.detect_faces(image_array)
82
  if not face_results: return annotated_image, predictions
83
+
84
  for face in face_results:
85
+ if face['confidence'] < 0.95: continue
86
+ x, y, w, h = face['box']
87
+ face_img = image_array[max(0,y):min(image_array.shape[0],y+h), max(0,x):min(image_array.shape[1],x+w)]
88
+ if face_img.size == 0: continue
89
+ pil_face = Image.fromarray(face_img)
90
+ pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
91
+ with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
92
+ pred_id_age = str(outputs['age_logits'].argmax(1).item())
93
+ pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
94
+ age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
95
+ gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
96
+ prediction_labels = {"age": age_label, "gender": gender_label}
97
+ predictions.append({**prediction_labels, 'box': (x, y, w, h)})
98
+ self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
99
  return annotated_image, predictions
100
 
101
+ def predict_lq(self, image_array: np.ndarray) -> (np.ndarray, list):
102
+ """Lightweight prediction using Haar Cascade for live feed."""
103
+ annotated_image, predictions = image_array.copy(), []
104
+ gray_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
105
+ faces = self.lq_face_detector.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
106
+ if len(faces) == 0: return annotated_image, predictions
107
+
108
+ for (x, y, w, h) in faces:
109
+ face_img = image_array[y:y+h, x:x+w]
110
+ if face_img.size == 0: continue
111
+ pil_face = Image.fromarray(face_img)
112
+ pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
113
+ with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
114
+ pred_id_age = str(outputs['age_logits'].argmax(1).item())
115
+ pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
116
+ age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
117
+ gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
118
+ prediction_labels = {"age": age_label, "gender": gender_label}
119
+ predictions.append({**prediction_labels, 'box': (x, y, w, h)})
120
+ self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
121
+ return annotated_image, predictions