Spaces:
Running
Running
| # 文件路径: evaluation/data_router.py | |
| """ | |
| 数据路由引擎 - 负责 SFT 数据管理和路由 | |
| 根据评估结果将样本路由到不同的数据集 | |
| """ | |
| import json | |
| import os | |
| from typing import Dict, List, Any | |
| from evaluation.models import EvaluationResult, DataQualityTier | |
| from evaluation.utils import smart_truncate, SFTLengthConfig | |
| class DataRoutingEngine: | |
| """评估驱动的数据路由引擎""" | |
| # SFT 训练提示词 | |
| SFT_INSTRUCTION = ( | |
| "你是一个专业的GitHub代码仓库分析助手。根据提供的代码上下文," | |
| "准确回答用户关于代码实现、架构设计、功能逻辑等问题。" | |
| "回答时应该:1) 直接引用相关代码 2) 解释代码的工作原理 3) 如有必要,提供代码示例。" | |
| ) | |
| def __init__(self, output_dir: str = "evaluation/sft_data"): | |
| self.output_dir = output_dir | |
| os.makedirs(output_dir, exist_ok=True) | |
| self.positive_samples_file = os.path.join(output_dir, "positive_samples.jsonl") | |
| self.negative_samples_file = os.path.join(output_dir, "negative_samples.jsonl") | |
| self.dpo_pairs_file = os.path.join(output_dir, "dpo_pairs.jsonl") | |
| self.eval_results_file = os.path.join(output_dir, "eval_results.jsonl") | |
| def route_sample(self, eval_result: EvaluationResult) -> str: | |
| """路由单个样本,返回数据质量等级""" | |
| if eval_result.overall_score == 0.0: | |
| eval_result.compute_overall_score() | |
| self.route_data(eval_result) | |
| return eval_result.data_quality_tier.value | |
| def route_data(self, eval_result: EvaluationResult) -> None: | |
| """ | |
| 根据评估结果路由数据 | |
| 路由规则: | |
| - score > 0.9 → Gold → positive_samples.jsonl | |
| - score > 0.6 → Silver → positive_samples.jsonl | |
| - score > 0.4 → Bronze → negative_samples.jsonl | |
| - score <= 0.4 → Rejected (不应到达此处,在 auto_eval 中已过滤) | |
| 注意: eval_results.jsonl 记录所有通过验证的样本,用于分析和审计 | |
| """ | |
| # 记录所有评估结果(完整审计日志) | |
| self._append_jsonl(self.eval_results_file, eval_result.to_dict()) | |
| # 根据质量分级路由到不同的 SFT 数据文件 | |
| if eval_result.overall_score > 0.9: | |
| # Gold: 高质量正样本 | |
| sft_sample = self._build_sft_sample(eval_result) | |
| self._append_jsonl(self.positive_samples_file, sft_sample) | |
| elif eval_result.overall_score > 0.6: | |
| # Silver: 可用正样本 | |
| sft_sample = self._build_sft_sample(eval_result) | |
| self._append_jsonl(self.positive_samples_file, sft_sample) | |
| elif eval_result.overall_score > 0.4: | |
| # Bronze: 负样本,可用于 DPO 或人工修正 | |
| sft_sample = self._build_sft_sample(eval_result, negative=True) | |
| self._append_jsonl(self.negative_samples_file, sft_sample) | |
| # <= 0.4: 不写入任何 SFT 文件(已在 auto_eval 中被拒绝) | |
| def _build_sft_sample(self, eval_result: EvaluationResult, negative: bool = False) -> Dict: | |
| """ | |
| 构建 SFT 训练样本 | |
| 长度限制(基于 SFTLengthConfig): | |
| - Context: 最大 2500 字符 (~800 tokens) | |
| - Answer: 最大 3000 字符 (~1000 tokens) | |
| - 总计: ~2000 tokens,适合 4096 max_length 训练 | |
| """ | |
| if eval_result.generation_metrics is None: | |
| return {} | |
| cfg = SFTLengthConfig | |
| # 1. 截断 Query | |
| query = eval_result.query | |
| if len(query) > cfg.MAX_QUERY_CHARS: | |
| query = query[:cfg.MAX_QUERY_CHARS] + "..." | |
| # 2. 智能截断 Context(保留开头 70% + 结尾 30%) | |
| context = eval_result.generation_metrics.retrieved_context | |
| context = smart_truncate(context, cfg.MAX_CONTEXT_CHARS, keep_ratio=0.7) | |
| # 3. 截断 Answer(保留开头,通常结论在开头) | |
| answer = eval_result.generation_metrics.generated_answer | |
| if len(answer) > cfg.MAX_ANSWER_CHARS: | |
| answer = answer[:cfg.MAX_ANSWER_CHARS] + "\n\n... [回答过长,已截断]" | |
| # 4. 构建 input 并检查总长度 | |
| input_text = f"[用户问题]\n{query}\n\n[代码上下文]\n{context}" | |
| # 如果总长度仍超限,进一步压缩 context | |
| total_len = len(self.SFT_INSTRUCTION) + len(input_text) + len(answer) | |
| if total_len > cfg.MAX_TOTAL_CHARS: | |
| excess = total_len - cfg.MAX_TOTAL_CHARS | |
| new_context_len = max(500, len(context) - excess) # 至少保留 500 字符 | |
| context = smart_truncate( | |
| eval_result.generation_metrics.retrieved_context, | |
| new_context_len, | |
| keep_ratio=0.7 | |
| ) | |
| input_text = f"[用户问题]\n{query}\n\n[代码上下文]\n{context}" | |
| return { | |
| "instruction": self.SFT_INSTRUCTION, | |
| "input": input_text, | |
| "output": answer, | |
| "metadata": { | |
| "query": eval_result.query[:200], # metadata 中也截断,节省空间 | |
| "repo_url": eval_result.repo_url, | |
| "language": eval_result.language, | |
| "session_id": eval_result.session_id, | |
| "timestamp": eval_result.timestamp.isoformat(), | |
| "quality_tier": eval_result.data_quality_tier.value, | |
| "overall_score": eval_result.overall_score, | |
| "faithfulness": eval_result.generation_metrics.faithfulness, | |
| "answer_relevance": eval_result.generation_metrics.answer_relevance, | |
| "answer_completeness": eval_result.generation_metrics.answer_completeness, | |
| "code_correctness": eval_result.generation_metrics.code_correctness, | |
| "is_negative": negative, | |
| "sft_ready": eval_result.sft_ready, | |
| # 记录原始长度,便于分析 | |
| "original_context_len": len(eval_result.generation_metrics.retrieved_context), | |
| "original_answer_len": len(eval_result.generation_metrics.generated_answer), | |
| "truncated": len(eval_result.generation_metrics.retrieved_context) > cfg.MAX_CONTEXT_CHARS | |
| or len(eval_result.generation_metrics.generated_answer) > cfg.MAX_ANSWER_CHARS, | |
| } | |
| } | |
| def _append_jsonl(self, filepath: str, data: Dict) -> None: | |
| """追加数据到 JSONL 文件""" | |
| with open(filepath, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data, ensure_ascii=False) + '\n') | |
| def get_statistics(self) -> Dict[str, int]: | |
| """获取当前数据统计""" | |
| stats = {} | |
| for name, filepath in [ | |
| ("positive", self.positive_samples_file), | |
| ("negative", self.negative_samples_file), | |
| ("dpo_pairs", self.dpo_pairs_file), | |
| ]: | |
| if os.path.exists(filepath): | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| stats[name] = sum(1 for _ in f) | |
| else: | |
| stats[name] = 0 | |
| return stats | |
| def get_distribution(self) -> Dict[str, int]: | |
| """获取评估结果的质量分布""" | |
| distribution = {"gold": 0, "silver": 0, "bronze": 0, "rejected": 0, "corrected": 0} | |
| if not os.path.exists(self.eval_results_file): | |
| return distribution | |
| try: | |
| with open(self.eval_results_file, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| try: | |
| result = json.loads(line) | |
| tier = result.get("data_quality_tier", "bronze") | |
| if tier in distribution: | |
| distribution[tier] += 1 | |
| except json.JSONDecodeError: | |
| continue | |
| except Exception as e: | |
| print(f"⚠️ Error reading eval results: {e}") | |
| return distribution | |
| def get_bad_samples(self, limit: int = 10) -> List[Dict[str, Any]]: | |
| """获取低质量样本用于人工审核""" | |
| bad_samples = [] | |
| if not os.path.exists(self.eval_results_file): | |
| return bad_samples | |
| try: | |
| with open(self.eval_results_file, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| try: | |
| result = json.loads(line) | |
| if result.get("overall_score", 0) < 0.5: | |
| sample = { | |
| "query": result.get("query", ""), | |
| "score": result.get("overall_score", 0), | |
| "issue": result.get("error_message", "Low quality"), | |
| "quality_tier": result.get("data_quality_tier", "rejected"), | |
| "timestamp": result.get("timestamp", "") | |
| } | |
| if result.get("generation"): | |
| gen = result["generation"] | |
| sample.update({ | |
| "faithfulness": gen.get("faithfulness", 0), | |
| "answer_relevance": gen.get("answer_relevance", 0), | |
| "answer_completeness": gen.get("answer_completeness", 0), | |
| }) | |
| bad_samples.append(sample) | |
| if len(bad_samples) >= limit: | |
| break | |
| except json.JSONDecodeError: | |
| continue | |
| except Exception as e: | |
| print(f"⚠️ Error reading bad samples: {e}") | |
| return sorted(bad_samples, key=lambda x: x["score"])[:limit] | |