trykopy / retrieval.py
Pavol Liška
v1-fix
3c35194
raw
history blame
2.04 kB
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere.rerank import CohereRerank
from langchain_core.vectorstores import VectorStoreRetriever
def retrieve(embedding, q, retrieve_document_count):
retriever: VectorStoreRetriever = embedding.get_vector_store().as_retriever(
search_type="similarity",
search_kwargs={"k": retrieve_document_count}
)
context_doc = retriever.get_relevant_documents(
query=q,
kwargs={"k": retrieve_document_count}
)
return context_doc
def retrieve_with_rerank(embedding, q, retrieve_document_count):
compression_retriever = reranking_retriever(embedding, retrieve_document_count)
context_doc = compression_retriever.invoke(
input=q,
kwargs={"k": retrieve_document_count}
)
# for doc in context_doc:
# text = doc.page_content
# print(" kontext: " + text.replace('\n', ' ').replace('\r', ' '))
return context_doc
def reranking_retriever(embedding, retrieve_document_count):
retriever: VectorStoreRetriever = embedding.get_vector_store().as_retriever(
search_type="similarity",
search_kwargs={"k": retrieve_document_count * 10}
)
compressor = CohereRerank(model="rerank-multilingual-v3.0")
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
return compression_retriever
# todo
# def hyde(agent: Agent, q, retrieve_document_count):
# retriever: VectorStoreRetriever = agent.embedding.get_vector_store().as_retriever(
# search_type="similarity",
# search_kwargs={"k": retrieve_document_count * 10}
# )
#
# context_doc = compression_retriever.get_relevant_documents(
# query=q,
# kwargs={"k": retrieve_document_count}
# )
#
# for doc in context_doc:
# text = doc.page_content
# print(" kontext: " + text.replace('\n', ' ').replace('\r', ' '))
#
# return context_doc