Zeggai Abdellah commited on
Commit
9e56f1b
Β·
1 Parent(s): cfc25dc

update the seaction tools

Browse files
Files changed (1) hide show
  1. prepare_env.py +35 -8
prepare_env.py CHANGED
@@ -128,9 +128,20 @@ def create_vectorstore_from_json(json_path: str, collection_name: str, embedding
128
  print(f"βœ… Vector store created with collection: {collection_name}")
129
  return vectorstore, documents
130
 
131
- def create_retriever(vectorstore, docs, llm):
132
- """Create ensemble retriever with vector and BM25 search"""
 
 
 
 
 
 
 
 
 
 
133
  print("πŸ” Creating ensemble retriever...")
 
134
  # PromptTemplate for Vaccine Assistant MultiQuery Retriever
135
  VACCINE_MULTIQUERY_PROMPT = PromptTemplate(
136
  input_variables=["question"],
@@ -150,17 +161,28 @@ def create_retriever(vectorstore, docs, llm):
150
 
151
  Provide only the alternative questions, one per line."""
152
  )
 
 
 
 
 
 
 
 
 
 
 
153
  # Vector retriever
154
  vector_retriever = vectorstore.as_retriever(
155
  search_type="similarity",
156
- search_kwargs={"k": 6}
157
  )
158
- print("βœ… Vector retriever created (k=6)")
159
 
160
  # BM25 retriever
161
  bm25_retriever = BM25Retriever.from_documents(docs)
162
- bm25_retriever.k = 3
163
- print("βœ… BM25 retriever created (k=2)")
164
 
165
  # Ensemble retriever
166
  ensemble_retriever = EnsembleRetriever(
@@ -169,7 +191,12 @@ def create_retriever(vectorstore, docs, llm):
169
  )
170
  print("βœ… Ensemble retriever created (weights: 0.5, 0.5)")
171
 
172
- # Multi-query expanding retriever
 
 
 
 
 
173
  expanding_retriever = MultiQueryRetriever.from_llm(
174
  retriever=ensemble_retriever,
175
  llm=llm,
@@ -266,7 +293,7 @@ def create_section_tools(embedding_function, llm):
266
  if os.path.exists(path):
267
  print(f"πŸ“ Creating retriever for section {section} from {path}")
268
  vstore, docs = create_vectorstore_from_json(path, f"Guide_2023_{section}", embedding_function)
269
- section_retrievers[section] = create_retriever(vstore, docs, llm)
270
  print(f"βœ… Successfully created retriever for section {section}")
271
  else:
272
  print(f"⚠️ Warning: File not found for section {section}: {path}")
 
128
  print(f"βœ… Vector store created with collection: {collection_name}")
129
  return vectorstore, documents
130
 
131
+ def create_retriever(vectorstore, docs, llm, get_all: bool = False):
132
+ """Create ensemble retriever with vector and BM25 search
133
+
134
+ Args:
135
+ vectorstore: The vector store for similarity search
136
+ docs: Documents for BM25 retriever
137
+ llm: Language model for multi-query generation
138
+ get_all: If True, configure retriever to return all documents
139
+
140
+ Returns:
141
+ Configured retriever (MultiQueryRetriever or EnsembleRetriever)
142
+ """
143
  print("πŸ” Creating ensemble retriever...")
144
+
145
  # PromptTemplate for Vaccine Assistant MultiQuery Retriever
146
  VACCINE_MULTIQUERY_PROMPT = PromptTemplate(
147
  input_variables=["question"],
 
161
 
162
  Provide only the alternative questions, one per line."""
163
  )
164
+
165
+ # Determine k values based on get_all parameter
166
+ if get_all:
167
+ vector_k = len(docs) # Get all documents
168
+ bm25_k = len(docs) # Get all documents
169
+ print(f"πŸ“„ GET_ALL mode: Setting k={len(docs)} (total documents)")
170
+ else:
171
+ vector_k = 6
172
+ bm25_k = 3
173
+ print(f"🎯 FILTERED mode: Vector k={vector_k}, BM25 k={bm25_k}")
174
+
175
  # Vector retriever
176
  vector_retriever = vectorstore.as_retriever(
177
  search_type="similarity",
178
+ search_kwargs={"k": vector_k}
179
  )
180
+ print(f"βœ… Vector retriever created (k={vector_k})")
181
 
182
  # BM25 retriever
183
  bm25_retriever = BM25Retriever.from_documents(docs)
184
+ bm25_retriever.k = bm25_k
185
+ print(f"βœ… BM25 retriever created (k={bm25_k})")
186
 
187
  # Ensemble retriever
188
  ensemble_retriever = EnsembleRetriever(
 
191
  )
192
  print("βœ… Ensemble retriever created (weights: 0.5, 0.5)")
193
 
194
+ # If get_all is True, return ensemble retriever directly to avoid query processing overhead
195
+ if get_all:
196
+ print("πŸ“‹ Returning ensemble retriever (bypassing MultiQuery for get_all mode)")
197
+ return ensemble_retriever
198
+
199
+ # Multi-query expanding retriever (only for filtered mode)
200
  expanding_retriever = MultiQueryRetriever.from_llm(
201
  retriever=ensemble_retriever,
202
  llm=llm,
 
293
  if os.path.exists(path):
294
  print(f"πŸ“ Creating retriever for section {section} from {path}")
295
  vstore, docs = create_vectorstore_from_json(path, f"Guide_2023_{section}", embedding_function)
296
+ section_retrievers[section] = create_retriever(vstore, docs, llm, get_all=True)
297
  print(f"βœ… Successfully created retriever for section {section}")
298
  else:
299
  print(f"⚠️ Warning: File not found for section {section}: {path}")