|
|
import numpy as np |
|
|
import torch |
|
|
from collections import defaultdict |
|
|
import re |
|
|
|
|
|
def extract_base_key_from_lora_key_fixed(key: str, state_dict_keys: set) -> str: |
|
|
if not (key.endswith(".A.weight") or key.endswith(".B.weight")): |
|
|
return None |
|
|
|
|
|
m = re.match(r"^(.*)\.heads_lora\.lora_([a-zA-Z0-9_]+)\.lora_(\d+)\.[AB]\.weight$", key) |
|
|
if m: |
|
|
prefix, name, idx = m.group(1), m.group(2), m.group(3) |
|
|
candidate = f"{prefix}.heads.{name}.{idx}.weight" |
|
|
return candidate if candidate in state_dict_keys else None |
|
|
|
|
|
|
|
|
if ".lora_to_" in key: |
|
|
return key.replace("lora_to_", "to_").replace(".A.weight", ".weight").replace(".B.weight", ".weight") |
|
|
|
|
|
|
|
|
m = re.match(r"(.*)\.lora_layers\.lora_(\d+)\.lora_(\d+)\.[AB]\.weight", key) |
|
|
if m: |
|
|
prefix, x, y = m.group(1), m.group(2), m.group(3) |
|
|
candidate = f"{prefix}.layers.{x}.{y}.weight" |
|
|
return candidate if candidate in state_dict_keys else None |
|
|
|
|
|
|
|
|
m = re.match(r"(.*)_lora\.lora_(\d+)\.[AB]\.weight", key) |
|
|
if m: |
|
|
candidate = f"{m.group(1)}.{m.group(2)}.weight" |
|
|
return candidate if candidate in state_dict_keys else None |
|
|
|
|
|
|
|
|
m = re.match(r"(.*)\.lora_layers\.lora_(\d+)\.[AB]\.weight", key) |
|
|
if m: |
|
|
root = m.group(1).rsplit(".", 1)[0] |
|
|
idx = m.group(2) |
|
|
candidate = f"{root}.{idx}.weight" |
|
|
return candidate if candidate in state_dict_keys else None |
|
|
|
|
|
|
|
|
if "_lora" in key: |
|
|
candidate = key.replace("_lora", "").replace(".A.weight", ".weight").replace(".B.weight", ".weight") |
|
|
return candidate if candidate in state_dict_keys else None |
|
|
|
|
|
|
|
|
m = re.match(r"(.*)\.lora_linear(\d+)\.[AB]\.weight", key) |
|
|
if m: |
|
|
prefix, layer = m.group(1), m.group(2) |
|
|
candidate = f"{prefix}.linear{layer}.weight" |
|
|
return candidate if candidate in state_dict_keys else None |
|
|
|
|
|
|
|
|
m = re.match(r"(.*)\.lora_([a-zA-Z0-9_]+)\.[AB]\.weight", key) |
|
|
if m: |
|
|
candidate = f"{m.group(1)}.{m.group(2)}.weight" |
|
|
return candidate if candidate in state_dict_keys else None |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def find_lora_matches_fixed(state_dict): |
|
|
state_dict_keys = set(state_dict.keys()) |
|
|
lora_map = defaultdict(dict) |
|
|
|
|
|
for key in state_dict: |
|
|
base_key = extract_base_key_from_lora_key_fixed(key, state_dict_keys) |
|
|
if base_key: |
|
|
kind = "A" if key.endswith(".A.weight") else "B" |
|
|
lora_map[base_key][kind] = key |
|
|
|
|
|
matched = { |
|
|
base_key: (vals["A"], vals["B"]) |
|
|
for base_key, vals in lora_map.items() |
|
|
if "A" in vals and "B" in vals |
|
|
} |
|
|
return matched |
|
|
|
|
|
|
|
|
def merge_lora_into_state_dict(state_dict, lora_matches, alpha=1.0): |
|
|
for base_key, (a_key, b_key) in lora_matches.items(): |
|
|
if base_key not in state_dict: |
|
|
print(f"[WARN] Base key not found: {base_key}") |
|
|
continue |
|
|
|
|
|
A = state_dict[a_key] |
|
|
B = state_dict[b_key] |
|
|
delta = (B @ A) * (alpha / B.shape[-1]) |
|
|
|
|
|
print(f"[INFO] Merged: {base_key} β {a_key}, {b_key}") |
|
|
state_dict[base_key] += delta |
|
|
|
|
|
|
|
|
def remove_lora_keys(state_dict, lora_matches): |
|
|
keys_to_remove = set() |
|
|
for _, (a_key, b_key) in lora_matches.items(): |
|
|
keys_to_remove.update([a_key, b_key]) |
|
|
for k in keys_to_remove: |
|
|
if k in state_dict: |
|
|
del state_dict[k] |
|
|
return state_dict |
|
|
|
|
|
|
|
|
if __name__=='__main__': |
|
|
from copy import deepcopy |
|
|
model = torch.load('/xxx.pth') |
|
|
state_dict = model['state_dict'] |
|
|
|
|
|
lora_matches = find_lora_matches_fixed(state_dict) |
|
|
merge_lora_into_state_dict(state_dict, lora_matches) |
|
|
state_dict = remove_lora_keys(state_dict, lora_matches) |
|
|
|
|
|
for k, v in state_dict.items(): |
|
|
print(k, v.shape) |
|
|
model['state_dict'] = deepcopy(state_dict) |
|
|
|
|
|
torch.save(model, 'xxx_merged.pth') |