RianLi commited on
Commit
dd099c7
·
verified ·
1 Parent(s): 4334bc5

Delete fine_tune.py

Browse files
Files changed (1) hide show
  1. fine_tune.py +0 -77
fine_tune.py DELETED
@@ -1,77 +0,0 @@
1
- import torch
2
- from datasets import load_dataset
3
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
5
- from trl import SFTTrainer
6
-
7
- # 1. 加载模型和分词器 (CPU优化版本)
8
- # 使用更小的模型以适配CPU环境
9
- model_name = "microsoft/DialoGPT-small" # 更小的模型,适合CPU训练
10
-
11
- # CPU环境下不需要量化配置
12
- model = AutoModelForCausalLM.from_pretrained(
13
- model_name,
14
- torch_dtype=torch.float32, # CPU使用float32
15
- low_cpu_mem_usage=True, # 优化CPU内存使用
16
- )
17
- model.config.use_cache = False
18
-
19
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
20
- tokenizer.pad_token = tokenizer.eos_token # set pad token
21
-
22
- # 2. 加载并准备数据集
23
- def formatting_prompts_func(example):
24
- output_texts = []
25
- for i in range(len(example['instruction'])):
26
- text = f"### Instruction:\n{example['instruction'][i]}\n\n### Input:\n{example['input'][i]}\n\n### Response:\n{example['output'][i]}"
27
- output_texts.append(text)
28
- return output_texts
29
-
30
- dataset = load_dataset("json", data_files="data.json", split="train")
31
-
32
- # 3. 配置LoRA参数 (适配DialoGPT)
33
- lora_config = LoraConfig(
34
- r=8, # Rank
35
- lora_alpha=32,
36
- lora_dropout=0.1,
37
- bias="none",
38
- task_type="CAUSAL_LM",
39
- target_modules=["c_attn", "c_proj"], # DialoGPT/GPT-2 架构的注意力模块
40
- )
41
-
42
- # 4. 创建PEFT模型 (CPU版本)
43
- # CPU环境下不需要量化准备
44
- model = get_peft_model(model, lora_config)
45
-
46
- # 5. 配置训练参数 (CPU优化)
47
- output_dir = "./dialogpt-small-lora"
48
- training_args = TrainingArguments(
49
- output_dir=output_dir,
50
- per_device_train_batch_size=1, # CPU环境使用更小的批次
51
- gradient_accumulation_steps=8, # 增加梯度累积以补偿小批次
52
- learning_rate=5e-4, # 稍微提高学习率
53
- logging_steps=5,
54
- max_steps=50, # 减少训练步数用于演示
55
- save_strategy="steps",
56
- save_steps=25,
57
- dataloader_num_workers=0, # CPU环境下设为0
58
- fp16=False, # CPU不支持fp16
59
- report_to=None, # 禁用wandb等报告
60
- )
61
-
62
- # 6. 创建Trainer并开始训练
63
- trainer = SFTTrainer(
64
- model=model,
65
- train_dataset=dataset,
66
- args=training_args,
67
- peft_config=lora_config,
68
- formatting_func=formatting_prompts_func,
69
- max_seq_length=512,
70
- )
71
-
72
- trainer.train()
73
-
74
- # 7. 保存模型
75
- print("Saving DialoGPT LoRA adapter...")
76
- trainer.save_model(output_dir)
77
- print(f"DialoGPT LoRA adapter saved to {output_dir}")