R2SE_model / data_yaml /combine_lora_weight.py
unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
import numpy as np
import torch
from collections import defaultdict
import re
def extract_base_key_from_lora_key_fixed(key: str, state_dict_keys: set) -> str:
if not (key.endswith(".A.weight") or key.endswith(".B.weight")):
return None
m = re.match(r"^(.*)\.heads_lora\.lora_([a-zA-Z0-9_]+)\.lora_(\d+)\.[AB]\.weight$", key)
if m:
prefix, name, idx = m.group(1), m.group(2), m.group(3)
candidate = f"{prefix}.heads.{name}.{idx}.weight"
return candidate if candidate in state_dict_keys else None
# Case 1: .lora_to_* β†’ .to_*
if ".lora_to_" in key:
return key.replace("lora_to_", "to_").replace(".A.weight", ".weight").replace(".B.weight", ".weight")
# Case 2: ...lora_layers.lora_X.lora_Y.A/B.weight β†’ ...layers.X.Y.weight
m = re.match(r"(.*)\.lora_layers\.lora_(\d+)\.lora_(\d+)\.[AB]\.weight", key)
if m:
prefix, x, y = m.group(1), m.group(2), m.group(3)
candidate = f"{prefix}.layers.{x}.{y}.weight"
return candidate if candidate in state_dict_keys else None
# Case 3: *_lora.lora_{i}.A/B.weight β†’ *.{i}.weight
m = re.match(r"(.*)_lora\.lora_(\d+)\.[AB]\.weight", key)
if m:
candidate = f"{m.group(1)}.{m.group(2)}.weight"
return candidate if candidate in state_dict_keys else None
# Case 4: *.lora_layers.lora_X.A/B.weight β†’ *.X.weight
m = re.match(r"(.*)\.lora_layers\.lora_(\d+)\.[AB]\.weight", key)
if m:
root = m.group(1).rsplit(".", 1)[0]
idx = m.group(2)
candidate = f"{root}.{idx}.weight"
return candidate if candidate in state_dict_keys else None
# Case 5: *_lora.A/B.weight β†’ *.weight
if "_lora" in key:
candidate = key.replace("_lora", "").replace(".A.weight", ".weight").replace(".B.weight", ".weight")
return candidate if candidate in state_dict_keys else None
# Case 7: *.lora_linearX.A/B.weight β†’ *.linearX.weight
m = re.match(r"(.*)\.lora_linear(\d+)\.[AB]\.weight", key)
if m:
prefix, layer = m.group(1), m.group(2)
candidate = f"{prefix}.linear{layer}.weight"
return candidate if candidate in state_dict_keys else None
# Case 8: general lora_{name}.A/B.weight β†’ {name}.weight
m = re.match(r"(.*)\.lora_([a-zA-Z0-9_]+)\.[AB]\.weight", key)
if m:
candidate = f"{m.group(1)}.{m.group(2)}.weight"
return candidate if candidate in state_dict_keys else None
return None
def find_lora_matches_fixed(state_dict):
state_dict_keys = set(state_dict.keys())
lora_map = defaultdict(dict)
for key in state_dict:
base_key = extract_base_key_from_lora_key_fixed(key, state_dict_keys)
if base_key:
kind = "A" if key.endswith(".A.weight") else "B"
lora_map[base_key][kind] = key
matched = {
base_key: (vals["A"], vals["B"])
for base_key, vals in lora_map.items()
if "A" in vals and "B" in vals
}
return matched
def merge_lora_into_state_dict(state_dict, lora_matches, alpha=1.0):
for base_key, (a_key, b_key) in lora_matches.items():
if base_key not in state_dict:
print(f"[WARN] Base key not found: {base_key}")
continue
A = state_dict[a_key]
B = state_dict[b_key]
delta = (B @ A) * (alpha / B.shape[-1])
print(f"[INFO] Merged: {base_key} ← {a_key}, {b_key}")
state_dict[base_key] += delta
def remove_lora_keys(state_dict, lora_matches):
keys_to_remove = set()
for _, (a_key, b_key) in lora_matches.items():
keys_to_remove.update([a_key, b_key])
for k in keys_to_remove:
if k in state_dict:
del state_dict[k]
return state_dict
if __name__=='__main__':
from copy import deepcopy
model = torch.load('/xxx.pth')
state_dict = model['state_dict']
lora_matches = find_lora_matches_fixed(state_dict)
merge_lora_into_state_dict(state_dict, lora_matches)
state_dict = remove_lora_keys(state_dict, lora_matches)
for k, v in state_dict.items():
print(k, v.shape)
model['state_dict'] = deepcopy(state_dict)
torch.save(model, 'xxx_merged.pth')