zino36 commited on
Commit
ad5a998
·
verified ·
1 Parent(s): 6a460fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -27
app.py CHANGED
@@ -2,14 +2,14 @@ import os, subprocess, json, pathlib, time, shutil
2
  import gradio as gr
3
 
4
  # ---------- CONSTANTS (visible in App Files) ----------
5
- RUN_ROOT = "/home/user/app/runs" # where all runs live
6
- LOG_ROOT = "/home/user/app/logs" # global logs (so we don't pre-create run dirs)
7
  LAST_PTR = pathlib.Path(RUN_ROOT) / "LAST" # remembers most recent run path
8
  os.makedirs(RUN_ROOT, exist_ok=True)
9
  os.makedirs(LOG_ROOT, exist_ok=True)
10
 
11
  # ---------- ENV / HUB ----------
12
- DEFAULT_REPO_ID = os.environ.get("REPO_ID", "") # e.g. "zino36/lerobot-pusht-colab"
13
  PUSH_DEFAULT = os.environ.get("PUSH_TO_HUB", "true").lower() in {"1","true","yes"}
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
 
@@ -43,8 +43,12 @@ def tail_file(path: str, n=200):
43
  return "".join(lines[-n:])
44
 
45
  # ---------- RUN DIR HELPERS ----------
 
 
 
 
46
  def newest_run(prefer_checkpoint: bool = True) -> str:
47
- """Return the newest pusht_* folder. If prefer_checkpoint=True, pick the newest that has checkpoints/last/."""
48
  root = pathlib.Path(RUN_ROOT)
49
  if not root.exists():
50
  return ""
@@ -56,9 +60,9 @@ def newest_run(prefer_checkpoint: bool = True) -> str:
56
  if has_checkpoint(str(r)):
57
  return str(r)
58
  return str(runs[0])
59
-
60
  def new_run_dir():
61
- """Return a unique run dir path WITHOUT creating it (so LeRobot can create it)."""
62
  base = pathlib.Path(RUN_ROOT) / f"pusht_{int(time.time())}"
63
  d = base
64
  i = 1
@@ -72,15 +76,14 @@ def current_run_dir(user_override: str | None):
72
  """
73
  Resolve which run to use:
74
  - If user typed something, accept folder name or full path.
75
- - Else try LAST pointer.
76
- - Else pick newest run with a checkpoint.
77
- - Else pick newest run (even without checkpoint).
78
  - Else return "" (none).
79
  """
80
  # A) explicit user input
81
  if user_override and user_override.strip():
82
  p = user_override.strip()
83
- # allow just "pusht_123..." as well as absolute path
84
  if not p.startswith("/"):
85
  p = str(pathlib.Path(RUN_ROOT) / p)
86
  return p
@@ -103,16 +106,10 @@ def current_run_dir(user_override: str | None):
103
  LAST_PTR.write_text(p)
104
  return p
105
 
106
- # E) nothing found
107
  return ""
108
-
109
-
110
- def has_checkpoint(run_dir: str):
111
- """We consider a checkpoint present once checkpoints/last/ exists (first save is at step 500)."""
112
- return os.path.isdir(os.path.join(run_dir, "checkpoints", "last"))
113
 
114
  def train_log_path_for_new(run_dir: str):
115
- """Write fresh-run logs to global LOG_ROOT so we don't pre-create run_dir."""
116
  name = pathlib.Path(run_dir).name
117
  return os.path.join(LOG_ROOT, f"{name}.train.log")
118
 
@@ -148,6 +145,7 @@ def start_training(steps, batch_size, push_to_hub, repo_id):
148
  return msg, run_dir, tail_file(log)
149
 
150
  def resume_training(extra_steps, push_to_hub, repo_id, run_dir_text):
 
151
  run_dir = current_run_dir(run_dir_text)
152
  if not run_dir:
153
  return "No run found on disk. Start a fresh training first (let it pass step 500 to create a checkpoint).", "", "(no log)"
@@ -155,20 +153,29 @@ def resume_training(extra_steps, push_to_hub, repo_id, run_dir_text):
155
 
156
  if not has_checkpoint(run_dir):
157
  return f"Selected run: {run_dir}\nNo checkpoint in {run_dir}/checkpoints/last/ yet — run at least 500 steps once.", run_dir, tail_file(log)
158
-
 
159
  push_flags = (f"--policy.push_to_hub=true --policy.repo_id='{repo_id.strip()}'"
160
  if push_to_hub and repo_id.strip() else
161
  "--policy.push_to_hub=false")
162
 
163
- cmd = (
 
 
164
  "lerobot-train "
165
  f"--output_dir='{run_dir}' "
166
  "--resume=true "
167
  f"--steps={extra_steps} "
168
  "--eval_freq=500 "
169
  "--save_freq=500 "
170
- f"{push_flags}"
171
  )
 
 
 
 
 
 
 
172
  rc, tail = _run(cmd, log)
173
  msg = f"Resumed run at: {run_dir}\nResume exited rc={rc}\n\n=== train.log tail ===\n{tail}"
174
  return msg, run_dir, tail_file(log)
@@ -196,12 +203,10 @@ def eval_latest(run_dir_text):
196
  )
197
  rc, tail = _run(cmd, elog)
198
 
199
- # --- optional patch: parse the printed dict and write metrics.json ---
200
- import re, ast, json, pathlib
201
  metrics_txt = "(metrics.json not found)"
202
  p = pathlib.Path(eval_out_dir) / "metrics.json"
203
-
204
- # Try to parse the last dict-like summary from the log tail
205
  m = re.findall(r"\{[^}]+pc_success[^}]+\}", tail, flags=re.S)
206
  if m:
207
  try:
@@ -217,7 +222,6 @@ def eval_latest(run_dir_text):
217
  except Exception:
218
  pass
219
  elif p.exists():
220
- # Fallback: if a previous metrics.json exists, show it
221
  try:
222
  d = json.loads(p.read_text())
223
  metrics_txt = f"Success rate: {d.get('success_rate')}\nAvg max overlap: {d.get('avg_max_overlap')}"
@@ -227,7 +231,7 @@ def eval_latest(run_dir_text):
227
 
228
  msg = f"Evaluated run at: {run_dir}\nEval exited rc={rc}\n\n=== eval.log tail ===\n{tail}"
229
  return msg, run_dir, tail_file(elog), metrics_txt
230
-
231
  # ---------- Maintenance (list / delete runs) ----------
232
  def list_runs():
233
  root = pathlib.Path(RUN_ROOT)
@@ -255,7 +259,6 @@ def delete_run_by_name(name: str):
255
  if not os.path.isdir(target):
256
  return f"Folder not found: {target}", list_runs()
257
  shutil.rmtree(target, ignore_errors=True)
258
- # clear LAST if it pointed here
259
  if LAST_PTR.exists() and LAST_PTR.read_text().strip() == target:
260
  LAST_PTR.unlink(missing_ok=True)
261
  return f"Deleted {target}", list_runs()
 
2
  import gradio as gr
3
 
4
  # ---------- CONSTANTS (visible in App Files) ----------
5
+ RUN_ROOT = "/home/user/app/runs" # where all runs live
6
+ LOG_ROOT = "/home/user/app/logs" # global logs (avoid pre-creating run dirs)
7
  LAST_PTR = pathlib.Path(RUN_ROOT) / "LAST" # remembers most recent run path
8
  os.makedirs(RUN_ROOT, exist_ok=True)
9
  os.makedirs(LOG_ROOT, exist_ok=True)
10
 
11
  # ---------- ENV / HUB ----------
12
+ DEFAULT_REPO_ID = os.environ.get("REPO_ID", "") # e.g. "zino36/lerobot-pusht-colab"
13
  PUSH_DEFAULT = os.environ.get("PUSH_TO_HUB", "true").lower() in {"1","true","yes"}
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
 
 
43
  return "".join(lines[-n:])
44
 
45
  # ---------- RUN DIR HELPERS ----------
46
+ def has_checkpoint(run_dir: str):
47
+ """Checkpoint considered present once checkpoints/last/ exists (first save ~ step 500)."""
48
+ return os.path.isdir(os.path.join(run_dir, "checkpoints", "last"))
49
+
50
  def newest_run(prefer_checkpoint: bool = True) -> str:
51
+ """Pick newest pusht_* folder; optionally require a checkpoint."""
52
  root = pathlib.Path(RUN_ROOT)
53
  if not root.exists():
54
  return ""
 
60
  if has_checkpoint(str(r)):
61
  return str(r)
62
  return str(runs[0])
63
+
64
  def new_run_dir():
65
+ """Return a unique run dir path WITHOUT creating it (LeRobot will create it)."""
66
  base = pathlib.Path(RUN_ROOT) / f"pusht_{int(time.time())}"
67
  d = base
68
  i = 1
 
76
  """
77
  Resolve which run to use:
78
  - If user typed something, accept folder name or full path.
79
+ - Else try LAST pointer (if valid).
80
+ - Else newest run WITH checkpoint.
81
+ - Else newest run (even without checkpoint).
82
  - Else return "" (none).
83
  """
84
  # A) explicit user input
85
  if user_override and user_override.strip():
86
  p = user_override.strip()
 
87
  if not p.startswith("/"):
88
  p = str(pathlib.Path(RUN_ROOT) / p)
89
  return p
 
106
  LAST_PTR.write_text(p)
107
  return p
108
 
 
109
  return ""
 
 
 
 
 
110
 
111
  def train_log_path_for_new(run_dir: str):
112
+ """Fresh-run logs go to global LOG_ROOT so we don't pre-create run_dir."""
113
  name = pathlib.Path(run_dir).name
114
  return os.path.join(LOG_ROOT, f"{name}.train.log")
115
 
 
145
  return msg, run_dir, tail_file(log)
146
 
147
  def resume_training(extra_steps, push_to_hub, repo_id, run_dir_text):
148
+ # Auto-pick newest run with a checkpoint when left blank.
149
  run_dir = current_run_dir(run_dir_text)
150
  if not run_dir:
151
  return "No run found on disk. Start a fresh training first (let it pass step 500 to create a checkpoint).", "", "(no log)"
 
153
 
154
  if not has_checkpoint(run_dir):
155
  return f"Selected run: {run_dir}\nNo checkpoint in {run_dir}/checkpoints/last/ yet — run at least 500 steps once.", run_dir, tail_file(log)
156
+
157
+ # Build push flags
158
  push_flags = (f"--policy.push_to_hub=true --policy.repo_id='{repo_id.strip()}'"
159
  if push_to_hub and repo_id.strip() else
160
  "--policy.push_to_hub=false")
161
 
162
+ # Prefer resuming with the exact saved config from this run
163
+ cfg_path = os.path.join(run_dir, "train_config.json")
164
+ base = (
165
  "lerobot-train "
166
  f"--output_dir='{run_dir}' "
167
  "--resume=true "
168
  f"--steps={extra_steps} "
169
  "--eval_freq=500 "
170
  "--save_freq=500 "
 
171
  )
172
+ if os.path.exists(cfg_path):
173
+ base += f"--config_path='{cfg_path}' "
174
+ else:
175
+ # Fallback to minimal schema (prevents policy parsing errors)
176
+ base += "--policy.type=diffusion --dataset.repo_id=lerobot/pusht --env.type=pusht "
177
+
178
+ cmd = base + push_flags
179
  rc, tail = _run(cmd, log)
180
  msg = f"Resumed run at: {run_dir}\nResume exited rc={rc}\n\n=== train.log tail ===\n{tail}"
181
  return msg, run_dir, tail_file(log)
 
203
  )
204
  rc, tail = _run(cmd, elog)
205
 
206
+ # --- parse printed dict and write metrics.json if missing ---
207
+ import re, ast
208
  metrics_txt = "(metrics.json not found)"
209
  p = pathlib.Path(eval_out_dir) / "metrics.json"
 
 
210
  m = re.findall(r"\{[^}]+pc_success[^}]+\}", tail, flags=re.S)
211
  if m:
212
  try:
 
222
  except Exception:
223
  pass
224
  elif p.exists():
 
225
  try:
226
  d = json.loads(p.read_text())
227
  metrics_txt = f"Success rate: {d.get('success_rate')}\nAvg max overlap: {d.get('avg_max_overlap')}"
 
231
 
232
  msg = f"Evaluated run at: {run_dir}\nEval exited rc={rc}\n\n=== eval.log tail ===\n{tail}"
233
  return msg, run_dir, tail_file(elog), metrics_txt
234
+
235
  # ---------- Maintenance (list / delete runs) ----------
236
  def list_runs():
237
  root = pathlib.Path(RUN_ROOT)
 
259
  if not os.path.isdir(target):
260
  return f"Folder not found: {target}", list_runs()
261
  shutil.rmtree(target, ignore_errors=True)
 
262
  if LAST_PTR.exists() and LAST_PTR.read_text().strip() == target:
263
  LAST_PTR.unlink(missing_ok=True)
264
  return f"Deleted {target}", list_runs()