davidtran999 commited on
Commit
c6046cb
·
verified ·
1 Parent(s): 750be9b

Upload backend/hue_portal/core/search_ml.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. backend/hue_portal/core/search_ml.py +67 -15
backend/hue_portal/core/search_ml.py CHANGED
@@ -28,22 +28,40 @@ def expand_query_with_synonyms(query: str) -> List[str]:
28
  expanded = [query_normalized]
29
 
30
  try:
31
- # Get all synonyms
32
- synonyms = Synonym.objects.all()
 
 
 
 
33
  for synonym in synonyms:
 
 
 
34
  keyword = normalize_text(synonym.keyword)
35
  alias = normalize_text(synonym.alias)
36
 
37
  # If query contains keyword, add alias
38
- if keyword in query_normalized:
39
- expanded.append(query_normalized.replace(keyword, alias))
 
 
 
 
 
 
40
  # If query contains alias, add keyword
41
- if alias in query_normalized:
42
- expanded.append(query_normalized.replace(alias, keyword))
 
 
 
 
 
43
  except Exception:
44
  pass # If Synonym table doesn't exist yet
45
 
46
- return list(set(expanded)) # Remove duplicates
47
 
48
 
49
  def create_search_vector(text_fields: List[str]) -> str:
@@ -177,16 +195,41 @@ def search_with_ml(
177
  # Attempt PostgreSQL BM25 ranking first when available
178
  if connection.vendor == "postgresql" and hasattr(queryset.model, "tsv_body"):
179
  try:
180
- expanded_queries = expand_query_with_synonyms(query)
181
- combined_query = None
182
- for q_variant in expanded_queries:
183
- variant_query = SearchQuery(q_variant, config="simple")
184
- combined_query = variant_query if combined_query is None else combined_query | variant_query
 
 
 
 
 
 
 
 
185
 
186
- if combined_query is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  ranked_qs = (
188
  queryset
189
- .annotate(rank=SearchRank(F("tsv_body"), combined_query))
190
  .filter(rank__gt=0)
191
  .order_by("-rank")
192
  )
@@ -195,6 +238,8 @@ def search_with_ml(
195
  for obj in results:
196
  obj._ml_score = getattr(obj, "rank", 0.0)
197
  return results
 
 
198
  except Exception:
199
  # Fall through to ML-based search if any error occurs (e.g. missing extensions)
200
  pass
@@ -213,7 +258,14 @@ def search_with_ml(
213
 
214
  # Calculate similarity scores
215
  try:
216
- scored_indices = calculate_similarity_scores(query, documents, top_k=top_k)
 
 
 
 
 
 
 
217
 
218
  # Filter by minimum score and get object IDs
219
  valid_indices = [idx for idx, score in scored_indices if score >= min_score]
 
28
  expanded = [query_normalized]
29
 
30
  try:
31
+ # Limit to prevent too many expansions
32
+ max_expansions = 10
33
+ expansion_count = 0
34
+
35
+ # Get all synonyms (limit to prevent too many DB queries)
36
+ synonyms = Synonym.objects.all()[:100] # Limit to 100 synonyms
37
  for synonym in synonyms:
38
+ if expansion_count >= max_expansions:
39
+ break
40
+
41
  keyword = normalize_text(synonym.keyword)
42
  alias = normalize_text(synonym.alias)
43
 
44
  # If query contains keyword, add alias
45
+ if keyword and keyword in query_normalized:
46
+ new_query = query_normalized.replace(keyword, alias)
47
+ if new_query not in expanded:
48
+ expanded.append(new_query)
49
+ expansion_count += 1
50
+ if expansion_count >= max_expansions:
51
+ break
52
+
53
  # If query contains alias, add keyword
54
+ if alias and alias in query_normalized:
55
+ new_query = query_normalized.replace(alias, keyword)
56
+ if new_query not in expanded:
57
+ expanded.append(new_query)
58
+ expansion_count += 1
59
+ if expansion_count >= max_expansions:
60
+ break
61
  except Exception:
62
  pass # If Synonym table doesn't exist yet
63
 
64
+ return list(set(expanded))[:10] # Remove duplicates and limit to 10 variants
65
 
66
 
67
  def create_search_vector(text_fields: List[str]) -> str:
 
195
  # Attempt PostgreSQL BM25 ranking first when available
196
  if connection.vendor == "postgresql" and hasattr(queryset.model, "tsv_body"):
197
  try:
198
+ import sys
199
+ # Increase recursion limit for query expansion
200
+ old_limit = sys.getrecursionlimit()
201
+ try:
202
+ sys.setrecursionlimit(3000) # Increase limit for query expansion
203
+ expanded_queries = expand_query_with_synonyms(query)
204
+ # Limit expanded queries to prevent too many variants
205
+ expanded_queries = expanded_queries[:5] # Max 5 variants
206
+
207
+ combined_query = None
208
+ for q_variant in expanded_queries:
209
+ variant_query = SearchQuery(q_variant, config="simple")
210
+ combined_query = variant_query if combined_query is None else combined_query | variant_query
211
 
212
+ if combined_query is not None:
213
+ ranked_qs = (
214
+ queryset
215
+ .annotate(rank=SearchRank(F("tsv_body"), combined_query))
216
+ .filter(rank__gt=0)
217
+ .order_by("-rank")
218
+ )
219
+ results = list(ranked_qs[:top_k])
220
+ if results:
221
+ for obj in results:
222
+ obj._ml_score = getattr(obj, "rank", 0.0)
223
+ return results
224
+ finally:
225
+ sys.setrecursionlimit(old_limit) # Restore original limit
226
+ except RecursionError as e:
227
+ # Fallback: use original query without expansion
228
+ try:
229
+ variant_query = SearchQuery(query, config="simple")
230
  ranked_qs = (
231
  queryset
232
+ .annotate(rank=SearchRank(F("tsv_body"), variant_query))
233
  .filter(rank__gt=0)
234
  .order_by("-rank")
235
  )
 
238
  for obj in results:
239
  obj._ml_score = getattr(obj, "rank", 0.0)
240
  return results
241
+ except Exception:
242
+ pass
243
  except Exception:
244
  # Fall through to ML-based search if any error occurs (e.g. missing extensions)
245
  pass
 
258
 
259
  # Calculate similarity scores
260
  try:
261
+ import sys
262
+ # Increase recursion limit for TF-IDF calculation
263
+ old_limit = sys.getrecursionlimit()
264
+ try:
265
+ sys.setrecursionlimit(3000) # Increase limit for TF-IDF
266
+ scored_indices = calculate_similarity_scores(query, documents, top_k=top_k)
267
+ finally:
268
+ sys.setrecursionlimit(old_limit) # Restore original limit
269
 
270
  # Filter by minimum score and get object IDs
271
  valid_indices = [idx for idx, score in scored_indices if score >= min_score]