Kshitijk20 commited on
Commit
4bcb337
·
1 Parent(s): 7114101

reranker added

Browse files
app/retrieval/reranker.py CHANGED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from langchain_core.documents import Document
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain_core.output_parsers import StrOutputParser
5
+
6
+ class Reranker:
7
+ """Reranker class to rerank retrieved documents based on relevance to the query."""
8
+ def __init__(self,llm, retrieved_docs:List[Document],query:str) ->List[Document]:
9
+ self.llm = llm
10
+ self.retrieved_docs = retrieved_docs
11
+ self.query = query
12
+
13
+ def rerank_documents(self) -> List[Document]:
14
+ """
15
+ Rerank the retrieved documents based on their relevance to the query.
16
+
17
+ Args:
18
+ retrieved_docs: List of Document objects retrieved from the retriever.
19
+ query: The original user query string.
20
+
21
+ Returns:
22
+ List of Document objects sorted by relevance to the query.
23
+ """
24
+ # Create a prompt template for scoring
25
+ # Prompt Template
26
+ prompt_template = PromptTemplate.from_template("""
27
+ You are a helpful assistant. Your task is to rank the following documents from most to least relevant to the user's question.
28
+
29
+ User Question: "{question}"
30
+
31
+ Documents:
32
+ {documents}
33
+
34
+ Instructions:
35
+ - Think about the relevance of each document to the user's question.
36
+ - Return a list of document indices in ranked order, starting from the most relevant.
37
+
38
+ Output format: comma-separated document indices (e.g., 2,1,3,0,...)
39
+ """)
40
+
41
+ chain=prompt_template | self.llm | StrOutputParser()
42
+ doc_texts = [f"{i+1}. {doc.page_content}" for i,doc in enumerate(self.retrieved_docs)]
43
+ response = chain.invoke({
44
+ "question": self.query,
45
+ "documents": "\n".join(doc_texts)
46
+ })
47
+ ranked_indices = [int(i) for i in response.split(",")]
48
+ return [self.retrieved_docs[i-1] for i in ranked_indices]
app/retrieval/retriever.py CHANGED
@@ -1,5 +1,6 @@
1
 
2
  from langchain.retrievers import EnsembleRetriever
 
3
 
4
  class Retriever:
5
  def __init__(self, pinecone_index, query = None, metadata = None, namespace=None, vectore_store = None,sparse_retriever = None, llm = None):
@@ -66,4 +67,7 @@ class Retriever:
66
  results = self.hybrid_retriever.invoke(self.query)
67
  for doc in results:
68
  print(f"printing Doc content : {doc.page_content}")
 
 
 
69
  return results
 
1
 
2
  from langchain.retrievers import EnsembleRetriever
3
+ from app.retrieval.reranker import Reranker
4
 
5
  class Retriever:
6
  def __init__(self, pinecone_index, query = None, metadata = None, namespace=None, vectore_store = None,sparse_retriever = None, llm = None):
 
67
  results = self.hybrid_retriever.invoke(self.query)
68
  for doc in results:
69
  print(f"printing Doc content : {doc.page_content}")
70
+ if self.llm:
71
+ reranker = Reranker(self.llm, results, self.query)
72
+ results = reranker.rerank_documents()
73
  return results