MANIT_Chat / server /utils /RetrieveQuery.py
WizardCoder2007's picture
update
99b1b18
Raw
History Blame Contribute Delete
1.7 kB
from .BM25_to_Dict import convert_bm25_to_dict
def QueryRetriever(queries,vector_retriever,keyword_retriever,chunks_dict):
documents= []
vector_documents= {}
rrf_scores= {}
k= keyword_retriever.k
final_documents= {}
for query in queries:
vector_results= vector_retriever.retrieve(query,top_k=10)
bm25_results= keyword_retriever.invoke(query)
bm25_results= convert_bm25_to_dict(bm25_results)
for doc in vector_results:
doc_id= doc['metadata']['chunk_id'] # this is dictionary sent by RAGRetreiver
final_documents[doc_id]= doc
curr_score= doc['similarity_score']
if doc_id not in vector_documents or curr_score>vector_documents[doc_id]['score']:
vector_documents[doc_id]= {"doc":doc,"score":curr_score}
for i,doc in enumerate(bm25_results):
chunk_id= doc['id']
final_documents[chunk_id]= doc
if chunk_id in rrf_scores: rrf_scores[chunk_id]+= 1/(k+i+1)
else: rrf_scores[chunk_id]= 1/(k+i+1)
vector_documents= sorted(vector_documents.values(),key=lambda x:x['score'],reverse=True)
for i,item in enumerate(vector_documents):
chunk_id= item['doc']['metadata']['chunk_id']
if chunk_id in rrf_scores: rrf_scores[chunk_id]+= 1/(k+i+1)
else: rrf_scores[chunk_id]= 1/(k+i+1)
# sort on basis of values
rrf_scores= sorted(rrf_scores.items(), key=lambda item: item[1], reverse=True)
# select top 10 documents
for chunk_id,score in rrf_scores[:12]:
if chunk_id in chunks_dict:
documents.append(final_documents[chunk_id])
return documents