File size: 13,259 Bytes
69bccb6
7d39621
 
69bccb6
ad5a998
 
69bccb6
7d39621
69bccb6
7d39621
69bccb6
ad5a998
7d39621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad5a998
 
 
 
6a460fa
ad5a998
6a460fa
 
 
 
 
 
 
 
 
 
 
ad5a998
7d39621
ad5a998
69bccb6
 
 
 
 
 
7d39621
 
 
 
6a460fa
 
 
ad5a998
 
 
6a460fa
 
 
7d39621
6a460fa
 
 
 
 
 
7d39621
6a460fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69bccb6
7d39621
69bccb6
ad5a998
69bccb6
 
 
7d39621
 
 
 
 
 
 
 
 
69bccb6
7d39621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b83a57
 
 
 
 
 
7d39621
 
3b83a57
 
 
 
 
7d39621
 
3b83a57
7d39621
3b83a57
 
 
 
 
 
 
 
 
 
 
 
7d39621
3b83a57
ad5a998
3b83a57
 
 
 
 
 
 
 
 
 
ad5a998
3b83a57
ad5a998
3b83a57
 
 
 
 
 
 
 
 
 
 
7d39621
 
 
3b83a57
7d39621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
add8b3a
ad5a998
 
645433b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
add8b3a
7d39621
add8b3a
 
7d39621
 
add8b3a
 
7d39621
 
ad5a998
69bccb6
7d39621
 
 
 
 
 
69bccb6
 
 
 
 
 
7d39621
 
 
 
69bccb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d39621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69bccb6
 
 
 
 
 
7d39621
 
 
 
69bccb6
 
7d39621
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import os, subprocess, json, pathlib, time, shutil
import gradio as gr

# ---------- CONSTANTS (visible in App Files) ----------
RUN_ROOT = "/home/user/app/runs"           # where all runs live
LOG_ROOT = "/home/user/app/logs"           # global logs (avoid pre-creating run dirs)
LAST_PTR = pathlib.Path(RUN_ROOT) / "LAST" # remembers most recent run path
os.makedirs(RUN_ROOT, exist_ok=True)
os.makedirs(LOG_ROOT, exist_ok=True)

# ---------- ENV / HUB ----------
DEFAULT_REPO_ID = os.environ.get("REPO_ID", "")  # e.g. "zino36/lerobot-pusht-colab"
PUSH_DEFAULT    = os.environ.get("PUSH_TO_HUB", "true").lower() in {"1","true","yes"}
HF_TOKEN        = os.environ.get("HF_TOKEN")

if HF_TOKEN:
    try:
        from huggingface_hub import login
        login(token=HF_TOKEN)
    except Exception as e:
        print("HF login failed:", e)

# ---------- LOG HELPERS ----------
def _run(cmd: str, logfile: str):
    os.makedirs(os.path.dirname(logfile), exist_ok=True)
    with open(logfile, "a", buffering=1) as f:
        f.write("\n---- CMD ----\n" + cmd + "\n--------------\n")
        p = subprocess.Popen(cmd, shell=True,
                             stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                             text=True, bufsize=1)
        lines = []
        for line in p.stdout:
            f.write(line)
            lines.append(line)
        p.wait()
        return p.returncode, "".join(lines[-200:])

def tail_file(path: str, n=200):
    if not os.path.exists(path):
        return "(no log yet)"
    with open(path, "r", errors="ignore") as f:
        lines = f.readlines()
    return "".join(lines[-n:])

# ---------- RUN DIR HELPERS ----------
def has_checkpoint(run_dir: str):
    """Checkpoint considered present once checkpoints/last/ exists (first save ~ step 500)."""
    return os.path.isdir(os.path.join(run_dir, "checkpoints", "last"))

def newest_run(prefer_checkpoint: bool = True) -> str:
    """Pick newest pusht_* folder; optionally require a checkpoint."""
    root = pathlib.Path(RUN_ROOT)
    if not root.exists():
        return ""
    runs = sorted(root.glob("pusht_*"), key=lambda r: r.stat().st_mtime, reverse=True)
    if not runs:
        return ""
    if prefer_checkpoint:
        for r in runs:
            if has_checkpoint(str(r)):
                return str(r)
    return str(runs[0])

def new_run_dir():
    """Return a unique run dir path WITHOUT creating it (LeRobot will create it)."""
    base = pathlib.Path(RUN_ROOT) / f"pusht_{int(time.time())}"
    d = base
    i = 1
    while d.exists():
        d = pathlib.Path(f"{base}_{i}")
        i += 1
    LAST_PTR.write_text(str(d))
    return str(d)

def current_run_dir(user_override: str | None):
    """
    Resolve which run to use:
    - If user typed something, accept folder name or full path.
    - Else try LAST pointer (if valid).
    - Else newest run WITH checkpoint.
    - Else newest run (even without checkpoint).
    - Else return "" (none).
    """
    # A) explicit user input
    if user_override and user_override.strip():
        p = user_override.strip()
        if not p.startswith("/"):
            p = str(pathlib.Path(RUN_ROOT) / p)
        return p

    # B) LAST pointer if present and valid
    if LAST_PTR.exists():
        p = LAST_PTR.read_text().strip()
        if p and os.path.isdir(p):
            return p

    # C) newest run WITH checkpoint
    p = newest_run(prefer_checkpoint=True)
    if p:
        LAST_PTR.write_text(p)
        return p

    # D) newest run (no checkpoint yet)
    p = newest_run(prefer_checkpoint=False)
    if p:
        LAST_PTR.write_text(p)
        return p

    return ""

def train_log_path_for_new(run_dir: str):
    """Fresh-run logs go to global LOG_ROOT so we don't pre-create run_dir."""
    name = pathlib.Path(run_dir).name
    return os.path.join(LOG_ROOT, f"{name}.train.log")

def train_log_path(run_dir: str):
    return os.path.join(run_dir, "logs", "train.log")

def eval_log_path(run_dir: str):
    return os.path.join(run_dir, "logs", "eval.log")

# ---------- ACTIONS ----------
def start_training(steps, batch_size, push_to_hub, repo_id):
    run_dir = new_run_dir()
    log = train_log_path_for_new(run_dir)

    push_flags = (f"--policy.push_to_hub=true --policy.repo_id='{repo_id.strip()}'"
                  if push_to_hub and repo_id.strip() else
                  "--policy.push_to_hub=false")

    cmd = (
        "lerobot-train "
        f"--output_dir='{run_dir}' "
        "--policy.type=diffusion "
        "--dataset.repo_id=lerobot/pusht "
        "--env.type=pusht "
        f"--batch_size={batch_size} "
        f"--steps={steps} "
        "--eval_freq=500 "
        "--save_freq=500 "
        f"{push_flags}"
    )
    rc, tail = _run(cmd, log)
    msg = f"Started fresh run at: {run_dir}\nTrain exited rc={rc}\n\n=== train.log tail ===\n{tail}"
    return msg, run_dir, tail_file(log)

def resume_training(extra_steps, push_to_hub, repo_id, run_dir_text):
    """
    Resume training from the newest run (with a checkpoint) if run_dir_text is blank.
    Uses the exact saved train_config.json when available; otherwise falls back to
    minimal CLI schema so the parser knows the policy/dataset/env types.
    """
    # 1) Resolve which run to resume
    run_dir = current_run_dir(run_dir_text)
    if not run_dir:
        return (
            "No run found on disk. Start a fresh training first (let it pass step 500 to create a checkpoint).",
            "",
            "(no log)",
        )
    log = train_log_path(run_dir)

    # 2) Ensure a checkpoint exists
    if not has_checkpoint(run_dir):
        return (
            f"Selected run: {run_dir}\nNo checkpoint in {run_dir}/checkpoints/last/ yet β€” run at least 500 steps once.",
            run_dir,
            tail_file(log),
        )

    # 3) Push flags (optional)
    push_flags = (
        f"--policy.push_to_hub=true --policy.repo_id='{repo_id.strip()}'"
        if (push_to_hub and repo_id and repo_id.strip())
        else "--policy.push_to_hub=false"
    )

    # 4) Prefer resuming with the saved config (LOCAL file -> use --config-file)
    cfg_path = os.path.join(run_dir, "train_config.json")

    parts = [
        "lerobot-train",
        f"--output_dir='{run_dir}'",
        "--resume=true",
        f"--steps={extra_steps}",
        "--eval_freq=500",
        "--save_freq=500",
    ]

    if os.path.exists(cfg_path):
        parts.insert(3, f"--config-file='{cfg_path}'")
    else:
        # Fallback minimal schema so the CLI can parse types correctly
        parts.extend([
            "--policy.type=diffusion",
            "--dataset.repo_id=lerobot/pusht",
            "--env.type=pusht",
        ])

    parts.append(push_flags)
    cmd = " ".join(parts)

    # 5) Run and return logs
    rc, tail = _run(cmd, log)
    msg = f"Resumed run at: {run_dir}\nResume exited rc={rc}\n\n=== train.log tail ===\n{tail}"
    return msg, run_dir, tail_file(log)
    
def eval_latest(run_dir_text):
    run_dir = current_run_dir(run_dir_text)
    if not run_dir:
        return "No run found yet. Start a fresh training first.", "", "(no log)", "(no metrics)"
    elog = eval_log_path(run_dir)

    if not has_checkpoint(run_dir):
        return f"No checkpoint in {run_dir}/checkpoints/last/ to evaluate.", run_dir, tail_file(elog), "(no metrics)"

    ckpt = os.path.join(run_dir, "checkpoints", "last", "pretrained_model")
    eval_out_dir = os.path.join(run_dir, "eval_latest")
    os.makedirs(eval_out_dir, exist_ok=True)

    cmd = (
        "lerobot-eval "
        f"--policy.path='{ckpt}' "
        "--env.type=pusht "
        "--eval.n_episodes=100 "
        "--eval.batch_size=50 "
        f"--output_dir='{eval_out_dir}'"
    )
    rc, tail = _run(cmd, elog)

    # --- parse printed dict and write metrics.json if missing ---
    import re, ast
    metrics_txt = "(metrics.json not found)"
    p = pathlib.Path(eval_out_dir) / "metrics.json"
    m = re.findall(r"\{[^}]+pc_success[^}]+\}", tail, flags=re.S)
    if m:
        try:
            d = ast.literal_eval(m[-1])
            out = {
                "success_rate": d.get("pc_success"),
                "avg_max_overlap": d.get("avg_max_reward"),
                "avg_sum_reward": d.get("avg_sum_reward"),
                "eval_s": d.get("eval_s"),
            }
            p.write_text(json.dumps(out, indent=2))
            metrics_txt = f"Success rate: {out['success_rate']}\nAvg max overlap: {out['avg_max_overlap']}"
        except Exception:
            pass
    elif p.exists():
        try:
            d = json.loads(p.read_text())
            metrics_txt = f"Success rate: {d.get('success_rate')}\nAvg max overlap: {d.get('avg_max_overlap')}"
        except Exception:
            metrics_txt = "(could not parse metrics.json)"
    # --- end patch ---

    msg = f"Evaluated run at: {run_dir}\nEval exited rc={rc}\n\n=== eval.log tail ===\n{tail}"
    return msg, run_dir, tail_file(elog), metrics_txt

# ---------- Maintenance (list / delete runs) ----------
def list_runs():
    root = pathlib.Path(RUN_ROOT)
    if not root.exists():
        return "(no runs)"
    rows = []
    for d in sorted(root.glob("pusht_*")):
        try:
            size = subprocess.check_output(
                ["bash","-lc", f"du -sh {d} | cut -f1"], text=True
            ).strip()
        except Exception:
            size = "?"
        ck = "βœ“" if has_checkpoint(str(d)) else "β€”"
        rows.append(f"{d.name}\t{size}\tcheckpoint:{ck}")
    return "name\tsize\tcheckpoint\n" + "\n".join(rows) if rows else "(no runs)"

def delete_run_by_name(name: str):
    name = os.path.basename((name or "").strip())
    if not name:
        return "Type a folder like 'pusht_1234567890'.", list_runs()
    target = os.path.join(RUN_ROOT, name)
    if not target.startswith(RUN_ROOT + "/"):
        return "Refusing to delete outside runs/.", list_runs()
    if not os.path.isdir(target):
        return f"Folder not found: {target}", list_runs()
    shutil.rmtree(target, ignore_errors=True)
    if LAST_PTR.exists() and LAST_PTR.read_text().strip() == target:
        LAST_PTR.unlink(missing_ok=True)
    return f"Deleted {target}", list_runs()

def delete_all_runs():
    if not os.path.isdir(RUN_ROOT):
        return "(runs/ missing)", list_runs()
    for n in os.listdir(RUN_ROOT):
        p = os.path.join(RUN_ROOT, n)
        if os.path.isdir(p) and n.startswith("pusht_"):
            shutil.rmtree(p, ignore_errors=True)
    LAST_PTR.unlink(missing_ok=True)
    return "Deleted all pusht_* runs.", list_runs()

# ---------- UI ----------
with gr.Blocks(title="LeRobot PushT Trainer (Space)") as demo:
    gr.Markdown("# πŸ€– LeRobot PushT Trainer\nTrain / Resume / Evaluate. Files persist under `/home/user/app/runs/` (see App Files).")

    with gr.Row():
        repo_id = gr.Textbox(label="Hugging Face Model Repo (optional)", value=DEFAULT_REPO_ID, placeholder="username/repo-name")
        push_to_hub = gr.Checkbox(label="Push checkpoints to Hub", value=PUSH_DEFAULT)

    with gr.Row():
        steps = gr.Slider(200, 20000, value=2000, step=100, label="Training steps (fresh run)")
        batch = gr.Slider(4, 64, value=16, step=2, label="Batch size")

    start_btn = gr.Button("πŸš€ Start Fresh Training")
    start_out = gr.Textbox(label="Start Output")
    run_dir_view = gr.Textbox(label="Current run directory (auto-filled after start)")
    train_log = gr.Textbox(label="train.log (tail)", lines=20)

    gr.Markdown("### Resume / Evaluate a Specific Run")
    run_dir_text = gr.Textbox(label="Run directory (leave blank to use the latest)")
    with gr.Row():
        extra_steps = gr.Slider(200, 20000, value=2000, step=100, label="Steps to add on resume")
        resume_btn = gr.Button("▢️ Resume from Last Checkpoint")
    resume_out = gr.Textbox(label="Resume Output")
    resume_log = gr.Textbox(label="train.log (tail)", lines=20)

    gr.Markdown("### Evaluate Latest Checkpoint of Selected Run")
    eval_btn = gr.Button("πŸ“ˆ Evaluate Latest")
    eval_out = gr.Textbox(label="Eval Output")
    eval_log = gr.Textbox(label="eval.log (tail)", lines=20)
    metrics_box = gr.Textbox(label="Parsed metrics (if metrics.json exists)")

    gr.Markdown("### Runs on disk")
    list_btn = gr.Button("πŸ“‚ List runs folder")
    list_out = gr.Textbox(label="runs/ listing", lines=12)

    gr.Markdown("### Maintenance")
    del_name = gr.Textbox(label="Run folder name to delete (e.g., pusht_1699999999)")
    del_one_btn = gr.Button("πŸ—‘οΈ Delete this run")
    del_all_btn = gr.Button("🧹 Delete ALL pusht_* runs")

    # Wiring
    start_btn.click(start_training, inputs=[steps, batch, push_to_hub, repo_id], outputs=[start_out, run_dir_view, train_log])
    resume_btn.click(resume_training, inputs=[extra_steps, push_to_hub, repo_id, run_dir_text], outputs=[resume_out, run_dir_view, resume_log])
    eval_btn.click(eval_latest, inputs=[run_dir_text], outputs=[eval_out, run_dir_view, eval_log, metrics_box])
    list_btn.click(list_runs, outputs=list_out)
    del_one_btn.click(delete_run_by_name, inputs=del_name, outputs=[list_out, list_out])
    del_all_btn.click(delete_all_runs, outputs=[list_out, list_out])

if __name__ == "__main__":
    demo.launch()