Bopalv commited on
Commit
d5ee486
·
verified ·
1 Parent(s): 6917d84

Upload DPO-Training/merge_lora.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. DPO-Training/merge_lora.py +89 -0
DPO-Training/merge_lora.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Merge LoRA adapters with base model after DPO training.
4
+
5
+ Usage:
6
+ python merge_lora.py --lora_path ./qwen3-0.6b-dpo --output_path ./qwen3-0.6b-dpo-merged
7
+ """
8
+
9
+ import argparse
10
+ import torch
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from peft import PeftModel
13
+ import os
14
+
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser(description="Merge LoRA adapters with base model")
18
+ parser.add_argument(
19
+ "--base_model", default="Qwen/Qwen3-0.6B", help="Base model name or path"
20
+ )
21
+ parser.add_argument(
22
+ "--lora_path", default="./qwen3-0.6b-dpo", help="Path to LoRA adapters"
23
+ )
24
+ parser.add_argument(
25
+ "--output_path",
26
+ default="./qwen3-0.6b-dpo-merged",
27
+ help="Output path for merged model",
28
+ )
29
+ parser.add_argument(
30
+ "--device", default="auto", help="Device to use (auto/cuda/cpu)"
31
+ )
32
+ parser.add_argument(
33
+ "--push_to_hub", default=None, help="Push to HuggingFace Hub (repo name)"
34
+ )
35
+
36
+ args = parser.parse_args()
37
+
38
+ print("=" * 60)
39
+ print("Merging LoRA Adapters")
40
+ print("=" * 60)
41
+ print(f"Base model: {args.base_model}")
42
+ print(f"LoRA path: {args.lora_path}")
43
+ print(f"Output: {args.output_path}")
44
+ print("=" * 60)
45
+
46
+ # Load tokenizer
47
+ print("\n📥 Loading tokenizer...")
48
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
49
+
50
+ # Load base model
51
+ print("📥 Loading base model...")
52
+ base_model = AutoModelForCausalLM.from_pretrained(
53
+ args.base_model,
54
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
55
+ device_map=args.device,
56
+ trust_remote_code=True,
57
+ )
58
+
59
+ # Load LoRA adapters
60
+ print("📥 Loading LoRA adapters...")
61
+ model = PeftModel.from_pretrained(base_model, args.lora_path)
62
+
63
+ # Merge adapters
64
+ print("🔧 Merging adapters...")
65
+ model = model.merge_and_unload()
66
+
67
+ # Save merged model
68
+ print(f"💾 Saving merged model to {args.output_path}...")
69
+ os.makedirs(args.output_path, exist_ok=True)
70
+ model.save_pretrained(args.output_path)
71
+ tokenizer.save_pretrained(args.output_path)
72
+
73
+ print("\n" + "=" * 60)
74
+ print("✅ Merge Complete!")
75
+ print("=" * 60)
76
+ print(f"Merged model saved to: {args.output_path}")
77
+
78
+ # Push to hub if requested
79
+ if args.push_to_hub:
80
+ print(f"\n📤 Pushing to HuggingFace Hub: {args.push_to_hub}")
81
+ model.push_to_hub(args.push_to_hub)
82
+ tokenizer.push_to_hub(args.push_to_hub)
83
+ print(f"✅ Pushed to: https://huggingface.co/{args.push_to_hub}")
84
+
85
+ print("\n🎉 Done!")
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()