Spaces:
Sleeping
Sleeping
| """ | |
| Streamlit Dashboard for Emotion Recognition System. | |
| """ | |
| import io | |
| import sys | |
| from pathlib import Path | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from PIL import Image | |
| # Add project root to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src.config import EMOTION_CLASSES, MODELS_DIR | |
| from src.inference.predictor import EmotionPredictor | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="Emotion Recognition Dashboard", | |
| page_icon="😊", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| font-size: 2.5rem; | |
| font-weight: bold; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| text-align: center; | |
| margin-bottom: 1rem; | |
| } | |
| .emotion-card { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 1.5rem; | |
| border-radius: 1rem; | |
| color: white; | |
| text-align: center; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| } | |
| .confidence-high { | |
| color: #10b981; | |
| font-weight: bold; | |
| } | |
| .confidence-medium { | |
| color: #f59e0b; | |
| font-weight: bold; | |
| } | |
| .confidence-low { | |
| color: #ef4444; | |
| font-weight: bold; | |
| } | |
| .stTabs [data-baseweb="tab-list"] { | |
| gap: 2rem; | |
| } | |
| .stTabs [data-baseweb="tab"] { | |
| height: 50px; | |
| padding-left: 20px; | |
| padding-right: 20px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Emotion emoji mapping | |
| EMOTION_EMOJIS = { | |
| "angry": "😠", | |
| "disgusted": "🤢", | |
| "fearful": "😨", | |
| "happy": "😊", | |
| "neutral": "😐", | |
| "sad": "😢", | |
| "surprised": "😲" | |
| } | |
| # Color palette for emotions | |
| EMOTION_COLORS = { | |
| "angry": "#ef4444", | |
| "disgusted": "#84cc16", | |
| "fearful": "#a855f7", | |
| "happy": "#22c55e", | |
| "neutral": "#6b7280", | |
| "sad": "#3b82f6", | |
| "surprised": "#f59e0b" | |
| } | |
| def load_predictor(model_name: str): | |
| """Load and cache the emotion predictor.""" | |
| predictor = EmotionPredictor(model_name) | |
| if predictor.load(): | |
| return predictor | |
| return None | |
| def get_intensity_class(intensity: str) -> str: | |
| """Get CSS class for intensity.""" | |
| return f"confidence-{intensity}" | |
| def create_probability_chart(probabilities: dict) -> go.Figure: | |
| """Create a horizontal bar chart for probabilities.""" | |
| emotions = list(probabilities.keys()) | |
| values = list(probabilities.values()) | |
| colors = [EMOTION_COLORS.get(e, "#6b7280") for e in emotions] | |
| fig = go.Figure(go.Bar( | |
| x=values, | |
| y=[f"{EMOTION_EMOJIS.get(e, '')} {e.capitalize()}" for e in emotions], | |
| orientation='h', | |
| marker_color=colors, | |
| text=[f"{v:.1%}" for v in values], | |
| textposition='outside' | |
| )) | |
| fig.update_layout( | |
| title="Emotion Probabilities", | |
| xaxis_title="Probability", | |
| yaxis_title="Emotion", | |
| height=350, | |
| margin=dict(l=20, r=20, t=40, b=20), | |
| xaxis=dict(range=[0, 1.1]) | |
| ) | |
| return fig | |
| def create_emotion_distribution_pie(counts: dict) -> go.Figure: | |
| """Create a pie chart for emotion distribution.""" | |
| emotions = [e for e, c in counts.items() if c > 0] | |
| values = [c for c in counts.values() if c > 0] | |
| colors = [EMOTION_COLORS.get(e, "#6b7280") for e in emotions] | |
| fig = go.Figure(go.Pie( | |
| labels=[f"{EMOTION_EMOJIS.get(e, '')} {e.capitalize()}" for e in emotions], | |
| values=values, | |
| marker_colors=colors, | |
| hole=0.4, | |
| textinfo='percent+label' | |
| )) | |
| fig.update_layout( | |
| title="Emotion Distribution", | |
| height=400, | |
| margin=dict(l=20, r=20, t=40, b=20) | |
| ) | |
| return fig | |
| def main(): | |
| """Main dashboard application.""" | |
| # Header | |
| st.markdown('<h1 class="main-header">🎭 Emotion Recognition Dashboard</h1>', unsafe_allow_html=True) | |
| st.markdown("---") | |
| # Sidebar | |
| with st.sidebar: | |
| st.image("https://img.icons8.com/clouds/200/brain.png", width=100) | |
| st.title("⚙️ Settings") | |
| # Model selection | |
| available_models = EmotionPredictor.get_available_models() | |
| model_options = [name for name, available in available_models.items() if available] | |
| if not model_options: | |
| st.error("No trained models found! Please train a model first.") | |
| st.info("Run: `python scripts/train_models.py`") | |
| model_name = None | |
| else: | |
| model_name = st.selectbox( | |
| "🤖 Select Model", | |
| model_options, | |
| format_func=lambda x: { | |
| "custom_cnn": "Custom CNN", | |
| "mobilenet": "MobileNetV2", | |
| "vgg19": "VGG-19" | |
| }.get(x, x) | |
| ) | |
| # Face detection toggle | |
| detect_face = st.toggle("👤 Enable Face Detection", value=True) | |
| # Confidence threshold | |
| confidence_threshold = st.slider( | |
| "📊 Confidence Threshold", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.5, | |
| step=0.05 | |
| ) | |
| st.markdown("---") | |
| # Model info | |
| st.subheader("📋 Model Status") | |
| for name, available in available_models.items(): | |
| icon = "✅" if available else "❌" | |
| display_name = { | |
| "custom_cnn": "Custom CNN", | |
| "mobilenet": "MobileNetV2", | |
| "vgg19": "VGG-19" | |
| }.get(name, name) | |
| st.write(f"{icon} {display_name}") | |
| # Main content | |
| if model_name is None: | |
| st.warning("Please train a model before using the dashboard.") | |
| return | |
| # Load predictor | |
| predictor = load_predictor(model_name) | |
| if predictor is None: | |
| st.error(f"Failed to load model: {model_name}") | |
| return | |
| # Tabs | |
| tab1, tab2, tab3 = st.tabs(["📷 Single Image", "📁 Batch Processing", "📊 Model Performance"]) | |
| # Tab 1: Single Image Analysis | |
| with tab1: | |
| st.subheader("Upload an Image for Emotion Analysis") | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| uploaded_file = st.file_uploader( | |
| "Choose an image...", | |
| type=["jpg", "jpeg", "png", "bmp"], | |
| key="single_upload" | |
| ) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", width="stretch") | |
| with col2: | |
| if uploaded_file is not None: | |
| with st.spinner("Analyzing emotion..."): | |
| # Convert to numpy array | |
| image_array = np.array(image.convert("RGB")) | |
| # Predict | |
| result = predictor.predict(image_array, detect_face=detect_face) | |
| if "error" in result: | |
| st.error(f"❌ {result['error']}") | |
| else: | |
| # Display result | |
| emotion = result["emotion"] | |
| confidence = result["confidence"] | |
| intensity = result["intensity"] | |
| # Emotion card | |
| st.markdown(f""" | |
| <div class="emotion-card"> | |
| <h1 style="font-size: 4rem; margin: 0;">{EMOTION_EMOJIS.get(emotion, '🎭')}</h1> | |
| <h2 style="margin: 0.5rem 0;">{emotion.upper()}</h2> | |
| <p style="font-size: 1.2rem;">Confidence: {confidence:.1%}</p> | |
| <p>Intensity: {intensity.capitalize()}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Probability chart | |
| if "all_probabilities" in result: | |
| fig = create_probability_chart(result["all_probabilities"]) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Face detection info | |
| if result["face_detected"]: | |
| st.success("✅ Face detected successfully") | |
| else: | |
| st.warning("⚠️ No face detected - using full image") | |
| # Tab 2: Batch Processing | |
| with tab2: | |
| st.subheader("Upload Multiple Images for Batch Analysis") | |
| uploaded_files = st.file_uploader( | |
| "Choose images...", | |
| type=["jpg", "jpeg", "png", "bmp"], | |
| accept_multiple_files=True, | |
| key="batch_upload" | |
| ) | |
| if uploaded_files: | |
| st.write(f"📁 {len(uploaded_files)} files selected") | |
| if st.button("🚀 Analyze All", type="primary"): | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| results = [] | |
| images = [] | |
| for i, file in enumerate(uploaded_files): | |
| status_text.text(f"Processing image {i+1}/{len(uploaded_files)}...") | |
| progress_bar.progress((i + 1) / len(uploaded_files)) | |
| try: | |
| image = Image.open(file) | |
| images.append(image) | |
| image_array = np.array(image.convert("RGB")) | |
| result = predictor.predict(image_array, detect_face=detect_face) | |
| result["filename"] = file.name | |
| results.append(result) | |
| except Exception as e: | |
| results.append({"error": str(e), "filename": file.name}) | |
| status_text.text("✅ Analysis complete!") | |
| # Display results | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| # Summary statistics | |
| successful = [r for r in results if "error" not in r] | |
| if successful: | |
| emotion_counts = {} | |
| for r in successful: | |
| emotion = r["emotion"] | |
| emotion_counts[emotion] = emotion_counts.get(emotion, 0) + 1 | |
| # Pie chart | |
| fig = create_emotion_distribution_pie(emotion_counts) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.metric("Total Images", len(results)) | |
| st.metric("Successful", len(successful)) | |
| st.metric("Failed", len(results) - len(successful)) | |
| with col2: | |
| # Results table | |
| table_data = [] | |
| for r in results: | |
| if "error" in r: | |
| table_data.append({ | |
| "File": r.get("filename", "Unknown"), | |
| "Emotion": "❌ Error", | |
| "Confidence": "-", | |
| "Intensity": "-" | |
| }) | |
| else: | |
| table_data.append({ | |
| "File": r.get("filename", "Unknown"), | |
| "Emotion": f"{EMOTION_EMOJIS.get(r['emotion'], '')} {r['emotion'].capitalize()}", | |
| "Confidence": f"{r['confidence']:.1%}", | |
| "Intensity": r["intensity"].capitalize() | |
| }) | |
| df = pd.DataFrame(table_data) | |
| st.dataframe(df, use_container_width=True, height=400) | |
| # Download button | |
| csv = df.to_csv(index=False) | |
| st.download_button( | |
| "📥 Download Results (CSV)", | |
| csv, | |
| "emotion_results.csv", | |
| "text/csv" | |
| ) | |
| # Image gallery with predictions | |
| st.subheader("📷 Analyzed Images") | |
| cols = st.columns(4) | |
| for i, (img, result) in enumerate(zip(images, results)): | |
| with cols[i % 4]: | |
| if "error" not in result: | |
| emoji = EMOTION_EMOJIS.get(result["emotion"], "") | |
| st.image(img, caption=f"{emoji} {result['emotion']}", width="stretch") | |
| else: | |
| st.image(img, caption="❌ Error", width="stretch") | |
| # Tab 3: Model Performance | |
| with tab3: | |
| st.subheader("📊 Model Performance Metrics") | |
| # Check for saved metrics | |
| metrics_path = MODELS_DIR / f"{model_name}.meta.json" | |
| history_path = MODELS_DIR / f"{model_name}.history.json" | |
| if metrics_path.exists(): | |
| import json | |
| with open(metrics_path, 'r') as f: | |
| metadata = json.load(f) | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric( | |
| "Best Validation Accuracy", | |
| f"{metadata.get('best_val_accuracy', 0):.1%}" | |
| ) | |
| with col2: | |
| st.metric( | |
| "Training Duration", | |
| f"{metadata.get('training_duration_seconds', 0)/60:.1f} min" | |
| ) | |
| with col3: | |
| st.metric( | |
| "Epochs Completed", | |
| metadata.get('epochs_completed', 0) | |
| ) | |
| if history_path.exists(): | |
| with open(history_path, 'r') as f: | |
| history = json.load(f) | |
| # Training curves | |
| fig = go.Figure() | |
| epochs = list(range(1, len(history['accuracy']) + 1)) | |
| fig.add_trace(go.Scatter( | |
| x=epochs, y=history['accuracy'], | |
| mode='lines', name='Training Accuracy', | |
| line=dict(color='#3b82f6') | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=epochs, y=history['val_accuracy'], | |
| mode='lines', name='Validation Accuracy', | |
| line=dict(color='#ef4444') | |
| )) | |
| fig.update_layout( | |
| title="Training History", | |
| xaxis_title="Epoch", | |
| yaxis_title="Accuracy", | |
| height=400 | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Loss curves | |
| fig2 = go.Figure() | |
| fig2.add_trace(go.Scatter( | |
| x=epochs, y=history['loss'], | |
| mode='lines', name='Training Loss', | |
| line=dict(color='#3b82f6') | |
| )) | |
| fig2.add_trace(go.Scatter( | |
| x=epochs, y=history['val_loss'], | |
| mode='lines', name='Validation Loss', | |
| line=dict(color='#ef4444') | |
| )) | |
| fig2.update_layout( | |
| title="Loss History", | |
| xaxis_title="Epoch", | |
| yaxis_title="Loss", | |
| height=400 | |
| ) | |
| st.plotly_chart(fig2, use_container_width=True) | |
| else: | |
| st.info("No training metrics found for this model. Train the model to see performance data.") | |
| # Show placeholder | |
| st.markdown(""" | |
| ### Expected Metrics After Training | |
| | Model | Expected Accuracy | Training Time | | |
| |-------|------------------|---------------| | |
| | Custom CNN | 60-68% | ~30 min | | |
| | MobileNetV2 | 65-72% | ~45 min | | |
| | VGG-19 | 68-75% | ~60 min | | |
| """) | |
| if __name__ == "__main__": | |
| main() | |