litwell commited on
Commit
5d224a4
·
verified ·
1 Parent(s): 918e37b

Upload models/src/merge_lora_weights.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/src/merge_lora_weights.py +22 -0
models/src/merge_lora_weights.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from utils import get_model_name_from_path, load_pretrained_model
3
+
4
+ def merge_lora(args):
5
+ model_name = get_model_name_from_path(args.model_path)
6
+ processor, model = load_pretrained_model(model_path=args.model_path, model_base=args.model_base,
7
+ model_name=model_name, device_map='cpu')
8
+
9
+ model.save_pretrained(args.save_model_path, safe_serialization=args.safe_serialization)
10
+ processor.save_pretrained(args.save_model_path)
11
+
12
+
13
+ if __name__ == "__main__":
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument("--model-path", type=str, required=True)
16
+ parser.add_argument("--model-base", type=str, required=True)
17
+ parser.add_argument("--save-model-path", type=str, required=True)
18
+ parser.add_argument("--safe-serialization", action='store_true')
19
+
20
+ args = parser.parse_args()
21
+
22
+ merge_lora(args)