# Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from peft import LoraConfig, inject_adapter_in_model from safetensors import safe_open def freeze_parameters(model: nn.Module): # Freeze parameters model.requires_grad_(False) model.eval() model.train() def add_lora_to_model( model: nn.Module, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None, lora_target_modules_support=None, ): model.lora_alpha = lora_alpha if init_lora_weights == "kaiming": init_lora_weights = True lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, init_lora_weights=init_lora_weights, target_modules=lora_target_modules.split(","), ) for lora_target_module in lora_config.target_modules: if lora_target_module not in lora_target_modules_support: raise ValueError(f"lora_target_module {lora_target_module} not in lora_target_modules_support") model = inject_adapter_in_model(lora_config, model) for param in model.parameters(): if param.requires_grad: param.data = param.to(torch.float32) for name, param in model.named_parameters(): if "lora" in name: param.data = param.data.to(dtype=torch.float32) # Lora pretrained lora weights if pretrained_lora_path is not None: state_dict = load_state_dict(pretrained_lora_path) if state_dict_converter is not None: state_dict = state_dict_converter(state_dict) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) all_keys = [i for i, _ in model.named_parameters()] num_updated_keys = len(all_keys) - len(missing_keys) num_unexpected_keys = len(unexpected_keys) print( f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected." ) def load_state_dict(file_path, torch_dtype=None): if file_path.endswith(".safetensors"): return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) else: return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) def load_state_dict_from_safetensors(file_path, torch_dtype=None): state_dict = {} with safe_open(file_path, framework="pt", device="cpu") as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) if torch_dtype is not None: state_dict[k] = state_dict[k].to(torch_dtype) return state_dict def load_state_dict_from_bin(file_path, torch_dtype=None): state_dict = torch.load(file_path, map_location="cpu", weights_only=True) if torch_dtype is not None: for i in state_dict: if isinstance(state_dict[i], torch.Tensor): state_dict[i] = state_dict[i].to(torch_dtype) return state_dict