from __future__ import annotations """ A10 visualizer.py ================= 3D skeleton visualization utilities for Issue #42 / Sprint 10. Designed to fit the current A10 sprint codebase where: - PoseNet/MoveNet input is 13 joints x 2 coordinates = 26 features - Kinect / one-step output is 13 joints x 3 coordinates = 39 features - Joint order must match data_loader.KINECT_JOINTS This module supports: - Static 3D skeleton plots - Side-by-side and overlay comparison plots - Per-joint error coloring - Multiple camera angles (front / side / top) - Joint trajectory trails - GIF / MP4 export with matplotlib animations - Interactive HTML viewer with play/pause, frame slider, and speed buttons - Saving prediction bundles so the visualizer can be called from training code Typical use: from visualizer import ( save_prediction_bundle, create_evaluation_visuals, ) bundle_dir = save_prediction_bundle( output_dir='A10/prediction_runs/demo', predicted_xyz=pred_xyz, ground_truth_xyz=true_xyz, sequence_name='A1_kinect.csv', metadata={'model': 'Dense_shallow_adam'} ) create_evaluation_visuals(bundle_dir) """ from dataclasses import dataclass import json from pathlib import Path from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation, PillowWriter, FFMpegWriter from matplotlib.colors import Normalize from matplotlib import cm # Plotly is optional but very useful for the interactive viewer. try: import plotly.graph_objects as go PLOTLY_AVAILABLE = True except Exception: PLOTLY_AVAILABLE = False # ----------------------------------------------------------------------------- # Skeleton definition # ----------------------------------------------------------------------------- JOINTS = [ 'head', 'left_shoulder', 'left_elbow', 'right_shoulder', 'right_elbow', 'left_hand', 'right_hand', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_foot', 'right_foot' ] JOINT_INDEX = {name: idx for idx, name in enumerate(JOINTS)} # Bone graph for the 13 Kinect joints used in your A10 codebase. BONES = [ ('head', 'left_shoulder'), ('head', 'right_shoulder'), ('left_shoulder', 'right_shoulder'), ('left_shoulder', 'left_elbow'), ('left_elbow', 'left_hand'), ('right_shoulder', 'right_elbow'), ('right_elbow', 'right_hand'), ('left_shoulder', 'left_hip'), ('right_shoulder', 'right_hip'), ('left_hip', 'right_hip'), ('left_hip', 'left_knee'), ('left_knee', 'left_foot'), ('right_hip', 'right_knee'), ('right_knee', 'right_foot'), ] VIEW_PRESETS = { 'front': dict(elev=15, azim=-90), 'side': dict(elev=15, azim=0), 'top': dict(elev=90, azim=-90), 'iso': dict(elev=20, azim=-55), } # ----------------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------------- def _as_path(pathlike: Union[str, Path]) -> Path: return pathlike if isinstance(pathlike, Path) else Path(pathlike) def _ensure_dir(pathlike: Union[str, Path]) -> Path: path = _as_path(pathlike) path.mkdir(parents=True, exist_ok=True) return path def reshape_xyz(data: np.ndarray) -> np.ndarray: """ Convert xyz data into shape (n_frames, 13, 3). Accepts: - (n_frames, 39) - (39,) - (n_frames, 13, 3) - (13, 3) """ arr = np.asarray(data, dtype=np.float32) if arr.ndim == 1: if arr.shape[0] != 39: raise ValueError(f'1D xyz input must have 39 values, got {arr.shape[0]}') arr = arr.reshape(1, 13, 3) elif arr.ndim == 2: if arr.shape == (13, 3): arr = arr.reshape(1, 13, 3) elif arr.shape[1] == 39: arr = arr.reshape(arr.shape[0], 13, 3) else: raise ValueError(f'2D xyz input must be (n,39) or (13,3); got {arr.shape}') elif arr.ndim == 3: if arr.shape[1:] != (13, 3): raise ValueError(f'3D xyz input must be (n,13,3); got {arr.shape}') else: raise ValueError(f'Unsupported xyz input shape: {arr.shape}') return arr def compute_joint_errors(pred_xyz: np.ndarray, gt_xyz: np.ndarray) -> np.ndarray: """Euclidean error per joint, shape (n_frames, 13).""" pred = reshape_xyz(pred_xyz) gt = reshape_xyz(gt_xyz) n = min(len(pred), len(gt)) pred = pred[:n] gt = gt[:n] return np.linalg.norm(pred - gt, axis=2) def compute_frame_errors(pred_xyz: np.ndarray, gt_xyz: np.ndarray) -> np.ndarray: """Mean Euclidean joint error per frame, shape (n_frames,).""" return compute_joint_errors(pred_xyz, gt_xyz).mean(axis=1) def infer_axis_limits(*arrays: np.ndarray, pad_ratio: float = 0.08) -> Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]: stacked = np.concatenate([reshape_xyz(a).reshape(-1, 3) for a in arrays], axis=0) mins = stacked.min(axis=0) maxs = stacked.max(axis=0) spans = np.maximum(maxs - mins, 1e-6) pads = spans * pad_ratio mins -= pads maxs += pads # Make a cubic box so the skeleton does not look distorted. center = (mins + maxs) / 2.0 radius = max((maxs - mins).max() / 2.0, 1e-4) return ( (center[0] - radius, center[0] + radius), (center[1] - radius, center[1] + radius), (center[2] - radius, center[2] + radius), ) def _bone_segments(points: np.ndarray) -> Iterable[Tuple[np.ndarray, np.ndarray]]: for j1, j2 in BONES: yield points[JOINT_INDEX[j1]], points[JOINT_INDEX[j2]] def save_prediction_bundle( output_dir: Union[str, Path], predicted_xyz: np.ndarray, ground_truth_xyz: Optional[np.ndarray] = None, sequence_name: Optional[str] = None, metadata: Optional[Dict] = None, posenet_xy: Optional[np.ndarray] = None, ) -> Path: """ Save model outputs in a simple, reusable format for the visualizer. """ out_dir = _ensure_dir(output_dir) pred = reshape_xyz(predicted_xyz) np.save(out_dir / 'predicted_xyz.npy', pred) if ground_truth_xyz is not None: gt = reshape_xyz(ground_truth_xyz) n = min(len(pred), len(gt)) np.save(out_dir / 'ground_truth_xyz.npy', gt[:n]) pred = pred[:n] if posenet_xy is not None: np.save(out_dir / 'posenet_xy.npy', np.asarray(posenet_xy, dtype=np.float32)) meta = dict(metadata or {}) meta['sequence_name'] = sequence_name meta['n_frames'] = int(len(pred)) meta['joints'] = JOINTS meta['bones'] = BONES with open(out_dir / 'metadata.json', 'w', encoding='utf-8') as f: json.dump(meta, f, indent=2) return out_dir def load_prediction_bundle(bundle_dir: Union[str, Path]) -> Dict[str, Optional[np.ndarray]]: bundle = _as_path(bundle_dir) pred = np.load(bundle / 'predicted_xyz.npy') gt_path = bundle / 'ground_truth_xyz.npy' xy_path = bundle / 'posenet_xy.npy' meta_path = bundle / 'metadata.json' gt = np.load(gt_path) if gt_path.exists() else None posenet_xy = np.load(xy_path) if xy_path.exists() else None metadata = {} if meta_path.exists(): with open(meta_path, 'r', encoding='utf-8') as f: metadata = json.load(f) return { 'predicted_xyz': reshape_xyz(pred), 'ground_truth_xyz': reshape_xyz(gt) if gt is not None else None, 'posenet_xy': posenet_xy, 'metadata': metadata, 'bundle_dir': bundle, } # ----------------------------------------------------------------------------- # Drawing # ----------------------------------------------------------------------------- def _draw_skeleton( ax, points: np.ndarray, title: Optional[str] = None, joint_errors: Optional[np.ndarray] = None, cmap_name: str = 'turbo', error_norm: Optional[Normalize] = None, bone_color: str = 'black', marker_size: int = 36, alpha: float = 1.0, show_labels: bool = False, trails: Optional[np.ndarray] = None, ): points = np.asarray(points, dtype=np.float32) cmap = cm.get_cmap(cmap_name) if joint_errors is None: joint_colors = ['tab:blue'] * len(points) else: if error_norm is None: error_norm = Normalize(vmin=float(np.min(joint_errors)), vmax=float(np.max(joint_errors)) + 1e-9) joint_colors = [cmap(error_norm(v)) for v in joint_errors] # Trails first if trails is not None and len(trails) > 1: for j in range(points.shape[0]): trail = trails[:, j, :] ax.plot(trail[:, 0], trail[:, 1], trail[:, 2], alpha=0.25, linewidth=1.0) # Bones for p1, p2 in _bone_segments(points): ax.plot( [p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=bone_color, linewidth=2, alpha=alpha, ) # Joints ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=joint_colors, s=marker_size, alpha=alpha) if show_labels: for idx, name in enumerate(JOINTS): x, y, z = points[idx] ax.text(x, y, z, name, fontsize=8) if title: ax.set_title(title) def _format_axes(ax, axis_limits, view_name='iso'): (xlim, ylim, zlim) = axis_limits ax.set_xlim(*xlim) ax.set_ylim(*ylim) ax.set_zlim(*zlim) ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('Z') view = VIEW_PRESETS.get(view_name, VIEW_PRESETS['iso']) ax.view_init(elev=view['elev'], azim=view['azim']) def plot_frame_comparison( predicted_xyz: np.ndarray, ground_truth_xyz: Optional[np.ndarray] = None, frame_idx: int = 0, save_path: Optional[Union[str, Path]] = None, show_labels: bool = False, overlay: bool = True, ) -> plt.Figure: """ Create a report-friendly static figure. Layout: - GT skeleton - Pred skeleton - Overlay - Error heatmap overlay """ pred = reshape_xyz(predicted_xyz) gt = reshape_xyz(ground_truth_xyz) if ground_truth_xyz is not None else None frame_idx = int(np.clip(frame_idx, 0, len(pred) - 1)) pred_f = pred[frame_idx] gt_f = gt[frame_idx] if gt is not None else None if gt_f is not None: joint_errors = np.linalg.norm(pred_f - gt_f, axis=1) error_norm = Normalize(vmin=0.0, vmax=max(float(joint_errors.max()), 1e-6)) axis_limits = infer_axis_limits(pred_f, gt_f) else: joint_errors = None error_norm = None axis_limits = infer_axis_limits(pred_f) if gt_f is None: fig = plt.figure(figsize=(6, 6)) ax = fig.add_subplot(111, projection='3d') _draw_skeleton(ax, pred_f, title=f'Predicted skeleton — frame {frame_idx}', show_labels=show_labels) _format_axes(ax, axis_limits, 'iso') else: fig = plt.figure(figsize=(14, 10)) ax1 = fig.add_subplot(221, projection='3d') ax2 = fig.add_subplot(222, projection='3d') ax3 = fig.add_subplot(223, projection='3d') ax4 = fig.add_subplot(224, projection='3d') _draw_skeleton(ax1, gt_f, title='Ground truth', bone_color='tab:green', show_labels=show_labels) _draw_skeleton(ax2, pred_f, title='Prediction', bone_color='tab:blue', show_labels=show_labels) if overlay: _draw_skeleton(ax3, gt_f, title='Overlay', bone_color='tab:green', alpha=0.65) _draw_skeleton(ax3, pred_f, bone_color='tab:blue', alpha=0.65) else: _draw_skeleton(ax3, pred_f, title='Prediction') _draw_skeleton( ax4, pred_f, title=f'Error heatmap (mean={joint_errors.mean():.4f})', joint_errors=joint_errors, error_norm=error_norm, bone_color='gray', show_labels=show_labels, ) for ax, view in zip([ax1, ax2, ax3, ax4], ['iso', 'iso', 'front', 'iso']): _format_axes(ax, axis_limits, view) fig.suptitle(f'3D Skeleton comparison — frame {frame_idx}', fontsize=14) fig.tight_layout() if save_path is not None: save_path = _as_path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=160, bbox_inches='tight') return fig def plot_multiview_frame( predicted_xyz: np.ndarray, ground_truth_xyz: Optional[np.ndarray] = None, frame_idx: int = 0, save_path: Optional[Union[str, Path]] = None, trails: int = 0, ): pred = reshape_xyz(predicted_xyz) gt = reshape_xyz(ground_truth_xyz) if ground_truth_xyz is not None else None frame_idx = int(np.clip(frame_idx, 0, len(pred) - 1)) pred_f = pred[frame_idx] gt_f = gt[frame_idx] if gt is not None else None trail_arr = None if trails > 0: s = max(0, frame_idx - trails) trail_arr = pred[s:frame_idx + 1] axis_limits = infer_axis_limits(pred, gt) if gt is not None else infer_axis_limits(pred) fig = plt.figure(figsize=(16, 4.5)) axes = [fig.add_subplot(1, 4, i + 1, projection='3d') for i in range(4)] titles = ['Front', 'Side', 'Top', 'Overlay / iso'] views = ['front', 'side', 'top', 'iso'] if gt_f is not None: joint_errors = np.linalg.norm(pred_f - gt_f, axis=1) error_norm = Normalize(vmin=0.0, vmax=max(float(joint_errors.max()), 1e-6)) else: joint_errors = None error_norm = None for ax, title, view in zip(axes, titles, views): if gt_f is not None and view == 'iso': _draw_skeleton(ax, gt_f, bone_color='tab:green', alpha=0.6) _draw_skeleton(ax, pred_f, joint_errors=joint_errors, error_norm=error_norm, bone_color='gray', alpha=0.85, trails=trail_arr) else: _draw_skeleton(ax, pred_f, joint_errors=joint_errors, error_norm=error_norm, bone_color='gray', trails=trail_arr) ax.set_title(title) _format_axes(ax, axis_limits, view) fig.suptitle(f'Multiview 3D skeleton — frame {frame_idx}', fontsize=14) fig.tight_layout() if save_path is not None: save_path = _as_path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=160, bbox_inches='tight') return fig # ----------------------------------------------------------------------------- # Animation export # ----------------------------------------------------------------------------- def animate_skeletons_matplotlib( predicted_xyz: np.ndarray, ground_truth_xyz: Optional[np.ndarray] = None, save_path: Union[str, Path] = 'animation.gif', fps: int = 15, dpi: int = 120, show_labels: bool = False, trail_length: int = 10, view_name: str = 'iso', ): """ Save a GIF or MP4 animation. Uses a 3-panel layout: GT, prediction, overlay/error. """ pred = reshape_xyz(predicted_xyz) gt = reshape_xyz(ground_truth_xyz) if ground_truth_xyz is not None else None n_frames = len(pred) if gt is None else min(len(pred), len(gt)) pred = pred[:n_frames] if gt is not None: gt = gt[:n_frames] all_joint_errors = np.linalg.norm(pred - gt, axis=2) error_norm = Normalize(vmin=0.0, vmax=max(float(all_joint_errors.max()), 1e-6)) else: all_joint_errors = None error_norm = None axis_limits = infer_axis_limits(pred, gt) if gt is not None else infer_axis_limits(pred) fig = plt.figure(figsize=(15, 5)) axes = [fig.add_subplot(1, 3, i + 1, projection='3d') for i in range(3)] def update(frame_idx): for ax in axes: ax.cla() pred_f = pred[frame_idx] gt_f = gt[frame_idx] if gt is not None else None trail = pred[max(0, frame_idx - trail_length):frame_idx + 1] if trail_length > 0 else None if gt_f is not None: joint_errors = all_joint_errors[frame_idx] _draw_skeleton(axes[0], gt_f, title='Ground truth', bone_color='tab:green', show_labels=show_labels) _draw_skeleton(axes[1], pred_f, title='Prediction', joint_errors=joint_errors, error_norm=error_norm, bone_color='gray', show_labels=show_labels) _draw_skeleton(axes[2], gt_f, title=f'Overlay — frame {frame_idx}', bone_color='tab:green', alpha=0.5) _draw_skeleton(axes[2], pred_f, joint_errors=joint_errors, error_norm=error_norm, bone_color='gray', alpha=0.9, trails=trail) else: _draw_skeleton(axes[0], pred_f, title=f'Prediction — frame {frame_idx}', bone_color='tab:blue', show_labels=show_labels, trails=trail) axes[1].set_visible(False) axes[2].set_visible(False) for ax in axes: if ax.get_visible(): _format_axes(ax, axis_limits, view_name) fig.suptitle(f'3D skeleton animation — frame {frame_idx + 1}/{n_frames}', fontsize=14) return axes anim = FuncAnimation(fig, update, frames=n_frames, interval=int(1000 / max(fps, 1)), blit=False) save_path = _as_path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) suffix = save_path.suffix.lower() if suffix == '.gif': writer = PillowWriter(fps=fps) elif suffix in {'.mp4', '.m4v'}: writer = FFMpegWriter(fps=fps) else: raise ValueError('save_path must end with .gif or .mp4') anim.save(save_path, writer=writer, dpi=dpi) plt.close(fig) return save_path # ----------------------------------------------------------------------------- # Plotly interactive viewer # ----------------------------------------------------------------------------- def _scatter3d_points(points, name, color, size=5, text=None): return go.Scatter3d( x=points[:, 0], y=points[:, 1], z=points[:, 2], mode='markers+text' if text is not None else 'markers', marker=dict(size=size, color=color), text=text, textposition='top center', name=name, ) def _scatter3d_bones(points, name, color, width=5): xs, ys, zs = [], [], [] for p1, p2 in _bone_segments(points): xs.extend([p1[0], p2[0], None]) ys.extend([p1[1], p2[1], None]) zs.extend([p1[2], p2[2], None]) return go.Scatter3d(x=xs, y=ys, z=zs, mode='lines', line=dict(color=color, width=width), name=name) def export_interactive_html( predicted_xyz: np.ndarray, ground_truth_xyz: Optional[np.ndarray] = None, html_path: Union[str, Path] = 'viewer.html', show_labels: bool = False, ): """ Export an interactive Plotly viewer with: - play / pause - frame stepping via slider - speed buttons - optional GT overlay toggle via legend """ if not PLOTLY_AVAILABLE: raise RuntimeError('Plotly is not installed. Run: pip install plotly') pred = reshape_xyz(predicted_xyz) gt = reshape_xyz(ground_truth_xyz) if ground_truth_xyz is not None else None n_frames = len(pred) if gt is None else min(len(pred), len(gt)) pred = pred[:n_frames] if gt is not None: gt = gt[:n_frames] err = np.linalg.norm(pred - gt, axis=2) err_mean = err.mean(axis=1) else: err = None err_mean = np.zeros(n_frames) (xlim, ylim, zlim) = infer_axis_limits(pred, gt) if gt is not None else infer_axis_limits(pred) text = JOINTS if show_labels else None def frame_data(i, speed_label='normal'): pred_f = pred[i] traces = [ _scatter3d_bones(pred_f, 'Prediction bones', 'royalblue'), _scatter3d_points(pred_f, 'Prediction joints', 'royalblue', text=text), ] if gt is not None: gt_f = gt[i] traces += [ _scatter3d_bones(gt_f, 'Ground truth bones', 'green'), _scatter3d_points(gt_f, 'Ground truth joints', 'green', text=text), ] return traces frames = [go.Frame(data=frame_data(i), name=str(i), layout=go.Layout(title_text=f'Frame {i} | mean error={err_mean[i]:.4f}' if gt is not None else f'Frame {i}')) for i in range(n_frames)] fig = go.Figure(data=frame_data(0), frames=frames) fig.update_layout( title='Interactive 3D skeleton viewer', scene=dict( xaxis=dict(range=list(xlim), title='X'), yaxis=dict(range=list(ylim), title='Y'), zaxis=dict(range=list(zlim), title='Z'), aspectmode='cube', camera=dict(eye=dict(x=1.3, y=1.3, z=0.8)), ), updatemenus=[ dict( type='buttons', direction='left', x=0.0, y=1.15, buttons=[ dict(label='Play', method='animate', args=[None, {'frame': {'duration': 80, 'redraw': True}, 'fromcurrent': True}]), dict(label='Pause', method='animate', args=[[None], {'frame': {'duration': 0, 'redraw': False}, 'mode': 'immediate'}]), dict(label='Slow', method='animate', args=[None, {'frame': {'duration': 180, 'redraw': True}, 'fromcurrent': True}]), dict(label='Normal', method='animate', args=[None, {'frame': {'duration': 80, 'redraw': True}, 'fromcurrent': True}]), dict(label='Fast', method='animate', args=[None, {'frame': {'duration': 30, 'redraw': True}, 'fromcurrent': True}]), ], ) ], sliders=[{ 'pad': {'b': 10, 't': 35}, 'len': 0.95, 'x': 0.03, 'y': 0.0, 'steps': [ { 'args': [[str(i)], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate'}], 'label': str(i), 'method': 'animate', } for i in range(n_frames) ], }], showlegend=True, ) html_path = _as_path(html_path) html_path.parent.mkdir(parents=True, exist_ok=True) fig.write_html(str(html_path), include_plotlyjs='cdn') return html_path # ----------------------------------------------------------------------------- # High-level workflow helpers # ----------------------------------------------------------------------------- def create_evaluation_visuals( bundle_dir: Union[str, Path], frame_indices: Optional[Sequence[int]] = None, export_gif: bool = True, export_mp4: bool = False, export_html: bool = True, fps: int = 15, trail_length: int = 10, ) -> Dict[str, List[str]]: """ Generate all standard outputs into: - bundle_dir/skeleton_plots/ - bundle_dir/animations/ """ bundle = load_prediction_bundle(bundle_dir) pred = bundle['predicted_xyz'] gt = bundle['ground_truth_xyz'] n_frames = len(pred) if gt is None else min(len(pred), len(gt)) plot_dir = _ensure_dir(_as_path(bundle_dir) / 'skeleton_plots') anim_dir = _ensure_dir(_as_path(bundle_dir) / 'animations') outputs = {'plots': [], 'animations': [], 'interactive': []} if frame_indices is None: frame_indices = sorted(set([0, max(0, n_frames // 2), max(0, n_frames - 1)])) for frame_idx in frame_indices: frame_idx = int(np.clip(frame_idx, 0, n_frames - 1)) static_path = plot_dir / f'comparison_frame_{frame_idx:04d}.png' multiview_path = plot_dir / f'multiview_frame_{frame_idx:04d}.png' plot_frame_comparison(pred, gt, frame_idx=frame_idx, save_path=static_path) plt.close('all') plot_multiview_frame(pred, gt, frame_idx=frame_idx, save_path=multiview_path, trails=trail_length) plt.close('all') outputs['plots'] += [str(static_path), str(multiview_path)] if export_gif: gif_path = anim_dir / 'comparison_animation.gif' animate_skeletons_matplotlib(pred, gt, gif_path, fps=fps, trail_length=trail_length) outputs['animations'].append(str(gif_path)) if export_mp4: mp4_path = anim_dir / 'comparison_animation.mp4' animate_skeletons_matplotlib(pred, gt, mp4_path, fps=fps, trail_length=trail_length) outputs['animations'].append(str(mp4_path)) if export_html: html_path = anim_dir / 'interactive_viewer.html' export_interactive_html(pred, gt, html_path) outputs['interactive'].append(str(html_path)) summary = { 'bundle_dir': str(bundle_dir), 'n_frames': int(n_frames), 'has_ground_truth': gt is not None, 'outputs': outputs, } with open(_as_path(bundle_dir) / 'visualization_summary.json', 'w', encoding='utf-8') as f: json.dump(summary, f, indent=2) return outputs def save_prediction_bundle_from_model( model, X_input: np.ndarray, y_true_xyz: Optional[np.ndarray], output_dir: Union[str, Path], output_scaler=None, sequence_name: Optional[str] = None, metadata: Optional[Dict] = None, ): """ Convenience helper for training code. - model.predict on X_input - optional inverse transform using output_scaler - save prediction bundle """ pred = model.predict(X_input, verbose=0) if output_scaler is not None: pred = output_scaler.inverse_transform(pred) if y_true_xyz is not None: y_true_xyz = output_scaler.inverse_transform(y_true_xyz) return save_prediction_bundle( output_dir=output_dir, predicted_xyz=pred, ground_truth_xyz=y_true_xyz, sequence_name=sequence_name, metadata=metadata, ) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='A10 3D skeleton visualizer') parser.add_argument('--bundle_dir', type=str, help='Folder containing predicted_xyz.npy and optional ground_truth_xyz.npy') parser.add_argument('--pred_npy', type=str, help='Path to predicted xyz .npy file') parser.add_argument('--gt_npy', type=str, default=None, help='Path to ground-truth xyz .npy file') parser.add_argument('--out_dir', type=str, default='visualizer_outputs', help='Output directory when using --pred_npy/--gt_npy') parser.add_argument('--fps', type=int, default=15) parser.add_argument('--no_html', action='store_true') parser.add_argument('--mp4', action='store_true') args = parser.parse_args() if args.bundle_dir: create_evaluation_visuals( bundle_dir=args.bundle_dir, export_gif=True, export_mp4=args.mp4, export_html=not args.no_html, fps=args.fps, ) print(f'Visualization outputs created in {args.bundle_dir}') elif args.pred_npy: pred = np.load(args.pred_npy) gt = np.load(args.gt_npy) if args.gt_npy else None bundle = save_prediction_bundle(args.out_dir, pred, gt) create_evaluation_visuals( bundle_dir=bundle, export_gif=True, export_mp4=args.mp4, export_html=not args.no_html, fps=args.fps, ) print(f'Visualization outputs created in {bundle}') else: parser.error('Provide either --bundle_dir or --pred_npy')