Raffael-Kultyshev's picture
Upload app.py with huggingface_hub
d17f5c3 verified
raw
history blame
4.48 kB
"""
DI Annotation Data Visualizer - Simple working version
"""
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from huggingface_hub import hf_hub_download
import json
DATASET_REPO = "DynamicIntelligence/di-annotation-data"
def load_data():
"""Load trajectory data from HuggingFace."""
try:
# Try to load parquet
local_path = hf_hub_download(
repo_id=DATASET_REPO,
filename="data/chunk-000/episode_000000.parquet",
repo_type="dataset"
)
df = pd.read_parquet(local_path)
return df, None
except Exception as e:
# Fallback to old format
try:
local_path = hf_hub_download(
repo_id=DATASET_REPO,
filename="data/Test_Data_Lidar_trajectory.parquet",
repo_type="dataset"
)
df = pd.read_parquet(local_path)
return df, None
except Exception as e2:
return None, f"Error loading data: {e2}"
def create_plots(df):
"""Create visualization plots."""
if df is None:
return None, None
# Check which columns exist
has_new_format = 'observation.state' in df.columns
if has_new_format:
# New LeRobot format
obs = df['observation.state'].apply(lambda x: x if isinstance(x, list) else [0,0,0])
camera_x = [o[0] if len(o) > 0 else 0 for o in obs]
camera_y = [o[1] if len(o) > 1 else 0 for o in obs]
camera_z = [o[2] if len(o) > 2 else 0 for o in obs]
timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df)))
else:
# Old format
camera_x = df['camera_x'].tolist() if 'camera_x' in df.columns else []
camera_y = df['camera_y'].tolist() if 'camera_y' in df.columns else []
camera_z = df['camera_z'].tolist() if 'camera_z' in df.columns else []
timestamps = df['timestamp'].tolist() if 'timestamp' in df.columns else list(range(len(df)))
if not camera_x:
return None, None
# 3D trajectory plot
fig_3d = go.Figure(data=[go.Scatter3d(
x=camera_x, y=camera_y, z=camera_z,
mode='lines+markers',
marker=dict(size=2, color=timestamps, colorscale='Viridis'),
line=dict(width=2, color='blue'),
name='Camera Path'
)])
fig_3d.update_layout(
title='Camera 3D Trajectory',
scene=dict(
xaxis_title='X (m)',
yaxis_title='Y (m)',
zaxis_title='Z (m)'
),
height=500
)
# Time series
fig_ts = make_subplots(rows=3, cols=1, subplot_titles=['Camera X', 'Camera Y', 'Camera Z'])
fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_x, name='X', line=dict(color='red')), row=1, col=1)
fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_y, name='Y', line=dict(color='green')), row=2, col=1)
fig_ts.add_trace(go.Scatter(x=timestamps, y=camera_z, name='Z', line=dict(color='blue')), row=3, col=1)
fig_ts.update_layout(height=600, title='Camera Position vs Time', showlegend=False)
return fig_3d, fig_ts
def visualize():
"""Main visualization function."""
df, error = load_data()
if error:
return None, None, f"Error: {error}"
if df is None or len(df) == 0:
return None, None, "No data found"
fig_3d, fig_ts = create_plots(df)
stats = f"""
**Dataset Stats:**
- Total frames: {len(df)}
- Columns: {', '.join(df.columns[:5])}...
- Source: {DATASET_REPO}
"""
return fig_3d, fig_ts, stats
# Create Gradio interface
with gr.Blocks(title="DI Annotation Data Visualizer", theme=gr.themes.Soft()) as demo:
gr.Markdown("# DI Annotation Data Visualizer")
gr.Markdown(f"Visualizing trajectory data from [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})")
load_btn = gr.Button("Load & Visualize Data", variant="primary")
with gr.Row():
plot_3d = gr.Plot(label="3D Trajectory")
plot_ts = gr.Plot(label="Time Series")
stats_output = gr.Markdown()
load_btn.click(
fn=visualize,
inputs=[],
outputs=[plot_3d, plot_ts, stats_output]
)
# Auto-load on start
demo.load(visualize, inputs=[], outputs=[plot_3d, plot_ts, stats_output])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)