hynky's picture
Publish LeRobot eval viewer Space
0dba2e6 verified
#!/usr/bin/env python
"""Browse LeRobot eval artifacts stored in a Hugging Face Bucket."""
from __future__ import annotations
import argparse
import json
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import gradio as gr
from gradio_rerun import Rerun
from huggingface_hub import download_bucket_files, list_bucket_tree
DEFAULT_BUCKET = "macrodata/lerobot-evals"
DEFAULT_CACHE_DIR = Path(os.environ.get("LEROBOT_EVAL_VIEWER_CACHE", "~/.cache/lerobot/eval_viewer")).expanduser()
RUN_MANIFEST_RE = re.compile(r"^runs/(?P<run_id>[^/]+)/(?P<run_time>[^/]+)/manifest\.json$")
EVAL_INFO_RE = re.compile(
r"^runs/(?P<run_id>[^/]+)/(?P<run_time>[^/]+)/evals/(?P<eval_type>[^/]+)/eval_info\.json$"
)
EPISODE_METADATA_RE = re.compile(
r"^runs/(?P<run_id>[^/]+)/(?P<run_time>[^/]+)/evals/(?P<eval_type>[^/]+)/episodes/"
r"(?P<episode_id>[^/]+)/metadata\.json$"
)
@dataclass(frozen=True)
class EvalIndex:
files: set[str]
runs: list[str]
evals_by_run: dict[str, list[str]]
episodes_by_eval: dict[str, list[str]]
def _file_path(item: Any) -> str | None:
if getattr(item, "type", None) not in (None, "file"):
return None
path = getattr(item, "path", None)
return str(path) if path else None
def _run_key(run_id: str, run_time: str) -> str:
return f"{run_id}/{run_time}"
def _eval_key(run_id: str, run_time: str, eval_type: str) -> str:
return f"{run_id}/{run_time}/{eval_type}"
def _split_eval_key(eval_key: str) -> tuple[str, str, str]:
run_id, run_time, eval_type = eval_key.split("/", 2)
return run_id, run_time, eval_type
def _base_path(run_id: str, run_time: str) -> str:
return f"runs/{run_id}/{run_time}"
def _eval_path(run_id: str, run_time: str, eval_type: str) -> str:
return f"{_base_path(run_id, run_time)}/evals/{eval_type}"
def _local_path(cache_dir: Path, bucket_id: str, remote_path: str) -> Path:
namespace, bucket_name = bucket_id.split("/", 1) if "/" in bucket_id else ("me", bucket_id)
return cache_dir / namespace / bucket_name / remote_path
def _download(bucket_id: str, remote_path: str, cache_dir: Path) -> Path | None:
local_path = _local_path(cache_dir, bucket_id, remote_path)
if local_path.exists():
return local_path
local_path.parent.mkdir(parents=True, exist_ok=True)
download_bucket_files(
bucket_id,
files=[(remote_path, local_path)],
raise_on_missing_files=False,
)
return local_path if local_path.exists() else None
def _read_text(bucket_id: str, remote_path: str, cache_dir: Path, max_chars: int | None = None) -> str:
path = _download(bucket_id, remote_path, cache_dir)
if path is None:
return ""
text = path.read_text(encoding="utf-8", errors="replace")
if max_chars is not None and len(text) > max_chars:
return text[-max_chars:]
return text
def _read_json(bucket_id: str, remote_path: str, cache_dir: Path) -> dict[str, Any]:
text = _read_text(bucket_id, remote_path, cache_dir)
if not text:
return {}
return json.loads(text)
def _build_index(bucket_id: str) -> EvalIndex:
paths = []
for item in list_bucket_tree(bucket_id, prefix="runs", recursive=True):
path = _file_path(item)
if path:
paths.append(path)
files = set(paths)
runs = set()
evals_by_run: dict[str, set[str]] = {}
episodes_by_eval: dict[str, set[str]] = {}
for path in paths:
if match := RUN_MANIFEST_RE.match(path):
key = _run_key(match["run_id"], match["run_time"])
runs.add(key)
evals_by_run.setdefault(key, set())
continue
if match := EVAL_INFO_RE.match(path):
run_key = _run_key(match["run_id"], match["run_time"])
eval_key = _eval_key(match["run_id"], match["run_time"], match["eval_type"])
runs.add(run_key)
evals_by_run.setdefault(run_key, set()).add(eval_key)
episodes_by_eval.setdefault(eval_key, set())
continue
if match := EPISODE_METADATA_RE.match(path):
run_key = _run_key(match["run_id"], match["run_time"])
eval_key = _eval_key(match["run_id"], match["run_time"], match["eval_type"])
runs.add(run_key)
evals_by_run.setdefault(run_key, set()).add(eval_key)
episodes_by_eval.setdefault(eval_key, set()).add(match["episode_id"])
return EvalIndex(
files=files,
runs=sorted(runs, reverse=True),
evals_by_run={key: sorted(value) for key, value in evals_by_run.items()},
episodes_by_eval={key: sorted(value) for key, value in episodes_by_eval.items()},
)
def _summarize_eval(info: dict[str, Any]) -> dict[str, Any]:
overall = info.get("overall") or info.get("aggregated") or {}
if not isinstance(overall, dict):
return {}
keys = ("pc_success", "avg_sum_reward", "avg_max_reward", "n_episodes", "eval_s", "eval_ep_s")
return {key: overall.get(key) for key in keys if key in overall}
def _trace_table(bucket_id: str, remote_path: str, cache_dir: Path, limit: int = 2000) -> tuple[list[str], list[list[Any]]]:
text = _read_text(bucket_id, remote_path, cache_dir)
if not text:
return [], []
rows = []
for line in text.splitlines()[:limit]:
if line.strip():
rows.append(json.loads(line))
if not rows:
return [], []
preferred = ["frame_index", "timestamp", "reward", "next.success", "done"]
vector_keys = [key for key in ("action", "observation.state") if key in rows[0]]
headers = preferred + vector_keys
table = []
for row in rows:
table.append([_table_cell(row.get(key)) for key in headers])
return headers, table
def _table_cell(value: Any) -> Any:
if isinstance(value, (dict, list)):
return json.dumps(value)
return value
def _choices(values: list[str], value: str | None = None) -> gr.Dropdown:
return gr.update(choices=values, value=value if value in values else (values[0] if values else None))
def _trace_update(headers: list[str] | None = None, rows: list[list[Any]] | None = None) -> gr.Dataframe:
headers = headers or []
rows = rows or []
return gr.update(headers=headers, value=rows, col_count=(len(headers), "dynamic"))
def build_app(default_bucket: str, cache_dir: Path) -> gr.Blocks:
css = """
.metric-panel textarea {font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;}
.rerun-panel {min-height: 720px;}
"""
def refresh(bucket_id: str):
index = _build_index(bucket_id)
empty_trace = _trace_update()
return index, _choices(index.runs), _choices([]), _choices([]), {}, "", "", None, empty_trace
def select_run(bucket_id: str, run_key: str | None, index: EvalIndex):
if not run_key:
return _choices([]), {}, ""
run_id, run_time = run_key.split("/", 1)
manifest_path = f"{_base_path(run_id, run_time)}/manifest.json"
manifest = _read_json(bucket_id, manifest_path, cache_dir)
evals = index.evals_by_run.get(run_key, [])
return _choices(evals), manifest, f"hf://buckets/{bucket_id}/{_base_path(run_id, run_time)}"
def select_eval(bucket_id: str, eval_key: str | None, index: EvalIndex):
if not eval_key:
return {}, _choices([]), "", ""
run_id, run_time, eval_type = _split_eval_key(eval_key)
eval_base = _eval_path(run_id, run_time, eval_type)
info = _read_json(bucket_id, f"{eval_base}/eval_info.json", cache_dir)
command = _read_text(bucket_id, f"{eval_base}/command.txt", cache_dir)
logs = _read_text(bucket_id, f"{eval_base}/logs.txt", cache_dir, max_chars=60_000)
episodes = index.episodes_by_eval.get(eval_key, [])
return _summarize_eval(info), _choices(episodes), command, logs
def select_episode(bucket_id: str, eval_key: str | None, episode_id: str | None, index: EvalIndex):
if not eval_key or not episode_id:
return {}, _trace_update(), None
run_id, run_time, eval_type = _split_eval_key(eval_key)
eval_base = _eval_path(run_id, run_time, eval_type)
episode_base = f"{eval_base}/episodes/{episode_id}"
metadata = _read_json(bucket_id, f"{episode_base}/metadata.json", cache_dir)
headers, rows = _trace_table(bucket_id, f"{episode_base}/trace.jsonl", cache_dir)
rrd_path = _download(bucket_id, f"{episode_base}/episode.rrd", cache_dir)
trace_update = _trace_update(headers, rows)
return metadata, trace_update, str(rrd_path) if rrd_path else None
with gr.Blocks(title="LeRobot Eval Viewer", css=css) as app:
index_state = gr.State(EvalIndex(files=set(), runs=[], evals_by_run={}, episodes_by_eval={}))
gr.Markdown("# LeRobot Eval Viewer")
gr.Markdown("Browse Hugging Face Bucket eval artifacts, inspect traces, and open episode `.rrd` files in Rerun.")
with gr.Row():
bucket = gr.Textbox(value=default_bucket, label="HF Bucket", scale=2)
refresh_button = gr.Button("Refresh", variant="primary", scale=0)
with gr.Row():
run_dropdown = gr.Dropdown(label="Run", choices=[], interactive=True)
eval_dropdown = gr.Dropdown(label="Eval", choices=[], interactive=True)
episode_dropdown = gr.Dropdown(label="Episode", choices=[], interactive=True)
run_uri = gr.Textbox(label="Run URI", interactive=False)
with gr.Row():
metrics = gr.JSON(label="Metrics", elem_classes=["metric-panel"])
manifest = gr.JSON(label="Manifest", elem_classes=["metric-panel"])
episode_metadata = gr.JSON(label="Episode Metadata", elem_classes=["metric-panel"])
with gr.Tab("Trace"):
trace = gr.Dataframe(
label="Trace",
headers=[],
value=[],
col_count=(0, "dynamic"),
wrap=True,
interactive=False,
)
with gr.Tab("Rerun"):
rerun = Rerun(
label="Rerun Episode",
streaming=True,
elem_classes=["rerun-panel"],
panel_states={
"blueprint": "collapsed",
"selection": "collapsed",
"time": "expanded",
},
)
with gr.Tab("Command"):
command = gr.Code(label="command.txt", language="shell")
with gr.Tab("Logs"):
logs = gr.Code(label="logs.txt", language="shell", lines=24)
refresh_button.click(
refresh,
inputs=[bucket],
outputs=[
index_state,
run_dropdown,
eval_dropdown,
episode_dropdown,
metrics,
command,
logs,
rerun,
trace,
],
)
bucket.submit(
refresh,
inputs=[bucket],
outputs=[
index_state,
run_dropdown,
eval_dropdown,
episode_dropdown,
metrics,
command,
logs,
rerun,
trace,
],
)
run_dropdown.change(
select_run,
inputs=[bucket, run_dropdown, index_state],
outputs=[eval_dropdown, manifest, run_uri],
)
eval_dropdown.change(
select_eval,
inputs=[bucket, eval_dropdown, index_state],
outputs=[metrics, episode_dropdown, command, logs],
)
episode_dropdown.change(
select_episode,
inputs=[bucket, eval_dropdown, episode_dropdown, index_state],
outputs=[episode_metadata, trace, rerun],
)
app.load(
refresh,
inputs=[bucket],
outputs=[
index_state,
run_dropdown,
eval_dropdown,
episode_dropdown,
metrics,
command,
logs,
rerun,
trace,
],
)
return app
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Launch a Gradio viewer for LeRobot eval bucket artifacts.")
parser.add_argument("--bucket", default=os.environ.get("LEROBOT_EVAL_BUCKET", DEFAULT_BUCKET))
parser.add_argument("--cache-dir", type=Path, default=DEFAULT_CACHE_DIR)
default_host = "0.0.0.0" if os.environ.get("SPACE_ID") else "127.0.0.1"
parser.add_argument("--host", default=os.environ.get("GRADIO_SERVER_NAME", default_host))
parser.add_argument("--port", type=int, default=int(os.environ.get("GRADIO_SERVER_PORT", "7860")))
parser.add_argument("--share", action="store_true", help="Create a public Gradio share URL.")
return parser.parse_args()
def main() -> None:
args = parse_args()
args.cache_dir.mkdir(parents=True, exist_ok=True)
app = build_app(default_bucket=args.bucket, cache_dir=args.cache_dir)
app.launch(
server_name=args.host,
server_port=args.port,
share=args.share,
ssr_mode=False,
allowed_paths=[str(args.cache_dir.resolve())],
)
if __name__ == "__main__":
main()