div18 commited on
Commit
46cd5c4
·
1 Parent(s): 871c1ae
Files changed (2) hide show
  1. training/launch_train.py +418 -0
  2. training/openenv_loop.py +14 -5
training/launch_train.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ launch_train.py — Launch full AntiAtropos training on Hugging Face Jobs.
4
+
5
+ Pushes model checkpoints, metrics dataset, and plots to HF Hub.
6
+ The local server is co-located for zero-latency environment interaction.
7
+ Supports automatic resume from latest Hub checkpoint.
8
+
9
+ Prerequisites:
10
+ 1. pip install "huggingface_hub>=0.25.0"
11
+ 2. huggingface-cli login (or set HF_TOKEN env var)
12
+ 3. HF Pro/Team account (required for GPU jobs)
13
+ 4. The Hub model and dataset repos are auto-created if they don't exist.
14
+ Alternatively create them manually:
15
+ hf repo create <hub-model-repo> --type model
16
+ hf repo create <hub-metrics-dataset> --type dataset
17
+
18
+ Lifecycle:
19
+ ┌─────────────────────────────────────────────┐
20
+ │ HF Job (A10G, ~4h) │
21
+ │ ┌──────────────────────────────────────┐ │
22
+ │ │ uvicorn :8000 ←──→ train.py │ │
23
+ │ │ (simulator) (GPU model) │ │
24
+ │ └──────────┬───────────────────────────┘ │
25
+ │ │ push adapter + plots │
26
+ │ ▼ │
27
+ │ HF Hub Model Repo │
28
+ │ (checkpoint-25, checkpoint-50, ...) │
29
+ │ │ push metrics.jsonl │
30
+ │ ▼ │
31
+ │ HF Hub Metrics Dataset │
32
+ └─────────────────────────────────────────────┘
33
+
34
+ Usage:
35
+ # Quick test (∼10 min):
36
+ python training/launch_train.py \
37
+ --hub-model-repo Keshav051/antiatropos-qlora \
38
+ --hub-metrics-dataset Keshav051/antiatropos-training-metrics \
39
+ --num-iterations 20 --num-episodes 4
40
+
41
+ # Full training (a10g-large ≈ $0.34/hr, ∼2h):
42
+ python training/launch_train.py \
43
+ --hub-model-repo Keshav051/antiatropos-qlora \
44
+ --hub-metrics-dataset Keshav051/antiatropos-training-metrics
45
+
46
+ # Custom flavor / longer timeout:
47
+ python training/launch_train.py \
48
+ --hub-model-repo Keshav051/antiatropos-qlora \
49
+ --hub-metrics-dataset Keshav051/antiatropos-training-metrics \
50
+ --flavor a10g-xlarge --timeout 12h \
51
+ --num-iterations 2000 --num-episodes 24
52
+
53
+ # Resume from latest Hub checkpoint:
54
+ python training/launch_train.py \
55
+ --hub-model-repo Keshav051/antiatropos-qlora \
56
+ --hub-metrics-dataset Keshav051/antiatropos-training-metrics \
57
+ --run-id exp_002
58
+
59
+ # Dry run (prints job command without launching):
60
+ python training/launch_train.py \
61
+ --hub-model-repo Keshav051/antiatropos-qlora \
62
+ --hub-metrics-dataset Keshav051/antiatropos-training-metrics \
63
+ --dry-run
64
+ """
65
+
66
+ from __future__ import annotations
67
+
68
+ import argparse
69
+ import os
70
+ import sys
71
+ from datetime import datetime
72
+ from pathlib import Path
73
+ from typing import Optional
74
+
75
+ TRAINING_DIR = Path(__file__).resolve().parent
76
+
77
+ DOCKER_IMAGE = "pytorch/pytorch:2.10.0-cuda12.6-cudnn9-devel"
78
+
79
+ DEFAULT_NUM_ITERATIONS = 500
80
+ DEFAULT_NUM_EPISODES = 12
81
+ DEFAULT_MAX_STEPS = 40
82
+ DEFAULT_EVAL_INTERVAL = 50
83
+ DEFAULT_CHECKPOINT_INTERVAL = 25
84
+ DEFAULT_PLOT_INTERVAL = 25
85
+
86
+
87
+ def build_job_command() -> str:
88
+ """Build the shell script that runs INSIDE the HF Job container.
89
+
90
+ Starts the AntiAtropos FastAPI server locally (eliminating HTTP latency)
91
+ then runs training against localhost:8000 with Hub persistence.
92
+ """
93
+ return (
94
+ "set -e\n"
95
+ "\n"
96
+ "echo '[bootstrap] Installing git...'\n"
97
+ "apt-get update -qq && apt-get install -y -qq git netcat-openbsd > /dev/null 2>&1\n"
98
+ "\n"
99
+ "echo '[bootstrap] Cloning $REPO...'\n"
100
+ "mkdir -p /workspace\n"
101
+ "git clone --depth 1 https://hf:${HF_TOKEN}@huggingface.co/$REPO /workspace/AntiAtropos\n"
102
+ "cd /workspace/AntiAtropos\n"
103
+ "\n"
104
+ "echo '[bootstrap] Installing dependencies...'\n"
105
+ "pip install --break-system-packages --no-deps torchvision -q\n"
106
+ "pip install --break-system-packages -r training/requirements.txt -q\n"
107
+ "\n"
108
+ "echo '[bootstrap] Starting local AntiAtropos server (simulated mode)...'\n"
109
+ "export ANTIATROPOS_ENV_MODE=simulated\n"
110
+ "uvicorn server.app:app --host 127.0.0.1 --port 8000 &\n"
111
+ "SERVER_PID=$!\n"
112
+ "\n"
113
+ "# Wait for server to be ready\n"
114
+ "echo '[bootstrap] Waiting for server...'\n"
115
+ "for i in $(seq 1 30); do\n"
116
+ " if curl -s http://127.0.0.1:8000/health > /dev/null 2>&1; then\n"
117
+ " echo '[bootstrap] Server ready.'\n"
118
+ " break\n"
119
+ " fi\n"
120
+ " sleep 1\n"
121
+ "done\n"
122
+ "\n"
123
+ "echo '[bootstrap] Launching training (local server, Hub persistence)...'\n"
124
+ "ANTIATROPOS_HUB_MODEL_REPO=$HUB_MODEL_REPO "
125
+ "ANTIATROPOS_HUB_METRICS_DATASET=$HUB_METRICS_DATASET "
126
+ "ANTIATROPOS_ENV_URL=http://localhost:8000 "
127
+ "python training/train.py "
128
+ "--run-id $RUN_ID "
129
+ "--num-iterations $NUM_ITERATIONS "
130
+ "--num-episodes $NUM_EPISODES "
131
+ "--max-steps $MAX_STEPS "
132
+ "--eval-interval $EVAL_INTERVAL "
133
+ "--checkpoint-interval $CHECKPOINT_INTERVAL "
134
+ "--plot-interval $PLOT_INTERVAL\n"
135
+ "TRAIN_EXIT=$?\n"
136
+ "\n"
137
+ "echo '[bootstrap] Stopping server...'\n"
138
+ "kill $SERVER_PID 2>/dev/null || true\n"
139
+ "wait $SERVER_PID 2>/dev/null || true\n"
140
+ "\n"
141
+ "exit $TRAIN_EXIT"
142
+ )
143
+
144
+
145
+ def ensure_hub_repos(
146
+ hub_model_repo: str,
147
+ hub_metrics_dataset: str,
148
+ hf_token: Optional[str],
149
+ ) -> None:
150
+ """Check if Hub repos exist; create them automatically if not."""
151
+ if not hf_token:
152
+ print(" [hub] No HF_TOKEN available, skipping repo check")
153
+ return
154
+
155
+ try:
156
+ from huggingface_hub import HfApi
157
+
158
+ api = HfApi()
159
+
160
+ for repo_id, repo_type in [
161
+ (hub_model_repo, "model"),
162
+ (hub_metrics_dataset, "dataset"),
163
+ ]:
164
+ base_url = (
165
+ "https://huggingface.co"
166
+ if repo_type == "model"
167
+ else "https://huggingface.co/datasets"
168
+ )
169
+ try:
170
+ info = api.repo_info(repo_id=repo_id, repo_type=repo_type)
171
+ print(f" [hub] Repo OK: {base_url}/{repo_id}")
172
+ except Exception:
173
+ print(f" [hub] Creating repo: {repo_id} ({repo_type})...")
174
+ api.create_repo(
175
+ repo_id=repo_id,
176
+ repo_type=repo_type,
177
+ private=True,
178
+ exist_ok=True,
179
+ )
180
+ print(f" [hub] Created: {base_url}/{repo_id}")
181
+ except Exception as e:
182
+ print(f"\n [hub] WARNING: Could not verify/create Hub repos: {e}")
183
+ print(" [hub] Create them manually:")
184
+ print(f" hf repo create {hub_model_repo} --type model")
185
+ print(f" hf repo create {hub_metrics_dataset} --type dataset")
186
+ print(f" Then visit:")
187
+ print(f" https://huggingface.co/{hub_model_repo}")
188
+ print(f" https://huggingface.co/datasets/{hub_metrics_dataset}")
189
+
190
+
191
+ def main() -> None:
192
+ parser = argparse.ArgumentParser(
193
+ description="AntiAtropos Full Training — HF Jobs with Hub persistence"
194
+ )
195
+ parser.add_argument(
196
+ "--flavor",
197
+ default="a10g-large",
198
+ help="GPU flavor (default: a10g-large). Run 'hf jobs hardware' for full list.",
199
+ )
200
+ parser.add_argument(
201
+ "--timeout",
202
+ default="4h",
203
+ help="Job timeout (default: 4h). Examples: 30m, 2h, 7200",
204
+ )
205
+ parser.add_argument(
206
+ "--repo",
207
+ default="Keshav051/AntiAtropos",
208
+ help="HF repo to clone (project source code)",
209
+ )
210
+ parser.add_argument(
211
+ "--hub-model-repo",
212
+ required=True,
213
+ help="HF Hub model repo for checkpoints, final adapter, and plots "
214
+ "(e.g. Keshav051/antiatropos-qlora)",
215
+ )
216
+ parser.add_argument(
217
+ "--hub-metrics-dataset",
218
+ required=True,
219
+ help="HF Hub dataset repo for training metrics.jsonl "
220
+ "(e.g. Keshav051/antiatropos-training-metrics)",
221
+ )
222
+ parser.add_argument(
223
+ "--run-id",
224
+ default=None,
225
+ help="Run identifier (default: train_YYYYMMDD_HHMMSS). "
226
+ "Use same ID to resume a previous run.",
227
+ )
228
+ parser.add_argument(
229
+ "--num-iterations",
230
+ type=int,
231
+ default=DEFAULT_NUM_ITERATIONS,
232
+ help=f"Training iterations (default: {DEFAULT_NUM_ITERATIONS})",
233
+ )
234
+ parser.add_argument(
235
+ "--num-episodes",
236
+ type=int,
237
+ default=DEFAULT_NUM_EPISODES,
238
+ help=f"Episodes per iteration (default: {DEFAULT_NUM_EPISODES})",
239
+ )
240
+ parser.add_argument(
241
+ "--max-steps",
242
+ type=int,
243
+ default=DEFAULT_MAX_STEPS,
244
+ help=f"Max steps per episode (default: {DEFAULT_MAX_STEPS})",
245
+ )
246
+ parser.add_argument(
247
+ "--eval-interval",
248
+ type=int,
249
+ default=DEFAULT_EVAL_INTERVAL,
250
+ help=f"Evaluate every N iterations (default: {DEFAULT_EVAL_INTERVAL})",
251
+ )
252
+ parser.add_argument(
253
+ "--checkpoint-interval",
254
+ type=int,
255
+ default=DEFAULT_CHECKPOINT_INTERVAL,
256
+ help=f"Checkpoint every N iterations (default: {DEFAULT_CHECKPOINT_INTERVAL})",
257
+ )
258
+ parser.add_argument(
259
+ "--plot-interval",
260
+ type=int,
261
+ default=DEFAULT_PLOT_INTERVAL,
262
+ help=f"Plot every N iterations (default: {DEFAULT_PLOT_INTERVAL})",
263
+ )
264
+ parser.add_argument(
265
+ "--dry-run",
266
+ action="store_true",
267
+ help="Print config and exit without launching",
268
+ )
269
+ parser.add_argument(
270
+ "--no-create-repos",
271
+ action="store_true",
272
+ help="Skip automatic Hub repo creation",
273
+ )
274
+ args = parser.parse_args()
275
+
276
+ run_id = args.run_id or f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
277
+
278
+ # ---- Print summary ----
279
+ print("=" * 60)
280
+ print(" ANTIATROPOS FULL TRAINING — HF Jobs")
281
+ print("=" * 60)
282
+ print(f" Image: {DOCKER_IMAGE}")
283
+ print(f" Flavor: {args.flavor}")
284
+ print(f" Timeout: {args.timeout}")
285
+ print(f" Code repo: {args.repo}")
286
+ print(f" Hub model repo: {args.hub_model_repo}")
287
+ print(f" Hub metrics dataset: {args.hub_metrics_dataset}")
288
+ print(f" Run ID: {run_id}")
289
+ print(f" Iterations: {args.num_iterations}")
290
+ print(f" Episodes/iter: {args.num_episodes}")
291
+ print(f" Steps/episode: {args.max_steps}")
292
+ print(f" Eval interval: {args.eval_interval}")
293
+ print(f" Checkpoint interval: {args.checkpoint_interval}")
294
+ print(f" Plot interval: {args.plot_interval}")
295
+ print("=" * 60)
296
+
297
+ # Estimated time
298
+ est_hours = (
299
+ args.num_iterations
300
+ * args.num_episodes
301
+ * args.max_steps
302
+ * 0.04 # ~40ms per step with parallel episodes
303
+ / 3600
304
+ )
305
+ print(f" Est. runtime: ~{est_hours:.1f}h (at 40ms/step)")
306
+ print(f" Est. cost: ~${est_hours * 0.34:.2f} (a10g-large at $0.34/hr)")
307
+ print("=" * 60)
308
+
309
+ if args.dry_run:
310
+ print("\n[DRY RUN] Job command:")
311
+ print(build_job_command())
312
+ print("\n[DRY RUN] To launch manually inside the container:")
313
+ print(
314
+ " python training/train.py \\\n"
315
+ f" --run-id {run_id} \\\n"
316
+ f" --num-iterations {args.num_iterations} \\\n"
317
+ f" --num-episodes {args.num_episodes} \\\n"
318
+ f" --max-steps {args.max_steps} \\\n"
319
+ f" --eval-interval {args.eval_interval} \\\n"
320
+ f" --checkpoint-interval {args.checkpoint_interval} \\\n"
321
+ f" --plot-interval {args.plot_interval}"
322
+ )
323
+ return
324
+
325
+ # ---- Resolve HF_TOKEN ----
326
+ hf_token: Optional[str] = (
327
+ os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
328
+ )
329
+ if not hf_token:
330
+ token_path = os.path.expanduser("~/.cache/huggingface/token")
331
+ if os.path.isfile(token_path):
332
+ with open(token_path) as f:
333
+ hf_token = f.read().strip()
334
+
335
+ secrets_dict: dict = {}
336
+ if not hf_token:
337
+ print("\nWARNING: HF_TOKEN not found. Job will FAIL to push to Hub.")
338
+ print(" Set HF_TOKEN env var or run: huggingface-cli login")
339
+ else:
340
+ secrets_dict = {"HF_TOKEN": hf_token}
341
+
342
+ # ---- Ensure Hub repos exist ----
343
+ if not args.no_create_repos and hf_token:
344
+ ensure_hub_repos(
345
+ args.hub_model_repo, args.hub_metrics_dataset, hf_token
346
+ )
347
+
348
+ # ---- Launch via run_job ----
349
+ try:
350
+ from huggingface_hub import run_job
351
+ except ImportError:
352
+ print("\nERROR: huggingface_hub too old. Run:")
353
+ print(" pip install 'huggingface_hub>=0.25.0'")
354
+ sys.exit(1)
355
+
356
+ job_command = build_job_command().replace("\r", "")
357
+
358
+ print("\nLaunching job...")
359
+ job = run_job(
360
+ image=DOCKER_IMAGE,
361
+ command=["bash", "-c", job_command],
362
+ flavor=args.flavor,
363
+ timeout=args.timeout,
364
+ secrets=secrets_dict,
365
+ env={
366
+ "REPO": args.repo,
367
+ "RUN_ID": run_id,
368
+ "HUB_MODEL_REPO": args.hub_model_repo,
369
+ "HUB_METRICS_DATASET": args.hub_metrics_dataset,
370
+ "NUM_ITERATIONS": str(args.num_iterations),
371
+ "NUM_EPISODES": str(args.num_episodes),
372
+ "MAX_STEPS": str(args.max_steps),
373
+ "EVAL_INTERVAL": str(args.eval_interval),
374
+ "CHECKPOINT_INTERVAL": str(args.checkpoint_interval),
375
+ "PLOT_INTERVAL": str(args.plot_interval),
376
+ },
377
+ )
378
+
379
+ print(f"\nJob launched! ID: {job.id}")
380
+ print(f" Monitor: {job.url}")
381
+ print(f" Logs: hf jobs logs {job.id}")
382
+ print(f" Cancel: hf jobs cancel {job.id}")
383
+
384
+ # ---- Stream logs ----
385
+ print("\nStreaming logs (Ctrl+C to stop watching)...\n")
386
+ try:
387
+ from huggingface_hub import fetch_job_logs, inspect_job
388
+ import time
389
+
390
+ seen = 0
391
+ while True:
392
+ status: Optional[str] = None
393
+ try:
394
+ info = inspect_job(job_id=job.id)
395
+ status = info.status.stage
396
+ except Exception:
397
+ pass
398
+
399
+ try:
400
+ logs = list(fetch_job_logs(job_id=job.id))
401
+ for line in logs[seen:]:
402
+ print(line, end="" if line.endswith("\n") else "\n")
403
+ seen = len(logs)
404
+ except Exception:
405
+ pass
406
+
407
+ if status in ("COMPLETED", "ERROR", "CANCELED"):
408
+ print(f"\nJob finished with status: {status}")
409
+ break
410
+ time.sleep(5)
411
+ except KeyboardInterrupt:
412
+ print("\n\nStopped watching logs. Job still running remotely.")
413
+ print(f" Check status: hf jobs inspect {job.id}")
414
+ print(f" Resume logs: hf jobs logs {job.id}")
415
+
416
+
417
+ if __name__ == "__main__":
418
+ main()
training/openenv_loop.py CHANGED
@@ -260,15 +260,21 @@ def repair_action(action_type: str, target_node_id: str, parameter: float) -> Tu
260
 
261
 
262
  def parse_action(text: str) -> ParsedAction:
263
- """Extract action from model output text."""
 
 
 
 
264
  try:
265
  start = text.find("{")
266
- end = text.rfind("}")
267
- if start == -1 or end == -1 or end < start:
268
  return ParsedAction("NO_OP", "node-0", 0.0, text,
269
  False, "no JSON found")
270
 
271
- obj = json.loads(text[start:end + 1])
 
 
 
272
  at = str(obj.get("action_type", "")).upper()
273
  nid = str(obj.get("target_node_id", "") or "node-0")
274
  param = float(obj.get("parameter") or 0.0)
@@ -281,7 +287,10 @@ def parse_action(text: str) -> ParsedAction:
281
  False, f"invalid target_node_id: {nid}")
282
 
283
  at, nid, param, repair_note = repair_action(at, nid, param)
284
- return ParsedAction(at, nid, param, text, True, repair_note)
 
 
 
285
  except Exception as e:
286
  return ParsedAction("NO_OP", "node-0", 0.0, text, False, str(e))
287
 
 
260
 
261
 
262
  def parse_action(text: str) -> ParsedAction:
263
+ """Extract action from model output text.
264
+
265
+ Uses raw_decode so that extra content after the first JSON object
266
+ (e.g. duplicate actions, trailing text) is silently ignored.
267
+ """
268
  try:
269
  start = text.find("{")
270
+ if start == -1:
 
271
  return ParsedAction("NO_OP", "node-0", 0.0, text,
272
  False, "no JSON found")
273
 
274
+ # Decode only the first complete JSON value (ignore extra data)
275
+ decoder = json.JSONDecoder()
276
+ obj, end_pos = decoder.raw_decode(text, start)
277
+
278
  at = str(obj.get("action_type", "")).upper()
279
  nid = str(obj.get("target_node_id", "") or "node-0")
280
  param = float(obj.get("parameter") or 0.0)
 
287
  False, f"invalid target_node_id: {nid}")
288
 
289
  at, nid, param, repair_note = repair_action(at, nid, param)
290
+ extracted = text[start:end_pos]
291
+ return ParsedAction(at, nid, param, extracted, True, repair_note)
292
+ except json.JSONDecodeError as e:
293
+ return ParsedAction("NO_OP", "node-0", 0.0, text, False, str(e))
294
  except Exception as e:
295
  return ParsedAction("NO_OP", "node-0", 0.0, text, False, str(e))
296