"""
RoboGen — HaptalAI Synthetic Robotics Dataset Generator
Gradio 5.9.1 / Python 3.11
Step flow:
1 Robot selection (card-style radio)
2 Task dropdown
3 Parameter sliders + failure checkboxes
4 Generate button
5 Quality results dashboard
6 Email gate + zip download
"""
from __future__ import annotations
import os
import sys
import io
import zipfile
import tempfile
import traceback
from typing import Optional, Dict, List
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import gradio as gr
import pandas as pd
from generator import (
generate_dataset,
score_dataset,
annotate_quality_scores,
TASKS_BY_ROBOT,
ROBOT_CONFIG,
FAILURE_TYPES,
)
from readme_gen import generate_readme
from airtable import log_email
# ── CSS ───────────────────────────────────────────────────────────────────────
_here = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(_here, "style.css")) as _f:
CSS = _f.read()
# ── Constants ─────────────────────────────────────────────────────────────────
TASK_LABELS = {
"pick_and_place": "Pick and Place",
"push_object": "Push Object",
"grasp_and_lift": "Grasp and Lift",
"stacking": "Stacking",
"drawer_open_close": "Drawer Open / Close",
}
FAILURE_LABELS = {
"grasp_slip": "Grasp Slip",
"velocity_spike": "Velocity Spike",
"torque_saturation": "Torque Saturation",
}
DEFAULTS = {
"SO-100": {"n_eps": 50, "success": 70, "fmin": 1.0, "fmax": 10.0},
"SO-101": {"n_eps": 50, "success": 70, "fmin": 1.0, "fmax": 10.0},
"Koch": {"n_eps": 30, "success": 75, "fmin": 0.5, "fmax": 8.0},
}
# ── HTML helpers ──────────────────────────────────────────────────────────────
def _results_html(result: Dict, robot: str, task: str) -> str:
score = result["overall_score"]
band = result["band"]
n_pass = result["n_passed"]
n_flag = result["n_flagged"]
n_eps = result["n_episodes"]
mismatch = result["mean_mismatch"]
fb = result["failure_breakdown"]
scorer = result["scorer_used"]
band_cls = band.lower()
band_desc = {
"Clean": "Trajectories are smooth and anomaly-free. Ready for policy training.",
"Review": "Some anomalies detected. Review flagged episodes before training.",
"Flagged": "High anomaly rate. Best used for failure analysis and augmentation.",
}.get(band, "")
total = sum(fb.values()) or 1
bars = "".join(
f'
'
f'
{FAILURE_LABELS.get(k,k)}'
f'
'
f'
{v} '
for k, v in sorted(fb.items(), key=lambda x: -x[1])
)
task_label = TASK_LABELS.get(task, task)
return f"""
{mismatch:.3f}
Mean Mismatch
Failure Type Breakdown
{bars or "No failure episodes in dataset."}
Scored by HaptalAI misalignment benchmark · scorer: {scorer}
"""
def _build_zip(df, result, robot, task, n_eps, success, fmin, fmax, failures) -> str:
df_out = annotate_quality_scores(df, result)
readme = generate_readme(
robot=robot, task=task, n_episodes=n_eps,
success_rate=success / 100, force_min=fmin, force_max=fmax,
failures=failures,
score=result["overall_score"], band=result["band"],
n_passed=result["n_passed"], n_flagged=result["n_flagged"],
mean_mismatch=result["mean_mismatch"],
failure_breakdown=result["failure_breakdown"],
scorer_used=result["scorer_used"],
)
tag = f"{robot.replace('-','')}_{task}"
fd, path = tempfile.mkstemp(suffix=".zip", prefix=f"robogen_{tag}_")
os.close(fd)
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
buf = io.BytesIO()
df_out.to_parquet(buf, index=False)
zf.writestr(f"robogen_{tag}.parquet", buf.getvalue())
zf.writestr("README.md", readme.encode("utf-8"))
return path
# ── Event handlers (module level — Gradio 5 requirement) ─────────────────────
def on_robot_select(robot: str):
if not robot:
return (
gr.update(visible=False),
gr.update(choices=[], value=None),
gr.update(visible=False),
gr.update(visible=False),
"",
)
tasks_raw = TASKS_BY_ROBOT[robot]
tasks_disp = [(TASK_LABELS.get(t, t), t) for t in tasks_raw]
return (
gr.update(visible=True),
gr.update(choices=tasks_disp, value=tasks_raw[0]),
gr.update(visible=False),
gr.update(visible=False),
robot,
)
def on_task_select(task: str, robot: str):
if not task or not robot:
return gr.update(visible=False), gr.update(visible=False), 50, 70, 1.0, 10.0
d = DEFAULTS.get(robot, DEFAULTS["SO-100"])
fr = ROBOT_CONFIG[robot]["force_range"]
return (
gr.update(visible=True),
gr.update(visible=True),
d["n_eps"],
d["success"],
fr[0],
fr[1],
)
def on_generate(robot, task, n_eps, success_pct, fmin, fmax, failures):
if not robot or not task:
return (
"Please complete steps 1 and 2 first.",
gr.update(visible=False), "",
gr.update(visible=False),
None, None,
)
if not failures:
failures = list(FAILURE_TYPES)
try:
df = generate_dataset(
robot=robot, task=task,
n_episodes=int(n_eps),
success_rate=float(success_pct) / 100,
force_min=float(fmin), force_max=float(fmax),
enabled_failures=list(failures),
seed=None,
)
result = score_dataset(df)
panel = _results_html(result, robot, task)
status = (
f"Generated {len(df):,} rows across {result['n_episodes']} episodes — "
f"score **{result['overall_score']:.1f}/100** ({result['band']})"
)
return (
status,
gr.update(visible=True), panel,
gr.update(visible=True),
df, result,
)
except Exception:
return (
f"Generation failed:\n```\n{traceback.format_exc()}\n```",
gr.update(visible=False), "",
gr.update(visible=False),
None, None,
)
def on_email_submit(email, robot, task, n_eps, success_pct, fmin, fmax, failures, df, result):
if not email or "@" not in email:
return "Please enter a valid email address.", gr.update(visible=False)
if df is None or result is None:
return "Generate a dataset first (Step 4).", gr.update(visible=False)
try:
ok, msg = log_email(
email=email.strip(), robot=robot, task=task,
n_episodes=int(n_eps),
quality_score=result["overall_score"],
band=result["band"],
)
if not ok:
print(f"[RoboGen] Airtable: {msg}")
except Exception as exc:
print(f"[RoboGen] Airtable exception: {exc}")
try:
path = _build_zip(
df=df, result=result, robot=robot, task=task,
n_eps=int(n_eps), success=float(success_pct),
fmin=float(fmin), fmax=float(fmax),
failures=list(failures),
)
return "Email confirmed. Your download is ready below.", gr.update(visible=True, value=path)
except Exception:
return (
f"Download preparation failed:\n```\n{traceback.format_exc()}\n```",
gr.update(visible=False),
)
# ── Build UI ──────────────────────────────────────────────────────────────────
with gr.Blocks(css=CSS, title="RoboGen") as demo:
robot_state = gr.State("")
df_state = gr.State(None)
result_state = gr.State(None)
gr.HTML("""
""")
# ── Step 1 ────────────────────────────────────────────────────────────────
with gr.Group(elem_classes=["step-card"]):
gr.HTML("""
""")
robot_select = gr.Radio(
choices=["SO-100", "Koch", "SO-101"],
value=None,
label="",
elem_classes=["robot-radio"],
)
# ── Step 2 ────────────────────────────────────────────────────────────────
with gr.Group(visible=False, elem_classes=["step-card"]) as step2_grp:
gr.HTML("""
""")
task_select = gr.Dropdown(choices=[], value=None, label="Task", interactive=True)
# ── Step 3 ────────────────────────────────────────────────────────────────
with gr.Group(visible=False, elem_classes=["step-card"]) as step3_grp:
gr.HTML("""
""")
with gr.Row():
n_episodes_slider = gr.Slider(
minimum=10, maximum=500, value=50, step=5,
label="Number of Episodes",
info="Total episodes in the dataset (10–500)",
)
success_slider = gr.Slider(
minimum=0, maximum=100, value=70, step=5,
label="Success Rate (%)",
info="Fraction of episodes with successful trajectories",
)
with gr.Row():
force_min_slider = gr.Slider(
minimum=0.1, maximum=10.0, value=1.0, step=0.1,
label="Min Contact Force (N)",
info="Lower bound of spring-damper contact force during grasping",
)
force_max_slider = gr.Slider(
minimum=1.0, maximum=20.0, value=10.0, step=0.5,
label="Max Contact Force (N)",
info="Upper bound of contact force — higher = firmer grip",
)
gr.HTML("""
Failure types to include
Grasp Slip — gripper opens mid-episode |
Velocity Spike — servo glitch (z>6.5) |
Torque Saturation — joint hits angular limit
""")
failure_check = gr.CheckboxGroup(
choices=["grasp_slip", "velocity_spike", "torque_saturation"],
value=["grasp_slip", "velocity_spike", "torque_saturation"],
label="",
elem_classes=["checkbox-group"],
)
# ── Step 4 ────────────────────────────────────────────────────────────────
with gr.Group(visible=False, elem_classes=["step-card"]) as step4_grp:
gr.HTML("""
""")
generate_btn = gr.Button("Generate Dataset", elem_classes=["btn-generate"], size="lg")
gen_status = gr.Markdown("", elem_classes=["status-msg"])
# ── Step 5 ────────────────────────────────────────────────────────────────
with gr.Group(visible=False, elem_classes=["step-card"]) as step5_grp:
gr.HTML("""
""")
results_html = gr.HTML("")
# ── Step 6 ────────────────────────────────────────────────────────────────
with gr.Group(visible=False, elem_classes=["step-card"]) as step6_grp:
gr.HTML("""
Enter your email to unlock the download. You'll receive occasional
updates on new robot configs and dataset improvements.
""")
with gr.Row():
email_input = gr.Textbox(
placeholder="you@example.com", label="Email",
scale=4, max_lines=1,
)
email_btn = gr.Button("Confirm", elem_classes=["btn-primary"], scale=1)
email_status = gr.Markdown("")
download_file = gr.File(label="Download robogen_dataset.zip", visible=False)
# ── Wire events ───────────────────────────────────────────────────────────
robot_select.change(
fn=on_robot_select,
inputs=[robot_select],
outputs=[step2_grp, task_select, step3_grp, step4_grp, robot_state],
api_name=False,
)
task_select.change(
fn=on_task_select,
inputs=[task_select, robot_state],
outputs=[step3_grp, step4_grp, n_episodes_slider, success_slider,
force_min_slider, force_max_slider],
api_name=False,
)
generate_btn.click(
fn=on_generate,
inputs=[robot_state, task_select, n_episodes_slider, success_slider,
force_min_slider, force_max_slider, failure_check],
outputs=[gen_status, step5_grp, results_html, step6_grp, df_state, result_state],
api_name=False,
)
email_btn.click(
fn=on_email_submit,
inputs=[email_input, robot_state, task_select,
n_episodes_slider, success_slider,
force_min_slider, force_max_slider,
failure_check, df_state, result_state],
outputs=[email_status, download_file],
api_name=False,
)
# ── Launch ────────────────────────────────────────────────────────────────────
demo.queue()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))