Raffael-Kultyshev's picture
Upload app.py with huggingface_hub
ccb16eb verified
raw
history blame
6.24 kB
#!/usr/bin/env python3
"""
DI Annotation Data Visualizer
Visualizes data from: DynamicIntelligence/di-annotation-data
"""
import gradio as gr
from huggingface_hub import hf_hub_download, list_repo_files, HfApi
import json
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path
import pandas as pd
# YOUR dataset - not the old one
DATASET_REPO = "DynamicIntelligence/di-annotation-data"
def list_episodes():
"""List all episodes in the dataset."""
try:
api = HfApi()
files = list(api.list_repo_files(repo_id=DATASET_REPO, repo_type="dataset"))
# Find episodes
episodes = set()
for f in files:
if f.startswith("episodes/") and f.endswith("/trajectory.json"):
parts = f.split("/")
if len(parts) >= 2:
episodes.add(parts[1])
elif f.startswith("episodes/") and "/" in f:
parts = f.split("/")
if len(parts) >= 2 and parts[1]:
episodes.add(parts[1])
return sorted(list(episodes)) if episodes else ["No episodes found"]
except Exception as e:
return [f"Error: {str(e)}"]
def load_episode_data(episode_id: str):
"""Load trajectory data for an episode."""
try:
# Download trajectory.json
local_path = hf_hub_download(
repo_id=DATASET_REPO,
filename=f"episodes/{episode_id}/trajectory.json",
repo_type="dataset"
)
with open(local_path) as f:
data = json.load(f)
return data, None
except Exception as e:
return None, str(e)
def create_plots(data: dict):
"""Create trajectory plots."""
camera = data.get("camera", {})
x = camera.get("x", [])
y = camera.get("y", [])
z = camera.get("z", [])
timestamps = data.get("timestamps", list(range(len(x))))
if not x:
empty = go.Figure()
empty.add_annotation(text="No trajectory data", showarrow=False, font_size=20)
return empty, empty
# 3D trajectory
fig_3d = go.Figure()
fig_3d.add_trace(go.Scatter3d(
x=x, y=y, z=z,
mode='lines',
line=dict(color='blue', width=4),
name='Camera'
))
fig_3d.add_trace(go.Scatter3d(
x=[x[0]], y=[y[0]], z=[z[0]],
mode='markers',
marker=dict(color='green', size=10),
name='Start'
))
fig_3d.add_trace(go.Scatter3d(
x=[x[-1]], y=[y[-1]], z=[z[-1]],
mode='markers',
marker=dict(color='red', size=10),
name='End'
))
fig_3d.update_layout(
title="Camera Trajectory (World Frame)",
scene=dict(xaxis_title='X (m)', yaxis_title='Y (m)', zaxis_title='Z (m)'),
height=500
)
# Time series
fig_ts = make_subplots(rows=3, cols=1, shared_xaxes=True,
subplot_titles=['Camera X', 'Camera Y', 'Camera Z'])
fig_ts.add_trace(go.Scatter(x=timestamps, y=x, name='X', line=dict(color='red')), row=1, col=1)
fig_ts.add_trace(go.Scatter(x=timestamps, y=y, name='Y', line=dict(color='green')), row=2, col=1)
fig_ts.add_trace(go.Scatter(x=timestamps, y=z, name='Z', line=dict(color='blue')), row=3, col=1)
fig_ts.update_layout(height=500, title="Position vs Time", showlegend=True)
fig_ts.update_xaxes(title_text="Time (s)", row=3, col=1)
return fig_3d, fig_ts
def load_and_visualize(episode_id: str):
"""Main function to load and visualize episode."""
if not episode_id or episode_id.startswith("Error") or episode_id == "No episodes found":
empty = go.Figure()
empty.add_annotation(text="Select an episode to visualize", showarrow=False, font_size=20)
return empty, empty, "No episode selected"
data, error = load_episode_data(episode_id)
if error:
empty = go.Figure()
empty.add_annotation(text=f"Error: {error}", showarrow=False, font_size=14)
return empty, empty, f"**Error:** {error}"
fig_3d, fig_ts = create_plots(data)
# Stats
camera = data.get("camera", {})
annotations = data.get("annotations", [])
stats = f"""
## Episode: {episode_id}
| Property | Value |
|----------|-------|
| Task | {data.get('task', 'Unknown')} |
| Language Instruction | {data.get('language_instruction', data.get('task', 'N/A'))} |
| Frames | {data.get('num_frames', len(camera.get('x', [])))} |
| FPS | {data.get('fps', 30)} |
| Frame Range | {data.get('frame_range', {}).get('start', 0)} β†’ {data.get('frame_range', {}).get('end', 'N/A')} |
| Annotations | {len(annotations)} segments |
"""
return fig_3d, fig_ts, stats
# Build Gradio interface
with gr.Blocks(
title="DI Annotation Visualizer",
theme=gr.themes.Soft(primary_hue="blue")
) as demo:
gr.Markdown(f"""
# DI Annotation Data Visualizer
Visualizing data from: [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})
This shows **YOUR** annotated episodes, not the old training dataset.
""")
with gr.Row():
episode_dropdown = gr.Dropdown(
label="Select Episode",
choices=list_episodes(),
interactive=True,
scale=3
)
refresh_btn = gr.Button("πŸ”„ Refresh", scale=1)
load_btn = gr.Button("πŸ“Š Load & Visualize", variant="primary", scale=1)
stats_output = gr.Markdown()
with gr.Tabs():
with gr.TabItem("🌐 3D Trajectory"):
plot_3d = gr.Plot(label="3D Trajectory")
with gr.TabItem("πŸ“ˆ Time Series"):
plot_ts = gr.Plot(label="Position vs Time")
# Events
load_btn.click(
fn=load_and_visualize,
inputs=[episode_dropdown],
outputs=[plot_3d, plot_ts, stats_output]
)
refresh_btn.click(
fn=lambda: gr.Dropdown(choices=list_episodes()),
outputs=[episode_dropdown]
)
gr.Markdown(f"""
---
**Dataset:** [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)