JacekAI / models /embeddings.py
Jacek Zadrożny
Add detailed logging and fix read-only filesystem issues
787b7ff
"""
OpenAI embeddings client with caching and retry logic.
This module provides an optimized client for generating text embeddings
using OpenAI's API with built-in caching and automatic retry on failures.
"""
import os
import hashlib
from typing import List
from functools import wraps
from openai import OpenAI, RateLimitError
from config import get_settings
from loguru import logger
try:
from diskcache import Cache
CACHE_AVAILABLE = True
except ImportError:
CACHE_AVAILABLE = False
logger.warning("diskcache not available - embeddings caching disabled")
class EmbeddingsClient:
"""
Client for generating embeddings with caching and retry logic.
Features:
- Persistent disk cache for embeddings (using diskcache)
- Automatic retry with exponential backoff
- Batch processing support
- Text truncation for long inputs
Attributes:
client: OpenAI client instance
model: Embedding model name
settings: Application settings
cache: Disk cache for embeddings (if available)
Examples:
>>> client = EmbeddingsClient()
>>> embedding = client.get_embedding("Hello world")
>>> len(embedding)
3072
>>> # Second call uses cache
>>> embedding2 = client.get_embedding("Hello world")
>>> embedding == embedding2
True
"""
def __init__(self, cache_dir: str = "./cache/embeddings"):
"""
Initialize OpenAI client with configured endpoint and cache.
Args:
cache_dir: Directory for persistent embedding cache
"""
self.settings = get_settings()
logger.info(f"Initializing EmbeddingsClient with {self.settings.llm_base_url}")
self.client = OpenAI(
api_key=self.settings.openai_api_key,
base_url=self.settings.llm_base_url
)
self.model = self.settings.embedding_model
# Initialize cache if available
if CACHE_AVAILABLE:
try:
import os
os.makedirs(cache_dir, exist_ok=True)
self.cache = Cache(cache_dir)
logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: enabled)")
except (OSError, PermissionError) as e:
logger.warning(f"Could not create cache directory (read-only?): {e}")
self.cache = None
logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: disabled)")
else:
self.cache = None
logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: disabled)")
def close(self):
"""Close cache and clean up resources."""
try:
if self.cache is not None:
self.cache.close()
logger.info("EmbeddingsClient cache closed")
if hasattr(self.client, 'close'):
self.client.close()
except Exception as e:
logger.warning(f"Error closing EmbeddingsClient: {e}")
def _get_cache_key(self, text: str) -> str:
"""
Generate cache key for text.
Args:
text: Input text
Returns:
MD5 hash of text as cache key
"""
return hashlib.md5(f"{self.model}:{text}".encode()).hexdigest()
def _get_embedding_uncached(self, text: str, retry_count: int = 5) -> List[float]:
"""
Generate embedding without cache (internal method).
Args:
text: Input text (already truncated)
retry_count: Number of retries on rate limit
Returns:
Embedding vector
"""
for attempt in range(retry_count):
try:
response = self.client.embeddings.create(
model=self.model,
input=text
)
embedding = response.data[0].embedding
logger.debug(f"Generated embedding (dim={len(embedding)})")
return embedding
except RateLimitError as e:
if attempt == retry_count - 1:
raise
wait_time = (2 ** attempt) * 2
logger.warning(f"Rate limited. Retrying in {wait_time}s (attempt {attempt + 1}/{retry_count})")
import time
time.sleep(wait_time)
raise RuntimeError(f"Failed after {retry_count} retries")
def get_embedding(self, text: str) -> List[float]:
"""
Generate embedding for text with caching.
Uses disk cache to avoid regenerating embeddings for the same text.
Automatically truncates long texts to 8191 characters.
Args:
text: Input text to embed
Returns:
List of float values representing the embedding
Examples:
>>> embedding = client.get_embedding("Hello world")
>>> len(embedding)
3072
"""
# Truncate text if too long (max token limit for embeddings)
text = text[:8191]
# Check cache
if self.cache is not None:
cache_key = self._get_cache_key(text)
if cache_key in self.cache:
logger.debug("Cache hit for embedding")
return self.cache[cache_key]
# Generate embedding
embedding = self._get_embedding_uncached(text)
# Store in cache
if self.cache is not None:
cache_key = self._get_cache_key(text)
self.cache[cache_key] = embedding
return embedding
def _get_embeddings_batch_uncached(
self,
texts: List[str],
retry_count: int = 3
) -> List[List[float]]:
"""
Generate embeddings for batch without cache (internal method).
Args:
texts: List of texts (already truncated)
retry_count: Number of retries on rate limit
Returns:
List of embedding vectors
"""
for attempt in range(retry_count):
try:
response = self.client.embeddings.create(
model=self.model,
input=texts
)
# Sort by index to maintain order
batch_embeddings = sorted(response.data, key=lambda x: x.index)
return [e.embedding for e in batch_embeddings]
except RateLimitError as e:
if attempt == retry_count - 1:
raise
wait_time = (2 ** attempt) * 2
logger.warning(f"Rate limited. Retrying in {wait_time}s (attempt {attempt + 1}/{retry_count})")
import time
time.sleep(wait_time)
raise RuntimeError(f"Failed after {retry_count} retries")
def get_embeddings_batch(
self,
texts: List[str],
batch_size: int = 100
) -> List[List[float]]:
"""
Generate embeddings for multiple texts efficiently with caching.
Processes texts in batches and uses cache when available.
Each text is checked against cache individually.
Args:
texts: List of texts to embed
batch_size: Number of texts per API call
Returns:
List of embeddings (one per input text)
Examples:
>>> texts = ["Hello", "World", "!"]
>>> embeddings = client.get_embeddings_batch(texts)
>>> len(embeddings)
3
"""
import time
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
num_batches = (len(texts) + batch_size - 1) // batch_size
current_batch_num = i // batch_size + 1
logger.info(
f" - Processing batch {current_batch_num}/{num_batches} "
f"({len(batch_texts)} texts)"
)
# Check cache for each text in batch
batch_embeddings = []
texts_to_generate = []
indices_to_generate = []
if self.cache is not None:
for idx, text in enumerate(batch_texts):
text_truncated = text[:8191]
cache_key = self._get_cache_key(text_truncated)
if cache_key in self.cache:
batch_embeddings.append((idx, self.cache[cache_key]))
else:
texts_to_generate.append(text_truncated)
indices_to_generate.append(idx)
if texts_to_generate:
logger.debug(
f"Cache: {len(batch_embeddings)} hits, "
f"{len(texts_to_generate)} misses"
)
else:
# No cache - generate all
texts_to_generate = [t[:8191] for t in batch_texts]
indices_to_generate = list(range(len(batch_texts)))
# Generate embeddings for cache misses
if texts_to_generate:
try:
generated = self._get_embeddings_batch_uncached(
texts_to_generate
)
# Store in cache and add to results
for idx, text, embedding in zip(
indices_to_generate,
texts_to_generate,
generated
):
batch_embeddings.append((idx, embedding))
if self.cache is not None:
cache_key = self._get_cache_key(text)
self.cache[cache_key] = embedding
except Exception as e:
logger.error(f" - Batch embedding failed: {e}")
raise
# Sort by original index and extract embeddings
batch_embeddings.sort(key=lambda x: x[0])
all_embeddings.extend([emb for _, emb in batch_embeddings])
# Small delay between batches to avoid rate limiting
if num_batches > 1 and current_batch_num < num_batches:
time.sleep(0.5)
logger.success(f" 🧠 Generated {len(all_embeddings)} embeddings total.")
return all_embeddings
def get_embeddings_client() -> EmbeddingsClient:
"""
Get or create embeddings client (singleton pattern).
Returns:
Shared EmbeddingsClient instance
Examples:
>>> client = get_embeddings_client()
>>> embedding = client.get_embedding("test")
"""
if not hasattr(get_embeddings_client, '_instance'):
get_embeddings_client._instance = EmbeddingsClient()
return get_embeddings_client._instance