Spaces:
Running
Running
Update src/EmotionRecognition/pipeline/hf_predictor.py
Browse files
src/EmotionRecognition/pipeline/hf_predictor.py
CHANGED
|
@@ -20,26 +20,19 @@ class HFPredictor:
|
|
| 20 |
self.stable_prediction = "---"
|
| 21 |
print("[PREDICTOR INFO] Predictor initialized successfully.")
|
| 22 |
|
| 23 |
-
|
| 24 |
-
def process_frame(self, frame):
|
| 25 |
"""
|
| 26 |
-
|
| 27 |
-
|
| 28 |
"""
|
| 29 |
-
if frame is None:
|
| 30 |
-
|
| 31 |
-
# --- MIRROR FIX: Flip the frame FIRST! ---
|
| 32 |
-
# This ensures detection and drawing happen in the same coordinate space the user sees.
|
| 33 |
-
frame = cv2.flip(frame, 1)
|
| 34 |
-
annotated_frame = frame.copy()
|
| 35 |
-
# --- END FIX ---
|
| 36 |
|
| 37 |
-
|
| 38 |
faces = self.face_detector.detect_faces(frame)
|
| 39 |
|
| 40 |
for face in faces:
|
| 41 |
x, y, width, height = face['box']
|
| 42 |
-
x, y = max(0, x), max(0, y)
|
| 43 |
face_roi = frame[y:y+height, x:x+width]
|
| 44 |
|
| 45 |
if face_roi.size > 0:
|
|
@@ -51,27 +44,66 @@ class HFPredictor:
|
|
| 51 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 52 |
predictions = probs[0].numpy()
|
| 53 |
pred_index = np.argmax(predictions)
|
| 54 |
-
|
| 55 |
-
# Use temporal smoothing for the displayed label
|
| 56 |
confidence = predictions[pred_index]
|
|
|
|
| 57 |
if confidence > self.confidence_threshold:
|
| 58 |
self.recent_predictions.append(pred_index)
|
| 59 |
-
if self.recent_predictions:
|
| 60 |
-
most_common_pred = Counter(self.recent_predictions).most_common(1)[0][0]
|
| 61 |
-
self.stable_prediction = self.classes[most_common_pred]
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
cv2.putText(annotated_frame, text, (x + 5, y - 5), FONT, 0.8, BLACK, 2)
|
| 73 |
cv2.rectangle(annotated_frame, (x, y), (x+width, y+height), GREEN, 3)
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
return annotated_frame, all_probabilities
|
|
|
|
| 20 |
self.stable_prediction = "---"
|
| 21 |
print("[PREDICTOR INFO] Predictor initialized successfully.")
|
| 22 |
|
| 23 |
+
def get_probabilities(self, frame):
|
|
|
|
| 24 |
"""
|
| 25 |
+
A lightweight function that takes a frame, runs inference,
|
| 26 |
+
updates the stable prediction, and returns ONLY the probability dictionary.
|
| 27 |
"""
|
| 28 |
+
if frame is None:
|
| 29 |
+
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
probabilities = {}
|
| 32 |
faces = self.face_detector.detect_faces(frame)
|
| 33 |
|
| 34 |
for face in faces:
|
| 35 |
x, y, width, height = face['box']
|
|
|
|
| 36 |
face_roi = frame[y:y+height, x:x+width]
|
| 37 |
|
| 38 |
if face_roi.size > 0:
|
|
|
|
| 44 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 45 |
predictions = probs[0].numpy()
|
| 46 |
pred_index = np.argmax(predictions)
|
|
|
|
|
|
|
| 47 |
confidence = predictions[pred_index]
|
| 48 |
+
|
| 49 |
if confidence > self.confidence_threshold:
|
| 50 |
self.recent_predictions.append(pred_index)
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
probabilities = {self.classes[i]: float(predictions[i]) for i in range(len(self.classes))}
|
| 53 |
+
|
| 54 |
+
return probabilities
|
| 55 |
+
|
| 56 |
+
def annotate_frame(self, frame):
|
| 57 |
+
"""
|
| 58 |
+
Takes a frame, detects faces, and returns the fully annotated version
|
| 59 |
+
using the latest stable prediction.
|
| 60 |
+
"""
|
| 61 |
+
if frame is None: return None
|
| 62 |
+
|
| 63 |
+
annotated_frame = frame.copy()
|
| 64 |
+
faces = self.face_detector.detect_faces(frame)
|
| 65 |
+
|
| 66 |
+
# We use the 'stable_prediction' which is updated by the high-fps get_probabilities call
|
| 67 |
+
# This ensures the box text is smooth and consistent.
|
| 68 |
+
for face in faces:
|
| 69 |
+
x, y, width, height = face['box']
|
| 70 |
+
GREEN = (0, 255, 0)
|
| 71 |
+
BLACK = (0, 0, 0)
|
| 72 |
+
FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 73 |
+
text = self.stable_prediction # Use the smoothed prediction
|
| 74 |
+
|
| 75 |
+
(text_width, text_height), baseline = cv2.getTextSize(text, FONT, 0.8, 2)
|
| 76 |
+
cv2.rectangle(annotated_frame, (x, y - text_height - baseline - 10), (x + text_width + 10, y), GREEN, cv2.FILLED)
|
| 77 |
+
cv2.putText(annotated_frame, text, (x + 5, y - 5), FONT, 0.8, BLACK, 2)
|
| 78 |
+
cv2.rectangle(annotated_frame, (x, y), (x+width, y+height), GREEN, 3)
|
| 79 |
+
|
| 80 |
+
return annotated_frame
|
| 81 |
+
|
| 82 |
+
def process_frame_for_upload(self, frame):
|
| 83 |
+
"""A simple, all-in-one function for static images and videos."""
|
| 84 |
+
if frame is None: return None, {}
|
| 85 |
+
annotated_frame = frame.copy()
|
| 86 |
+
probabilities = {}
|
| 87 |
+
faces = self.face_detector.detect_faces(frame)
|
| 88 |
+
for face in faces:
|
| 89 |
+
x, y, width, height = face['box']
|
| 90 |
+
face_roi = frame[y:y+height, x:x+width]
|
| 91 |
+
if face_roi.size > 0:
|
| 92 |
+
pil_image = Image.fromarray(face_roi)
|
| 93 |
+
inputs = self.processor(images=pil_image, return_tensors="pt")
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
logits = self.model(**inputs).logits
|
| 96 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 97 |
+
predictions = probs[0].numpy()
|
| 98 |
+
pred_index = np.argmax(predictions)
|
| 99 |
+
emotion = self.classes[pred_index]
|
| 100 |
+
confidence = predictions[pred_index]
|
| 101 |
+
text = f"{emotion} ({confidence*100:.1f}%)"
|
| 102 |
+
# (Drawing logic is duplicated here for simplicity)
|
| 103 |
+
GREEN = (0, 255, 0); BLACK = (0, 0, 0); FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 104 |
+
(tw, th), bl = cv2.getTextSize(text, FONT, 0.8, 2)
|
| 105 |
+
cv2.rectangle(annotated_frame, (x, y-th-bl-10), (x+tw+10, y), GREEN, cv2.FILLED)
|
| 106 |
cv2.putText(annotated_frame, text, (x + 5, y - 5), FONT, 0.8, BLACK, 2)
|
| 107 |
cv2.rectangle(annotated_frame, (x, y), (x+width, y+height), GREEN, 3)
|
| 108 |
+
probabilities = {self.classes[i]: float(predictions[i]) for i in range(len(self.classes))}
|
| 109 |
+
return annotated_frame, probabilities
|
|
|
|
|
|