File size: 5,083 Bytes
eec9162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

import time
import threading
import numpy as np
from dataclasses import dataclass, field
from collections import defaultdict
from typing import Optional



@dataclass
class CacheEntry:
    """
    Represents a single cached query-result pair.

    Fields:
    - query:            original query string
    - embedding:        L2-normalized query embedding (numpy array)
    - result:           the answer/document returned for this query
    - dominant_cluster: which GMM cluster this query belongs to
    - timestamp:        when this entry was cached (unix time)
    """
    query:             str
    embedding:         np.ndarray
    result:            str
    dominant_cluster:  int
    timestamp:         float = field(default_factory=time.time)



class SemanticCache:
    """
    A cluster-partitioned semantic cache.

    Instead of storing all entries in one flat list,
    entries are grouped by their dominant GMM cluster.

    Lookup flow:
    1. Embed the incoming query
    2. Find its dominant cluster (via GMM)
    3. Only scan entries in THAT cluster
    4. Return the best match if similarity >= threshold

    This means lookup cost is O(n/k) where:
    - n = total cache entries
    - k = number of clusters (10 in our case)
    """

    def __init__(self, similarity_threshold: float = 0.85):
        """
        Args:
            similarity_threshold: minimum cosine similarity to count as a cache hit.
                                  Explored in cache/threshold_analysis.py
        """
        self.threshold = similarity_threshold

        # Cluster-partitioned storage
        # Key: cluster_id (int)
        # Value: list of CacheEntry objects
        self._store: dict[int, list[CacheEntry]] = defaultdict(list)

        # Stats counters
        self._hit_count  = 0
        self._miss_count = 0

        # Thread safety
        self._lock = threading.Lock()


    def lookup(
        self,
        query_embedding: np.ndarray,
        dominant_cluster: int
    ) -> tuple[Optional[CacheEntry], float]:
        """
        Search for a semantically similar cached query.

        Args:
            query_embedding:  L2-normalized embedding of incoming query
            dominant_cluster: GMM cluster index for this query

        Returns:
            (best_matching_entry, similarity_score)
            If no match found: (None, best_score_seen)
        """
        with self._lock:
            candidates = self._store[dominant_cluster]

            if not candidates:
                self._miss_count += 1
                return None, 0.0

            best_entry = None
            best_score = 0.0

            for entry in candidates:
                # Cosine similarity = dot product of L2-normalized vectors
                # We normalized at embedding time, so this is exact cosine sim
                score = float(np.dot(query_embedding, entry.embedding))

                if score > best_score:
                    best_score = score
                    best_entry = entry

            if best_score >= self.threshold:
                self._hit_count += 1
                return best_entry, best_score
            else:
                self._miss_count += 1
                return None, best_score


    def store(self, entry: CacheEntry) -> None:
        """
        Add a new entry to the cache under its dominant cluster.

        Args:
            entry: CacheEntry to store
        """
        with self._lock:
            self._store[entry.dominant_cluster].append(entry)

   
    def flush(self) -> None:
        """
        Wipe all cache entries and reset all stats counters.
        Called by DELETE /cache endpoint.
        """
        with self._lock:
            self._store.clear()
            self._hit_count  = 0
            self._miss_count = 0

  
    @property
    def stats(self) -> dict:
        """
        Returns current cache statistics.
        Called by GET /cache/stats endpoint.
        """
        with self._lock:
            total_entries = sum(len(v) for v in self._store.values())
            total_queries = self._hit_count + self._miss_count
            hit_rate      = (
                round(self._hit_count / total_queries, 3)
                if total_queries > 0 else 0.0
            )
            return {
                "total_entries": total_entries,
                "hit_count":     self._hit_count,
                "miss_count":    self._miss_count,
                "hit_rate":      hit_rate
            }

    @property
    def total_entries(self) -> int:
        with self._lock:
            return sum(len(v) for v in self._store.values())

    def get_cluster_sizes(self) -> dict[int, int]:
        """Returns how many entries are in each cluster."""
        with self._lock:
            return {k: len(v) for k, v in self._store.items()}

    def __repr__(self):
        return (
            f"SemanticCache("
            f"threshold={self.threshold}, "
            f"entries={self.total_entries}, "
            f"hits={self._hit_count}, "
            f"misses={self._miss_count})"
        )