""" 自适应RAG系统检索效果评估脚本 评估不同检索策略和配置的效果 """ import os import sys import time import json import argparse from typing import List, Dict, Any, Optional from dotenv import load_dotenv # 加载环境变量 load_dotenv() # 导入项目模块 from main import AdaptiveRAGSystem from document_processor import DocumentProcessor from retrieval_evaluation import RetrievalEvaluator, RetrievalResult, RetrievalTestSet try: from langchain_core.documents import Document except ImportError: try: from langchain_core.documents import Document except ImportError: from langchain.schema import Document # 导入LangChain相关模块 from langchain_community.vectorstores import FAISS, Chroma from langchain_community.retrievers import BM25Retriever try: from langchain.retrievers import EnsembleRetriever from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import LLMChainExtractor except ImportError: try: from langchain_core.retrievers import EnsembleRetriever from langchain_core.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import LLMChainExtractor except ImportError: print("Warning: Could not import advanced retriever components. Some features may be limited.") EnsembleRetriever = None ContextualCompressionRetriever = None LLMChainExtractor = None class AdaptiveRAGRetriever: """自适应RAG系统检索器包装器""" def __init__(self, system_config: Dict[str, Any], retriever_type: str = "default"): """ 初始化检索器 Args: system_config: 系统配置 retriever_type: 检索器类型 """ self.system_config = system_config self.retriever_type = retriever_type self.system = None self._initialize_system() def _initialize_system(self): """初始化RAG系统""" try: # 根据检索器类型调整配置 config = self.system_config.copy() if self.retriever_type == "vector_only": config["retrieval_strategy"] = "vector" elif self.retriever_type == "bm25_only": config["retrieval_strategy"] = "bm25" elif self.retriever_type == "hybrid": config["retrieval_strategy"] = "hybrid" elif self.retriever_type == "graph": config["retrieval_strategy"] = "graph" elif self.retriever_type == "compression": config["use_compression"] = True elif self.retriever_type == "rerank": config["use_reranking"] = True elif self.retriever_type == "query_expansion": config["use_query_expansion"] = True # 创建系统实例 self.system = AdaptiveRAGSystem(config) # 初始化文档处理器(如果需要) if not hasattr(self.system, 'document_processor') or self.system.document_processor is None: self.system.document_processor = DocumentProcessor(config) except Exception as e: print(f"初始化RAG系统失败: {e}") raise def retrieve(self, query: str, top_k: int = 10) -> List[Document]: """ 检索文档 Args: query: 查询文本 top_k: 返回的文档数量 Returns: 检索到的文档列表 """ try: # 使用系统的检索方法 if hasattr(self.system, 'retrieve'): docs = self.system.retrieve(query, top_k) else: # 如果没有直接的retrieve方法,尝试通过文档处理器检索 if self.system.document_processor: docs = self.system.document_processor.retrieve(query, top_k) else: raise ValueError("无法找到检索方法") return docs[:top_k] except Exception as e: print(f"检索失败: {e}") return [] def create_evaluation_dataset(data_dir: str = "data", num_queries: int = 20) -> RetrievalTestSet: """ 从项目数据创建评估数据集 Args: data_dir: 数据目录 num_queries: 查询数量 Returns: 检索测试集 """ # 检查数据目录 if not os.path.exists(data_dir): print(f"数据目录 {data_dir} 不存在,创建示例数据集") from retrieval_evaluation import create_sample_test_set return create_sample_test_set() # 尝试从现有数据创建测试集 try: # 加载文档 documents = [] doc_files = [] # 查找所有文本文件 for root, dirs, files in os.walk(data_dir): for file in files: if file.endswith('.txt') or file.endswith('.md'): doc_files.append(os.path.join(root, file)) # 如果没有找到文档文件,创建示例数据集 if not doc_files: print(f"在 {data_dir} 中未找到文档文件,创建示例数据集") from retrieval_evaluation import create_sample_test_set return create_sample_test_set() # 读取文档内容 for i, file_path in enumerate(doc_files): with open(file_path, 'r', encoding='utf-8') as f: content = f.read().strip() if content: documents.append(Document(page_content=content, metadata={"source": file_path, "doc_id": str(i)})) # 生成查询(这里简化处理,实际应用中应该使用真实查询) queries = [] qrels = {} # 从文档中提取关键句子作为查询 for i in range(min(num_queries, len(documents))): doc = documents[i] sentences = doc.page_content.split('.') if sentences: # 取第一个非空句子作为查询 for sentence in sentences: sentence = sentence.strip() if sentence and len(sentence) > 10: # 确保查询有足够长度 queries.append(sentence) # 假设查询与当前文档相关 qrels[str(i)] = {str(i): 2} # 高度相关 # 可能与其他文档也相关 for j in range(min(3, len(documents))): if j != i: qrels[str(i)][str(j)] = 1 # 部分相关 break # 保存查询文件 with open("eval_queries.txt", "w", encoding="utf-8") as f: for query in queries: f.write(query + "\n") # 保存文档文件 with open("eval_documents.txt", "w", encoding="utf-8") as f: for doc in documents: f.write(doc.page_content + "\n") # 保存相关性标注文件 with open("eval_qrels.csv", "w", encoding="utf-8") as f: for query_id, doc_relevance in qrels.items(): for doc_id, relevance in doc_relevance.items(): f.write(f"{query_id},{doc_id},{relevance}\n") print(f"评估数据集已创建:") print(f"- 查询数量: {len(queries)}") print(f"- 文档数量: {len(documents)}") print(f"- eval_queries.txt: 查询文件") print(f"- eval_documents.txt: 文档文件") print(f"- eval_qrels.csv: 相关性标注文件") return RetrievalTestSet("eval_queries.txt", "eval_documents.txt", "eval_qrels.csv") except Exception as e: print(f"创建评估数据集失败: {e}") print("创建示例数据集") from retrieval_evaluation import create_sample_test_set return create_sample_test_set() def evaluate_retrievers(system_config: Dict[str, Any], retriever_types: List[str], test_set: RetrievalTestSet, output_dir: str = "evaluation_results") -> Dict[str, Any]: """ 评估多个检索器 Args: system_config: 系统配置 retriever_types: 检索器类型列表 test_set: 测试集 output_dir: 输出目录 Returns: 评估结果 """ # 创建输出目录 os.makedirs(output_dir, exist_ok=True) # 初始化评估器 evaluator = RetrievalEvaluator() # 存储所有检索结果 all_results = {} # 评估每个检索器 for retriever_type in retriever_types: print(f"\n评估检索器: {retriever_type}") print("=" * 50) try: # 创建检索器 retriever = AdaptiveRAGRetriever(system_config, retriever_type) # 获取检索结果 results = test_set.get_retrieval_results(retriever) all_results[retriever_type] = results print(f"完成 {len(results)} 个查询的检索") except Exception as e: print(f"评估检索器 {retriever_type} 失败: {e}") continue # 比较检索器 if len(all_results) > 1: print("\n比较检索器性能") print("=" * 50) metrics = evaluator.compare_retrievers(all_results) # 生成报告 report = evaluator.generate_report( metrics, os.path.join(output_dir, "retrieval_evaluation_report.md") ) # 绘制比较图 evaluator.plot_metrics_comparison( metrics, os.path.join(output_dir, "retrieval_evaluation_comparison.png") ) # 保存详细指标 metrics_data = {} for name, metric in metrics.items(): metrics_data[name] = { "precision_at_k": metric.precision_at_k, "recall_at_k": metric.recall_at_k, "f1_at_k": metric.f1_at_k, "map_score": metric.map_score, "mrr": metric.mrr, "ndcg_at_k": metric.ndcg_at_k, "coverage": metric.coverage, "diversity": metric.diversity, "novelty": metric.novelty, "latency": metric.latency } with open(os.path.join(output_dir, "metrics.json"), "w", encoding="utf-8") as f: json.dump(metrics_data, f, indent=2, ensure_ascii=False) return { "metrics": metrics, "metrics_data": metrics_data, "report": report, "results": all_results } else: print("只有一个检索器成功评估,跳过比较") return {"results": all_results} def main(): """主函数""" parser = argparse.ArgumentParser(description="评估自适应RAG系统的检索效果") parser.add_argument("--config", type=str, default="config.py", help="配置文件路径") parser.add_argument("--data_dir", type=str, default="data", help="数据目录") parser.add_argument("--output_dir", type=str, default="evaluation_results", help="输出目录") parser.add_argument("--num_queries", type=int, default=20, help="查询数量") parser.add_argument("--retrievers", nargs="+", default=["default", "vector_only", "bm25_only", "hybrid"], help="要评估的检索器类型") args = parser.parse_args() # 加载配置 try: if args.config.endswith('.py'): # 动态导入Python配置文件 import importlib.util spec = importlib.util.spec_from_file_location("config", args.config) config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(config_module) system_config = config_module.config else: # 加载JSON配置文件 with open(args.config, 'r', encoding='utf-8') as f: system_config = json.load(f) except Exception as e: print(f"加载配置文件失败: {e}") print("使用默认配置") system_config = { "model_name": "gpt-3.5-turbo", "vector_store": "faiss", "retrieval_strategy": "hybrid", "use_reranking": False, "use_compression": False, "use_query_expansion": False } # 创建评估数据集 print("创建评估数据集") test_set = create_evaluation_dataset(args.data_dir, args.num_queries) # 评估检索器 print("\n开始评估检索器") results = evaluate_retrievers(system_config, args.retrievers, test_set, args.output_dir) print("\n评估完成!") print(f"结果保存在: {args.output_dir}") if __name__ == "__main__": main()