Spaces:
Sleeping
Sleeping
| """ | |
| TennisVision - AI Ball Tracker | |
| Gradio application for tennis ball detection and tracking. | |
| """ | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Tuple, Optional | |
| from detector import BallDetector | |
| from tracker import BallTracker | |
| from utils import ( | |
| VideoReader, | |
| VideoWriter, | |
| export_trajectory_csv, | |
| validate_video_file, | |
| create_output_directory, | |
| draw_detection, | |
| draw_trajectory_trail, | |
| draw_speed_label, | |
| draw_info_panel, | |
| create_trajectory_plot | |
| ) | |
| def process_video( | |
| video_path: str, | |
| model_name: str, | |
| confidence_threshold: float, | |
| progress=gr.Progress() | |
| ) -> Tuple[Optional[str], Optional[str], Optional[str], str]: | |
| """ | |
| Process a video to track the tennis ball. | |
| Args: | |
| video_path: Path to input video file | |
| model_name: Detection model identifier | |
| confidence_threshold: Minimum detection confidence | |
| progress: Gradio progress tracker | |
| Returns: | |
| Tuple of (output_video_path, csv_path, plot_path, status_message) | |
| """ | |
| try: | |
| # Validate input video | |
| is_valid, msg = validate_video_file(video_path) | |
| if not is_valid: | |
| return None, None, None, f"โ Error: {msg}" | |
| progress(0, desc="Initializing models...") | |
| # Initialize detector and tracker | |
| detector = BallDetector( | |
| model_name=model_name, | |
| confidence_threshold=confidence_threshold | |
| ) | |
| # Read video properties | |
| with VideoReader(video_path) as reader: | |
| video_props = reader.get_properties() | |
| fps = video_props['fps'] | |
| frame_count = video_props['frame_count'] | |
| width = video_props['width'] | |
| height = video_props['height'] | |
| # Initialize tracker | |
| tracker = BallTracker(dt=1.0 / fps, max_missing_frames=int(fps * 0.5)) | |
| # Create temporary output files | |
| output_dir = create_output_directory("output") | |
| temp_video = tempfile.NamedTemporaryFile( | |
| delete=False, suffix='.mp4', dir=output_dir | |
| ) | |
| output_video_path = temp_video.name | |
| temp_video.close() | |
| csv_path = output_dir / "trajectory.csv" | |
| plot_path = output_dir / "trajectory_plot.png" | |
| progress(0.1, desc="Processing frames...") | |
| # Process video | |
| detection_count = 0 | |
| with VideoReader(video_path) as reader, \ | |
| VideoWriter(output_video_path, fps, width, height) as writer: | |
| for frame_num, frame in reader.read_frames(): | |
| # Update progress | |
| progress_pct = 0.1 + 0.7 * (frame_num / frame_count) | |
| progress( | |
| progress_pct, | |
| desc=f"Processing frame {frame_num + 1}/{frame_count}" | |
| ) | |
| # Detect ball | |
| detections = detector.detect(frame) | |
| # Update tracker | |
| if len(detections) > 0: | |
| # Use highest confidence detection | |
| best_detection = detections[0] | |
| cx, cy = detector.get_ball_center(best_detection) | |
| state = tracker.update((cx, cy)) | |
| detection_count += 1 | |
| # Draw detection box | |
| frame = draw_detection(frame, best_detection) | |
| else: | |
| # Predict without detection | |
| state = tracker.update(None) | |
| # Draw trajectory and info if tracker is active | |
| if state is not None and tracker.is_initialized(): | |
| x, y, vx, vy = state | |
| # Draw trajectory trail | |
| positions = tracker.get_last_n_positions(20) | |
| frame = draw_trajectory_trail(frame, positions) | |
| # Calculate and draw speed | |
| speed = tracker.get_speed(state) | |
| frame = draw_speed_label(frame, (x, y), speed, fps) | |
| # Draw info panel | |
| conf = detections[0][4] if len(detections) > 0 else None | |
| frame = draw_info_panel(frame, frame_num + 1, frame_count, fps, conf) | |
| # Write frame | |
| writer.write_frame(frame) | |
| # Export trajectory data | |
| progress(0.8, desc="Exporting trajectory data...") | |
| trajectory = tracker.get_trajectory() | |
| if len(trajectory) == 0: | |
| return None, None, None, "โ No ball detected in video. Try lowering the confidence threshold." | |
| # Export CSV | |
| export_success = export_trajectory_csv(trajectory, fps, str(csv_path)) | |
| if not export_success: | |
| csv_path = None | |
| # Create trajectory plot | |
| progress(0.9, desc="Creating trajectory plot...") | |
| try: | |
| create_trajectory_plot(trajectory, fps, str(plot_path)) | |
| except Exception as e: | |
| print(f"Failed to create plot: {e}") | |
| plot_path = None | |
| progress(1.0, desc="Complete!") | |
| # Generate status message | |
| status = f"""โ **Processing Complete!** | |
| **Video Info:** | |
| - Total Frames: {frame_count} | |
| - Frame Rate: {fps:.1f} FPS | |
| - Resolution: {width}x{height} | |
| **Tracking Results:** | |
| - Ball Detected: {detection_count} frames ({100 * detection_count / frame_count:.1f}%) | |
| - Trajectory Points: {len(trajectory)} | |
| **Outputs:** | |
| - Processed video with overlays | |
| - Trajectory CSV with {len(trajectory)} data points | |
| - 2D trajectory plot color-coded by speed | |
| """ | |
| return ( | |
| output_video_path, | |
| str(csv_path) if csv_path else None, | |
| str(plot_path) if plot_path else None, | |
| status | |
| ) | |
| except Exception as e: | |
| error_msg = f"โ **Error during processing:** {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, None, error_msg | |
| # Create Gradio interface | |
| def create_interface(): | |
| """Create and configure the Gradio interface.""" | |
| with gr.Blocks( | |
| title="TennisVision - AI Ball Tracker", | |
| theme=gr.themes.Soft() | |
| ) as app: | |
| gr.Markdown( | |
| """ | |
| # ๐พ TennisVision - AI Ball Tracker | |
| Upload a tennis video to automatically detect and track the ball using | |
| state-of-the-art computer vision models. | |
| **Features:** | |
| - Real-time ball detection with YOLOv8 | |
| - Smooth trajectory tracking with Kalman filter | |
| - Speed estimation and visualization | |
| - Downloadable outputs (video, CSV, plot) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### โ๏ธ Input & Settings") | |
| video_input = gr.Video( | |
| label="Upload Tennis Video", | |
| sources=["upload"] | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| choices=["yolov8n", "yolov8s", "yolov8m"], | |
| value="yolov8n", | |
| label="Detection Model", | |
| info="yolov8n is fastest, yolov8m is most accurate" | |
| ) | |
| confidence_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.3, | |
| step=0.05, | |
| label="Confidence Threshold", | |
| info="Lower = more detections (may include false positives)" | |
| ) | |
| process_btn = gr.Button( | |
| "๐ Run Tracking", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### ๐ก Tips | |
| - Use short clips (5-15 seconds) for faster processing | |
| - Ensure the ball is visible and in motion | |
| - Lower confidence threshold if ball is not detected | |
| - YOLOv8n provides fastest inference (~30 FPS) | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### ๐ Results") | |
| status_output = gr.Markdown( | |
| "Upload a video and click **Run Tracking** to begin." | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("๐น Processed Video"): | |
| video_output = gr.Video( | |
| label="Tracked Video", | |
| show_label=False | |
| ) | |
| with gr.Tab("๐ Trajectory Plot"): | |
| plot_output = gr.Image( | |
| label="2D Trajectory", | |
| show_label=False | |
| ) | |
| with gr.Tab("๐ฅ Downloads"): | |
| gr.Markdown("### Download Files") | |
| csv_output = gr.File( | |
| label="Trajectory Data (CSV)" | |
| ) | |
| video_download = gr.File( | |
| label="Processed Video (MP4)" | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_video, | |
| inputs=[video_input, model_dropdown, confidence_slider], | |
| outputs=[video_output, csv_output, plot_output, status_output] | |
| ).then( | |
| fn=lambda x: x, | |
| inputs=[video_output], | |
| outputs=[video_download] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### ๐ About | |
| **TennisVision** uses YOLOv8 for ball detection and Kalman filtering | |
| for smooth trajectory tracking. The system estimates ball speed and | |
| visualizes the complete trajectory with color-coded speed indicators. | |
| **Model:** YOLOv8 (Ultralytics) | |
| **Tracking:** Kalman Filter | |
| **Framework:** Gradio + OpenCV | |
| Built for deployment on Hugging Face Spaces ๐ค | |
| """ | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| # Create output directory | |
| create_output_directory("output") | |
| # Launch app | |
| app = create_interface() | |
| app.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |