| """Retrieval router — dispatches to DocumentRetriever for unstructured sources. |
| |
| Routing rules: |
| - unstructured / document / both → DocumentRetriever (PGVector, PDF/DOCX/TXT) |
| - structured / schema → empty list; handled by query/service.py |
| - chat → empty list; bypasses retrieval entirely |
| |
| Exposes the same interface as the old src/rag/retriever.py so call sites in |
| chat.py require no changes beyond the import path. |
| """ |
|
|
| import hashlib |
| import json |
| from dataclasses import asdict |
|
|
| from src.db.redis.connection import get_redis |
| from src.middlewares.logging import get_logger |
| from src.retrieval.base import RetrievalResult |
| from src.retrieval.document import DocumentRetriever |
|
|
| logger = get_logger("retrieval_router") |
|
|
| _CACHE_TTL = 3600 |
| _CACHE_KEY_PREFIX = "retrieval" |
|
|
|
|
| class RetrievalRouter: |
| def __init__(self) -> None: |
| self._retriever: DocumentRetriever | None = None |
|
|
| def _get_retriever(self) -> DocumentRetriever: |
| if self._retriever is None: |
| self._retriever = DocumentRetriever() |
| return self._retriever |
|
|
| async def retrieve( |
| self, |
| query: str, |
| user_id: str, |
| k: int = 5, |
| ) -> list[RetrievalResult]: |
| redis = await get_redis() |
| query_hash = hashlib.md5(query.encode()).hexdigest() |
| cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{query_hash}:{k}" |
|
|
| cached = await redis.get(cache_key) |
| if cached: |
| try: |
| raw = json.loads(cached) |
| logger.info("returning cached retrieval results") |
| return [RetrievalResult(**r) for r in raw] |
| except Exception: |
| logger.warning("corrupted retrieval cache, fetching fresh") |
|
|
| try: |
| results = await self._get_retriever().retrieve(query, user_id, k) |
| except Exception as e: |
| logger.error("retrieval failed", error=str(e)) |
| return [] |
|
|
| await redis.setex( |
| cache_key, |
| _CACHE_TTL, |
| json.dumps([asdict(r) for r in results]), |
| ) |
| return results |
|
|
| async def invalidate_cache(self, user_id: str) -> int: |
| """Delete all cached retrieval entries for a user. Call after upload/delete.""" |
| redis = await get_redis() |
| pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*" |
| keys = [key async for key in redis.scan_iter(match=pattern)] |
| if not keys: |
| return 0 |
| deleted = await redis.delete(*keys) |
| logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted) |
| return int(deleted) |
|
|
|
|
| retrieval_router = RetrievalRouter() |
|
|