DavidNgoue commited on
Commit
3c01684
·
verified ·
1 Parent(s): 76dc6f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -52
app.py CHANGED
@@ -15,7 +15,6 @@ class EmotionCNN(nn.Module):
15
  def __init__(self):
16
  super(EmotionCNN, self).__init__()
17
  self.conv_layers = nn.Sequential(
18
- # Premier bloc
19
  nn.Conv2d(1, 32, 3, padding=1),
20
  nn.BatchNorm2d(32),
21
  nn.ReLU(),
@@ -24,8 +23,6 @@ class EmotionCNN(nn.Module):
24
  nn.ReLU(),
25
  nn.MaxPool2d(2),
26
  nn.Dropout2d(0.25),
27
-
28
- # Deuxième bloc
29
  nn.Conv2d(32, 64, 3, padding=1),
30
  nn.BatchNorm2d(64),
31
  nn.ReLU(),
@@ -34,8 +31,6 @@ class EmotionCNN(nn.Module):
34
  nn.ReLU(),
35
  nn.MaxPool2d(2),
36
  nn.Dropout2d(0.25),
37
-
38
- # Troisième bloc
39
  nn.Conv2d(64, 128, 3, padding=1),
40
  nn.BatchNorm2d(128),
41
  nn.ReLU(),
@@ -45,7 +40,6 @@ class EmotionCNN(nn.Module):
45
  nn.MaxPool2d(2),
46
  nn.Dropout2d(0.25)
47
  )
48
-
49
  self.fc_layers = nn.Sequential(
50
  nn.Linear(128 * 6 * 6, 512),
51
  nn.ReLU(),
@@ -62,7 +56,7 @@ class EmotionCNN(nn.Module):
62
  x = self.fc_layers(x)
63
  return x
64
 
65
- # Dictionnaire des émotions et leurs messages associés
66
  emotion_dict = {
67
  0: {"name": "Colère", "message": "Respirez profondément et prenez un moment pour vous calmer."},
68
  1: {"name": "Mépris", "message": "Essayez de voir les choses d'un autre point de vue."},
@@ -129,8 +123,12 @@ st.title("🎭 Détecteur d'Émotions en Temps Réel")
129
  def load_model():
130
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131
  model = EmotionCNN().to(device)
132
- model.load_state_dict(torch.load("cnn_emotion_model.pth", map_location=device))
133
- model.eval()
 
 
 
 
134
  return model, device
135
 
136
  # Chargement du modèle
@@ -146,15 +144,31 @@ transform = transforms.Compose([
146
 
147
  # Chargement du classificateur Haar pour la détection de visage
148
  face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
 
 
 
149
 
150
  def detect_faces(frame):
151
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
152
  faces = face_cascade.detectMultiScale(gray, 1.1, 4)
153
  return faces
154
 
155
- # Configuration RTC pour WebRTC (utile pour Hugging Face Spaces)
156
  RTC_CONFIGURATION = RTCConfiguration({
157
- "iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  })
159
 
160
  # Classe pour traiter les frames vidéo
@@ -169,55 +183,61 @@ class VideoProcessor(VideoProcessorBase):
169
  self.message_placeholder = st.session_state.get('message_placeholder')
170
 
171
  def recv(self, frame):
172
- img = frame.to_ndarray(format="bgr24")
173
- faces = detect_faces(img)
174
-
175
- for (x, y, w, h) in faces:
176
- cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 2)
177
- face_img = img[y:y+h, x:x+w]
178
- pil_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
179
- img_tensor = self.transform(pil_img).unsqueeze(0).to(self.device)
180
- with torch.no_grad():
181
- output = self.model(img_tensor)
182
- _, predicted = torch.max(output, 1)
183
- emotion_idx = predicted.item()
184
- emotion_name = self.emotion_dict[emotion_idx]["name"]
185
- cv2.putText(img, emotion_name, (x, y-10),
186
- cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
187
-
188
- # Mettre à jour les placeholders (utiliser st.session_state pour partager)
189
- if self.emotion_placeholder:
190
- self.emotion_placeholder.markdown(f"""
191
- <div class="emotion-box">
192
- <div class="emotion-title">{emotion_name}</div>
193
- </div>
194
- """, unsafe_allow_html=True)
195
 
196
- if self.message_placeholder:
197
- self.message_placeholder.markdown(f"""
198
- <div class="emotion-box">
199
- <div class="emotion-message">{self.emotion_dict[emotion_idx]["message"]}</div>
200
- </div>
201
- """, unsafe_allow_html=True)
202
-
203
- return av.VideoFrame.from_ndarray(img, format="bgr24")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- # Configuration de la webcam avec streamlit-webrtc
206
  col1, col2 = st.columns([2, 1])
207
 
208
  with col1:
209
  st.markdown("### 📹 Flux Vidéo")
210
- # Lancer le flux webcam
211
- webrtc_ctx = webrtc_streamer(
212
- key="emotion-detection",
213
- rtc_configuration=RTC_CONFIGURATION,
214
- video_processor_factory=VideoProcessor,
215
- media_stream_constraints={"video": True, "audio": False},
216
- )
 
 
 
 
217
 
218
  with col2:
219
  st.markdown("### 😊 Émotion Détectée")
220
- # Utiliser session_state pour partager les placeholders avec VideoProcessor
221
  if 'emotion_placeholder' not in st.session_state:
222
  st.session_state.emotion_placeholder = st.empty()
223
  if 'message_placeholder' not in st.session_state:
@@ -226,4 +246,40 @@ with col2:
226
  emotion_placeholder = st.session_state.emotion_placeholder
227
  message_placeholder = st.session_state.message_placeholder
228
 
229
- st.info("👆 Autorisez l'accès à la webcam dans votre navigateur pour démarrer la détection d'émotions.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def __init__(self):
16
  super(EmotionCNN, self).__init__()
17
  self.conv_layers = nn.Sequential(
 
18
  nn.Conv2d(1, 32, 3, padding=1),
19
  nn.BatchNorm2d(32),
20
  nn.ReLU(),
 
23
  nn.ReLU(),
24
  nn.MaxPool2d(2),
25
  nn.Dropout2d(0.25),
 
 
26
  nn.Conv2d(32, 64, 3, padding=1),
27
  nn.BatchNorm2d(64),
28
  nn.ReLU(),
 
31
  nn.ReLU(),
32
  nn.MaxPool2d(2),
33
  nn.Dropout2d(0.25),
 
 
34
  nn.Conv2d(64, 128, 3, padding=1),
35
  nn.BatchNorm2d(128),
36
  nn.ReLU(),
 
40
  nn.MaxPool2d(2),
41
  nn.Dropout2d(0.25)
42
  )
 
43
  self.fc_layers = nn.Sequential(
44
  nn.Linear(128 * 6 * 6, 512),
45
  nn.ReLU(),
 
56
  x = self.fc_layers(x)
57
  return x
58
 
59
+ # Dictionnaire des émotions
60
  emotion_dict = {
61
  0: {"name": "Colère", "message": "Respirez profondément et prenez un moment pour vous calmer."},
62
  1: {"name": "Mépris", "message": "Essayez de voir les choses d'un autre point de vue."},
 
123
  def load_model():
124
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
125
  model = EmotionCNN().to(device)
126
+ try:
127
+ model.load_state_dict(torch.load("cnn_emotion_model.pth", map_location=device))
128
+ model.eval()
129
+ except Exception as e:
130
+ st.error(f"Erreur lors du chargement du modèle : {str(e)}")
131
+ st.stop()
132
  return model, device
133
 
134
  # Chargement du modèle
 
144
 
145
  # Chargement du classificateur Haar pour la détection de visage
146
  face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
147
+ if face_cascade.empty():
148
+ st.error("Erreur : Impossible de charger le classificateur Haar pour la détection de visage.")
149
+ st.stop()
150
 
151
  def detect_faces(frame):
152
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
153
  faces = face_cascade.detectMultiScale(gray, 1.1, 4)
154
  return faces
155
 
156
+ # Configuration RTC avec plusieurs STUN et TURN
157
  RTC_CONFIGURATION = RTCConfiguration({
158
+ "iceServers": [
159
+ {"urls": "stun:stun.l.google.com:19302"},
160
+ {"urls": "stun:stun1.l.google.com:19302"},
161
+ {"urls": "stun:stun2.l.google.com:19302"},
162
+ {"urls": "stun:stun3.l.google.com:19302"},
163
+ {"urls": "stun:stun4.l.google.com:19302"},
164
+ {"urls": "stun:stun.stunprotocol.org:3478"},
165
+ # Exemple de configuration TURN (remplacez par vos propres identifiants si disponible)
166
+ {
167
+ "urls": "turn:your-turn-server.example.com:3478",
168
+ "username": "your-username",
169
+ "credential": "your-password"
170
+ }
171
+ ]
172
  })
173
 
174
  # Classe pour traiter les frames vidéo
 
183
  self.message_placeholder = st.session_state.get('message_placeholder')
184
 
185
  def recv(self, frame):
186
+ try:
187
+ img = frame.to_ndarray(format="bgr24")
188
+ faces = detect_faces(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ for (x, y, w, h) in faces:
191
+ cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 2)
192
+ face_img = img[y:y+h, x:x+w]
193
+ pil_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
194
+ img_tensor = self.transform(pil_img).unsqueeze(0).to(self.device)
195
+ with torch.no_grad():
196
+ output = self.model(img_tensor)
197
+ _, predicted = torch.max(output, 1)
198
+ emotion_idx = predicted.item()
199
+ emotion_name = self.emotion_dict[emotion_idx]["name"]
200
+ cv2.putText(img, emotion_name, (x, y-10),
201
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
202
+
203
+ # Mettre à jour les placeholders
204
+ if self.emotion_placeholder:
205
+ self.emotion_placeholder.markdown(f"""
206
+ <div class="emotion-box">
207
+ <div class="emotion-title">{emotion_name}</div>
208
+ </div>
209
+ """, unsafe_allow_html=True)
210
+
211
+ if self.message_placeholder:
212
+ self.message_placeholder.markdown(f"""
213
+ <div class="emotion-box">
214
+ <div class="emotion-message">{self.emotion_dict[emotion_idx]["message"]}</div>
215
+ </div>
216
+ """, unsafe_allow_html=True)
217
+ return av.VideoFrame.from_ndarray(img, format="bgr24")
218
+ except Exception as e:
219
+ st.error(f"Erreur lors du traitement de la frame : {str(e)}")
220
+ return frame
221
 
222
+ # Configuration de l'interface
223
  col1, col2 = st.columns([2, 1])
224
 
225
  with col1:
226
  st.markdown("### 📹 Flux Vidéo")
227
+ try:
228
+ webrtc_ctx = webrtc_streamer(
229
+ key="emotion-detection",
230
+ rtc_configuration=RTC_CONFIGURATION,
231
+ video_processor_factory=VideoProcessor,
232
+ media_stream_constraints={"video": True, "audio": False},
233
+ async_processing=True
234
+ )
235
+ except Exception as e:
236
+ st.error(f"Erreur lors de l'initialisation de WebRTC : {str(e)}")
237
+ st.warning("Vérifiez votre connexion réseau ou les paramètres STUN/TURN.")
238
 
239
  with col2:
240
  st.markdown("### 😊 Émotion Détectée")
 
241
  if 'emotion_placeholder' not in st.session_state:
242
  st.session_state.emotion_placeholder = st.empty()
243
  if 'message_placeholder' not in st.session_state:
 
246
  emotion_placeholder = st.session_state.emotion_placeholder
247
  message_placeholder = st.session_state.message_placeholder
248
 
249
+ st.info("👆 Autorisez l'accès à la webcam dans votre navigateur pour démarrer la détection d'émotions.")
250
+ st.warning("Si la connexion échoue, vérifiez votre réseau ou configurez un serveur TURN pour WebRTC.")
251
+
252
+ # Option de téléchargement d'image comme solution de secours
253
+ st.markdown("### 📷 Ou téléchargez une image")
254
+ uploaded_file = st.file_uploader("Choisissez une image...", type=["jpg", "jpeg", "png"])
255
+ if uploaded_file is not None:
256
+ image = Image.open(uploaded_file)
257
+ frame = np.array(image)
258
+ faces = detect_faces(frame)
259
+
260
+ for (x, y, w, h) in faces:
261
+ cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
262
+ face_img = frame[y:y+h, x:x+w]
263
+ pil_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
264
+ img_tensor = transform(pil_img).unsqueeze(0).to(device)
265
+ with torch.no_grad():
266
+ output = model(img_tensor)
267
+ _, predicted = torch.max(output, 1)
268
+ emotion_idx = predicted.item()
269
+ emotion_name = emotion_dict[emotion_idx]["name"]
270
+ cv2.putText(frame, emotion_name, (x, y-10),
271
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
272
+
273
+ emotion_placeholder.markdown(f"""
274
+ <div class="emotion-box">
275
+ <div class="emotion-title">{emotion_name}</div>
276
+ </div>
277
+ """, unsafe_allow_html=True)
278
+
279
+ message_placeholder.markdown(f"""
280
+ <div class="emotion-box">
281
+ <div class="emotion-message">{emotion_dict[emotion_idx]["message"]}</div>
282
+ </div>
283
+ """, unsafe_allow_html=True)
284
+
285
+ st.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))