bstraehle commited on
Commit
bb2e1f2
·
verified ·
1 Parent(s): c90c1d1

Update custom_utils.py

Browse files
Files changed (1) hide show
  1. custom_utils.py +37 -16
custom_utils.py CHANGED
@@ -25,6 +25,7 @@ def rag_retrieval_naive(openai_api_key,
25
  collection,
26
  vector_index="vector_index"):
27
  # Naive RAG: Semantic search
 
28
  retrieval_result = vector_search_naive(
29
  openai_api_key,
30
  prompt,
@@ -36,9 +37,7 @@ def rag_retrieval_naive(openai_api_key,
36
  if not retrieval_result:
37
  return "No results found."
38
 
39
- #print("###")
40
- #print(retrieval_result)
41
- #print("###")
42
 
43
  return retrieval_result
44
 
@@ -52,12 +51,14 @@ def rag_retrieval_advanced(openai_api_key,
52
  # Advanced RAG: Semantic search plus...
53
 
54
  # 1a) Pre-retrieval processing: index filter (accomodates, bedrooms) plus...
55
-
56
  # 1b) Post-retrieval processing: result filter (accomodates, bedrooms) plus...
57
-
58
- # 2) Weighted average review, sorted in descending order
59
 
60
- additional_stages = [get_average_review_stage(), get_weighting_stage(), get_sorting_stage()]
 
 
 
 
61
 
62
  retrieval_result = vector_search_advanced(
63
  openai_api_key,
@@ -73,9 +74,7 @@ def rag_retrieval_advanced(openai_api_key,
73
  if not retrieval_result:
74
  return "No results found."
75
 
76
- #print("###")
77
- #print(retrieval_result)
78
- #print("###")
79
 
80
  return retrieval_result
81
 
@@ -140,7 +139,7 @@ def vector_search_naive(openai_api_key,
140
  }
141
  }
142
 
143
- pipeline = [vector_search_stage, get_remove_embedding_stage()]
144
 
145
  return invoke_search(db, collection, pipeline)
146
 
@@ -157,7 +156,7 @@ def vector_search_advanced(openai_api_key,
157
  if query_embedding is None:
158
  return "Invalid query or embedding generation failed."
159
 
160
- vector_search_filter_stage = {
161
  "$vectorSearch": {
162
  "index": vector_index,
163
  "queryVector": query_embedding,
@@ -173,16 +172,38 @@ def vector_search_advanced(openai_api_key,
173
  }
174
  }
175
 
176
- pipeline = [vector_search_filter_stage, get_remove_embedding_stage()] + additional_stages
177
 
178
  return invoke_search(db, collection, pipeline)
179
 
180
- def get_remove_embedding_stage():
181
  return {
182
  "$unset": "description_embedding"
183
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- def get_result_filter_stage():
186
  return {
187
  "$match": {
188
  "accommodates": { "$eq": 2},
@@ -190,7 +211,7 @@ def get_result_filter_stage():
190
  }
191
  }
192
 
193
- def get_average_review_stage():
194
  return {
195
  "$addFields": {
196
  "averageReview": {
 
25
  collection,
26
  vector_index="vector_index"):
27
  # Naive RAG: Semantic search
28
+
29
  retrieval_result = vector_search_naive(
30
  openai_api_key,
31
  prompt,
 
37
  if not retrieval_result:
38
  return "No results found."
39
 
40
+ print(retrieval_result)
 
 
41
 
42
  return retrieval_result
43
 
 
51
  # Advanced RAG: Semantic search plus...
52
 
53
  # 1a) Pre-retrieval processing: index filter (accomodates, bedrooms) plus...
 
54
  # 1b) Post-retrieval processing: result filter (accomodates, bedrooms) plus...
55
+ # 2) Weighted average review, sorted in descending order
 
56
 
57
+ additional_stages = [
58
+ get_average_review_and_review_count_stage(),
59
+ get_weighting_stage(),
60
+ get_sorting_stage()
61
+ ]
62
 
63
  retrieval_result = vector_search_advanced(
64
  openai_api_key,
 
74
  if not retrieval_result:
75
  return "No results found."
76
 
77
+ print(retrieval_result)
 
 
78
 
79
  return retrieval_result
80
 
 
139
  }
140
  }
141
 
142
+ pipeline = [vector_search_stage, get_remove_fields_stage()]
143
 
144
  return invoke_search(db, collection, pipeline)
145
 
 
156
  if query_embedding is None:
157
  return "Invalid query or embedding generation failed."
158
 
159
+ vector_search_and_filter_stage = {
160
  "$vectorSearch": {
161
  "index": vector_index,
162
  "queryVector": query_embedding,
 
172
  }
173
  }
174
 
175
+ pipeline = [vector_search_and_filter_stage, get_remove_fields_stage()] + additional_stages
176
 
177
  return invoke_search(db, collection, pipeline)
178
 
179
+ def get_remove_fields_stage():
180
  return {
181
  "$unset": "description_embedding"
182
  }
183
+
184
+ def get_project_fields_stage():
185
+ return {
186
+ "$project": {
187
+ "_id": 0,
188
+ "name": 1,
189
+ "accommodates": 1,
190
+ "address.street": 1,
191
+ "address.government_area": 1,
192
+ "address.market": 1,
193
+ "address.country": 1,
194
+ "address.country_code": 1,
195
+ "address.location.type": 1,
196
+ "address.location.coordinates": 1,
197
+ "address.location.is_location_exact": 1,
198
+ "summary": 1,
199
+ "space": 1,
200
+ "neighborhood_overview": 1,
201
+ "notes": 1,
202
+ "score": {"$meta": "vectorSearchScore"}
203
+ }
204
+ }
205
 
206
+ def get_filter_result_stage():
207
  return {
208
  "$match": {
209
  "accommodates": { "$eq": 2},
 
211
  }
212
  }
213
 
214
+ def get_average_review_and_review_count_stage():
215
  return {
216
  "$addFields": {
217
  "averageReview": {