Spaces:
Sleeping
Sleeping
| """ | |
| Visualization utilities for ball tracking. | |
| This module provides functions for rendering bounding boxes, trajectories, | |
| and creating 2D trajectory plots with speed-based color coding. | |
| """ | |
| import cv2 | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as mcolors | |
| from typing import List, Tuple, Optional | |
| from matplotlib.figure import Figure | |
| def draw_detection( | |
| frame: np.ndarray, | |
| detection: Tuple[int, int, int, int, float], | |
| color: Tuple[int, int, int] = (0, 255, 0), | |
| thickness: int = 2 | |
| ) -> np.ndarray: | |
| """ | |
| Draw a bounding box for a detection on the frame. | |
| Args: | |
| frame: Input frame (BGR format) | |
| detection: Bounding box as (x1, y1, x2, y2, confidence) | |
| color: Box color in BGR format | |
| thickness: Line thickness | |
| Returns: | |
| Frame with drawn bounding box | |
| """ | |
| x1, y1, x2, y2, conf = detection | |
| # Draw rectangle | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, thickness) | |
| # Draw confidence label | |
| label = f"{conf:.2f}" | |
| label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | |
| label_y = max(y1 - 10, label_size[1]) | |
| cv2.rectangle( | |
| frame, | |
| (x1, label_y - label_size[1] - 5), | |
| (x1 + label_size[0], label_y + 5), | |
| color, | |
| -1 | |
| ) | |
| cv2.putText( | |
| frame, | |
| label, | |
| (x1, label_y), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (0, 0, 0), | |
| 1 | |
| ) | |
| return frame | |
| def draw_trajectory_trail( | |
| frame: np.ndarray, | |
| positions: List[Tuple[float, float]], | |
| color: Tuple[int, int, int] = (0, 255, 255), | |
| max_points: int = 20 | |
| ) -> np.ndarray: | |
| """ | |
| Draw a trail showing recent ball positions. | |
| Args: | |
| frame: Input frame (BGR format) | |
| positions: List of (x, y) positions (most recent last) | |
| color: Trail color in BGR format | |
| max_points: Maximum number of points to show | |
| Returns: | |
| Frame with drawn trajectory trail | |
| """ | |
| if len(positions) < 2: | |
| return frame | |
| # Use only recent positions | |
| recent = positions[-max_points:] | |
| # Draw lines connecting positions with fading effect | |
| for i in range(1, len(recent)): | |
| # Calculate alpha (opacity) based on position in trail | |
| alpha = i / len(recent) | |
| # Blend color with background | |
| pt1 = (int(recent[i - 1][0]), int(recent[i - 1][1])) | |
| pt2 = (int(recent[i][0]), int(recent[i][1])) | |
| # Draw line with thickness varying by position | |
| thickness = max(1, int(2 * alpha)) | |
| line_color = tuple(int(c * alpha) for c in color) | |
| cv2.line(frame, pt1, pt2, line_color, thickness, cv2.LINE_AA) | |
| # Draw circle at current position | |
| if len(recent) > 0: | |
| curr_pos = (int(recent[-1][0]), int(recent[-1][1])) | |
| cv2.circle(frame, curr_pos, 5, color, -1, cv2.LINE_AA) | |
| return frame | |
| def draw_speed_label( | |
| frame: np.ndarray, | |
| position: Tuple[float, float], | |
| speed: float, | |
| fps: float, | |
| color: Tuple[int, int, int] = (255, 255, 255) | |
| ) -> np.ndarray: | |
| """ | |
| Draw speed information near the ball position. | |
| Args: | |
| frame: Input frame (BGR format) | |
| position: Ball position as (x, y) | |
| speed: Speed in pixels per second | |
| fps: Video frame rate | |
| color: Text color in BGR format | |
| Returns: | |
| Frame with speed label | |
| """ | |
| x, y = int(position[0]), int(position[1]) | |
| # Convert pixel speed to approximate real-world units | |
| # (This is a rough estimate; proper conversion requires camera calibration) | |
| speed_kmh = speed * 0.01 # Rough approximation | |
| label = f"{speed_kmh:.1f} km/h" | |
| # Draw label with background | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.6 | |
| thickness = 2 | |
| label_size, _ = cv2.getTextSize(label, font, font_scale, thickness) | |
| # Position label above the ball | |
| label_x = x - label_size[0] // 2 | |
| label_y = y - 20 | |
| # Ensure label stays within frame | |
| label_x = max(0, min(label_x, frame.shape[1] - label_size[0])) | |
| label_y = max(label_size[1] + 5, label_y) | |
| # Draw background rectangle | |
| cv2.rectangle( | |
| frame, | |
| (label_x - 5, label_y - label_size[1] - 5), | |
| (label_x + label_size[0] + 5, label_y + 5), | |
| (0, 0, 0), | |
| -1 | |
| ) | |
| # Draw text | |
| cv2.putText( | |
| frame, | |
| label, | |
| (label_x, label_y), | |
| font, | |
| font_scale, | |
| color, | |
| thickness, | |
| cv2.LINE_AA | |
| ) | |
| return frame | |
| def draw_info_panel( | |
| frame: np.ndarray, | |
| frame_num: int, | |
| total_frames: int, | |
| fps: float, | |
| detection_conf: Optional[float] = None | |
| ) -> np.ndarray: | |
| """ | |
| Draw an information panel at the top of the frame. | |
| Args: | |
| frame: Input frame (BGR format) | |
| frame_num: Current frame number | |
| total_frames: Total number of frames | |
| fps: Video frame rate | |
| detection_conf: Detection confidence (if available) | |
| Returns: | |
| Frame with info panel | |
| """ | |
| # Create semi-transparent overlay | |
| overlay = frame.copy() | |
| cv2.rectangle(overlay, (0, 0), (frame.shape[1], 60), (0, 0, 0), -1) | |
| frame = cv2.addWeighted(overlay, 0.6, frame, 0.4, 0) | |
| # Draw text information | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.6 | |
| color = (255, 255, 255) | |
| thickness = 2 | |
| # Frame counter | |
| frame_text = f"Frame: {frame_num}/{total_frames}" | |
| cv2.putText(frame, frame_text, (10, 25), font, font_scale, color, thickness) | |
| # Time | |
| time_text = f"Time: {frame_num / fps:.2f}s" | |
| cv2.putText(frame, time_text, (10, 50), font, font_scale, color, thickness) | |
| # Detection confidence (if available) | |
| if detection_conf is not None: | |
| conf_text = f"Confidence: {detection_conf:.2%}" | |
| cv2.putText(frame, conf_text, (250, 25), font, font_scale, color, thickness) | |
| return frame | |
| def create_trajectory_plot( | |
| trajectory: List[Tuple[float, float, float, float, int]], | |
| fps: float, | |
| output_path: Optional[str] = None | |
| ) -> Figure: | |
| """ | |
| Create a 2D trajectory plot color-coded by speed. | |
| Args: | |
| trajectory: List of (x, y, vx, vy, frame_num) tuples | |
| fps: Video frame rate | |
| output_path: Path to save plot (optional) | |
| Returns: | |
| Matplotlib Figure object | |
| """ | |
| if len(trajectory) == 0: | |
| # Create empty plot | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| ax.text( | |
| 0.5, 0.5, "No trajectory data available", | |
| ha='center', va='center', fontsize=14 | |
| ) | |
| ax.set_xlim(0, 1) | |
| ax.set_ylim(0, 1) | |
| return fig | |
| # Extract coordinates and velocities | |
| x_coords = [p[0] for p in trajectory] | |
| y_coords = [p[1] for p in trajectory] | |
| vx = [p[2] for p in trajectory] | |
| vy = [p[3] for p in trajectory] | |
| # Calculate speeds | |
| speeds = [np.sqrt(vx[i]**2 + vy[i]**2) / (1.0 / fps) for i in range(len(vx))] | |
| # Create figure | |
| fig, ax = plt.subplots(figsize=(12, 10)) | |
| # Normalize speeds for color mapping | |
| if max(speeds) > 0: | |
| norm = mcolors.Normalize(vmin=min(speeds), vmax=max(speeds)) | |
| colormap = plt.cm.jet | |
| else: | |
| norm = None | |
| colormap = None | |
| # Plot trajectory with color-coded speeds | |
| for i in range(1, len(x_coords)): | |
| if norm is not None: | |
| color = colormap(norm(speeds[i])) | |
| else: | |
| color = 'blue' | |
| ax.plot( | |
| [x_coords[i - 1], x_coords[i]], | |
| [y_coords[i - 1], y_coords[i]], | |
| color=color, | |
| linewidth=2, | |
| alpha=0.7 | |
| ) | |
| # Add start and end markers | |
| ax.scatter(x_coords[0], y_coords[0], c='green', s=100, marker='o', | |
| label='Start', zorder=5, edgecolors='black', linewidths=2) | |
| ax.scatter(x_coords[-1], y_coords[-1], c='red', s=100, marker='X', | |
| label='End', zorder=5, edgecolors='black', linewidths=2) | |
| # Formatting | |
| ax.set_xlabel('X Position (pixels)', fontsize=12, fontweight='bold') | |
| ax.set_ylabel('Y Position (pixels)', fontsize=12, fontweight='bold') | |
| ax.set_title('Tennis Ball Trajectory (Color = Speed)', fontsize=14, fontweight='bold') | |
| ax.legend(loc='best', fontsize=10) | |
| ax.grid(True, alpha=0.3) | |
| ax.invert_yaxis() # Invert Y-axis to match image coordinates | |
| # Add colorbar | |
| if norm is not None: | |
| sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm) | |
| sm.set_array([]) | |
| cbar = plt.colorbar(sm, ax=ax, label='Speed (pixels/sec)') | |
| plt.tight_layout() | |
| # Save if path provided | |
| if output_path: | |
| try: | |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') | |
| except Exception as e: | |
| print(f"Error saving plot: {str(e)}") | |
| return fig | |