""" OpenAI embeddings client with caching and retry logic. This module provides an optimized client for generating text embeddings using OpenAI's API with built-in caching and automatic retry on failures. """ import os import hashlib from typing import List from functools import wraps from openai import OpenAI, RateLimitError from config import get_settings from loguru import logger try: from diskcache import Cache CACHE_AVAILABLE = True except ImportError: CACHE_AVAILABLE = False logger.warning("diskcache not available - embeddings caching disabled") class EmbeddingsClient: """ Client for generating embeddings with caching and retry logic. Features: - Persistent disk cache for embeddings (using diskcache) - Automatic retry with exponential backoff - Batch processing support - Text truncation for long inputs Attributes: client: OpenAI client instance model: Embedding model name settings: Application settings cache: Disk cache for embeddings (if available) Examples: >>> client = EmbeddingsClient() >>> embedding = client.get_embedding("Hello world") >>> len(embedding) 3072 >>> # Second call uses cache >>> embedding2 = client.get_embedding("Hello world") >>> embedding == embedding2 True """ def __init__(self, cache_dir: str = "./cache/embeddings"): """ Initialize OpenAI client with configured endpoint and cache. Args: cache_dir: Directory for persistent embedding cache """ self.settings = get_settings() logger.info(f"Initializing EmbeddingsClient with {self.settings.llm_base_url}") self.client = OpenAI( api_key=self.settings.openai_api_key, base_url=self.settings.llm_base_url ) self.model = self.settings.embedding_model # Initialize cache if available if CACHE_AVAILABLE: try: import os os.makedirs(cache_dir, exist_ok=True) self.cache = Cache(cache_dir) logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: enabled)") except (OSError, PermissionError) as e: logger.warning(f"Could not create cache directory (read-only?): {e}") self.cache = None logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: disabled)") else: self.cache = None logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: disabled)") def close(self): """Close cache and clean up resources.""" try: if self.cache is not None: self.cache.close() logger.info("EmbeddingsClient cache closed") if hasattr(self.client, 'close'): self.client.close() except Exception as e: logger.warning(f"Error closing EmbeddingsClient: {e}") def _get_cache_key(self, text: str) -> str: """ Generate cache key for text. Args: text: Input text Returns: MD5 hash of text as cache key """ return hashlib.md5(f"{self.model}:{text}".encode()).hexdigest() def _get_embedding_uncached(self, text: str, retry_count: int = 5) -> List[float]: """ Generate embedding without cache (internal method). Args: text: Input text (already truncated) retry_count: Number of retries on rate limit Returns: Embedding vector """ for attempt in range(retry_count): try: response = self.client.embeddings.create( model=self.model, input=text ) embedding = response.data[0].embedding logger.debug(f"Generated embedding (dim={len(embedding)})") return embedding except RateLimitError as e: if attempt == retry_count - 1: raise wait_time = (2 ** attempt) * 2 logger.warning(f"Rate limited. Retrying in {wait_time}s (attempt {attempt + 1}/{retry_count})") import time time.sleep(wait_time) raise RuntimeError(f"Failed after {retry_count} retries") def get_embedding(self, text: str) -> List[float]: """ Generate embedding for text with caching. Uses disk cache to avoid regenerating embeddings for the same text. Automatically truncates long texts to 8191 characters. Args: text: Input text to embed Returns: List of float values representing the embedding Examples: >>> embedding = client.get_embedding("Hello world") >>> len(embedding) 3072 """ # Truncate text if too long (max token limit for embeddings) text = text[:8191] # Check cache if self.cache is not None: cache_key = self._get_cache_key(text) if cache_key in self.cache: logger.debug("Cache hit for embedding") return self.cache[cache_key] # Generate embedding embedding = self._get_embedding_uncached(text) # Store in cache if self.cache is not None: cache_key = self._get_cache_key(text) self.cache[cache_key] = embedding return embedding def _get_embeddings_batch_uncached( self, texts: List[str], retry_count: int = 3 ) -> List[List[float]]: """ Generate embeddings for batch without cache (internal method). Args: texts: List of texts (already truncated) retry_count: Number of retries on rate limit Returns: List of embedding vectors """ for attempt in range(retry_count): try: response = self.client.embeddings.create( model=self.model, input=texts ) # Sort by index to maintain order batch_embeddings = sorted(response.data, key=lambda x: x.index) return [e.embedding for e in batch_embeddings] except RateLimitError as e: if attempt == retry_count - 1: raise wait_time = (2 ** attempt) * 2 logger.warning(f"Rate limited. Retrying in {wait_time}s (attempt {attempt + 1}/{retry_count})") import time time.sleep(wait_time) raise RuntimeError(f"Failed after {retry_count} retries") def get_embeddings_batch( self, texts: List[str], batch_size: int = 100 ) -> List[List[float]]: """ Generate embeddings for multiple texts efficiently with caching. Processes texts in batches and uses cache when available. Each text is checked against cache individually. Args: texts: List of texts to embed batch_size: Number of texts per API call Returns: List of embeddings (one per input text) Examples: >>> texts = ["Hello", "World", "!"] >>> embeddings = client.get_embeddings_batch(texts) >>> len(embeddings) 3 """ import time all_embeddings = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] num_batches = (len(texts) + batch_size - 1) // batch_size current_batch_num = i // batch_size + 1 logger.info( f" - Processing batch {current_batch_num}/{num_batches} " f"({len(batch_texts)} texts)" ) # Check cache for each text in batch batch_embeddings = [] texts_to_generate = [] indices_to_generate = [] if self.cache is not None: for idx, text in enumerate(batch_texts): text_truncated = text[:8191] cache_key = self._get_cache_key(text_truncated) if cache_key in self.cache: batch_embeddings.append((idx, self.cache[cache_key])) else: texts_to_generate.append(text_truncated) indices_to_generate.append(idx) if texts_to_generate: logger.debug( f"Cache: {len(batch_embeddings)} hits, " f"{len(texts_to_generate)} misses" ) else: # No cache - generate all texts_to_generate = [t[:8191] for t in batch_texts] indices_to_generate = list(range(len(batch_texts))) # Generate embeddings for cache misses if texts_to_generate: try: generated = self._get_embeddings_batch_uncached( texts_to_generate ) # Store in cache and add to results for idx, text, embedding in zip( indices_to_generate, texts_to_generate, generated ): batch_embeddings.append((idx, embedding)) if self.cache is not None: cache_key = self._get_cache_key(text) self.cache[cache_key] = embedding except Exception as e: logger.error(f" - Batch embedding failed: {e}") raise # Sort by original index and extract embeddings batch_embeddings.sort(key=lambda x: x[0]) all_embeddings.extend([emb for _, emb in batch_embeddings]) # Small delay between batches to avoid rate limiting if num_batches > 1 and current_batch_num < num_batches: time.sleep(0.5) logger.success(f" 🧠 Generated {len(all_embeddings)} embeddings total.") return all_embeddings def get_embeddings_client() -> EmbeddingsClient: """ Get or create embeddings client (singleton pattern). Returns: Shared EmbeddingsClient instance Examples: >>> client = get_embeddings_client() >>> embedding = client.get_embedding("test") """ if not hasattr(get_embeddings_client, '_instance'): get_embeddings_client._instance = EmbeddingsClient() return get_embeddings_client._instance