usiddiquee786 commited on
Commit
aebe355
·
verified ·
1 Parent(s): 77103ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +278 -103
app.py CHANGED
@@ -6,6 +6,9 @@ import shutil
6
  from pathlib import Path
7
  import sys
8
  import importlib.util
 
 
 
9
 
10
  # Ensure models directory exists
11
  MODELS_DIR = Path("models")
@@ -19,12 +22,14 @@ def ensure_dependencies():
19
  """Ensure all required dependencies are installed."""
20
  required_packages = [
21
  "ultralytics",
22
- "boxmot"
 
 
23
  ]
24
 
25
  for package in required_packages:
26
  try:
27
- importlib.import_module(package)
28
  print(f"✅ {package} is installed")
29
  except ImportError:
30
  print(f"⚠️ {package} is not installed, attempting to install...")
@@ -46,8 +51,219 @@ def apply_patches():
46
  else:
47
  print("⚠️ tracker_patch.py not found, skipping patches")
48
 
49
- def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold):
50
- """Run object tracking on the uploaded video."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  try:
52
  # Create temporary workspace
53
  with tempfile.TemporaryDirectory() as temp_dir:
@@ -55,107 +271,45 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids,
55
  input_path = os.path.join(temp_dir, "input_video.mp4")
56
  shutil.copy(video_file, input_path)
57
 
58
- # Prepare output directory
59
- output_dir = os.path.join(temp_dir, "output")
60
- os.makedirs(output_dir, exist_ok=True)
61
-
62
- # Build command
63
- cmd = [
64
- "python", "tracking/track.py",
65
- "--yolo-model", str(MODELS_DIR / yolo_model),
66
- "--reid-model", str(MODELS_DIR / reid_model),
67
- "--tracking-method", tracking_method,
68
- "--source", input_path,
69
- "--conf", str(conf_threshold),
70
- "--save",
71
- "--project", output_dir,
72
- "--name", "track",
73
- "--exist-ok"
74
- ]
75
 
76
- # Add class filtering if specific classes are provided
77
- if class_ids and class_ids.strip():
78
- # Parse the comma-separated class IDs
79
- try:
80
- # Split by comma and convert to integers to validate
81
- class_list = [int(c.strip()) for c in class_ids.split(",") if c.strip()]
82
- # Add each class ID as a separate argument
83
- if class_list:
84
- cmd.append("--classes")
85
- cmd.extend(str(c) for c in class_list)
86
- except ValueError:
87
- return None, "Invalid class IDs. Please enter comma-separated numbers (e.g., '0,1,2')."
88
 
89
- # Special handling for OcSort
90
- if tracking_method == "ocsort":
91
- cmd.append("--per-class")
92
 
93
- # Execute tracking with error handling
94
- print(f"Executing command: {' '.join(cmd)}")
95
- process = subprocess.run(
96
- cmd,
97
- capture_output=True,
98
- text=True
 
 
 
 
 
99
  )
100
 
101
- # Check for errors in output
102
- if process.returncode != 0:
103
- error_message = process.stderr or process.stdout
104
- print(f"Process failed with return code {process.returncode}")
105
- print(f"Error: {error_message}")
106
- return None, f"Error in tracking process: {error_message}"
107
-
108
- print(f"Process completed with return code {process.returncode}")
109
-
110
- # Find output video
111
- output_files = []
112
- for root, _, files in os.walk(output_dir):
113
- for file in files:
114
- if file.lower().endswith((".mp4", ".avi", ".mov")):
115
- output_files.append(os.path.join(root, file))
116
-
117
- print(f"Found output files: {output_files}")
118
-
119
- if not output_files:
120
- print("No output video files found")
121
- return None, "No output video was generated. Check if tracking was successful."
122
-
123
- output_file = output_files[0]
124
- print(f"Selected output file: {output_file}")
125
-
126
- # Verify file exists and has size
127
- if os.path.exists(output_file):
128
- file_size = os.path.getsize(output_file)
129
- print(f"Output file exists with size: {file_size} bytes")
130
-
131
- if file_size == 0:
132
- return None, "Output video was generated but has zero size."
133
-
134
  # Copy to permanent location with unique name
135
- permanent_path = os.path.join(OUTPUT_DIR, f"output_{os.path.basename(video_file)}")
136
- shutil.copy(output_file, permanent_path)
137
  print(f"Copied output to permanent location: {permanent_path}")
138
 
139
- # Ensure the file is in MP4 format for better compatibility with Gradio
140
- if not permanent_path.lower().endswith('.mp4'):
141
- mp4_path = os.path.splitext(permanent_path)[0] + '.mp4'
142
- try:
143
- print(f"Converting to MP4 format: {mp4_path}")
144
- subprocess.run([
145
- 'ffmpeg', '-i', permanent_path,
146
- '-c:v', 'libx264', '-preset', 'fast',
147
- '-c:a', 'aac', mp4_path
148
- ], check=True, capture_output=True)
149
- os.remove(permanent_path) # Remove the original file
150
- permanent_path = mp4_path
151
- except Exception as e:
152
- print(f"Failed to convert to MP4: {str(e)}")
153
- # Continue with original file if conversion fails
154
 
155
- return permanent_path, "Processing completed successfully!"
156
  else:
157
- print(f"Output file not found at {output_file}")
158
- return None, "Output file was referenced but doesn't exist on disk."
159
 
160
  except Exception as e:
161
  import traceback
@@ -163,13 +317,15 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids,
163
  return None, f"Error: {str(e)}"
164
 
165
  # Define the Gradio interface
166
- def process_video(video_path, yolo_model, reid_model, tracking_method, class_ids, conf_threshold):
 
167
  # Validate inputs
168
  if not video_path:
169
  return None, "Please upload a video file"
170
 
171
  print(f"Processing video: {video_path}")
172
  print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, classes={class_ids}, conf={conf_threshold}")
 
173
 
174
  output_path, status = run_tracking(
175
  video_path,
@@ -177,7 +333,9 @@ def process_video(video_path, yolo_model, reid_model, tracking_method, class_ids
177
  reid_model,
178
  tracking_method,
179
  class_ids,
180
- conf_threshold
 
 
181
  )
182
 
183
  if output_path:
@@ -193,15 +351,16 @@ def process_video(video_path, yolo_model, reid_model, tracking_method, class_ids
193
  yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
194
  reid_models = ["osnet_x0_25_msmt17.pt"]
195
  tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
 
196
 
197
  # Ensure dependencies and apply patches at startup
198
  ensure_dependencies()
199
  apply_patches()
200
 
201
  # Create the Gradio interface
202
- with gr.Blocks(title="YOLO Object Tracking") as app:
203
- gr.Markdown("# 🚀 YOLO Object Tracking")
204
- gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
205
 
206
  # Add class reference information
207
  with gr.Accordion("YOLO Class Reference", open=False):
@@ -261,15 +420,31 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
261
  label="Confidence Threshold"
262
  )
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  process_btn = gr.Button("Process Video", variant="primary")
265
 
266
  with gr.Column(scale=1):
267
- output_video = gr.Video(label="Output Video with Tracking")
268
- status_text = gr.Textbox(label="Status", value="Ready to process video")
269
 
270
  process_btn.click(
271
  fn=process_video,
272
- inputs=[input_video, yolo_model, reid_model, tracking_method, class_ids, conf_threshold],
 
273
  outputs=[output_video, status_text]
274
  )
275
 
 
6
  from pathlib import Path
7
  import sys
8
  import importlib.util
9
+ import cv2
10
+ import numpy as np
11
+ from ultralytics.utils.plotting import Annotator, colors
12
 
13
  # Ensure models directory exists
14
  MODELS_DIR = Path("models")
 
22
  """Ensure all required dependencies are installed."""
23
  required_packages = [
24
  "ultralytics",
25
+ "boxmot",
26
+ "opencv-python",
27
+ "numpy"
28
  ]
29
 
30
  for package in required_packages:
31
  try:
32
+ importlib.import_module(package.replace('-', '_'))
33
  print(f"✅ {package} is installed")
34
  except ImportError:
35
  print(f"⚠️ {package} is not installed, attempting to install...")
 
51
  else:
52
  print("⚠️ tracker_patch.py not found, skipping patches")
53
 
54
+ class LineCounter:
55
+ """Count objects crossing a line"""
56
+ def __init__(self, line_position=0.5, line_orientation='horizontal'):
57
+ """
58
+ Initialize a line counter
59
+
60
+ Args:
61
+ line_position: float between 0 and 1, position of line (default: middle)
62
+ line_orientation: 'horizontal' or 'vertical'
63
+ """
64
+ self.line_position = line_position
65
+ self.line_orientation = line_orientation
66
+ self.counts = {} # Track counts by class
67
+ self.crossed_ids = set() # Track IDs that have crossed the line
68
+ self.prev_positions = {} # Store previous positions of tracked objects
69
+
70
+ def update(self, bboxes, identities, clss, frame_shape):
71
+ """
72
+ Update counter with new detections
73
+
74
+ Args:
75
+ bboxes: list of bounding boxes [x1, y1, x2, y2]
76
+ identities: list of track IDs
77
+ clss: list of class IDs
78
+ frame_shape: tuple of (height, width)
79
+
80
+ Returns:
81
+ count_info: dict of counts by class
82
+ """
83
+ if not len(bboxes):
84
+ return self.counts, []
85
+
86
+ height, width = frame_shape[:2]
87
+
88
+ # Calculate line position in pixels
89
+ if self.line_orientation == 'horizontal':
90
+ line_pos = int(height * self.line_position)
91
+ start_point = (0, line_pos)
92
+ end_point = (width, line_pos)
93
+ else: # vertical
94
+ line_pos = int(width * self.line_position)
95
+ start_point = (line_pos, 0)
96
+ end_point = (line_pos, height)
97
+
98
+ # Store line info for drawing
99
+ line_info = {
100
+ 'start': start_point,
101
+ 'end': end_point,
102
+ 'orientation': self.line_orientation
103
+ }
104
+
105
+ # Process each detection
106
+ for i, (bbox, track_id, cls) in enumerate(zip(bboxes, identities, clss)):
107
+ x1, y1, x2, y2 = bbox
108
+
109
+ # Use center point of bbox to determine position
110
+ if self.line_orientation == 'horizontal':
111
+ center_pos = (y1 + y2) / 2
112
+ crossed_now = False
113
+
114
+ # Check if crossed the line
115
+ if track_id in self.prev_positions:
116
+ prev_pos = self.prev_positions[track_id]
117
+ # Crossed from above to below
118
+ if prev_pos < line_pos and center_pos >= line_pos:
119
+ crossed_now = True
120
+ # Crossed from below to above
121
+ elif prev_pos >= line_pos and center_pos < line_pos:
122
+ crossed_now = True
123
+
124
+ # Store current position for next frame
125
+ self.prev_positions[track_id] = center_pos
126
+
127
+ else: # vertical line
128
+ center_pos = (x1 + x2) / 2
129
+ crossed_now = False
130
+
131
+ # Check if crossed the line
132
+ if track_id in self.prev_positions:
133
+ prev_pos = self.prev_positions[track_id]
134
+ # Crossed from left to right
135
+ if prev_pos < line_pos and center_pos >= line_pos:
136
+ crossed_now = True
137
+ # Crossed from right to left
138
+ elif prev_pos >= line_pos and center_pos < line_pos:
139
+ crossed_now = True
140
+
141
+ # Store current position for next frame
142
+ self.prev_positions[track_id] = center_pos
143
+
144
+ # Count if crossed and not counted before
145
+ if crossed_now and track_id not in self.crossed_ids:
146
+ self.crossed_ids.add(track_id)
147
+ cls_id = int(cls)
148
+ self.counts[cls_id] = self.counts.get(cls_id, 0) + 1
149
+
150
+ return self.counts, line_info
151
+
152
+ def reset(self):
153
+ """Reset all counts and tracking info"""
154
+ self.counts = {}
155
+ self.crossed_ids = set()
156
+ self.prev_positions = {}
157
+
158
+ def process_video_with_counter(input_path, output_path, model_path, reid_model, tracking_method,
159
+ selected_classes, conf_threshold, line_position, line_orientation):
160
+ """Process video with the line counter"""
161
+ # Import here to avoid import errors if dependencies are missing
162
+ from ultralytics import YOLO
163
+
164
+ # Load the model
165
+ model = YOLO(model_path)
166
+
167
+ # Prepare classes filter
168
+ classes = None
169
+ if selected_classes and selected_classes.strip():
170
+ try:
171
+ classes = [int(c.strip()) for c in selected_classes.split(",") if c.strip()]
172
+ except ValueError:
173
+ print("Invalid class IDs, using all classes")
174
+
175
+ # Initialize video capture and get video info
176
+ cap = cv2.VideoCapture(input_path)
177
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
178
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
179
+ fps = cap.get(cv2.CAP_PROP_FPS)
180
+
181
+ # Initialize video writer
182
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
183
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
184
+
185
+ # Initialize line counter
186
+ counter = LineCounter(line_position=line_position, line_orientation=line_orientation)
187
+
188
+ # Track with YOLO
189
+ results = model.track(
190
+ source=input_path,
191
+ conf=conf_threshold,
192
+ classes=classes,
193
+ tracker=tracking_method,
194
+ save=False,
195
+ stream=True,
196
+ verbose=False
197
+ )
198
+
199
+ # Process each frame
200
+ for i, result in enumerate(results):
201
+ frame = result.orig_img
202
+
203
+ # Skip if no detections or tracking info
204
+ if result.boxes.id is None:
205
+ annotator = Annotator(frame)
206
+ # Draw the line
207
+ if line_orientation == 'horizontal':
208
+ line_y = int(height * line_position)
209
+ annotator.line((0, line_y), (width, line_y), color=(0, 255, 0), thickness=2)
210
+ else:
211
+ line_x = int(width * line_position)
212
+ annotator.line((line_x, 0), (line_x, height), color=(0, 255, 0), thickness=2)
213
+
214
+ # Add count text
215
+ count_text = "Count: 0"
216
+ annotator.text((20, 40), count_text, color=(0, 0, 255), thickness=2)
217
+
218
+ out.write(frame)
219
+ continue
220
+
221
+ # Get boxes, track IDs and classes
222
+ boxes = result.boxes.xyxy.cpu().numpy()
223
+ track_ids = result.boxes.id.cpu().numpy().astype(int)
224
+ classes = result.boxes.cls.cpu().numpy().astype(int)
225
+
226
+ # Update counter
227
+ counts, line_info = counter.update(boxes, track_ids, classes, frame.shape)
228
+
229
+ # Start drawing on frame
230
+ annotator = Annotator(frame)
231
+
232
+ # Draw tracking results
233
+ for box, track_id, cls in zip(boxes, track_ids, classes):
234
+ x1, y1, x2, y2 = map(int, box)
235
+ id = int(track_id)
236
+ color = colors(id % 10)
237
+ # Draw bbox
238
+ annotator.box_label([x1, y1, x2, y2], f'{id} {model.names[int(cls)]}', color=color)
239
+
240
+ # Draw the counting line
241
+ line_start, line_end = line_info['start'], line_info['end']
242
+ annotator.line(line_start, line_end, color=(0, 255, 0), thickness=2)
243
+
244
+ # Draw counts for each class
245
+ y_offset = 40
246
+ for cls_id, count in counts.items():
247
+ cls_name = model.names.get(cls_id, f"Class {cls_id}")
248
+ count_text = f"{cls_name}: {count}"
249
+ annotator.text((20, y_offset), count_text, color=(0, 0, 255), thickness=2)
250
+ y_offset += 30
251
+
252
+ # Write the frame
253
+ out.write(frame)
254
+
255
+ # Print progress
256
+ if i % 100 == 0:
257
+ print(f"Processed {i} frames")
258
+
259
+ # Release resources
260
+ cap.release()
261
+ out.release()
262
+ return counts
263
+
264
+ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold,
265
+ line_position, line_orientation):
266
+ """Run object tracking with line counting on the uploaded video."""
267
  try:
268
  # Create temporary workspace
269
  with tempfile.TemporaryDirectory() as temp_dir:
 
271
  input_path = os.path.join(temp_dir, "input_video.mp4")
272
  shutil.copy(video_file, input_path)
273
 
274
+ # Prepare output file
275
+ output_path = os.path.join(temp_dir, "output_video.mp4")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
+ # Get full model path
278
+ model_path = str(MODELS_DIR / yolo_model)
 
 
 
 
 
 
 
 
 
 
279
 
280
+ print(f"Processing video with counter. Model: {model_path}, Line: {line_orientation} at {line_position}")
 
 
281
 
282
+ # Process the video with our counter function
283
+ counts = process_video_with_counter(
284
+ input_path=input_path,
285
+ output_path=output_path,
286
+ model_path=model_path,
287
+ reid_model=reid_model,
288
+ tracking_method=tracking_method,
289
+ selected_classes=class_ids,
290
+ conf_threshold=conf_threshold,
291
+ line_position=line_position,
292
+ line_orientation=line_orientation
293
  )
294
 
295
+ # Check if output file exists and has size
296
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  # Copy to permanent location with unique name
298
+ permanent_path = os.path.join(OUTPUT_DIR, f"counted_{os.path.basename(video_file)}")
299
+ shutil.copy(output_path, permanent_path)
300
  print(f"Copied output to permanent location: {permanent_path}")
301
 
302
+ # Format counts for display
303
+ count_message = "Objects counted:\n"
304
+ if counts:
305
+ for cls_id, count in counts.items():
306
+ count_message += f"Class {cls_id}: {count}\n"
307
+ else:
308
+ count_message += "No objects crossed the line"
 
 
 
 
 
 
 
 
309
 
310
+ return permanent_path, count_message
311
  else:
312
+ return None, "Error: Output video was not generated properly."
 
313
 
314
  except Exception as e:
315
  import traceback
 
317
  return None, f"Error: {str(e)}"
318
 
319
  # Define the Gradio interface
320
+ def process_video(video_path, yolo_model, reid_model, tracking_method, class_ids, conf_threshold,
321
+ line_position, line_orientation):
322
  # Validate inputs
323
  if not video_path:
324
  return None, "Please upload a video file"
325
 
326
  print(f"Processing video: {video_path}")
327
  print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, classes={class_ids}, conf={conf_threshold}")
328
+ print(f"Line counter: {line_orientation} at position {line_position}")
329
 
330
  output_path, status = run_tracking(
331
  video_path,
 
333
  reid_model,
334
  tracking_method,
335
  class_ids,
336
+ conf_threshold,
337
+ line_position,
338
+ line_orientation
339
  )
340
 
341
  if output_path:
 
351
  yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
352
  reid_models = ["osnet_x0_25_msmt17.pt"]
353
  tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
354
+ line_orientations = ["horizontal", "vertical"]
355
 
356
  # Ensure dependencies and apply patches at startup
357
  ensure_dependencies()
358
  apply_patches()
359
 
360
  # Create the Gradio interface
361
+ with gr.Blocks(title="YOLO Object Tracking with Line Counter") as app:
362
+ gr.Markdown("# 🚀 YOLO Object Tracking with Line Counter")
363
+ gr.Markdown("Upload a video file to detect, track and count objects crossing a line. Processing may take a few minutes depending on video length.")
364
 
365
  # Add class reference information
366
  with gr.Accordion("YOLO Class Reference", open=False):
 
420
  label="Confidence Threshold"
421
  )
422
 
423
+ # Line counter settings
424
+ gr.Markdown("### Line Counter Settings")
425
+ line_orientation = gr.Dropdown(
426
+ choices=line_orientations,
427
+ value="horizontal",
428
+ label="Line Orientation"
429
+ )
430
+ line_position = gr.Slider(
431
+ minimum=0.1,
432
+ maximum=0.9,
433
+ value=0.5,
434
+ step=0.05,
435
+ label="Line Position (0.1 = top/left, 0.9 = bottom/right)"
436
+ )
437
+
438
  process_btn = gr.Button("Process Video", variant="primary")
439
 
440
  with gr.Column(scale=1):
441
+ output_video = gr.Video(label="Output Video with Tracking and Counting")
442
+ status_text = gr.Textbox(label="Count Results", value="Ready to process video")
443
 
444
  process_btn.click(
445
  fn=process_video,
446
+ inputs=[input_video, yolo_model, reid_model, tracking_method, class_ids, conf_threshold,
447
+ line_position, line_orientation],
448
  outputs=[output_video, status_text]
449
  )
450