capricode-codefix / train.py
pangxiang's picture
Update train.py
b5bd3c2 verified
import json
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset
import os
def load_training_data():
"""加载训练数据"""
try:
with open('train_data.json', 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"📊 加载了 {len(data)} 条训练数据")
return data
except FileNotFoundError:
print("❌ 训练数据文件不存在,使用示例数据")
# 返回一些示例数据
return [
{
"input": "print('hello",
"output": "print('hello')",
"language": "python"
},
{
"input": "<div class=test>",
"output": "<div class=\"test\">",
"language": "html"
}
]
def prepare_dataset(data):
"""准备训练数据集"""
texts = []
for item in data:
# 创建训练文本格式
prompt = f"修复以下{item.get('language', 'code')}代码:\n{item['input']}\n修复后:\n{item['output']}"
texts.append(prompt)
return Dataset.from_dict({"text": texts})
def train_model():
"""训练模型"""
print("🚀 开始训练代码修复模型...")
# 加载数据
data = load_training_data()
if len(data) < 5:
print("❌ 训练数据不足,至少需要5条数据")
return
# 初始化模型和分词器
model_name = "microsoft/DialoGPT-small"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 添加pad token如果不存在
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 准备数据集
dataset = prepare_dataset(data)
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
padding=True,
max_length=512
)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# 训练参数
training_args = TrainingArguments(
output_dir="./codefix-model",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=2,
save_steps=500,
save_total_limit=2,
logging_steps=100,
prediction_loss_only=True,
remove_unused_columns=False,
)
# 数据收集器
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # 不使用掩码语言模型
)
# 训练器
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset,
)
# 开始训练
print("🔥 开始模型训练...")
trainer.train()
# 保存模型
trainer.save_model()
tokenizer.save_pretrained("./codefix-model")
print("✅ 模型训练完成!保存在 ./codefix-model 目录")
except Exception as e:
print(f"❌ 训练失败: {e}")
def incremental_train(new_feedback_file="user_feedback.json"):
"""增量训练 - 基于用户反馈"""
if not os.path.exists(new_feedback_file):
print("❌ 用户反馈文件不存在")
return
with open(new_feedback_file, 'r', encoding='utf-8') as f:
feedback_data = json.load(f)
# 只使用正确的反馈作为训练数据
training_data = []
for feedback in feedback_data:
if feedback.get("correct", False):
training_data.append({
"input": feedback["original"],
"output": feedback["fixed"],
"language": feedback["language"]
})
if len(training_data) > 0:
print(f"🔄 基于 {len(training_data)} 条用户反馈进行增量训练")
# 这里可以调用训练函数进行增量训练
# 为了简化,暂时只记录
print("📝 增量训练数据已准备就绪")
if __name__ == "__main__":
# 检查是否进行增量训练
import sys
if len(sys.argv) > 1 and sys.argv[1] == "incremental":
incremental_train()
else:
train_model()