File size: 3,933 Bytes
f8f5549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc745c3
f8f5549
007a9b7
 
f8f5549
bc745c3
 
 
007a9b7
592e904
f8f5549
 
 
 
bc745c3
f8f5549
 
 
 
 
 
 
 
 
 
 
 
007a9b7
 
 
 
 
 
 
 
 
bc745c3
007a9b7
 
 
 
f8f5549
007a9b7
 
 
 
 
 
bc745c3
007a9b7
bc745c3
 
 
f8f5549
 
bc745c3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
import numpy as np
import cv2
from mtcnn import MTCNN
from collections import deque, Counter
from PIL import Image

LOCAL_MODEL_PATH = "sota_model"

class HFPredictor:
    def __init__(self, smoothing_window=10, confidence_threshold=0.3):
        print(f"[PREDICTOR INFO] Loading model from local path: {LOCAL_MODEL_PATH}...")
        self.processor = AutoImageProcessor.from_pretrained(LOCAL_MODEL_PATH)
        self.model = AutoModelForImageClassification.from_pretrained(LOCAL_MODEL_PATH)
        self.face_detector = MTCNN()
        self.classes = list(self.model.config.id2label.values())
        self.confidence_threshold = confidence_threshold
        self.recent_predictions = deque(maxlen=smoothing_window)
        self.stable_prediction = "---"
        print("[PREDICTOR INFO] Predictor initialized successfully.")

    def process_frame(self, frame):
        """
        Processes a single frame. This function is now used for ALL predictions
        (live, image, and video) to ensure consistency.
        """
        if frame is None: return frame, {}

        annotated_frame = frame.copy()
        all_probabilities = {}

        faces = self.face_detector.detect_faces(frame)
        
        for face in faces:
            x, y, width, height = face['box']
            x, y = max(0, x), max(0, y)
            face_roi = frame[y:y+height, x:x+width]
            
            if face_roi.size > 0:
                pil_image = Image.fromarray(face_roi)
                inputs = self.processor(images=pil_image, return_tensors="pt")
                with torch.no_grad():
                    logits = self.model(**inputs).logits
                
                probs = torch.nn.functional.softmax(logits, dim=-1)
                predictions = probs[0].numpy()
                pred_index = np.argmax(predictions)
                confidence = predictions[pred_index]

                # --- THIS IS THE DEFINITIVE FIX ---
                # For the bounding box text, we determine which label to show.
                # For the live feed, we want smooth predictions. For static images, we want the direct one.
                # A simple check on the deque can tell us if we are in a "live" context.
                if len(self.recent_predictions) > 0:
                    # If the deque has items, we are in a live stream, so use smoothing.
                    if confidence > self.confidence_threshold:
                        self.recent_predictions.append(pred_index)
                    most_common_pred = Counter(self.recent_predictions).most_common(1)[0][0]
                    display_emotion = self.classes[most_common_pred]
                else:
                    # If the deque is empty, it's a static image/video, so use the direct prediction.
                    display_emotion = self.classes[pred_index]
                
                # Reset the deque for the next live session if this was a static call
                if len(self.recent_predictions) == 0:
                    self.recent_predictions.clear()

                text = f"{display_emotion} ({confidence*100:.1f}%)"
                # --- END FIX ---
                
                GREEN = (0, 255, 0); BLACK = (0, 0, 0); FONT = cv2.FONT_HERSHEY_SIMPLEX
                (text_width, text_height), baseline = cv2.getTextSize(text, FONT, 0.8, 2)
                
                cv2.rectangle(annotated_frame, (x, y - text_height - baseline - 10), (x + text_width + 10, y), GREEN, cv2.FILLED)
                cv2.putText(annotated_frame, text, (x + 5, y - 5), FONT, 0.8, BLACK, 2)
                cv2.rectangle(annotated_frame, (x, y), (x+width, y+height), GREEN, 3)
                
                all_probabilities = {self.classes[i]: float(predictions[i]) for i in range(len(self.classes))}
        
        return annotated_frame, all_probabilities