GitHub Action
Automated deployment from GitHub Actions
007a9b7
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
import numpy as np
import cv2
from mtcnn import MTCNN
from collections import deque, Counter
from PIL import Image
LOCAL_MODEL_PATH = "sota_model"
class HFPredictor:
def __init__(self, smoothing_window=10, confidence_threshold=0.3):
print(f"[PREDICTOR INFO] Loading model from local path: {LOCAL_MODEL_PATH}...")
self.processor = AutoImageProcessor.from_pretrained(LOCAL_MODEL_PATH)
self.model = AutoModelForImageClassification.from_pretrained(LOCAL_MODEL_PATH)
self.face_detector = MTCNN()
self.classes = list(self.model.config.id2label.values())
self.confidence_threshold = confidence_threshold
self.recent_predictions = deque(maxlen=smoothing_window)
self.stable_prediction = "---"
print("[PREDICTOR INFO] Predictor initialized successfully.")
def process_frame(self, frame):
"""
Processes a single frame. This function is now used for ALL predictions
(live, image, and video) to ensure consistency.
"""
if frame is None: return frame, {}
annotated_frame = frame.copy()
all_probabilities = {}
faces = self.face_detector.detect_faces(frame)
for face in faces:
x, y, width, height = face['box']
x, y = max(0, x), max(0, y)
face_roi = frame[y:y+height, x:x+width]
if face_roi.size > 0:
pil_image = Image.fromarray(face_roi)
inputs = self.processor(images=pil_image, return_tensors="pt")
with torch.no_grad():
logits = self.model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
predictions = probs[0].numpy()
pred_index = np.argmax(predictions)
confidence = predictions[pred_index]
# --- THIS IS THE DEFINITIVE FIX ---
# For the bounding box text, we determine which label to show.
# For the live feed, we want smooth predictions. For static images, we want the direct one.
# A simple check on the deque can tell us if we are in a "live" context.
if len(self.recent_predictions) > 0:
# If the deque has items, we are in a live stream, so use smoothing.
if confidence > self.confidence_threshold:
self.recent_predictions.append(pred_index)
most_common_pred = Counter(self.recent_predictions).most_common(1)[0][0]
display_emotion = self.classes[most_common_pred]
else:
# If the deque is empty, it's a static image/video, so use the direct prediction.
display_emotion = self.classes[pred_index]
# Reset the deque for the next live session if this was a static call
if len(self.recent_predictions) == 0:
self.recent_predictions.clear()
text = f"{display_emotion} ({confidence*100:.1f}%)"
# --- END FIX ---
GREEN = (0, 255, 0); BLACK = (0, 0, 0); FONT = cv2.FONT_HERSHEY_SIMPLEX
(text_width, text_height), baseline = cv2.getTextSize(text, FONT, 0.8, 2)
cv2.rectangle(annotated_frame, (x, y - text_height - baseline - 10), (x + text_width + 10, y), GREEN, cv2.FILLED)
cv2.putText(annotated_frame, text, (x + 5, y - 5), FONT, 0.8, BLACK, 2)
cv2.rectangle(annotated_frame, (x, y), (x+width, y+height), GREEN, 3)
all_probabilities = {self.classes[i]: float(predictions[i]) for i in range(len(self.classes))}
return annotated_frame, all_probabilities