Spaces:
Sleeping
Sleeping
| """ | |
| LoRA模型合并脚本 | |
| 将LoRA适配器合并到基础模型中 | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| def merge_lora( | |
| base_model_path: str, | |
| lora_path: str, | |
| output_dir: str | |
| ): | |
| """ | |
| 合并LoRA适配器到基础模型 | |
| Args: | |
| base_model_path: 基础模型路径或名称 | |
| lora_path: LoRA适配器路径 | |
| output_dir: 输出目录 | |
| """ | |
| print("="*50) | |
| print("合并LoRA模型") | |
| print("="*50) | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # 检测设备 | |
| use_mps = torch.backends.mps.is_available() | |
| device = "mps" if use_mps else ("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"\n使用设备: {device}") | |
| torch_dtype = torch.bfloat16 if use_mps else torch.float16 | |
| # 加载基础模型 | |
| print(f"\n加载基础模型: {base_model_path}") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_path, | |
| torch_dtype=torch_dtype, | |
| device_map=None if use_mps else "auto", | |
| trust_remote_code=True | |
| ) | |
| if use_mps: | |
| base_model = base_model.to("mps") | |
| # 加载tokenizer | |
| print(f"加载tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| base_model_path, | |
| trust_remote_code=True | |
| ) | |
| # 加载LoRA适配器 | |
| print(f"加载LoRA适配器: {lora_path}") | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| lora_path | |
| ) | |
| # 合并模型 | |
| print(f"\n合并LoRA权重...") | |
| merged_model = model.merge_and_unload() | |
| # 保存合并后的模型 | |
| print(f"保存合并后的模型到: {output_dir}") | |
| merged_model.save_pretrained( | |
| output_dir, | |
| safe_serialization=True | |
| ) | |
| tokenizer.save_pretrained(output_dir) | |
| # 打印模型信息 | |
| print(f"\n✓ 模型合并完成!") | |
| print(f"\n合并后模型信息:") | |
| print(f" 路径: {output_dir}") | |
| print(f" 模型大小: {sum(f.stat().st_size for f in output_dir.rglob('*.safetensors') if f.is_file()) / 1024**3:.2f} GB") | |
| # 验证模型 | |
| print(f"\n验证合并后的模型...") | |
| try: | |
| test_model = AutoModelForCausalLM.from_pretrained( | |
| output_dir, | |
| device_map=None if use_mps else "auto", | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=True | |
| ) | |
| print(f"✓ 模型验证成功,可以正常加载") | |
| except Exception as e: | |
| print(f"⚠️ 模型验证失败: {e}") | |
| return output_dir | |
| def test_merged_model(model_path: str, test_prompts: list = None): | |
| """测试合并后的模型""" | |
| print("\n" + "="*50) | |
| print("测试合并后的模型") | |
| print("="*50) | |
| import torch | |
| # 加载模型 | |
| print(f"\n加载模型: {model_path}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True | |
| ) | |
| # 默认测试提示 | |
| if test_prompts is None: | |
| test_prompts = [ | |
| "我想申请培训", | |
| "请问年假怎么计算?", | |
| "我想查询社保缴纳情况" | |
| ] | |
| print(f"\n测试提示数: {len(test_prompts)}") | |
| for idx, prompt in enumerate(test_prompts, 1): | |
| print(f"\n--- 测试 {idx} ---") | |
| print(f"输入: {prompt}") | |
| # 构建消息 | |
| messages = [ | |
| {"role": "system", "content": "你是一个专业的HR助手,请礼貌、准确地回答员工问题。"}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| # 应用模板 | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # 编码 | |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) | |
| # 生成 | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=100, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| # 解码 | |
| response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
| print(f"输出: {response}") | |
| print(f"\n✓ 测试完成") | |
| def main(): | |
| """主函数""" | |
| parser = argparse.ArgumentParser(description='合并LoRA模型') | |
| parser.add_argument( | |
| '--base_model', | |
| type=str, | |
| required=True, | |
| help='基础模型路径或名称' | |
| ) | |
| parser.add_argument( | |
| '--lora_path', | |
| type=str, | |
| required=True, | |
| help='LoRA适配器路径' | |
| ) | |
| parser.add_argument( | |
| '--output_dir', | |
| type=str, | |
| required=True, | |
| help='输出目录' | |
| ) | |
| parser.add_argument( | |
| '--test', | |
| action='store_true', | |
| help='是否测试合并后的模型' | |
| ) | |
| args = parser.parse_args() | |
| # 合并模型 | |
| output_path = merge_lora( | |
| base_model_path=args.base_model, | |
| lora_path=args.lora_path, | |
| output_dir=args.output_dir | |
| ) | |
| # 测试模型 | |
| if args.test: | |
| test_merged_model(str(output_path)) | |
| print(f"\n✓ 完成!") | |
| print(f"\n使用方法:") | |
| print(f" from transformers import AutoModelForCausalLM") | |
| print(f" model = AutoModelForCausalLM.from_pretrained('{args.output_dir}')") | |
| print(f" tokenizer = AutoTokenizer.from_pretrained('{args.output_dir}')") | |
| if __name__ == '__main__': | |
| main() | |