File size: 2,657 Bytes
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Retrieval router — dispatches to DocumentRetriever for unstructured sources.

Routing rules:
  - unstructured / document / both → DocumentRetriever (PGVector, PDF/DOCX/TXT)
  - structured / schema            → empty list; handled by query/service.py
  - chat                           → empty list; bypasses retrieval entirely

Exposes the same interface as the old src/rag/retriever.py so call sites in
chat.py require no changes beyond the import path.
"""

import hashlib
import json
from dataclasses import asdict

from src.db.redis.connection import get_redis
from src.middlewares.logging import get_logger
from src.retrieval.base import RetrievalResult
from src.retrieval.document import DocumentRetriever

logger = get_logger("retrieval_router")

_CACHE_TTL = 3600
_CACHE_KEY_PREFIX = "retrieval"


class RetrievalRouter:
    def __init__(self) -> None:
        self._retriever: DocumentRetriever | None = None

    def _get_retriever(self) -> DocumentRetriever:
        if self._retriever is None:
            self._retriever = DocumentRetriever()
        return self._retriever

    async def retrieve(
        self,
        query: str,
        user_id: str,
        k: int = 5,
    ) -> list[RetrievalResult]:
        redis = await get_redis()
        query_hash = hashlib.md5(query.encode()).hexdigest()
        cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{query_hash}:{k}"

        cached = await redis.get(cache_key)
        if cached:
            try:
                raw = json.loads(cached)
                logger.info("returning cached retrieval results")
                return [RetrievalResult(**r) for r in raw]
            except Exception:
                logger.warning("corrupted retrieval cache, fetching fresh")

        try:
            results = await self._get_retriever().retrieve(query, user_id, k)
        except Exception as e:
            logger.error("retrieval failed", error=str(e))
            return []

        await redis.setex(
            cache_key,
            _CACHE_TTL,
            json.dumps([asdict(r) for r in results]),
        )
        return results

    async def invalidate_cache(self, user_id: str) -> int:
        """Delete all cached retrieval entries for a user. Call after upload/delete."""
        redis = await get_redis()
        pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*"
        keys = [key async for key in redis.scan_iter(match=pattern)]
        if not keys:
            return 0
        deleted = await redis.delete(*keys)
        logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted)
        return int(deleted)


retrieval_router = RetrievalRouter()