visiontest / test_predict_objects_video.py
tarto2's picture
Upload folder using huggingface_hub
e4189f9 verified
import argparse
import time
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import sys
import os
import queue
import threading
import tempfile
from urllib.parse import urlparse
import cv2
import requests
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from miner import Miner, BoundingBox
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="High-speed object detection benchmark on a video file."
)
parser.add_argument(
"--repo-path",
type=Path,
default="",
help="Path to the HuggingFace/SecretVision repository (models, configs).",
)
parser.add_argument(
"--video-path",
type=str,
default="test.mp4",
help="Path to the input video file or URL (http:// or https://).",
)
parser.add_argument(
"--video-url",
type=str,
default=None,
help="URL to download video from (alternative to --video-path).",
)
parser.add_argument(
"--output-video",
type=Path,
default="outputs-detections/annotated.mp4",
help="Optional path to save an annotated video with detections.",
)
parser.add_argument(
"--output-dir",
type=Path,
default="outputs-detections/frames",
help="Optional directory to save annotated frames.",
)
parser.add_argument(
"--batch-size",
type=int,
default=None,
help="Batch size for YOLO inference (None = process all frames at once).",
)
parser.add_argument(
"--stride",
type=int,
default=1,
help="Sample every Nth frame from the video.",
)
parser.add_argument(
"--max-frames",
type=int,
default=None,
help="Maximum number of frames to process (after stride).",
)
parser.add_argument(
"--conf-threshold",
type=float,
default=0.5,
help="Confidence threshold for detections.",
)
parser.add_argument(
"--iou-threshold",
type=float,
default=0.45,
help="IoU threshold used by YOLO NMS.",
)
parser.add_argument(
"--classes",
type=int,
nargs="+",
default=None,
help="Optional list of class IDs to keep (default: all classes).",
)
parser.add_argument(
"--no-visualization",
action="store_true",
help="Skip saving annotated frames/video to maximize throughput.",
)
return parser.parse_args()
def draw_boxes(frame, boxes: List[BoundingBox]) -> None:
"""Draw bounding boxes on a frame."""
if not boxes:
return
color_map = {
0: (0, 255, 255), # ball - cyan
1: (0, 165, 255), # goalkeeper - orange
2: (0, 255, 0), # player - green
3: (255, 0, 0), # referee - blue
4: (128, 128, 128), # gray
5: (255, 255, 0), # cyan
6: (255, 0, 255), # magenta
7: (0, 128, 255), # orange
}
h, w = frame.shape[:2]
for box in boxes:
# Validate and clamp coordinates
x1 = max(0, min(int(box.x1), w - 1))
y1 = max(0, min(int(box.y1), h - 1))
x2 = max(x1 + 1, min(int(box.x2), w))
y2 = max(y1 + 1, min(int(box.y2), h))
if x2 <= x1 or y2 <= y1:
continue # Skip invalid boxes
color = color_map.get(box.cls_id, (255, 255, 255))
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
label = f"{box.cls_id}:{box.conf:.2f}"
cv2.putText(
frame,
label,
(x1, max(12, y1 - 6)),
cv2.FONT_HERSHEY_SIMPLEX,
0.4,
color,
1,
lineType=cv2.LINE_AA,
)
def annotate_frame(frame, boxes: List[BoundingBox], frame_id: int) -> cv2.Mat:
annotated = frame.copy()
draw_boxes(annotated, boxes)
info = f"Frame {frame_id} | Boxes: {len(boxes)}"
cv2.putText(
annotated,
info,
(10, 25),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(255, 255, 255),
2,
lineType=cv2.LINE_AA,
)
return annotated
def download_video_from_url(url: str, temp_dir: Optional[Path] = None) -> Path:
"""Download video from URL to a temporary file."""
print(f"Downloading video from {url}...")
download_start = time.time()
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
if temp_dir is None:
temp_dir = Path(tempfile.gettempdir())
else:
temp_dir.mkdir(parents=True, exist_ok=True)
# Get filename from URL or use a temp name
parsed_url = urlparse(url)
filename = os.path.basename(parsed_url.path) or "video.mp4"
temp_file = temp_dir / f"temp_{int(time.time())}_{filename}"
with open(temp_file, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
download_time = time.time() - download_start
print(f"Download completed in {download_time:.3f}s")
return temp_file
def stream_video_frames(
video_path: Path,
frame_queue: queue.Queue,
stride: int = 1,
max_frames: Optional[int] = None,
stop_event: Optional[threading.Event] = None,
) -> Tuple[int, float]:
"""
Decode video frames in a separate thread and put them in a queue.
Returns: (total_frames_decoded, fps)
"""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise RuntimeError(f"Unable to open video: {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
frame_count = 0
source_idx = 0
decode_start = time.time()
print(f"Decoding frames from {video_path}...")
try:
while True:
if stop_event and stop_event.is_set():
break
ret, frame = cap.read()
if not ret:
break
if source_idx % stride == 0:
frame_queue.put((frame_count, frame))
frame_count += 1
if max_frames and frame_count >= max_frames:
break
if frame_count % 100 == 0:
print(f"Decoded {frame_count} frames...")
source_idx += 1
finally:
cap.release()
frame_queue.put((None, None)) # Sentinel to signal end
decode_time = time.time() - decode_start
print(f"Total frames decoded: {frame_count} in {decode_time:.3f}s")
return frame_count, fps
def load_video_frames(
video_path: Path, stride: int = 1, max_frames: Optional[int] = None
) -> List[cv2.Mat]:
"""Legacy function: load all frames into memory (non-streaming)."""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise RuntimeError(f"Unable to open video: {video_path}")
frames: List[cv2.Mat] = []
frame_count = 0
source_idx = 0
print(f"Loading frames from {video_path}")
while True:
ret, frame = cap.read()
if not ret:
break
if source_idx % stride == 0:
frames.append(frame)
frame_count += 1
if max_frames and frame_count >= max_frames:
break
if frame_count % 100 == 0:
print(f"Loaded {frame_count} frames...")
source_idx += 1
cap.release()
print(f"Total frames loaded: {len(frames)}")
return frames
def save_results(
frames: List[cv2.Mat],
detections: Dict[int, List[BoundingBox]],
output_video: Optional[Path],
output_dir: Optional[Path],
fps: float,
) -> None:
if output_video is None and output_dir is None:
return
if not frames:
print("No frames to save.")
return
height, width = frames[0].shape[:2]
writer = None
if output_video:
output_video.parent.mkdir(parents=True, exist_ok=True)
writer = cv2.VideoWriter(
str(output_video),
cv2.VideoWriter_fourcc(*"mp4v"),
fps,
(width, height),
)
print(f"Saving annotated video to {output_video}")
if output_dir:
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Saving annotated frames to {output_dir}")
for frame_idx, frame in enumerate(frames):
boxes = detections.get(frame_idx, [])
annotated = annotate_frame(frame, boxes, frame_idx)
if writer:
writer.write(annotated)
if output_dir:
frame_path = output_dir / f"frame_{frame_idx:06d}.jpg"
cv2.imwrite(str(frame_path), annotated)
if (frame_idx + 1) % 100 == 0:
print(f"Saved {frame_idx + 1}/{len(frames)} frames...")
if writer:
writer.release()
print(f"Video saved to {output_video}")
def aggregate_stats(detections: Dict[int, List[BoundingBox]]) -> Dict[str, float]:
total_frames = len(detections)
total_boxes = sum(len(boxes) for boxes in detections.values())
class_counts: Dict[int, int] = {}
for boxes in detections.values():
for box in boxes:
class_counts[box.cls_id] = class_counts.get(box.cls_id, 0) + 1
stats: Dict[str, float] = {
"frames": total_frames,
"boxes": total_boxes,
}
stats["avg_boxes_per_frame"] = (
total_boxes / total_frames if total_frames > 0 else 0.0
)
for cls_id, count in sorted(class_counts.items()):
stats[f"class_{cls_id}_count"] = count
return stats
def detection_worker(
miner: Miner,
frame_queue: queue.Queue,
result_queue: queue.Queue,
batch_size: int,
conf_threshold: float,
iou_threshold: float,
classes: Optional[List[int]],
stop_event: threading.Event,
) -> None:
"""
Worker thread that processes frames for detection.
Takes frames from frame_queue and puts results in result_queue.
"""
frame_batch: List[cv2.Mat] = []
frame_indices: List[int] = []
while True:
if stop_event.is_set():
break
try:
item = frame_queue.get(timeout=0.5)
frame_idx, frame = item
if frame_idx is None: # Sentinel - decoding finished
# Process remaining frames in batch
if frame_batch:
batch_detections = miner.predict_objects(
images=frame_batch,
batch_size=None,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
classes=classes,
verbose=False,
)
result_queue.put(('batch', {
'indices': frame_indices,
'detections': batch_detections,
'frames': frame_batch.copy(),
}))
result_queue.put(('done', None))
break
frame_batch.append(frame)
frame_indices.append(frame_idx)
# Process batch when full
if len(frame_batch) >= batch_size:
batch_detections = miner.predict_objects(
images=frame_batch,
batch_size=None,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
classes=classes,
verbose=False,
)
# Debug: Check what we got
total_boxes_in_batch = sum(len(boxes) for boxes in batch_detections.values())
if total_boxes_in_batch > 0:
print(f"Detection worker: Processed batch of {len(frame_batch)} frames, "
f"found {total_boxes_in_batch} boxes, "
f"detection keys: {list(batch_detections.keys())}, "
f"frame indices: {frame_indices[:5]}...")
result_queue.put(('batch', {
'indices': frame_indices.copy(),
'detections': batch_detections,
'frames': frame_batch.copy(),
}))
frame_batch.clear()
frame_indices.clear()
except queue.Empty:
continue
except Exception as e:
print(f"Error in detection worker: {e}")
result_queue.put(('error', str(e)))
break
def process_video_streaming(
miner: Miner,
video_path: Path,
batch_size: Optional[int],
conf_threshold: float,
iou_threshold: float,
classes: Optional[List[int]],
stride: int,
max_frames: Optional[int],
) -> Tuple[Dict[int, List[BoundingBox]], List[cv2.Mat], float, float]:
"""
Process video with truly parallel decode and detection.
Decode thread and detection thread run simultaneously.
Returns: (detections, frames, fps, total_time)
"""
frame_queue: queue.Queue = queue.Queue(maxsize=50) # Buffer for decoded frames
result_queue: queue.Queue = queue.Queue() # Results from detection
frames_queue: queue.Queue = queue.Queue() # Store all decoded frames separately
stop_event = threading.Event()
effective_batch = batch_size if batch_size else 16
# Modified decode function that also stores frames
def decode_and_store_frames():
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise RuntimeError(f"Unable to open video: {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
frame_count = 0
source_idx = 0
decode_start = time.time()
print(f"Decoding frames from {video_path}...")
try:
while True:
if stop_event.is_set():
break
ret, frame = cap.read()
if not ret:
break
if source_idx % stride == 0:
frame_queue.put((frame_count, frame))
frames_queue.put((frame_count, frame)) # Store frame separately
frame_count += 1
if max_frames and frame_count >= max_frames:
break
if frame_count % 100 == 0:
print(f"Decoded {frame_count} frames...")
source_idx += 1
finally:
cap.release()
frame_queue.put((None, None)) # Sentinel to signal end
frames_queue.put((None, None)) # Sentinel for frames queue
decode_time = time.time() - decode_start
print(f"Total frames decoded: {frame_count} in {decode_time:.3f}s")
return frame_count, fps
# Start decode thread
decode_thread = threading.Thread(
target=decode_and_store_frames,
daemon=True,
)
# Start detection thread
detect_thread = threading.Thread(
target=detection_worker,
args=(miner, frame_queue, result_queue, effective_batch,
conf_threshold, iou_threshold, classes, stop_event),
daemon=True,
)
print("\n" + "=" * 60)
print("Running parallel decode + detection...")
print(f"Batch size: {effective_batch}")
print(f"Conf threshold: {conf_threshold}")
print(f"IoU threshold: {iou_threshold}")
if classes:
print(f"Classes filtered: {classes}")
total_time_start = time.time()
decode_thread.start()
detect_thread.start()
# Collect all decoded frames first (independent of detection)
frames_dict: Dict[int, cv2.Mat] = {}
while True:
try:
frame_idx, frame = frames_queue.get(timeout=1.0)
if frame_idx is None:
break
frames_dict[frame_idx] = frame
except queue.Empty:
if not decode_thread.is_alive():
break
continue
# Collect results from detection thread
all_batches = [] # Store all batch results
frames_processed = 0
detection_done = False
while not detection_done:
try:
result_type, result_data = result_queue.get(timeout=2.0)
if result_type == 'batch':
batch_boxes = sum(len(boxes) for boxes in result_data['detections'].values())
all_batches.append(result_data)
frames_processed += len(result_data['indices'])
if batch_boxes > 0:
print(f"Collected batch: {len(result_data['indices'])} frames, {batch_boxes} boxes")
if frames_processed % 100 == 0:
print(f"Processed {frames_processed} frames...")
elif result_type == 'done':
detection_done = True
break
elif result_type == 'error':
print(f"Detection error: {result_data}")
detection_done = True
break
except queue.Empty:
# Check if threads are still alive
if not detect_thread.is_alive():
detection_done = True
break
continue
# Assemble detections in correct order
detections: Dict[int, List[BoundingBox]] = {}
print(f"Debug: Assembling detections from {len(all_batches)} batches...")
for batch_idx, batch_data in enumerate(all_batches):
batch_indices = batch_data['indices']
batch_detections = batch_data['detections']
# Debug first batch
if batch_idx == 0:
print(f"Debug batch 0: {len(batch_indices)} frame indices, "
f"detection keys: {list(batch_detections.keys())}, "
f"total boxes in batch: {sum(len(boxes) for boxes in batch_detections.values())}")
for local_idx, orig_idx in enumerate(batch_indices):
boxes = batch_detections.get(local_idx, [])
detections[orig_idx] = boxes
if batch_idx == 0 and local_idx < 3 and len(boxes) > 0:
print(f"Debug: Frame {orig_idx} (local_idx {local_idx}) has {len(boxes)} boxes")
# Convert frames_dict to ordered list
if frames_dict:
max_idx = max(frames_dict.keys())
frames = [frames_dict[i] for i in range(max_idx + 1) if i in frames_dict]
# Debug: Print detection statistics
total_detections = sum(len(boxes) for boxes in detections.values())
frames_with_detections = sum(1 for boxes in detections.values() if len(boxes) > 0)
print(f"Debug: {len(frames)} frames, {len(detections)} detection entries, "
f"{total_detections} total boxes, {frames_with_detections} frames with detections")
else:
frames = []
# Wait for threads to finish
decode_thread.join(timeout=5.0)
detect_thread.join(timeout=10.0)
total_time = time.time() - total_time_start
# Get FPS from video metadata
cap = cv2.VideoCapture(str(video_path))
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
cap.release()
return detections, frames, fps, total_time
def main() -> None:
args = parse_args()
print("Initializing Miner...")
init_start = time.time()
miner = Miner(args.repo_path)
print(f"Miner initialized in {time.time() - init_start:.2f}s")
# Handle URL or local file
video_path = args.video_url if args.video_url else args.video_path
temp_file = None
# Check if it's a URL
if str(video_path).startswith(('http://', 'https://')):
print("\n" + "=" * 60)
temp_file = download_video_from_url(str(video_path))
video_path = temp_file
# Use streaming mode for parallel processing
print("\n" + "=" * 60)
process_start = time.time()
detections, frames, fps, total_time = process_video_streaming(
miner=miner,
video_path=Path(video_path),
batch_size=args.batch_size,
conf_threshold=args.conf_threshold,
iou_threshold=args.iou_threshold,
classes=args.classes,
stride=args.stride,
max_frames=args.max_frames,
)
# Clean up temp file if downloaded
if temp_file and temp_file.exists():
try:
temp_file.unlink()
print(f"Cleaned up temporary file: {temp_file}")
except Exception as e:
print(f"Warning: Could not delete temp file {temp_file}: {e}")
total_frames = len(frames)
fps_achieved = total_frames / total_time if total_time > 0 else 0.0
time_per_frame = total_time / total_frames if total_frames > 0 else 0.0
print("\n" + "=" * 60)
print("OBJECT DETECTION PERFORMANCE")
print("=" * 60)
print(f"Total frames processed: {total_frames}")
print(f"Total processing time: {total_time:.3f}s")
print(f"Average time per frame: {time_per_frame*1000:.2f} ms")
print(f"Throughput: {fps_achieved:.2f} FPS")
stats = aggregate_stats(detections)
print("\n" + "=" * 60)
print("DETECTION STATISTICS")
print("=" * 60)
for key, value in stats.items():
if isinstance(value, float):
print(f"{key}: {value:.2f}")
else:
print(f"{key}: {value}")
if not args.no_visualization and (args.output_video or args.output_dir) and frames:
print("\n" + "=" * 60)
print("Saving annotated outputs...")
save_start = time.time()
save_results(
frames=frames,
detections=detections,
output_video=args.output_video,
output_dir=args.output_dir,
fps=fps / args.stride,
)
print(f"Outputs saved in {time.time() - save_start:.2f}s")
elif not frames:
print("\n" + "=" * 60)
print("No frames processed. Skipping output saving.")
print("\n" + "=" * 60)
print("Done!")
print("=" * 60)
if __name__ == "__main__":
main()