Spaces:
Sleeping
Sleeping
| import os, subprocess, json, pathlib, time, shutil | |
| import gradio as gr | |
| # ---------- CONSTANTS (visible in App Files) ---------- | |
| RUN_ROOT = "/home/user/app/runs" # where all runs live | |
| LOG_ROOT = "/home/user/app/logs" # global logs (avoid pre-creating run dirs) | |
| LAST_PTR = pathlib.Path(RUN_ROOT) / "LAST" # remembers most recent run path | |
| os.makedirs(RUN_ROOT, exist_ok=True) | |
| os.makedirs(LOG_ROOT, exist_ok=True) | |
| # ---------- ENV / HUB ---------- | |
| DEFAULT_REPO_ID = os.environ.get("REPO_ID", "") # e.g. "zino36/lerobot-pusht-colab" | |
| PUSH_DEFAULT = os.environ.get("PUSH_TO_HUB", "true").lower() in {"1","true","yes"} | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if HF_TOKEN: | |
| try: | |
| from huggingface_hub import login | |
| login(token=HF_TOKEN) | |
| except Exception as e: | |
| print("HF login failed:", e) | |
| # ---------- LOG HELPERS ---------- | |
| def _run(cmd: str, logfile: str): | |
| os.makedirs(os.path.dirname(logfile), exist_ok=True) | |
| with open(logfile, "a", buffering=1) as f: | |
| f.write("\n---- CMD ----\n" + cmd + "\n--------------\n") | |
| p = subprocess.Popen(cmd, shell=True, | |
| stdout=subprocess.PIPE, stderr=subprocess.STDOUT, | |
| text=True, bufsize=1) | |
| lines = [] | |
| for line in p.stdout: | |
| f.write(line) | |
| lines.append(line) | |
| p.wait() | |
| return p.returncode, "".join(lines[-200:]) | |
| def tail_file(path: str, n=200): | |
| if not os.path.exists(path): | |
| return "(no log yet)" | |
| with open(path, "r", errors="ignore") as f: | |
| lines = f.readlines() | |
| return "".join(lines[-n:]) | |
| # ---------- RUN DIR HELPERS ---------- | |
| def has_checkpoint(run_dir: str): | |
| """Checkpoint considered present once checkpoints/last/ exists (first save ~ step 500).""" | |
| return os.path.isdir(os.path.join(run_dir, "checkpoints", "last")) | |
| def newest_run(prefer_checkpoint: bool = True) -> str: | |
| """Pick newest pusht_* folder; optionally require a checkpoint.""" | |
| root = pathlib.Path(RUN_ROOT) | |
| if not root.exists(): | |
| return "" | |
| runs = sorted(root.glob("pusht_*"), key=lambda r: r.stat().st_mtime, reverse=True) | |
| if not runs: | |
| return "" | |
| if prefer_checkpoint: | |
| for r in runs: | |
| if has_checkpoint(str(r)): | |
| return str(r) | |
| return str(runs[0]) | |
| def new_run_dir(): | |
| """Return a unique run dir path WITHOUT creating it (LeRobot will create it).""" | |
| base = pathlib.Path(RUN_ROOT) / f"pusht_{int(time.time())}" | |
| d = base | |
| i = 1 | |
| while d.exists(): | |
| d = pathlib.Path(f"{base}_{i}") | |
| i += 1 | |
| LAST_PTR.write_text(str(d)) | |
| return str(d) | |
| def current_run_dir(user_override: str | None): | |
| """ | |
| Resolve which run to use: | |
| - If user typed something, accept folder name or full path. | |
| - Else try LAST pointer (if valid). | |
| - Else newest run WITH checkpoint. | |
| - Else newest run (even without checkpoint). | |
| - Else return "" (none). | |
| """ | |
| # A) explicit user input | |
| if user_override and user_override.strip(): | |
| p = user_override.strip() | |
| if not p.startswith("/"): | |
| p = str(pathlib.Path(RUN_ROOT) / p) | |
| return p | |
| # B) LAST pointer if present and valid | |
| if LAST_PTR.exists(): | |
| p = LAST_PTR.read_text().strip() | |
| if p and os.path.isdir(p): | |
| return p | |
| # C) newest run WITH checkpoint | |
| p = newest_run(prefer_checkpoint=True) | |
| if p: | |
| LAST_PTR.write_text(p) | |
| return p | |
| # D) newest run (no checkpoint yet) | |
| p = newest_run(prefer_checkpoint=False) | |
| if p: | |
| LAST_PTR.write_text(p) | |
| return p | |
| return "" | |
| def train_log_path_for_new(run_dir: str): | |
| """Fresh-run logs go to global LOG_ROOT so we don't pre-create run_dir.""" | |
| name = pathlib.Path(run_dir).name | |
| return os.path.join(LOG_ROOT, f"{name}.train.log") | |
| def train_log_path(run_dir: str): | |
| return os.path.join(run_dir, "logs", "train.log") | |
| def eval_log_path(run_dir: str): | |
| return os.path.join(run_dir, "logs", "eval.log") | |
| # ---------- ACTIONS ---------- | |
| def start_training(steps, batch_size, push_to_hub, repo_id): | |
| run_dir = new_run_dir() | |
| log = train_log_path_for_new(run_dir) | |
| push_flags = (f"--policy.push_to_hub=true --policy.repo_id='{repo_id.strip()}'" | |
| if push_to_hub and repo_id.strip() else | |
| "--policy.push_to_hub=false") | |
| cmd = ( | |
| "lerobot-train " | |
| f"--output_dir='{run_dir}' " | |
| "--policy.type=diffusion " | |
| "--dataset.repo_id=lerobot/pusht " | |
| "--env.type=pusht " | |
| f"--batch_size={batch_size} " | |
| f"--steps={steps} " | |
| "--eval_freq=500 " | |
| "--save_freq=500 " | |
| f"{push_flags}" | |
| ) | |
| rc, tail = _run(cmd, log) | |
| msg = f"Started fresh run at: {run_dir}\nTrain exited rc={rc}\n\n=== train.log tail ===\n{tail}" | |
| return msg, run_dir, tail_file(log) | |
| def resume_training(extra_steps, push_to_hub, repo_id, run_dir_text): | |
| """ | |
| Resume training from the newest run (with a checkpoint) if run_dir_text is blank. | |
| Uses the exact saved train_config.json when available; otherwise falls back to | |
| minimal CLI schema so the parser knows the policy/dataset/env types. | |
| """ | |
| # 1) Resolve which run to resume | |
| run_dir = current_run_dir(run_dir_text) | |
| if not run_dir: | |
| return ( | |
| "No run found on disk. Start a fresh training first (let it pass step 500 to create a checkpoint).", | |
| "", | |
| "(no log)", | |
| ) | |
| log = train_log_path(run_dir) | |
| # 2) Ensure a checkpoint exists | |
| if not has_checkpoint(run_dir): | |
| return ( | |
| f"Selected run: {run_dir}\nNo checkpoint in {run_dir}/checkpoints/last/ yet β run at least 500 steps once.", | |
| run_dir, | |
| tail_file(log), | |
| ) | |
| # 3) Push flags (optional) | |
| push_flags = ( | |
| f"--policy.push_to_hub=true --policy.repo_id='{repo_id.strip()}'" | |
| if (push_to_hub and repo_id and repo_id.strip()) | |
| else "--policy.push_to_hub=false" | |
| ) | |
| # 4) Prefer resuming with the saved config (LOCAL file -> use --config-file) | |
| cfg_path = os.path.join(run_dir, "train_config.json") | |
| parts = [ | |
| "lerobot-train", | |
| f"--output_dir='{run_dir}'", | |
| "--resume=true", | |
| f"--steps={extra_steps}", | |
| "--eval_freq=500", | |
| "--save_freq=500", | |
| ] | |
| if os.path.exists(cfg_path): | |
| parts.insert(3, f"--config-file='{cfg_path}'") | |
| else: | |
| # Fallback minimal schema so the CLI can parse types correctly | |
| parts.extend([ | |
| "--policy.type=diffusion", | |
| "--dataset.repo_id=lerobot/pusht", | |
| "--env.type=pusht", | |
| ]) | |
| parts.append(push_flags) | |
| cmd = " ".join(parts) | |
| # 5) Run and return logs | |
| rc, tail = _run(cmd, log) | |
| msg = f"Resumed run at: {run_dir}\nResume exited rc={rc}\n\n=== train.log tail ===\n{tail}" | |
| return msg, run_dir, tail_file(log) | |
| def eval_latest(run_dir_text): | |
| run_dir = current_run_dir(run_dir_text) | |
| if not run_dir: | |
| return "No run found yet. Start a fresh training first.", "", "(no log)", "(no metrics)" | |
| elog = eval_log_path(run_dir) | |
| if not has_checkpoint(run_dir): | |
| return f"No checkpoint in {run_dir}/checkpoints/last/ to evaluate.", run_dir, tail_file(elog), "(no metrics)" | |
| ckpt = os.path.join(run_dir, "checkpoints", "last", "pretrained_model") | |
| eval_out_dir = os.path.join(run_dir, "eval_latest") | |
| os.makedirs(eval_out_dir, exist_ok=True) | |
| cmd = ( | |
| "lerobot-eval " | |
| f"--policy.path='{ckpt}' " | |
| "--env.type=pusht " | |
| "--eval.n_episodes=100 " | |
| "--eval.batch_size=50 " | |
| f"--output_dir='{eval_out_dir}'" | |
| ) | |
| rc, tail = _run(cmd, elog) | |
| # --- parse printed dict and write metrics.json if missing --- | |
| import re, ast | |
| metrics_txt = "(metrics.json not found)" | |
| p = pathlib.Path(eval_out_dir) / "metrics.json" | |
| m = re.findall(r"\{[^}]+pc_success[^}]+\}", tail, flags=re.S) | |
| if m: | |
| try: | |
| d = ast.literal_eval(m[-1]) | |
| out = { | |
| "success_rate": d.get("pc_success"), | |
| "avg_max_overlap": d.get("avg_max_reward"), | |
| "avg_sum_reward": d.get("avg_sum_reward"), | |
| "eval_s": d.get("eval_s"), | |
| } | |
| p.write_text(json.dumps(out, indent=2)) | |
| metrics_txt = f"Success rate: {out['success_rate']}\nAvg max overlap: {out['avg_max_overlap']}" | |
| except Exception: | |
| pass | |
| elif p.exists(): | |
| try: | |
| d = json.loads(p.read_text()) | |
| metrics_txt = f"Success rate: {d.get('success_rate')}\nAvg max overlap: {d.get('avg_max_overlap')}" | |
| except Exception: | |
| metrics_txt = "(could not parse metrics.json)" | |
| # --- end patch --- | |
| msg = f"Evaluated run at: {run_dir}\nEval exited rc={rc}\n\n=== eval.log tail ===\n{tail}" | |
| return msg, run_dir, tail_file(elog), metrics_txt | |
| # ---------- Maintenance (list / delete runs) ---------- | |
| def list_runs(): | |
| root = pathlib.Path(RUN_ROOT) | |
| if not root.exists(): | |
| return "(no runs)" | |
| rows = [] | |
| for d in sorted(root.glob("pusht_*")): | |
| try: | |
| size = subprocess.check_output( | |
| ["bash","-lc", f"du -sh {d} | cut -f1"], text=True | |
| ).strip() | |
| except Exception: | |
| size = "?" | |
| ck = "β" if has_checkpoint(str(d)) else "β" | |
| rows.append(f"{d.name}\t{size}\tcheckpoint:{ck}") | |
| return "name\tsize\tcheckpoint\n" + "\n".join(rows) if rows else "(no runs)" | |
| def delete_run_by_name(name: str): | |
| name = os.path.basename((name or "").strip()) | |
| if not name: | |
| return "Type a folder like 'pusht_1234567890'.", list_runs() | |
| target = os.path.join(RUN_ROOT, name) | |
| if not target.startswith(RUN_ROOT + "/"): | |
| return "Refusing to delete outside runs/.", list_runs() | |
| if not os.path.isdir(target): | |
| return f"Folder not found: {target}", list_runs() | |
| shutil.rmtree(target, ignore_errors=True) | |
| if LAST_PTR.exists() and LAST_PTR.read_text().strip() == target: | |
| LAST_PTR.unlink(missing_ok=True) | |
| return f"Deleted {target}", list_runs() | |
| def delete_all_runs(): | |
| if not os.path.isdir(RUN_ROOT): | |
| return "(runs/ missing)", list_runs() | |
| for n in os.listdir(RUN_ROOT): | |
| p = os.path.join(RUN_ROOT, n) | |
| if os.path.isdir(p) and n.startswith("pusht_"): | |
| shutil.rmtree(p, ignore_errors=True) | |
| LAST_PTR.unlink(missing_ok=True) | |
| return "Deleted all pusht_* runs.", list_runs() | |
| # ---------- UI ---------- | |
| with gr.Blocks(title="LeRobot PushT Trainer (Space)") as demo: | |
| gr.Markdown("# π€ LeRobot PushT Trainer\nTrain / Resume / Evaluate. Files persist under `/home/user/app/runs/` (see App Files).") | |
| with gr.Row(): | |
| repo_id = gr.Textbox(label="Hugging Face Model Repo (optional)", value=DEFAULT_REPO_ID, placeholder="username/repo-name") | |
| push_to_hub = gr.Checkbox(label="Push checkpoints to Hub", value=PUSH_DEFAULT) | |
| with gr.Row(): | |
| steps = gr.Slider(200, 20000, value=2000, step=100, label="Training steps (fresh run)") | |
| batch = gr.Slider(4, 64, value=16, step=2, label="Batch size") | |
| start_btn = gr.Button("π Start Fresh Training") | |
| start_out = gr.Textbox(label="Start Output") | |
| run_dir_view = gr.Textbox(label="Current run directory (auto-filled after start)") | |
| train_log = gr.Textbox(label="train.log (tail)", lines=20) | |
| gr.Markdown("### Resume / Evaluate a Specific Run") | |
| run_dir_text = gr.Textbox(label="Run directory (leave blank to use the latest)") | |
| with gr.Row(): | |
| extra_steps = gr.Slider(200, 20000, value=2000, step=100, label="Steps to add on resume") | |
| resume_btn = gr.Button("βΆοΈ Resume from Last Checkpoint") | |
| resume_out = gr.Textbox(label="Resume Output") | |
| resume_log = gr.Textbox(label="train.log (tail)", lines=20) | |
| gr.Markdown("### Evaluate Latest Checkpoint of Selected Run") | |
| eval_btn = gr.Button("π Evaluate Latest") | |
| eval_out = gr.Textbox(label="Eval Output") | |
| eval_log = gr.Textbox(label="eval.log (tail)", lines=20) | |
| metrics_box = gr.Textbox(label="Parsed metrics (if metrics.json exists)") | |
| gr.Markdown("### Runs on disk") | |
| list_btn = gr.Button("π List runs folder") | |
| list_out = gr.Textbox(label="runs/ listing", lines=12) | |
| gr.Markdown("### Maintenance") | |
| del_name = gr.Textbox(label="Run folder name to delete (e.g., pusht_1699999999)") | |
| del_one_btn = gr.Button("ποΈ Delete this run") | |
| del_all_btn = gr.Button("π§Ή Delete ALL pusht_* runs") | |
| # Wiring | |
| start_btn.click(start_training, inputs=[steps, batch, push_to_hub, repo_id], outputs=[start_out, run_dir_view, train_log]) | |
| resume_btn.click(resume_training, inputs=[extra_steps, push_to_hub, repo_id, run_dir_text], outputs=[resume_out, run_dir_view, resume_log]) | |
| eval_btn.click(eval_latest, inputs=[run_dir_text], outputs=[eval_out, run_dir_view, eval_log, metrics_box]) | |
| list_btn.click(list_runs, outputs=list_out) | |
| del_one_btn.click(delete_run_by_name, inputs=del_name, outputs=[list_out, list_out]) | |
| del_all_btn.click(delete_all_runs, outputs=[list_out, list_out]) | |
| if __name__ == "__main__": | |
| demo.launch() |