Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| DI Annotation Data Visualizer | |
| Visualizes data from: DynamicIntelligence/di-annotation-data | |
| """ | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download, list_repo_files, HfApi | |
| import json | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from pathlib import Path | |
| import pandas as pd | |
| # YOUR dataset - not the old one | |
| DATASET_REPO = "DynamicIntelligence/di-annotation-data" | |
| def list_episodes(): | |
| """List all episodes in the dataset.""" | |
| try: | |
| api = HfApi() | |
| files = list(api.list_repo_files(repo_id=DATASET_REPO, repo_type="dataset")) | |
| # Find episodes | |
| episodes = set() | |
| for f in files: | |
| if f.startswith("episodes/") and f.endswith("/trajectory.json"): | |
| parts = f.split("/") | |
| if len(parts) >= 2: | |
| episodes.add(parts[1]) | |
| elif f.startswith("episodes/") and "/" in f: | |
| parts = f.split("/") | |
| if len(parts) >= 2 and parts[1]: | |
| episodes.add(parts[1]) | |
| return sorted(list(episodes)) if episodes else ["No episodes found"] | |
| except Exception as e: | |
| return [f"Error: {str(e)}"] | |
| def load_episode_data(episode_id: str): | |
| """Load trajectory data for an episode.""" | |
| try: | |
| # Download trajectory.json | |
| local_path = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename=f"episodes/{episode_id}/trajectory.json", | |
| repo_type="dataset" | |
| ) | |
| with open(local_path) as f: | |
| data = json.load(f) | |
| return data, None | |
| except Exception as e: | |
| return None, str(e) | |
| def create_plots(data: dict): | |
| """Create trajectory plots.""" | |
| camera = data.get("camera", {}) | |
| x = camera.get("x", []) | |
| y = camera.get("y", []) | |
| z = camera.get("z", []) | |
| timestamps = data.get("timestamps", list(range(len(x)))) | |
| if not x: | |
| empty = go.Figure() | |
| empty.add_annotation(text="No trajectory data", showarrow=False, font_size=20) | |
| return empty, empty | |
| # 3D trajectory | |
| fig_3d = go.Figure() | |
| fig_3d.add_trace(go.Scatter3d( | |
| x=x, y=y, z=z, | |
| mode='lines', | |
| line=dict(color='blue', width=4), | |
| name='Camera' | |
| )) | |
| fig_3d.add_trace(go.Scatter3d( | |
| x=[x[0]], y=[y[0]], z=[z[0]], | |
| mode='markers', | |
| marker=dict(color='green', size=10), | |
| name='Start' | |
| )) | |
| fig_3d.add_trace(go.Scatter3d( | |
| x=[x[-1]], y=[y[-1]], z=[z[-1]], | |
| mode='markers', | |
| marker=dict(color='red', size=10), | |
| name='End' | |
| )) | |
| fig_3d.update_layout( | |
| title="Camera Trajectory (World Frame)", | |
| scene=dict(xaxis_title='X (m)', yaxis_title='Y (m)', zaxis_title='Z (m)'), | |
| height=500 | |
| ) | |
| # Time series | |
| fig_ts = make_subplots(rows=3, cols=1, shared_xaxes=True, | |
| subplot_titles=['Camera X', 'Camera Y', 'Camera Z']) | |
| fig_ts.add_trace(go.Scatter(x=timestamps, y=x, name='X', line=dict(color='red')), row=1, col=1) | |
| fig_ts.add_trace(go.Scatter(x=timestamps, y=y, name='Y', line=dict(color='green')), row=2, col=1) | |
| fig_ts.add_trace(go.Scatter(x=timestamps, y=z, name='Z', line=dict(color='blue')), row=3, col=1) | |
| fig_ts.update_layout(height=500, title="Position vs Time", showlegend=True) | |
| fig_ts.update_xaxes(title_text="Time (s)", row=3, col=1) | |
| return fig_3d, fig_ts | |
| def load_and_visualize(episode_id: str): | |
| """Main function to load and visualize episode.""" | |
| if not episode_id or episode_id.startswith("Error") or episode_id == "No episodes found": | |
| empty = go.Figure() | |
| empty.add_annotation(text="Select an episode to visualize", showarrow=False, font_size=20) | |
| return empty, empty, "No episode selected" | |
| data, error = load_episode_data(episode_id) | |
| if error: | |
| empty = go.Figure() | |
| empty.add_annotation(text=f"Error: {error}", showarrow=False, font_size=14) | |
| return empty, empty, f"**Error:** {error}" | |
| fig_3d, fig_ts = create_plots(data) | |
| # Stats | |
| camera = data.get("camera", {}) | |
| annotations = data.get("annotations", []) | |
| stats = f""" | |
| ## Episode: {episode_id} | |
| | Property | Value | | |
| |----------|-------| | |
| | Task | {data.get('task', 'Unknown')} | | |
| | Language Instruction | {data.get('language_instruction', data.get('task', 'N/A'))} | | |
| | Frames | {data.get('num_frames', len(camera.get('x', [])))} | | |
| | FPS | {data.get('fps', 30)} | | |
| | Frame Range | {data.get('frame_range', {}).get('start', 0)} β {data.get('frame_range', {}).get('end', 'N/A')} | | |
| | Annotations | {len(annotations)} segments | | |
| """ | |
| return fig_3d, fig_ts, stats | |
| # Build Gradio interface | |
| with gr.Blocks( | |
| title="DI Annotation Visualizer", | |
| theme=gr.themes.Soft(primary_hue="blue") | |
| ) as demo: | |
| gr.Markdown(f""" | |
| # DI Annotation Data Visualizer | |
| Visualizing data from: [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO}) | |
| This shows **YOUR** annotated episodes, not the old training dataset. | |
| """) | |
| with gr.Row(): | |
| episode_dropdown = gr.Dropdown( | |
| label="Select Episode", | |
| choices=list_episodes(), | |
| interactive=True, | |
| scale=3 | |
| ) | |
| refresh_btn = gr.Button("π Refresh", scale=1) | |
| load_btn = gr.Button("π Load & Visualize", variant="primary", scale=1) | |
| stats_output = gr.Markdown() | |
| with gr.Tabs(): | |
| with gr.TabItem("π 3D Trajectory"): | |
| plot_3d = gr.Plot(label="3D Trajectory") | |
| with gr.TabItem("π Time Series"): | |
| plot_ts = gr.Plot(label="Position vs Time") | |
| # Events | |
| load_btn.click( | |
| fn=load_and_visualize, | |
| inputs=[episode_dropdown], | |
| outputs=[plot_3d, plot_ts, stats_output] | |
| ) | |
| refresh_btn.click( | |
| fn=lambda: gr.Dropdown(choices=list_episodes()), | |
| outputs=[episode_dropdown] | |
| ) | |
| gr.Markdown(f""" | |
| --- | |
| **Dataset:** [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO}) | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |