Spaces:
Running
Running
Update custom_utils.py
Browse files- custom_utils.py +37 -16
custom_utils.py
CHANGED
|
@@ -25,6 +25,7 @@ def rag_retrieval_naive(openai_api_key,
|
|
| 25 |
collection,
|
| 26 |
vector_index="vector_index"):
|
| 27 |
# Naive RAG: Semantic search
|
|
|
|
| 28 |
retrieval_result = vector_search_naive(
|
| 29 |
openai_api_key,
|
| 30 |
prompt,
|
|
@@ -36,9 +37,7 @@ def rag_retrieval_naive(openai_api_key,
|
|
| 36 |
if not retrieval_result:
|
| 37 |
return "No results found."
|
| 38 |
|
| 39 |
-
|
| 40 |
-
#print(retrieval_result)
|
| 41 |
-
#print("###")
|
| 42 |
|
| 43 |
return retrieval_result
|
| 44 |
|
|
@@ -52,12 +51,14 @@ def rag_retrieval_advanced(openai_api_key,
|
|
| 52 |
# Advanced RAG: Semantic search plus...
|
| 53 |
|
| 54 |
# 1a) Pre-retrieval processing: index filter (accomodates, bedrooms) plus...
|
| 55 |
-
|
| 56 |
# 1b) Post-retrieval processing: result filter (accomodates, bedrooms) plus...
|
| 57 |
-
|
| 58 |
-
# 2) Weighted average review, sorted in descending order
|
| 59 |
|
| 60 |
-
additional_stages = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
retrieval_result = vector_search_advanced(
|
| 63 |
openai_api_key,
|
|
@@ -73,9 +74,7 @@ def rag_retrieval_advanced(openai_api_key,
|
|
| 73 |
if not retrieval_result:
|
| 74 |
return "No results found."
|
| 75 |
|
| 76 |
-
|
| 77 |
-
#print(retrieval_result)
|
| 78 |
-
#print("###")
|
| 79 |
|
| 80 |
return retrieval_result
|
| 81 |
|
|
@@ -140,7 +139,7 @@ def vector_search_naive(openai_api_key,
|
|
| 140 |
}
|
| 141 |
}
|
| 142 |
|
| 143 |
-
pipeline = [vector_search_stage,
|
| 144 |
|
| 145 |
return invoke_search(db, collection, pipeline)
|
| 146 |
|
|
@@ -157,7 +156,7 @@ def vector_search_advanced(openai_api_key,
|
|
| 157 |
if query_embedding is None:
|
| 158 |
return "Invalid query or embedding generation failed."
|
| 159 |
|
| 160 |
-
|
| 161 |
"$vectorSearch": {
|
| 162 |
"index": vector_index,
|
| 163 |
"queryVector": query_embedding,
|
|
@@ -173,16 +172,38 @@ def vector_search_advanced(openai_api_key,
|
|
| 173 |
}
|
| 174 |
}
|
| 175 |
|
| 176 |
-
pipeline = [
|
| 177 |
|
| 178 |
return invoke_search(db, collection, pipeline)
|
| 179 |
|
| 180 |
-
def
|
| 181 |
return {
|
| 182 |
"$unset": "description_embedding"
|
| 183 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
-
def
|
| 186 |
return {
|
| 187 |
"$match": {
|
| 188 |
"accommodates": { "$eq": 2},
|
|
@@ -190,7 +211,7 @@ def get_result_filter_stage():
|
|
| 190 |
}
|
| 191 |
}
|
| 192 |
|
| 193 |
-
def
|
| 194 |
return {
|
| 195 |
"$addFields": {
|
| 196 |
"averageReview": {
|
|
|
|
| 25 |
collection,
|
| 26 |
vector_index="vector_index"):
|
| 27 |
# Naive RAG: Semantic search
|
| 28 |
+
|
| 29 |
retrieval_result = vector_search_naive(
|
| 30 |
openai_api_key,
|
| 31 |
prompt,
|
|
|
|
| 37 |
if not retrieval_result:
|
| 38 |
return "No results found."
|
| 39 |
|
| 40 |
+
print(retrieval_result)
|
|
|
|
|
|
|
| 41 |
|
| 42 |
return retrieval_result
|
| 43 |
|
|
|
|
| 51 |
# Advanced RAG: Semantic search plus...
|
| 52 |
|
| 53 |
# 1a) Pre-retrieval processing: index filter (accomodates, bedrooms) plus...
|
|
|
|
| 54 |
# 1b) Post-retrieval processing: result filter (accomodates, bedrooms) plus...
|
| 55 |
+
# 2) Weighted average review, sorted in descending order
|
|
|
|
| 56 |
|
| 57 |
+
additional_stages = [
|
| 58 |
+
get_average_review_and_review_count_stage(),
|
| 59 |
+
get_weighting_stage(),
|
| 60 |
+
get_sorting_stage()
|
| 61 |
+
]
|
| 62 |
|
| 63 |
retrieval_result = vector_search_advanced(
|
| 64 |
openai_api_key,
|
|
|
|
| 74 |
if not retrieval_result:
|
| 75 |
return "No results found."
|
| 76 |
|
| 77 |
+
print(retrieval_result)
|
|
|
|
|
|
|
| 78 |
|
| 79 |
return retrieval_result
|
| 80 |
|
|
|
|
| 139 |
}
|
| 140 |
}
|
| 141 |
|
| 142 |
+
pipeline = [vector_search_stage, get_remove_fields_stage()]
|
| 143 |
|
| 144 |
return invoke_search(db, collection, pipeline)
|
| 145 |
|
|
|
|
| 156 |
if query_embedding is None:
|
| 157 |
return "Invalid query or embedding generation failed."
|
| 158 |
|
| 159 |
+
vector_search_and_filter_stage = {
|
| 160 |
"$vectorSearch": {
|
| 161 |
"index": vector_index,
|
| 162 |
"queryVector": query_embedding,
|
|
|
|
| 172 |
}
|
| 173 |
}
|
| 174 |
|
| 175 |
+
pipeline = [vector_search_and_filter_stage, get_remove_fields_stage()] + additional_stages
|
| 176 |
|
| 177 |
return invoke_search(db, collection, pipeline)
|
| 178 |
|
| 179 |
+
def get_remove_fields_stage():
|
| 180 |
return {
|
| 181 |
"$unset": "description_embedding"
|
| 182 |
}
|
| 183 |
+
|
| 184 |
+
def get_project_fields_stage():
|
| 185 |
+
return {
|
| 186 |
+
"$project": {
|
| 187 |
+
"_id": 0,
|
| 188 |
+
"name": 1,
|
| 189 |
+
"accommodates": 1,
|
| 190 |
+
"address.street": 1,
|
| 191 |
+
"address.government_area": 1,
|
| 192 |
+
"address.market": 1,
|
| 193 |
+
"address.country": 1,
|
| 194 |
+
"address.country_code": 1,
|
| 195 |
+
"address.location.type": 1,
|
| 196 |
+
"address.location.coordinates": 1,
|
| 197 |
+
"address.location.is_location_exact": 1,
|
| 198 |
+
"summary": 1,
|
| 199 |
+
"space": 1,
|
| 200 |
+
"neighborhood_overview": 1,
|
| 201 |
+
"notes": 1,
|
| 202 |
+
"score": {"$meta": "vectorSearchScore"}
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
|
| 206 |
+
def get_filter_result_stage():
|
| 207 |
return {
|
| 208 |
"$match": {
|
| 209 |
"accommodates": { "$eq": 2},
|
|
|
|
| 211 |
}
|
| 212 |
}
|
| 213 |
|
| 214 |
+
def get_average_review_and_review_count_stage():
|
| 215 |
return {
|
| 216 |
"$addFields": {
|
| 217 |
"averageReview": {
|