Dougaya commited on
Commit
4d18315
·
verified ·
1 Parent(s): 4b1a7b1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ import numpy as np
4
+ import tempfile
5
+ import os
6
+ from ultralytics import YOLO
7
+ from deep_sort_realtime.deepsort_tracker import DeepSort
8
+ from collections import defaultdict
9
+
10
+ # Dictionnaire pour compter les objets détectés
11
+ class_counts = defaultdict(set)
12
+
13
+ # Charger modèle YOLOv8
14
+ model = YOLO("best.pt") # Assure-toi que ce fichier est bien dans le même dossier
15
+
16
+ # Initialiser DeepSORT
17
+ tracker = DeepSort(max_age=30)
18
+
19
+ # 📸 Détection image
20
+ def detect_on_image(image):
21
+ results = model(image)[0]
22
+ for box in results.boxes:
23
+ cls_id = int(box.cls[0])
24
+ conf = float(box.conf[0])
25
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
26
+ if conf > 0.4:
27
+ label = f"{model.names[cls_id]} {conf:.2f}"
28
+ cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 2)
29
+ cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2)
30
+ return image
31
+
32
+ # 🎥 Détection vidéo
33
+ def detect_and_track_video(video_path):
34
+ if not os.path.exists(video_path):
35
+ return None
36
+
37
+ cap = cv2.VideoCapture(video_path)
38
+ width = int(cap.get(3))
39
+ height = int(cap.get(4))
40
+ fps = cap.get(cv2.CAP_PROP_FPS)
41
+ temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
42
+ out = cv2.VideoWriter(temp_output.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
43
+ class_counts.clear()
44
+
45
+ while cap.isOpened():
46
+ ret, frame = cap.read()
47
+ if not ret:
48
+ break
49
+
50
+ results = model(frame)[0]
51
+ detections = []
52
+
53
+ for box in results.boxes:
54
+ cls_id = int(box.cls[0])
55
+ conf = float(box.conf[0])
56
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
57
+ if conf > 0.4:
58
+ detections.append(([x1, y1, x2 - x1, y2 - y1], conf, model.names[cls_id]))
59
+
60
+ tracks = tracker.update_tracks(detections, frame=frame)
61
+
62
+ for track in tracks:
63
+ if not track.is_confirmed():
64
+ continue
65
+ track_id = track.track_id
66
+ l, t, r, b = map(int, track.to_ltrb())
67
+ label = track.get_det_class()
68
+ cv2.rectangle(frame, (l, t), (r, b), (0, 255, 0), 2)
69
+ cv2.putText(frame, f'{label} ID {track_id}', (l, t - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
70
+ class_counts[label].add(track_id)
71
+
72
+ out.write(frame)
73
+
74
+ cap.release()
75
+ out.release()
76
+
77
+ return temp_output.name
78
+
79
+ # Interfaces Gradio
80
+ image_interface = gr.Interface(
81
+ fn=detect_on_image,
82
+ inputs=gr.Image(type="numpy", label="Image de surveillance"),
83
+ outputs=gr.Image(type="numpy", label="Image annotée"),
84
+ title="📸 Détection sur Image",
85
+ description="Détection de bagages et objets avec YOLOv8."
86
+ )
87
+
88
+ video_interface = gr.Interface(
89
+ fn=detect_and_track_video,
90
+ inputs=gr.Video(label="Vidéo de surveillance"),
91
+ outputs=gr.Video(label="Vidéo annotée avec suivi"),
92
+ title="🎥 Suivi sur Vidéo",
93
+ description="Suivi multi-objets avec DeepSORT + YOLOv8."
94
+ )
95
+
96
+ # Interface finale
97
+ gr.TabbedInterface(
98
+ [image_interface, video_interface],
99
+ tab_names=["📷 Image", "🎥 Vidéo"]
100
+ ).launch()