Raffael-Kultyshev commited on
Commit
2d24e61
Β·
verified Β·
1 Parent(s): 44f87c2

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DI Annotation Data Visualizer
4
+ Visualizes data from: DynamicIntelligence/di-annotation-data
5
+ """
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import hf_hub_download, list_repo_files, HfApi
9
+ import json
10
+ import numpy as np
11
+ import plotly.graph_objects as go
12
+ from plotly.subplots import make_subplots
13
+ from pathlib import Path
14
+ import pandas as pd
15
+
16
+ # YOUR dataset - not the old one
17
+ DATASET_REPO = "DynamicIntelligence/di-annotation-data"
18
+
19
+ def list_episodes():
20
+ """List all episodes in the dataset."""
21
+ try:
22
+ api = HfApi()
23
+ files = list(api.list_repo_files(repo_id=DATASET_REPO, repo_type="dataset"))
24
+
25
+ # Find episodes
26
+ episodes = set()
27
+ for f in files:
28
+ if f.startswith("episodes/") and f.endswith("/trajectory.json"):
29
+ parts = f.split("/")
30
+ if len(parts) >= 2:
31
+ episodes.add(parts[1])
32
+ elif f.startswith("episodes/") and "/" in f:
33
+ parts = f.split("/")
34
+ if len(parts) >= 2 and parts[1]:
35
+ episodes.add(parts[1])
36
+
37
+ return sorted(list(episodes)) if episodes else ["No episodes found"]
38
+ except Exception as e:
39
+ return [f"Error: {str(e)}"]
40
+
41
+
42
+ def load_episode_data(episode_id: str):
43
+ """Load trajectory data for an episode."""
44
+ try:
45
+ # Download trajectory.json
46
+ local_path = hf_hub_download(
47
+ repo_id=DATASET_REPO,
48
+ filename=f"episodes/{episode_id}/trajectory.json",
49
+ repo_type="dataset"
50
+ )
51
+
52
+ with open(local_path) as f:
53
+ data = json.load(f)
54
+
55
+ return data, None
56
+ except Exception as e:
57
+ return None, str(e)
58
+
59
+
60
+ def create_plots(data: dict):
61
+ """Create trajectory plots."""
62
+
63
+ camera = data.get("camera", {})
64
+ x = camera.get("x", [])
65
+ y = camera.get("y", [])
66
+ z = camera.get("z", [])
67
+ timestamps = data.get("timestamps", list(range(len(x))))
68
+
69
+ if not x:
70
+ empty = go.Figure()
71
+ empty.add_annotation(text="No trajectory data", showarrow=False, font_size=20)
72
+ return empty, empty
73
+
74
+ # 3D trajectory
75
+ fig_3d = go.Figure()
76
+ fig_3d.add_trace(go.Scatter3d(
77
+ x=x, y=y, z=z,
78
+ mode='lines',
79
+ line=dict(color='blue', width=4),
80
+ name='Camera'
81
+ ))
82
+ fig_3d.add_trace(go.Scatter3d(
83
+ x=[x[0]], y=[y[0]], z=[z[0]],
84
+ mode='markers',
85
+ marker=dict(color='green', size=10),
86
+ name='Start'
87
+ ))
88
+ fig_3d.add_trace(go.Scatter3d(
89
+ x=[x[-1]], y=[y[-1]], z=[z[-1]],
90
+ mode='markers',
91
+ marker=dict(color='red', size=10),
92
+ name='End'
93
+ ))
94
+ fig_3d.update_layout(
95
+ title="Camera Trajectory (World Frame)",
96
+ scene=dict(xaxis_title='X (m)', yaxis_title='Y (m)', zaxis_title='Z (m)'),
97
+ height=500
98
+ )
99
+
100
+ # Time series
101
+ fig_ts = make_subplots(rows=3, cols=1, shared_xaxes=True,
102
+ subplot_titles=['Camera X', 'Camera Y', 'Camera Z'])
103
+
104
+ fig_ts.add_trace(go.Scatter(x=timestamps, y=x, name='X', line=dict(color='red')), row=1, col=1)
105
+ fig_ts.add_trace(go.Scatter(x=timestamps, y=y, name='Y', line=dict(color='green')), row=2, col=1)
106
+ fig_ts.add_trace(go.Scatter(x=timestamps, y=z, name='Z', line=dict(color='blue')), row=3, col=1)
107
+
108
+ fig_ts.update_layout(height=500, title="Position vs Time", showlegend=True)
109
+ fig_ts.update_xaxes(title_text="Time (s)", row=3, col=1)
110
+
111
+ return fig_3d, fig_ts
112
+
113
+
114
+ def load_and_visualize(episode_id: str):
115
+ """Main function to load and visualize episode."""
116
+
117
+ if not episode_id or episode_id.startswith("Error") or episode_id == "No episodes found":
118
+ empty = go.Figure()
119
+ empty.add_annotation(text="Select an episode to visualize", showarrow=False, font_size=20)
120
+ return empty, empty, "No episode selected"
121
+
122
+ data, error = load_episode_data(episode_id)
123
+
124
+ if error:
125
+ empty = go.Figure()
126
+ empty.add_annotation(text=f"Error: {error}", showarrow=False, font_size=14)
127
+ return empty, empty, f"**Error:** {error}"
128
+
129
+ fig_3d, fig_ts = create_plots(data)
130
+
131
+ # Stats
132
+ camera = data.get("camera", {})
133
+ annotations = data.get("annotations", [])
134
+
135
+ stats = f"""
136
+ ## Episode: {episode_id}
137
+
138
+ | Property | Value |
139
+ |----------|-------|
140
+ | Task | {data.get('task', 'Unknown')} |
141
+ | Language Instruction | {data.get('language_instruction', data.get('task', 'N/A'))} |
142
+ | Frames | {data.get('num_frames', len(camera.get('x', [])))} |
143
+ | FPS | {data.get('fps', 30)} |
144
+ | Frame Range | {data.get('frame_range', {}).get('start', 0)} β†’ {data.get('frame_range', {}).get('end', 'N/A')} |
145
+ | Annotations | {len(annotations)} segments |
146
+ """
147
+
148
+ return fig_3d, fig_ts, stats
149
+
150
+
151
+ # Build Gradio interface
152
+ with gr.Blocks(
153
+ title="DI Annotation Visualizer",
154
+ theme=gr.themes.Soft(primary_hue="blue")
155
+ ) as demo:
156
+ gr.Markdown(f"""
157
+ # DI Annotation Data Visualizer
158
+
159
+ Visualizing data from: [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})
160
+
161
+ This shows **YOUR** annotated episodes, not the old training dataset.
162
+ """)
163
+
164
+ with gr.Row():
165
+ episode_dropdown = gr.Dropdown(
166
+ label="Select Episode",
167
+ choices=list_episodes(),
168
+ interactive=True,
169
+ scale=3
170
+ )
171
+ refresh_btn = gr.Button("πŸ”„ Refresh", scale=1)
172
+ load_btn = gr.Button("πŸ“Š Load & Visualize", variant="primary", scale=1)
173
+
174
+ stats_output = gr.Markdown()
175
+
176
+ with gr.Tabs():
177
+ with gr.TabItem("🌐 3D Trajectory"):
178
+ plot_3d = gr.Plot(label="3D Trajectory")
179
+
180
+ with gr.TabItem("πŸ“ˆ Time Series"):
181
+ plot_ts = gr.Plot(label="Position vs Time")
182
+
183
+ # Events
184
+ load_btn.click(
185
+ fn=load_and_visualize,
186
+ inputs=[episode_dropdown],
187
+ outputs=[plot_3d, plot_ts, stats_output]
188
+ )
189
+
190
+ refresh_btn.click(
191
+ fn=lambda: gr.Dropdown(choices=list_episodes()),
192
+ outputs=[episode_dropdown]
193
+ )
194
+
195
+ gr.Markdown(f"""
196
+ ---
197
+ **Dataset:** [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})
198
+ """)
199
+
200
+ if __name__ == "__main__":
201
+ demo.launch()