import json import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) from datasets import Dataset import os def load_training_data(): """加载训练数据""" try: with open('train_data.json', 'r', encoding='utf-8') as f: data = json.load(f) print(f"📊 加载了 {len(data)} 条训练数据") return data except FileNotFoundError: print("❌ 训练数据文件不存在,使用示例数据") # 返回一些示例数据 return [ { "input": "print('hello", "output": "print('hello')", "language": "python" }, { "input": "
", "output": "
", "language": "html" } ] def prepare_dataset(data): """准备训练数据集""" texts = [] for item in data: # 创建训练文本格式 prompt = f"修复以下{item.get('language', 'code')}代码:\n{item['input']}\n修复后:\n{item['output']}" texts.append(prompt) return Dataset.from_dict({"text": texts}) def train_model(): """训练模型""" print("🚀 开始训练代码修复模型...") # 加载数据 data = load_training_data() if len(data) < 5: print("❌ 训练数据不足,至少需要5条数据") return # 初始化模型和分词器 model_name = "microsoft/DialoGPT-small" try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) # 添加pad token如果不存在 if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 准备数据集 dataset = prepare_dataset(data) def tokenize_function(examples): return tokenizer( examples["text"], truncation=True, padding=True, max_length=512 ) tokenized_dataset = dataset.map(tokenize_function, batched=True) # 训练参数 training_args = TrainingArguments( output_dir="./codefix-model", overwrite_output_dir=True, num_train_epochs=3, per_device_train_batch_size=2, save_steps=500, save_total_limit=2, logging_steps=100, prediction_loss_only=True, remove_unused_columns=False, ) # 数据收集器 data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, # 不使用掩码语言模型 ) # 训练器 trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=tokenized_dataset, ) # 开始训练 print("🔥 开始模型训练...") trainer.train() # 保存模型 trainer.save_model() tokenizer.save_pretrained("./codefix-model") print("✅ 模型训练完成!保存在 ./codefix-model 目录") except Exception as e: print(f"❌ 训练失败: {e}") def incremental_train(new_feedback_file="user_feedback.json"): """增量训练 - 基于用户反馈""" if not os.path.exists(new_feedback_file): print("❌ 用户反馈文件不存在") return with open(new_feedback_file, 'r', encoding='utf-8') as f: feedback_data = json.load(f) # 只使用正确的反馈作为训练数据 training_data = [] for feedback in feedback_data: if feedback.get("correct", False): training_data.append({ "input": feedback["original"], "output": feedback["fixed"], "language": feedback["language"] }) if len(training_data) > 0: print(f"🔄 基于 {len(training_data)} 条用户反馈进行增量训练") # 这里可以调用训练函数进行增量训练 # 为了简化,暂时只记录 print("📝 增量训练数据已准备就绪") if __name__ == "__main__": # 检查是否进行增量训练 import sys if len(sys.argv) > 1 and sys.argv[1] == "incremental": incremental_train() else: train_model()