ALYYAN commited on
Commit
9e440e5
·
verified ·
1 Parent(s): e547d01

Update src/cnnClassifier/pipeline/prediction.py

Browse files
src/cnnClassifier/pipeline/prediction.py CHANGED
@@ -4,12 +4,13 @@ from PIL import Image
4
  from transformers import AutoImageProcessor
5
  import cv2
6
  from huggingface_hub import hf_hub_download
7
- from mtcnn import MTCNN # For high-quality
8
  from pathlib import Path
9
  import sys
10
  import os
11
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
12
  from safetensors.torch import load_file as load_safetensors
 
13
 
14
  try:
15
  src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
@@ -17,7 +18,6 @@ try:
17
  from components.multi_task_model_trainer import MultiTaskEfficientNet
18
  from utils.common import read_yaml
19
  except ImportError:
20
- # Fallback for Hugging Face Spaces
21
  from src.cnnClassifier.components.multi_task_model_trainer import MultiTaskEfficientNet
22
  from src.cnnClassifier.utils.common import read_yaml
23
 
@@ -25,65 +25,45 @@ class PredictionPipeline:
25
  def __init__(self, repo_id: str = "ALYYAN/Facial-Age-Det"):
26
  self.device = "cpu"
27
  self.repo_id = repo_id
28
-
29
  print("--- Initializing Prediction Pipeline by downloading artifacts from Hub ---")
30
-
31
- # Define cache dir (matches Dockerfile ENV)
32
  cache_dir = os.getenv("HF_HOME", "/app/hf_cache")
33
-
34
- # Download individual files → return full path
35
- self.model_file = hf_hub_download(
36
- repo_id=self.repo_id,
37
- filename="checkpoint-26873/model.safetensors",
38
- cache_dir=cache_dir
39
- )
40
- self.params_path = hf_hub_download(
41
- repo_id=self.repo_id,
42
- filename="params.yaml",
43
- cache_dir=cache_dir
44
- )
45
- self.data_csv_path = hf_hub_download(
46
- repo_id=self.repo_id,
47
- filename="fairface_cleaned.csv",
48
- cache_dir=cache_dir
49
- )
50
-
51
  self.base_model_name = "google/efficientnet-b2"
52
  self.params = read_yaml(Path(self.params_path))
 
 
 
 
53
 
54
- self.label_maps = self._load_label_maps()
55
  self.processor = AutoImageProcessor.from_pretrained(self.base_model_name)
56
- self.transforms = Compose([
57
- Resize((self.params.IMAGE_SIZE, self.params.IMAGE_SIZE)),
58
- ToTensor(),
59
- Normalize(mean=self.processor.image_mean, std=self.processor.image_std)
60
- ])
61
  self.model = self._load_model()
62
-
63
- # Face detector
64
  haar_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
65
  self.lq_face_detector = cv2.CascadeClassifier(haar_cascade_path)
66
  self.hq_face_detector = MTCNN()
67
-
68
  print(f"--- Pipeline Initialized Successfully on device: {self.device} ---")
69
 
70
-
 
 
 
 
 
 
 
 
 
 
 
 
71
  def _load_model(self):
72
- num_age = len(self.label_maps['age_id2label'])
73
- num_gender = len(self.label_maps['gender_id2label'])
74
- num_race = 7
75
-
76
  model = MultiTaskEfficientNet(self.base_model_name, num_age, num_gender, num_race)
77
-
78
  weight_file = Path(self.model_file)
79
- if not weight_file.exists():
80
- raise FileNotFoundError(f"Weights not found: {weight_file}")
81
-
82
- if weight_file.suffix == ".safetensors":
83
- state_dict = load_safetensors(weight_file, device="cpu")
84
- else:
85
- state_dict = torch.load(weight_file, map_location="cpu")
86
-
87
  model.load_state_dict(state_dict)
88
  model.to(self.device)
89
  model.eval()
@@ -106,11 +86,9 @@ class PredictionPipeline:
106
  cv2.putText(image, line, (x + 5, y_text), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
107
 
108
  def predict_hq(self, image_array: np.ndarray) -> (np.ndarray, list):
109
- """High-quality prediction using MTCNN for images and videos."""
110
  annotated_image, predictions = image_array.copy(), []
111
  face_results = self.hq_face_detector.detect_faces(image_array)
112
  if not face_results: return annotated_image, predictions
113
-
114
  for face in face_results:
115
  if face['confidence'] < 0.95: continue
116
  x, y, w, h = face['box']
@@ -129,12 +107,10 @@ class PredictionPipeline:
129
  return annotated_image, predictions
130
 
131
  def predict_lq(self, image_array: np.ndarray) -> (np.ndarray, list):
132
- """Lightweight prediction using Haar Cascade for live feed."""
133
  annotated_image, predictions = image_array.copy(), []
134
  gray_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
135
  faces = self.lq_face_detector.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
136
  if len(faces) == 0: return annotated_image, predictions
137
-
138
  for (x, y, w, h) in faces:
139
  face_img = image_array[y:y+h, x:x+w]
140
  if face_img.size == 0: continue
 
4
  from transformers import AutoImageProcessor
5
  import cv2
6
  from huggingface_hub import hf_hub_download
7
+ from mtcnn import MTCNN
8
  from pathlib import Path
9
  import sys
10
  import os
11
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
12
  from safetensors.torch import load_file as load_safetensors
13
+ import pandas as pd
14
 
15
  try:
16
  src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
 
18
  from components.multi_task_model_trainer import MultiTaskEfficientNet
19
  from utils.common import read_yaml
20
  except ImportError:
 
21
  from src.cnnClassifier.components.multi_task_model_trainer import MultiTaskEfficientNet
22
  from src.cnnClassifier.utils.common import read_yaml
23
 
 
25
  def __init__(self, repo_id: str = "ALYYAN/Facial-Age-Det"):
26
  self.device = "cpu"
27
  self.repo_id = repo_id
 
28
  print("--- Initializing Prediction Pipeline by downloading artifacts from Hub ---")
 
 
29
  cache_dir = os.getenv("HF_HOME", "/app/hf_cache")
30
+ self.model_file = hf_hub_download(repo_id=self.repo_id, filename="checkpoint-26873/model.safetensors", cache_dir=cache_dir)
31
+ self.params_path = hf_hub_download(repo_id=self.repo_id, filename="params.yaml", cache_dir=cache_dir)
32
+ self.data_csv_path = hf_hub_download(repo_id=self.repo_id, filename="fairface_cleaned.csv", cache_dir=cache_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  self.base_model_name = "google/efficientnet-b2"
34
  self.params = read_yaml(Path(self.params_path))
35
+
36
+ # --- THE FIX IS HERE: CALL THE METHOD THAT EXISTS ---
37
+ self.label_maps = self._load_label_maps_from_csv()
38
+ # --- END FIX ---
39
 
 
40
  self.processor = AutoImageProcessor.from_pretrained(self.base_model_name)
41
+ self.transforms = Compose([Resize((self.params.IMAGE_SIZE, self.params.IMAGE_SIZE)), ToTensor(), Normalize(mean=self.processor.image_mean, std=self.processor.image_std)])
 
 
 
 
42
  self.model = self._load_model()
 
 
43
  haar_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
44
  self.lq_face_detector = cv2.CascadeClassifier(haar_cascade_path)
45
  self.hq_face_detector = MTCNN()
 
46
  print(f"--- Pipeline Initialized Successfully on device: {self.device} ---")
47
 
48
+ # --- THE MISSING METHOD ---
49
+ def _load_label_maps_from_csv(self):
50
+ print(f"Generating label maps from downloaded CSV: {self.data_csv_path}")
51
+ df = pd.read_csv(self.data_csv_path)
52
+ label_maps = {}
53
+ tasks = {'age': lambda x: int(str(x).split('-')[0]), 'gender': None}
54
+ for task, sort_key in tasks.items():
55
+ labels_str = [str(label) for label in df[task].unique()]
56
+ sorted_labels = sorted(labels_str, key=sort_key)
57
+ label_maps[f'{task}_id2label'] = {str(i): label for i, label in enumerate(sorted_labels)}
58
+ return label_maps
59
+ # --- END MISSING METHOD ---
60
+
61
  def _load_model(self):
62
+ num_age, num_gender, num_race = len(self.label_maps['age_id2label']), len(self.label_maps['gender_id2label']), 7
 
 
 
63
  model = MultiTaskEfficientNet(self.base_model_name, num_age, num_gender, num_race)
 
64
  weight_file = Path(self.model_file)
65
+ if not weight_file.exists(): raise FileNotFoundError(f"Weights not found: {weight_file}")
66
+ state_dict = load_safetensors(weight_file, device="cpu") if weight_file.suffix == ".safetensors" else torch.load(weight_file, map_location="cpu")
 
 
 
 
 
 
67
  model.load_state_dict(state_dict)
68
  model.to(self.device)
69
  model.eval()
 
86
  cv2.putText(image, line, (x + 5, y_text), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
87
 
88
  def predict_hq(self, image_array: np.ndarray) -> (np.ndarray, list):
 
89
  annotated_image, predictions = image_array.copy(), []
90
  face_results = self.hq_face_detector.detect_faces(image_array)
91
  if not face_results: return annotated_image, predictions
 
92
  for face in face_results:
93
  if face['confidence'] < 0.95: continue
94
  x, y, w, h = face['box']
 
107
  return annotated_image, predictions
108
 
109
  def predict_lq(self, image_array: np.ndarray) -> (np.ndarray, list):
 
110
  annotated_image, predictions = image_array.copy(), []
111
  gray_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
112
  faces = self.lq_face_detector.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
113
  if len(faces) == 0: return annotated_image, predictions
 
114
  for (x, y, w, h) in faces:
115
  face_img = image_array[y:y+h, x:x+w]
116
  if face_img.size == 0: continue