File size: 13,219 Bytes
5ad083c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94a7032
 
 
 
3141e61
 
 
5ad083c
 
 
 
94a7032
 
 
 
 
 
 
 
 
 
 
 
 
 
5ad083c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
"""
自适应RAG系统检索效果评估脚本
评估不同检索策略和配置的效果
"""

import os
import sys
import time
import json
import argparse
from typing import List, Dict, Any, Optional
from dotenv import load_dotenv

# 加载环境变量
load_dotenv()

# 导入项目模块
from main import AdaptiveRAGSystem
from document_processor import DocumentProcessor
from retrieval_evaluation import RetrievalEvaluator, RetrievalResult, RetrievalTestSet
try:
    from langchain_core.documents import Document
except ImportError:
    try:
        from langchain_core.documents import Document
    except ImportError:
        from langchain.schema import Document

# 导入LangChain相关模块
from langchain_community.vectorstores import FAISS, Chroma
from langchain_community.retrievers import BM25Retriever
try:
    from langchain.retrievers import EnsembleRetriever
    from langchain.retrievers import ContextualCompressionRetriever
    from langchain.retrievers.document_compressors import LLMChainExtractor
except ImportError:
    try:
        from langchain_core.retrievers import EnsembleRetriever
        from langchain_core.retrievers import ContextualCompressionRetriever
        from langchain.retrievers.document_compressors import LLMChainExtractor
    except ImportError:
        print("Warning: Could not import advanced retriever components. Some features may be limited.")
        EnsembleRetriever = None
        ContextualCompressionRetriever = None
        LLMChainExtractor = None


class AdaptiveRAGRetriever:
    """自适应RAG系统检索器包装器"""
    
    def __init__(self, system_config: Dict[str, Any], retriever_type: str = "default"):
        """
        初始化检索器
        
        Args:
            system_config: 系统配置
            retriever_type: 检索器类型
        """
        self.system_config = system_config
        self.retriever_type = retriever_type
        self.system = None
        self._initialize_system()
    
    def _initialize_system(self):
        """初始化RAG系统"""
        try:
            # 根据检索器类型调整配置
            config = self.system_config.copy()
            
            if self.retriever_type == "vector_only":
                config["retrieval_strategy"] = "vector"
            elif self.retriever_type == "bm25_only":
                config["retrieval_strategy"] = "bm25"
            elif self.retriever_type == "hybrid":
                config["retrieval_strategy"] = "hybrid"
            elif self.retriever_type == "graph":
                config["retrieval_strategy"] = "graph"
            elif self.retriever_type == "compression":
                config["use_compression"] = True
            elif self.retriever_type == "rerank":
                config["use_reranking"] = True
            elif self.retriever_type == "query_expansion":
                config["use_query_expansion"] = True
            
            # 创建系统实例
            self.system = AdaptiveRAGSystem(config)
            
            # 初始化文档处理器(如果需要)
            if not hasattr(self.system, 'document_processor') or self.system.document_processor is None:
                self.system.document_processor = DocumentProcessor(config)
            
        except Exception as e:
            print(f"初始化RAG系统失败: {e}")
            raise
    
    def retrieve(self, query: str, top_k: int = 10) -> List[Document]:
        """
        检索文档
        
        Args:
            query: 查询文本
            top_k: 返回的文档数量
            
        Returns:
            检索到的文档列表
        """
        try:
            # 使用系统的检索方法
            if hasattr(self.system, 'retrieve'):
                docs = self.system.retrieve(query, top_k)
            else:
                # 如果没有直接的retrieve方法,尝试通过文档处理器检索
                if self.system.document_processor:
                    docs = self.system.document_processor.retrieve(query, top_k)
                else:
                    raise ValueError("无法找到检索方法")
            
            return docs[:top_k]
        except Exception as e:
            print(f"检索失败: {e}")
            return []


def create_evaluation_dataset(data_dir: str = "data", num_queries: int = 20) -> RetrievalTestSet:
    """
    从项目数据创建评估数据集
    
    Args:
        data_dir: 数据目录
        num_queries: 查询数量
        
    Returns:
        检索测试集
    """
    # 检查数据目录
    if not os.path.exists(data_dir):
        print(f"数据目录 {data_dir} 不存在,创建示例数据集")
        from retrieval_evaluation import create_sample_test_set
        return create_sample_test_set()
    
    # 尝试从现有数据创建测试集
    try:
        # 加载文档
        documents = []
        doc_files = []
        
        # 查找所有文本文件
        for root, dirs, files in os.walk(data_dir):
            for file in files:
                if file.endswith('.txt') or file.endswith('.md'):
                    doc_files.append(os.path.join(root, file))
        
        # 如果没有找到文档文件,创建示例数据集
        if not doc_files:
            print(f"在 {data_dir} 中未找到文档文件,创建示例数据集")
            from retrieval_evaluation import create_sample_test_set
            return create_sample_test_set()
        
        # 读取文档内容
        for i, file_path in enumerate(doc_files):
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read().strip()
                if content:
                    documents.append(Document(page_content=content, metadata={"source": file_path, "doc_id": str(i)}))
        
        # 生成查询(这里简化处理,实际应用中应该使用真实查询)
        queries = []
        qrels = {}
        
        # 从文档中提取关键句子作为查询
        for i in range(min(num_queries, len(documents))):
            doc = documents[i]
            sentences = doc.page_content.split('.')
            if sentences:
                # 取第一个非空句子作为查询
                for sentence in sentences:
                    sentence = sentence.strip()
                    if sentence and len(sentence) > 10:  # 确保查询有足够长度
                        queries.append(sentence)
                        # 假设查询与当前文档相关
                        qrels[str(i)] = {str(i): 2}  # 高度相关
                        # 可能与其他文档也相关
                        for j in range(min(3, len(documents))):
                            if j != i:
                                qrels[str(i)][str(j)] = 1  # 部分相关
                        break
        
        # 保存查询文件
        with open("eval_queries.txt", "w", encoding="utf-8") as f:
            for query in queries:
                f.write(query + "\n")
        
        # 保存文档文件
        with open("eval_documents.txt", "w", encoding="utf-8") as f:
            for doc in documents:
                f.write(doc.page_content + "\n")
        
        # 保存相关性标注文件
        with open("eval_qrels.csv", "w", encoding="utf-8") as f:
            for query_id, doc_relevance in qrels.items():
                for doc_id, relevance in doc_relevance.items():
                    f.write(f"{query_id},{doc_id},{relevance}\n")
        
        print(f"评估数据集已创建:")
        print(f"- 查询数量: {len(queries)}")
        print(f"- 文档数量: {len(documents)}")
        print(f"- eval_queries.txt: 查询文件")
        print(f"- eval_documents.txt: 文档文件")
        print(f"- eval_qrels.csv: 相关性标注文件")
        
        return RetrievalTestSet("eval_queries.txt", "eval_documents.txt", "eval_qrels.csv")
    
    except Exception as e:
        print(f"创建评估数据集失败: {e}")
        print("创建示例数据集")
        from retrieval_evaluation import create_sample_test_set
        return create_sample_test_set()


def evaluate_retrievers(system_config: Dict[str, Any], 
                       retriever_types: List[str],
                       test_set: RetrievalTestSet,
                       output_dir: str = "evaluation_results") -> Dict[str, Any]:
    """
    评估多个检索器
    
    Args:
        system_config: 系统配置
        retriever_types: 检索器类型列表
        test_set: 测试集
        output_dir: 输出目录
        
    Returns:
        评估结果
    """
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 初始化评估器
    evaluator = RetrievalEvaluator()
    
    # 存储所有检索结果
    all_results = {}
    
    # 评估每个检索器
    for retriever_type in retriever_types:
        print(f"\n评估检索器: {retriever_type}")
        print("=" * 50)
        
        try:
            # 创建检索器
            retriever = AdaptiveRAGRetriever(system_config, retriever_type)
            
            # 获取检索结果
            results = test_set.get_retrieval_results(retriever)
            all_results[retriever_type] = results
            
            print(f"完成 {len(results)} 个查询的检索")
            
        except Exception as e:
            print(f"评估检索器 {retriever_type} 失败: {e}")
            continue
    
    # 比较检索器
    if len(all_results) > 1:
        print("\n比较检索器性能")
        print("=" * 50)
        metrics = evaluator.compare_retrievers(all_results)
        
        # 生成报告
        report = evaluator.generate_report(
            metrics, 
            os.path.join(output_dir, "retrieval_evaluation_report.md")
        )
        
        # 绘制比较图
        evaluator.plot_metrics_comparison(
            metrics, 
            os.path.join(output_dir, "retrieval_evaluation_comparison.png")
        )
        
        # 保存详细指标
        metrics_data = {}
        for name, metric in metrics.items():
            metrics_data[name] = {
                "precision_at_k": metric.precision_at_k,
                "recall_at_k": metric.recall_at_k,
                "f1_at_k": metric.f1_at_k,
                "map_score": metric.map_score,
                "mrr": metric.mrr,
                "ndcg_at_k": metric.ndcg_at_k,
                "coverage": metric.coverage,
                "diversity": metric.diversity,
                "novelty": metric.novelty,
                "latency": metric.latency
            }
        
        with open(os.path.join(output_dir, "metrics.json"), "w", encoding="utf-8") as f:
            json.dump(metrics_data, f, indent=2, ensure_ascii=False)
        
        return {
            "metrics": metrics,
            "metrics_data": metrics_data,
            "report": report,
            "results": all_results
        }
    else:
        print("只有一个检索器成功评估,跳过比较")
        return {"results": all_results}


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="评估自适应RAG系统的检索效果")
    parser.add_argument("--config", type=str, default="config.py", help="配置文件路径")
    parser.add_argument("--data_dir", type=str, default="data", help="数据目录")
    parser.add_argument("--output_dir", type=str, default="evaluation_results", help="输出目录")
    parser.add_argument("--num_queries", type=int, default=20, help="查询数量")
    parser.add_argument("--retrievers", nargs="+", 
                       default=["default", "vector_only", "bm25_only", "hybrid"],
                       help="要评估的检索器类型")
    
    args = parser.parse_args()
    
    # 加载配置
    try:
        if args.config.endswith('.py'):
            # 动态导入Python配置文件
            import importlib.util
            spec = importlib.util.spec_from_file_location("config", args.config)
            config_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(config_module)
            system_config = config_module.config
        else:
            # 加载JSON配置文件
            with open(args.config, 'r', encoding='utf-8') as f:
                system_config = json.load(f)
    except Exception as e:
        print(f"加载配置文件失败: {e}")
        print("使用默认配置")
        system_config = {
            "model_name": "gpt-3.5-turbo",
            "vector_store": "faiss",
            "retrieval_strategy": "hybrid",
            "use_reranking": False,
            "use_compression": False,
            "use_query_expansion": False
        }
    
    # 创建评估数据集
    print("创建评估数据集")
    test_set = create_evaluation_dataset(args.data_dir, args.num_queries)
    
    # 评估检索器
    print("\n开始评估检索器")
    results = evaluate_retrievers(system_config, args.retrievers, test_set, args.output_dir)
    
    print("\n评估完成!")
    print(f"结果保存在: {args.output_dir}")


if __name__ == "__main__":
    main()