File size: 6,236 Bytes
2d24e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccb16eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
#!/usr/bin/env python3
"""
DI Annotation Data Visualizer
Visualizes data from: DynamicIntelligence/di-annotation-data
"""

import gradio as gr
from huggingface_hub import hf_hub_download, list_repo_files, HfApi
import json
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path
import pandas as pd

# YOUR dataset - not the old one
DATASET_REPO = "DynamicIntelligence/di-annotation-data"

def list_episodes():
    """List all episodes in the dataset."""
    try:
        api = HfApi()
        files = list(api.list_repo_files(repo_id=DATASET_REPO, repo_type="dataset"))
        
        # Find episodes
        episodes = set()
        for f in files:
            if f.startswith("episodes/") and f.endswith("/trajectory.json"):
                parts = f.split("/")
                if len(parts) >= 2:
                    episodes.add(parts[1])
            elif f.startswith("episodes/") and "/" in f:
                parts = f.split("/")
                if len(parts) >= 2 and parts[1]:
                    episodes.add(parts[1])
        
        return sorted(list(episodes)) if episodes else ["No episodes found"]
    except Exception as e:
        return [f"Error: {str(e)}"]


def load_episode_data(episode_id: str):
    """Load trajectory data for an episode."""
    try:
        # Download trajectory.json
        local_path = hf_hub_download(
            repo_id=DATASET_REPO,
            filename=f"episodes/{episode_id}/trajectory.json",
            repo_type="dataset"
        )
        
        with open(local_path) as f:
            data = json.load(f)
        
        return data, None
    except Exception as e:
        return None, str(e)


def create_plots(data: dict):
    """Create trajectory plots."""
    
    camera = data.get("camera", {})
    x = camera.get("x", [])
    y = camera.get("y", [])
    z = camera.get("z", [])
    timestamps = data.get("timestamps", list(range(len(x))))
    
    if not x:
        empty = go.Figure()
        empty.add_annotation(text="No trajectory data", showarrow=False, font_size=20)
        return empty, empty
    
    # 3D trajectory
    fig_3d = go.Figure()
    fig_3d.add_trace(go.Scatter3d(
        x=x, y=y, z=z,
        mode='lines',
        line=dict(color='blue', width=4),
        name='Camera'
    ))
    fig_3d.add_trace(go.Scatter3d(
        x=[x[0]], y=[y[0]], z=[z[0]],
        mode='markers',
        marker=dict(color='green', size=10),
        name='Start'
    ))
    fig_3d.add_trace(go.Scatter3d(
        x=[x[-1]], y=[y[-1]], z=[z[-1]],
        mode='markers',
        marker=dict(color='red', size=10),
        name='End'
    ))
    fig_3d.update_layout(
        title="Camera Trajectory (World Frame)",
        scene=dict(xaxis_title='X (m)', yaxis_title='Y (m)', zaxis_title='Z (m)'),
        height=500
    )
    
    # Time series
    fig_ts = make_subplots(rows=3, cols=1, shared_xaxes=True,
                          subplot_titles=['Camera X', 'Camera Y', 'Camera Z'])
    
    fig_ts.add_trace(go.Scatter(x=timestamps, y=x, name='X', line=dict(color='red')), row=1, col=1)
    fig_ts.add_trace(go.Scatter(x=timestamps, y=y, name='Y', line=dict(color='green')), row=2, col=1)
    fig_ts.add_trace(go.Scatter(x=timestamps, y=z, name='Z', line=dict(color='blue')), row=3, col=1)
    
    fig_ts.update_layout(height=500, title="Position vs Time", showlegend=True)
    fig_ts.update_xaxes(title_text="Time (s)", row=3, col=1)
    
    return fig_3d, fig_ts


def load_and_visualize(episode_id: str):
    """Main function to load and visualize episode."""
    
    if not episode_id or episode_id.startswith("Error") or episode_id == "No episodes found":
        empty = go.Figure()
        empty.add_annotation(text="Select an episode to visualize", showarrow=False, font_size=20)
        return empty, empty, "No episode selected"
    
    data, error = load_episode_data(episode_id)
    
    if error:
        empty = go.Figure()
        empty.add_annotation(text=f"Error: {error}", showarrow=False, font_size=14)
        return empty, empty, f"**Error:** {error}"
    
    fig_3d, fig_ts = create_plots(data)
    
    # Stats
    camera = data.get("camera", {})
    annotations = data.get("annotations", [])
    
    stats = f"""
## Episode: {episode_id}

| Property | Value |
|----------|-------|
| Task | {data.get('task', 'Unknown')} |
| Language Instruction | {data.get('language_instruction', data.get('task', 'N/A'))} |
| Frames | {data.get('num_frames', len(camera.get('x', [])))} |
| FPS | {data.get('fps', 30)} |
| Frame Range | {data.get('frame_range', {}).get('start', 0)} β†’ {data.get('frame_range', {}).get('end', 'N/A')} |
| Annotations | {len(annotations)} segments |
"""
    
    return fig_3d, fig_ts, stats


# Build Gradio interface
with gr.Blocks(
    title="DI Annotation Visualizer",
    theme=gr.themes.Soft(primary_hue="blue")
) as demo:
    gr.Markdown(f"""
# DI Annotation Data Visualizer

Visualizing data from: [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})

This shows **YOUR** annotated episodes, not the old training dataset.
    """)
    
    with gr.Row():
        episode_dropdown = gr.Dropdown(
            label="Select Episode",
            choices=list_episodes(),
            interactive=True,
            scale=3
        )
        refresh_btn = gr.Button("πŸ”„ Refresh", scale=1)
        load_btn = gr.Button("πŸ“Š Load & Visualize", variant="primary", scale=1)
    
    stats_output = gr.Markdown()
    
    with gr.Tabs():
        with gr.TabItem("🌐 3D Trajectory"):
            plot_3d = gr.Plot(label="3D Trajectory")
        
        with gr.TabItem("πŸ“ˆ Time Series"):
            plot_ts = gr.Plot(label="Position vs Time")
    
    # Events
    load_btn.click(
        fn=load_and_visualize,
        inputs=[episode_dropdown],
        outputs=[plot_3d, plot_ts, stats_output]
    )
    
    refresh_btn.click(
        fn=lambda: gr.Dropdown(choices=list_episodes()),
        outputs=[episode_dropdown]
    )
    
    gr.Markdown(f"""
---
**Dataset:** [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})
    """)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)