File size: 2,155 Bytes
e44e5dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import annotations

from statistics import mean
from typing import Mapping

from backend.mcp_server.common.database import search_vectors
from backend.mcp_server.common.embeddings import embed_text
from backend.mcp_server.common.logging import log_rag_search_metrics
from backend.mcp_server.common.tenant import TenantContext
from backend.mcp_server.common.utils import ToolValidationError, tool_handler


@tool_handler("rag.search")
async def rag_search(context: TenantContext, payload: Mapping[str, Any]) -> dict[str, Any]:
    """
    Perform semantic search across the tenant's knowledge base.
    """

    query = payload.get("query")
    if not isinstance(query, str) or not query.strip():
        raise ToolValidationError("query must be a non-empty string")

    limit = payload.get("limit", 10)
    try:
        limit_value = max(1, min(int(limit), 25))
    except (TypeError, ValueError):
        raise ToolValidationError("limit must be an integer between 1 and 25")

    threshold = payload.get("threshold", 0.55)
    try:
        threshold_value = max(0.0, min(float(threshold), 1.0))
    except (TypeError, ValueError):
        raise ToolValidationError("threshold must be a float between 0.0 and 1.0")

    embedding = embed_text(query)
    raw_results = search_vectors(context.tenant_id, embedding, limit=limit_value)
    filtered = [
        {"text": chunk.get("text", ""), "relevance": chunk.get("similarity", 0.0)}
        for chunk in raw_results
        if chunk.get("similarity", 0.0) >= threshold_value
    ][:3]

    hits = len(raw_results)
    avg_score = mean([item.get("similarity", 0.0) for item in raw_results]) if raw_results else None
    top_score = raw_results[0].get("similarity") if raw_results else None

    log_rag_search_metrics(
        tenant_id=context.tenant_id,
        query=query,
        hits_count=hits,
        avg_score=avg_score,
        top_score=top_score,
    )

    return {
        "query": query,
        "results": filtered,
        "metadata": {
            "limit": limit_value,
            "threshold": threshold_value,
            "hits_before_filter": hits,
        },
    }