RCaz commited on
Commit
49414cf
·
1 Parent(s): 7ca98bc

reranker from langchain

Browse files
Files changed (2) hide show
  1. app.py +12 -9
  2. requirements.txt +1 -1
app.py CHANGED
@@ -4,7 +4,8 @@
4
  from dotenv import load_dotenv
5
  import os
6
  load_dotenv()
7
-
 
8
 
9
  from langchain.chat_models import init_chat_model
10
 
@@ -102,7 +103,7 @@ def format_source(doc):
102
  page_label = doc.metadata["pagpage_labele"]
103
  total_page = doc.metadata["total_page"]
104
  return f"{source.split('/')[-1]} page({page_label/total_page})"
105
-
106
  # setup chatbot
107
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
108
  from langchain.chat_models import init_chat_model
@@ -153,14 +154,16 @@ def predict(message, history, request: gr.Request):
153
  # Retrieve relevant documents for the current message
154
  relevant_docs = vectorstore.similarity_search(message,k=20) # retriever
155
 
156
- # reranker
157
- from ragatouille import RAGPretrainedModel
158
-
159
- RERANKER = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
 
 
 
 
 
160
 
161
- relevant_docs = RERANKER.rerank(message, relevant_docs, k=10)
162
- relevant_docs = [doc["content"] for doc in relevant_docs]
163
-
164
  # Build context from retrieved documents
165
  context = "\nExtracted documents:\n" + "\n".join([
166
  f"Content document {i+1}: {doc.page_content}\n\n---"
 
4
  from dotenv import load_dotenv
5
  import os
6
  load_dotenv()
7
+ from langchain.retrievers import ContextualCompressionRetriever
8
+ from langchain_community.document_compressors import ColbertReranker
9
 
10
  from langchain.chat_models import init_chat_model
11
 
 
103
  page_label = doc.metadata["pagpage_labele"]
104
  total_page = doc.metadata["total_page"]
105
  return f"{source.split('/')[-1]} page({page_label/total_page})"
106
+
107
  # setup chatbot
108
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
109
  from langchain.chat_models import init_chat_model
 
154
  # Retrieve relevant documents for the current message
155
  relevant_docs = vectorstore.similarity_search(message,k=20) # retriever
156
 
157
+ # reank docs
158
+ reranker = ColbertReranker(
159
+ model_name="colbert-ir/colbertv2.0",
160
+ top_n=10
161
+ )
162
+ relevant_docs = reranker.compress_documents(
163
+ documents=relevant_docs,
164
+ query=message
165
+ )
166
 
 
 
 
167
  # Build context from retrieved documents
168
  context = "\nExtracted documents:\n" + "\n".join([
169
  f"Content document {i+1}: {doc.page_content}\n\n---"
requirements.txt CHANGED
@@ -11,7 +11,7 @@ langchain==0.3.8
11
  langchain-community==0.3.8
12
  langchain-openai==0.2.9
13
  langchain-huggingface==0.1.0
14
- RAGatouille
15
 
16
  gradio
17
 
 
11
  langchain-community==0.3.8
12
  langchain-openai==0.2.9
13
  langchain-huggingface==0.1.0
14
+
15
 
16
  gradio
17