Michael Rabinovich commited on
Commit
dd284f3
·
1 Parent(s): 5fb3ebc

fix HF Jobs polling namespace

Browse files
Files changed (2) hide show
  1. submit.py +22 -3
  2. tests/test_submit.py +70 -0
submit.py CHANGED
@@ -298,6 +298,11 @@ def _shard_bucket_root(submission_id: str) -> Path:
298
  return Path(SHARD_BUCKET_MOUNT) / _shard_bucket_relative_root(submission_id)
299
 
300
 
 
 
 
 
 
301
  def _with_hub_retries(fn, *, what: str):
302
  """Run *fn* (a Hub commit) retrying transient HTTP errors with backoff.
303
 
@@ -1287,7 +1292,11 @@ def _poll_shards_until_done(
1287
  for shard_id in running:
1288
  st = shards[shard_id]
1289
  try:
1290
- info = inspect_job(job_id=st["job_id"])
 
 
 
 
1291
  consecutive_errors = 0
1292
  except Exception as e: # noqa: BLE001 - retry transient API errors
1293
  consecutive_errors += 1
@@ -1359,7 +1368,11 @@ def _poll_until_done(
1359
  last_stage: str | None = None
1360
  while True:
1361
  try:
1362
- info = inspect_job(job_id=job_id)
 
 
 
 
1363
  consecutive_errors = 0
1364
  except Exception as e: # noqa: BLE001 - retry transient API errors
1365
  consecutive_errors += 1
@@ -1408,7 +1421,13 @@ def _job_failure_reason(
1408
  if status_message:
1409
  parts.append(status_message)
1410
  try:
1411
- tail = list(fetch_job_logs(job_id=job_id))[-JOB_LOG_TAIL_LINES:]
 
 
 
 
 
 
1412
  if tail:
1413
  parts.append("logs: " + " | ".join(tail))
1414
  except Exception as e: # noqa: BLE001 - logs are best-effort
 
298
  return Path(SHARD_BUCKET_MOUNT) / _shard_bucket_relative_root(submission_id)
299
 
300
 
301
+ def _jobs_token() -> str | None:
302
+ """Token used for HF Jobs control-plane calls."""
303
+ return os.environ.get("HF_TOKEN")
304
+
305
+
306
  def _with_hub_retries(fn, *, what: str):
307
  """Run *fn* (a Hub commit) retrying transient HTTP errors with backoff.
308
 
 
1292
  for shard_id in running:
1293
  st = shards[shard_id]
1294
  try:
1295
+ info = inspect_job(
1296
+ job_id=st["job_id"],
1297
+ namespace=EVAL_JOB_NAMESPACE,
1298
+ token=_jobs_token(),
1299
+ )
1300
  consecutive_errors = 0
1301
  except Exception as e: # noqa: BLE001 - retry transient API errors
1302
  consecutive_errors += 1
 
1368
  last_stage: str | None = None
1369
  while True:
1370
  try:
1371
+ info = inspect_job(
1372
+ job_id=job_id,
1373
+ namespace=EVAL_JOB_NAMESPACE,
1374
+ token=_jobs_token(),
1375
+ )
1376
  consecutive_errors = 0
1377
  except Exception as e: # noqa: BLE001 - retry transient API errors
1378
  consecutive_errors += 1
 
1421
  if status_message:
1422
  parts.append(status_message)
1423
  try:
1424
+ tail = list(
1425
+ fetch_job_logs(
1426
+ job_id=job_id,
1427
+ namespace=EVAL_JOB_NAMESPACE,
1428
+ token=_jobs_token(),
1429
+ )
1430
+ )[-JOB_LOG_TAIL_LINES:]
1431
  if tail:
1432
  parts.append("logs: " + " | ".join(tail))
1433
  except Exception as e: # noqa: BLE001 - logs are best-effort
tests/test_submit.py CHANGED
@@ -194,6 +194,76 @@ def test_eval_job_stages_shard_to_mounted_bucket(tmp_path: Path, monkeypatch):
194
  assert (staged / "101" / "result.json").read_text(encoding="utf-8") == "{}"
195
 
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  def _stub_meta() -> dict:
198
  """Minimum meta.json shape that survives ``_load_and_validate_meta``."""
199
  return {
 
194
  assert (staged / "101" / "result.json").read_text(encoding="utf-8") == "{}"
195
 
196
 
197
+ def test_poll_until_done_uses_jobs_namespace_and_token(monkeypatch):
198
+ """Polling must target the namespace where Jobs were dispatched."""
199
+ captured: dict = {}
200
+
201
+ def fake_inspect_job(**kwargs):
202
+ captured.update(kwargs)
203
+ return SimpleNamespace(
204
+ status=SimpleNamespace(stage="COMPLETED", message=None),
205
+ )
206
+
207
+ monkeypatch.setenv("HF_TOKEN", "hf_test")
208
+ monkeypatch.setattr(submit, "inspect_job", fake_inspect_job)
209
+
210
+ assert submit._poll_until_done("job-123", "sub-1") == ("COMPLETED", None)
211
+ assert captured == {
212
+ "job_id": "job-123",
213
+ "namespace": submit.EVAL_JOB_NAMESPACE,
214
+ "token": "hf_test",
215
+ }
216
+
217
+
218
+ def test_shard_poll_uses_jobs_namespace_and_token(monkeypatch):
219
+ """Sharded polling uses the same Jobs namespace/token as dispatch."""
220
+ captured: dict = {}
221
+
222
+ def fake_inspect_job(**kwargs):
223
+ captured.update(kwargs)
224
+ return SimpleNamespace(
225
+ status=SimpleNamespace(stage="COMPLETED", message=None),
226
+ )
227
+
228
+ monkeypatch.setenv("HF_TOKEN", "hf_test")
229
+ monkeypatch.setattr(submit, "inspect_job", fake_inspect_job)
230
+ monkeypatch.setattr(submit.time, "sleep", lambda *_: None)
231
+
232
+ failures = submit._poll_shards_until_done(
233
+ "sub-1",
234
+ "https://example.test/sub-1.zip",
235
+ {"shard_000": {"job_id": "job-123", "stage": None, "message": None}},
236
+ )
237
+
238
+ assert failures == []
239
+ assert captured == {
240
+ "job_id": "job-123",
241
+ "namespace": submit.EVAL_JOB_NAMESPACE,
242
+ "token": "hf_test",
243
+ }
244
+
245
+
246
+ def test_job_failure_reason_fetches_logs_with_namespace_and_token(monkeypatch):
247
+ """Failure diagnostics fetch logs from the same Jobs namespace."""
248
+ captured: dict = {}
249
+
250
+ def fake_fetch_job_logs(**kwargs):
251
+ captured.update(kwargs)
252
+ return ["line 1\n", "line 2\n"]
253
+
254
+ monkeypatch.setenv("HF_TOKEN", "hf_test")
255
+ monkeypatch.setattr(submit, "fetch_job_logs", fake_fetch_job_logs)
256
+
257
+ reason = submit._job_failure_reason("job-123", "ERROR", "boom")
258
+
259
+ assert "line 2" in reason
260
+ assert captured == {
261
+ "job_id": "job-123",
262
+ "namespace": submit.EVAL_JOB_NAMESPACE,
263
+ "token": "hf_test",
264
+ }
265
+
266
+
267
  def _stub_meta() -> dict:
268
  """Minimum meta.json shape that survives ``_load_and_validate_meta``."""
269
  return {