lewtun HF Staff OpenAI Codex commited on
Commit
4fc6e96
·
unverified ·
1 Parent(s): 754345f

Add ML Intern Hub artifact metadata (#225)

Browse files

* Add ML Intern Hub artifact metadata

Co-authored-by: OpenAI Codex <codex@openai.com>

* Extend Hub artifact hooks to sandbox bash

Co-authored-by: OpenAI Codex <codex@openai.com>

* Create Hub artifact collections at session start

Co-authored-by: OpenAI Codex <codex@openai.com>

* Address Hub artifact PR review comments

Co-authored-by: OpenAI Codex <codex@openai.com>

* Respect Hub collection title limits

Co-authored-by: OpenAI Codex <codex@openai.com>

* Shorten collection session UUID fragments

Co-authored-by: OpenAI Codex <codex@openai.com>

---------

Co-authored-by: OpenAI Codex <codex@openai.com>

agent/core/agent_loop.py CHANGED
@@ -26,6 +26,7 @@ from agent.core.cost_estimation import CostEstimate, estimate_tool_cost
26
  from agent.messaging.gateway import NotificationGateway
27
  from agent.core import telemetry
28
  from agent.core.doom_loop import check_for_doom_loop
 
29
  from agent.core.llm_params import _resolve_llm_params
30
  from agent.core.prompt_caching import with_prompt_caching
31
  from agent.core.session import Event, OpType, Session
@@ -1998,6 +1999,7 @@ async def submission_loop(
1998
  )
1999
  if session_holder is not None:
2000
  session_holder[0] = session
 
2001
  logger.info("Agent loop started")
2002
 
2003
  # Retry any failed uploads from previous sessions (fire-and-forget).
 
26
  from agent.messaging.gateway import NotificationGateway
27
  from agent.core import telemetry
28
  from agent.core.doom_loop import check_for_doom_loop
29
+ from agent.core.hub_artifacts import start_session_artifact_collection_task
30
  from agent.core.llm_params import _resolve_llm_params
31
  from agent.core.prompt_caching import with_prompt_caching
32
  from agent.core.session import Event, OpType, Session
 
1999
  )
2000
  if session_holder is not None:
2001
  session_holder[0] = session
2002
+ start_session_artifact_collection_task(session, token=hf_token)
2003
  logger.info("Agent loop started")
2004
 
2005
  # Retry any failed uploads from previous sessions (fire-and-forget).
agent/core/hub_artifacts.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Best-effort Hub metadata for artifacts generated by ML Intern sessions."""
2
+
3
+ import asyncio
4
+ import base64
5
+ import logging
6
+ import re
7
+ import shlex
8
+ import tempfile
9
+ import textwrap
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ from huggingface_hub import HfApi, hf_hub_download
15
+ from huggingface_hub.repocard import metadata_load, metadata_save
16
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ ML_INTERN_TAG = "ml-intern"
21
+ SUPPORTED_REPO_TYPES = {"model", "dataset", "space"}
22
+ PROVENANCE_MARKER = "<!-- ml-intern-provenance -->"
23
+ _COLLECTION_TITLE_PREFIX = "ml-intern-artifacts"
24
+ _COLLECTION_TITLE_MAX_LENGTH = 59
25
+ _UUID_SESSION_ID_RE = re.compile(
26
+ r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
27
+ r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
28
+ )
29
+ _KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts"
30
+ _REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts"
31
+ _COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug"
32
+ _COLLECTION_TASK_ATTR = "_ml_intern_artifact_collection_task"
33
+ _SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {}
34
+ _USAGE_HEADING_RE = re.compile(
35
+ r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b",
36
+ re.IGNORECASE | re.MULTILINE,
37
+ )
38
+ _FRONT_MATTER_RE = re.compile(r"\A---\s*\n.*?\n---\s*\n?", re.DOTALL)
39
+
40
+
41
+ def _safe_session_id(session: Any) -> str:
42
+ raw = str(getattr(session, "session_id", "") or "unknown-session")
43
+ safe = re.sub(r"[^A-Za-z0-9._-]+", "-", raw).strip("-")
44
+ return safe or "unknown-session"
45
+
46
+
47
+ def session_artifact_date(session: Any) -> str:
48
+ """Return the YYYY-MM-DD partition date for a session."""
49
+ raw = getattr(session, "session_start_time", None)
50
+ if raw:
51
+ try:
52
+ return datetime.fromisoformat(str(raw).replace("Z", "+00:00")).strftime(
53
+ "%Y-%m-%d"
54
+ )
55
+ except ValueError:
56
+ logger.debug("Could not parse session_start_time=%r", raw)
57
+ return datetime.utcnow().strftime("%Y-%m-%d")
58
+
59
+
60
+ def _collection_session_id_fragment(session: Any) -> str:
61
+ safe_id = _safe_session_id(session)
62
+ if _UUID_SESSION_ID_RE.match(safe_id):
63
+ return safe_id[:8]
64
+ stem = f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-"
65
+ max_id_length = max(1, _COLLECTION_TITLE_MAX_LENGTH - len(stem))
66
+ if len(safe_id) <= max_id_length:
67
+ return safe_id
68
+ return safe_id[:max_id_length].rstrip("-._") or safe_id[:max_id_length]
69
+
70
+
71
+ def artifact_collection_title(session: Any) -> str:
72
+ return (
73
+ f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-"
74
+ f"{_collection_session_id_fragment(session)}"
75
+ )
76
+
77
+
78
+ def _artifact_key(repo_id: str, repo_type: str | None) -> str:
79
+ return f"{repo_type or 'model'}:{repo_id}"
80
+
81
+
82
+ def _session_artifact_set(session: Any, attr: str) -> set[str]:
83
+ current = getattr(session, attr, None)
84
+ if isinstance(current, set):
85
+ return current
86
+ current = set()
87
+ try:
88
+ setattr(session, attr, current)
89
+ except Exception:
90
+ logger.warning(
91
+ "Could not attach %s to session; using process-local fallback state",
92
+ attr,
93
+ )
94
+ return _SESSION_ARTIFACT_SET_FALLBACK.setdefault((id(session), attr), set())
95
+ return current
96
+
97
+
98
+ def remember_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> None:
99
+ if session is None or not repo_id:
100
+ return
101
+ _session_artifact_set(session, _KNOWN_ARTIFACTS_ATTR).add(
102
+ _artifact_key(repo_id, repo_type)
103
+ )
104
+
105
+
106
+ def is_known_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> bool:
107
+ if session is None or not repo_id:
108
+ return False
109
+ return _artifact_key(repo_id, repo_type) in _session_artifact_set(
110
+ session, _KNOWN_ARTIFACTS_ATTR
111
+ )
112
+
113
+
114
+ def _merge_tags(metadata: dict[str, Any], tag: str = ML_INTERN_TAG) -> dict[str, Any]:
115
+ merged = dict(metadata)
116
+ raw_tags = merged.get("tags")
117
+ if raw_tags is None:
118
+ tags: list[str] = []
119
+ elif isinstance(raw_tags, str):
120
+ tags = [raw_tags]
121
+ elif isinstance(raw_tags, list):
122
+ tags = [str(item) for item in raw_tags]
123
+ else:
124
+ tags = [str(raw_tags)]
125
+
126
+ if tag not in tags:
127
+ tags.append(tag)
128
+ merged["tags"] = tags
129
+ return merged
130
+
131
+
132
+ def _metadata_from_content(content: str) -> dict[str, Any]:
133
+ with tempfile.TemporaryDirectory() as tmp_dir:
134
+ path = Path(tmp_dir) / "README.md"
135
+ path.write_text(content, encoding="utf-8")
136
+ return metadata_load(path) or {}
137
+
138
+
139
+ def _content_with_metadata(content: str, metadata: dict[str, Any]) -> str:
140
+ with tempfile.TemporaryDirectory() as tmp_dir:
141
+ path = Path(tmp_dir) / "README.md"
142
+ path.write_text(content, encoding="utf-8")
143
+ metadata_save(path, metadata)
144
+ return path.read_text(encoding="utf-8")
145
+
146
+
147
+ def _body_without_metadata(content: str) -> str:
148
+ return _FRONT_MATTER_RE.sub("", content, count=1).strip()
149
+
150
+
151
+ def _append_section(content: str, section: str) -> str:
152
+ base = content.rstrip()
153
+ if base:
154
+ return f"{base}\n\n{section.strip()}\n"
155
+ return f"{section.strip()}\n"
156
+
157
+
158
+ def _provenance_section(repo_type: str) -> str:
159
+ label = {"model": "model", "dataset": "dataset"}.get(repo_type, "Hub")
160
+ return f"""{PROVENANCE_MARKER}
161
+ ## Generated by ML Intern
162
+
163
+ This {label} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
164
+
165
+ - Try ML Intern: https://smolagents-ml-intern.hf.space
166
+ - Source code: https://github.com/huggingface/ml-intern
167
+ """
168
+
169
+
170
+ def _usage_section(repo_id: str, repo_type: str) -> str:
171
+ if repo_type == "dataset":
172
+ return f"""## Usage
173
+
174
+ ```python
175
+ from datasets import load_dataset
176
+
177
+ dataset = load_dataset("{repo_id}")
178
+ ```
179
+ """
180
+
181
+ return f"""## Usage
182
+
183
+ ```python
184
+ from transformers import AutoModelForCausalLM, AutoTokenizer
185
+
186
+ model_id = "{repo_id}"
187
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
188
+ model = AutoModelForCausalLM.from_pretrained(model_id)
189
+ ```
190
+
191
+ For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.
192
+ """
193
+
194
+
195
+ def augment_repo_card_content(
196
+ content: str | None,
197
+ repo_id: str,
198
+ repo_type: str = "model",
199
+ *,
200
+ extra_metadata: dict[str, Any] | None = None,
201
+ ) -> str:
202
+ """Return README content with ML Intern metadata and provenance added."""
203
+ repo_type = repo_type or "model"
204
+ content = content or ""
205
+ metadata = _metadata_from_content(content)
206
+ if extra_metadata:
207
+ metadata = {**extra_metadata, **metadata}
208
+ metadata = _merge_tags(metadata)
209
+ updated = _content_with_metadata(content, metadata)
210
+
211
+ if not _body_without_metadata(updated):
212
+ updated = _append_section(updated, f"# {repo_id}")
213
+
214
+ if repo_type in {"model", "dataset"} and PROVENANCE_MARKER not in updated:
215
+ updated = _append_section(updated, _provenance_section(repo_type))
216
+ if not _USAGE_HEADING_RE.search(content):
217
+ updated = _append_section(updated, _usage_section(repo_id, repo_type))
218
+
219
+ return updated
220
+
221
+
222
+ def _read_remote_readme(
223
+ api: Any,
224
+ repo_id: str,
225
+ repo_type: str,
226
+ *,
227
+ token: str | bool | None = None,
228
+ ) -> str:
229
+ token_value = token if token is not None else getattr(api, "token", None)
230
+ try:
231
+ readme_path = hf_hub_download(
232
+ repo_id=repo_id,
233
+ filename="README.md",
234
+ repo_type=repo_type,
235
+ token=token_value,
236
+ )
237
+ except (EntryNotFoundError, RepositoryNotFoundError):
238
+ return ""
239
+ return Path(readme_path).read_text(encoding="utf-8")
240
+
241
+
242
+ def _update_repo_card(
243
+ api: Any,
244
+ repo_id: str,
245
+ repo_type: str,
246
+ *,
247
+ token: str | bool | None = None,
248
+ extra_metadata: dict[str, Any] | None = None,
249
+ ) -> None:
250
+ current = _read_remote_readme(api, repo_id, repo_type, token=token)
251
+ updated = augment_repo_card_content(
252
+ current,
253
+ repo_id,
254
+ repo_type,
255
+ extra_metadata=extra_metadata,
256
+ )
257
+ if updated == current:
258
+ return
259
+ api.upload_file(
260
+ path_or_fileobj=updated.encode("utf-8"),
261
+ path_in_repo="README.md",
262
+ repo_id=repo_id,
263
+ repo_type=repo_type,
264
+ token=token,
265
+ commit_message="Update ML Intern artifact metadata",
266
+ )
267
+
268
+
269
+ def _ensure_collection_slug(
270
+ api: Any,
271
+ session: Any,
272
+ *,
273
+ token: str | bool | None = None,
274
+ ) -> str | None:
275
+ slug = getattr(session, _COLLECTION_SLUG_ATTR, None)
276
+ if slug:
277
+ return slug
278
+
279
+ title = artifact_collection_title(session)
280
+ collection = api.create_collection(
281
+ title=title,
282
+ description=(
283
+ f"Artifacts generated by ML Intern session {_safe_session_id(session)} "
284
+ f"on {session_artifact_date(session)}."
285
+ ),
286
+ private=True,
287
+ exists_ok=True,
288
+ token=token,
289
+ )
290
+ slug = getattr(collection, "slug", None)
291
+ if slug:
292
+ setattr(session, _COLLECTION_SLUG_ATTR, slug)
293
+ return slug
294
+
295
+
296
+ async def ensure_session_artifact_collection(
297
+ session: Any,
298
+ *,
299
+ token: str | bool | None = None,
300
+ ) -> str | None:
301
+ """Create/cache the per-session artifact collection without raising."""
302
+ if session is None or not getattr(session, "session_id", None):
303
+ return None
304
+ token_value = token if token is not None else getattr(session, "hf_token", None)
305
+ if not token_value:
306
+ return None
307
+
308
+ try:
309
+ api = HfApi(token=token_value)
310
+ return await asyncio.to_thread(
311
+ _ensure_collection_slug,
312
+ api,
313
+ session,
314
+ token=token_value,
315
+ )
316
+ except Exception as e:
317
+ logger.warning(
318
+ "ML Intern session collection creation failed for %s: %s",
319
+ _safe_session_id(session),
320
+ e,
321
+ )
322
+ return None
323
+
324
+
325
+ def start_session_artifact_collection_task(
326
+ session: Any,
327
+ *,
328
+ token: str | bool | None = None,
329
+ ) -> asyncio.Task | None:
330
+ """Schedule best-effort collection creation for a newly started session."""
331
+ if session is None or not getattr(session, "session_id", None):
332
+ return None
333
+ if getattr(session, _COLLECTION_SLUG_ATTR, None):
334
+ return None
335
+
336
+ token_value = token if token is not None else getattr(session, "hf_token", None)
337
+ if not token_value:
338
+ return None
339
+
340
+ existing = getattr(session, _COLLECTION_TASK_ATTR, None)
341
+ if isinstance(existing, asyncio.Task) and not existing.done():
342
+ return existing
343
+
344
+ try:
345
+ loop = asyncio.get_running_loop()
346
+ except RuntimeError:
347
+ return None
348
+
349
+ async def _run() -> None:
350
+ await ensure_session_artifact_collection(session, token=token_value)
351
+
352
+ task = loop.create_task(_run())
353
+ try:
354
+ setattr(session, _COLLECTION_TASK_ATTR, task)
355
+ except Exception:
356
+ logger.debug("Could not attach ML Intern collection task to session")
357
+ return task
358
+
359
+
360
+ def _add_to_collection(
361
+ api: Any,
362
+ session: Any,
363
+ repo_id: str,
364
+ repo_type: str,
365
+ *,
366
+ token: str | bool | None = None,
367
+ ) -> None:
368
+ slug = _ensure_collection_slug(api, session, token=token)
369
+ if not slug:
370
+ return
371
+ api.add_collection_item(
372
+ collection_slug=slug,
373
+ item_id=repo_id,
374
+ item_type=repo_type,
375
+ note=(
376
+ f"Generated by ML Intern session {_safe_session_id(session)} "
377
+ f"on {session_artifact_date(session)}."
378
+ ),
379
+ exists_ok=True,
380
+ token=token,
381
+ )
382
+
383
+
384
+ def register_hub_artifact(
385
+ api: Any,
386
+ repo_id: str,
387
+ repo_type: str = "model",
388
+ *,
389
+ session: Any = None,
390
+ token: str | bool | None = None,
391
+ extra_metadata: dict[str, Any] | None = None,
392
+ force: bool = False,
393
+ ) -> bool:
394
+ """Tag, card, and collection-register a Hub artifact without raising."""
395
+ if session is None or not repo_id:
396
+ return False
397
+ repo_type = repo_type or "model"
398
+ if repo_type not in SUPPORTED_REPO_TYPES:
399
+ return False
400
+
401
+ key = _artifact_key(repo_id, repo_type)
402
+ remember_hub_artifact(session, repo_id, repo_type)
403
+ registered = _session_artifact_set(session, _REGISTERED_ARTIFACTS_ATTR)
404
+ if key in registered and not force:
405
+ return True
406
+
407
+ token_value = token if token is not None else getattr(api, "token", None)
408
+ card_updated = False
409
+ collection_updated = False
410
+ try:
411
+ _update_repo_card(
412
+ api,
413
+ repo_id,
414
+ repo_type,
415
+ token=token_value,
416
+ extra_metadata=extra_metadata,
417
+ )
418
+ card_updated = True
419
+ except Exception as e:
420
+ logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e)
421
+
422
+ try:
423
+ _add_to_collection(api, session, repo_id, repo_type, token=token_value)
424
+ collection_updated = True
425
+ except Exception as e:
426
+ logger.debug("ML Intern collection update failed for %s: %s", repo_id, e)
427
+
428
+ if card_updated and collection_updated:
429
+ registered.add(key)
430
+ return True
431
+ return False
432
+
433
+
434
+ def build_hub_artifact_sitecustomize(session: Any) -> str:
435
+ """Build standalone sitecustomize.py code for HF Jobs Python processes."""
436
+ if session is None or not getattr(session, "session_id", None):
437
+ return ""
438
+
439
+ session_id = _safe_session_id(session)
440
+ session_date = session_artifact_date(session)
441
+ collection_title = artifact_collection_title(session)
442
+ collection_slug = getattr(session, _COLLECTION_SLUG_ATTR, None)
443
+
444
+ return (
445
+ textwrap.dedent(
446
+ f"""
447
+ # Auto-generated by ML Intern. Best-effort Hub artifact metadata only.
448
+ def _install_ml_intern_artifact_hooks():
449
+ import os
450
+ import re
451
+ import tempfile
452
+ from pathlib import Path
453
+
454
+ try:
455
+ import huggingface_hub as _hub
456
+ from huggingface_hub import HfApi, hf_hub_download
457
+ from huggingface_hub.repocard import metadata_load, metadata_save
458
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
459
+ except Exception:
460
+ return
461
+
462
+ session_id = {session_id!r}
463
+ session_date = {session_date!r}
464
+ collection_title = {collection_title!r}
465
+ tag = {ML_INTERN_TAG!r}
466
+ marker = {PROVENANCE_MARKER!r}
467
+ supported = {sorted(SUPPORTED_REPO_TYPES)!r}
468
+ registering = False
469
+ collection_slug = {collection_slug!r}
470
+ registered = set()
471
+ usage_re = re.compile(
472
+ r"^#{{2,6}}\\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\\b",
473
+ re.IGNORECASE | re.MULTILINE,
474
+ )
475
+ front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL)
476
+
477
+ def _token(value=None, api=None):
478
+ if isinstance(value, str) and value:
479
+ return value
480
+ api_token = getattr(api, "token", None)
481
+ if isinstance(api_token, str) and api_token:
482
+ return api_token
483
+ return (
484
+ os.environ.get("HF_TOKEN")
485
+ or os.environ.get("HUGGINGFACE_HUB_TOKEN")
486
+ or None
487
+ )
488
+
489
+ def _merge_tags(metadata):
490
+ metadata = dict(metadata or {{}})
491
+ raw_tags = metadata.get("tags")
492
+ if raw_tags is None:
493
+ tags = []
494
+ elif isinstance(raw_tags, str):
495
+ tags = [raw_tags]
496
+ elif isinstance(raw_tags, list):
497
+ tags = [str(item) for item in raw_tags]
498
+ else:
499
+ tags = [str(raw_tags)]
500
+ if tag not in tags:
501
+ tags.append(tag)
502
+ metadata["tags"] = tags
503
+ return metadata
504
+
505
+ def _metadata_from_content(content):
506
+ with tempfile.TemporaryDirectory() as tmp_dir:
507
+ path = Path(tmp_dir) / "README.md"
508
+ path.write_text(content or "", encoding="utf-8")
509
+ return metadata_load(path) or {{}}
510
+
511
+ def _content_with_metadata(content, metadata):
512
+ with tempfile.TemporaryDirectory() as tmp_dir:
513
+ path = Path(tmp_dir) / "README.md"
514
+ path.write_text(content or "", encoding="utf-8")
515
+ metadata_save(path, metadata)
516
+ return path.read_text(encoding="utf-8")
517
+
518
+ def _body_without_metadata(content):
519
+ return front_matter_re.sub("", content or "", count=1).strip()
520
+
521
+ def _append_section(content, section):
522
+ base = (content or "").rstrip()
523
+ if base:
524
+ return base + "\\n\\n" + section.strip() + "\\n"
525
+ return section.strip() + "\\n"
526
+
527
+ def _provenance(repo_type):
528
+ label = {{"model": "model", "dataset": "dataset"}}.get(
529
+ repo_type, "Hub"
530
+ )
531
+ return (
532
+ marker
533
+ + "\\n## Generated by ML Intern\\n\\n"
534
+ + f"This {{label}} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.\\n\\n"
535
+ + "- Try ML Intern: https://smolagents-ml-intern.hf.space\\n"
536
+ + "- Source code: https://github.com/huggingface/ml-intern\\n"
537
+ )
538
+
539
+ def _usage(repo_id, repo_type):
540
+ if repo_type == "dataset":
541
+ return (
542
+ "## Usage\\n\\n"
543
+ "```python\\n"
544
+ "from datasets import load_dataset\\n\\n"
545
+ f"dataset = load_dataset({{repo_id!r}})\\n"
546
+ "```\\n"
547
+ )
548
+ return (
549
+ "## Usage\\n\\n"
550
+ "```python\\n"
551
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\\n\\n"
552
+ f"model_id = {{repo_id!r}}\\n"
553
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\\n"
554
+ "model = AutoModelForCausalLM.from_pretrained(model_id)\\n"
555
+ "```\\n\\n"
556
+ "For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.\\n"
557
+ )
558
+
559
+ def _augment(content, repo_id, repo_type, extra_metadata=None):
560
+ metadata = _metadata_from_content(content or "")
561
+ if extra_metadata:
562
+ metadata = {{**extra_metadata, **metadata}}
563
+ updated = _content_with_metadata(content or "", _merge_tags(metadata))
564
+ if not _body_without_metadata(updated):
565
+ updated = _append_section(updated, f"# {{repo_id}}")
566
+ if repo_type in {{"model", "dataset"}} and marker not in updated:
567
+ updated = _append_section(updated, _provenance(repo_type))
568
+ if not usage_re.search(content or ""):
569
+ updated = _append_section(updated, _usage(repo_id, repo_type))
570
+ return updated
571
+
572
+ def _readme(api, repo_id, repo_type, token_value):
573
+ try:
574
+ path = hf_hub_download(
575
+ repo_id=repo_id,
576
+ filename="README.md",
577
+ repo_type=repo_type,
578
+ token=token_value,
579
+ )
580
+ except (EntryNotFoundError, RepositoryNotFoundError):
581
+ return ""
582
+ return Path(path).read_text(encoding="utf-8")
583
+
584
+ def _ensure_collection(api, token_value):
585
+ nonlocal collection_slug
586
+ if collection_slug:
587
+ return collection_slug
588
+ collection = api.create_collection(
589
+ title=collection_title,
590
+ description=(
591
+ f"Artifacts generated by ML Intern session {{session_id}} "
592
+ f"on {{session_date}}."
593
+ ),
594
+ private=True,
595
+ exists_ok=True,
596
+ token=token_value,
597
+ )
598
+ collection_slug = getattr(collection, "slug", None)
599
+ return collection_slug
600
+
601
+ def _register(
602
+ repo_id,
603
+ repo_type="model",
604
+ token_value=None,
605
+ extra_metadata=None,
606
+ force=False,
607
+ ):
608
+ nonlocal registering
609
+ if registering or not repo_id:
610
+ return
611
+ repo_type = repo_type or "model"
612
+ if repo_type not in supported:
613
+ return
614
+ key = f"{{repo_type}}:{{repo_id}}"
615
+ if key in registered and not force:
616
+ return
617
+ registering = True
618
+ try:
619
+ token_value = _token(token_value)
620
+ api = HfApi(token=token_value)
621
+ try:
622
+ current = _readme(api, repo_id, repo_type, token_value)
623
+ updated = _augment(
624
+ current, repo_id, repo_type, extra_metadata=extra_metadata
625
+ )
626
+ if updated != current:
627
+ _original_upload_file(
628
+ api,
629
+ path_or_fileobj=updated.encode("utf-8"),
630
+ path_in_repo="README.md",
631
+ repo_id=repo_id,
632
+ repo_type=repo_type,
633
+ token=token_value,
634
+ commit_message="Update ML Intern artifact metadata",
635
+ )
636
+ except Exception:
637
+ pass
638
+ try:
639
+ slug = _ensure_collection(api, token_value)
640
+ if slug:
641
+ api.add_collection_item(
642
+ collection_slug=slug,
643
+ item_id=repo_id,
644
+ item_type=repo_type,
645
+ note=(
646
+ f"Generated by ML Intern session {{session_id}} "
647
+ f"on {{session_date}}."
648
+ ),
649
+ exists_ok=True,
650
+ token=token_value,
651
+ )
652
+ except Exception:
653
+ pass
654
+ registered.add(key)
655
+ finally:
656
+ registering = False
657
+
658
+ _original_create_repo = HfApi.create_repo
659
+ _original_upload_file = HfApi.upload_file
660
+ _original_upload_folder = getattr(HfApi, "upload_folder", None)
661
+ _original_create_commit = getattr(HfApi, "create_commit", None)
662
+
663
+ def _repo_id(args, kwargs):
664
+ return kwargs.get("repo_id") or (args[0] if args else None)
665
+
666
+ def _repo_type(kwargs):
667
+ return kwargs.get("repo_type") or "model"
668
+
669
+ def _patched_create_repo(self, *args, **kwargs):
670
+ result = _original_create_repo(self, *args, **kwargs)
671
+ repo_id = _repo_id(args, kwargs)
672
+ repo_type = _repo_type(kwargs)
673
+ extra = None
674
+ if repo_type == "space" and kwargs.get("space_sdk"):
675
+ extra = {{"sdk": kwargs.get("space_sdk")}}
676
+ _register(repo_id, repo_type, _token(kwargs.get("token"), self), extra)
677
+ return result
678
+
679
+ def _patched_upload_file(self, *args, **kwargs):
680
+ result = _original_upload_file(self, *args, **kwargs)
681
+ if not kwargs.get("create_pr"):
682
+ force = kwargs.get("path_in_repo") == "README.md"
683
+ _register(
684
+ kwargs.get("repo_id"),
685
+ _repo_type(kwargs),
686
+ _token(kwargs.get("token"), self),
687
+ force=force,
688
+ )
689
+ return result
690
+
691
+ def _patched_upload_folder(self, *args, **kwargs):
692
+ result = _original_upload_folder(self, *args, **kwargs)
693
+ if not kwargs.get("create_pr"):
694
+ _register(
695
+ kwargs.get("repo_id"),
696
+ _repo_type(kwargs),
697
+ _token(kwargs.get("token"), self),
698
+ force=True,
699
+ )
700
+ return result
701
+
702
+ def _patched_create_commit(self, *args, **kwargs):
703
+ result = _original_create_commit(self, *args, **kwargs)
704
+ if not kwargs.get("create_pr"):
705
+ _register(
706
+ _repo_id(args, kwargs),
707
+ _repo_type(kwargs),
708
+ _token(kwargs.get("token"), self),
709
+ force=True,
710
+ )
711
+ return result
712
+
713
+ HfApi.create_repo = _patched_create_repo
714
+ HfApi.upload_file = _patched_upload_file
715
+ if _original_upload_folder is not None:
716
+ HfApi.upload_folder = _patched_upload_folder
717
+ if _original_create_commit is not None:
718
+ HfApi.create_commit = _patched_create_commit
719
+
720
+ def _patch_module_func(name, method_name):
721
+ original = getattr(_hub, name, None)
722
+ if original is None:
723
+ return
724
+ method = getattr(HfApi, method_name)
725
+
726
+ def _patched(*args, **kwargs):
727
+ api = HfApi(token=_token(kwargs.get("token")))
728
+ return method(api, *args, **kwargs)
729
+
730
+ setattr(_hub, name, _patched)
731
+
732
+ _patch_module_func("create_repo", "create_repo")
733
+ _patch_module_func("upload_file", "upload_file")
734
+ if _original_upload_folder is not None:
735
+ _patch_module_func("upload_folder", "upload_folder")
736
+ if _original_create_commit is not None:
737
+ _patch_module_func("create_commit", "create_commit")
738
+
739
+ try:
740
+ _install_ml_intern_artifact_hooks()
741
+ except Exception:
742
+ pass
743
+ """
744
+ ).strip()
745
+ + "\n"
746
+ )
747
+
748
+
749
+ def wrap_shell_command_with_hub_artifact_bootstrap(
750
+ command: str,
751
+ session: Any,
752
+ ) -> str:
753
+ """Prefix a shell command so child Python processes load Hub hooks."""
754
+ sitecustomize = build_hub_artifact_sitecustomize(session)
755
+ if not sitecustomize or not command:
756
+ return command
757
+
758
+ encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii")
759
+ bootstrap = (
760
+ '_ml_intern_artifacts_dir="$(mktemp -d 2>/dev/null)" '
761
+ f"&& printf %s {shlex.quote(encoded)} | base64 -d "
762
+ '> "$_ml_intern_artifacts_dir/sitecustomize.py" '
763
+ '&& export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"'
764
+ )
765
+ return f"{bootstrap}; {command}"
agent/tools/hf_repo_files_tool.py CHANGED
@@ -10,6 +10,7 @@ from typing import Any, Dict, Literal, Optional
10
  from huggingface_hub import HfApi, hf_hub_download
11
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
12
 
 
13
  from agent.tools.types import ToolResult
14
 
15
  OperationType = Literal["list", "read", "upload", "delete"]
@@ -39,8 +40,9 @@ def _format_size(size_bytes: int) -> str:
39
  class HfRepoFilesTool:
40
  """Tool for file operations on HF repos."""
41
 
42
- def __init__(self, hf_token: Optional[str] = None):
43
  self.api = HfApi(token=hf_token)
 
44
 
45
  async def execute(self, args: Dict[str, Any]) -> ToolResult:
46
  """Execute the specified operation."""
@@ -214,6 +216,16 @@ class HfRepoFilesTool:
214
  create_pr=create_pr,
215
  )
216
 
 
 
 
 
 
 
 
 
 
 
217
  url = _build_repo_url(repo_id, repo_type)
218
  if create_pr and hasattr(result, "pr_url"):
219
  response = f"**Uploaded as PR**\n{result.pr_url}"
@@ -343,7 +355,7 @@ async def hf_repo_files_handler(
343
  """Handler for agent tool router."""
344
  try:
345
  hf_token = session.hf_token if session else None
346
- tool = HfRepoFilesTool(hf_token=hf_token)
347
  result = await tool.execute(arguments)
348
  return result["formatted"], not result.get("isError", False)
349
  except Exception as e:
 
10
  from huggingface_hub import HfApi, hf_hub_download
11
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
12
 
13
+ from agent.core.hub_artifacts import is_known_hub_artifact, register_hub_artifact
14
  from agent.tools.types import ToolResult
15
 
16
  OperationType = Literal["list", "read", "upload", "delete"]
 
40
  class HfRepoFilesTool:
41
  """Tool for file operations on HF repos."""
42
 
43
+ def __init__(self, hf_token: Optional[str] = None, session: Any = None):
44
  self.api = HfApi(token=hf_token)
45
+ self.session = session
46
 
47
  async def execute(self, args: Dict[str, Any]) -> ToolResult:
48
  """Execute the specified operation."""
 
216
  create_pr=create_pr,
217
  )
218
 
219
+ if not create_pr and is_known_hub_artifact(self.session, repo_id, repo_type):
220
+ await _async_call(
221
+ register_hub_artifact,
222
+ self.api,
223
+ repo_id,
224
+ repo_type,
225
+ session=self.session,
226
+ force=path == "README.md",
227
+ )
228
+
229
  url = _build_repo_url(repo_id, repo_type)
230
  if create_pr and hasattr(result, "pr_url"):
231
  response = f"**Uploaded as PR**\n{result.pr_url}"
 
355
  """Handler for agent tool router."""
356
  try:
357
  hf_token = session.hf_token if session else None
358
+ tool = HfRepoFilesTool(hf_token=hf_token, session=session)
359
  result = await tool.execute(arguments)
360
  return result["formatted"], not result.get("isError", False)
361
  except Exception as e:
agent/tools/hf_repo_git_tool.py CHANGED
@@ -10,6 +10,7 @@ from typing import Any, Dict, Literal, Optional
10
  from huggingface_hub import HfApi
11
  from huggingface_hub.utils import RepositoryNotFoundError
12
 
 
13
  from agent.tools.types import ToolResult
14
 
15
  OperationType = Literal[
@@ -45,8 +46,9 @@ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
45
  class HfRepoGitTool:
46
  """Tool for git-like operations on HF repos."""
47
 
48
- def __init__(self, hf_token: Optional[str] = None):
49
  self.api = HfApi(token=hf_token)
 
50
 
51
  async def execute(self, args: Dict[str, Any]) -> ToolResult:
52
  """Execute the specified operation."""
@@ -552,6 +554,17 @@ class HfRepoGitTool:
552
  kwargs["space_sdk"] = space_sdk
553
 
554
  result = await _async_call(self.api.create_repo, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
555
 
556
  return {
557
  "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
@@ -747,7 +760,7 @@ async def hf_repo_git_handler(
747
  """Handler for agent tool router."""
748
  try:
749
  hf_token = session.hf_token if session else None
750
- tool = HfRepoGitTool(hf_token=hf_token)
751
  result = await tool.execute(arguments)
752
  return result["formatted"], not result.get("isError", False)
753
  except Exception as e:
 
10
  from huggingface_hub import HfApi
11
  from huggingface_hub.utils import RepositoryNotFoundError
12
 
13
+ from agent.core.hub_artifacts import register_hub_artifact
14
  from agent.tools.types import ToolResult
15
 
16
  OperationType = Literal[
 
46
  class HfRepoGitTool:
47
  """Tool for git-like operations on HF repos."""
48
 
49
+ def __init__(self, hf_token: Optional[str] = None, session: Any = None):
50
  self.api = HfApi(token=hf_token)
51
+ self.session = session
52
 
53
  async def execute(self, args: Dict[str, Any]) -> ToolResult:
54
  """Execute the specified operation."""
 
554
  kwargs["space_sdk"] = space_sdk
555
 
556
  result = await _async_call(self.api.create_repo, **kwargs)
557
+ extra_metadata = None
558
+ if repo_type == "space" and space_sdk:
559
+ extra_metadata = {"sdk": space_sdk}
560
+ await _async_call(
561
+ register_hub_artifact,
562
+ self.api,
563
+ repo_id,
564
+ repo_type,
565
+ session=self.session,
566
+ extra_metadata=extra_metadata,
567
+ )
568
 
569
  return {
570
  "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
 
760
  """Handler for agent tool router."""
761
  try:
762
  hf_token = session.hf_token if session else None
763
+ tool = HfRepoGitTool(hf_token=hf_token, session=session)
764
  result = await tool.execute(arguments)
765
  return result["formatted"], not result.get("isError", False)
766
  except Exception as e:
agent/tools/jobs_tool.py CHANGED
@@ -9,6 +9,7 @@ import base64
9
  import http.client
10
  import logging
11
  import re
 
12
  from typing import Any, Awaitable, Callable, Dict, Literal, Optional
13
 
14
  import httpx
@@ -20,6 +21,7 @@ from agent.core.hf_access import (
20
  is_billing_error,
21
  resolve_jobs_namespace,
22
  )
 
23
  from agent.core.session import Event
24
  from agent.tools.trackio_seed import ensure_trackio_dashboard
25
  from agent.tools.types import ToolResult
@@ -237,6 +239,26 @@ def _resolve_uv_command(
237
  return _build_uv_command(script, with_deps, python, script_args)
238
 
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  async def _async_call(func, *args, **kwargs):
241
  """Wrap synchronous HfApi calls for async context"""
242
  return await asyncio.to_thread(func, *args, **kwargs)
@@ -560,6 +582,8 @@ class HfJobsTool:
560
  image = args.get("image", "python:3.12")
561
  job_type = "Docker"
562
 
 
 
563
  # Run the job
564
  flavor = args.get("hardware_flavor", "cpu-basic")
565
  timeout_str = args.get("timeout", "30m")
@@ -912,6 +936,8 @@ To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}}
912
  image = args.get("image", "python:3.12")
913
  job_type = "Docker"
914
 
 
 
915
  # Create scheduled job
916
  scheduled_job = await _async_call(
917
  self.api.create_scheduled_job,
 
9
  import http.client
10
  import logging
11
  import re
12
+ import shlex
13
  from typing import Any, Awaitable, Callable, Dict, Literal, Optional
14
 
15
  import httpx
 
21
  is_billing_error,
22
  resolve_jobs_namespace,
23
  )
24
+ from agent.core.hub_artifacts import build_hub_artifact_sitecustomize
25
  from agent.core.session import Event
26
  from agent.tools.trackio_seed import ensure_trackio_dashboard
27
  from agent.tools.types import ToolResult
 
239
  return _build_uv_command(script, with_deps, python, script_args)
240
 
241
 
242
+ def _wrap_command_with_artifact_bootstrap(
243
+ command: list[str], session: Any = None
244
+ ) -> list[str]:
245
+ """Install sitecustomize hooks before the user command runs in HF Jobs."""
246
+ sitecustomize = build_hub_artifact_sitecustomize(session)
247
+ if not sitecustomize:
248
+ return command
249
+
250
+ encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii")
251
+ original_command = shlex.join(command)
252
+ shell = (
253
+ 'set -e; _ml_intern_artifacts_dir="$(mktemp -d)"; '
254
+ f"printf %s {shlex.quote(encoded)} | base64 -d "
255
+ '> "$_ml_intern_artifacts_dir/sitecustomize.py"; '
256
+ 'export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"; '
257
+ f"exec {original_command}"
258
+ )
259
+ return ["/bin/sh", "-lc", shell]
260
+
261
+
262
  async def _async_call(func, *args, **kwargs):
263
  """Wrap synchronous HfApi calls for async context"""
264
  return await asyncio.to_thread(func, *args, **kwargs)
 
582
  image = args.get("image", "python:3.12")
583
  job_type = "Docker"
584
 
585
+ command = _wrap_command_with_artifact_bootstrap(command, self.session)
586
+
587
  # Run the job
588
  flavor = args.get("hardware_flavor", "cpu-basic")
589
  timeout_str = args.get("timeout", "30m")
 
936
  image = args.get("image", "python:3.12")
937
  job_type = "Docker"
938
 
939
+ command = _wrap_command_with_artifact_bootstrap(command, self.session)
940
+
941
  # Create scheduled job
942
  scheduled_job = await _async_call(
943
  self.api.create_scheduled_job,
agent/tools/local_tools.py CHANGED
@@ -15,6 +15,8 @@ import tempfile
15
  from pathlib import Path
16
  from typing import Any
17
 
 
 
18
 
19
  MAX_OUTPUT_CHARS = 25_000
20
  MAX_LINE_LENGTH = 4000
@@ -98,10 +100,13 @@ def _truncate_output(
98
  # ── Handlers ────────────────────────────────────────────────────────────
99
 
100
 
101
- async def _bash_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]:
 
 
102
  command = args.get("command", "")
103
  if not command:
104
  return "No command provided.", False
 
105
  work_dir = args.get("work_dir", ".")
106
  timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT)
107
  try:
 
15
  from pathlib import Path
16
  from typing import Any
17
 
18
+ from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap
19
+
20
 
21
  MAX_OUTPUT_CHARS = 25_000
22
  MAX_LINE_LENGTH = 4000
 
100
  # ── Handlers ────────────────────────────────────────────────────────────
101
 
102
 
103
+ async def _bash_handler(
104
+ args: dict[str, Any], session: Any = None, **_kw
105
+ ) -> tuple[str, bool]:
106
  command = args.get("command", "")
107
  if not command:
108
  return "No command provided.", False
109
+ command = wrap_shell_command_with_hub_artifact_bootstrap(command, session)
110
  work_dir = args.get("work_dir", ".")
111
  timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT)
112
  try:
agent/tools/sandbox_tool.py CHANGED
@@ -21,6 +21,7 @@ from typing import Any
21
 
22
  from huggingface_hub import HfApi, SpaceHardware
23
 
 
24
  from agent.core.session import Event
25
  from agent.tools.sandbox_client import Sandbox
26
  from agent.tools.trackio_seed import ensure_trackio_dashboard
@@ -729,6 +730,14 @@ def _make_tool_handler(sandbox_tool_name: str):
729
  return "Sandbox is still starting. Please retry shortly.", False
730
 
731
  try:
 
 
 
 
 
 
 
 
732
  result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args)
733
  if result.success:
734
  output = result.output or "(no output)"
 
21
 
22
  from huggingface_hub import HfApi, SpaceHardware
23
 
24
+ from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap
25
  from agent.core.session import Event
26
  from agent.tools.sandbox_client import Sandbox
27
  from agent.tools.trackio_seed import ensure_trackio_dashboard
 
730
  return "Sandbox is still starting. Please retry shortly.", False
731
 
732
  try:
733
+ if sandbox_tool_name == "bash" and args.get("command"):
734
+ args = {
735
+ **args,
736
+ "command": wrap_shell_command_with_hub_artifact_bootstrap(
737
+ args["command"],
738
+ session,
739
+ ),
740
+ }
741
  result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args)
742
  if result.success:
743
  output = result.output or "(no output)"
backend/session_manager.py CHANGED
@@ -12,10 +12,11 @@ from typing import Any, Optional
12
 
13
  from agent.config import load_config
14
  from agent.core.agent_loop import process_submission
15
- from agent.messaging.gateway import NotificationGateway
16
  from agent.core.session import Event, OpType, Session
17
  from agent.core.session_persistence import get_session_store
18
  from agent.core.tools import ToolRouter
 
19
 
20
  # Get project root (parent of backend directory)
21
  PROJECT_ROOT = Path(__file__).parent.parent
@@ -135,6 +136,7 @@ class SessionManager:
135
  self.sessions: dict[str, AgentSession] = {}
136
  self._lock = asyncio.Lock()
137
  self.persistence_store = None
 
138
 
139
  async def start(self) -> None:
140
  """Start shared background resources."""
@@ -411,6 +413,28 @@ class SessionManager:
411
  session.sandbox_preload_cancel_event = None
412
  self._start_cpu_sandbox_preload(agent_session)
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
415
  try:
416
  await self._store().update_session_fields(
@@ -567,6 +591,7 @@ class SessionManager:
567
  existing,
568
  preload_sandbox=preload_sandbox,
569
  )
 
570
  return existing
571
  return None
572
 
@@ -588,6 +613,7 @@ class SessionManager:
588
  existing,
589
  preload_sandbox=preload_sandbox,
590
  )
 
591
  return existing
592
  return None
593
 
@@ -674,7 +700,9 @@ class SessionManager:
674
  hf_token=hf_token,
675
  hf_username=hf_username,
676
  )
 
677
  return started
 
678
  if preload_sandbox:
679
  self._start_cpu_sandbox_preload(agent_session)
680
  logger.info("Restored session %s for user %s", session_id, owner or user_id)
@@ -757,6 +785,7 @@ class SessionManager:
757
  event_queue=event_queue,
758
  tool_router=tool_router,
759
  )
 
760
  await self.persist_session_snapshot(agent_session, runtime_state="idle")
761
  self._start_cpu_sandbox_preload(agent_session)
762
 
 
12
 
13
  from agent.config import load_config
14
  from agent.core.agent_loop import process_submission
15
+ from agent.core.hub_artifacts import start_session_artifact_collection_task
16
  from agent.core.session import Event, OpType, Session
17
  from agent.core.session_persistence import get_session_store
18
  from agent.core.tools import ToolRouter
19
+ from agent.messaging.gateway import NotificationGateway
20
 
21
  # Get project root (parent of backend directory)
22
  PROJECT_ROOT = Path(__file__).parent.parent
 
136
  self.sessions: dict[str, AgentSession] = {}
137
  self._lock = asyncio.Lock()
138
  self.persistence_store = None
139
+ self.enable_hub_artifact_collections = True
140
 
141
  async def start(self) -> None:
142
  """Start shared background resources."""
 
413
  session.sandbox_preload_cancel_event = None
414
  self._start_cpu_sandbox_preload(agent_session)
415
 
416
+ def _start_hub_artifact_collection(self, agent_session: AgentSession) -> None:
417
+ """Kick off best-effort Hub collection creation for the session."""
418
+ if not getattr(self, "enable_hub_artifact_collections", False):
419
+ return
420
+ session = agent_session.session
421
+ if not getattr(session, "session_id", None):
422
+ try:
423
+ session.session_id = agent_session.session_id
424
+ except Exception:
425
+ logger.debug("Could not attach session id for Hub artifact collection")
426
+ token = agent_session.hf_token or getattr(session, "hf_token", None)
427
+ if not token:
428
+ return
429
+ try:
430
+ start_session_artifact_collection_task(session, token=token)
431
+ except Exception as e:
432
+ logger.debug(
433
+ "Failed to schedule Hub artifact collection for %s: %s",
434
+ agent_session.session_id,
435
+ e,
436
+ )
437
+
438
  async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
439
  try:
440
  await self._store().update_session_fields(
 
591
  existing,
592
  preload_sandbox=preload_sandbox,
593
  )
594
+ self._start_hub_artifact_collection(existing)
595
  return existing
596
  return None
597
 
 
613
  existing,
614
  preload_sandbox=preload_sandbox,
615
  )
616
+ self._start_hub_artifact_collection(existing)
617
  return existing
618
  return None
619
 
 
700
  hf_token=hf_token,
701
  hf_username=hf_username,
702
  )
703
+ self._start_hub_artifact_collection(started)
704
  return started
705
+ self._start_hub_artifact_collection(agent_session)
706
  if preload_sandbox:
707
  self._start_cpu_sandbox_preload(agent_session)
708
  logger.info("Restored session %s for user %s", session_id, owner or user_id)
 
785
  event_queue=event_queue,
786
  tool_router=tool_router,
787
  )
788
+ self._start_hub_artifact_collection(agent_session)
789
  await self.persist_session_snapshot(agent_session, runtime_state="idle")
790
  self._start_cpu_sandbox_preload(agent_session)
791
 
tests/unit/test_hub_artifacts.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from types import SimpleNamespace
4
+
5
+ import pytest
6
+
7
+ from agent.core import hub_artifacts
8
+ from agent.core.hub_artifacts import (
9
+ ML_INTERN_TAG,
10
+ PROVENANCE_MARKER,
11
+ artifact_collection_title,
12
+ augment_repo_card_content,
13
+ build_hub_artifact_sitecustomize,
14
+ ensure_session_artifact_collection,
15
+ is_known_hub_artifact,
16
+ register_hub_artifact,
17
+ remember_hub_artifact,
18
+ start_session_artifact_collection_task,
19
+ wrap_shell_command_with_hub_artifact_bootstrap,
20
+ )
21
+ from agent.tools import local_tools, sandbox_tool
22
+ from agent.tools.hf_repo_files_tool import HfRepoFilesTool
23
+ from agent.tools.hf_repo_git_tool import HfRepoGitTool
24
+ from agent.tools.jobs_tool import _wrap_command_with_artifact_bootstrap
25
+
26
+
27
+ def _session() -> SimpleNamespace:
28
+ return SimpleNamespace(
29
+ session_id="session-123",
30
+ session_start_time="2026-05-05T10:20:30",
31
+ )
32
+
33
+
34
+ def test_artifact_collection_title_uses_session_date_and_id():
35
+ assert (
36
+ artifact_collection_title(_session())
37
+ == "ml-intern-artifacts-2026-05-05-session-123"
38
+ )
39
+
40
+
41
+ def test_artifact_collection_title_uses_short_uuid_fragment():
42
+ session = SimpleNamespace(
43
+ session_id="fadcbc77-3439-4c2b-bc52-50d7f6353af3",
44
+ session_start_time="2026-05-05T10:20:30",
45
+ )
46
+
47
+ title = artifact_collection_title(session)
48
+
49
+ assert title == "ml-intern-artifacts-2026-05-05-fadcbc77"
50
+ assert len(title) < 60
51
+
52
+
53
+ def test_artifact_collection_title_still_truncates_long_non_uuid_ids():
54
+ session = SimpleNamespace(
55
+ session_id="custom-session-id-that-is-longer-than-the-hub-title-limit",
56
+ session_start_time="2026-05-05T10:20:30",
57
+ )
58
+
59
+ title = artifact_collection_title(session)
60
+
61
+ assert title.startswith("ml-intern-artifacts-2026-05-05-custom-session-id")
62
+ assert len(title) < 60
63
+
64
+
65
+ def test_model_card_merges_tags_and_appends_provenance_and_usage():
66
+ content = """---
67
+ license: apache-2.0
68
+ tags:
69
+ - text-generation
70
+ ---
71
+ # Existing Model
72
+
73
+ Existing details stay here.
74
+ """
75
+
76
+ updated = augment_repo_card_content(content, "alice/model", "model")
77
+ second_pass = augment_repo_card_content(updated, "alice/model", "model")
78
+
79
+ assert "license: apache-2.0" in updated
80
+ assert "- text-generation" in updated
81
+ assert f"- {ML_INTERN_TAG}" in updated
82
+ assert "# Existing Model" in updated
83
+ assert "Existing details stay here." in updated
84
+ assert PROVENANCE_MARKER in updated
85
+ assert "AutoModelForCausalLM" in updated
86
+ assert second_pass.count(PROVENANCE_MARKER) == 1
87
+ assert second_pass.count("AutoModelForCausalLM") == updated.count(
88
+ "AutoModelForCausalLM"
89
+ )
90
+
91
+
92
+ def test_dataset_card_adds_load_dataset_usage():
93
+ updated = augment_repo_card_content("", "alice/dataset", "dataset")
94
+
95
+ assert f"- {ML_INTERN_TAG}" in updated
96
+ assert "# alice/dataset" in updated
97
+ assert "from datasets import load_dataset" in updated
98
+ assert 'load_dataset("alice/dataset")' in updated
99
+
100
+
101
+ def test_existing_usage_section_is_preserved_without_duplicate_usage():
102
+ content = """# Existing Dataset
103
+
104
+ ## Usage
105
+
106
+ Use the custom loader in this repository.
107
+ """
108
+
109
+ updated = augment_repo_card_content(content, "alice/dataset", "dataset")
110
+
111
+ assert "Use the custom loader in this repository." in updated
112
+ assert "from datasets import load_dataset" not in updated
113
+ assert PROVENANCE_MARKER in updated
114
+
115
+
116
+ def test_space_card_gets_metadata_without_provenance_body():
117
+ updated = augment_repo_card_content("# Existing Space\n", "alice/space", "space")
118
+
119
+ assert f"- {ML_INTERN_TAG}" in updated
120
+ assert "# Existing Space" in updated
121
+ assert PROVENANCE_MARKER not in updated
122
+
123
+
124
+ def test_register_hub_artifact_creates_private_collection_and_adds_item_once(
125
+ monkeypatch,
126
+ ):
127
+ session = _session()
128
+
129
+ class FakeApi:
130
+ token = "hf-token"
131
+
132
+ def __init__(self):
133
+ self.created_collections = []
134
+ self.collection_items = []
135
+ self.uploads = []
136
+
137
+ def create_collection(self, **kwargs):
138
+ self.created_collections.append(kwargs)
139
+ return SimpleNamespace(slug="alice/ml-intern-artifacts")
140
+
141
+ def add_collection_item(self, **kwargs):
142
+ self.collection_items.append(kwargs)
143
+
144
+ def upload_file(self, **kwargs):
145
+ self.uploads.append(kwargs)
146
+
147
+ api = FakeApi()
148
+ monkeypatch.setattr(hub_artifacts, "_read_remote_readme", lambda *_, **__: "")
149
+
150
+ assert register_hub_artifact(api, "alice/model", "model", session=session)
151
+ assert register_hub_artifact(api, "alice/model", "model", session=session)
152
+
153
+ assert is_known_hub_artifact(session, "alice/model", "model")
154
+ assert len(api.created_collections) == 1
155
+ assert api.created_collections[0]["title"] == artifact_collection_title(session)
156
+ assert api.created_collections[0]["private"] is True
157
+ assert len(api.collection_items) == 1
158
+ assert api.collection_items[0]["item_id"] == "alice/model"
159
+ assert api.collection_items[0]["item_type"] == "model"
160
+ assert api.collection_items[0]["exists_ok"] is True
161
+ assert len(api.uploads) == 1
162
+ assert b"ml-intern" in api.uploads[0]["path_or_fileobj"]
163
+
164
+
165
+ def test_register_hub_artifact_retries_after_partial_failure(monkeypatch):
166
+ session = _session()
167
+ api = SimpleNamespace(token="hf-token")
168
+ card_attempts = 0
169
+ collection_attempts = 0
170
+
171
+ def flaky_update_repo_card(*args, **kwargs):
172
+ nonlocal card_attempts
173
+ card_attempts += 1
174
+ if card_attempts == 1:
175
+ raise RuntimeError("temporary card failure")
176
+
177
+ def add_to_collection(*args, **kwargs):
178
+ nonlocal collection_attempts
179
+ collection_attempts += 1
180
+
181
+ monkeypatch.setattr(
182
+ hub_artifacts,
183
+ "_update_repo_card",
184
+ flaky_update_repo_card,
185
+ )
186
+ monkeypatch.setattr(hub_artifacts, "_add_to_collection", add_to_collection)
187
+
188
+ assert not register_hub_artifact(api, "alice/model", "model", session=session)
189
+ assert register_hub_artifact(api, "alice/model", "model", session=session)
190
+ assert register_hub_artifact(api, "alice/model", "model", session=session)
191
+
192
+ assert card_attempts == 2
193
+ assert collection_attempts == 2
194
+
195
+
196
+ def test_register_hub_artifact_retries_after_collection_failure(monkeypatch):
197
+ session = _session()
198
+ api = SimpleNamespace(token="hf-token")
199
+ card_attempts = 0
200
+ collection_attempts = 0
201
+
202
+ def update_repo_card(*args, **kwargs):
203
+ nonlocal card_attempts
204
+ card_attempts += 1
205
+
206
+ def flaky_add_to_collection(*args, **kwargs):
207
+ nonlocal collection_attempts
208
+ collection_attempts += 1
209
+ if collection_attempts == 1:
210
+ raise RuntimeError("temporary collection failure")
211
+
212
+ monkeypatch.setattr(hub_artifacts, "_update_repo_card", update_repo_card)
213
+ monkeypatch.setattr(
214
+ hub_artifacts,
215
+ "_add_to_collection",
216
+ flaky_add_to_collection,
217
+ )
218
+
219
+ assert not register_hub_artifact(api, "alice/model", "model", session=session)
220
+ assert register_hub_artifact(api, "alice/model", "model", session=session)
221
+ assert register_hub_artifact(api, "alice/model", "model", session=session)
222
+
223
+ assert card_attempts == 2
224
+ assert collection_attempts == 2
225
+
226
+
227
+ def test_session_artifact_set_falls_back_when_session_rejects_attrs(caplog):
228
+ class SlottedSession:
229
+ __slots__ = ("session_id", "session_start_time")
230
+
231
+ def __init__(self):
232
+ self.session_id = "session-123"
233
+ self.session_start_time = "2026-05-05T10:20:30"
234
+
235
+ session = SlottedSession()
236
+
237
+ with caplog.at_level(logging.WARNING):
238
+ remember_hub_artifact(session, "alice/model", "model")
239
+
240
+ assert is_known_hub_artifact(session, "alice/model", "model")
241
+ assert "using process-local fallback state" in caplog.text
242
+
243
+
244
+ @pytest.mark.asyncio
245
+ async def test_ensure_session_artifact_collection_uses_user_token(monkeypatch):
246
+ session = _session()
247
+ calls = []
248
+
249
+ class FakeApi:
250
+ def __init__(self, token):
251
+ self.token = token
252
+
253
+ def fake_ensure_collection_slug(api, seen_session, **kwargs):
254
+ calls.append((api.token, seen_session, kwargs))
255
+ return "alice/ml-intern-artifacts"
256
+
257
+ monkeypatch.setattr(hub_artifacts, "HfApi", FakeApi)
258
+ monkeypatch.setattr(
259
+ hub_artifacts,
260
+ "_ensure_collection_slug",
261
+ fake_ensure_collection_slug,
262
+ )
263
+
264
+ slug = await ensure_session_artifact_collection(session, token="hf-token")
265
+
266
+ assert slug == "alice/ml-intern-artifacts"
267
+ assert calls == [
268
+ ("hf-token", session, {"token": "hf-token"}),
269
+ ]
270
+
271
+
272
+ @pytest.mark.asyncio
273
+ async def test_start_session_artifact_collection_task_dedupes(monkeypatch):
274
+ session = _session()
275
+ calls = []
276
+
277
+ async def fake_ensure_session_artifact_collection(seen_session, **kwargs):
278
+ calls.append((seen_session, kwargs))
279
+ await asyncio.sleep(0)
280
+ return "alice/ml-intern-artifacts"
281
+
282
+ monkeypatch.setattr(
283
+ hub_artifacts,
284
+ "ensure_session_artifact_collection",
285
+ fake_ensure_session_artifact_collection,
286
+ )
287
+
288
+ task = start_session_artifact_collection_task(session, token="hf-token")
289
+ second = start_session_artifact_collection_task(session, token="hf-token")
290
+
291
+ assert task is not None
292
+ assert second is task
293
+ await task
294
+ assert calls == [(session, {"token": "hf-token"})]
295
+
296
+
297
+ def test_start_session_artifact_collection_task_skips_without_token():
298
+ assert start_session_artifact_collection_task(_session()) is None
299
+
300
+
301
+ @pytest.mark.asyncio
302
+ async def test_hf_repo_git_create_repo_registers_artifact(monkeypatch):
303
+ session = _session()
304
+ calls = []
305
+
306
+ class FakeApi:
307
+ token = "hf-token"
308
+
309
+ def create_repo(self, **kwargs):
310
+ self.create_kwargs = kwargs
311
+ return "https://huggingface.co/spaces/alice/demo"
312
+
313
+ def fake_register(api, repo_id, repo_type, **kwargs):
314
+ calls.append((api, repo_id, repo_type, kwargs))
315
+ return True
316
+
317
+ monkeypatch.setattr(
318
+ "agent.tools.hf_repo_git_tool.register_hub_artifact",
319
+ fake_register,
320
+ )
321
+ tool = HfRepoGitTool(hf_token="hf-token", session=session)
322
+ tool.api = FakeApi()
323
+
324
+ result = await tool._create_repo(
325
+ {
326
+ "repo_id": "alice/demo",
327
+ "repo_type": "space",
328
+ "space_sdk": "gradio",
329
+ "private": True,
330
+ }
331
+ )
332
+
333
+ assert result["totalResults"] == 1
334
+ assert calls == [
335
+ (
336
+ tool.api,
337
+ "alice/demo",
338
+ "space",
339
+ {"session": session, "extra_metadata": {"sdk": "gradio"}},
340
+ )
341
+ ]
342
+
343
+
344
+ @pytest.mark.asyncio
345
+ async def test_hf_repo_files_upload_registers_known_artifact_with_force(monkeypatch):
346
+ session = _session()
347
+ calls = []
348
+ uploads = []
349
+
350
+ class FakeApi:
351
+ token = "hf-token"
352
+
353
+ def upload_file(self, **kwargs):
354
+ uploads.append(kwargs)
355
+ return SimpleNamespace()
356
+
357
+ def fake_register(api, repo_id, repo_type, **kwargs):
358
+ calls.append((api, repo_id, repo_type, kwargs))
359
+ return True
360
+
361
+ monkeypatch.setattr(
362
+ "agent.tools.hf_repo_files_tool.register_hub_artifact",
363
+ fake_register,
364
+ )
365
+ remember_hub_artifact(session, "alice/model", "model")
366
+
367
+ tool = HfRepoFilesTool(hf_token="hf-token", session=session)
368
+ tool.api = FakeApi()
369
+
370
+ result = await tool._upload(
371
+ {
372
+ "repo_id": "alice/model",
373
+ "repo_type": "model",
374
+ "path": "weights.bin",
375
+ "content": b"weights",
376
+ }
377
+ )
378
+ readme_result = await tool._upload(
379
+ {
380
+ "repo_id": "alice/model",
381
+ "repo_type": "model",
382
+ "path": "README.md",
383
+ "content": "# Model",
384
+ }
385
+ )
386
+
387
+ assert result["totalResults"] == 1
388
+ assert readme_result["totalResults"] == 1
389
+ assert [upload["path_in_repo"] for upload in uploads] == [
390
+ "weights.bin",
391
+ "README.md",
392
+ ]
393
+ assert calls == [
394
+ (
395
+ tool.api,
396
+ "alice/model",
397
+ "model",
398
+ {"session": session, "force": False},
399
+ ),
400
+ (
401
+ tool.api,
402
+ "alice/model",
403
+ "model",
404
+ {"session": session, "force": True},
405
+ ),
406
+ ]
407
+
408
+
409
+ def test_hf_jobs_artifact_bootstrap_wraps_command_without_changing_exec_target():
410
+ command = ["uv", "run", "train.py"]
411
+ wrapped = _wrap_command_with_artifact_bootstrap(command, _session())
412
+
413
+ assert wrapped[0:2] == ["/bin/sh", "-lc"]
414
+ assert "sitecustomize.py" in wrapped[2]
415
+ assert "PYTHONPATH" in wrapped[2]
416
+ assert "exec uv run train.py" in wrapped[2]
417
+ assert _wrap_command_with_artifact_bootstrap(command, None) == command
418
+
419
+
420
+ def test_shell_bootstrap_wraps_capybara_push_to_hub_pattern():
421
+ command = (
422
+ "pip install -q datasets huggingface_hub && python -c "
423
+ "\"subset.push_to_hub('lewtun/Capybara-100', private=False)\""
424
+ )
425
+
426
+ wrapped = wrap_shell_command_with_hub_artifact_bootstrap(command, _session())
427
+
428
+ assert "sitecustomize.py" in wrapped
429
+ assert "PYTHONPATH" in wrapped
430
+ assert command in wrapped
431
+ assert wrap_shell_command_with_hub_artifact_bootstrap(command, None) == command
432
+ assert (
433
+ wrap_shell_command_with_hub_artifact_bootstrap(
434
+ command,
435
+ SimpleNamespace(session_start_time="2026-05-05T10:20:30"),
436
+ )
437
+ == command
438
+ )
439
+
440
+
441
+ @pytest.mark.asyncio
442
+ async def test_sandbox_bash_wraps_command_for_session_artifact_hooks():
443
+ calls = []
444
+
445
+ class FakeSandbox:
446
+ def call_tool(self, name, args):
447
+ calls.append((name, args))
448
+ return SimpleNamespace(success=True, output="ok", error="")
449
+
450
+ session = _session()
451
+ session.sandbox = FakeSandbox()
452
+
453
+ handler = sandbox_tool._make_tool_handler("bash")
454
+ output, ok = await handler({"command": "python make_dataset.py"}, session=session)
455
+
456
+ assert ok is True
457
+ assert output == "ok"
458
+ assert calls[0][0] == "bash"
459
+ assert "sitecustomize.py" in calls[0][1]["command"]
460
+ assert "python make_dataset.py" in calls[0][1]["command"]
461
+
462
+
463
+ @pytest.mark.asyncio
464
+ async def test_local_bash_wraps_command_for_session_artifact_hooks(monkeypatch):
465
+ seen = {}
466
+
467
+ def fake_run(command, **kwargs):
468
+ seen["command"] = command
469
+ seen["kwargs"] = kwargs
470
+ return SimpleNamespace(stdout="ok", stderr="", returncode=0)
471
+
472
+ monkeypatch.setattr(local_tools.subprocess, "run", fake_run)
473
+
474
+ output, ok = await local_tools._bash_handler(
475
+ {"command": "python make_dataset.py"},
476
+ session=_session(),
477
+ )
478
+
479
+ assert ok is True
480
+ assert output == "ok"
481
+ assert "sitecustomize.py" in seen["command"]
482
+ assert "python make_dataset.py" in seen["command"]
483
+
484
+
485
+ def test_sitecustomize_bootstrap_is_valid_python():
486
+ code = build_hub_artifact_sitecustomize(_session())
487
+
488
+ compile(code, "sitecustomize.py", "exec")
489
+ assert "ml-intern-artifacts-2026-05-05-session-123" in code
490
+
491
+
492
+ def test_sitecustomize_bootstrap_reuses_existing_collection_slug():
493
+ session = _session()
494
+ setattr(
495
+ session,
496
+ hub_artifacts._COLLECTION_SLUG_ATTR,
497
+ "alice/ml-intern-artifacts-2026-05-05-session-123",
498
+ )
499
+
500
+ code = build_hub_artifact_sitecustomize(session)
501
+
502
+ compile(code, "sitecustomize.py", "exec")
503
+ assert (
504
+ "collection_slug = 'alice/ml-intern-artifacts-2026-05-05-session-123'" in code
505
+ )
tests/unit/test_session_manager_persistence.py CHANGED
@@ -430,6 +430,32 @@ async def test_create_session_schedules_cpu_sandbox_preload():
430
  await _cancel_runtime_tasks(manager)
431
 
432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  @pytest.mark.asyncio
434
  async def test_lazy_restore_schedules_cpu_sandbox_preload():
435
  manager = _manager_with_store(RestoreStore())
@@ -454,6 +480,37 @@ async def test_lazy_restore_schedules_cpu_sandbox_preload():
454
  await _cancel_runtime_tasks(manager)
455
 
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  @pytest.mark.asyncio
458
  async def test_lazy_restore_deletes_persisted_sandbox_before_preload(monkeypatch):
459
  deleted: list[tuple[str, str, str]] = []
 
430
  await _cancel_runtime_tasks(manager)
431
 
432
 
433
+ @pytest.mark.asyncio
434
+ async def test_create_session_starts_hub_artifact_collection(monkeypatch):
435
+ manager = _manager_with_store(NoopSessionStore())
436
+ manager.enable_hub_artifact_collections = True
437
+ stop = _install_fake_runtime(manager)
438
+ started: list[tuple[str, str]] = []
439
+
440
+ def fake_start_session_artifact_collection_task(session, **kwargs):
441
+ started.append((session.session_id, kwargs["token"]))
442
+ return None
443
+
444
+ monkeypatch.setattr(
445
+ "session_manager.start_session_artifact_collection_task",
446
+ fake_start_session_artifact_collection_task,
447
+ )
448
+ manager._start_cpu_sandbox_preload = lambda _: None # type: ignore[method-assign]
449
+
450
+ try:
451
+ session_id = await manager.create_session(user_id="owner", hf_token="token")
452
+
453
+ assert started == [(session_id, "token")]
454
+ finally:
455
+ stop.set()
456
+ await _cancel_runtime_tasks(manager)
457
+
458
+
459
  @pytest.mark.asyncio
460
  async def test_lazy_restore_schedules_cpu_sandbox_preload():
461
  manager = _manager_with_store(RestoreStore())
 
480
  await _cancel_runtime_tasks(manager)
481
 
482
 
483
+ @pytest.mark.asyncio
484
+ async def test_lazy_restore_starts_hub_artifact_collection(monkeypatch):
485
+ manager = _manager_with_store(RestoreStore())
486
+ manager.enable_hub_artifact_collections = True
487
+ stop = _install_fake_runtime(manager)
488
+ started: list[tuple[str, str]] = []
489
+
490
+ def fake_start_session_artifact_collection_task(session, **kwargs):
491
+ started.append((session.session_id, kwargs["token"]))
492
+ return None
493
+
494
+ monkeypatch.setattr(
495
+ "session_manager.start_session_artifact_collection_task",
496
+ fake_start_session_artifact_collection_task,
497
+ )
498
+ manager._start_cpu_sandbox_preload = lambda _: None # type: ignore[method-assign]
499
+
500
+ try:
501
+ restored = await manager.ensure_session_loaded(
502
+ "persisted-session",
503
+ user_id="owner",
504
+ hf_token="token",
505
+ )
506
+
507
+ assert restored is not None
508
+ assert started == [("persisted-session", "token")]
509
+ finally:
510
+ stop.set()
511
+ await _cancel_runtime_tasks(manager)
512
+
513
+
514
  @pytest.mark.asyncio
515
  async def test_lazy_restore_deletes_persisted_sandbox_before_preload(monkeypatch):
516
  deleted: list[tuple[str, str, str]] = []