Spaces:
Runtime error
Runtime error
File size: 6,236 Bytes
2d24e61 ccb16eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | #!/usr/bin/env python3
"""
DI Annotation Data Visualizer
Visualizes data from: DynamicIntelligence/di-annotation-data
"""
import gradio as gr
from huggingface_hub import hf_hub_download, list_repo_files, HfApi
import json
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path
import pandas as pd
# YOUR dataset - not the old one
DATASET_REPO = "DynamicIntelligence/di-annotation-data"
def list_episodes():
"""List all episodes in the dataset."""
try:
api = HfApi()
files = list(api.list_repo_files(repo_id=DATASET_REPO, repo_type="dataset"))
# Find episodes
episodes = set()
for f in files:
if f.startswith("episodes/") and f.endswith("/trajectory.json"):
parts = f.split("/")
if len(parts) >= 2:
episodes.add(parts[1])
elif f.startswith("episodes/") and "/" in f:
parts = f.split("/")
if len(parts) >= 2 and parts[1]:
episodes.add(parts[1])
return sorted(list(episodes)) if episodes else ["No episodes found"]
except Exception as e:
return [f"Error: {str(e)}"]
def load_episode_data(episode_id: str):
"""Load trajectory data for an episode."""
try:
# Download trajectory.json
local_path = hf_hub_download(
repo_id=DATASET_REPO,
filename=f"episodes/{episode_id}/trajectory.json",
repo_type="dataset"
)
with open(local_path) as f:
data = json.load(f)
return data, None
except Exception as e:
return None, str(e)
def create_plots(data: dict):
"""Create trajectory plots."""
camera = data.get("camera", {})
x = camera.get("x", [])
y = camera.get("y", [])
z = camera.get("z", [])
timestamps = data.get("timestamps", list(range(len(x))))
if not x:
empty = go.Figure()
empty.add_annotation(text="No trajectory data", showarrow=False, font_size=20)
return empty, empty
# 3D trajectory
fig_3d = go.Figure()
fig_3d.add_trace(go.Scatter3d(
x=x, y=y, z=z,
mode='lines',
line=dict(color='blue', width=4),
name='Camera'
))
fig_3d.add_trace(go.Scatter3d(
x=[x[0]], y=[y[0]], z=[z[0]],
mode='markers',
marker=dict(color='green', size=10),
name='Start'
))
fig_3d.add_trace(go.Scatter3d(
x=[x[-1]], y=[y[-1]], z=[z[-1]],
mode='markers',
marker=dict(color='red', size=10),
name='End'
))
fig_3d.update_layout(
title="Camera Trajectory (World Frame)",
scene=dict(xaxis_title='X (m)', yaxis_title='Y (m)', zaxis_title='Z (m)'),
height=500
)
# Time series
fig_ts = make_subplots(rows=3, cols=1, shared_xaxes=True,
subplot_titles=['Camera X', 'Camera Y', 'Camera Z'])
fig_ts.add_trace(go.Scatter(x=timestamps, y=x, name='X', line=dict(color='red')), row=1, col=1)
fig_ts.add_trace(go.Scatter(x=timestamps, y=y, name='Y', line=dict(color='green')), row=2, col=1)
fig_ts.add_trace(go.Scatter(x=timestamps, y=z, name='Z', line=dict(color='blue')), row=3, col=1)
fig_ts.update_layout(height=500, title="Position vs Time", showlegend=True)
fig_ts.update_xaxes(title_text="Time (s)", row=3, col=1)
return fig_3d, fig_ts
def load_and_visualize(episode_id: str):
"""Main function to load and visualize episode."""
if not episode_id or episode_id.startswith("Error") or episode_id == "No episodes found":
empty = go.Figure()
empty.add_annotation(text="Select an episode to visualize", showarrow=False, font_size=20)
return empty, empty, "No episode selected"
data, error = load_episode_data(episode_id)
if error:
empty = go.Figure()
empty.add_annotation(text=f"Error: {error}", showarrow=False, font_size=14)
return empty, empty, f"**Error:** {error}"
fig_3d, fig_ts = create_plots(data)
# Stats
camera = data.get("camera", {})
annotations = data.get("annotations", [])
stats = f"""
## Episode: {episode_id}
| Property | Value |
|----------|-------|
| Task | {data.get('task', 'Unknown')} |
| Language Instruction | {data.get('language_instruction', data.get('task', 'N/A'))} |
| Frames | {data.get('num_frames', len(camera.get('x', [])))} |
| FPS | {data.get('fps', 30)} |
| Frame Range | {data.get('frame_range', {}).get('start', 0)} β {data.get('frame_range', {}).get('end', 'N/A')} |
| Annotations | {len(annotations)} segments |
"""
return fig_3d, fig_ts, stats
# Build Gradio interface
with gr.Blocks(
title="DI Annotation Visualizer",
theme=gr.themes.Soft(primary_hue="blue")
) as demo:
gr.Markdown(f"""
# DI Annotation Data Visualizer
Visualizing data from: [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})
This shows **YOUR** annotated episodes, not the old training dataset.
""")
with gr.Row():
episode_dropdown = gr.Dropdown(
label="Select Episode",
choices=list_episodes(),
interactive=True,
scale=3
)
refresh_btn = gr.Button("π Refresh", scale=1)
load_btn = gr.Button("π Load & Visualize", variant="primary", scale=1)
stats_output = gr.Markdown()
with gr.Tabs():
with gr.TabItem("π 3D Trajectory"):
plot_3d = gr.Plot(label="3D Trajectory")
with gr.TabItem("π Time Series"):
plot_ts = gr.Plot(label="Position vs Time")
# Events
load_btn.click(
fn=load_and_visualize,
inputs=[episode_dropdown],
outputs=[plot_3d, plot_ts, stats_output]
)
refresh_btn.click(
fn=lambda: gr.Dropdown(choices=list_episodes()),
outputs=[episode_dropdown]
)
gr.Markdown(f"""
---
**Dataset:** [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|