File size: 6,764 Bytes
0231daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""
Simple in-memory caching layer for embeddings.

This module provides an LRU cache for embedding results to reduce
redundant computations for identical requests.
"""

import hashlib
import json
import time
from typing import Any, Dict, List, Optional, Union
from collections import OrderedDict
from threading import Lock
from loguru import logger


class EmbeddingCache:
    """
    Thread-safe LRU cache for embedding results.

    This cache stores embedding results with a TTL (time-to-live) and
    implements LRU eviction when the cache is full.

    Attributes:
        max_size: Maximum number of entries in the cache
        ttl: Time-to-live in seconds for cached entries
        _cache: OrderedDict storing cached entries
        _lock: Threading lock for thread-safety
        _hits: Number of cache hits
        _misses: Number of cache misses
    """

    def __init__(self, max_size: int = 1000, ttl: int = 3600):
        """
        Initialize the embedding cache.

        Args:
            max_size: Maximum number of entries (default: 1000)
            ttl: Time-to-live in seconds (default: 3600 = 1 hour)
        """
        self.max_size = max_size
        self.ttl = ttl
        self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
        self._lock = Lock()
        self._hits = 0
        self._misses = 0

        logger.info(f"Initialized embedding cache (max_size={max_size}, ttl={ttl}s)")

    def _generate_key(
        self,
        texts: Union[str, List[str]],
        model_id: str,
        prompt: Optional[str] = None,
        **kwargs,
    ) -> str:
        """
        Generate a unique cache key for the request.

        Args:
            texts: Single text or list of texts
            model_id: Model identifier
            prompt: Optional prompt
            **kwargs: Additional parameters

        Returns:
            SHA256 hash of the request parameters
        """
        # Normalize texts to list
        if isinstance(texts, str):
            texts = [texts]

        # Create deterministic representation
        cache_dict = {
            "texts": texts,
            "model_id": model_id,
            "prompt": prompt,
            "kwargs": sorted(kwargs.items()) if kwargs else [],
        }

        # Generate hash
        cache_str = json.dumps(cache_dict, sort_keys=True)
        return hashlib.sha256(cache_str.encode()).hexdigest()

    def get(
        self,
        texts: Union[str, List[str]],
        model_id: str,
        prompt: Optional[str] = None,
        **kwargs,
    ) -> Optional[Any]:
        """
        Retrieve a cached embedding result.

        Args:
            texts: Single text or list of texts
            model_id: Model identifier
            prompt: Optional prompt
            **kwargs: Additional parameters

        Returns:
            Cached result if found and not expired, None otherwise
        """
        key = self._generate_key(texts, model_id, prompt, **kwargs)

        with self._lock:
            if key not in self._cache:
                self._misses += 1
                return None

            entry = self._cache[key]

            # Check if expired
            if time.time() - entry["timestamp"] > self.ttl:
                del self._cache[key]
                self._misses += 1
                logger.debug(f"Cache entry expired: {key[:8]}...")
                return None

            # Move to end (LRU)
            self._cache.move_to_end(key)
            self._hits += 1

            logger.debug(f"Cache hit: {key[:8]}... (hit_rate={self.hit_rate:.2%})")

            return entry["result"]

    def set(
        self,
        texts: Union[str, List[str]],
        model_id: str,
        result: Any,
        prompt: Optional[str] = None,
        **kwargs,
    ) -> None:
        """
        Store an embedding result in the cache.

        Args:
            texts: Single text or list of texts
            model_id: Model identifier
            result: Embedding result to cache
            prompt: Optional prompt
            **kwargs: Additional parameters
        """
        key = self._generate_key(texts, model_id, prompt, **kwargs)

        with self._lock:
            # Evict oldest entry if cache is full
            if len(self._cache) >= self.max_size:
                oldest_key = next(iter(self._cache))
                del self._cache[oldest_key]
                logger.debug(f"Cache full, evicted: {oldest_key[:8]}...")

            # Store new entry
            self._cache[key] = {"result": result, "timestamp": time.time()}

            logger.debug(
                f"Cache set: {key[:8]}... (size={len(self._cache)}/{self.max_size})"
            )

    def clear(self) -> None:
        """Clear all cached entries."""
        with self._lock:
            count = len(self._cache)
            self._cache.clear()
            self._hits = 0
            self._misses = 0
            logger.info(f"Cleared {count} cache entries")

    def cleanup_expired(self) -> int:
        """
        Remove all expired entries from the cache.

        Returns:
            Number of entries removed
        """
        with self._lock:
            current_time = time.time()
            expired_keys = [
                key
                for key, entry in self._cache.items()
                if current_time - entry["timestamp"] > self.ttl
            ]

            for key in expired_keys:
                del self._cache[key]

            if expired_keys:
                logger.info(f"Cleaned up {len(expired_keys)} expired cache entries")

            return len(expired_keys)

    @property
    def size(self) -> int:
        """Get current number of cached entries."""
        return len(self._cache)

    @property
    def hit_rate(self) -> float:
        """
        Calculate cache hit rate.

        Returns:
            Hit rate as a float between 0 and 1
        """
        total = self._hits + self._misses
        if total == 0:
            return 0.0
        return self._hits / total

    @property
    def stats(self) -> Dict[str, Any]:
        """
        Get cache statistics.

        Returns:
            Dictionary with cache statistics
        """
        return {
            "size": self.size,
            "max_size": self.max_size,
            "hits": self._hits,
            "misses": self._misses,
            "hit_rate": f"{self.hit_rate:.2%}",
            "ttl": self.ttl,
        }

    def __repr__(self) -> str:
        """String representation of the cache."""
        return (
            f"EmbeddingCache("
            f"size={self.size}/{self.max_size}, "
            f"hits={self._hits}, "
            f"misses={self._misses}, "
            f"hit_rate={self.hit_rate:.2%})"
        )