smashfix-v1 / src /research /visualize.py
uncertainrods's picture
v1-try-deploy
0d0412d
# Updated visualize.py
# --- DETERMINISM FIXES (MUST BE BEFORE IMPORTS) ---
import os
import sys
# Check for GPU flag early
_use_gpu = '--gpu' in sys.argv
if not _use_gpu:
# Force CPU mode for deterministic predictions
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['MEDIAPIPE_DISABLE_GPU'] = '1'
print("🔒 Running in CPU mode for deterministic predictions (use --gpu to enable GPU)")
os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import argparse
import cv2
import yaml
import numpy as np
import mediapipe as mp
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
from typing import Optional, Dict
import os
# Local imports
from utils import normalize_sequence, should_skip_crop, get_segment_bounds, resolve_crop_config_for_video
from features import PoseFeatureExtractor
from ksi_v2 import (
EnhancedKSI,
ShotPhaseSegmenter,
extract_sequence_features,
FEATURE_NAMES,
ShotPhase
)
def load_config():
with open("params.yaml") as f:
return yaml.safe_load(f)
def visualize_2d(video_path: str, crop_config: Dict, speed: float = 1.0, segment_rules: Optional[Dict] = None):
"""2D skeleton overlay visualization."""
print(f"\n--- 2D Visualization: {video_path} ---")
print("Press 'q' to quit.")
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False, model_complexity=1, min_detection_confidence=0.5)
mp_drawing = mp.solutions.drawing_utils
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open {video_path}")
return
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
start_frame, tail_frames = get_segment_bounds(video_path, fps, total_frames, default_seconds=1.75, segment_cfg=segment_rules)
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
base_delay_ms = max(1, int(1000.0 / fps))
# Decide whether to skip cropping for files like 'name (1).mp4'
filename = os.path.basename(video_path)
skip_crop = should_skip_crop(filename)
processed = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
processed += 1
# Apply Crop (skip if filename matches the '(N)' pattern)
if skip_crop:
frame_cropped = frame
else:
h, w = frame.shape[:2]
start_row = int(h * crop_config['top'])
end_row = h - int(h * crop_config['bottom'])
start_col = int(w * crop_config['left'])
end_col = w - int(w * crop_config['right'])
frame_cropped = frame[start_row:end_row, start_col:end_col]
if frame_cropped.size == 0:
continue
# Detect
image_rgb = cv2.cvtColor(frame_cropped, cv2.COLOR_BGR2RGB)
results = pose.process(image_rgb)
# Draw
if results.pose_landmarks:
mp_drawing.draw_landmarks(
frame_cropped, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
mp_drawing.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=2),
mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2)
)
cv2.imshow('Badminton 2D View (Cropped)', frame_cropped)
delay = max(1, int(base_delay_ms / max(speed, 1e-3)))
if cv2.waitKey(delay) & 0xFF == ord('q'):
break
if processed >= int(tail_frames):
break
cap.release()
cv2.destroyAllWindows()
pose.close()
def visualize_3d(video_path: str, crop_config: Dict, mp_config: Dict, speed: float = 1.0, segment_rules: Optional[Dict] = None):
"""3D skeleton animation with phase annotations."""
print(f"\n--- 3D Visualization: {video_path} ---")
print("Extracting landmarks...")
extractor = PoseFeatureExtractor(mp_config)
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
start_frame, tail_frames = get_segment_bounds(video_path, fps, total_frames, default_seconds=1.75, segment_cfg=segment_rules)
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
# Decide whether to skip cropping for files like 'name (1).mp4'
filename = os.path.basename(video_path)
skip_crop = should_skip_crop(filename)
raw_landmarks = []
processed = 0
while True:
ret, frame = cap.read()
if not ret:
break
processed += 1
if not skip_crop:
h, w = frame.shape[:2]
frame = frame[
int(h * crop_config['top']):h - int(h * crop_config['bottom']),
int(w * crop_config['left']):w - int(w * crop_config['right'])
]
if frame.size == 0:
continue
res = extractor.pose.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
if res.pose_landmarks:
lm = np.array([[l.x, l.y, l.z] for l in res.pose_landmarks.landmark])
raw_landmarks.append(lm)
if processed >= int(tail_frames):
break
cap.release()
if not raw_landmarks:
print("No landmarks found.")
return
print(f"Normalizing {len(raw_landmarks)} frames...")
normalized_seq = normalize_sequence(raw_landmarks) # [T, 33, 3]
# Segment phases using ksi_v2
segmenter = ShotPhaseSegmenter(fps=fps)
phases = segmenter.segment(normalized_seq)
# Phase colors for visualization
phase_colors = {
ShotPhase.PREPARATION.value: 'gray',
ShotPhase.LOADING.value: 'blue',
ShotPhase.ACCELERATION.value: 'orange',
ShotPhase.CONTACT.value: 'red',
ShotPhase.FOLLOW_THROUGH.value: 'green'
}
def get_phase_for_frame(frame_idx: int) -> str:
for phase_name, (start, end) in phases.items():
if start <= frame_idx < end:
return phase_name
return ShotPhase.FOLLOW_THROUGH.value
# Animation Setup
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
pose_connections = mp.solutions.pose.POSE_CONNECTIONS
def update(frame_idx):
ax.clear()
keypoints = normalized_seq[frame_idx]
current_phase = get_phase_for_frame(frame_idx)
phase_color = phase_colors.get(current_phase, 'gray')
# Plot Joints
ax.scatter(keypoints[:, 0], keypoints[:, 2], keypoints[:, 1],
c=phase_color, marker='o', s=30)
# Plot Bones
for conn in pose_connections:
start, end = conn
ax.plot([keypoints[start, 0], keypoints[end, 0]],
[keypoints[start, 2], keypoints[end, 2]],
[keypoints[start, 1], keypoints[end, 1]],
color=phase_color, linewidth=2)
ax.set_xlabel('X')
ax.set_ylabel('Z')
ax.set_zlabel('Y')
ax.set_xlim([-1, 1])
ax.set_ylim([-1, 1])
ax.set_zlim([-1, 1])
ax.set_title(f"Frame {frame_idx}/{len(normalized_seq)} | Phase: {current_phase.upper()}")
ax.view_init(elev=20, azim=-60)
print("Starting Animation Window...")
print(f"Detected phases: {phases}")
interval_ms = max(1, int((1000.0 / fps) / max(speed, 1e-3)))
ani = FuncAnimation(fig, update, frames=len(normalized_seq), interval=interval_ms)
plt.show()
def visualize_ksi_analysis(user_landmarks: np.ndarray,
expert_landmarks: np.ndarray,
fps: float = 30.0,
weights: Optional[Dict[str, float]] = None):
"""
Visualize KSI analysis with phase scores, velocity profiles, and confidence.
Uses the new ksi_v2 EnhancedKSI with contact-centered windowing.
"""
if weights is None:
weights = {'pose': 0.5, 'velocity': 0.3, 'acceleration': 0.2}
# Initialize enhanced KSI calculator with contact-centered windowing
ksi_calc = EnhancedKSI(
fps=fps,
contact_window_pre_frames=18, # ±18 frames around contact
contact_window_post_frames=18,
bootstrap_min=50,
bootstrap_max=200,
ranking_margin=0.05
)
# Calculate KSI
result = ksi_calc.calculate(expert_landmarks, user_landmarks, weights)
# Create visualization
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle(f"KSI Analysis | Score: {result.ksi_total:.3f} (Weighted: {result.ksi_weighted:.3f})",
fontsize=14, fontweight='bold')
# 1. Component scores bar chart
ax1 = axes[0, 0]
components = ['pose', 'velocity', 'acceleration', 'jerk']
scores = [result.components.get(c, 0) for c in components]
colors = ['steelblue', 'coral', 'seagreen', 'orchid']
ax1.bar(components, scores, color=colors)
ax1.set_ylabel('Score')
ax1.set_title('Component Scores')
ax1.set_ylim([0, 1])
for i, v in enumerate(scores):
ax1.text(i, v + 0.02, f'{v:.2f}', ha='center', fontsize=9)
# 2. Phase scores
ax2 = axes[0, 1]
if result.phase_scores:
phase_names = list(result.phase_scores.keys())
phase_vals = list(result.phase_scores.values())
phase_colors = ['gray', 'blue', 'orange', 'red', 'green']
ax2.barh(phase_names, phase_vals, color=phase_colors[:len(phase_names)])
ax2.set_xlim([0, 1])
ax2.set_xlabel('Score')
ax2.set_title('Phase Scores')
else:
ax2.text(0.5, 0.5, 'No phase data', ha='center', va='center')
ax2.set_title('Phase Scores')
# 3. Velocity profile comparison
ax3 = axes[0, 2]
if result.velocity_analysis:
vel_expert = result.velocity_analysis.get('velocity_profile_expert', [])
vel_user = result.velocity_analysis.get('velocity_profile_user', [])
if vel_expert and vel_user:
ax3.plot(vel_expert, label='Expert', color='green', linewidth=2)
ax3.plot(vel_user, label='User', color='blue', linewidth=2, alpha=0.7)
# Mark peaks
expert_peak = result.velocity_analysis.get('expert_peak_frame', 0)
user_peak = result.velocity_analysis.get('user_peak_frame', 0)
if expert_peak < len(vel_expert):
ax3.axvline(expert_peak, color='green', linestyle='--', alpha=0.5, label='Expert peak')
if user_peak < len(vel_user):
ax3.axvline(user_peak, color='blue', linestyle='--', alpha=0.5, label='User peak')
ax3.legend(fontsize=8)
ax3.set_xlabel('Frame')
ax3.set_ylabel('Velocity')
ax3.set_title(f"Velocity Profile (Timing offset: {result.velocity_analysis.get('timing_offset_ms', 0):.0f}ms)")
else:
ax3.text(0.5, 0.5, 'No velocity data', ha='center', va='center')
ax3.set_title('Velocity Profile')
# 4. Top joint errors
ax4 = axes[1, 0]
if result.per_joint_errors:
# Sort by mean error, take top 8
sorted_errors = sorted(
result.per_joint_errors.items(),
key=lambda x: x[1].mean_error,
reverse=True
)[:8]
joint_names = [e[0].replace('_', '\n') for e in sorted_errors]
mean_errors = [e[1].mean_error for e in sorted_errors]
ci_lower = [e[1].confidence_interval[0] for e in sorted_errors]
ci_upper = [e[1].confidence_interval[1] for e in sorted_errors]
yerr = [[m - l for m, l in zip(mean_errors, ci_lower)],
[u - m for m, u in zip(mean_errors, ci_upper)]]
ax4.barh(joint_names, mean_errors, xerr=yerr, color='salmon', capsize=3)
ax4.set_xlabel('Mean Error (degrees/ratio)')
ax4.set_title('Top Joint Errors (with 95% CI)')
else:
ax4.text(0.5, 0.5, 'No error data', ha='center', va='center')
ax4.set_title('Joint Errors')
# 5. Confidence summary
ax5 = axes[1, 1]
if result.confidence:
conf_items = [
('KSI Mean', result.confidence.get('mean', 0)),
('KSI Std', result.confidence.get('std', 0)),
('CI Lower', result.confidence.get('ci_95_lower', 0)),
('CI Upper', result.confidence.get('ci_95_upper', 0)),
('Uncertainty', result.confidence.get('uncertainty_scalar', 0)),
('Valid Ratio', result.confidence.get('valid_frame_ratio', 0)),
]
labels = [c[0] for c in conf_items]
values = [c[1] for c in conf_items]
ax5.barh(labels, values, color='teal')
ax5.set_title(f"Confidence (Reliable: {result.confidence.get('reliable', False)}, "
f"n_bootstrap: {result.confidence.get('n_bootstrap', 'N/A')})")
else:
ax5.text(0.5, 0.5, 'No confidence data', ha='center', va='center')
ax5.set_title('Confidence')
# 6. Recommendations
ax6 = axes[1, 2]
ax6.axis('off')
rec_text = "RECOMMENDATIONS:\n\n"
for i, rec in enumerate(result.recommendations[:5], 1):
rec_text += f"{i}. {rec}\n\n"
ax6.text(0.05, 0.95, rec_text, transform=ax6.transAxes, fontsize=9,
verticalalignment='top', wrap=True,
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
ax6.set_title('Top Recommendations')
plt.tight_layout()
plt.savefig('dvclive/ksi_analysis.png', dpi=150)
print("Saved KSI analysis to dvclive/ksi_analysis.png")
plt.show()
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--video", help="Path to video file")
parser.add_argument("--mode", choices=['2d', '3d', 'both', 'ksi'], default='both')
parser.add_argument("--speed", type=float, default=1.0, help="Playback speed (e.g., 0.5 for half speed)")
parser.add_argument("--user_npz", help="User poses .npz for KSI analysis")
parser.add_argument("--expert_npz", help="Expert poses .npz for KSI analysis")
parser.add_argument("--gpu", action='store_true', help="Enable GPU acceleration (less deterministic)")
args = parser.parse_args()
params = load_config()
crop_cfg = params['pose_pipeline']['crop_config']
crop_overrides = params.get('crop_overrides', {})
mp_cfg = params['mediapipe']
segment_rules = params.get('segment_rules', {})
if args.mode == 'ksi':
if not args.user_npz or not args.expert_npz:
print("KSI mode requires --user_npz and --expert_npz")
else:
user_data = np.load(args.user_npz)
expert_data = np.load(args.expert_npz)
user_lm = user_data['features'] if 'features' in user_data else user_data['arr_0']
expert_lm = expert_data['features'] if 'features' in expert_data else expert_data['arr_0']
# Reshape to [T, 33, 3] if needed
if user_lm.ndim == 2 and user_lm.shape[1] == 99:
user_lm = user_lm.reshape(-1, 33, 3)
if expert_lm.ndim == 2 and expert_lm.shape[1] == 99:
expert_lm = expert_lm.reshape(-1, 33, 3)
visualize_ksi_analysis(user_lm, expert_lm, fps=params.get('fps', 30.0),
weights=params.get('ksi', {}).get('weights'))
else:
if not args.video:
print("Video mode requires --video")
else:
if args.mode in ['2d', 'both']:
eff_crop_cfg = resolve_crop_config_for_video(args.video, crop_cfg, crop_overrides)
visualize_2d(args.video, eff_crop_cfg, speed=args.speed, segment_rules=segment_rules)
if args.mode in ['3d', 'both']:
visualize_3d(args.video, eff_crop_cfg, mp_cfg, speed=args.speed, segment_rules=segment_rules)