from langchain.retrievers import ContextualCompressionRetriever from langchain_cohere.rerank import CohereRerank from langchain_core.vectorstores import VectorStoreRetriever def retrieve(embedding, q, retrieve_document_count): retriever: VectorStoreRetriever = embedding.get_vector_store().as_retriever( search_type="similarity", search_kwargs={"k": retrieve_document_count} ) context_doc = retriever.get_relevant_documents( query=q, kwargs={"k": retrieve_document_count} ) return context_doc def retrieve_with_rerank(embedding, q, retrieve_document_count): compression_retriever = reranking_retriever(embedding, retrieve_document_count) context_doc = compression_retriever.invoke( input=q, kwargs={"k": retrieve_document_count} ) # for doc in context_doc: # text = doc.page_content # print(" kontext: " + text.replace('\n', ' ').replace('\r', ' ')) return context_doc def reranking_retriever(embedding, retrieve_document_count): retriever: VectorStoreRetriever = embedding.get_vector_store().as_retriever( search_type="similarity", search_kwargs={"k": retrieve_document_count * 10} ) compressor = CohereRerank(model="rerank-multilingual-v3.0") compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever ) return compression_retriever # todo # def hyde(agent: Agent, q, retrieve_document_count): # retriever: VectorStoreRetriever = agent.embedding.get_vector_store().as_retriever( # search_type="similarity", # search_kwargs={"k": retrieve_document_count * 10} # ) # # context_doc = compression_retriever.get_relevant_documents( # query=q, # kwargs={"k": retrieve_document_count} # ) # # for doc in context_doc: # text = doc.page_content # print(" kontext: " + text.replace('\n', ' ').replace('\r', ' ')) # # return context_doc