Spaces:
Sleeping
Sleeping
| # Updated visualize.py | |
| # --- DETERMINISM FIXES (MUST BE BEFORE IMPORTS) --- | |
| import os | |
| import sys | |
| # Check for GPU flag early | |
| _use_gpu = '--gpu' in sys.argv | |
| if not _use_gpu: | |
| # Force CPU mode for deterministic predictions | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
| os.environ['MEDIAPIPE_DISABLE_GPU'] = '1' | |
| print("🔒 Running in CPU mode for deterministic predictions (use --gpu to enable GPU)") | |
| os.environ['TF_DETERMINISTIC_OPS'] = '1' | |
| os.environ['TF_CUDNN_DETERMINISTIC'] = '1' | |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
| import argparse | |
| import cv2 | |
| import yaml | |
| import numpy as np | |
| import mediapipe as mp | |
| import matplotlib.pyplot as plt | |
| from matplotlib.animation import FuncAnimation | |
| from mpl_toolkits.mplot3d import Axes3D | |
| from typing import Optional, Dict | |
| import os | |
| # Local imports | |
| from utils import normalize_sequence, should_skip_crop, get_segment_bounds, resolve_crop_config_for_video | |
| from features import PoseFeatureExtractor | |
| from ksi_v2 import ( | |
| EnhancedKSI, | |
| ShotPhaseSegmenter, | |
| extract_sequence_features, | |
| FEATURE_NAMES, | |
| ShotPhase | |
| ) | |
| def load_config(): | |
| with open("params.yaml") as f: | |
| return yaml.safe_load(f) | |
| def visualize_2d(video_path: str, crop_config: Dict, speed: float = 1.0, segment_rules: Optional[Dict] = None): | |
| """2D skeleton overlay visualization.""" | |
| print(f"\n--- 2D Visualization: {video_path} ---") | |
| print("Press 'q' to quit.") | |
| mp_pose = mp.solutions.pose | |
| pose = mp_pose.Pose(static_image_mode=False, model_complexity=1, min_detection_confidence=0.5) | |
| mp_drawing = mp.solutions.drawing_utils | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| print(f"Error: Could not open {video_path}") | |
| return | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| start_frame, tail_frames = get_segment_bounds(video_path, fps, total_frames, default_seconds=1.75, segment_cfg=segment_rules) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| base_delay_ms = max(1, int(1000.0 / fps)) | |
| # Decide whether to skip cropping for files like 'name (1).mp4' | |
| filename = os.path.basename(video_path) | |
| skip_crop = should_skip_crop(filename) | |
| processed = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| processed += 1 | |
| # Apply Crop (skip if filename matches the '(N)' pattern) | |
| if skip_crop: | |
| frame_cropped = frame | |
| else: | |
| h, w = frame.shape[:2] | |
| start_row = int(h * crop_config['top']) | |
| end_row = h - int(h * crop_config['bottom']) | |
| start_col = int(w * crop_config['left']) | |
| end_col = w - int(w * crop_config['right']) | |
| frame_cropped = frame[start_row:end_row, start_col:end_col] | |
| if frame_cropped.size == 0: | |
| continue | |
| # Detect | |
| image_rgb = cv2.cvtColor(frame_cropped, cv2.COLOR_BGR2RGB) | |
| results = pose.process(image_rgb) | |
| # Draw | |
| if results.pose_landmarks: | |
| mp_drawing.draw_landmarks( | |
| frame_cropped, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, | |
| mp_drawing.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=2), | |
| mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2) | |
| ) | |
| cv2.imshow('Badminton 2D View (Cropped)', frame_cropped) | |
| delay = max(1, int(base_delay_ms / max(speed, 1e-3))) | |
| if cv2.waitKey(delay) & 0xFF == ord('q'): | |
| break | |
| if processed >= int(tail_frames): | |
| break | |
| cap.release() | |
| cv2.destroyAllWindows() | |
| pose.close() | |
| def visualize_3d(video_path: str, crop_config: Dict, mp_config: Dict, speed: float = 1.0, segment_rules: Optional[Dict] = None): | |
| """3D skeleton animation with phase annotations.""" | |
| print(f"\n--- 3D Visualization: {video_path} ---") | |
| print("Extracting landmarks...") | |
| extractor = PoseFeatureExtractor(mp_config) | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| start_frame, tail_frames = get_segment_bounds(video_path, fps, total_frames, default_seconds=1.75, segment_cfg=segment_rules) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| # Decide whether to skip cropping for files like 'name (1).mp4' | |
| filename = os.path.basename(video_path) | |
| skip_crop = should_skip_crop(filename) | |
| raw_landmarks = [] | |
| processed = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| processed += 1 | |
| if not skip_crop: | |
| h, w = frame.shape[:2] | |
| frame = frame[ | |
| int(h * crop_config['top']):h - int(h * crop_config['bottom']), | |
| int(w * crop_config['left']):w - int(w * crop_config['right']) | |
| ] | |
| if frame.size == 0: | |
| continue | |
| res = extractor.pose.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| if res.pose_landmarks: | |
| lm = np.array([[l.x, l.y, l.z] for l in res.pose_landmarks.landmark]) | |
| raw_landmarks.append(lm) | |
| if processed >= int(tail_frames): | |
| break | |
| cap.release() | |
| if not raw_landmarks: | |
| print("No landmarks found.") | |
| return | |
| print(f"Normalizing {len(raw_landmarks)} frames...") | |
| normalized_seq = normalize_sequence(raw_landmarks) # [T, 33, 3] | |
| # Segment phases using ksi_v2 | |
| segmenter = ShotPhaseSegmenter(fps=fps) | |
| phases = segmenter.segment(normalized_seq) | |
| # Phase colors for visualization | |
| phase_colors = { | |
| ShotPhase.PREPARATION.value: 'gray', | |
| ShotPhase.LOADING.value: 'blue', | |
| ShotPhase.ACCELERATION.value: 'orange', | |
| ShotPhase.CONTACT.value: 'red', | |
| ShotPhase.FOLLOW_THROUGH.value: 'green' | |
| } | |
| def get_phase_for_frame(frame_idx: int) -> str: | |
| for phase_name, (start, end) in phases.items(): | |
| if start <= frame_idx < end: | |
| return phase_name | |
| return ShotPhase.FOLLOW_THROUGH.value | |
| # Animation Setup | |
| fig = plt.figure(figsize=(10, 8)) | |
| ax = fig.add_subplot(111, projection='3d') | |
| pose_connections = mp.solutions.pose.POSE_CONNECTIONS | |
| def update(frame_idx): | |
| ax.clear() | |
| keypoints = normalized_seq[frame_idx] | |
| current_phase = get_phase_for_frame(frame_idx) | |
| phase_color = phase_colors.get(current_phase, 'gray') | |
| # Plot Joints | |
| ax.scatter(keypoints[:, 0], keypoints[:, 2], keypoints[:, 1], | |
| c=phase_color, marker='o', s=30) | |
| # Plot Bones | |
| for conn in pose_connections: | |
| start, end = conn | |
| ax.plot([keypoints[start, 0], keypoints[end, 0]], | |
| [keypoints[start, 2], keypoints[end, 2]], | |
| [keypoints[start, 1], keypoints[end, 1]], | |
| color=phase_color, linewidth=2) | |
| ax.set_xlabel('X') | |
| ax.set_ylabel('Z') | |
| ax.set_zlabel('Y') | |
| ax.set_xlim([-1, 1]) | |
| ax.set_ylim([-1, 1]) | |
| ax.set_zlim([-1, 1]) | |
| ax.set_title(f"Frame {frame_idx}/{len(normalized_seq)} | Phase: {current_phase.upper()}") | |
| ax.view_init(elev=20, azim=-60) | |
| print("Starting Animation Window...") | |
| print(f"Detected phases: {phases}") | |
| interval_ms = max(1, int((1000.0 / fps) / max(speed, 1e-3))) | |
| ani = FuncAnimation(fig, update, frames=len(normalized_seq), interval=interval_ms) | |
| plt.show() | |
| def visualize_ksi_analysis(user_landmarks: np.ndarray, | |
| expert_landmarks: np.ndarray, | |
| fps: float = 30.0, | |
| weights: Optional[Dict[str, float]] = None): | |
| """ | |
| Visualize KSI analysis with phase scores, velocity profiles, and confidence. | |
| Uses the new ksi_v2 EnhancedKSI with contact-centered windowing. | |
| """ | |
| if weights is None: | |
| weights = {'pose': 0.5, 'velocity': 0.3, 'acceleration': 0.2} | |
| # Initialize enhanced KSI calculator with contact-centered windowing | |
| ksi_calc = EnhancedKSI( | |
| fps=fps, | |
| contact_window_pre_frames=18, # ±18 frames around contact | |
| contact_window_post_frames=18, | |
| bootstrap_min=50, | |
| bootstrap_max=200, | |
| ranking_margin=0.05 | |
| ) | |
| # Calculate KSI | |
| result = ksi_calc.calculate(expert_landmarks, user_landmarks, weights) | |
| # Create visualization | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 10)) | |
| fig.suptitle(f"KSI Analysis | Score: {result.ksi_total:.3f} (Weighted: {result.ksi_weighted:.3f})", | |
| fontsize=14, fontweight='bold') | |
| # 1. Component scores bar chart | |
| ax1 = axes[0, 0] | |
| components = ['pose', 'velocity', 'acceleration', 'jerk'] | |
| scores = [result.components.get(c, 0) for c in components] | |
| colors = ['steelblue', 'coral', 'seagreen', 'orchid'] | |
| ax1.bar(components, scores, color=colors) | |
| ax1.set_ylabel('Score') | |
| ax1.set_title('Component Scores') | |
| ax1.set_ylim([0, 1]) | |
| for i, v in enumerate(scores): | |
| ax1.text(i, v + 0.02, f'{v:.2f}', ha='center', fontsize=9) | |
| # 2. Phase scores | |
| ax2 = axes[0, 1] | |
| if result.phase_scores: | |
| phase_names = list(result.phase_scores.keys()) | |
| phase_vals = list(result.phase_scores.values()) | |
| phase_colors = ['gray', 'blue', 'orange', 'red', 'green'] | |
| ax2.barh(phase_names, phase_vals, color=phase_colors[:len(phase_names)]) | |
| ax2.set_xlim([0, 1]) | |
| ax2.set_xlabel('Score') | |
| ax2.set_title('Phase Scores') | |
| else: | |
| ax2.text(0.5, 0.5, 'No phase data', ha='center', va='center') | |
| ax2.set_title('Phase Scores') | |
| # 3. Velocity profile comparison | |
| ax3 = axes[0, 2] | |
| if result.velocity_analysis: | |
| vel_expert = result.velocity_analysis.get('velocity_profile_expert', []) | |
| vel_user = result.velocity_analysis.get('velocity_profile_user', []) | |
| if vel_expert and vel_user: | |
| ax3.plot(vel_expert, label='Expert', color='green', linewidth=2) | |
| ax3.plot(vel_user, label='User', color='blue', linewidth=2, alpha=0.7) | |
| # Mark peaks | |
| expert_peak = result.velocity_analysis.get('expert_peak_frame', 0) | |
| user_peak = result.velocity_analysis.get('user_peak_frame', 0) | |
| if expert_peak < len(vel_expert): | |
| ax3.axvline(expert_peak, color='green', linestyle='--', alpha=0.5, label='Expert peak') | |
| if user_peak < len(vel_user): | |
| ax3.axvline(user_peak, color='blue', linestyle='--', alpha=0.5, label='User peak') | |
| ax3.legend(fontsize=8) | |
| ax3.set_xlabel('Frame') | |
| ax3.set_ylabel('Velocity') | |
| ax3.set_title(f"Velocity Profile (Timing offset: {result.velocity_analysis.get('timing_offset_ms', 0):.0f}ms)") | |
| else: | |
| ax3.text(0.5, 0.5, 'No velocity data', ha='center', va='center') | |
| ax3.set_title('Velocity Profile') | |
| # 4. Top joint errors | |
| ax4 = axes[1, 0] | |
| if result.per_joint_errors: | |
| # Sort by mean error, take top 8 | |
| sorted_errors = sorted( | |
| result.per_joint_errors.items(), | |
| key=lambda x: x[1].mean_error, | |
| reverse=True | |
| )[:8] | |
| joint_names = [e[0].replace('_', '\n') for e in sorted_errors] | |
| mean_errors = [e[1].mean_error for e in sorted_errors] | |
| ci_lower = [e[1].confidence_interval[0] for e in sorted_errors] | |
| ci_upper = [e[1].confidence_interval[1] for e in sorted_errors] | |
| yerr = [[m - l for m, l in zip(mean_errors, ci_lower)], | |
| [u - m for m, u in zip(mean_errors, ci_upper)]] | |
| ax4.barh(joint_names, mean_errors, xerr=yerr, color='salmon', capsize=3) | |
| ax4.set_xlabel('Mean Error (degrees/ratio)') | |
| ax4.set_title('Top Joint Errors (with 95% CI)') | |
| else: | |
| ax4.text(0.5, 0.5, 'No error data', ha='center', va='center') | |
| ax4.set_title('Joint Errors') | |
| # 5. Confidence summary | |
| ax5 = axes[1, 1] | |
| if result.confidence: | |
| conf_items = [ | |
| ('KSI Mean', result.confidence.get('mean', 0)), | |
| ('KSI Std', result.confidence.get('std', 0)), | |
| ('CI Lower', result.confidence.get('ci_95_lower', 0)), | |
| ('CI Upper', result.confidence.get('ci_95_upper', 0)), | |
| ('Uncertainty', result.confidence.get('uncertainty_scalar', 0)), | |
| ('Valid Ratio', result.confidence.get('valid_frame_ratio', 0)), | |
| ] | |
| labels = [c[0] for c in conf_items] | |
| values = [c[1] for c in conf_items] | |
| ax5.barh(labels, values, color='teal') | |
| ax5.set_title(f"Confidence (Reliable: {result.confidence.get('reliable', False)}, " | |
| f"n_bootstrap: {result.confidence.get('n_bootstrap', 'N/A')})") | |
| else: | |
| ax5.text(0.5, 0.5, 'No confidence data', ha='center', va='center') | |
| ax5.set_title('Confidence') | |
| # 6. Recommendations | |
| ax6 = axes[1, 2] | |
| ax6.axis('off') | |
| rec_text = "RECOMMENDATIONS:\n\n" | |
| for i, rec in enumerate(result.recommendations[:5], 1): | |
| rec_text += f"{i}. {rec}\n\n" | |
| ax6.text(0.05, 0.95, rec_text, transform=ax6.transAxes, fontsize=9, | |
| verticalalignment='top', wrap=True, | |
| bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8)) | |
| ax6.set_title('Top Recommendations') | |
| plt.tight_layout() | |
| plt.savefig('dvclive/ksi_analysis.png', dpi=150) | |
| print("Saved KSI analysis to dvclive/ksi_analysis.png") | |
| plt.show() | |
| return result | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--video", help="Path to video file") | |
| parser.add_argument("--mode", choices=['2d', '3d', 'both', 'ksi'], default='both') | |
| parser.add_argument("--speed", type=float, default=1.0, help="Playback speed (e.g., 0.5 for half speed)") | |
| parser.add_argument("--user_npz", help="User poses .npz for KSI analysis") | |
| parser.add_argument("--expert_npz", help="Expert poses .npz for KSI analysis") | |
| parser.add_argument("--gpu", action='store_true', help="Enable GPU acceleration (less deterministic)") | |
| args = parser.parse_args() | |
| params = load_config() | |
| crop_cfg = params['pose_pipeline']['crop_config'] | |
| crop_overrides = params.get('crop_overrides', {}) | |
| mp_cfg = params['mediapipe'] | |
| segment_rules = params.get('segment_rules', {}) | |
| if args.mode == 'ksi': | |
| if not args.user_npz or not args.expert_npz: | |
| print("KSI mode requires --user_npz and --expert_npz") | |
| else: | |
| user_data = np.load(args.user_npz) | |
| expert_data = np.load(args.expert_npz) | |
| user_lm = user_data['features'] if 'features' in user_data else user_data['arr_0'] | |
| expert_lm = expert_data['features'] if 'features' in expert_data else expert_data['arr_0'] | |
| # Reshape to [T, 33, 3] if needed | |
| if user_lm.ndim == 2 and user_lm.shape[1] == 99: | |
| user_lm = user_lm.reshape(-1, 33, 3) | |
| if expert_lm.ndim == 2 and expert_lm.shape[1] == 99: | |
| expert_lm = expert_lm.reshape(-1, 33, 3) | |
| visualize_ksi_analysis(user_lm, expert_lm, fps=params.get('fps', 30.0), | |
| weights=params.get('ksi', {}).get('weights')) | |
| else: | |
| if not args.video: | |
| print("Video mode requires --video") | |
| else: | |
| if args.mode in ['2d', 'both']: | |
| eff_crop_cfg = resolve_crop_config_for_video(args.video, crop_cfg, crop_overrides) | |
| visualize_2d(args.video, eff_crop_cfg, speed=args.speed, segment_rules=segment_rules) | |
| if args.mode in ['3d', 'both']: | |
| visualize_3d(args.video, eff_crop_cfg, mp_cfg, speed=args.speed, segment_rules=segment_rules) |