pangxiang commited on
Commit
b5bd3c2
·
verified ·
1 Parent(s): 82e64ab

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +140 -13
train.py CHANGED
@@ -1,27 +1,154 @@
1
  import json
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
 
 
 
 
 
 
 
 
4
 
5
  def load_training_data():
6
- with open('train_data.json', 'r', encoding='utf-8') as f:
7
- return json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def train_model():
 
 
 
10
  # 加载数据
11
  data = load_training_data()
12
 
13
- # 初始化tokenizer和模型(使用小模型)
14
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
15
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # 数据处理和训练逻辑
18
- # ... 这里添加你的训练代码
19
 
20
- # 保存模型
21
- model.save_pretrained("./trained_model")
22
- tokenizer.save_pretrained("./trained_model")
 
 
 
 
 
 
23
 
24
- print("模型训练完成!")
 
 
 
 
25
 
26
  if __name__ == "__main__":
27
- train_model()
 
 
 
 
 
 
 
1
  import json
2
  import torch
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ TrainingArguments,
7
+ Trainer,
8
+ DataCollatorForLanguageModeling
9
+ )
10
+ from datasets import Dataset
11
+ import os
12
 
13
  def load_training_data():
14
+ """加载训练数据"""
15
+ try:
16
+ with open('train_data.json', 'r', encoding='utf-8') as f:
17
+ data = json.load(f)
18
+ print(f"📊 加载了 {len(data)} 条训练数据")
19
+ return data
20
+ except FileNotFoundError:
21
+ print("❌ 训练数据文件不存在,使用示例数据")
22
+ # 返回一些示例数据
23
+ return [
24
+ {
25
+ "input": "print('hello",
26
+ "output": "print('hello')",
27
+ "language": "python"
28
+ },
29
+ {
30
+ "input": "<div class=test>",
31
+ "output": "<div class=\"test\">",
32
+ "language": "html"
33
+ }
34
+ ]
35
+
36
+ def prepare_dataset(data):
37
+ """准备训练数据集"""
38
+ texts = []
39
+
40
+ for item in data:
41
+ # 创建训练文本格式
42
+ prompt = f"修复以下{item.get('language', 'code')}代码:\n{item['input']}\n修复后:\n{item['output']}"
43
+ texts.append(prompt)
44
+
45
+ return Dataset.from_dict({"text": texts})
46
 
47
  def train_model():
48
+ """训练模型"""
49
+ print("🚀 开始训练代码修复模型...")
50
+
51
  # 加载数据
52
  data = load_training_data()
53
 
54
+ if len(data) < 5:
55
+ print("❌ 训练数据不足,至少需要5条数据")
56
+ return
57
+
58
+ # 初始化模型和分词器
59
+ model_name = "microsoft/DialoGPT-small"
60
+
61
+ try:
62
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
63
+ model = AutoModelForCausalLM.from_pretrained(model_name)
64
+
65
+ # 添加pad token如果不存在
66
+ if tokenizer.pad_token is None:
67
+ tokenizer.pad_token = tokenizer.eos_token
68
+
69
+ # 准备数据集
70
+ dataset = prepare_dataset(data)
71
+
72
+ def tokenize_function(examples):
73
+ return tokenizer(
74
+ examples["text"],
75
+ truncation=True,
76
+ padding=True,
77
+ max_length=512
78
+ )
79
+
80
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
81
+
82
+ # 训练参数
83
+ training_args = TrainingArguments(
84
+ output_dir="./codefix-model",
85
+ overwrite_output_dir=True,
86
+ num_train_epochs=3,
87
+ per_device_train_batch_size=2,
88
+ save_steps=500,
89
+ save_total_limit=2,
90
+ logging_steps=100,
91
+ prediction_loss_only=True,
92
+ remove_unused_columns=False,
93
+ )
94
+
95
+ # 数据收集器
96
+ data_collator = DataCollatorForLanguageModeling(
97
+ tokenizer=tokenizer,
98
+ mlm=False, # 不使用掩码语言模型
99
+ )
100
+
101
+ # 训练器
102
+ trainer = Trainer(
103
+ model=model,
104
+ args=training_args,
105
+ data_collator=data_collator,
106
+ train_dataset=tokenized_dataset,
107
+ )
108
+
109
+ # 开始训练
110
+ print("🔥 开始模型训练...")
111
+ trainer.train()
112
+
113
+ # 保存模型
114
+ trainer.save_model()
115
+ tokenizer.save_pretrained("./codefix-model")
116
+
117
+ print("✅ 模型训练完成!保存在 ./codefix-model 目录")
118
+
119
+ except Exception as e:
120
+ print(f"❌ 训练失败: {e}")
121
+
122
+ def incremental_train(new_feedback_file="user_feedback.json"):
123
+ """增量训练 - 基于用户反馈"""
124
+ if not os.path.exists(new_feedback_file):
125
+ print("❌ 用户反馈文件不存在")
126
+ return
127
 
128
+ with open(new_feedback_file, 'r', encoding='utf-8') as f:
129
+ feedback_data = json.load(f)
130
 
131
+ # 只使用正确的反馈作为训练数据
132
+ training_data = []
133
+ for feedback in feedback_data:
134
+ if feedback.get("correct", False):
135
+ training_data.append({
136
+ "input": feedback["original"],
137
+ "output": feedback["fixed"],
138
+ "language": feedback["language"]
139
+ })
140
 
141
+ if len(training_data) > 0:
142
+ print(f"🔄 基于 {len(training_data)} 条用户反馈进行增量训练")
143
+ # 这里可以调用训练函数进行增量训练
144
+ # 为了简化,暂时只���录
145
+ print("📝 增量训练数据已准备就绪")
146
 
147
  if __name__ == "__main__":
148
+ # 检查是否进行增量训练
149
+ import sys
150
+ if len(sys.argv) > 1 and sys.argv[1] == "incremental":
151
+ incremental_train()
152
+ else:
153
+ train_model()
154
+