Raffael-Kultyshev commited on
Commit
d17f5c3
·
verified ·
1 Parent(s): 82e088a

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DI Annotation Data Visualizer - Simple working version
3
+ """
4
+ import gradio as gr
5
+ import pandas as pd
6
+ import plotly.graph_objects as go
7
+ from plotly.subplots import make_subplots
8
+ from huggingface_hub import hf_hub_download
9
+ import json
10
+
11
+ DATASET_REPO = "DynamicIntelligence/di-annotation-data"
12
+
13
+ def load_data():
14
+ """Load trajectory data from HuggingFace."""
15
+ try:
16
+ # Try to load parquet
17
+ local_path = hf_hub_download(
18
+ repo_id=DATASET_REPO,
19
+ filename="data/chunk-000/episode_000000.parquet",
20
+ repo_type="dataset"
21
+ )
22
+ df = pd.read_parquet(local_path)
23
+ return df, None
24
+ except Exception as e:
25
+ # Fallback to old format
26
+ try:
27
+ local_path = hf_hub_download(
28
+ repo_id=DATASET_REPO,
29
+ filename="data/Test_Data_Lidar_trajectory.parquet",
30
+ repo_type="dataset"
31
+ )
32
+ df = pd.read_parquet(local_path)
33
+ return df, None
34
+ except Exception as e2:
35
+ return None, f"Error loading data: {e2}"
36
+
37
+ def create_plots(df):
38
+ """Create visualization plots."""
39
+ if df is None:
40
+ return None, None
41
+
42
+ # Check which columns exist
43
+ has_new_format = 'observation.state' in df.columns
44
+
45
+ if has_new_format:
46
+ # New LeRobot format
47
+ obs = df['observation.state'].apply(lambda x: x if isinstance(x, list) else [0,0,0])
48
+ camera_x = [o[0] if len(o) > 0 else 0 for o in obs]
49
+ camera_y = [o[1] if len(o) > 1 else 0 for o in obs]
50
+ camera_z = [o[2] if len(o) > 2 else 0 for o in obs]
51
+ timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df)))
52
+ else:
53
+ # Old format
54
+ camera_x = df['camera_x'].tolist() if 'camera_x' in df.columns else []
55
+ camera_y = df['camera_y'].tolist() if 'camera_y' in df.columns else []
56
+ camera_z = df['camera_z'].tolist() if 'camera_z' in df.columns else []
57
+ timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df)))
58
+
59
+ if not camera_x:
60
+ return None, None
61
+
62
+ # 3D trajectory plot
63
+ fig_3d = go.Figure(data=[go.Scatter3d(
64
+ x=camera_x, y=camera_y, z=camera_z,
65
+ mode='lines+markers',
66
+ marker=dict(size=2, color=timestamps, colorscale='Viridis'),
67
+ line=dict(width=2, color='blue'),
68
+ name='Camera Path'
69
+ )])
70
+ fig_3d.update_layout(
71
+ title='Camera 3D Trajectory',
72
+ scene=dict(
73
+ xaxis_title='X (m)',
74
+ yaxis_title='Y (m)',
75
+ zaxis_title='Z (m)'
76
+ ),
77
+ height=500
78
+ )
79
+
80
+ # Time series
81
+ fig_ts = make_subplots(rows=3, cols=1, subplot_titles=['Camera X', 'Camera Y', 'Camera Z'])
82
+ fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_x, name='X', line=dict(color='red')), row=1, col=1)
83
+ fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_y, name='Y', line=dict(color='green')), row=2, col=1)
84
+ fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_z, name='Z', line=dict(color='blue')), row=3, col=1)
85
+ fig_ts.update_layout(height=600, title='Camera Position vs Time', showlegend=False)
86
+
87
+ return fig_3d, fig_ts
88
+
89
+ def visualize():
90
+ """Main visualization function."""
91
+ df, error = load_data()
92
+
93
+ if error:
94
+ return None, None, f"Error: {error}"
95
+
96
+ if df is None or len(df) == 0:
97
+ return None, None, "No data found"
98
+
99
+ fig_3d, fig_ts = create_plots(df)
100
+
101
+ stats = f"""
102
+ **Dataset Stats:**
103
+ - Total frames: {len(df)}
104
+ - Columns: {', '.join(df.columns[:5])}...
105
+ - Source: {DATASET_REPO}
106
+ """
107
+
108
+ return fig_3d, fig_ts, stats
109
+
110
+ # Create Gradio interface
111
+ with gr.Blocks(title="DI Annotation Data Visualizer", theme=gr.themes.Soft()) as demo:
112
+ gr.Markdown("# DI Annotation Data Visualizer")
113
+ gr.Markdown(f"Visualizing trajectory data from [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})")
114
+
115
+ load_btn = gr.Button("Load & Visualize Data", variant="primary")
116
+
117
+ with gr.Row():
118
+ plot_3d = gr.Plot(label="3D Trajectory")
119
+ plot_ts = gr.Plot(label="Time Series")
120
+
121
+ stats_output = gr.Markdown()
122
+
123
+ load_btn.click(
124
+ fn=visualize,
125
+ inputs=[],
126
+ outputs=[plot_3d, plot_ts, stats_output]
127
+ )
128
+
129
+ # Auto-load on start
130
+ demo.load(visualize, inputs=[], outputs=[plot_3d, plot_ts, stats_output])
131
+
132
+ if __name__ == "__main__":
133
+ demo.launch(server_name="0.0.0.0", server_port=7860)