File size: 1,831 Bytes
33d0c55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""Persistent query cache.

Stores (query_text, retrieved_chunks, timestamp) and indexes the queries
by their BGE embedding. New queries similar to past successful queries
return the cached chunks immediately (Tier 1 of the retrieval pipeline).

This is the device-side counterpart of the future central-server PageRank
curation layer: as users issue queries, successful (query, chunks) pairs
accumulate locally, and can later be uploaded for collective curation.
"""

from typing import List, Optional, Tuple
import time

import numpy as np

from rag import BGERetriever


class QueryCache:
    def __init__(self, retriever: BGERetriever, sim_threshold: float = 0.85):
        self.retriever = retriever
        self.sim_threshold = sim_threshold
        self.entries: list = []                                  # list of (query, [chunks], ts)
        self.q_embs: np.ndarray = np.zeros((0, retriever._dim()), dtype=np.float32)

    def __len__(self) -> int:
        return len(self.entries)

    def add(self, query: str, chunks: List[str]) -> None:
        emb = self.retriever._encode([query], is_query=True)
        self.entries.append((query, list(chunks), time.time()))
        if len(self.q_embs) == 0:
            self.q_embs = emb
        else:
            self.q_embs = np.vstack([self.q_embs, emb])

    def lookup(self, query: str) -> Optional[Tuple[List[str], float, str]]:
        """If a sufficiently-similar past query exists, return (chunks, sim, matched_query)."""
        if len(self.entries) == 0:
            return None
        emb = self.retriever._encode([query], is_query=True)[0]
        sims = self.q_embs @ emb
        idx = int(sims.argmax())
        if sims[idx] >= self.sim_threshold:
            q, chunks, _ts = self.entries[idx]
            return chunks, float(sims[idx]), q
        return None