import argparse import time from pathlib import Path from typing import List, Dict, Tuple import sys import os import cv2 import numpy as np sys.path.append(os.path.dirname(os.path.abspath(__file__))) from miner import Miner def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Test keypoint prediction on video file with maximum speed optimization." ) parser.add_argument( "--repo-path", type=Path, default="", help="Path to the HuggingFace/SecretVision repository (models, configs).", ) parser.add_argument( "--video-path", type=Path, default="test.mp4", help="Path to the input video file.", ) parser.add_argument( "--output-video", type=Path, default="outputs-keypoints/annotated.mp4", help="Optional path to save an annotated video with keypoints.", ) parser.add_argument( "--output-dir", type=Path, default="outputs-keypoints/frames", help="Optional directory to save annotated frames.", ) parser.add_argument( "--batch-size", type=int, default=None, help="Batch size for keypoint prediction (None = auto, processes all frames at once for max speed).", ) parser.add_argument( "--stride", type=int, default=1, help="Sample every Nth frame from the video (1 = all frames).", ) parser.add_argument( "--max-frames", type=int, default=None, help="Maximum number of frames to process (after stride).", ) parser.add_argument( "--n-keypoints", type=int, default=32, help="Number of keypoints expected per frame.", ) parser.add_argument( "--conf-threshold", type=float, default=0.5, help="Confidence threshold for regular keypoints.", ) parser.add_argument( "--corner-conf-threshold", type=float, default=0.3, help="Confidence threshold for corner keypoints.", ) parser.add_argument( "--no-visualization", action="store_true", help="Skip visualization to maximize speed.", ) return parser.parse_args() def draw_keypoints(frame: np.ndarray, keypoints: List[Tuple[int, int]], color: Tuple[int, int, int] = (0, 255, 255)) -> None: """Draw keypoints on frame.""" for x, y in keypoints: if x == 0 and y == 0: continue cv2.circle(frame, (x, y), radius=3, color=color, thickness=-1) cv2.circle(frame, (x, y), radius=5, color=(0, 0, 0), thickness=1) def annotate_frame(frame: np.ndarray, keypoints: List[Tuple[int, int]], frame_id: int) -> np.ndarray: """Annotate frame with keypoints and frame ID.""" annotated = frame.copy() draw_keypoints(annotated, keypoints) # Count valid keypoints valid_count = sum(1 for kp in keypoints if kp[0] != 0 or kp[1] != 0) # Draw frame info info_text = f"Frame {frame_id} | Keypoints: {valid_count}/{len(keypoints)}" cv2.putText( annotated, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, lineType=cv2.LINE_AA, ) return annotated def load_video_frames(video_path: Path, stride: int = 1, max_frames: int = None) -> List[np.ndarray]: """Load frames from video file.""" cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): raise RuntimeError(f"Unable to open video: {video_path}") frames = [] frame_count = 0 source_frame_idx = 0 print(f"Loading frames from video: {video_path}") while True: ret, frame = cap.read() if not ret: break if source_frame_idx % stride != 0: source_frame_idx += 1 continue frames.append(frame) frame_count += 1 source_frame_idx += 1 if max_frames and frame_count >= max_frames: break if frame_count % 100 == 0: print(f"Loaded {frame_count} frames...") cap.release() print(f"Total frames loaded: {len(frames)}") return frames def save_results( frames: List[np.ndarray], keypoints_dict: Dict[int, List[Tuple[int, int]]], output_video: Path = None, output_dir: Path = None, fps: float = 25.0, width: int = None, height: int = None, ) -> None: """Save annotated frames and/or video.""" if output_video is None and output_dir is None: return if width is None or height is None: height, width = frames[0].shape[:2] writer = None if output_video: output_video.parent.mkdir(parents=True, exist_ok=True) writer = cv2.VideoWriter( str(output_video), cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height), ) print(f"Saving annotated video to: {output_video}") if output_dir: output_dir.mkdir(parents=True, exist_ok=True) print(f"Saving annotated frames to: {output_dir}") for frame_idx, frame in enumerate(frames): keypoints = keypoints_dict.get(frame_idx, []) annotated = annotate_frame(frame, keypoints, frame_idx) if writer: writer.write(annotated) if output_dir: frame_path = output_dir / f"frame_{frame_idx:06d}.jpg" cv2.imwrite(str(frame_path), annotated) if (frame_idx + 1) % 100 == 0: print(f"Saved {frame_idx + 1}/{len(frames)} frames...") if writer: writer.release() print(f"Video saved: {output_video}") def calculate_statistics(keypoints_dict: Dict[int, List[Tuple[int, int]]]) -> Dict[str, float]: """Calculate keypoint detection statistics.""" total_frames = len(keypoints_dict) if total_frames == 0: return { "total_frames": 0, "avg_valid_keypoints": 0.0, "max_valid_keypoints": 0, "min_valid_keypoints": 0, "frames_with_keypoints": 0, } valid_counts = [] frames_with_keypoints = 0 for keypoints in keypoints_dict.values(): valid_count = sum(1 for kp in keypoints if kp[0] != 0 or kp[1] != 0) valid_counts.append(valid_count) if valid_count > 0: frames_with_keypoints += 1 return { "total_frames": total_frames, "avg_valid_keypoints": sum(valid_counts) / len(valid_counts) if valid_counts else 0.0, "max_valid_keypoints": max(valid_counts) if valid_counts else 0, "min_valid_keypoints": min(valid_counts) if valid_counts else 0, "frames_with_keypoints": frames_with_keypoints, "keypoint_detection_rate": frames_with_keypoints / total_frames if total_frames > 0 else 0.0, } def main() -> None: args = parse_args() # Initialize miner print("Initializing Miner...") init_start = time.time() miner = Miner(args.repo_path) init_time = time.time() - init_start print(f"Miner initialized in {init_time:.2f} seconds") # Load video frames print("\n" + "="*60) print("Loading video frames...") load_start = time.time() frames = load_video_frames(args.video_path, args.stride, args.max_frames) load_time = time.time() - load_start print(f"Frames loaded in {load_time:.2f} seconds") if len(frames) == 0: print("No frames loaded. Exiting.") return # Get video properties for output height, width = frames[0].shape[:2] cap = cv2.VideoCapture(str(args.video_path)) fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 cap.release() # Predict keypoints print("\n" + "="*60) print("Predicting keypoints...") print(f"Total frames: {len(frames)}") print(f"Batch size: {args.batch_size if args.batch_size else 'auto (all frames)'}") print(f"Confidence threshold: {args.conf_threshold}") print(f"Corner confidence threshold: {args.corner_conf_threshold}") predict_start = time.time() keypoints_dict = miner.predict_keypoints( images=frames, n_keypoints=args.n_keypoints, batch_size=args.batch_size, conf_threshold=args.conf_threshold, corner_conf_threshold=args.corner_conf_threshold, verbose=True, ) predict_time = time.time() - predict_start # Calculate performance metrics total_frames = len(frames) fps_achieved = total_frames / predict_time if predict_time > 0 else 0 time_per_frame = predict_time / total_frames if total_frames > 0 else 0 # Print performance results print("\n" + "="*60) print("KEYPOINT PREDICTION PERFORMANCE") print("="*60) print(f"Total frames processed: {total_frames}") print(f"Total prediction time: {predict_time:.3f} seconds") print(f"Average time per frame: {time_per_frame*1000:.2f} ms") print(f"Throughput: {fps_achieved:.2f} FPS") print(f"Batch processing: {'Yes' if args.batch_size else 'No (single batch)'}") # Calculate and print statistics stats = calculate_statistics(keypoints_dict) print("\n" + "="*60) print("KEYPOINT DETECTION STATISTICS") print("="*60) for key, value in stats.items(): if isinstance(value, float): print(f"{key}: {value:.2f}") else: print(f"{key}: {value}") # Save results if requested if not args.no_visualization and (args.output_video or args.output_dir): print("\n" + "="*60) print("Saving results...") save_start = time.time() save_results( frames=frames, keypoints_dict=keypoints_dict, output_video=args.output_video, output_dir=args.output_dir, fps=fps / args.stride, width=width, height=height, ) save_time = time.time() - save_start print(f"Results saved in {save_time:.2f} seconds") print("\n" + "="*60) print("Done!") print("="*60) if __name__ == "__main__": main()