import streamlit as st from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration import av import cv2 import numpy as np import torch import torch.nn.functional as F import json from huggingface_hub import hf_hub_download from collections import deque import plotly.graph_objects as go from PIL import Image # Page config st.set_page_config( page_title="MindSense AI | Emotion Recognition", page_icon="🧠", layout="wide" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) # ============================================================================ # Load Model from HuggingFace Hub # ============================================================================ @st.cache_resource def load_model(): """Download and load model from HF Hub""" repo_id = "Arko007/mindsense-emotion-model" with st.spinner("🧠 Loading AI model..."): try: model_path = hf_hub_download(repo_id=repo_id, filename="mindsense_emotion_model.pt") config_path = hf_hub_download(repo_id=repo_id, filename="model_config.json") with open(config_path, 'r') as f: config = json.load(f) model = torch.jit.load(model_path, map_location='cpu') model.eval() return model, config except Exception as e: st.error(f"❌ Error loading model: {e}") return None, None model, config = load_model() if model is None: st.error("Failed to load model. Please check the repository.") st.stop() st.success(f"✅ Model loaded! Accuracy: {config.get('best_val_acc', 0):.2f}%") # ============================================================================ # Emotion Analyzer # ============================================================================ class EmotionAnalyzer: def __init__(self, model, config): self.model = model self.config = config self.emotions = config['classes'] self.mean = np.array(config['mean']).reshape(3, 1, 1) self.std = np.array(config['std']).reshape(3, 1, 1) self.face_cascade = cv2.CascadeClassifier( cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' ) @torch.no_grad() def analyze_frame(self, frame): """Analyze frame for emotions""" try: gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) faces = self.face_cascade.detectMultiScale(gray, 1.3, 5) if len(faces) == 0: return self._default_result() x, y, w, h = max(faces, key=lambda f: f[2] * f[3]) face_roi = frame[y:y+h, x:x+w] # Preprocess face_rgb = cv2.cvtColor(face_roi, cv2.COLOR_BGR2RGB) face_resized = cv2.resize(face_rgb, (384, 384)) img_tensor = torch.from_numpy(face_resized).float().permute(2, 0, 1) / 255.0 img_tensor = (img_tensor - torch.from_numpy(self.mean).float()) / torch.from_numpy(self.std).float() img_tensor = img_tensor.unsqueeze(0) # Inference emotion_logits, stress_pred, valence_pred = self.model(img_tensor) emotion_probs = F.softmax(emotion_logits, dim=1)[0].numpy() emotion_idx = np.argmax(emotion_probs) return { 'dominant_emotion': self.emotions[emotion_idx], 'confidence': float(emotion_probs[emotion_idx]), 'all_emotions': {e: float(p) for e, p in zip(self.emotions, emotion_probs)}, 'stress_score': float(stress_pred.item()), 'valence': float(valence_pred.item()), 'face_location': (x, y, w, h) } except Exception as e: return self._default_result() def _default_result(self): return { 'dominant_emotion': 'neutral', 'confidence': 0.0, 'all_emotions': {e: 0.0 for e in self.emotions}, 'stress_score': 0.0, 'valence': 0.0, 'face_location': None } # Initialize analyzer if 'analyzer' not in st.session_state: st.session_state.analyzer = EmotionAnalyzer(model, config) if 'emotion_history' not in st.session_state: st.session_state.emotion_history = deque(maxlen=100) if 'stress_scores' not in st.session_state: st.session_state.stress_scores = deque(maxlen=100) # ============================================================================ # UI # ============================================================================ st.markdown('
Real-Time Emotion Recognition & Mental Health Assessment
', unsafe_allow_html=True) # Sidebar with st.sidebar: st.markdown("### ⚙️ Settings") confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5, 0.05) show_all_emotions = st.checkbox("Show All Emotions", value=True) st.markdown("---") st.markdown("### 📊 Model Info") st.info(f""" **Architecture:** Custom EfficientNet-CNN **Parameters:** {config.get('total_params', 0) / 1e6:.2f}M **Accuracy:** {config.get('best_val_acc', 0):.2f}% **Trained on:** FER2013 (28k images) """) # Main content tab1, tab2 = st.tabs(["🎥 Live Webcam", "📤 Upload Image"]) with tab1: col1, col2 = st.columns([2, 1]) with col1: st.markdown("### Live Analysis") rtc_config = RTCConfiguration( {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} ) class VideoProcessor: def __init__(self): self.frame_count = 0 def recv(self, frame): img = frame.to_ndarray(format="bgr24") self.frame_count += 1 if self.frame_count % 3 == 0: result = st.session_state.analyzer.analyze_frame(img) if result['face_location']: x, y, w, h = result['face_location'] emotion = result['dominant_emotion'] confidence = result['confidence'] color_map = { 'happy': (0, 255, 0), 'sad': (255, 0, 0), 'angry': (0, 0, 255), 'fear': (128, 0, 128), 'surprise': (255, 255, 0), 'neutral': (128, 128, 128), 'disgust': (0, 128, 128) } color = color_map.get(emotion, (255, 255, 255)) cv2.rectangle(img, (x, y), (x+w, y+h), color, 2) label = f"{emotion.upper()} ({confidence:.0%})" cv2.putText(img, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) if confidence > confidence_threshold: st.session_state.emotion_history.append(emotion) st.session_state.stress_scores.append(result['stress_score']) return av.VideoFrame.from_ndarray(img, format="bgr24") webrtc_ctx = webrtc_streamer( key="emotion-detection", mode=WebRtcMode.SENDRECV, rtc_configuration=rtc_config, video_processor_factory=VideoProcessor, media_stream_constraints={"video": True, "audio": False}, async_processing=True ) with col2: st.markdown("### 📊 Live Metrics") if len(st.session_state.emotion_history) > 0: current_emotion = st.session_state.emotion_history[-1] avg_stress = np.mean(list(st.session_state.stress_scores)[-10:]) emotion_emoji = { 'happy': '😊', 'sad': '😢', 'angry': '😠', 'fear': '😨', 'surprise': '😮', 'neutral': '😐', 'disgust': '🤢' } st.markdown(f"## {emotion_emoji.get(current_emotion, '😐')} {current_emotion.title()}") st.metric("Stress Level", f"{avg_stress:.1%}") st.progress(avg_stress) if show_all_emotions: st.markdown("#### All Emotions") result = st.session_state.analyzer.analyze_frame(np.zeros((100, 100, 3), dtype=np.uint8)) for emotion, prob in sorted(result['all_emotions'].items(), key=lambda x: x[1], reverse=True): st.text(f"{emotion.title()}: {prob:.1%}") else: st.info("👋 Start webcam to begin") with tab2: st.markdown("### Upload an Image") uploaded_file = st.file_uploader("Choose an image...", type=['jpg', 'jpeg', 'png']) if uploaded_file: image = Image.open(uploaded_file) image_np = np.array(image) col1, col2 = st.columns(2) with col1: st.image(image, caption="Uploaded Image", use_column_width=True) with col2: result = st.session_state.analyzer.analyze_frame(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) st.markdown("### 🎭 Analysis Results") st.markdown(f"**Emotion:** {result['dominant_emotion'].title()}") st.markdown(f"**Confidence:** {result['confidence']:.1%}") st.markdown(f"**Stress:** {result['stress_score']:.1%}") st.markdown(f"**Valence:** {result['valence']:.2f}") if show_all_emotions: st.markdown("#### Emotion Distribution") for emotion, prob in sorted(result['all_emotions'].items(), key=lambda x: x[1], reverse=True): st.progress(prob) st.caption(f"{emotion.title()}: {prob:.1%}") # Visualizations if len(st.session_state.emotion_history) > 10: st.markdown("---") st.markdown("### 📈 Analysis Dashboard") col1, col2 = st.columns(2) with col1: from collections import Counter emotion_counts = Counter(st.session_state.emotion_history) fig = go.Figure(data=[go.Pie( labels=list(emotion_counts.keys()), values=list(emotion_counts.values()), hole=0.4 )]) fig.update_layout(title="Emotion Distribution", height=300) st.plotly_chart(fig, use_container_width=True) with col2: fig = go.Figure() fig.add_trace(go.Scatter( y=list(st.session_state.stress_scores), mode='lines', fill='tozeroy', line=dict(color='#667eea', width=2) )) fig.update_layout(title="Stress Timeline", height=300, yaxis_range=[0, 1]) st.plotly_chart(fig, use_container_width=True) # Footer st.markdown("---") st.markdown("""🧠 MindSense AI | Built with PyTorch & Streamlit
⚠️ Disclaimer: Research tool only. Not for medical diagnosis.