""" 检索优化模块 """ import logging from typing import List, Dict, Any from langchain_community.vectorstores import FAISS from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document logger = logging.getLogger(__name__) class RetrievalOptimizationModule: """检索优化模块,混合检索,通过多个检索器检索的结果来进行优化 """ def __init__(self, vectorstore: FAISS, chunks: List[Document]): """初始化检索优化模块 :param vectorstore: 构建好了的向量索引 :type vectorstore: FAISS :param chunks: 文档块列表 :type chunks: List[Document] """ self.vectorstore = vectorstore self.chunks = chunks self.setup_retrievers() def setup_retrievers(self): """构建向量检索器和BM25检索器 """ logger.info("开始构建检索器...") #向量检索器 self.vector_retriever = self.vectorstore.as_retriever( search_type = "similarity", search_kwargs = {"k" : 5} ) #BM25检索器 #这里每次进行排序的时候都需要用到chunks,所以为了避免每次都切分文档加载chunks,直接保存到本地之后从本地加载 self.bm25_retirever = BM25Retriever.from_documents( documents = self.chunks, k = 5 ) logger.info("检索器构建完成") def filtered_hybrid_search(self, query: str, top_k : int = 3, filters : dict = None) -> List[Document]: """使用RRF将两种检索器结果融合,混合检索 :param query: 查询的文本 :type query: str :param top_k: 相似文档个数, defaults to 3 :type top_k: int, optional :param filters: 过滤条件 :type filters: dict :return: 相似文档列表 :rtype: List[Document] """ #这里需要塞入filters,直接再setup_retrievers塞入的话每次询问都需要重新构建检索库 if filters: #如果本次有过滤条件,就临时赛进检索器的配置里面 self.vector_retriever.search_kwargs['filter'] = filters else: #没有过滤条件,要把上一次的残留清空 if "filter" in self.vector_retriever.search_kwargs: del self.vector_retriever.search_kwargs['filter'] #分别获取向量检索器和BM25检索器的结果 vector_docs = self.vector_retriever.invoke(query) raw_bm25_docs = self.bm25_retirever.invoke(query) #先进行应用元数据过滤 bm25_docs = [] if filters: for doc in raw_bm25_docs: match = True #query:["category" : ["各种种类"]]这种格式是需要大模型进行提取的 for key, value in filters.items(): if key in doc.metadata: if isinstance(value, list): if doc.metadata[key] not in value: match = False break else: if doc.metadata[key] != value: match = False break else: match = False break if match: bm25_docs.append(doc) logger.info(f"过滤完成,一共得到{len(vector_docs) + len(bm25_docs)}个文档") else: bm25_docs = raw_bm25_docs #使用RRF重排 reranked_docs = self._rrf_rerank(vector_docs, bm25_docs) return reranked_docs[:top_k] def _rrf_rerank(self, vector_docs: List[Document], bm25_docs: List[Document], k :int = 60) -> List[Document]: """构建RRF算法结构并实现重排 :param vector_docs: 向量检索结果 :type vector_docs: List[Document] :param bm25_docs: BM25检索结果 :type bm25_docs: List[Document] :param k: RRF参数,用于平滑排名, defaults to 60 :type k: int, optional :return: 重排之后的document列表 :rtype: List[Document] """ doc_scores = {} doc_objects = {} #计算向量检索结果的RRF分数 for rank, doc in enumerate(vector_docs): #将文本内容转化为hash值作为文档的id doc_id = hash(doc.page_content) doc_objects[doc_id] = doc #RRF公式:1/(k + rank) rrf_score = 1.0 / (k + rank + 1)#这里没有排名为0的,但是索引有 #这里用累加而不是覆盖,应为下面的BM25计算分数的时候也会用这个字典进行加分 doc_scores[doc_id] = doc_scores.get(doc_id, 0) + rrf_score logger.debug(f"向量检索 - 文档{rank + 1}: RRF分数 = {rrf_score:.4f}") #计算BM25检索结果的RRF分数 for rank, doc in enumerate(bm25_docs): #用hash值生成id doc_id = hash(doc.page_content) doc_objects[doc_id] = doc rrf_score = 1.0 / (k + rank + 1) doc_scores[doc_id] = doc_scores.get(doc_id, 0) + rrf_score logger.debug(f"BM25检索 - 文档{rank + 1}: RRF分数 = {rrf_score:.4f}") #按最终的RRF分数进行排序 sorted_docs = sorted(doc_scores.items(), key = lambda x : x[1], reverse=True)#这里用items将字典打包成元组 #构建最终结果 reranked_docs = [] for doc_id, final_score in sorted_docs: if doc_id in doc_objects: doc = doc_objects[doc_id] #将RRF分数添加到文档的元数据中 doc.metadata["rrf_score"] = final_score reranked_docs.append(doc) logger.debug(f"最终排序 - 文档:{doc.page_content[:50]}... 最终分数为:{final_score:.4f}") logger.info(f"RRF重排结束:向量索引{len(vector_docs)}个文档,BM25检索{len(bm25_docs)}个文档,合并后{len(reranked_docs)}个文档") return reranked_docs