Michael Rabinovich commited on
Commit ·
dd284f3
1
Parent(s): 5fb3ebc
fix HF Jobs polling namespace
Browse files- submit.py +22 -3
- tests/test_submit.py +70 -0
submit.py
CHANGED
|
@@ -298,6 +298,11 @@ def _shard_bucket_root(submission_id: str) -> Path:
|
|
| 298 |
return Path(SHARD_BUCKET_MOUNT) / _shard_bucket_relative_root(submission_id)
|
| 299 |
|
| 300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
def _with_hub_retries(fn, *, what: str):
|
| 302 |
"""Run *fn* (a Hub commit) retrying transient HTTP errors with backoff.
|
| 303 |
|
|
@@ -1287,7 +1292,11 @@ def _poll_shards_until_done(
|
|
| 1287 |
for shard_id in running:
|
| 1288 |
st = shards[shard_id]
|
| 1289 |
try:
|
| 1290 |
-
info = inspect_job(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1291 |
consecutive_errors = 0
|
| 1292 |
except Exception as e: # noqa: BLE001 - retry transient API errors
|
| 1293 |
consecutive_errors += 1
|
|
@@ -1359,7 +1368,11 @@ def _poll_until_done(
|
|
| 1359 |
last_stage: str | None = None
|
| 1360 |
while True:
|
| 1361 |
try:
|
| 1362 |
-
info = inspect_job(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1363 |
consecutive_errors = 0
|
| 1364 |
except Exception as e: # noqa: BLE001 - retry transient API errors
|
| 1365 |
consecutive_errors += 1
|
|
@@ -1408,7 +1421,13 @@ def _job_failure_reason(
|
|
| 1408 |
if status_message:
|
| 1409 |
parts.append(status_message)
|
| 1410 |
try:
|
| 1411 |
-
tail = list(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1412 |
if tail:
|
| 1413 |
parts.append("logs: " + " | ".join(tail))
|
| 1414 |
except Exception as e: # noqa: BLE001 - logs are best-effort
|
|
|
|
| 298 |
return Path(SHARD_BUCKET_MOUNT) / _shard_bucket_relative_root(submission_id)
|
| 299 |
|
| 300 |
|
| 301 |
+
def _jobs_token() -> str | None:
|
| 302 |
+
"""Token used for HF Jobs control-plane calls."""
|
| 303 |
+
return os.environ.get("HF_TOKEN")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
def _with_hub_retries(fn, *, what: str):
|
| 307 |
"""Run *fn* (a Hub commit) retrying transient HTTP errors with backoff.
|
| 308 |
|
|
|
|
| 1292 |
for shard_id in running:
|
| 1293 |
st = shards[shard_id]
|
| 1294 |
try:
|
| 1295 |
+
info = inspect_job(
|
| 1296 |
+
job_id=st["job_id"],
|
| 1297 |
+
namespace=EVAL_JOB_NAMESPACE,
|
| 1298 |
+
token=_jobs_token(),
|
| 1299 |
+
)
|
| 1300 |
consecutive_errors = 0
|
| 1301 |
except Exception as e: # noqa: BLE001 - retry transient API errors
|
| 1302 |
consecutive_errors += 1
|
|
|
|
| 1368 |
last_stage: str | None = None
|
| 1369 |
while True:
|
| 1370 |
try:
|
| 1371 |
+
info = inspect_job(
|
| 1372 |
+
job_id=job_id,
|
| 1373 |
+
namespace=EVAL_JOB_NAMESPACE,
|
| 1374 |
+
token=_jobs_token(),
|
| 1375 |
+
)
|
| 1376 |
consecutive_errors = 0
|
| 1377 |
except Exception as e: # noqa: BLE001 - retry transient API errors
|
| 1378 |
consecutive_errors += 1
|
|
|
|
| 1421 |
if status_message:
|
| 1422 |
parts.append(status_message)
|
| 1423 |
try:
|
| 1424 |
+
tail = list(
|
| 1425 |
+
fetch_job_logs(
|
| 1426 |
+
job_id=job_id,
|
| 1427 |
+
namespace=EVAL_JOB_NAMESPACE,
|
| 1428 |
+
token=_jobs_token(),
|
| 1429 |
+
)
|
| 1430 |
+
)[-JOB_LOG_TAIL_LINES:]
|
| 1431 |
if tail:
|
| 1432 |
parts.append("logs: " + " | ".join(tail))
|
| 1433 |
except Exception as e: # noqa: BLE001 - logs are best-effort
|
tests/test_submit.py
CHANGED
|
@@ -194,6 +194,76 @@ def test_eval_job_stages_shard_to_mounted_bucket(tmp_path: Path, monkeypatch):
|
|
| 194 |
assert (staged / "101" / "result.json").read_text(encoding="utf-8") == "{}"
|
| 195 |
|
| 196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
def _stub_meta() -> dict:
|
| 198 |
"""Minimum meta.json shape that survives ``_load_and_validate_meta``."""
|
| 199 |
return {
|
|
|
|
| 194 |
assert (staged / "101" / "result.json").read_text(encoding="utf-8") == "{}"
|
| 195 |
|
| 196 |
|
| 197 |
+
def test_poll_until_done_uses_jobs_namespace_and_token(monkeypatch):
|
| 198 |
+
"""Polling must target the namespace where Jobs were dispatched."""
|
| 199 |
+
captured: dict = {}
|
| 200 |
+
|
| 201 |
+
def fake_inspect_job(**kwargs):
|
| 202 |
+
captured.update(kwargs)
|
| 203 |
+
return SimpleNamespace(
|
| 204 |
+
status=SimpleNamespace(stage="COMPLETED", message=None),
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
monkeypatch.setenv("HF_TOKEN", "hf_test")
|
| 208 |
+
monkeypatch.setattr(submit, "inspect_job", fake_inspect_job)
|
| 209 |
+
|
| 210 |
+
assert submit._poll_until_done("job-123", "sub-1") == ("COMPLETED", None)
|
| 211 |
+
assert captured == {
|
| 212 |
+
"job_id": "job-123",
|
| 213 |
+
"namespace": submit.EVAL_JOB_NAMESPACE,
|
| 214 |
+
"token": "hf_test",
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def test_shard_poll_uses_jobs_namespace_and_token(monkeypatch):
|
| 219 |
+
"""Sharded polling uses the same Jobs namespace/token as dispatch."""
|
| 220 |
+
captured: dict = {}
|
| 221 |
+
|
| 222 |
+
def fake_inspect_job(**kwargs):
|
| 223 |
+
captured.update(kwargs)
|
| 224 |
+
return SimpleNamespace(
|
| 225 |
+
status=SimpleNamespace(stage="COMPLETED", message=None),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
monkeypatch.setenv("HF_TOKEN", "hf_test")
|
| 229 |
+
monkeypatch.setattr(submit, "inspect_job", fake_inspect_job)
|
| 230 |
+
monkeypatch.setattr(submit.time, "sleep", lambda *_: None)
|
| 231 |
+
|
| 232 |
+
failures = submit._poll_shards_until_done(
|
| 233 |
+
"sub-1",
|
| 234 |
+
"https://example.test/sub-1.zip",
|
| 235 |
+
{"shard_000": {"job_id": "job-123", "stage": None, "message": None}},
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
assert failures == []
|
| 239 |
+
assert captured == {
|
| 240 |
+
"job_id": "job-123",
|
| 241 |
+
"namespace": submit.EVAL_JOB_NAMESPACE,
|
| 242 |
+
"token": "hf_test",
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def test_job_failure_reason_fetches_logs_with_namespace_and_token(monkeypatch):
|
| 247 |
+
"""Failure diagnostics fetch logs from the same Jobs namespace."""
|
| 248 |
+
captured: dict = {}
|
| 249 |
+
|
| 250 |
+
def fake_fetch_job_logs(**kwargs):
|
| 251 |
+
captured.update(kwargs)
|
| 252 |
+
return ["line 1\n", "line 2\n"]
|
| 253 |
+
|
| 254 |
+
monkeypatch.setenv("HF_TOKEN", "hf_test")
|
| 255 |
+
monkeypatch.setattr(submit, "fetch_job_logs", fake_fetch_job_logs)
|
| 256 |
+
|
| 257 |
+
reason = submit._job_failure_reason("job-123", "ERROR", "boom")
|
| 258 |
+
|
| 259 |
+
assert "line 2" in reason
|
| 260 |
+
assert captured == {
|
| 261 |
+
"job_id": "job-123",
|
| 262 |
+
"namespace": submit.EVAL_JOB_NAMESPACE,
|
| 263 |
+
"token": "hf_test",
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
|
| 267 |
def _stub_meta() -> dict:
|
| 268 |
"""Minimum meta.json shape that survives ``_load_and_validate_meta``."""
|
| 269 |
return {
|