| |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| from huggingface_hub import snapshot_download |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Download model snapshots into this deployment repo.") |
| parser.add_argument("--manifest", default="weights_manifest.json") |
| parser.add_argument("--profile", default=None) |
| parser.add_argument("--revision", default=None) |
| parser.add_argument("--local-files-only", action="store_true") |
| args = parser.parse_args() |
|
|
| manifest_path = Path(args.manifest) |
| manifest = json.loads(manifest_path.read_text()) |
| profile_name = args.profile or manifest["default_profile"] |
| profile = manifest["profiles"][profile_name] |
|
|
| for item in profile: |
| local_dir = Path(item["local_dir"]) |
| local_dir.mkdir(parents=True, exist_ok=True) |
| print(f"Downloading {item['repo_id']} -> {local_dir}") |
| snapshot_download( |
| repo_id=item["repo_id"], |
| local_dir=str(local_dir), |
| revision=args.revision, |
| local_files_only=args.local_files_only, |
| ignore_patterns=["*.msgpack", "*.h5"], |
| ) |
| print("Done. The repo now has local checkpoint snapshots under models/.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|