datascienceharp commited on
Commit
8e55c07
·
1 Parent(s): 42ae5d8

better hybrid search

Browse files
Files changed (1) hide show
  1. app.py +54 -28
app.py CHANGED
@@ -2,7 +2,7 @@ import spaces
2
  import os
3
  from typing import List, Optional, Dict, Any
4
  import gradio as gr
5
- from qdrant_client import QdrantClient
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from peft import PeftModel
@@ -13,6 +13,8 @@ from dataclasses import dataclass
13
 
14
  warnings.filterwarnings("ignore")
15
 
 
 
16
  # Configuration with secure credential handling
17
  @dataclass
18
  class Config:
@@ -188,44 +190,68 @@ class FiftyOneAssistant:
188
  """Wrapper for GPU embedding function."""
189
  return get_embedding_gpu(text, config.embedding_model)
190
 
191
- def hybrid_search(self, query_text: str, query_vector: List[float], limit: int = 5) -> List[Dict[str, Any]]:
192
- """Perform hybrid search using global qdrant_client."""
193
- if not qdrant_client:
194
- raise ValueError("Qdrant client not initialized")
195
 
196
- # Dense vector search
197
- dense_results = qdrant_client.query_points(
198
- collection_name=config.qdrant_collection_name,
199
  query=query_vector,
200
- limit=limit * 2,
201
- with_payload=True
202
  ).points
203
 
204
- # Keyword scoring
205
- query_words = set(word.lower() for word in query_text.split() if len(word) > 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- scored_results = []
208
- for result in dense_results:
209
- payload = result.payload
210
- doc_text = f"{payload.get('query', '')} {payload.get('response', '')}".lower()
211
-
212
- keyword_bonus = 0.0
213
- if query_words:
214
- doc_words = set(doc_text.split())
215
- overlap = query_words.intersection(doc_words)
216
- keyword_bonus = len(overlap) / len(query_words) * 0.2
217
-
218
- final_score = result.score + keyword_bonus
219
- scored_results.append({'score': final_score, 'payload': payload})
 
 
 
 
 
 
 
 
220
 
221
- scored_results.sort(key=lambda x: x['score'], reverse=True)
222
- return scored_results[:limit]
 
223
 
224
  def get_context(self, user_query: str, top_k: int = 3) -> str:
225
  """Get relevant context."""
226
  try:
227
  query_vector = self.get_embedding(user_query)
228
- results = self.hybrid_search(user_query, query_vector, top_k)
229
 
230
  if not results:
231
  return "No relevant content found"
 
2
  import os
3
  from typing import List, Optional, Dict, Any
4
  import gradio as gr
5
+ from qdrant_client import QdrantClient, models
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from peft import PeftModel
 
13
 
14
  warnings.filterwarnings("ignore")
15
 
16
+ QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME")
17
+
18
  # Configuration with secure credential handling
19
  @dataclass
20
  class Config:
 
190
  """Wrapper for GPU embedding function."""
191
  return get_embedding_gpu(text, config.embedding_model)
192
 
193
+ def hybrid_search(query_text: str, query_vector: List[float], limit: int = 3, text_boost: float = 0.3):
194
+ """
195
+ Hybrid search using only dense vectors + text filtering
196
+ """
197
 
198
+ # Pure vector search for semantic similarity
199
+ vector_results = qdrant_client.query_points(
200
+ collection_name=QDRANT_COLLECTION_NAME,
201
  query=query_vector,
202
+ limit=limit
 
203
  ).points
204
 
205
+ # Vector search WITH text filter for exact matches
206
+ text_filtered_results = qdrant_client.query_points(
207
+ collection_name=QDRANT_COLLECTION_NAME,
208
+ query=query_vector,
209
+ query_filter=models.Filter(
210
+ should=[
211
+ models.FieldCondition(
212
+ key="query",
213
+ match=models.MatchText(text=query_text)
214
+ ),
215
+ models.FieldCondition(
216
+ key="response",
217
+ match=models.MatchText(text=query_text)
218
+ )
219
+ ]
220
+ ),
221
+ limit=limit
222
+ ).points
223
 
224
+ # Create lookup for text matches
225
+ text_match_ids = {result.id for result in text_filtered_results}
226
+
227
+ # Process all vector results and boost those with text matches
228
+ final_results = []
229
+ seen_ids = set()
230
+
231
+ for result in vector_results:
232
+ if result.id not in seen_ids:
233
+ # Boost score if this item also has text matches
234
+ score = result.score
235
+ if result.id in text_match_ids:
236
+ score = score + (score * text_boost) # Proportional boost
237
+
238
+ final_results.append({
239
+ 'id': result.id,
240
+ 'score': score,
241
+ 'payload': result.payload,
242
+ 'has_text_match': result.id in text_match_ids
243
+ })
244
+ seen_ids.add(result.id)
245
 
246
+ # Sort by boosted scores
247
+ final_results.sort(key=lambda x: x['score'], reverse=True)
248
+ return final_results[:limit]
249
 
250
  def get_context(self, user_query: str, top_k: int = 3) -> str:
251
  """Get relevant context."""
252
  try:
253
  query_vector = self.get_embedding(user_query)
254
+ results = self.hybrid_search(user_query, query_vector)
255
 
256
  if not results:
257
  return "No relevant content found"