icarus112 commited on
Commit
861dd6c
·
verified ·
1 Parent(s): d39539e

Upload entrypoint.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. entrypoint.py +291 -0
entrypoint.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import os
6
+ import subprocess
7
+ import sys
8
+ import time
9
+ from http.server import BaseHTTPRequestHandler, HTTPServer
10
+ from pathlib import Path
11
+ from threading import Thread
12
+
13
+
14
+ # =============================================================================
15
+ # EARLY CUDA FABRIC MANAGER KICK (before ANY CUDA-touching imports)
16
+ # =============================================================================
17
+ # On H200 hosts, cudaGetDeviceCount can return Error 802 "system not yet
18
+ # initialized" on first use, because nvidia-fabricmanager on the host
19
+ # synchronizes with the container's first driver call. Once any NVML/CUDA
20
+ # call succeeds once (even just nvidia-smi), the fabric is up for the rest
21
+ # of the container lifetime.
22
+ #
23
+ # Our previous approach (wait in a subprocess before training) didn't work
24
+ # because the "initialization failed" state persisted across calls in the
25
+ # same container. The real fix: kick the driver exactly once with
26
+ # nvidia-smi, which is what successfully-working baseline containers do
27
+ # implicitly via their first torch.cuda call.
28
+ #
29
+ # Must happen BEFORE `import torch` (because any import that eagerly calls
30
+ # cudaGetDeviceCount will cache the Error 802 state).
31
+ def _early_cuda_kick() -> None:
32
+ deadline = time.time() + 120.0
33
+ attempt = 0
34
+ while time.time() < deadline:
35
+ attempt += 1
36
+ r = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=30)
37
+ if r.returncode == 0 and 'H200' in (r.stdout or '') or 'H100' in (r.stdout or '') \
38
+ or 'A100' in (r.stdout or '') or r.returncode == 0:
39
+ print(f'[boot] nvidia-smi OK on attempt {attempt}', flush=True)
40
+ break
41
+ print(f'[boot] nvidia-smi attempt {attempt} rc={r.returncode} stderr={(r.stderr or "")[:120]}',
42
+ flush=True)
43
+ time.sleep(2)
44
+ # After nvidia-smi, probe torch in a subprocess so any latent error state
45
+ # doesn't leak into the main process's CUDA context.
46
+ probe = 'import torch; import sys; sys.exit(0 if torch.cuda.is_available() else 1)'
47
+ torch_deadline = time.time() + 120.0
48
+ t_attempt = 0
49
+ while time.time() < torch_deadline:
50
+ t_attempt += 1
51
+ r = subprocess.run([sys.executable, '-c', probe], capture_output=True, text=True, timeout=60)
52
+ if r.returncode == 0:
53
+ print(f'[boot] torch.cuda.is_available() = True after {t_attempt} probe(s)', flush=True)
54
+ return
55
+ if t_attempt == 1:
56
+ print(f'[boot] torch cuda probe {t_attempt}: {(r.stderr or "")[:200]}', flush=True)
57
+ time.sleep(2)
58
+ print('[boot] WARNING: torch.cuda never became ready — training will likely fail', flush=True)
59
+
60
+
61
+ _early_cuda_kick()
62
+
63
+ # Hydrate triton compilation cache from HF Hub before any triton/mamba_ssm import.
64
+ # triton_cache_setup.py is copied next to this file by the job bash command.
65
+ try:
66
+ import triton_cache_setup as _tcs
67
+ _tcs.setup()
68
+ except ImportError:
69
+ print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
70
+
71
+ from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
72
+
73
+ REPO_ROOT = Path('/workspace/feather')
74
+ CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
75
+ LOG_FILE = REPO_ROOT / 'run_domain_expanded.log'
76
+ JOB_ID = os.environ.get('JOB_ID', 'local-job')
77
+ OUTPUT_REPO = os.environ.get('HF_REPO_ID', 'icarus112/feather-pretrain-checkpoints')
78
+ TOKEN = os.environ.get('HF_TOKEN')
79
+ RUNTIME_MODE = os.environ.get('FEATHER_RUNTIME_MODE', 'space')
80
+ APP_PORT = int(os.environ.get('PORT', '7860'))
81
+
82
+
83
+ class _HealthHandler(BaseHTTPRequestHandler):
84
+ def do_GET(self):
85
+ if self.path in ('/', '/health', '/healthz', '/ready'):
86
+ payload = {
87
+ 'status': 'ok',
88
+ 'mode': RUNTIME_MODE,
89
+ 'job_id': JOB_ID,
90
+ }
91
+ body = json.dumps(payload).encode('utf-8')
92
+ self.send_response(200)
93
+ self.send_header('Content-Type', 'application/json')
94
+ self.send_header('Content-Length', str(len(body)))
95
+ self.end_headers()
96
+ self.wfile.write(body)
97
+ return
98
+ self.send_response(404)
99
+ self.end_headers()
100
+
101
+ def log_message(self, format, *args):
102
+ return
103
+
104
+
105
+ def _start_health_server() -> HTTPServer:
106
+ server = HTTPServer(('0.0.0.0', APP_PORT), _HealthHandler)
107
+ thread = Thread(target=server.serve_forever, daemon=True)
108
+ thread.start()
109
+ print(f'[space] health server listening on 0.0.0.0:{APP_PORT}', flush=True)
110
+ return server
111
+
112
+
113
+ def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
114
+ if not path.exists():
115
+ print(f'[upload] skip missing {path}', flush=True)
116
+ return
117
+ api.upload_file(
118
+ path_or_fileobj=str(path),
119
+ path_in_repo=dest,
120
+ repo_id=OUTPUT_REPO,
121
+ repo_type='model',
122
+ )
123
+ print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
124
+
125
+
126
+ def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
127
+ """Block until CUDA is fully initialized or timeout.
128
+
129
+ On H200 hosts with NVSwitch/fabric manager, nvidia driver setup can race
130
+ with container start. cudaGetDeviceCount can return CUDA_ERROR_SYSTEM_NOT_READY
131
+ (error 802) for the first few seconds, and any import that triggers
132
+ @triton.autotune (e.g. mamba_ssm, torch amp utilities) blows up with
133
+ "0 active drivers" if it happens during that window.
134
+
135
+ We pre-init CUDA in a throwaway Python subprocess (so any error state does
136
+ not leak into the main training process) and retry until torch.cuda
137
+ reports ready.
138
+ """
139
+ import time as _t
140
+ probe = (
141
+ "import torch; "
142
+ "import sys; "
143
+ "avail = torch.cuda.is_available(); "
144
+ "count = torch.cuda.device_count() if avail else 0; "
145
+ "sys.exit(0 if (avail and count > 0) else 1)"
146
+ )
147
+ deadline = _t.time() + timeout_s
148
+ attempt = 0
149
+ while _t.time() < deadline:
150
+ attempt += 1
151
+ r = subprocess.run(['python', '-c', probe], capture_output=True, text=True)
152
+ if r.returncode == 0:
153
+ print(f'[job] CUDA ready after {attempt} probe(s)', flush=True)
154
+ return
155
+ if attempt == 1:
156
+ print(f'[job] CUDA not ready yet (will retry up to {timeout_s}s): {r.stderr.strip()[:200]}', flush=True)
157
+ _t.sleep(2)
158
+ print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
159
+
160
+
161
+ def _truthy_env(name: str, default: str = '0') -> bool:
162
+ return os.environ.get(name, default).strip().lower() in {'1', 'true', 'yes', 'on'}
163
+
164
+
165
+ def _check_training_artifacts_ready() -> tuple[bool, bool]:
166
+ """Return whether metrics and final checkpoints are visible to the job wrapper."""
167
+ metrics_seen = False
168
+ if LOG_FILE.exists():
169
+ try:
170
+ tail = LOG_FILE.read_text(errors='replace')[-20000:]
171
+ metrics_seen = '[METRICS_JSON]' in tail or '[METRICS] wrote' in tail
172
+ except OSError:
173
+ metrics_seen = False
174
+ checkpoints_ready = (CACHE_ROOT / 'latest.pt').exists() and (CACHE_ROOT / 'pretrain_final.pt').exists()
175
+ return metrics_seen, checkpoints_ready
176
+
177
+
178
+ def _run_training_subprocess(cmd: list[str]) -> int:
179
+ """Run training, optionally stopping after metrics/checkpoints for clean upload.
180
+
181
+ Full-corpus streaming can leave dataset downloader worker threads alive during
182
+ Python finalization after useful metrics/checkpoints have already been written.
183
+ On HF Jobs this may keep the job RUNNING or flip it to ERROR before the
184
+ entrypoint uploads artifacts. The watcher preserves the completed canary by
185
+ terminating the train subprocess once the metrics/checkpoint contract is met.
186
+ """
187
+ if not _truthy_env('FEATHER_HF_EXIT_AFTER_METRICS', '1'):
188
+ return subprocess.run(cmd, check=False).returncode
189
+
190
+ proc = subprocess.Popen(cmd)
191
+ metrics_seen = False
192
+ checkpoints_ready = False
193
+ while proc.poll() is None:
194
+ metrics_seen, checkpoints_ready = _check_training_artifacts_ready()
195
+ if metrics_seen and checkpoints_ready:
196
+ print('[job] metrics/checkpoints observed; terminating training subprocess for clean artifact upload', flush=True)
197
+ proc.terminate()
198
+ try:
199
+ proc.wait(timeout=30)
200
+ except subprocess.TimeoutExpired:
201
+ print('[job] training subprocess did not terminate cleanly; killing it', flush=True)
202
+ proc.kill()
203
+ proc.wait(timeout=30)
204
+ return 0
205
+ time.sleep(5)
206
+
207
+ metrics_seen, checkpoints_ready = _check_training_artifacts_ready()
208
+ if proc.returncode != 0 and metrics_seen and checkpoints_ready:
209
+ print(
210
+ f'[job] training subprocess exited rc={proc.returncode} after writing metrics/checkpoints; treating canary as successful for upload',
211
+ flush=True,
212
+ )
213
+ return 0
214
+ return int(proc.returncode or 0)
215
+
216
+
217
+ def run_job_mode() -> int:
218
+ os.chdir(REPO_ROOT)
219
+
220
+ # Dynamic live patch from GitHub to bypass Space build errors
221
+ GIT_REF = os.environ.get('FEATHER_GIT_REF')
222
+ if GIT_REF:
223
+ print(f'[bootstrap] dynamic sync to {GIT_REF}...', flush=True)
224
+ subprocess.run(['git', 'fetch', 'origin'], cwd=REPO_ROOT, check=False)
225
+ subprocess.run(['git', 'checkout', GIT_REF], cwd=REPO_ROOT, check=False)
226
+
227
+ os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
228
+ os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
229
+ os.environ.setdefault('HYDRA_DOWNLOAD_WORKERS', '16')
230
+ os.environ.setdefault('HYDRA_CKPT_INTERVAL', '1000')
231
+ os.environ.setdefault('HYDRA_RESUME_CKPT', str(CACHE_ROOT / 'latest.pt'))
232
+
233
+ # CUDA readiness was kicked at module import via _early_cuda_kick. Keep
234
+ # the wait as a second safety net — no-op if CUDA already ready.
235
+ _wait_for_cuda_ready()
236
+
237
+ cmd = [
238
+ 'bash',
239
+ './scripts/run_domain_expanded_pretrain.sh',
240
+ '--target-shards', os.environ['HYDRA_TARGET_SHARDS'],
241
+ '--download-workers', os.environ['HYDRA_DOWNLOAD_WORKERS'],
242
+ ]
243
+ print('[job] starting Feather domain-expanded pretrain', flush=True)
244
+ print(f'[job] command={cmd}', flush=True)
245
+ proc_returncode = _run_training_subprocess(cmd)
246
+
247
+ # Push triton compilation cache back to HF Hub for next run.
248
+ try:
249
+ import triton_cache_setup as _tcs
250
+ _tcs.teardown()
251
+ except Exception as _tcs_err:
252
+ print(f'[triton_cache] teardown error (non-fatal): {_tcs_err}', flush=True)
253
+
254
+ if TOKEN:
255
+ api = HfApi(token=TOKEN)
256
+ try:
257
+ api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True)
258
+ except Exception as e:
259
+ print(f'[upload] create_repo warning: {type(e).__name__}: {e}', flush=True)
260
+ prefix = f'jobs/{JOB_ID}'
261
+ try:
262
+ upload_artifact(api, LOG_FILE, f'{prefix}/run_domain_expanded.log')
263
+ upload_artifact(api, CACHE_ROOT / 'latest.pt', f'{prefix}/latest.pt')
264
+ upload_artifact(api, CACHE_ROOT / 'pretrain_final.pt', f'{prefix}/pretrain_final.pt')
265
+ except Exception as e:
266
+ print(f'[upload] upload warning: {type(e).__name__}: {e}', flush=True)
267
+ else:
268
+ print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
269
+
270
+ return proc_returncode
271
+
272
+
273
+ def run_space_mode() -> int:
274
+ server = _start_health_server()
275
+ print('[space] Feather runtime image ready', flush=True)
276
+ try:
277
+ while True:
278
+ time.sleep(3600)
279
+ finally:
280
+ server.shutdown()
281
+ server.server_close()
282
+
283
+
284
+ def main() -> int:
285
+ if RUNTIME_MODE == 'job':
286
+ return run_job_mode()
287
+ return run_space_mode()
288
+
289
+
290
+ if __name__ == '__main__':
291
+ raise SystemExit(main())