Emotion_AI / src /streamlit_app.py
Arko007's picture
Update src/streamlit_app.py
ab232bc verified
import streamlit as st
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
import av
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import json
from huggingface_hub import hf_hub_download
from collections import deque
import plotly.graph_objects as go
from PIL import Image
# Page config
st.set_page_config(
page_title="MindSense AI | Emotion Recognition",
page_icon="๐Ÿง ",
layout="wide"
)
# Custom CSS
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
* { font-family: 'Inter', sans-serif; }
.main {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
.title-gradient {
background: linear-gradient(90deg, #667eea 0%, #764ba2 50%, #f093fb 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 3rem;
font-weight: 800;
text-align: center;
margin-bottom: 10px;
}
.subtitle {
text-align: center;
color: rgba(255, 255, 255, 0.9);
font-size: 1.1rem;
margin-bottom: 30px;
}
.metric-card {
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(20px);
border: 1px solid rgba(255, 255, 255, 0.2);
border-radius: 15px;
padding: 20px;
margin: 10px 0;
}
div[data-testid="stMetricValue"] {
font-size: 1.8rem;
font-weight: 700;
}
</style>
""", unsafe_allow_html=True)
# ============================================================================
# Load Model from HuggingFace Hub
# ============================================================================
@st.cache_resource
def load_model():
"""Download and load model from HF Hub"""
repo_id = "Arko007/mindsense-emotion-model"
with st.spinner("๐Ÿง  Loading AI model..."):
try:
model_path = hf_hub_download(repo_id=repo_id, filename="mindsense_emotion_model.pt")
config_path = hf_hub_download(repo_id=repo_id, filename="model_config.json")
with open(config_path, 'r') as f:
config = json.load(f)
model = torch.jit.load(model_path, map_location='cpu')
model.eval()
return model, config
except Exception as e:
st.error(f"โŒ Error loading model: {e}")
return None, None
model, config = load_model()
if model is None:
st.error("Failed to load model. Please check the repository.")
st.stop()
st.success(f"โœ… Model loaded! Accuracy: {config.get('best_val_acc', 0):.2f}%")
# ============================================================================
# Emotion Analyzer
# ============================================================================
class EmotionAnalyzer:
def __init__(self, model, config):
self.model = model
self.config = config
self.emotions = config['classes']
self.mean = np.array(config['mean']).reshape(3, 1, 1)
self.std = np.array(config['std']).reshape(3, 1, 1)
self.face_cascade = cv2.CascadeClassifier(
cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
)
@torch.no_grad()
def analyze_frame(self, frame):
"""Analyze frame for emotions"""
try:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
faces = self.face_cascade.detectMultiScale(gray, 1.3, 5)
if len(faces) == 0:
return self._default_result()
x, y, w, h = max(faces, key=lambda f: f[2] * f[3])
face_roi = frame[y:y+h, x:x+w]
# Preprocess
face_rgb = cv2.cvtColor(face_roi, cv2.COLOR_BGR2RGB)
face_resized = cv2.resize(face_rgb, (384, 384))
img_tensor = torch.from_numpy(face_resized).float().permute(2, 0, 1) / 255.0
img_tensor = (img_tensor - torch.from_numpy(self.mean).float()) / torch.from_numpy(self.std).float()
img_tensor = img_tensor.unsqueeze(0)
# Inference
emotion_logits, stress_pred, valence_pred = self.model(img_tensor)
emotion_probs = F.softmax(emotion_logits, dim=1)[0].numpy()
emotion_idx = np.argmax(emotion_probs)
return {
'dominant_emotion': self.emotions[emotion_idx],
'confidence': float(emotion_probs[emotion_idx]),
'all_emotions': {e: float(p) for e, p in zip(self.emotions, emotion_probs)},
'stress_score': float(stress_pred.item()),
'valence': float(valence_pred.item()),
'face_location': (x, y, w, h)
}
except Exception as e:
return self._default_result()
def _default_result(self):
return {
'dominant_emotion': 'neutral',
'confidence': 0.0,
'all_emotions': {e: 0.0 for e in self.emotions},
'stress_score': 0.0,
'valence': 0.0,
'face_location': None
}
# Initialize analyzer
if 'analyzer' not in st.session_state:
st.session_state.analyzer = EmotionAnalyzer(model, config)
if 'emotion_history' not in st.session_state:
st.session_state.emotion_history = deque(maxlen=100)
if 'stress_scores' not in st.session_state:
st.session_state.stress_scores = deque(maxlen=100)
# ============================================================================
# UI
# ============================================================================
st.markdown('<h1 class="title-gradient">๐Ÿง  MindSense AI</h1>', unsafe_allow_html=True)
st.markdown('<p class="subtitle">Real-Time Emotion Recognition & Mental Health Assessment</p>', unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.markdown("### โš™๏ธ Settings")
confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5, 0.05)
show_all_emotions = st.checkbox("Show All Emotions", value=True)
st.markdown("---")
st.markdown("### ๐Ÿ“Š Model Info")
st.info(f"""
**Architecture:** Custom EfficientNet-CNN
**Parameters:** {config.get('total_params', 0) / 1e6:.2f}M
**Accuracy:** {config.get('best_val_acc', 0):.2f}%
**Trained on:** FER2013 (28k images)
""")
# Main content
tab1, tab2 = st.tabs(["๐ŸŽฅ Live Webcam", "๐Ÿ“ค Upload Image"])
with tab1:
col1, col2 = st.columns([2, 1])
with col1:
st.markdown("### Live Analysis")
rtc_config = RTCConfiguration(
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
)
class VideoProcessor:
def __init__(self):
self.frame_count = 0
def recv(self, frame):
img = frame.to_ndarray(format="bgr24")
self.frame_count += 1
if self.frame_count % 3 == 0:
result = st.session_state.analyzer.analyze_frame(img)
if result['face_location']:
x, y, w, h = result['face_location']
emotion = result['dominant_emotion']
confidence = result['confidence']
color_map = {
'happy': (0, 255, 0), 'sad': (255, 0, 0),
'angry': (0, 0, 255), 'fear': (128, 0, 128),
'surprise': (255, 255, 0), 'neutral': (128, 128, 128),
'disgust': (0, 128, 128)
}
color = color_map.get(emotion, (255, 255, 255))
cv2.rectangle(img, (x, y), (x+w, y+h), color, 2)
label = f"{emotion.upper()} ({confidence:.0%})"
cv2.putText(img, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
if confidence > confidence_threshold:
st.session_state.emotion_history.append(emotion)
st.session_state.stress_scores.append(result['stress_score'])
return av.VideoFrame.from_ndarray(img, format="bgr24")
webrtc_ctx = webrtc_streamer(
key="emotion-detection",
mode=WebRtcMode.SENDRECV,
rtc_configuration=rtc_config,
video_processor_factory=VideoProcessor,
media_stream_constraints={"video": True, "audio": False},
async_processing=True
)
with col2:
st.markdown("### ๐Ÿ“Š Live Metrics")
if len(st.session_state.emotion_history) > 0:
current_emotion = st.session_state.emotion_history[-1]
avg_stress = np.mean(list(st.session_state.stress_scores)[-10:])
emotion_emoji = {
'happy': '๐Ÿ˜Š', 'sad': '๐Ÿ˜ข', 'angry': '๐Ÿ˜ ',
'fear': '๐Ÿ˜จ', 'surprise': '๐Ÿ˜ฎ', 'neutral': '๐Ÿ˜',
'disgust': '๐Ÿคข'
}
st.markdown(f"## {emotion_emoji.get(current_emotion, '๐Ÿ˜')} {current_emotion.title()}")
st.metric("Stress Level", f"{avg_stress:.1%}")
st.progress(avg_stress)
if show_all_emotions:
st.markdown("#### All Emotions")
result = st.session_state.analyzer.analyze_frame(np.zeros((100, 100, 3), dtype=np.uint8))
for emotion, prob in sorted(result['all_emotions'].items(), key=lambda x: x[1], reverse=True):
st.text(f"{emotion.title()}: {prob:.1%}")
else:
st.info("๐Ÿ‘‹ Start webcam to begin")
with tab2:
st.markdown("### Upload an Image")
uploaded_file = st.file_uploader("Choose an image...", type=['jpg', 'jpeg', 'png'])
if uploaded_file:
image = Image.open(uploaded_file)
image_np = np.array(image)
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Uploaded Image", use_column_width=True)
with col2:
result = st.session_state.analyzer.analyze_frame(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
st.markdown("### ๐ŸŽญ Analysis Results")
st.markdown(f"**Emotion:** {result['dominant_emotion'].title()}")
st.markdown(f"**Confidence:** {result['confidence']:.1%}")
st.markdown(f"**Stress:** {result['stress_score']:.1%}")
st.markdown(f"**Valence:** {result['valence']:.2f}")
if show_all_emotions:
st.markdown("#### Emotion Distribution")
for emotion, prob in sorted(result['all_emotions'].items(), key=lambda x: x[1], reverse=True):
st.progress(prob)
st.caption(f"{emotion.title()}: {prob:.1%}")
# Visualizations
if len(st.session_state.emotion_history) > 10:
st.markdown("---")
st.markdown("### ๐Ÿ“ˆ Analysis Dashboard")
col1, col2 = st.columns(2)
with col1:
from collections import Counter
emotion_counts = Counter(st.session_state.emotion_history)
fig = go.Figure(data=[go.Pie(
labels=list(emotion_counts.keys()),
values=list(emotion_counts.values()),
hole=0.4
)])
fig.update_layout(title="Emotion Distribution", height=300)
st.plotly_chart(fig, use_container_width=True)
with col2:
fig = go.Figure()
fig.add_trace(go.Scatter(
y=list(st.session_state.stress_scores),
mode='lines',
fill='tozeroy',
line=dict(color='#667eea', width=2)
))
fig.update_layout(title="Stress Timeline", height=300, yaxis_range=[0, 1])
st.plotly_chart(fig, use_container_width=True)
# Footer
st.markdown("---")
st.markdown("""
<div style='text-align:center; color:rgba(255,255,255,0.7);'>
<p>๐Ÿง  MindSense AI | Built with PyTorch & Streamlit</p>
<p>โš ๏ธ <strong>Disclaimer:</strong> Research tool only. Not for medical diagnosis.</p>
</div>
""", unsafe_allow_html=True)