Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import cv2 | |
| import tempfile | |
| import os | |
| import time | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import threading | |
| from PIL import Image | |
| import torch | |
| # --- FONCTIONS UTILES --- | |
| def draw_text_with_background( | |
| image, | |
| text, | |
| position, | |
| font=cv2.FONT_HERSHEY_SIMPLEX, | |
| font_scale=1, | |
| font_thickness=2, | |
| text_color=(255, 255, 255), | |
| bg_color=(0, 0, 0), | |
| padding=5, | |
| ): | |
| """Ajoute du texte avec un fond sur une image OpenCV (bornes sécurisées).""" | |
| (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, font_thickness) | |
| x, y = position | |
| tl_x = max(0, x) | |
| tl_y = max(0, y - text_height - padding) | |
| br_x = min(image.shape[1] - 1, x + text_width + padding * 2) | |
| br_y = min(image.shape[0] - 1, y + padding) | |
| cv2.rectangle(image, (tl_x, tl_y), (br_x, br_y), bg_color, -1) | |
| cv2.putText( | |
| image, | |
| text, | |
| (tl_x + padding, min(y, image.shape[0] - 1)), | |
| font, | |
| font_scale, | |
| text_color, | |
| font_thickness, | |
| cv2.LINE_AA, | |
| ) | |
| # --- CLASSE YOLO OPTIMISÉE --- | |
| class YOLOVideoProcessor: | |
| def __init__(self, model_path, poly1, poly2, tracker_method="bot"): | |
| # Device | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Paramètres d'optimisation | |
| self.frame_skip = 2 | |
| self.downsample_factor = 0.5 | |
| self.img_size = 640 | |
| self.conf_threshold = 0.35 | |
| # Modèle | |
| self.model = YOLO(model_path) | |
| self.model.to(self.device) | |
| # Tracking | |
| self.tracker_method = tracker_method | |
| self.tracker_config = "botsort.yaml" if self.tracker_method.lower() == "bot" else "bytetrack.yaml" | |
| # États | |
| self.unique_region1_ids = set() | |
| self.unique_region2_ids = set() | |
| self.poly1 = poly1 | |
| self.poly2 = poly2 | |
| self.stop_processing = False | |
| self.last_processed_frame = None | |
| self.current_frame = 0 | |
| # Paramètres anti-duplicata pour camions longs | |
| self.iou_threshold = 0.3 # Seuil IoU pour fusionner les détections proches | |
| self.min_box_area = 500 # Surface minimale pour être considéré comme véhicule | |
| self.max_aspect_ratio = 5.0 # Ratio hauteur/largeur max pour éviter détections étirées | |
| # Historique des détections pour filtrage temporel | |
| self.detection_history = {} # {track_id: {'boxes': [], 'frames': []}} | |
| self.history_length = 5 # Nombre de frames à garder en mémoire | |
| def is_in_region(center, poly): | |
| poly_np = np.array(poly, dtype=np.int32) | |
| return cv2.pointPolygonTest(poly_np, center, False) >= 0 | |
| def reset_counts(self): | |
| self.unique_region1_ids.clear() | |
| self.unique_region2_ids.clear() | |
| self.detection_history.clear() | |
| def _pick_fourcc(self, output_path): | |
| ext = os.path.splitext(output_path)[1].lower() | |
| if ext == ".mp4": | |
| return cv2.VideoWriter_fourcc(*"mp4v") | |
| return cv2.VideoWriter_fourcc(*"XVID") | |
| def calculate_iou(self, box1, box2): | |
| """Calcule l'IoU (Intersection over Union) entre deux boîtes""" | |
| x1_min, y1_min, x1_max, y1_max = box1 | |
| x2_min, y2_min, x2_max, y2_max = box2 | |
| # Intersection | |
| inter_x_min = max(x1_min, x2_min) | |
| inter_y_min = max(y1_min, y2_min) | |
| inter_x_max = min(x1_max, x2_max) | |
| inter_y_max = min(y1_max, y2_max) | |
| inter_area = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min) | |
| # Union | |
| box1_area = (x1_max - x1_min) * (y1_max - y1_min) | |
| box2_area = (x2_max - x2_min) * (y2_max - y2_min) | |
| union_area = box1_area + box2_area - inter_area | |
| if union_area == 0: | |
| return 0 | |
| return inter_area / union_area | |
| def filter_overlapping_detections(self, boxes_coords, track_ids, confidences): | |
| """Filtre les détections qui se chevauchent (ex: plusieurs détections sur un camion)""" | |
| if len(boxes_coords) == 0: | |
| return [], [], [] | |
| # Créer une liste de détections avec leurs indices | |
| detections = [] | |
| for i, (box, tid, conf) in enumerate(zip(boxes_coords, track_ids, confidences)): | |
| x_min, y_min, x_max, y_max = box | |
| area = (x_max - x_min) * (y_max - y_min) | |
| aspect_ratio = (y_max - y_min) / max(1, x_max - x_min) | |
| # Filtrer les détections trop petites ou avec un aspect ratio bizarre | |
| if area < self.min_box_area or aspect_ratio > self.max_aspect_ratio: | |
| continue | |
| detections.append({ | |
| 'index': i, | |
| 'box': box, | |
| 'track_id': tid, | |
| 'conf': conf, | |
| 'area': area | |
| }) | |
| # Trier par confiance décroissante | |
| detections.sort(key=lambda x: x['conf'], reverse=True) | |
| # Non-Maximum Suppression manuel | |
| keep_indices = [] | |
| while len(detections) > 0: | |
| # Garder la détection avec la plus haute confiance | |
| best = detections.pop(0) | |
| keep_indices.append(best['index']) | |
| # Supprimer les détections qui se chevauchent trop avec la meilleure | |
| filtered_detections = [] | |
| for det in detections: | |
| iou = self.calculate_iou(best['box'], det['box']) | |
| if iou < self.iou_threshold: # Garder si IoU faible (pas de chevauchement) | |
| filtered_detections.append(det) | |
| detections = filtered_detections | |
| # Retourner les détections filtrées | |
| filtered_boxes = [boxes_coords[i] for i in keep_indices] | |
| filtered_ids = [track_ids[i] for i in keep_indices] | |
| filtered_confs = [confidences[i] for i in keep_indices] | |
| return filtered_boxes, filtered_ids, filtered_confs | |
| def update_detection_history(self, track_id, box, frame_num): | |
| """Met à jour l'historique des détections pour un véhicule""" | |
| if track_id not in self.detection_history: | |
| self.detection_history[track_id] = {'boxes': [], 'frames': []} | |
| self.detection_history[track_id]['boxes'].append(box) | |
| self.detection_history[track_id]['frames'].append(frame_num) | |
| # Garder seulement les N dernières frames | |
| if len(self.detection_history[track_id]['boxes']) > self.history_length: | |
| self.detection_history[track_id]['boxes'].pop(0) | |
| self.detection_history[track_id]['frames'].pop(0) | |
| def is_stable_detection(self, track_id): | |
| """Vérifie si une détection est stable (pas un faux positif temporaire)""" | |
| if track_id not in self.detection_history: | |
| return False | |
| # Considérer stable si détecté sur au moins 3 frames | |
| return len(self.detection_history[track_id]['boxes']) >= 3 | |
| def process_video(self, video_path, output_path, progress_bar=None): | |
| """Traite une vidéo enregistrée avec optimisations""" | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| st.error("⚠️ Erreur : Impossible d'ouvrir la vidéo.") | |
| return | |
| frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if not fps or fps <= 1e-3: | |
| fps = 30.0 | |
| fourcc = self._pick_fourcc(output_path) | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) | |
| if not out.isOpened(): | |
| st.error("⚠️ Erreur : Impossible d'ouvrir la vidéo de sortie (codec).") | |
| cap.release() | |
| return | |
| self.reset_counts() | |
| processed_frames = 0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| success, frame = cap.read() | |
| if not success: | |
| break | |
| # Progression | |
| if progress_bar is not None and total_frames > 0: | |
| progress = min(1.0, processed_frames / float(total_frames)) | |
| progress_bar.progress(progress) | |
| # Skip de frames | |
| if frame_count % self.frame_skip == 0: | |
| processed_frame = self.process_frame(frame, frame_count) | |
| self.last_processed_frame = processed_frame | |
| else: | |
| processed_frame = self.last_processed_frame if self.last_processed_frame is not None else frame | |
| if processed_frame is None: | |
| processed_frame = frame | |
| # S'assurer de la taille attendue | |
| if processed_frame.shape[1] != frame_width or processed_frame.shape[0] != frame_height: | |
| processed_frame = cv2.resize(processed_frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA) | |
| out.write(processed_frame) | |
| processed_frames += 1 | |
| frame_count += 1 | |
| cap.release() | |
| out.release() | |
| cv2.destroyAllWindows() | |
| if processed_frames == 0: | |
| st.error("⚠️ Aucune image n'a été écrite dans la vidéo de sortie !") | |
| return len(self.unique_region1_ids), len(self.unique_region2_ids) | |
| def process_frame(self, frame, frame_num=0): | |
| """Traite une image individuelle avec YOLO et le tracking, avec filtrage anti-duplicata""" | |
| if frame is None: | |
| return None | |
| # Redimensionner l'image pour accélérer le traitement | |
| orig_height, orig_width = frame.shape[:2] | |
| resized_width, resized_height = orig_width, orig_height | |
| if self.downsample_factor < 1.0: | |
| resized_width = max(1, int(orig_width * self.downsample_factor)) | |
| resized_height = max(1, int(orig_height * self.downsample_factor)) | |
| resized_frame = cv2.resize(frame, (resized_width, resized_height), interpolation=cv2.INTER_AREA) | |
| else: | |
| resized_frame = frame | |
| # Détection + tracking | |
| with torch.no_grad(): | |
| results = self.model.track( | |
| resized_frame, | |
| persist=True, | |
| tracker=self.tracker_config, | |
| conf=self.conf_threshold, | |
| imgsz=self.img_size, | |
| device=self.device, | |
| classes=[2, 5, 7], # COCO: 2=car, 5=bus, 7=truck (évite autres objets) | |
| verbose=False | |
| ) | |
| display_frame = frame.copy() | |
| frame_height, frame_width = display_frame.shape[:2] | |
| # Dessiner les polygones | |
| cv2.polylines(display_frame, [np.array(self.poly1, np.int32)], isClosed=True, color=(0, 255, 0), thickness=2) | |
| cv2.polylines(display_frame, [np.array(self.poly2, np.int32)], isClosed=True, color=(255, 0, 0), thickness=2) | |
| # Échelle pour remonter aux coords originales | |
| scale_x = orig_width / float(resized_width) | |
| scale_y = orig_height / float(resized_height) | |
| if results and len(results) > 0 and getattr(results[0], "boxes", None) is not None: | |
| try: | |
| boxes = results[0].boxes.xywh.cpu().numpy() | |
| ids_tensor = results[0].boxes.id | |
| confs = results[0].boxes.conf.cpu().numpy() | |
| if ids_tensor is None: | |
| track_ids = [None] * len(boxes) | |
| else: | |
| track_ids = ids_tensor.int().cpu().tolist() | |
| # Convertir les boîtes en format [x_min, y_min, x_max, y_max] | |
| boxes_coords = [] | |
| for x, y, w, h in boxes: | |
| center_x = int(x * scale_x) | |
| center_y = int(y * scale_y) | |
| width = int(w * scale_x) | |
| height = int(h * scale_y) | |
| x_min = max(0, center_x - width // 2) | |
| y_min = max(0, center_y - height // 2) | |
| x_max = min(frame_width - 1, center_x + width // 2) | |
| y_max = min(frame_height - 1, center_y + height // 2) | |
| boxes_coords.append([x_min, y_min, x_max, y_max]) | |
| # Filtrer les détections qui se chevauchent | |
| filtered_boxes, filtered_ids, filtered_confs = self.filter_overlapping_detections( | |
| boxes_coords, track_ids, confs | |
| ) | |
| # Traiter les détections filtrées | |
| for box, track_id, conf in zip(filtered_boxes, filtered_ids, filtered_confs): | |
| if track_id is None: | |
| continue | |
| x_min, y_min, x_max, y_max = box | |
| center_x = (x_min + x_max) // 2 | |
| center_y = (y_min + y_max) // 2 | |
| center_point = (center_x, center_y) | |
| # Mettre à jour l'historique | |
| self.update_detection_history(track_id, box, frame_num) | |
| # Compter seulement les détections stables | |
| if self.is_stable_detection(track_id): | |
| if self.is_in_region(center_point, self.poly1): | |
| self.unique_region1_ids.add(track_id) | |
| if self.is_in_region(center_point, self.poly2): | |
| self.unique_region2_ids.add(track_id) | |
| # Dessiner la boîte (vert si stable, jaune sinon) | |
| color = (0, 255, 0) if self.is_stable_detection(track_id) else (0, 255, 255) | |
| cv2.rectangle(display_frame, (x_min, y_min), (x_max, y_max), color, 2) | |
| # Afficher l'ID et la confiance | |
| label = f"ID:{track_id} {conf:.2f}" | |
| cv2.putText(display_frame, label, (x_min, y_min - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| except Exception as e: | |
| draw_text_with_background(display_frame, f"Tracking error: {e}", (10, 60), bg_color=(80, 0, 0)) | |
| # Affichage du comptage | |
| # draw_text_with_background(display_frame, f"Total Sens 1: {len(self.unique_region1_ids)}", (10, frame_height - 50)) | |
| draw_text_with_background(display_frame, f"Total: {len(self.unique_region2_ids)}", (frame_width - 300, frame_height - 50)) | |
| return display_frame | |
| def process_webcam(self, camera_id=0, display_placeholder=None, count_placeholders=None): | |
| """Traite la vidéo en temps réel depuis une webcam""" | |
| cap = cv2.VideoCapture(camera_id) | |
| if not cap.isOpened(): | |
| st.error("⚠️ Erreur : Impossible d'ouvrir la webcam.") | |
| return | |
| try: | |
| cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) | |
| cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) | |
| cap.set(cv2.CAP_PROP_FPS, 30) | |
| except Exception: | |
| pass | |
| self.reset_counts() | |
| self.stop_processing = False | |
| frame_count = 0 | |
| last_ts = time.time() | |
| while not self.stop_processing: | |
| success, frame = cap.read() | |
| if not success: | |
| st.error("⚠️ Erreur lors de la lecture du flux vidéo.") | |
| break | |
| if frame_count % self.frame_skip == 0: | |
| processed_frame = self.process_frame(frame, frame_count) | |
| self.last_processed_frame = processed_frame | |
| now = time.time() | |
| dt = max(1e-6, now - last_ts) | |
| fps = 1.0 / dt | |
| last_ts = now | |
| if processed_frame is not None: | |
| draw_text_with_background(processed_frame, f"FPS: {fps:.1f}", (10, 30)) | |
| else: | |
| processed_frame = self.last_processed_frame if self.last_processed_frame is not None else frame | |
| if processed_frame is not None: | |
| try: | |
| processed_frame_rgb = cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) | |
| except Exception: | |
| processed_frame_rgb = processed_frame | |
| img = Image.fromarray(processed_frame_rgb) | |
| if display_placeholder: | |
| display_placeholder.image(img, channels="RGB", use_column_width=True) | |
| if count_placeholders and len(count_placeholders) >= 2: | |
| count_placeholders[0].metric("Véhicules Sens 1 (Vert)", len(self.unique_region1_ids)) | |
| count_placeholders[1].metric("Véhicules Sens 2 (Rouge)", len(self.unique_region2_ids)) | |
| frame_count += 1 | |
| time.sleep(0.01) | |
| cap.release() | |
| st.success("✅ Flux vidéo arrêté.") | |
| # --- INTERFACE STREAMLIT --- | |
| def main(): | |
| st.set_page_config( | |
| page_title="Détecteur de Véhicules", | |
| page_icon="🚗", | |
| layout="wide" | |
| ) | |
| st.title("🚗 Détection et comptage de Véhicules sur l'Autoroute de l'Avenir") | |
| # Session state | |
| st.session_state.setdefault("webcam_active", False) | |
| st.session_state.setdefault("processor", None) | |
| # Modèle | |
| model_path = "best.pt" | |
| if not os.path.exists(model_path): | |
| with st.spinner("📥 Chargement du modèle YOLO..."): | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| model_path = hf_hub_download(repo_id="ModuMLTECH/Trafic_congestion", filename="best.pt") | |
| st.success("✅ Modèle chargé depuis Hugging Face Hub.") | |
| except Exception as e: | |
| st.error(f"❌ Erreur lors du chargement du modèle: {e}") | |
| st.warning("⚠️ Utilisation du modèle YOLO public à la place (yolov8n.pt).") | |
| model_path = "yolov8n.pt" | |
| # Tabs | |
| tab1, tab2 = st.tabs(["📹 Analyse de Vidéo", "🎥 Détection en Temps Réel"]) | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("🔹 Paramètres") | |
| st.subheader("📍 Polygone 1 (vert)") | |
| poly1_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "0,0 0,0 0,0 0,0") | |
| st.subheader("📍 Polygone 2 (rouge)") | |
| poly2_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "500,150 700,150 1100,530 630,530") | |
| tracker_method = st.selectbox("Méthode de tracking", ["bot", "byte"], index=0) | |
| st.subheader("🚀 Paramètres d'optimisation") | |
| frame_skip = st.slider("Skip de frames", 1, 5, 2) | |
| downsample = st.slider("Facteur d'échelle", 0.3, 1.0, 0.5, 0.1) | |
| conf_threshold = st.slider("Seuil de confiance", 0.1, 0.9, 0.35, 0.05) | |
| st.subheader("🔧 Anti-duplicata") | |
| iou_thresh = st.slider("Seuil IoU (fusion détections)", 0.1, 0.9, 0.3, 0.05) | |
| min_area = st.slider("Surface minimale (pixels²)", 100, 2000, 500, 100) | |
| def parse_polygon(input_text): | |
| try: | |
| pts = [] | |
| for token in input_text.replace(";", " ").split(): | |
| x, y = token.split(",") | |
| pts.append((int(x), int(y))) | |
| return pts | |
| except Exception: | |
| return [] | |
| poly1 = parse_polygon(poly1_input) | |
| poly2 = parse_polygon(poly2_input) | |
| valid_polygons = len(poly1) == 4 and len(poly2) == 4 | |
| # Onglet 1: Analyse vidéo | |
| with tab1: | |
| uploaded_file = st.file_uploader("📂 Upload une vidéo", type=["mp4", "avi", "mkv", "mov"]) | |
| if uploaded_file is not None: | |
| temp_dir = tempfile.mkdtemp() | |
| ext = os.path.splitext(uploaded_file.name)[1].lower() or ".mp4" | |
| input_video_path = os.path.join(temp_dir, f"input_video{ext}") | |
| output_video_path = os.path.join(temp_dir, f"output_video{ext}") | |
| with open(input_video_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| st.video(input_video_path) | |
| if st.button("▶️ Lancer la détection"): | |
| if valid_polygons: | |
| progress_bar = st.progress(0) | |
| processor = YOLOVideoProcessor(model_path, poly1, poly2, tracker_method) | |
| processor.frame_skip = frame_skip | |
| processor.downsample_factor = downsample | |
| processor.conf_threshold = conf_threshold | |
| processor.iou_threshold = iou_thresh | |
| processor.min_box_area = min_area | |
| start_time = time.time() | |
| counts = processor.process_video(input_video_path, output_video_path, progress_bar=progress_bar) | |
| end_time = time.time() | |
| if counts: | |
| count1, count2 = counts | |
| st.success(f"✅ Traitement terminé en {end_time - start_time:.2f} s") | |
| col_result1, col_result2 = st.columns(2) | |
| col_result1.metric("Véhicules Sens 1 (Vert)", count1) | |
| col_result2.metric("Véhicules Sens 2 (Rouge)", count2) | |
| st.subheader("Vidéo traitée") | |
| st.video(output_video_path) | |
| with open(output_video_path, "rb") as file: | |
| st.download_button( | |
| label="⬇️ Télécharger la vidéo", | |
| data=file, | |
| file_name=f"video_traitee{ext}", | |
| mime=f"video/{ext.strip('.')}", | |
| ) | |
| else: | |
| st.error("❌ Les coordonnées des polygones doivent contenir **exactement 4 points**.") | |
| # Onglet 2: Webcam | |
| with tab2: | |
| st.header("Détection en Temps Réel avec Webcam") | |
| camera_options = {"Webcam par défaut": 0} | |
| for i in range(1, 5): | |
| try: | |
| cap = cv2.VideoCapture(i) | |
| if cap.isOpened(): | |
| camera_options[f"Caméra {i}"] = i | |
| cap.release() | |
| except Exception: | |
| pass | |
| selected_camera = st.selectbox("Sélectionnez la source vidéo", list(camera_options.keys())) | |
| camera_id = camera_options[selected_camera] | |
| video_placeholder = st.empty() | |
| col1, col2 = st.columns(2) | |
| count_placeholders = [col1.empty(), col2.empty()] | |
| st.info("ℹ️ Optimisations: redimensionnement, skip de frames, filtrage anti-duplicata, CUDA si disponible.") | |
| col_start, col_stop = st.columns(2) | |
| if col_start.button("▶️ Démarrer la détection en direct"): | |
| if not valid_polygons: | |
| st.error("❌ Les coordonnées des polygones doivent contenir **exactement 4 points**.") | |
| elif st.session_state.webcam_active: | |
| st.warning("⚠️ La webcam est déjà active !") | |
| else: | |
| processor = YOLOVideoProcessor(model_path, poly1, poly2, tracker_method) | |
| processor.frame_skip = frame_skip | |
| processor.downsample_factor = downsample | |
| processor.conf_threshold = conf_threshold | |
| processor.iou_threshold = iou_thresh | |
| processor.min_box_area = min_area | |
| st.session_state.processor = processor | |
| st.session_state.webcam_active = True | |
| threading.Thread( | |
| target=st.session_state.processor.process_webcam, | |
| args=(camera_id, video_placeholder, count_placeholders), | |
| daemon=True, | |
| ).start() | |
| if col_stop.button("⏹️ Arrêter la détection"): | |
| if st.session_state.webcam_active and st.session_state.processor: | |
| st.session_state.processor.stop_processing = True | |
| st.session_state.webcam_active = False | |
| time.sleep(0.5) | |
| video_placeholder.empty() | |
| else: | |
| st.warning("⚠️ Aucune détection en cours !") | |
| if __name__ == "__main__": | |
| main() |