ImportError: This modeling file requires the following packages that were not found in your environment: einops. Run `pip install einops`

#17
Files changed (2) hide show
  1. app.py +43 -1
  2. requirements.txt +3 -0
app.py CHANGED
@@ -102,7 +102,43 @@ def format_source(doc):
102
  page_label = doc.metadata["pagpage_labele"]
103
  total_page = doc.metadata["total_page"]
104
  return f"{source.split('/')[-1]} page({page_label/total_page})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # setup chatbot
107
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
108
  from langchain.chat_models import init_chat_model
@@ -151,16 +187,22 @@ def predict(message, history, request: gr.Request):
151
 
152
 
153
  # Retrieve relevant documents for the current message
 
154
  relevant_docs = vectorstore.similarity_search(message,k=20) # retriever
155
 
 
 
 
 
 
156
  # Build context from retrieved documents
 
157
  context = "\nExtracted documents:\n" + "\n".join([
158
  f"Content document {i+1}: {doc.page_content}\n\n---"
159
  for i, doc in enumerate(relevant_docs)
160
  ])
161
 
162
 
163
-
164
  # RAG tool
165
  RAG_PROMPT_TEMPLATE="""You will be asked information related to Rémi Cazelles's specific projects, work and education.
166
  Using the information contained in the context, provide a structured answer to the question.
 
102
  page_label = doc.metadata["pagpage_labele"]
103
  total_page = doc.metadata["total_page"]
104
  return f"{source.split('/')[-1]} page({page_label/total_page})"
105
+
106
+
107
+ # reranker
108
+ from sentence_transformers import CrossEncoder
109
+ import numpy as np
110
+ import torch
111
+
112
+ class ProductionReranker:
113
+ def __init__(self, model_name="jinaai/jina-reranker-v2-base-multilingual"):
114
+ self.model = CrossEncoder(
115
+ model_name,
116
+ max_length=512,
117
+ device='cuda' if torch.cuda.is_available() else 'cpu',
118
+ trust_remote_code=True
119
+ )
120
 
121
+ def rerank(self, query, documents, k=5):
122
+ # Extract text
123
+ doc_texts = [
124
+ doc.page_content if hasattr(doc, 'page_content') else str(doc)
125
+ for doc in documents
126
+ ]
127
+
128
+ # Score in batches for efficiency
129
+ pairs = [[query, doc] for doc in doc_texts]
130
+ scores = self.model.predict(pairs, batch_size=32)
131
+
132
+ # Get top-k
133
+ top_indices = np.argsort(scores)[::-1][:k]
134
+
135
+ # Return with scores
136
+ reranked = [(documents[i], float(scores[i])) for i in top_indices]
137
+ return [doc for doc, score in reranked]
138
+
139
+
140
+
141
+
142
  # setup chatbot
143
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
144
  from langchain.chat_models import init_chat_model
 
187
 
188
 
189
  # Retrieve relevant documents for the current message
190
+ print("retreive docs ...")
191
  relevant_docs = vectorstore.similarity_search(message,k=20) # retriever
192
 
193
+ # reank docs
194
+ print("reranking ...")
195
+ RERANKER = ProductionReranker()
196
+ relevant_docs = RERANKER.rerank(message, relevant_docs, k=10)
197
+
198
  # Build context from retrieved documents
199
+ print("build context ...")
200
  context = "\nExtracted documents:\n" + "\n".join([
201
  f"Content document {i+1}: {doc.page_content}\n\n---"
202
  for i, doc in enumerate(relevant_docs)
203
  ])
204
 
205
 
 
206
  # RAG tool
207
  RAG_PROMPT_TEMPLATE="""You will be asked information related to Rémi Cazelles's specific projects, work and education.
208
  Using the information contained in the context, provide a structured answer to the question.
requirements.txt CHANGED
@@ -5,6 +5,8 @@ torchaudio
5
 
6
  sentence-transformers
7
  faiss-cpu
 
 
8
 
9
  langchain-core==0.3.21
10
  langchain==0.3.8
@@ -12,6 +14,7 @@ langchain-community==0.3.8
12
  langchain-openai==0.2.9
13
  langchain-huggingface==0.1.0
14
 
 
15
  gradio
16
 
17
  python-dotenv
 
5
 
6
  sentence-transformers
7
  faiss-cpu
8
+ sentence-transformers>=2.5.0
9
+ einops
10
 
11
  langchain-core==0.3.21
12
  langchain==0.3.8
 
14
  langchain-openai==0.2.9
15
  langchain-huggingface==0.1.0
16
 
17
+
18
  gradio
19
 
20
  python-dotenv