nothingworry commited on
Commit
9d50a01
·
1 Parent(s): 4c04529

imporve RAG

Browse files
backend/api/mcp_clients/mcp_client.py CHANGED
@@ -10,9 +10,31 @@ class MCPClient:
10
  client: httpx.AsyncClient = field(default_factory=lambda: httpx.AsyncClient(timeout=30))
11
 
12
 
13
- async def call_rag(self, tenant_id: str, query: str):
14
- r = await self.client.post(f"{self.rag_url}/search", json={"tenant_id":tenant_id,"query":query})
15
- return r.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  async def call_web(self, tenant_id: str, query: str):
 
10
  client: httpx.AsyncClient = field(default_factory=lambda: httpx.AsyncClient(timeout=30))
11
 
12
 
13
+ async def call_rag(self, tenant_id: str, query: str, threshold: float = 0.3):
14
+ """
15
+ Calls the RAG search endpoint and returns the unwrapped results.
16
+ The MCP server wraps responses in a 'data' field, so we extract it.
17
+
18
+ Uses a lower threshold (0.3) by default to ensure we find relevant results
19
+ even if semantic similarity is moderate.
20
+ """
21
+ r = await self.client.post(
22
+ f"{self.rag_url}/search",
23
+ json={
24
+ "tenant_id": tenant_id,
25
+ "query": query,
26
+ "threshold": threshold # Lower threshold for better recall
27
+ }
28
+ )
29
+ if r.status_code != 200:
30
+ return {"results": [], "error": f"HTTP {r.status_code}"}
31
+ data = r.json()
32
+ # MCP server wraps response in a 'data' field
33
+ # Extract the actual result data
34
+ if isinstance(data, dict) and "data" in data:
35
+ return data["data"]
36
+ # If not wrapped, return as-is (backward compatibility)
37
+ return data
38
 
39
 
40
  async def call_web(self, tenant_id: str, query: str):
backend/api/mcp_clients/rag_client.py CHANGED
@@ -19,6 +19,7 @@ class RAGClient:
19
  async def search(self, query: str, tenant_id: str):
20
  """
21
  Sends the query to the RAG server and returns document chunks.
 
22
  """
23
 
24
  try:
@@ -35,7 +36,16 @@ class RAGClient:
35
  return []
36
 
37
  data = response.json()
38
- return data.get("results", [])
 
 
 
 
 
 
 
 
 
39
 
40
  except Exception as e:
41
  print("RAG Client Error:", e)
 
19
  async def search(self, query: str, tenant_id: str):
20
  """
21
  Sends the query to the RAG server and returns document chunks.
22
+ Unwraps MCP server responses automatically.
23
  """
24
 
25
  try:
 
36
  return []
37
 
38
  data = response.json()
39
+
40
+ if isinstance(data, dict) and data.get("status") == "error":
41
+ print("RAG Client Error:", data.get("message"))
42
+ return []
43
+
44
+ if isinstance(data, dict) and "data" in data:
45
+ payload = data["data"]
46
+ return payload.get("results", []) if isinstance(payload, dict) else payload
47
+
48
+ return data.get("results", []) if isinstance(data, dict) else data
49
 
50
  except Exception as e:
51
  print("RAG Client Error:", e)
backend/api/services/tool_selector.py CHANGED
@@ -37,7 +37,8 @@ class ToolSelector:
37
  # RAG patterns: internal knowledge, company-specific, documentation
38
  rag_patterns = [
39
  r"company", r"internal", r"documentation", r"our ", r"your ",
40
- r"knowledge base", r"private", r"internal docs", r"corporate"
 
41
  ]
42
  if rag_has_data or rag_score >= 0.55 or any(re.search(p, msg) for p in rag_patterns):
43
  needs_rag = True
 
37
  # RAG patterns: internal knowledge, company-specific, documentation
38
  rag_patterns = [
39
  r"company", r"internal", r"documentation", r"our ", r"your ",
40
+ r"knowledge base", r"private", r"internal docs", r"corporate",
41
+ r"admin", r"administrator", r"who is", r"what is" # Add admin and fact lookup patterns
42
  ]
43
  if rag_has_data or rag_score >= 0.55 or any(re.search(p, msg) for p in rag_patterns):
44
  needs_rag = True
backend/mcp_server/common/database.py CHANGED
@@ -155,11 +155,11 @@ def search_vectors(tenant_id: str, vector: list, limit: int = 5) -> List[Dict[st
155
  print("DB SEARCH ERROR: tenant_id is empty")
156
  return []
157
 
158
- tenant_id = tenant_id.strip()
159
  conn = get_connection()
160
  cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
161
 
162
- # Query with explicit tenant_id filtering
163
  cur.execute(
164
  """
165
  SELECT
@@ -167,11 +167,11 @@ def search_vectors(tenant_id: str, vector: list, limit: int = 5) -> List[Dict[st
167
  tenant_id,
168
  1 - (embedding <=> %s::vector(384)) AS similarity
169
  FROM documents
170
- WHERE tenant_id = %s
171
  ORDER BY embedding <=> %s::vector(384)
172
  LIMIT %s;
173
  """,
174
- (vector, tenant_id, vector, limit),
175
  )
176
 
177
  rows = cur.fetchall()
@@ -180,9 +180,9 @@ def search_vectors(tenant_id: str, vector: list, limit: int = 5) -> List[Dict[st
180
  results: List[Dict[str, Any]] = []
181
  for row in rows:
182
  row_tenant_id = row.get("tenant_id", "")
183
- if row_tenant_id != tenant_id:
184
  print(
185
- f"WARNING: Found document with tenant_id '{row_tenant_id}' when searching for '{tenant_id}' - skipping"
186
  )
187
  continue
188
 
@@ -211,58 +211,35 @@ def list_all_documents(
211
  ) -> Dict[str, Any]:
212
  """
213
  List all documents for a tenant with pagination.
214
- Handles tenant_id normalization to match documents stored with different formatting.
215
  """
216
  try:
217
- # Normalize tenant_id to ensure consistency
218
  tenant_id_normalized = tenant_id.strip()
219
-
220
  conn = get_connection()
221
  cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
222
 
223
- # Get all unique tenant_ids that match when normalized
224
- cur.execute("SELECT DISTINCT tenant_id FROM documents;")
225
- all_tenant_ids = [row[0] for row in cur.fetchall()]
226
-
227
- # Find tenant_ids that match when normalized
228
- matching_tenant_ids = []
229
- for stored_tenant_id in all_tenant_ids:
230
- if stored_tenant_id and stored_tenant_id.strip() == tenant_id_normalized:
231
- matching_tenant_ids.append(stored_tenant_id)
232
-
233
- if not matching_tenant_ids:
234
- # No matching tenant_ids found
235
- cur.close()
236
- conn.close()
237
- return {"documents": [], "total": 0, "limit": limit, "offset": offset}
238
-
239
- # Build query to match any of the normalized tenant_ids
240
- placeholders = ','.join(['%s'] * len(matching_tenant_ids))
241
  cur.execute(
242
- f"""
243
  SELECT
244
  id,
245
  chunk_text,
246
  created_at
247
  FROM documents
248
- WHERE tenant_id IN ({placeholders})
249
  ORDER BY created_at DESC
250
  LIMIT %s OFFSET %s;
251
  """,
252
- tuple(matching_tenant_ids) + (limit, offset),
253
  )
254
-
255
  rows = cur.fetchall()
256
 
257
- # Get total count for all matching tenant_ids
258
- placeholders = ','.join(['%s'] * len(matching_tenant_ids))
259
  cur.execute(
260
- f"""
261
  SELECT COUNT(*) as total
262
  FROM documents
263
- WHERE tenant_id IN ({placeholders});
264
  """,
265
- tuple(matching_tenant_ids),
266
  )
267
  total_row = cur.fetchone()
268
  total = total_row["total"] if total_row else 0
@@ -299,56 +276,24 @@ def delete_document(tenant_id: str, document_id: int) -> bool:
299
  Returns True if document was deleted, False otherwise.
300
  """
301
  try:
302
- # Normalize tenant_id to ensure consistency
303
- tenant_id = tenant_id.strip()
304
-
305
  conn = get_connection()
306
  cur = conn.cursor()
307
 
308
- # First, verify the document exists
309
  cur.execute(
310
  """
311
- SELECT id, tenant_id FROM documents
312
- WHERE id = %s;
313
  """,
314
- (document_id,),
315
  )
316
- doc_row = cur.fetchone()
317
-
318
- if doc_row is None:
319
- print(f"DB DELETE: Document {document_id} does not exist")
320
- cur.close()
321
- conn.close()
322
- return False
323
-
324
- doc_tenant_id = doc_row[1] if len(doc_row) > 1 else None
325
- # Normalize both tenant_ids for comparison (handle existing data with whitespace)
326
- doc_tenant_id_normalized = doc_tenant_id.strip() if doc_tenant_id else None
327
- tenant_id_normalized = tenant_id.strip()
328
-
329
- # Try to delete with normalized comparison - if normalized match, use stored value for actual delete
330
- if doc_tenant_id_normalized == tenant_id_normalized:
331
- # Tenant IDs match after normalization - proceed with delete using stored tenant_id
332
- cur.execute(
333
- """
334
- DELETE FROM documents
335
- WHERE id = %s AND tenant_id = %s;
336
- """,
337
- (document_id, doc_tenant_id),
338
- )
339
- deleted = cur.rowcount > 0
340
- else:
341
- # Tenant IDs don't match - log the mismatch
342
- print(f"DB DELETE: Document {document_id} belongs to tenant '{doc_tenant_id}' (normalized: '{doc_tenant_id_normalized}'), not '{tenant_id}' (normalized: '{tenant_id_normalized}')")
343
- print(f"DB DELETE: Tenant ID lengths - stored: {len(doc_tenant_id) if doc_tenant_id else 0}, requested: {len(tenant_id)}")
344
- print(f"DB DELETE: Tenant ID repr - stored: {repr(doc_tenant_id)}, requested: {repr(tenant_id)}")
345
- deleted = False
346
-
347
  if deleted:
348
- print(f"DB DELETE: Successfully deleted document {document_id} for tenant '{tenant_id}'")
349
  else:
350
- print(f"DB DELETE: Failed to delete document {document_id} for tenant '{tenant_id}' (rowcount: {cur.rowcount})")
351
-
352
  conn.commit()
353
  cur.close()
354
  conn.close()
@@ -369,47 +314,21 @@ def delete_all_documents(tenant_id: str) -> int:
369
  Handles tenant_id normalization to match documents stored with different formatting.
370
  """
371
  try:
372
- # Normalize tenant_id
373
- tenant_id = tenant_id.strip()
374
-
375
  conn = get_connection()
376
  cur = conn.cursor()
377
 
378
- # First, get all unique tenant_ids that match when normalized
379
  cur.execute(
380
  """
381
- SELECT DISTINCT tenant_id FROM documents;
382
- """
 
 
383
  )
384
- all_tenant_ids = [row[0] for row in cur.fetchall()]
385
-
386
- # Find tenant_ids that match when normalized
387
- matching_tenant_ids = []
388
- tenant_id_normalized = tenant_id.strip()
389
- for stored_tenant_id in all_tenant_ids:
390
- if stored_tenant_id and stored_tenant_id.strip() == tenant_id_normalized:
391
- matching_tenant_ids.append(stored_tenant_id)
392
-
393
- if not matching_tenant_ids:
394
- print(f"DB DELETE ALL: No documents found for tenant '{tenant_id}' (normalized: '{tenant_id_normalized}')")
395
- cur.close()
396
- conn.close()
397
- return 0
398
-
399
- # Delete documents matching any of the normalized tenant_ids
400
- deleted_count = 0
401
- for matching_tenant_id in matching_tenant_ids:
402
- cur.execute(
403
- """
404
- DELETE FROM documents
405
- WHERE tenant_id = %s;
406
- """,
407
- (matching_tenant_id,),
408
- )
409
- deleted_count += cur.rowcount
410
-
411
- print(f"DB DELETE ALL: Deleted {deleted_count} document(s) for tenant '{tenant_id}' (matched {len(matching_tenant_ids)} tenant_id variant(s))")
412
-
413
  conn.commit()
414
  cur.close()
415
  conn.close()
 
155
  print("DB SEARCH ERROR: tenant_id is empty")
156
  return []
157
 
158
+ tenant_id_normalized = tenant_id.strip()
159
  conn = get_connection()
160
  cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
161
 
162
+ # Query with normalized tenant_id filtering
163
  cur.execute(
164
  """
165
  SELECT
 
167
  tenant_id,
168
  1 - (embedding <=> %s::vector(384)) AS similarity
169
  FROM documents
170
+ WHERE TRIM(tenant_id) = %s
171
  ORDER BY embedding <=> %s::vector(384)
172
  LIMIT %s;
173
  """,
174
+ (vector, tenant_id_normalized, vector, limit),
175
  )
176
 
177
  rows = cur.fetchall()
 
180
  results: List[Dict[str, Any]] = []
181
  for row in rows:
182
  row_tenant_id = row.get("tenant_id", "")
183
+ if row_tenant_id and row_tenant_id.strip() != tenant_id_normalized:
184
  print(
185
+ f"WARNING: Found document with tenant_id '{row_tenant_id}' when searching for '{tenant_id_normalized}' - skipping"
186
  )
187
  continue
188
 
 
211
  ) -> Dict[str, Any]:
212
  """
213
  List all documents for a tenant with pagination.
214
+ tenant_id comparison is normalized via TRIM() to handle historical data.
215
  """
216
  try:
 
217
  tenant_id_normalized = tenant_id.strip()
 
218
  conn = get_connection()
219
  cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  cur.execute(
222
+ """
223
  SELECT
224
  id,
225
  chunk_text,
226
  created_at
227
  FROM documents
228
+ WHERE TRIM(tenant_id) = %s
229
  ORDER BY created_at DESC
230
  LIMIT %s OFFSET %s;
231
  """,
232
+ (tenant_id_normalized, limit, offset),
233
  )
 
234
  rows = cur.fetchall()
235
 
 
 
236
  cur.execute(
237
+ """
238
  SELECT COUNT(*) as total
239
  FROM documents
240
+ WHERE TRIM(tenant_id) = %s;
241
  """,
242
+ (tenant_id_normalized,),
243
  )
244
  total_row = cur.fetchone()
245
  total = total_row["total"] if total_row else 0
 
276
  Returns True if document was deleted, False otherwise.
277
  """
278
  try:
279
+ tenant_id_normalized = tenant_id.strip()
 
 
280
  conn = get_connection()
281
  cur = conn.cursor()
282
 
 
283
  cur.execute(
284
  """
285
+ DELETE FROM documents
286
+ WHERE id = %s AND TRIM(tenant_id) = %s;
287
  """,
288
+ (document_id, tenant_id_normalized),
289
  )
290
+
291
+ deleted = cur.rowcount > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  if deleted:
293
+ print(f"DB DELETE: Deleted document {document_id} for tenant '{tenant_id_normalized}'")
294
  else:
295
+ print(f"DB DELETE: Document {document_id} not found for tenant '{tenant_id_normalized}'")
296
+
297
  conn.commit()
298
  cur.close()
299
  conn.close()
 
314
  Handles tenant_id normalization to match documents stored with different formatting.
315
  """
316
  try:
317
+ tenant_id_normalized = tenant_id.strip()
 
 
318
  conn = get_connection()
319
  cur = conn.cursor()
320
 
 
321
  cur.execute(
322
  """
323
+ DELETE FROM documents
324
+ WHERE TRIM(tenant_id) = %s;
325
+ """,
326
+ (tenant_id_normalized,),
327
  )
328
+
329
+ deleted_count = cur.rowcount
330
+ print(f"DB DELETE ALL: Deleted {deleted_count} document(s) for tenant '{tenant_id_normalized}'")
331
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  conn.commit()
333
  cur.close()
334
  conn.close()
backend/mcp_server/rag/search.py CHANGED
@@ -1,7 +1,7 @@
1
  from __future__ import annotations
2
 
3
  from statistics import mean
4
- from typing import Mapping
5
 
6
  from backend.mcp_server.common.database import search_vectors
7
  from backend.mcp_server.common.embeddings import embed_text
@@ -26,7 +26,7 @@ async def rag_search(context: TenantContext, payload: Mapping[str, Any]) -> dict
26
  except (TypeError, ValueError):
27
  raise ToolValidationError("limit must be an integer between 1 and 25")
28
 
29
- threshold = payload.get("threshold", 0.55)
30
  try:
31
  threshold_value = max(0.0, min(float(threshold), 1.0))
32
  except (TypeError, ValueError):
@@ -34,11 +34,27 @@ async def rag_search(context: TenantContext, payload: Mapping[str, Any]) -> dict
34
 
35
  embedding = embed_text(query)
36
  raw_results = search_vectors(context.tenant_id, embedding, limit=limit_value)
37
- filtered = [
38
- {"text": chunk.get("text", ""), "relevance": chunk.get("similarity", 0.0)}
39
- for chunk in raw_results
40
- if chunk.get("similarity", 0.0) >= threshold_value
41
- ][:3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  hits = len(raw_results)
44
  avg_score = mean([item.get("similarity", 0.0) for item in raw_results]) if raw_results else None
 
1
  from __future__ import annotations
2
 
3
  from statistics import mean
4
+ from typing import Any, Mapping
5
 
6
  from backend.mcp_server.common.database import search_vectors
7
  from backend.mcp_server.common.embeddings import embed_text
 
26
  except (TypeError, ValueError):
27
  raise ToolValidationError("limit must be an integer between 1 and 25")
28
 
29
+ threshold = payload.get("threshold", 0.3) # Lower default threshold for better recall
30
  try:
31
  threshold_value = max(0.0, min(float(threshold), 1.0))
32
  except (TypeError, ValueError):
 
34
 
35
  embedding = embed_text(query)
36
  raw_results = search_vectors(context.tenant_id, embedding, limit=limit_value)
37
+ # Return top results even if slightly below threshold, but prioritize high-scoring ones
38
+ filtered = []
39
+ for chunk in raw_results:
40
+ similarity = chunk.get("similarity", 0.0)
41
+ if similarity >= threshold_value:
42
+ filtered.append({
43
+ "text": chunk.get("text", ""),
44
+ "relevance": similarity,
45
+ "score": similarity # Add score field for compatibility
46
+ })
47
+ # If we have results above threshold, return top 3. Otherwise, return top 1 even if below threshold.
48
+ if filtered:
49
+ filtered = sorted(filtered, key=lambda x: x.get("relevance", 0.0), reverse=True)[:3]
50
+ elif raw_results:
51
+ # Return the top result even if below threshold, as it might still be relevant
52
+ top_chunk = raw_results[0]
53
+ filtered = [{
54
+ "text": top_chunk.get("text", ""),
55
+ "relevance": top_chunk.get("similarity", 0.0),
56
+ "score": top_chunk.get("similarity", 0.0)
57
+ }]
58
 
59
  hits = len(raw_results)
60
  avg_score = mean([item.get("similarity", 0.0) for item in raw_results]) if raw_results else None