nothingworry's picture
imporve RAG
9d50a01
raw
history blame
2.99 kB
from __future__ import annotations
from statistics import mean
from typing import Any, Mapping
from backend.mcp_server.common.database import search_vectors
from backend.mcp_server.common.embeddings import embed_text
from backend.mcp_server.common.logging import log_rag_search_metrics
from backend.mcp_server.common.tenant import TenantContext
from backend.mcp_server.common.utils import ToolValidationError, tool_handler
@tool_handler("rag.search")
async def rag_search(context: TenantContext, payload: Mapping[str, Any]) -> dict[str, Any]:
"""
Perform semantic search across the tenant's knowledge base.
"""
query = payload.get("query")
if not isinstance(query, str) or not query.strip():
raise ToolValidationError("query must be a non-empty string")
limit = payload.get("limit", 10)
try:
limit_value = max(1, min(int(limit), 25))
except (TypeError, ValueError):
raise ToolValidationError("limit must be an integer between 1 and 25")
threshold = payload.get("threshold", 0.3) # Lower default threshold for better recall
try:
threshold_value = max(0.0, min(float(threshold), 1.0))
except (TypeError, ValueError):
raise ToolValidationError("threshold must be a float between 0.0 and 1.0")
embedding = embed_text(query)
raw_results = search_vectors(context.tenant_id, embedding, limit=limit_value)
# Return top results even if slightly below threshold, but prioritize high-scoring ones
filtered = []
for chunk in raw_results:
similarity = chunk.get("similarity", 0.0)
if similarity >= threshold_value:
filtered.append({
"text": chunk.get("text", ""),
"relevance": similarity,
"score": similarity # Add score field for compatibility
})
# If we have results above threshold, return top 3. Otherwise, return top 1 even if below threshold.
if filtered:
filtered = sorted(filtered, key=lambda x: x.get("relevance", 0.0), reverse=True)[:3]
elif raw_results:
# Return the top result even if below threshold, as it might still be relevant
top_chunk = raw_results[0]
filtered = [{
"text": top_chunk.get("text", ""),
"relevance": top_chunk.get("similarity", 0.0),
"score": top_chunk.get("similarity", 0.0)
}]
hits = len(raw_results)
avg_score = mean([item.get("similarity", 0.0) for item in raw_results]) if raw_results else None
top_score = raw_results[0].get("similarity") if raw_results else None
log_rag_search_metrics(
tenant_id=context.tenant_id,
query=query,
hits_count=hits,
avg_score=avg_score,
top_score=top_score,
)
return {
"query": query,
"results": filtered,
"metadata": {
"limit": limit_value,
"threshold": threshold_value,
"hits_before_filter": hits,
},
}