File size: 9,718 Bytes
7644eac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
"""
Semantic caching layer using Redis for fast query result retrieval.

Semantic caching stores query embeddings and their results, allowing instant
responses for similar queries without re-running the expensive RAG pipeline.
"""
import os
import json
import hashlib
from typing import Optional, Any, Dict
import numpy as np

# Optional imports
try:
    import redis
    REDIS_AVAILABLE = True
except ImportError:
    REDIS_AVAILABLE = False
    print("⚠️  Redis not installed. Install with: pip install redis")

try:
    from openai import OpenAI
    OPENAI_AVAILABLE = True
except ImportError:
    OPENAI_AVAILABLE = False

try:
    from src.utils.config import EMBEDDING_MODEL
except Exception:
    EMBEDDING_MODEL = "text-embedding-3-small"

class SemanticCache:
    """
    Redis-based semantic cache for RAG queries.
    
    Instead of exact string matching, this cache uses embedding similarity
    to find semantically similar queries and return cached results.
    """
    
    def __init__(
        self,
        redis_url: Optional[str] = None,
        redis_host: str = "localhost",
        redis_port: int = 6379,
        redis_password: Optional[str] = None,
        redis_db: int = 0,
        ttl: int = 3600,
        similarity_threshold: float = 0.95,
        openai_api_key: Optional[str] = None
    ):
        """
        Initialize semantic cache.
        
        Args:
            redis_url: Full Redis URL (e.g., redis://... or rediss://... for SSL)
            redis_host: Redis server host (used if redis_url not provided)
            redis_port: Redis server port (used if redis_url not provided)
            redis_password: Redis password (used if redis_url not provided)
            redis_db: Redis database number (used if redis_url not provided)
            ttl: Time-to-live for cached entries (seconds)
            similarity_threshold: Minimum cosine similarity for cache hit
            openai_api_key: OpenAI API key for embeddings
        """
        self.ttl = ttl
        self.similarity_threshold = similarity_threshold
        self.redis_client = None
        self.openai_client = None
        self.enabled = False
        
        # Initialize Redis
        if REDIS_AVAILABLE:
            try:
                # Prefer REDIS_URL if provided and valid (for Upstash, Render, etc.)
                if redis_url and redis_url.strip() and redis_url.startswith(('redis://', 'rediss://', 'unix://')):
                    self.redis_client = redis.from_url(
                        redis_url,
                        decode_responses=False,
                        ssl_cert_reqs=None  # Accept self-signed certs for Upstash
                    )
                    print(f"✅ Connecting to Redis via URL...")
                else:
                    self.redis_client = redis.Redis(
                        host=redis_host,
                        port=redis_port,
                        password=redis_password if redis_password else None,
                        db=redis_db,
                        decode_responses=False
                    )
                    print(f"✅ Connecting to Redis at {redis_host}:{redis_port}...")
                
                # Test connection
                self.redis_client.ping()
                self.enabled = True
                print(f"✅ Semantic cache initialized successfully")
            except Exception as e:
                print(f"⚠️  Redis connection failed: {e}. Caching disabled.")
                self.redis_client = None
        else:
            print("⚠️  Redis not available. Caching disabled.")
        
        # Initialize OpenAI for embeddings
        if OPENAI_AVAILABLE:
            api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
            if api_key:
                self.openai_client = OpenAI(api_key=api_key)
        
        if not self.openai_client:
            print("⚠️  OpenAI not available. Semantic caching disabled.")
            self.enabled = False
    
    def _get_embedding(self, text: str) -> Optional[np.ndarray]:
        """
        Get embedding vector for text.
        
        Args:
            text: Text to embed
            
        Returns:
            Embedding vector or None if failed
        """
        if not self.openai_client:
            return None
        
        try:
            response = self.openai_client.embeddings.create(
                model=EMBEDDING_MODEL,
                input=text
            )
            embedding = np.array(response.data[0].embedding)
            return embedding
        except Exception as e:
            print(f"⚠️  Embedding generation failed: {e}")
            return None
    
    def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
        """
        Calculate cosine similarity between two vectors.
        
        Args:
            vec1: First vector
            vec2: Second vector
            
        Returns:
            Cosine similarity (0-1)
        """
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        
        if norm1 == 0 or norm2 == 0:
            return 0.0
        
        return float(dot_product / (norm1 * norm2))
    
    def get(self, query: str, threshold: Optional[float] = None) -> Optional[Any]:
        """
        Get cached result for a query.

        Args:
            query: Search query text
            threshold: Optional override for similarity threshold

        Returns:
            Cached result or None if not found
        """
        if not self.enabled or not self.redis_client:
            return None

        current_threshold = threshold if threshold is not None else self.similarity_threshold

        try:
            # Get query embedding
            query_embedding = self._get_embedding(query)
            if query_embedding is None:
                return None

            best_similarity = 0.0
            best_result = None

            # Iterate using SCAN to avoid blocking Redis on large keyspaces
            for key in self.redis_client.scan_iter(match="semantic_cache:*"):
                try:
                    cached_data = self.redis_client.get(key)
                    if not cached_data:
                        continue

                    cached_dict = json.loads(cached_data.decode("utf-8"))
                    cached_embedding = np.array(cached_dict["embedding"])

                    similarity = self._cosine_similarity(query_embedding, cached_embedding)

                    if similarity > best_similarity and similarity >= current_threshold:
                        best_similarity = similarity
                        best_result = cached_dict["result"]

                except Exception:
                    continue

            if best_result is not None:
                print(f"💰 Cache hit! Similarity: {best_similarity:.3f}")
                return best_result

            return None

        except Exception as e:
            print(f"⚠️  Cache get failed: {e}")
            return None
    
    def set(self, query: str, result: Any) -> bool:
        """
        Cache a query result.
        
        Args:
            query: Search query
            result: Result to cache
            
        Returns:
            Success status
        """
        if not self.enabled or not self.redis_client:
            return False
        
        try:
            # Get query embedding
            query_embedding = self._get_embedding(query)
            if query_embedding is None:
                return False
            
            # Create cache key (hash of query for uniqueness)
            query_hash = hashlib.md5(query.encode()).hexdigest()
            cache_key = f"semantic_cache:{query_hash}"
            
            # Prepare cache data
            cache_data = {
                'query': query,
                'embedding': query_embedding.tolist(),
                'result': result
            }
            
            # Store in Redis with TTL
            self.redis_client.setex(
                cache_key,
                self.ttl,
                json.dumps(cache_data).encode('utf-8')
            )
            
            print(f"💾 Cached query: '{query}' (TTL: {self.ttl}s)")
            return True
            
        except Exception as e:
            print(f"⚠️  Cache set failed: {e}")
            return False
    
    def clear(self) -> int:
        """
        Clear all cached entries.
        
        Returns:
            Number of entries cleared
        """
        if not self.enabled or not self.redis_client:
            return 0
        
        try:
            keys = self.redis_client.keys("semantic_cache:*")
            if keys:
                count = self.redis_client.delete(*keys)
                print(f"🗑️  Cleared {count} cache entries")
                return count
            return 0
        except Exception as e:
            print(f"⚠️  Cache clear failed: {e}")
            return 0
    
    def get_stats(self) -> Dict[str, Any]:
        """
        Get cache statistics.
        
        Returns:
            Dictionary with cache stats
        """
        if not self.enabled or not self.redis_client:
            return {"enabled": False}
        
        try:
            keys = self.redis_client.keys("semantic_cache:*")
            return {
                "enabled": True,
                "entries": len(keys),
                "ttl": self.ttl,
                "similarity_threshold": self.similarity_threshold
            }
        except Exception as e:
            return {"enabled": False, "error": str(e)}