#!/usr/bin/env python3 import os import sys import argparse import urllib.parse from huggingface_hub import snapshot_download def parse_hf_url(url_or_id): """ Parses a Hugging Face URL or Repo ID and extracts the repo ID and type. Example URL: https://huggingface.co/google/gemma-4-26B-A4B-it Example Dataset URL: https://huggingface.co/datasets/ggml-org/ci """ # Check if it is a URL or a repo ID if not (url_or_id.startswith("http://") or url_or_id.startswith("https://")): # If it contains a slash, assume it is user/repo return url_or_id, "model" parsed = urllib.parse.urlparse(url_or_id) if parsed.netloc not in ("huggingface.co", "www.huggingface.co"): raise ValueError(f"URL host must be huggingface.co, got: {parsed.netloc}") path_parts = [p for p in parsed.path.split("/") if p] if not path_parts: raise ValueError("Hugging Face URL path is empty") repo_type = "model" if path_parts[0] in ("datasets", "spaces"): repo_type = "dataset" if path_parts[0] == "datasets" else "space" path_parts = path_parts[1:] if len(path_parts) < 2: if len(path_parts) == 1: return path_parts[0], repo_type raise ValueError("Could not extract repository ID from Hugging Face URL") repo_id = f"{path_parts[0]}/{path_parts[1]}" return repo_id, repo_type def main(): parser = argparse.ArgumentParser( description="Download a Hugging Face model or dataset from a URL or repository ID." ) parser.add_argument( "url_or_id", type=str, help="Hugging Face repository URL (e.g. https://huggingface.co/google/gemma-4-26B-A4B-it) or repository ID (e.g. google/gemma-4-26B-A4B-it)." ) parser.add_argument( "--local-dir", type=str, default=None, help="Directory to save the downloaded model. Defaults to a folder matching the repository name in the current directory." ) parser.add_argument( "--token", type=str, default=os.environ.get("HF_TOKEN"), help="Hugging Face API token. Can also be set via the HF_TOKEN environment variable." ) parser.add_argument( "--exclude", type=str, nargs="*", help="Glob patterns to exclude from download (e.g., *.bin, *.pt)" ) parser.add_argument( "--include", type=str, nargs="*", help="Glob patterns to include in download (e.g., *.safetensors)" ) args = parser.parse_args() try: repo_id, repo_type = parse_hf_url(args.url_or_id) except ValueError as e: print(f"Error parsing input URL/ID: {e}", file=sys.stderr) sys.exit(1) # Determine local directory if not specified if args.local_dir is None: repo_name = repo_id.split("/")[-1] args.local_dir = os.path.join(os.getcwd(), repo_name) print(f"Repository ID: {repo_id}") print(f"Repository Type: {repo_type}") print(f"Target Directory: {args.local_dir}") os.makedirs(args.local_dir, exist_ok=True) try: downloaded_path = snapshot_download( repo_id=repo_id, repo_type=repo_type, local_dir=args.local_dir, local_dir_use_symlinks=False, token=args.token, ignore_patterns=args.exclude, allow_patterns=args.include ) print(f"\nDownload completed successfully!") print(f"Files saved in: {downloaded_path}") except Exception as e: print(f"\nError downloading repository: {e}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()