import numpy as np import torch from collections import defaultdict import re def extract_base_key_from_lora_key_fixed(key: str, state_dict_keys: set) -> str: if not (key.endswith(".A.weight") or key.endswith(".B.weight")): return None m = re.match(r"^(.*)\.heads_lora\.lora_([a-zA-Z0-9_]+)\.lora_(\d+)\.[AB]\.weight$", key) if m: prefix, name, idx = m.group(1), m.group(2), m.group(3) candidate = f"{prefix}.heads.{name}.{idx}.weight" return candidate if candidate in state_dict_keys else None # Case 1: .lora_to_* → .to_* if ".lora_to_" in key: return key.replace("lora_to_", "to_").replace(".A.weight", ".weight").replace(".B.weight", ".weight") # Case 2: ...lora_layers.lora_X.lora_Y.A/B.weight → ...layers.X.Y.weight m = re.match(r"(.*)\.lora_layers\.lora_(\d+)\.lora_(\d+)\.[AB]\.weight", key) if m: prefix, x, y = m.group(1), m.group(2), m.group(3) candidate = f"{prefix}.layers.{x}.{y}.weight" return candidate if candidate in state_dict_keys else None # Case 3: *_lora.lora_{i}.A/B.weight → *.{i}.weight m = re.match(r"(.*)_lora\.lora_(\d+)\.[AB]\.weight", key) if m: candidate = f"{m.group(1)}.{m.group(2)}.weight" return candidate if candidate in state_dict_keys else None # Case 4: *.lora_layers.lora_X.A/B.weight → *.X.weight m = re.match(r"(.*)\.lora_layers\.lora_(\d+)\.[AB]\.weight", key) if m: root = m.group(1).rsplit(".", 1)[0] idx = m.group(2) candidate = f"{root}.{idx}.weight" return candidate if candidate in state_dict_keys else None # Case 5: *_lora.A/B.weight → *.weight if "_lora" in key: candidate = key.replace("_lora", "").replace(".A.weight", ".weight").replace(".B.weight", ".weight") return candidate if candidate in state_dict_keys else None # Case 7: *.lora_linearX.A/B.weight → *.linearX.weight m = re.match(r"(.*)\.lora_linear(\d+)\.[AB]\.weight", key) if m: prefix, layer = m.group(1), m.group(2) candidate = f"{prefix}.linear{layer}.weight" return candidate if candidate in state_dict_keys else None # Case 8: general lora_{name}.A/B.weight → {name}.weight m = re.match(r"(.*)\.lora_([a-zA-Z0-9_]+)\.[AB]\.weight", key) if m: candidate = f"{m.group(1)}.{m.group(2)}.weight" return candidate if candidate in state_dict_keys else None return None def find_lora_matches_fixed(state_dict): state_dict_keys = set(state_dict.keys()) lora_map = defaultdict(dict) for key in state_dict: base_key = extract_base_key_from_lora_key_fixed(key, state_dict_keys) if base_key: kind = "A" if key.endswith(".A.weight") else "B" lora_map[base_key][kind] = key matched = { base_key: (vals["A"], vals["B"]) for base_key, vals in lora_map.items() if "A" in vals and "B" in vals } return matched def merge_lora_into_state_dict(state_dict, lora_matches, alpha=1.0): for base_key, (a_key, b_key) in lora_matches.items(): if base_key not in state_dict: print(f"[WARN] Base key not found: {base_key}") continue A = state_dict[a_key] B = state_dict[b_key] delta = (B @ A) * (alpha / B.shape[-1]) print(f"[INFO] Merged: {base_key} ← {a_key}, {b_key}") state_dict[base_key] += delta def remove_lora_keys(state_dict, lora_matches): keys_to_remove = set() for _, (a_key, b_key) in lora_matches.items(): keys_to_remove.update([a_key, b_key]) for k in keys_to_remove: if k in state_dict: del state_dict[k] return state_dict if __name__=='__main__': from copy import deepcopy model = torch.load('/xxx.pth') state_dict = model['state_dict'] lora_matches = find_lora_matches_fixed(state_dict) merge_lora_into_state_dict(state_dict, lora_matches) state_dict = remove_lora_keys(state_dict, lora_matches) for k, v in state_dict.items(): print(k, v.shape) model['state_dict'] = deepcopy(state_dict) torch.save(model, 'xxx_merged.pth')