File size: 3,677 Bytes
7803d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python3
import os
import sys
import argparse
import urllib.parse
from huggingface_hub import snapshot_download

def parse_hf_url(url_or_id):
    """
    Parses a Hugging Face URL or Repo ID and extracts the repo ID and type.
    Example URL: https://huggingface.co/google/gemma-4-26B-A4B-it
    Example Dataset URL: https://huggingface.co/datasets/ggml-org/ci
    """
    # Check if it is a URL or a repo ID
    if not (url_or_id.startswith("http://") or url_or_id.startswith("https://")):
        # If it contains a slash, assume it is user/repo
        return url_or_id, "model"

    parsed = urllib.parse.urlparse(url_or_id)
    if parsed.netloc not in ("huggingface.co", "www.huggingface.co"):
        raise ValueError(f"URL host must be huggingface.co, got: {parsed.netloc}")

    path_parts = [p for p in parsed.path.split("/") if p]
    if not path_parts:
        raise ValueError("Hugging Face URL path is empty")

    repo_type = "model"
    if path_parts[0] in ("datasets", "spaces"):
        repo_type = "dataset" if path_parts[0] == "datasets" else "space"
        path_parts = path_parts[1:]

    if len(path_parts) < 2:
        if len(path_parts) == 1:
            return path_parts[0], repo_type
        raise ValueError("Could not extract repository ID from Hugging Face URL")

    repo_id = f"{path_parts[0]}/{path_parts[1]}"
    return repo_id, repo_type

def main():
    parser = argparse.ArgumentParser(
        description="Download a Hugging Face model or dataset from a URL or repository ID."
    )
    parser.add_argument(
        "url_or_id",
        type=str,
        help="Hugging Face repository URL (e.g. https://huggingface.co/google/gemma-4-26B-A4B-it) or repository ID (e.g. google/gemma-4-26B-A4B-it)."
    )
    parser.add_argument(
        "--local-dir",
        type=str,
        default=None,
        help="Directory to save the downloaded model. Defaults to a folder matching the repository name in the current directory."
    )
    parser.add_argument(
        "--token",
        type=str,
        default=os.environ.get("HF_TOKEN"),
        help="Hugging Face API token. Can also be set via the HF_TOKEN environment variable."
    )
    parser.add_argument(
        "--exclude",
        type=str,
        nargs="*",
        help="Glob patterns to exclude from download (e.g., *.bin, *.pt)"
    )
    parser.add_argument(
        "--include",
        type=str,
        nargs="*",
        help="Glob patterns to include in download (e.g., *.safetensors)"
    )
    args = parser.parse_args()

    try:
        repo_id, repo_type = parse_hf_url(args.url_or_id)
    except ValueError as e:
        print(f"Error parsing input URL/ID: {e}", file=sys.stderr)
        sys.exit(1)

    # Determine local directory if not specified
    if args.local_dir is None:
        repo_name = repo_id.split("/")[-1]
        args.local_dir = os.path.join(os.getcwd(), repo_name)

    print(f"Repository ID:   {repo_id}")
    print(f"Repository Type: {repo_type}")
    print(f"Target Directory: {args.local_dir}")

    os.makedirs(args.local_dir, exist_ok=True)

    try:
        downloaded_path = snapshot_download(
            repo_id=repo_id,
            repo_type=repo_type,
            local_dir=args.local_dir,
            local_dir_use_symlinks=False,
            token=args.token,
            ignore_patterns=args.exclude,
            allow_patterns=args.include
        )
        print(f"\nDownload completed successfully!")
        print(f"Files saved in: {downloaded_path}")
    except Exception as e:
        print(f"\nError downloading repository: {e}", file=sys.stderr)
        sys.exit(1)

if __name__ == "__main__":
    main()