"""Retrieval router — dispatches to DocumentRetriever for unstructured sources. Routing rules: - unstructured / document / both → DocumentRetriever (PGVector, PDF/DOCX/TXT) - structured / schema → empty list; handled by query/service.py - chat → empty list; bypasses retrieval entirely Exposes the same interface as the old src/rag/retriever.py so call sites in chat.py require no changes beyond the import path. """ import hashlib import json from dataclasses import asdict from src.db.redis.connection import get_redis from src.middlewares.logging import get_logger from src.retrieval.base import RetrievalResult from src.retrieval.document import DocumentRetriever logger = get_logger("retrieval_router") _CACHE_TTL = 3600 _CACHE_KEY_PREFIX = "retrieval" class RetrievalRouter: def __init__(self) -> None: self._retriever: DocumentRetriever | None = None def _get_retriever(self) -> DocumentRetriever: if self._retriever is None: self._retriever = DocumentRetriever() return self._retriever async def retrieve( self, query: str, user_id: str, k: int = 5, ) -> list[RetrievalResult]: redis = await get_redis() query_hash = hashlib.md5(query.encode()).hexdigest() cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{query_hash}:{k}" cached = await redis.get(cache_key) if cached: try: raw = json.loads(cached) logger.info("returning cached retrieval results") return [RetrievalResult(**r) for r in raw] except Exception: logger.warning("corrupted retrieval cache, fetching fresh") try: results = await self._get_retriever().retrieve(query, user_id, k) except Exception as e: logger.error("retrieval failed", error=str(e)) return [] await redis.setex( cache_key, _CACHE_TTL, json.dumps([asdict(r) for r in results]), ) return results async def invalidate_cache(self, user_id: str) -> int: """Delete all cached retrieval entries for a user. Call after upload/delete.""" redis = await get_redis() pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*" keys = [key async for key in redis.scan_iter(match=pattern)] if not keys: return 0 deleted = await redis.delete(*keys) logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted) return int(deleted) retrieval_router = RetrievalRouter()