File size: 2,501 Bytes
1ee0dbf 318ccbc 310eeeb 318ccbc 69535bd 318ccbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import streamlit as st
from streamlit_webrtc import webrtc_streamer, VideoProcessorBase
import av
import cv2
import torch
import numpy as np
from torchvision import transforms
from utils.model_loader import load_model
st.title("Facial Expression Recognition")
st.caption("dataset: https://www.kaggle.com/datasets/jonathanoheix/face-expression-recognition-dataset")
# Load model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "data/models/expression_predictor_cnn.pth"
CLASSES = ['Angry', 'Disgust', 'Scared', 'Happy', 'Neutral', 'Sad', 'Surprised']
model = load_model(MODEL_PATH, DEVICE)
# Face detection
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
# Transform for inference
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Grayscale(),
transforms.Resize((48, 48)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Video processor
livestatus = st.empty()
class VideoProcessor(VideoProcessorBase):
def recv(self, frame):
global global_face_data
img = frame.to_ndarray(format="bgr24")
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5)
face_data = []
for i, (x, y, w, h) in enumerate(faces):
face_crop = gray[y:y+h, x:x+w]
face_tensor = transform(face_crop).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(face_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)[0].cpu().numpy()
top_idx = np.argmax(probs)
label = CLASSES[top_idx]
# Draw face + label on video
face_id = f"Face {i+1}"
cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 1)
cv2.putText(img, f"{face_id}: {label}", (x, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
return av.VideoFrame.from_ndarray(img, format="bgr24")
ctx = webrtc_streamer(
key="emotion-detect",
video_processor_factory=VideoProcessor,
media_stream_constraints={"video": True, "audio": False},
async_processing=True,
)
if ctx.state.playing:
livestatus.success("π’ Live")
else:
livestatus.error("π΄ Offline") |