import argparse import time from pathlib import Path from typing import Dict, List, Tuple, Optional import sys import os import queue import threading import tempfile from urllib.parse import urlparse import cv2 import requests sys.path.append(os.path.dirname(os.path.abspath(__file__))) from miner import Miner, BoundingBox def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="High-speed object detection benchmark on a video file." ) parser.add_argument( "--repo-path", type=Path, default="", help="Path to the HuggingFace/SecretVision repository (models, configs).", ) parser.add_argument( "--video-path", type=str, default="test.mp4", help="Path to the input video file or URL (http:// or https://).", ) parser.add_argument( "--video-url", type=str, default=None, help="URL to download video from (alternative to --video-path).", ) parser.add_argument( "--output-video", type=Path, default="outputs-detections/annotated.mp4", help="Optional path to save an annotated video with detections.", ) parser.add_argument( "--output-dir", type=Path, default="outputs-detections/frames", help="Optional directory to save annotated frames.", ) parser.add_argument( "--batch-size", type=int, default=None, help="Batch size for YOLO inference (None = process all frames at once).", ) parser.add_argument( "--stride", type=int, default=1, help="Sample every Nth frame from the video.", ) parser.add_argument( "--max-frames", type=int, default=None, help="Maximum number of frames to process (after stride).", ) parser.add_argument( "--conf-threshold", type=float, default=0.5, help="Confidence threshold for detections.", ) parser.add_argument( "--iou-threshold", type=float, default=0.45, help="IoU threshold used by YOLO NMS.", ) parser.add_argument( "--classes", type=int, nargs="+", default=None, help="Optional list of class IDs to keep (default: all classes).", ) parser.add_argument( "--no-visualization", action="store_true", help="Skip saving annotated frames/video to maximize throughput.", ) return parser.parse_args() def draw_boxes(frame, boxes: List[BoundingBox]) -> None: """Draw bounding boxes on a frame.""" if not boxes: return color_map = { 0: (0, 255, 255), # ball - cyan 1: (0, 165, 255), # goalkeeper - orange 2: (0, 255, 0), # player - green 3: (255, 0, 0), # referee - blue 4: (128, 128, 128), # gray 5: (255, 255, 0), # cyan 6: (255, 0, 255), # magenta 7: (0, 128, 255), # orange } h, w = frame.shape[:2] for box in boxes: # Validate and clamp coordinates x1 = max(0, min(int(box.x1), w - 1)) y1 = max(0, min(int(box.y1), h - 1)) x2 = max(x1 + 1, min(int(box.x2), w)) y2 = max(y1 + 1, min(int(box.y2), h)) if x2 <= x1 or y2 <= y1: continue # Skip invalid boxes color = color_map.get(box.cls_id, (255, 255, 255)) cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) label = f"{box.cls_id}:{box.conf:.2f}" cv2.putText( frame, label, (x1, max(12, y1 - 6)), cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, lineType=cv2.LINE_AA, ) def annotate_frame(frame, boxes: List[BoundingBox], frame_id: int) -> cv2.Mat: annotated = frame.copy() draw_boxes(annotated, boxes) info = f"Frame {frame_id} | Boxes: {len(boxes)}" cv2.putText( annotated, info, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, lineType=cv2.LINE_AA, ) return annotated def download_video_from_url(url: str, temp_dir: Optional[Path] = None) -> Path: """Download video from URL to a temporary file.""" print(f"Downloading video from {url}...") download_start = time.time() response = requests.get(url, stream=True, timeout=30) response.raise_for_status() if temp_dir is None: temp_dir = Path(tempfile.gettempdir()) else: temp_dir.mkdir(parents=True, exist_ok=True) # Get filename from URL or use a temp name parsed_url = urlparse(url) filename = os.path.basename(parsed_url.path) or "video.mp4" temp_file = temp_dir / f"temp_{int(time.time())}_{filename}" with open(temp_file, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) download_time = time.time() - download_start print(f"Download completed in {download_time:.3f}s") return temp_file def stream_video_frames( video_path: Path, frame_queue: queue.Queue, stride: int = 1, max_frames: Optional[int] = None, stop_event: Optional[threading.Event] = None, ) -> Tuple[int, float]: """ Decode video frames in a separate thread and put them in a queue. Returns: (total_frames_decoded, fps) """ cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): raise RuntimeError(f"Unable to open video: {video_path}") fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 frame_count = 0 source_idx = 0 decode_start = time.time() print(f"Decoding frames from {video_path}...") try: while True: if stop_event and stop_event.is_set(): break ret, frame = cap.read() if not ret: break if source_idx % stride == 0: frame_queue.put((frame_count, frame)) frame_count += 1 if max_frames and frame_count >= max_frames: break if frame_count % 100 == 0: print(f"Decoded {frame_count} frames...") source_idx += 1 finally: cap.release() frame_queue.put((None, None)) # Sentinel to signal end decode_time = time.time() - decode_start print(f"Total frames decoded: {frame_count} in {decode_time:.3f}s") return frame_count, fps def load_video_frames( video_path: Path, stride: int = 1, max_frames: Optional[int] = None ) -> List[cv2.Mat]: """Legacy function: load all frames into memory (non-streaming).""" cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): raise RuntimeError(f"Unable to open video: {video_path}") frames: List[cv2.Mat] = [] frame_count = 0 source_idx = 0 print(f"Loading frames from {video_path}") while True: ret, frame = cap.read() if not ret: break if source_idx % stride == 0: frames.append(frame) frame_count += 1 if max_frames and frame_count >= max_frames: break if frame_count % 100 == 0: print(f"Loaded {frame_count} frames...") source_idx += 1 cap.release() print(f"Total frames loaded: {len(frames)}") return frames def save_results( frames: List[cv2.Mat], detections: Dict[int, List[BoundingBox]], output_video: Optional[Path], output_dir: Optional[Path], fps: float, ) -> None: if output_video is None and output_dir is None: return if not frames: print("No frames to save.") return height, width = frames[0].shape[:2] writer = None if output_video: output_video.parent.mkdir(parents=True, exist_ok=True) writer = cv2.VideoWriter( str(output_video), cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height), ) print(f"Saving annotated video to {output_video}") if output_dir: output_dir.mkdir(parents=True, exist_ok=True) print(f"Saving annotated frames to {output_dir}") for frame_idx, frame in enumerate(frames): boxes = detections.get(frame_idx, []) annotated = annotate_frame(frame, boxes, frame_idx) if writer: writer.write(annotated) if output_dir: frame_path = output_dir / f"frame_{frame_idx:06d}.jpg" cv2.imwrite(str(frame_path), annotated) if (frame_idx + 1) % 100 == 0: print(f"Saved {frame_idx + 1}/{len(frames)} frames...") if writer: writer.release() print(f"Video saved to {output_video}") def aggregate_stats(detections: Dict[int, List[BoundingBox]]) -> Dict[str, float]: total_frames = len(detections) total_boxes = sum(len(boxes) for boxes in detections.values()) class_counts: Dict[int, int] = {} for boxes in detections.values(): for box in boxes: class_counts[box.cls_id] = class_counts.get(box.cls_id, 0) + 1 stats: Dict[str, float] = { "frames": total_frames, "boxes": total_boxes, } stats["avg_boxes_per_frame"] = ( total_boxes / total_frames if total_frames > 0 else 0.0 ) for cls_id, count in sorted(class_counts.items()): stats[f"class_{cls_id}_count"] = count return stats def detection_worker( miner: Miner, frame_queue: queue.Queue, result_queue: queue.Queue, batch_size: int, conf_threshold: float, iou_threshold: float, classes: Optional[List[int]], stop_event: threading.Event, ) -> None: """ Worker thread that processes frames for detection. Takes frames from frame_queue and puts results in result_queue. """ frame_batch: List[cv2.Mat] = [] frame_indices: List[int] = [] while True: if stop_event.is_set(): break try: item = frame_queue.get(timeout=0.5) frame_idx, frame = item if frame_idx is None: # Sentinel - decoding finished # Process remaining frames in batch if frame_batch: batch_detections = miner.predict_objects( images=frame_batch, batch_size=None, conf_threshold=conf_threshold, iou_threshold=iou_threshold, classes=classes, verbose=False, ) result_queue.put(('batch', { 'indices': frame_indices, 'detections': batch_detections, 'frames': frame_batch.copy(), })) result_queue.put(('done', None)) break frame_batch.append(frame) frame_indices.append(frame_idx) # Process batch when full if len(frame_batch) >= batch_size: batch_detections = miner.predict_objects( images=frame_batch, batch_size=None, conf_threshold=conf_threshold, iou_threshold=iou_threshold, classes=classes, verbose=False, ) # Debug: Check what we got total_boxes_in_batch = sum(len(boxes) for boxes in batch_detections.values()) if total_boxes_in_batch > 0: print(f"Detection worker: Processed batch of {len(frame_batch)} frames, " f"found {total_boxes_in_batch} boxes, " f"detection keys: {list(batch_detections.keys())}, " f"frame indices: {frame_indices[:5]}...") result_queue.put(('batch', { 'indices': frame_indices.copy(), 'detections': batch_detections, 'frames': frame_batch.copy(), })) frame_batch.clear() frame_indices.clear() except queue.Empty: continue except Exception as e: print(f"Error in detection worker: {e}") result_queue.put(('error', str(e))) break def process_video_streaming( miner: Miner, video_path: Path, batch_size: Optional[int], conf_threshold: float, iou_threshold: float, classes: Optional[List[int]], stride: int, max_frames: Optional[int], ) -> Tuple[Dict[int, List[BoundingBox]], List[cv2.Mat], float, float]: """ Process video with truly parallel decode and detection. Decode thread and detection thread run simultaneously. Returns: (detections, frames, fps, total_time) """ frame_queue: queue.Queue = queue.Queue(maxsize=50) # Buffer for decoded frames result_queue: queue.Queue = queue.Queue() # Results from detection frames_queue: queue.Queue = queue.Queue() # Store all decoded frames separately stop_event = threading.Event() effective_batch = batch_size if batch_size else 16 # Modified decode function that also stores frames def decode_and_store_frames(): cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): raise RuntimeError(f"Unable to open video: {video_path}") fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 frame_count = 0 source_idx = 0 decode_start = time.time() print(f"Decoding frames from {video_path}...") try: while True: if stop_event.is_set(): break ret, frame = cap.read() if not ret: break if source_idx % stride == 0: frame_queue.put((frame_count, frame)) frames_queue.put((frame_count, frame)) # Store frame separately frame_count += 1 if max_frames and frame_count >= max_frames: break if frame_count % 100 == 0: print(f"Decoded {frame_count} frames...") source_idx += 1 finally: cap.release() frame_queue.put((None, None)) # Sentinel to signal end frames_queue.put((None, None)) # Sentinel for frames queue decode_time = time.time() - decode_start print(f"Total frames decoded: {frame_count} in {decode_time:.3f}s") return frame_count, fps # Start decode thread decode_thread = threading.Thread( target=decode_and_store_frames, daemon=True, ) # Start detection thread detect_thread = threading.Thread( target=detection_worker, args=(miner, frame_queue, result_queue, effective_batch, conf_threshold, iou_threshold, classes, stop_event), daemon=True, ) print("\n" + "=" * 60) print("Running parallel decode + detection...") print(f"Batch size: {effective_batch}") print(f"Conf threshold: {conf_threshold}") print(f"IoU threshold: {iou_threshold}") if classes: print(f"Classes filtered: {classes}") total_time_start = time.time() decode_thread.start() detect_thread.start() # Collect all decoded frames first (independent of detection) frames_dict: Dict[int, cv2.Mat] = {} while True: try: frame_idx, frame = frames_queue.get(timeout=1.0) if frame_idx is None: break frames_dict[frame_idx] = frame except queue.Empty: if not decode_thread.is_alive(): break continue # Collect results from detection thread all_batches = [] # Store all batch results frames_processed = 0 detection_done = False while not detection_done: try: result_type, result_data = result_queue.get(timeout=2.0) if result_type == 'batch': batch_boxes = sum(len(boxes) for boxes in result_data['detections'].values()) all_batches.append(result_data) frames_processed += len(result_data['indices']) if batch_boxes > 0: print(f"Collected batch: {len(result_data['indices'])} frames, {batch_boxes} boxes") if frames_processed % 100 == 0: print(f"Processed {frames_processed} frames...") elif result_type == 'done': detection_done = True break elif result_type == 'error': print(f"Detection error: {result_data}") detection_done = True break except queue.Empty: # Check if threads are still alive if not detect_thread.is_alive(): detection_done = True break continue # Assemble detections in correct order detections: Dict[int, List[BoundingBox]] = {} print(f"Debug: Assembling detections from {len(all_batches)} batches...") for batch_idx, batch_data in enumerate(all_batches): batch_indices = batch_data['indices'] batch_detections = batch_data['detections'] # Debug first batch if batch_idx == 0: print(f"Debug batch 0: {len(batch_indices)} frame indices, " f"detection keys: {list(batch_detections.keys())}, " f"total boxes in batch: {sum(len(boxes) for boxes in batch_detections.values())}") for local_idx, orig_idx in enumerate(batch_indices): boxes = batch_detections.get(local_idx, []) detections[orig_idx] = boxes if batch_idx == 0 and local_idx < 3 and len(boxes) > 0: print(f"Debug: Frame {orig_idx} (local_idx {local_idx}) has {len(boxes)} boxes") # Convert frames_dict to ordered list if frames_dict: max_idx = max(frames_dict.keys()) frames = [frames_dict[i] for i in range(max_idx + 1) if i in frames_dict] # Debug: Print detection statistics total_detections = sum(len(boxes) for boxes in detections.values()) frames_with_detections = sum(1 for boxes in detections.values() if len(boxes) > 0) print(f"Debug: {len(frames)} frames, {len(detections)} detection entries, " f"{total_detections} total boxes, {frames_with_detections} frames with detections") else: frames = [] # Wait for threads to finish decode_thread.join(timeout=5.0) detect_thread.join(timeout=10.0) total_time = time.time() - total_time_start # Get FPS from video metadata cap = cv2.VideoCapture(str(video_path)) fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 cap.release() return detections, frames, fps, total_time def main() -> None: args = parse_args() print("Initializing Miner...") init_start = time.time() miner = Miner(args.repo_path) print(f"Miner initialized in {time.time() - init_start:.2f}s") # Handle URL or local file video_path = args.video_url if args.video_url else args.video_path temp_file = None # Check if it's a URL if str(video_path).startswith(('http://', 'https://')): print("\n" + "=" * 60) temp_file = download_video_from_url(str(video_path)) video_path = temp_file # Use streaming mode for parallel processing print("\n" + "=" * 60) process_start = time.time() detections, frames, fps, total_time = process_video_streaming( miner=miner, video_path=Path(video_path), batch_size=args.batch_size, conf_threshold=args.conf_threshold, iou_threshold=args.iou_threshold, classes=args.classes, stride=args.stride, max_frames=args.max_frames, ) # Clean up temp file if downloaded if temp_file and temp_file.exists(): try: temp_file.unlink() print(f"Cleaned up temporary file: {temp_file}") except Exception as e: print(f"Warning: Could not delete temp file {temp_file}: {e}") total_frames = len(frames) fps_achieved = total_frames / total_time if total_time > 0 else 0.0 time_per_frame = total_time / total_frames if total_frames > 0 else 0.0 print("\n" + "=" * 60) print("OBJECT DETECTION PERFORMANCE") print("=" * 60) print(f"Total frames processed: {total_frames}") print(f"Total processing time: {total_time:.3f}s") print(f"Average time per frame: {time_per_frame*1000:.2f} ms") print(f"Throughput: {fps_achieved:.2f} FPS") stats = aggregate_stats(detections) print("\n" + "=" * 60) print("DETECTION STATISTICS") print("=" * 60) for key, value in stats.items(): if isinstance(value, float): print(f"{key}: {value:.2f}") else: print(f"{key}: {value}") if not args.no_visualization and (args.output_video or args.output_dir) and frames: print("\n" + "=" * 60) print("Saving annotated outputs...") save_start = time.time() save_results( frames=frames, detections=detections, output_video=args.output_video, output_dir=args.output_dir, fps=fps / args.stride, ) print(f"Outputs saved in {time.time() - save_start:.2f}s") elif not frames: print("\n" + "=" * 60) print("No frames processed. Skipping output saving.") print("\n" + "=" * 60) print("Done!") print("=" * 60) if __name__ == "__main__": main()