| import argparse
|
| import time
|
| from pathlib import Path
|
| from typing import Dict, List, Tuple, Optional
|
| import sys
|
| import os
|
| import queue
|
| import threading
|
| import tempfile
|
| from urllib.parse import urlparse
|
|
|
| import cv2
|
| import requests
|
|
|
| sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
| from miner import Miner, BoundingBox
|
|
|
|
|
| def parse_args() -> argparse.Namespace:
|
| parser = argparse.ArgumentParser(
|
| description="High-speed object detection benchmark on a video file."
|
| )
|
| parser.add_argument(
|
| "--repo-path",
|
| type=Path,
|
| default="",
|
| help="Path to the HuggingFace/SecretVision repository (models, configs).",
|
| )
|
| parser.add_argument(
|
| "--video-path",
|
| type=str,
|
| default="test.mp4",
|
| help="Path to the input video file or URL (http:// or https://).",
|
| )
|
| parser.add_argument(
|
| "--video-url",
|
| type=str,
|
| default=None,
|
| help="URL to download video from (alternative to --video-path).",
|
| )
|
| parser.add_argument(
|
| "--output-video",
|
| type=Path,
|
| default="outputs-detections/annotated.mp4",
|
| help="Optional path to save an annotated video with detections.",
|
| )
|
| parser.add_argument(
|
| "--output-dir",
|
| type=Path,
|
| default="outputs-detections/frames",
|
| help="Optional directory to save annotated frames.",
|
| )
|
| parser.add_argument(
|
| "--batch-size",
|
| type=int,
|
| default=None,
|
| help="Batch size for YOLO inference (None = process all frames at once).",
|
| )
|
| parser.add_argument(
|
| "--stride",
|
| type=int,
|
| default=1,
|
| help="Sample every Nth frame from the video.",
|
| )
|
| parser.add_argument(
|
| "--max-frames",
|
| type=int,
|
| default=None,
|
| help="Maximum number of frames to process (after stride).",
|
| )
|
| parser.add_argument(
|
| "--conf-threshold",
|
| type=float,
|
| default=0.5,
|
| help="Confidence threshold for detections.",
|
| )
|
| parser.add_argument(
|
| "--iou-threshold",
|
| type=float,
|
| default=0.45,
|
| help="IoU threshold used by YOLO NMS.",
|
| )
|
| parser.add_argument(
|
| "--classes",
|
| type=int,
|
| nargs="+",
|
| default=None,
|
| help="Optional list of class IDs to keep (default: all classes).",
|
| )
|
| parser.add_argument(
|
| "--no-visualization",
|
| action="store_true",
|
| help="Skip saving annotated frames/video to maximize throughput.",
|
| )
|
| return parser.parse_args()
|
|
|
|
|
| def draw_boxes(frame, boxes: List[BoundingBox]) -> None:
|
| """Draw bounding boxes on a frame."""
|
| if not boxes:
|
| return
|
|
|
| color_map = {
|
| 0: (0, 255, 255),
|
| 1: (0, 165, 255),
|
| 2: (0, 255, 0),
|
| 3: (255, 0, 0),
|
| 4: (128, 128, 128),
|
| 5: (255, 255, 0),
|
| 6: (255, 0, 255),
|
| 7: (0, 128, 255),
|
| }
|
|
|
| h, w = frame.shape[:2]
|
|
|
| for box in boxes:
|
|
|
| x1 = max(0, min(int(box.x1), w - 1))
|
| y1 = max(0, min(int(box.y1), h - 1))
|
| x2 = max(x1 + 1, min(int(box.x2), w))
|
| y2 = max(y1 + 1, min(int(box.y2), h))
|
|
|
| if x2 <= x1 or y2 <= y1:
|
| continue
|
|
|
| color = color_map.get(box.cls_id, (255, 255, 255))
|
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
| label = f"{box.cls_id}:{box.conf:.2f}"
|
| cv2.putText(
|
| frame,
|
| label,
|
| (x1, max(12, y1 - 6)),
|
| cv2.FONT_HERSHEY_SIMPLEX,
|
| 0.4,
|
| color,
|
| 1,
|
| lineType=cv2.LINE_AA,
|
| )
|
|
|
|
|
| def annotate_frame(frame, boxes: List[BoundingBox], frame_id: int) -> cv2.Mat:
|
| annotated = frame.copy()
|
| draw_boxes(annotated, boxes)
|
| info = f"Frame {frame_id} | Boxes: {len(boxes)}"
|
| cv2.putText(
|
| annotated,
|
| info,
|
| (10, 25),
|
| cv2.FONT_HERSHEY_SIMPLEX,
|
| 0.7,
|
| (255, 255, 255),
|
| 2,
|
| lineType=cv2.LINE_AA,
|
| )
|
| return annotated
|
|
|
|
|
| def download_video_from_url(url: str, temp_dir: Optional[Path] = None) -> Path:
|
| """Download video from URL to a temporary file."""
|
| print(f"Downloading video from {url}...")
|
| download_start = time.time()
|
|
|
| response = requests.get(url, stream=True, timeout=30)
|
| response.raise_for_status()
|
|
|
| if temp_dir is None:
|
| temp_dir = Path(tempfile.gettempdir())
|
| else:
|
| temp_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| parsed_url = urlparse(url)
|
| filename = os.path.basename(parsed_url.path) or "video.mp4"
|
| temp_file = temp_dir / f"temp_{int(time.time())}_{filename}"
|
|
|
| with open(temp_file, 'wb') as f:
|
| for chunk in response.iter_content(chunk_size=8192):
|
| f.write(chunk)
|
|
|
| download_time = time.time() - download_start
|
| print(f"Download completed in {download_time:.3f}s")
|
| return temp_file
|
|
|
|
|
| def stream_video_frames(
|
| video_path: Path,
|
| frame_queue: queue.Queue,
|
| stride: int = 1,
|
| max_frames: Optional[int] = None,
|
| stop_event: Optional[threading.Event] = None,
|
| ) -> Tuple[int, float]:
|
| """
|
| Decode video frames in a separate thread and put them in a queue.
|
| Returns: (total_frames_decoded, fps)
|
| """
|
| cap = cv2.VideoCapture(str(video_path))
|
| if not cap.isOpened():
|
| raise RuntimeError(f"Unable to open video: {video_path}")
|
|
|
| fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
| frame_count = 0
|
| source_idx = 0
|
| decode_start = time.time()
|
|
|
| print(f"Decoding frames from {video_path}...")
|
| try:
|
| while True:
|
| if stop_event and stop_event.is_set():
|
| break
|
|
|
| ret, frame = cap.read()
|
| if not ret:
|
| break
|
|
|
| if source_idx % stride == 0:
|
| frame_queue.put((frame_count, frame))
|
| frame_count += 1
|
| if max_frames and frame_count >= max_frames:
|
| break
|
| if frame_count % 100 == 0:
|
| print(f"Decoded {frame_count} frames...")
|
|
|
| source_idx += 1
|
| finally:
|
| cap.release()
|
| frame_queue.put((None, None))
|
|
|
| decode_time = time.time() - decode_start
|
| print(f"Total frames decoded: {frame_count} in {decode_time:.3f}s")
|
| return frame_count, fps
|
|
|
|
|
| def load_video_frames(
|
| video_path: Path, stride: int = 1, max_frames: Optional[int] = None
|
| ) -> List[cv2.Mat]:
|
| """Legacy function: load all frames into memory (non-streaming)."""
|
| cap = cv2.VideoCapture(str(video_path))
|
| if not cap.isOpened():
|
| raise RuntimeError(f"Unable to open video: {video_path}")
|
|
|
| frames: List[cv2.Mat] = []
|
| frame_count = 0
|
| source_idx = 0
|
|
|
| print(f"Loading frames from {video_path}")
|
| while True:
|
| ret, frame = cap.read()
|
| if not ret:
|
| break
|
|
|
| if source_idx % stride == 0:
|
| frames.append(frame)
|
| frame_count += 1
|
| if max_frames and frame_count >= max_frames:
|
| break
|
| if frame_count % 100 == 0:
|
| print(f"Loaded {frame_count} frames...")
|
|
|
| source_idx += 1
|
|
|
| cap.release()
|
| print(f"Total frames loaded: {len(frames)}")
|
| return frames
|
|
|
|
|
| def save_results(
|
| frames: List[cv2.Mat],
|
| detections: Dict[int, List[BoundingBox]],
|
| output_video: Optional[Path],
|
| output_dir: Optional[Path],
|
| fps: float,
|
| ) -> None:
|
| if output_video is None and output_dir is None:
|
| return
|
|
|
| if not frames:
|
| print("No frames to save.")
|
| return
|
|
|
| height, width = frames[0].shape[:2]
|
| writer = None
|
| if output_video:
|
| output_video.parent.mkdir(parents=True, exist_ok=True)
|
| writer = cv2.VideoWriter(
|
| str(output_video),
|
| cv2.VideoWriter_fourcc(*"mp4v"),
|
| fps,
|
| (width, height),
|
| )
|
| print(f"Saving annotated video to {output_video}")
|
|
|
| if output_dir:
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
| print(f"Saving annotated frames to {output_dir}")
|
|
|
| for frame_idx, frame in enumerate(frames):
|
| boxes = detections.get(frame_idx, [])
|
| annotated = annotate_frame(frame, boxes, frame_idx)
|
|
|
| if writer:
|
| writer.write(annotated)
|
| if output_dir:
|
| frame_path = output_dir / f"frame_{frame_idx:06d}.jpg"
|
| cv2.imwrite(str(frame_path), annotated)
|
|
|
| if (frame_idx + 1) % 100 == 0:
|
| print(f"Saved {frame_idx + 1}/{len(frames)} frames...")
|
|
|
| if writer:
|
| writer.release()
|
| print(f"Video saved to {output_video}")
|
|
|
|
|
| def aggregate_stats(detections: Dict[int, List[BoundingBox]]) -> Dict[str, float]:
|
| total_frames = len(detections)
|
| total_boxes = sum(len(boxes) for boxes in detections.values())
|
|
|
| class_counts: Dict[int, int] = {}
|
| for boxes in detections.values():
|
| for box in boxes:
|
| class_counts[box.cls_id] = class_counts.get(box.cls_id, 0) + 1
|
|
|
| stats: Dict[str, float] = {
|
| "frames": total_frames,
|
| "boxes": total_boxes,
|
| }
|
| stats["avg_boxes_per_frame"] = (
|
| total_boxes / total_frames if total_frames > 0 else 0.0
|
| )
|
| for cls_id, count in sorted(class_counts.items()):
|
| stats[f"class_{cls_id}_count"] = count
|
|
|
| return stats
|
|
|
|
|
| def detection_worker(
|
| miner: Miner,
|
| frame_queue: queue.Queue,
|
| result_queue: queue.Queue,
|
| batch_size: int,
|
| conf_threshold: float,
|
| iou_threshold: float,
|
| classes: Optional[List[int]],
|
| stop_event: threading.Event,
|
| ) -> None:
|
| """
|
| Worker thread that processes frames for detection.
|
| Takes frames from frame_queue and puts results in result_queue.
|
| """
|
| frame_batch: List[cv2.Mat] = []
|
| frame_indices: List[int] = []
|
|
|
| while True:
|
| if stop_event.is_set():
|
| break
|
|
|
| try:
|
| item = frame_queue.get(timeout=0.5)
|
| frame_idx, frame = item
|
|
|
| if frame_idx is None:
|
|
|
| if frame_batch:
|
| batch_detections = miner.predict_objects(
|
| images=frame_batch,
|
| batch_size=None,
|
| conf_threshold=conf_threshold,
|
| iou_threshold=iou_threshold,
|
| classes=classes,
|
| verbose=False,
|
| )
|
|
|
| result_queue.put(('batch', {
|
| 'indices': frame_indices,
|
| 'detections': batch_detections,
|
| 'frames': frame_batch.copy(),
|
| }))
|
|
|
| result_queue.put(('done', None))
|
| break
|
|
|
| frame_batch.append(frame)
|
| frame_indices.append(frame_idx)
|
|
|
|
|
| if len(frame_batch) >= batch_size:
|
| batch_detections = miner.predict_objects(
|
| images=frame_batch,
|
| batch_size=None,
|
| conf_threshold=conf_threshold,
|
| iou_threshold=iou_threshold,
|
| classes=classes,
|
| verbose=False,
|
| )
|
|
|
|
|
| total_boxes_in_batch = sum(len(boxes) for boxes in batch_detections.values())
|
| if total_boxes_in_batch > 0:
|
| print(f"Detection worker: Processed batch of {len(frame_batch)} frames, "
|
| f"found {total_boxes_in_batch} boxes, "
|
| f"detection keys: {list(batch_detections.keys())}, "
|
| f"frame indices: {frame_indices[:5]}...")
|
|
|
| result_queue.put(('batch', {
|
| 'indices': frame_indices.copy(),
|
| 'detections': batch_detections,
|
| 'frames': frame_batch.copy(),
|
| }))
|
|
|
| frame_batch.clear()
|
| frame_indices.clear()
|
|
|
| except queue.Empty:
|
| continue
|
| except Exception as e:
|
| print(f"Error in detection worker: {e}")
|
| result_queue.put(('error', str(e)))
|
| break
|
|
|
|
|
| def process_video_streaming(
|
| miner: Miner,
|
| video_path: Path,
|
| batch_size: Optional[int],
|
| conf_threshold: float,
|
| iou_threshold: float,
|
| classes: Optional[List[int]],
|
| stride: int,
|
| max_frames: Optional[int],
|
| ) -> Tuple[Dict[int, List[BoundingBox]], List[cv2.Mat], float, float]:
|
| """
|
| Process video with truly parallel decode and detection.
|
| Decode thread and detection thread run simultaneously.
|
| Returns: (detections, frames, fps, total_time)
|
| """
|
| frame_queue: queue.Queue = queue.Queue(maxsize=50)
|
| result_queue: queue.Queue = queue.Queue()
|
| frames_queue: queue.Queue = queue.Queue()
|
| stop_event = threading.Event()
|
|
|
| effective_batch = batch_size if batch_size else 16
|
|
|
|
|
| def decode_and_store_frames():
|
| cap = cv2.VideoCapture(str(video_path))
|
| if not cap.isOpened():
|
| raise RuntimeError(f"Unable to open video: {video_path}")
|
|
|
| fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
| frame_count = 0
|
| source_idx = 0
|
| decode_start = time.time()
|
|
|
| print(f"Decoding frames from {video_path}...")
|
| try:
|
| while True:
|
| if stop_event.is_set():
|
| break
|
|
|
| ret, frame = cap.read()
|
| if not ret:
|
| break
|
|
|
| if source_idx % stride == 0:
|
| frame_queue.put((frame_count, frame))
|
| frames_queue.put((frame_count, frame))
|
| frame_count += 1
|
| if max_frames and frame_count >= max_frames:
|
| break
|
| if frame_count % 100 == 0:
|
| print(f"Decoded {frame_count} frames...")
|
|
|
| source_idx += 1
|
| finally:
|
| cap.release()
|
| frame_queue.put((None, None))
|
| frames_queue.put((None, None))
|
|
|
| decode_time = time.time() - decode_start
|
| print(f"Total frames decoded: {frame_count} in {decode_time:.3f}s")
|
| return frame_count, fps
|
|
|
|
|
| decode_thread = threading.Thread(
|
| target=decode_and_store_frames,
|
| daemon=True,
|
| )
|
|
|
|
|
| detect_thread = threading.Thread(
|
| target=detection_worker,
|
| args=(miner, frame_queue, result_queue, effective_batch,
|
| conf_threshold, iou_threshold, classes, stop_event),
|
| daemon=True,
|
| )
|
|
|
| print("\n" + "=" * 60)
|
| print("Running parallel decode + detection...")
|
| print(f"Batch size: {effective_batch}")
|
| print(f"Conf threshold: {conf_threshold}")
|
| print(f"IoU threshold: {iou_threshold}")
|
| if classes:
|
| print(f"Classes filtered: {classes}")
|
|
|
| total_time_start = time.time()
|
| decode_thread.start()
|
| detect_thread.start()
|
|
|
|
|
| frames_dict: Dict[int, cv2.Mat] = {}
|
| while True:
|
| try:
|
| frame_idx, frame = frames_queue.get(timeout=1.0)
|
| if frame_idx is None:
|
| break
|
| frames_dict[frame_idx] = frame
|
| except queue.Empty:
|
| if not decode_thread.is_alive():
|
| break
|
| continue
|
|
|
|
|
| all_batches = []
|
| frames_processed = 0
|
| detection_done = False
|
|
|
| while not detection_done:
|
| try:
|
| result_type, result_data = result_queue.get(timeout=2.0)
|
|
|
| if result_type == 'batch':
|
| batch_boxes = sum(len(boxes) for boxes in result_data['detections'].values())
|
| all_batches.append(result_data)
|
| frames_processed += len(result_data['indices'])
|
| if batch_boxes > 0:
|
| print(f"Collected batch: {len(result_data['indices'])} frames, {batch_boxes} boxes")
|
| if frames_processed % 100 == 0:
|
| print(f"Processed {frames_processed} frames...")
|
|
|
| elif result_type == 'done':
|
| detection_done = True
|
| break
|
|
|
| elif result_type == 'error':
|
| print(f"Detection error: {result_data}")
|
| detection_done = True
|
| break
|
|
|
| except queue.Empty:
|
|
|
| if not detect_thread.is_alive():
|
| detection_done = True
|
| break
|
| continue
|
|
|
|
|
| detections: Dict[int, List[BoundingBox]] = {}
|
|
|
| print(f"Debug: Assembling detections from {len(all_batches)} batches...")
|
| for batch_idx, batch_data in enumerate(all_batches):
|
| batch_indices = batch_data['indices']
|
| batch_detections = batch_data['detections']
|
|
|
|
|
| if batch_idx == 0:
|
| print(f"Debug batch 0: {len(batch_indices)} frame indices, "
|
| f"detection keys: {list(batch_detections.keys())}, "
|
| f"total boxes in batch: {sum(len(boxes) for boxes in batch_detections.values())}")
|
|
|
| for local_idx, orig_idx in enumerate(batch_indices):
|
| boxes = batch_detections.get(local_idx, [])
|
| detections[orig_idx] = boxes
|
| if batch_idx == 0 and local_idx < 3 and len(boxes) > 0:
|
| print(f"Debug: Frame {orig_idx} (local_idx {local_idx}) has {len(boxes)} boxes")
|
|
|
|
|
| if frames_dict:
|
| max_idx = max(frames_dict.keys())
|
| frames = [frames_dict[i] for i in range(max_idx + 1) if i in frames_dict]
|
|
|
|
|
| total_detections = sum(len(boxes) for boxes in detections.values())
|
| frames_with_detections = sum(1 for boxes in detections.values() if len(boxes) > 0)
|
| print(f"Debug: {len(frames)} frames, {len(detections)} detection entries, "
|
| f"{total_detections} total boxes, {frames_with_detections} frames with detections")
|
| else:
|
| frames = []
|
|
|
|
|
| decode_thread.join(timeout=5.0)
|
| detect_thread.join(timeout=10.0)
|
| total_time = time.time() - total_time_start
|
|
|
|
|
| cap = cv2.VideoCapture(str(video_path))
|
| fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
| cap.release()
|
|
|
| return detections, frames, fps, total_time
|
|
|
|
|
| def main() -> None:
|
| args = parse_args()
|
|
|
| print("Initializing Miner...")
|
| init_start = time.time()
|
| miner = Miner(args.repo_path)
|
| print(f"Miner initialized in {time.time() - init_start:.2f}s")
|
|
|
|
|
| video_path = args.video_url if args.video_url else args.video_path
|
| temp_file = None
|
|
|
|
|
| if str(video_path).startswith(('http://', 'https://')):
|
| print("\n" + "=" * 60)
|
| temp_file = download_video_from_url(str(video_path))
|
| video_path = temp_file
|
|
|
|
|
| print("\n" + "=" * 60)
|
| process_start = time.time()
|
| detections, frames, fps, total_time = process_video_streaming(
|
| miner=miner,
|
| video_path=Path(video_path),
|
| batch_size=args.batch_size,
|
| conf_threshold=args.conf_threshold,
|
| iou_threshold=args.iou_threshold,
|
| classes=args.classes,
|
| stride=args.stride,
|
| max_frames=args.max_frames,
|
| )
|
|
|
|
|
| if temp_file and temp_file.exists():
|
| try:
|
| temp_file.unlink()
|
| print(f"Cleaned up temporary file: {temp_file}")
|
| except Exception as e:
|
| print(f"Warning: Could not delete temp file {temp_file}: {e}")
|
|
|
| total_frames = len(frames)
|
| fps_achieved = total_frames / total_time if total_time > 0 else 0.0
|
| time_per_frame = total_time / total_frames if total_frames > 0 else 0.0
|
|
|
| print("\n" + "=" * 60)
|
| print("OBJECT DETECTION PERFORMANCE")
|
| print("=" * 60)
|
| print(f"Total frames processed: {total_frames}")
|
| print(f"Total processing time: {total_time:.3f}s")
|
| print(f"Average time per frame: {time_per_frame*1000:.2f} ms")
|
| print(f"Throughput: {fps_achieved:.2f} FPS")
|
|
|
| stats = aggregate_stats(detections)
|
| print("\n" + "=" * 60)
|
| print("DETECTION STATISTICS")
|
| print("=" * 60)
|
| for key, value in stats.items():
|
| if isinstance(value, float):
|
| print(f"{key}: {value:.2f}")
|
| else:
|
| print(f"{key}: {value}")
|
|
|
| if not args.no_visualization and (args.output_video or args.output_dir) and frames:
|
| print("\n" + "=" * 60)
|
| print("Saving annotated outputs...")
|
| save_start = time.time()
|
| save_results(
|
| frames=frames,
|
| detections=detections,
|
| output_video=args.output_video,
|
| output_dir=args.output_dir,
|
| fps=fps / args.stride,
|
| )
|
| print(f"Outputs saved in {time.time() - save_start:.2f}s")
|
| elif not frames:
|
| print("\n" + "=" * 60)
|
| print("No frames processed. Skipping output saving.")
|
|
|
| print("\n" + "=" * 60)
|
| print("Done!")
|
| print("=" * 60)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|
|
|