Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Visualization script for validating play detections. | |
| This script generates visualizations of detected plays: | |
| 1. Video clips of each detected play with overlay | |
| 2. Summary statistics and comparison with ground truth (if available) | |
| 3. Timeline visualization of detected plays | |
| Usage: | |
| # Visualize results from detection | |
| python scripts/visualize_detections.py output/detected_plays_quick.json | |
| # Compare with ground truth | |
| python scripts/visualize_detections.py output/detected_plays_extended.json --ground-truth tests/test_data/ground_truth_plays.json | |
| # Generate video clips for each play | |
| python scripts/visualize_detections.py output/detected_plays_quick.json --generate-clips | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Optional | |
| import cv2 | |
| import numpy as np | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Default paths (scripts/archive/ -> project root) | |
| PROJECT_ROOT = Path(__file__).parent.parent.parent | |
| DEFAULT_VIDEO_PATH = PROJECT_ROOT / "full_videos" / "OSU vs Tenn 12.21.24.mkv" | |
| OUTPUT_DIR = PROJECT_ROOT / "output" | |
| class PlayComparison: | |
| """Comparison between detected and ground truth plays.""" | |
| detected_play: Dict[str, Any] | |
| ground_truth_play: Optional[Dict[str, Any]] | |
| time_diff_start: Optional[float] | |
| time_diff_end: Optional[float] | |
| is_match: bool | |
| def load_results(results_path: str) -> Dict[str, Any]: | |
| """Load detection results from JSON file.""" | |
| with open(results_path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def load_ground_truth(ground_truth_path: str) -> Optional[List[Dict[str, Any]]]: | |
| """Load ground truth plays from JSON file if it exists.""" | |
| path = Path(ground_truth_path) | |
| if not path.exists(): | |
| return None | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| # Handle different formats | |
| if isinstance(data, list): | |
| return data | |
| if isinstance(data, dict) and "plays" in data: | |
| return data["plays"] | |
| return None | |
| def compare_with_ground_truth(detected_plays: List[Dict], ground_truth: List[Dict], tolerance: float = 2.0) -> List[PlayComparison]: | |
| """ | |
| Compare detected plays with ground truth. | |
| Args: | |
| detected_plays: List of detected plays | |
| ground_truth: List of ground truth plays | |
| tolerance: Time tolerance in seconds for matching | |
| Returns: | |
| List of PlayComparison objects | |
| """ | |
| comparisons = [] | |
| matched_gt_indices = set() | |
| for detected in detected_plays: | |
| best_match = None | |
| best_diff = float("inf") | |
| best_gt_idx = None | |
| for gt_idx, gt_play in enumerate(ground_truth): | |
| if gt_idx in matched_gt_indices: | |
| continue | |
| # Compare start times | |
| gt_start = gt_play.get("start_time", gt_play.get("start", 0)) | |
| det_start = detected.get("start_time", 0) | |
| start_diff = abs(gt_start - det_start) | |
| if start_diff < tolerance and start_diff < best_diff: | |
| best_match = gt_play | |
| best_diff = start_diff | |
| best_gt_idx = gt_idx | |
| if best_match is not None: | |
| matched_gt_indices.add(best_gt_idx) | |
| gt_start = best_match.get("start_time", best_match.get("start", 0)) | |
| gt_end = best_match.get("end_time", best_match.get("end", 0)) | |
| comparison = PlayComparison( | |
| detected_play=detected, | |
| ground_truth_play=best_match, | |
| time_diff_start=detected.get("start_time", 0) - gt_start, | |
| time_diff_end=detected.get("end_time", 0) - gt_end, | |
| is_match=True, | |
| ) | |
| else: | |
| comparison = PlayComparison(detected_play=detected, ground_truth_play=None, time_diff_start=None, time_diff_end=None, is_match=False) | |
| comparisons.append(comparison) | |
| return comparisons | |
| def print_summary(results: Dict[str, Any], comparisons: Optional[List[PlayComparison]] = None) -> None: | |
| """Print summary of detection results.""" | |
| plays = results.get("plays", []) | |
| logger.info("=" * 60) | |
| logger.info("DETECTION SUMMARY") | |
| logger.info("=" * 60) | |
| logger.info("Video: %s", results.get("video", "unknown")) | |
| segment = results.get("segment", {}) | |
| logger.info("Segment: %.1fs - %.1fs", segment.get("start", 0), segment.get("end", 0)) | |
| processing = results.get("processing", {}) | |
| logger.info("Frames processed: %d", processing.get("total_frames", 0)) | |
| logger.info("Frames with scorebug: %d", processing.get("frames_with_scorebug", 0)) | |
| logger.info("Frames with clock: %d", processing.get("frames_with_clock", 0)) | |
| logger.info("-" * 60) | |
| logger.info("Plays detected: %d", len(plays)) | |
| if plays: | |
| durations = [p.get("duration", p.get("end_time", 0) - p.get("start_time", 0)) for p in plays] | |
| logger.info("Duration stats: avg=%.1fs, min=%.1fs, max=%.1fs", sum(durations) / len(durations), min(durations), max(durations)) | |
| # Count detection methods | |
| start_methods = {} | |
| end_methods = {} | |
| for play in plays: | |
| sm = play.get("start_method", "unknown") | |
| em = play.get("end_method", "unknown") | |
| start_methods[sm] = start_methods.get(sm, 0) + 1 | |
| end_methods[em] = end_methods.get(em, 0) + 1 | |
| logger.info("Start methods: %s", start_methods) | |
| logger.info("End methods: %s", end_methods) | |
| if comparisons: | |
| logger.info("-" * 60) | |
| logger.info("GROUND TRUTH COMPARISON") | |
| logger.info("-" * 60) | |
| matches = sum(1 for c in comparisons if c.is_match) | |
| false_positives = sum(1 for c in comparisons if not c.is_match) | |
| logger.info("Matched plays: %d", matches) | |
| logger.info("False positives: %d", false_positives) | |
| if matches > 0: | |
| start_diffs = [abs(c.time_diff_start) for c in comparisons if c.is_match and c.time_diff_start is not None] | |
| end_diffs = [abs(c.time_diff_end) for c in comparisons if c.is_match and c.time_diff_end is not None] | |
| if start_diffs: | |
| logger.info("Start time error: avg=%.2fs, max=%.2fs", sum(start_diffs) / len(start_diffs), max(start_diffs)) | |
| if end_diffs: | |
| logger.info("End time error: avg=%.2fs, max=%.2fs", sum(end_diffs) / len(end_diffs), max(end_diffs)) | |
| logger.info("=" * 60) | |
| def print_plays_table(plays: List[Dict[str, Any]]) -> None: | |
| """Print a table of detected plays.""" | |
| logger.info("") | |
| logger.info("DETECTED PLAYS") | |
| logger.info("-" * 80) | |
| logger.info("%-5s %-10s %-10s %-8s %-12s %-12s", "#", "Start", "End", "Duration", "Start Method", "End Method") | |
| logger.info("-" * 80) | |
| for play in plays: | |
| logger.info( | |
| "%-5d %-10.1f %-10.1f %-8.1f %-12s %-12s", | |
| play.get("play_number", 0), | |
| play.get("start_time", 0), | |
| play.get("end_time", 0), | |
| play.get("duration", play.get("end_time", 0) - play.get("start_time", 0)), | |
| play.get("start_method", "unknown"), | |
| play.get("end_method", "unknown"), | |
| ) | |
| logger.info("-" * 80) | |
| def create_timeline_image(plays: List[Dict], segment_start: float, segment_end: float, output_path: str) -> None: | |
| """ | |
| Create a timeline visualization of detected plays. | |
| Args: | |
| plays: List of detected plays | |
| segment_start: Segment start time | |
| segment_end: Segment end time | |
| output_path: Path to save the image | |
| """ | |
| # Image dimensions | |
| width = 1200 | |
| height = 200 | |
| margin = 50 | |
| # Create image | |
| image = np.zeros((height, width, 3), dtype=np.uint8) | |
| image[:] = (40, 40, 40) # Dark gray background | |
| # Draw timeline | |
| timeline_y = height // 2 | |
| timeline_start_x = margin | |
| timeline_end_x = width - margin | |
| timeline_width = timeline_end_x - timeline_start_x | |
| # Draw timeline axis | |
| cv2.line(image, (timeline_start_x, timeline_y), (timeline_end_x, timeline_y), (200, 200, 200), 2) | |
| # Draw time markers | |
| segment_duration = segment_end - segment_start | |
| for seconds in range(0, int(segment_duration) + 1, 30): | |
| x = timeline_start_x + int(seconds / segment_duration * timeline_width) | |
| cv2.line(image, (x, timeline_y - 5), (x, timeline_y + 5), (200, 200, 200), 1) | |
| mins = int((segment_start + seconds) // 60) | |
| secs = int((segment_start + seconds) % 60) | |
| time_label = "%d:%02d" % (mins, secs) | |
| cv2.putText(image, time_label, (x - 15, timeline_y + 25), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1) | |
| # Draw plays | |
| for play in plays: | |
| start_time = play.get("start_time", 0) - segment_start | |
| end_time = play.get("end_time", 0) - segment_start | |
| start_x = timeline_start_x + int(start_time / segment_duration * timeline_width) | |
| end_x = timeline_start_x + int(end_time / segment_duration * timeline_width) | |
| # Draw play bar | |
| cv2.rectangle(image, (start_x, timeline_y - 20), (end_x, timeline_y - 5), (0, 255, 0), -1) | |
| # Draw play number | |
| play_num = play.get("play_number", 0) | |
| cv2.putText(image, str(play_num), (start_x, timeline_y - 25), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1) | |
| # Add title | |
| cv2.putText(image, "Play Detection Timeline", (width // 2 - 100, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) | |
| # Add legend | |
| cv2.rectangle(image, (width - 150, 10), (width - 130, 25), (0, 255, 0), -1) | |
| cv2.putText(image, "Detected Play", (width - 125, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) | |
| # Save image | |
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) | |
| cv2.imwrite(output_path, image) | |
| logger.info("Timeline saved to: %s", output_path) | |
| def generate_play_clips_ffmpeg(results: Dict[str, Any], video_path: str, output_dir: str, padding: float = 2.0) -> Dict[str, float]: | |
| """ | |
| Generate video clips for each detected play using ffmpeg (much faster than OpenCV). | |
| Args: | |
| results: Detection results | |
| video_path: Path to source video | |
| output_dir: Directory to save clips | |
| padding: Seconds of padding before/after play | |
| Returns: | |
| Dictionary with timing information | |
| """ | |
| import subprocess | |
| import time | |
| timing = {"clip_extraction": 0.0, "concatenation": 0.0} | |
| plays = results.get("plays", []) | |
| if not plays: | |
| logger.warning("No plays to generate clips for") | |
| return timing | |
| # Create output directory | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| logger.info("Generating %d play clips with ffmpeg...", len(plays)) | |
| clip_paths = [] | |
| t_start = time.perf_counter() | |
| for play in plays: | |
| play_num = play.get("play_number", 0) | |
| start_time = max(0, play.get("start_time", 0) - padding) | |
| end_time = play.get("end_time", 0) + padding | |
| duration = end_time - start_time | |
| # Create output file | |
| clip_path = output_path / ("play_%02d.mp4" % play_num) | |
| clip_paths.append(clip_path) | |
| # Use ffmpeg for fast extraction | |
| # -ss before -i for fast seeking, -t for duration | |
| cmd = [ | |
| "ffmpeg", | |
| "-y", # Overwrite output | |
| "-ss", | |
| str(start_time), | |
| "-i", | |
| video_path, | |
| "-t", | |
| str(duration), | |
| "-c:v", | |
| "libx264", # Re-encode for compatibility | |
| "-preset", | |
| "fast", | |
| "-crf", | |
| "23", | |
| "-c:a", | |
| "aac", | |
| "-b:a", | |
| "128k", | |
| "-loglevel", | |
| "error", | |
| str(clip_path), | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True, capture_output=True) | |
| logger.info(" Created: %s (%.1fs - %.1fs, duration: %.1fs)", clip_path.name, start_time, end_time, duration) | |
| except subprocess.CalledProcessError as e: | |
| logger.error(" Failed to create %s: %s", clip_path.name, e.stderr.decode() if e.stderr else str(e)) | |
| timing["clip_extraction"] = time.perf_counter() - t_start | |
| logger.info("Clip extraction complete! (%.2fs)", timing["clip_extraction"]) | |
| # Concatenate all clips into a single highlight video | |
| if len(clip_paths) > 1: | |
| t_start = time.perf_counter() | |
| concat_path = output_path / "all_plays.mp4" | |
| logger.info("Concatenating %d clips into %s...", len(clip_paths), concat_path.name) | |
| # Create concat file list | |
| concat_list_path = output_path / "concat_list.txt" | |
| with open(concat_list_path, "w", encoding="utf-8") as f: | |
| for clip_path in clip_paths: | |
| f.write("file '%s'\n" % clip_path.name) | |
| # Use ffmpeg concat demuxer | |
| cmd = [ | |
| "ffmpeg", | |
| "-y", | |
| "-f", | |
| "concat", | |
| "-safe", | |
| "0", | |
| "-i", | |
| str(concat_list_path), | |
| "-c", | |
| "copy", # No re-encoding needed | |
| "-loglevel", | |
| "error", | |
| str(concat_path), | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True, capture_output=True, cwd=str(output_path)) | |
| logger.info(" Created: %s", concat_path.name) | |
| except subprocess.CalledProcessError as e: | |
| logger.error(" Failed to concatenate: %s", e.stderr.decode() if e.stderr else str(e)) | |
| # Clean up concat list | |
| concat_list_path.unlink(missing_ok=True) | |
| timing["concatenation"] = time.perf_counter() - t_start | |
| logger.info("Concatenation complete! (%.2fs)", timing["concatenation"]) | |
| return timing | |
| def generate_play_clips(results: Dict[str, Any], video_path: str, output_dir: str, padding: float = 2.0) -> None: | |
| """ | |
| Generate video clips for each detected play (legacy OpenCV version - slow). | |
| Args: | |
| results: Detection results | |
| video_path: Path to source video | |
| output_dir: Directory to save clips | |
| padding: Seconds of padding before/after play | |
| """ | |
| plays = results.get("plays", []) | |
| if not plays: | |
| logger.warning("No plays to generate clips for") | |
| return | |
| # Open video | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| logger.error("Could not open video: %s", video_path) | |
| return | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # Create output directory | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| logger.info("Generating %d play clips...", len(plays)) | |
| for play in plays: | |
| play_num = play.get("play_number", 0) | |
| start_time = play.get("start_time", 0) - padding | |
| end_time = play.get("end_time", 0) + padding | |
| # Create output file | |
| clip_path = output_path / ("play_%02d.mp4" % play_num) | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") # pylint: disable=no-member | |
| out = cv2.VideoWriter(str(clip_path), fourcc, fps, (frame_width, frame_height)) | |
| # Seek to start | |
| start_frame = int(start_time * fps) | |
| end_frame = int(end_time * fps) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| # Write frames | |
| current_frame = start_frame | |
| while current_frame < end_frame: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Add overlay | |
| current_time = current_frame / fps | |
| play_start = play.get("start_time", 0) | |
| play_end = play.get("end_time", 0) | |
| # Determine if we're in the play | |
| in_play = play_start <= current_time <= play_end | |
| color = (0, 255, 0) if in_play else (128, 128, 128) | |
| # Add text overlay | |
| cv2.putText(frame, "Play #%d" % play_num, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2) | |
| cv2.putText(frame, "Time: %.1fs" % current_time, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) | |
| if in_play: | |
| elapsed = current_time - play_start | |
| cv2.putText(frame, "IN PLAY (%.1fs)" % elapsed, (10, 110), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
| out.write(frame) | |
| current_frame += 1 | |
| out.release() | |
| logger.info(" Created: %s (%.1fs - %.1fs)", clip_path.name, start_time, end_time) | |
| cap.release() | |
| logger.info("Clip generation complete!") | |
| def print_timing_summary(results: Dict[str, Any], clip_timing: Optional[Dict[str, float]] = None) -> None: | |
| """Print timing breakdown from detection and clip generation.""" | |
| timing = results.get("timing", {}) | |
| if not timing and not clip_timing: | |
| return | |
| logger.info("") | |
| logger.info("=" * 60) | |
| logger.info("TIMING BREAKDOWN") | |
| logger.info("=" * 60) | |
| total_detection = 0.0 | |
| if timing: | |
| logger.info("Detection Phase:") | |
| for section, duration in timing.items(): | |
| logger.info(" %s: %.2fs", section, duration) | |
| total_detection += duration | |
| logger.info(" DETECTION TOTAL: %.2fs", total_detection) | |
| if clip_timing: | |
| logger.info("Clip Generation Phase:") | |
| total_clips = 0.0 | |
| for section, duration in clip_timing.items(): | |
| logger.info(" %s: %.2fs", section, duration) | |
| total_clips += duration | |
| logger.info(" CLIP TOTAL: %.2fs", total_clips) | |
| logger.info("=" * 60) | |
| def main(): | |
| """Main entry point.""" | |
| parser = argparse.ArgumentParser(description="Visualize play detection results") | |
| parser.add_argument("results_file", help="Path to detection results JSON file") | |
| parser.add_argument("--ground-truth", type=str, help="Path to ground truth JSON file") | |
| parser.add_argument("--video", type=str, help="Path to video file (for clip generation)") | |
| parser.add_argument("--generate-clips", action="store_true", help="Generate video clips for each play") | |
| parser.add_argument("--use-opencv", action="store_true", help="Use OpenCV instead of ffmpeg for clip generation (slower)") | |
| parser.add_argument("--padding", type=float, default=2.0, help="Seconds of padding before/after each play (default: 2.0)") | |
| parser.add_argument("--output-dir", type=str, help="Output directory for visualizations") | |
| args = parser.parse_args() | |
| # Load results | |
| results_path = Path(args.results_file) | |
| if not results_path.exists(): | |
| logger.error("Results file not found: %s", results_path) | |
| return 1 | |
| logger.info("Loading results from: %s", results_path) | |
| results = load_results(str(results_path)) | |
| # Load ground truth if provided | |
| comparisons = None | |
| if args.ground_truth: | |
| gt_path = Path(args.ground_truth) | |
| if gt_path.exists(): | |
| logger.info("Loading ground truth from: %s", gt_path) | |
| ground_truth = load_ground_truth(str(gt_path)) | |
| if ground_truth: | |
| detected_plays = results.get("plays", []) | |
| comparisons = compare_with_ground_truth(detected_plays, ground_truth) | |
| else: | |
| logger.warning("Ground truth file not found: %s", gt_path) | |
| # Print summary | |
| print_summary(results, comparisons) | |
| print_plays_table(results.get("plays", [])) | |
| # Create timeline image | |
| output_dir = args.output_dir or str(OUTPUT_DIR) | |
| segment = results.get("segment", {}) | |
| timeline_path = str(Path(output_dir) / "timeline.png") | |
| create_timeline_image(results.get("plays", []), segment.get("start", 0), segment.get("end", 0), timeline_path) | |
| # Generate clips if requested | |
| clip_timing = None | |
| if args.generate_clips: | |
| video_path = args.video or str(DEFAULT_VIDEO_PATH) | |
| if not Path(video_path).exists(): | |
| logger.error("Video not found: %s", video_path) | |
| return 1 | |
| clips_dir = str(Path(output_dir) / "clips") | |
| if args.use_opencv: | |
| generate_play_clips(results, video_path, clips_dir, padding=args.padding) | |
| else: | |
| clip_timing = generate_play_clips_ffmpeg(results, video_path, clips_dir, padding=args.padding) | |
| # Print timing summary | |
| print_timing_summary(results, clip_timing) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |