tennisvision / utils /visualization.py
Onur Çopur
first commit
3b90d9c
"""
Visualization utilities for ball tracking.
This module provides functions for rendering bounding boxes, trajectories,
and creating 2D trajectory plots with speed-based color coding.
"""
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from typing import List, Tuple, Optional
from matplotlib.figure import Figure
def draw_detection(
frame: np.ndarray,
detection: Tuple[int, int, int, int, float],
color: Tuple[int, int, int] = (0, 255, 0),
thickness: int = 2
) -> np.ndarray:
"""
Draw a bounding box for a detection on the frame.
Args:
frame: Input frame (BGR format)
detection: Bounding box as (x1, y1, x2, y2, confidence)
color: Box color in BGR format
thickness: Line thickness
Returns:
Frame with drawn bounding box
"""
x1, y1, x2, y2, conf = detection
# Draw rectangle
cv2.rectangle(frame, (x1, y1), (x2, y2), color, thickness)
# Draw confidence label
label = f"{conf:.2f}"
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
label_y = max(y1 - 10, label_size[1])
cv2.rectangle(
frame,
(x1, label_y - label_size[1] - 5),
(x1 + label_size[0], label_y + 5),
color,
-1
)
cv2.putText(
frame,
label,
(x1, label_y),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0),
1
)
return frame
def draw_trajectory_trail(
frame: np.ndarray,
positions: List[Tuple[float, float]],
color: Tuple[int, int, int] = (0, 255, 255),
max_points: int = 20
) -> np.ndarray:
"""
Draw a trail showing recent ball positions.
Args:
frame: Input frame (BGR format)
positions: List of (x, y) positions (most recent last)
color: Trail color in BGR format
max_points: Maximum number of points to show
Returns:
Frame with drawn trajectory trail
"""
if len(positions) < 2:
return frame
# Use only recent positions
recent = positions[-max_points:]
# Draw lines connecting positions with fading effect
for i in range(1, len(recent)):
# Calculate alpha (opacity) based on position in trail
alpha = i / len(recent)
# Blend color with background
pt1 = (int(recent[i - 1][0]), int(recent[i - 1][1]))
pt2 = (int(recent[i][0]), int(recent[i][1]))
# Draw line with thickness varying by position
thickness = max(1, int(2 * alpha))
line_color = tuple(int(c * alpha) for c in color)
cv2.line(frame, pt1, pt2, line_color, thickness, cv2.LINE_AA)
# Draw circle at current position
if len(recent) > 0:
curr_pos = (int(recent[-1][0]), int(recent[-1][1]))
cv2.circle(frame, curr_pos, 5, color, -1, cv2.LINE_AA)
return frame
def draw_speed_label(
frame: np.ndarray,
position: Tuple[float, float],
speed: float,
fps: float,
color: Tuple[int, int, int] = (255, 255, 255)
) -> np.ndarray:
"""
Draw speed information near the ball position.
Args:
frame: Input frame (BGR format)
position: Ball position as (x, y)
speed: Speed in pixels per second
fps: Video frame rate
color: Text color in BGR format
Returns:
Frame with speed label
"""
x, y = int(position[0]), int(position[1])
# Convert pixel speed to approximate real-world units
# (This is a rough estimate; proper conversion requires camera calibration)
speed_kmh = speed * 0.01 # Rough approximation
label = f"{speed_kmh:.1f} km/h"
# Draw label with background
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 2
label_size, _ = cv2.getTextSize(label, font, font_scale, thickness)
# Position label above the ball
label_x = x - label_size[0] // 2
label_y = y - 20
# Ensure label stays within frame
label_x = max(0, min(label_x, frame.shape[1] - label_size[0]))
label_y = max(label_size[1] + 5, label_y)
# Draw background rectangle
cv2.rectangle(
frame,
(label_x - 5, label_y - label_size[1] - 5),
(label_x + label_size[0] + 5, label_y + 5),
(0, 0, 0),
-1
)
# Draw text
cv2.putText(
frame,
label,
(label_x, label_y),
font,
font_scale,
color,
thickness,
cv2.LINE_AA
)
return frame
def draw_info_panel(
frame: np.ndarray,
frame_num: int,
total_frames: int,
fps: float,
detection_conf: Optional[float] = None
) -> np.ndarray:
"""
Draw an information panel at the top of the frame.
Args:
frame: Input frame (BGR format)
frame_num: Current frame number
total_frames: Total number of frames
fps: Video frame rate
detection_conf: Detection confidence (if available)
Returns:
Frame with info panel
"""
# Create semi-transparent overlay
overlay = frame.copy()
cv2.rectangle(overlay, (0, 0), (frame.shape[1], 60), (0, 0, 0), -1)
frame = cv2.addWeighted(overlay, 0.6, frame, 0.4, 0)
# Draw text information
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
color = (255, 255, 255)
thickness = 2
# Frame counter
frame_text = f"Frame: {frame_num}/{total_frames}"
cv2.putText(frame, frame_text, (10, 25), font, font_scale, color, thickness)
# Time
time_text = f"Time: {frame_num / fps:.2f}s"
cv2.putText(frame, time_text, (10, 50), font, font_scale, color, thickness)
# Detection confidence (if available)
if detection_conf is not None:
conf_text = f"Confidence: {detection_conf:.2%}"
cv2.putText(frame, conf_text, (250, 25), font, font_scale, color, thickness)
return frame
def create_trajectory_plot(
trajectory: List[Tuple[float, float, float, float, int]],
fps: float,
output_path: Optional[str] = None
) -> Figure:
"""
Create a 2D trajectory plot color-coded by speed.
Args:
trajectory: List of (x, y, vx, vy, frame_num) tuples
fps: Video frame rate
output_path: Path to save plot (optional)
Returns:
Matplotlib Figure object
"""
if len(trajectory) == 0:
# Create empty plot
fig, ax = plt.subplots(figsize=(10, 8))
ax.text(
0.5, 0.5, "No trajectory data available",
ha='center', va='center', fontsize=14
)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
return fig
# Extract coordinates and velocities
x_coords = [p[0] for p in trajectory]
y_coords = [p[1] for p in trajectory]
vx = [p[2] for p in trajectory]
vy = [p[3] for p in trajectory]
# Calculate speeds
speeds = [np.sqrt(vx[i]**2 + vy[i]**2) / (1.0 / fps) for i in range(len(vx))]
# Create figure
fig, ax = plt.subplots(figsize=(12, 10))
# Normalize speeds for color mapping
if max(speeds) > 0:
norm = mcolors.Normalize(vmin=min(speeds), vmax=max(speeds))
colormap = plt.cm.jet
else:
norm = None
colormap = None
# Plot trajectory with color-coded speeds
for i in range(1, len(x_coords)):
if norm is not None:
color = colormap(norm(speeds[i]))
else:
color = 'blue'
ax.plot(
[x_coords[i - 1], x_coords[i]],
[y_coords[i - 1], y_coords[i]],
color=color,
linewidth=2,
alpha=0.7
)
# Add start and end markers
ax.scatter(x_coords[0], y_coords[0], c='green', s=100, marker='o',
label='Start', zorder=5, edgecolors='black', linewidths=2)
ax.scatter(x_coords[-1], y_coords[-1], c='red', s=100, marker='X',
label='End', zorder=5, edgecolors='black', linewidths=2)
# Formatting
ax.set_xlabel('X Position (pixels)', fontsize=12, fontweight='bold')
ax.set_ylabel('Y Position (pixels)', fontsize=12, fontweight='bold')
ax.set_title('Tennis Ball Trajectory (Color = Speed)', fontsize=14, fontweight='bold')
ax.legend(loc='best', fontsize=10)
ax.grid(True, alpha=0.3)
ax.invert_yaxis() # Invert Y-axis to match image coordinates
# Add colorbar
if norm is not None:
sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, label='Speed (pixels/sec)')
plt.tight_layout()
# Save if path provided
if output_path:
try:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
except Exception as e:
print(f"Error saving plot: {str(e)}")
return fig