ModuMLTECH's picture
Update app.py
468dc21 verified
import streamlit as st
import cv2
import time
import numpy as np
from ultralytics import YOLO
from collections import defaultdict
import tempfile
import os
from pathlib import Path
import sys
import torch
# Fonction pour charger le modèle
def load_model(model_path):
try:
return YOLO(model_path, task="detect")
except Exception as e:
st.error(f"Erreur lors du chargement du modèle: {str(e)}")
return None
st.set_page_config(page_title="Détection et Suivi de densité de Trafic sur l'Autoroute de l'Avenir", page_icon="🚗", layout="wide")
# Couleurs pour les IDs de suivi
def get_color_for_id(track_id):
np.random.seed(track_id)
return tuple(np.random.randint(0, 255, size=3).tolist())
# Fonction pour afficher du texte avec fond
def draw_text_with_background(image, text, position, font=cv2.FONT_HERSHEY_SIMPLEX,
font_scale=1, font_thickness=2, text_color=(255, 255, 255), bg_color=(0, 0, 0), padding=5):
text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0]
text_width, text_height = text_size
x, y = position
top_left = (x, y - text_height - padding)
bottom_right = (x + text_width + padding * 2, y + padding)
cv2.rectangle(image, top_left, bottom_right, bg_color, -1)
cv2.putText(image, text, (x + padding, y), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
class VideoProcessor:
def __init__(self, model_path, video_path, output_path, poly1, poly2,
device=None, conf_threshold=0.25, traffic_threshold=2):
# Détection automatique du périphérique optimal
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
st.sidebar.info(f"Utilisation du périphérique: {self.device}")
# Charger le modèle
self.model = load_model(model_path)
if self.model is None:
st.error("Échec du chargement du modèle.")
st.stop()
self.conf_threshold = conf_threshold
self.track_history = defaultdict(list)
self.video_path = video_path
self.output_path = output_path
self.max_track_history = 15 # Limite de l'historique pour réduire la mémoire
self.traffic_threshold = traffic_threshold
# Utiliser les polygones fournis
self.poly1_np = np.array(poly1, dtype=np.int32)
self.poly2_np = np.array(poly2, dtype=np.int32)
def is_in_region(self, center, poly_np):
return cv2.pointPolygonTest(poly_np, center, False) >= 0
def process_video(self, stframe, progress_bar):
cap = cv2.VideoCapture(self.video_path)
if not cap.isOpened():
st.error(f"Impossible d'ouvrir la vidéo à {self.video_path}")
return None
# Métadonnées de la vidéo
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
input_fps = cap.get(cv2.CAP_PROP_FPS)
# Configuration de l'écriture vidéo
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(self.output_path, fourcc, input_fps, (frame_width, frame_height))
# Intervalle d'affichage pour Streamlit
display_interval = max(1, int(input_fps / 10))
# Variables pour le traitement
prev_frame_time = time.time()
frame_count = 0
process_count = 0
# Configuration du tracker
tracker = "botsort.yaml"
# Préparer les variables de texte
static_texts = {
'fps': 'FPS: {}',
'total': 'Total Vehicules: {}',
'region1': 'Trafic Sens 1: {}',
'region1_status': 'Situation Trafic: {}',
'region2': 'Trafic Sens 2: {}',
'region2_status': 'Situation Trafic: {}'
}
# Préparation des polygones
region1_poly = self.poly1_np.reshape((-1, 1, 2))
region2_poly = self.poly2_np.reshape((-1, 1, 2))
# Boucle principale
while cap.isOpened():
success, frame = cap.read()
if not success:
break
frame_count += 1
progress_bar.progress(min(frame_count / total_frames, 1.0))
process_count += 1
# Effectuer la détection
try:
results = self.model.track(
source=frame,
persist=True,
tracker=tracker,
classes=[0], # Classe pour les véhicules
conf=self.conf_threshold,
verbose=False
)
except Exception as e:
st.error(f"Erreur lors du traitement de l'image: {str(e)}")
continue
# Initialiser les tableaux
boxes = []
track_ids = []
confidences = []
if results and len(results) > 0 and hasattr(results[0], 'boxes') and len(results[0].boxes) > 0:
try:
boxes = results[0].boxes.xywh.cpu().numpy()
confidences = results[0].boxes.conf.cpu().numpy()
if hasattr(results[0].boxes, 'id'):
track_ids = results[0].boxes.id.int().cpu().tolist()
else:
track_ids = list(range(len(boxes)))
except Exception as e:
st.error(f"Erreur lors de l'extraction des résultats: {str(e)}")
# Créer une copie pour l'annotation
annotated_frame = frame.copy()
# Dessiner les régions
cv2.polylines(annotated_frame, [region1_poly], isClosed=True, color=(255, 0, 0), thickness=2)
cv2.polylines(annotated_frame, [region2_poly], isClosed=True, color=(0, 255, 0), thickness=2)
# Initialiser les ensembles pour les régions
current_region1_ids = set()
current_region2_ids = set()
# Traiter les détections
if boxes.size > 0 and len(track_ids) > 0:
for box, track_id, confidence in zip(boxes, track_ids, confidences):
# Extraire les coordonnées
x, y, w, h = box
# Obtenir la couleur pour cet ID
color = get_color_for_id(track_id)
# Dessiner le rectangle
cv2.rectangle(
annotated_frame,
(int(x - w / 2), int(y - h / 2)),
(int(x + w / 2), int(y + h / 2)),
color,
2
)
# Ajouter du texte
draw_text_with_background(
annotated_frame,
f'ID: {track_id} ({confidence:.2f})',
(int(x - w / 2), int(y - h / 2) - 10),
bg_color=color
)
# Point central et historique de suivi
center_point = (int(x), int(y))
cv2.circle(annotated_frame, center_point, radius=4, color=color, thickness=-1)
# Mettre à jour l'historique
self.track_history[track_id].append((float(x), float(y)))
if len(self.track_history[track_id]) > self.max_track_history:
self.track_history[track_id].pop(0)
# Dessiner le chemin
if len(self.track_history[track_id]) > 1:
points = np.array(self.track_history[track_id]).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(annotated_frame, [points], isClosed=False, color=color, thickness=2)
# Vérifier si dans une région
if self.is_in_region(center_point, self.poly1_np):
current_region1_ids.add(track_id)
if self.is_in_region(center_point, self.poly2_np):
current_region2_ids.add(track_id)
# Calculer les statistiques
all_cars = len(track_ids)
region1_cars = len(current_region1_ids)
region2_cars = len(current_region2_ids)
# Statut du trafic
region1_status = "Dense" if region1_cars > self.traffic_threshold else "Fluide"
region2_status = "Dense" if region2_cars > self.traffic_threshold else "Fluide"
# Calculer le FPS
new_frame_time = time.time()
fps = 1 / (new_frame_time - prev_frame_time) if prev_frame_time > 0 else 30
prev_frame_time = new_frame_time
# Afficher les informations
draw_text_with_background(annotated_frame, static_texts['fps'].format(int(fps)), (7, 30), bg_color=(0, 0, 0), text_color=(0, 255, 0))
draw_text_with_background(annotated_frame, static_texts['total'].format(all_cars), (7, 65), bg_color=(0, 0, 0), text_color=(0, 255, 0))
draw_text_with_background(annotated_frame, static_texts['region1'].format(region1_cars), (7, 105), bg_color=(255, 0, 0), text_color=(255, 255, 255))
region1_text_color = (0, 0, 255) if region1_status == "Dense" else (255, 255, 255)
draw_text_with_background(annotated_frame, static_texts['region1_status'].format(region1_status), (7, 139), bg_color=(255, 0, 0), text_color=region1_text_color)
draw_text_with_background(annotated_frame, static_texts['region2'].format(region2_cars), (880, 105), bg_color=(0, 255, 0), text_color=(255, 255, 255))
region2_text_color = (0, 0, 255) if region2_status == "Dense" else (255, 255, 255)
draw_text_with_background(annotated_frame, static_texts['region2_status'].format(region2_status), (880, 139), bg_color=(0, 255, 0), text_color=region2_text_color)
# Écrire la frame
out.write(annotated_frame)
# Mettre à jour l'affichage Streamlit
if process_count % display_interval == 0:
stframe.image(annotated_frame, channels="BGR", use_container_width=True)
# Libérer les ressources
cap.release()
out.release()
cv2.destroyAllWindows()
# Nettoyer la mémoire
self.track_history.clear()
return self.output_path
def main():
st.title("🚗 Détection et Suivi de densité de Trafic sur l'Autoroute de l'Avenir")
st.write("Cette application détecte les véhicules, et analyse la densité du trafic dans les deux sens d'un tronçon de l'Autoroute de l'Avenir.")
# Sidebar pour les options
st.sidebar.header("Configuration")
# Détection du matériel
if torch.cuda.is_available():
device_options = ["cuda", "cpu"]
device_default = "cuda"
st.sidebar.success("🚀 GPU détecté! Traitement accéléré disponible.")
else:
device_options = ["cpu"]
device_default = "cpu"
st.sidebar.warning("⚠️ Pas de GPU détecté. Performances limitées au CPU.")
device = st.sidebar.selectbox("Périphérique de calcul", device_options, index=device_options.index(device_default))
# Paramètres de détection
confidence = st.sidebar.slider("Seuil de confiance", 0.1, 1.0, 0.25, 0.05,
help="Un seuil plus élevé détecte moins d'objets mais avec plus de précision")
# Configuration des régions d'intérêt (ROI)
st.sidebar.subheader("Régions d'intérêt")
# Valeurs par défaut pour les polygones
# default_poly1 = [(465, 350), (609, 350), (520, 630), (3, 630)]
# default_poly2 = [(678, 350), (815, 350), (1203, 630), (743, 630)]
default_poly1 = [(900, 350), (1150, 350), (650, 630), (200, 630)]
default_poly2 = [(1200, 350), (1400, 350), (1150, 630), (743, 630)]
# Configuration des polygones
st.sidebar.markdown("#### Région 1 (Sens 1)")
poly1 = []
for i in range(4):
col1, col2 = st.sidebar.columns(2)
with col1:
x = st.number_input(f"Point {i+1} X", value=default_poly1[i][0], min_value=0, step=1, key=f"p1_x_{i}")
with col2:
y = st.number_input(f"Point {i+1} Y", value=default_poly1[i][1], min_value=0, step=1, key=f"p1_y_{i}")
poly1.append((x, y))
st.sidebar.markdown("#### Région 2 (Sens 2)")
poly2 = []
for i in range(4):
col1, col2 = st.sidebar.columns(2)
with col1:
x = st.number_input(f"Point {i+1} X", value=default_poly2[i][0], min_value=0, step=1, key=f"p2_x_{i}")
with col2:
y = st.number_input(f"Point {i+1} Y", value=default_poly2[i][1], min_value=0, step=1, key=f"p2_y_{i}")
poly2.append((x, y))
# Seuil pour trafic dense
traffic_threshold = st.sidebar.slider("Seuil de trafic dense", 1, 10, 2,
help="Nombre de véhicules à partir duquel le trafic est considéré comme dense")
# Upload du modèle
st.sidebar.subheader("Modèle de détection")
model_file = st.sidebar.file_uploader("Uploader votre modèle YOLO (best.pt)", type=["pt"])
# Upload de la vidéo
st.subheader("Vidéo à analyser")
video_file = st.file_uploader("Uploader une vidéo", type=["mp4", "avi", "mov", "mkv"])
# Afficher un aperçu des ROI sur la vidéo
if video_file:
try:
# Lire la première frame pour l'aperçu
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(video_file.name).suffix) as tmp_video:
tmp_video.write(video_file.getvalue())
video_path = tmp_video.name
cap = cv2.VideoCapture(video_path)
ret, frame = cap.read()
if ret:
# Dessiner les polygones
preview = frame.copy()
cv2.polylines(preview, [np.array(poly1)], isClosed=True, color=(255, 0, 0), thickness=2)
cv2.polylines(preview, [np.array(poly2)], isClosed=True, color=(0, 255, 0), thickness=2)
# Ajouter des labels
cv2.putText(preview, "Zone 1", (poly1[0][0], poly1[0][1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
cv2.putText(preview, "Zone 2", (poly2[0][0], poly2[0][1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# Afficher l'aperçu
st.subheader("Aperçu des zones de détection")
st.image(preview, channels="BGR", caption="Aperçu des régions d'intérêt configurées")
cap.release()
os.unlink(video_path)
else:
st.warning("Impossible de générer un aperçu des zones de détection.")
except Exception as e:
st.warning(f"Erreur lors de la génération de l'aperçu: {str(e)}")
if video_file and model_file:
# Sauvegarde temporaire des fichiers
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp_model:
tmp_model.write(model_file.getvalue())
model_path = tmp_model.name
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(video_file.name).suffix) as tmp_video:
tmp_video.write(video_file.getvalue())
video_path = tmp_video.name
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
st.success("Fichiers chargés avec succès. Prêt à traiter la vidéo.")
# Création des colonnes pour l'affichage
col1, col2 = st.columns(2)
with col1:
st.write("Vidéo originale:")
st_video_placeholder = st.empty()
st_video_placeholder.video(video_file)
with col2:
st.write("Vidéo traitée:")
stframe = st.empty()
progress_bar = st.progress(0)
# Bouton pour lancer le traitement
if st.button("Traiter la vidéo"):
start_time = time.time()
try:
with st.spinner("Traitement en cours..."):
# Traitement de la vidéo
video_processor = VideoProcessor(
model_path=model_path,
video_path=video_path,
output_path=output_path,
device=device,
conf_threshold=confidence,
poly1=poly1,
poly2=poly2,
traffic_threshold=traffic_threshold
)
output_video = video_processor.process_video(stframe, progress_bar)
if output_video:
# Temps de traitement
total_time = time.time() - start_time
st.success(f"✅ Traitement terminé en {total_time:.2f} secondes!")
# Téléchargement de la vidéo
try:
with open(output_video, "rb") as file:
video_bytes = file.read()
st.download_button(
label="Télécharger la vidéo traitée",
data=video_bytes,
file_name="video_traitee.mp4",
mime="video/mp4"
)
# Afficher la vidéo résultante
st.video(video_bytes)
except Exception as e:
st.error(f"Erreur lors de la préparation du téléchargement: {str(e)}")
else:
st.error("Échec du traitement de la vidéo.")
except Exception as e:
st.error(f"Une erreur s'est produite: {str(e)}")
st.error(f"Détails: {sys.exc_info()}")
# Nettoyer les fichiers temporaires
try:
if model_path and os.path.exists(model_path):
os.unlink(model_path)
if video_path and os.path.exists(video_path):
os.unlink(video_path)
if output_path and os.path.exists(output_path):
os.unlink(output_path)
except Exception as e:
st.warning(f"Problème lors du nettoyage des fichiers temporaires: {str(e)}")
else:
st.info("Veuillez uploader un modèle YOLO (best.pt) et une vidéo pour commencer.")
st.sidebar.markdown("---")
st.sidebar.markdown("### Comment utiliser cette application")
st.sidebar.markdown("""
1. Choisissez un périphérique de calcul (GPU recommandé)
2. Configurez les régions d'intérêt si nécessaire
3. Uploadez votre modèle YOLO (best.pt)
4. Uploadez une vidéo à analyser
5. Cliquez sur "Traiter la vidéo"
6. Visualisez les résultats et téléchargez la vidéo traitée
""")
if __name__ == "__main__":
main()