ALYYAN commited on
Commit
55b3ce5
·
verified ·
1 Parent(s): f0293fd

Update src/EmotionRecognition/pipeline/hf_predictor.py

Browse files
src/EmotionRecognition/pipeline/hf_predictor.py CHANGED
@@ -20,26 +20,19 @@ class HFPredictor:
20
  self.stable_prediction = "---"
21
  print("[PREDICTOR INFO] Predictor initialized successfully.")
22
 
23
-
24
- def process_frame(self, frame):
25
  """
26
- Processes a single frame: flips it for a mirror effect, detects faces,
27
- predicts emotions, and draws professional annotations.
28
  """
29
- if frame is None: return frame, {}
30
-
31
- # --- MIRROR FIX: Flip the frame FIRST! ---
32
- # This ensures detection and drawing happen in the same coordinate space the user sees.
33
- frame = cv2.flip(frame, 1)
34
- annotated_frame = frame.copy()
35
- # --- END FIX ---
36
 
37
- all_probabilities = {}
38
  faces = self.face_detector.detect_faces(frame)
39
 
40
  for face in faces:
41
  x, y, width, height = face['box']
42
- x, y = max(0, x), max(0, y)
43
  face_roi = frame[y:y+height, x:x+width]
44
 
45
  if face_roi.size > 0:
@@ -51,27 +44,66 @@ class HFPredictor:
51
  probs = torch.nn.functional.softmax(logits, dim=-1)
52
  predictions = probs[0].numpy()
53
  pred_index = np.argmax(predictions)
54
-
55
- # Use temporal smoothing for the displayed label
56
  confidence = predictions[pred_index]
 
57
  if confidence > self.confidence_threshold:
58
  self.recent_predictions.append(pred_index)
59
- if self.recent_predictions:
60
- most_common_pred = Counter(self.recent_predictions).most_common(1)[0][0]
61
- self.stable_prediction = self.classes[most_common_pred]
62
 
63
- # --- PROFESSIONAL DRAWING LOGIC ---
64
- GREEN = (0, 255, 0)
65
- BLACK = (0, 0, 0)
66
- FONT = cv2.FONT_HERSHEY_SIMPLEX
67
- text = f"{self.stable_prediction} ({confidence*100:.1f}%)"
68
-
69
- (text_width, text_height), baseline = cv2.getTextSize(text, FONT, 0.8, 2)
70
-
71
- cv2.rectangle(annotated_frame, (x, y - text_height - baseline - 10), (x + text_width + 10, y), GREEN, cv2.FILLED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  cv2.putText(annotated_frame, text, (x + 5, y - 5), FONT, 0.8, BLACK, 2)
73
  cv2.rectangle(annotated_frame, (x, y), (x+width, y+height), GREEN, 3)
74
-
75
- all_probabilities = {self.classes[i]: float(predictions[i]) for i in range(len(self.classes))}
76
-
77
- return annotated_frame, all_probabilities
 
20
  self.stable_prediction = "---"
21
  print("[PREDICTOR INFO] Predictor initialized successfully.")
22
 
23
+ def get_probabilities(self, frame):
 
24
  """
25
+ A lightweight function that takes a frame, runs inference,
26
+ updates the stable prediction, and returns ONLY the probability dictionary.
27
  """
28
+ if frame is None:
29
+ return {}
 
 
 
 
 
30
 
31
+ probabilities = {}
32
  faces = self.face_detector.detect_faces(frame)
33
 
34
  for face in faces:
35
  x, y, width, height = face['box']
 
36
  face_roi = frame[y:y+height, x:x+width]
37
 
38
  if face_roi.size > 0:
 
44
  probs = torch.nn.functional.softmax(logits, dim=-1)
45
  predictions = probs[0].numpy()
46
  pred_index = np.argmax(predictions)
 
 
47
  confidence = predictions[pred_index]
48
+
49
  if confidence > self.confidence_threshold:
50
  self.recent_predictions.append(pred_index)
 
 
 
51
 
52
+ probabilities = {self.classes[i]: float(predictions[i]) for i in range(len(self.classes))}
53
+
54
+ return probabilities
55
+
56
+ def annotate_frame(self, frame):
57
+ """
58
+ Takes a frame, detects faces, and returns the fully annotated version
59
+ using the latest stable prediction.
60
+ """
61
+ if frame is None: return None
62
+
63
+ annotated_frame = frame.copy()
64
+ faces = self.face_detector.detect_faces(frame)
65
+
66
+ # We use the 'stable_prediction' which is updated by the high-fps get_probabilities call
67
+ # This ensures the box text is smooth and consistent.
68
+ for face in faces:
69
+ x, y, width, height = face['box']
70
+ GREEN = (0, 255, 0)
71
+ BLACK = (0, 0, 0)
72
+ FONT = cv2.FONT_HERSHEY_SIMPLEX
73
+ text = self.stable_prediction # Use the smoothed prediction
74
+
75
+ (text_width, text_height), baseline = cv2.getTextSize(text, FONT, 0.8, 2)
76
+ cv2.rectangle(annotated_frame, (x, y - text_height - baseline - 10), (x + text_width + 10, y), GREEN, cv2.FILLED)
77
+ cv2.putText(annotated_frame, text, (x + 5, y - 5), FONT, 0.8, BLACK, 2)
78
+ cv2.rectangle(annotated_frame, (x, y), (x+width, y+height), GREEN, 3)
79
+
80
+ return annotated_frame
81
+
82
+ def process_frame_for_upload(self, frame):
83
+ """A simple, all-in-one function for static images and videos."""
84
+ if frame is None: return None, {}
85
+ annotated_frame = frame.copy()
86
+ probabilities = {}
87
+ faces = self.face_detector.detect_faces(frame)
88
+ for face in faces:
89
+ x, y, width, height = face['box']
90
+ face_roi = frame[y:y+height, x:x+width]
91
+ if face_roi.size > 0:
92
+ pil_image = Image.fromarray(face_roi)
93
+ inputs = self.processor(images=pil_image, return_tensors="pt")
94
+ with torch.no_grad():
95
+ logits = self.model(**inputs).logits
96
+ probs = torch.nn.functional.softmax(logits, dim=-1)
97
+ predictions = probs[0].numpy()
98
+ pred_index = np.argmax(predictions)
99
+ emotion = self.classes[pred_index]
100
+ confidence = predictions[pred_index]
101
+ text = f"{emotion} ({confidence*100:.1f}%)"
102
+ # (Drawing logic is duplicated here for simplicity)
103
+ GREEN = (0, 255, 0); BLACK = (0, 0, 0); FONT = cv2.FONT_HERSHEY_SIMPLEX
104
+ (tw, th), bl = cv2.getTextSize(text, FONT, 0.8, 2)
105
+ cv2.rectangle(annotated_frame, (x, y-th-bl-10), (x+tw+10, y), GREEN, cv2.FILLED)
106
  cv2.putText(annotated_frame, text, (x + 5, y - 5), FONT, 0.8, BLACK, 2)
107
  cv2.rectangle(annotated_frame, (x, y), (x+width, y+height), GREEN, 3)
108
+ probabilities = {self.classes[i]: float(predictions[i]) for i in range(len(self.classes))}
109
+ return annotated_frame, probabilities