Spaces:
Sleeping
Sleeping
Zeggai Abdellah
commited on
Commit
Β·
9e56f1b
1
Parent(s):
cfc25dc
update the seaction tools
Browse files- prepare_env.py +35 -8
prepare_env.py
CHANGED
|
@@ -128,9 +128,20 @@ def create_vectorstore_from_json(json_path: str, collection_name: str, embedding
|
|
| 128 |
print(f"β
Vector store created with collection: {collection_name}")
|
| 129 |
return vectorstore, documents
|
| 130 |
|
| 131 |
-
def create_retriever(vectorstore, docs, llm):
|
| 132 |
-
"""Create ensemble retriever with vector and BM25 search
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
print("π Creating ensemble retriever...")
|
|
|
|
| 134 |
# PromptTemplate for Vaccine Assistant MultiQuery Retriever
|
| 135 |
VACCINE_MULTIQUERY_PROMPT = PromptTemplate(
|
| 136 |
input_variables=["question"],
|
|
@@ -150,17 +161,28 @@ def create_retriever(vectorstore, docs, llm):
|
|
| 150 |
|
| 151 |
Provide only the alternative questions, one per line."""
|
| 152 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
# Vector retriever
|
| 154 |
vector_retriever = vectorstore.as_retriever(
|
| 155 |
search_type="similarity",
|
| 156 |
-
search_kwargs={"k":
|
| 157 |
)
|
| 158 |
-
print("β
Vector retriever created (k=
|
| 159 |
|
| 160 |
# BM25 retriever
|
| 161 |
bm25_retriever = BM25Retriever.from_documents(docs)
|
| 162 |
-
bm25_retriever.k =
|
| 163 |
-
print("β
BM25 retriever created (k=
|
| 164 |
|
| 165 |
# Ensemble retriever
|
| 166 |
ensemble_retriever = EnsembleRetriever(
|
|
@@ -169,7 +191,12 @@ def create_retriever(vectorstore, docs, llm):
|
|
| 169 |
)
|
| 170 |
print("β
Ensemble retriever created (weights: 0.5, 0.5)")
|
| 171 |
|
| 172 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
expanding_retriever = MultiQueryRetriever.from_llm(
|
| 174 |
retriever=ensemble_retriever,
|
| 175 |
llm=llm,
|
|
@@ -266,7 +293,7 @@ def create_section_tools(embedding_function, llm):
|
|
| 266 |
if os.path.exists(path):
|
| 267 |
print(f"π Creating retriever for section {section} from {path}")
|
| 268 |
vstore, docs = create_vectorstore_from_json(path, f"Guide_2023_{section}", embedding_function)
|
| 269 |
-
section_retrievers[section] = create_retriever(vstore, docs, llm)
|
| 270 |
print(f"β
Successfully created retriever for section {section}")
|
| 271 |
else:
|
| 272 |
print(f"β οΈ Warning: File not found for section {section}: {path}")
|
|
|
|
| 128 |
print(f"β
Vector store created with collection: {collection_name}")
|
| 129 |
return vectorstore, documents
|
| 130 |
|
| 131 |
+
def create_retriever(vectorstore, docs, llm, get_all: bool = False):
|
| 132 |
+
"""Create ensemble retriever with vector and BM25 search
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
vectorstore: The vector store for similarity search
|
| 136 |
+
docs: Documents for BM25 retriever
|
| 137 |
+
llm: Language model for multi-query generation
|
| 138 |
+
get_all: If True, configure retriever to return all documents
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Configured retriever (MultiQueryRetriever or EnsembleRetriever)
|
| 142 |
+
"""
|
| 143 |
print("π Creating ensemble retriever...")
|
| 144 |
+
|
| 145 |
# PromptTemplate for Vaccine Assistant MultiQuery Retriever
|
| 146 |
VACCINE_MULTIQUERY_PROMPT = PromptTemplate(
|
| 147 |
input_variables=["question"],
|
|
|
|
| 161 |
|
| 162 |
Provide only the alternative questions, one per line."""
|
| 163 |
)
|
| 164 |
+
|
| 165 |
+
# Determine k values based on get_all parameter
|
| 166 |
+
if get_all:
|
| 167 |
+
vector_k = len(docs) # Get all documents
|
| 168 |
+
bm25_k = len(docs) # Get all documents
|
| 169 |
+
print(f"π GET_ALL mode: Setting k={len(docs)} (total documents)")
|
| 170 |
+
else:
|
| 171 |
+
vector_k = 6
|
| 172 |
+
bm25_k = 3
|
| 173 |
+
print(f"π― FILTERED mode: Vector k={vector_k}, BM25 k={bm25_k}")
|
| 174 |
+
|
| 175 |
# Vector retriever
|
| 176 |
vector_retriever = vectorstore.as_retriever(
|
| 177 |
search_type="similarity",
|
| 178 |
+
search_kwargs={"k": vector_k}
|
| 179 |
)
|
| 180 |
+
print(f"β
Vector retriever created (k={vector_k})")
|
| 181 |
|
| 182 |
# BM25 retriever
|
| 183 |
bm25_retriever = BM25Retriever.from_documents(docs)
|
| 184 |
+
bm25_retriever.k = bm25_k
|
| 185 |
+
print(f"β
BM25 retriever created (k={bm25_k})")
|
| 186 |
|
| 187 |
# Ensemble retriever
|
| 188 |
ensemble_retriever = EnsembleRetriever(
|
|
|
|
| 191 |
)
|
| 192 |
print("β
Ensemble retriever created (weights: 0.5, 0.5)")
|
| 193 |
|
| 194 |
+
# If get_all is True, return ensemble retriever directly to avoid query processing overhead
|
| 195 |
+
if get_all:
|
| 196 |
+
print("π Returning ensemble retriever (bypassing MultiQuery for get_all mode)")
|
| 197 |
+
return ensemble_retriever
|
| 198 |
+
|
| 199 |
+
# Multi-query expanding retriever (only for filtered mode)
|
| 200 |
expanding_retriever = MultiQueryRetriever.from_llm(
|
| 201 |
retriever=ensemble_retriever,
|
| 202 |
llm=llm,
|
|
|
|
| 293 |
if os.path.exists(path):
|
| 294 |
print(f"π Creating retriever for section {section} from {path}")
|
| 295 |
vstore, docs = create_vectorstore_from_json(path, f"Guide_2023_{section}", embedding_function)
|
| 296 |
+
section_retrievers[section] = create_retriever(vstore, docs, llm, get_all=True)
|
| 297 |
print(f"β
Successfully created retriever for section {section}")
|
| 298 |
else:
|
| 299 |
print(f"β οΈ Warning: File not found for section {section}: {path}")
|