| | import os |
| | import argparse |
| |
|
| | from huggingface_hub import snapshot_download |
| |
|
| | from mlagents_envs import logging_util |
| | from mlagents_envs.logging_util import get_logger |
| |
|
| | logger = get_logger(__name__) |
| | logging_util.set_log_level(logging_util.INFO) |
| |
|
| |
|
| | def load_from_hf(repo_id: str, local_dir: str) -> None: |
| | """ |
| | Download a model from Hugging Face Hub. |
| | :param repo_id: id of the model repository from the Hugging Face Hub |
| | :param local_dir: local destination of the repository |
| | """ |
| | _, repo_name = repo_id.split("/") |
| |
|
| | local_dir = os.path.join(local_dir, repo_name) |
| |
|
| | snapshot_download(repo_id=repo_id, local_dir=local_dir) |
| |
|
| | logger.info(f"The repository {repo_id} has been downloaded to {local_dir}") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--repo-id", |
| | help="Repo id of the model repository from the Hugging Face Hub", |
| | type=str, |
| | ) |
| | parser.add_argument( |
| | "--local-dir", |
| | help="Local destination of the repository", |
| | type=str, |
| | default="./", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | |
| | load_from_hf(args.repo_id, args.local_dir) |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | main() |
| |
|