| | import os |
| | import subprocess |
| | import sys |
| | import argparse |
| | from pathlib import Path |
| | from concurrent.futures import ( |
| | ProcessPoolExecutor, |
| | as_completed, |
| | ) |
| | from zipnn_compress_file import compress_file |
| |
|
| | sys.path.append( |
| | os.path.abspath( |
| | os.path.join( |
| | os.path.dirname(__file__), ".." |
| | ) |
| | ) |
| | ) |
| |
|
| |
|
| | KB = 1024 |
| | MB = 1024 * 1024 |
| | GB = 1024 * 1024 * 1024 |
| |
|
| | RED = "\033[91m" |
| | YELLOW = "\033[93m" |
| | GREEN = "\033[92m" |
| | RESET = "\033[0m" |
| |
|
| |
|
| | def check_and_install_zipnn(): |
| | try: |
| | import zipnn |
| | except ImportError: |
| | print("zipnn not found. Installing...") |
| | subprocess.check_call( |
| | [ |
| | sys.executable, |
| | "-m", |
| | "pip", |
| | "install", |
| | "zipnn", |
| | "--upgrade", |
| | ] |
| | ) |
| | import zipnn |
| |
|
| |
|
| | def parse_streaming_chunk_size( |
| | streaming_chunk_size, |
| | ): |
| | if str(streaming_chunk_size).isdigit(): |
| | final = int(streaming_chunk_size) |
| | else: |
| | size_value = int( |
| | streaming_chunk_size[:-2] |
| | ) |
| | size_unit = streaming_chunk_size[ |
| | -2 |
| | ].lower() |
| |
|
| | if size_unit == "k": |
| | final = KB * size_value |
| | elif size_unit == "m": |
| | final = MB * size_value |
| | elif size_unit == "g": |
| | final = GB * size_value |
| | else: |
| | raise ValueError( |
| | f"Invalid size unit: {size_unit}. Use 'k', 'm', or 'g'." |
| | ) |
| |
|
| | return final |
| |
|
| | def replace_in_file(file_path, old: str, new: str) -> None: |
| | """Given a file_path, replace all occurrences of `old` with `new` inpalce.""" |
| |
|
| | with open(file_path, 'r') as file: |
| | file_data = file.read() |
| |
|
| | file_data = file_data.replace(old, new) |
| |
|
| | with open(file_path, 'w') as file: |
| | file.write(file_data) |
| |
|
| | def compress_files_with_suffix( |
| | suffix, |
| | dtype="", |
| | streaming_chunk_size=1048576, |
| | path=".", |
| | delete=False, |
| | r=False, |
| | force=False, |
| | max_processes=1, |
| | hf_cache=False, |
| | model="", |
| | branch="main", |
| | ): |
| | import zipnn |
| |
|
| | overwrite_first=True |
| | file_list = [] |
| | streaming_chunk_size = ( |
| | parse_streaming_chunk_size( |
| | streaming_chunk_size |
| | ) |
| | ) |
| | if model: |
| | if not hf_cache: |
| | raise ValueError( |
| | "Must specify --hf_cache when using --model" |
| | ) |
| | try: |
| | from huggingface_hub import scan_cache_dir |
| | except ImportError: |
| | raise ImportError( |
| | "huggingface_hub not found. Please pip install huggingface_hub." |
| | ) |
| | cache = scan_cache_dir() |
| | repo = next((repo for repo in cache.repos if repo.repo_id == model), None) |
| |
|
| | if repo is not None: |
| | print(f"Found repo {model} in cache") |
| | |
| | |
| | hash = '' |
| | try: |
| | with open(os.path.join(repo.repo_path, 'refs', branch), "r") as ref: |
| | hash = ref.read() |
| | except FileNotFoundError: |
| | raise FileNotFoundError(f"Branch {branch} not found in repo {model}") |
| | |
| | path = os.path.join(repo.repo_path, 'snapshots', hash) |
| |
|
| | directories_to_search = ( |
| | os.walk(path) |
| | if r |
| | else [(path, [], os.listdir(path))] |
| | ) |
| | files_found = False |
| | for root, _, files in directories_to_search: |
| | for file_name in files: |
| | if file_name.endswith(suffix): |
| | compressed_path = ( |
| | file_name + ".znn" |
| | ) |
| | if not force and os.path.exists( |
| | compressed_path |
| | ): |
| | |
| | if overwrite_first: |
| | overwrite_first=False |
| | user_input = ( |
| | input( |
| | f"Compressed files already exists; Would you like to overwrite them all (y/n)? " |
| | ) |
| | .strip() |
| | .lower() |
| | ) |
| | if user_input not in ( |
| | "y", |
| | "yes", |
| | ): |
| | print( |
| | f"No forced overwriting." |
| | ) |
| | else: |
| | print( |
| | f"Overwriting all compressed files." |
| | ) |
| | force=True |
| | |
| | if not force: |
| | user_input = ( |
| | input( |
| | f"{compressed_path} already exists; overwrite (y/n)? " |
| | ) |
| | .strip() |
| | .lower() |
| | ) |
| | if user_input not in ( |
| | "y", |
| | "yes", |
| | ): |
| | print( |
| | f"Skipping {file_name}..." |
| | ) |
| | continue |
| | files_found = True |
| | full_path = os.path.join( |
| | root, file_name |
| | ) |
| | file_list.append(full_path) |
| |
|
| | if file_list and hf_cache: |
| | try: |
| | from transformers.utils import ( |
| | SAFE_WEIGHTS_INDEX_NAME, |
| | WEIGHTS_INDEX_NAME |
| | ) |
| | except ImportError: |
| | raise ImportError( |
| | "Transformers not found. Please pip install transformers." |
| | ) |
| | |
| | if os.path.exists(os.path.join(path, SAFE_WEIGHTS_INDEX_NAME)): |
| | print(f"{YELLOW}Fixing Hugging Face model json...{RESET}") |
| | blob_name = os.path.join(path, os.readlink(os.path.join(path, SAFE_WEIGHTS_INDEX_NAME))) |
| | replace_in_file( |
| | file_path=blob_name, |
| | old=f"{suffix}", |
| | new=f"{suffix}.znn" |
| | ) |
| | elif os.path.exists(os.path.join(path, WEIGHTS_INDEX_NAME)): |
| | print(f"{YELLOW}Fixing Hugging Face model json...{RESET}") |
| | blob_name = os.path.join(path, os.readlink(os.path.join(path, WEIGHTS_INDEX_NAME))) |
| | replace_in_file( |
| | file_path=blob_name, |
| | old=f"{suffix}", |
| | new=f"{suffix}.znn" |
| | ) |
| |
|
| | with ProcessPoolExecutor( |
| | max_workers=max_processes |
| | ) as executor: |
| | future_to_file = { |
| | executor.submit( |
| | compress_file, |
| | file, |
| | dtype, |
| | streaming_chunk_size, |
| | delete, |
| | True, |
| | hf_cache, |
| | ): file |
| | for file in file_list[:max_processes] |
| | } |
| | file_list = file_list[max_processes:] |
| | while future_to_file: |
| | for future in as_completed( |
| | future_to_file |
| | ): |
| | file = future_to_file.pop(future) |
| |
|
| | try: |
| | future.result() |
| | except Exception as exc: |
| | print( |
| | f"{RED}File {file} generated an exception: {exc}{RESET}" |
| | ) |
| |
|
| | if file_list: |
| | next_file = file_list.pop(0) |
| | future_to_file[ |
| | executor.submit( |
| | compress_file, |
| | next_file, |
| | dtype, |
| | streaming_chunk_size, |
| | delete, |
| | True, |
| | hf_cache, |
| | ) |
| | ] = next_file |
| |
|
| | if not files_found: |
| | print( |
| | f"{RED}No files with the suffix '{suffix}' found.{RESET}" |
| | ) |
| |
|
| | print(f"{GREEN}All files compressed{RESET}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | if len(sys.argv) < 2: |
| | print( |
| | "Usage: python compress_files.py <suffix>" |
| | ) |
| | print( |
| | "Example: python compress_files.py 'safetensors'" |
| | ) |
| | sys.exit(1) |
| |
|
| | parser = argparse.ArgumentParser( |
| | description="Enter a suffix to compress, (optional) dtype, (optional) streaming chunk size, (optional) path to files." |
| | ) |
| | parser.add_argument( |
| | "suffix", |
| | type=str, |
| | help="Specify the file suffix to compress all files with that suffix. If a single file name is provided, only that file will be compressed.", |
| | ) |
| | parser.add_argument( |
| | "--float32", |
| | action="store_true", |
| | help="A flag that triggers float32 compression", |
| | ) |
| | parser.add_argument( |
| | "--streaming_chunk_size", |
| | type=str, |
| | help="An optional streaming chunk size. The format is int (for size in Bytes) or int+KB/MB/GB. Default is 1MB", |
| | ) |
| | parser.add_argument( |
| | "--path", |
| | type=str, |
| | help="Path to files to compress", |
| | ) |
| | parser.add_argument( |
| | "--delete", |
| | action="store_true", |
| | help="A flag that triggers deletion of a single file instead of compression", |
| | ) |
| | parser.add_argument( |
| | "-r", |
| | action="store_true", |
| | help="A flag that triggers recursive search on all subdirectories", |
| | ) |
| | parser.add_argument( |
| | "--recursive", |
| | action="store_true", |
| | help="A flag that triggers recursive search on all subdirectories", |
| | ) |
| | parser.add_argument( |
| | "--force", |
| | action="store_true", |
| | help="A flag that forces overwriting when compressing.", |
| | ) |
| | parser.add_argument( |
| | "--max_processes", |
| | type=int, |
| | help="The amount of maximum processes.", |
| | ) |
| | parser.add_argument( |
| | "--hf_cache", |
| | action="store_true", |
| | help="A flag that indicates if the file is in the Hugging Face cache. Must either specify --model or --path to the model's snapshot cache.", |
| | ) |
| | parser.add_argument( |
| | "--model", |
| | type=str, |
| | help="Only when using --hf_cache, specify the model name or path. E.g. 'ibm-granite/granite-7b-instruct'", |
| | ) |
| | parser.add_argument( |
| | "--model_branch", |
| | type=str, |
| | default="main", |
| | help="Only when using --model, specify the model branch. Default is 'main'", |
| | ) |
| | args = parser.parse_args() |
| | optional_kwargs = {} |
| | if args.float32: |
| | optional_kwargs["dtype"] = 32 |
| | if args.streaming_chunk_size is not None: |
| | optional_kwargs[ |
| | "streaming_chunk_size" |
| | ] = args.streaming_chunk_size |
| | if args.path is not None: |
| | optional_kwargs["path"] = args.path |
| | if args.delete: |
| | optional_kwargs["delete"] = args.delete |
| | if args.r or args.recursive: |
| | optional_kwargs["r"] = args.r |
| | if args.force: |
| | optional_kwargs["force"] = args.force |
| | if args.max_processes: |
| | optional_kwargs["max_processes"] = ( |
| | args.max_processes |
| | ) |
| | if args.hf_cache: |
| | optional_kwargs["hf_cache"] = args.hf_cache |
| | if args.model: |
| | optional_kwargs["model"] = args.model |
| | if args.model_branch: |
| | optional_kwargs[ |
| | "branch" |
| | ] = args.model_branch |
| |
|
| | check_and_install_zipnn() |
| | compress_files_with_suffix( |
| | args.suffix, **optional_kwargs |
| | ) |
| |
|