| import argparse |
| import time |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import List, Tuple, Dict |
|
|
| import cv2 |
| import numpy as np |
|
|
| from miner3 import Miner, TVFrameResult, BoundingBox |
| from keypoint_evaluation import ( |
| evaluate_keypoints_for_frame, |
| evaluate_keypoints_for_frame_opencv_cuda, |
| evaluate_keypoints_batch_gpu, |
| load_template_from_file, |
| project_image_using_keypoints, |
| extract_masks_for_ground_and_lines, |
| extract_mask_of_ground_lines_in_image, |
| extract_masks_for_ground_and_lines_no_validation, |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Run Miner.predict_batch on a video and visualize results." |
| ) |
| parser.add_argument( |
| "--repo-path", |
| type=Path, |
| default="", |
| help="Path to the HuggingFace/SecretVision repository (models, configs).", |
| ) |
| parser.add_argument( |
| "--video-path", |
| type=Path, |
| default="2025_06_28_e40fec95_39d4f90f11cd419b89c620a6442d37_1414c99f.mp4", |
| help="Path to the input video file.", |
| ) |
| parser.add_argument( |
| "--output-video", |
| type=Path, |
| default='outputs/annotated.mp4', |
| help="Optional path to save an annotated video.", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=Path, |
| default='outputs/frames', |
| help="Optional directory to dump annotated frames.", |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=64, |
| help="Number of frames per predict_batch call.", |
| ) |
| parser.add_argument( |
| "--stride", |
| type=int, |
| default=1, |
| help="Sample every Nth frame from the video.", |
| ) |
| parser.add_argument( |
| "--max-frames", |
| type=int, |
| default=None, |
| help="Maximum number of frames to process (after stride).", |
| ) |
| parser.add_argument( |
| "--visualize-keypoints", |
| type=Path, |
| default="outputs/keypoints_visualizations", |
| help="Optional directory to save keypoint evaluation visualizations (warped template + original template for all frames).", |
| ) |
| parser.add_argument( |
| "--n-keypoints", |
| type=int, |
| default=32, |
| help="Number of keypoints Miner should return per frame.", |
| ) |
| parser.add_argument( |
| "--template-image", |
| type=Path, |
| default='football_pitch_template.png', |
| help="Path to football pitch template image (default: football_pitch_template.png in repo path).", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def draw_keypoints(frame: np.ndarray, keypoints: List[Tuple[int, int]]) -> None: |
| for x, y in keypoints: |
| if x == 0 and y == 0: |
| continue |
| cv2.circle(frame, (x, y), radius=2, color=(0, 255, 255), thickness=-1) |
|
|
|
|
| def draw_boxes(frame: np.ndarray, boxes: List[BoundingBox]) -> None: |
| color_map = { |
| 0: (0, 255, 255), |
| 1: (0, 165, 255), |
| 2: (0, 255, 0), |
| 3: (255, 0, 0), |
| 4: (128, 128, 128), |
| 5: (255, 255, 0), |
| 6: (255, 0, 255), |
| 7: (0, 128, 255), |
| } |
| for box in boxes: |
| color = color_map.get(box.cls_id, (255, 255, 255)) |
| cv2.rectangle(frame, (box.x1, box.y1), (box.x2, box.y2), color, 2) |
| label = f"{box.cls_id}:{box.conf:.2f}" |
| cv2.putText( |
| frame, |
| label, |
| (box.x1, max(10, box.y1 - 5)), |
| cv2.FONT_HERSHEY_SIMPLEX, |
| 0.4, |
| color, |
| 1, |
| lineType=cv2.LINE_AA, |
| ) |
|
|
|
|
| def annotate_frame(frame: np.ndarray, result: TVFrameResult) -> np.ndarray: |
| annotated = frame.copy() |
| draw_boxes(annotated, result.boxes) |
| draw_keypoints(annotated, result.keypoints) |
| cv2.putText( |
| annotated, |
| f"Frame {result.frame_id}", |
| (10, 20), |
| cv2.FONT_HERSHEY_SIMPLEX, |
| 0.6, |
| (255, 255, 255), |
| 2, |
| lineType=cv2.LINE_AA, |
| ) |
| return annotated |
|
|
|
|
| def ensure_output_dir(path: Path) -> None: |
| if path is not None: |
| path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def aggregate_stats(results: List[TVFrameResult]) -> Dict[str, float]: |
| class_counts = defaultdict(int) |
| team_counts = defaultdict(int) |
| total_boxes = 0 |
| for res in results: |
| for box in res.boxes: |
| class_counts[box.cls_id] += 1 |
| if box.cls_id in (6, 7): |
| team_counts[box.cls_id] += 1 |
| total_boxes += 1 |
| stats = { |
| "frames": len(results), |
| "boxes": total_boxes, |
| } |
| for cls_id, count in sorted(class_counts.items()): |
| stats[f"class_{cls_id}_count"] = count |
| for team_id, count in sorted(team_counts.items()): |
| stats[f"team_{team_id}_count"] = count |
| return stats |
|
|
|
|
| def visualize_keypoint_evaluation( |
| frame: np.ndarray, |
| frame_keypoints: List[Tuple[int, int]], |
| template_image: np.ndarray, |
| template_keypoints: List[Tuple[int, int]], |
| score: float, |
| output_path: Path, |
| frame_id: int, |
| ) -> np.ndarray: |
| """ |
| Visualize keypoint evaluation by drawing warped template and original template side by side. |
| |
| Args: |
| frame: Original frame image |
| frame_keypoints: Keypoints detected in the frame |
| template_image: Original template image |
| template_keypoints: Template keypoints |
| score: Evaluation score |
| output_path: Path to save the visualization |
| frame_id: Frame ID for labeling |
| |
| Returns: |
| Visualization image with warped template and original template side by side |
| """ |
| |
| warped_template = None |
| mask_lines_expected = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) |
| mask_lines_predicted = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) |
| is_twisted = False |
| |
| try: |
| |
| warped_template = project_image_using_keypoints( |
| image=template_image, |
| source_keypoints=template_keypoints, |
| destination_keypoints=frame_keypoints, |
| destination_width=frame.shape[1], |
| destination_height=frame.shape[0], |
| ) |
| |
| |
| try: |
| mask_ground, mask_lines_expected = extract_masks_for_ground_and_lines( |
| image=warped_template |
| ) |
| mask_lines_predicted = extract_mask_of_ground_lines_in_image( |
| image=frame, ground_mask=mask_ground |
| ) |
| except Exception as e: |
| |
| mask_lines_expected = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) |
| mask_lines_predicted = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) |
| except Exception as e: |
| |
| |
| is_twisted = "twisted" in str(e).lower() or "Projection twisted" in str(e) |
| warped_template = None |
| print(f"Warning: Could not warp template for frame {frame_id}: {e}") |
| |
| |
| |
| template_resized = cv2.resize( |
| template_image, |
| (int(template_image.shape[1] * frame.shape[0] / template_image.shape[0]), frame.shape[0]) |
| ) |
| |
| |
| h, w = frame.shape[:2] |
| template_h, template_w = template_resized.shape[:2] |
| spacing = 10 |
| vis_width = w + spacing + w + spacing + template_w + 20 |
| |
| |
| num_valid_keypoints = sum(1 for x, y in frame_keypoints if not (x == 0 and y == 0)) |
| max_lines_per_column = 10 |
| num_columns = (num_valid_keypoints + max_lines_per_column - 1) // max_lines_per_column |
| keypoint_text_height = 55 + min(max_lines_per_column, num_valid_keypoints) * 18 |
| vis_height = max(h, template_h) + max(60, keypoint_text_height) |
| |
| visualization = np.ones((vis_height, vis_width, 3), dtype=np.uint8) * 255 |
| |
| |
| frame_with_mask = frame.copy() |
| |
| mask_predicted_colored = np.zeros_like(frame_with_mask) |
| mask_predicted_colored[:, :, 1] = mask_lines_predicted * 255 |
| frame_with_mask = cv2.addWeighted(frame_with_mask, 0.7, mask_predicted_colored, 0.3, 0) |
| visualization[:h, :w] = frame_with_mask |
| |
| |
| warped_x = w + spacing |
| if warped_template is not None: |
| warped_with_mask = warped_template.copy() |
| |
| mask_expected_colored = np.zeros_like(warped_with_mask) |
| mask_expected_colored[:, :, 0] = mask_lines_expected * 255 |
| warped_with_mask = cv2.addWeighted(warped_with_mask, 0.7, mask_expected_colored, 0.3, 0) |
| |
| mask_predicted_colored_warped = np.zeros_like(warped_with_mask) |
| mask_predicted_colored_warped[:, :, 1] = mask_lines_predicted * 255 |
| warped_with_mask = cv2.addWeighted(warped_with_mask, 0.8, mask_predicted_colored_warped, 0.2, 0) |
| visualization[:h, warped_x:warped_x+w] = warped_with_mask |
| else: |
| |
| error_img = np.zeros((h, w, 3), dtype=np.uint8) |
| cv2.putText( |
| error_img, "Warping Failed", (w//4, h//2), |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2 |
| ) |
| visualization[:h, warped_x:warped_x+w] = error_img |
| |
| |
| template_x = warped_x + w + spacing |
| visualization[:template_h, template_x:template_x+template_w] = template_resized |
| |
| |
| |
| for i, (x, y) in enumerate(frame_keypoints): |
| if not (x == 0 and y == 0): |
| |
| draw_x = max(0, min(x, vis_width - 1)) |
| draw_y = max(0, min(y, vis_height - 1)) |
| cv2.circle(visualization, (draw_x, draw_y), 5, (0, 255, 0), -1) |
| cv2.putText( |
| visualization, str(i+1), (draw_x+8, draw_y-8), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1 |
| ) |
| |
| |
| cv2.putText( |
| visualization, "Original Frame (Green=Predicted Lines)", (10, h + 20), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2 |
| ) |
| warped_label = f"Warped Template (Blue=Expected, Green=Predicted, Score: {score:.3f})" |
| if is_twisted: |
| warped_label += " [TWISTED]" |
| cv2.putText( |
| visualization, warped_label, (warped_x, h + 20), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255) if is_twisted else (0, 0, 0), 2 |
| ) |
| cv2.putText( |
| visualization, "Original Template", (template_x, template_h + 20), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2 |
| ) |
| |
| cv2.putText( |
| visualization, f"Frame {frame_id}", (10, 30), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2 |
| ) |
| |
| |
| line_height = 18 |
| font_scale = 0.4 |
| font_thickness = 1 |
| |
| |
| keypoint_lines = [] |
| for i, (x, y) in enumerate(frame_keypoints): |
| |
| |
| if not (x == 0 and y == 0): |
| keypoint_lines.append(f"KP{i+1}: ({x},{y})") |
| |
| |
| max_lines_per_column = 10 |
| num_columns = (len(keypoint_lines) + max_lines_per_column - 1) // max_lines_per_column |
| column_width = 150 |
| |
| |
| start_y_bottom = vis_height - 10 |
| |
| for col_idx in range(num_columns): |
| start_idx = col_idx * max_lines_per_column |
| end_idx = min(start_idx + max_lines_per_column, len(keypoint_lines)) |
| x_pos = 10 + col_idx * column_width |
| column_lines = keypoint_lines[start_idx:end_idx] |
| num_lines_in_column = len(column_lines) |
| |
| for line_idx, kp_line in enumerate(column_lines): |
| |
| |
| y_pos = start_y_bottom - (num_lines_in_column - line_idx - 1) * line_height |
| cv2.putText( |
| visualization, kp_line, (x_pos, y_pos), |
| cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), font_thickness |
| ) |
| |
| |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| cv2.imwrite(str(output_path), visualization) |
| |
| return visualization |
|
|
|
|
| def evaluate_keypoints_batch( |
| results: List[TVFrameResult], |
| original_frames: Dict[int, np.ndarray], |
| template_image: np.ndarray, |
| template_keypoints: List[Tuple[int, int]], |
| visualization_output_dir: Path = None, |
| ) -> Dict[str, float]: |
| """ |
| Evaluate keypoint accuracy for a batch of results. |
| |
| Args: |
| results: List of TVFrameResult objects with keypoints |
| original_frames: Dictionary mapping frame_id to frame image |
| template_image: Template image for evaluation |
| template_keypoints: Template keypoints |
| visualization_output_dir: Optional directory to save visualization images for all frames |
| |
| Returns: |
| Dictionary with keypoint evaluation statistics |
| """ |
| frame_scores = [] |
| valid_frames = 0 |
| |
| for result in results: |
| frame_id = result.frame_id |
| if frame_id not in original_frames: |
| continue |
| |
| frame_image = original_frames[frame_id] |
| frame_keypoints = result.keypoints |
| |
| |
| valid_keypoints = [kp for kp in frame_keypoints if kp[0] != 0.0 or kp[1] != 0.0] |
| if len(valid_keypoints) < 4: |
| score = 0.0 |
| frame_scores.append(score) |
| |
| if visualization_output_dir: |
| vis_path = visualization_output_dir / f"frame_{frame_id:06d}_score_{score:.3f}_invalid.jpg" |
| visualize_keypoint_evaluation( |
| frame=frame_image, |
| frame_keypoints=frame_keypoints, |
| template_image=template_image, |
| template_keypoints=template_keypoints, |
| score=score, |
| output_path=vis_path, |
| frame_id=frame_id, |
| ) |
| continue |
|
|
| if len(frame_keypoints) != len(template_keypoints): |
| score = 0.0 |
| frame_scores.append(score) |
| |
| if visualization_output_dir: |
| vis_path = visualization_output_dir / f"frame_{frame_id:06d}_score_{score:.3f}_mismatch.jpg" |
| visualize_keypoint_evaluation( |
| frame=frame_image, |
| frame_keypoints=frame_keypoints, |
| template_image=template_image, |
| template_keypoints=template_keypoints, |
| score=score, |
| output_path=vis_path, |
| frame_id=frame_id, |
| ) |
| continue |
| |
| try: |
| score = evaluate_keypoints_for_frame( |
| template_keypoints=template_keypoints, |
| frame_keypoints=frame_keypoints, |
| frame=frame_image, |
| floor_markings_template=template_image.copy(), |
| ) |
| print(f'Frame {frame_id} score: {score}') |
| frame_scores.append(score) |
| valid_frames += 1 |
| |
| |
| if visualization_output_dir: |
| vis_path = visualization_output_dir / f"frame_{frame_id:06d}_score_{score:.3f}.jpg" |
| visualize_keypoint_evaluation( |
| frame=frame_image, |
| frame_keypoints=frame_keypoints, |
| template_image=template_image, |
| template_keypoints=template_keypoints, |
| score=score, |
| output_path=vis_path, |
| frame_id=frame_id, |
| ) |
| except Exception as e: |
| print(f"Error evaluating keypoints for frame {frame_id}: {e}") |
| score = 0.0 |
| frame_scores.append(score) |
| |
| if visualization_output_dir: |
| vis_path = visualization_output_dir / f"frame_{frame_id:06d}_score_{score:.3f}_error.jpg" |
| visualize_keypoint_evaluation( |
| frame=frame_image, |
| frame_keypoints=frame_keypoints, |
| template_image=template_image, |
| template_keypoints=template_keypoints, |
| score=score, |
| output_path=vis_path, |
| frame_id=frame_id, |
| ) |
| |
| if len(frame_scores) == 0: |
| return { |
| "keypoint_avg_score": 0.0, |
| "keypoint_valid_frames": 0, |
| "keypoint_total_frames": len(results), |
| } |
| |
| return { |
| "keypoint_avg_score": sum(frame_scores) / len(frame_scores), |
| "keypoint_max_score": max(frame_scores), |
| "keypoint_min_score": min(frame_scores), |
| "keypoint_valid_frames": valid_frames, |
| "keypoint_total_frames": len(results), |
| "keypoint_frames_above_0.5": sum(1 for s in frame_scores if s > 0.5), |
| "keypoint_frames_above_0.7": sum(1 for s in frame_scores if s > 0.7), |
| } |
|
|
|
|
| def evaluate_keypoints_batch_fast( |
| results: List[TVFrameResult], |
| original_frames: Dict[int, np.ndarray], |
| template_image: np.ndarray, |
| template_keypoints: List[Tuple[int, int]], |
| batch_size: int = 32, |
| ) -> Dict[str, float]: |
| """ |
| Fast batch GPU evaluation of keypoint accuracy for multiple frames simultaneously. |
| |
| This function uses batch GPU processing to evaluate frames in smaller batches, |
| which is 5-10x faster than sequential evaluation while avoiding memory issues. |
| |
| Args: |
| results: List of TVFrameResult objects |
| original_frames: Dictionary mapping frame_id to frame image |
| template_image: Template image for evaluation |
| template_keypoints: Template keypoints |
| batch_size: Number of frames to process in each GPU batch (default: 8) |
| |
| Returns: |
| Dictionary with keypoint evaluation statistics |
| """ |
| |
| frame_keypoints_list = [] |
| frames_list = [] |
| result_indices = [] |
| |
| for idx, result in enumerate(results): |
| frame_id = result.frame_id |
| if frame_id not in original_frames: |
| continue |
| |
| frame_image = original_frames[frame_id] |
| frame_keypoints = result.keypoints |
| |
| |
| valid_keypoints = [kp for kp in frame_keypoints if kp[0] != 0.0 or kp[1] != 0.0] |
| if len(valid_keypoints) < 4: |
| continue |
|
|
| if len(frame_keypoints) != len(template_keypoints): |
| continue |
| |
| frame_keypoints_list.append(frame_keypoints) |
| frames_list.append(frame_image) |
| result_indices.append(idx) |
| |
| if len(frames_list) == 0: |
| return { |
| "keypoint_avg_score": 0.0, |
| "keypoint_valid_frames": 0, |
| "keypoint_total_frames": len(results), |
| } |
| |
| |
| all_scores = [] |
| all_result_indices = [] |
| |
| num_batches = (len(frames_list) + batch_size - 1) // batch_size |
| |
| for batch_idx in range(num_batches): |
| start_idx = batch_idx * batch_size |
| end_idx = min(start_idx + batch_size, len(frames_list)) |
| |
| batch_frames = frames_list[start_idx:end_idx] |
| batch_keypoints = frame_keypoints_list[start_idx:end_idx] |
| batch_indices = result_indices[start_idx:end_idx] |
| |
| |
| try: |
| scores_batch = evaluate_keypoints_batch_gpu( |
| template_keypoints=template_keypoints, |
| frame_keypoints_list=batch_keypoints, |
| frames=batch_frames, |
| floor_markings_template=template_image, |
| device="cuda", |
| ) |
| all_scores.extend(scores_batch) |
| all_result_indices.extend(batch_indices) |
| except Exception as e: |
| print(f"Error in batch GPU evaluation (batch {batch_idx + 1}/{num_batches}): {e}, falling back to sequential for this batch") |
| |
| for frame_keypoints, frame_image, result_idx in zip(batch_keypoints, batch_frames, batch_indices): |
| try: |
| score = evaluate_keypoints_for_frame_opencv_cuda( |
| template_keypoints=template_keypoints, |
| frame_keypoints=frame_keypoints, |
| frame=frame_image, |
| floor_markings_template=template_image.copy(), |
| ) |
| all_scores.append(score) |
| all_result_indices.append(result_idx) |
| except Exception as e2: |
| print(f"Error evaluating keypoints: {e2}") |
| all_scores.append(0.0) |
| all_result_indices.append(result_idx) |
| |
| |
| frame_scores = [0.0] * len(results) |
| valid_frames = 0 |
| for result_idx, score in zip(all_result_indices, all_scores): |
| frame_scores[result_idx] = score |
| if score > 0.0: |
| valid_frames += 1 |
| |
| if len([s for s in frame_scores if s > 0.0]) == 0: |
| return { |
| "keypoint_avg_score": 0.0, |
| "keypoint_valid_frames": 0, |
| "keypoint_total_frames": len(results), |
| } |
| |
| |
| valid_scores = [s for s in frame_scores if s > 0.0] |
| |
| return { |
| "keypoint_avg_score": sum(valid_scores) / len(valid_scores) if valid_scores else 0.0, |
| "keypoint_max_score": max(valid_scores) if valid_scores else 0.0, |
| "keypoint_min_score": min(valid_scores) if valid_scores else 0.0, |
| "keypoint_valid_frames": valid_frames, |
| "keypoint_total_frames": len(results), |
| "keypoint_frames_above_0.5": sum(1 for s in valid_scores if s > 0.5), |
| "keypoint_frames_above_0.7": sum(1 for s in valid_scores if s > 0.7), |
| } |
|
|
|
|
| def process_batches( |
| miner: Miner, |
| frames: List[np.ndarray], |
| frame_ids: List[int], |
| n_keypoints: int, |
| ) -> List[TVFrameResult]: |
| start = time.time() |
| results = miner.predict_batch(frames, offset=frame_ids[0], n_keypoints=n_keypoints) |
| end = time.time() |
| print( |
| f"[Batch frames {frame_ids[0]}-{frame_ids[-1]}] " |
| f"predict_batch latency: {end - start:.2f}s " |
| f"({len(frames) / (end - start + 1e-6):.2f} FPS)" |
| ) |
| return results |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| miner = Miner(args.repo_path) |
|
|
| cap = cv2.VideoCapture(str(args.video_path)) |
| if not cap.isOpened(): |
| raise RuntimeError(f"Unable to open video: {args.video_path}") |
|
|
| ensure_output_dir(args.output_dir) |
| |
| |
| fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| |
| |
| if args.template_image: |
| template_image_path = args.template_image |
| else: |
| |
| template_image_path = args.repo_path / "football_pitch_template.png" |
| |
| if not template_image_path.exists(): |
| raise ValueError( |
| f"Template image not found: {template_image_path}. " |
| f"Please provide --template-image path or place football_pitch_template.png in repo path." |
| ) |
| |
| |
| print(f"Loading template from {template_image_path}") |
| template_image, template_keypoints = load_template_from_file(str(template_image_path)) |
| print(f"Loaded template with {len(template_keypoints)} keypoints") |
| |
| writer = None |
| if args.output_video: |
| args.output_video.parent.mkdir(parents=True, exist_ok=True) |
| writer = cv2.VideoWriter( |
| str(args.output_video), |
| cv2.VideoWriter_fourcc(*"mp4v"), |
| fps / args.stride, |
| (width, height), |
| ) |
|
|
| processed_results: List[TVFrameResult] = [] |
| frames_buffer: List[np.ndarray] = [] |
| frame_ids_buffer: List[int] = [] |
| original_frames: Dict[int, np.ndarray] = {} |
| processed_frames = 0 |
| source_frame_idx = 0 |
|
|
| start_time = time.time() |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| if source_frame_idx % args.stride != 0: |
| source_frame_idx += 1 |
| continue |
|
|
| frames_buffer.append(frame) |
| frame_ids_buffer.append(source_frame_idx) |
| original_frames[source_frame_idx] = frame.copy() |
| processed_frames += 1 |
| source_frame_idx += 1 |
|
|
| if args.max_frames and processed_frames >= args.max_frames: |
| break |
| if len(frames_buffer) < args.batch_size: |
| continue |
|
|
| batch_results = process_batches( |
| miner, frames_buffer, frame_ids_buffer, args.n_keypoints |
| ) |
| processed_results.extend(batch_results) |
| for res, original in zip(batch_results, frames_buffer): |
| annotated = annotate_frame(original, res) |
| if writer: |
| writer.write(annotated) |
| if args.output_dir: |
| frame_path = args.output_dir / f"frame_{res.frame_id:06d}.jpg" |
| cv2.imwrite(str(frame_path), annotated) |
| frames_buffer.clear() |
| frame_ids_buffer.clear() |
|
|
| |
| if frames_buffer: |
| batch_results = process_batches( |
| miner, frames_buffer, frame_ids_buffer, args.n_keypoints |
| ) |
| processed_results.extend(batch_results) |
| for res, original in zip(batch_results, frames_buffer): |
| annotated = annotate_frame(original, res) |
| if writer: |
| writer.write(annotated) |
| if args.output_dir: |
| frame_path = args.output_dir / f"frame_{res.frame_id:06d}.jpg" |
| cv2.imwrite(str(frame_path), annotated) |
|
|
| cap.release() |
| if writer: |
| writer.release() |
|
|
| stats = aggregate_stats(processed_results) |
|
|
| end_time = time.time() |
| print(f"Total time taken: {end_time - start_time:.2f} seconds") |
| |
| |
| time_start = time.time() |
| print("\n===== Evaluating Keypoints =====") |
| keypoint_stats = evaluate_keypoints_batch( |
| processed_results, |
| original_frames, |
| template_image, |
| template_keypoints, |
| visualization_output_dir=args.visualize_keypoints, |
| ) |
| time_end = time.time() |
| print(f"Keypoint evaluation time: {time_end - time_start:.2f} seconds") |
| |
| print("\n===== Summary =====") |
| for key, value in stats.items(): |
| print(f"{key}: {value}") |
| if stats["frames"]: |
| avg_boxes = stats["boxes"] / stats["frames"] |
| print(f"Average boxes per frame: {avg_boxes:.2f}") |
| |
| print("\n===== Keypoint Evaluation =====") |
| for key, value in keypoint_stats.items(): |
| print(f"{key}: {value}") |
| if keypoint_stats["keypoint_total_frames"] > 0: |
| valid_ratio = keypoint_stats["keypoint_valid_frames"] / keypoint_stats["keypoint_total_frames"] |
| print(f"Keypoint evaluation success rate: {valid_ratio:.2%}") |
| |
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|
|
|