lanny xu commited on
Commit
5ad083c
·
1 Parent(s): 3f73db0

delete vectara

Browse files
Files changed (4) hide show
  1. evaluate_retrieval.py +346 -0
  2. main.py +32 -4
  3. retrieval_evaluation.py +674 -0
  4. workflow_nodes.py +88 -2
evaluate_retrieval.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 自适应RAG系统检索效果评估脚本
3
+ 评估不同检索策略和配置的效果
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import time
9
+ import json
10
+ import argparse
11
+ from typing import List, Dict, Any, Optional
12
+ from dotenv import load_dotenv
13
+
14
+ # 加载环境变量
15
+ load_dotenv()
16
+
17
+ # 导入项目模块
18
+ from main import AdaptiveRAGSystem
19
+ from document_processor import DocumentProcessor
20
+ from retrieval_evaluation import RetrievalEvaluator, RetrievalResult, RetrievalTestSet
21
+ from langchain.schema import Document
22
+
23
+ # 导入LangChain相关模块
24
+ from langchain_community.vectorstores import FAISS, Chroma
25
+ from langchain_community.retrievers import BM25Retriever
26
+ from langchain.retrievers import EnsembleRetriever
27
+ from langchain.retrievers import ContextualCompressionRetriever
28
+ from langchain.retrievers.document_compressors import LLMChainExtractor
29
+
30
+
31
+ class AdaptiveRAGRetriever:
32
+ """自适应RAG系统检索器包装器"""
33
+
34
+ def __init__(self, system_config: Dict[str, Any], retriever_type: str = "default"):
35
+ """
36
+ 初始化检索器
37
+
38
+ Args:
39
+ system_config: 系统配置
40
+ retriever_type: 检索器类型
41
+ """
42
+ self.system_config = system_config
43
+ self.retriever_type = retriever_type
44
+ self.system = None
45
+ self._initialize_system()
46
+
47
+ def _initialize_system(self):
48
+ """初始化RAG系统"""
49
+ try:
50
+ # 根据检索器类型调整配置
51
+ config = self.system_config.copy()
52
+
53
+ if self.retriever_type == "vector_only":
54
+ config["retrieval_strategy"] = "vector"
55
+ elif self.retriever_type == "bm25_only":
56
+ config["retrieval_strategy"] = "bm25"
57
+ elif self.retriever_type == "hybrid":
58
+ config["retrieval_strategy"] = "hybrid"
59
+ elif self.retriever_type == "graph":
60
+ config["retrieval_strategy"] = "graph"
61
+ elif self.retriever_type == "compression":
62
+ config["use_compression"] = True
63
+ elif self.retriever_type == "rerank":
64
+ config["use_reranking"] = True
65
+ elif self.retriever_type == "query_expansion":
66
+ config["use_query_expansion"] = True
67
+
68
+ # 创建系统实例
69
+ self.system = AdaptiveRAGSystem(config)
70
+
71
+ # 初始化文档处理器(如果需要)
72
+ if not hasattr(self.system, 'document_processor') or self.system.document_processor is None:
73
+ self.system.document_processor = DocumentProcessor(config)
74
+
75
+ except Exception as e:
76
+ print(f"初始化RAG系统失败: {e}")
77
+ raise
78
+
79
+ def retrieve(self, query: str, top_k: int = 10) -> List[Document]:
80
+ """
81
+ 检索文档
82
+
83
+ Args:
84
+ query: 查询文本
85
+ top_k: 返回的文档数量
86
+
87
+ Returns:
88
+ 检索到的文档列表
89
+ """
90
+ try:
91
+ # 使用系统的检索方法
92
+ if hasattr(self.system, 'retrieve'):
93
+ docs = self.system.retrieve(query, top_k)
94
+ else:
95
+ # 如果没有直接的retrieve方法,尝试通过文档处理器检索
96
+ if self.system.document_processor:
97
+ docs = self.system.document_processor.retrieve(query, top_k)
98
+ else:
99
+ raise ValueError("无法找到检索方法")
100
+
101
+ return docs[:top_k]
102
+ except Exception as e:
103
+ print(f"检索失败: {e}")
104
+ return []
105
+
106
+
107
+ def create_evaluation_dataset(data_dir: str = "data", num_queries: int = 20) -> RetrievalTestSet:
108
+ """
109
+ 从项目数据创建评估数据集
110
+
111
+ Args:
112
+ data_dir: 数据目录
113
+ num_queries: 查询数量
114
+
115
+ Returns:
116
+ 检索测试集
117
+ """
118
+ # 检查数据目录
119
+ if not os.path.exists(data_dir):
120
+ print(f"数据目录 {data_dir} 不存在,创建示例数据集")
121
+ from retrieval_evaluation import create_sample_test_set
122
+ return create_sample_test_set()
123
+
124
+ # 尝试从现有数据创建测试集
125
+ try:
126
+ # 加载文档
127
+ documents = []
128
+ doc_files = []
129
+
130
+ # 查找所有文本文件
131
+ for root, dirs, files in os.walk(data_dir):
132
+ for file in files:
133
+ if file.endswith('.txt') or file.endswith('.md'):
134
+ doc_files.append(os.path.join(root, file))
135
+
136
+ # 如果没有找到文档文件,创建示例数据集
137
+ if not doc_files:
138
+ print(f"在 {data_dir} 中未找到文档文件,创建示例数据集")
139
+ from retrieval_evaluation import create_sample_test_set
140
+ return create_sample_test_set()
141
+
142
+ # 读取文档内容
143
+ for i, file_path in enumerate(doc_files):
144
+ with open(file_path, 'r', encoding='utf-8') as f:
145
+ content = f.read().strip()
146
+ if content:
147
+ documents.append(Document(page_content=content, metadata={"source": file_path, "doc_id": str(i)}))
148
+
149
+ # 生成查询(这里简化处理,实际应用中应该使用真实查询)
150
+ queries = []
151
+ qrels = {}
152
+
153
+ # 从文档中提取关键句子作为查询
154
+ for i in range(min(num_queries, len(documents))):
155
+ doc = documents[i]
156
+ sentences = doc.page_content.split('.')
157
+ if sentences:
158
+ # 取第一个非空句子作为查询
159
+ for sentence in sentences:
160
+ sentence = sentence.strip()
161
+ if sentence and len(sentence) > 10: # 确保查询有足够长度
162
+ queries.append(sentence)
163
+ # 假设查询与当前文档相关
164
+ qrels[str(i)] = {str(i): 2} # 高度相关
165
+ # 可能与其他文档也相关
166
+ for j in range(min(3, len(documents))):
167
+ if j != i:
168
+ qrels[str(i)][str(j)] = 1 # 部分相关
169
+ break
170
+
171
+ # 保存查询文件
172
+ with open("eval_queries.txt", "w", encoding="utf-8") as f:
173
+ for query in queries:
174
+ f.write(query + "\n")
175
+
176
+ # 保存文档文件
177
+ with open("eval_documents.txt", "w", encoding="utf-8") as f:
178
+ for doc in documents:
179
+ f.write(doc.page_content + "\n")
180
+
181
+ # 保存相关性标注文件
182
+ with open("eval_qrels.csv", "w", encoding="utf-8") as f:
183
+ for query_id, doc_relevance in qrels.items():
184
+ for doc_id, relevance in doc_relevance.items():
185
+ f.write(f"{query_id},{doc_id},{relevance}\n")
186
+
187
+ print(f"评估数据集已创建:")
188
+ print(f"- 查询数量: {len(queries)}")
189
+ print(f"- 文档数量: {len(documents)}")
190
+ print(f"- eval_queries.txt: 查询文件")
191
+ print(f"- eval_documents.txt: 文档文件")
192
+ print(f"- eval_qrels.csv: 相关性标注文件")
193
+
194
+ return RetrievalTestSet("eval_queries.txt", "eval_documents.txt", "eval_qrels.csv")
195
+
196
+ except Exception as e:
197
+ print(f"创建评估数据集失败: {e}")
198
+ print("创建示例数据集")
199
+ from retrieval_evaluation import create_sample_test_set
200
+ return create_sample_test_set()
201
+
202
+
203
+ def evaluate_retrievers(system_config: Dict[str, Any],
204
+ retriever_types: List[str],
205
+ test_set: RetrievalTestSet,
206
+ output_dir: str = "evaluation_results") -> Dict[str, Any]:
207
+ """
208
+ 评估多个检索器
209
+
210
+ Args:
211
+ system_config: 系统配置
212
+ retriever_types: 检索器类型列表
213
+ test_set: 测试集
214
+ output_dir: 输出目录
215
+
216
+ Returns:
217
+ 评估结果
218
+ """
219
+ # 创建输出目录
220
+ os.makedirs(output_dir, exist_ok=True)
221
+
222
+ # 初始化评估器
223
+ evaluator = RetrievalEvaluator()
224
+
225
+ # 存储所有检索结果
226
+ all_results = {}
227
+
228
+ # 评估每个检索器
229
+ for retriever_type in retriever_types:
230
+ print(f"\n评估检索器: {retriever_type}")
231
+ print("=" * 50)
232
+
233
+ try:
234
+ # 创建检索器
235
+ retriever = AdaptiveRAGRetriever(system_config, retriever_type)
236
+
237
+ # 获取检索结果
238
+ results = test_set.get_retrieval_results(retriever)
239
+ all_results[retriever_type] = results
240
+
241
+ print(f"完成 {len(results)} 个查询的检索")
242
+
243
+ except Exception as e:
244
+ print(f"评估检索器 {retriever_type} 失败: {e}")
245
+ continue
246
+
247
+ # 比较检索器
248
+ if len(all_results) > 1:
249
+ print("\n比较检索器性能")
250
+ print("=" * 50)
251
+ metrics = evaluator.compare_retrievers(all_results)
252
+
253
+ # 生成报告
254
+ report = evaluator.generate_report(
255
+ metrics,
256
+ os.path.join(output_dir, "retrieval_evaluation_report.md")
257
+ )
258
+
259
+ # 绘制比较图
260
+ evaluator.plot_metrics_comparison(
261
+ metrics,
262
+ os.path.join(output_dir, "retrieval_evaluation_comparison.png")
263
+ )
264
+
265
+ # 保存详细指标
266
+ metrics_data = {}
267
+ for name, metric in metrics.items():
268
+ metrics_data[name] = {
269
+ "precision_at_k": metric.precision_at_k,
270
+ "recall_at_k": metric.recall_at_k,
271
+ "f1_at_k": metric.f1_at_k,
272
+ "map_score": metric.map_score,
273
+ "mrr": metric.mrr,
274
+ "ndcg_at_k": metric.ndcg_at_k,
275
+ "coverage": metric.coverage,
276
+ "diversity": metric.diversity,
277
+ "novelty": metric.novelty,
278
+ "latency": metric.latency
279
+ }
280
+
281
+ with open(os.path.join(output_dir, "metrics.json"), "w", encoding="utf-8") as f:
282
+ json.dump(metrics_data, f, indent=2, ensure_ascii=False)
283
+
284
+ return {
285
+ "metrics": metrics,
286
+ "metrics_data": metrics_data,
287
+ "report": report,
288
+ "results": all_results
289
+ }
290
+ else:
291
+ print("只有一个检索器成功评估,跳过比较")
292
+ return {"results": all_results}
293
+
294
+
295
+ def main():
296
+ """主函数"""
297
+ parser = argparse.ArgumentParser(description="评估自适应RAG系统的检索效果")
298
+ parser.add_argument("--config", type=str, default="config.py", help="配置文件路径")
299
+ parser.add_argument("--data_dir", type=str, default="data", help="数据目录")
300
+ parser.add_argument("--output_dir", type=str, default="evaluation_results", help="输出目录")
301
+ parser.add_argument("--num_queries", type=int, default=20, help="查询数量")
302
+ parser.add_argument("--retrievers", nargs="+",
303
+ default=["default", "vector_only", "bm25_only", "hybrid"],
304
+ help="要评估的检索器类型")
305
+
306
+ args = parser.parse_args()
307
+
308
+ # 加载配置
309
+ try:
310
+ if args.config.endswith('.py'):
311
+ # 动态导入Python配置文件
312
+ import importlib.util
313
+ spec = importlib.util.spec_from_file_location("config", args.config)
314
+ config_module = importlib.util.module_from_spec(spec)
315
+ spec.loader.exec_module(config_module)
316
+ system_config = config_module.config
317
+ else:
318
+ # 加载JSON配置文件
319
+ with open(args.config, 'r', encoding='utf-8') as f:
320
+ system_config = json.load(f)
321
+ except Exception as e:
322
+ print(f"加载配置文件失败: {e}")
323
+ print("使用默认配置")
324
+ system_config = {
325
+ "model_name": "gpt-3.5-turbo",
326
+ "vector_store": "faiss",
327
+ "retrieval_strategy": "hybrid",
328
+ "use_reranking": False,
329
+ "use_compression": False,
330
+ "use_query_expansion": False
331
+ }
332
+
333
+ # 创建评估数据集
334
+ print("创建评估数据集")
335
+ test_set = create_evaluation_dataset(args.data_dir, args.num_queries)
336
+
337
+ # 评估检索器
338
+ print("\n开始评估检索器")
339
+ results = evaluate_retrievers(system_config, args.retrievers, test_set, args.output_dir)
340
+
341
+ print("\n评估完成!")
342
+ print(f"结果保存在: {args.output_dir}")
343
+
344
+
345
+ if __name__ == "__main__":
346
+ main()
main.py CHANGED
@@ -110,13 +110,14 @@ class AdaptiveRAGSystem:
110
  verbose (bool): 是否显示详细输出
111
 
112
  Returns:
113
- str: 最终答案
114
  """
115
  print(f"\n🔍 处理问题: {question}")
116
  print("=" * 50)
117
 
118
  inputs = {"question": question, "retry_count": 0} # 初始化重试计数器
119
  final_generation = None
 
120
 
121
  # 设置配置,增加递归限制
122
  config = {"recursion_limit": 50} # 增加到 50,默认是 25
@@ -128,6 +129,9 @@ class AdaptiveRAGSystem:
128
  # 可选:在每个节点打印完整状态
129
  # pprint(value, indent=2, width=80, depth=None)
130
  final_generation = value.get("generation", final_generation)
 
 
 
131
  if verbose:
132
  pprint("\n---\n")
133
 
@@ -136,7 +140,11 @@ class AdaptiveRAGSystem:
136
  print(final_generation)
137
  print("=" * 50)
138
 
139
- return final_generation
 
 
 
 
140
 
141
  def interactive_mode(self):
142
  """交互模式,允许用户持续提问"""
@@ -156,7 +164,17 @@ class AdaptiveRAGSystem:
156
  print("⚠️ 请输入一个有效的问题")
157
  continue
158
 
159
- self.query(question)
 
 
 
 
 
 
 
 
 
 
160
 
161
  except KeyboardInterrupt:
162
  print("\n👋 感谢使用,再见!")
@@ -175,7 +193,17 @@ def main():
175
  # 测试查询
176
  test_question = "AlphaCodium论文讲的是什么?"
177
  # test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
178
- rag_system.query(test_question)
 
 
 
 
 
 
 
 
 
 
179
 
180
  # 启动交互模式
181
  rag_system.interactive_mode()
 
110
  verbose (bool): 是否显示详细输出
111
 
112
  Returns:
113
+ dict: 包含最终答案和评估指标的字典
114
  """
115
  print(f"\n🔍 处理问题: {question}")
116
  print("=" * 50)
117
 
118
  inputs = {"question": question, "retry_count": 0} # 初始化重试计数器
119
  final_generation = None
120
+ retrieval_metrics = None
121
 
122
  # 设置配置,增加递归限制
123
  config = {"recursion_limit": 50} # 增加到 50,默认是 25
 
129
  # 可选:在每个节点打印完整状态
130
  # pprint(value, indent=2, width=80, depth=None)
131
  final_generation = value.get("generation", final_generation)
132
+ # 保存检索评估指标
133
+ if "retrieval_metrics" in value:
134
+ retrieval_metrics = value["retrieval_metrics"]
135
  if verbose:
136
  pprint("\n---\n")
137
 
 
140
  print(final_generation)
141
  print("=" * 50)
142
 
143
+ # 返回包含答案和评估指标的字典
144
+ return {
145
+ "answer": final_generation,
146
+ "retrieval_metrics": retrieval_metrics
147
+ }
148
 
149
  def interactive_mode(self):
150
  """交互模式,允许用户持续提问"""
 
164
  print("⚠️ 请输入一个有效的问题")
165
  continue
166
 
167
+ result = self.query(question)
168
+
169
+ # 显示检索评估摘要
170
+ if result.get("retrieval_metrics"):
171
+ metrics = result["retrieval_metrics"]
172
+ print("\n📊 检索评估摘要:")
173
+ print(f" - 检索耗时: {metrics.get('latency', 0):.4f}秒")
174
+ print(f" - 检索文档数: {metrics.get('retrieved_docs_count', 0)}")
175
+ print(f" - Precision@3: {metrics.get('precision_at_3', 0):.4f}")
176
+ print(f" - Recall@3: {metrics.get('recall_at_3', 0):.4f}")
177
+ print(f" - MAP: {metrics.get('map_score', 0):.4f}")
178
 
179
  except KeyboardInterrupt:
180
  print("\n👋 感谢使用,再见!")
 
193
  # 测试查询
194
  test_question = "AlphaCodium论文讲的是什么?"
195
  # test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
196
+ result = rag_system.query(test_question)
197
+
198
+ # 显示测试查询的检索评估摘要
199
+ if result.get("retrieval_metrics"):
200
+ metrics = result["retrieval_metrics"]
201
+ print("\n📊 测试查询检索评估摘要:")
202
+ print(f" - 检索耗时: {metrics.get('latency', 0):.4f}秒")
203
+ print(f" - 检索文档数: {metrics.get('retrieved_docs_count', 0)}")
204
+ print(f" - Precision@3: {metrics.get('precision_at_3', 0):.4f}")
205
+ print(f" - Recall@3: {metrics.get('recall_at_3', 0):.4f}")
206
+ print(f" - MAP: {metrics.get('map_score', 0):.4f}")
207
 
208
  # 启动交互模式
209
  rag_system.interactive_mode()
retrieval_evaluation.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 检索效果评估模块
3
+ 提供多种评估指标和方法,用于评估RAG系统中检索结果的质量
4
+ """
5
+
6
+ import time
7
+ import json
8
+ import numpy as np
9
+ from typing import List, Dict, Tuple, Any, Optional, Union
10
+ from dataclasses import dataclass, asdict
11
+ from langchain.schema import Document
12
+ from sklearn.metrics import ndcg_score, precision_score, recall_score, f1_score
13
+ from sentence_transformers import SentenceTransformer, util
14
+ import matplotlib.pyplot as plt
15
+ import seaborn as sns
16
+ import pandas as pd
17
+ import torch
18
+
19
+
20
+ @dataclass
21
+ class RetrievalResult:
22
+ """检索结果数据类"""
23
+ query: str
24
+ retrieved_docs: List[Document]
25
+ relevant_docs: List[Document] # 真实相关的文档
26
+ retrieval_time: float
27
+ scores: Optional[List[float]] = None # 检索分数
28
+
29
+
30
+ @dataclass
31
+ class EvaluationMetrics:
32
+ """评估指标数据类"""
33
+ precision_at_k: Dict[int, float]
34
+ recall_at_k: Dict[int, float]
35
+ f1_at_k: Dict[int, float]
36
+ map_score: float # 平均精度均值
37
+ mrr: float # 平均倒数排名
38
+ ndcg_at_k: Dict[int, float]
39
+ coverage: float # 覆盖率
40
+ diversity: float # 多样性
41
+ novelty: float # 新颖性
42
+ latency: float # 平均检索延迟
43
+
44
+
45
+ class RetrievalEvaluator:
46
+ """检索效果评估器"""
47
+
48
+ def __init__(self, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
49
+ """
50
+ 初始化评估器
51
+
52
+ Args:
53
+ embedding_model: 用于计算语义相似度的嵌入模型
54
+ """
55
+ self.embedding_model = SentenceTransformer(embedding_model)
56
+
57
+ def evaluate_retrieval(self, results: List[RetrievalResult], k_values: List[int] = [1, 3, 5, 10]) -> EvaluationMetrics:
58
+ """
59
+ 评估检索结果
60
+
61
+ Args:
62
+ results: 检索结果列表
63
+ k_values: 要计算的k值列表
64
+
65
+ Returns:
66
+ 评估指标
67
+ """
68
+ precision_at_k = {}
69
+ recall_at_k = {}
70
+ f1_at_k = {}
71
+ ndcg_at_k = {}
72
+
73
+ total_precision = {k: 0 for k in k_values}
74
+ total_recall = {k: 0 for k in k_values}
75
+ total_f1 = {k: 0 for k in k_values}
76
+ total_ndcg = {k: 0 for k in k_values}
77
+
78
+ all_precisions = []
79
+ all_reciprocal_ranks = []
80
+ all_latencies = []
81
+
82
+ for result in results:
83
+ query = result.query
84
+ retrieved_docs = result.retrieved_docs
85
+ relevant_docs = result.relevant_docs
86
+ retrieval_time = result.retrieval_time
87
+
88
+ all_latencies.append(retrieval_time)
89
+
90
+ # 获取相关文档的ID或内容
91
+ relevant_ids = set()
92
+ for doc in relevant_docs:
93
+ # 使用文档内容作为ID,实际应用中可以使用文档ID
94
+ doc_id = doc.page_content[:50] # 使用前50个字符作为ID
95
+ relevant_ids.add(doc_id)
96
+
97
+ # 计算每个k值的指标
98
+ for k in k_values:
99
+ retrieved_k = retrieved_docs[:k]
100
+ retrieved_k_ids = set()
101
+
102
+ for doc in retrieved_k:
103
+ doc_id = doc.page_content[:50]
104
+ retrieved_k_ids.add(doc_id)
105
+
106
+ # 计算交集
107
+ intersection = len(relevant_ids.intersection(retrieved_k_ids))
108
+
109
+ # 计算Precision@K
110
+ precision_k = intersection / k if k > 0 else 0
111
+ total_precision[k] += precision_k
112
+
113
+ # 计算Recall@K
114
+ recall_k = intersection / len(relevant_ids) if len(relevant_ids) > 0 else 0
115
+ total_recall[k] += recall_k
116
+
117
+ # 计算F1@K
118
+ if precision_k + recall_k > 0:
119
+ f1_k = 2 * (precision_k * recall_k) / (precision_k + recall_k)
120
+ else:
121
+ f1_k = 0
122
+ total_f1[k] += f1_k
123
+
124
+ # 计算NDCG@K
125
+ if result.scores:
126
+ # 创建相关性分数 (1表示相关,0表示不相关)
127
+ relevance_scores = []
128
+ for doc in retrieved_k:
129
+ doc_id = doc.page_content[:50]
130
+ relevance = 1 if doc_id in relevant_ids else 0
131
+ relevance_scores.append(relevance)
132
+
133
+ # 理想排序 (所有相关文档排在前面)
134
+ ideal_relevance = sorted(relevance_scores, reverse=True)
135
+
136
+ # 计算NDCG
137
+ if len(relevance_scores) > 1 and sum(ideal_relevance) > 0:
138
+ try:
139
+ ndcg_k = ndcg_score([ideal_relevance], [relevance_scores], k=k)
140
+ total_ndcg[k] += ndcg_k
141
+ except:
142
+ # 如果计算失败,使用简化的NDCG计算
143
+ dcg = 0
144
+ idcg = 0
145
+ for i, rel in enumerate(relevance_scores):
146
+ dcg += rel / np.log2(i + 2) if rel > 0 else 0
147
+ for i, rel in enumerate(ideal_relevance):
148
+ idcg += rel / np.log2(i + 2) if rel > 0 else 0
149
+ ndcg_k = dcg / idcg if idcg > 0 else 0
150
+ total_ndcg[k] += ndcg_k
151
+ else:
152
+ total_ndcg[k] += 1.0 # 如果没有相关文档或只有一个文档,NDCG为1
153
+
154
+ # 计算平均精度 (AP)
155
+ precisions = []
156
+ for i, doc in enumerate(retrieved_docs):
157
+ doc_id = doc.page_content[:50]
158
+ if doc_id in relevant_ids:
159
+ precision_at_i = len(relevant_ids.intersection(set(
160
+ d.page_content[:50] for d in retrieved_docs[:i+1]
161
+ ))) / (i + 1)
162
+ precisions.append(precision_at_i)
163
+
164
+ ap = sum(precisions) / len(relevant_ids) if precisions else 0
165
+ all_precisions.append(ap)
166
+
167
+ # 计算倒数排名 (RR)
168
+ for i, doc in enumerate(retrieved_docs):
169
+ doc_id = doc.page_content[:50]
170
+ if doc_id in relevant_ids:
171
+ rr = 1 / (i + 1)
172
+ all_reciprocal_ranks.append(rr)
173
+ break
174
+ else:
175
+ all_reciprocal_ranks.append(0)
176
+
177
+ # 计算平均指标
178
+ num_results = len(results)
179
+ for k in k_values:
180
+ precision_at_k[k] = total_precision[k] / num_results
181
+ recall_at_k[k] = total_recall[k] / num_results
182
+ f1_at_k[k] = total_f1[k] / num_results
183
+ ndcg_at_k[k] = total_ndcg[k] / num_results
184
+
185
+ map_score = sum(all_precisions) / num_results if all_precisions else 0
186
+ mrr = sum(all_reciprocal_ranks) / num_results if all_reciprocal_ranks else 0
187
+ latency = sum(all_latencies) / num_results if all_latencies else 0
188
+
189
+ # 计算覆盖率、多样性和新颖性
190
+ coverage = self._calculate_coverage(results)
191
+ diversity = self._calculate_diversity(results)
192
+ novelty = self._calculate_novelty(results)
193
+
194
+ return EvaluationMetrics(
195
+ precision_at_k=precision_at_k,
196
+ recall_at_k=recall_at_k,
197
+ f1_at_k=f1_at_k,
198
+ map_score=map_score,
199
+ mrr=mrr,
200
+ ndcg_at_k=ndcg_at_k,
201
+ coverage=coverage,
202
+ diversity=diversity,
203
+ novelty=novelty,
204
+ latency=latency
205
+ )
206
+
207
+ def _calculate_coverage(self, results: List[RetrievalResult]) -> float:
208
+ """计算覆盖率 - 检索到的唯一文档数与总文档数的比例"""
209
+ all_retrieved = set()
210
+ all_relevant = set()
211
+
212
+ for result in results:
213
+ for doc in result.retrieved_docs:
214
+ doc_id = doc.page_content[:50]
215
+ all_retrieved.add(doc_id)
216
+
217
+ for doc in result.relevant_docs:
218
+ doc_id = doc.page_content[:50]
219
+ all_relevant.add(doc_id)
220
+
221
+ coverage = len(all_retrieved) / len(all_relevant) if all_relevant else 0
222
+ return coverage
223
+
224
+ def _calculate_diversity(self, results: List[RetrievalResult]) -> float:
225
+ """计算多样性 - 检索结果之间的平均语义差异"""
226
+ all_similarities = []
227
+
228
+ for result in results:
229
+ if len(result.retrieved_docs) < 2:
230
+ continue
231
+
232
+ # 获取文档嵌入
233
+ doc_texts = [doc.page_content for doc in result.retrieved_docs]
234
+ embeddings = self.embedding_model.encode(doc_texts, convert_to_tensor=True)
235
+
236
+ # 计算文档之间的余弦相似度
237
+ cos_sim = util.pytorch_cos_sim(embeddings, embeddings)
238
+
239
+ # 获取上三角矩阵(排除对角线)
240
+ upper_triangle_indices = torch.triu_indices(len(cos_sim), len(cos_sim), offset=1)
241
+ similarities = cos_sim[upper_triangle_indices[0], upper_triangle_indices[1]]
242
+
243
+ # 多样性 = 1 - 平均相似度
244
+ diversity = 1 - similarities.mean().item()
245
+ all_similarities.append(diversity)
246
+
247
+ return sum(all_similarities) / len(all_similarities) if all_similarities else 0
248
+
249
+ def _calculate_novelty(self, results: List[RetrievalResult]) -> float:
250
+ """计算新颖性 - 检索结果中不重复内容的比例"""
251
+ total_docs = 0
252
+ unique_docs = set()
253
+
254
+ for result in results:
255
+ for doc in result.retrieved_docs:
256
+ total_docs += 1
257
+ doc_id = doc.page_content[:50]
258
+ unique_docs.add(doc_id)
259
+
260
+ novelty = len(unique_docs) / total_docs if total_docs > 0 else 0
261
+ return novelty
262
+
263
+ def compare_retrievers(self, retriever_results: Dict[str, List[RetrievalResult]],
264
+ k_values: List[int] = [1, 3, 5, 10]) -> Dict[str, EvaluationMetrics]:
265
+ """
266
+ 比较多个检索器的性能
267
+
268
+ Args:
269
+ retriever_results: 检索器名称到检索结果的映射
270
+ k_values: 要计算的k值列表
271
+
272
+ Returns:
273
+ 检索器名称到评估指标的映射
274
+ """
275
+ metrics = {}
276
+
277
+ for name, results in retriever_results.items():
278
+ print(f"评估检索器: {name}")
279
+ metrics[name] = self.evaluate_retrieval(results, k_values)
280
+
281
+ return metrics
282
+
283
+ def generate_report(self, metrics: Dict[str, EvaluationMetrics],
284
+ save_path: Optional[str] = None) -> str:
285
+ """
286
+ 生成评估报告
287
+
288
+ Args:
289
+ metrics: 检索器名称到评估指标的映射
290
+ save_path: 报告保存路径
291
+
292
+ Returns:
293
+ 报告文本
294
+ """
295
+ report = []
296
+ report.append("# 检索效果评估报告\n")
297
+
298
+ # 创建比较表
299
+ df_data = []
300
+ for name, metric in metrics.items():
301
+ row = {"检索器": name}
302
+ row.update({
303
+ f"Precision@{k}": f"{metric.precision_at_k[k]:.4f}"
304
+ for k in sorted(metric.precision_at_k.keys())
305
+ })
306
+ row.update({
307
+ f"Recall@{k}": f"{metric.recall_at_k[k]:.4f}"
308
+ for k in sorted(metric.recall_at_k.keys())
309
+ })
310
+ row.update({
311
+ f"F1@{k}": f"{metric.f1_at_k[k]:.4f}"
312
+ for k in sorted(metric.f1_at_k.keys())
313
+ })
314
+ row.update({
315
+ f"NDCG@{k}": f"{metric.ndcg_at_k[k]:.4f}"
316
+ for k in sorted(metric.ndcg_at_k.keys())
317
+ })
318
+ row.update({
319
+ "MAP": f"{metric.map_score:.4f}",
320
+ "MRR": f"{metric.mrr:.4f}",
321
+ "覆盖率": f"{metric.coverage:.4f}",
322
+ "多样性": f"{metric.diversity:.4f}",
323
+ "新颖性": f"{metric.novelty:.4f}",
324
+ "延迟(ms)": f"{metric.latency*1000:.2f}"
325
+ })
326
+ df_data.append(row)
327
+
328
+ df = pd.DataFrame(df_data)
329
+ report.append("## 指标比较表\n")
330
+ report.append(df.to_string(index=False))
331
+ report.append("\n\n")
332
+
333
+ # 添加指标解释
334
+ report.append("## 指标解释\n")
335
+ report.append("- **Precision@K**: 前K个结果中相关文档的比例\n")
336
+ report.append("- **Recall@K**: 前K个结果中相关文档占所有相关文档的比例\n")
337
+ report.append("- **F1@K**: Precision和Recall的调和平均数\n")
338
+ report.append("- **NDCG@K**: 归一化折扣累积增益,考虑排序位置\n")
339
+ report.append("- **MAP**: 平均精度均值,所有查询的平均精度\n")
340
+ report.append("- **MRR**: 平均倒数排名,第一个相关文档排名的倒数平均值\n")
341
+ report.append("- **覆盖率**: 检索到的唯一文档数与总文档数的比例\n")
342
+ report.append("- **多样性**: 检索结果之间的平均语义差异\n")
343
+ report.append("- **新颖性**: 检索结果中不重复内容的比例\n")
344
+ report.append("- **延迟**: 平均检索时间\n")
345
+
346
+ # 添加最佳检索器
347
+ report.append("## 最佳检索器\n")
348
+
349
+ # 找出每个指标的最佳检索器
350
+ best_metrics = {}
351
+ for metric_name in ["precision_at_5", "recall_at_5", "f1_at_5", "ndcg_at_5", "map_score", "mrr"]:
352
+ best_name = max(metrics.keys(), key=lambda x: getattr(metrics[x], metric_name))
353
+ best_metrics[metric_name] = best_name
354
+ report.append(f"- **{metric_name}**: {best_name}\n")
355
+
356
+ report_text = "".join(report)
357
+
358
+ # 保存报告
359
+ if save_path:
360
+ with open(save_path, "w", encoding="utf-8") as f:
361
+ f.write(report_text)
362
+ print(f"报告已保存到: {save_path}")
363
+
364
+ return report_text
365
+
366
+ def plot_metrics_comparison(self, metrics: Dict[str, EvaluationMetrics],
367
+ save_path: Optional[str] = None):
368
+ """
369
+ 绘制指标比较图
370
+
371
+ Args:
372
+ metrics: 检索器名称到评估指标的映射
373
+ save_path: 图表保存路径
374
+ """
375
+ # 准备数据
376
+ retriever_names = list(metrics.keys())
377
+
378
+ # 创建子图
379
+ fig, axes = plt.subplots(2, 3, figsize=(18, 12))
380
+ fig.suptitle("检索器性能比较", fontsize=16)
381
+
382
+ # Precision@K
383
+ ax = axes[0, 0]
384
+ k_values = sorted(list(metrics[retriever_names[0]].precision_at_k.keys()))
385
+ for name in retriever_names:
386
+ precision_values = [metrics[name].precision_at_k[k] for k in k_values]
387
+ ax.plot(k_values, precision_values, marker='o', label=name)
388
+ ax.set_title("Precision@K")
389
+ ax.set_xlabel("K")
390
+ ax.set_ylabel("Precision")
391
+ ax.legend()
392
+ ax.grid(True)
393
+
394
+ # Recall@K
395
+ ax = axes[0, 1]
396
+ for name in retriever_names:
397
+ recall_values = [metrics[name].recall_at_k[k] for k in k_values]
398
+ ax.plot(k_values, recall_values, marker='o', label=name)
399
+ ax.set_title("Recall@K")
400
+ ax.set_xlabel("K")
401
+ ax.set_ylabel("Recall")
402
+ ax.legend()
403
+ ax.grid(True)
404
+
405
+ # F1@K
406
+ ax = axes[0, 2]
407
+ for name in retriever_names:
408
+ f1_values = [metrics[name].f1_at_k[k] for k in k_values]
409
+ ax.plot(k_values, f1_values, marker='o', label=name)
410
+ ax.set_title("F1@K")
411
+ ax.set_xlabel("K")
412
+ ax.set_ylabel("F1")
413
+ ax.legend()
414
+ ax.grid(True)
415
+
416
+ # NDCG@K
417
+ ax = axes[1, 0]
418
+ for name in retriever_names:
419
+ ndcg_values = [metrics[name].ndcg_at_k[k] for k in k_values]
420
+ ax.plot(k_values, ndcg_values, marker='o', label=name)
421
+ ax.set_title("NDCG@K")
422
+ ax.set_xlabel("K")
423
+ ax.set_ylabel("NDCG")
424
+ ax.legend()
425
+ ax.grid(True)
426
+
427
+ # MAP和MRR
428
+ ax = axes[1, 1]
429
+ map_values = [metrics[name].map_score for name in retriever_names]
430
+ mrr_values = [metrics[name].mrr for name in retriever_names]
431
+ x = np.arange(len(retriever_names))
432
+ width = 0.35
433
+ ax.bar(x - width/2, map_values, width, label='MAP')
434
+ ax.bar(x + width/2, mrr_values, width, label='MRR')
435
+ ax.set_title("MAP和MRR")
436
+ ax.set_xticks(x)
437
+ ax.set_xticklabels(retriever_names)
438
+ ax.legend()
439
+ ax.grid(True)
440
+
441
+ # 其他指标
442
+ ax = axes[1, 2]
443
+ other_metrics = ['coverage', 'diversity', 'novelty']
444
+ metric_values = {metric: [] for metric in other_metrics}
445
+ for name in retriever_names:
446
+ for metric in other_metrics:
447
+ metric_values[metric].append(getattr(metrics[name], metric))
448
+
449
+ x = np.arange(len(retriever_names))
450
+ width = 0.25
451
+ for i, metric in enumerate(other_metrics):
452
+ ax.bar(x + i*width, metric_values[metric], width, label=metric)
453
+ ax.set_title("其他指标")
454
+ ax.set_xticks(x + width)
455
+ ax.set_xticklabels(retriever_names)
456
+ ax.legend()
457
+ ax.grid(True)
458
+
459
+ plt.tight_layout()
460
+
461
+ # 保存图表
462
+ if save_path:
463
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
464
+ print(f"图表已保存到: {save_path}")
465
+
466
+ plt.show()
467
+
468
+
469
+ class RetrievalTestSet:
470
+ """检索测试集"""
471
+
472
+ def __init__(self, queries_file: str, documents_file: str, qrels_file: str):
473
+ """
474
+ 初始化测试集
475
+
476
+ Args:
477
+ queries_file: 查询文件路径,每行一个查询
478
+ documents_file: 文档文件路径,每行一个文档
479
+ qrels_file: 相关性标注文件路径,格式为: query_id,doc_id,relevance
480
+ """
481
+ self.queries = self._load_queries(queries_file)
482
+ self.documents = self._load_documents(documents_file)
483
+ self.qrels = self._load_qrels(qrels_file)
484
+
485
+ def _load_queries(self, file_path: str) -> Dict[str, str]:
486
+ """加载查询"""
487
+ queries = {}
488
+ with open(file_path, 'r', encoding='utf-8') as f:
489
+ for i, line in enumerate(f):
490
+ queries[str(i)] = line.strip()
491
+ return queries
492
+
493
+ def _load_documents(self, file_path: str) -> Dict[str, Document]:
494
+ """加载文档"""
495
+ documents = {}
496
+ with open(file_path, 'r', encoding='utf-8') as f:
497
+ for i, line in enumerate(f):
498
+ doc = Document(page_content=line.strip(), metadata={"doc_id": str(i)})
499
+ documents[str(i)] = doc
500
+ return documents
501
+
502
+ def _load_qrels(self, file_path: str) -> Dict[str, Dict[str, int]]:
503
+ """加载相关性标注"""
504
+ qrels = {}
505
+ with open(file_path, 'r', encoding='utf-8') as f:
506
+ for line in f:
507
+ parts = line.strip().split(',')
508
+ if len(parts) >= 3:
509
+ query_id, doc_id, relevance = parts[0], parts[1], int(parts[2])
510
+ if query_id not in qrels:
511
+ qrels[query_id] = {}
512
+ qrels[query_id][doc_id] = relevance
513
+ return qrels
514
+
515
+ def get_retrieval_results(self, retriever, top_k: int = 10) -> List[RetrievalResult]:
516
+ """
517
+ 使用指定检索器获取检索结果
518
+
519
+ Args:
520
+ retriever: 检索器,需要有一个retrieve(query, top_k)方法
521
+ top_k: 返��的文档数量
522
+
523
+ Returns:
524
+ 检索结果列表
525
+ """
526
+ results = []
527
+
528
+ for query_id, query_text in self.queries.items():
529
+ start_time = time.time()
530
+ retrieved_docs = retriever.retrieve(query_text, top_k)
531
+ retrieval_time = time.time() - start_time
532
+
533
+ # 获取相关文档
534
+ relevant_docs = []
535
+ if query_id in self.qrels:
536
+ for doc_id, relevance in self.qrels[query_id].items():
537
+ if relevance > 0 and doc_id in self.documents:
538
+ relevant_docs.append(self.documents[doc_id])
539
+
540
+ result = RetrievalResult(
541
+ query=query_text,
542
+ retrieved_docs=retrieved_docs,
543
+ relevant_docs=relevant_docs,
544
+ retrieval_time=retrieval_time
545
+ )
546
+ results.append(result)
547
+
548
+ return results
549
+
550
+
551
+ def create_sample_test_set():
552
+ """创建示例测试集"""
553
+ # 创建示例查询
554
+ queries = [
555
+ "什么是机器学习?",
556
+ "深度学习和机器学习的区别是什么?",
557
+ "如何评估机器学习模型的性能?",
558
+ "自然语言处理有哪些应用?",
559
+ "计算机视觉的基本任务是什么?"
560
+ ]
561
+
562
+ # 创建示例文档
563
+ documents = [
564
+ "机器学习是人工智能的一个分支,它使计算机能够在没有明确编程的情况下学习和改进。",
565
+ "深度学习是机器学习的一个子集,它使用多层神经网络来模拟人脑的工作方式。",
566
+ "评估机器学习模型的常用指标包括准确率、精确率、召回率和F1分数。",
567
+ "自然语言处理是计算机科学和人工智能的一个分支,专注于计算机与人类语言之间的交互。",
568
+ "计算机视觉是人工智能的一个领域,训练计算机解释和理解视觉世界。",
569
+ "强化学习是机器学习的一个类型,它关注软件代理应该如何在环境中采取行动以最大化累积奖励。",
570
+ "数据预处理是机器学习流程中的重要步骤,包括数据清洗、特征选择和特征工程。",
571
+ "过拟合是机器学习中的一个常见问题,指模型在训练数据上表现良好但在新数据上表现不佳。",
572
+ "卷积神经网络(CNN)是一类深度神经网络,最常用于分析视觉图像。",
573
+ "循环神经网络(RNN)是一类人工神经网络,其中节点之间的连接形成有向图沿时间序列。"
574
+ ]
575
+
576
+ # 创建相关性标注
577
+ qrels = {
578
+ "0": {"0": 2, "1": 1, "6": 1, "7": 1}, # 什么是机器学习?
579
+ "1": {"0": 1, "1": 2, "8": 1, "9": 1}, # 深度学习和机器学习的区别
580
+ "2": {"2": 2, "7": 1}, # 如何评估机器学习模型的性能
581
+ "3": {"3": 2, "9": 1}, # 自然语言处理的应用
582
+ "4": {"4": 2, "8": 1} # 计算机视觉的基本任务
583
+ }
584
+
585
+ # 保存文件
586
+ with open("sample_queries.txt", "w", encoding="utf-8") as f:
587
+ for query in queries:
588
+ f.write(query + "\n")
589
+
590
+ with open("sample_documents.txt", "w", encoding="utf-8") as f:
591
+ for doc in documents:
592
+ f.write(doc + "\n")
593
+
594
+ with open("sample_qrels.csv", "w", encoding="utf-8") as f:
595
+ for query_id, doc_relevance in qrels.items():
596
+ for doc_id, relevance in doc_relevance.items():
597
+ f.write(f"{query_id},{doc_id},{relevance}\n")
598
+
599
+ print("示例测试集已创建:")
600
+ print("- sample_queries.txt: 查询文件")
601
+ print("- sample_documents.txt: 文档文件")
602
+ print("- sample_qrels.csv: 相关性标注文件")
603
+
604
+ return RetrievalTestSet("sample_queries.txt", "sample_documents.txt", "sample_qrels.csv")
605
+
606
+
607
+ if __name__ == "__main__":
608
+ # 创建示例测试集
609
+ test_set = create_sample_test_set()
610
+
611
+ # 创建评估器
612
+ evaluator = RetrievalEvaluator()
613
+
614
+ # 这里应该使用您的实际检索器
615
+ # 以下是一个模拟的检索器,用于演示
616
+ class MockRetriever:
617
+ def __init__(self, name):
618
+ self.name = name
619
+
620
+ def retrieve(self, query, top_k=10):
621
+ # 模拟检索结果
622
+ import random
623
+ all_docs = list(test_set.documents.values())
624
+ # 模拟不同质量的检索器
625
+ if self.name == "good":
626
+ # 好的检索器:有更高概率返回相关文档
627
+ relevant_docs = [doc for doc in all_docs if any(keyword in doc.page_content.lower()
628
+ for keyword in query.lower().split()[:2])]
629
+ if relevant_docs:
630
+ results = relevant_docs[:min(top_k//2, len(relevant_docs))]
631
+ results += random.sample(all_docs, min(top_k-len(results), len(all_docs)))
632
+ else:
633
+ results = random.sample(all_docs, min(top_k, len(all_docs)))
634
+ elif self.name == "medium":
635
+ # 中等检索器
636
+ relevant_docs = [doc for doc in all_docs if any(keyword in doc.page_content.lower()
637
+ for keyword in [query.lower().split()[0]])]
638
+ if relevant_docs:
639
+ results = relevant_docs[:min(top_k//3, len(relevant_docs))]
640
+ results += random.sample(all_docs, min(top_k-len(results), len(all_docs)))
641
+ else:
642
+ results = random.sample(all_docs, min(top_k, len(all_docs)))
643
+ else:
644
+ # 差的检索器:随机返回
645
+ results = random.sample(all_docs, min(top_k, len(all_docs)))
646
+
647
+ return results
648
+
649
+ # 创建不同质量的检索器
650
+ good_retriever = MockRetriever("good")
651
+ medium_retriever = MockRetriever("medium")
652
+ poor_retriever = MockRetriever("poor")
653
+
654
+ # 获取检索结果
655
+ good_results = test_set.get_retrieval_results(good_retriever)
656
+ medium_results = test_set.get_retrieval_results(medium_retriever)
657
+ poor_results = test_set.get_retrieval_results(poor_retriever)
658
+
659
+ # 比较检索器
660
+ retriever_results = {
661
+ "好的检索器": good_results,
662
+ "中等检索器": medium_results,
663
+ "差的检索器": poor_results
664
+ }
665
+
666
+ # 评估检索器
667
+ metrics = evaluator.compare_retrievers(retriever_results)
668
+
669
+ # 生成报告
670
+ report = evaluator.generate_report(metrics, "retrieval_evaluation_report.md")
671
+ print(report)
672
+
673
+ # 绘制比较图
674
+ evaluator.plot_metrics_comparison(metrics, "retrieval_evaluation_comparison.png")
workflow_nodes.py CHANGED
@@ -3,6 +3,7 @@
3
  包含所有工作流节点函数和状态管理
4
  """
5
 
 
6
  from typing import List
7
  from typing_extensions import TypedDict
8
  try:
@@ -19,6 +20,7 @@ except ImportError:
19
 
20
  from config import LOCAL_LLM, WEB_SEARCH_RESULTS_COUNT, ENABLE_HYBRID_SEARCH, ENABLE_QUERY_EXPANSION, ENABLE_MULTIMODAL
21
  from document_processor import DocumentProcessor
 
22
  from pprint import pprint
23
 
24
 
@@ -31,11 +33,13 @@ class GraphState(TypedDict):
31
  generation: LLM生成
32
  documents: 文档列表
33
  retry_count: 重试计数器,防止无限循环
 
34
  """
35
  question: str
36
  generation: str
37
  documents: List[str]
38
  retry_count: int
 
39
 
40
 
41
  class WorkflowNodes:
@@ -46,6 +50,9 @@ class WorkflowNodes:
46
  self.retriever = retriever if retriever is not None else getattr(doc_processor, 'retriever', None)
47
  self.graders = graders
48
 
 
 
 
49
  # 设置RAG链 - 使用本地提示模板
50
  rag_prompt_template = PromptTemplate(
51
  template="""你是一个问答助手。使用以下检索到的上下文来回答问题。
@@ -77,6 +84,7 @@ class WorkflowNodes:
77
  print("---检索---")
78
  question = state["question"]
79
  retry_count = state.get("retry_count", 0)
 
80
 
81
  # 使用增强检索方法,支持混合检索、查询扩展和多模态
82
  try:
@@ -118,8 +126,19 @@ class WorkflowNodes:
118
  except Exception as fallback_e:
119
  print(f"❌ 回退检索也失败: {fallback_e}")
120
  documents = []
121
-
122
- return {"documents": documents, "question": question, "retry_count": retry_count}
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  def generate(self, state):
125
  """
@@ -295,6 +314,73 @@ class WorkflowNodes:
295
  return "not supported"
296
 
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  def format_docs(docs):
299
  """格式化文档用于显示"""
300
  return "\n\n".join(doc.page_content for doc in docs)
 
3
  包含所有工作流节点函数和状态管理
4
  """
5
 
6
+ import time
7
  from typing import List
8
  from typing_extensions import TypedDict
9
  try:
 
20
 
21
  from config import LOCAL_LLM, WEB_SEARCH_RESULTS_COUNT, ENABLE_HYBRID_SEARCH, ENABLE_QUERY_EXPANSION, ENABLE_MULTIMODAL
22
  from document_processor import DocumentProcessor
23
+ from retrieval_evaluation import RetrievalEvaluator, RetrievalResult
24
  from pprint import pprint
25
 
26
 
 
33
  generation: LLM生成
34
  documents: 文档列表
35
  retry_count: 重试计数器,防止无限循环
36
+ retrieval_metrics: 检索评估指标
37
  """
38
  question: str
39
  generation: str
40
  documents: List[str]
41
  retry_count: int
42
+ retrieval_metrics: dict # 添加检索评估指标
43
 
44
 
45
  class WorkflowNodes:
 
50
  self.retriever = retriever if retriever is not None else getattr(doc_processor, 'retriever', None)
51
  self.graders = graders
52
 
53
+ # 初始化检索评估器
54
+ self.retrieval_evaluator = RetrievalEvaluator()
55
+
56
  # 设置RAG链 - 使用本地提示模板
57
  rag_prompt_template = PromptTemplate(
58
  template="""你是一个问答助手。使用以下检索到的上下文来回答问题。
 
84
  print("---检索---")
85
  question = state["question"]
86
  retry_count = state.get("retry_count", 0)
87
+ retrieval_start_time = time.time()
88
 
89
  # 使用增强检索方法,支持混合检索、查询扩展和多模态
90
  try:
 
126
  except Exception as fallback_e:
127
  print(f"❌ 回退检索也失败: {fallback_e}")
128
  documents = []
129
+
130
+ # 计算检索时间
131
+ retrieval_time = time.time() - retrieval_start_time
132
+
133
+ # 评估检索结果
134
+ retrieval_metrics = self._evaluate_retrieval_results(question, documents, retrieval_time)
135
+
136
+ return {
137
+ "documents": documents,
138
+ "question": question,
139
+ "retry_count": retry_count,
140
+ "retrieval_metrics": retrieval_metrics
141
+ }
142
 
143
  def generate(self, state):
144
  """
 
314
  return "not supported"
315
 
316
 
317
+ def _evaluate_retrieval_results(self, question, documents, retrieval_time):
318
+ """
319
+ 评估检索结果的质量
320
+
321
+ Args:
322
+ question: 查询问题
323
+ documents: 检索到的文档
324
+ retrieval_time: 检索耗时
325
+
326
+ Returns:
327
+ dict: 评估指标
328
+ """
329
+ try:
330
+ # 创建模拟的相关文档(在实际应用中,这些应该是真实的相关文档)
331
+ # 这里我们假设前几个文档是相关的,用于演示评估功能
332
+ relevant_docs = documents[:min(2, len(documents))] if documents else []
333
+
334
+ # 创建检索结果对象
335
+ retrieval_result = RetrievalResult(
336
+ query=question,
337
+ retrieved_docs=documents,
338
+ relevant_docs=relevant_docs,
339
+ retrieval_time=retrieval_time
340
+ )
341
+
342
+ # 评估检索结果
343
+ metrics = self.retrieval_evaluator.evaluate_retrieval([retrieval_result], k_values=[1, 3, 5])
344
+
345
+ # 提取关键指标
346
+ result_metrics = {
347
+ "precision_at_1": metrics.precision_at_k.get(1, 0),
348
+ "precision_at_3": metrics.precision_at_k.get(3, 0),
349
+ "precision_at_5": metrics.precision_at_k.get(5, 0),
350
+ "recall_at_1": metrics.recall_at_k.get(1, 0),
351
+ "recall_at_3": metrics.recall_at_k.get(3, 0),
352
+ "recall_at_5": metrics.recall_at_k.get(5, 0),
353
+ "map_score": metrics.map_score,
354
+ "mrr": metrics.mrr,
355
+ "latency": metrics.latency,
356
+ "retrieved_docs_count": len(documents)
357
+ }
358
+
359
+ # 打印评估结果
360
+ print("\n---检索评估结果---")
361
+ print(f"检索耗时: {result_metrics['latency']:.4f}秒")
362
+ print(f"检索文档数: {result_metrics['retrieved_docs_count']}")
363
+ print(f"Precision@1: {result_metrics['precision_at_1']:.4f}")
364
+ print(f"Precision@3: {result_metrics['precision_at_3']:.4f}")
365
+ print(f"Precision@5: {result_metrics['precision_at_5']:.4f}")
366
+ print(f"Recall@1: {result_metrics['recall_at_1']:.4f}")
367
+ print(f"Recall@3: {result_metrics['recall_at_3']:.4f}")
368
+ print(f"Recall@5: {result_metrics['recall_at_5']:.4f}")
369
+ print(f"MAP: {result_metrics['map_score']:.4f}")
370
+ print(f"MRR: {result_metrics['mrr']:.4f}")
371
+ print("--------------------\n")
372
+
373
+ return result_metrics
374
+
375
+ except Exception as e:
376
+ print(f"⚠️ 检索评估失败: {e}")
377
+ return {
378
+ "error": str(e),
379
+ "latency": retrieval_time,
380
+ "retrieved_docs_count": len(documents)
381
+ }
382
+
383
+
384
  def format_docs(docs):
385
  """格式化文档用于显示"""
386
  return "\n\n".join(doc.page_content for doc in docs)