Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Best-effort Hub metadata for artifacts generated by ML Intern sessions.""" | |
| import base64 | |
| import logging | |
| import re | |
| import shlex | |
| import tempfile | |
| import textwrap | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.repocard import metadata_load, metadata_save | |
| from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError | |
| logger = logging.getLogger(__name__) | |
| ML_INTERN_TAG = "ml-intern" | |
| SUPPORTED_REPO_TYPES = {"model", "dataset", "space"} | |
| PROVENANCE_MARKER = "<!-- ml-intern-provenance -->" | |
| _COLLECTION_TITLE_PREFIX = "ml-intern-artifacts" | |
| _COLLECTION_TITLE_MAX_LENGTH = 59 | |
| _UUID_SESSION_ID_RE = re.compile( | |
| r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-" | |
| r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$" | |
| ) | |
| _KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts" | |
| _REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts" | |
| _COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug" | |
| _SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {} | |
| _USAGE_HEADING_RE = re.compile( | |
| r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b", | |
| re.IGNORECASE | re.MULTILINE, | |
| ) | |
| _FRONT_MATTER_RE = re.compile(r"\A---\s*\n.*?\n---\s*\n?", re.DOTALL) | |
| def _safe_session_id(session: Any) -> str: | |
| raw = str(getattr(session, "session_id", "") or "unknown-session") | |
| safe = re.sub(r"[^A-Za-z0-9._-]+", "-", raw).strip("-") | |
| return safe or "unknown-session" | |
| def session_artifact_date(session: Any) -> str: | |
| """Return the YYYY-MM-DD partition date for a session.""" | |
| raw = getattr(session, "session_start_time", None) | |
| if raw: | |
| try: | |
| return datetime.fromisoformat(str(raw).replace("Z", "+00:00")).strftime( | |
| "%Y-%m-%d" | |
| ) | |
| except ValueError: | |
| logger.debug("Could not parse session_start_time=%r", raw) | |
| return datetime.utcnow().strftime("%Y-%m-%d") | |
| def _collection_session_id_fragment(session: Any) -> str: | |
| safe_id = _safe_session_id(session) | |
| if _UUID_SESSION_ID_RE.match(safe_id): | |
| return safe_id[:8] | |
| stem = f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-" | |
| max_id_length = max(1, _COLLECTION_TITLE_MAX_LENGTH - len(stem)) | |
| if len(safe_id) <= max_id_length: | |
| return safe_id | |
| return safe_id[:max_id_length].rstrip("-._") or safe_id[:max_id_length] | |
| def artifact_collection_title(session: Any) -> str: | |
| return ( | |
| f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-" | |
| f"{_collection_session_id_fragment(session)}" | |
| ) | |
| def _artifact_key(repo_id: str, repo_type: str | None) -> str: | |
| return f"{repo_type or 'model'}:{repo_id}" | |
| def _sandbox_space_name_pattern() -> str: | |
| from agent.tools.sandbox_tool import SANDBOX_SPACE_NAME_RE | |
| return SANDBOX_SPACE_NAME_RE.pattern | |
| def is_sandbox_hub_repo(repo_id: str | None, repo_type: str | None) -> bool: | |
| """Return True for ML Intern's ephemeral sandbox Space repos.""" | |
| if (repo_type or "model") != "space" or not repo_id: | |
| return False | |
| repo_name = str(repo_id).rsplit("/", 1)[-1] | |
| return bool(re.fullmatch(_sandbox_space_name_pattern(), repo_name)) | |
| def _session_artifact_set(session: Any, attr: str) -> set[str]: | |
| current = getattr(session, attr, None) | |
| if isinstance(current, set): | |
| return current | |
| current = set() | |
| try: | |
| setattr(session, attr, current) | |
| except Exception: | |
| logger.warning( | |
| "Could not attach %s to session; using process-local fallback state", | |
| attr, | |
| ) | |
| return _SESSION_ARTIFACT_SET_FALLBACK.setdefault((id(session), attr), set()) | |
| return current | |
| def remember_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> None: | |
| if session is None or not repo_id: | |
| return | |
| _session_artifact_set(session, _KNOWN_ARTIFACTS_ATTR).add( | |
| _artifact_key(repo_id, repo_type) | |
| ) | |
| def is_known_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> bool: | |
| if session is None or not repo_id: | |
| return False | |
| return _artifact_key(repo_id, repo_type) in _session_artifact_set( | |
| session, _KNOWN_ARTIFACTS_ATTR | |
| ) | |
| def _merge_tags(metadata: dict[str, Any], tag: str = ML_INTERN_TAG) -> dict[str, Any]: | |
| merged = dict(metadata) | |
| raw_tags = merged.get("tags") | |
| if raw_tags is None: | |
| tags: list[str] = [] | |
| elif isinstance(raw_tags, str): | |
| tags = [raw_tags] | |
| elif isinstance(raw_tags, list): | |
| tags = [str(item) for item in raw_tags] | |
| else: | |
| tags = [str(raw_tags)] | |
| if tag not in tags: | |
| tags.append(tag) | |
| merged["tags"] = tags | |
| return merged | |
| def _metadata_from_content(content: str) -> dict[str, Any]: | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| path = Path(tmp_dir) / "README.md" | |
| path.write_text(content, encoding="utf-8") | |
| return metadata_load(path) or {} | |
| def _content_with_metadata(content: str, metadata: dict[str, Any]) -> str: | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| path = Path(tmp_dir) / "README.md" | |
| path.write_text(content, encoding="utf-8") | |
| metadata_save(path, metadata) | |
| return path.read_text(encoding="utf-8") | |
| def _body_without_metadata(content: str) -> str: | |
| return _FRONT_MATTER_RE.sub("", content, count=1).strip() | |
| def _append_section(content: str, section: str) -> str: | |
| base = content.rstrip() | |
| if base: | |
| return f"{base}\n\n{section.strip()}\n" | |
| return f"{section.strip()}\n" | |
| def _provenance_section(repo_type: str) -> str: | |
| label = {"model": "model", "dataset": "dataset"}.get(repo_type, "Hub") | |
| return f"""{PROVENANCE_MARKER} | |
| ## Generated by ML Intern | |
| This {label} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub. | |
| - Try ML Intern: https://smolagents-ml-intern.hf.space | |
| - Source code: https://github.com/huggingface/ml-intern | |
| """ | |
| def _usage_section(repo_id: str, repo_type: str) -> str: | |
| if repo_type == "dataset": | |
| return f"""## Usage | |
| ```python | |
| from datasets import load_dataset | |
| dataset = load_dataset("{repo_id}") | |
| ``` | |
| """ | |
| return f"""## Usage | |
| ```python | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model_id = "{repo_id}" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| ``` | |
| For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class. | |
| """ | |
| def augment_repo_card_content( | |
| content: str | None, | |
| repo_id: str, | |
| repo_type: str = "model", | |
| *, | |
| extra_metadata: dict[str, Any] | None = None, | |
| ) -> str: | |
| """Return README content with ML Intern metadata and provenance added.""" | |
| repo_type = repo_type or "model" | |
| content = content or "" | |
| metadata = _metadata_from_content(content) | |
| if extra_metadata: | |
| metadata = {**extra_metadata, **metadata} | |
| metadata = _merge_tags(metadata) | |
| updated = _content_with_metadata(content, metadata) | |
| if not _body_without_metadata(updated): | |
| updated = _append_section(updated, f"# {repo_id}") | |
| if repo_type in {"model", "dataset"} and PROVENANCE_MARKER not in updated: | |
| updated = _append_section(updated, _provenance_section(repo_type)) | |
| if not _USAGE_HEADING_RE.search(content): | |
| updated = _append_section(updated, _usage_section(repo_id, repo_type)) | |
| return updated | |
| def _read_remote_readme( | |
| api: Any, | |
| repo_id: str, | |
| repo_type: str, | |
| *, | |
| token: str | bool | None = None, | |
| ) -> str: | |
| token_value = token if token is not None else getattr(api, "token", None) | |
| try: | |
| readme_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="README.md", | |
| repo_type=repo_type, | |
| token=token_value, | |
| ) | |
| except (EntryNotFoundError, RepositoryNotFoundError): | |
| return "" | |
| return Path(readme_path).read_text(encoding="utf-8") | |
| def _update_repo_card( | |
| api: Any, | |
| repo_id: str, | |
| repo_type: str, | |
| *, | |
| token: str | bool | None = None, | |
| extra_metadata: dict[str, Any] | None = None, | |
| ) -> None: | |
| current = _read_remote_readme(api, repo_id, repo_type, token=token) | |
| updated = augment_repo_card_content( | |
| current, | |
| repo_id, | |
| repo_type, | |
| extra_metadata=extra_metadata, | |
| ) | |
| if updated == current: | |
| return | |
| api.upload_file( | |
| path_or_fileobj=updated.encode("utf-8"), | |
| path_in_repo="README.md", | |
| repo_id=repo_id, | |
| repo_type=repo_type, | |
| token=token, | |
| commit_message="Update ML Intern artifact metadata", | |
| ) | |
| def _ensure_collection_slug( | |
| api: Any, | |
| session: Any, | |
| *, | |
| token: str | bool | None = None, | |
| ) -> str | None: | |
| slug = getattr(session, _COLLECTION_SLUG_ATTR, None) | |
| if slug: | |
| return slug | |
| title = artifact_collection_title(session) | |
| collection = api.create_collection( | |
| title=title, | |
| description=( | |
| f"Artifacts generated by ML Intern session {_safe_session_id(session)} " | |
| f"on {session_artifact_date(session)}." | |
| ), | |
| private=True, | |
| exists_ok=True, | |
| token=token, | |
| ) | |
| slug = getattr(collection, "slug", None) | |
| if slug: | |
| setattr(session, _COLLECTION_SLUG_ATTR, slug) | |
| return slug | |
| def _add_to_collection( | |
| api: Any, | |
| session: Any, | |
| repo_id: str, | |
| repo_type: str, | |
| *, | |
| token: str | bool | None = None, | |
| ) -> bool: | |
| slug = _ensure_collection_slug(api, session, token=token) | |
| if not slug: | |
| return False | |
| api.add_collection_item( | |
| collection_slug=slug, | |
| item_id=repo_id, | |
| item_type=repo_type, | |
| note=( | |
| f"Generated by ML Intern session {_safe_session_id(session)} " | |
| f"on {session_artifact_date(session)}." | |
| ), | |
| exists_ok=True, | |
| token=token, | |
| ) | |
| return True | |
| def register_hub_artifact( | |
| api: Any, | |
| repo_id: str, | |
| repo_type: str = "model", | |
| *, | |
| session: Any = None, | |
| token: str | bool | None = None, | |
| extra_metadata: dict[str, Any] | None = None, | |
| force: bool = False, | |
| ) -> bool: | |
| """Tag, card, and collection-register a Hub artifact without raising.""" | |
| if session is None or not repo_id: | |
| return False | |
| repo_type = repo_type or "model" | |
| if repo_type not in SUPPORTED_REPO_TYPES: | |
| return False | |
| if is_sandbox_hub_repo(repo_id, repo_type): | |
| return False | |
| key = _artifact_key(repo_id, repo_type) | |
| remember_hub_artifact(session, repo_id, repo_type) | |
| registered = _session_artifact_set(session, _REGISTERED_ARTIFACTS_ATTR) | |
| if key in registered and not force: | |
| return True | |
| token_value = token if token is not None else getattr(api, "token", None) | |
| card_updated = False | |
| collection_updated = False | |
| try: | |
| _update_repo_card( | |
| api, | |
| repo_id, | |
| repo_type, | |
| token=token_value, | |
| extra_metadata=extra_metadata, | |
| ) | |
| card_updated = True | |
| except Exception as e: | |
| logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e) | |
| try: | |
| collection_updated = _add_to_collection( | |
| api, | |
| session, | |
| repo_id, | |
| repo_type, | |
| token=token_value, | |
| ) | |
| except Exception as e: | |
| logger.debug("ML Intern collection update failed for %s: %s", repo_id, e) | |
| if card_updated and collection_updated: | |
| registered.add(key) | |
| return True | |
| return False | |
| def build_hub_artifact_sitecustomize(session: Any) -> str: | |
| """Build standalone sitecustomize.py code for HF Jobs Python processes.""" | |
| if session is None or not getattr(session, "session_id", None): | |
| return "" | |
| session_id = _safe_session_id(session) | |
| session_date = session_artifact_date(session) | |
| collection_title = artifact_collection_title(session) | |
| collection_slug = getattr(session, _COLLECTION_SLUG_ATTR, None) | |
| return ( | |
| textwrap.dedent( | |
| f""" | |
| # Auto-generated by ML Intern. Best-effort Hub artifact metadata only. | |
| def _install_ml_intern_artifact_hooks(): | |
| import os | |
| import re | |
| import tempfile | |
| from pathlib import Path | |
| try: | |
| import huggingface_hub as _hub | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from huggingface_hub.repocard import metadata_load, metadata_save | |
| from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError | |
| except Exception: | |
| return | |
| session_id = {session_id!r} | |
| session_date = {session_date!r} | |
| collection_title = {collection_title!r} | |
| tag = {ML_INTERN_TAG!r} | |
| marker = {PROVENANCE_MARKER!r} | |
| supported = {sorted(SUPPORTED_REPO_TYPES)!r} | |
| sandbox_space_re = re.compile({_sandbox_space_name_pattern()!r}) | |
| registering = False | |
| collection_slug = {collection_slug!r} | |
| registered = set() | |
| usage_re = re.compile( | |
| r"^#{{2,6}}\\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\\b", | |
| re.IGNORECASE | re.MULTILINE, | |
| ) | |
| front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL) | |
| collection_cache_path = ( | |
| os.environ.get("ML_INTERN_ARTIFACT_COLLECTION_CACHE") | |
| or str( | |
| Path(tempfile.gettempdir()) | |
| / f"ml-intern-artifacts-{{session_id}}.collection" | |
| ) | |
| ) | |
| def _token(value=None, api=None): | |
| if isinstance(value, str) and value: | |
| return value | |
| api_token = getattr(api, "token", None) | |
| if isinstance(api_token, str) and api_token: | |
| return api_token | |
| return ( | |
| os.environ.get("HF_TOKEN") | |
| or os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| or None | |
| ) | |
| def _merge_tags(metadata): | |
| metadata = dict(metadata or {{}}) | |
| raw_tags = metadata.get("tags") | |
| if raw_tags is None: | |
| tags = [] | |
| elif isinstance(raw_tags, str): | |
| tags = [raw_tags] | |
| elif isinstance(raw_tags, list): | |
| tags = [str(item) for item in raw_tags] | |
| else: | |
| tags = [str(raw_tags)] | |
| if tag not in tags: | |
| tags.append(tag) | |
| metadata["tags"] = tags | |
| return metadata | |
| def _metadata_from_content(content): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| path = Path(tmp_dir) / "README.md" | |
| path.write_text(content or "", encoding="utf-8") | |
| return metadata_load(path) or {{}} | |
| def _content_with_metadata(content, metadata): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| path = Path(tmp_dir) / "README.md" | |
| path.write_text(content or "", encoding="utf-8") | |
| metadata_save(path, metadata) | |
| return path.read_text(encoding="utf-8") | |
| def _body_without_metadata(content): | |
| return front_matter_re.sub("", content or "", count=1).strip() | |
| def _append_section(content, section): | |
| base = (content or "").rstrip() | |
| if base: | |
| return base + "\\n\\n" + section.strip() + "\\n" | |
| return section.strip() + "\\n" | |
| def _provenance(repo_type): | |
| label = {{"model": "model", "dataset": "dataset"}}.get( | |
| repo_type, "Hub" | |
| ) | |
| return ( | |
| marker | |
| + "\\n## Generated by ML Intern\\n\\n" | |
| + f"This {{label}} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.\\n\\n" | |
| + "- Try ML Intern: https://smolagents-ml-intern.hf.space\\n" | |
| + "- Source code: https://github.com/huggingface/ml-intern\\n" | |
| ) | |
| def _usage(repo_id, repo_type): | |
| if repo_type == "dataset": | |
| return ( | |
| "## Usage\\n\\n" | |
| "```python\\n" | |
| "from datasets import load_dataset\\n\\n" | |
| f"dataset = load_dataset({{repo_id!r}})\\n" | |
| "```\\n" | |
| ) | |
| return ( | |
| "## Usage\\n\\n" | |
| "```python\\n" | |
| "from transformers import AutoModelForCausalLM, AutoTokenizer\\n\\n" | |
| f"model_id = {{repo_id!r}}\\n" | |
| "tokenizer = AutoTokenizer.from_pretrained(model_id)\\n" | |
| "model = AutoModelForCausalLM.from_pretrained(model_id)\\n" | |
| "```\\n\\n" | |
| "For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.\\n" | |
| ) | |
| def _augment(content, repo_id, repo_type, extra_metadata=None): | |
| metadata = _metadata_from_content(content or "") | |
| if extra_metadata: | |
| metadata = {{**extra_metadata, **metadata}} | |
| updated = _content_with_metadata(content or "", _merge_tags(metadata)) | |
| if not _body_without_metadata(updated): | |
| updated = _append_section(updated, f"# {{repo_id}}") | |
| if repo_type in {{"model", "dataset"}} and marker not in updated: | |
| updated = _append_section(updated, _provenance(repo_type)) | |
| if not usage_re.search(content or ""): | |
| updated = _append_section(updated, _usage(repo_id, repo_type)) | |
| return updated | |
| def _readme(api, repo_id, repo_type, token_value): | |
| try: | |
| path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="README.md", | |
| repo_type=repo_type, | |
| token=token_value, | |
| ) | |
| except (EntryNotFoundError, RepositoryNotFoundError): | |
| return "" | |
| return Path(path).read_text(encoding="utf-8") | |
| def _ensure_collection(api, token_value): | |
| nonlocal collection_slug | |
| if collection_slug: | |
| return collection_slug | |
| try: | |
| cached_slug = Path(collection_cache_path).read_text( | |
| encoding="utf-8" | |
| ).strip() | |
| if cached_slug: | |
| collection_slug = cached_slug | |
| return collection_slug | |
| except Exception: | |
| pass | |
| collection = api.create_collection( | |
| title=collection_title, | |
| description=( | |
| f"Artifacts generated by ML Intern session {{session_id}} " | |
| f"on {{session_date}}." | |
| ), | |
| private=True, | |
| exists_ok=True, | |
| token=token_value, | |
| ) | |
| collection_slug = getattr(collection, "slug", None) | |
| if collection_slug: | |
| try: | |
| cache_path = Path(collection_cache_path) | |
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |
| cache_path.write_text(collection_slug, encoding="utf-8") | |
| except Exception: | |
| pass | |
| return collection_slug | |
| def _register( | |
| repo_id, | |
| repo_type="model", | |
| token_value=None, | |
| extra_metadata=None, | |
| force=False, | |
| ): | |
| nonlocal registering | |
| if registering or not repo_id: | |
| return | |
| repo_type = repo_type or "model" | |
| if repo_type not in supported: | |
| return | |
| if _is_sandbox_repo(repo_id, repo_type): | |
| return | |
| key = f"{{repo_type}}:{{repo_id}}" | |
| if key in registered and not force: | |
| return | |
| registering = True | |
| try: | |
| token_value = _token(token_value) | |
| api = HfApi(token=token_value) | |
| card_updated = False | |
| try: | |
| current = _readme(api, repo_id, repo_type, token_value) | |
| updated = _augment( | |
| current, repo_id, repo_type, extra_metadata=extra_metadata | |
| ) | |
| if updated != current: | |
| _original_upload_file( | |
| api, | |
| path_or_fileobj=updated.encode("utf-8"), | |
| path_in_repo="README.md", | |
| repo_id=repo_id, | |
| repo_type=repo_type, | |
| token=token_value, | |
| commit_message="Update ML Intern artifact metadata", | |
| ) | |
| card_updated = True | |
| except Exception: | |
| pass | |
| collection_updated = False | |
| try: | |
| slug = _ensure_collection(api, token_value) | |
| if slug: | |
| api.add_collection_item( | |
| collection_slug=slug, | |
| item_id=repo_id, | |
| item_type=repo_type, | |
| note=( | |
| f"Generated by ML Intern session {{session_id}} " | |
| f"on {{session_date}}." | |
| ), | |
| exists_ok=True, | |
| token=token_value, | |
| ) | |
| collection_updated = True | |
| except Exception: | |
| pass | |
| if card_updated and collection_updated: | |
| registered.add(key) | |
| finally: | |
| registering = False | |
| _original_create_repo = HfApi.create_repo | |
| _original_upload_file = HfApi.upload_file | |
| _original_upload_folder = getattr(HfApi, "upload_folder", None) | |
| _original_create_commit = getattr(HfApi, "create_commit", None) | |
| def _repo_id(args, kwargs): | |
| return kwargs.get("repo_id") or (args[0] if args else None) | |
| def _repo_type(kwargs): | |
| return kwargs.get("repo_type") or "model" | |
| def _is_sandbox_repo(repo_id, repo_type): | |
| if (repo_type or "model") != "space" or not repo_id: | |
| return False | |
| repo_name = str(repo_id).rsplit("/", 1)[-1] | |
| return bool(sandbox_space_re.fullmatch(repo_name)) | |
| def _patched_create_repo(self, *args, **kwargs): | |
| result = _original_create_repo(self, *args, **kwargs) | |
| repo_id = _repo_id(args, kwargs) | |
| repo_type = _repo_type(kwargs) | |
| extra = None | |
| if repo_type == "space" and kwargs.get("space_sdk"): | |
| extra = {{"sdk": kwargs.get("space_sdk")}} | |
| _register(repo_id, repo_type, _token(kwargs.get("token"), self), extra) | |
| return result | |
| def _patched_upload_file(self, *args, **kwargs): | |
| result = _original_upload_file(self, *args, **kwargs) | |
| if not kwargs.get("create_pr"): | |
| force = kwargs.get("path_in_repo") == "README.md" | |
| _register( | |
| kwargs.get("repo_id"), | |
| _repo_type(kwargs), | |
| _token(kwargs.get("token"), self), | |
| force=force, | |
| ) | |
| return result | |
| def _patched_upload_folder(self, *args, **kwargs): | |
| result = _original_upload_folder(self, *args, **kwargs) | |
| if not kwargs.get("create_pr"): | |
| _register( | |
| kwargs.get("repo_id"), | |
| _repo_type(kwargs), | |
| _token(kwargs.get("token"), self), | |
| force=True, | |
| ) | |
| return result | |
| def _patched_create_commit(self, *args, **kwargs): | |
| result = _original_create_commit(self, *args, **kwargs) | |
| if not kwargs.get("create_pr"): | |
| _register( | |
| _repo_id(args, kwargs), | |
| _repo_type(kwargs), | |
| _token(kwargs.get("token"), self), | |
| force=True, | |
| ) | |
| return result | |
| HfApi.create_repo = _patched_create_repo | |
| HfApi.upload_file = _patched_upload_file | |
| if _original_upload_folder is not None: | |
| HfApi.upload_folder = _patched_upload_folder | |
| if _original_create_commit is not None: | |
| HfApi.create_commit = _patched_create_commit | |
| def _patch_module_func(name, method_name): | |
| original = getattr(_hub, name, None) | |
| if original is None: | |
| return | |
| method = getattr(HfApi, method_name) | |
| def _patched(*args, **kwargs): | |
| api = HfApi(token=_token(kwargs.get("token"))) | |
| return method(api, *args, **kwargs) | |
| setattr(_hub, name, _patched) | |
| _patch_module_func("create_repo", "create_repo") | |
| _patch_module_func("upload_file", "upload_file") | |
| if _original_upload_folder is not None: | |
| _patch_module_func("upload_folder", "upload_folder") | |
| if _original_create_commit is not None: | |
| _patch_module_func("create_commit", "create_commit") | |
| try: | |
| _install_ml_intern_artifact_hooks() | |
| except Exception: | |
| pass | |
| """ | |
| ).strip() | |
| + "\n" | |
| ) | |
| def wrap_shell_command_with_hub_artifact_bootstrap( | |
| command: str, | |
| session: Any, | |
| ) -> str: | |
| """Prefix a shell command so child Python processes load Hub hooks.""" | |
| sitecustomize = build_hub_artifact_sitecustomize(session) | |
| if not sitecustomize or not command: | |
| return command | |
| encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii") | |
| bootstrap = ( | |
| '_ml_intern_artifacts_dir="$(mktemp -d 2>/dev/null)" ' | |
| f"&& printf %s {shlex.quote(encoded)} | base64 -d " | |
| '> "$_ml_intern_artifacts_dir/sitecustomize.py" ' | |
| '&& export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"' | |
| ) | |
| return f"{bootstrap}; {command}" | |