Spaces:
Runtime error
Runtime error
File size: 4,475 Bytes
d17f5c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""
DI Annotation Data Visualizer - Simple working version
"""
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from huggingface_hub import hf_hub_download
import json
DATASET_REPO = "DynamicIntelligence/di-annotation-data"
def load_data():
"""Load trajectory data from HuggingFace."""
try:
# Try to load parquet
local_path = hf_hub_download(
repo_id=DATASET_REPO,
filename="data/chunk-000/episode_000000.parquet",
repo_type="dataset"
)
df = pd.read_parquet(local_path)
return df, None
except Exception as e:
# Fallback to old format
try:
local_path = hf_hub_download(
repo_id=DATASET_REPO,
filename="data/Test_Data_Lidar_trajectory.parquet",
repo_type="dataset"
)
df = pd.read_parquet(local_path)
return df, None
except Exception as e2:
return None, f"Error loading data: {e2}"
def create_plots(df):
"""Create visualization plots."""
if df is None:
return None, None
# Check which columns exist
has_new_format = 'observation.state' in df.columns
if has_new_format:
# New LeRobot format
obs = df['observation.state'].apply(lambda x: x if isinstance(x, list) else [0,0,0])
camera_x = [o[0] if len(o) > 0 else 0 for o in obs]
camera_y = [o[1] if len(o) > 1 else 0 for o in obs]
camera_z = [o[2] if len(o) > 2 else 0 for o in obs]
timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df)))
else:
# Old format
camera_x = df['camera_x'].tolist() if 'camera_x' in df.columns else []
camera_y = df['camera_y'].tolist() if 'camera_y' in df.columns else []
camera_z = df['camera_z'].tolist() if 'camera_z' in df.columns else []
timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df)))
if not camera_x:
return None, None
# 3D trajectory plot
fig_3d = go.Figure(data=[go.Scatter3d(
x=camera_x, y=camera_y, z=camera_z,
mode='lines+markers',
marker=dict(size=2, color=timestamps, colorscale='Viridis'),
line=dict(width=2, color='blue'),
name='Camera Path'
)])
fig_3d.update_layout(
title='Camera 3D Trajectory',
scene=dict(
xaxis_title='X (m)',
yaxis_title='Y (m)',
zaxis_title='Z (m)'
),
height=500
)
# Time series
fig_ts = make_subplots(rows=3, cols=1, subplot_titles=['Camera X', 'Camera Y', 'Camera Z'])
fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_x, name='X', line=dict(color='red')), row=1, col=1)
fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_y, name='Y', line=dict(color='green')), row=2, col=1)
fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_z, name='Z', line=dict(color='blue')), row=3, col=1)
fig_ts.update_layout(height=600, title='Camera Position vs Time', showlegend=False)
return fig_3d, fig_ts
def visualize():
"""Main visualization function."""
df, error = load_data()
if error:
return None, None, f"Error: {error}"
if df is None or len(df) == 0:
return None, None, "No data found"
fig_3d, fig_ts = create_plots(df)
stats = f"""
**Dataset Stats:**
- Total frames: {len(df)}
- Columns: {', '.join(df.columns[:5])}...
- Source: {DATASET_REPO}
"""
return fig_3d, fig_ts, stats
# Create Gradio interface
with gr.Blocks(title="DI Annotation Data Visualizer", theme=gr.themes.Soft()) as demo:
gr.Markdown("# DI Annotation Data Visualizer")
gr.Markdown(f"Visualizing trajectory data from [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})")
load_btn = gr.Button("Load & Visualize Data", variant="primary")
with gr.Row():
plot_3d = gr.Plot(label="3D Trajectory")
plot_ts = gr.Plot(label="Time Series")
stats_output = gr.Markdown()
load_btn.click(
fn=visualize,
inputs=[],
outputs=[plot_3d, plot_ts, stats_output]
)
# Auto-load on start
demo.load(visualize, inputs=[], outputs=[plot_3d, plot_ts, stats_output])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|