DavidNgoue's picture
Update app.py
3c01684 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import cv2
import streamlit as st
from PIL import Image
import numpy as np
import time
from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, RTCConfiguration
import av
# Définition du modèle CNN
class EmotionCNN(nn.Module):
def __init__(self):
super(EmotionCNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.25),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.25),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.25)
)
self.fc_layers = nn.Sequential(
nn.Linear(128 * 6 * 6, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, 8)
)
def forward(self, x):
x = self.conv_layers(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
return x
# Dictionnaire des émotions
emotion_dict = {
0: {"name": "Colère", "message": "Respirez profondément et prenez un moment pour vous calmer."},
1: {"name": "Mépris", "message": "Essayez de voir les choses d'un autre point de vue."},
2: {"name": "Dégoût", "message": "Concentrez-vous sur les aspects positifs de la situation."},
3: {"name": "Peur", "message": "Vous êtes en sécurité, prenez votre temps pour vous apaiser."},
4: {"name": "Bonheur", "message": "Votre sourire illumine la pièce ! Continuez ainsi !"},
5: {"name": "Neutre", "message": "Vous semblez calme et posé."},
6: {"name": "Tristesse", "message": "Chaque jour est une nouvelle opportunité. Gardez espoir !"},
7: {"name": "Surprise", "message": "La vie est pleine de surprises positives !"}
}
# Configuration de la page Streamlit
st.set_page_config(page_title="Détecteur d'Émotions", layout="wide")
# Styles CSS personnalisés
st.markdown("""
<style>
.main {
background-color: #f5f5f5;
}
.stButton>button {
background-color: #4CAF50;
color: white;
padding: 15px 32px;
text-align: center;
text-decoration: none;
display: inline-block;
font-size: 16px;
margin: 4px 2px;
cursor: pointer;
border-radius: 12px;
border: none;
transition-duration: 0.4s;
}
.stButton>button:hover {
background-color: #45a049;
}
.emotion-box {
padding: 20px;
border-radius: 10px;
background-color: white;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
margin: 10px 0;
}
.emotion-title {
color: #333;
font-size: 24px;
font-weight: bold;
margin-bottom: 10px;
}
.emotion-message {
color: #666;
font-size: 18px;
line-height: 1.5;
}
</style>
""", unsafe_allow_html=True)
# Titre de l'application
st.title("🎭 Détecteur d'Émotions en Temps Réel")
# Initialisation du modèle
@st.cache_resource
def load_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EmotionCNN().to(device)
try:
model.load_state_dict(torch.load("cnn_emotion_model.pth", map_location=device))
model.eval()
except Exception as e:
st.error(f"Erreur lors du chargement du modèle : {str(e)}")
st.stop()
return model, device
# Chargement du modèle
model, device = load_model()
# Transformations pour l'image
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((48, 48)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# Chargement du classificateur Haar pour la détection de visage
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
if face_cascade.empty():
st.error("Erreur : Impossible de charger le classificateur Haar pour la détection de visage.")
st.stop()
def detect_faces(frame):
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
return faces
# Configuration RTC avec plusieurs STUN et TURN
RTC_CONFIGURATION = RTCConfiguration({
"iceServers": [
{"urls": "stun:stun.l.google.com:19302"},
{"urls": "stun:stun1.l.google.com:19302"},
{"urls": "stun:stun2.l.google.com:19302"},
{"urls": "stun:stun3.l.google.com:19302"},
{"urls": "stun:stun4.l.google.com:19302"},
{"urls": "stun:stun.stunprotocol.org:3478"},
# Exemple de configuration TURN (remplacez par vos propres identifiants si disponible)
{
"urls": "turn:your-turn-server.example.com:3478",
"username": "your-username",
"credential": "your-password"
}
]
})
# Classe pour traiter les frames vidéo
class VideoProcessor(VideoProcessorBase):
def __init__(self):
self.model = model
self.device = device
self.transform = transform
self.face_cascade = face_cascade
self.emotion_dict = emotion_dict
self.emotion_placeholder = st.session_state.get('emotion_placeholder')
self.message_placeholder = st.session_state.get('message_placeholder')
def recv(self, frame):
try:
img = frame.to_ndarray(format="bgr24")
faces = detect_faces(img)
for (x, y, w, h) in faces:
cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 2)
face_img = img[y:y+h, x:x+w]
pil_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
img_tensor = self.transform(pil_img).unsqueeze(0).to(self.device)
with torch.no_grad():
output = self.model(img_tensor)
_, predicted = torch.max(output, 1)
emotion_idx = predicted.item()
emotion_name = self.emotion_dict[emotion_idx]["name"]
cv2.putText(img, emotion_name, (x, y-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# Mettre à jour les placeholders
if self.emotion_placeholder:
self.emotion_placeholder.markdown(f"""
<div class="emotion-box">
<div class="emotion-title">{emotion_name}</div>
</div>
""", unsafe_allow_html=True)
if self.message_placeholder:
self.message_placeholder.markdown(f"""
<div class="emotion-box">
<div class="emotion-message">{self.emotion_dict[emotion_idx]["message"]}</div>
</div>
""", unsafe_allow_html=True)
return av.VideoFrame.from_ndarray(img, format="bgr24")
except Exception as e:
st.error(f"Erreur lors du traitement de la frame : {str(e)}")
return frame
# Configuration de l'interface
col1, col2 = st.columns([2, 1])
with col1:
st.markdown("### 📹 Flux Vidéo")
try:
webrtc_ctx = webrtc_streamer(
key="emotion-detection",
rtc_configuration=RTC_CONFIGURATION,
video_processor_factory=VideoProcessor,
media_stream_constraints={"video": True, "audio": False},
async_processing=True
)
except Exception as e:
st.error(f"Erreur lors de l'initialisation de WebRTC : {str(e)}")
st.warning("Vérifiez votre connexion réseau ou les paramètres STUN/TURN.")
with col2:
st.markdown("### 😊 Émotion Détectée")
if 'emotion_placeholder' not in st.session_state:
st.session_state.emotion_placeholder = st.empty()
if 'message_placeholder' not in st.session_state:
st.session_state.message_placeholder = st.empty()
emotion_placeholder = st.session_state.emotion_placeholder
message_placeholder = st.session_state.message_placeholder
st.info("👆 Autorisez l'accès à la webcam dans votre navigateur pour démarrer la détection d'émotions.")
st.warning("Si la connexion échoue, vérifiez votre réseau ou configurez un serveur TURN pour WebRTC.")
# Option de téléchargement d'image comme solution de secours
st.markdown("### 📷 Ou téléchargez une image")
uploaded_file = st.file_uploader("Choisissez une image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
frame = np.array(image)
faces = detect_faces(frame)
for (x, y, w, h) in faces:
cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
face_img = frame[y:y+h, x:x+w]
pil_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
img_tensor = transform(pil_img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
_, predicted = torch.max(output, 1)
emotion_idx = predicted.item()
emotion_name = emotion_dict[emotion_idx]["name"]
cv2.putText(frame, emotion_name, (x, y-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
emotion_placeholder.markdown(f"""
<div class="emotion-box">
<div class="emotion-title">{emotion_name}</div>
</div>
""", unsafe_allow_html=True)
message_placeholder.markdown(f"""
<div class="emotion-box">
<div class="emotion-message">{emotion_dict[emotion_idx]["message"]}</div>
</div>
""", unsafe_allow_html=True)
st.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))