nothingworry's picture
feat: update the encoding model
0e8c152
raw
history blame
4.85 kB
from __future__ import annotations
from statistics import mean
from typing import Any, Mapping
from backend.mcp_server.common.database import search_vectors
from backend.mcp_server.common.embeddings import embed_text
from backend.mcp_server.common.logging import log_rag_search_metrics
from backend.mcp_server.common.reranker import rerank_results
from backend.mcp_server.common.tenant import TenantContext
from backend.mcp_server.common.utils import ToolValidationError, tool_handler
@tool_handler("rag.search")
async def rag_search(context: TenantContext, payload: Mapping[str, Any]) -> dict[str, Any]:
"""
Perform semantic search across the tenant's knowledge base.
"""
query = payload.get("query")
if not isinstance(query, str) or not query.strip():
raise ToolValidationError("query must be a non-empty string")
limit = payload.get("limit", 10)
try:
limit_value = max(1, min(int(limit), 25))
except (TypeError, ValueError):
raise ToolValidationError("limit must be an integer between 1 and 25")
threshold = payload.get("threshold", 0.3) # Lower default threshold for better recall
try:
threshold_value = max(0.0, min(float(threshold), 1.0))
except (TypeError, ValueError):
raise ToolValidationError("threshold must be a float between 0.0 and 1.0")
embedding = embed_text(query)
# Step 1: Get top 10 candidates from vector search for re-ranking
# We fetch more candidates than requested to allow cross-encoder to find the best matches
rerank_candidates_count = max(10, limit_value * 2) # Get at least 10, or 2x the requested limit
raw_results = search_vectors(context.tenant_id, embedding, limit=rerank_candidates_count)
# Step 2: Re-rank candidates using cross-encoder for improved accuracy
# Re-rank up to top 10 candidates (or all if fewer than 10)
candidates_for_rerank = raw_results[:10] # Re-rank top 10 (or all available)
reranked_results = None
if candidates_for_rerank:
# Prepare candidates with text and initial similarity score
candidates = [
{
"text": chunk.get("text", ""),
"relevance": chunk.get("similarity", 0.0),
"score": chunk.get("similarity", 0.0),
}
for chunk in candidates_for_rerank
]
# Re-rank using cross-encoder (returns top_k results already sorted)
reranked = rerank_results(query, candidates, top_k=limit_value)
if reranked:
reranked_results = reranked
# Step 3: Use re-ranked results if available, otherwise use original vector search results
results_to_filter = reranked_results if reranked_results else raw_results
# Step 4: Filter by threshold and return top results
filtered = []
for chunk in results_to_filter:
# Re-ranked results have "score" and "relevance", original have "similarity"
similarity = chunk.get("similarity") or chunk.get("score") or chunk.get("relevance") or 0.0
if similarity >= threshold_value:
filtered.append({
"text": chunk.get("text", ""),
"relevance": similarity,
"score": similarity # Add score field for compatibility
})
# If we have results above threshold, return top results. Otherwise, return top 1 even if below threshold.
if filtered:
filtered = sorted(filtered, key=lambda x: x.get("relevance", 0.0), reverse=True)[:limit_value]
elif results_to_filter:
# Return the top result even if below threshold, as it might still be relevant
top_chunk = results_to_filter[0]
similarity = top_chunk.get("similarity") or top_chunk.get("score") or top_chunk.get("relevance") or 0.0
filtered = [{
"text": top_chunk.get("text", ""),
"relevance": similarity,
"score": similarity
}]
# Calculate metrics from the results we're using (re-ranked or original)
hits = len(results_to_filter)
scores_for_metrics = [
item.get("similarity") or item.get("score") or item.get("relevance") or 0.0
for item in results_to_filter
]
avg_score = mean(scores_for_metrics) if scores_for_metrics else None
top_score = scores_for_metrics[0] if scores_for_metrics else None
log_rag_search_metrics(
tenant_id=context.tenant_id,
query=query,
hits_count=hits,
avg_score=avg_score,
top_score=top_score,
)
return {
"query": query,
"results": filtered,
"metadata": {
"limit": limit_value,
"threshold": threshold_value,
"hits_before_filter": len(raw_results),
"reranked": reranked_results is not None,
},
}