zino36 commited on
Commit
3b83a57
·
verified ·
1 Parent(s): dc206dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -24
app.py CHANGED
@@ -145,42 +145,67 @@ def start_training(steps, batch_size, push_to_hub, repo_id):
145
  return msg, run_dir, tail_file(log)
146
 
147
  def resume_training(extra_steps, push_to_hub, repo_id, run_dir_text):
148
- # Auto-pick newest run with a checkpoint when left blank.
 
 
 
 
 
149
  run_dir = current_run_dir(run_dir_text)
150
  if not run_dir:
151
- return "No run found on disk. Start a fresh training first (let it pass step 500 to create a checkpoint).", "", "(no log)"
 
 
 
 
152
  log = train_log_path(run_dir)
153
 
 
154
  if not has_checkpoint(run_dir):
155
- return f"Selected run: {run_dir}\nNo checkpoint in {run_dir}/checkpoints/last/ yet — run at least 500 steps once.", run_dir, tail_file(log)
156
-
157
- # Build push flags
158
- push_flags = (f"--policy.push_to_hub=true --policy.repo_id='{repo_id.strip()}'"
159
- if push_to_hub and repo_id.strip() else
160
- "--policy.push_to_hub=false")
 
 
 
 
 
 
161
 
162
- # Prefer resuming with the exact saved config from this run
163
  cfg_path = os.path.join(run_dir, "train_config.json")
164
- base = (
165
- "lerobot-train "
166
- f"--output_dir='{run_dir}' "
167
- "--resume=true "
168
- f"--config_path='{cfg_path}' "
169
- f"--steps={extra_steps} "
170
- "--eval_freq=500 "
171
- "--save_freq=500 "
172
- f"{push_flags}"
173
- )
174
  if os.path.exists(cfg_path):
175
- base += f"--config_path='{cfg_path}' "
 
176
  else:
177
- base += "--policy.type=diffusion --dataset.repo_id=lerobot/pusht --env.type=pusht "
178
-
179
- cmd = base + push_flags
 
 
 
 
 
 
 
 
180
  rc, tail = _run(cmd, log)
181
  msg = f"Resumed run at: {run_dir}\nResume exited rc={rc}\n\n=== train.log tail ===\n{tail}"
182
  return msg, run_dir, tail_file(log)
183
-
184
  def eval_latest(run_dir_text):
185
  run_dir = current_run_dir(run_dir_text)
186
  if not run_dir:
 
145
  return msg, run_dir, tail_file(log)
146
 
147
  def resume_training(extra_steps, push_to_hub, repo_id, run_dir_text):
148
+ """
149
+ Resume training from the newest run (with a checkpoint) if run_dir_text is blank.
150
+ Uses the exact saved train_config.json when available; otherwise falls back to
151
+ minimal CLI schema so the parser knows the policy/dataset/env types.
152
+ """
153
+ # 1) Resolve which run to resume
154
  run_dir = current_run_dir(run_dir_text)
155
  if not run_dir:
156
+ return (
157
+ "No run found on disk. Start a fresh training first (let it pass step 500 to create a checkpoint).",
158
+ "",
159
+ "(no log)",
160
+ )
161
  log = train_log_path(run_dir)
162
 
163
+ # 2) Ensure a checkpoint exists
164
  if not has_checkpoint(run_dir):
165
+ return (
166
+ f"Selected run: {run_dir}\nNo checkpoint in {run_dir}/checkpoints/last/ yet — run at least 500 steps once.",
167
+ run_dir,
168
+ tail_file(log),
169
+ )
170
+
171
+ # 3) Push flags (optional)
172
+ push_flags = (
173
+ f"--policy.push_to_hub=true --policy.repo_id='{repo_id.strip()}'"
174
+ if (push_to_hub and repo_id and repo_id.strip())
175
+ else "--policy.push_to_hub=false"
176
+ )
177
 
178
+ # 4) Prefer resuming with the saved config (LOCAL file -> use --config-file)
179
  cfg_path = os.path.join(run_dir, "train_config.json")
180
+
181
+ parts = [
182
+ "lerobot-train",
183
+ f"--output_dir='{run_dir}'",
184
+ "--resume=true",
185
+ f"--steps={extra_steps}",
186
+ "--eval_freq=500",
187
+ "--save_freq=500",
188
+ ]
189
+
190
  if os.path.exists(cfg_path):
191
+ # IMPORTANT: use --config-file for a local JSON, not --config_path
192
+ parts.insert(3, f"--config-file='{cfg_path}'")
193
  else:
194
+ # Fallback minimal schema so the CLI can parse types correctly
195
+ parts.extend([
196
+ "--policy.type=diffusion",
197
+ "--dataset.repo_id=lerobot/pusht",
198
+ "--env.type=pusht",
199
+ ])
200
+
201
+ parts.append(push_flags)
202
+ cmd = " ".join(parts)
203
+
204
+ # 5) Run and return logs
205
  rc, tail = _run(cmd, log)
206
  msg = f"Resumed run at: {run_dir}\nResume exited rc={rc}\n\n=== train.log tail ===\n{tail}"
207
  return msg, run_dir, tail_file(log)
208
+
209
  def eval_latest(run_dir_text):
210
  run_dir = current_run_dir(run_dir_text)
211
  if not run_dir: