| |
| import os |
| import sys |
| import argparse |
| import urllib.parse |
| from huggingface_hub import snapshot_download |
|
|
| def parse_hf_url(url_or_id): |
| """ |
| Parses a Hugging Face URL or Repo ID and extracts the repo ID and type. |
| Example URL: https://huggingface.co/google/gemma-4-26B-A4B-it |
| Example Dataset URL: https://huggingface.co/datasets/ggml-org/ci |
| """ |
| |
| if not (url_or_id.startswith("http://") or url_or_id.startswith("https://")): |
| |
| return url_or_id, "model" |
|
|
| parsed = urllib.parse.urlparse(url_or_id) |
| if parsed.netloc not in ("huggingface.co", "www.huggingface.co"): |
| raise ValueError(f"URL host must be huggingface.co, got: {parsed.netloc}") |
|
|
| path_parts = [p for p in parsed.path.split("/") if p] |
| if not path_parts: |
| raise ValueError("Hugging Face URL path is empty") |
|
|
| repo_type = "model" |
| if path_parts[0] in ("datasets", "spaces"): |
| repo_type = "dataset" if path_parts[0] == "datasets" else "space" |
| path_parts = path_parts[1:] |
|
|
| if len(path_parts) < 2: |
| if len(path_parts) == 1: |
| return path_parts[0], repo_type |
| raise ValueError("Could not extract repository ID from Hugging Face URL") |
|
|
| repo_id = f"{path_parts[0]}/{path_parts[1]}" |
| return repo_id, repo_type |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Download a Hugging Face model or dataset from a URL or repository ID." |
| ) |
| parser.add_argument( |
| "url_or_id", |
| type=str, |
| help="Hugging Face repository URL (e.g. https://huggingface.co/google/gemma-4-26B-A4B-it) or repository ID (e.g. google/gemma-4-26B-A4B-it)." |
| ) |
| parser.add_argument( |
| "--local-dir", |
| type=str, |
| default=None, |
| help="Directory to save the downloaded model. Defaults to a folder matching the repository name in the current directory." |
| ) |
| parser.add_argument( |
| "--token", |
| type=str, |
| default=os.environ.get("HF_TOKEN"), |
| help="Hugging Face API token. Can also be set via the HF_TOKEN environment variable." |
| ) |
| parser.add_argument( |
| "--exclude", |
| type=str, |
| nargs="*", |
| help="Glob patterns to exclude from download (e.g., *.bin, *.pt)" |
| ) |
| parser.add_argument( |
| "--include", |
| type=str, |
| nargs="*", |
| help="Glob patterns to include in download (e.g., *.safetensors)" |
| ) |
| args = parser.parse_args() |
|
|
| try: |
| repo_id, repo_type = parse_hf_url(args.url_or_id) |
| except ValueError as e: |
| print(f"Error parsing input URL/ID: {e}", file=sys.stderr) |
| sys.exit(1) |
|
|
| |
| if args.local_dir is None: |
| repo_name = repo_id.split("/")[-1] |
| args.local_dir = os.path.join(os.getcwd(), repo_name) |
|
|
| print(f"Repository ID: {repo_id}") |
| print(f"Repository Type: {repo_type}") |
| print(f"Target Directory: {args.local_dir}") |
|
|
| os.makedirs(args.local_dir, exist_ok=True) |
|
|
| try: |
| downloaded_path = snapshot_download( |
| repo_id=repo_id, |
| repo_type=repo_type, |
| local_dir=args.local_dir, |
| local_dir_use_symlinks=False, |
| token=args.token, |
| ignore_patterns=args.exclude, |
| allow_patterns=args.include |
| ) |
| print(f"\nDownload completed successfully!") |
| print(f"Files saved in: {downloaded_path}") |
| except Exception as e: |
| print(f"\nError downloading repository: {e}", file=sys.stderr) |
| sys.exit(1) |
|
|
| if __name__ == "__main__": |
| main() |
|
|