fahmiaziz98 commited on
Commit
d9f3e5d
·
1 Parent(s): f1d93c7

[DELETED]: embedding cache

Browse files
src/api/dependencies.py CHANGED
@@ -6,16 +6,13 @@ route handlers, ensuring consistent access to shared resources.
6
  """
7
 
8
  from typing import Optional
9
- from fastapi import Depends, HTTPException, status
10
 
11
- from src.config.settings import Settings, get_settings
12
  from src.core.manager import ModelManager
13
- from src.core.cache import EmbeddingCache
14
 
15
 
16
  # Global instances (initialized at startup)
17
  _model_manager: Optional[ModelManager] = None
18
- _embedding_cache: Optional[EmbeddingCache] = None
19
 
20
 
21
  def set_model_manager(manager: ModelManager) -> None:
@@ -31,19 +28,6 @@ def set_model_manager(manager: ModelManager) -> None:
31
  _model_manager = manager
32
 
33
 
34
- def set_embedding_cache(cache: EmbeddingCache) -> None:
35
- """
36
- Set the global embedding cache instance.
37
-
38
- Called during application startup if caching is enabled.
39
-
40
- Args:
41
- cache: EmbeddingCache instance
42
- """
43
- global _embedding_cache
44
- _embedding_cache = cache
45
-
46
-
47
  def get_model_manager() -> ModelManager:
48
  """
49
  Get the model manager instance.
@@ -72,29 +56,3 @@ def get_model_manager() -> ModelManager:
72
  )
73
  return _model_manager
74
 
75
-
76
- def get_embedding_cache() -> Optional[EmbeddingCache]:
77
- """
78
- Get the embedding cache instance (if enabled).
79
-
80
- Returns:
81
- EmbeddingCache instance or None if caching is disabled
82
- """
83
- return _embedding_cache
84
-
85
-
86
- def get_cache_if_enabled(
87
- settings: Settings = Depends(get_settings),
88
- ) -> Optional[EmbeddingCache]:
89
- """
90
- Get cache only if caching is enabled in settings.
91
-
92
- Args:
93
- settings: Application settings
94
-
95
- Returns:
96
- EmbeddingCache instance if enabled, None otherwise
97
- """
98
- if settings.ENABLE_CACHE:
99
- return _embedding_cache
100
- return None
 
6
  """
7
 
8
  from typing import Optional
9
+ from fastapi import HTTPException, status
10
 
 
11
  from src.core.manager import ModelManager
 
12
 
13
 
14
  # Global instances (initialized at startup)
15
  _model_manager: Optional[ModelManager] = None
 
16
 
17
 
18
  def set_model_manager(manager: ModelManager) -> None:
 
28
  _model_manager = manager
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def get_model_manager() -> ModelManager:
32
  """
33
  Get the model manager instance.
 
56
  )
57
  return _model_manager
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/api/routers/embedding.py CHANGED
@@ -17,14 +17,13 @@ from src.models.schemas import (
17
  SparseEmbedding,
18
  )
19
  from src.core.manager import ModelManager
20
- from src.core.cache import EmbeddingCache
21
  from src.core.exceptions import (
22
  ModelNotFoundError,
23
  ModelNotLoadedError,
24
  EmbeddingGenerationError,
25
  ValidationError,
26
  )
27
- from src.api.dependencies import get_model_manager, get_cache_if_enabled
28
  from src.utils.validators import extract_embedding_kwargs, validate_texts
29
  from src.config.settings import get_settings
30
 
@@ -41,7 +40,6 @@ router = APIRouter(prefix="/embeddings", tags=["embeddings"])
41
  async def create_embeddings_document(
42
  request: EmbedRequest,
43
  manager: ModelManager = Depends(get_model_manager),
44
- cache: EmbeddingCache = Depends(get_cache_if_enabled),
45
  settings=Depends(get_settings),
46
  ):
47
  """
@@ -53,7 +51,6 @@ async def create_embeddings_document(
53
  Args:
54
  request: BatchEmbedRequest with texts, model_id, and optional parameters
55
  manager: Model manager dependency
56
- cache: Cache dependency (if enabled)
57
  settings: Application settings
58
 
59
  Returns:
@@ -73,19 +70,6 @@ async def create_embeddings_document(
73
  # Extract kwargs
74
  kwargs = extract_embedding_kwargs(request)
75
 
76
- # Check cache first (batch requests typically not cached due to size)
77
- # But we can cache if batch is small
78
- if cache is not None and len(request.texts) <= 10:
79
- cache_key = str(sorted(request.texts)) # Simple key for small batches
80
- cached_result = cache.get(
81
- texts=cache_key,
82
- model_id=request.model_id,
83
- prompt=request.prompt,
84
- **kwargs,
85
- )
86
- if cached_result is not None:
87
- logger.debug(f"Cache hit for batch (size={len(request.texts)})")
88
- return cached_result
89
 
90
  # Get model
91
  model = manager.get_model(request.model_id)
@@ -133,17 +117,6 @@ async def create_embeddings_document(
133
  processing_time=processing_time,
134
  )
135
 
136
- # Cache small batches
137
- if cache is not None and len(request.texts) <= 10:
138
- cache_key = str(sorted(request.texts))
139
- cache.set(
140
- texts=cache_key,
141
- model_id=request.model_id,
142
- result=response,
143
- prompt=request.prompt,
144
- **kwargs,
145
- )
146
-
147
  logger.info(
148
  f"Generated {len(request.texts)} embeddings "
149
  f"in {processing_time:.3f}s ({len(request.texts) / processing_time:.1f} texts/s)"
@@ -174,7 +147,6 @@ async def create_embeddings_document(
174
  async def create_query_embedding(
175
  request: EmbedRequest,
176
  manager: ModelManager = Depends(get_model_manager),
177
- cache: EmbeddingCache = Depends(get_cache_if_enabled),
178
  ):
179
  """
180
  Generate a single/batch query embedding.
@@ -185,7 +157,6 @@ async def create_query_embedding(
185
  Args:
186
  request: EmbedRequest with text, model_id, and optional parameters
187
  manager: Model manager dependency
188
- cache: Cache dependency (if enabled)
189
  settings: Application settings
190
 
191
  Returns:
@@ -201,20 +172,6 @@ async def create_query_embedding(
201
  # Extract kwargs
202
  kwargs = extract_embedding_kwargs(request)
203
 
204
- # Check cache (with 'query' prefix in key)
205
- cache_key_kwargs = {"endpoint": "query", **kwargs}
206
-
207
- if cache is not None:
208
- cached_result = cache.get(
209
- texts=request.texts,
210
- model_id=request.model_id,
211
- prompt=request.prompt,
212
- **cache_key_kwargs,
213
- )
214
- if cached_result is not None:
215
- logger.debug(f"Cache hit for query model {request.model_id}")
216
- return cached_result
217
-
218
  # Get model
219
  model = manager.get_model(request.model_id)
220
  config = manager.model_configs[request.model_id]
@@ -261,16 +218,6 @@ async def create_query_embedding(
261
  processing_time=processing_time,
262
  )
263
 
264
- # Cache small batches
265
- if cache is not None and len(request.texts) <= 10:
266
- cache_key = str(sorted(request.texts))
267
- cache.set(
268
- texts=cache_key,
269
- model_id=request.model_id,
270
- result=response,
271
- prompt=request.prompt,
272
- **kwargs,
273
- )
274
 
275
  logger.info(
276
  f"Generated {len(request.texts)} embeddings "
 
17
  SparseEmbedding,
18
  )
19
  from src.core.manager import ModelManager
 
20
  from src.core.exceptions import (
21
  ModelNotFoundError,
22
  ModelNotLoadedError,
23
  EmbeddingGenerationError,
24
  ValidationError,
25
  )
26
+ from src.api.dependencies import get_model_manager
27
  from src.utils.validators import extract_embedding_kwargs, validate_texts
28
  from src.config.settings import get_settings
29
 
 
40
  async def create_embeddings_document(
41
  request: EmbedRequest,
42
  manager: ModelManager = Depends(get_model_manager),
 
43
  settings=Depends(get_settings),
44
  ):
45
  """
 
51
  Args:
52
  request: BatchEmbedRequest with texts, model_id, and optional parameters
53
  manager: Model manager dependency
 
54
  settings: Application settings
55
 
56
  Returns:
 
70
  # Extract kwargs
71
  kwargs = extract_embedding_kwargs(request)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Get model
75
  model = manager.get_model(request.model_id)
 
117
  processing_time=processing_time,
118
  )
119
 
 
 
 
 
 
 
 
 
 
 
 
120
  logger.info(
121
  f"Generated {len(request.texts)} embeddings "
122
  f"in {processing_time:.3f}s ({len(request.texts) / processing_time:.1f} texts/s)"
 
147
  async def create_query_embedding(
148
  request: EmbedRequest,
149
  manager: ModelManager = Depends(get_model_manager),
 
150
  ):
151
  """
152
  Generate a single/batch query embedding.
 
157
  Args:
158
  request: EmbedRequest with text, model_id, and optional parameters
159
  manager: Model manager dependency
 
160
  settings: Application settings
161
 
162
  Returns:
 
172
  # Extract kwargs
173
  kwargs = extract_embedding_kwargs(request)
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  # Get model
176
  model = manager.get_model(request.model_id)
177
  config = manager.model_configs[request.model_id]
 
218
  processing_time=processing_time,
219
  )
220
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  logger.info(
223
  f"Generated {len(request.texts)} embeddings "
src/core/cache.py DELETED
@@ -1,237 +0,0 @@
1
- """
2
- Simple in-memory caching layer for embeddings.
3
-
4
- This module provides an LRU cache for embedding results to reduce
5
- redundant computations for identical requests.
6
- """
7
-
8
- import hashlib
9
- import json
10
- import time
11
- from typing import Any, Dict, List, Optional, Union
12
- from collections import OrderedDict
13
- from threading import Lock
14
- from loguru import logger
15
-
16
-
17
- class EmbeddingCache:
18
- """
19
- Thread-safe LRU cache for embedding results.
20
-
21
- This cache stores embedding results with a TTL (time-to-live) and
22
- implements LRU eviction when the cache is full.
23
-
24
- Attributes:
25
- max_size: Maximum number of entries in the cache
26
- ttl: Time-to-live in seconds for cached entries
27
- _cache: OrderedDict storing cached entries
28
- _lock: Threading lock for thread-safety
29
- _hits: Number of cache hits
30
- _misses: Number of cache misses
31
- """
32
-
33
- def __init__(self, max_size: int = 1000, ttl: int = 3600):
34
- """
35
- Initialize the embedding cache.
36
-
37
- Args:
38
- max_size: Maximum number of entries (default: 1000)
39
- ttl: Time-to-live in seconds (default: 3600 = 1 hour)
40
- """
41
- self.max_size = max_size
42
- self.ttl = ttl
43
- self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
44
- self._lock = Lock()
45
- self._hits = 0
46
- self._misses = 0
47
-
48
- logger.info(f"Initialized embedding cache (max_size={max_size}, ttl={ttl}s)")
49
-
50
- def _generate_key(
51
- self,
52
- texts: Union[str, List[str]],
53
- model_id: str,
54
- prompt: Optional[str] = None,
55
- **kwargs,
56
- ) -> str:
57
- """
58
- Generate a unique cache key for the request.
59
-
60
- Args:
61
- texts: Single text or list of texts
62
- model_id: Model identifier
63
- prompt: Optional prompt
64
- **kwargs: Additional parameters
65
-
66
- Returns:
67
- SHA256 hash of the request parameters
68
- """
69
- # Normalize texts to list
70
- if isinstance(texts, str):
71
- texts = [texts]
72
-
73
- # Create deterministic representation
74
- cache_dict = {
75
- "texts": texts,
76
- "model_id": model_id,
77
- "prompt": prompt,
78
- "kwargs": sorted(kwargs.items()) if kwargs else [],
79
- }
80
-
81
- # Generate hash
82
- cache_str = json.dumps(cache_dict, sort_keys=True)
83
- return hashlib.sha256(cache_str.encode()).hexdigest()
84
-
85
- def get(
86
- self,
87
- texts: Union[str, List[str]],
88
- model_id: str,
89
- prompt: Optional[str] = None,
90
- **kwargs,
91
- ) -> Optional[Any]:
92
- """
93
- Retrieve a cached embedding result.
94
-
95
- Args:
96
- texts: Single text or list of texts
97
- model_id: Model identifier
98
- prompt: Optional prompt
99
- **kwargs: Additional parameters
100
-
101
- Returns:
102
- Cached result if found and not expired, None otherwise
103
- """
104
- key = self._generate_key(texts, model_id, prompt, **kwargs)
105
-
106
- with self._lock:
107
- if key not in self._cache:
108
- self._misses += 1
109
- return None
110
-
111
- entry = self._cache[key]
112
-
113
- # Check if expired
114
- if time.time() - entry["timestamp"] > self.ttl:
115
- del self._cache[key]
116
- self._misses += 1
117
- logger.debug(f"Cache entry expired: {key[:8]}...")
118
- return None
119
-
120
- # Move to end (LRU)
121
- self._cache.move_to_end(key)
122
- self._hits += 1
123
-
124
- logger.debug(f"Cache hit: {key[:8]}... (hit_rate={self.hit_rate:.2%})")
125
-
126
- return entry["result"]
127
-
128
- def set(
129
- self,
130
- texts: Union[str, List[str]],
131
- model_id: str,
132
- result: Any,
133
- prompt: Optional[str] = None,
134
- **kwargs,
135
- ) -> None:
136
- """
137
- Store an embedding result in the cache.
138
-
139
- Args:
140
- texts: Single text or list of texts
141
- model_id: Model identifier
142
- result: Embedding result to cache
143
- prompt: Optional prompt
144
- **kwargs: Additional parameters
145
- """
146
- key = self._generate_key(texts, model_id, prompt, **kwargs)
147
-
148
- with self._lock:
149
- # Evict oldest entry if cache is full
150
- if len(self._cache) >= self.max_size:
151
- oldest_key = next(iter(self._cache))
152
- del self._cache[oldest_key]
153
- logger.debug(f"Cache full, evicted: {oldest_key[:8]}...")
154
-
155
- # Store new entry
156
- self._cache[key] = {"result": result, "timestamp": time.time()}
157
-
158
- logger.debug(
159
- f"Cache set: {key[:8]}... (size={len(self._cache)}/{self.max_size})"
160
- )
161
-
162
- def clear(self) -> None:
163
- """Clear all cached entries."""
164
- with self._lock:
165
- count = len(self._cache)
166
- self._cache.clear()
167
- self._hits = 0
168
- self._misses = 0
169
- logger.info(f"Cleared {count} cache entries")
170
-
171
- def cleanup_expired(self) -> int:
172
- """
173
- Remove all expired entries from the cache.
174
-
175
- Returns:
176
- Number of entries removed
177
- """
178
- with self._lock:
179
- current_time = time.time()
180
- expired_keys = [
181
- key
182
- for key, entry in self._cache.items()
183
- if current_time - entry["timestamp"] > self.ttl
184
- ]
185
-
186
- for key in expired_keys:
187
- del self._cache[key]
188
-
189
- if expired_keys:
190
- logger.info(f"Cleaned up {len(expired_keys)} expired cache entries")
191
-
192
- return len(expired_keys)
193
-
194
- @property
195
- def size(self) -> int:
196
- """Get current number of cached entries."""
197
- return len(self._cache)
198
-
199
- @property
200
- def hit_rate(self) -> float:
201
- """
202
- Calculate cache hit rate.
203
-
204
- Returns:
205
- Hit rate as a float between 0 and 1
206
- """
207
- total = self._hits + self._misses
208
- if total == 0:
209
- return 0.0
210
- return self._hits / total
211
-
212
- @property
213
- def stats(self) -> Dict[str, Any]:
214
- """
215
- Get cache statistics.
216
-
217
- Returns:
218
- Dictionary with cache statistics
219
- """
220
- return {
221
- "size": self.size,
222
- "max_size": self.max_size,
223
- "hits": self._hits,
224
- "misses": self._misses,
225
- "hit_rate": f"{self.hit_rate:.2%}",
226
- "ttl": self.ttl,
227
- }
228
-
229
- def __repr__(self) -> str:
230
- """String representation of the cache."""
231
- return (
232
- f"EmbeddingCache("
233
- f"size={self.size}/{self.max_size}, "
234
- f"hits={self._hits}, "
235
- f"misses={self._misses}, "
236
- f"hit_rate={self.hit_rate:.2%})"
237
- )