HPC-Quantize / download_model.py
CompressedGemma's picture
Upload 5 files
7803d72 verified
Raw
History Blame Contribute Delete
3.68 kB
#!/usr/bin/env python3
import os
import sys
import argparse
import urllib.parse
from huggingface_hub import snapshot_download
def parse_hf_url(url_or_id):
"""
Parses a Hugging Face URL or Repo ID and extracts the repo ID and type.
Example URL: https://huggingface.co/google/gemma-4-26B-A4B-it
Example Dataset URL: https://huggingface.co/datasets/ggml-org/ci
"""
# Check if it is a URL or a repo ID
if not (url_or_id.startswith("http://") or url_or_id.startswith("https://")):
# If it contains a slash, assume it is user/repo
return url_or_id, "model"
parsed = urllib.parse.urlparse(url_or_id)
if parsed.netloc not in ("huggingface.co", "www.huggingface.co"):
raise ValueError(f"URL host must be huggingface.co, got: {parsed.netloc}")
path_parts = [p for p in parsed.path.split("/") if p]
if not path_parts:
raise ValueError("Hugging Face URL path is empty")
repo_type = "model"
if path_parts[0] in ("datasets", "spaces"):
repo_type = "dataset" if path_parts[0] == "datasets" else "space"
path_parts = path_parts[1:]
if len(path_parts) < 2:
if len(path_parts) == 1:
return path_parts[0], repo_type
raise ValueError("Could not extract repository ID from Hugging Face URL")
repo_id = f"{path_parts[0]}/{path_parts[1]}"
return repo_id, repo_type
def main():
parser = argparse.ArgumentParser(
description="Download a Hugging Face model or dataset from a URL or repository ID."
)
parser.add_argument(
"url_or_id",
type=str,
help="Hugging Face repository URL (e.g. https://huggingface.co/google/gemma-4-26B-A4B-it) or repository ID (e.g. google/gemma-4-26B-A4B-it)."
)
parser.add_argument(
"--local-dir",
type=str,
default=None,
help="Directory to save the downloaded model. Defaults to a folder matching the repository name in the current directory."
)
parser.add_argument(
"--token",
type=str,
default=os.environ.get("HF_TOKEN"),
help="Hugging Face API token. Can also be set via the HF_TOKEN environment variable."
)
parser.add_argument(
"--exclude",
type=str,
nargs="*",
help="Glob patterns to exclude from download (e.g., *.bin, *.pt)"
)
parser.add_argument(
"--include",
type=str,
nargs="*",
help="Glob patterns to include in download (e.g., *.safetensors)"
)
args = parser.parse_args()
try:
repo_id, repo_type = parse_hf_url(args.url_or_id)
except ValueError as e:
print(f"Error parsing input URL/ID: {e}", file=sys.stderr)
sys.exit(1)
# Determine local directory if not specified
if args.local_dir is None:
repo_name = repo_id.split("/")[-1]
args.local_dir = os.path.join(os.getcwd(), repo_name)
print(f"Repository ID: {repo_id}")
print(f"Repository Type: {repo_type}")
print(f"Target Directory: {args.local_dir}")
os.makedirs(args.local_dir, exist_ok=True)
try:
downloaded_path = snapshot_download(
repo_id=repo_id,
repo_type=repo_type,
local_dir=args.local_dir,
local_dir_use_symlinks=False,
token=args.token,
ignore_patterns=args.exclude,
allow_patterns=args.include
)
print(f"\nDownload completed successfully!")
print(f"Files saved in: {downloaded_path}")
except Exception as e:
print(f"\nError downloading repository: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()