Rifqi Hafizuddin
[KM-438-439] add retriever feature
ba550a5
raw
history blame
2.38 kB
"""Routes retrieval requests to the appropriate retriever based on source_hint."""
import asyncio
import hashlib
import json
from typing import Literal
from src.db.redis.connection import get_redis
from src.middlewares.logging import get_logger
from src.rag.base import BaseRetriever, RetrievalResult
logger = get_logger("retrieval_router")
_CACHE_TTL = 3600 # 1 hour
SourceHint = Literal["document", "schema", "both"]
class RetrievalRouter:
def __init__(
self,
schema_retriever: BaseRetriever,
document_retriever: BaseRetriever,
):
self._retrievers: dict[str, BaseRetriever] = {
"schema": schema_retriever,
"document": document_retriever,
}
def _route(self, source_hint: SourceHint) -> list[BaseRetriever]:
if source_hint == "schema":
return [self._retrievers["schema"]]
if source_hint == "document":
return [self._retrievers["document"]]
return list(self._retrievers.values())
async def retrieve(
self,
query: str,
user_id: str,
source_hint: SourceHint = "both",
k: int = 10,
) -> list[RetrievalResult]:
redis = await get_redis()
query_hash = hashlib.md5(query.encode()).hexdigest()
cache_key = f"retrieval:{user_id}:{source_hint}:{query_hash}:{k}"
cached = await redis.get(cache_key)
if cached:
logger.info("returning cached retrieval results", source_hint=source_hint)
raw = json.loads(cached)
return [RetrievalResult(**r) for r in raw]
retrievers = self._route(source_hint)
batches = await asyncio.gather(
*[r.retrieve(query, user_id, k) for r in retrievers],
return_exceptions=True,
)
results: list[RetrievalResult] = []
for batch in batches:
if isinstance(batch, Exception):
logger.error("retriever failed", error=str(batch))
continue
results.extend(batch)
results.sort(key=lambda r: r.score, reverse=True)
results = results[:k]
logger.info("retrieved chunks", count=len(results), source_hint=source_hint)
await redis.setex(
cache_key,
_CACHE_TTL,
json.dumps([vars(r) for r in results]),
)
return results