ishaq101's picture
feat/Catalog Retrieval System (#1)
6bff5d9
"""Retrieval router — dispatches to DocumentRetriever for unstructured sources.
Routing rules:
- unstructured / document / both → DocumentRetriever (PGVector, PDF/DOCX/TXT)
- structured / schema → empty list; handled by query/service.py
- chat → empty list; bypasses retrieval entirely
Exposes the same interface as the old src/rag/retriever.py so call sites in
chat.py require no changes beyond the import path.
"""
import hashlib
import json
from dataclasses import asdict
from src.db.redis.connection import get_redis
from src.middlewares.logging import get_logger
from src.retrieval.base import RetrievalResult
from src.retrieval.document import DocumentRetriever
logger = get_logger("retrieval_router")
_CACHE_TTL = 3600
_CACHE_KEY_PREFIX = "retrieval"
class RetrievalRouter:
def __init__(self) -> None:
self._retriever: DocumentRetriever | None = None
def _get_retriever(self) -> DocumentRetriever:
if self._retriever is None:
self._retriever = DocumentRetriever()
return self._retriever
async def retrieve(
self,
query: str,
user_id: str,
k: int = 5,
) -> list[RetrievalResult]:
redis = await get_redis()
query_hash = hashlib.md5(query.encode()).hexdigest()
cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{query_hash}:{k}"
cached = await redis.get(cache_key)
if cached:
try:
raw = json.loads(cached)
logger.info("returning cached retrieval results")
return [RetrievalResult(**r) for r in raw]
except Exception:
logger.warning("corrupted retrieval cache, fetching fresh")
try:
results = await self._get_retriever().retrieve(query, user_id, k)
except Exception as e:
logger.error("retrieval failed", error=str(e))
return []
await redis.setex(
cache_key,
_CACHE_TTL,
json.dumps([asdict(r) for r in results]),
)
return results
async def invalidate_cache(self, user_id: str) -> int:
"""Delete all cached retrieval entries for a user. Call after upload/delete."""
redis = await get_redis()
pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*"
keys = [key async for key in redis.scan_iter(match=pattern)]
if not keys:
return 0
deleted = await redis.delete(*keys)
logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted)
return int(deleted)
retrieval_router = RetrievalRouter()