import streamlit as st import cv2 import tempfile import os import time import numpy as np from ultralytics import YOLO import threading from PIL import Image import torch # --- FONCTIONS UTILES --- def draw_text_with_background( image, text, position, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1, font_thickness=2, text_color=(255, 255, 255), bg_color=(0, 0, 0), padding=5, ): """Ajoute du texte avec un fond sur une image OpenCV (bornes sécurisées).""" (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, font_thickness) x, y = position tl_x = max(0, x) tl_y = max(0, y - text_height - padding) br_x = min(image.shape[1] - 1, x + text_width + padding * 2) br_y = min(image.shape[0] - 1, y + padding) cv2.rectangle(image, (tl_x, tl_y), (br_x, br_y), bg_color, -1) cv2.putText( image, text, (tl_x + padding, min(y, image.shape[0] - 1)), font, font_scale, text_color, font_thickness, cv2.LINE_AA, ) # --- CLASSE YOLO OPTIMISÉE --- class YOLOVideoProcessor: def __init__(self, model_path, poly1, poly2, tracker_method="bot"): # Device self.device = "cuda" if torch.cuda.is_available() else "cpu" # Paramètres d'optimisation self.frame_skip = 2 self.downsample_factor = 0.5 self.img_size = 640 self.conf_threshold = 0.35 # Modèle self.model = YOLO(model_path) self.model.to(self.device) # Tracking self.tracker_method = tracker_method self.tracker_config = "botsort.yaml" if self.tracker_method.lower() == "bot" else "bytetrack.yaml" # États self.unique_region1_ids = set() self.unique_region2_ids = set() self.poly1 = poly1 self.poly2 = poly2 self.stop_processing = False self.last_processed_frame = None self.current_frame = 0 # Paramètres anti-duplicata pour camions longs self.iou_threshold = 0.3 # Seuil IoU pour fusionner les détections proches self.min_box_area = 500 # Surface minimale pour être considéré comme véhicule self.max_aspect_ratio = 5.0 # Ratio hauteur/largeur max pour éviter détections étirées # Historique des détections pour filtrage temporel self.detection_history = {} # {track_id: {'boxes': [], 'frames': []}} self.history_length = 5 # Nombre de frames à garder en mémoire @staticmethod def is_in_region(center, poly): poly_np = np.array(poly, dtype=np.int32) return cv2.pointPolygonTest(poly_np, center, False) >= 0 def reset_counts(self): self.unique_region1_ids.clear() self.unique_region2_ids.clear() self.detection_history.clear() def _pick_fourcc(self, output_path): ext = os.path.splitext(output_path)[1].lower() if ext == ".mp4": return cv2.VideoWriter_fourcc(*"mp4v") return cv2.VideoWriter_fourcc(*"XVID") def calculate_iou(self, box1, box2): """Calcule l'IoU (Intersection over Union) entre deux boîtes""" x1_min, y1_min, x1_max, y1_max = box1 x2_min, y2_min, x2_max, y2_max = box2 # Intersection inter_x_min = max(x1_min, x2_min) inter_y_min = max(y1_min, y2_min) inter_x_max = min(x1_max, x2_max) inter_y_max = min(y1_max, y2_max) inter_area = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min) # Union box1_area = (x1_max - x1_min) * (y1_max - y1_min) box2_area = (x2_max - x2_min) * (y2_max - y2_min) union_area = box1_area + box2_area - inter_area if union_area == 0: return 0 return inter_area / union_area def filter_overlapping_detections(self, boxes_coords, track_ids, confidences): """Filtre les détections qui se chevauchent (ex: plusieurs détections sur un camion)""" if len(boxes_coords) == 0: return [], [], [] # Créer une liste de détections avec leurs indices detections = [] for i, (box, tid, conf) in enumerate(zip(boxes_coords, track_ids, confidences)): x_min, y_min, x_max, y_max = box area = (x_max - x_min) * (y_max - y_min) aspect_ratio = (y_max - y_min) / max(1, x_max - x_min) # Filtrer les détections trop petites ou avec un aspect ratio bizarre if area < self.min_box_area or aspect_ratio > self.max_aspect_ratio: continue detections.append({ 'index': i, 'box': box, 'track_id': tid, 'conf': conf, 'area': area }) # Trier par confiance décroissante detections.sort(key=lambda x: x['conf'], reverse=True) # Non-Maximum Suppression manuel keep_indices = [] while len(detections) > 0: # Garder la détection avec la plus haute confiance best = detections.pop(0) keep_indices.append(best['index']) # Supprimer les détections qui se chevauchent trop avec la meilleure filtered_detections = [] for det in detections: iou = self.calculate_iou(best['box'], det['box']) if iou < self.iou_threshold: # Garder si IoU faible (pas de chevauchement) filtered_detections.append(det) detections = filtered_detections # Retourner les détections filtrées filtered_boxes = [boxes_coords[i] for i in keep_indices] filtered_ids = [track_ids[i] for i in keep_indices] filtered_confs = [confidences[i] for i in keep_indices] return filtered_boxes, filtered_ids, filtered_confs def update_detection_history(self, track_id, box, frame_num): """Met à jour l'historique des détections pour un véhicule""" if track_id not in self.detection_history: self.detection_history[track_id] = {'boxes': [], 'frames': []} self.detection_history[track_id]['boxes'].append(box) self.detection_history[track_id]['frames'].append(frame_num) # Garder seulement les N dernières frames if len(self.detection_history[track_id]['boxes']) > self.history_length: self.detection_history[track_id]['boxes'].pop(0) self.detection_history[track_id]['frames'].pop(0) def is_stable_detection(self, track_id): """Vérifie si une détection est stable (pas un faux positif temporaire)""" if track_id not in self.detection_history: return False # Considérer stable si détecté sur au moins 3 frames return len(self.detection_history[track_id]['boxes']) >= 3 def process_video(self, video_path, output_path, progress_bar=None): """Traite une vidéo enregistrée avec optimisations""" cap = cv2.VideoCapture(video_path) if not cap.isOpened(): st.error("⚠️ Erreur : Impossible d'ouvrir la vidéo.") return frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) if not fps or fps <= 1e-3: fps = 30.0 fourcc = self._pick_fourcc(output_path) out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) if not out.isOpened(): st.error("⚠️ Erreur : Impossible d'ouvrir la vidéo de sortie (codec).") cap.release() return self.reset_counts() processed_frames = 0 total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_count = 0 while cap.isOpened(): success, frame = cap.read() if not success: break # Progression if progress_bar is not None and total_frames > 0: progress = min(1.0, processed_frames / float(total_frames)) progress_bar.progress(progress) # Skip de frames if frame_count % self.frame_skip == 0: processed_frame = self.process_frame(frame, frame_count) self.last_processed_frame = processed_frame else: processed_frame = self.last_processed_frame if self.last_processed_frame is not None else frame if processed_frame is None: processed_frame = frame # S'assurer de la taille attendue if processed_frame.shape[1] != frame_width or processed_frame.shape[0] != frame_height: processed_frame = cv2.resize(processed_frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA) out.write(processed_frame) processed_frames += 1 frame_count += 1 cap.release() out.release() cv2.destroyAllWindows() if processed_frames == 0: st.error("⚠️ Aucune image n'a été écrite dans la vidéo de sortie !") return len(self.unique_region1_ids), len(self.unique_region2_ids) def process_frame(self, frame, frame_num=0): """Traite une image individuelle avec YOLO et le tracking, avec filtrage anti-duplicata""" if frame is None: return None # Redimensionner l'image pour accélérer le traitement orig_height, orig_width = frame.shape[:2] resized_width, resized_height = orig_width, orig_height if self.downsample_factor < 1.0: resized_width = max(1, int(orig_width * self.downsample_factor)) resized_height = max(1, int(orig_height * self.downsample_factor)) resized_frame = cv2.resize(frame, (resized_width, resized_height), interpolation=cv2.INTER_AREA) else: resized_frame = frame # Détection + tracking with torch.no_grad(): results = self.model.track( resized_frame, persist=True, tracker=self.tracker_config, conf=self.conf_threshold, imgsz=self.img_size, device=self.device, classes=[2, 5, 7], # COCO: 2=car, 5=bus, 7=truck (évite autres objets) verbose=False ) display_frame = frame.copy() frame_height, frame_width = display_frame.shape[:2] # Dessiner les polygones cv2.polylines(display_frame, [np.array(self.poly1, np.int32)], isClosed=True, color=(0, 255, 0), thickness=2) cv2.polylines(display_frame, [np.array(self.poly2, np.int32)], isClosed=True, color=(255, 0, 0), thickness=2) # Échelle pour remonter aux coords originales scale_x = orig_width / float(resized_width) scale_y = orig_height / float(resized_height) if results and len(results) > 0 and getattr(results[0], "boxes", None) is not None: try: boxes = results[0].boxes.xywh.cpu().numpy() ids_tensor = results[0].boxes.id confs = results[0].boxes.conf.cpu().numpy() if ids_tensor is None: track_ids = [None] * len(boxes) else: track_ids = ids_tensor.int().cpu().tolist() # Convertir les boîtes en format [x_min, y_min, x_max, y_max] boxes_coords = [] for x, y, w, h in boxes: center_x = int(x * scale_x) center_y = int(y * scale_y) width = int(w * scale_x) height = int(h * scale_y) x_min = max(0, center_x - width // 2) y_min = max(0, center_y - height // 2) x_max = min(frame_width - 1, center_x + width // 2) y_max = min(frame_height - 1, center_y + height // 2) boxes_coords.append([x_min, y_min, x_max, y_max]) # Filtrer les détections qui se chevauchent filtered_boxes, filtered_ids, filtered_confs = self.filter_overlapping_detections( boxes_coords, track_ids, confs ) # Traiter les détections filtrées for box, track_id, conf in zip(filtered_boxes, filtered_ids, filtered_confs): if track_id is None: continue x_min, y_min, x_max, y_max = box center_x = (x_min + x_max) // 2 center_y = (y_min + y_max) // 2 center_point = (center_x, center_y) # Mettre à jour l'historique self.update_detection_history(track_id, box, frame_num) # Compter seulement les détections stables if self.is_stable_detection(track_id): if self.is_in_region(center_point, self.poly1): self.unique_region1_ids.add(track_id) if self.is_in_region(center_point, self.poly2): self.unique_region2_ids.add(track_id) # Dessiner la boîte (vert si stable, jaune sinon) color = (0, 255, 0) if self.is_stable_detection(track_id) else (0, 255, 255) cv2.rectangle(display_frame, (x_min, y_min), (x_max, y_max), color, 2) # Afficher l'ID et la confiance label = f"ID:{track_id} {conf:.2f}" cv2.putText(display_frame, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) except Exception as e: draw_text_with_background(display_frame, f"Tracking error: {e}", (10, 60), bg_color=(80, 0, 0)) # Affichage du comptage # draw_text_with_background(display_frame, f"Total Sens 1: {len(self.unique_region1_ids)}", (10, frame_height - 50)) draw_text_with_background(display_frame, f"Total: {len(self.unique_region2_ids)}", (frame_width - 300, frame_height - 50)) return display_frame def process_webcam(self, camera_id=0, display_placeholder=None, count_placeholders=None): """Traite la vidéo en temps réel depuis une webcam""" cap = cv2.VideoCapture(camera_id) if not cap.isOpened(): st.error("⚠️ Erreur : Impossible d'ouvrir la webcam.") return try: cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) cap.set(cv2.CAP_PROP_FPS, 30) except Exception: pass self.reset_counts() self.stop_processing = False frame_count = 0 last_ts = time.time() while not self.stop_processing: success, frame = cap.read() if not success: st.error("⚠️ Erreur lors de la lecture du flux vidéo.") break if frame_count % self.frame_skip == 0: processed_frame = self.process_frame(frame, frame_count) self.last_processed_frame = processed_frame now = time.time() dt = max(1e-6, now - last_ts) fps = 1.0 / dt last_ts = now if processed_frame is not None: draw_text_with_background(processed_frame, f"FPS: {fps:.1f}", (10, 30)) else: processed_frame = self.last_processed_frame if self.last_processed_frame is not None else frame if processed_frame is not None: try: processed_frame_rgb = cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) except Exception: processed_frame_rgb = processed_frame img = Image.fromarray(processed_frame_rgb) if display_placeholder: display_placeholder.image(img, channels="RGB", use_column_width=True) if count_placeholders and len(count_placeholders) >= 2: count_placeholders[0].metric("Véhicules Sens 1 (Vert)", len(self.unique_region1_ids)) count_placeholders[1].metric("Véhicules Sens 2 (Rouge)", len(self.unique_region2_ids)) frame_count += 1 time.sleep(0.01) cap.release() st.success("✅ Flux vidéo arrêté.") # --- INTERFACE STREAMLIT --- def main(): st.set_page_config( page_title="Détecteur de Véhicules", page_icon="🚗", layout="wide" ) st.title("🚗 Détection et comptage de Véhicules sur l'Autoroute de l'Avenir") # Session state st.session_state.setdefault("webcam_active", False) st.session_state.setdefault("processor", None) # Modèle model_path = "best.pt" if not os.path.exists(model_path): with st.spinner("📥 Chargement du modèle YOLO..."): try: from huggingface_hub import hf_hub_download model_path = hf_hub_download(repo_id="ModuMLTECH/Trafic_congestion", filename="best.pt") st.success("✅ Modèle chargé depuis Hugging Face Hub.") except Exception as e: st.error(f"❌ Erreur lors du chargement du modèle: {e}") st.warning("⚠️ Utilisation du modèle YOLO public à la place (yolov8n.pt).") model_path = "yolov8n.pt" # Tabs tab1, tab2 = st.tabs(["📹 Analyse de Vidéo", "🎥 Détection en Temps Réel"]) # Sidebar with st.sidebar: st.header("🔹 Paramètres") st.subheader("📍 Polygone 1 (vert)") poly1_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "0,0 0,0 0,0 0,0") st.subheader("📍 Polygone 2 (rouge)") poly2_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "500,150 700,150 1100,530 630,530") tracker_method = st.selectbox("Méthode de tracking", ["bot", "byte"], index=0) st.subheader("🚀 Paramètres d'optimisation") frame_skip = st.slider("Skip de frames", 1, 5, 2) downsample = st.slider("Facteur d'échelle", 0.3, 1.0, 0.5, 0.1) conf_threshold = st.slider("Seuil de confiance", 0.1, 0.9, 0.35, 0.05) st.subheader("🔧 Anti-duplicata") iou_thresh = st.slider("Seuil IoU (fusion détections)", 0.1, 0.9, 0.3, 0.05) min_area = st.slider("Surface minimale (pixels²)", 100, 2000, 500, 100) def parse_polygon(input_text): try: pts = [] for token in input_text.replace(";", " ").split(): x, y = token.split(",") pts.append((int(x), int(y))) return pts except Exception: return [] poly1 = parse_polygon(poly1_input) poly2 = parse_polygon(poly2_input) valid_polygons = len(poly1) == 4 and len(poly2) == 4 # Onglet 1: Analyse vidéo with tab1: uploaded_file = st.file_uploader("📂 Upload une vidéo", type=["mp4", "avi", "mkv", "mov"]) if uploaded_file is not None: temp_dir = tempfile.mkdtemp() ext = os.path.splitext(uploaded_file.name)[1].lower() or ".mp4" input_video_path = os.path.join(temp_dir, f"input_video{ext}") output_video_path = os.path.join(temp_dir, f"output_video{ext}") with open(input_video_path, "wb") as f: f.write(uploaded_file.getbuffer()) st.video(input_video_path) if st.button("▶️ Lancer la détection"): if valid_polygons: progress_bar = st.progress(0) processor = YOLOVideoProcessor(model_path, poly1, poly2, tracker_method) processor.frame_skip = frame_skip processor.downsample_factor = downsample processor.conf_threshold = conf_threshold processor.iou_threshold = iou_thresh processor.min_box_area = min_area start_time = time.time() counts = processor.process_video(input_video_path, output_video_path, progress_bar=progress_bar) end_time = time.time() if counts: count1, count2 = counts st.success(f"✅ Traitement terminé en {end_time - start_time:.2f} s") col_result1, col_result2 = st.columns(2) col_result1.metric("Véhicules Sens 1 (Vert)", count1) col_result2.metric("Véhicules Sens 2 (Rouge)", count2) st.subheader("Vidéo traitée") st.video(output_video_path) with open(output_video_path, "rb") as file: st.download_button( label="⬇️ Télécharger la vidéo", data=file, file_name=f"video_traitee{ext}", mime=f"video/{ext.strip('.')}", ) else: st.error("❌ Les coordonnées des polygones doivent contenir **exactement 4 points**.") # Onglet 2: Webcam with tab2: st.header("Détection en Temps Réel avec Webcam") camera_options = {"Webcam par défaut": 0} for i in range(1, 5): try: cap = cv2.VideoCapture(i) if cap.isOpened(): camera_options[f"Caméra {i}"] = i cap.release() except Exception: pass selected_camera = st.selectbox("Sélectionnez la source vidéo", list(camera_options.keys())) camera_id = camera_options[selected_camera] video_placeholder = st.empty() col1, col2 = st.columns(2) count_placeholders = [col1.empty(), col2.empty()] st.info("ℹ️ Optimisations: redimensionnement, skip de frames, filtrage anti-duplicata, CUDA si disponible.") col_start, col_stop = st.columns(2) if col_start.button("▶️ Démarrer la détection en direct"): if not valid_polygons: st.error("❌ Les coordonnées des polygones doivent contenir **exactement 4 points**.") elif st.session_state.webcam_active: st.warning("⚠️ La webcam est déjà active !") else: processor = YOLOVideoProcessor(model_path, poly1, poly2, tracker_method) processor.frame_skip = frame_skip processor.downsample_factor = downsample processor.conf_threshold = conf_threshold processor.iou_threshold = iou_thresh processor.min_box_area = min_area st.session_state.processor = processor st.session_state.webcam_active = True threading.Thread( target=st.session_state.processor.process_webcam, args=(camera_id, video_placeholder, count_placeholders), daemon=True, ).start() if col_stop.button("⏹️ Arrêter la détection"): if st.session_state.webcam_active and st.session_state.processor: st.session_state.processor.stop_processing = True st.session_state.webcam_active = False time.sleep(0.5) video_placeholder.empty() else: st.warning("⚠️ Aucune détection en cours !") if __name__ == "__main__": main()