zino36's picture
Update app.py
4df282e verified
import os, subprocess, json, pathlib, time, shutil
import gradio as gr
# ---------- CONSTANTS (visible in App Files) ----------
RUN_ROOT = "/home/user/app/runs" # where all runs live
LOG_ROOT = "/home/user/app/logs" # global logs (avoid pre-creating run dirs)
LAST_PTR = pathlib.Path(RUN_ROOT) / "LAST" # remembers most recent run path
os.makedirs(RUN_ROOT, exist_ok=True)
os.makedirs(LOG_ROOT, exist_ok=True)
# ---------- ENV / HUB ----------
DEFAULT_REPO_ID = os.environ.get("REPO_ID", "") # e.g. "zino36/lerobot-pusht-colab"
PUSH_DEFAULT = os.environ.get("PUSH_TO_HUB", "true").lower() in {"1","true","yes"}
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
try:
from huggingface_hub import login
login(token=HF_TOKEN)
except Exception as e:
print("HF login failed:", e)
# ---------- LOG HELPERS ----------
def _run(cmd: str, logfile: str):
os.makedirs(os.path.dirname(logfile), exist_ok=True)
with open(logfile, "a", buffering=1) as f:
f.write("\n---- CMD ----\n" + cmd + "\n--------------\n")
p = subprocess.Popen(cmd, shell=True,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
text=True, bufsize=1)
lines = []
for line in p.stdout:
f.write(line)
lines.append(line)
p.wait()
return p.returncode, "".join(lines[-200:])
def tail_file(path: str, n=200):
if not os.path.exists(path):
return "(no log yet)"
with open(path, "r", errors="ignore") as f:
lines = f.readlines()
return "".join(lines[-n:])
# ---------- RUN DIR HELPERS ----------
def has_checkpoint(run_dir: str):
"""Checkpoint considered present once checkpoints/last/ exists (first save ~ step 500)."""
return os.path.isdir(os.path.join(run_dir, "checkpoints", "last"))
def newest_run(prefer_checkpoint: bool = True) -> str:
"""Pick newest pusht_* folder; optionally require a checkpoint."""
root = pathlib.Path(RUN_ROOT)
if not root.exists():
return ""
runs = sorted(root.glob("pusht_*"), key=lambda r: r.stat().st_mtime, reverse=True)
if not runs:
return ""
if prefer_checkpoint:
for r in runs:
if has_checkpoint(str(r)):
return str(r)
return str(runs[0])
def new_run_dir():
"""Return a unique run dir path WITHOUT creating it (LeRobot will create it)."""
base = pathlib.Path(RUN_ROOT) / f"pusht_{int(time.time())}"
d = base
i = 1
while d.exists():
d = pathlib.Path(f"{base}_{i}")
i += 1
LAST_PTR.write_text(str(d))
return str(d)
def current_run_dir(user_override: str | None):
"""
Resolve which run to use:
- If user typed something, accept folder name or full path.
- Else try LAST pointer (if valid).
- Else newest run WITH checkpoint.
- Else newest run (even without checkpoint).
- Else return "" (none).
"""
# A) explicit user input
if user_override and user_override.strip():
p = user_override.strip()
if not p.startswith("/"):
p = str(pathlib.Path(RUN_ROOT) / p)
return p
# B) LAST pointer if present and valid
if LAST_PTR.exists():
p = LAST_PTR.read_text().strip()
if p and os.path.isdir(p):
return p
# C) newest run WITH checkpoint
p = newest_run(prefer_checkpoint=True)
if p:
LAST_PTR.write_text(p)
return p
# D) newest run (no checkpoint yet)
p = newest_run(prefer_checkpoint=False)
if p:
LAST_PTR.write_text(p)
return p
return ""
def train_log_path_for_new(run_dir: str):
"""Fresh-run logs go to global LOG_ROOT so we don't pre-create run_dir."""
name = pathlib.Path(run_dir).name
return os.path.join(LOG_ROOT, f"{name}.train.log")
def train_log_path(run_dir: str):
return os.path.join(run_dir, "logs", "train.log")
def eval_log_path(run_dir: str):
return os.path.join(run_dir, "logs", "eval.log")
# ---------- ACTIONS ----------
def start_training(steps, batch_size, push_to_hub, repo_id):
run_dir = new_run_dir()
log = train_log_path_for_new(run_dir)
push_flags = (f"--policy.push_to_hub=true --policy.repo_id='{repo_id.strip()}'"
if push_to_hub and repo_id.strip() else
"--policy.push_to_hub=false")
cmd = (
"lerobot-train "
f"--output_dir='{run_dir}' "
"--policy.type=diffusion "
"--dataset.repo_id=lerobot/pusht "
"--env.type=pusht "
f"--batch_size={batch_size} "
f"--steps={steps} "
"--eval_freq=500 "
"--save_freq=500 "
f"{push_flags}"
)
rc, tail = _run(cmd, log)
msg = f"Started fresh run at: {run_dir}\nTrain exited rc={rc}\n\n=== train.log tail ===\n{tail}"
return msg, run_dir, tail_file(log)
def resume_training(extra_steps, push_to_hub, repo_id, run_dir_text):
"""
Resume training from the newest run (with a checkpoint) if run_dir_text is blank.
Uses the exact saved train_config.json when available; otherwise falls back to
minimal CLI schema so the parser knows the policy/dataset/env types.
"""
# 1) Resolve which run to resume
run_dir = current_run_dir(run_dir_text)
if not run_dir:
return (
"No run found on disk. Start a fresh training first (let it pass step 500 to create a checkpoint).",
"",
"(no log)",
)
log = train_log_path(run_dir)
# 2) Ensure a checkpoint exists
if not has_checkpoint(run_dir):
return (
f"Selected run: {run_dir}\nNo checkpoint in {run_dir}/checkpoints/last/ yet β€” run at least 500 steps once.",
run_dir,
tail_file(log),
)
# 3) Push flags (optional)
push_flags = (
f"--policy.push_to_hub=true --policy.repo_id='{repo_id.strip()}'"
if (push_to_hub and repo_id and repo_id.strip())
else "--policy.push_to_hub=false"
)
# 4) Prefer resuming with the saved config (LOCAL file -> use --config-file)
cfg_path = os.path.join(run_dir, "train_config.json")
parts = [
"lerobot-train",
f"--output_dir='{run_dir}'",
"--resume=true",
f"--steps={extra_steps}",
"--eval_freq=500",
"--save_freq=500",
]
if os.path.exists(cfg_path):
parts.insert(3, f"--config-file='{cfg_path}'")
else:
# Fallback minimal schema so the CLI can parse types correctly
parts.extend([
"--policy.type=diffusion",
"--dataset.repo_id=lerobot/pusht",
"--env.type=pusht",
])
parts.append(push_flags)
cmd = " ".join(parts)
# 5) Run and return logs
rc, tail = _run(cmd, log)
msg = f"Resumed run at: {run_dir}\nResume exited rc={rc}\n\n=== train.log tail ===\n{tail}"
return msg, run_dir, tail_file(log)
def eval_latest(run_dir_text):
run_dir = current_run_dir(run_dir_text)
if not run_dir:
return "No run found yet. Start a fresh training first.", "", "(no log)", "(no metrics)"
elog = eval_log_path(run_dir)
if not has_checkpoint(run_dir):
return f"No checkpoint in {run_dir}/checkpoints/last/ to evaluate.", run_dir, tail_file(elog), "(no metrics)"
ckpt = os.path.join(run_dir, "checkpoints", "last", "pretrained_model")
eval_out_dir = os.path.join(run_dir, "eval_latest")
os.makedirs(eval_out_dir, exist_ok=True)
cmd = (
"lerobot-eval "
f"--policy.path='{ckpt}' "
"--env.type=pusht "
"--eval.n_episodes=100 "
"--eval.batch_size=50 "
f"--output_dir='{eval_out_dir}'"
)
rc, tail = _run(cmd, elog)
# --- parse printed dict and write metrics.json if missing ---
import re, ast
metrics_txt = "(metrics.json not found)"
p = pathlib.Path(eval_out_dir) / "metrics.json"
m = re.findall(r"\{[^}]+pc_success[^}]+\}", tail, flags=re.S)
if m:
try:
d = ast.literal_eval(m[-1])
out = {
"success_rate": d.get("pc_success"),
"avg_max_overlap": d.get("avg_max_reward"),
"avg_sum_reward": d.get("avg_sum_reward"),
"eval_s": d.get("eval_s"),
}
p.write_text(json.dumps(out, indent=2))
metrics_txt = f"Success rate: {out['success_rate']}\nAvg max overlap: {out['avg_max_overlap']}"
except Exception:
pass
elif p.exists():
try:
d = json.loads(p.read_text())
metrics_txt = f"Success rate: {d.get('success_rate')}\nAvg max overlap: {d.get('avg_max_overlap')}"
except Exception:
metrics_txt = "(could not parse metrics.json)"
# --- end patch ---
msg = f"Evaluated run at: {run_dir}\nEval exited rc={rc}\n\n=== eval.log tail ===\n{tail}"
return msg, run_dir, tail_file(elog), metrics_txt
# ---------- Maintenance (list / delete runs) ----------
def list_runs():
root = pathlib.Path(RUN_ROOT)
if not root.exists():
return "(no runs)"
rows = []
for d in sorted(root.glob("pusht_*")):
try:
size = subprocess.check_output(
["bash","-lc", f"du -sh {d} | cut -f1"], text=True
).strip()
except Exception:
size = "?"
ck = "βœ“" if has_checkpoint(str(d)) else "β€”"
rows.append(f"{d.name}\t{size}\tcheckpoint:{ck}")
return "name\tsize\tcheckpoint\n" + "\n".join(rows) if rows else "(no runs)"
def delete_run_by_name(name: str):
name = os.path.basename((name or "").strip())
if not name:
return "Type a folder like 'pusht_1234567890'.", list_runs()
target = os.path.join(RUN_ROOT, name)
if not target.startswith(RUN_ROOT + "/"):
return "Refusing to delete outside runs/.", list_runs()
if not os.path.isdir(target):
return f"Folder not found: {target}", list_runs()
shutil.rmtree(target, ignore_errors=True)
if LAST_PTR.exists() and LAST_PTR.read_text().strip() == target:
LAST_PTR.unlink(missing_ok=True)
return f"Deleted {target}", list_runs()
def delete_all_runs():
if not os.path.isdir(RUN_ROOT):
return "(runs/ missing)", list_runs()
for n in os.listdir(RUN_ROOT):
p = os.path.join(RUN_ROOT, n)
if os.path.isdir(p) and n.startswith("pusht_"):
shutil.rmtree(p, ignore_errors=True)
LAST_PTR.unlink(missing_ok=True)
return "Deleted all pusht_* runs.", list_runs()
# ---------- UI ----------
with gr.Blocks(title="LeRobot PushT Trainer (Space)") as demo:
gr.Markdown("# πŸ€– LeRobot PushT Trainer\nTrain / Resume / Evaluate. Files persist under `/home/user/app/runs/` (see App Files).")
with gr.Row():
repo_id = gr.Textbox(label="Hugging Face Model Repo (optional)", value=DEFAULT_REPO_ID, placeholder="username/repo-name")
push_to_hub = gr.Checkbox(label="Push checkpoints to Hub", value=PUSH_DEFAULT)
with gr.Row():
steps = gr.Slider(200, 20000, value=2000, step=100, label="Training steps (fresh run)")
batch = gr.Slider(4, 64, value=16, step=2, label="Batch size")
start_btn = gr.Button("πŸš€ Start Fresh Training")
start_out = gr.Textbox(label="Start Output")
run_dir_view = gr.Textbox(label="Current run directory (auto-filled after start)")
train_log = gr.Textbox(label="train.log (tail)", lines=20)
gr.Markdown("### Resume / Evaluate a Specific Run")
run_dir_text = gr.Textbox(label="Run directory (leave blank to use the latest)")
with gr.Row():
extra_steps = gr.Slider(200, 20000, value=2000, step=100, label="Steps to add on resume")
resume_btn = gr.Button("▢️ Resume from Last Checkpoint")
resume_out = gr.Textbox(label="Resume Output")
resume_log = gr.Textbox(label="train.log (tail)", lines=20)
gr.Markdown("### Evaluate Latest Checkpoint of Selected Run")
eval_btn = gr.Button("πŸ“ˆ Evaluate Latest")
eval_out = gr.Textbox(label="Eval Output")
eval_log = gr.Textbox(label="eval.log (tail)", lines=20)
metrics_box = gr.Textbox(label="Parsed metrics (if metrics.json exists)")
gr.Markdown("### Runs on disk")
list_btn = gr.Button("πŸ“‚ List runs folder")
list_out = gr.Textbox(label="runs/ listing", lines=12)
gr.Markdown("### Maintenance")
del_name = gr.Textbox(label="Run folder name to delete (e.g., pusht_1699999999)")
del_one_btn = gr.Button("πŸ—‘οΈ Delete this run")
del_all_btn = gr.Button("🧹 Delete ALL pusht_* runs")
# Wiring
start_btn.click(start_training, inputs=[steps, batch, push_to_hub, repo_id], outputs=[start_out, run_dir_view, train_log])
resume_btn.click(resume_training, inputs=[extra_steps, push_to_hub, repo_id, run_dir_text], outputs=[resume_out, run_dir_view, resume_log])
eval_btn.click(eval_latest, inputs=[run_dir_text], outputs=[eval_out, run_dir_view, eval_log, metrics_box])
list_btn.click(list_runs, outputs=list_out)
del_one_btn.click(delete_run_by_name, inputs=del_name, outputs=[list_out, list_out])
del_all_btn.click(delete_all_runs, outputs=[list_out, list_out])
if __name__ == "__main__":
demo.launch()