Spaces:
Build error
Build error
| import streamlit as st | |
| from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration | |
| import av | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| from collections import deque | |
| import plotly.graph_objects as go | |
| from PIL import Image | |
| # Page config | |
| st.set_page_config( | |
| page_title="MindSense AI | Emotion Recognition", | |
| page_icon="๐ง ", | |
| layout="wide" | |
| ) | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap'); | |
| * { font-family: 'Inter', sans-serif; } | |
| .main { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| } | |
| .title-gradient { | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 50%, #f093fb 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-size: 3rem; | |
| font-weight: 800; | |
| text-align: center; | |
| margin-bottom: 10px; | |
| } | |
| .subtitle { | |
| text-align: center; | |
| color: rgba(255, 255, 255, 0.9); | |
| font-size: 1.1rem; | |
| margin-bottom: 30px; | |
| } | |
| .metric-card { | |
| background: rgba(255, 255, 255, 0.1); | |
| backdrop-filter: blur(20px); | |
| border: 1px solid rgba(255, 255, 255, 0.2); | |
| border-radius: 15px; | |
| padding: 20px; | |
| margin: 10px 0; | |
| } | |
| div[data-testid="stMetricValue"] { | |
| font-size: 1.8rem; | |
| font-weight: 700; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ============================================================================ | |
| # Load Model from HuggingFace Hub | |
| # ============================================================================ | |
| def load_model(): | |
| """Download and load model from HF Hub""" | |
| repo_id = "Arko007/mindsense-emotion-model" | |
| with st.spinner("๐ง Loading AI model..."): | |
| try: | |
| model_path = hf_hub_download(repo_id=repo_id, filename="mindsense_emotion_model.pt") | |
| config_path = hf_hub_download(repo_id=repo_id, filename="model_config.json") | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| model = torch.jit.load(model_path, map_location='cpu') | |
| model.eval() | |
| return model, config | |
| except Exception as e: | |
| st.error(f"โ Error loading model: {e}") | |
| return None, None | |
| model, config = load_model() | |
| if model is None: | |
| st.error("Failed to load model. Please check the repository.") | |
| st.stop() | |
| st.success(f"โ Model loaded! Accuracy: {config.get('best_val_acc', 0):.2f}%") | |
| # ============================================================================ | |
| # Emotion Analyzer | |
| # ============================================================================ | |
| class EmotionAnalyzer: | |
| def __init__(self, model, config): | |
| self.model = model | |
| self.config = config | |
| self.emotions = config['classes'] | |
| self.mean = np.array(config['mean']).reshape(3, 1, 1) | |
| self.std = np.array(config['std']).reshape(3, 1, 1) | |
| self.face_cascade = cv2.CascadeClassifier( | |
| cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' | |
| ) | |
| def analyze_frame(self, frame): | |
| """Analyze frame for emotions""" | |
| try: | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| faces = self.face_cascade.detectMultiScale(gray, 1.3, 5) | |
| if len(faces) == 0: | |
| return self._default_result() | |
| x, y, w, h = max(faces, key=lambda f: f[2] * f[3]) | |
| face_roi = frame[y:y+h, x:x+w] | |
| # Preprocess | |
| face_rgb = cv2.cvtColor(face_roi, cv2.COLOR_BGR2RGB) | |
| face_resized = cv2.resize(face_rgb, (384, 384)) | |
| img_tensor = torch.from_numpy(face_resized).float().permute(2, 0, 1) / 255.0 | |
| img_tensor = (img_tensor - torch.from_numpy(self.mean).float()) / torch.from_numpy(self.std).float() | |
| img_tensor = img_tensor.unsqueeze(0) | |
| # Inference | |
| emotion_logits, stress_pred, valence_pred = self.model(img_tensor) | |
| emotion_probs = F.softmax(emotion_logits, dim=1)[0].numpy() | |
| emotion_idx = np.argmax(emotion_probs) | |
| return { | |
| 'dominant_emotion': self.emotions[emotion_idx], | |
| 'confidence': float(emotion_probs[emotion_idx]), | |
| 'all_emotions': {e: float(p) for e, p in zip(self.emotions, emotion_probs)}, | |
| 'stress_score': float(stress_pred.item()), | |
| 'valence': float(valence_pred.item()), | |
| 'face_location': (x, y, w, h) | |
| } | |
| except Exception as e: | |
| return self._default_result() | |
| def _default_result(self): | |
| return { | |
| 'dominant_emotion': 'neutral', | |
| 'confidence': 0.0, | |
| 'all_emotions': {e: 0.0 for e in self.emotions}, | |
| 'stress_score': 0.0, | |
| 'valence': 0.0, | |
| 'face_location': None | |
| } | |
| # Initialize analyzer | |
| if 'analyzer' not in st.session_state: | |
| st.session_state.analyzer = EmotionAnalyzer(model, config) | |
| if 'emotion_history' not in st.session_state: | |
| st.session_state.emotion_history = deque(maxlen=100) | |
| if 'stress_scores' not in st.session_state: | |
| st.session_state.stress_scores = deque(maxlen=100) | |
| # ============================================================================ | |
| # UI | |
| # ============================================================================ | |
| st.markdown('<h1 class="title-gradient">๐ง MindSense AI</h1>', unsafe_allow_html=True) | |
| st.markdown('<p class="subtitle">Real-Time Emotion Recognition & Mental Health Assessment</p>', unsafe_allow_html=True) | |
| # Sidebar | |
| with st.sidebar: | |
| st.markdown("### โ๏ธ Settings") | |
| confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5, 0.05) | |
| show_all_emotions = st.checkbox("Show All Emotions", value=True) | |
| st.markdown("---") | |
| st.markdown("### ๐ Model Info") | |
| st.info(f""" | |
| **Architecture:** Custom EfficientNet-CNN | |
| **Parameters:** {config.get('total_params', 0) / 1e6:.2f}M | |
| **Accuracy:** {config.get('best_val_acc', 0):.2f}% | |
| **Trained on:** FER2013 (28k images) | |
| """) | |
| # Main content | |
| tab1, tab2 = st.tabs(["๐ฅ Live Webcam", "๐ค Upload Image"]) | |
| with tab1: | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.markdown("### Live Analysis") | |
| rtc_config = RTCConfiguration( | |
| {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} | |
| ) | |
| class VideoProcessor: | |
| def __init__(self): | |
| self.frame_count = 0 | |
| def recv(self, frame): | |
| img = frame.to_ndarray(format="bgr24") | |
| self.frame_count += 1 | |
| if self.frame_count % 3 == 0: | |
| result = st.session_state.analyzer.analyze_frame(img) | |
| if result['face_location']: | |
| x, y, w, h = result['face_location'] | |
| emotion = result['dominant_emotion'] | |
| confidence = result['confidence'] | |
| color_map = { | |
| 'happy': (0, 255, 0), 'sad': (255, 0, 0), | |
| 'angry': (0, 0, 255), 'fear': (128, 0, 128), | |
| 'surprise': (255, 255, 0), 'neutral': (128, 128, 128), | |
| 'disgust': (0, 128, 128) | |
| } | |
| color = color_map.get(emotion, (255, 255, 255)) | |
| cv2.rectangle(img, (x, y), (x+w, y+h), color, 2) | |
| label = f"{emotion.upper()} ({confidence:.0%})" | |
| cv2.putText(img, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) | |
| if confidence > confidence_threshold: | |
| st.session_state.emotion_history.append(emotion) | |
| st.session_state.stress_scores.append(result['stress_score']) | |
| return av.VideoFrame.from_ndarray(img, format="bgr24") | |
| webrtc_ctx = webrtc_streamer( | |
| key="emotion-detection", | |
| mode=WebRtcMode.SENDRECV, | |
| rtc_configuration=rtc_config, | |
| video_processor_factory=VideoProcessor, | |
| media_stream_constraints={"video": True, "audio": False}, | |
| async_processing=True | |
| ) | |
| with col2: | |
| st.markdown("### ๐ Live Metrics") | |
| if len(st.session_state.emotion_history) > 0: | |
| current_emotion = st.session_state.emotion_history[-1] | |
| avg_stress = np.mean(list(st.session_state.stress_scores)[-10:]) | |
| emotion_emoji = { | |
| 'happy': '๐', 'sad': '๐ข', 'angry': '๐ ', | |
| 'fear': '๐จ', 'surprise': '๐ฎ', 'neutral': '๐', | |
| 'disgust': '๐คข' | |
| } | |
| st.markdown(f"## {emotion_emoji.get(current_emotion, '๐')} {current_emotion.title()}") | |
| st.metric("Stress Level", f"{avg_stress:.1%}") | |
| st.progress(avg_stress) | |
| if show_all_emotions: | |
| st.markdown("#### All Emotions") | |
| result = st.session_state.analyzer.analyze_frame(np.zeros((100, 100, 3), dtype=np.uint8)) | |
| for emotion, prob in sorted(result['all_emotions'].items(), key=lambda x: x[1], reverse=True): | |
| st.text(f"{emotion.title()}: {prob:.1%}") | |
| else: | |
| st.info("๐ Start webcam to begin") | |
| with tab2: | |
| st.markdown("### Upload an Image") | |
| uploaded_file = st.file_uploader("Choose an image...", type=['jpg', 'jpeg', 'png']) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file) | |
| image_np = np.array(image) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| with col2: | |
| result = st.session_state.analyzer.analyze_frame(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) | |
| st.markdown("### ๐ญ Analysis Results") | |
| st.markdown(f"**Emotion:** {result['dominant_emotion'].title()}") | |
| st.markdown(f"**Confidence:** {result['confidence']:.1%}") | |
| st.markdown(f"**Stress:** {result['stress_score']:.1%}") | |
| st.markdown(f"**Valence:** {result['valence']:.2f}") | |
| if show_all_emotions: | |
| st.markdown("#### Emotion Distribution") | |
| for emotion, prob in sorted(result['all_emotions'].items(), key=lambda x: x[1], reverse=True): | |
| st.progress(prob) | |
| st.caption(f"{emotion.title()}: {prob:.1%}") | |
| # Visualizations | |
| if len(st.session_state.emotion_history) > 10: | |
| st.markdown("---") | |
| st.markdown("### ๐ Analysis Dashboard") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| from collections import Counter | |
| emotion_counts = Counter(st.session_state.emotion_history) | |
| fig = go.Figure(data=[go.Pie( | |
| labels=list(emotion_counts.keys()), | |
| values=list(emotion_counts.values()), | |
| hole=0.4 | |
| )]) | |
| fig.update_layout(title="Emotion Distribution", height=300) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with col2: | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| y=list(st.session_state.stress_scores), | |
| mode='lines', | |
| fill='tozeroy', | |
| line=dict(color='#667eea', width=2) | |
| )) | |
| fig.update_layout(title="Stress Timeline", height=300, yaxis_range=[0, 1]) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Footer | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style='text-align:center; color:rgba(255,255,255,0.7);'> | |
| <p>๐ง MindSense AI | Built with PyTorch & Streamlit</p> | |
| <p>โ ๏ธ <strong>Disclaimer:</strong> Research tool only. Not for medical diagnosis.</p> | |
| </div> | |
| """, unsafe_allow_html=True) |