Arko007 commited on
Commit
ab232bc
ยท
verified ยท
1 Parent(s): 31ed371

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +340 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,342 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
3
+ import av
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import json
9
+ from huggingface_hub import hf_hub_download
10
+ from collections import deque
11
+ import plotly.graph_objects as go
12
+ from PIL import Image
13
+
14
+ # Page config
15
+ st.set_page_config(
16
+ page_title="MindSense AI | Emotion Recognition",
17
+ page_icon="๐Ÿง ",
18
+ layout="wide"
19
+ )
20
+
21
+ # Custom CSS
22
+ st.markdown("""
23
+ <style>
24
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
25
+
26
+ * { font-family: 'Inter', sans-serif; }
27
+
28
+ .main {
29
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
30
+ }
31
+
32
+ .title-gradient {
33
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 50%, #f093fb 100%);
34
+ -webkit-background-clip: text;
35
+ -webkit-text-fill-color: transparent;
36
+ font-size: 3rem;
37
+ font-weight: 800;
38
+ text-align: center;
39
+ margin-bottom: 10px;
40
+ }
41
+
42
+ .subtitle {
43
+ text-align: center;
44
+ color: rgba(255, 255, 255, 0.9);
45
+ font-size: 1.1rem;
46
+ margin-bottom: 30px;
47
+ }
48
+
49
+ .metric-card {
50
+ background: rgba(255, 255, 255, 0.1);
51
+ backdrop-filter: blur(20px);
52
+ border: 1px solid rgba(255, 255, 255, 0.2);
53
+ border-radius: 15px;
54
+ padding: 20px;
55
+ margin: 10px 0;
56
+ }
57
+
58
+ div[data-testid="stMetricValue"] {
59
+ font-size: 1.8rem;
60
+ font-weight: 700;
61
+ }
62
+ </style>
63
+ """, unsafe_allow_html=True)
64
+
65
+ # ============================================================================
66
+ # Load Model from HuggingFace Hub
67
+ # ============================================================================
68
+
69
+ @st.cache_resource
70
+ def load_model():
71
+ """Download and load model from HF Hub"""
72
+ repo_id = "Arko007/mindsense-emotion-model"
73
+
74
+ with st.spinner("๐Ÿง  Loading AI model..."):
75
+ try:
76
+ model_path = hf_hub_download(repo_id=repo_id, filename="mindsense_emotion_model.pt")
77
+ config_path = hf_hub_download(repo_id=repo_id, filename="model_config.json")
78
+
79
+ with open(config_path, 'r') as f:
80
+ config = json.load(f)
81
+
82
+ model = torch.jit.load(model_path, map_location='cpu')
83
+ model.eval()
84
+
85
+ return model, config
86
+
87
+ except Exception as e:
88
+ st.error(f"โŒ Error loading model: {e}")
89
+ return None, None
90
+
91
+ model, config = load_model()
92
+
93
+ if model is None:
94
+ st.error("Failed to load model. Please check the repository.")
95
+ st.stop()
96
+
97
+ st.success(f"โœ… Model loaded! Accuracy: {config.get('best_val_acc', 0):.2f}%")
98
+
99
+ # ============================================================================
100
+ # Emotion Analyzer
101
+ # ============================================================================
102
+
103
+ class EmotionAnalyzer:
104
+ def __init__(self, model, config):
105
+ self.model = model
106
+ self.config = config
107
+ self.emotions = config['classes']
108
+ self.mean = np.array(config['mean']).reshape(3, 1, 1)
109
+ self.std = np.array(config['std']).reshape(3, 1, 1)
110
+ self.face_cascade = cv2.CascadeClassifier(
111
+ cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
112
+ )
113
+
114
+ @torch.no_grad()
115
+ def analyze_frame(self, frame):
116
+ """Analyze frame for emotions"""
117
+ try:
118
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
119
+ faces = self.face_cascade.detectMultiScale(gray, 1.3, 5)
120
+
121
+ if len(faces) == 0:
122
+ return self._default_result()
123
+
124
+ x, y, w, h = max(faces, key=lambda f: f[2] * f[3])
125
+ face_roi = frame[y:y+h, x:x+w]
126
+
127
+ # Preprocess
128
+ face_rgb = cv2.cvtColor(face_roi, cv2.COLOR_BGR2RGB)
129
+ face_resized = cv2.resize(face_rgb, (384, 384))
130
+
131
+ img_tensor = torch.from_numpy(face_resized).float().permute(2, 0, 1) / 255.0
132
+ img_tensor = (img_tensor - torch.from_numpy(self.mean).float()) / torch.from_numpy(self.std).float()
133
+ img_tensor = img_tensor.unsqueeze(0)
134
+
135
+ # Inference
136
+ emotion_logits, stress_pred, valence_pred = self.model(img_tensor)
137
+
138
+ emotion_probs = F.softmax(emotion_logits, dim=1)[0].numpy()
139
+ emotion_idx = np.argmax(emotion_probs)
140
+
141
+ return {
142
+ 'dominant_emotion': self.emotions[emotion_idx],
143
+ 'confidence': float(emotion_probs[emotion_idx]),
144
+ 'all_emotions': {e: float(p) for e, p in zip(self.emotions, emotion_probs)},
145
+ 'stress_score': float(stress_pred.item()),
146
+ 'valence': float(valence_pred.item()),
147
+ 'face_location': (x, y, w, h)
148
+ }
149
+
150
+ except Exception as e:
151
+ return self._default_result()
152
+
153
+ def _default_result(self):
154
+ return {
155
+ 'dominant_emotion': 'neutral',
156
+ 'confidence': 0.0,
157
+ 'all_emotions': {e: 0.0 for e in self.emotions},
158
+ 'stress_score': 0.0,
159
+ 'valence': 0.0,
160
+ 'face_location': None
161
+ }
162
+
163
+ # Initialize analyzer
164
+ if 'analyzer' not in st.session_state:
165
+ st.session_state.analyzer = EmotionAnalyzer(model, config)
166
+ if 'emotion_history' not in st.session_state:
167
+ st.session_state.emotion_history = deque(maxlen=100)
168
+ if 'stress_scores' not in st.session_state:
169
+ st.session_state.stress_scores = deque(maxlen=100)
170
+
171
+ # ============================================================================
172
+ # UI
173
+ # ============================================================================
174
+
175
+ st.markdown('<h1 class="title-gradient">๐Ÿง  MindSense AI</h1>', unsafe_allow_html=True)
176
+ st.markdown('<p class="subtitle">Real-Time Emotion Recognition & Mental Health Assessment</p>', unsafe_allow_html=True)
177
+
178
+ # Sidebar
179
+ with st.sidebar:
180
+ st.markdown("### โš™๏ธ Settings")
181
+ confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5, 0.05)
182
+ show_all_emotions = st.checkbox("Show All Emotions", value=True)
183
+
184
+ st.markdown("---")
185
+ st.markdown("### ๐Ÿ“Š Model Info")
186
+ st.info(f"""
187
+ **Architecture:** Custom EfficientNet-CNN
188
+
189
+ **Parameters:** {config.get('total_params', 0) / 1e6:.2f}M
190
+
191
+ **Accuracy:** {config.get('best_val_acc', 0):.2f}%
192
+
193
+ **Trained on:** FER2013 (28k images)
194
+ """)
195
+
196
+ # Main content
197
+ tab1, tab2 = st.tabs(["๐ŸŽฅ Live Webcam", "๐Ÿ“ค Upload Image"])
198
+
199
+ with tab1:
200
+ col1, col2 = st.columns([2, 1])
201
+
202
+ with col1:
203
+ st.markdown("### Live Analysis")
204
+
205
+ rtc_config = RTCConfiguration(
206
+ {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
207
+ )
208
+
209
+ class VideoProcessor:
210
+ def __init__(self):
211
+ self.frame_count = 0
212
+
213
+ def recv(self, frame):
214
+ img = frame.to_ndarray(format="bgr24")
215
+ self.frame_count += 1
216
+
217
+ if self.frame_count % 3 == 0:
218
+ result = st.session_state.analyzer.analyze_frame(img)
219
+
220
+ if result['face_location']:
221
+ x, y, w, h = result['face_location']
222
+ emotion = result['dominant_emotion']
223
+ confidence = result['confidence']
224
+
225
+ color_map = {
226
+ 'happy': (0, 255, 0), 'sad': (255, 0, 0),
227
+ 'angry': (0, 0, 255), 'fear': (128, 0, 128),
228
+ 'surprise': (255, 255, 0), 'neutral': (128, 128, 128),
229
+ 'disgust': (0, 128, 128)
230
+ }
231
+ color = color_map.get(emotion, (255, 255, 255))
232
+
233
+ cv2.rectangle(img, (x, y), (x+w, y+h), color, 2)
234
+ label = f"{emotion.upper()} ({confidence:.0%})"
235
+ cv2.putText(img, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
236
+
237
+ if confidence > confidence_threshold:
238
+ st.session_state.emotion_history.append(emotion)
239
+ st.session_state.stress_scores.append(result['stress_score'])
240
+
241
+ return av.VideoFrame.from_ndarray(img, format="bgr24")
242
+
243
+ webrtc_ctx = webrtc_streamer(
244
+ key="emotion-detection",
245
+ mode=WebRtcMode.SENDRECV,
246
+ rtc_configuration=rtc_config,
247
+ video_processor_factory=VideoProcessor,
248
+ media_stream_constraints={"video": True, "audio": False},
249
+ async_processing=True
250
+ )
251
+
252
+ with col2:
253
+ st.markdown("### ๐Ÿ“Š Live Metrics")
254
+
255
+ if len(st.session_state.emotion_history) > 0:
256
+ current_emotion = st.session_state.emotion_history[-1]
257
+ avg_stress = np.mean(list(st.session_state.stress_scores)[-10:])
258
+
259
+ emotion_emoji = {
260
+ 'happy': '๐Ÿ˜Š', 'sad': '๐Ÿ˜ข', 'angry': '๐Ÿ˜ ',
261
+ 'fear': '๐Ÿ˜จ', 'surprise': '๐Ÿ˜ฎ', 'neutral': '๐Ÿ˜',
262
+ 'disgust': '๐Ÿคข'
263
+ }
264
+
265
+ st.markdown(f"## {emotion_emoji.get(current_emotion, '๐Ÿ˜')} {current_emotion.title()}")
266
+ st.metric("Stress Level", f"{avg_stress:.1%}")
267
+ st.progress(avg_stress)
268
+
269
+ if show_all_emotions:
270
+ st.markdown("#### All Emotions")
271
+ result = st.session_state.analyzer.analyze_frame(np.zeros((100, 100, 3), dtype=np.uint8))
272
+ for emotion, prob in sorted(result['all_emotions'].items(), key=lambda x: x[1], reverse=True):
273
+ st.text(f"{emotion.title()}: {prob:.1%}")
274
+ else:
275
+ st.info("๐Ÿ‘‹ Start webcam to begin")
276
+
277
+ with tab2:
278
+ st.markdown("### Upload an Image")
279
+ uploaded_file = st.file_uploader("Choose an image...", type=['jpg', 'jpeg', 'png'])
280
+
281
+ if uploaded_file:
282
+ image = Image.open(uploaded_file)
283
+ image_np = np.array(image)
284
+
285
+ col1, col2 = st.columns(2)
286
+
287
+ with col1:
288
+ st.image(image, caption="Uploaded Image", use_column_width=True)
289
+
290
+ with col2:
291
+ result = st.session_state.analyzer.analyze_frame(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
292
+
293
+ st.markdown("### ๐ŸŽญ Analysis Results")
294
+ st.markdown(f"**Emotion:** {result['dominant_emotion'].title()}")
295
+ st.markdown(f"**Confidence:** {result['confidence']:.1%}")
296
+ st.markdown(f"**Stress:** {result['stress_score']:.1%}")
297
+ st.markdown(f"**Valence:** {result['valence']:.2f}")
298
+
299
+ if show_all_emotions:
300
+ st.markdown("#### Emotion Distribution")
301
+ for emotion, prob in sorted(result['all_emotions'].items(), key=lambda x: x[1], reverse=True):
302
+ st.progress(prob)
303
+ st.caption(f"{emotion.title()}: {prob:.1%}")
304
+
305
+ # Visualizations
306
+ if len(st.session_state.emotion_history) > 10:
307
+ st.markdown("---")
308
+ st.markdown("### ๐Ÿ“ˆ Analysis Dashboard")
309
+
310
+ col1, col2 = st.columns(2)
311
+
312
+ with col1:
313
+ from collections import Counter
314
+ emotion_counts = Counter(st.session_state.emotion_history)
315
+
316
+ fig = go.Figure(data=[go.Pie(
317
+ labels=list(emotion_counts.keys()),
318
+ values=list(emotion_counts.values()),
319
+ hole=0.4
320
+ )])
321
+ fig.update_layout(title="Emotion Distribution", height=300)
322
+ st.plotly_chart(fig, use_container_width=True)
323
+
324
+ with col2:
325
+ fig = go.Figure()
326
+ fig.add_trace(go.Scatter(
327
+ y=list(st.session_state.stress_scores),
328
+ mode='lines',
329
+ fill='tozeroy',
330
+ line=dict(color='#667eea', width=2)
331
+ ))
332
+ fig.update_layout(title="Stress Timeline", height=300, yaxis_range=[0, 1])
333
+ st.plotly_chart(fig, use_container_width=True)
334
 
335
+ # Footer
336
+ st.markdown("---")
337
+ st.markdown("""
338
+ <div style='text-align:center; color:rgba(255,255,255,0.7);'>
339
+ <p>๐Ÿง  MindSense AI | Built with PyTorch & Streamlit</p>
340
+ <p>โš ๏ธ <strong>Disclaimer:</strong> Research tool only. Not for medical diagnosis.</p>
341
+ </div>
342
+ """, unsafe_allow_html=True)