File size: 11,296 Bytes
1dab660 5411262 1dab660 5411262 1dab660 787b7ff 1dab660 5fb63e2 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 5411262 1dab660 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 |
"""
OpenAI embeddings client with caching and retry logic.
This module provides an optimized client for generating text embeddings
using OpenAI's API with built-in caching and automatic retry on failures.
"""
import os
import hashlib
from typing import List
from functools import wraps
from openai import OpenAI, RateLimitError
from config import get_settings
from loguru import logger
try:
from diskcache import Cache
CACHE_AVAILABLE = True
except ImportError:
CACHE_AVAILABLE = False
logger.warning("diskcache not available - embeddings caching disabled")
class EmbeddingsClient:
"""
Client for generating embeddings with caching and retry logic.
Features:
- Persistent disk cache for embeddings (using diskcache)
- Automatic retry with exponential backoff
- Batch processing support
- Text truncation for long inputs
Attributes:
client: OpenAI client instance
model: Embedding model name
settings: Application settings
cache: Disk cache for embeddings (if available)
Examples:
>>> client = EmbeddingsClient()
>>> embedding = client.get_embedding("Hello world")
>>> len(embedding)
3072
>>> # Second call uses cache
>>> embedding2 = client.get_embedding("Hello world")
>>> embedding == embedding2
True
"""
def __init__(self, cache_dir: str = "./cache/embeddings"):
"""
Initialize OpenAI client with configured endpoint and cache.
Args:
cache_dir: Directory for persistent embedding cache
"""
self.settings = get_settings()
logger.info(f"Initializing EmbeddingsClient with {self.settings.llm_base_url}")
self.client = OpenAI(
api_key=self.settings.openai_api_key,
base_url=self.settings.llm_base_url
)
self.model = self.settings.embedding_model
# Initialize cache if available
if CACHE_AVAILABLE:
try:
import os
os.makedirs(cache_dir, exist_ok=True)
self.cache = Cache(cache_dir)
logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: enabled)")
except (OSError, PermissionError) as e:
logger.warning(f"Could not create cache directory (read-only?): {e}")
self.cache = None
logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: disabled)")
else:
self.cache = None
logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: disabled)")
def close(self):
"""Close cache and clean up resources."""
try:
if self.cache is not None:
self.cache.close()
logger.info("EmbeddingsClient cache closed")
if hasattr(self.client, 'close'):
self.client.close()
except Exception as e:
logger.warning(f"Error closing EmbeddingsClient: {e}")
def _get_cache_key(self, text: str) -> str:
"""
Generate cache key for text.
Args:
text: Input text
Returns:
MD5 hash of text as cache key
"""
return hashlib.md5(f"{self.model}:{text}".encode()).hexdigest()
def _get_embedding_uncached(self, text: str, retry_count: int = 5) -> List[float]:
"""
Generate embedding without cache (internal method).
Args:
text: Input text (already truncated)
retry_count: Number of retries on rate limit
Returns:
Embedding vector
"""
for attempt in range(retry_count):
try:
response = self.client.embeddings.create(
model=self.model,
input=text
)
embedding = response.data[0].embedding
logger.debug(f"Generated embedding (dim={len(embedding)})")
return embedding
except RateLimitError as e:
if attempt == retry_count - 1:
raise
wait_time = (2 ** attempt) * 2
logger.warning(f"Rate limited. Retrying in {wait_time}s (attempt {attempt + 1}/{retry_count})")
import time
time.sleep(wait_time)
raise RuntimeError(f"Failed after {retry_count} retries")
def get_embedding(self, text: str) -> List[float]:
"""
Generate embedding for text with caching.
Uses disk cache to avoid regenerating embeddings for the same text.
Automatically truncates long texts to 8191 characters.
Args:
text: Input text to embed
Returns:
List of float values representing the embedding
Examples:
>>> embedding = client.get_embedding("Hello world")
>>> len(embedding)
3072
"""
# Truncate text if too long (max token limit for embeddings)
text = text[:8191]
# Check cache
if self.cache is not None:
cache_key = self._get_cache_key(text)
if cache_key in self.cache:
logger.debug("Cache hit for embedding")
return self.cache[cache_key]
# Generate embedding
embedding = self._get_embedding_uncached(text)
# Store in cache
if self.cache is not None:
cache_key = self._get_cache_key(text)
self.cache[cache_key] = embedding
return embedding
def _get_embeddings_batch_uncached(
self,
texts: List[str],
retry_count: int = 3
) -> List[List[float]]:
"""
Generate embeddings for batch without cache (internal method).
Args:
texts: List of texts (already truncated)
retry_count: Number of retries on rate limit
Returns:
List of embedding vectors
"""
for attempt in range(retry_count):
try:
response = self.client.embeddings.create(
model=self.model,
input=texts
)
# Sort by index to maintain order
batch_embeddings = sorted(response.data, key=lambda x: x.index)
return [e.embedding for e in batch_embeddings]
except RateLimitError as e:
if attempt == retry_count - 1:
raise
wait_time = (2 ** attempt) * 2
logger.warning(f"Rate limited. Retrying in {wait_time}s (attempt {attempt + 1}/{retry_count})")
import time
time.sleep(wait_time)
raise RuntimeError(f"Failed after {retry_count} retries")
def get_embeddings_batch(
self,
texts: List[str],
batch_size: int = 100
) -> List[List[float]]:
"""
Generate embeddings for multiple texts efficiently with caching.
Processes texts in batches and uses cache when available.
Each text is checked against cache individually.
Args:
texts: List of texts to embed
batch_size: Number of texts per API call
Returns:
List of embeddings (one per input text)
Examples:
>>> texts = ["Hello", "World", "!"]
>>> embeddings = client.get_embeddings_batch(texts)
>>> len(embeddings)
3
"""
import time
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
num_batches = (len(texts) + batch_size - 1) // batch_size
current_batch_num = i // batch_size + 1
logger.info(
f" - Processing batch {current_batch_num}/{num_batches} "
f"({len(batch_texts)} texts)"
)
# Check cache for each text in batch
batch_embeddings = []
texts_to_generate = []
indices_to_generate = []
if self.cache is not None:
for idx, text in enumerate(batch_texts):
text_truncated = text[:8191]
cache_key = self._get_cache_key(text_truncated)
if cache_key in self.cache:
batch_embeddings.append((idx, self.cache[cache_key]))
else:
texts_to_generate.append(text_truncated)
indices_to_generate.append(idx)
if texts_to_generate:
logger.debug(
f"Cache: {len(batch_embeddings)} hits, "
f"{len(texts_to_generate)} misses"
)
else:
# No cache - generate all
texts_to_generate = [t[:8191] for t in batch_texts]
indices_to_generate = list(range(len(batch_texts)))
# Generate embeddings for cache misses
if texts_to_generate:
try:
generated = self._get_embeddings_batch_uncached(
texts_to_generate
)
# Store in cache and add to results
for idx, text, embedding in zip(
indices_to_generate,
texts_to_generate,
generated
):
batch_embeddings.append((idx, embedding))
if self.cache is not None:
cache_key = self._get_cache_key(text)
self.cache[cache_key] = embedding
except Exception as e:
logger.error(f" - Batch embedding failed: {e}")
raise
# Sort by original index and extract embeddings
batch_embeddings.sort(key=lambda x: x[0])
all_embeddings.extend([emb for _, emb in batch_embeddings])
# Small delay between batches to avoid rate limiting
if num_batches > 1 and current_batch_num < num_batches:
time.sleep(0.5)
logger.success(f" 🧠 Generated {len(all_embeddings)} embeddings total.")
return all_embeddings
def get_embeddings_client() -> EmbeddingsClient:
"""
Get or create embeddings client (singleton pattern).
Returns:
Shared EmbeddingsClient instance
Examples:
>>> client = get_embeddings_client()
>>> embedding = client.get_embedding("test")
"""
if not hasattr(get_embeddings_client, '_instance'):
get_embeddings_client._instance = EmbeddingsClient()
return get_embeddings_client._instance
|