Spaces:
Running
Running
File size: 3,933 Bytes
f8f5549 bc745c3 f8f5549 007a9b7 f8f5549 bc745c3 007a9b7 592e904 f8f5549 bc745c3 f8f5549 007a9b7 bc745c3 007a9b7 f8f5549 007a9b7 bc745c3 007a9b7 bc745c3 f8f5549 bc745c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
import numpy as np
import cv2
from mtcnn import MTCNN
from collections import deque, Counter
from PIL import Image
LOCAL_MODEL_PATH = "sota_model"
class HFPredictor:
def __init__(self, smoothing_window=10, confidence_threshold=0.3):
print(f"[PREDICTOR INFO] Loading model from local path: {LOCAL_MODEL_PATH}...")
self.processor = AutoImageProcessor.from_pretrained(LOCAL_MODEL_PATH)
self.model = AutoModelForImageClassification.from_pretrained(LOCAL_MODEL_PATH)
self.face_detector = MTCNN()
self.classes = list(self.model.config.id2label.values())
self.confidence_threshold = confidence_threshold
self.recent_predictions = deque(maxlen=smoothing_window)
self.stable_prediction = "---"
print("[PREDICTOR INFO] Predictor initialized successfully.")
def process_frame(self, frame):
"""
Processes a single frame. This function is now used for ALL predictions
(live, image, and video) to ensure consistency.
"""
if frame is None: return frame, {}
annotated_frame = frame.copy()
all_probabilities = {}
faces = self.face_detector.detect_faces(frame)
for face in faces:
x, y, width, height = face['box']
x, y = max(0, x), max(0, y)
face_roi = frame[y:y+height, x:x+width]
if face_roi.size > 0:
pil_image = Image.fromarray(face_roi)
inputs = self.processor(images=pil_image, return_tensors="pt")
with torch.no_grad():
logits = self.model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
predictions = probs[0].numpy()
pred_index = np.argmax(predictions)
confidence = predictions[pred_index]
# --- THIS IS THE DEFINITIVE FIX ---
# For the bounding box text, we determine which label to show.
# For the live feed, we want smooth predictions. For static images, we want the direct one.
# A simple check on the deque can tell us if we are in a "live" context.
if len(self.recent_predictions) > 0:
# If the deque has items, we are in a live stream, so use smoothing.
if confidence > self.confidence_threshold:
self.recent_predictions.append(pred_index)
most_common_pred = Counter(self.recent_predictions).most_common(1)[0][0]
display_emotion = self.classes[most_common_pred]
else:
# If the deque is empty, it's a static image/video, so use the direct prediction.
display_emotion = self.classes[pred_index]
# Reset the deque for the next live session if this was a static call
if len(self.recent_predictions) == 0:
self.recent_predictions.clear()
text = f"{display_emotion} ({confidence*100:.1f}%)"
# --- END FIX ---
GREEN = (0, 255, 0); BLACK = (0, 0, 0); FONT = cv2.FONT_HERSHEY_SIMPLEX
(text_width, text_height), baseline = cv2.getTextSize(text, FONT, 0.8, 2)
cv2.rectangle(annotated_frame, (x, y - text_height - baseline - 10), (x + text_width + 10, y), GREEN, cv2.FILLED)
cv2.putText(annotated_frame, text, (x + 5, y - 5), FONT, 0.8, BLACK, 2)
cv2.rectangle(annotated_frame, (x, y), (x+width, y+height), GREEN, 3)
all_probabilities = {self.classes[i]: float(predictions[i]) for i in range(len(self.classes))}
return annotated_frame, all_probabilities |