| """Routes retrieval requests to the appropriate retriever based on source_hint.""" |
|
|
| import asyncio |
| import hashlib |
| import json |
| from typing import Literal |
|
|
| from src.db.redis.connection import get_redis |
| from src.middlewares.logging import get_logger |
| from src.rag.base import BaseRetriever, RetrievalResult |
|
|
| logger = get_logger("retrieval_router") |
|
|
| _CACHE_TTL = 3600 |
| SourceHint = Literal["document", "schema", "both"] |
|
|
|
|
| class RetrievalRouter: |
| def __init__( |
| self, |
| schema_retriever: BaseRetriever, |
| document_retriever: BaseRetriever, |
| ): |
| self._retrievers: dict[str, BaseRetriever] = { |
| "schema": schema_retriever, |
| "document": document_retriever, |
| } |
|
|
| def _route(self, source_hint: SourceHint) -> list[BaseRetriever]: |
| if source_hint == "schema": |
| return [self._retrievers["schema"]] |
| if source_hint == "document": |
| return [self._retrievers["document"]] |
| return list(self._retrievers.values()) |
|
|
| async def retrieve( |
| self, |
| query: str, |
| user_id: str, |
| source_hint: SourceHint = "both", |
| k: int = 10, |
| ) -> list[RetrievalResult]: |
| redis = await get_redis() |
| query_hash = hashlib.md5(query.encode()).hexdigest() |
| cache_key = f"retrieval:{user_id}:{source_hint}:{query_hash}:{k}" |
|
|
| cached = await redis.get(cache_key) |
| if cached: |
| logger.info("returning cached retrieval results", source_hint=source_hint) |
| raw = json.loads(cached) |
| return [RetrievalResult(**r) for r in raw] |
|
|
| retrievers = self._route(source_hint) |
| batches = await asyncio.gather( |
| *[r.retrieve(query, user_id, k) for r in retrievers], |
| return_exceptions=True, |
| ) |
|
|
| results: list[RetrievalResult] = [] |
| for batch in batches: |
| if isinstance(batch, Exception): |
| logger.error("retriever failed", error=str(batch)) |
| continue |
| results.extend(batch) |
|
|
| results.sort(key=lambda r: r.score, reverse=True) |
| results = results[:k] |
|
|
| logger.info("retrieved chunks", count=len(results), source_hint=source_hint) |
| await redis.setex( |
| cache_key, |
| _CACHE_TTL, |
| json.dumps([vars(r) for r in results]), |
| ) |
| return results |
|
|