Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorForLanguageModeling | |
| ) | |
| from datasets import Dataset | |
| import os | |
| def load_training_data(): | |
| """加载训练数据""" | |
| try: | |
| with open('train_data.json', 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| print(f"📊 加载了 {len(data)} 条训练数据") | |
| return data | |
| except FileNotFoundError: | |
| print("❌ 训练数据文件不存在,使用示例数据") | |
| # 返回一些示例数据 | |
| return [ | |
| { | |
| "input": "print('hello", | |
| "output": "print('hello')", | |
| "language": "python" | |
| }, | |
| { | |
| "input": "<div class=test>", | |
| "output": "<div class=\"test\">", | |
| "language": "html" | |
| } | |
| ] | |
| def prepare_dataset(data): | |
| """准备训练数据集""" | |
| texts = [] | |
| for item in data: | |
| # 创建训练文本格式 | |
| prompt = f"修复以下{item.get('language', 'code')}代码:\n{item['input']}\n修复后:\n{item['output']}" | |
| texts.append(prompt) | |
| return Dataset.from_dict({"text": texts}) | |
| def train_model(): | |
| """训练模型""" | |
| print("🚀 开始训练代码修复模型...") | |
| # 加载数据 | |
| data = load_training_data() | |
| if len(data) < 5: | |
| print("❌ 训练数据不足,至少需要5条数据") | |
| return | |
| # 初始化模型和分词器 | |
| model_name = "microsoft/DialoGPT-small" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| # 添加pad token如果不存在 | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # 准备数据集 | |
| dataset = prepare_dataset(data) | |
| def tokenize_function(examples): | |
| return tokenizer( | |
| examples["text"], | |
| truncation=True, | |
| padding=True, | |
| max_length=512 | |
| ) | |
| tokenized_dataset = dataset.map(tokenize_function, batched=True) | |
| # 训练参数 | |
| training_args = TrainingArguments( | |
| output_dir="./codefix-model", | |
| overwrite_output_dir=True, | |
| num_train_epochs=3, | |
| per_device_train_batch_size=2, | |
| save_steps=500, | |
| save_total_limit=2, | |
| logging_steps=100, | |
| prediction_loss_only=True, | |
| remove_unused_columns=False, | |
| ) | |
| # 数据收集器 | |
| data_collator = DataCollatorForLanguageModeling( | |
| tokenizer=tokenizer, | |
| mlm=False, # 不使用掩码语言模型 | |
| ) | |
| # 训练器 | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| data_collator=data_collator, | |
| train_dataset=tokenized_dataset, | |
| ) | |
| # 开始训练 | |
| print("🔥 开始模型训练...") | |
| trainer.train() | |
| # 保存模型 | |
| trainer.save_model() | |
| tokenizer.save_pretrained("./codefix-model") | |
| print("✅ 模型训练完成!保存在 ./codefix-model 目录") | |
| except Exception as e: | |
| print(f"❌ 训练失败: {e}") | |
| def incremental_train(new_feedback_file="user_feedback.json"): | |
| """增量训练 - 基于用户反馈""" | |
| if not os.path.exists(new_feedback_file): | |
| print("❌ 用户反馈文件不存在") | |
| return | |
| with open(new_feedback_file, 'r', encoding='utf-8') as f: | |
| feedback_data = json.load(f) | |
| # 只使用正确的反馈作为训练数据 | |
| training_data = [] | |
| for feedback in feedback_data: | |
| if feedback.get("correct", False): | |
| training_data.append({ | |
| "input": feedback["original"], | |
| "output": feedback["fixed"], | |
| "language": feedback["language"] | |
| }) | |
| if len(training_data) > 0: | |
| print(f"🔄 基于 {len(training_data)} 条用户反馈进行增量训练") | |
| # 这里可以调用训练函数进行增量训练 | |
| # 为了简化,暂时只记录 | |
| print("📝 增量训练数据已准备就绪") | |
| if __name__ == "__main__": | |
| # 检查是否进行增量训练 | |
| import sys | |
| if len(sys.argv) > 1 and sys.argv[1] == "incremental": | |
| incremental_train() | |
| else: | |
| train_model() | |