|
|
import sys |
|
|
import os |
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
|
|
|
|
|
import streamlit as st |
|
|
from streamlit_webrtc import webrtc_streamer, VideoProcessorBase |
|
|
import av |
|
|
import cv2 |
|
|
import torch |
|
|
import numpy as np |
|
|
from torchvision import transforms |
|
|
from utils.model_loader import load_model |
|
|
|
|
|
|
|
|
st.title("Facial Expression Recognition") |
|
|
st.caption("dataset: https://www.kaggle.com/datasets/jonathanoheix/face-expression-recognition-dataset") |
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
MODEL_PATH = "data/models/expression_predictor_cnn.pth" |
|
|
CLASSES = ['Angry', 'Disgust', 'Scared', 'Happy', 'Neutral', 'Sad', 'Surprised'] |
|
|
model = load_model(MODEL_PATH, DEVICE) |
|
|
|
|
|
|
|
|
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.ToPILImage(), |
|
|
transforms.Grayscale(), |
|
|
transforms.Resize((48, 48)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5,), (0.5,)) |
|
|
]) |
|
|
|
|
|
|
|
|
livestatus = st.empty() |
|
|
class VideoProcessor(VideoProcessorBase): |
|
|
def recv(self, frame): |
|
|
global global_face_data |
|
|
img = frame.to_ndarray(format="bgr24") |
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
|
|
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5) |
|
|
|
|
|
face_data = [] |
|
|
|
|
|
for i, (x, y, w, h) in enumerate(faces): |
|
|
face_crop = gray[y:y+h, x:x+w] |
|
|
face_tensor = transform(face_crop).unsqueeze(0).to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(face_tensor) |
|
|
probs = torch.nn.functional.softmax(outputs, dim=1)[0].cpu().numpy() |
|
|
top_idx = np.argmax(probs) |
|
|
label = CLASSES[top_idx] |
|
|
|
|
|
|
|
|
face_id = f"Face {i+1}" |
|
|
cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 1) |
|
|
cv2.putText(img, f"{face_id}: {label}", (x, y - 10), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) |
|
|
|
|
|
return av.VideoFrame.from_ndarray(img, format="bgr24") |
|
|
|
|
|
ctx = webrtc_streamer( |
|
|
key="emotion-detect", |
|
|
video_processor_factory=VideoProcessor, |
|
|
media_stream_constraints={"video": True, "audio": False}, |
|
|
async_processing=True, |
|
|
) |
|
|
|
|
|
if ctx.state.playing: |
|
|
livestatus.success("π’ Live") |
|
|
else: |
|
|
livestatus.error("π΄ Offline") |