增加lora载入时去除部分模块的正则表达式写法
Browse files- app.py +1 -1
- rwkv_lora.py +23 -2
app.py
CHANGED
|
@@ -17,7 +17,7 @@ parser.add_argument('--ckpt',type=str,default="rwkv-loramerge-0426-v2-4096-epoch
|
|
| 17 |
parser.add_argument('--model_path',type=str,default=None,help="local model path")
|
| 18 |
parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path')
|
| 19 |
parser.add_argument('--lora_alpha', type=float, default=0, help='lora alpha')
|
| 20 |
-
parser.add_argument('--lora_layer_filter',type=str,default=None,help='layer filter. Default merge all layer. Example: "25-31"')
|
| 21 |
args = parser.parse_args()
|
| 22 |
os.environ["RWKV_JIT_ON"] = '1'
|
| 23 |
|
|
|
|
| 17 |
parser.add_argument('--model_path',type=str,default=None,help="local model path")
|
| 18 |
parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path')
|
| 19 |
parser.add_argument('--lora_alpha', type=float, default=0, help='lora alpha')
|
| 20 |
+
parser.add_argument('--lora_layer_filter',type=str,default=None,help='layer filter. Default merge all layer. Example: "0.2*25-31"')
|
| 21 |
args = parser.parse_args()
|
| 22 |
os.environ["RWKV_JIT_ON"] = '1'
|
| 23 |
|
rwkv_lora.py
CHANGED
|
@@ -7,11 +7,21 @@ import types, gc, os, time, re
|
|
| 7 |
import torch
|
| 8 |
from torch.nn import functional as F
|
| 9 |
|
|
|
|
| 10 |
def get_filter_keys_and_merge_coef(layer_filter):
|
| 11 |
if layer_filter:
|
| 12 |
layers = []
|
| 13 |
layer_coef = {}
|
|
|
|
| 14 |
for layer in layer_filter.split(' '):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
if '*' in layer:
|
| 16 |
coef,_,layer = layer.partition('*')
|
| 17 |
coef = float(coef)
|
|
@@ -20,22 +30,31 @@ def get_filter_keys_and_merge_coef(layer_filter):
|
|
| 20 |
if layer.isdecimal():
|
| 21 |
layers.append(int(layer))
|
| 22 |
layer_coef[int(layer)]=coef
|
|
|
|
| 23 |
elif '-' in layer:
|
| 24 |
start,_,end = layer.partition('-')
|
| 25 |
start,end = int(start),int(end)
|
| 26 |
layers.extend(range(start,end+1))
|
| 27 |
for l in range(start,end+1):
|
| 28 |
layer_coef[l] = coef
|
|
|
|
| 29 |
else:
|
| 30 |
raise NotImplementedError("layer_filter Not implemented:",layer_filter)
|
| 31 |
layers = sorted(set(layers))
|
| 32 |
-
layer_prefixes = tuple(f"blocks.{l}." for l in layers)
|
| 33 |
def filter_keys(keys):
|
| 34 |
new_keys = []
|
| 35 |
for key in keys:
|
|
|
|
|
|
|
| 36 |
if key.startswith("blocks."): #过滤掉blocks开头,且不在允许范围内的权重
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
continue
|
|
|
|
|
|
|
| 39 |
new_keys.append(key)
|
| 40 |
return new_keys
|
| 41 |
def merge_coef(key):
|
|
@@ -59,6 +78,8 @@ def lora_merge(base_model,lora,lora_alpha,device="cuda",layer_filter=None,):
|
|
| 59 |
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
|
| 60 |
# pdb.set_trace() #DEBUG
|
| 61 |
for k in filter_keys(w_lora.keys()): #处理time_mixing之类的融合
|
|
|
|
|
|
|
| 62 |
w[k] = w_lora[k]
|
| 63 |
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
|
| 64 |
# merge LoRA weights
|
|
|
|
| 7 |
import torch
|
| 8 |
from torch.nn import functional as F
|
| 9 |
|
| 10 |
+
# valid_filter_pattern = r"(((\d+\.\d+\*)?(\d+)(-\d+)?(/\S+)?|(/\S+))(\s+|$))+"
|
| 11 |
def get_filter_keys_and_merge_coef(layer_filter):
|
| 12 |
if layer_filter:
|
| 13 |
layers = []
|
| 14 |
layer_coef = {}
|
| 15 |
+
layer_remove_patterns = {}
|
| 16 |
for layer in layer_filter.split(' '):
|
| 17 |
+
if '/' in layer: #过滤pattern,需要写成正则表达式
|
| 18 |
+
layer,_,remove_pattern = layer.partition('/')
|
| 19 |
+
remove_pattern = re.compile(remove_pattern)
|
| 20 |
+
else:
|
| 21 |
+
remove_pattern = None
|
| 22 |
+
if layer=='':
|
| 23 |
+
layer_remove_patterns['global']=remove_pattern
|
| 24 |
+
continue
|
| 25 |
if '*' in layer:
|
| 26 |
coef,_,layer = layer.partition('*')
|
| 27 |
coef = float(coef)
|
|
|
|
| 30 |
if layer.isdecimal():
|
| 31 |
layers.append(int(layer))
|
| 32 |
layer_coef[int(layer)]=coef
|
| 33 |
+
layer_remove_patterns[int(layer)]=remove_pattern
|
| 34 |
elif '-' in layer:
|
| 35 |
start,_,end = layer.partition('-')
|
| 36 |
start,end = int(start),int(end)
|
| 37 |
layers.extend(range(start,end+1))
|
| 38 |
for l in range(start,end+1):
|
| 39 |
layer_coef[l] = coef
|
| 40 |
+
layer_remove_patterns[l]=remove_pattern
|
| 41 |
else:
|
| 42 |
raise NotImplementedError("layer_filter Not implemented:",layer_filter)
|
| 43 |
layers = sorted(set(layers))
|
| 44 |
+
# layer_prefixes = tuple(f"blocks.{l}." for l in layers)
|
| 45 |
def filter_keys(keys):
|
| 46 |
new_keys = []
|
| 47 |
for key in keys:
|
| 48 |
+
if layer_remove_patterns.get("global") and layer_remove_patterns['global'].search(key):
|
| 49 |
+
continue #符合全局去除规则
|
| 50 |
if key.startswith("blocks."): #过滤掉blocks开头,且不在允许范围内的权重
|
| 51 |
+
l = int(key.split('.')[1])
|
| 52 |
+
if l not in layers: #不在允许层,过滤掉
|
| 53 |
+
continue
|
| 54 |
+
if layer_remove_patterns[l] and layer_remove_patterns[l].search(key): #符合对应层的去除规则,过滤掉
|
| 55 |
continue
|
| 56 |
+
# if not key.startswith(layer_prefixes):
|
| 57 |
+
# continue
|
| 58 |
new_keys.append(key)
|
| 59 |
return new_keys
|
| 60 |
def merge_coef(key):
|
|
|
|
| 78 |
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
|
| 79 |
# pdb.set_trace() #DEBUG
|
| 80 |
for k in filter_keys(w_lora.keys()): #处理time_mixing之类的融合
|
| 81 |
+
if k in w:
|
| 82 |
+
print(f"replacing {k}")
|
| 83 |
w[k] = w_lora[k]
|
| 84 |
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
|
| 85 |
# merge LoRA weights
|