from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments from peft import LoraConfig, get_peft_model, TaskType from datasets import load_dataset import torch import os def main(): # 基础模型位置 model_name = "dushuai112233/Qwen2-1.5B-Instruct" device = "cuda" if torch.cuda.is_available() else "cpu" # 加载分词器和模型 tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) # Setup PEFT peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1 ) model = get_peft_model(model, peft_config) # 加载数据集 ds = load_dataset("dushuai112233/medical") train_dataset = ds["train"] val_dataset = ds["validation"] # 数据集预处理 def tokenize_function(examples): encodings = tokenizer(examples['question'], padding='max_length', truncation=True, max_length=128) encodings['labels'] = encodings['input_ids'].copy() return encodings train_dataset = train_dataset.map(tokenize_function, batched=True) val_dataset = val_dataset.map(tokenize_function, batched=True) # 设置训练参数 training_args = TrainingArguments( output_dir="./output", evaluation_strategy="epoch", per_device_train_batch_size=1, per_device_eval_batch_size=1, logging_dir="./logs", logging_steps=10, save_steps=100, # 每 100 步保存一次检查点 save_total_limit=2, # 限制最多保存 2 个检查点 num_train_epochs=10, load_best_model_at_end=False, # 是否在训练结束时加载最优模型 ) # 定义 Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, tokenizer=tokenizer, ) # 检查是否有中断点 checkpoint = None if os.path.exists("./output") and len(os.listdir("./output")) > 0: checkpoint = max([os.path.join("./output", ckpt) for ckpt in os.listdir("./output")], key=os.path.getmtime) print(f"Resuming training from checkpoint: {checkpoint}") # 开始训练 trainer.train(resume_from_checkpoint=checkpoint) # 保存最终模型 model.save_pretrained('./output') if __name__ == '__main__': main()