import torch from datasets import load_dataset from peft import LoraConfig, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments from trl import SFTTrainer # 1. 加载模型和分词器 # 使用Qwen2.5-0.5B模型,小尺寸适合CPU训练 model_name = "Qwen/Qwen2.5-0.5B" # 使用小尺寸的Qwen模型 model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True, trust_remote_code=True, ) model.config.use_cache = False tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # 为Qwen添加pad token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 2. 加载并准备数据集 dataset = load_dataset("json", data_files="data.json", split="train") # 将数据集转换为文本格式,添加结束标记 def convert_to_text(examples): texts = [] for i in range(len(examples['instruction'])): text = f"### Instruction:\n{examples['instruction'][i]}\n\n### Input:\n{examples['input'][i]}\n\n### Response:\n{examples['output'][i]}{tokenizer.eos_token}" texts.append(text) return {"text": texts} dataset = dataset.map(convert_to_text, batched=True, remove_columns=dataset.column_names) # 3. 配置LoRA参数 lora_config = LoraConfig( r=16, # 增加rank lora_alpha=32, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # Qwen的注意力和MLP模块 ) # 4. 创建PEFT模型 model = get_peft_model(model, lora_config) # 5. 配置训练参数 output_dir = "./qwen2.5-0.5b-lora" training_args = TrainingArguments( output_dir=output_dir, per_device_train_batch_size=2, # 稍微增加批次大小 gradient_accumulation_steps=4, # 减少梯度累积 learning_rate=2e-4, # 调整学习率 logging_steps=10, max_steps=200, # 增加训练步数 save_strategy="steps", save_steps=50, dataloader_num_workers=0, fp16=False, report_to=[], remove_unused_columns=False, warmup_steps=20, # 添加预热步数 weight_decay=0.01, # 添加权重衰减 ) # 6. 创建SFTTrainer trainer = SFTTrainer( model=model, train_dataset=dataset, args=training_args, peft_config=lora_config, ) # 7. 开始训练 print("开始训练Qwen2.5-0.5B模型...") trainer.train() # 8. 保存模型 print("保存Qwen2.5-0.5B LoRA适配器...") trainer.save_model(output_dir) print(f"Qwen2.5-0.5B LoRA适配器已保存到 {output_dir}")