Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from statistics import mean | |
| from typing import Any, Mapping | |
| from backend.mcp_server.common.database import search_vectors | |
| from backend.mcp_server.common.embeddings import embed_text | |
| from backend.mcp_server.common.logging import log_rag_search_metrics | |
| from backend.mcp_server.common.reranker import rerank_results | |
| from backend.mcp_server.common.tenant import TenantContext | |
| from backend.mcp_server.common.utils import ToolValidationError, tool_handler | |
| async def rag_search(context: TenantContext, payload: Mapping[str, Any]) -> dict[str, Any]: | |
| """ | |
| Perform semantic search across the tenant's knowledge base. | |
| """ | |
| query = payload.get("query") | |
| if not isinstance(query, str) or not query.strip(): | |
| raise ToolValidationError("query must be a non-empty string") | |
| limit = payload.get("limit", 10) | |
| try: | |
| limit_value = max(1, min(int(limit), 25)) | |
| except (TypeError, ValueError): | |
| raise ToolValidationError("limit must be an integer between 1 and 25") | |
| threshold = payload.get("threshold", 0.3) # Lower default threshold for better recall | |
| try: | |
| threshold_value = max(0.0, min(float(threshold), 1.0)) | |
| except (TypeError, ValueError): | |
| raise ToolValidationError("threshold must be a float between 0.0 and 1.0") | |
| embedding = embed_text(query) | |
| # Step 1: Get top 10 candidates from vector search for re-ranking | |
| # We fetch more candidates than requested to allow cross-encoder to find the best matches | |
| rerank_candidates_count = max(10, limit_value * 2) # Get at least 10, or 2x the requested limit | |
| raw_results = search_vectors(context.tenant_id, embedding, limit=rerank_candidates_count) | |
| # Step 2: Re-rank candidates using cross-encoder for improved accuracy | |
| # Re-rank up to top 10 candidates (or all if fewer than 10) | |
| candidates_for_rerank = raw_results[:10] # Re-rank top 10 (or all available) | |
| reranked_results = None | |
| if candidates_for_rerank: | |
| # Prepare candidates with text and initial similarity score | |
| candidates = [ | |
| { | |
| "text": chunk.get("text", ""), | |
| "relevance": chunk.get("similarity", 0.0), | |
| "score": chunk.get("similarity", 0.0), | |
| } | |
| for chunk in candidates_for_rerank | |
| ] | |
| # Re-rank using cross-encoder (returns top_k results already sorted) | |
| reranked = rerank_results(query, candidates, top_k=limit_value) | |
| if reranked: | |
| reranked_results = reranked | |
| # Step 3: Use re-ranked results if available, otherwise use original vector search results | |
| results_to_filter = reranked_results if reranked_results else raw_results | |
| # Step 4: Filter by threshold and return top results | |
| filtered = [] | |
| for chunk in results_to_filter: | |
| # Re-ranked results have "score" and "relevance", original have "similarity" | |
| similarity = chunk.get("similarity") or chunk.get("score") or chunk.get("relevance") or 0.0 | |
| if similarity >= threshold_value: | |
| filtered.append({ | |
| "text": chunk.get("text", ""), | |
| "relevance": similarity, | |
| "score": similarity # Add score field for compatibility | |
| }) | |
| # If we have results above threshold, return top results. Otherwise, return top 1 even if below threshold. | |
| if filtered: | |
| filtered = sorted(filtered, key=lambda x: x.get("relevance", 0.0), reverse=True)[:limit_value] | |
| elif results_to_filter: | |
| # Return the top result even if below threshold, as it might still be relevant | |
| top_chunk = results_to_filter[0] | |
| similarity = top_chunk.get("similarity") or top_chunk.get("score") or top_chunk.get("relevance") or 0.0 | |
| filtered = [{ | |
| "text": top_chunk.get("text", ""), | |
| "relevance": similarity, | |
| "score": similarity | |
| }] | |
| # Calculate metrics from the results we're using (re-ranked or original) | |
| hits = len(results_to_filter) | |
| scores_for_metrics = [ | |
| item.get("similarity") or item.get("score") or item.get("relevance") or 0.0 | |
| for item in results_to_filter | |
| ] | |
| avg_score = mean(scores_for_metrics) if scores_for_metrics else None | |
| top_score = scores_for_metrics[0] if scores_for_metrics else None | |
| log_rag_search_metrics( | |
| tenant_id=context.tenant_id, | |
| query=query, | |
| hits_count=hits, | |
| avg_score=avg_score, | |
| top_score=top_score, | |
| ) | |
| return { | |
| "query": query, | |
| "results": filtered, | |
| "metadata": { | |
| "limit": limit_value, | |
| "threshold": threshold_value, | |
| "hits_before_filter": len(raw_results), | |
| "reranked": reranked_results is not None, | |
| }, | |
| } | |