my-llama2-finetune / fine_tune_improved.py
RianLi's picture
Upload fine_tune_improved.py
bf41bee verified
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
# 1. 加载模型和分词器
# 使用Qwen2.5-0.5B模型,小尺寸适合CPU训练
model_name = "Qwen/Qwen2.5-0.5B" # 使用小尺寸的Qwen模型
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# 为Qwen添加pad token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 2. 加载并准备数据集
dataset = load_dataset("json", data_files="data.json", split="train")
# 将数据集转换为文本格式,添加结束标记
def convert_to_text(examples):
texts = []
for i in range(len(examples['instruction'])):
text = f"### Instruction:\n{examples['instruction'][i]}\n\n### Input:\n{examples['input'][i]}\n\n### Response:\n{examples['output'][i]}{tokenizer.eos_token}"
texts.append(text)
return {"text": texts}
dataset = dataset.map(convert_to_text, batched=True, remove_columns=dataset.column_names)
# 3. 配置LoRA参数
lora_config = LoraConfig(
r=16, # 增加rank
lora_alpha=32,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # Qwen的注意力和MLP模块
)
# 4. 创建PEFT模型
model = get_peft_model(model, lora_config)
# 5. 配置训练参数
output_dir = "./qwen2.5-0.5b-lora"
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=2, # 稍微增加批次大小
gradient_accumulation_steps=4, # 减少梯度累积
learning_rate=2e-4, # 调整学习率
logging_steps=10,
max_steps=200, # 增加训练步数
save_strategy="steps",
save_steps=50,
dataloader_num_workers=0,
fp16=False,
report_to=[],
remove_unused_columns=False,
warmup_steps=20, # 添加预热步数
weight_decay=0.01, # 添加权重衰减
)
# 6. 创建SFTTrainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
peft_config=lora_config,
)
# 7. 开始训练
print("开始训练Qwen2.5-0.5B模型...")
trainer.train()
# 8. 保存模型
print("保存Qwen2.5-0.5B LoRA适配器...")
trainer.save_model(output_dir)
print(f"Qwen2.5-0.5B LoRA适配器已保存到 {output_dir}")