""" DI Annotation Data Visualizer - Simple working version """ import gradio as gr import pandas as pd import plotly.graph_objects as go from plotly.subplots import make_subplots from huggingface_hub import hf_hub_download import json DATASET_REPO = "DynamicIntelligence/di-annotation-data" def load_data(): """Load trajectory data from HuggingFace.""" try: # Try to load parquet local_path = hf_hub_download( repo_id=DATASET_REPO, filename="data/chunk-000/episode_000000.parquet", repo_type="dataset" ) df = pd.read_parquet(local_path) return df, None except Exception as e: # Fallback to old format try: local_path = hf_hub_download( repo_id=DATASET_REPO, filename="data/Test_Data_Lidar_trajectory.parquet", repo_type="dataset" ) df = pd.read_parquet(local_path) return df, None except Exception as e2: return None, f"Error loading data: {e2}" def create_plots(df): """Create visualization plots.""" if df is None: return None, None # Check which columns exist has_new_format = 'observation.state' in df.columns if has_new_format: # New LeRobot format obs = df['observation.state'].apply(lambda x: x if isinstance(x, list) else [0,0,0]) camera_x = [o[0] if len(o) > 0 else 0 for o in obs] camera_y = [o[1] if len(o) > 1 else 0 for o in obs] camera_z = [o[2] if len(o) > 2 else 0 for o in obs] timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df))) else: # Old format camera_x = df['camera_x'].tolist() if 'camera_x' in df.columns else [] camera_y = df['camera_y'].tolist() if 'camera_y' in df.columns else [] camera_z = df['camera_z'].tolist() if 'camera_z' in df.columns else [] timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df))) if not camera_x: return None, None # 3D trajectory plot fig_3d = go.Figure(data=[go.Scatter3d( x=camera_x, y=camera_y, z=camera_z, mode='lines+markers', marker=dict(size=2, color=timestamps, colorscale='Viridis'), line=dict(width=2, color='blue'), name='Camera Path' )]) fig_3d.update_layout( title='Camera 3D Trajectory', scene=dict( xaxis_title='X (m)', yaxis_title='Y (m)', zaxis_title='Z (m)' ), height=500 ) # Time series fig_ts = make_subplots(rows=3, cols=1, subplot_titles=['Camera X', 'Camera Y', 'Camera Z']) fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_x, name='X', line=dict(color='red')), row=1, col=1) fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_y, name='Y', line=dict(color='green')), row=2, col=1) fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_z, name='Z', line=dict(color='blue')), row=3, col=1) fig_ts.update_layout(height=600, title='Camera Position vs Time', showlegend=False) return fig_3d, fig_ts def visualize(): """Main visualization function.""" df, error = load_data() if error: return None, None, f"Error: {error}" if df is None or len(df) == 0: return None, None, "No data found" fig_3d, fig_ts = create_plots(df) stats = f""" **Dataset Stats:** - Total frames: {len(df)} - Columns: {', '.join(df.columns[:5])}... - Source: {DATASET_REPO} """ return fig_3d, fig_ts, stats # Create Gradio interface with gr.Blocks(title="DI Annotation Data Visualizer", theme=gr.themes.Soft()) as demo: gr.Markdown("# DI Annotation Data Visualizer") gr.Markdown(f"Visualizing trajectory data from [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})") load_btn = gr.Button("Load & Visualize Data", variant="primary") with gr.Row(): plot_3d = gr.Plot(label="3D Trajectory") plot_ts = gr.Plot(label="Time Series") stats_output = gr.Markdown() load_btn.click( fn=visualize, inputs=[], outputs=[plot_3d, plot_ts, stats_output] ) # Auto-load on start demo.load(visualize, inputs=[], outputs=[plot_3d, plot_ts, stats_output]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)