File size: 12,456 Bytes
1ea875f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
#!/usr/bin/env python3
"""
检索系统离线评估脚本

用于测试 chunking 和检索策略的准确率。
使用 golden_dataset.json 中的标注数据作为 ground truth。

使用方法:
    python evaluation/test_retrieval.py --repo https://github.com/tiangolo/fastapi
    python evaluation/test_retrieval.py --repo https://github.com/tiangolo/fastapi --top-k 5
    python evaluation/test_retrieval.py --repo https://github.com/tiangolo/fastapi --verbose

Author: Dexter
Date: 2026-01-28
"""

import json
import os
import sys
import asyncio
import argparse
from typing import List, Dict, Tuple
from dataclasses import dataclass, field
from datetime import datetime

# 添加项目根目录到 path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from app.services.vector_service import store_manager
from app.services.github_service import get_repo_structure


@dataclass
class RetrievalTestResult:
    """单个测试用例的结果"""
    query: str
    expected_files: List[str]
    retrieved_files: List[str]
    hit: bool                      # 是否命中任意一个预期文件
    recall: float                  # 召回率: 命中的预期文件 / 总预期文件
    precision: float               # 精确率: 命中的预期文件 / 检索结果数
    reciprocal_rank: float         # 倒数排名: 1 / 第一个命中的位置
    difficulty: str = ""
    category: str = ""


@dataclass
class EvaluationReport:
    """完整评估报告"""
    repo_url: str
    top_k: int
    total_queries: int
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    
    # 聚合指标
    hit_rate: float = 0.0          # 命中率: 至少命中一个的查询比例
    mean_recall: float = 0.0       # 平均召回率
    mean_precision: float = 0.0    # 平均精确率
    mrr: float = 0.0               # Mean Reciprocal Rank
    
    # 按难度分组
    by_difficulty: Dict[str, Dict] = field(default_factory=dict)
    
    # 详细结果
    results: List[RetrievalTestResult] = field(default_factory=list)
    failed_cases: List[Dict] = field(default_factory=list)


class RetrievalEvaluator:
    """检索系统评估器"""
    
    def __init__(self, golden_dataset_path: str = "evaluation/golden_dataset.json"):
        self.golden_dataset = self._load_golden_dataset(golden_dataset_path)
        print(f"📊 Loaded {len(self.golden_dataset)} test cases from golden dataset")
    
    def _load_golden_dataset(self, path: str) -> List[Dict]:
        """加载黄金数据集"""
        if not os.path.exists(path):
            raise FileNotFoundError(f"Golden dataset not found: {path}")
        
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    
    async def evaluate(
        self,
        repo_url: str,
        session_id: str = "eval_test",
        top_k: int = 5,
        verbose: bool = False
    ) -> EvaluationReport:
        """
        运行完整的检索评估
        
        Args:
            repo_url: 要评估的仓库 URL
            session_id: 会话 ID
            top_k: 每次检索返回的文件数
            verbose: 是否打印详细信息
        """
        print(f"\n{'='*60}")
        print(f"🔍 Retrieval Evaluation")
        print(f"{'='*60}")
        print(f"Repository: {repo_url}")
        print(f"Top-K: {top_k}")
        print(f"Test Cases: {len(self.golden_dataset)}")
        print(f"{'='*60}\n")
        
        # 获取仓库文件列表
        print("📂 Fetching repository structure...")
        file_list = get_repo_structure(repo_url)  # 同步函数,不需要 await
        print(f"   Found {len(file_list)} files")
        
        # 获取向量存储
        store = store_manager.get_store(session_id)
        chunk_count = store.collection.count()  # 使用 collection.count()
        if chunk_count == 0:
            print("\n⚠️  Vector store is empty!")
            print("   Please run the agent first to index the repository.")
            print("   Example: Access http://localhost:8000 and analyze the repo first.")
            return None
        print(f"   Vector store has {chunk_count} chunks")
        
        # 运行评估
        report = EvaluationReport(
            repo_url=repo_url,
            top_k=top_k,
            total_queries=len(self.golden_dataset)
        )
        
        hits = 0
        recalls = []
        precisions = []
        reciprocal_ranks = []
        
        difficulty_stats = {}
        
        for i, sample in enumerate(self.golden_dataset):
            query = sample.get("query", "")
            expected_files = sample.get("expected_files", [])
            difficulty = sample.get("difficulty", "medium")
            category = sample.get("category", "general")
            
            if not query or not expected_files:
                continue
            
            # 执行检索 (使用 hybrid search)
            try:
                results = await store.search_hybrid(query, top_k=top_k)
            except Exception as e:
                if verbose:
                    print(f"  [ERR] Search failed: {e}")
                continue
            
            # 提取检索到的文件路径
            retrieved_files = []
            for doc in results:
                if isinstance(doc, dict):
                    file_path = doc.get("file", "")
                    if file_path and file_path not in retrieved_files:
                        retrieved_files.append(file_path)
            
            # 计算指标
            expected_set = set(expected_files)
            retrieved_set = set(retrieved_files[:top_k])
            
            # 命中的文件
            hits_set = expected_set & retrieved_set
            
            # Hit: 是否命中任意一个
            hit = len(hits_set) > 0
            if hit:
                hits += 1
            
            # Recall: 命中的 / 期望的
            recall = len(hits_set) / len(expected_set) if expected_set else 0
            recalls.append(recall)
            
            # Precision: 命中的 / 检索的
            precision = len(hits_set) / min(len(retrieved_files), top_k) if retrieved_files else 0
            precisions.append(precision)
            
            # Reciprocal Rank: 1 / 第一个命中的位置
            rr = 0.0
            for rank, file in enumerate(retrieved_files[:top_k], 1):
                if file in expected_set:
                    rr = 1.0 / rank
                    break
            reciprocal_ranks.append(rr)
            
            # 记录结果
            result = RetrievalTestResult(
                query=query,
                expected_files=expected_files,
                retrieved_files=retrieved_files[:top_k],
                hit=hit,
                recall=recall,
                precision=precision,
                reciprocal_rank=rr,
                difficulty=difficulty,
                category=category
            )
            report.results.append(result)
            
            # 按难度统计
            if difficulty not in difficulty_stats:
                difficulty_stats[difficulty] = {"hits": 0, "total": 0, "recalls": [], "precisions": []}
            difficulty_stats[difficulty]["total"] += 1
            if hit:
                difficulty_stats[difficulty]["hits"] += 1
            difficulty_stats[difficulty]["recalls"].append(recall)
            difficulty_stats[difficulty]["precisions"].append(precision)
            
            # 记录失败案例
            if not hit:
                report.failed_cases.append({
                    "query": query,
                    "expected": expected_files,
                    "retrieved": retrieved_files[:top_k],
                    "difficulty": difficulty
                })
            
            # 打印进度
            if verbose:
                status = "✅" if hit else "❌"
                print(f"  [{i+1:3d}] {status} Recall={recall:.2f} | {query[:50]}...")
            else:
                print(f"\r  Progress: {i+1}/{len(self.golden_dataset)}", end="")
        
        print("\n")
        
        # 计算聚合指标
        report.hit_rate = hits / len(self.golden_dataset) if self.golden_dataset else 0
        report.mean_recall = sum(recalls) / len(recalls) if recalls else 0
        report.mean_precision = sum(precisions) / len(precisions) if precisions else 0
        report.mrr = sum(reciprocal_ranks) / len(reciprocal_ranks) if reciprocal_ranks else 0
        
        # 按难度汇总
        for diff, stats in difficulty_stats.items():
            report.by_difficulty[diff] = {
                "hit_rate": stats["hits"] / stats["total"] if stats["total"] else 0,
                "mean_recall": sum(stats["recalls"]) / len(stats["recalls"]) if stats["recalls"] else 0,
                "mean_precision": sum(stats["precisions"]) / len(stats["precisions"]) if stats["precisions"] else 0,
                "total": stats["total"]
            }
        
        return report
    
    def print_report(self, report: EvaluationReport):
        """打印评估报告"""
        print(f"\n{'='*60}")
        print(f"📊 RETRIEVAL EVALUATION REPORT")
        print(f"{'='*60}")
        print(f"Repository: {report.repo_url}")
        print(f"Top-K: {report.top_k}")
        print(f"Total Queries: {report.total_queries}")
        print(f"Timestamp: {report.timestamp}")
        print(f"{'='*60}\n")
        
        print("📈 OVERALL METRICS")
        print(f"   Hit Rate:       {report.hit_rate:.1%}")
        print(f"   Mean Recall:    {report.mean_recall:.1%}")
        print(f"   Mean Precision: {report.mean_precision:.1%}")
        print(f"   MRR:            {report.mrr:.3f}")
        
        print(f"\n📊 BY DIFFICULTY")
        for diff, stats in sorted(report.by_difficulty.items()):
            print(f"   {diff.upper():8s} | Hit: {stats['hit_rate']:.1%} | Recall: {stats['mean_recall']:.1%} | n={stats['total']}")
        
        if report.failed_cases:
            print(f"\n❌ FAILED CASES ({len(report.failed_cases)} total)")
            for case in report.failed_cases[:5]:  # 只显示前5个
                print(f"   Query: {case['query'][:60]}...")
                print(f"   Expected: {case['expected']}")
                print(f"   Got: {case['retrieved'][:3]}...")
                print()
        
        print(f"{'='*60}")
    
    def save_report(self, report: EvaluationReport, output_path: str = "evaluation/retrieval_report.json"):
        """保存报告到文件"""
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # 转换为可序列化格式
        data = {
            "repo_url": report.repo_url,
            "top_k": report.top_k,
            "total_queries": report.total_queries,
            "timestamp": report.timestamp,
            "metrics": {
                "hit_rate": report.hit_rate,
                "mean_recall": report.mean_recall,
                "mean_precision": report.mean_precision,
                "mrr": report.mrr
            },
            "by_difficulty": report.by_difficulty,
            "failed_cases": report.failed_cases
        }
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        
        print(f"\n💾 Report saved to: {output_path}")


async def main():
    parser = argparse.ArgumentParser(description="Evaluate retrieval system using golden dataset")
    parser.add_argument("--repo", required=True, help="GitHub repository URL to evaluate")
    parser.add_argument("--top-k", type=int, default=5, help="Number of results to retrieve (default: 5)")
    parser.add_argument("--session", default="eval_test", help="Session ID for vector store")
    parser.add_argument("--verbose", "-v", action="store_true", help="Print detailed results")
    parser.add_argument("--save", action="store_true", help="Save report to file")
    
    args = parser.parse_args()
    
    evaluator = RetrievalEvaluator()
    report = await evaluator.evaluate(
        repo_url=args.repo,
        session_id=args.session,
        top_k=args.top_k,
        verbose=args.verbose
    )
    
    if report:
        evaluator.print_report(report)
        if args.save:
            evaluator.save_report(report)


if __name__ == "__main__":
    asyncio.run(main())