expression-recognition / src /streamlit_app.py
allantacuelwvsu's picture
added evaluations
69535bd
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import streamlit as st
from streamlit_webrtc import webrtc_streamer, VideoProcessorBase
import av
import cv2
import torch
import numpy as np
from torchvision import transforms
from utils.model_loader import load_model
st.title("Facial Expression Recognition")
st.caption("dataset: https://www.kaggle.com/datasets/jonathanoheix/face-expression-recognition-dataset")
# Load model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "data/models/expression_predictor_cnn.pth"
CLASSES = ['Angry', 'Disgust', 'Scared', 'Happy', 'Neutral', 'Sad', 'Surprised']
model = load_model(MODEL_PATH, DEVICE)
# Face detection
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
# Transform for inference
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Grayscale(),
transforms.Resize((48, 48)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Video processor
livestatus = st.empty()
class VideoProcessor(VideoProcessorBase):
def recv(self, frame):
global global_face_data
img = frame.to_ndarray(format="bgr24")
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5)
face_data = []
for i, (x, y, w, h) in enumerate(faces):
face_crop = gray[y:y+h, x:x+w]
face_tensor = transform(face_crop).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(face_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)[0].cpu().numpy()
top_idx = np.argmax(probs)
label = CLASSES[top_idx]
# Draw face + label on video
face_id = f"Face {i+1}"
cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 1)
cv2.putText(img, f"{face_id}: {label}", (x, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
return av.VideoFrame.from_ndarray(img, format="bgr24")
ctx = webrtc_streamer(
key="emotion-detect",
video_processor_factory=VideoProcessor,
media_stream_constraints={"video": True, "audio": False},
async_processing=True,
)
if ctx.state.playing:
livestatus.success("🟒 Live")
else:
livestatus.error("πŸ”΄ Offline")