import sys import os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import streamlit as st from streamlit_webrtc import webrtc_streamer, VideoProcessorBase import av import cv2 import torch import numpy as np from torchvision import transforms from utils.model_loader import load_model st.title("Facial Expression Recognition") st.caption("dataset: https://www.kaggle.com/datasets/jonathanoheix/face-expression-recognition-dataset") # Load model DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_PATH = "data/models/expression_predictor_cnn.pth" CLASSES = ['Angry', 'Disgust', 'Scared', 'Happy', 'Neutral', 'Sad', 'Surprised'] model = load_model(MODEL_PATH, DEVICE) # Face detection face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') # Transform for inference transform = transforms.Compose([ transforms.ToPILImage(), transforms.Grayscale(), transforms.Resize((48, 48)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # Video processor livestatus = st.empty() class VideoProcessor(VideoProcessorBase): def recv(self, frame): global global_face_data img = frame.to_ndarray(format="bgr24") gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5) face_data = [] for i, (x, y, w, h) in enumerate(faces): face_crop = gray[y:y+h, x:x+w] face_tensor = transform(face_crop).unsqueeze(0).to(DEVICE) with torch.no_grad(): outputs = model(face_tensor) probs = torch.nn.functional.softmax(outputs, dim=1)[0].cpu().numpy() top_idx = np.argmax(probs) label = CLASSES[top_idx] # Draw face + label on video face_id = f"Face {i+1}" cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 1) cv2.putText(img, f"{face_id}: {label}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) return av.VideoFrame.from_ndarray(img, format="bgr24") ctx = webrtc_streamer( key="emotion-detect", video_processor_factory=VideoProcessor, media_stream_constraints={"video": True, "audio": False}, async_processing=True, ) if ctx.state.playing: livestatus.success("🟢 Live") else: livestatus.error("🔴 Offline")