from __future__ import annotations from statistics import mean from typing import Mapping from backend.mcp_server.common.database import search_vectors from backend.mcp_server.common.embeddings import embed_text from backend.mcp_server.common.logging import log_rag_search_metrics from backend.mcp_server.common.tenant import TenantContext from backend.mcp_server.common.utils import ToolValidationError, tool_handler @tool_handler("rag.search") async def rag_search(context: TenantContext, payload: Mapping[str, Any]) -> dict[str, Any]: """ Perform semantic search across the tenant's knowledge base. """ query = payload.get("query") if not isinstance(query, str) or not query.strip(): raise ToolValidationError("query must be a non-empty string") limit = payload.get("limit", 10) try: limit_value = max(1, min(int(limit), 25)) except (TypeError, ValueError): raise ToolValidationError("limit must be an integer between 1 and 25") threshold = payload.get("threshold", 0.55) try: threshold_value = max(0.0, min(float(threshold), 1.0)) except (TypeError, ValueError): raise ToolValidationError("threshold must be a float between 0.0 and 1.0") embedding = embed_text(query) raw_results = search_vectors(context.tenant_id, embedding, limit=limit_value) filtered = [ {"text": chunk.get("text", ""), "relevance": chunk.get("similarity", 0.0)} for chunk in raw_results if chunk.get("similarity", 0.0) >= threshold_value ][:3] hits = len(raw_results) avg_score = mean([item.get("similarity", 0.0) for item in raw_results]) if raw_results else None top_score = raw_results[0].get("similarity") if raw_results else None log_rag_search_metrics( tenant_id=context.tenant_id, query=query, hits_count=hits, avg_score=avg_score, top_score=top_score, ) return { "query": query, "results": filtered, "metadata": { "limit": limit_value, "threshold": threshold_value, "hits_before_filter": hits, }, }