Raffael-Kultyshev's picture
Fix Python 3.13 compatibility - use simpler gradio version
0d8a870
#!/usr/bin/env python3
"""
DI Trajectory Visualizer - Simple version without audio dependencies
"""
import json
import numpy as np
from pathlib import Path
from dataclasses import dataclass
from typing import List
# Use basic imports only
try:
import gradio as gr
GRADIO_AVAILABLE = True
except ImportError:
GRADIO_AVAILABLE = False
import plotly.graph_objects as go
from plotly.subplots import make_subplots
@dataclass
class TrajectoryData:
"""Container for trajectory data."""
timestamps: np.ndarray
camera_x: np.ndarray
camera_y: np.ndarray
camera_z: np.ndarray
left_hand_x: np.ndarray
left_hand_y: np.ndarray
left_hand_z: np.ndarray
right_hand_x: np.ndarray
right_hand_y: np.ndarray
right_hand_z: np.ndarray
left_hand_roll: np.ndarray
left_hand_pitch: np.ndarray
left_hand_yaw: np.ndarray
right_hand_roll: np.ndarray
right_hand_pitch: np.ndarray
right_hand_yaw: np.ndarray
def create_sample_data() -> TrajectoryData:
"""Create sample trajectory data for demo."""
n = 300 # 10 seconds at 30fps
t = np.linspace(0, 10, n)
# Camera moving in a path
camera_x = np.sin(t * 0.5) * 0.5
camera_y = np.cos(t * 0.3) * 0.3
camera_z = t * 0.1
# Left hand
left_hand_x = camera_x + np.sin(t * 2) * 0.3 + 0.2
left_hand_y = camera_y + np.cos(t * 2) * 0.2 - 0.3
left_hand_z = camera_z + np.sin(t) * 0.1
# Right hand
right_hand_x = camera_x + np.sin(t * 2 + np.pi) * 0.3 - 0.2
right_hand_y = camera_y + np.cos(t * 2 + np.pi) * 0.2 - 0.3
right_hand_z = camera_z + np.cos(t) * 0.1
# Orientations
left_hand_roll = np.sin(t * 3) * 30
left_hand_pitch = np.cos(t * 2) * 20
left_hand_yaw = np.sin(t * 1.5) * 45
right_hand_roll = np.sin(t * 3 + 1) * 30
right_hand_pitch = np.cos(t * 2 + 1) * 20
right_hand_yaw = np.sin(t * 1.5 + 1) * 45
return TrajectoryData(
timestamps=t,
camera_x=camera_x, camera_y=camera_y, camera_z=camera_z,
left_hand_x=left_hand_x, left_hand_y=left_hand_y, left_hand_z=left_hand_z,
right_hand_x=right_hand_x, right_hand_y=right_hand_y, right_hand_z=right_hand_z,
left_hand_roll=left_hand_roll, left_hand_pitch=left_hand_pitch, left_hand_yaw=left_hand_yaw,
right_hand_roll=right_hand_roll, right_hand_pitch=right_hand_pitch, right_hand_yaw=right_hand_yaw
)
def load_from_json(episode_path: str) -> TrajectoryData:
"""Load trajectory data from JSON files."""
path = Path(episode_path)
# Try to load metadata
metadata = {"poses": []}
for meta_file in ["metadata.json", "extracted/metadata.json"]:
meta_path = path / meta_file
if meta_path.exists():
with open(meta_path) as f:
metadata = json.load(f)
break
# Try to load hands_3d
hands_3d = {"frames": []}
hands_path = path / "hands_3d.json"
if hands_path.exists():
with open(hands_path) as f:
hands_3d = json.load(f)
# Try to load end_effector
end_effector = {"frames": []}
ee_path = path / "end_effector.json"
if ee_path.exists():
with open(ee_path) as f:
end_effector = json.load(f)
# Parse frames
frames = metadata.get('frames', metadata.get('poses', []))
n = max(len(frames), 1)
fps = metadata.get('fps', 30)
timestamps = np.arange(n) / fps
# Camera positions
camera_x, camera_y, camera_z = [], [], []
for f in frames:
pos = f.get('camera_pose', {}).get('position', f.get('position', [0, 0, 0]))
camera_x.append(pos[0] if len(pos) > 0 else 0)
camera_y.append(pos[1] if len(pos) > 1 else 0)
camera_z.append(pos[2] if len(pos) > 2 else 0)
# Hand positions
hframes = hands_3d.get('frames', [])
left_hand_x = [f.get('left_hand', {}).get('position', [0,0,0])[0] for f in hframes] or [0]*n
left_hand_y = [f.get('left_hand', {}).get('position', [0,0,0])[1] for f in hframes] or [0]*n
left_hand_z = [f.get('left_hand', {}).get('position', [0,0,0])[2] for f in hframes] or [0]*n
right_hand_x = [f.get('right_hand', {}).get('position', [0,0,0])[0] for f in hframes] or [0]*n
right_hand_y = [f.get('right_hand', {}).get('position', [0,0,0])[1] for f in hframes] or [0]*n
right_hand_z = [f.get('right_hand', {}).get('position', [0,0,0])[2] for f in hframes] or [0]*n
# Orientations
eframes = end_effector.get('frames', [])
left_hand_roll = [f.get('left_hand', {}).get('orientation', [0,0,0])[0] for f in eframes] or [0]*n
left_hand_pitch = [f.get('left_hand', {}).get('orientation', [0,0,0])[1] for f in eframes] or [0]*n
left_hand_yaw = [f.get('left_hand', {}).get('orientation', [0,0,0])[2] for f in eframes] or [0]*n
right_hand_roll = [f.get('right_hand', {}).get('orientation', [0,0,0])[0] for f in eframes] or [0]*n
right_hand_pitch = [f.get('right_hand', {}).get('orientation', [0,0,0])[1] for f in eframes] or [0]*n
right_hand_yaw = [f.get('right_hand', {}).get('orientation', [0,0,0])[2] for f in eframes] or [0]*n
return TrajectoryData(
timestamps=np.array(timestamps) if len(timestamps) else np.array([0]),
camera_x=np.array(camera_x) if camera_x else np.array([0]),
camera_y=np.array(camera_y) if camera_y else np.array([0]),
camera_z=np.array(camera_z) if camera_z else np.array([0]),
left_hand_x=np.array(left_hand_x),
left_hand_y=np.array(left_hand_y),
left_hand_z=np.array(left_hand_z),
right_hand_x=np.array(right_hand_x),
right_hand_y=np.array(right_hand_y),
right_hand_z=np.array(right_hand_z),
left_hand_roll=np.array(left_hand_roll),
left_hand_pitch=np.array(left_hand_pitch),
left_hand_yaw=np.array(left_hand_yaw),
right_hand_roll=np.array(right_hand_roll),
right_hand_pitch=np.array(right_hand_pitch),
right_hand_yaw=np.array(right_hand_yaw)
)
def create_time_series_plot(data: TrajectoryData) -> go.Figure:
"""Create 15-subplot visualization."""
fig = make_subplots(
rows=5, cols=3,
subplot_titles=[
'Camera X (m)', 'Camera Y (m)', 'Camera Z (m)',
'Left Hand X', 'Left Hand Y', 'Left Hand Z',
'Right Hand X', 'Right Hand Y', 'Right Hand Z',
'Left Roll', 'Left Pitch', 'Left Yaw',
'Right Roll', 'Right Pitch', 'Right Yaw',
],
vertical_spacing=0.06,
horizontal_spacing=0.04
)
t = data.timestamps
colors = ['#2563eb', '#dc2626', '#16a34a', '#ea580c', '#9333ea']
# Row 1: Camera
fig.add_trace(go.Scatter(x=t, y=data.camera_x, line=dict(color=colors[0], width=1), showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=t, y=data.camera_y, line=dict(color=colors[0], width=1), showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=t, y=data.camera_z, line=dict(color=colors[0], width=1), showlegend=False), row=1, col=3)
# Row 2: Left hand position
fig.add_trace(go.Scatter(x=t, y=data.left_hand_x, line=dict(color=colors[1], width=1), showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=t, y=data.left_hand_y, line=dict(color=colors[1], width=1), showlegend=False), row=2, col=2)
fig.add_trace(go.Scatter(x=t, y=data.left_hand_z, line=dict(color=colors[1], width=1), showlegend=False), row=2, col=3)
# Row 3: Right hand position
fig.add_trace(go.Scatter(x=t, y=data.right_hand_x, line=dict(color=colors[2], width=1), showlegend=False), row=3, col=1)
fig.add_trace(go.Scatter(x=t, y=data.right_hand_y, line=dict(color=colors[2], width=1), showlegend=False), row=3, col=2)
fig.add_trace(go.Scatter(x=t, y=data.right_hand_z, line=dict(color=colors[2], width=1), showlegend=False), row=3, col=3)
# Row 4: Left hand orientation
fig.add_trace(go.Scatter(x=t, y=data.left_hand_roll, line=dict(color=colors[3], width=1), showlegend=False), row=4, col=1)
fig.add_trace(go.Scatter(x=t, y=data.left_hand_pitch, line=dict(color=colors[3], width=1), showlegend=False), row=4, col=2)
fig.add_trace(go.Scatter(x=t, y=data.left_hand_yaw, line=dict(color=colors[3], width=1), showlegend=False), row=4, col=3)
# Row 5: Right hand orientation
fig.add_trace(go.Scatter(x=t, y=data.right_hand_roll, line=dict(color=colors[4], width=1), showlegend=False), row=5, col=1)
fig.add_trace(go.Scatter(x=t, y=data.right_hand_pitch, line=dict(color=colors[4], width=1), showlegend=False), row=5, col=2)
fig.add_trace(go.Scatter(x=t, y=data.right_hand_yaw, line=dict(color=colors[4], width=1), showlegend=False), row=5, col=3)
fig.update_layout(height=1000, showlegend=False, title_text="57 Data Streams Visualization")
fig.update_xaxes(title_text="Time (sec)")
return fig
def create_3d_plot(data: TrajectoryData) -> go.Figure:
"""Create 3D trajectory plot."""
fig = go.Figure()
fig.add_trace(go.Scatter3d(
x=data.camera_x, y=data.camera_y, z=data.camera_z,
mode='lines', name='Camera', line=dict(color='#2563eb', width=4)
))
fig.add_trace(go.Scatter3d(
x=data.left_hand_x, y=data.left_hand_y, z=data.left_hand_z,
mode='lines', name='Left Hand', line=dict(color='#dc2626', width=4)
))
fig.add_trace(go.Scatter3d(
x=data.right_hand_x, y=data.right_hand_y, z=data.right_hand_z,
mode='lines', name='Right Hand', line=dict(color='#16a34a', width=4)
))
fig.update_layout(
title='3D Trajectory (World Frame)',
scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z', aspectmode='data'),
height=600
)
return fig
def visualize(source: str):
"""Main visualization function."""
if source == "Sample Data":
data = create_sample_data()
info = "Using generated sample data (10 seconds, 300 frames)"
else:
try:
data = load_from_json(source)
info = f"Loaded from: {source}"
except Exception as e:
data = create_sample_data()
info = f"Error loading data: {e}. Using sample data."
time_series = create_time_series_plot(data)
plot_3d = create_3d_plot(data)
stats = f"""
## Data Statistics
| Metric | Value |
|--------|-------|
| Duration | {data.timestamps[-1]:.2f} sec |
| Frames | {len(data.timestamps)} |
| Visualized Streams | 15 |
| Total Streams | 57 (including 42 joint positions) |
**Info:** {info}
"""
return time_series, plot_3d, stats
# Gradio Interface
if GRADIO_AVAILABLE:
with gr.Blocks(title="DI Trajectory Visualizer") as demo:
gr.Markdown("""
# Dynamic Intelligence - Trajectory Visualizer
Visualize 57 data streams from humanoid robot training data.
### Data Streams
- **Camera**: X, Y, Z position (world frame)
- **Left Hand**: X, Y, Z position + Roll, Pitch, Yaw
- **Right Hand**: X, Y, Z position + Roll, Pitch, Yaw
- **Joints**: 21 keypoints × 2 hands × XYZ (stored, not visualized)
""")
with gr.Row():
source_input = gr.Dropdown(
choices=["Sample Data"],
value="Sample Data",
label="Data Source"
)
load_btn = gr.Button("Visualize", variant="primary")
stats_output = gr.Markdown()
with gr.Tabs():
with gr.TabItem("Time Series (15 plots)"):
time_plot = gr.Plot()
with gr.TabItem("3D View"):
plot_3d = gr.Plot()
load_btn.click(
fn=visualize,
inputs=[source_input],
outputs=[time_plot, plot_3d, stats_output]
)
# Auto-load sample data
demo.load(
fn=lambda: visualize("Sample Data"),
outputs=[time_plot, plot_3d, stats_output]
)
demo.launch()
else:
print("Gradio not available. Install with: pip install gradio")