lewtun HF Staff OpenAI Codex commited on
Commit
60474c1
·
unverified ·
1 Parent(s): 0321690

Create Hub artifact collections lazily (#245)

Browse files

* Create Hub artifact collections lazily

Co-authored-by: OpenAI Codex <codex@openai.com>

* Cache lazy Hub collection slugs in child hooks

Co-authored-by: OpenAI Codex <codex@openai.com>

---------

Co-authored-by: OpenAI Codex <codex@openai.com>

agent/core/agent_loop.py CHANGED
@@ -27,7 +27,6 @@ from agent.core.cost_estimation import CostEstimate, estimate_tool_cost
27
  from agent.messaging.gateway import NotificationGateway
28
  from agent.core import telemetry
29
  from agent.core.doom_loop import check_for_doom_loop
30
- from agent.core.hub_artifacts import start_session_artifact_collection_task
31
  from agent.core.llm_params import _resolve_llm_params
32
  from agent.core.prompt_caching import with_prompt_caching
33
  from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
@@ -2024,7 +2023,6 @@ async def submission_loop(
2024
  )
2025
  if session_holder is not None:
2026
  session_holder[0] = session
2027
- start_session_artifact_collection_task(session, token=hf_token)
2028
  logger.info("Agent loop started")
2029
 
2030
  # Retry any failed uploads from previous sessions (fire-and-forget).
 
27
  from agent.messaging.gateway import NotificationGateway
28
  from agent.core import telemetry
29
  from agent.core.doom_loop import check_for_doom_loop
 
30
  from agent.core.llm_params import _resolve_llm_params
31
  from agent.core.prompt_caching import with_prompt_caching
32
  from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
 
2023
  )
2024
  if session_holder is not None:
2025
  session_holder[0] = session
 
2026
  logger.info("Agent loop started")
2027
 
2028
  # Retry any failed uploads from previous sessions (fire-and-forget).
agent/core/hub_artifacts.py CHANGED
@@ -1,6 +1,5 @@
1
  """Best-effort Hub metadata for artifacts generated by ML Intern sessions."""
2
 
3
- import asyncio
4
  import base64
5
  import logging
6
  import re
@@ -11,7 +10,7 @@ from datetime import datetime
11
  from pathlib import Path
12
  from typing import Any
13
 
14
- from huggingface_hub import HfApi, hf_hub_download
15
  from huggingface_hub.repocard import metadata_load, metadata_save
16
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
17
 
@@ -29,7 +28,6 @@ _UUID_SESSION_ID_RE = re.compile(
29
  _KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts"
30
  _REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts"
31
  _COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug"
32
- _COLLECTION_TASK_ATTR = "_ml_intern_artifact_collection_task"
33
  _SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {}
34
  _USAGE_HEADING_RE = re.compile(
35
  r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b",
@@ -307,70 +305,6 @@ def _ensure_collection_slug(
307
  return slug
308
 
309
 
310
- async def ensure_session_artifact_collection(
311
- session: Any,
312
- *,
313
- token: str | bool | None = None,
314
- ) -> str | None:
315
- """Create/cache the per-session artifact collection without raising."""
316
- if session is None or not getattr(session, "session_id", None):
317
- return None
318
- token_value = token if token is not None else getattr(session, "hf_token", None)
319
- if not token_value:
320
- return None
321
-
322
- try:
323
- api = HfApi(token=token_value)
324
- return await asyncio.to_thread(
325
- _ensure_collection_slug,
326
- api,
327
- session,
328
- token=token_value,
329
- )
330
- except Exception as e:
331
- logger.warning(
332
- "ML Intern session collection creation failed for %s: %s",
333
- _safe_session_id(session),
334
- e,
335
- )
336
- return None
337
-
338
-
339
- def start_session_artifact_collection_task(
340
- session: Any,
341
- *,
342
- token: str | bool | None = None,
343
- ) -> asyncio.Task | None:
344
- """Schedule best-effort collection creation for a newly started session."""
345
- if session is None or not getattr(session, "session_id", None):
346
- return None
347
- if getattr(session, _COLLECTION_SLUG_ATTR, None):
348
- return None
349
-
350
- token_value = token if token is not None else getattr(session, "hf_token", None)
351
- if not token_value:
352
- return None
353
-
354
- existing = getattr(session, _COLLECTION_TASK_ATTR, None)
355
- if isinstance(existing, asyncio.Task) and not existing.done():
356
- return existing
357
-
358
- try:
359
- loop = asyncio.get_running_loop()
360
- except RuntimeError:
361
- return None
362
-
363
- async def _run() -> None:
364
- await ensure_session_artifact_collection(session, token=token_value)
365
-
366
- task = loop.create_task(_run())
367
- try:
368
- setattr(session, _COLLECTION_TASK_ATTR, task)
369
- except Exception:
370
- logger.debug("Could not attach ML Intern collection task to session")
371
- return task
372
-
373
-
374
  def _add_to_collection(
375
  api: Any,
376
  session: Any,
@@ -378,10 +312,10 @@ def _add_to_collection(
378
  repo_type: str,
379
  *,
380
  token: str | bool | None = None,
381
- ) -> None:
382
  slug = _ensure_collection_slug(api, session, token=token)
383
  if not slug:
384
- return
385
  api.add_collection_item(
386
  collection_slug=slug,
387
  item_id=repo_id,
@@ -393,6 +327,7 @@ def _add_to_collection(
393
  exists_ok=True,
394
  token=token,
395
  )
 
396
 
397
 
398
  def register_hub_artifact(
@@ -436,8 +371,13 @@ def register_hub_artifact(
436
  logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e)
437
 
438
  try:
439
- _add_to_collection(api, session, repo_id, repo_type, token=token_value)
440
- collection_updated = True
 
 
 
 
 
441
  except Exception as e:
442
  logger.debug("ML Intern collection update failed for %s: %s", repo_id, e)
443
 
@@ -490,6 +430,13 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
490
  re.IGNORECASE | re.MULTILINE,
491
  )
492
  front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL)
 
 
 
 
 
 
 
493
 
494
  def _token(value=None, api=None):
495
  if isinstance(value, str) and value:
@@ -602,6 +549,15 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
602
  nonlocal collection_slug
603
  if collection_slug:
604
  return collection_slug
 
 
 
 
 
 
 
 
 
605
  collection = api.create_collection(
606
  title=collection_title,
607
  description=(
@@ -613,6 +569,13 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
613
  token=token_value,
614
  )
615
  collection_slug = getattr(collection, "slug", None)
 
 
 
 
 
 
 
616
  return collection_slug
617
 
618
  def _register(
@@ -637,6 +600,7 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
637
  try:
638
  token_value = _token(token_value)
639
  api = HfApi(token=token_value)
 
640
  try:
641
  current = _readme(api, repo_id, repo_type, token_value)
642
  updated = _augment(
@@ -652,8 +616,10 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
652
  token=token_value,
653
  commit_message="Update ML Intern artifact metadata",
654
  )
 
655
  except Exception:
656
  pass
 
657
  try:
658
  slug = _ensure_collection(api, token_value)
659
  if slug:
@@ -668,9 +634,11 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
668
  exists_ok=True,
669
  token=token_value,
670
  )
 
671
  except Exception:
672
  pass
673
- registered.add(key)
 
674
  finally:
675
  registering = False
676
 
 
1
  """Best-effort Hub metadata for artifacts generated by ML Intern sessions."""
2
 
 
3
  import base64
4
  import logging
5
  import re
 
10
  from pathlib import Path
11
  from typing import Any
12
 
13
+ from huggingface_hub import hf_hub_download
14
  from huggingface_hub.repocard import metadata_load, metadata_save
15
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
16
 
 
28
  _KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts"
29
  _REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts"
30
  _COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug"
 
31
  _SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {}
32
  _USAGE_HEADING_RE = re.compile(
33
  r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b",
 
305
  return slug
306
 
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  def _add_to_collection(
309
  api: Any,
310
  session: Any,
 
312
  repo_type: str,
313
  *,
314
  token: str | bool | None = None,
315
+ ) -> bool:
316
  slug = _ensure_collection_slug(api, session, token=token)
317
  if not slug:
318
+ return False
319
  api.add_collection_item(
320
  collection_slug=slug,
321
  item_id=repo_id,
 
327
  exists_ok=True,
328
  token=token,
329
  )
330
+ return True
331
 
332
 
333
  def register_hub_artifact(
 
371
  logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e)
372
 
373
  try:
374
+ collection_updated = _add_to_collection(
375
+ api,
376
+ session,
377
+ repo_id,
378
+ repo_type,
379
+ token=token_value,
380
+ )
381
  except Exception as e:
382
  logger.debug("ML Intern collection update failed for %s: %s", repo_id, e)
383
 
 
430
  re.IGNORECASE | re.MULTILINE,
431
  )
432
  front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL)
433
+ collection_cache_path = (
434
+ os.environ.get("ML_INTERN_ARTIFACT_COLLECTION_CACHE")
435
+ or str(
436
+ Path(tempfile.gettempdir())
437
+ / f"ml-intern-artifacts-{{session_id}}.collection"
438
+ )
439
+ )
440
 
441
  def _token(value=None, api=None):
442
  if isinstance(value, str) and value:
 
549
  nonlocal collection_slug
550
  if collection_slug:
551
  return collection_slug
552
+ try:
553
+ cached_slug = Path(collection_cache_path).read_text(
554
+ encoding="utf-8"
555
+ ).strip()
556
+ if cached_slug:
557
+ collection_slug = cached_slug
558
+ return collection_slug
559
+ except Exception:
560
+ pass
561
  collection = api.create_collection(
562
  title=collection_title,
563
  description=(
 
569
  token=token_value,
570
  )
571
  collection_slug = getattr(collection, "slug", None)
572
+ if collection_slug:
573
+ try:
574
+ cache_path = Path(collection_cache_path)
575
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
576
+ cache_path.write_text(collection_slug, encoding="utf-8")
577
+ except Exception:
578
+ pass
579
  return collection_slug
580
 
581
  def _register(
 
600
  try:
601
  token_value = _token(token_value)
602
  api = HfApi(token=token_value)
603
+ card_updated = False
604
  try:
605
  current = _readme(api, repo_id, repo_type, token_value)
606
  updated = _augment(
 
616
  token=token_value,
617
  commit_message="Update ML Intern artifact metadata",
618
  )
619
+ card_updated = True
620
  except Exception:
621
  pass
622
+ collection_updated = False
623
  try:
624
  slug = _ensure_collection(api, token_value)
625
  if slug:
 
634
  exists_ok=True,
635
  token=token_value,
636
  )
637
+ collection_updated = True
638
  except Exception:
639
  pass
640
+ if card_updated and collection_updated:
641
+ registered.add(key)
642
  finally:
643
  registering = False
644
 
backend/session_manager.py CHANGED
@@ -12,7 +12,6 @@ from typing import Any, Optional
12
 
13
  from agent.config import load_config
14
  from agent.core.agent_loop import process_submission
15
- from agent.core.hub_artifacts import start_session_artifact_collection_task
16
  from agent.core.session import Event, OpType, Session
17
  from agent.core.session_persistence import get_session_store
18
  from agent.core.tools import ToolRouter
@@ -136,7 +135,6 @@ class SessionManager:
136
  self.sessions: dict[str, AgentSession] = {}
137
  self._lock = asyncio.Lock()
138
  self.persistence_store = None
139
- self.enable_hub_artifact_collections = True
140
 
141
  async def start(self) -> None:
142
  """Start shared background resources."""
@@ -413,28 +411,6 @@ class SessionManager:
413
  session.sandbox_preload_cancel_event = None
414
  self._start_cpu_sandbox_preload(agent_session)
415
 
416
- def _start_hub_artifact_collection(self, agent_session: AgentSession) -> None:
417
- """Kick off best-effort Hub collection creation for the session."""
418
- if not getattr(self, "enable_hub_artifact_collections", False):
419
- return
420
- session = agent_session.session
421
- if not getattr(session, "session_id", None):
422
- try:
423
- session.session_id = agent_session.session_id
424
- except Exception:
425
- logger.debug("Could not attach session id for Hub artifact collection")
426
- token = agent_session.hf_token or getattr(session, "hf_token", None)
427
- if not token:
428
- return
429
- try:
430
- start_session_artifact_collection_task(session, token=token)
431
- except Exception as e:
432
- logger.debug(
433
- "Failed to schedule Hub artifact collection for %s: %s",
434
- agent_session.session_id,
435
- e,
436
- )
437
-
438
  async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
439
  try:
440
  await self._store().update_session_fields(
@@ -591,7 +567,6 @@ class SessionManager:
591
  existing,
592
  preload_sandbox=preload_sandbox,
593
  )
594
- self._start_hub_artifact_collection(existing)
595
  return existing
596
  return None
597
 
@@ -613,7 +588,6 @@ class SessionManager:
613
  existing,
614
  preload_sandbox=preload_sandbox,
615
  )
616
- self._start_hub_artifact_collection(existing)
617
  return existing
618
  return None
619
 
@@ -700,9 +674,7 @@ class SessionManager:
700
  hf_token=hf_token,
701
  hf_username=hf_username,
702
  )
703
- self._start_hub_artifact_collection(started)
704
  return started
705
- self._start_hub_artifact_collection(agent_session)
706
  if preload_sandbox:
707
  self._start_cpu_sandbox_preload(agent_session)
708
  logger.info("Restored session %s for user %s", session_id, owner or user_id)
@@ -785,7 +757,6 @@ class SessionManager:
785
  event_queue=event_queue,
786
  tool_router=tool_router,
787
  )
788
- self._start_hub_artifact_collection(agent_session)
789
  await self.persist_session_snapshot(agent_session, runtime_state="idle")
790
  self._start_cpu_sandbox_preload(agent_session)
791
 
 
12
 
13
  from agent.config import load_config
14
  from agent.core.agent_loop import process_submission
 
15
  from agent.core.session import Event, OpType, Session
16
  from agent.core.session_persistence import get_session_store
17
  from agent.core.tools import ToolRouter
 
135
  self.sessions: dict[str, AgentSession] = {}
136
  self._lock = asyncio.Lock()
137
  self.persistence_store = None
 
138
 
139
  async def start(self) -> None:
140
  """Start shared background resources."""
 
411
  session.sandbox_preload_cancel_event = None
412
  self._start_cpu_sandbox_preload(agent_session)
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
415
  try:
416
  await self._store().update_session_fields(
 
567
  existing,
568
  preload_sandbox=preload_sandbox,
569
  )
 
570
  return existing
571
  return None
572
 
 
588
  existing,
589
  preload_sandbox=preload_sandbox,
590
  )
 
591
  return existing
592
  return None
593
 
 
674
  hf_token=hf_token,
675
  hf_username=hf_username,
676
  )
 
677
  return started
 
678
  if preload_sandbox:
679
  self._start_cpu_sandbox_preload(agent_session)
680
  logger.info("Restored session %s for user %s", session_id, owner or user_id)
 
757
  event_queue=event_queue,
758
  tool_router=tool_router,
759
  )
 
760
  await self.persist_session_snapshot(agent_session, runtime_state="idle")
761
  self._start_cpu_sandbox_preload(agent_session)
762
 
tests/unit/test_hub_artifacts.py CHANGED
@@ -1,4 +1,3 @@
1
- import asyncio
2
  import logging
3
  from types import SimpleNamespace
4
 
@@ -11,12 +10,10 @@ from agent.core.hub_artifacts import (
11
  artifact_collection_title,
12
  augment_repo_card_content,
13
  build_hub_artifact_sitecustomize,
14
- ensure_session_artifact_collection,
15
  is_known_hub_artifact,
16
  is_sandbox_hub_repo,
17
  register_hub_artifact,
18
  remember_hub_artifact,
19
- start_session_artifact_collection_task,
20
  wrap_shell_command_with_hub_artifact_bootstrap,
21
  )
22
  from agent.tools import local_tools, sandbox_tool
@@ -207,6 +204,7 @@ def test_register_hub_artifact_retries_after_partial_failure(monkeypatch):
207
  def add_to_collection(*args, **kwargs):
208
  nonlocal collection_attempts
209
  collection_attempts += 1
 
210
 
211
  monkeypatch.setattr(
212
  hub_artifacts,
@@ -238,6 +236,7 @@ def test_register_hub_artifact_retries_after_collection_failure(monkeypatch):
238
  collection_attempts += 1
239
  if collection_attempts == 1:
240
  raise RuntimeError("temporary collection failure")
 
241
 
242
  monkeypatch.setattr(hub_artifacts, "_update_repo_card", update_repo_card)
243
  monkeypatch.setattr(
@@ -271,63 +270,6 @@ def test_session_artifact_set_falls_back_when_session_rejects_attrs(caplog):
271
  assert "using process-local fallback state" in caplog.text
272
 
273
 
274
- @pytest.mark.asyncio
275
- async def test_ensure_session_artifact_collection_uses_user_token(monkeypatch):
276
- session = _session()
277
- calls = []
278
-
279
- class FakeApi:
280
- def __init__(self, token):
281
- self.token = token
282
-
283
- def fake_ensure_collection_slug(api, seen_session, **kwargs):
284
- calls.append((api.token, seen_session, kwargs))
285
- return "alice/ml-intern-artifacts"
286
-
287
- monkeypatch.setattr(hub_artifacts, "HfApi", FakeApi)
288
- monkeypatch.setattr(
289
- hub_artifacts,
290
- "_ensure_collection_slug",
291
- fake_ensure_collection_slug,
292
- )
293
-
294
- slug = await ensure_session_artifact_collection(session, token="hf-token")
295
-
296
- assert slug == "alice/ml-intern-artifacts"
297
- assert calls == [
298
- ("hf-token", session, {"token": "hf-token"}),
299
- ]
300
-
301
-
302
- @pytest.mark.asyncio
303
- async def test_start_session_artifact_collection_task_dedupes(monkeypatch):
304
- session = _session()
305
- calls = []
306
-
307
- async def fake_ensure_session_artifact_collection(seen_session, **kwargs):
308
- calls.append((seen_session, kwargs))
309
- await asyncio.sleep(0)
310
- return "alice/ml-intern-artifacts"
311
-
312
- monkeypatch.setattr(
313
- hub_artifacts,
314
- "ensure_session_artifact_collection",
315
- fake_ensure_session_artifact_collection,
316
- )
317
-
318
- task = start_session_artifact_collection_task(session, token="hf-token")
319
- second = start_session_artifact_collection_task(session, token="hf-token")
320
-
321
- assert task is not None
322
- assert second is task
323
- await task
324
- assert calls == [(session, {"token": "hf-token"})]
325
-
326
-
327
- def test_start_session_artifact_collection_task_skips_without_token():
328
- assert start_session_artifact_collection_task(_session()) is None
329
-
330
-
331
  @pytest.mark.asyncio
332
  async def test_hf_repo_git_create_repo_registers_artifact(monkeypatch):
333
  session = _session()
@@ -535,6 +477,78 @@ def test_sitecustomize_bootstrap_reuses_existing_collection_slug():
535
  )
536
 
537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
539
  import huggingface_hub as hub
540
  from huggingface_hub import HfApi
 
 
1
  import logging
2
  from types import SimpleNamespace
3
 
 
10
  artifact_collection_title,
11
  augment_repo_card_content,
12
  build_hub_artifact_sitecustomize,
 
13
  is_known_hub_artifact,
14
  is_sandbox_hub_repo,
15
  register_hub_artifact,
16
  remember_hub_artifact,
 
17
  wrap_shell_command_with_hub_artifact_bootstrap,
18
  )
19
  from agent.tools import local_tools, sandbox_tool
 
204
  def add_to_collection(*args, **kwargs):
205
  nonlocal collection_attempts
206
  collection_attempts += 1
207
+ return True
208
 
209
  monkeypatch.setattr(
210
  hub_artifacts,
 
236
  collection_attempts += 1
237
  if collection_attempts == 1:
238
  raise RuntimeError("temporary collection failure")
239
+ return True
240
 
241
  monkeypatch.setattr(hub_artifacts, "_update_repo_card", update_repo_card)
242
  monkeypatch.setattr(
 
270
  assert "using process-local fallback state" in caplog.text
271
 
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  @pytest.mark.asyncio
274
  async def test_hf_repo_git_create_repo_registers_artifact(monkeypatch):
275
  session = _session()
 
477
  )
478
 
479
 
480
+ def test_sitecustomize_caches_lazy_collection_slug_across_bootstraps(
481
+ monkeypatch,
482
+ tmp_path,
483
+ ):
484
+ import huggingface_hub as hub
485
+ from huggingface_hub import HfApi
486
+
487
+ readme_path = tmp_path / "README.md"
488
+ readme_path.write_text("# Existing Model\n", encoding="utf-8")
489
+ cache_path = tmp_path / "collection-slug.txt"
490
+ collection_slug = "alice/ml-intern-artifacts-2026-05-05-session-123"
491
+ uploads = []
492
+ downloads = []
493
+ collection_creates = []
494
+ collection_items = []
495
+
496
+ def fake_upload_file(self, **kwargs):
497
+ uploads.append(kwargs)
498
+ return SimpleNamespace()
499
+
500
+ def fake_hf_hub_download(*args, **kwargs):
501
+ downloads.append((args, kwargs))
502
+ return str(readme_path)
503
+
504
+ def fake_create_collection(self, **kwargs):
505
+ collection_creates.append(kwargs)
506
+ return SimpleNamespace(slug=collection_slug)
507
+
508
+ def fake_add_collection_item(self, **kwargs):
509
+ collection_items.append(kwargs)
510
+
511
+ monkeypatch.setenv("ML_INTERN_ARTIFACT_COLLECTION_CACHE", str(cache_path))
512
+ code = build_hub_artifact_sitecustomize(_session())
513
+
514
+ def install_fresh_bootstrap():
515
+ monkeypatch.setattr(HfApi, "upload_file", fake_upload_file)
516
+ monkeypatch.setattr(HfApi, "create_collection", fake_create_collection)
517
+ monkeypatch.setattr(HfApi, "add_collection_item", fake_add_collection_item)
518
+ monkeypatch.setattr(hub, "hf_hub_download", fake_hf_hub_download)
519
+ exec(code, {})
520
+ assert HfApi.upload_file is not fake_upload_file
521
+
522
+ install_fresh_bootstrap()
523
+ HfApi(token="hf-token").upload_file(
524
+ path_or_fileobj=b"weights",
525
+ path_in_repo="model.safetensors",
526
+ repo_id="alice/model-a",
527
+ repo_type="model",
528
+ token="hf-token",
529
+ )
530
+
531
+ install_fresh_bootstrap()
532
+ HfApi(token="hf-token").upload_file(
533
+ path_or_fileobj=b"weights",
534
+ path_in_repo="model.safetensors",
535
+ repo_id="alice/model-b",
536
+ repo_type="model",
537
+ token="hf-token",
538
+ )
539
+
540
+ assert cache_path.read_text(encoding="utf-8") == collection_slug
541
+ assert len(collection_creates) == 1
542
+ assert [item["item_id"] for item in collection_items] == [
543
+ "alice/model-a",
544
+ "alice/model-b",
545
+ ]
546
+ assert [download[1]["repo_id"] for download in downloads] == [
547
+ "alice/model-a",
548
+ "alice/model-b",
549
+ ]
550
+
551
+
552
  def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
553
  import huggingface_hub as hub
554
  from huggingface_hub import HfApi
tests/unit/test_session_manager_persistence.py CHANGED
@@ -425,32 +425,9 @@ async def test_create_session_schedules_cpu_sandbox_preload():
425
 
426
  assert scheduled == [session_id]
427
  assert session_id in manager.sessions
428
- finally:
429
- stop.set()
430
- await _cancel_runtime_tasks(manager)
431
-
432
-
433
- @pytest.mark.asyncio
434
- async def test_create_session_starts_hub_artifact_collection(monkeypatch):
435
- manager = _manager_with_store(NoopSessionStore())
436
- manager.enable_hub_artifact_collections = True
437
- stop = _install_fake_runtime(manager)
438
- started: list[tuple[str, str]] = []
439
-
440
- def fake_start_session_artifact_collection_task(session, **kwargs):
441
- started.append((session.session_id, kwargs["token"]))
442
- return None
443
-
444
- monkeypatch.setattr(
445
- "session_manager.start_session_artifact_collection_task",
446
- fake_start_session_artifact_collection_task,
447
- )
448
- manager._start_cpu_sandbox_preload = lambda _: None # type: ignore[method-assign]
449
-
450
- try:
451
- session_id = await manager.create_session(user_id="owner", hf_token="token")
452
-
453
- assert started == [(session_id, "token")]
454
  finally:
455
  stop.set()
456
  await _cancel_runtime_tasks(manager)
@@ -475,37 +452,8 @@ async def test_lazy_restore_schedules_cpu_sandbox_preload():
475
  assert restored is not None
476
  assert scheduled == ["persisted-session"]
477
  assert "persisted-session" in manager.sessions
478
- finally:
479
- stop.set()
480
- await _cancel_runtime_tasks(manager)
481
-
482
-
483
- @pytest.mark.asyncio
484
- async def test_lazy_restore_starts_hub_artifact_collection(monkeypatch):
485
- manager = _manager_with_store(RestoreStore())
486
- manager.enable_hub_artifact_collections = True
487
- stop = _install_fake_runtime(manager)
488
- started: list[tuple[str, str]] = []
489
-
490
- def fake_start_session_artifact_collection_task(session, **kwargs):
491
- started.append((session.session_id, kwargs["token"]))
492
- return None
493
-
494
- monkeypatch.setattr(
495
- "session_manager.start_session_artifact_collection_task",
496
- fake_start_session_artifact_collection_task,
497
- )
498
- manager._start_cpu_sandbox_preload = lambda _: None # type: ignore[method-assign]
499
-
500
- try:
501
- restored = await manager.ensure_session_loaded(
502
- "persisted-session",
503
- user_id="owner",
504
- hf_token="token",
505
- )
506
-
507
- assert restored is not None
508
- assert started == [("persisted-session", "token")]
509
  finally:
510
  stop.set()
511
  await _cancel_runtime_tasks(manager)
 
425
 
426
  assert scheduled == [session_id]
427
  assert session_id in manager.sessions
428
+ runtime_session = manager.sessions[session_id].session
429
+ assert not hasattr(runtime_session, "_ml_intern_artifact_collection_task")
430
+ assert not hasattr(runtime_session, "_ml_intern_artifact_collection_slug")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  finally:
432
  stop.set()
433
  await _cancel_runtime_tasks(manager)
 
452
  assert restored is not None
453
  assert scheduled == ["persisted-session"]
454
  assert "persisted-session" in manager.sessions
455
+ assert not hasattr(restored.session, "_ml_intern_artifact_collection_task")
456
+ assert not hasattr(restored.session, "_ml_intern_artifact_collection_slug")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  finally:
458
  stop.set()
459
  await _cancel_runtime_tasks(manager)