lingbot-vla / lingbotvla /utils /lora_utils.py
bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from peft import LoraConfig, inject_adapter_in_model
from safetensors import safe_open
def freeze_parameters(model: nn.Module):
# Freeze parameters
model.requires_grad_(False)
model.eval()
model.train()
def add_lora_to_model(
model: nn.Module,
lora_rank=4,
lora_alpha=4,
lora_target_modules="q,k,v,o,ffn.0,ffn.2",
init_lora_weights="kaiming",
pretrained_lora_path=None,
state_dict_converter=None,
lora_target_modules_support=None,
):
model.lora_alpha = lora_alpha
if init_lora_weights == "kaiming":
init_lora_weights = True
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights=init_lora_weights,
target_modules=lora_target_modules.split(","),
)
for lora_target_module in lora_config.target_modules:
if lora_target_module not in lora_target_modules_support:
raise ValueError(f"lora_target_module {lora_target_module} not in lora_target_modules_support")
model = inject_adapter_in_model(lora_config, model)
for param in model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)
for name, param in model.named_parameters():
if "lora" in name:
param.data = param.data.to(dtype=torch.float32)
# Lora pretrained lora weights
if pretrained_lora_path is not None:
state_dict = load_state_dict(pretrained_lora_path)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
all_keys = [i for i, _ in model.named_parameters()]
num_updated_keys = len(all_keys) - len(missing_keys)
num_unexpected_keys = len(unexpected_keys)
print(
f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected."
)
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict