tennisvision / app.py
Onur ร‡opur
first commit
3b90d9c
"""
TennisVision - AI Ball Tracker
Gradio application for tennis ball detection and tracking.
"""
import gradio as gr
import cv2
import numpy as np
import os
import tempfile
from pathlib import Path
from typing import Tuple, Optional
from detector import BallDetector
from tracker import BallTracker
from utils import (
VideoReader,
VideoWriter,
export_trajectory_csv,
validate_video_file,
create_output_directory,
draw_detection,
draw_trajectory_trail,
draw_speed_label,
draw_info_panel,
create_trajectory_plot
)
def process_video(
video_path: str,
model_name: str,
confidence_threshold: float,
progress=gr.Progress()
) -> Tuple[Optional[str], Optional[str], Optional[str], str]:
"""
Process a video to track the tennis ball.
Args:
video_path: Path to input video file
model_name: Detection model identifier
confidence_threshold: Minimum detection confidence
progress: Gradio progress tracker
Returns:
Tuple of (output_video_path, csv_path, plot_path, status_message)
"""
try:
# Validate input video
is_valid, msg = validate_video_file(video_path)
if not is_valid:
return None, None, None, f"โŒ Error: {msg}"
progress(0, desc="Initializing models...")
# Initialize detector and tracker
detector = BallDetector(
model_name=model_name,
confidence_threshold=confidence_threshold
)
# Read video properties
with VideoReader(video_path) as reader:
video_props = reader.get_properties()
fps = video_props['fps']
frame_count = video_props['frame_count']
width = video_props['width']
height = video_props['height']
# Initialize tracker
tracker = BallTracker(dt=1.0 / fps, max_missing_frames=int(fps * 0.5))
# Create temporary output files
output_dir = create_output_directory("output")
temp_video = tempfile.NamedTemporaryFile(
delete=False, suffix='.mp4', dir=output_dir
)
output_video_path = temp_video.name
temp_video.close()
csv_path = output_dir / "trajectory.csv"
plot_path = output_dir / "trajectory_plot.png"
progress(0.1, desc="Processing frames...")
# Process video
detection_count = 0
with VideoReader(video_path) as reader, \
VideoWriter(output_video_path, fps, width, height) as writer:
for frame_num, frame in reader.read_frames():
# Update progress
progress_pct = 0.1 + 0.7 * (frame_num / frame_count)
progress(
progress_pct,
desc=f"Processing frame {frame_num + 1}/{frame_count}"
)
# Detect ball
detections = detector.detect(frame)
# Update tracker
if len(detections) > 0:
# Use highest confidence detection
best_detection = detections[0]
cx, cy = detector.get_ball_center(best_detection)
state = tracker.update((cx, cy))
detection_count += 1
# Draw detection box
frame = draw_detection(frame, best_detection)
else:
# Predict without detection
state = tracker.update(None)
# Draw trajectory and info if tracker is active
if state is not None and tracker.is_initialized():
x, y, vx, vy = state
# Draw trajectory trail
positions = tracker.get_last_n_positions(20)
frame = draw_trajectory_trail(frame, positions)
# Calculate and draw speed
speed = tracker.get_speed(state)
frame = draw_speed_label(frame, (x, y), speed, fps)
# Draw info panel
conf = detections[0][4] if len(detections) > 0 else None
frame = draw_info_panel(frame, frame_num + 1, frame_count, fps, conf)
# Write frame
writer.write_frame(frame)
# Export trajectory data
progress(0.8, desc="Exporting trajectory data...")
trajectory = tracker.get_trajectory()
if len(trajectory) == 0:
return None, None, None, "โŒ No ball detected in video. Try lowering the confidence threshold."
# Export CSV
export_success = export_trajectory_csv(trajectory, fps, str(csv_path))
if not export_success:
csv_path = None
# Create trajectory plot
progress(0.9, desc="Creating trajectory plot...")
try:
create_trajectory_plot(trajectory, fps, str(plot_path))
except Exception as e:
print(f"Failed to create plot: {e}")
plot_path = None
progress(1.0, desc="Complete!")
# Generate status message
status = f"""โœ… **Processing Complete!**
**Video Info:**
- Total Frames: {frame_count}
- Frame Rate: {fps:.1f} FPS
- Resolution: {width}x{height}
**Tracking Results:**
- Ball Detected: {detection_count} frames ({100 * detection_count / frame_count:.1f}%)
- Trajectory Points: {len(trajectory)}
**Outputs:**
- Processed video with overlays
- Trajectory CSV with {len(trajectory)} data points
- 2D trajectory plot color-coded by speed
"""
return (
output_video_path,
str(csv_path) if csv_path else None,
str(plot_path) if plot_path else None,
status
)
except Exception as e:
error_msg = f"โŒ **Error during processing:** {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return None, None, None, error_msg
# Create Gradio interface
def create_interface():
"""Create and configure the Gradio interface."""
with gr.Blocks(
title="TennisVision - AI Ball Tracker",
theme=gr.themes.Soft()
) as app:
gr.Markdown(
"""
# ๐ŸŽพ TennisVision - AI Ball Tracker
Upload a tennis video to automatically detect and track the ball using
state-of-the-art computer vision models.
**Features:**
- Real-time ball detection with YOLOv8
- Smooth trajectory tracking with Kalman filter
- Speed estimation and visualization
- Downloadable outputs (video, CSV, plot)
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### โš™๏ธ Input & Settings")
video_input = gr.Video(
label="Upload Tennis Video",
sources=["upload"]
)
model_dropdown = gr.Dropdown(
choices=["yolov8n", "yolov8s", "yolov8m"],
value="yolov8n",
label="Detection Model",
info="yolov8n is fastest, yolov8m is most accurate"
)
confidence_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.3,
step=0.05,
label="Confidence Threshold",
info="Lower = more detections (may include false positives)"
)
process_btn = gr.Button(
"๐Ÿš€ Run Tracking",
variant="primary",
size="lg"
)
gr.Markdown(
"""
### ๐Ÿ’ก Tips
- Use short clips (5-15 seconds) for faster processing
- Ensure the ball is visible and in motion
- Lower confidence threshold if ball is not detected
- YOLOv8n provides fastest inference (~30 FPS)
"""
)
with gr.Column(scale=2):
gr.Markdown("### ๐Ÿ“Š Results")
status_output = gr.Markdown(
"Upload a video and click **Run Tracking** to begin."
)
with gr.Tabs():
with gr.Tab("๐Ÿ“น Processed Video"):
video_output = gr.Video(
label="Tracked Video",
show_label=False
)
with gr.Tab("๐Ÿ“ˆ Trajectory Plot"):
plot_output = gr.Image(
label="2D Trajectory",
show_label=False
)
with gr.Tab("๐Ÿ“ฅ Downloads"):
gr.Markdown("### Download Files")
csv_output = gr.File(
label="Trajectory Data (CSV)"
)
video_download = gr.File(
label="Processed Video (MP4)"
)
# Event handlers
process_btn.click(
fn=process_video,
inputs=[video_input, model_dropdown, confidence_slider],
outputs=[video_output, csv_output, plot_output, status_output]
).then(
fn=lambda x: x,
inputs=[video_output],
outputs=[video_download]
)
gr.Markdown(
"""
---
### ๐Ÿ“š About
**TennisVision** uses YOLOv8 for ball detection and Kalman filtering
for smooth trajectory tracking. The system estimates ball speed and
visualizes the complete trajectory with color-coded speed indicators.
**Model:** YOLOv8 (Ultralytics)
**Tracking:** Kalman Filter
**Framework:** Gradio + OpenCV
Built for deployment on Hugging Face Spaces ๐Ÿค—
"""
)
return app
if __name__ == "__main__":
# Create output directory
create_output_directory("output")
# Launch app
app = create_interface()
app.launch(
share=False,
server_name="0.0.0.0",
server_port=7860
)