File size: 7,992 Bytes
eb731f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""Telecom RAG - Semantic Query Cache

Implements semantic caching for query deduplication:
- Uses embedding similarity to detect similar queries
- Cache hits return previous answers without re-running the pipeline
- Reduces latency and LLM costs

Per architecture doc Section 7.2: Similarity threshold 0.95
"""

from typing import Dict, Any, Optional, List, Tuple
import numpy as np
from datetime import datetime, timedelta



import json
import numpy as np
import pickle
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List, Tuple
import os

class SemanticCache:
    """
    Semantic cache with Redis persistence.
    - Vectors: Kept in memory (numpy) for fast cosine similarity search.
    - Payloads: Stored in Redis (hash) for persistence and sharing.
    - Persistence: On startup, vectors are loaded from Redis into memory.
    """
    
    def __init__(
        self, 
        similarity_threshold: float = 0.95,
        max_cache_size: int = 1000,
        ttl_hours: int = 24,
        redis_url: str = None
    ):
        """
        Initialize semantic cache with Redis.

        Args:
            similarity_threshold: Minimum similarity for cache hit
            redis_url: Connection string for Redis
        """
        from .config import ENABLE_REDIS

        self.similarity_threshold = similarity_threshold
        self.ttl = timedelta(hours=ttl_hours)
        self.redis = None
        self.local_cache = []

        # In-memory vector index: List of (embedding, redis_key)
        self.vector_index: List[Tuple[np.ndarray, str]] = []

        if ENABLE_REDIS:
            # Get Redis URL from env or use provided value
            redis_url = redis_url or os.getenv("REDIS_URL", "redis://localhost:6379/0")

            try:
                import redis
                self.redis = redis.from_url(
                    redis_url,
                    decode_responses=False,
                    socket_connect_timeout=1,
                    socket_timeout=1
                )  # Bytes for vectors
                self.redis.ping()
                print("✅ Connected to Redis cache")
                self._load_index_from_redis()
            except Exception as e:
                print(f"⚠️ Redis unavailable: {e}")
                print("   Using in-memory only cache (will be lost on restart)")
                self.redis = None
        else:
            print("ℹ️ Redis disabled via config (ENABLE_REDIS=False)") 

    def _load_index_from_redis(self):
        """Load all cached vectors from Redis into memory."""
        if not self.redis:
            return
            
        try:
            # Keys pattern: "cache:vector:*"
            keys = self.redis.keys("cache:vector:*")
            count = 0
            for key in keys:
                # Key is bytes, decode to str
                key_str = key.decode("utf-8")
                # Get vector (bytes -> numpy)
                vector_bytes = self.redis.get(key)
                if vector_bytes:
                    vector = pickle.loads(vector_bytes)
                    # Extract ID from key: cache:vector:<uuid>
                    cache_id = key_str.split(":")[-1]
                    self.vector_index.append((vector, cache_id))
                    count += 1
            print(f"⚡ Loaded {count} vectors from Redis cache")
        except Exception as e:
            print(f"⚠️ Failed to sync with Redis: {e}")

    def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        if norm_a == 0 or norm_b == 0:
            return 0.0
        return float(np.dot(a, b) / (norm_a * norm_b))
    
    def get(self, query_embedding: List[float]) -> Optional[Dict[str, Any]]:
        """Get cached response if similar query exists."""
        query_emb = np.array(query_embedding)

        # 1. Search in-memory vector index
        best_sim = -1.0
        best_id = None
        best_idx = None

        for idx, (cached_emb, cache_id) in enumerate(self.vector_index):
            sim = self._cosine_similarity(query_emb, cached_emb)
            if sim > best_sim:
                best_sim = sim
                best_id = cache_id
                best_idx = idx

        # 2. Check threshold
        if best_sim >= self.similarity_threshold and best_id is not None:
            # 3. Retrieve payload from Redis or local
            if self.redis:
                payload_json = self.redis.get(f"cache:payload:{best_id}")
                if payload_json:
                    return json.loads(payload_json)
            else:
                # Local fallback using tracked index
                if best_idx is not None and best_idx < len(self.local_cache):
                    return self.local_cache[best_idx]['payload']

        return None

    def set(self, query_embedding: List[float], query: str, response: Dict[str, Any]):
        """Cache a response."""
        import uuid
        cache_id = str(uuid.uuid4())
        
        # Helper to serialize datetime/numpy in JSON
        def json_serial(obj):
            if isinstance(obj, (datetime, datetime.date)):
                return obj.isoformat()
            if isinstance(obj, np.ndarray):
                return obj.tolist()
            return str(obj)

        vector_np = np.array(query_embedding)
        
        if self.redis:
            try:
                # 1. Store Vector (for reload)
                self.redis.setex(
                    f"cache:vector:{cache_id}",
                    self.ttl,
                    pickle.dumps(vector_np)
                )
                
                # 2. Store Payload
                self.redis.setex(
                    f"cache:payload:{cache_id}", 
                    self.ttl,
                    json.dumps(response, default=json_serial)
                )
                
                # 3. Add to local index
                self.vector_index.append((vector_np, cache_id))
                
            except Exception as e:
                print(f"⚠️ Redis cache set failed: {e}")
        else:
            # Fallback
            if len(self.local_cache) >= 1000:
                self.local_cache.pop(0)
                self.vector_index.pop(0)
            
            self.local_cache.append({'payload': response})
            self.vector_index.append((vector_np, cache_id)) # ID doesn't matter much here
    
    def get_stats(self) -> Dict[str, Any]:
        """Get cache statistics."""
        return {
            "cached_queries": len(self.vector_index),
            "backend": "redis" if self.redis else "in-memory",
            "similarity_threshold": self.similarity_threshold,
        }

    def clear(self):
        if self.redis:
            self.redis.flushdb()
        self.vector_index = []
        self.local_cache = []


# Global instance
_cache_instance: Optional[SemanticCache] = None


def get_cache() -> SemanticCache:
    """Get or create global cache instance."""
    global _cache_instance
    if _cache_instance is None:
        _cache_instance = SemanticCache()
    return _cache_instance


if __name__ == "__main__":
    # Test cache
    cache = SemanticCache(similarity_threshold=0.9)
    
    # Simulate embeddings (random vectors)
    embedding1 = np.random.rand(384).tolist()
    embedding2 = np.random.rand(384).tolist()
    embedding1_similar = (np.array(embedding1) * 0.99 + np.random.rand(384) * 0.01).tolist()
    
    # Test cache miss
    result = cache.get(embedding1)
    print(f"Cache miss: {result}")
    
    # Add to cache
    cache.set(embedding1, "What is HARQ?", {"answer": "HARQ is..."})
    
    # Test cache hit with similar embedding
    result = cache.get(embedding1_similar)
    print(f"Cache hit: {result}")
    
    # Test cache miss with different embedding
    result = cache.get(embedding2)
    print(f"Cache miss different: {result}")
    
    print(f"\nStats: {cache.get_stats()}")