ModuMLTECH commited on
Commit
320dc3c
·
verified ·
1 Parent(s): f33ab54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -149
app.py CHANGED
@@ -9,13 +9,6 @@ import threading
9
  from PIL import Image
10
  import torch
11
 
12
- # ---- Contexte Streamlit pour threads (safe fallback) ----
13
- try:
14
- from streamlit.runtime.scriptrunner import add_script_run_ctx
15
- except Exception:
16
- def add_script_run_ctx(t):
17
- return t
18
-
19
  # --- FONCTIONS UTILES ---
20
  def draw_text_with_background(
21
  image,
@@ -56,10 +49,10 @@ class YOLOVideoProcessor:
56
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
57
 
58
  # Paramètres d'optimisation
59
- self.frame_skip = 2 # Traiter une image sur N
60
- self.downsample_factor = 0.5 # Réduire la taille des images
61
  self.img_size = 640
62
- self.conf_threshold = 0.5 # plus strict par défaut
63
 
64
  # Modèle
65
  self.model = YOLO(model_path)
@@ -78,19 +71,24 @@ class YOLOVideoProcessor:
78
  self.last_processed_frame = None
79
  self.current_frame = 0
80
 
81
- # Filtres anti-roues (valeurs relatives à l'image d'affichage)
82
- self.min_w_ratio = 0.04 # largeur >= 4% de la largeur du frame
83
- self.min_h_ratio = 0.05 # hauteur >= 5% de la hauteur du frame
84
- self.min_area_ratio = 0.0025 # aire >= 0.25% de l'aire du frame
 
 
 
 
85
 
86
  @staticmethod
87
- def is_in_region(point, poly):
88
  poly_np = np.array(poly, dtype=np.int32)
89
- return cv2.pointPolygonTest(poly_np, point, False) >= 0
90
 
91
  def reset_counts(self):
92
  self.unique_region1_ids.clear()
93
  self.unique_region2_ids.clear()
 
94
 
95
  def _pick_fourcc(self, output_path):
96
  ext = os.path.splitext(output_path)[1].lower()
@@ -98,26 +96,117 @@ class YOLOVideoProcessor:
98
  return cv2.VideoWriter_fourcc(*"mp4v")
99
  return cv2.VideoWriter_fourcc(*"XVID")
100
 
101
- # ---------- TRAITEMENT VIDEO (thread principal) ----------
102
- def process_video(self, video_path, output_path, progress_bar=None, status_placeholder=None):
103
- """Traite une vidéo enregistrée avec optimisations (aucun st.* ici)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  cap = cv2.VideoCapture(video_path)
105
  if not cap.isOpened():
106
- if status_placeholder:
107
- status_placeholder.error("⚠️ Impossible d'ouvrir la vidéo.")
108
  return
109
 
110
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
111
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
112
  fps = cap.get(cv2.CAP_PROP_FPS)
113
  if not fps or fps <= 1e-3:
114
- fps = 30.0 # défaut
115
 
116
  fourcc = self._pick_fourcc(output_path)
117
  out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
118
  if not out.isOpened():
119
- if status_placeholder:
120
- status_placeholder.error("⚠️ Impossible d'ouvrir la vidéo de sortie (codec).")
121
  cap.release()
122
  return
123
 
@@ -131,11 +220,14 @@ class YOLOVideoProcessor:
131
  if not success:
132
  break
133
 
 
134
  if progress_bar is not None and total_frames > 0:
135
- progress_bar.progress(min(1.0, processed_frames / float(total_frames)))
 
136
 
 
137
  if frame_count % self.frame_skip == 0:
138
- processed_frame = self.process_frame(frame)
139
  self.last_processed_frame = processed_frame
140
  else:
141
  processed_frame = self.last_processed_frame if self.last_processed_frame is not None else frame
@@ -143,6 +235,7 @@ class YOLOVideoProcessor:
143
  if processed_frame is None:
144
  processed_frame = frame
145
 
 
146
  if processed_frame.shape[1] != frame_width or processed_frame.shape[0] != frame_height:
147
  processed_frame = cv2.resize(processed_frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA)
148
 
@@ -154,101 +247,122 @@ class YOLOVideoProcessor:
154
  out.release()
155
  cv2.destroyAllWindows()
156
 
157
- if processed_frames == 0 and status_placeholder:
158
- status_placeholder.error("⚠️ Aucune image n'a été écrite dans la vidéo de sortie !")
159
 
160
  return len(self.unique_region1_ids), len(self.unique_region2_ids)
161
 
162
- # ---------- TRAITEMENT PAR IMAGE ----------
163
- def process_frame(self, frame):
164
- """Traite une image individuelle avec YOLO + tracking, optimisé et filtré anti-roues."""
165
  if frame is None:
166
  return None
167
 
168
- # Downscale pour accélérer
169
- orig_h, orig_w = frame.shape[:2]
170
- resized_w, resized_h = orig_w, orig_h
171
  if self.downsample_factor < 1.0:
172
- resized_w = max(1, int(orig_w * self.downsample_factor))
173
- resized_h = max(1, int(orig_h * self.downsample_factor))
174
- resized_frame = cv2.resize(frame, (resized_w, resized_h), interpolation=cv2.INTER_AREA)
175
  else:
176
  resized_frame = frame
177
 
178
- # Détection + tracking (avec classes & iou)
179
  with torch.no_grad():
180
  results = self.model.track(
181
  resized_frame,
182
  persist=True,
183
  tracker=self.tracker_config,
184
  conf=self.conf_threshold,
185
- iou=0.5,
186
  imgsz=self.img_size,
187
  device=self.device,
188
- classes=[2, 3, 5, 7], # voitures, motos, bus, camions (COCO)
 
189
  )
190
 
191
- display = frame.copy()
192
- H, W = display.shape[:2]
193
-
194
- # Polylines
195
- cv2.polylines(display, [np.array(self.poly1, np.int32)], isClosed=True, color=(0, 255, 0), thickness=2)
196
- cv2.polylines(display, [np.array(self.poly2, np.int32)], isClosed=True, color=(255, 0, 0), thickness=2)
197
 
198
- # Échelle vers taille originale
199
- sx = orig_w / float(resized_w)
200
- sy = orig_h / float(resized_h)
201
 
202
- # Seuils anti-roues (relatifs à l'image d'affichage)
203
- min_w = int(self.min_w_ratio * W)
204
- min_h = int(self.min_h_ratio * H)
205
- min_area = int(self.min_area_ratio * W * H)
206
 
207
  if results and len(results) > 0 and getattr(results[0], "boxes", None) is not None:
208
  try:
209
  boxes = results[0].boxes.xywh.cpu().numpy()
210
  ids_tensor = results[0].boxes.id
211
- track_ids = ([None] * len(boxes)) if ids_tensor is None else ids_tensor.int().cpu().tolist()
212
-
213
- for (x, y, w, h), tid in zip(boxes, track_ids):
214
- # Rescale
215
- cx = int(x * sx)
216
- cy = int(y * sy)
217
- ww = int(w * sx)
218
- hh = int(h * sy)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- # --- FILTRE TAILLE MIN (anti-roues) ---
221
- if ww < min_w or hh < min_h or (ww * hh) < min_area:
 
222
  continue
223
-
224
- # Point de comptage : bas de la boîte (bottom-center)
225
- bottom_center = (cx, cy + hh // 2)
226
-
227
- if tid is not None:
228
- if self.is_in_region(bottom_center, self.poly1):
229
- self.unique_region1_ids.add(tid)
230
- if self.is_in_region(bottom_center, self.poly2):
231
- self.unique_region2_ids.add(tid)
232
-
233
- # Dessin bbox
234
- tl = (max(0, cx - ww // 2), max(0, cy - hh // 2))
235
- br = (min(W - 1, cx + ww // 2), min(H - 1, cy + hh // 2))
236
- cv2.rectangle(display, tl, br, (0, 255, 0), 2)
237
-
 
 
 
 
 
 
 
 
 
 
238
  except Exception as e:
239
- draw_text_with_background(display, f"Tracking error: {e}", (10, 60), bg_color=(80, 0, 0))
240
 
241
- draw_text_with_background(display, f"Total Sens 1: {len(self.unique_region1_ids)}", (10, H - 50))
242
- draw_text_with_background(display, f"Total Sens 2: {len(self.unique_region2_ids)}", (W - 300, H - 50))
243
- return display
244
 
245
- # ---------- CAPTURE WEBCAM (thread secondaire, aucun st.*) ----------
246
- def process_webcam(self, camera_id=0, display_placeholder=None, count_placeholders=None, status_placeholder=None):
247
- """Traite la vidéo en temps réel depuis une webcam (aucun appel direct à streamlit dans ce thread)."""
 
248
  cap = cv2.VideoCapture(camera_id)
249
  if not cap.isOpened():
250
- if status_placeholder:
251
- status_placeholder.error("⚠️ Impossible d'ouvrir la webcam.")
252
  return
253
 
254
  try:
@@ -263,50 +377,43 @@ class YOLOVideoProcessor:
263
  frame_count = 0
264
  last_ts = time.time()
265
 
266
- # Afficher une première image (pour signaler la connexion)
267
- ok, first = cap.read()
268
- if ok and display_placeholder:
269
- try:
270
- rgb0 = cv2.cvtColor(first, cv2.COLOR_BGR2RGB)
271
- except Exception:
272
- rgb0 = first
273
- display_placeholder.image(Image.fromarray(rgb0), channels="RGB", use_column_width=True, caption="Webcam connectée")
274
-
275
  while not self.stop_processing:
276
  success, frame = cap.read()
277
  if not success:
278
- if status_placeholder:
279
- status_placeholder.error("⚠️ Erreur lors de la lecture du flux vidéo.")
280
  break
281
 
282
  if frame_count % self.frame_skip == 0:
283
- processed = self.process_frame(frame)
284
- self.last_processed_frame = processed
285
  now = time.time()
286
- fps = 1.0 / max(1e-6, (now - last_ts))
 
287
  last_ts = now
288
- if processed is not None:
289
- draw_text_with_background(processed, f"FPS: {fps:.1f}", (10, 30))
290
  else:
291
- processed = self.last_processed_frame if self.last_processed_frame is not None else frame
292
 
293
- if processed is not None and display_placeholder:
294
  try:
295
- rgb = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)
296
  except Exception:
297
- rgb = processed
298
- display_placeholder.image(Image.fromarray(rgb), channels="RGB", use_column_width=True)
 
 
 
299
 
300
- if count_placeholders and len(count_placeholders) >= 2:
301
- count_placeholders[0].metric("Véhicules Sens 1 (Vert)", len(self.unique_region1_ids))
302
- count_placeholders[1].metric("Véhicules Sens 2 (Rouge)", len(self.unique_region2_ids))
303
 
304
  frame_count += 1
305
  time.sleep(0.01)
306
 
307
  cap.release()
308
- if status_placeholder:
309
- status_placeholder.success("✅ Flux vidéo arrêté.")
310
 
311
 
312
  # --- INTERFACE STREAMLIT ---
@@ -319,9 +426,9 @@ def main():
319
 
320
  st.title("🚗 Détection et comptage de Véhicules sur l'Autoroute de l'Avenir")
321
 
 
322
  st.session_state.setdefault("webcam_active", False)
323
  st.session_state.setdefault("processor", None)
324
- st.session_state.setdefault("processing_thread", None)
325
 
326
  # Modèle
327
  model_path = "best.pt"
@@ -345,15 +452,20 @@ def main():
345
 
346
  st.subheader("📍 Polygone 1 (vert)")
347
  poly1_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "900,350 1150,350 700,630 200,630")
 
348
  st.subheader("📍 Polygone 2 (rouge)")
349
  poly2_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "1200,350 1400,350 1150,630 743,630")
350
 
351
- tracker_method = st.selectbox("Méthode de tracking", ["bot", "byte"], index=1) # byte par défaut
352
 
353
  st.subheader("🚀 Paramètres d'optimisation")
354
- frame_skip = st.slider("Skip de frames (plus élevé = plus rapide)", 1, 5, 2)
355
- downsample = st.slider("Facteur d'échelle (plus petit = plus rapide)", 0.3, 1.0, 0.5, 0.1)
356
- conf_threshold = st.slider("Seuil de confiance", 0.1, 0.9, 0.5, 0.05) # 0.5 par défaut
 
 
 
 
357
 
358
  def parse_polygon(input_text):
359
  try:
@@ -371,9 +483,7 @@ def main():
371
 
372
  # Onglet 1: Analyse vidéo
373
  with tab1:
374
- uploaded_file = st.file_uploader("📂 Upload une vidéo", type=["mp4", "avi", "mov", "mkv"])
375
- status_vid = st.empty()
376
-
377
  if uploaded_file is not None:
378
  temp_dir = tempfile.mkdtemp()
379
  ext = os.path.splitext(uploaded_file.name)[1].lower() or ".mp4"
@@ -392,38 +502,37 @@ def main():
392
  processor.frame_skip = frame_skip
393
  processor.downsample_factor = downsample
394
  processor.conf_threshold = conf_threshold
 
 
395
 
396
  start_time = time.time()
397
- counts = processor.process_video(
398
- input_video_path, output_video_path,
399
- progress_bar=progress_bar,
400
- status_placeholder=status_vid
401
- )
402
  end_time = time.time()
403
  if counts:
404
- c1, c2 = counts
405
- st.success(f"✅ Terminé en {end_time - start_time:.2f} s")
 
406
  col_result1, col_result2 = st.columns(2)
407
- col_result1.metric("Véhicules Sens 1 (Vert)", c1)
408
- col_result2.metric("Véhicules Sens 2 (Rouge)", c2)
409
 
410
  st.subheader("Vidéo traitée")
411
  st.video(output_video_path)
 
412
  with open(output_video_path, "rb") as file:
413
  st.download_button(
414
  label="⬇️ Télécharger la vidéo",
415
  data=file,
416
  file_name=f"video_traitee{ext}",
417
- mime=f"video/{ext.strip('.') or 'mp4'}",
418
  )
419
  else:
420
- st.error("❌ Les polygones doivent contenir **exactement 4 points**.")
421
 
422
  # Onglet 2: Webcam
423
  with tab2:
424
  st.header("Détection en Temps Réel avec Webcam")
425
 
426
- # Découverte simple des caméras locales
427
  camera_options = {"Webcam par défaut": 0}
428
  for i in range(1, 5):
429
  try:
@@ -438,17 +547,16 @@ def main():
438
  camera_id = camera_options[selected_camera]
439
 
440
  video_placeholder = st.empty()
441
- status_cam = st.empty()
442
  col1, col2 = st.columns(2)
443
  count_placeholders = [col1.empty(), col2.empty()]
444
 
445
- st.info("ℹ️ Optimisations: redimensionnement, skip de frames, CUDA si disponible.")
446
 
447
  col_start, col_stop = st.columns(2)
448
 
449
  if col_start.button("▶️ Démarrer la détection en direct"):
450
  if not valid_polygons:
451
- st.error("❌ Les polygones doivent contenir **exactement 4 points**.")
452
  elif st.session_state.webcam_active:
453
  st.warning("⚠️ La webcam est déjà active !")
454
  else:
@@ -456,31 +564,24 @@ def main():
456
  processor.frame_skip = frame_skip
457
  processor.downsample_factor = downsample
458
  processor.conf_threshold = conf_threshold
 
 
459
 
460
  st.session_state.processor = processor
461
  st.session_state.webcam_active = True
462
 
463
- t = threading.Thread(
464
  target=st.session_state.processor.process_webcam,
465
- args=(camera_id, video_placeholder, count_placeholders, status_cam),
466
  daemon=True,
467
- )
468
- add_script_run_ctx(t) # <— attache le contexte Streamlit
469
- t.start()
470
- st.session_state.processing_thread = t
471
 
472
  if col_stop.button("⏹️ Arrêter la détection"):
473
  if st.session_state.webcam_active and st.session_state.processor:
474
  st.session_state.processor.stop_processing = True
475
  st.session_state.webcam_active = False
476
- # attendre la fin proprement
477
- t = st.session_state.get("processing_thread")
478
- if t:
479
- t.join(timeout=2.0)
480
- st.session_state.processing_thread = None
481
- time.sleep(0.3)
482
  video_placeholder.empty()
483
- status_cam.info("Arrêt demandé.")
484
  else:
485
  st.warning("⚠️ Aucune détection en cours !")
486
 
 
9
  from PIL import Image
10
  import torch
11
 
 
 
 
 
 
 
 
12
  # --- FONCTIONS UTILES ---
13
  def draw_text_with_background(
14
  image,
 
49
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
 
51
  # Paramètres d'optimisation
52
+ self.frame_skip = 2
53
+ self.downsample_factor = 0.5
54
  self.img_size = 640
55
+ self.conf_threshold = 0.35
56
 
57
  # Modèle
58
  self.model = YOLO(model_path)
 
71
  self.last_processed_frame = None
72
  self.current_frame = 0
73
 
74
+ # Paramètres anti-duplicata pour camions longs
75
+ self.iou_threshold = 0.3 # Seuil IoU pour fusionner les détections proches
76
+ self.min_box_area = 500 # Surface minimale pour être considéré comme véhicule
77
+ self.max_aspect_ratio = 5.0 # Ratio hauteur/largeur max pour éviter détections étirées
78
+
79
+ # Historique des détections pour filtrage temporel
80
+ self.detection_history = {} # {track_id: {'boxes': [], 'frames': []}}
81
+ self.history_length = 5 # Nombre de frames à garder en mémoire
82
 
83
  @staticmethod
84
+ def is_in_region(center, poly):
85
  poly_np = np.array(poly, dtype=np.int32)
86
+ return cv2.pointPolygonTest(poly_np, center, False) >= 0
87
 
88
  def reset_counts(self):
89
  self.unique_region1_ids.clear()
90
  self.unique_region2_ids.clear()
91
+ self.detection_history.clear()
92
 
93
  def _pick_fourcc(self, output_path):
94
  ext = os.path.splitext(output_path)[1].lower()
 
96
  return cv2.VideoWriter_fourcc(*"mp4v")
97
  return cv2.VideoWriter_fourcc(*"XVID")
98
 
99
+ def calculate_iou(self, box1, box2):
100
+ """Calcule l'IoU (Intersection over Union) entre deux boîtes"""
101
+ x1_min, y1_min, x1_max, y1_max = box1
102
+ x2_min, y2_min, x2_max, y2_max = box2
103
+
104
+ # Intersection
105
+ inter_x_min = max(x1_min, x2_min)
106
+ inter_y_min = max(y1_min, y2_min)
107
+ inter_x_max = min(x1_max, x2_max)
108
+ inter_y_max = min(y1_max, y2_max)
109
+
110
+ inter_area = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min)
111
+
112
+ # Union
113
+ box1_area = (x1_max - x1_min) * (y1_max - y1_min)
114
+ box2_area = (x2_max - x2_min) * (y2_max - y2_min)
115
+ union_area = box1_area + box2_area - inter_area
116
+
117
+ if union_area == 0:
118
+ return 0
119
+
120
+ return inter_area / union_area
121
+
122
+ def filter_overlapping_detections(self, boxes_coords, track_ids, confidences):
123
+ """Filtre les détections qui se chevauchent (ex: plusieurs détections sur un camion)"""
124
+ if len(boxes_coords) == 0:
125
+ return [], [], []
126
+
127
+ # Créer une liste de détections avec leurs indices
128
+ detections = []
129
+ for i, (box, tid, conf) in enumerate(zip(boxes_coords, track_ids, confidences)):
130
+ x_min, y_min, x_max, y_max = box
131
+ area = (x_max - x_min) * (y_max - y_min)
132
+ aspect_ratio = (y_max - y_min) / max(1, x_max - x_min)
133
+
134
+ # Filtrer les détections trop petites ou avec un aspect ratio bizarre
135
+ if area < self.min_box_area or aspect_ratio > self.max_aspect_ratio:
136
+ continue
137
+
138
+ detections.append({
139
+ 'index': i,
140
+ 'box': box,
141
+ 'track_id': tid,
142
+ 'conf': conf,
143
+ 'area': area
144
+ })
145
+
146
+ # Trier par confiance décroissante
147
+ detections.sort(key=lambda x: x['conf'], reverse=True)
148
+
149
+ # Non-Maximum Suppression manuel
150
+ keep_indices = []
151
+ while len(detections) > 0:
152
+ # Garder la détection avec la plus haute confiance
153
+ best = detections.pop(0)
154
+ keep_indices.append(best['index'])
155
+
156
+ # Supprimer les détections qui se chevauchent trop avec la meilleure
157
+ filtered_detections = []
158
+ for det in detections:
159
+ iou = self.calculate_iou(best['box'], det['box'])
160
+ if iou < self.iou_threshold: # Garder si IoU faible (pas de chevauchement)
161
+ filtered_detections.append(det)
162
+
163
+ detections = filtered_detections
164
+
165
+ # Retourner les détections filtrées
166
+ filtered_boxes = [boxes_coords[i] for i in keep_indices]
167
+ filtered_ids = [track_ids[i] for i in keep_indices]
168
+ filtered_confs = [confidences[i] for i in keep_indices]
169
+
170
+ return filtered_boxes, filtered_ids, filtered_confs
171
+
172
+ def update_detection_history(self, track_id, box, frame_num):
173
+ """Met à jour l'historique des détections pour un véhicule"""
174
+ if track_id not in self.detection_history:
175
+ self.detection_history[track_id] = {'boxes': [], 'frames': []}
176
+
177
+ self.detection_history[track_id]['boxes'].append(box)
178
+ self.detection_history[track_id]['frames'].append(frame_num)
179
+
180
+ # Garder seulement les N dernières frames
181
+ if len(self.detection_history[track_id]['boxes']) > self.history_length:
182
+ self.detection_history[track_id]['boxes'].pop(0)
183
+ self.detection_history[track_id]['frames'].pop(0)
184
+
185
+ def is_stable_detection(self, track_id):
186
+ """Vérifie si une détection est stable (pas un faux positif temporaire)"""
187
+ if track_id not in self.detection_history:
188
+ return False
189
+
190
+ # Considérer stable si détecté sur au moins 3 frames
191
+ return len(self.detection_history[track_id]['boxes']) >= 3
192
+
193
+ def process_video(self, video_path, output_path, progress_bar=None):
194
+ """Traite une vidéo enregistrée avec optimisations"""
195
  cap = cv2.VideoCapture(video_path)
196
  if not cap.isOpened():
197
+ st.error("⚠️ Erreur : Impossible d'ouvrir la vidéo.")
 
198
  return
199
 
200
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
201
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
202
  fps = cap.get(cv2.CAP_PROP_FPS)
203
  if not fps or fps <= 1e-3:
204
+ fps = 30.0
205
 
206
  fourcc = self._pick_fourcc(output_path)
207
  out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
208
  if not out.isOpened():
209
+ st.error("⚠️ Erreur : Impossible d'ouvrir la vidéo de sortie (codec).")
 
210
  cap.release()
211
  return
212
 
 
220
  if not success:
221
  break
222
 
223
+ # Progression
224
  if progress_bar is not None and total_frames > 0:
225
+ progress = min(1.0, processed_frames / float(total_frames))
226
+ progress_bar.progress(progress)
227
 
228
+ # Skip de frames
229
  if frame_count % self.frame_skip == 0:
230
+ processed_frame = self.process_frame(frame, frame_count)
231
  self.last_processed_frame = processed_frame
232
  else:
233
  processed_frame = self.last_processed_frame if self.last_processed_frame is not None else frame
 
235
  if processed_frame is None:
236
  processed_frame = frame
237
 
238
+ # S'assurer de la taille attendue
239
  if processed_frame.shape[1] != frame_width or processed_frame.shape[0] != frame_height:
240
  processed_frame = cv2.resize(processed_frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA)
241
 
 
247
  out.release()
248
  cv2.destroyAllWindows()
249
 
250
+ if processed_frames == 0:
251
+ st.error("⚠️ Aucune image n'a été écrite dans la vidéo de sortie !")
252
 
253
  return len(self.unique_region1_ids), len(self.unique_region2_ids)
254
 
255
+ def process_frame(self, frame, frame_num=0):
256
+ """Traite une image individuelle avec YOLO et le tracking, avec filtrage anti-duplicata"""
 
257
  if frame is None:
258
  return None
259
 
260
+ # Redimensionner l'image pour accélérer le traitement
261
+ orig_height, orig_width = frame.shape[:2]
262
+ resized_width, resized_height = orig_width, orig_height
263
  if self.downsample_factor < 1.0:
264
+ resized_width = max(1, int(orig_width * self.downsample_factor))
265
+ resized_height = max(1, int(orig_height * self.downsample_factor))
266
+ resized_frame = cv2.resize(frame, (resized_width, resized_height), interpolation=cv2.INTER_AREA)
267
  else:
268
  resized_frame = frame
269
 
270
+ # Détection + tracking
271
  with torch.no_grad():
272
  results = self.model.track(
273
  resized_frame,
274
  persist=True,
275
  tracker=self.tracker_config,
276
  conf=self.conf_threshold,
 
277
  imgsz=self.img_size,
278
  device=self.device,
279
+ classes=[2, 5, 7], # COCO: 2=car, 5=bus, 7=truck (évite autres objets)
280
+ verbose=False
281
  )
282
 
283
+ display_frame = frame.copy()
284
+ frame_height, frame_width = display_frame.shape[:2]
 
 
 
 
285
 
286
+ # Dessiner les polygones
287
+ cv2.polylines(display_frame, [np.array(self.poly1, np.int32)], isClosed=True, color=(0, 255, 0), thickness=2)
288
+ cv2.polylines(display_frame, [np.array(self.poly2, np.int32)], isClosed=True, color=(255, 0, 0), thickness=2)
289
 
290
+ # Échelle pour remonter aux coords originales
291
+ scale_x = orig_width / float(resized_width)
292
+ scale_y = orig_height / float(resized_height)
 
293
 
294
  if results and len(results) > 0 and getattr(results[0], "boxes", None) is not None:
295
  try:
296
  boxes = results[0].boxes.xywh.cpu().numpy()
297
  ids_tensor = results[0].boxes.id
298
+ confs = results[0].boxes.conf.cpu().numpy()
299
+
300
+ if ids_tensor is None:
301
+ track_ids = [None] * len(boxes)
302
+ else:
303
+ track_ids = ids_tensor.int().cpu().tolist()
304
+
305
+ # Convertir les boîtes en format [x_min, y_min, x_max, y_max]
306
+ boxes_coords = []
307
+ for x, y, w, h in boxes:
308
+ center_x = int(x * scale_x)
309
+ center_y = int(y * scale_y)
310
+ width = int(w * scale_x)
311
+ height = int(h * scale_y)
312
+ x_min = max(0, center_x - width // 2)
313
+ y_min = max(0, center_y - height // 2)
314
+ x_max = min(frame_width - 1, center_x + width // 2)
315
+ y_max = min(frame_height - 1, center_y + height // 2)
316
+ boxes_coords.append([x_min, y_min, x_max, y_max])
317
+
318
+ # Filtrer les détections qui se chevauchent
319
+ filtered_boxes, filtered_ids, filtered_confs = self.filter_overlapping_detections(
320
+ boxes_coords, track_ids, confs
321
+ )
322
 
323
+ # Traiter les détections filtrées
324
+ for box, track_id, conf in zip(filtered_boxes, filtered_ids, filtered_confs):
325
+ if track_id is None:
326
  continue
327
+
328
+ x_min, y_min, x_max, y_max = box
329
+ center_x = (x_min + x_max) // 2
330
+ center_y = (y_min + y_max) // 2
331
+ center_point = (center_x, center_y)
332
+
333
+ # Mettre à jour l'historique
334
+ self.update_detection_history(track_id, box, frame_num)
335
+
336
+ # Compter seulement les détections stables
337
+ if self.is_stable_detection(track_id):
338
+ if self.is_in_region(center_point, self.poly1):
339
+ self.unique_region1_ids.add(track_id)
340
+ if self.is_in_region(center_point, self.poly2):
341
+ self.unique_region2_ids.add(track_id)
342
+
343
+ # Dessiner la boîte (vert si stable, jaune sinon)
344
+ color = (0, 255, 0) if self.is_stable_detection(track_id) else (0, 255, 255)
345
+ cv2.rectangle(display_frame, (x_min, y_min), (x_max, y_max), color, 2)
346
+
347
+ # Afficher l'ID et la confiance
348
+ label = f"ID:{track_id} {conf:.2f}"
349
+ cv2.putText(display_frame, label, (x_min, y_min - 10),
350
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
351
+
352
  except Exception as e:
353
+ draw_text_with_background(display_frame, f"Tracking error: {e}", (10, 60), bg_color=(80, 0, 0))
354
 
355
+ # Affichage du comptage
356
+ # draw_text_with_background(display_frame, f"Total Sens 1: {len(self.unique_region1_ids)}", (10, frame_height - 50))
357
+ draw_text_with_background(display_frame, f"Total comptes: {len(self.unique_region2_ids)}", (frame_width - 300, frame_height - 50))
358
 
359
+ return display_frame
360
+
361
+ def process_webcam(self, camera_id=0, display_placeholder=None, count_placeholders=None):
362
+ """Traite la vidéo en temps réel depuis une webcam"""
363
  cap = cv2.VideoCapture(camera_id)
364
  if not cap.isOpened():
365
+ st.error("⚠️ Erreur : Impossible d'ouvrir la webcam.")
 
366
  return
367
 
368
  try:
 
377
  frame_count = 0
378
  last_ts = time.time()
379
 
 
 
 
 
 
 
 
 
 
380
  while not self.stop_processing:
381
  success, frame = cap.read()
382
  if not success:
383
+ st.error("⚠️ Erreur lors de la lecture du flux vidéo.")
 
384
  break
385
 
386
  if frame_count % self.frame_skip == 0:
387
+ processed_frame = self.process_frame(frame, frame_count)
388
+ self.last_processed_frame = processed_frame
389
  now = time.time()
390
+ dt = max(1e-6, now - last_ts)
391
+ fps = 1.0 / dt
392
  last_ts = now
393
+ if processed_frame is not None:
394
+ draw_text_with_background(processed_frame, f"FPS: {fps:.1f}", (10, 30))
395
  else:
396
+ processed_frame = self.last_processed_frame if self.last_processed_frame is not None else frame
397
 
398
+ if processed_frame is not None:
399
  try:
400
+ processed_frame_rgb = cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB)
401
  except Exception:
402
+ processed_frame_rgb = processed_frame
403
+ img = Image.fromarray(processed_frame_rgb)
404
+
405
+ if display_placeholder:
406
+ display_placeholder.image(img, channels="RGB", use_column_width=True)
407
 
408
+ if count_placeholders and len(count_placeholders) >= 2:
409
+ count_placeholders[0].metric("Véhicules Sens 1 (Vert)", len(self.unique_region1_ids))
410
+ count_placeholders[1].metric("Véhicules Sens 2 (Rouge)", len(self.unique_region2_ids))
411
 
412
  frame_count += 1
413
  time.sleep(0.01)
414
 
415
  cap.release()
416
+ st.success("✅ Flux vidéo arrêté.")
 
417
 
418
 
419
  # --- INTERFACE STREAMLIT ---
 
426
 
427
  st.title("🚗 Détection et comptage de Véhicules sur l'Autoroute de l'Avenir")
428
 
429
+ # Session state
430
  st.session_state.setdefault("webcam_active", False)
431
  st.session_state.setdefault("processor", None)
 
432
 
433
  # Modèle
434
  model_path = "best.pt"
 
452
 
453
  st.subheader("📍 Polygone 1 (vert)")
454
  poly1_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "900,350 1150,350 700,630 200,630")
455
+
456
  st.subheader("📍 Polygone 2 (rouge)")
457
  poly2_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "1200,350 1400,350 1150,630 743,630")
458
 
459
+ tracker_method = st.selectbox("Méthode de tracking", ["bot", "byte"], index=0)
460
 
461
  st.subheader("🚀 Paramètres d'optimisation")
462
+ frame_skip = st.slider("Skip de frames", 1, 5, 2)
463
+ downsample = st.slider("Facteur d'échelle", 0.3, 1.0, 0.5, 0.1)
464
+ conf_threshold = st.slider("Seuil de confiance", 0.1, 0.9, 0.35, 0.05)
465
+
466
+ st.subheader("🔧 Anti-duplicata")
467
+ iou_thresh = st.slider("Seuil IoU (fusion détections)", 0.1, 0.9, 0.3, 0.05)
468
+ min_area = st.slider("Surface minimale (pixels²)", 100, 2000, 500, 100)
469
 
470
  def parse_polygon(input_text):
471
  try:
 
483
 
484
  # Onglet 1: Analyse vidéo
485
  with tab1:
486
+ uploaded_file = st.file_uploader("📂 Upload une vidéo", type=["mp4", "avi", "mkv", "mov"])
 
 
487
  if uploaded_file is not None:
488
  temp_dir = tempfile.mkdtemp()
489
  ext = os.path.splitext(uploaded_file.name)[1].lower() or ".mp4"
 
502
  processor.frame_skip = frame_skip
503
  processor.downsample_factor = downsample
504
  processor.conf_threshold = conf_threshold
505
+ processor.iou_threshold = iou_thresh
506
+ processor.min_box_area = min_area
507
 
508
  start_time = time.time()
509
+ counts = processor.process_video(input_video_path, output_video_path, progress_bar=progress_bar)
 
 
 
 
510
  end_time = time.time()
511
  if counts:
512
+ count1, count2 = counts
513
+ st.success(f"✅ Traitement terminé en {end_time - start_time:.2f} s")
514
+
515
  col_result1, col_result2 = st.columns(2)
516
+ col_result1.metric("Véhicules Sens 1 (Vert)", count1)
517
+ col_result2.metric("Véhicules Sens 2 (Rouge)", count2)
518
 
519
  st.subheader("Vidéo traitée")
520
  st.video(output_video_path)
521
+
522
  with open(output_video_path, "rb") as file:
523
  st.download_button(
524
  label="⬇️ Télécharger la vidéo",
525
  data=file,
526
  file_name=f"video_traitee{ext}",
527
+ mime=f"video/{ext.strip('.')}",
528
  )
529
  else:
530
+ st.error("❌ Les coordonnées des polygones doivent contenir **exactement 4 points**.")
531
 
532
  # Onglet 2: Webcam
533
  with tab2:
534
  st.header("Détection en Temps Réel avec Webcam")
535
 
 
536
  camera_options = {"Webcam par défaut": 0}
537
  for i in range(1, 5):
538
  try:
 
547
  camera_id = camera_options[selected_camera]
548
 
549
  video_placeholder = st.empty()
 
550
  col1, col2 = st.columns(2)
551
  count_placeholders = [col1.empty(), col2.empty()]
552
 
553
+ st.info("ℹ️ Optimisations: redimensionnement, skip de frames, filtrage anti-duplicata, CUDA si disponible.")
554
 
555
  col_start, col_stop = st.columns(2)
556
 
557
  if col_start.button("▶️ Démarrer la détection en direct"):
558
  if not valid_polygons:
559
+ st.error("❌ Les coordonnées des polygones doivent contenir **exactement 4 points**.")
560
  elif st.session_state.webcam_active:
561
  st.warning("⚠️ La webcam est déjà active !")
562
  else:
 
564
  processor.frame_skip = frame_skip
565
  processor.downsample_factor = downsample
566
  processor.conf_threshold = conf_threshold
567
+ processor.iou_threshold = iou_thresh
568
+ processor.min_box_area = min_area
569
 
570
  st.session_state.processor = processor
571
  st.session_state.webcam_active = True
572
 
573
+ threading.Thread(
574
  target=st.session_state.processor.process_webcam,
575
+ args=(camera_id, video_placeholder, count_placeholders),
576
  daemon=True,
577
+ ).start()
 
 
 
578
 
579
  if col_stop.button("⏹️ Arrêter la détection"):
580
  if st.session_state.webcam_active and st.session_state.processor:
581
  st.session_state.processor.stop_processing = True
582
  st.session_state.webcam_active = False
583
+ time.sleep(0.5)
 
 
 
 
 
584
  video_placeholder.empty()
 
585
  else:
586
  st.warning("⚠️ Aucune détection en cours !")
587