from __future__ import annotations from statistics import mean from typing import Any, Mapping from backend.mcp_server.common.database import search_vectors from backend.mcp_server.common.embeddings import embed_text from backend.mcp_server.common.logging import log_rag_search_metrics from backend.mcp_server.common.reranker import rerank_results from backend.mcp_server.common.tenant import TenantContext from backend.mcp_server.common.utils import ToolValidationError, tool_handler @tool_handler("rag.search") async def rag_search(context: TenantContext, payload: Mapping[str, Any]) -> dict[str, Any]: """ Perform semantic search across the tenant's knowledge base. """ query = payload.get("query") if not isinstance(query, str) or not query.strip(): raise ToolValidationError("query must be a non-empty string") limit = payload.get("limit", 10) try: limit_value = max(1, min(int(limit), 25)) except (TypeError, ValueError): raise ToolValidationError("limit must be an integer between 1 and 25") threshold = payload.get("threshold", 0.3) # Lower default threshold for better recall try: threshold_value = max(0.0, min(float(threshold), 1.0)) except (TypeError, ValueError): raise ToolValidationError("threshold must be a float between 0.0 and 1.0") embedding = embed_text(query) # Step 1: Get top 10 candidates from vector search for re-ranking # We fetch more candidates than requested to allow cross-encoder to find the best matches rerank_candidates_count = max(10, limit_value * 2) # Get at least 10, or 2x the requested limit raw_results = search_vectors(context.tenant_id, embedding, limit=rerank_candidates_count) # Step 2: Re-rank candidates using cross-encoder for improved accuracy # Re-rank up to top 10 candidates (or all if fewer than 10) candidates_for_rerank = raw_results[:10] # Re-rank top 10 (or all available) reranked_results = None if candidates_for_rerank: # Prepare candidates with text and initial similarity score candidates = [ { "text": chunk.get("text", ""), "relevance": chunk.get("similarity", 0.0), "score": chunk.get("similarity", 0.0), } for chunk in candidates_for_rerank ] # Re-rank using cross-encoder (returns top_k results already sorted) reranked = rerank_results(query, candidates, top_k=limit_value) if reranked: reranked_results = reranked # Step 3: Use re-ranked results if available, otherwise use original vector search results results_to_filter = reranked_results if reranked_results else raw_results # Step 4: Filter by threshold and return top results filtered = [] for chunk in results_to_filter: # Re-ranked results have "score" and "relevance", original have "similarity" similarity = chunk.get("similarity") or chunk.get("score") or chunk.get("relevance") or 0.0 if similarity >= threshold_value: filtered.append({ "text": chunk.get("text", ""), "relevance": similarity, "score": similarity # Add score field for compatibility }) # If we have results above threshold, return top results. Otherwise, return top 1 even if below threshold. if filtered: filtered = sorted(filtered, key=lambda x: x.get("relevance", 0.0), reverse=True)[:limit_value] elif results_to_filter: # Return the top result even if below threshold, as it might still be relevant top_chunk = results_to_filter[0] similarity = top_chunk.get("similarity") or top_chunk.get("score") or top_chunk.get("relevance") or 0.0 filtered = [{ "text": top_chunk.get("text", ""), "relevance": similarity, "score": similarity }] # Calculate metrics from the results we're using (re-ranked or original) hits = len(results_to_filter) scores_for_metrics = [ item.get("similarity") or item.get("score") or item.get("relevance") or 0.0 for item in results_to_filter ] avg_score = mean(scores_for_metrics) if scores_for_metrics else None top_score = scores_for_metrics[0] if scores_for_metrics else None log_rag_search_metrics( tenant_id=context.tenant_id, query=query, hits_count=hits, avg_score=avg_score, top_score=top_score, ) return { "query": query, "results": filtered, "metadata": { "limit": limit_value, "threshold": threshold_value, "hits_before_filter": len(raw_results), "reranked": reranked_results is not None, }, }