Spaces:
Running
Running
File size: 12,456 Bytes
1ea875f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 | #!/usr/bin/env python3
"""
检索系统离线评估脚本
用于测试 chunking 和检索策略的准确率。
使用 golden_dataset.json 中的标注数据作为 ground truth。
使用方法:
python evaluation/test_retrieval.py --repo https://github.com/tiangolo/fastapi
python evaluation/test_retrieval.py --repo https://github.com/tiangolo/fastapi --top-k 5
python evaluation/test_retrieval.py --repo https://github.com/tiangolo/fastapi --verbose
Author: Dexter
Date: 2026-01-28
"""
import json
import os
import sys
import asyncio
import argparse
from typing import List, Dict, Tuple
from dataclasses import dataclass, field
from datetime import datetime
# 添加项目根目录到 path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.services.vector_service import store_manager
from app.services.github_service import get_repo_structure
@dataclass
class RetrievalTestResult:
"""单个测试用例的结果"""
query: str
expected_files: List[str]
retrieved_files: List[str]
hit: bool # 是否命中任意一个预期文件
recall: float # 召回率: 命中的预期文件 / 总预期文件
precision: float # 精确率: 命中的预期文件 / 检索结果数
reciprocal_rank: float # 倒数排名: 1 / 第一个命中的位置
difficulty: str = ""
category: str = ""
@dataclass
class EvaluationReport:
"""完整评估报告"""
repo_url: str
top_k: int
total_queries: int
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
# 聚合指标
hit_rate: float = 0.0 # 命中率: 至少命中一个的查询比例
mean_recall: float = 0.0 # 平均召回率
mean_precision: float = 0.0 # 平均精确率
mrr: float = 0.0 # Mean Reciprocal Rank
# 按难度分组
by_difficulty: Dict[str, Dict] = field(default_factory=dict)
# 详细结果
results: List[RetrievalTestResult] = field(default_factory=list)
failed_cases: List[Dict] = field(default_factory=list)
class RetrievalEvaluator:
"""检索系统评估器"""
def __init__(self, golden_dataset_path: str = "evaluation/golden_dataset.json"):
self.golden_dataset = self._load_golden_dataset(golden_dataset_path)
print(f"📊 Loaded {len(self.golden_dataset)} test cases from golden dataset")
def _load_golden_dataset(self, path: str) -> List[Dict]:
"""加载黄金数据集"""
if not os.path.exists(path):
raise FileNotFoundError(f"Golden dataset not found: {path}")
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
async def evaluate(
self,
repo_url: str,
session_id: str = "eval_test",
top_k: int = 5,
verbose: bool = False
) -> EvaluationReport:
"""
运行完整的检索评估
Args:
repo_url: 要评估的仓库 URL
session_id: 会话 ID
top_k: 每次检索返回的文件数
verbose: 是否打印详细信息
"""
print(f"\n{'='*60}")
print(f"🔍 Retrieval Evaluation")
print(f"{'='*60}")
print(f"Repository: {repo_url}")
print(f"Top-K: {top_k}")
print(f"Test Cases: {len(self.golden_dataset)}")
print(f"{'='*60}\n")
# 获取仓库文件列表
print("📂 Fetching repository structure...")
file_list = get_repo_structure(repo_url) # 同步函数,不需要 await
print(f" Found {len(file_list)} files")
# 获取向量存储
store = store_manager.get_store(session_id)
chunk_count = store.collection.count() # 使用 collection.count()
if chunk_count == 0:
print("\n⚠️ Vector store is empty!")
print(" Please run the agent first to index the repository.")
print(" Example: Access http://localhost:8000 and analyze the repo first.")
return None
print(f" Vector store has {chunk_count} chunks")
# 运行评估
report = EvaluationReport(
repo_url=repo_url,
top_k=top_k,
total_queries=len(self.golden_dataset)
)
hits = 0
recalls = []
precisions = []
reciprocal_ranks = []
difficulty_stats = {}
for i, sample in enumerate(self.golden_dataset):
query = sample.get("query", "")
expected_files = sample.get("expected_files", [])
difficulty = sample.get("difficulty", "medium")
category = sample.get("category", "general")
if not query or not expected_files:
continue
# 执行检索 (使用 hybrid search)
try:
results = await store.search_hybrid(query, top_k=top_k)
except Exception as e:
if verbose:
print(f" [ERR] Search failed: {e}")
continue
# 提取检索到的文件路径
retrieved_files = []
for doc in results:
if isinstance(doc, dict):
file_path = doc.get("file", "")
if file_path and file_path not in retrieved_files:
retrieved_files.append(file_path)
# 计算指标
expected_set = set(expected_files)
retrieved_set = set(retrieved_files[:top_k])
# 命中的文件
hits_set = expected_set & retrieved_set
# Hit: 是否命中任意一个
hit = len(hits_set) > 0
if hit:
hits += 1
# Recall: 命中的 / 期望的
recall = len(hits_set) / len(expected_set) if expected_set else 0
recalls.append(recall)
# Precision: 命中的 / 检索的
precision = len(hits_set) / min(len(retrieved_files), top_k) if retrieved_files else 0
precisions.append(precision)
# Reciprocal Rank: 1 / 第一个命中的位置
rr = 0.0
for rank, file in enumerate(retrieved_files[:top_k], 1):
if file in expected_set:
rr = 1.0 / rank
break
reciprocal_ranks.append(rr)
# 记录结果
result = RetrievalTestResult(
query=query,
expected_files=expected_files,
retrieved_files=retrieved_files[:top_k],
hit=hit,
recall=recall,
precision=precision,
reciprocal_rank=rr,
difficulty=difficulty,
category=category
)
report.results.append(result)
# 按难度统计
if difficulty not in difficulty_stats:
difficulty_stats[difficulty] = {"hits": 0, "total": 0, "recalls": [], "precisions": []}
difficulty_stats[difficulty]["total"] += 1
if hit:
difficulty_stats[difficulty]["hits"] += 1
difficulty_stats[difficulty]["recalls"].append(recall)
difficulty_stats[difficulty]["precisions"].append(precision)
# 记录失败案例
if not hit:
report.failed_cases.append({
"query": query,
"expected": expected_files,
"retrieved": retrieved_files[:top_k],
"difficulty": difficulty
})
# 打印进度
if verbose:
status = "✅" if hit else "❌"
print(f" [{i+1:3d}] {status} Recall={recall:.2f} | {query[:50]}...")
else:
print(f"\r Progress: {i+1}/{len(self.golden_dataset)}", end="")
print("\n")
# 计算聚合指标
report.hit_rate = hits / len(self.golden_dataset) if self.golden_dataset else 0
report.mean_recall = sum(recalls) / len(recalls) if recalls else 0
report.mean_precision = sum(precisions) / len(precisions) if precisions else 0
report.mrr = sum(reciprocal_ranks) / len(reciprocal_ranks) if reciprocal_ranks else 0
# 按难度汇总
for diff, stats in difficulty_stats.items():
report.by_difficulty[diff] = {
"hit_rate": stats["hits"] / stats["total"] if stats["total"] else 0,
"mean_recall": sum(stats["recalls"]) / len(stats["recalls"]) if stats["recalls"] else 0,
"mean_precision": sum(stats["precisions"]) / len(stats["precisions"]) if stats["precisions"] else 0,
"total": stats["total"]
}
return report
def print_report(self, report: EvaluationReport):
"""打印评估报告"""
print(f"\n{'='*60}")
print(f"📊 RETRIEVAL EVALUATION REPORT")
print(f"{'='*60}")
print(f"Repository: {report.repo_url}")
print(f"Top-K: {report.top_k}")
print(f"Total Queries: {report.total_queries}")
print(f"Timestamp: {report.timestamp}")
print(f"{'='*60}\n")
print("📈 OVERALL METRICS")
print(f" Hit Rate: {report.hit_rate:.1%}")
print(f" Mean Recall: {report.mean_recall:.1%}")
print(f" Mean Precision: {report.mean_precision:.1%}")
print(f" MRR: {report.mrr:.3f}")
print(f"\n📊 BY DIFFICULTY")
for diff, stats in sorted(report.by_difficulty.items()):
print(f" {diff.upper():8s} | Hit: {stats['hit_rate']:.1%} | Recall: {stats['mean_recall']:.1%} | n={stats['total']}")
if report.failed_cases:
print(f"\n❌ FAILED CASES ({len(report.failed_cases)} total)")
for case in report.failed_cases[:5]: # 只显示前5个
print(f" Query: {case['query'][:60]}...")
print(f" Expected: {case['expected']}")
print(f" Got: {case['retrieved'][:3]}...")
print()
print(f"{'='*60}")
def save_report(self, report: EvaluationReport, output_path: str = "evaluation/retrieval_report.json"):
"""保存报告到文件"""
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 转换为可序列化格式
data = {
"repo_url": report.repo_url,
"top_k": report.top_k,
"total_queries": report.total_queries,
"timestamp": report.timestamp,
"metrics": {
"hit_rate": report.hit_rate,
"mean_recall": report.mean_recall,
"mean_precision": report.mean_precision,
"mrr": report.mrr
},
"by_difficulty": report.by_difficulty,
"failed_cases": report.failed_cases
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"\n💾 Report saved to: {output_path}")
async def main():
parser = argparse.ArgumentParser(description="Evaluate retrieval system using golden dataset")
parser.add_argument("--repo", required=True, help="GitHub repository URL to evaluate")
parser.add_argument("--top-k", type=int, default=5, help="Number of results to retrieve (default: 5)")
parser.add_argument("--session", default="eval_test", help="Session ID for vector store")
parser.add_argument("--verbose", "-v", action="store_true", help="Print detailed results")
parser.add_argument("--save", action="store_true", help="Save report to file")
args = parser.parse_args()
evaluator = RetrievalEvaluator()
report = await evaluator.evaluate(
repo_url=args.repo,
session_id=args.session,
top_k=args.top_k,
verbose=args.verbose
)
if report:
evaluator.print_report(report)
if args.save:
evaluator.save_report(report)
if __name__ == "__main__":
asyncio.run(main())
|