File size: 9,873 Bytes
0a4529c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# DEPENDENCIES
import numpy as np
from typing import List
from typing import Optional
from numpy.typing import NDArray
from config.settings import get_settings
from config.logging_config import get_logger
from utils.error_handler import handle_errors
from utils.cache_manager import EmbeddingCache as BaseEmbeddingCache


# Setup Settings and Logging
settings = get_settings()
logger   = get_logger(__name__)


class EmbeddingCache:
    """
    Embedding cache with numpy array support and statistics: Wraps the base cache with embedding-specific features
    """
    def __init__(self, max_size: int = None, ttl: int = None):
        """
        Initialize embedding cache
        
        Arguments:
        ----------
            max_size { int } : Maximum cache size
            
            ttl      { int } : Time to live in seconds
        """
        self.logger               = logger
        self.max_size             = max_size or settings.CACHE_MAX_SIZE
        self.ttl                  = ttl or settings.CACHE_TTL
        
        # Initialize base cache
        self.base_cache           = BaseEmbeddingCache(max_size = self.max_size,
                                                       ttl      = self.ttl,
                                                      )
        
        # Enhanced statistics
        self.hits                 = 0
        self.misses               = 0
        self.embeddings_generated = 0
        
        self.logger.info(f"Initialized EmbeddingCache: max_size={self.max_size}, ttl={self.ttl}")
    

    def get_embedding(self, text: str) -> Optional[NDArray]:
        """
        Get embedding from cache
        
        Arguments:
        ----------
            text { str }   : Input text
        
        Returns:
        --------
               { NDArray } : Cached embedding or None
        """
        cached = self.base_cache.get_embedding(text)
        
        if cached is not None:
            self.hits += 1
            
            # Convert list back to numpy array
            return np.array(cached)
        
        else:
            self.misses += 1
            return None
    

    def set_embedding(self, text: str, embedding: NDArray):
        """
        Store embedding in cache
        
        Arguments:
        ----------
            text      { str }     : Input text

            embedding { NDArray } : Embedding vector
        """
        # Convert numpy array to list for serialization
        embedding_list             = embedding.tolist()
        
        self.base_cache.set_embedding(text, embedding_list)

        self.embeddings_generated += 1
    

    def batch_get_embeddings(self, texts: List[str]) -> tuple[List[Optional[NDArray]], List[str]]:
        """
        Get multiple embeddings from cache
        
        Arguments:
        ----------
            texts { list } : List of texts
        
        Returns:
        --------
             { tuple }     : Tuple of (cached_embeddings, missing_texts)
        """
        cached_embeddings = list()
        missing_texts     = list()
        
        for text in texts:
            embedding = self.get_embedding(text)
            
            if embedding is not None:
                cached_embeddings.append(embedding)
            
            else:
                missing_texts.append(text)
                cached_embeddings.append(None)
        
        return cached_embeddings, missing_texts
    

    def batch_set_embeddings(self, texts: List[str], embeddings: List[NDArray]):
        """
        Store multiple embeddings in cache
        
        Arguments:
        ----------
            texts      { list } : List of texts

            embeddings { list } : List of embedding vectors
        """
        if (len(texts) != len(embeddings)):
            raise ValueError("Texts and embeddings must have same length")
        
        for text, embedding in zip(texts, embeddings):
            self.set_embedding(text, embedding)
    

    def get_cached_embeddings(self, texts: List[str], embed_function: callable, batch_size: Optional[int] = None) -> List[NDArray]:
        """
        Smart embedding getter: uses cache for existing, generates for missing
        
        Arguments:
        ----------
            texts          { list }     : List of texts

            embed_function { callable } : Function to generate embeddings for missing texts
            
            batch_size     { int }      : Batch size for generation
        
        Returns:
        --------
                       { list }         : List of embeddings
        """
        # Get cached embeddings
        cached_embeddings, missing_texts = self.batch_get_embeddings(texts = texts)
        
        if not missing_texts:
            self.logger.debug(f"All {len(texts)} embeddings found in cache")
            return cached_embeddings
        
        # Generate missing embeddings
        self.logger.info(f"Generating {len(missing_texts)} embeddings ({(len(missing_texts)/len(texts))*100:.1f}% cache miss)")
        
        missing_embeddings               = embed_function(missing_texts, batch_size = batch_size)
        
        # Store new embeddings in cache
        self.batch_set_embeddings(missing_texts, missing_embeddings)
        
        # Combine results
        result_embeddings = list()
        missing_idx       = 0
        
        for emb in cached_embeddings:
            if emb is not None:
                result_embeddings.append(emb)
            
            else:
                result_embeddings.append(missing_embeddings[missing_idx])
                missing_idx += 1
        
        return result_embeddings
    

    def clear(self):
        """
        Clear entire cache
        """
        self.base_cache.clear()
        self.hits                  = 0
        self.misses                = 0
        self.embeddings_generated  = 0
        
        self.logger.info("Cleared embedding cache")
    

    def get_stats(self) -> dict:
        """
        Get cache statistics
        
        Returns:
        --------
            { dict }    : Statistics dictionary
        """
        base_stats     = self.base_cache.get_stats()
        
        total_requests = self.hits + self.misses
        hit_rate       = (self.hits / total_requests * 100) if (total_requests > 0) else 0
        
        stats = {**base_stats,
                 "hits"                  : self.hits,
                 "misses"                : self.misses,
                 "hit_rate_percentage"   : hit_rate,
                 "embeddings_generated"  : self.embeddings_generated,
                 "cache_size"            : self.base_cache.cache.size(),
                 "max_size"              : self.max_size,
                }
        
        return stats
    

    def save_to_file(self, file_path: str) -> bool:
        """
        Save cache to file
        
        Arguments:
        ----------
            file_path { str } : Path to save file
        
        Returns:
        --------
               { bool }       : True if successful
        """
        return self.base_cache.save_to_file(file_path)
    

    def load_from_file(self, file_path: str) -> bool:
        """
        Load cache from file
        
        Arguments:
        ----------
            file_path { str } : Path to load file
        
        Returns:
        --------
               { bool }       : True if successful
        """
        return self.base_cache.load_from_file(file_path)
    

    def warm_cache(self, texts: List[str], embed_function: callable, batch_size: Optional[int] = None):
        """
        Pre-populate cache with embeddings
        
        Arguments:
        ----------
            texts            { list }   : List of texts to warm cache with

            embed_function { callable } : Embedding generation function
            
            batch_size       { int }    : Batch size
        """
        # Check which texts are not in cache
        _, missing_texts = self.batch_get_embeddings(texts = texts)
        
        if not missing_texts:
            self.logger.info("Cache already warm for all texts")
            return
        
        self.logger.info(f"Warming cache with {len(missing_texts)} embeddings")
        
        # Generate and cache embeddings
        embeddings = embed_function(missing_texts, batch_size = batch_size)

        self.batch_set_embeddings(missing_texts, embeddings)
        
        self.logger.info(f"Cache warming complete: added {len(missing_texts)} embeddings")


# Global embedding cache instance
_embedding_cache = None


def get_embedding_cache() -> EmbeddingCache:
    """
    Get global embedding cache instance
    
    Returns:
    --------
        { EmbeddingCache } : EmbeddingCache instance
    """
    global _embedding_cache

    if _embedding_cache is None:
        _embedding_cache = EmbeddingCache()
    
    return _embedding_cache


def cache_embeddings(texts: List[str], embeddings: List[NDArray]):
    """
    Convenience function to cache embeddings
    
    Arguments:
    ----------
        texts      { list } : List of texts

        embeddings { list } : List of embeddings
    """
    cache = get_embedding_cache()

    cache.batch_set_embeddings(texts, embeddings)


def get_cached_embeddings(texts: List[str], embed_function: callable, **kwargs) -> List[NDArray]:
    """
    Convenience function to get cached embeddings
    
    Arguments:
    ----------
        texts          { list } : List of texts

        embed_function { callable } : Embedding function
        
        **kwargs                   : Additional arguments
    
    Returns:
    --------
                 { list }          : List of embeddings
    """
    cache = get_embedding_cache()

    return cache.get_cached_embeddings(texts, embed_function, **kwargs)