Spaces:
Runtime error
Runtime error
| """ | |
| DI Annotation Data Visualizer - Simple working version | |
| """ | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| DATASET_REPO = "DynamicIntelligence/di-annotation-data" | |
| def load_data(): | |
| """Load trajectory data from HuggingFace.""" | |
| try: | |
| # Try to load parquet | |
| local_path = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename="data/chunk-000/episode_000000.parquet", | |
| repo_type="dataset" | |
| ) | |
| df = pd.read_parquet(local_path) | |
| return df, None | |
| except Exception as e: | |
| # Fallback to old format | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename="data/Test_Data_Lidar_trajectory.parquet", | |
| repo_type="dataset" | |
| ) | |
| df = pd.read_parquet(local_path) | |
| return df, None | |
| except Exception as e2: | |
| return None, f"Error loading data: {e2}" | |
| def create_plots(df): | |
| """Create visualization plots.""" | |
| if df is None: | |
| return None, None | |
| # Check which columns exist | |
| has_new_format = 'observation.state' in df.columns | |
| if has_new_format: | |
| # New LeRobot format | |
| obs = df['observation.state'].apply(lambda x: x if isinstance(x, list) else [0,0,0]) | |
| camera_x = [o[0] if len(o) > 0 else 0 for o in obs] | |
| camera_y = [o[1] if len(o) > 1 else 0 for o in obs] | |
| camera_z = [o[2] if len(o) > 2 else 0 for o in obs] | |
| timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df))) | |
| else: | |
| # Old format | |
| camera_x = df['camera_x'].tolist() if 'camera_x' in df.columns else [] | |
| camera_y = df['camera_y'].tolist() if 'camera_y' in df.columns else [] | |
| camera_z = df['camera_z'].tolist() if 'camera_z' in df.columns else [] | |
| timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df))) | |
| if not camera_x: | |
| return None, None | |
| # 3D trajectory plot | |
| fig_3d = go.Figure(data=[go.Scatter3d( | |
| x=camera_x, y=camera_y, z=camera_z, | |
| mode='lines+markers', | |
| marker=dict(size=2, color=timestamps, colorscale='Viridis'), | |
| line=dict(width=2, color='blue'), | |
| name='Camera Path' | |
| )]) | |
| fig_3d.update_layout( | |
| title='Camera 3D Trajectory', | |
| scene=dict( | |
| xaxis_title='X (m)', | |
| yaxis_title='Y (m)', | |
| zaxis_title='Z (m)' | |
| ), | |
| height=500 | |
| ) | |
| # Time series | |
| fig_ts = make_subplots(rows=3, cols=1, subplot_titles=['Camera X', 'Camera Y', 'Camera Z']) | |
| fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_x, name='X', line=dict(color='red')), row=1, col=1) | |
| fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_y, name='Y', line=dict(color='green')), row=2, col=1) | |
| fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_z, name='Z', line=dict(color='blue')), row=3, col=1) | |
| fig_ts.update_layout(height=600, title='Camera Position vs Time', showlegend=False) | |
| return fig_3d, fig_ts | |
| def visualize(): | |
| """Main visualization function.""" | |
| df, error = load_data() | |
| if error: | |
| return None, None, f"Error: {error}" | |
| if df is None or len(df) == 0: | |
| return None, None, "No data found" | |
| fig_3d, fig_ts = create_plots(df) | |
| stats = f""" | |
| **Dataset Stats:** | |
| - Total frames: {len(df)} | |
| - Columns: {', '.join(df.columns[:5])}... | |
| - Source: {DATASET_REPO} | |
| """ | |
| return fig_3d, fig_ts, stats | |
| # Create Gradio interface | |
| with gr.Blocks(title="DI Annotation Data Visualizer", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# DI Annotation Data Visualizer") | |
| gr.Markdown(f"Visualizing trajectory data from [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})") | |
| load_btn = gr.Button("Load & Visualize Data", variant="primary") | |
| with gr.Row(): | |
| plot_3d = gr.Plot(label="3D Trajectory") | |
| plot_ts = gr.Plot(label="Time Series") | |
| stats_output = gr.Markdown() | |
| load_btn.click( | |
| fn=visualize, | |
| inputs=[], | |
| outputs=[plot_3d, plot_ts, stats_output] | |
| ) | |
| # Auto-load on start | |
| demo.load(visualize, inputs=[], outputs=[plot_3d, plot_ts, stats_output]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |