fer / app.py
imran-nawar's picture
Update app.py
8e52a03 verified
"""
Facial Emotion Recognition System
This file can be used for inference using the pretrained model, which incorporated CLIP encoder
to extract features and then classifies facial emotions using a fully connected layer. This model
achieved about 62% accuracy on the test data.
"""
import streamlit as st
import cv2
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from model_architecture import CLIPClassifier
from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration, WebRtcMode
import av
import logging
import threading
from typing import Union
import time
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Page configuration
st.set_page_config(page_title="Facial Emotion Recognition", layout="wide")
# Constants
EMOTIONS = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']
PRETRAINED_MODEL = 'clip_fer.pth'
# WebRTC configuration with more STUN servers and longer timeout
RTC_CONFIGURATION = RTCConfiguration(
{"iceServers": [
{"urls": ["stun:stun.l.google.com:19302"]},
{"urls": ["stun:stun1.l.google.com:19302"]},
{"urls": ["stun:stun2.l.google.com:19302"]},
{"urls": ["stun:stun3.l.google.com:19302"]},
{"urls": ["stun:stun4.l.google.com:19302"]},
]},
ice_connection_timeout=30, # Increased timeout
)
# Global variables
model = None
face_cascade = None
model_lock = threading.Lock()
@st.cache_resource
def load_model(model_path, num_classes=7):
"""Load the CLIP-based emotion recognition model with caching."""
try:
model = CLIPClassifier(num_classes)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
return model
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
@st.cache_resource
def load_face_cascade():
"""Load the face cascade classifier with caching."""
try:
return cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
except Exception as e:
logger.error(f"Error loading face cascade: {str(e)}")
raise
def preprocess_image(image):
"""Preprocess image for the model."""
try:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
return transform(image).unsqueeze(0)
except Exception as e:
logger.error(f"Error preprocessing image: {str(e)}")
return None
class EmotionDetector(VideoTransformerBase):
def __init__(self) -> None:
self.model = load_model(PRETRAINED_MODEL)
self.face_cascade = load_face_cascade()
self.frame_count = 0
self.start_time = time.time()
logger.info("EmotionDetector initialized successfully")
def recv(self, frame: av.VideoFrame) -> Union[av.VideoFrame, None]:
self.frame_count += 1
if self.frame_count == 1:
logger.info("First frame received successfully")
try:
img = frame.to_ndarray(format="bgr24")
# Performance logging
if self.frame_count % 30 == 0: # Log every 30 frames
current_time = time.time()
fps = 30 / (current_time - self.start_time)
logger.info(f"FPS: {fps:.2f}")
self.start_time = current_time
# Process frame
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
faces = self.face_cascade.detectMultiScale(
image=gray,
scaleFactor=1.1,
minNeighbors=5,
minSize=(30, 30)
)
for (x, y, w, h) in faces:
try:
face = img[y:y+h, x:x+w]
face_tensor = preprocess_image(face)
if face_tensor is not None:
with torch.no_grad():
output = self.model(face_tensor)
_, predicted = torch.max(output, 1)
emotion = EMOTIONS[predicted.item()]
cv2.rectangle(img, (x, y), (x+w, y+h), (210, 140, 70), 2)
cv2.putText(img, emotion, (x, y-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (50, 200, 0), 2)
except Exception as e:
logger.error(f"Error processing face: {str(e)}")
continue
return av.VideoFrame.from_ndarray(img, format="bgr24")
except Exception as e:
logger.error(f"Error in frame processing: {str(e)}")
return frame
def main():
st.sidebar.header("Facial Emotion Recognition System")
activities = ["Home", "Live Detection", "About"]
choice = st.sidebar.selectbox("Select Activity", activities)
if choice == "Live Detection":
st.title("Live Emotion Detection")
st.write("Click 'START' below to begin webcam emotion detection")
status_placeholder = st.empty()
try:
webrtc_ctx = webrtc_streamer(
key="emotion_detection",
mode=WebRtcMode.SENDRECV,
rtc_configuration=RTC_CONFIGURATION,
video_processor_factory=EmotionDetector,
media_stream_constraints={
"video": {
"width": {"min": 480, "ideal": 640, "max": 1920},
"height": {"min": 360, "ideal": 480, "max": 1080},
"frameRate": {"ideal": 30}
},
"audio": False
},
async_processing=True,
video_frame_callback=None,
async_transform=True,
)
# Stream status handling
if webrtc_ctx.state.playing:
status_placeholder.success("✅ Stream started successfully!")
st.write("If you see a black screen:")
st.write("1. Check if camera permissions are granted")
st.write("2. Try refreshing the page")
st.write("3. Try a different browser (Chrome recommended)")
else:
status_placeholder.warning("⚠️ Stream not started. Please check camera permissions.")
except Exception as e:
logger.error(f"Stream initialization error: {str(e)}")
status_placeholder.error(f"❌ Error: {str(e)}")
st.write("Troubleshooting steps:")
st.write("1. Ensure camera is not being used by another application")
st.write("2. Try refreshing the page")
st.write("3. Check browser console for detailed errors")
elif choice == "Home":
st.title("Facial Emotion Recognition System")
st.write("""
This application uses a CLIP model fine-tuned on the FER-2013 dataset
to perform real-time emotion recognition from a live webcam feed.
### Features:
- Real-time face detection
- Emotion classification into 7 categories
- Live webcam processing
To begin, select 'Live Detection' from the sidebar.
""")
elif choice == "About":
st.title("About")
st.write("""
### Facial Emotion Recognition System
This application uses:
- CLIP-based neural network for emotion recognition
- OpenCV for face detection
- Streamlit and streamlit-webrtc for the web interface
Created by Imran Nawar
""")
# Footer
footer = """
<div style="position: fixed; bottom: 0; width: 100%; background-color: #EDF3FA;
padding: 10px; text-align: center;">
Created by Imran Nawar
</div>
"""
st.markdown(footer, unsafe_allow_html=True)
if __name__ == "__main__":
main()