| |
|
| | import argparse |
| | import gc |
| | import glob |
| | import json |
| | import os |
| | import shutil |
| | import tempfile |
| |
|
| | from huggingface_hub import snapshot_download |
| | import torch |
| | from torch import nn |
| | from tqdm import tqdm |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
| |
|
| |
|
| | GB = 1 << 30 |
| |
|
| |
|
| | def split_files(model_path, tmp_path, split_size): |
| | if not os.path.exists(model_path): |
| | model_path = snapshot_download(repo_id=model_path) |
| | if not os.path.exists(tmp_path): |
| | os.makedirs(tmp_path) |
| |
|
| | file_pattern = os.path.join(model_path, "pytorch_model-*.bin") |
| | files = glob.glob(file_pattern) |
| |
|
| | part = 0 |
| | try: |
| | for file_path in tqdm(files): |
| | state_dict = torch.load(file_path) |
| | new_state_dict = {} |
| |
|
| | current_size = 0 |
| | for name, param in state_dict.items(): |
| | param_size = param.numel() * param.element_size() |
| |
|
| | if current_size + param_size > split_size: |
| | new_file_name = f"pytorch_model-{part}.bin" |
| | new_file_path = os.path.join(tmp_path, new_file_name) |
| | torch.save(new_state_dict, new_file_path) |
| | current_size = 0 |
| | new_state_dict = None |
| | gc.collect() |
| | new_state_dict = {} |
| | part += 1 |
| |
|
| | new_state_dict[name] = param |
| | current_size += param_size |
| |
|
| | new_file_name = f"pytorch_model-{part}.bin" |
| | new_file_path = os.path.join(tmp_path, new_file_name) |
| | torch.save(new_state_dict, new_file_path) |
| | new_state_dict = None |
| | gc.collect() |
| | new_state_dict = {} |
| | part += 1 |
| | except Exception as e: |
| | print(f"An error occurred during split_files: {e}") |
| | shutil.rmtree(tmp_path) |
| | raise |
| |
|
| |
|
| | def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): |
| | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) |
| | delta_config = AutoConfig.from_pretrained(delta_path) |
| |
|
| | if os.path.exists(target_model_path): |
| | shutil.rmtree(target_model_path) |
| | os.makedirs(target_model_path) |
| |
|
| | split_size = 4 * GB |
| |
|
| | with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path: |
| | print(f"Split files for the base model to {tmp_base_path}") |
| | split_files(base_model_path, tmp_base_path, split_size) |
| | print(f"Split files for the delta weights to {tmp_delta_path}") |
| | split_files(delta_path, tmp_delta_path, split_size) |
| |
|
| | base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin") |
| | base_files = glob.glob(base_pattern) |
| | base_state_dict = torch.load(base_files[0]) |
| | delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") |
| | delta_files = glob.glob(delta_pattern) |
| | |
| |
|
| | print("Applying the delta") |
| | weight_map = {} |
| | total_size = 0 |
| |
|
| | for i, delta_file in tqdm(enumerate(delta_files)): |
| | state_dict = torch.load(delta_file) |
| | file_name = f"pytorch_model-{i}.bin" |
| | for name, param in state_dict.items(): |
| | if name not in base_state_dict: |
| | for base_file in base_files: |
| | base_state_dict = torch.load(base_file) |
| | gc.collect() |
| | if name in base_state_dict: |
| | break |
| | if state_dict[name].shape == base_state_dict[name].shape: |
| | state_dict[name] += base_state_dict[name] |
| | else: |
| | print(name) |
| | weight_map[name] = file_name |
| | total_size += param.numel() * param.element_size() |
| | gc.collect() |
| | torch.save(state_dict, os.path.join(target_model_path, file_name)) |
| |
|
| | with open( |
| | os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w" |
| | ) as f: |
| | json.dump( |
| | {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f |
| | ) |
| |
|
| | print(f"Saving the target model to {target_model_path}") |
| | delta_tokenizer.save_pretrained(target_model_path) |
| | delta_config.save_pretrained(target_model_path) |
| |
|
| |
|
| | def apply_delta(base_model_path, target_model_path, delta_path): |
| | print(f"Loading the delta weights from {delta_path}") |
| | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) |
| | delta = AutoModelForCausalLM.from_pretrained( |
| | delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True |
| | ) |
| |
|
| | print(f"Loading the base model from {base_model_path}") |
| | base = AutoModelForCausalLM.from_pretrained( |
| | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True |
| | ) |
| |
|
| | print("Applying the delta") |
| | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): |
| | assert name in base.state_dict() |
| | if param.shape == base.state_dict()[name].shape: |
| | param.data += base.state_dict()[name] |
| | else: |
| | print(name) |
| |
|
| |
|
| | print(f"Saving the target model to {target_model_path}") |
| | delta.save_pretrained(target_model_path) |
| | delta_tokenizer.save_pretrained(target_model_path) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--base-model-path", type=str, required=True) |
| | parser.add_argument("--target-model-path", type=str, required=True) |
| | parser.add_argument("--delta-path", type=str, required=True) |
| | parser.add_argument( |
| | "--low-cpu-mem", |
| | action="store_true", |
| | help="Lower the cpu memory usage. This will split large files and use " |
| | "disk as swap to reduce the memory usage below 10GB.", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | if args.low_cpu_mem: |
| | apply_delta_low_cpu_mem( |
| | args.base_model_path, args.target_model_path, args.delta_path |
| | ) |
| | else: |
| | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) |
| |
|