Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Add ML Intern Hub artifact metadata (#225)
Browse files* Add ML Intern Hub artifact metadata
Co-authored-by: OpenAI Codex <codex@openai.com>
* Extend Hub artifact hooks to sandbox bash
Co-authored-by: OpenAI Codex <codex@openai.com>
* Create Hub artifact collections at session start
Co-authored-by: OpenAI Codex <codex@openai.com>
* Address Hub artifact PR review comments
Co-authored-by: OpenAI Codex <codex@openai.com>
* Respect Hub collection title limits
Co-authored-by: OpenAI Codex <codex@openai.com>
* Shorten collection session UUID fragments
Co-authored-by: OpenAI Codex <codex@openai.com>
---------
Co-authored-by: OpenAI Codex <codex@openai.com>
- agent/core/agent_loop.py +2 -0
- agent/core/hub_artifacts.py +765 -0
- agent/tools/hf_repo_files_tool.py +14 -2
- agent/tools/hf_repo_git_tool.py +15 -2
- agent/tools/jobs_tool.py +26 -0
- agent/tools/local_tools.py +6 -1
- agent/tools/sandbox_tool.py +9 -0
- backend/session_manager.py +30 -1
- tests/unit/test_hub_artifacts.py +505 -0
- tests/unit/test_session_manager_persistence.py +57 -0
agent/core/agent_loop.py
CHANGED
|
@@ -26,6 +26,7 @@ from agent.core.cost_estimation import CostEstimate, estimate_tool_cost
|
|
| 26 |
from agent.messaging.gateway import NotificationGateway
|
| 27 |
from agent.core import telemetry
|
| 28 |
from agent.core.doom_loop import check_for_doom_loop
|
|
|
|
| 29 |
from agent.core.llm_params import _resolve_llm_params
|
| 30 |
from agent.core.prompt_caching import with_prompt_caching
|
| 31 |
from agent.core.session import Event, OpType, Session
|
|
@@ -1998,6 +1999,7 @@ async def submission_loop(
|
|
| 1998 |
)
|
| 1999 |
if session_holder is not None:
|
| 2000 |
session_holder[0] = session
|
|
|
|
| 2001 |
logger.info("Agent loop started")
|
| 2002 |
|
| 2003 |
# Retry any failed uploads from previous sessions (fire-and-forget).
|
|
|
|
| 26 |
from agent.messaging.gateway import NotificationGateway
|
| 27 |
from agent.core import telemetry
|
| 28 |
from agent.core.doom_loop import check_for_doom_loop
|
| 29 |
+
from agent.core.hub_artifacts import start_session_artifact_collection_task
|
| 30 |
from agent.core.llm_params import _resolve_llm_params
|
| 31 |
from agent.core.prompt_caching import with_prompt_caching
|
| 32 |
from agent.core.session import Event, OpType, Session
|
|
|
|
| 1999 |
)
|
| 2000 |
if session_holder is not None:
|
| 2001 |
session_holder[0] = session
|
| 2002 |
+
start_session_artifact_collection_task(session, token=hf_token)
|
| 2003 |
logger.info("Agent loop started")
|
| 2004 |
|
| 2005 |
# Retry any failed uploads from previous sessions (fire-and-forget).
|
agent/core/hub_artifacts.py
ADDED
|
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Best-effort Hub metadata for artifacts generated by ML Intern sessions."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import base64
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
import shlex
|
| 8 |
+
import tempfile
|
| 9 |
+
import textwrap
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 15 |
+
from huggingface_hub.repocard import metadata_load, metadata_save
|
| 16 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
ML_INTERN_TAG = "ml-intern"
|
| 21 |
+
SUPPORTED_REPO_TYPES = {"model", "dataset", "space"}
|
| 22 |
+
PROVENANCE_MARKER = "<!-- ml-intern-provenance -->"
|
| 23 |
+
_COLLECTION_TITLE_PREFIX = "ml-intern-artifacts"
|
| 24 |
+
_COLLECTION_TITLE_MAX_LENGTH = 59
|
| 25 |
+
_UUID_SESSION_ID_RE = re.compile(
|
| 26 |
+
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
|
| 27 |
+
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
| 28 |
+
)
|
| 29 |
+
_KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts"
|
| 30 |
+
_REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts"
|
| 31 |
+
_COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug"
|
| 32 |
+
_COLLECTION_TASK_ATTR = "_ml_intern_artifact_collection_task"
|
| 33 |
+
_SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {}
|
| 34 |
+
_USAGE_HEADING_RE = re.compile(
|
| 35 |
+
r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b",
|
| 36 |
+
re.IGNORECASE | re.MULTILINE,
|
| 37 |
+
)
|
| 38 |
+
_FRONT_MATTER_RE = re.compile(r"\A---\s*\n.*?\n---\s*\n?", re.DOTALL)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _safe_session_id(session: Any) -> str:
|
| 42 |
+
raw = str(getattr(session, "session_id", "") or "unknown-session")
|
| 43 |
+
safe = re.sub(r"[^A-Za-z0-9._-]+", "-", raw).strip("-")
|
| 44 |
+
return safe or "unknown-session"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def session_artifact_date(session: Any) -> str:
|
| 48 |
+
"""Return the YYYY-MM-DD partition date for a session."""
|
| 49 |
+
raw = getattr(session, "session_start_time", None)
|
| 50 |
+
if raw:
|
| 51 |
+
try:
|
| 52 |
+
return datetime.fromisoformat(str(raw).replace("Z", "+00:00")).strftime(
|
| 53 |
+
"%Y-%m-%d"
|
| 54 |
+
)
|
| 55 |
+
except ValueError:
|
| 56 |
+
logger.debug("Could not parse session_start_time=%r", raw)
|
| 57 |
+
return datetime.utcnow().strftime("%Y-%m-%d")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _collection_session_id_fragment(session: Any) -> str:
|
| 61 |
+
safe_id = _safe_session_id(session)
|
| 62 |
+
if _UUID_SESSION_ID_RE.match(safe_id):
|
| 63 |
+
return safe_id[:8]
|
| 64 |
+
stem = f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-"
|
| 65 |
+
max_id_length = max(1, _COLLECTION_TITLE_MAX_LENGTH - len(stem))
|
| 66 |
+
if len(safe_id) <= max_id_length:
|
| 67 |
+
return safe_id
|
| 68 |
+
return safe_id[:max_id_length].rstrip("-._") or safe_id[:max_id_length]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def artifact_collection_title(session: Any) -> str:
|
| 72 |
+
return (
|
| 73 |
+
f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-"
|
| 74 |
+
f"{_collection_session_id_fragment(session)}"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _artifact_key(repo_id: str, repo_type: str | None) -> str:
|
| 79 |
+
return f"{repo_type or 'model'}:{repo_id}"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _session_artifact_set(session: Any, attr: str) -> set[str]:
|
| 83 |
+
current = getattr(session, attr, None)
|
| 84 |
+
if isinstance(current, set):
|
| 85 |
+
return current
|
| 86 |
+
current = set()
|
| 87 |
+
try:
|
| 88 |
+
setattr(session, attr, current)
|
| 89 |
+
except Exception:
|
| 90 |
+
logger.warning(
|
| 91 |
+
"Could not attach %s to session; using process-local fallback state",
|
| 92 |
+
attr,
|
| 93 |
+
)
|
| 94 |
+
return _SESSION_ARTIFACT_SET_FALLBACK.setdefault((id(session), attr), set())
|
| 95 |
+
return current
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def remember_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> None:
|
| 99 |
+
if session is None or not repo_id:
|
| 100 |
+
return
|
| 101 |
+
_session_artifact_set(session, _KNOWN_ARTIFACTS_ATTR).add(
|
| 102 |
+
_artifact_key(repo_id, repo_type)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def is_known_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> bool:
|
| 107 |
+
if session is None or not repo_id:
|
| 108 |
+
return False
|
| 109 |
+
return _artifact_key(repo_id, repo_type) in _session_artifact_set(
|
| 110 |
+
session, _KNOWN_ARTIFACTS_ATTR
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _merge_tags(metadata: dict[str, Any], tag: str = ML_INTERN_TAG) -> dict[str, Any]:
|
| 115 |
+
merged = dict(metadata)
|
| 116 |
+
raw_tags = merged.get("tags")
|
| 117 |
+
if raw_tags is None:
|
| 118 |
+
tags: list[str] = []
|
| 119 |
+
elif isinstance(raw_tags, str):
|
| 120 |
+
tags = [raw_tags]
|
| 121 |
+
elif isinstance(raw_tags, list):
|
| 122 |
+
tags = [str(item) for item in raw_tags]
|
| 123 |
+
else:
|
| 124 |
+
tags = [str(raw_tags)]
|
| 125 |
+
|
| 126 |
+
if tag not in tags:
|
| 127 |
+
tags.append(tag)
|
| 128 |
+
merged["tags"] = tags
|
| 129 |
+
return merged
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _metadata_from_content(content: str) -> dict[str, Any]:
|
| 133 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 134 |
+
path = Path(tmp_dir) / "README.md"
|
| 135 |
+
path.write_text(content, encoding="utf-8")
|
| 136 |
+
return metadata_load(path) or {}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _content_with_metadata(content: str, metadata: dict[str, Any]) -> str:
|
| 140 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 141 |
+
path = Path(tmp_dir) / "README.md"
|
| 142 |
+
path.write_text(content, encoding="utf-8")
|
| 143 |
+
metadata_save(path, metadata)
|
| 144 |
+
return path.read_text(encoding="utf-8")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _body_without_metadata(content: str) -> str:
|
| 148 |
+
return _FRONT_MATTER_RE.sub("", content, count=1).strip()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _append_section(content: str, section: str) -> str:
|
| 152 |
+
base = content.rstrip()
|
| 153 |
+
if base:
|
| 154 |
+
return f"{base}\n\n{section.strip()}\n"
|
| 155 |
+
return f"{section.strip()}\n"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _provenance_section(repo_type: str) -> str:
|
| 159 |
+
label = {"model": "model", "dataset": "dataset"}.get(repo_type, "Hub")
|
| 160 |
+
return f"""{PROVENANCE_MARKER}
|
| 161 |
+
## Generated by ML Intern
|
| 162 |
+
|
| 163 |
+
This {label} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
|
| 164 |
+
|
| 165 |
+
- Try ML Intern: https://smolagents-ml-intern.hf.space
|
| 166 |
+
- Source code: https://github.com/huggingface/ml-intern
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _usage_section(repo_id: str, repo_type: str) -> str:
|
| 171 |
+
if repo_type == "dataset":
|
| 172 |
+
return f"""## Usage
|
| 173 |
+
|
| 174 |
+
```python
|
| 175 |
+
from datasets import load_dataset
|
| 176 |
+
|
| 177 |
+
dataset = load_dataset("{repo_id}")
|
| 178 |
+
```
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
return f"""## Usage
|
| 182 |
+
|
| 183 |
+
```python
|
| 184 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 185 |
+
|
| 186 |
+
model_id = "{repo_id}"
|
| 187 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 188 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def augment_repo_card_content(
|
| 196 |
+
content: str | None,
|
| 197 |
+
repo_id: str,
|
| 198 |
+
repo_type: str = "model",
|
| 199 |
+
*,
|
| 200 |
+
extra_metadata: dict[str, Any] | None = None,
|
| 201 |
+
) -> str:
|
| 202 |
+
"""Return README content with ML Intern metadata and provenance added."""
|
| 203 |
+
repo_type = repo_type or "model"
|
| 204 |
+
content = content or ""
|
| 205 |
+
metadata = _metadata_from_content(content)
|
| 206 |
+
if extra_metadata:
|
| 207 |
+
metadata = {**extra_metadata, **metadata}
|
| 208 |
+
metadata = _merge_tags(metadata)
|
| 209 |
+
updated = _content_with_metadata(content, metadata)
|
| 210 |
+
|
| 211 |
+
if not _body_without_metadata(updated):
|
| 212 |
+
updated = _append_section(updated, f"# {repo_id}")
|
| 213 |
+
|
| 214 |
+
if repo_type in {"model", "dataset"} and PROVENANCE_MARKER not in updated:
|
| 215 |
+
updated = _append_section(updated, _provenance_section(repo_type))
|
| 216 |
+
if not _USAGE_HEADING_RE.search(content):
|
| 217 |
+
updated = _append_section(updated, _usage_section(repo_id, repo_type))
|
| 218 |
+
|
| 219 |
+
return updated
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _read_remote_readme(
|
| 223 |
+
api: Any,
|
| 224 |
+
repo_id: str,
|
| 225 |
+
repo_type: str,
|
| 226 |
+
*,
|
| 227 |
+
token: str | bool | None = None,
|
| 228 |
+
) -> str:
|
| 229 |
+
token_value = token if token is not None else getattr(api, "token", None)
|
| 230 |
+
try:
|
| 231 |
+
readme_path = hf_hub_download(
|
| 232 |
+
repo_id=repo_id,
|
| 233 |
+
filename="README.md",
|
| 234 |
+
repo_type=repo_type,
|
| 235 |
+
token=token_value,
|
| 236 |
+
)
|
| 237 |
+
except (EntryNotFoundError, RepositoryNotFoundError):
|
| 238 |
+
return ""
|
| 239 |
+
return Path(readme_path).read_text(encoding="utf-8")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _update_repo_card(
|
| 243 |
+
api: Any,
|
| 244 |
+
repo_id: str,
|
| 245 |
+
repo_type: str,
|
| 246 |
+
*,
|
| 247 |
+
token: str | bool | None = None,
|
| 248 |
+
extra_metadata: dict[str, Any] | None = None,
|
| 249 |
+
) -> None:
|
| 250 |
+
current = _read_remote_readme(api, repo_id, repo_type, token=token)
|
| 251 |
+
updated = augment_repo_card_content(
|
| 252 |
+
current,
|
| 253 |
+
repo_id,
|
| 254 |
+
repo_type,
|
| 255 |
+
extra_metadata=extra_metadata,
|
| 256 |
+
)
|
| 257 |
+
if updated == current:
|
| 258 |
+
return
|
| 259 |
+
api.upload_file(
|
| 260 |
+
path_or_fileobj=updated.encode("utf-8"),
|
| 261 |
+
path_in_repo="README.md",
|
| 262 |
+
repo_id=repo_id,
|
| 263 |
+
repo_type=repo_type,
|
| 264 |
+
token=token,
|
| 265 |
+
commit_message="Update ML Intern artifact metadata",
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def _ensure_collection_slug(
|
| 270 |
+
api: Any,
|
| 271 |
+
session: Any,
|
| 272 |
+
*,
|
| 273 |
+
token: str | bool | None = None,
|
| 274 |
+
) -> str | None:
|
| 275 |
+
slug = getattr(session, _COLLECTION_SLUG_ATTR, None)
|
| 276 |
+
if slug:
|
| 277 |
+
return slug
|
| 278 |
+
|
| 279 |
+
title = artifact_collection_title(session)
|
| 280 |
+
collection = api.create_collection(
|
| 281 |
+
title=title,
|
| 282 |
+
description=(
|
| 283 |
+
f"Artifacts generated by ML Intern session {_safe_session_id(session)} "
|
| 284 |
+
f"on {session_artifact_date(session)}."
|
| 285 |
+
),
|
| 286 |
+
private=True,
|
| 287 |
+
exists_ok=True,
|
| 288 |
+
token=token,
|
| 289 |
+
)
|
| 290 |
+
slug = getattr(collection, "slug", None)
|
| 291 |
+
if slug:
|
| 292 |
+
setattr(session, _COLLECTION_SLUG_ATTR, slug)
|
| 293 |
+
return slug
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
async def ensure_session_artifact_collection(
|
| 297 |
+
session: Any,
|
| 298 |
+
*,
|
| 299 |
+
token: str | bool | None = None,
|
| 300 |
+
) -> str | None:
|
| 301 |
+
"""Create/cache the per-session artifact collection without raising."""
|
| 302 |
+
if session is None or not getattr(session, "session_id", None):
|
| 303 |
+
return None
|
| 304 |
+
token_value = token if token is not None else getattr(session, "hf_token", None)
|
| 305 |
+
if not token_value:
|
| 306 |
+
return None
|
| 307 |
+
|
| 308 |
+
try:
|
| 309 |
+
api = HfApi(token=token_value)
|
| 310 |
+
return await asyncio.to_thread(
|
| 311 |
+
_ensure_collection_slug,
|
| 312 |
+
api,
|
| 313 |
+
session,
|
| 314 |
+
token=token_value,
|
| 315 |
+
)
|
| 316 |
+
except Exception as e:
|
| 317 |
+
logger.warning(
|
| 318 |
+
"ML Intern session collection creation failed for %s: %s",
|
| 319 |
+
_safe_session_id(session),
|
| 320 |
+
e,
|
| 321 |
+
)
|
| 322 |
+
return None
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def start_session_artifact_collection_task(
|
| 326 |
+
session: Any,
|
| 327 |
+
*,
|
| 328 |
+
token: str | bool | None = None,
|
| 329 |
+
) -> asyncio.Task | None:
|
| 330 |
+
"""Schedule best-effort collection creation for a newly started session."""
|
| 331 |
+
if session is None or not getattr(session, "session_id", None):
|
| 332 |
+
return None
|
| 333 |
+
if getattr(session, _COLLECTION_SLUG_ATTR, None):
|
| 334 |
+
return None
|
| 335 |
+
|
| 336 |
+
token_value = token if token is not None else getattr(session, "hf_token", None)
|
| 337 |
+
if not token_value:
|
| 338 |
+
return None
|
| 339 |
+
|
| 340 |
+
existing = getattr(session, _COLLECTION_TASK_ATTR, None)
|
| 341 |
+
if isinstance(existing, asyncio.Task) and not existing.done():
|
| 342 |
+
return existing
|
| 343 |
+
|
| 344 |
+
try:
|
| 345 |
+
loop = asyncio.get_running_loop()
|
| 346 |
+
except RuntimeError:
|
| 347 |
+
return None
|
| 348 |
+
|
| 349 |
+
async def _run() -> None:
|
| 350 |
+
await ensure_session_artifact_collection(session, token=token_value)
|
| 351 |
+
|
| 352 |
+
task = loop.create_task(_run())
|
| 353 |
+
try:
|
| 354 |
+
setattr(session, _COLLECTION_TASK_ATTR, task)
|
| 355 |
+
except Exception:
|
| 356 |
+
logger.debug("Could not attach ML Intern collection task to session")
|
| 357 |
+
return task
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def _add_to_collection(
|
| 361 |
+
api: Any,
|
| 362 |
+
session: Any,
|
| 363 |
+
repo_id: str,
|
| 364 |
+
repo_type: str,
|
| 365 |
+
*,
|
| 366 |
+
token: str | bool | None = None,
|
| 367 |
+
) -> None:
|
| 368 |
+
slug = _ensure_collection_slug(api, session, token=token)
|
| 369 |
+
if not slug:
|
| 370 |
+
return
|
| 371 |
+
api.add_collection_item(
|
| 372 |
+
collection_slug=slug,
|
| 373 |
+
item_id=repo_id,
|
| 374 |
+
item_type=repo_type,
|
| 375 |
+
note=(
|
| 376 |
+
f"Generated by ML Intern session {_safe_session_id(session)} "
|
| 377 |
+
f"on {session_artifact_date(session)}."
|
| 378 |
+
),
|
| 379 |
+
exists_ok=True,
|
| 380 |
+
token=token,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def register_hub_artifact(
|
| 385 |
+
api: Any,
|
| 386 |
+
repo_id: str,
|
| 387 |
+
repo_type: str = "model",
|
| 388 |
+
*,
|
| 389 |
+
session: Any = None,
|
| 390 |
+
token: str | bool | None = None,
|
| 391 |
+
extra_metadata: dict[str, Any] | None = None,
|
| 392 |
+
force: bool = False,
|
| 393 |
+
) -> bool:
|
| 394 |
+
"""Tag, card, and collection-register a Hub artifact without raising."""
|
| 395 |
+
if session is None or not repo_id:
|
| 396 |
+
return False
|
| 397 |
+
repo_type = repo_type or "model"
|
| 398 |
+
if repo_type not in SUPPORTED_REPO_TYPES:
|
| 399 |
+
return False
|
| 400 |
+
|
| 401 |
+
key = _artifact_key(repo_id, repo_type)
|
| 402 |
+
remember_hub_artifact(session, repo_id, repo_type)
|
| 403 |
+
registered = _session_artifact_set(session, _REGISTERED_ARTIFACTS_ATTR)
|
| 404 |
+
if key in registered and not force:
|
| 405 |
+
return True
|
| 406 |
+
|
| 407 |
+
token_value = token if token is not None else getattr(api, "token", None)
|
| 408 |
+
card_updated = False
|
| 409 |
+
collection_updated = False
|
| 410 |
+
try:
|
| 411 |
+
_update_repo_card(
|
| 412 |
+
api,
|
| 413 |
+
repo_id,
|
| 414 |
+
repo_type,
|
| 415 |
+
token=token_value,
|
| 416 |
+
extra_metadata=extra_metadata,
|
| 417 |
+
)
|
| 418 |
+
card_updated = True
|
| 419 |
+
except Exception as e:
|
| 420 |
+
logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e)
|
| 421 |
+
|
| 422 |
+
try:
|
| 423 |
+
_add_to_collection(api, session, repo_id, repo_type, token=token_value)
|
| 424 |
+
collection_updated = True
|
| 425 |
+
except Exception as e:
|
| 426 |
+
logger.debug("ML Intern collection update failed for %s: %s", repo_id, e)
|
| 427 |
+
|
| 428 |
+
if card_updated and collection_updated:
|
| 429 |
+
registered.add(key)
|
| 430 |
+
return True
|
| 431 |
+
return False
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def build_hub_artifact_sitecustomize(session: Any) -> str:
|
| 435 |
+
"""Build standalone sitecustomize.py code for HF Jobs Python processes."""
|
| 436 |
+
if session is None or not getattr(session, "session_id", None):
|
| 437 |
+
return ""
|
| 438 |
+
|
| 439 |
+
session_id = _safe_session_id(session)
|
| 440 |
+
session_date = session_artifact_date(session)
|
| 441 |
+
collection_title = artifact_collection_title(session)
|
| 442 |
+
collection_slug = getattr(session, _COLLECTION_SLUG_ATTR, None)
|
| 443 |
+
|
| 444 |
+
return (
|
| 445 |
+
textwrap.dedent(
|
| 446 |
+
f"""
|
| 447 |
+
# Auto-generated by ML Intern. Best-effort Hub artifact metadata only.
|
| 448 |
+
def _install_ml_intern_artifact_hooks():
|
| 449 |
+
import os
|
| 450 |
+
import re
|
| 451 |
+
import tempfile
|
| 452 |
+
from pathlib import Path
|
| 453 |
+
|
| 454 |
+
try:
|
| 455 |
+
import huggingface_hub as _hub
|
| 456 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 457 |
+
from huggingface_hub.repocard import metadata_load, metadata_save
|
| 458 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 459 |
+
except Exception:
|
| 460 |
+
return
|
| 461 |
+
|
| 462 |
+
session_id = {session_id!r}
|
| 463 |
+
session_date = {session_date!r}
|
| 464 |
+
collection_title = {collection_title!r}
|
| 465 |
+
tag = {ML_INTERN_TAG!r}
|
| 466 |
+
marker = {PROVENANCE_MARKER!r}
|
| 467 |
+
supported = {sorted(SUPPORTED_REPO_TYPES)!r}
|
| 468 |
+
registering = False
|
| 469 |
+
collection_slug = {collection_slug!r}
|
| 470 |
+
registered = set()
|
| 471 |
+
usage_re = re.compile(
|
| 472 |
+
r"^#{{2,6}}\\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\\b",
|
| 473 |
+
re.IGNORECASE | re.MULTILINE,
|
| 474 |
+
)
|
| 475 |
+
front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL)
|
| 476 |
+
|
| 477 |
+
def _token(value=None, api=None):
|
| 478 |
+
if isinstance(value, str) and value:
|
| 479 |
+
return value
|
| 480 |
+
api_token = getattr(api, "token", None)
|
| 481 |
+
if isinstance(api_token, str) and api_token:
|
| 482 |
+
return api_token
|
| 483 |
+
return (
|
| 484 |
+
os.environ.get("HF_TOKEN")
|
| 485 |
+
or os.environ.get("HUGGINGFACE_HUB_TOKEN")
|
| 486 |
+
or None
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
def _merge_tags(metadata):
|
| 490 |
+
metadata = dict(metadata or {{}})
|
| 491 |
+
raw_tags = metadata.get("tags")
|
| 492 |
+
if raw_tags is None:
|
| 493 |
+
tags = []
|
| 494 |
+
elif isinstance(raw_tags, str):
|
| 495 |
+
tags = [raw_tags]
|
| 496 |
+
elif isinstance(raw_tags, list):
|
| 497 |
+
tags = [str(item) for item in raw_tags]
|
| 498 |
+
else:
|
| 499 |
+
tags = [str(raw_tags)]
|
| 500 |
+
if tag not in tags:
|
| 501 |
+
tags.append(tag)
|
| 502 |
+
metadata["tags"] = tags
|
| 503 |
+
return metadata
|
| 504 |
+
|
| 505 |
+
def _metadata_from_content(content):
|
| 506 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 507 |
+
path = Path(tmp_dir) / "README.md"
|
| 508 |
+
path.write_text(content or "", encoding="utf-8")
|
| 509 |
+
return metadata_load(path) or {{}}
|
| 510 |
+
|
| 511 |
+
def _content_with_metadata(content, metadata):
|
| 512 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 513 |
+
path = Path(tmp_dir) / "README.md"
|
| 514 |
+
path.write_text(content or "", encoding="utf-8")
|
| 515 |
+
metadata_save(path, metadata)
|
| 516 |
+
return path.read_text(encoding="utf-8")
|
| 517 |
+
|
| 518 |
+
def _body_without_metadata(content):
|
| 519 |
+
return front_matter_re.sub("", content or "", count=1).strip()
|
| 520 |
+
|
| 521 |
+
def _append_section(content, section):
|
| 522 |
+
base = (content or "").rstrip()
|
| 523 |
+
if base:
|
| 524 |
+
return base + "\\n\\n" + section.strip() + "\\n"
|
| 525 |
+
return section.strip() + "\\n"
|
| 526 |
+
|
| 527 |
+
def _provenance(repo_type):
|
| 528 |
+
label = {{"model": "model", "dataset": "dataset"}}.get(
|
| 529 |
+
repo_type, "Hub"
|
| 530 |
+
)
|
| 531 |
+
return (
|
| 532 |
+
marker
|
| 533 |
+
+ "\\n## Generated by ML Intern\\n\\n"
|
| 534 |
+
+ f"This {{label}} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.\\n\\n"
|
| 535 |
+
+ "- Try ML Intern: https://smolagents-ml-intern.hf.space\\n"
|
| 536 |
+
+ "- Source code: https://github.com/huggingface/ml-intern\\n"
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
def _usage(repo_id, repo_type):
|
| 540 |
+
if repo_type == "dataset":
|
| 541 |
+
return (
|
| 542 |
+
"## Usage\\n\\n"
|
| 543 |
+
"```python\\n"
|
| 544 |
+
"from datasets import load_dataset\\n\\n"
|
| 545 |
+
f"dataset = load_dataset({{repo_id!r}})\\n"
|
| 546 |
+
"```\\n"
|
| 547 |
+
)
|
| 548 |
+
return (
|
| 549 |
+
"## Usage\\n\\n"
|
| 550 |
+
"```python\\n"
|
| 551 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\\n\\n"
|
| 552 |
+
f"model_id = {{repo_id!r}}\\n"
|
| 553 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id)\\n"
|
| 554 |
+
"model = AutoModelForCausalLM.from_pretrained(model_id)\\n"
|
| 555 |
+
"```\\n\\n"
|
| 556 |
+
"For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.\\n"
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
def _augment(content, repo_id, repo_type, extra_metadata=None):
|
| 560 |
+
metadata = _metadata_from_content(content or "")
|
| 561 |
+
if extra_metadata:
|
| 562 |
+
metadata = {{**extra_metadata, **metadata}}
|
| 563 |
+
updated = _content_with_metadata(content or "", _merge_tags(metadata))
|
| 564 |
+
if not _body_without_metadata(updated):
|
| 565 |
+
updated = _append_section(updated, f"# {{repo_id}}")
|
| 566 |
+
if repo_type in {{"model", "dataset"}} and marker not in updated:
|
| 567 |
+
updated = _append_section(updated, _provenance(repo_type))
|
| 568 |
+
if not usage_re.search(content or ""):
|
| 569 |
+
updated = _append_section(updated, _usage(repo_id, repo_type))
|
| 570 |
+
return updated
|
| 571 |
+
|
| 572 |
+
def _readme(api, repo_id, repo_type, token_value):
|
| 573 |
+
try:
|
| 574 |
+
path = hf_hub_download(
|
| 575 |
+
repo_id=repo_id,
|
| 576 |
+
filename="README.md",
|
| 577 |
+
repo_type=repo_type,
|
| 578 |
+
token=token_value,
|
| 579 |
+
)
|
| 580 |
+
except (EntryNotFoundError, RepositoryNotFoundError):
|
| 581 |
+
return ""
|
| 582 |
+
return Path(path).read_text(encoding="utf-8")
|
| 583 |
+
|
| 584 |
+
def _ensure_collection(api, token_value):
|
| 585 |
+
nonlocal collection_slug
|
| 586 |
+
if collection_slug:
|
| 587 |
+
return collection_slug
|
| 588 |
+
collection = api.create_collection(
|
| 589 |
+
title=collection_title,
|
| 590 |
+
description=(
|
| 591 |
+
f"Artifacts generated by ML Intern session {{session_id}} "
|
| 592 |
+
f"on {{session_date}}."
|
| 593 |
+
),
|
| 594 |
+
private=True,
|
| 595 |
+
exists_ok=True,
|
| 596 |
+
token=token_value,
|
| 597 |
+
)
|
| 598 |
+
collection_slug = getattr(collection, "slug", None)
|
| 599 |
+
return collection_slug
|
| 600 |
+
|
| 601 |
+
def _register(
|
| 602 |
+
repo_id,
|
| 603 |
+
repo_type="model",
|
| 604 |
+
token_value=None,
|
| 605 |
+
extra_metadata=None,
|
| 606 |
+
force=False,
|
| 607 |
+
):
|
| 608 |
+
nonlocal registering
|
| 609 |
+
if registering or not repo_id:
|
| 610 |
+
return
|
| 611 |
+
repo_type = repo_type or "model"
|
| 612 |
+
if repo_type not in supported:
|
| 613 |
+
return
|
| 614 |
+
key = f"{{repo_type}}:{{repo_id}}"
|
| 615 |
+
if key in registered and not force:
|
| 616 |
+
return
|
| 617 |
+
registering = True
|
| 618 |
+
try:
|
| 619 |
+
token_value = _token(token_value)
|
| 620 |
+
api = HfApi(token=token_value)
|
| 621 |
+
try:
|
| 622 |
+
current = _readme(api, repo_id, repo_type, token_value)
|
| 623 |
+
updated = _augment(
|
| 624 |
+
current, repo_id, repo_type, extra_metadata=extra_metadata
|
| 625 |
+
)
|
| 626 |
+
if updated != current:
|
| 627 |
+
_original_upload_file(
|
| 628 |
+
api,
|
| 629 |
+
path_or_fileobj=updated.encode("utf-8"),
|
| 630 |
+
path_in_repo="README.md",
|
| 631 |
+
repo_id=repo_id,
|
| 632 |
+
repo_type=repo_type,
|
| 633 |
+
token=token_value,
|
| 634 |
+
commit_message="Update ML Intern artifact metadata",
|
| 635 |
+
)
|
| 636 |
+
except Exception:
|
| 637 |
+
pass
|
| 638 |
+
try:
|
| 639 |
+
slug = _ensure_collection(api, token_value)
|
| 640 |
+
if slug:
|
| 641 |
+
api.add_collection_item(
|
| 642 |
+
collection_slug=slug,
|
| 643 |
+
item_id=repo_id,
|
| 644 |
+
item_type=repo_type,
|
| 645 |
+
note=(
|
| 646 |
+
f"Generated by ML Intern session {{session_id}} "
|
| 647 |
+
f"on {{session_date}}."
|
| 648 |
+
),
|
| 649 |
+
exists_ok=True,
|
| 650 |
+
token=token_value,
|
| 651 |
+
)
|
| 652 |
+
except Exception:
|
| 653 |
+
pass
|
| 654 |
+
registered.add(key)
|
| 655 |
+
finally:
|
| 656 |
+
registering = False
|
| 657 |
+
|
| 658 |
+
_original_create_repo = HfApi.create_repo
|
| 659 |
+
_original_upload_file = HfApi.upload_file
|
| 660 |
+
_original_upload_folder = getattr(HfApi, "upload_folder", None)
|
| 661 |
+
_original_create_commit = getattr(HfApi, "create_commit", None)
|
| 662 |
+
|
| 663 |
+
def _repo_id(args, kwargs):
|
| 664 |
+
return kwargs.get("repo_id") or (args[0] if args else None)
|
| 665 |
+
|
| 666 |
+
def _repo_type(kwargs):
|
| 667 |
+
return kwargs.get("repo_type") or "model"
|
| 668 |
+
|
| 669 |
+
def _patched_create_repo(self, *args, **kwargs):
|
| 670 |
+
result = _original_create_repo(self, *args, **kwargs)
|
| 671 |
+
repo_id = _repo_id(args, kwargs)
|
| 672 |
+
repo_type = _repo_type(kwargs)
|
| 673 |
+
extra = None
|
| 674 |
+
if repo_type == "space" and kwargs.get("space_sdk"):
|
| 675 |
+
extra = {{"sdk": kwargs.get("space_sdk")}}
|
| 676 |
+
_register(repo_id, repo_type, _token(kwargs.get("token"), self), extra)
|
| 677 |
+
return result
|
| 678 |
+
|
| 679 |
+
def _patched_upload_file(self, *args, **kwargs):
|
| 680 |
+
result = _original_upload_file(self, *args, **kwargs)
|
| 681 |
+
if not kwargs.get("create_pr"):
|
| 682 |
+
force = kwargs.get("path_in_repo") == "README.md"
|
| 683 |
+
_register(
|
| 684 |
+
kwargs.get("repo_id"),
|
| 685 |
+
_repo_type(kwargs),
|
| 686 |
+
_token(kwargs.get("token"), self),
|
| 687 |
+
force=force,
|
| 688 |
+
)
|
| 689 |
+
return result
|
| 690 |
+
|
| 691 |
+
def _patched_upload_folder(self, *args, **kwargs):
|
| 692 |
+
result = _original_upload_folder(self, *args, **kwargs)
|
| 693 |
+
if not kwargs.get("create_pr"):
|
| 694 |
+
_register(
|
| 695 |
+
kwargs.get("repo_id"),
|
| 696 |
+
_repo_type(kwargs),
|
| 697 |
+
_token(kwargs.get("token"), self),
|
| 698 |
+
force=True,
|
| 699 |
+
)
|
| 700 |
+
return result
|
| 701 |
+
|
| 702 |
+
def _patched_create_commit(self, *args, **kwargs):
|
| 703 |
+
result = _original_create_commit(self, *args, **kwargs)
|
| 704 |
+
if not kwargs.get("create_pr"):
|
| 705 |
+
_register(
|
| 706 |
+
_repo_id(args, kwargs),
|
| 707 |
+
_repo_type(kwargs),
|
| 708 |
+
_token(kwargs.get("token"), self),
|
| 709 |
+
force=True,
|
| 710 |
+
)
|
| 711 |
+
return result
|
| 712 |
+
|
| 713 |
+
HfApi.create_repo = _patched_create_repo
|
| 714 |
+
HfApi.upload_file = _patched_upload_file
|
| 715 |
+
if _original_upload_folder is not None:
|
| 716 |
+
HfApi.upload_folder = _patched_upload_folder
|
| 717 |
+
if _original_create_commit is not None:
|
| 718 |
+
HfApi.create_commit = _patched_create_commit
|
| 719 |
+
|
| 720 |
+
def _patch_module_func(name, method_name):
|
| 721 |
+
original = getattr(_hub, name, None)
|
| 722 |
+
if original is None:
|
| 723 |
+
return
|
| 724 |
+
method = getattr(HfApi, method_name)
|
| 725 |
+
|
| 726 |
+
def _patched(*args, **kwargs):
|
| 727 |
+
api = HfApi(token=_token(kwargs.get("token")))
|
| 728 |
+
return method(api, *args, **kwargs)
|
| 729 |
+
|
| 730 |
+
setattr(_hub, name, _patched)
|
| 731 |
+
|
| 732 |
+
_patch_module_func("create_repo", "create_repo")
|
| 733 |
+
_patch_module_func("upload_file", "upload_file")
|
| 734 |
+
if _original_upload_folder is not None:
|
| 735 |
+
_patch_module_func("upload_folder", "upload_folder")
|
| 736 |
+
if _original_create_commit is not None:
|
| 737 |
+
_patch_module_func("create_commit", "create_commit")
|
| 738 |
+
|
| 739 |
+
try:
|
| 740 |
+
_install_ml_intern_artifact_hooks()
|
| 741 |
+
except Exception:
|
| 742 |
+
pass
|
| 743 |
+
"""
|
| 744 |
+
).strip()
|
| 745 |
+
+ "\n"
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
def wrap_shell_command_with_hub_artifact_bootstrap(
|
| 750 |
+
command: str,
|
| 751 |
+
session: Any,
|
| 752 |
+
) -> str:
|
| 753 |
+
"""Prefix a shell command so child Python processes load Hub hooks."""
|
| 754 |
+
sitecustomize = build_hub_artifact_sitecustomize(session)
|
| 755 |
+
if not sitecustomize or not command:
|
| 756 |
+
return command
|
| 757 |
+
|
| 758 |
+
encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii")
|
| 759 |
+
bootstrap = (
|
| 760 |
+
'_ml_intern_artifacts_dir="$(mktemp -d 2>/dev/null)" '
|
| 761 |
+
f"&& printf %s {shlex.quote(encoded)} | base64 -d "
|
| 762 |
+
'> "$_ml_intern_artifacts_dir/sitecustomize.py" '
|
| 763 |
+
'&& export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"'
|
| 764 |
+
)
|
| 765 |
+
return f"{bootstrap}; {command}"
|
agent/tools/hf_repo_files_tool.py
CHANGED
|
@@ -10,6 +10,7 @@ from typing import Any, Dict, Literal, Optional
|
|
| 10 |
from huggingface_hub import HfApi, hf_hub_download
|
| 11 |
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 12 |
|
|
|
|
| 13 |
from agent.tools.types import ToolResult
|
| 14 |
|
| 15 |
OperationType = Literal["list", "read", "upload", "delete"]
|
|
@@ -39,8 +40,9 @@ def _format_size(size_bytes: int) -> str:
|
|
| 39 |
class HfRepoFilesTool:
|
| 40 |
"""Tool for file operations on HF repos."""
|
| 41 |
|
| 42 |
-
def __init__(self, hf_token: Optional[str] = None):
|
| 43 |
self.api = HfApi(token=hf_token)
|
|
|
|
| 44 |
|
| 45 |
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 46 |
"""Execute the specified operation."""
|
|
@@ -214,6 +216,16 @@ class HfRepoFilesTool:
|
|
| 214 |
create_pr=create_pr,
|
| 215 |
)
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
url = _build_repo_url(repo_id, repo_type)
|
| 218 |
if create_pr and hasattr(result, "pr_url"):
|
| 219 |
response = f"**Uploaded as PR**\n{result.pr_url}"
|
|
@@ -343,7 +355,7 @@ async def hf_repo_files_handler(
|
|
| 343 |
"""Handler for agent tool router."""
|
| 344 |
try:
|
| 345 |
hf_token = session.hf_token if session else None
|
| 346 |
-
tool = HfRepoFilesTool(hf_token=hf_token)
|
| 347 |
result = await tool.execute(arguments)
|
| 348 |
return result["formatted"], not result.get("isError", False)
|
| 349 |
except Exception as e:
|
|
|
|
| 10 |
from huggingface_hub import HfApi, hf_hub_download
|
| 11 |
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 12 |
|
| 13 |
+
from agent.core.hub_artifacts import is_known_hub_artifact, register_hub_artifact
|
| 14 |
from agent.tools.types import ToolResult
|
| 15 |
|
| 16 |
OperationType = Literal["list", "read", "upload", "delete"]
|
|
|
|
| 40 |
class HfRepoFilesTool:
|
| 41 |
"""Tool for file operations on HF repos."""
|
| 42 |
|
| 43 |
+
def __init__(self, hf_token: Optional[str] = None, session: Any = None):
|
| 44 |
self.api = HfApi(token=hf_token)
|
| 45 |
+
self.session = session
|
| 46 |
|
| 47 |
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 48 |
"""Execute the specified operation."""
|
|
|
|
| 216 |
create_pr=create_pr,
|
| 217 |
)
|
| 218 |
|
| 219 |
+
if not create_pr and is_known_hub_artifact(self.session, repo_id, repo_type):
|
| 220 |
+
await _async_call(
|
| 221 |
+
register_hub_artifact,
|
| 222 |
+
self.api,
|
| 223 |
+
repo_id,
|
| 224 |
+
repo_type,
|
| 225 |
+
session=self.session,
|
| 226 |
+
force=path == "README.md",
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
url = _build_repo_url(repo_id, repo_type)
|
| 230 |
if create_pr and hasattr(result, "pr_url"):
|
| 231 |
response = f"**Uploaded as PR**\n{result.pr_url}"
|
|
|
|
| 355 |
"""Handler for agent tool router."""
|
| 356 |
try:
|
| 357 |
hf_token = session.hf_token if session else None
|
| 358 |
+
tool = HfRepoFilesTool(hf_token=hf_token, session=session)
|
| 359 |
result = await tool.execute(arguments)
|
| 360 |
return result["formatted"], not result.get("isError", False)
|
| 361 |
except Exception as e:
|
agent/tools/hf_repo_git_tool.py
CHANGED
|
@@ -10,6 +10,7 @@ from typing import Any, Dict, Literal, Optional
|
|
| 10 |
from huggingface_hub import HfApi
|
| 11 |
from huggingface_hub.utils import RepositoryNotFoundError
|
| 12 |
|
|
|
|
| 13 |
from agent.tools.types import ToolResult
|
| 14 |
|
| 15 |
OperationType = Literal[
|
|
@@ -45,8 +46,9 @@ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
|
|
| 45 |
class HfRepoGitTool:
|
| 46 |
"""Tool for git-like operations on HF repos."""
|
| 47 |
|
| 48 |
-
def __init__(self, hf_token: Optional[str] = None):
|
| 49 |
self.api = HfApi(token=hf_token)
|
|
|
|
| 50 |
|
| 51 |
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 52 |
"""Execute the specified operation."""
|
|
@@ -552,6 +554,17 @@ class HfRepoGitTool:
|
|
| 552 |
kwargs["space_sdk"] = space_sdk
|
| 553 |
|
| 554 |
result = await _async_call(self.api.create_repo, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
|
| 556 |
return {
|
| 557 |
"formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
|
|
@@ -747,7 +760,7 @@ async def hf_repo_git_handler(
|
|
| 747 |
"""Handler for agent tool router."""
|
| 748 |
try:
|
| 749 |
hf_token = session.hf_token if session else None
|
| 750 |
-
tool = HfRepoGitTool(hf_token=hf_token)
|
| 751 |
result = await tool.execute(arguments)
|
| 752 |
return result["formatted"], not result.get("isError", False)
|
| 753 |
except Exception as e:
|
|
|
|
| 10 |
from huggingface_hub import HfApi
|
| 11 |
from huggingface_hub.utils import RepositoryNotFoundError
|
| 12 |
|
| 13 |
+
from agent.core.hub_artifacts import register_hub_artifact
|
| 14 |
from agent.tools.types import ToolResult
|
| 15 |
|
| 16 |
OperationType = Literal[
|
|
|
|
| 46 |
class HfRepoGitTool:
|
| 47 |
"""Tool for git-like operations on HF repos."""
|
| 48 |
|
| 49 |
+
def __init__(self, hf_token: Optional[str] = None, session: Any = None):
|
| 50 |
self.api = HfApi(token=hf_token)
|
| 51 |
+
self.session = session
|
| 52 |
|
| 53 |
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 54 |
"""Execute the specified operation."""
|
|
|
|
| 554 |
kwargs["space_sdk"] = space_sdk
|
| 555 |
|
| 556 |
result = await _async_call(self.api.create_repo, **kwargs)
|
| 557 |
+
extra_metadata = None
|
| 558 |
+
if repo_type == "space" and space_sdk:
|
| 559 |
+
extra_metadata = {"sdk": space_sdk}
|
| 560 |
+
await _async_call(
|
| 561 |
+
register_hub_artifact,
|
| 562 |
+
self.api,
|
| 563 |
+
repo_id,
|
| 564 |
+
repo_type,
|
| 565 |
+
session=self.session,
|
| 566 |
+
extra_metadata=extra_metadata,
|
| 567 |
+
)
|
| 568 |
|
| 569 |
return {
|
| 570 |
"formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
|
|
|
|
| 760 |
"""Handler for agent tool router."""
|
| 761 |
try:
|
| 762 |
hf_token = session.hf_token if session else None
|
| 763 |
+
tool = HfRepoGitTool(hf_token=hf_token, session=session)
|
| 764 |
result = await tool.execute(arguments)
|
| 765 |
return result["formatted"], not result.get("isError", False)
|
| 766 |
except Exception as e:
|
agent/tools/jobs_tool.py
CHANGED
|
@@ -9,6 +9,7 @@ import base64
|
|
| 9 |
import http.client
|
| 10 |
import logging
|
| 11 |
import re
|
|
|
|
| 12 |
from typing import Any, Awaitable, Callable, Dict, Literal, Optional
|
| 13 |
|
| 14 |
import httpx
|
|
@@ -20,6 +21,7 @@ from agent.core.hf_access import (
|
|
| 20 |
is_billing_error,
|
| 21 |
resolve_jobs_namespace,
|
| 22 |
)
|
|
|
|
| 23 |
from agent.core.session import Event
|
| 24 |
from agent.tools.trackio_seed import ensure_trackio_dashboard
|
| 25 |
from agent.tools.types import ToolResult
|
|
@@ -237,6 +239,26 @@ def _resolve_uv_command(
|
|
| 237 |
return _build_uv_command(script, with_deps, python, script_args)
|
| 238 |
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
async def _async_call(func, *args, **kwargs):
|
| 241 |
"""Wrap synchronous HfApi calls for async context"""
|
| 242 |
return await asyncio.to_thread(func, *args, **kwargs)
|
|
@@ -560,6 +582,8 @@ class HfJobsTool:
|
|
| 560 |
image = args.get("image", "python:3.12")
|
| 561 |
job_type = "Docker"
|
| 562 |
|
|
|
|
|
|
|
| 563 |
# Run the job
|
| 564 |
flavor = args.get("hardware_flavor", "cpu-basic")
|
| 565 |
timeout_str = args.get("timeout", "30m")
|
|
@@ -912,6 +936,8 @@ To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}}
|
|
| 912 |
image = args.get("image", "python:3.12")
|
| 913 |
job_type = "Docker"
|
| 914 |
|
|
|
|
|
|
|
| 915 |
# Create scheduled job
|
| 916 |
scheduled_job = await _async_call(
|
| 917 |
self.api.create_scheduled_job,
|
|
|
|
| 9 |
import http.client
|
| 10 |
import logging
|
| 11 |
import re
|
| 12 |
+
import shlex
|
| 13 |
from typing import Any, Awaitable, Callable, Dict, Literal, Optional
|
| 14 |
|
| 15 |
import httpx
|
|
|
|
| 21 |
is_billing_error,
|
| 22 |
resolve_jobs_namespace,
|
| 23 |
)
|
| 24 |
+
from agent.core.hub_artifacts import build_hub_artifact_sitecustomize
|
| 25 |
from agent.core.session import Event
|
| 26 |
from agent.tools.trackio_seed import ensure_trackio_dashboard
|
| 27 |
from agent.tools.types import ToolResult
|
|
|
|
| 239 |
return _build_uv_command(script, with_deps, python, script_args)
|
| 240 |
|
| 241 |
|
| 242 |
+
def _wrap_command_with_artifact_bootstrap(
|
| 243 |
+
command: list[str], session: Any = None
|
| 244 |
+
) -> list[str]:
|
| 245 |
+
"""Install sitecustomize hooks before the user command runs in HF Jobs."""
|
| 246 |
+
sitecustomize = build_hub_artifact_sitecustomize(session)
|
| 247 |
+
if not sitecustomize:
|
| 248 |
+
return command
|
| 249 |
+
|
| 250 |
+
encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii")
|
| 251 |
+
original_command = shlex.join(command)
|
| 252 |
+
shell = (
|
| 253 |
+
'set -e; _ml_intern_artifacts_dir="$(mktemp -d)"; '
|
| 254 |
+
f"printf %s {shlex.quote(encoded)} | base64 -d "
|
| 255 |
+
'> "$_ml_intern_artifacts_dir/sitecustomize.py"; '
|
| 256 |
+
'export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"; '
|
| 257 |
+
f"exec {original_command}"
|
| 258 |
+
)
|
| 259 |
+
return ["/bin/sh", "-lc", shell]
|
| 260 |
+
|
| 261 |
+
|
| 262 |
async def _async_call(func, *args, **kwargs):
|
| 263 |
"""Wrap synchronous HfApi calls for async context"""
|
| 264 |
return await asyncio.to_thread(func, *args, **kwargs)
|
|
|
|
| 582 |
image = args.get("image", "python:3.12")
|
| 583 |
job_type = "Docker"
|
| 584 |
|
| 585 |
+
command = _wrap_command_with_artifact_bootstrap(command, self.session)
|
| 586 |
+
|
| 587 |
# Run the job
|
| 588 |
flavor = args.get("hardware_flavor", "cpu-basic")
|
| 589 |
timeout_str = args.get("timeout", "30m")
|
|
|
|
| 936 |
image = args.get("image", "python:3.12")
|
| 937 |
job_type = "Docker"
|
| 938 |
|
| 939 |
+
command = _wrap_command_with_artifact_bootstrap(command, self.session)
|
| 940 |
+
|
| 941 |
# Create scheduled job
|
| 942 |
scheduled_job = await _async_call(
|
| 943 |
self.api.create_scheduled_job,
|
agent/tools/local_tools.py
CHANGED
|
@@ -15,6 +15,8 @@ import tempfile
|
|
| 15 |
from pathlib import Path
|
| 16 |
from typing import Any
|
| 17 |
|
|
|
|
|
|
|
| 18 |
|
| 19 |
MAX_OUTPUT_CHARS = 25_000
|
| 20 |
MAX_LINE_LENGTH = 4000
|
|
@@ -98,10 +100,13 @@ def _truncate_output(
|
|
| 98 |
# ── Handlers ────────────────────────────────────────────────────────────
|
| 99 |
|
| 100 |
|
| 101 |
-
async def _bash_handler(
|
|
|
|
|
|
|
| 102 |
command = args.get("command", "")
|
| 103 |
if not command:
|
| 104 |
return "No command provided.", False
|
|
|
|
| 105 |
work_dir = args.get("work_dir", ".")
|
| 106 |
timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT)
|
| 107 |
try:
|
|
|
|
| 15 |
from pathlib import Path
|
| 16 |
from typing import Any
|
| 17 |
|
| 18 |
+
from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap
|
| 19 |
+
|
| 20 |
|
| 21 |
MAX_OUTPUT_CHARS = 25_000
|
| 22 |
MAX_LINE_LENGTH = 4000
|
|
|
|
| 100 |
# ── Handlers ────────────────────────────────────────────────────────────
|
| 101 |
|
| 102 |
|
| 103 |
+
async def _bash_handler(
|
| 104 |
+
args: dict[str, Any], session: Any = None, **_kw
|
| 105 |
+
) -> tuple[str, bool]:
|
| 106 |
command = args.get("command", "")
|
| 107 |
if not command:
|
| 108 |
return "No command provided.", False
|
| 109 |
+
command = wrap_shell_command_with_hub_artifact_bootstrap(command, session)
|
| 110 |
work_dir = args.get("work_dir", ".")
|
| 111 |
timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT)
|
| 112 |
try:
|
agent/tools/sandbox_tool.py
CHANGED
|
@@ -21,6 +21,7 @@ from typing import Any
|
|
| 21 |
|
| 22 |
from huggingface_hub import HfApi, SpaceHardware
|
| 23 |
|
|
|
|
| 24 |
from agent.core.session import Event
|
| 25 |
from agent.tools.sandbox_client import Sandbox
|
| 26 |
from agent.tools.trackio_seed import ensure_trackio_dashboard
|
|
@@ -729,6 +730,14 @@ def _make_tool_handler(sandbox_tool_name: str):
|
|
| 729 |
return "Sandbox is still starting. Please retry shortly.", False
|
| 730 |
|
| 731 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args)
|
| 733 |
if result.success:
|
| 734 |
output = result.output or "(no output)"
|
|
|
|
| 21 |
|
| 22 |
from huggingface_hub import HfApi, SpaceHardware
|
| 23 |
|
| 24 |
+
from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap
|
| 25 |
from agent.core.session import Event
|
| 26 |
from agent.tools.sandbox_client import Sandbox
|
| 27 |
from agent.tools.trackio_seed import ensure_trackio_dashboard
|
|
|
|
| 730 |
return "Sandbox is still starting. Please retry shortly.", False
|
| 731 |
|
| 732 |
try:
|
| 733 |
+
if sandbox_tool_name == "bash" and args.get("command"):
|
| 734 |
+
args = {
|
| 735 |
+
**args,
|
| 736 |
+
"command": wrap_shell_command_with_hub_artifact_bootstrap(
|
| 737 |
+
args["command"],
|
| 738 |
+
session,
|
| 739 |
+
),
|
| 740 |
+
}
|
| 741 |
result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args)
|
| 742 |
if result.success:
|
| 743 |
output = result.output or "(no output)"
|
backend/session_manager.py
CHANGED
|
@@ -12,10 +12,11 @@ from typing import Any, Optional
|
|
| 12 |
|
| 13 |
from agent.config import load_config
|
| 14 |
from agent.core.agent_loop import process_submission
|
| 15 |
-
from agent.
|
| 16 |
from agent.core.session import Event, OpType, Session
|
| 17 |
from agent.core.session_persistence import get_session_store
|
| 18 |
from agent.core.tools import ToolRouter
|
|
|
|
| 19 |
|
| 20 |
# Get project root (parent of backend directory)
|
| 21 |
PROJECT_ROOT = Path(__file__).parent.parent
|
|
@@ -135,6 +136,7 @@ class SessionManager:
|
|
| 135 |
self.sessions: dict[str, AgentSession] = {}
|
| 136 |
self._lock = asyncio.Lock()
|
| 137 |
self.persistence_store = None
|
|
|
|
| 138 |
|
| 139 |
async def start(self) -> None:
|
| 140 |
"""Start shared background resources."""
|
|
@@ -411,6 +413,28 @@ class SessionManager:
|
|
| 411 |
session.sandbox_preload_cancel_event = None
|
| 412 |
self._start_cpu_sandbox_preload(agent_session)
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
|
| 415 |
try:
|
| 416 |
await self._store().update_session_fields(
|
|
@@ -567,6 +591,7 @@ class SessionManager:
|
|
| 567 |
existing,
|
| 568 |
preload_sandbox=preload_sandbox,
|
| 569 |
)
|
|
|
|
| 570 |
return existing
|
| 571 |
return None
|
| 572 |
|
|
@@ -588,6 +613,7 @@ class SessionManager:
|
|
| 588 |
existing,
|
| 589 |
preload_sandbox=preload_sandbox,
|
| 590 |
)
|
|
|
|
| 591 |
return existing
|
| 592 |
return None
|
| 593 |
|
|
@@ -674,7 +700,9 @@ class SessionManager:
|
|
| 674 |
hf_token=hf_token,
|
| 675 |
hf_username=hf_username,
|
| 676 |
)
|
|
|
|
| 677 |
return started
|
|
|
|
| 678 |
if preload_sandbox:
|
| 679 |
self._start_cpu_sandbox_preload(agent_session)
|
| 680 |
logger.info("Restored session %s for user %s", session_id, owner or user_id)
|
|
@@ -757,6 +785,7 @@ class SessionManager:
|
|
| 757 |
event_queue=event_queue,
|
| 758 |
tool_router=tool_router,
|
| 759 |
)
|
|
|
|
| 760 |
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 761 |
self._start_cpu_sandbox_preload(agent_session)
|
| 762 |
|
|
|
|
| 12 |
|
| 13 |
from agent.config import load_config
|
| 14 |
from agent.core.agent_loop import process_submission
|
| 15 |
+
from agent.core.hub_artifacts import start_session_artifact_collection_task
|
| 16 |
from agent.core.session import Event, OpType, Session
|
| 17 |
from agent.core.session_persistence import get_session_store
|
| 18 |
from agent.core.tools import ToolRouter
|
| 19 |
+
from agent.messaging.gateway import NotificationGateway
|
| 20 |
|
| 21 |
# Get project root (parent of backend directory)
|
| 22 |
PROJECT_ROOT = Path(__file__).parent.parent
|
|
|
|
| 136 |
self.sessions: dict[str, AgentSession] = {}
|
| 137 |
self._lock = asyncio.Lock()
|
| 138 |
self.persistence_store = None
|
| 139 |
+
self.enable_hub_artifact_collections = True
|
| 140 |
|
| 141 |
async def start(self) -> None:
|
| 142 |
"""Start shared background resources."""
|
|
|
|
| 413 |
session.sandbox_preload_cancel_event = None
|
| 414 |
self._start_cpu_sandbox_preload(agent_session)
|
| 415 |
|
| 416 |
+
def _start_hub_artifact_collection(self, agent_session: AgentSession) -> None:
|
| 417 |
+
"""Kick off best-effort Hub collection creation for the session."""
|
| 418 |
+
if not getattr(self, "enable_hub_artifact_collections", False):
|
| 419 |
+
return
|
| 420 |
+
session = agent_session.session
|
| 421 |
+
if not getattr(session, "session_id", None):
|
| 422 |
+
try:
|
| 423 |
+
session.session_id = agent_session.session_id
|
| 424 |
+
except Exception:
|
| 425 |
+
logger.debug("Could not attach session id for Hub artifact collection")
|
| 426 |
+
token = agent_session.hf_token or getattr(session, "hf_token", None)
|
| 427 |
+
if not token:
|
| 428 |
+
return
|
| 429 |
+
try:
|
| 430 |
+
start_session_artifact_collection_task(session, token=token)
|
| 431 |
+
except Exception as e:
|
| 432 |
+
logger.debug(
|
| 433 |
+
"Failed to schedule Hub artifact collection for %s: %s",
|
| 434 |
+
agent_session.session_id,
|
| 435 |
+
e,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
|
| 439 |
try:
|
| 440 |
await self._store().update_session_fields(
|
|
|
|
| 591 |
existing,
|
| 592 |
preload_sandbox=preload_sandbox,
|
| 593 |
)
|
| 594 |
+
self._start_hub_artifact_collection(existing)
|
| 595 |
return existing
|
| 596 |
return None
|
| 597 |
|
|
|
|
| 613 |
existing,
|
| 614 |
preload_sandbox=preload_sandbox,
|
| 615 |
)
|
| 616 |
+
self._start_hub_artifact_collection(existing)
|
| 617 |
return existing
|
| 618 |
return None
|
| 619 |
|
|
|
|
| 700 |
hf_token=hf_token,
|
| 701 |
hf_username=hf_username,
|
| 702 |
)
|
| 703 |
+
self._start_hub_artifact_collection(started)
|
| 704 |
return started
|
| 705 |
+
self._start_hub_artifact_collection(agent_session)
|
| 706 |
if preload_sandbox:
|
| 707 |
self._start_cpu_sandbox_preload(agent_session)
|
| 708 |
logger.info("Restored session %s for user %s", session_id, owner or user_id)
|
|
|
|
| 785 |
event_queue=event_queue,
|
| 786 |
tool_router=tool_router,
|
| 787 |
)
|
| 788 |
+
self._start_hub_artifact_collection(agent_session)
|
| 789 |
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 790 |
self._start_cpu_sandbox_preload(agent_session)
|
| 791 |
|
tests/unit/test_hub_artifacts.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
from types import SimpleNamespace
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from agent.core import hub_artifacts
|
| 8 |
+
from agent.core.hub_artifacts import (
|
| 9 |
+
ML_INTERN_TAG,
|
| 10 |
+
PROVENANCE_MARKER,
|
| 11 |
+
artifact_collection_title,
|
| 12 |
+
augment_repo_card_content,
|
| 13 |
+
build_hub_artifact_sitecustomize,
|
| 14 |
+
ensure_session_artifact_collection,
|
| 15 |
+
is_known_hub_artifact,
|
| 16 |
+
register_hub_artifact,
|
| 17 |
+
remember_hub_artifact,
|
| 18 |
+
start_session_artifact_collection_task,
|
| 19 |
+
wrap_shell_command_with_hub_artifact_bootstrap,
|
| 20 |
+
)
|
| 21 |
+
from agent.tools import local_tools, sandbox_tool
|
| 22 |
+
from agent.tools.hf_repo_files_tool import HfRepoFilesTool
|
| 23 |
+
from agent.tools.hf_repo_git_tool import HfRepoGitTool
|
| 24 |
+
from agent.tools.jobs_tool import _wrap_command_with_artifact_bootstrap
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _session() -> SimpleNamespace:
|
| 28 |
+
return SimpleNamespace(
|
| 29 |
+
session_id="session-123",
|
| 30 |
+
session_start_time="2026-05-05T10:20:30",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_artifact_collection_title_uses_session_date_and_id():
|
| 35 |
+
assert (
|
| 36 |
+
artifact_collection_title(_session())
|
| 37 |
+
== "ml-intern-artifacts-2026-05-05-session-123"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_artifact_collection_title_uses_short_uuid_fragment():
|
| 42 |
+
session = SimpleNamespace(
|
| 43 |
+
session_id="fadcbc77-3439-4c2b-bc52-50d7f6353af3",
|
| 44 |
+
session_start_time="2026-05-05T10:20:30",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
title = artifact_collection_title(session)
|
| 48 |
+
|
| 49 |
+
assert title == "ml-intern-artifacts-2026-05-05-fadcbc77"
|
| 50 |
+
assert len(title) < 60
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_artifact_collection_title_still_truncates_long_non_uuid_ids():
|
| 54 |
+
session = SimpleNamespace(
|
| 55 |
+
session_id="custom-session-id-that-is-longer-than-the-hub-title-limit",
|
| 56 |
+
session_start_time="2026-05-05T10:20:30",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
title = artifact_collection_title(session)
|
| 60 |
+
|
| 61 |
+
assert title.startswith("ml-intern-artifacts-2026-05-05-custom-session-id")
|
| 62 |
+
assert len(title) < 60
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_model_card_merges_tags_and_appends_provenance_and_usage():
|
| 66 |
+
content = """---
|
| 67 |
+
license: apache-2.0
|
| 68 |
+
tags:
|
| 69 |
+
- text-generation
|
| 70 |
+
---
|
| 71 |
+
# Existing Model
|
| 72 |
+
|
| 73 |
+
Existing details stay here.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
updated = augment_repo_card_content(content, "alice/model", "model")
|
| 77 |
+
second_pass = augment_repo_card_content(updated, "alice/model", "model")
|
| 78 |
+
|
| 79 |
+
assert "license: apache-2.0" in updated
|
| 80 |
+
assert "- text-generation" in updated
|
| 81 |
+
assert f"- {ML_INTERN_TAG}" in updated
|
| 82 |
+
assert "# Existing Model" in updated
|
| 83 |
+
assert "Existing details stay here." in updated
|
| 84 |
+
assert PROVENANCE_MARKER in updated
|
| 85 |
+
assert "AutoModelForCausalLM" in updated
|
| 86 |
+
assert second_pass.count(PROVENANCE_MARKER) == 1
|
| 87 |
+
assert second_pass.count("AutoModelForCausalLM") == updated.count(
|
| 88 |
+
"AutoModelForCausalLM"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def test_dataset_card_adds_load_dataset_usage():
|
| 93 |
+
updated = augment_repo_card_content("", "alice/dataset", "dataset")
|
| 94 |
+
|
| 95 |
+
assert f"- {ML_INTERN_TAG}" in updated
|
| 96 |
+
assert "# alice/dataset" in updated
|
| 97 |
+
assert "from datasets import load_dataset" in updated
|
| 98 |
+
assert 'load_dataset("alice/dataset")' in updated
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def test_existing_usage_section_is_preserved_without_duplicate_usage():
|
| 102 |
+
content = """# Existing Dataset
|
| 103 |
+
|
| 104 |
+
## Usage
|
| 105 |
+
|
| 106 |
+
Use the custom loader in this repository.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
updated = augment_repo_card_content(content, "alice/dataset", "dataset")
|
| 110 |
+
|
| 111 |
+
assert "Use the custom loader in this repository." in updated
|
| 112 |
+
assert "from datasets import load_dataset" not in updated
|
| 113 |
+
assert PROVENANCE_MARKER in updated
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_space_card_gets_metadata_without_provenance_body():
|
| 117 |
+
updated = augment_repo_card_content("# Existing Space\n", "alice/space", "space")
|
| 118 |
+
|
| 119 |
+
assert f"- {ML_INTERN_TAG}" in updated
|
| 120 |
+
assert "# Existing Space" in updated
|
| 121 |
+
assert PROVENANCE_MARKER not in updated
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def test_register_hub_artifact_creates_private_collection_and_adds_item_once(
|
| 125 |
+
monkeypatch,
|
| 126 |
+
):
|
| 127 |
+
session = _session()
|
| 128 |
+
|
| 129 |
+
class FakeApi:
|
| 130 |
+
token = "hf-token"
|
| 131 |
+
|
| 132 |
+
def __init__(self):
|
| 133 |
+
self.created_collections = []
|
| 134 |
+
self.collection_items = []
|
| 135 |
+
self.uploads = []
|
| 136 |
+
|
| 137 |
+
def create_collection(self, **kwargs):
|
| 138 |
+
self.created_collections.append(kwargs)
|
| 139 |
+
return SimpleNamespace(slug="alice/ml-intern-artifacts")
|
| 140 |
+
|
| 141 |
+
def add_collection_item(self, **kwargs):
|
| 142 |
+
self.collection_items.append(kwargs)
|
| 143 |
+
|
| 144 |
+
def upload_file(self, **kwargs):
|
| 145 |
+
self.uploads.append(kwargs)
|
| 146 |
+
|
| 147 |
+
api = FakeApi()
|
| 148 |
+
monkeypatch.setattr(hub_artifacts, "_read_remote_readme", lambda *_, **__: "")
|
| 149 |
+
|
| 150 |
+
assert register_hub_artifact(api, "alice/model", "model", session=session)
|
| 151 |
+
assert register_hub_artifact(api, "alice/model", "model", session=session)
|
| 152 |
+
|
| 153 |
+
assert is_known_hub_artifact(session, "alice/model", "model")
|
| 154 |
+
assert len(api.created_collections) == 1
|
| 155 |
+
assert api.created_collections[0]["title"] == artifact_collection_title(session)
|
| 156 |
+
assert api.created_collections[0]["private"] is True
|
| 157 |
+
assert len(api.collection_items) == 1
|
| 158 |
+
assert api.collection_items[0]["item_id"] == "alice/model"
|
| 159 |
+
assert api.collection_items[0]["item_type"] == "model"
|
| 160 |
+
assert api.collection_items[0]["exists_ok"] is True
|
| 161 |
+
assert len(api.uploads) == 1
|
| 162 |
+
assert b"ml-intern" in api.uploads[0]["path_or_fileobj"]
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def test_register_hub_artifact_retries_after_partial_failure(monkeypatch):
|
| 166 |
+
session = _session()
|
| 167 |
+
api = SimpleNamespace(token="hf-token")
|
| 168 |
+
card_attempts = 0
|
| 169 |
+
collection_attempts = 0
|
| 170 |
+
|
| 171 |
+
def flaky_update_repo_card(*args, **kwargs):
|
| 172 |
+
nonlocal card_attempts
|
| 173 |
+
card_attempts += 1
|
| 174 |
+
if card_attempts == 1:
|
| 175 |
+
raise RuntimeError("temporary card failure")
|
| 176 |
+
|
| 177 |
+
def add_to_collection(*args, **kwargs):
|
| 178 |
+
nonlocal collection_attempts
|
| 179 |
+
collection_attempts += 1
|
| 180 |
+
|
| 181 |
+
monkeypatch.setattr(
|
| 182 |
+
hub_artifacts,
|
| 183 |
+
"_update_repo_card",
|
| 184 |
+
flaky_update_repo_card,
|
| 185 |
+
)
|
| 186 |
+
monkeypatch.setattr(hub_artifacts, "_add_to_collection", add_to_collection)
|
| 187 |
+
|
| 188 |
+
assert not register_hub_artifact(api, "alice/model", "model", session=session)
|
| 189 |
+
assert register_hub_artifact(api, "alice/model", "model", session=session)
|
| 190 |
+
assert register_hub_artifact(api, "alice/model", "model", session=session)
|
| 191 |
+
|
| 192 |
+
assert card_attempts == 2
|
| 193 |
+
assert collection_attempts == 2
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def test_register_hub_artifact_retries_after_collection_failure(monkeypatch):
|
| 197 |
+
session = _session()
|
| 198 |
+
api = SimpleNamespace(token="hf-token")
|
| 199 |
+
card_attempts = 0
|
| 200 |
+
collection_attempts = 0
|
| 201 |
+
|
| 202 |
+
def update_repo_card(*args, **kwargs):
|
| 203 |
+
nonlocal card_attempts
|
| 204 |
+
card_attempts += 1
|
| 205 |
+
|
| 206 |
+
def flaky_add_to_collection(*args, **kwargs):
|
| 207 |
+
nonlocal collection_attempts
|
| 208 |
+
collection_attempts += 1
|
| 209 |
+
if collection_attempts == 1:
|
| 210 |
+
raise RuntimeError("temporary collection failure")
|
| 211 |
+
|
| 212 |
+
monkeypatch.setattr(hub_artifacts, "_update_repo_card", update_repo_card)
|
| 213 |
+
monkeypatch.setattr(
|
| 214 |
+
hub_artifacts,
|
| 215 |
+
"_add_to_collection",
|
| 216 |
+
flaky_add_to_collection,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
assert not register_hub_artifact(api, "alice/model", "model", session=session)
|
| 220 |
+
assert register_hub_artifact(api, "alice/model", "model", session=session)
|
| 221 |
+
assert register_hub_artifact(api, "alice/model", "model", session=session)
|
| 222 |
+
|
| 223 |
+
assert card_attempts == 2
|
| 224 |
+
assert collection_attempts == 2
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def test_session_artifact_set_falls_back_when_session_rejects_attrs(caplog):
|
| 228 |
+
class SlottedSession:
|
| 229 |
+
__slots__ = ("session_id", "session_start_time")
|
| 230 |
+
|
| 231 |
+
def __init__(self):
|
| 232 |
+
self.session_id = "session-123"
|
| 233 |
+
self.session_start_time = "2026-05-05T10:20:30"
|
| 234 |
+
|
| 235 |
+
session = SlottedSession()
|
| 236 |
+
|
| 237 |
+
with caplog.at_level(logging.WARNING):
|
| 238 |
+
remember_hub_artifact(session, "alice/model", "model")
|
| 239 |
+
|
| 240 |
+
assert is_known_hub_artifact(session, "alice/model", "model")
|
| 241 |
+
assert "using process-local fallback state" in caplog.text
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@pytest.mark.asyncio
|
| 245 |
+
async def test_ensure_session_artifact_collection_uses_user_token(monkeypatch):
|
| 246 |
+
session = _session()
|
| 247 |
+
calls = []
|
| 248 |
+
|
| 249 |
+
class FakeApi:
|
| 250 |
+
def __init__(self, token):
|
| 251 |
+
self.token = token
|
| 252 |
+
|
| 253 |
+
def fake_ensure_collection_slug(api, seen_session, **kwargs):
|
| 254 |
+
calls.append((api.token, seen_session, kwargs))
|
| 255 |
+
return "alice/ml-intern-artifacts"
|
| 256 |
+
|
| 257 |
+
monkeypatch.setattr(hub_artifacts, "HfApi", FakeApi)
|
| 258 |
+
monkeypatch.setattr(
|
| 259 |
+
hub_artifacts,
|
| 260 |
+
"_ensure_collection_slug",
|
| 261 |
+
fake_ensure_collection_slug,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
slug = await ensure_session_artifact_collection(session, token="hf-token")
|
| 265 |
+
|
| 266 |
+
assert slug == "alice/ml-intern-artifacts"
|
| 267 |
+
assert calls == [
|
| 268 |
+
("hf-token", session, {"token": "hf-token"}),
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
@pytest.mark.asyncio
|
| 273 |
+
async def test_start_session_artifact_collection_task_dedupes(monkeypatch):
|
| 274 |
+
session = _session()
|
| 275 |
+
calls = []
|
| 276 |
+
|
| 277 |
+
async def fake_ensure_session_artifact_collection(seen_session, **kwargs):
|
| 278 |
+
calls.append((seen_session, kwargs))
|
| 279 |
+
await asyncio.sleep(0)
|
| 280 |
+
return "alice/ml-intern-artifacts"
|
| 281 |
+
|
| 282 |
+
monkeypatch.setattr(
|
| 283 |
+
hub_artifacts,
|
| 284 |
+
"ensure_session_artifact_collection",
|
| 285 |
+
fake_ensure_session_artifact_collection,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
task = start_session_artifact_collection_task(session, token="hf-token")
|
| 289 |
+
second = start_session_artifact_collection_task(session, token="hf-token")
|
| 290 |
+
|
| 291 |
+
assert task is not None
|
| 292 |
+
assert second is task
|
| 293 |
+
await task
|
| 294 |
+
assert calls == [(session, {"token": "hf-token"})]
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def test_start_session_artifact_collection_task_skips_without_token():
|
| 298 |
+
assert start_session_artifact_collection_task(_session()) is None
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
@pytest.mark.asyncio
|
| 302 |
+
async def test_hf_repo_git_create_repo_registers_artifact(monkeypatch):
|
| 303 |
+
session = _session()
|
| 304 |
+
calls = []
|
| 305 |
+
|
| 306 |
+
class FakeApi:
|
| 307 |
+
token = "hf-token"
|
| 308 |
+
|
| 309 |
+
def create_repo(self, **kwargs):
|
| 310 |
+
self.create_kwargs = kwargs
|
| 311 |
+
return "https://huggingface.co/spaces/alice/demo"
|
| 312 |
+
|
| 313 |
+
def fake_register(api, repo_id, repo_type, **kwargs):
|
| 314 |
+
calls.append((api, repo_id, repo_type, kwargs))
|
| 315 |
+
return True
|
| 316 |
+
|
| 317 |
+
monkeypatch.setattr(
|
| 318 |
+
"agent.tools.hf_repo_git_tool.register_hub_artifact",
|
| 319 |
+
fake_register,
|
| 320 |
+
)
|
| 321 |
+
tool = HfRepoGitTool(hf_token="hf-token", session=session)
|
| 322 |
+
tool.api = FakeApi()
|
| 323 |
+
|
| 324 |
+
result = await tool._create_repo(
|
| 325 |
+
{
|
| 326 |
+
"repo_id": "alice/demo",
|
| 327 |
+
"repo_type": "space",
|
| 328 |
+
"space_sdk": "gradio",
|
| 329 |
+
"private": True,
|
| 330 |
+
}
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
assert result["totalResults"] == 1
|
| 334 |
+
assert calls == [
|
| 335 |
+
(
|
| 336 |
+
tool.api,
|
| 337 |
+
"alice/demo",
|
| 338 |
+
"space",
|
| 339 |
+
{"session": session, "extra_metadata": {"sdk": "gradio"}},
|
| 340 |
+
)
|
| 341 |
+
]
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@pytest.mark.asyncio
|
| 345 |
+
async def test_hf_repo_files_upload_registers_known_artifact_with_force(monkeypatch):
|
| 346 |
+
session = _session()
|
| 347 |
+
calls = []
|
| 348 |
+
uploads = []
|
| 349 |
+
|
| 350 |
+
class FakeApi:
|
| 351 |
+
token = "hf-token"
|
| 352 |
+
|
| 353 |
+
def upload_file(self, **kwargs):
|
| 354 |
+
uploads.append(kwargs)
|
| 355 |
+
return SimpleNamespace()
|
| 356 |
+
|
| 357 |
+
def fake_register(api, repo_id, repo_type, **kwargs):
|
| 358 |
+
calls.append((api, repo_id, repo_type, kwargs))
|
| 359 |
+
return True
|
| 360 |
+
|
| 361 |
+
monkeypatch.setattr(
|
| 362 |
+
"agent.tools.hf_repo_files_tool.register_hub_artifact",
|
| 363 |
+
fake_register,
|
| 364 |
+
)
|
| 365 |
+
remember_hub_artifact(session, "alice/model", "model")
|
| 366 |
+
|
| 367 |
+
tool = HfRepoFilesTool(hf_token="hf-token", session=session)
|
| 368 |
+
tool.api = FakeApi()
|
| 369 |
+
|
| 370 |
+
result = await tool._upload(
|
| 371 |
+
{
|
| 372 |
+
"repo_id": "alice/model",
|
| 373 |
+
"repo_type": "model",
|
| 374 |
+
"path": "weights.bin",
|
| 375 |
+
"content": b"weights",
|
| 376 |
+
}
|
| 377 |
+
)
|
| 378 |
+
readme_result = await tool._upload(
|
| 379 |
+
{
|
| 380 |
+
"repo_id": "alice/model",
|
| 381 |
+
"repo_type": "model",
|
| 382 |
+
"path": "README.md",
|
| 383 |
+
"content": "# Model",
|
| 384 |
+
}
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
assert result["totalResults"] == 1
|
| 388 |
+
assert readme_result["totalResults"] == 1
|
| 389 |
+
assert [upload["path_in_repo"] for upload in uploads] == [
|
| 390 |
+
"weights.bin",
|
| 391 |
+
"README.md",
|
| 392 |
+
]
|
| 393 |
+
assert calls == [
|
| 394 |
+
(
|
| 395 |
+
tool.api,
|
| 396 |
+
"alice/model",
|
| 397 |
+
"model",
|
| 398 |
+
{"session": session, "force": False},
|
| 399 |
+
),
|
| 400 |
+
(
|
| 401 |
+
tool.api,
|
| 402 |
+
"alice/model",
|
| 403 |
+
"model",
|
| 404 |
+
{"session": session, "force": True},
|
| 405 |
+
),
|
| 406 |
+
]
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def test_hf_jobs_artifact_bootstrap_wraps_command_without_changing_exec_target():
|
| 410 |
+
command = ["uv", "run", "train.py"]
|
| 411 |
+
wrapped = _wrap_command_with_artifact_bootstrap(command, _session())
|
| 412 |
+
|
| 413 |
+
assert wrapped[0:2] == ["/bin/sh", "-lc"]
|
| 414 |
+
assert "sitecustomize.py" in wrapped[2]
|
| 415 |
+
assert "PYTHONPATH" in wrapped[2]
|
| 416 |
+
assert "exec uv run train.py" in wrapped[2]
|
| 417 |
+
assert _wrap_command_with_artifact_bootstrap(command, None) == command
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def test_shell_bootstrap_wraps_capybara_push_to_hub_pattern():
|
| 421 |
+
command = (
|
| 422 |
+
"pip install -q datasets huggingface_hub && python -c "
|
| 423 |
+
"\"subset.push_to_hub('lewtun/Capybara-100', private=False)\""
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
wrapped = wrap_shell_command_with_hub_artifact_bootstrap(command, _session())
|
| 427 |
+
|
| 428 |
+
assert "sitecustomize.py" in wrapped
|
| 429 |
+
assert "PYTHONPATH" in wrapped
|
| 430 |
+
assert command in wrapped
|
| 431 |
+
assert wrap_shell_command_with_hub_artifact_bootstrap(command, None) == command
|
| 432 |
+
assert (
|
| 433 |
+
wrap_shell_command_with_hub_artifact_bootstrap(
|
| 434 |
+
command,
|
| 435 |
+
SimpleNamespace(session_start_time="2026-05-05T10:20:30"),
|
| 436 |
+
)
|
| 437 |
+
== command
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
@pytest.mark.asyncio
|
| 442 |
+
async def test_sandbox_bash_wraps_command_for_session_artifact_hooks():
|
| 443 |
+
calls = []
|
| 444 |
+
|
| 445 |
+
class FakeSandbox:
|
| 446 |
+
def call_tool(self, name, args):
|
| 447 |
+
calls.append((name, args))
|
| 448 |
+
return SimpleNamespace(success=True, output="ok", error="")
|
| 449 |
+
|
| 450 |
+
session = _session()
|
| 451 |
+
session.sandbox = FakeSandbox()
|
| 452 |
+
|
| 453 |
+
handler = sandbox_tool._make_tool_handler("bash")
|
| 454 |
+
output, ok = await handler({"command": "python make_dataset.py"}, session=session)
|
| 455 |
+
|
| 456 |
+
assert ok is True
|
| 457 |
+
assert output == "ok"
|
| 458 |
+
assert calls[0][0] == "bash"
|
| 459 |
+
assert "sitecustomize.py" in calls[0][1]["command"]
|
| 460 |
+
assert "python make_dataset.py" in calls[0][1]["command"]
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
@pytest.mark.asyncio
|
| 464 |
+
async def test_local_bash_wraps_command_for_session_artifact_hooks(monkeypatch):
|
| 465 |
+
seen = {}
|
| 466 |
+
|
| 467 |
+
def fake_run(command, **kwargs):
|
| 468 |
+
seen["command"] = command
|
| 469 |
+
seen["kwargs"] = kwargs
|
| 470 |
+
return SimpleNamespace(stdout="ok", stderr="", returncode=0)
|
| 471 |
+
|
| 472 |
+
monkeypatch.setattr(local_tools.subprocess, "run", fake_run)
|
| 473 |
+
|
| 474 |
+
output, ok = await local_tools._bash_handler(
|
| 475 |
+
{"command": "python make_dataset.py"},
|
| 476 |
+
session=_session(),
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
assert ok is True
|
| 480 |
+
assert output == "ok"
|
| 481 |
+
assert "sitecustomize.py" in seen["command"]
|
| 482 |
+
assert "python make_dataset.py" in seen["command"]
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def test_sitecustomize_bootstrap_is_valid_python():
|
| 486 |
+
code = build_hub_artifact_sitecustomize(_session())
|
| 487 |
+
|
| 488 |
+
compile(code, "sitecustomize.py", "exec")
|
| 489 |
+
assert "ml-intern-artifacts-2026-05-05-session-123" in code
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def test_sitecustomize_bootstrap_reuses_existing_collection_slug():
|
| 493 |
+
session = _session()
|
| 494 |
+
setattr(
|
| 495 |
+
session,
|
| 496 |
+
hub_artifacts._COLLECTION_SLUG_ATTR,
|
| 497 |
+
"alice/ml-intern-artifacts-2026-05-05-session-123",
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
code = build_hub_artifact_sitecustomize(session)
|
| 501 |
+
|
| 502 |
+
compile(code, "sitecustomize.py", "exec")
|
| 503 |
+
assert (
|
| 504 |
+
"collection_slug = 'alice/ml-intern-artifacts-2026-05-05-session-123'" in code
|
| 505 |
+
)
|
tests/unit/test_session_manager_persistence.py
CHANGED
|
@@ -430,6 +430,32 @@ async def test_create_session_schedules_cpu_sandbox_preload():
|
|
| 430 |
await _cancel_runtime_tasks(manager)
|
| 431 |
|
| 432 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
@pytest.mark.asyncio
|
| 434 |
async def test_lazy_restore_schedules_cpu_sandbox_preload():
|
| 435 |
manager = _manager_with_store(RestoreStore())
|
|
@@ -454,6 +480,37 @@ async def test_lazy_restore_schedules_cpu_sandbox_preload():
|
|
| 454 |
await _cancel_runtime_tasks(manager)
|
| 455 |
|
| 456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
@pytest.mark.asyncio
|
| 458 |
async def test_lazy_restore_deletes_persisted_sandbox_before_preload(monkeypatch):
|
| 459 |
deleted: list[tuple[str, str, str]] = []
|
|
|
|
| 430 |
await _cancel_runtime_tasks(manager)
|
| 431 |
|
| 432 |
|
| 433 |
+
@pytest.mark.asyncio
|
| 434 |
+
async def test_create_session_starts_hub_artifact_collection(monkeypatch):
|
| 435 |
+
manager = _manager_with_store(NoopSessionStore())
|
| 436 |
+
manager.enable_hub_artifact_collections = True
|
| 437 |
+
stop = _install_fake_runtime(manager)
|
| 438 |
+
started: list[tuple[str, str]] = []
|
| 439 |
+
|
| 440 |
+
def fake_start_session_artifact_collection_task(session, **kwargs):
|
| 441 |
+
started.append((session.session_id, kwargs["token"]))
|
| 442 |
+
return None
|
| 443 |
+
|
| 444 |
+
monkeypatch.setattr(
|
| 445 |
+
"session_manager.start_session_artifact_collection_task",
|
| 446 |
+
fake_start_session_artifact_collection_task,
|
| 447 |
+
)
|
| 448 |
+
manager._start_cpu_sandbox_preload = lambda _: None # type: ignore[method-assign]
|
| 449 |
+
|
| 450 |
+
try:
|
| 451 |
+
session_id = await manager.create_session(user_id="owner", hf_token="token")
|
| 452 |
+
|
| 453 |
+
assert started == [(session_id, "token")]
|
| 454 |
+
finally:
|
| 455 |
+
stop.set()
|
| 456 |
+
await _cancel_runtime_tasks(manager)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
@pytest.mark.asyncio
|
| 460 |
async def test_lazy_restore_schedules_cpu_sandbox_preload():
|
| 461 |
manager = _manager_with_store(RestoreStore())
|
|
|
|
| 480 |
await _cancel_runtime_tasks(manager)
|
| 481 |
|
| 482 |
|
| 483 |
+
@pytest.mark.asyncio
|
| 484 |
+
async def test_lazy_restore_starts_hub_artifact_collection(monkeypatch):
|
| 485 |
+
manager = _manager_with_store(RestoreStore())
|
| 486 |
+
manager.enable_hub_artifact_collections = True
|
| 487 |
+
stop = _install_fake_runtime(manager)
|
| 488 |
+
started: list[tuple[str, str]] = []
|
| 489 |
+
|
| 490 |
+
def fake_start_session_artifact_collection_task(session, **kwargs):
|
| 491 |
+
started.append((session.session_id, kwargs["token"]))
|
| 492 |
+
return None
|
| 493 |
+
|
| 494 |
+
monkeypatch.setattr(
|
| 495 |
+
"session_manager.start_session_artifact_collection_task",
|
| 496 |
+
fake_start_session_artifact_collection_task,
|
| 497 |
+
)
|
| 498 |
+
manager._start_cpu_sandbox_preload = lambda _: None # type: ignore[method-assign]
|
| 499 |
+
|
| 500 |
+
try:
|
| 501 |
+
restored = await manager.ensure_session_loaded(
|
| 502 |
+
"persisted-session",
|
| 503 |
+
user_id="owner",
|
| 504 |
+
hf_token="token",
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
assert restored is not None
|
| 508 |
+
assert started == [("persisted-session", "token")]
|
| 509 |
+
finally:
|
| 510 |
+
stop.set()
|
| 511 |
+
await _cancel_runtime_tasks(manager)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
@pytest.mark.asyncio
|
| 515 |
async def test_lazy_restore_deletes_persisted_sandbox_before_preload(monkeypatch):
|
| 516 |
deleted: list[tuple[str, str, str]] = []
|