|
|
import json |
|
|
import os |
|
|
import logging |
|
|
import sys |
|
|
import html |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional |
|
|
from functools import lru_cache |
|
|
|
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import plotly.graph_objects as go |
|
|
import plotly.io as pio |
|
|
from huggingface_hub import snapshot_download, HfApi |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DEFAULT_DATASET_ID = os.getenv( |
|
|
"DATASET_ID", "DynamicIntelligence/humanoid-robots-training-dataset" |
|
|
) |
|
|
LOCAL_DATASET_DIR = Path("dataset_cache") |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
JOINT_ALIASES = { |
|
|
"wrist": "Wrist", |
|
|
"thumb_tip": "Thumb Tip", |
|
|
"index_mcp": "Index MCP", |
|
|
"index_tip": "Index Tip", |
|
|
} |
|
|
|
|
|
JOINT_NAME_MAP = { |
|
|
"wrist": "WRIST", |
|
|
"thumb_tip": "THUMB_TIP", |
|
|
"index_mcp": "INDEX_FINGER_MCP", |
|
|
"index_tip": "INDEX_FINGER_TIP", |
|
|
} |
|
|
|
|
|
METRIC_LABELS = { |
|
|
"x_cm": "X (cm)", |
|
|
"y_cm": "Y (cm)", |
|
|
"z_cm": "Z (cm)", |
|
|
"yaw_deg": "Yaw (°)", |
|
|
"pitch_deg": "Pitch (°)", |
|
|
"roll_deg": "Roll (°)", |
|
|
} |
|
|
|
|
|
PLOT_GRID = [ |
|
|
["x_cm", "y_cm", "z_cm"], |
|
|
["yaw_deg", "pitch_deg", "roll_deg"], |
|
|
] |
|
|
|
|
|
PLOT_ORDER = [metric for row in PLOT_GRID for metric in row] |
|
|
|
|
|
CUSTOM_CSS = """ |
|
|
:root, .gradio-container, body { |
|
|
background-color: #050a18 !important; |
|
|
color: #f8fafc !important; |
|
|
font-family: 'Inter', 'Segoe UI', system-ui, sans-serif; |
|
|
} |
|
|
.side-panel { |
|
|
background: #0f172a; |
|
|
padding: 20px; |
|
|
border-radius: 18px; |
|
|
border: 1px solid #1f2b47; |
|
|
min-height: 100%; |
|
|
} |
|
|
.stats-card ul { |
|
|
list-style: none; |
|
|
padding: 0; |
|
|
margin: 0; |
|
|
font-size: 0.92rem; |
|
|
} |
|
|
.stats-card li { |
|
|
margin-bottom: 10px; |
|
|
color: #e2e8f0; |
|
|
} |
|
|
.stats-card span { |
|
|
display: inline-block; |
|
|
margin-right: 6px; |
|
|
color: #7dd3fc; |
|
|
} |
|
|
.episodes-title { |
|
|
margin: 18px 0 8px; |
|
|
font-size: 0.78rem; |
|
|
text-transform: uppercase; |
|
|
letter-spacing: 0.14em; |
|
|
color: #94a3b8; |
|
|
} |
|
|
.episode-list .gr-form { |
|
|
padding: 0; |
|
|
} |
|
|
.episode-list .gr-form > div { |
|
|
gap: 0; |
|
|
} |
|
|
.episode-list input[type="radio"] { |
|
|
display: none; |
|
|
} |
|
|
.episode-list label { |
|
|
background: transparent !important; |
|
|
border: none !important; |
|
|
color: #cbd5f5 !important; |
|
|
padding: 3px 0 !important; |
|
|
justify-content: flex-start; |
|
|
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; |
|
|
font-size: 0.9rem; |
|
|
text-decoration: underline; |
|
|
} |
|
|
.episode-list label:hover { |
|
|
color: #67e8f9 !important; |
|
|
cursor: pointer; |
|
|
} |
|
|
.episode-list input[type="radio"]:checked + label { |
|
|
color: #facc15 !important; |
|
|
font-weight: 700; |
|
|
margin-left: -2px; |
|
|
} |
|
|
.main-panel { |
|
|
padding-top: 8px; |
|
|
} |
|
|
.instruction-card { |
|
|
background: #0f172a; |
|
|
padding: 18px 20px; |
|
|
border-radius: 18px; |
|
|
border: 1px solid #1f2b47; |
|
|
} |
|
|
.instruction-label { |
|
|
font-size: 0.75rem; |
|
|
letter-spacing: 0.12em; |
|
|
text-transform: uppercase; |
|
|
color: #94a3b8; |
|
|
margin-bottom: 10px; |
|
|
} |
|
|
.instruction-text { |
|
|
font-size: 1.1rem; |
|
|
line-height: 1.5; |
|
|
} |
|
|
.video-card { |
|
|
background: #0f172a; |
|
|
border: 1px solid #1f2b47; |
|
|
border-radius: 18px; |
|
|
padding: 18px 20px; |
|
|
margin-top: 18px; |
|
|
} |
|
|
.video-title { |
|
|
font-size: 0.78rem; |
|
|
text-transform: uppercase; |
|
|
letter-spacing: 0.18em; |
|
|
color: #94a3b8; |
|
|
margin-bottom: 8px; |
|
|
} |
|
|
.video-panel video { |
|
|
border-radius: 12px; |
|
|
border: 1px solid #1f2b47; |
|
|
background: #030712; |
|
|
} |
|
|
.download-button button { |
|
|
border-radius: 999px; |
|
|
border: 1px solid #334155; |
|
|
background: #1e293b; |
|
|
color: #f8fafc; |
|
|
font-size: 0.85rem; |
|
|
padding: 8px 24px; |
|
|
} |
|
|
.download-button button:hover { |
|
|
border-color: #67e8f9; |
|
|
color: #67e8f9; |
|
|
} |
|
|
.plots-wrap { |
|
|
margin-top: 18px; |
|
|
} |
|
|
.plots-wrap .gr-row { |
|
|
gap: 16px; |
|
|
} |
|
|
.plot-html { |
|
|
background: #111a2c; |
|
|
border-radius: 12px; |
|
|
padding: 10px; |
|
|
border: 1px solid #1f2b47; |
|
|
min-height: 320px; |
|
|
} |
|
|
.plot-html iframe { |
|
|
width: 100%; |
|
|
height: 300px; |
|
|
border: none; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def get_dataset_revision(repo_id: str) -> Optional[str]: |
|
|
try: |
|
|
info = HfApi(token=HF_TOKEN).repo_info(repo_id=repo_id, repo_type="dataset") |
|
|
return info.sha |
|
|
except Exception as exc: |
|
|
logger.warning(f"Could not fetch dataset revision for {repo_id}: {exc}") |
|
|
return None |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=2) |
|
|
def get_dataset_root(repo_id: str, revision: Optional[str]) -> Path: |
|
|
local_path = snapshot_download( |
|
|
repo_id=repo_id, |
|
|
repo_type="dataset", |
|
|
local_dir=LOCAL_DATASET_DIR, |
|
|
local_dir_use_symlinks=False, |
|
|
revision=revision, |
|
|
token=HF_TOKEN, |
|
|
) |
|
|
return Path(local_path) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=2) |
|
|
def load_info(repo_id: str, revision: Optional[str]) -> Dict: |
|
|
root = get_dataset_root(repo_id, revision) |
|
|
info_path = root / "meta" / "info.json" |
|
|
with open(info_path, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def resolve_path(root: Path, template: str, episode_chunk: int, episode_index: int) -> Path: |
|
|
if isinstance(template, dict): |
|
|
rgb_template = template.get("rgb") |
|
|
if rgb_template is None: |
|
|
raise ValueError("RGB template missing from metadata") |
|
|
return root / rgb_template.format(episode_chunk=episode_chunk, episode_index=episode_index) |
|
|
return root / template.format(episode_chunk=episode_chunk, episode_index=episode_index) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=64) |
|
|
def load_episode(repo_id: str, episode_index: int, revision: Optional[str]) -> Dict: |
|
|
info = load_info(repo_id, revision) |
|
|
root = get_dataset_root(repo_id, revision) |
|
|
episode_meta = next((ep for ep in info["episodes"] if ep["episode_index"] == episode_index), None) |
|
|
if not episode_meta: |
|
|
raise ValueError(f"Episode {episode_index} not found in metadata") |
|
|
|
|
|
chunk = episode_meta["episode_chunk"] |
|
|
parquet_path = resolve_path(root, info["data_path"], chunk, episode_index) |
|
|
if not parquet_path.exists(): |
|
|
raise FileNotFoundError(f"Parquet file not found: {parquet_path}") |
|
|
|
|
|
df = pd.read_parquet(parquet_path) |
|
|
timestamps, state_df = build_state_dataframe(df) |
|
|
|
|
|
rgb_path = resolve_path(root, info["video_path"], chunk, episode_index) |
|
|
|
|
|
instruction = ( |
|
|
episode_meta.get("language_instruction") |
|
|
or ( |
|
|
df["language_instruction"].dropna().iloc[0] |
|
|
if "language_instruction" in df.columns and not df["language_instruction"].isna().all() |
|
|
else info.get("task", "Tape roll to bowl") |
|
|
) |
|
|
) |
|
|
|
|
|
return { |
|
|
"timestamps": timestamps, |
|
|
"state_df": state_df, |
|
|
"rgb_path": rgb_path, |
|
|
"instruction": instruction, |
|
|
} |
|
|
|
|
|
|
|
|
def build_state_dataframe(df: pd.DataFrame) -> (List[float], pd.DataFrame): |
|
|
if "frame_idx" not in df.columns or "timestamp_s" not in df.columns: |
|
|
raise ValueError("Episode parquet missing frame timing information.") |
|
|
|
|
|
frame_times = ( |
|
|
df[["frame_idx", "timestamp_s"]] |
|
|
.drop_duplicates("frame_idx") |
|
|
.set_index("frame_idx") |
|
|
.sort_index() |
|
|
) |
|
|
frame_indices = frame_times.index.to_list() |
|
|
|
|
|
state_df = pd.DataFrame(index=frame_indices) |
|
|
for alias, joint_name in JOINT_NAME_MAP.items(): |
|
|
joint_df = ( |
|
|
df[df["joint_name"] == joint_name] |
|
|
.set_index("frame_idx") |
|
|
.sort_index() |
|
|
.reindex(frame_indices) |
|
|
) |
|
|
for metric in METRIC_LABELS.keys(): |
|
|
if metric in joint_df.columns: |
|
|
state_df[f"{alias}_{metric}"] = joint_df[metric].astype(float) |
|
|
|
|
|
state_df.reset_index(drop=True, inplace=True) |
|
|
timestamps = frame_times["timestamp_s"].to_list() |
|
|
return timestamps, state_df |
|
|
|
|
|
|
|
|
def build_plot_fig(data: Dict, metric: str) -> go.Figure: |
|
|
timestamps = data["timestamps"] |
|
|
state_df = data["state_df"] |
|
|
fig = go.Figure() |
|
|
for alias, label in JOINT_ALIASES.items(): |
|
|
col_name = f"{alias}_{metric}" |
|
|
if col_name not in state_df.columns: |
|
|
continue |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=timestamps, |
|
|
y=state_df[col_name], |
|
|
mode="lines", |
|
|
name=label, |
|
|
) |
|
|
) |
|
|
fig.update_layout( |
|
|
margin=dict(l=20, r=20, t=30, b=20), |
|
|
height=250, |
|
|
template="plotly_dark", |
|
|
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), |
|
|
xaxis_title="Time (s)", |
|
|
yaxis_title=METRIC_LABELS[metric], |
|
|
) |
|
|
fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)") |
|
|
fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)") |
|
|
return fig |
|
|
|
|
|
|
|
|
def build_plot_html(data: Dict, metric: str) -> str: |
|
|
fig = build_plot_fig(data, metric) |
|
|
return pio.to_html(fig, include_plotlyjs="cdn", full_html=False) |
|
|
|
|
|
|
|
|
def format_episode_label(idx: int) -> str: |
|
|
return f"Episode {idx:02d}" |
|
|
|
|
|
|
|
|
def parse_episode_label(label: str) -> int: |
|
|
return int(label.replace("Episode", "").strip()) |
|
|
|
|
|
|
|
|
def format_instruction_html(text: str) -> str: |
|
|
safe_text = html.escape(text) |
|
|
return ( |
|
|
'<div class="instruction-card">' |
|
|
'<p class="instruction-label">Language Instruction</p>' |
|
|
f'<p class="instruction-text">{safe_text}</p>' |
|
|
"</div>" |
|
|
) |
|
|
|
|
|
|
|
|
def build_interface(): |
|
|
revision = get_dataset_revision(DEFAULT_DATASET_ID) |
|
|
info = load_info(DEFAULT_DATASET_ID, revision) |
|
|
episode_indices = sorted(ep["episode_index"] for ep in info["episodes"]) |
|
|
if not episode_indices: |
|
|
raise RuntimeError("No episodes found in dataset metadata.") |
|
|
|
|
|
default_idx = episode_indices[0] |
|
|
default_label = format_episode_label(default_idx) |
|
|
default_data = load_episode(DEFAULT_DATASET_ID, default_idx, revision) |
|
|
default_video = str(default_data["rgb_path"]) |
|
|
default_instruction = default_data["instruction"] |
|
|
default_figs = {metric: build_plot_html(default_data, metric) for metric in METRIC_LABELS.keys()} |
|
|
|
|
|
total_frames = sum(ep.get("num_frames", 0) for ep in info["episodes"]) |
|
|
fps = info.get("fps", 30.0) |
|
|
stats_html = f""" |
|
|
<div class="stats-card"> |
|
|
<ul> |
|
|
<li><span>Number of samples/frames:</span> {total_frames:,}</li> |
|
|
<li><span>Number of episodes:</span> {len(episode_indices)}</li> |
|
|
<li><span>Frames per second:</span> {fps:.1f}</li> |
|
|
</ul> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
theme = gr.themes.Soft( |
|
|
primary_hue="cyan", secondary_hue="blue", neutral_hue="slate" |
|
|
).set( |
|
|
body_background_fill="#0c1424", |
|
|
body_text_color="#f8fafc", |
|
|
block_background_fill="#111a2c", |
|
|
block_title_text_color="#f8fafc", |
|
|
input_background_fill="#151f33", |
|
|
border_color_primary="#1f2b47", |
|
|
shadow_drop="none", |
|
|
) |
|
|
|
|
|
with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo: |
|
|
gr.Markdown("# Humanoid Robots Hand Pose Viewer") |
|
|
gr.Markdown( |
|
|
"Visualize RGB + 6DoF hand trajectories for all Moving_Mini tasks " |
|
|
"(humanoid-robots-training-dataset)." |
|
|
) |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
with gr.Column(scale=1, min_width=260, elem_classes=["side-panel"]): |
|
|
gr.HTML(stats_html) |
|
|
gr.HTML('<div class="episodes-title">Episodes</div>') |
|
|
episode_radio = gr.Radio( |
|
|
choices=[format_episode_label(i) for i in episode_indices], |
|
|
value=default_label, |
|
|
label="Episodes", |
|
|
elem_classes=["episode-list"], |
|
|
) |
|
|
with gr.Column(scale=2, min_width=640, elem_classes=["main-panel"]): |
|
|
instruction_box = gr.HTML( |
|
|
format_instruction_html(default_instruction), |
|
|
label="Language Instruction", |
|
|
) |
|
|
with gr.Column(elem_classes=["video-card"]): |
|
|
gr.HTML('<div class="video-title">RGB</div>') |
|
|
video = gr.Video( |
|
|
height=360, |
|
|
value=default_video, |
|
|
elem_classes=["video-panel"], |
|
|
show_label=False, |
|
|
show_download_button=False, |
|
|
) |
|
|
download_button = gr.DownloadButton( |
|
|
label="Download", |
|
|
value=default_video, |
|
|
elem_classes=["download-button"], |
|
|
) |
|
|
|
|
|
plot_outputs = [] |
|
|
gr.Markdown("### Joint trajectories", elem_classes=["plots-title"]) |
|
|
with gr.Column(elem_classes=["plots-wrap"]): |
|
|
for row in PLOT_GRID: |
|
|
with gr.Row(): |
|
|
for metric in row: |
|
|
plot = gr.HTML(value=default_figs[metric], elem_classes=["plot-html"]) |
|
|
plot_outputs.append(plot) |
|
|
|
|
|
outputs = [instruction_box, video, download_button] + plot_outputs |
|
|
|
|
|
def load_episode_payload(label: str): |
|
|
idx = parse_episode_label(label) |
|
|
data = load_episode(DEFAULT_DATASET_ID, idx, revision) |
|
|
video_path = str(data["rgb_path"]) |
|
|
figs = [build_plot_html(data, metric) for metric in PLOT_ORDER] |
|
|
return [ |
|
|
format_instruction_html(data["instruction"]), |
|
|
video_path, |
|
|
gr.DownloadButton.update(value=video_path), |
|
|
*figs, |
|
|
] |
|
|
|
|
|
episode_radio.change(fn=load_episode_payload, inputs=episode_radio, outputs=outputs) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
demo = build_interface() |
|
|
demo.queue().launch(show_api=False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|