| import os |
| from pathlib import Path |
| from typing import Dict, List, Optional, Union |
|
|
| from huggingface_hub import HfApi |
| from huggingface_hub.constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE, REPO_TYPES |
| from huggingface_hub.file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name |
| from huggingface_hub.utils import filter_repo_objects, validate_hf_hub_args |
| from joblib import Parallel, delayed |
|
|
|
|
| @validate_hf_hub_args |
| def snapshot_download( |
| repo_id: str, |
| *, |
| revision: Optional[str] = None, |
| repo_type: Optional[str] = None, |
| cache_dir: Union[str, Path, None] = None, |
| library_name: Optional[str] = None, |
| library_version: Optional[str] = None, |
| user_agent: Optional[Union[Dict, str]] = None, |
| proxies: Optional[Dict] = None, |
| etag_timeout: Optional[float] = 10, |
| resume_download: Optional[bool] = False, |
| use_auth_token: Optional[Union[bool, str]] = None, |
| local_files_only: Optional[bool] = False, |
| allow_regex: Optional[Union[List[str], str]] = None, |
| ignore_regex: Optional[Union[List[str], str]] = None, |
| allow_patterns: Optional[Union[List[str], str]] = None, |
| ignore_patterns: Optional[Union[List[str], str]] = None, |
| ) -> str: |
| """Download all files of a repo. |
| |
| Downloads a whole snapshot of a repo's files at the specified revision. This |
| is useful when you want all files from a repo, because you don't know which |
| ones you will need a priori. All files are nested inside a folder in order |
| to keep their actual filename relative to that folder. |
| |
| An alternative would be to just clone a repo but this would require that the |
| user always has git and git-lfs installed, and properly configured. |
| |
| Args: |
| repo_id (`str`): |
| A user or an organization name and a repo name separated by a `/`. |
| revision (`str`, *optional*): |
| An optional Git revision id which can be a branch name, a tag, or a |
| commit hash. |
| repo_type (`str`, *optional*): |
| Set to `"dataset"` or `"space"` if uploading to a dataset or space, |
| `None` or `"model"` if uploading to a model. Default is `None`. |
| cache_dir (`str`, `Path`, *optional*): |
| Path to the folder where cached files are stored. |
| library_name (`str`, *optional*): |
| The name of the library to which the object corresponds. |
| library_version (`str`, *optional*): |
| The version of the library. |
| user_agent (`str`, `dict`, *optional*): |
| The user-agent info in the form of a dictionary or a string. |
| proxies (`dict`, *optional*): |
| Dictionary mapping protocol to the URL of the proxy passed to |
| `requests.request`. |
| etag_timeout (`float`, *optional*, defaults to `10`): |
| When fetching ETag, how many seconds to wait for the server to send |
| data before giving up which is passed to `requests.request`. |
| resume_download (`bool`, *optional*, defaults to `False): |
| If `True`, resume a previously interrupted download. |
| use_auth_token (`str`, `bool`, *optional*): |
| A token to be used for the download. |
| - If `True`, the token is read from the HuggingFace config |
| folder. |
| - If a string, it's used as the authentication token. |
| local_files_only (`bool`, *optional*, defaults to `False`): |
| If `True`, avoid downloading the file and return the path to the |
| local cached file if it exists. |
| allow_patterns (`List[str]` or `str`, *optional*): |
| If provided, only files matching at least one pattern are downloaded. |
| ignore_patterns (`List[str]` or `str`, *optional*): |
| If provided, files matching any of the patterns are not downloaded. |
| |
| Returns: |
| Local folder path (string) of repo snapshot |
| |
| <Tip> |
| |
| Raises the following errors: |
| |
| - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) |
| if `use_auth_token=True` and the token cannot be found. |
| - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if |
| ETag cannot be determined. |
| - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) |
| if some parameter value is invalid |
| |
| </Tip> |
| """ |
| if cache_dir is None: |
| cache_dir = HUGGINGFACE_HUB_CACHE |
| if revision is None: |
| revision = DEFAULT_REVISION |
| if isinstance(cache_dir, Path): |
| cache_dir = str(cache_dir) |
|
|
| if repo_type is None: |
| repo_type = "model" |
| if repo_type not in REPO_TYPES: |
| raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are:" f" {str(REPO_TYPES)}") |
|
|
| storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) |
|
|
| |
| |
| if allow_regex is not None: |
| allow_patterns = allow_regex |
| if ignore_regex is not None: |
| ignore_patterns = ignore_regex |
|
|
| |
| |
| |
| |
| if local_files_only: |
| if REGEX_COMMIT_HASH.match(revision): |
| commit_hash = revision |
| else: |
| |
| ref_path = os.path.join(storage_folder, "refs", revision) |
| with open(ref_path) as f: |
| commit_hash = f.read() |
|
|
| snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) |
|
|
| if os.path.exists(snapshot_folder): |
| return snapshot_folder |
|
|
| raise ValueError( |
| "Cannot find an appropriate cached snapshot folder for the specified" |
| " revision on the local disk and outgoing traffic has been disabled. To" |
| " enable repo look-ups and downloads online, set 'local_files_only' to" |
| " False." |
| ) |
|
|
| |
| _api = HfApi() |
| repo_info = _api.repo_info( |
| repo_id=repo_id, |
| repo_type=repo_type, |
| revision=revision, |
| use_auth_token=use_auth_token, |
| ) |
| filtered_repo_files = list( |
| filter_repo_objects( |
| items=[f.rfilename for f in repo_info.siblings], |
| allow_patterns=allow_patterns, |
| ignore_patterns=ignore_patterns, |
| ) |
| ) |
| commit_hash = repo_info.sha |
| snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) |
| |
| |
| |
| if revision != commit_hash: |
| ref_path = os.path.join(storage_folder, "refs", revision) |
| os.makedirs(os.path.dirname(ref_path), exist_ok=True) |
| with open(ref_path, "w") as f: |
| f.write(commit_hash) |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| Parallel(n_jobs=10, backend="threading")( |
| delayed(hf_hub_download)( |
| repo_id, |
| filename=repo_file, |
| repo_type=repo_type, |
| revision=commit_hash, |
| cache_dir=cache_dir, |
| library_name=library_name, |
| library_version=library_version, |
| user_agent=user_agent, |
| proxies=proxies, |
| etag_timeout=etag_timeout, |
| resume_download=resume_download, |
| use_auth_token=use_auth_token, |
| ) |
| for repo_file in filtered_repo_files |
| ) |
|
|
| return snapshot_folder |
|
|