"""Routes retrieval requests to the appropriate retriever based on source_hint.""" import asyncio import hashlib import json from typing import Literal from src.db.redis.connection import get_redis from src.middlewares.logging import get_logger from src.rag.base import BaseRetriever, RetrievalResult logger = get_logger("retrieval_router") _CACHE_TTL = 3600 # 1 hour SourceHint = Literal["document", "schema", "both"] class RetrievalRouter: def __init__( self, schema_retriever: BaseRetriever, document_retriever: BaseRetriever, ): self._retrievers: dict[str, BaseRetriever] = { "schema": schema_retriever, "document": document_retriever, } def _route(self, source_hint: SourceHint) -> list[BaseRetriever]: if source_hint == "schema": return [self._retrievers["schema"]] if source_hint == "document": return [self._retrievers["document"]] return list(self._retrievers.values()) async def retrieve( self, query: str, user_id: str, source_hint: SourceHint = "both", k: int = 10, ) -> list[RetrievalResult]: redis = await get_redis() query_hash = hashlib.md5(query.encode()).hexdigest() cache_key = f"retrieval:{user_id}:{source_hint}:{query_hash}:{k}" cached = await redis.get(cache_key) if cached: logger.info("returning cached retrieval results", source_hint=source_hint) raw = json.loads(cached) return [RetrievalResult(**r) for r in raw] retrievers = self._route(source_hint) batches = await asyncio.gather( *[r.retrieve(query, user_id, k) for r in retrievers], return_exceptions=True, ) results: list[RetrievalResult] = [] for batch in batches: if isinstance(batch, Exception): logger.error("retriever failed", error=str(batch)) continue results.extend(batch) results.sort(key=lambda r: r.score, reverse=True) results = results[:k] logger.info("retrieved chunks", count=len(results), source_hint=source_hint) await redis.setex( cache_key, _CACHE_TTL, json.dumps([vars(r) for r in results]), ) return results