Spaces:
Running
Running
File size: 5,083 Bytes
eec9162 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import time
import threading
import numpy as np
from dataclasses import dataclass, field
from collections import defaultdict
from typing import Optional
@dataclass
class CacheEntry:
"""
Represents a single cached query-result pair.
Fields:
- query: original query string
- embedding: L2-normalized query embedding (numpy array)
- result: the answer/document returned for this query
- dominant_cluster: which GMM cluster this query belongs to
- timestamp: when this entry was cached (unix time)
"""
query: str
embedding: np.ndarray
result: str
dominant_cluster: int
timestamp: float = field(default_factory=time.time)
class SemanticCache:
"""
A cluster-partitioned semantic cache.
Instead of storing all entries in one flat list,
entries are grouped by their dominant GMM cluster.
Lookup flow:
1. Embed the incoming query
2. Find its dominant cluster (via GMM)
3. Only scan entries in THAT cluster
4. Return the best match if similarity >= threshold
This means lookup cost is O(n/k) where:
- n = total cache entries
- k = number of clusters (10 in our case)
"""
def __init__(self, similarity_threshold: float = 0.85):
"""
Args:
similarity_threshold: minimum cosine similarity to count as a cache hit.
Explored in cache/threshold_analysis.py
"""
self.threshold = similarity_threshold
# Cluster-partitioned storage
# Key: cluster_id (int)
# Value: list of CacheEntry objects
self._store: dict[int, list[CacheEntry]] = defaultdict(list)
# Stats counters
self._hit_count = 0
self._miss_count = 0
# Thread safety
self._lock = threading.Lock()
def lookup(
self,
query_embedding: np.ndarray,
dominant_cluster: int
) -> tuple[Optional[CacheEntry], float]:
"""
Search for a semantically similar cached query.
Args:
query_embedding: L2-normalized embedding of incoming query
dominant_cluster: GMM cluster index for this query
Returns:
(best_matching_entry, similarity_score)
If no match found: (None, best_score_seen)
"""
with self._lock:
candidates = self._store[dominant_cluster]
if not candidates:
self._miss_count += 1
return None, 0.0
best_entry = None
best_score = 0.0
for entry in candidates:
# Cosine similarity = dot product of L2-normalized vectors
# We normalized at embedding time, so this is exact cosine sim
score = float(np.dot(query_embedding, entry.embedding))
if score > best_score:
best_score = score
best_entry = entry
if best_score >= self.threshold:
self._hit_count += 1
return best_entry, best_score
else:
self._miss_count += 1
return None, best_score
def store(self, entry: CacheEntry) -> None:
"""
Add a new entry to the cache under its dominant cluster.
Args:
entry: CacheEntry to store
"""
with self._lock:
self._store[entry.dominant_cluster].append(entry)
def flush(self) -> None:
"""
Wipe all cache entries and reset all stats counters.
Called by DELETE /cache endpoint.
"""
with self._lock:
self._store.clear()
self._hit_count = 0
self._miss_count = 0
@property
def stats(self) -> dict:
"""
Returns current cache statistics.
Called by GET /cache/stats endpoint.
"""
with self._lock:
total_entries = sum(len(v) for v in self._store.values())
total_queries = self._hit_count + self._miss_count
hit_rate = (
round(self._hit_count / total_queries, 3)
if total_queries > 0 else 0.0
)
return {
"total_entries": total_entries,
"hit_count": self._hit_count,
"miss_count": self._miss_count,
"hit_rate": hit_rate
}
@property
def total_entries(self) -> int:
with self._lock:
return sum(len(v) for v in self._store.values())
def get_cluster_sizes(self) -> dict[int, int]:
"""Returns how many entries are in each cluster."""
with self._lock:
return {k: len(v) for k, v in self._store.items()}
def __repr__(self):
return (
f"SemanticCache("
f"threshold={self.threshold}, "
f"entries={self.total_entries}, "
f"hits={self._hit_count}, "
f"misses={self._miss_count})"
) |