File size: 2,657 Bytes
6bff5d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | """Retrieval router — dispatches to DocumentRetriever for unstructured sources.
Routing rules:
- unstructured / document / both → DocumentRetriever (PGVector, PDF/DOCX/TXT)
- structured / schema → empty list; handled by query/service.py
- chat → empty list; bypasses retrieval entirely
Exposes the same interface as the old src/rag/retriever.py so call sites in
chat.py require no changes beyond the import path.
"""
import hashlib
import json
from dataclasses import asdict
from src.db.redis.connection import get_redis
from src.middlewares.logging import get_logger
from src.retrieval.base import RetrievalResult
from src.retrieval.document import DocumentRetriever
logger = get_logger("retrieval_router")
_CACHE_TTL = 3600
_CACHE_KEY_PREFIX = "retrieval"
class RetrievalRouter:
def __init__(self) -> None:
self._retriever: DocumentRetriever | None = None
def _get_retriever(self) -> DocumentRetriever:
if self._retriever is None:
self._retriever = DocumentRetriever()
return self._retriever
async def retrieve(
self,
query: str,
user_id: str,
k: int = 5,
) -> list[RetrievalResult]:
redis = await get_redis()
query_hash = hashlib.md5(query.encode()).hexdigest()
cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{query_hash}:{k}"
cached = await redis.get(cache_key)
if cached:
try:
raw = json.loads(cached)
logger.info("returning cached retrieval results")
return [RetrievalResult(**r) for r in raw]
except Exception:
logger.warning("corrupted retrieval cache, fetching fresh")
try:
results = await self._get_retriever().retrieve(query, user_id, k)
except Exception as e:
logger.error("retrieval failed", error=str(e))
return []
await redis.setex(
cache_key,
_CACHE_TTL,
json.dumps([asdict(r) for r in results]),
)
return results
async def invalidate_cache(self, user_id: str) -> int:
"""Delete all cached retrieval entries for a user. Call after upload/delete."""
redis = await get_redis()
pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*"
keys = [key async for key in redis.scan_iter(match=pattern)]
if not keys:
return 0
deleted = await redis.delete(*keys)
logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted)
return int(deleted)
retrieval_router = RetrievalRouter()
|