robogen / app.py
HaptalAI's picture
Upload folder using huggingface_hub
7c19df4 verified
Raw
History Blame Contribute Delete
17.2 kB
"""
RoboGen β€” HaptalAI Synthetic Robotics Dataset Generator
Gradio 5.9.1 / Python 3.11
Step flow:
1 Robot selection (card-style radio)
2 Task dropdown
3 Parameter sliders + failure checkboxes
4 Generate button
5 Quality results dashboard
6 Email gate + zip download
"""
from __future__ import annotations
import os
import sys
import io
import zipfile
import tempfile
import traceback
from typing import Optional, Dict, List
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import gradio as gr
import pandas as pd
from generator import (
generate_dataset,
score_dataset,
annotate_quality_scores,
TASKS_BY_ROBOT,
ROBOT_CONFIG,
FAILURE_TYPES,
)
from readme_gen import generate_readme
from airtable import log_email
# ── CSS ───────────────────────────────────────────────────────────────────────
_here = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(_here, "style.css")) as _f:
CSS = _f.read()
# ── Constants ─────────────────────────────────────────────────────────────────
TASK_LABELS = {
"pick_and_place": "Pick and Place",
"push_object": "Push Object",
"grasp_and_lift": "Grasp and Lift",
"stacking": "Stacking",
"drawer_open_close": "Drawer Open / Close",
}
FAILURE_LABELS = {
"grasp_slip": "Grasp Slip",
"velocity_spike": "Velocity Spike",
"torque_saturation": "Torque Saturation",
}
DEFAULTS = {
"SO-100": {"n_eps": 50, "success": 70, "fmin": 1.0, "fmax": 10.0},
"SO-101": {"n_eps": 50, "success": 70, "fmin": 1.0, "fmax": 10.0},
"Koch": {"n_eps": 30, "success": 75, "fmin": 0.5, "fmax": 8.0},
}
# ── HTML helpers ──────────────────────────────────────────────────────────────
def _results_html(result: Dict, robot: str, task: str) -> str:
score = result["overall_score"]
band = result["band"]
n_pass = result["n_passed"]
n_flag = result["n_flagged"]
n_eps = result["n_episodes"]
mismatch = result["mean_mismatch"]
fb = result["failure_breakdown"]
scorer = result["scorer_used"]
band_cls = band.lower()
band_desc = {
"Clean": "Trajectories are smooth and anomaly-free. Ready for policy training.",
"Review": "Some anomalies detected. Review flagged episodes before training.",
"Flagged": "High anomaly rate. Best used for failure analysis and augmentation.",
}.get(band, "")
total = sum(fb.values()) or 1
bars = "".join(
f'<div class="rg-failure-bar">'
f'<span class="rg-failure-label">{FAILURE_LABELS.get(k,k)}</span>'
f'<div class="rg-bar-track"><div class="rg-bar-fill" style="width:{v/total*100:.0f}%"></div></div>'
f'<span class="rg-bar-count">{v}</span></div>'
for k, v in sorted(fb.items(), key=lambda x: -x[1])
)
task_label = TASK_LABELS.get(task, task)
return f"""
<div class="rg-results">
<div class="rg-score-row">
<div class="rg-score-circle {band_cls}">
<span class="rg-score-value">{score:.0f}</span>
<span class="rg-score-denom">/ 100</span>
</div>
<div class="rg-score-info">
<div class="rg-band-badge {band_cls}">{band}</div>
<div class="rg-band-desc">{band_desc}</div>
</div>
</div>
<div class="rg-stat-grid">
<div class="rg-stat"><div class="rg-stat-value">{n_eps}</div><div class="rg-stat-label">Total Episodes</div></div>
<div class="rg-stat"><div class="rg-stat-value" style="color:var(--green)">{n_pass}</div><div class="rg-stat-label">Passed</div></div>
<div class="rg-stat"><div class="rg-stat-value" style="color:var(--red)">{n_flag}</div><div class="rg-stat-label">Flagged</div></div>
<div class="rg-stat"><div class="rg-stat-value">{mismatch:.3f}</div><div class="rg-stat-label">Mean Mismatch</div></div>
<div class="rg-stat"><div class="rg-stat-value">{robot}</div><div class="rg-stat-label">Robot</div></div>
<div class="rg-stat"><div class="rg-stat-value" style="font-size:0.9rem">{task_label}</div><div class="rg-stat-label">Task</div></div>
</div>
<div class="rg-failure-section">
<div class="rg-failure-title">Failure Type Breakdown</div>
{bars or "No failure episodes in dataset."}
</div>
<div class="rg-scorer-note">
Scored by HaptalAI misalignment benchmark &middot; scorer: <code>{scorer}</code>
</div>
</div>"""
def _build_zip(df, result, robot, task, n_eps, success, fmin, fmax, failures) -> str:
df_out = annotate_quality_scores(df, result)
readme = generate_readme(
robot=robot, task=task, n_episodes=n_eps,
success_rate=success / 100, force_min=fmin, force_max=fmax,
failures=failures,
score=result["overall_score"], band=result["band"],
n_passed=result["n_passed"], n_flagged=result["n_flagged"],
mean_mismatch=result["mean_mismatch"],
failure_breakdown=result["failure_breakdown"],
scorer_used=result["scorer_used"],
)
tag = f"{robot.replace('-','')}_{task}"
fd, path = tempfile.mkstemp(suffix=".zip", prefix=f"robogen_{tag}_")
os.close(fd)
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
buf = io.BytesIO()
df_out.to_parquet(buf, index=False)
zf.writestr(f"robogen_{tag}.parquet", buf.getvalue())
zf.writestr("README.md", readme.encode("utf-8"))
return path
# ── Event handlers (module level β€” Gradio 5 requirement) ─────────────────────
def on_robot_select(robot: str):
if not robot:
return (
gr.update(visible=False),
gr.update(choices=[], value=None),
gr.update(visible=False),
gr.update(visible=False),
"",
)
tasks_raw = TASKS_BY_ROBOT[robot]
tasks_disp = [(TASK_LABELS.get(t, t), t) for t in tasks_raw]
return (
gr.update(visible=True),
gr.update(choices=tasks_disp, value=tasks_raw[0]),
gr.update(visible=False),
gr.update(visible=False),
robot,
)
def on_task_select(task: str, robot: str):
if not task or not robot:
return gr.update(visible=False), gr.update(visible=False), 50, 70, 1.0, 10.0
d = DEFAULTS.get(robot, DEFAULTS["SO-100"])
fr = ROBOT_CONFIG[robot]["force_range"]
return (
gr.update(visible=True),
gr.update(visible=True),
d["n_eps"],
d["success"],
fr[0],
fr[1],
)
def on_generate(robot, task, n_eps, success_pct, fmin, fmax, failures):
if not robot or not task:
return (
"Please complete steps 1 and 2 first.",
gr.update(visible=False), "",
gr.update(visible=False),
None, None,
)
if not failures:
failures = list(FAILURE_TYPES)
try:
df = generate_dataset(
robot=robot, task=task,
n_episodes=int(n_eps),
success_rate=float(success_pct) / 100,
force_min=float(fmin), force_max=float(fmax),
enabled_failures=list(failures),
seed=None,
)
result = score_dataset(df)
panel = _results_html(result, robot, task)
status = (
f"Generated {len(df):,} rows across {result['n_episodes']} episodes β€” "
f"score **{result['overall_score']:.1f}/100** ({result['band']})"
)
return (
status,
gr.update(visible=True), panel,
gr.update(visible=True),
df, result,
)
except Exception:
return (
f"Generation failed:\n```\n{traceback.format_exc()}\n```",
gr.update(visible=False), "",
gr.update(visible=False),
None, None,
)
def on_email_submit(email, robot, task, n_eps, success_pct, fmin, fmax, failures, df, result):
if not email or "@" not in email:
return "Please enter a valid email address.", gr.update(visible=False)
if df is None or result is None:
return "Generate a dataset first (Step 4).", gr.update(visible=False)
try:
ok, msg = log_email(
email=email.strip(), robot=robot, task=task,
n_episodes=int(n_eps),
quality_score=result["overall_score"],
band=result["band"],
)
if not ok:
print(f"[RoboGen] Airtable: {msg}")
except Exception as exc:
print(f"[RoboGen] Airtable exception: {exc}")
try:
path = _build_zip(
df=df, result=result, robot=robot, task=task,
n_eps=int(n_eps), success=float(success_pct),
fmin=float(fmin), fmax=float(fmax),
failures=list(failures),
)
return "Email confirmed. Your download is ready below.", gr.update(visible=True, value=path)
except Exception:
return (
f"Download preparation failed:\n```\n{traceback.format_exc()}\n```",
gr.update(visible=False),
)
# ── Build UI ──────────────────────────────────────────────────────────────────
with gr.Blocks(css=CSS, title="RoboGen") as demo:
robot_state = gr.State("")
df_state = gr.State(None)
result_state = gr.State(None)
gr.HTML("""
<div class="rg-header">
<div class="rg-logo">RoboGen</div>
<div class="rg-tagline">Synthetic robotics datasets, physics-accurate &amp; quality-scored</div>
<div class="rg-badge">LeRobot-format &nbsp;&middot;&nbsp; SO-100 / SO-101 / Koch &nbsp;&middot;&nbsp; HaptalAI</div>
</div>""")
# ── Step 1 ────────────────────────────────────────────────────────────────
with gr.Group(elem_classes=["step-card"]):
gr.HTML("""
<div class="step-header">
<span class="step-num">1</span>
<span class="step-title">Select Robot</span>
</div>""")
robot_select = gr.Radio(
choices=["SO-100", "Koch", "SO-101"],
value=None,
label="",
elem_classes=["robot-radio"],
)
# ── Step 2 ────────────────────────────────────────────────────────────────
with gr.Group(visible=False, elem_classes=["step-card"]) as step2_grp:
gr.HTML("""
<div class="step-header">
<span class="step-num">2</span>
<span class="step-title">Select Task</span>
</div>""")
task_select = gr.Dropdown(choices=[], value=None, label="Task", interactive=True)
# ── Step 3 ────────────────────────────────────────────────────────────────
with gr.Group(visible=False, elem_classes=["step-card"]) as step3_grp:
gr.HTML("""
<div class="step-header">
<span class="step-num">3</span>
<span class="step-title">Configure Parameters</span>
</div>""")
with gr.Row():
n_episodes_slider = gr.Slider(
minimum=10, maximum=500, value=50, step=5,
label="Number of Episodes",
info="Total episodes in the dataset (10–500)",
)
success_slider = gr.Slider(
minimum=0, maximum=100, value=70, step=5,
label="Success Rate (%)",
info="Fraction of episodes with successful trajectories",
)
with gr.Row():
force_min_slider = gr.Slider(
minimum=0.1, maximum=10.0, value=1.0, step=0.1,
label="Min Contact Force (N)",
info="Lower bound of spring-damper contact force during grasping",
)
force_max_slider = gr.Slider(
minimum=1.0, maximum=20.0, value=10.0, step=0.5,
label="Max Contact Force (N)",
info="Upper bound of contact force β€” higher = firmer grip",
)
gr.HTML("""
<div style="margin:4px 0 8px;font-size:0.82rem;color:#8892a4;">
<b>Failure types to include</b> &nbsp;
<span style="font-style:italic;">
Grasp Slip β€” gripper opens mid-episode &nbsp;|&nbsp;
Velocity Spike β€” servo glitch (z&gt;6.5) &nbsp;|&nbsp;
Torque Saturation β€” joint hits angular limit
</span>
</div>""")
failure_check = gr.CheckboxGroup(
choices=["grasp_slip", "velocity_spike", "torque_saturation"],
value=["grasp_slip", "velocity_spike", "torque_saturation"],
label="",
elem_classes=["checkbox-group"],
)
# ── Step 4 ────────────────────────────────────────────────────────────────
with gr.Group(visible=False, elem_classes=["step-card"]) as step4_grp:
gr.HTML("""
<div class="step-header">
<span class="step-num">4</span>
<span class="step-title">Generate Dataset</span>
</div>""")
generate_btn = gr.Button("Generate Dataset", elem_classes=["btn-generate"], size="lg")
gen_status = gr.Markdown("", elem_classes=["status-msg"])
# ── Step 5 ────────────────────────────────────────────────────────────────
with gr.Group(visible=False, elem_classes=["step-card"]) as step5_grp:
gr.HTML("""
<div class="step-header">
<span class="step-num">5</span>
<span class="step-title">Quality Results</span>
</div>""")
results_html = gr.HTML("")
# ── Step 6 ────────────────────────────────────────────────────────────────
with gr.Group(visible=False, elem_classes=["step-card"]) as step6_grp:
gr.HTML("""
<div class="step-header">
<span class="step-num">6</span>
<span class="step-title">Download Dataset</span>
</div>
<div class="email-gate-note">
Enter your email to unlock the download. You'll receive occasional
updates on new robot configs and dataset improvements.
</div>""")
with gr.Row():
email_input = gr.Textbox(
placeholder="you@example.com", label="Email",
scale=4, max_lines=1,
)
email_btn = gr.Button("Confirm", elem_classes=["btn-primary"], scale=1)
email_status = gr.Markdown("")
download_file = gr.File(label="Download robogen_dataset.zip", visible=False)
# ── Wire events ───────────────────────────────────────────────────────────
robot_select.change(
fn=on_robot_select,
inputs=[robot_select],
outputs=[step2_grp, task_select, step3_grp, step4_grp, robot_state],
api_name=False,
)
task_select.change(
fn=on_task_select,
inputs=[task_select, robot_state],
outputs=[step3_grp, step4_grp, n_episodes_slider, success_slider,
force_min_slider, force_max_slider],
api_name=False,
)
generate_btn.click(
fn=on_generate,
inputs=[robot_state, task_select, n_episodes_slider, success_slider,
force_min_slider, force_max_slider, failure_check],
outputs=[gen_status, step5_grp, results_html, step6_grp, df_state, result_state],
api_name=False,
)
email_btn.click(
fn=on_email_submit,
inputs=[email_input, robot_state, task_select,
n_episodes_slider, success_slider,
force_min_slider, force_max_slider,
failure_check, df_state, result_state],
outputs=[email_status, download_file],
api_name=False,
)
# ── Launch ────────────────────────────────────────────────────────────────────
demo.queue()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))