|
|
""" |
|
|
OpenAI embeddings client with caching and retry logic. |
|
|
|
|
|
This module provides an optimized client for generating text embeddings |
|
|
using OpenAI's API with built-in caching and automatic retry on failures. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import hashlib |
|
|
from typing import List |
|
|
from functools import wraps |
|
|
from openai import OpenAI, RateLimitError |
|
|
from config import get_settings |
|
|
from loguru import logger |
|
|
|
|
|
try: |
|
|
from diskcache import Cache |
|
|
CACHE_AVAILABLE = True |
|
|
except ImportError: |
|
|
CACHE_AVAILABLE = False |
|
|
logger.warning("diskcache not available - embeddings caching disabled") |
|
|
|
|
|
|
|
|
class EmbeddingsClient: |
|
|
""" |
|
|
Client for generating embeddings with caching and retry logic. |
|
|
|
|
|
Features: |
|
|
- Persistent disk cache for embeddings (using diskcache) |
|
|
- Automatic retry with exponential backoff |
|
|
- Batch processing support |
|
|
- Text truncation for long inputs |
|
|
|
|
|
Attributes: |
|
|
client: OpenAI client instance |
|
|
model: Embedding model name |
|
|
settings: Application settings |
|
|
cache: Disk cache for embeddings (if available) |
|
|
|
|
|
Examples: |
|
|
>>> client = EmbeddingsClient() |
|
|
>>> embedding = client.get_embedding("Hello world") |
|
|
>>> len(embedding) |
|
|
3072 |
|
|
|
|
|
>>> # Second call uses cache |
|
|
>>> embedding2 = client.get_embedding("Hello world") |
|
|
>>> embedding == embedding2 |
|
|
True |
|
|
""" |
|
|
|
|
|
def __init__(self, cache_dir: str = "./cache/embeddings"): |
|
|
""" |
|
|
Initialize OpenAI client with configured endpoint and cache. |
|
|
|
|
|
Args: |
|
|
cache_dir: Directory for persistent embedding cache |
|
|
""" |
|
|
self.settings = get_settings() |
|
|
|
|
|
logger.info(f"Initializing EmbeddingsClient with {self.settings.llm_base_url}") |
|
|
self.client = OpenAI( |
|
|
api_key=self.settings.openai_api_key, |
|
|
base_url=self.settings.llm_base_url |
|
|
) |
|
|
|
|
|
self.model = self.settings.embedding_model |
|
|
|
|
|
|
|
|
if CACHE_AVAILABLE: |
|
|
try: |
|
|
import os |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
self.cache = Cache(cache_dir) |
|
|
logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: enabled)") |
|
|
except (OSError, PermissionError) as e: |
|
|
logger.warning(f"Could not create cache directory (read-only?): {e}") |
|
|
self.cache = None |
|
|
logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: disabled)") |
|
|
else: |
|
|
self.cache = None |
|
|
logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: disabled)") |
|
|
|
|
|
def close(self): |
|
|
"""Close cache and clean up resources.""" |
|
|
try: |
|
|
if self.cache is not None: |
|
|
self.cache.close() |
|
|
logger.info("EmbeddingsClient cache closed") |
|
|
if hasattr(self.client, 'close'): |
|
|
self.client.close() |
|
|
except Exception as e: |
|
|
logger.warning(f"Error closing EmbeddingsClient: {e}") |
|
|
|
|
|
def _get_cache_key(self, text: str) -> str: |
|
|
""" |
|
|
Generate cache key for text. |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
|
|
|
Returns: |
|
|
MD5 hash of text as cache key |
|
|
""" |
|
|
return hashlib.md5(f"{self.model}:{text}".encode()).hexdigest() |
|
|
|
|
|
def _get_embedding_uncached(self, text: str, retry_count: int = 5) -> List[float]: |
|
|
""" |
|
|
Generate embedding without cache (internal method). |
|
|
|
|
|
Args: |
|
|
text: Input text (already truncated) |
|
|
retry_count: Number of retries on rate limit |
|
|
|
|
|
Returns: |
|
|
Embedding vector |
|
|
""" |
|
|
for attempt in range(retry_count): |
|
|
try: |
|
|
response = self.client.embeddings.create( |
|
|
model=self.model, |
|
|
input=text |
|
|
) |
|
|
embedding = response.data[0].embedding |
|
|
logger.debug(f"Generated embedding (dim={len(embedding)})") |
|
|
return embedding |
|
|
except RateLimitError as e: |
|
|
if attempt == retry_count - 1: |
|
|
raise |
|
|
wait_time = (2 ** attempt) * 2 |
|
|
logger.warning(f"Rate limited. Retrying in {wait_time}s (attempt {attempt + 1}/{retry_count})") |
|
|
import time |
|
|
time.sleep(wait_time) |
|
|
raise RuntimeError(f"Failed after {retry_count} retries") |
|
|
|
|
|
def get_embedding(self, text: str) -> List[float]: |
|
|
""" |
|
|
Generate embedding for text with caching. |
|
|
|
|
|
Uses disk cache to avoid regenerating embeddings for the same text. |
|
|
Automatically truncates long texts to 8191 characters. |
|
|
|
|
|
Args: |
|
|
text: Input text to embed |
|
|
|
|
|
Returns: |
|
|
List of float values representing the embedding |
|
|
|
|
|
Examples: |
|
|
>>> embedding = client.get_embedding("Hello world") |
|
|
>>> len(embedding) |
|
|
3072 |
|
|
""" |
|
|
|
|
|
text = text[:8191] |
|
|
|
|
|
|
|
|
if self.cache is not None: |
|
|
cache_key = self._get_cache_key(text) |
|
|
if cache_key in self.cache: |
|
|
logger.debug("Cache hit for embedding") |
|
|
return self.cache[cache_key] |
|
|
|
|
|
|
|
|
embedding = self._get_embedding_uncached(text) |
|
|
|
|
|
|
|
|
if self.cache is not None: |
|
|
cache_key = self._get_cache_key(text) |
|
|
self.cache[cache_key] = embedding |
|
|
|
|
|
return embedding |
|
|
|
|
|
def _get_embeddings_batch_uncached( |
|
|
self, |
|
|
texts: List[str], |
|
|
retry_count: int = 3 |
|
|
) -> List[List[float]]: |
|
|
""" |
|
|
Generate embeddings for batch without cache (internal method). |
|
|
|
|
|
Args: |
|
|
texts: List of texts (already truncated) |
|
|
retry_count: Number of retries on rate limit |
|
|
|
|
|
Returns: |
|
|
List of embedding vectors |
|
|
""" |
|
|
for attempt in range(retry_count): |
|
|
try: |
|
|
response = self.client.embeddings.create( |
|
|
model=self.model, |
|
|
input=texts |
|
|
) |
|
|
|
|
|
batch_embeddings = sorted(response.data, key=lambda x: x.index) |
|
|
return [e.embedding for e in batch_embeddings] |
|
|
except RateLimitError as e: |
|
|
if attempt == retry_count - 1: |
|
|
raise |
|
|
wait_time = (2 ** attempt) * 2 |
|
|
logger.warning(f"Rate limited. Retrying in {wait_time}s (attempt {attempt + 1}/{retry_count})") |
|
|
import time |
|
|
time.sleep(wait_time) |
|
|
raise RuntimeError(f"Failed after {retry_count} retries") |
|
|
|
|
|
def get_embeddings_batch( |
|
|
self, |
|
|
texts: List[str], |
|
|
batch_size: int = 100 |
|
|
) -> List[List[float]]: |
|
|
""" |
|
|
Generate embeddings for multiple texts efficiently with caching. |
|
|
|
|
|
Processes texts in batches and uses cache when available. |
|
|
Each text is checked against cache individually. |
|
|
|
|
|
Args: |
|
|
texts: List of texts to embed |
|
|
batch_size: Number of texts per API call |
|
|
|
|
|
Returns: |
|
|
List of embeddings (one per input text) |
|
|
|
|
|
Examples: |
|
|
>>> texts = ["Hello", "World", "!"] |
|
|
>>> embeddings = client.get_embeddings_batch(texts) |
|
|
>>> len(embeddings) |
|
|
3 |
|
|
""" |
|
|
import time |
|
|
all_embeddings = [] |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch_texts = texts[i:i + batch_size] |
|
|
|
|
|
num_batches = (len(texts) + batch_size - 1) // batch_size |
|
|
current_batch_num = i // batch_size + 1 |
|
|
|
|
|
logger.info( |
|
|
f" - Processing batch {current_batch_num}/{num_batches} " |
|
|
f"({len(batch_texts)} texts)" |
|
|
) |
|
|
|
|
|
|
|
|
batch_embeddings = [] |
|
|
texts_to_generate = [] |
|
|
indices_to_generate = [] |
|
|
|
|
|
if self.cache is not None: |
|
|
for idx, text in enumerate(batch_texts): |
|
|
text_truncated = text[:8191] |
|
|
cache_key = self._get_cache_key(text_truncated) |
|
|
|
|
|
if cache_key in self.cache: |
|
|
batch_embeddings.append((idx, self.cache[cache_key])) |
|
|
else: |
|
|
texts_to_generate.append(text_truncated) |
|
|
indices_to_generate.append(idx) |
|
|
|
|
|
if texts_to_generate: |
|
|
logger.debug( |
|
|
f"Cache: {len(batch_embeddings)} hits, " |
|
|
f"{len(texts_to_generate)} misses" |
|
|
) |
|
|
else: |
|
|
|
|
|
texts_to_generate = [t[:8191] for t in batch_texts] |
|
|
indices_to_generate = list(range(len(batch_texts))) |
|
|
|
|
|
|
|
|
if texts_to_generate: |
|
|
try: |
|
|
generated = self._get_embeddings_batch_uncached( |
|
|
texts_to_generate |
|
|
) |
|
|
|
|
|
|
|
|
for idx, text, embedding in zip( |
|
|
indices_to_generate, |
|
|
texts_to_generate, |
|
|
generated |
|
|
): |
|
|
batch_embeddings.append((idx, embedding)) |
|
|
|
|
|
if self.cache is not None: |
|
|
cache_key = self._get_cache_key(text) |
|
|
self.cache[cache_key] = embedding |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f" - Batch embedding failed: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
batch_embeddings.sort(key=lambda x: x[0]) |
|
|
all_embeddings.extend([emb for _, emb in batch_embeddings]) |
|
|
|
|
|
|
|
|
if num_batches > 1 and current_batch_num < num_batches: |
|
|
time.sleep(0.5) |
|
|
|
|
|
logger.success(f" 🧠 Generated {len(all_embeddings)} embeddings total.") |
|
|
return all_embeddings |
|
|
|
|
|
|
|
|
def get_embeddings_client() -> EmbeddingsClient: |
|
|
""" |
|
|
Get or create embeddings client (singleton pattern). |
|
|
|
|
|
Returns: |
|
|
Shared EmbeddingsClient instance |
|
|
|
|
|
Examples: |
|
|
>>> client = get_embeddings_client() |
|
|
>>> embedding = client.get_embedding("test") |
|
|
""" |
|
|
if not hasattr(get_embeddings_client, '_instance'): |
|
|
get_embeddings_client._instance = EmbeddingsClient() |
|
|
return get_embeddings_client._instance |
|
|
|