import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms import cv2 import streamlit as st from PIL import Image import numpy as np import time from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, RTCConfiguration import av # Définition du modèle CNN class EmotionCNN(nn.Module): def __init__(self): super(EmotionCNN, self).__init__() self.conv_layers = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout2d(0.25), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout2d(0.25), nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout2d(0.25) ) self.fc_layers = nn.Sequential( nn.Linear(128 * 6 * 6, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, 8) ) def forward(self, x): x = self.conv_layers(x) x = x.view(x.size(0), -1) x = self.fc_layers(x) return x # Dictionnaire des émotions emotion_dict = { 0: {"name": "Colère", "message": "Respirez profondément et prenez un moment pour vous calmer."}, 1: {"name": "Mépris", "message": "Essayez de voir les choses d'un autre point de vue."}, 2: {"name": "Dégoût", "message": "Concentrez-vous sur les aspects positifs de la situation."}, 3: {"name": "Peur", "message": "Vous êtes en sécurité, prenez votre temps pour vous apaiser."}, 4: {"name": "Bonheur", "message": "Votre sourire illumine la pièce ! Continuez ainsi !"}, 5: {"name": "Neutre", "message": "Vous semblez calme et posé."}, 6: {"name": "Tristesse", "message": "Chaque jour est une nouvelle opportunité. Gardez espoir !"}, 7: {"name": "Surprise", "message": "La vie est pleine de surprises positives !"} } # Configuration de la page Streamlit st.set_page_config(page_title="Détecteur d'Émotions", layout="wide") # Styles CSS personnalisés st.markdown(""" """, unsafe_allow_html=True) # Titre de l'application st.title("🎭 Détecteur d'Émotions en Temps Réel") # Initialisation du modèle @st.cache_resource def load_model(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = EmotionCNN().to(device) try: model.load_state_dict(torch.load("cnn_emotion_model.pth", map_location=device)) model.eval() except Exception as e: st.error(f"Erreur lors du chargement du modèle : {str(e)}") st.stop() return model, device # Chargement du modèle model, device = load_model() # Transformations pour l'image transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((48, 48)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) # Chargement du classificateur Haar pour la détection de visage face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') if face_cascade.empty(): st.error("Erreur : Impossible de charger le classificateur Haar pour la détection de visage.") st.stop() def detect_faces(frame): gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale(gray, 1.1, 4) return faces # Configuration RTC avec plusieurs STUN et TURN RTC_CONFIGURATION = RTCConfiguration({ "iceServers": [ {"urls": "stun:stun.l.google.com:19302"}, {"urls": "stun:stun1.l.google.com:19302"}, {"urls": "stun:stun2.l.google.com:19302"}, {"urls": "stun:stun3.l.google.com:19302"}, {"urls": "stun:stun4.l.google.com:19302"}, {"urls": "stun:stun.stunprotocol.org:3478"}, # Exemple de configuration TURN (remplacez par vos propres identifiants si disponible) { "urls": "turn:your-turn-server.example.com:3478", "username": "your-username", "credential": "your-password" } ] }) # Classe pour traiter les frames vidéo class VideoProcessor(VideoProcessorBase): def __init__(self): self.model = model self.device = device self.transform = transform self.face_cascade = face_cascade self.emotion_dict = emotion_dict self.emotion_placeholder = st.session_state.get('emotion_placeholder') self.message_placeholder = st.session_state.get('message_placeholder') def recv(self, frame): try: img = frame.to_ndarray(format="bgr24") faces = detect_faces(img) for (x, y, w, h) in faces: cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 2) face_img = img[y:y+h, x:x+w] pil_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)) img_tensor = self.transform(pil_img).unsqueeze(0).to(self.device) with torch.no_grad(): output = self.model(img_tensor) _, predicted = torch.max(output, 1) emotion_idx = predicted.item() emotion_name = self.emotion_dict[emotion_idx]["name"] cv2.putText(img, emotion_name, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) # Mettre à jour les placeholders if self.emotion_placeholder: self.emotion_placeholder.markdown(f"""
{emotion_name}
""", unsafe_allow_html=True) if self.message_placeholder: self.message_placeholder.markdown(f"""
{self.emotion_dict[emotion_idx]["message"]}
""", unsafe_allow_html=True) return av.VideoFrame.from_ndarray(img, format="bgr24") except Exception as e: st.error(f"Erreur lors du traitement de la frame : {str(e)}") return frame # Configuration de l'interface col1, col2 = st.columns([2, 1]) with col1: st.markdown("### 📹 Flux Vidéo") try: webrtc_ctx = webrtc_streamer( key="emotion-detection", rtc_configuration=RTC_CONFIGURATION, video_processor_factory=VideoProcessor, media_stream_constraints={"video": True, "audio": False}, async_processing=True ) except Exception as e: st.error(f"Erreur lors de l'initialisation de WebRTC : {str(e)}") st.warning("Vérifiez votre connexion réseau ou les paramètres STUN/TURN.") with col2: st.markdown("### 😊 Émotion Détectée") if 'emotion_placeholder' not in st.session_state: st.session_state.emotion_placeholder = st.empty() if 'message_placeholder' not in st.session_state: st.session_state.message_placeholder = st.empty() emotion_placeholder = st.session_state.emotion_placeholder message_placeholder = st.session_state.message_placeholder st.info("👆 Autorisez l'accès à la webcam dans votre navigateur pour démarrer la détection d'émotions.") st.warning("Si la connexion échoue, vérifiez votre réseau ou configurez un serveur TURN pour WebRTC.") # Option de téléchargement d'image comme solution de secours st.markdown("### 📷 Ou téléchargez une image") uploaded_file = st.file_uploader("Choisissez une image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) frame = np.array(image) faces = detect_faces(frame) for (x, y, w, h) in faces: cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2) face_img = frame[y:y+h, x:x+w] pil_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)) img_tensor = transform(pil_img).unsqueeze(0).to(device) with torch.no_grad(): output = model(img_tensor) _, predicted = torch.max(output, 1) emotion_idx = predicted.item() emotion_name = emotion_dict[emotion_idx]["name"] cv2.putText(frame, emotion_name, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) emotion_placeholder.markdown(f"""
{emotion_name}
""", unsafe_allow_html=True) message_placeholder.markdown(f"""
{emotion_dict[emotion_idx]["message"]}
""", unsafe_allow_html=True) st.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))