| """Best-effort Hub metadata for artifacts generated by ML Intern sessions.""" |
|
|
| import base64 |
| import logging |
| import re |
| import shlex |
| import tempfile |
| import textwrap |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Any |
|
|
| from huggingface_hub import hf_hub_download |
| from huggingface_hub.repocard import metadata_load, metadata_save |
| from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError |
|
|
| logger = logging.getLogger(__name__) |
|
|
| ML_INTERN_TAG = "ml-intern" |
| SUPPORTED_REPO_TYPES = {"model", "dataset", "space"} |
| PROVENANCE_MARKER = "<!-- ml-intern-provenance -->" |
| _COLLECTION_TITLE_PREFIX = "ml-intern-artifacts" |
| _COLLECTION_TITLE_MAX_LENGTH = 59 |
| _UUID_SESSION_ID_RE = re.compile( |
| r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-" |
| r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$" |
| ) |
| _KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts" |
| _REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts" |
| _COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug" |
| _SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {} |
| _USAGE_HEADING_RE = re.compile( |
| r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b", |
| re.IGNORECASE | re.MULTILINE, |
| ) |
| _FRONT_MATTER_RE = re.compile(r"\A---\s*\n.*?\n---\s*\n?", re.DOTALL) |
|
|
|
|
| def _safe_session_id(session: Any) -> str: |
| raw = str(getattr(session, "session_id", "") or "unknown-session") |
| safe = re.sub(r"[^A-Za-z0-9._-]+", "-", raw).strip("-") |
| return safe or "unknown-session" |
|
|
|
|
| def session_artifact_date(session: Any) -> str: |
| """Return the YYYY-MM-DD partition date for a session.""" |
| raw = getattr(session, "session_start_time", None) |
| if raw: |
| try: |
| return datetime.fromisoformat(str(raw).replace("Z", "+00:00")).strftime( |
| "%Y-%m-%d" |
| ) |
| except ValueError: |
| logger.debug("Could not parse session_start_time=%r", raw) |
| return datetime.utcnow().strftime("%Y-%m-%d") |
|
|
|
|
| def _collection_session_id_fragment(session: Any) -> str: |
| safe_id = _safe_session_id(session) |
| if _UUID_SESSION_ID_RE.match(safe_id): |
| return safe_id[:8] |
| stem = f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-" |
| max_id_length = max(1, _COLLECTION_TITLE_MAX_LENGTH - len(stem)) |
| if len(safe_id) <= max_id_length: |
| return safe_id |
| return safe_id[:max_id_length].rstrip("-._") or safe_id[:max_id_length] |
|
|
|
|
| def artifact_collection_title(session: Any) -> str: |
| return ( |
| f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-" |
| f"{_collection_session_id_fragment(session)}" |
| ) |
|
|
|
|
| def _artifact_key(repo_id: str, repo_type: str | None) -> str: |
| return f"{repo_type or 'model'}:{repo_id}" |
|
|
|
|
| def _sandbox_space_name_pattern() -> str: |
| from agent.tools.sandbox_tool import SANDBOX_SPACE_NAME_RE |
|
|
| return SANDBOX_SPACE_NAME_RE.pattern |
|
|
|
|
| def is_sandbox_hub_repo(repo_id: str | None, repo_type: str | None) -> bool: |
| """Return True for ML Intern's ephemeral sandbox Space repos.""" |
| if (repo_type or "model") != "space" or not repo_id: |
| return False |
| repo_name = str(repo_id).rsplit("/", 1)[-1] |
| return bool(re.fullmatch(_sandbox_space_name_pattern(), repo_name)) |
|
|
|
|
| def _session_artifact_set(session: Any, attr: str) -> set[str]: |
| current = getattr(session, attr, None) |
| if isinstance(current, set): |
| return current |
| current = set() |
| try: |
| setattr(session, attr, current) |
| except Exception: |
| logger.warning( |
| "Could not attach %s to session; using process-local fallback state", |
| attr, |
| ) |
| return _SESSION_ARTIFACT_SET_FALLBACK.setdefault((id(session), attr), set()) |
| return current |
|
|
|
|
| def remember_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> None: |
| if session is None or not repo_id: |
| return |
| _session_artifact_set(session, _KNOWN_ARTIFACTS_ATTR).add( |
| _artifact_key(repo_id, repo_type) |
| ) |
|
|
|
|
| def is_known_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> bool: |
| if session is None or not repo_id: |
| return False |
| return _artifact_key(repo_id, repo_type) in _session_artifact_set( |
| session, _KNOWN_ARTIFACTS_ATTR |
| ) |
|
|
|
|
| def _merge_tags(metadata: dict[str, Any], tag: str = ML_INTERN_TAG) -> dict[str, Any]: |
| merged = dict(metadata) |
| raw_tags = merged.get("tags") |
| if raw_tags is None: |
| tags: list[str] = [] |
| elif isinstance(raw_tags, str): |
| tags = [raw_tags] |
| elif isinstance(raw_tags, list): |
| tags = [str(item) for item in raw_tags] |
| else: |
| tags = [str(raw_tags)] |
|
|
| if tag not in tags: |
| tags.append(tag) |
| merged["tags"] = tags |
| return merged |
|
|
|
|
| def _metadata_from_content(content: str) -> dict[str, Any]: |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| path = Path(tmp_dir) / "README.md" |
| path.write_text(content, encoding="utf-8") |
| return metadata_load(path) or {} |
|
|
|
|
| def _content_with_metadata(content: str, metadata: dict[str, Any]) -> str: |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| path = Path(tmp_dir) / "README.md" |
| path.write_text(content, encoding="utf-8") |
| metadata_save(path, metadata) |
| return path.read_text(encoding="utf-8") |
|
|
|
|
| def _body_without_metadata(content: str) -> str: |
| return _FRONT_MATTER_RE.sub("", content, count=1).strip() |
|
|
|
|
| def _append_section(content: str, section: str) -> str: |
| base = content.rstrip() |
| if base: |
| return f"{base}\n\n{section.strip()}\n" |
| return f"{section.strip()}\n" |
|
|
|
|
| def _provenance_section(repo_type: str) -> str: |
| label = {"model": "model", "dataset": "dataset"}.get(repo_type, "Hub") |
| return f"""{PROVENANCE_MARKER} |
| ## Generated by ML Intern |
| |
| This {label} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub. |
| |
| - Try ML Intern: https://smolagents-ml-intern.hf.space |
| - Source code: https://github.com/huggingface/ml-intern |
| """ |
|
|
|
|
| def _usage_section(repo_id: str, repo_type: str) -> str: |
| if repo_type == "dataset": |
| return f"""## Usage |
| |
| ```python |
| from datasets import load_dataset |
| |
| dataset = load_dataset("{repo_id}") |
| ``` |
| """ |
|
|
| return f"""## Usage |
| |
| ```python |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
| model_id = "{repo_id}" |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForCausalLM.from_pretrained(model_id) |
| ``` |
| |
| For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class. |
| """ |
|
|
|
|
| def augment_repo_card_content( |
| content: str | None, |
| repo_id: str, |
| repo_type: str = "model", |
| *, |
| extra_metadata: dict[str, Any] | None = None, |
| ) -> str: |
| """Return README content with ML Intern metadata and provenance added.""" |
| repo_type = repo_type or "model" |
| content = content or "" |
| metadata = _metadata_from_content(content) |
| if extra_metadata: |
| metadata = {**extra_metadata, **metadata} |
| metadata = _merge_tags(metadata) |
| updated = _content_with_metadata(content, metadata) |
|
|
| if not _body_without_metadata(updated): |
| updated = _append_section(updated, f"# {repo_id}") |
|
|
| if repo_type in {"model", "dataset"} and PROVENANCE_MARKER not in updated: |
| updated = _append_section(updated, _provenance_section(repo_type)) |
| if not _USAGE_HEADING_RE.search(content): |
| updated = _append_section(updated, _usage_section(repo_id, repo_type)) |
|
|
| return updated |
|
|
|
|
| def _read_remote_readme( |
| api: Any, |
| repo_id: str, |
| repo_type: str, |
| *, |
| token: str | bool | None = None, |
| ) -> str: |
| token_value = token if token is not None else getattr(api, "token", None) |
| try: |
| readme_path = hf_hub_download( |
| repo_id=repo_id, |
| filename="README.md", |
| repo_type=repo_type, |
| token=token_value, |
| ) |
| except (EntryNotFoundError, RepositoryNotFoundError): |
| return "" |
| return Path(readme_path).read_text(encoding="utf-8") |
|
|
|
|
| def _update_repo_card( |
| api: Any, |
| repo_id: str, |
| repo_type: str, |
| *, |
| token: str | bool | None = None, |
| extra_metadata: dict[str, Any] | None = None, |
| ) -> None: |
| current = _read_remote_readme(api, repo_id, repo_type, token=token) |
| updated = augment_repo_card_content( |
| current, |
| repo_id, |
| repo_type, |
| extra_metadata=extra_metadata, |
| ) |
| if updated == current: |
| return |
| api.upload_file( |
| path_or_fileobj=updated.encode("utf-8"), |
| path_in_repo="README.md", |
| repo_id=repo_id, |
| repo_type=repo_type, |
| token=token, |
| commit_message="Update ML Intern artifact metadata", |
| ) |
|
|
|
|
| def _ensure_collection_slug( |
| api: Any, |
| session: Any, |
| *, |
| token: str | bool | None = None, |
| ) -> str | None: |
| slug = getattr(session, _COLLECTION_SLUG_ATTR, None) |
| if slug: |
| return slug |
|
|
| title = artifact_collection_title(session) |
| collection = api.create_collection( |
| title=title, |
| description=( |
| f"Artifacts generated by ML Intern session {_safe_session_id(session)} " |
| f"on {session_artifact_date(session)}." |
| ), |
| private=True, |
| exists_ok=True, |
| token=token, |
| ) |
| slug = getattr(collection, "slug", None) |
| if slug: |
| setattr(session, _COLLECTION_SLUG_ATTR, slug) |
| return slug |
|
|
|
|
| def _add_to_collection( |
| api: Any, |
| session: Any, |
| repo_id: str, |
| repo_type: str, |
| *, |
| token: str | bool | None = None, |
| ) -> bool: |
| slug = _ensure_collection_slug(api, session, token=token) |
| if not slug: |
| return False |
| api.add_collection_item( |
| collection_slug=slug, |
| item_id=repo_id, |
| item_type=repo_type, |
| note=( |
| f"Generated by ML Intern session {_safe_session_id(session)} " |
| f"on {session_artifact_date(session)}." |
| ), |
| exists_ok=True, |
| token=token, |
| ) |
| return True |
|
|
|
|
| def register_hub_artifact( |
| api: Any, |
| repo_id: str, |
| repo_type: str = "model", |
| *, |
| session: Any = None, |
| token: str | bool | None = None, |
| extra_metadata: dict[str, Any] | None = None, |
| force: bool = False, |
| ) -> bool: |
| """Tag, card, and collection-register a Hub artifact without raising.""" |
| if session is None or not repo_id: |
| return False |
| repo_type = repo_type or "model" |
| if repo_type not in SUPPORTED_REPO_TYPES: |
| return False |
| if is_sandbox_hub_repo(repo_id, repo_type): |
| return False |
|
|
| key = _artifact_key(repo_id, repo_type) |
| remember_hub_artifact(session, repo_id, repo_type) |
| registered = _session_artifact_set(session, _REGISTERED_ARTIFACTS_ATTR) |
| if key in registered and not force: |
| return True |
|
|
| token_value = token if token is not None else getattr(api, "token", None) |
| card_updated = False |
| collection_updated = False |
| try: |
| _update_repo_card( |
| api, |
| repo_id, |
| repo_type, |
| token=token_value, |
| extra_metadata=extra_metadata, |
| ) |
| card_updated = True |
| except Exception as e: |
| logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e) |
|
|
| try: |
| collection_updated = _add_to_collection( |
| api, |
| session, |
| repo_id, |
| repo_type, |
| token=token_value, |
| ) |
| except Exception as e: |
| logger.debug("ML Intern collection update failed for %s: %s", repo_id, e) |
|
|
| if card_updated and collection_updated: |
| registered.add(key) |
| return True |
| return False |
|
|
|
|
| def build_hub_artifact_sitecustomize(session: Any) -> str: |
| """Build standalone sitecustomize.py code for HF Jobs Python processes.""" |
| if session is None or not getattr(session, "session_id", None): |
| return "" |
|
|
| session_id = _safe_session_id(session) |
| session_date = session_artifact_date(session) |
| collection_title = artifact_collection_title(session) |
| collection_slug = getattr(session, _COLLECTION_SLUG_ATTR, None) |
|
|
| return ( |
| textwrap.dedent( |
| f""" |
| # Auto-generated by ML Intern. Best-effort Hub artifact metadata only. |
| def _install_ml_intern_artifact_hooks(): |
| import os |
| import re |
| import tempfile |
| from pathlib import Path |
| |
| try: |
| import huggingface_hub as _hub |
| from huggingface_hub import HfApi, hf_hub_download |
| from huggingface_hub.repocard import metadata_load, metadata_save |
| from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError |
| except Exception: |
| return |
| |
| session_id = {session_id!r} |
| session_date = {session_date!r} |
| collection_title = {collection_title!r} |
| tag = {ML_INTERN_TAG!r} |
| marker = {PROVENANCE_MARKER!r} |
| supported = {sorted(SUPPORTED_REPO_TYPES)!r} |
| sandbox_space_re = re.compile({_sandbox_space_name_pattern()!r}) |
| registering = False |
| collection_slug = {collection_slug!r} |
| registered = set() |
| usage_re = re.compile( |
| r"^#{{2,6}}\\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\\b", |
| re.IGNORECASE | re.MULTILINE, |
| ) |
| front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL) |
| collection_cache_path = ( |
| os.environ.get("ML_INTERN_ARTIFACT_COLLECTION_CACHE") |
| or str( |
| Path(tempfile.gettempdir()) |
| / f"ml-intern-artifacts-{{session_id}}.collection" |
| ) |
| ) |
| |
| def _token(value=None, api=None): |
| if isinstance(value, str) and value: |
| return value |
| api_token = getattr(api, "token", None) |
| if isinstance(api_token, str) and api_token: |
| return api_token |
| return ( |
| os.environ.get("HF_TOKEN") |
| or os.environ.get("HUGGINGFACE_HUB_TOKEN") |
| or None |
| ) |
| |
| def _merge_tags(metadata): |
| metadata = dict(metadata or {{}}) |
| raw_tags = metadata.get("tags") |
| if raw_tags is None: |
| tags = [] |
| elif isinstance(raw_tags, str): |
| tags = [raw_tags] |
| elif isinstance(raw_tags, list): |
| tags = [str(item) for item in raw_tags] |
| else: |
| tags = [str(raw_tags)] |
| if tag not in tags: |
| tags.append(tag) |
| metadata["tags"] = tags |
| return metadata |
| |
| def _metadata_from_content(content): |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| path = Path(tmp_dir) / "README.md" |
| path.write_text(content or "", encoding="utf-8") |
| return metadata_load(path) or {{}} |
| |
| def _content_with_metadata(content, metadata): |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| path = Path(tmp_dir) / "README.md" |
| path.write_text(content or "", encoding="utf-8") |
| metadata_save(path, metadata) |
| return path.read_text(encoding="utf-8") |
| |
| def _body_without_metadata(content): |
| return front_matter_re.sub("", content or "", count=1).strip() |
| |
| def _append_section(content, section): |
| base = (content or "").rstrip() |
| if base: |
| return base + "\\n\\n" + section.strip() + "\\n" |
| return section.strip() + "\\n" |
| |
| def _provenance(repo_type): |
| label = {{"model": "model", "dataset": "dataset"}}.get( |
| repo_type, "Hub" |
| ) |
| return ( |
| marker |
| + "\\n## Generated by ML Intern\\n\\n" |
| + f"This {{label}} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.\\n\\n" |
| + "- Try ML Intern: https://smolagents-ml-intern.hf.space\\n" |
| + "- Source code: https://github.com/huggingface/ml-intern\\n" |
| ) |
| |
| def _usage(repo_id, repo_type): |
| if repo_type == "dataset": |
| return ( |
| "## Usage\\n\\n" |
| "```python\\n" |
| "from datasets import load_dataset\\n\\n" |
| f"dataset = load_dataset({{repo_id!r}})\\n" |
| "```\\n" |
| ) |
| return ( |
| "## Usage\\n\\n" |
| "```python\\n" |
| "from transformers import AutoModelForCausalLM, AutoTokenizer\\n\\n" |
| f"model_id = {{repo_id!r}}\\n" |
| "tokenizer = AutoTokenizer.from_pretrained(model_id)\\n" |
| "model = AutoModelForCausalLM.from_pretrained(model_id)\\n" |
| "```\\n\\n" |
| "For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.\\n" |
| ) |
| |
| def _augment(content, repo_id, repo_type, extra_metadata=None): |
| metadata = _metadata_from_content(content or "") |
| if extra_metadata: |
| metadata = {{**extra_metadata, **metadata}} |
| updated = _content_with_metadata(content or "", _merge_tags(metadata)) |
| if not _body_without_metadata(updated): |
| updated = _append_section(updated, f"# {{repo_id}}") |
| if repo_type in {{"model", "dataset"}} and marker not in updated: |
| updated = _append_section(updated, _provenance(repo_type)) |
| if not usage_re.search(content or ""): |
| updated = _append_section(updated, _usage(repo_id, repo_type)) |
| return updated |
| |
| def _readme(api, repo_id, repo_type, token_value): |
| try: |
| path = hf_hub_download( |
| repo_id=repo_id, |
| filename="README.md", |
| repo_type=repo_type, |
| token=token_value, |
| ) |
| except (EntryNotFoundError, RepositoryNotFoundError): |
| return "" |
| return Path(path).read_text(encoding="utf-8") |
| |
| def _ensure_collection(api, token_value): |
| nonlocal collection_slug |
| if collection_slug: |
| return collection_slug |
| try: |
| cached_slug = Path(collection_cache_path).read_text( |
| encoding="utf-8" |
| ).strip() |
| if cached_slug: |
| collection_slug = cached_slug |
| return collection_slug |
| except Exception: |
| pass |
| collection = api.create_collection( |
| title=collection_title, |
| description=( |
| f"Artifacts generated by ML Intern session {{session_id}} " |
| f"on {{session_date}}." |
| ), |
| private=True, |
| exists_ok=True, |
| token=token_value, |
| ) |
| collection_slug = getattr(collection, "slug", None) |
| if collection_slug: |
| try: |
| cache_path = Path(collection_cache_path) |
| cache_path.parent.mkdir(parents=True, exist_ok=True) |
| cache_path.write_text(collection_slug, encoding="utf-8") |
| except Exception: |
| pass |
| return collection_slug |
| |
| def _register( |
| repo_id, |
| repo_type="model", |
| token_value=None, |
| extra_metadata=None, |
| force=False, |
| ): |
| nonlocal registering |
| if registering or not repo_id: |
| return |
| repo_type = repo_type or "model" |
| if repo_type not in supported: |
| return |
| if _is_sandbox_repo(repo_id, repo_type): |
| return |
| key = f"{{repo_type}}:{{repo_id}}" |
| if key in registered and not force: |
| return |
| registering = True |
| try: |
| token_value = _token(token_value) |
| api = HfApi(token=token_value) |
| card_updated = False |
| try: |
| current = _readme(api, repo_id, repo_type, token_value) |
| updated = _augment( |
| current, repo_id, repo_type, extra_metadata=extra_metadata |
| ) |
| if updated != current: |
| _original_upload_file( |
| api, |
| path_or_fileobj=updated.encode("utf-8"), |
| path_in_repo="README.md", |
| repo_id=repo_id, |
| repo_type=repo_type, |
| token=token_value, |
| commit_message="Update ML Intern artifact metadata", |
| ) |
| card_updated = True |
| except Exception: |
| pass |
| collection_updated = False |
| try: |
| slug = _ensure_collection(api, token_value) |
| if slug: |
| api.add_collection_item( |
| collection_slug=slug, |
| item_id=repo_id, |
| item_type=repo_type, |
| note=( |
| f"Generated by ML Intern session {{session_id}} " |
| f"on {{session_date}}." |
| ), |
| exists_ok=True, |
| token=token_value, |
| ) |
| collection_updated = True |
| except Exception: |
| pass |
| if card_updated and collection_updated: |
| registered.add(key) |
| finally: |
| registering = False |
| |
| _original_create_repo = HfApi.create_repo |
| _original_upload_file = HfApi.upload_file |
| _original_upload_folder = getattr(HfApi, "upload_folder", None) |
| _original_create_commit = getattr(HfApi, "create_commit", None) |
| |
| def _repo_id(args, kwargs): |
| return kwargs.get("repo_id") or (args[0] if args else None) |
| |
| def _repo_type(kwargs): |
| return kwargs.get("repo_type") or "model" |
| |
| def _is_sandbox_repo(repo_id, repo_type): |
| if (repo_type or "model") != "space" or not repo_id: |
| return False |
| repo_name = str(repo_id).rsplit("/", 1)[-1] |
| return bool(sandbox_space_re.fullmatch(repo_name)) |
| |
| def _patched_create_repo(self, *args, **kwargs): |
| result = _original_create_repo(self, *args, **kwargs) |
| repo_id = _repo_id(args, kwargs) |
| repo_type = _repo_type(kwargs) |
| extra = None |
| if repo_type == "space" and kwargs.get("space_sdk"): |
| extra = {{"sdk": kwargs.get("space_sdk")}} |
| _register(repo_id, repo_type, _token(kwargs.get("token"), self), extra) |
| return result |
| |
| def _patched_upload_file(self, *args, **kwargs): |
| result = _original_upload_file(self, *args, **kwargs) |
| if not kwargs.get("create_pr"): |
| force = kwargs.get("path_in_repo") == "README.md" |
| _register( |
| kwargs.get("repo_id"), |
| _repo_type(kwargs), |
| _token(kwargs.get("token"), self), |
| force=force, |
| ) |
| return result |
| |
| def _patched_upload_folder(self, *args, **kwargs): |
| result = _original_upload_folder(self, *args, **kwargs) |
| if not kwargs.get("create_pr"): |
| _register( |
| kwargs.get("repo_id"), |
| _repo_type(kwargs), |
| _token(kwargs.get("token"), self), |
| force=True, |
| ) |
| return result |
| |
| def _patched_create_commit(self, *args, **kwargs): |
| result = _original_create_commit(self, *args, **kwargs) |
| if not kwargs.get("create_pr"): |
| _register( |
| _repo_id(args, kwargs), |
| _repo_type(kwargs), |
| _token(kwargs.get("token"), self), |
| force=True, |
| ) |
| return result |
| |
| HfApi.create_repo = _patched_create_repo |
| HfApi.upload_file = _patched_upload_file |
| if _original_upload_folder is not None: |
| HfApi.upload_folder = _patched_upload_folder |
| if _original_create_commit is not None: |
| HfApi.create_commit = _patched_create_commit |
| |
| def _patch_module_func(name, method_name): |
| original = getattr(_hub, name, None) |
| if original is None: |
| return |
| method = getattr(HfApi, method_name) |
| |
| def _patched(*args, **kwargs): |
| api = HfApi(token=_token(kwargs.get("token"))) |
| return method(api, *args, **kwargs) |
| |
| setattr(_hub, name, _patched) |
| |
| _patch_module_func("create_repo", "create_repo") |
| _patch_module_func("upload_file", "upload_file") |
| if _original_upload_folder is not None: |
| _patch_module_func("upload_folder", "upload_folder") |
| if _original_create_commit is not None: |
| _patch_module_func("create_commit", "create_commit") |
| |
| try: |
| _install_ml_intern_artifact_hooks() |
| except Exception: |
| pass |
| """ |
| ).strip() |
| + "\n" |
| ) |
|
|
|
|
| def wrap_shell_command_with_hub_artifact_bootstrap( |
| command: str, |
| session: Any, |
| ) -> str: |
| """Prefix a shell command so child Python processes load Hub hooks.""" |
| sitecustomize = build_hub_artifact_sitecustomize(session) |
| if not sitecustomize or not command: |
| return command |
|
|
| encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii") |
| bootstrap = ( |
| '_ml_intern_artifacts_dir="$(mktemp -d 2>/dev/null)" ' |
| f"&& printf %s {shlex.quote(encoded)} | base64 -d " |
| '> "$_ml_intern_artifacts_dir/sitecustomize.py" ' |
| '&& export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"' |
| ) |
| return f"{bootstrap}; {command}" |
|
|