Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| SFT 数据清洗与导出脚本 | |
| 功能: | |
| 1. 从 eval_results.jsonl 读取原始评估数据 | |
| 2. 应用严格的质量过滤规则 | |
| 3. 转换为标准 SFT 训练格式 | |
| 4. 导出为可直接用于训练的数据集 | |
| Author: Dexter | |
| Date: 2026-01-28 | |
| """ | |
| import json | |
| import os | |
| from datetime import datetime | |
| from typing import Dict, List, Tuple | |
| from pathlib import Path | |
| from evaluation.utils import is_chatty_query, has_code_indicators | |
| # ============================================================================ | |
| # 配置 | |
| # ============================================================================ | |
| class CleaningConfig: | |
| """数据清洗配置""" | |
| # 质量阈值 | |
| MIN_OVERALL_SCORE = 0.7 # 最低综合分 | |
| MIN_FAITHFULNESS = 0.6 # 最低 faithfulness | |
| MIN_ANSWER_RELEVANCE = 0.6 # 最低 answer_relevance | |
| # 长度阈值 | |
| MIN_QUERY_LENGTH = 10 # 最短 query | |
| MIN_ANSWER_LENGTH = 100 # 最短 answer | |
| MIN_CONTEXT_LENGTH = 50 # 最短 context | |
| MAX_CONTEXT_LENGTH = 4000 # 最长 context(截断) | |
| # 必须条件 | |
| REQUIRE_REPO_URL = True # 必须有仓库 URL | |
| REQUIRE_CODE_IN_CONTEXT = True # 上下文必须包含代码 | |
| # 输出配置 | |
| OUTPUT_DIR = "evaluation/sft_data/cleaned" | |
| # ============================================================================ | |
| # 数据清洗逻辑 | |
| # ============================================================================ | |
| def validate_sample(sample: Dict, config: CleaningConfig) -> Tuple[bool, str]: | |
| """ | |
| 验证单个样本是否符合质量标准 | |
| Returns: | |
| (is_valid, rejection_reason) | |
| """ | |
| # 1. 检查基本字段存在 | |
| if not sample.get("query"): | |
| return False, "missing_query" | |
| if not sample.get("generation"): | |
| return False, "missing_generation" | |
| gen = sample["generation"] | |
| # 2. 检查 repo_url | |
| if config.REQUIRE_REPO_URL and not sample.get("repo_url"): | |
| return False, "missing_repo_url" | |
| # 3. 检查质量分数 | |
| overall_score = sample.get("overall_score", 0) | |
| if overall_score < config.MIN_OVERALL_SCORE: | |
| return False, f"low_score:{overall_score:.2f}" | |
| faithfulness = gen.get("faithfulness", 0) | |
| if faithfulness < config.MIN_FAITHFULNESS: | |
| return False, f"low_faithfulness:{faithfulness:.2f}" | |
| answer_relevance = gen.get("answer_relevance", 0) | |
| if answer_relevance < config.MIN_ANSWER_RELEVANCE: | |
| return False, f"low_relevance:{answer_relevance:.2f}" | |
| # 4. 检查长度 | |
| query = sample.get("query", "") | |
| if len(query) < config.MIN_QUERY_LENGTH: | |
| return False, f"short_query:{len(query)}" | |
| answer = gen.get("generated_answer", "") | |
| if len(answer) < config.MIN_ANSWER_LENGTH: | |
| return False, f"short_answer:{len(answer)}" | |
| context = gen.get("retrieved_context", "") | |
| if len(context) < config.MIN_CONTEXT_LENGTH: | |
| return False, f"short_context:{len(context)}" | |
| # 5. 检查闲聊 | |
| if is_chatty_query(query): | |
| return False, "chatty_query" | |
| # 6. 检查代码存在 | |
| if config.REQUIRE_CODE_IN_CONTEXT and not has_code_indicators(context): | |
| return False, "no_code_in_context" | |
| return True, "passed" | |
| def transform_to_sft_format(sample: Dict, config: CleaningConfig) -> Dict: | |
| """ | |
| 将原始评估数据转换为标准 SFT 格式 | |
| """ | |
| gen = sample["generation"] | |
| # 清理和截断 context | |
| context = gen.get("retrieved_context", "") | |
| if len(context) > config.MAX_CONTEXT_LENGTH: | |
| context = context[:config.MAX_CONTEXT_LENGTH] + "\n... [truncated]" | |
| # 构建标准 SFT 格式 | |
| sft_sample = { | |
| # === 核心训练字段 === | |
| "instruction": "你是一个专业的GitHub代码仓库分析助手。根据提供的代码上下文,准确回答用户关于代码实现、架构设计、功能逻辑等问题。回答时应该:1) 直接引用相关代码 2) 解释代码的工作原理 3) 如有必要,提供代码示例。", | |
| "input": f"[用户问题]\n{sample['query']}\n\n[代码上下文]\n{context}", | |
| "output": gen.get("generated_answer", ""), | |
| # === 元数据 === | |
| "metadata": { | |
| "query": sample["query"], | |
| "repo_url": sample.get("repo_url", ""), | |
| "language": sample.get("language", "en"), | |
| "session_id": sample.get("session_id", ""), | |
| "timestamp": sample.get("timestamp", ""), | |
| "quality_tier": sample.get("data_quality_tier", ""), | |
| "overall_score": sample.get("overall_score", 0), | |
| "faithfulness": gen.get("faithfulness", 0), | |
| "answer_relevance": gen.get("answer_relevance", 0), | |
| "answer_completeness": gen.get("answer_completeness", 0), | |
| "code_correctness": gen.get("code_correctness", 0), | |
| } | |
| } | |
| return sft_sample | |
| def clean_and_export( | |
| input_file: str = "evaluation/sft_data/eval_results.jsonl", | |
| config: CleaningConfig = None | |
| ) -> Dict: | |
| """ | |
| 清洗数据并导出 | |
| Returns: | |
| 统计信息 | |
| """ | |
| config = config or CleaningConfig() | |
| # 创建输出目录 | |
| output_dir = Path(config.OUTPUT_DIR) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # 统计 | |
| stats = { | |
| "total_read": 0, | |
| "passed": 0, | |
| "rejected": 0, | |
| "rejection_reasons": {}, | |
| "quality_distribution": {"gold": 0, "silver": 0, "bronze": 0} | |
| } | |
| # 输出文件 | |
| output_file = output_dir / f"sft_train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" | |
| rejected_file = output_dir / f"rejected_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" | |
| print("=" * 60) | |
| print("🧹 SFT 数据清洗与导出") | |
| print("=" * 60) | |
| print(f"输入文件: {input_file}") | |
| print(f"输出目录: {output_dir}") | |
| print(f"质量阈值: score >= {config.MIN_OVERALL_SCORE}") | |
| print() | |
| if not os.path.exists(input_file): | |
| print(f"❌ 输入文件不存在: {input_file}") | |
| return stats | |
| passed_samples = [] | |
| rejected_samples = [] | |
| # 读取并处理 | |
| with open(input_file, 'r', encoding='utf-8') as f: | |
| for line_num, line in enumerate(f, 1): | |
| try: | |
| sample = json.loads(line) | |
| stats["total_read"] += 1 | |
| # 验证 | |
| is_valid, reason = validate_sample(sample, config) | |
| if is_valid: | |
| # 转换格式 | |
| sft_sample = transform_to_sft_format(sample, config) | |
| passed_samples.append(sft_sample) | |
| stats["passed"] += 1 | |
| # 统计质量分布 | |
| score = sample.get("overall_score", 0) | |
| if score > 0.9: | |
| stats["quality_distribution"]["gold"] += 1 | |
| elif score > 0.7: | |
| stats["quality_distribution"]["silver"] += 1 | |
| else: | |
| stats["quality_distribution"]["bronze"] += 1 | |
| else: | |
| rejected_samples.append({ | |
| "reason": reason, | |
| "query": sample.get("query", "")[:50], | |
| "score": sample.get("overall_score", 0) | |
| }) | |
| stats["rejected"] += 1 | |
| stats["rejection_reasons"][reason] = stats["rejection_reasons"].get(reason, 0) + 1 | |
| except json.JSONDecodeError as e: | |
| print(f" ⚠️ 第 {line_num} 行 JSON 解析错误: {e}") | |
| continue | |
| # 写入通过的样本 | |
| if passed_samples: | |
| with open(output_file, 'w', encoding='utf-8') as f: | |
| for sample in passed_samples: | |
| f.write(json.dumps(sample, ensure_ascii=False) + '\n') | |
| print(f"✅ 已导出 {len(passed_samples)} 条高质量样本到: {output_file}") | |
| # 写入拒绝的样本(用于分析) | |
| if rejected_samples: | |
| with open(rejected_file, 'w', encoding='utf-8') as f: | |
| for sample in rejected_samples: | |
| f.write(json.dumps(sample, ensure_ascii=False) + '\n') | |
| print(f"📝 已记录 {len(rejected_samples)} 条被拒绝样本到: {rejected_file}") | |
| # 打印统计 | |
| print() | |
| print("=" * 60) | |
| print("📊 统计报告") | |
| print("=" * 60) | |
| print(f"总读取: {stats['total_read']}") | |
| print(f"通过: {stats['passed']} ({stats['passed']/max(stats['total_read'],1)*100:.1f}%)") | |
| print(f"拒绝: {stats['rejected']} ({stats['rejected']/max(stats['total_read'],1)*100:.1f}%)") | |
| print() | |
| print("质量分布:") | |
| print(f" 🥇 Gold (>0.9): {stats['quality_distribution']['gold']}") | |
| print(f" 🥈 Silver (>0.7): {stats['quality_distribution']['silver']}") | |
| print(f" 🥉 Bronze (>0.5): {stats['quality_distribution']['bronze']}") | |
| print() | |
| if stats["rejection_reasons"]: | |
| print("拒绝原因分布:") | |
| for reason, count in sorted(stats["rejection_reasons"].items(), key=lambda x: -x[1]): | |
| print(f" - {reason}: {count}") | |
| print() | |
| print("=" * 60) | |
| return stats | |
| def export_for_training( | |
| input_file: str, | |
| output_file: str, | |
| format_type: str = "alpaca" | |
| ) -> int: | |
| """ | |
| 将清洗后的数据导出为特定训练格式 | |
| Args: | |
| input_file: 清洗后的 JSONL 文件 | |
| output_file: 输出文件 | |
| format_type: 格式类型 (alpaca, sharegpt, messages) | |
| Returns: | |
| 导出的样本数量 | |
| """ | |
| samples = [] | |
| with open(input_file, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| sample = json.loads(line) | |
| if format_type == "alpaca": | |
| # Alpaca 格式(适用于 LLaMA-Factory 等) | |
| formatted = { | |
| "instruction": sample["instruction"], | |
| "input": sample["input"], | |
| "output": sample["output"] | |
| } | |
| elif format_type == "sharegpt": | |
| # ShareGPT 格式 | |
| formatted = { | |
| "conversations": [ | |
| {"from": "system", "value": sample["instruction"]}, | |
| {"from": "human", "value": sample["input"]}, | |
| {"from": "gpt", "value": sample["output"]} | |
| ] | |
| } | |
| elif format_type == "messages": | |
| # OpenAI messages 格式 | |
| formatted = { | |
| "messages": [ | |
| {"role": "system", "content": sample["instruction"]}, | |
| {"role": "user", "content": sample["input"]}, | |
| {"role": "assistant", "content": sample["output"]} | |
| ] | |
| } | |
| else: | |
| formatted = sample | |
| samples.append(formatted) | |
| # 写入 | |
| with open(output_file, 'w', encoding='utf-8') as f: | |
| if output_file.endswith('.json'): | |
| json.dump(samples, f, ensure_ascii=False, indent=2) | |
| else: | |
| for sample in samples: | |
| f.write(json.dumps(sample, ensure_ascii=False) + '\n') | |
| print(f"✅ 已导出 {len(samples)} 条样本为 {format_type} 格式: {output_file}") | |
| return len(samples) | |
| # ============================================================================ | |
| # 主函数 | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="SFT 数据清洗与导出工具") | |
| parser.add_argument("--input", "-i", default="evaluation/sft_data/eval_results.jsonl", | |
| help="输入文件路径") | |
| parser.add_argument("--min-score", "-s", type=float, default=0.7, | |
| help="最低质量分数 (默认: 0.7)") | |
| parser.add_argument("--format", "-f", choices=["alpaca", "sharegpt", "messages"], | |
| default="alpaca", help="导出格式 (默认: alpaca)") | |
| parser.add_argument("--export", "-e", action="store_true", | |
| help="同时导出为训练格式") | |
| args = parser.parse_args() | |
| # 配置 | |
| config = CleaningConfig() | |
| config.MIN_OVERALL_SCORE = args.min_score | |
| # 清洗 | |
| stats = clean_and_export(args.input, config) | |
| # 导出为训练格式 | |
| if args.export and stats["passed"] > 0: | |
| # 找到最新的清洗文件 | |
| output_dir = Path(config.OUTPUT_DIR) | |
| cleaned_files = sorted(output_dir.glob("sft_train_*.jsonl"), reverse=True) | |
| if cleaned_files: | |
| latest_file = cleaned_files[0] | |
| export_file = output_dir / f"train_{args.format}.jsonl" | |
| export_for_training(str(latest_file), str(export_file), args.format) | |