RepoReaper / evaluation /data_router.py
GitHub Actions Bot
deploy: auto-inject hf config & sync
1ea875f
# 文件路径: evaluation/data_router.py
"""
数据路由引擎 - 负责 SFT 数据管理和路由
根据评估结果将样本路由到不同的数据集
"""
import json
import os
from typing import Dict, List, Any
from evaluation.models import EvaluationResult, DataQualityTier
from evaluation.utils import smart_truncate, SFTLengthConfig
class DataRoutingEngine:
"""评估驱动的数据路由引擎"""
# SFT 训练提示词
SFT_INSTRUCTION = (
"你是一个专业的GitHub代码仓库分析助手。根据提供的代码上下文,"
"准确回答用户关于代码实现、架构设计、功能逻辑等问题。"
"回答时应该:1) 直接引用相关代码 2) 解释代码的工作原理 3) 如有必要,提供代码示例。"
)
def __init__(self, output_dir: str = "evaluation/sft_data"):
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
self.positive_samples_file = os.path.join(output_dir, "positive_samples.jsonl")
self.negative_samples_file = os.path.join(output_dir, "negative_samples.jsonl")
self.dpo_pairs_file = os.path.join(output_dir, "dpo_pairs.jsonl")
self.eval_results_file = os.path.join(output_dir, "eval_results.jsonl")
def route_sample(self, eval_result: EvaluationResult) -> str:
"""路由单个样本,返回数据质量等级"""
if eval_result.overall_score == 0.0:
eval_result.compute_overall_score()
self.route_data(eval_result)
return eval_result.data_quality_tier.value
def route_data(self, eval_result: EvaluationResult) -> None:
"""
根据评估结果路由数据
路由规则:
- score > 0.9 → Gold → positive_samples.jsonl
- score > 0.6 → Silver → positive_samples.jsonl
- score > 0.4 → Bronze → negative_samples.jsonl
- score <= 0.4 → Rejected (不应到达此处,在 auto_eval 中已过滤)
注意: eval_results.jsonl 记录所有通过验证的样本,用于分析和审计
"""
# 记录所有评估结果(完整审计日志)
self._append_jsonl(self.eval_results_file, eval_result.to_dict())
# 根据质量分级路由到不同的 SFT 数据文件
if eval_result.overall_score > 0.9:
# Gold: 高质量正样本
sft_sample = self._build_sft_sample(eval_result)
self._append_jsonl(self.positive_samples_file, sft_sample)
elif eval_result.overall_score > 0.6:
# Silver: 可用正样本
sft_sample = self._build_sft_sample(eval_result)
self._append_jsonl(self.positive_samples_file, sft_sample)
elif eval_result.overall_score > 0.4:
# Bronze: 负样本,可用于 DPO 或人工修正
sft_sample = self._build_sft_sample(eval_result, negative=True)
self._append_jsonl(self.negative_samples_file, sft_sample)
# <= 0.4: 不写入任何 SFT 文件(已在 auto_eval 中被拒绝)
def _build_sft_sample(self, eval_result: EvaluationResult, negative: bool = False) -> Dict:
"""
构建 SFT 训练样本
长度限制(基于 SFTLengthConfig):
- Context: 最大 2500 字符 (~800 tokens)
- Answer: 最大 3000 字符 (~1000 tokens)
- 总计: ~2000 tokens,适合 4096 max_length 训练
"""
if eval_result.generation_metrics is None:
return {}
cfg = SFTLengthConfig
# 1. 截断 Query
query = eval_result.query
if len(query) > cfg.MAX_QUERY_CHARS:
query = query[:cfg.MAX_QUERY_CHARS] + "..."
# 2. 智能截断 Context(保留开头 70% + 结尾 30%)
context = eval_result.generation_metrics.retrieved_context
context = smart_truncate(context, cfg.MAX_CONTEXT_CHARS, keep_ratio=0.7)
# 3. 截断 Answer(保留开头,通常结论在开头)
answer = eval_result.generation_metrics.generated_answer
if len(answer) > cfg.MAX_ANSWER_CHARS:
answer = answer[:cfg.MAX_ANSWER_CHARS] + "\n\n... [回答过长,已截断]"
# 4. 构建 input 并检查总长度
input_text = f"[用户问题]\n{query}\n\n[代码上下文]\n{context}"
# 如果总长度仍超限,进一步压缩 context
total_len = len(self.SFT_INSTRUCTION) + len(input_text) + len(answer)
if total_len > cfg.MAX_TOTAL_CHARS:
excess = total_len - cfg.MAX_TOTAL_CHARS
new_context_len = max(500, len(context) - excess) # 至少保留 500 字符
context = smart_truncate(
eval_result.generation_metrics.retrieved_context,
new_context_len,
keep_ratio=0.7
)
input_text = f"[用户问题]\n{query}\n\n[代码上下文]\n{context}"
return {
"instruction": self.SFT_INSTRUCTION,
"input": input_text,
"output": answer,
"metadata": {
"query": eval_result.query[:200], # metadata 中也截断,节省空间
"repo_url": eval_result.repo_url,
"language": eval_result.language,
"session_id": eval_result.session_id,
"timestamp": eval_result.timestamp.isoformat(),
"quality_tier": eval_result.data_quality_tier.value,
"overall_score": eval_result.overall_score,
"faithfulness": eval_result.generation_metrics.faithfulness,
"answer_relevance": eval_result.generation_metrics.answer_relevance,
"answer_completeness": eval_result.generation_metrics.answer_completeness,
"code_correctness": eval_result.generation_metrics.code_correctness,
"is_negative": negative,
"sft_ready": eval_result.sft_ready,
# 记录原始长度,便于分析
"original_context_len": len(eval_result.generation_metrics.retrieved_context),
"original_answer_len": len(eval_result.generation_metrics.generated_answer),
"truncated": len(eval_result.generation_metrics.retrieved_context) > cfg.MAX_CONTEXT_CHARS
or len(eval_result.generation_metrics.generated_answer) > cfg.MAX_ANSWER_CHARS,
}
}
def _append_jsonl(self, filepath: str, data: Dict) -> None:
"""追加数据到 JSONL 文件"""
with open(filepath, 'a', encoding='utf-8') as f:
f.write(json.dumps(data, ensure_ascii=False) + '\n')
def get_statistics(self) -> Dict[str, int]:
"""获取当前数据统计"""
stats = {}
for name, filepath in [
("positive", self.positive_samples_file),
("negative", self.negative_samples_file),
("dpo_pairs", self.dpo_pairs_file),
]:
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
stats[name] = sum(1 for _ in f)
else:
stats[name] = 0
return stats
def get_distribution(self) -> Dict[str, int]:
"""获取评估结果的质量分布"""
distribution = {"gold": 0, "silver": 0, "bronze": 0, "rejected": 0, "corrected": 0}
if not os.path.exists(self.eval_results_file):
return distribution
try:
with open(self.eval_results_file, 'r', encoding='utf-8') as f:
for line in f:
try:
result = json.loads(line)
tier = result.get("data_quality_tier", "bronze")
if tier in distribution:
distribution[tier] += 1
except json.JSONDecodeError:
continue
except Exception as e:
print(f"⚠️ Error reading eval results: {e}")
return distribution
def get_bad_samples(self, limit: int = 10) -> List[Dict[str, Any]]:
"""获取低质量样本用于人工审核"""
bad_samples = []
if not os.path.exists(self.eval_results_file):
return bad_samples
try:
with open(self.eval_results_file, 'r', encoding='utf-8') as f:
for line in f:
try:
result = json.loads(line)
if result.get("overall_score", 0) < 0.5:
sample = {
"query": result.get("query", ""),
"score": result.get("overall_score", 0),
"issue": result.get("error_message", "Low quality"),
"quality_tier": result.get("data_quality_tier", "rejected"),
"timestamp": result.get("timestamp", "")
}
if result.get("generation"):
gen = result["generation"]
sample.update({
"faithfulness": gen.get("faithfulness", 0),
"answer_relevance": gen.get("answer_relevance", 0),
"answer_completeness": gen.get("answer_completeness", 0),
})
bad_samples.append(sample)
if len(bad_samples) >= limit:
break
except json.JSONDecodeError:
continue
except Exception as e:
print(f"⚠️ Error reading bad samples: {e}")
return sorted(bad_samples, key=lambda x: x["score"])[:limit]