| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from peft import LoraConfig, inject_adapter_in_model |
| from safetensors import safe_open |
|
|
|
|
| def freeze_parameters(model: nn.Module): |
| |
| model.requires_grad_(False) |
| model.eval() |
| model.train() |
|
|
|
|
| def add_lora_to_model( |
| model: nn.Module, |
| lora_rank=4, |
| lora_alpha=4, |
| lora_target_modules="q,k,v,o,ffn.0,ffn.2", |
| init_lora_weights="kaiming", |
| pretrained_lora_path=None, |
| state_dict_converter=None, |
| lora_target_modules_support=None, |
| ): |
| model.lora_alpha = lora_alpha |
| if init_lora_weights == "kaiming": |
| init_lora_weights = True |
|
|
| lora_config = LoraConfig( |
| r=lora_rank, |
| lora_alpha=lora_alpha, |
| init_lora_weights=init_lora_weights, |
| target_modules=lora_target_modules.split(","), |
| ) |
|
|
| for lora_target_module in lora_config.target_modules: |
| if lora_target_module not in lora_target_modules_support: |
| raise ValueError(f"lora_target_module {lora_target_module} not in lora_target_modules_support") |
|
|
| model = inject_adapter_in_model(lora_config, model) |
| for param in model.parameters(): |
| if param.requires_grad: |
| param.data = param.to(torch.float32) |
|
|
| for name, param in model.named_parameters(): |
| if "lora" in name: |
| param.data = param.data.to(dtype=torch.float32) |
|
|
| |
| if pretrained_lora_path is not None: |
| state_dict = load_state_dict(pretrained_lora_path) |
| if state_dict_converter is not None: |
| state_dict = state_dict_converter(state_dict) |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
| all_keys = [i for i, _ in model.named_parameters()] |
| num_updated_keys = len(all_keys) - len(missing_keys) |
| num_unexpected_keys = len(unexpected_keys) |
| print( |
| f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected." |
| ) |
|
|
|
|
| def load_state_dict(file_path, torch_dtype=None): |
| if file_path.endswith(".safetensors"): |
| return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) |
| else: |
| return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) |
|
|
|
|
| def load_state_dict_from_safetensors(file_path, torch_dtype=None): |
| state_dict = {} |
| with safe_open(file_path, framework="pt", device="cpu") as f: |
| for k in f.keys(): |
| state_dict[k] = f.get_tensor(k) |
| if torch_dtype is not None: |
| state_dict[k] = state_dict[k].to(torch_dtype) |
| return state_dict |
|
|
|
|
| def load_state_dict_from_bin(file_path, torch_dtype=None): |
| state_dict = torch.load(file_path, map_location="cpu", weights_only=True) |
| if torch_dtype is not None: |
| for i in state_dict: |
| if isinstance(state_dict[i], torch.Tensor): |
| state_dict[i] = state_dict[i].to(torch_dtype) |
| return state_dict |
|
|