File size: 4,475 Bytes
d17f5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
DI Annotation Data Visualizer - Simple working version
"""
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from huggingface_hub import hf_hub_download
import json

DATASET_REPO = "DynamicIntelligence/di-annotation-data"

def load_data():
    """Load trajectory data from HuggingFace."""
    try:
        # Try to load parquet
        local_path = hf_hub_download(
            repo_id=DATASET_REPO,
            filename="data/chunk-000/episode_000000.parquet",
            repo_type="dataset"
        )
        df = pd.read_parquet(local_path)
        return df, None
    except Exception as e:
        # Fallback to old format
        try:
            local_path = hf_hub_download(
                repo_id=DATASET_REPO,
                filename="data/Test_Data_Lidar_trajectory.parquet",
                repo_type="dataset"
            )
            df = pd.read_parquet(local_path)
            return df, None
        except Exception as e2:
            return None, f"Error loading data: {e2}"

def create_plots(df):
    """Create visualization plots."""
    if df is None:
        return None, None
    
    # Check which columns exist
    has_new_format = 'observation.state' in df.columns
    
    if has_new_format:
        # New LeRobot format
        obs = df['observation.state'].apply(lambda x: x if isinstance(x, list) else [0,0,0])
        camera_x = [o[0] if len(o) > 0 else 0 for o in obs]
        camera_y = [o[1] if len(o) > 1 else 0 for o in obs]
        camera_z = [o[2] if len(o) > 2 else 0 for o in obs]
        timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df)))
    else:
        # Old format
        camera_x = df['camera_x'].tolist() if 'camera_x' in df.columns else []
        camera_y = df['camera_y'].tolist() if 'camera_y' in df.columns else []
        camera_z = df['camera_z'].tolist() if 'camera_z' in df.columns else []
        timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df)))
    
    if not camera_x:
        return None, None
    
    # 3D trajectory plot
    fig_3d = go.Figure(data=[go.Scatter3d(
        x=camera_x, y=camera_y, z=camera_z,
        mode='lines+markers',
        marker=dict(size=2, color=timestamps, colorscale='Viridis'),
        line=dict(width=2, color='blue'),
        name='Camera Path'
    )])
    fig_3d.update_layout(
        title='Camera 3D Trajectory',
        scene=dict(
            xaxis_title='X (m)',
            yaxis_title='Y (m)',
            zaxis_title='Z (m)'
        ),
        height=500
    )
    
    # Time series
    fig_ts = make_subplots(rows=3, cols=1, subplot_titles=['Camera X', 'Camera Y', 'Camera Z'])
    fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_x, name='X', line=dict(color='red')), row=1, col=1)
    fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_y, name='Y', line=dict(color='green')), row=2, col=1)
    fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_z, name='Z', line=dict(color='blue')), row=3, col=1)
    fig_ts.update_layout(height=600, title='Camera Position vs Time', showlegend=False)
    
    return fig_3d, fig_ts

def visualize():
    """Main visualization function."""
    df, error = load_data()
    
    if error:
        return None, None, f"Error: {error}"
    
    if df is None or len(df) == 0:
        return None, None, "No data found"
    
    fig_3d, fig_ts = create_plots(df)
    
    stats = f"""
**Dataset Stats:**
- Total frames: {len(df)}
- Columns: {', '.join(df.columns[:5])}...
- Source: {DATASET_REPO}
    """
    
    return fig_3d, fig_ts, stats

# Create Gradio interface
with gr.Blocks(title="DI Annotation Data Visualizer", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# DI Annotation Data Visualizer")
    gr.Markdown(f"Visualizing trajectory data from [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})")
    
    load_btn = gr.Button("Load & Visualize Data", variant="primary")
    
    with gr.Row():
        plot_3d = gr.Plot(label="3D Trajectory")
        plot_ts = gr.Plot(label="Time Series")
    
    stats_output = gr.Markdown()
    
    load_btn.click(
        fn=visualize,
        inputs=[],
        outputs=[plot_3d, plot_ts, stats_output]
    )
    
    # Auto-load on start
    demo.load(visualize, inputs=[], outputs=[plot_3d, plot_ts, stats_output])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)