File size: 2,378 Bytes
ba550a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Routes retrieval requests to the appropriate retriever based on source_hint."""

import asyncio
import hashlib
import json
from typing import Literal

from src.db.redis.connection import get_redis
from src.middlewares.logging import get_logger
from src.rag.base import BaseRetriever, RetrievalResult

logger = get_logger("retrieval_router")

_CACHE_TTL = 3600  # 1 hour
SourceHint = Literal["document", "schema", "both"]


class RetrievalRouter:
    def __init__(
        self,
        schema_retriever: BaseRetriever,
        document_retriever: BaseRetriever,
    ):
        self._retrievers: dict[str, BaseRetriever] = {
            "schema": schema_retriever,
            "document": document_retriever,
        }

    def _route(self, source_hint: SourceHint) -> list[BaseRetriever]:
        if source_hint == "schema":
            return [self._retrievers["schema"]]
        if source_hint == "document":
            return [self._retrievers["document"]]
        return list(self._retrievers.values())

    async def retrieve(
        self,
        query: str,
        user_id: str,
        source_hint: SourceHint = "both",
        k: int = 10,
    ) -> list[RetrievalResult]:
        redis = await get_redis()
        query_hash = hashlib.md5(query.encode()).hexdigest()
        cache_key = f"retrieval:{user_id}:{source_hint}:{query_hash}:{k}"

        cached = await redis.get(cache_key)
        if cached:
            logger.info("returning cached retrieval results", source_hint=source_hint)
            raw = json.loads(cached)
            return [RetrievalResult(**r) for r in raw]

        retrievers = self._route(source_hint)
        batches = await asyncio.gather(
            *[r.retrieve(query, user_id, k) for r in retrievers],
            return_exceptions=True,
        )

        results: list[RetrievalResult] = []
        for batch in batches:
            if isinstance(batch, Exception):
                logger.error("retriever failed", error=str(batch))
                continue
            results.extend(batch)

        results.sort(key=lambda r: r.score, reverse=True)
        results = results[:k]

        logger.info("retrieved chunks", count=len(results), source_hint=source_hint)
        await redis.setex(
            cache_key,
            _CACHE_TTL,
            json.dumps([vars(r) for r in results]),
        )
        return results