import json import os import logging import sys import html from pathlib import Path from typing import Dict, List, Optional from functools import lru_cache import gradio as gr import pandas as pd import plotly.graph_objects as go import plotly.io as pio from huggingface_hub import snapshot_download, HfApi logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger(__name__) DEFAULT_DATASET_ID = os.getenv( "DATASET_ID", "raffaelkultyshev/humanoid-robots-training-dataset" ) LOCAL_DATASET_DIR = Path("dataset_cache") HF_TOKEN = os.getenv("HF_TOKEN") JOINT_ALIASES = { "wrist": "Wrist", "thumb_tip": "Thumb Tip", "index_mcp": "Index MCP", "index_tip": "Index Tip", } JOINT_NAME_MAP = { "wrist": "WRIST", "thumb_tip": "THUMB_TIP", "index_mcp": "INDEX_FINGER_MCP", "index_tip": "INDEX_FINGER_TIP", } METRIC_LABELS = { "x_cm": "X (cm)", "y_cm": "Y (cm)", "z_cm": "Z (cm)", "yaw_deg": "Yaw (°)", "pitch_deg": "Pitch (°)", "roll_deg": "Roll (°)", } PLOT_GRID = [ ["x_cm", "y_cm", "z_cm"], ["yaw_deg", "pitch_deg", "roll_deg"], ] PLOT_ORDER = [metric for row in PLOT_GRID for metric in row] CUSTOM_CSS = """ :root, .gradio-container, body { background-color: #050a18 !important; color: #f8fafc !important; font-family: 'Inter', 'Segoe UI', system-ui, sans-serif; } .side-panel { background: #0f172a; padding: 20px; border-radius: 18px; border: 1px solid #1f2b47; min-height: 100%; } .stats-card ul { list-style: none; padding: 0; margin: 0; font-size: 0.92rem; } .stats-card li { margin-bottom: 10px; color: #e2e8f0; } .stats-card span { display: inline-block; margin-right: 6px; color: #7dd3fc; } .episodes-title { margin: 18px 0 8px; font-size: 0.78rem; text-transform: uppercase; letter-spacing: 0.14em; color: #94a3b8; } .episode-list .gr-form { padding: 0; } .episode-list .gr-form > div { gap: 0; } .episode-list input[type="radio"] { display: none; } .episode-list label { background: transparent !important; border: none !important; color: #cbd5f5 !important; padding: 3px 0 !important; justify-content: flex-start; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; font-size: 0.9rem; text-decoration: underline; } .episode-list label:hover { color: #67e8f9 !important; cursor: pointer; } .episode-list input[type="radio"]:checked + label { color: #facc15 !important; font-weight: 700; margin-left: -2px; } .main-panel { padding-top: 8px; } .instruction-card { background: #0f172a; padding: 18px 20px; border-radius: 18px; border: 1px solid #1f2b47; } .instruction-label { font-size: 0.75rem; letter-spacing: 0.12em; text-transform: uppercase; color: #94a3b8; margin-bottom: 10px; } .instruction-text { font-size: 1.1rem; line-height: 1.5; } .video-card { background: #0f172a; border: 1px solid #1f2b47; border-radius: 18px; padding: 18px 20px; margin-top: 18px; } .video-title { font-size: 0.78rem; text-transform: uppercase; letter-spacing: 0.18em; color: #94a3b8; margin-bottom: 8px; } .video-panel video { border-radius: 12px; border: 1px solid #1f2b47; background: #030712; } .download-button button { border-radius: 999px; border: 1px solid #334155; background: #1e293b; color: #f8fafc; font-size: 0.85rem; padding: 8px 24px; } .download-button button:hover { border-color: #67e8f9; color: #67e8f9; } .plots-wrap { margin-top: 18px; } .plots-wrap .gr-row { gap: 16px; } .plot-html { background: #111a2c; border-radius: 12px; padding: 10px; border: 1px solid #1f2b47; min-height: 320px; } .plot-html iframe { width: 100%; height: 300px; border: none; } """ @lru_cache(maxsize=1) def get_dataset_revision(repo_id: str) -> Optional[str]: try: info = HfApi(token=HF_TOKEN).repo_info(repo_id=repo_id, repo_type="dataset") return info.sha except Exception as exc: logger.warning(f"Could not fetch dataset revision for {repo_id}: {exc}") return None @lru_cache(maxsize=2) def get_dataset_root(repo_id: str, revision: Optional[str]) -> Path: local_path = snapshot_download( repo_id=repo_id, repo_type="dataset", local_dir=LOCAL_DATASET_DIR, local_dir_use_symlinks=False, revision=revision, token=HF_TOKEN, ) return Path(local_path) @lru_cache(maxsize=2) def load_info(repo_id: str, revision: Optional[str]) -> Dict: root = get_dataset_root(repo_id, revision) info_path = root / "meta" / "info.json" with open(info_path, "r", encoding="utf-8") as f: return json.load(f) def resolve_path(root: Path, template: str, episode_chunk: int, episode_index: int) -> Path: if isinstance(template, dict): rgb_template = template.get("rgb") if rgb_template is None: raise ValueError("RGB template missing from metadata") return root / rgb_template.format(episode_chunk=episode_chunk, episode_index=episode_index) return root / template.format(episode_chunk=episode_chunk, episode_index=episode_index) @lru_cache(maxsize=64) def load_episode(repo_id: str, episode_index: int, revision: Optional[str]) -> Dict: info = load_info(repo_id, revision) root = get_dataset_root(repo_id, revision) episode_meta = next((ep for ep in info["episodes"] if ep["episode_index"] == episode_index), None) if not episode_meta: raise ValueError(f"Episode {episode_index} not found in metadata") chunk = episode_meta["episode_chunk"] parquet_path = resolve_path(root, info["data_path"], chunk, episode_index) if not parquet_path.exists(): raise FileNotFoundError(f"Parquet file not found: {parquet_path}") df = pd.read_parquet(parquet_path) timestamps, state_df = build_state_dataframe(df) rgb_path = resolve_path(root, info["video_path"], chunk, episode_index) instruction = ( episode_meta.get("language_instruction") or ( df["language_instruction"].dropna().iloc[0] if "language_instruction" in df.columns and not df["language_instruction"].isna().all() else info.get("task", "Tape roll to bowl") ) ) return { "timestamps": timestamps, "state_df": state_df, "rgb_path": rgb_path, "instruction": instruction, } def build_state_dataframe(df: pd.DataFrame) -> (List[float], pd.DataFrame): if "frame_idx" not in df.columns or "timestamp_s" not in df.columns: raise ValueError("Episode parquet missing frame timing information.") frame_times = ( df[["frame_idx", "timestamp_s"]] .drop_duplicates("frame_idx") .set_index("frame_idx") .sort_index() ) frame_indices = frame_times.index.to_list() state_df = pd.DataFrame(index=frame_indices) for alias, joint_name in JOINT_NAME_MAP.items(): joint_df = ( df[df["joint_name"] == joint_name] .set_index("frame_idx") .sort_index() .reindex(frame_indices) ) for metric in METRIC_LABELS.keys(): if metric in joint_df.columns: state_df[f"{alias}_{metric}"] = joint_df[metric].astype(float) state_df.reset_index(drop=True, inplace=True) timestamps = frame_times["timestamp_s"].to_list() return timestamps, state_df def build_plot_fig(data: Dict, metric: str) -> go.Figure: timestamps = data["timestamps"] state_df = data["state_df"] fig = go.Figure() for alias, label in JOINT_ALIASES.items(): col_name = f"{alias}_{metric}" if col_name not in state_df.columns: continue fig.add_trace( go.Scatter( x=timestamps, y=state_df[col_name], mode="lines", name=label, ) ) fig.update_layout( margin=dict(l=20, r=20, t=30, b=20), height=250, template="plotly_dark", legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), xaxis_title="Time (s)", yaxis_title=METRIC_LABELS[metric], ) fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)") fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)") return fig def build_plot_html(data: Dict, metric: str) -> str: fig = build_plot_fig(data, metric) return pio.to_html(fig, include_plotlyjs="cdn", full_html=False) def format_episode_label(idx: int) -> str: return f"Episode {idx:02d}" def parse_episode_label(label: str) -> int: return int(label.replace("Episode", "").strip()) def format_instruction_html(text: str) -> str: safe_text = html.escape(text) return ( '
' '

Language Instruction

' f'

{safe_text}

' "
" ) def build_interface(): revision = get_dataset_revision(DEFAULT_DATASET_ID) info = load_info(DEFAULT_DATASET_ID, revision) episode_indices = sorted(ep["episode_index"] for ep in info["episodes"]) if not episode_indices: raise RuntimeError("No episodes found in dataset metadata.") default_idx = episode_indices[0] default_label = format_episode_label(default_idx) default_data = load_episode(DEFAULT_DATASET_ID, default_idx, revision) default_video = str(default_data["rgb_path"]) default_instruction = default_data["instruction"] default_figs = {metric: build_plot_html(default_data, metric) for metric in METRIC_LABELS.keys()} total_frames = sum(ep.get("num_frames", 0) for ep in info["episodes"]) fps = info.get("fps", 30.0) stats_html = f"""
""" theme = gr.themes.Soft( primary_hue="cyan", secondary_hue="blue", neutral_hue="slate" ).set( body_background_fill="#0c1424", body_text_color="#f8fafc", block_background_fill="#111a2c", block_title_text_color="#f8fafc", input_background_fill="#151f33", border_color_primary="#1f2b47", shadow_drop="none", ) with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo: gr.Markdown("# Humanoid Robots Hand Pose Viewer") gr.Markdown( "Visualize RGB + 6DoF hand trajectories for all Moving_Mini tasks " "(humanoid-robots-training-dataset)." ) with gr.Row(equal_height=True): with gr.Column(scale=1, min_width=260, elem_classes=["side-panel"]): gr.HTML(stats_html) gr.HTML('
Episodes
') episode_radio = gr.Radio( choices=[format_episode_label(i) for i in episode_indices], value=default_label, label="Episodes", elem_classes=["episode-list"], ) with gr.Column(scale=2, min_width=640, elem_classes=["main-panel"]): instruction_box = gr.HTML( format_instruction_html(default_instruction), label="Language Instruction", ) with gr.Column(elem_classes=["video-card"]): gr.HTML('
RGB
') video = gr.Video( height=360, value=default_video, elem_classes=["video-panel"], show_label=False, show_download_button=False, ) download_button = gr.DownloadButton( label="Download", value=default_video, elem_classes=["download-button"], ) plot_outputs = [] gr.Markdown("### Joint trajectories", elem_classes=["plots-title"]) with gr.Column(elem_classes=["plots-wrap"]): for row in PLOT_GRID: with gr.Row(): for metric in row: plot = gr.HTML(value=default_figs[metric], elem_classes=["plot-html"]) plot_outputs.append(plot) outputs = [instruction_box, video, download_button] + plot_outputs def load_episode_payload(label: str): idx = parse_episode_label(label) data = load_episode(DEFAULT_DATASET_ID, idx, revision) video_path = str(data["rgb_path"]) figs = [build_plot_html(data, metric) for metric in PLOT_ORDER] return [ format_instruction_html(data["instruction"]), video_path, gr.DownloadButton.update(value=video_path), *figs, ] episode_radio.change(fn=load_episode_payload, inputs=episode_radio, outputs=outputs) return demo def main(): demo = build_interface() demo.queue().launch(show_api=False) if __name__ == "__main__": main()