hr-eval-api-v2 / scripts /merge_lora.py
KarenYYH
Initial commit - HR Evaluation API v2
c8b1f17
"""
LoRA模型合并脚本
将LoRA适配器合并到基础模型中
"""
import argparse
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
def merge_lora(
base_model_path: str,
lora_path: str,
output_dir: str
):
"""
合并LoRA适配器到基础模型
Args:
base_model_path: 基础模型路径或名称
lora_path: LoRA适配器路径
output_dir: 输出目录
"""
print("="*50)
print("合并LoRA模型")
print("="*50)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# 检测设备
use_mps = torch.backends.mps.is_available()
device = "mps" if use_mps else ("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n使用设备: {device}")
torch_dtype = torch.bfloat16 if use_mps else torch.float16
# 加载基础模型
print(f"\n加载基础模型: {base_model_path}")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
torch_dtype=torch_dtype,
device_map=None if use_mps else "auto",
trust_remote_code=True
)
if use_mps:
base_model = base_model.to("mps")
# 加载tokenizer
print(f"加载tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
base_model_path,
trust_remote_code=True
)
# 加载LoRA适配器
print(f"加载LoRA适配器: {lora_path}")
model = PeftModel.from_pretrained(
base_model,
lora_path
)
# 合并模型
print(f"\n合并LoRA权重...")
merged_model = model.merge_and_unload()
# 保存合并后的模型
print(f"保存合并后的模型到: {output_dir}")
merged_model.save_pretrained(
output_dir,
safe_serialization=True
)
tokenizer.save_pretrained(output_dir)
# 打印模型信息
print(f"\n✓ 模型合并完成!")
print(f"\n合并后模型信息:")
print(f" 路径: {output_dir}")
print(f" 模型大小: {sum(f.stat().st_size for f in output_dir.rglob('*.safetensors') if f.is_file()) / 1024**3:.2f} GB")
# 验证模型
print(f"\n验证合并后的模型...")
try:
test_model = AutoModelForCausalLM.from_pretrained(
output_dir,
device_map=None if use_mps else "auto",
torch_dtype=torch_dtype,
trust_remote_code=True
)
print(f"✓ 模型验证成功,可以正常加载")
except Exception as e:
print(f"⚠️ 模型验证失败: {e}")
return output_dir
def test_merged_model(model_path: str, test_prompts: list = None):
"""测试合并后的模型"""
print("\n" + "="*50)
print("测试合并后的模型")
print("="*50)
import torch
# 加载模型
print(f"\n加载模型: {model_path}")
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True
)
# 默认测试提示
if test_prompts is None:
test_prompts = [
"我想申请培训",
"请问年假怎么计算?",
"我想查询社保缴纳情况"
]
print(f"\n测试提示数: {len(test_prompts)}")
for idx, prompt in enumerate(test_prompts, 1):
print(f"\n--- 测试 {idx} ---")
print(f"输入: {prompt}")
# 构建消息
messages = [
{"role": "system", "content": "你是一个专业的HR助手,请礼貌、准确地回答员工问题。"},
{"role": "user", "content": prompt}
]
# 应用模板
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# 编码
inputs = tokenizer(text, return_tensors="pt").to(model.device)
# 生成
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
do_sample=True
)
# 解码
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
print(f"输出: {response}")
print(f"\n✓ 测试完成")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='合并LoRA模型')
parser.add_argument(
'--base_model',
type=str,
required=True,
help='基础模型路径或名称'
)
parser.add_argument(
'--lora_path',
type=str,
required=True,
help='LoRA适配器路径'
)
parser.add_argument(
'--output_dir',
type=str,
required=True,
help='输出目录'
)
parser.add_argument(
'--test',
action='store_true',
help='是否测试合并后的模型'
)
args = parser.parse_args()
# 合并模型
output_path = merge_lora(
base_model_path=args.base_model,
lora_path=args.lora_path,
output_dir=args.output_dir
)
# 测试模型
if args.test:
test_merged_model(str(output_path))
print(f"\n✓ 完成!")
print(f"\n使用方法:")
print(f" from transformers import AutoModelForCausalLM")
print(f" model = AutoModelForCausalLM.from_pretrained('{args.output_dir}')")
print(f" tokenizer = AutoTokenizer.from_pretrained('{args.output_dir}')")
if __name__ == '__main__':
main()