AgenticAI-RAG / src /retrieval /embeddings.py
GreymanT's picture
Upload 80 files
8bf4d58 verified
"""Embedding generation using OpenAI."""
import logging
from typing import List, Optional
from functools import lru_cache
from openai import OpenAI
from src.core.config import get_settings
logger = logging.getLogger(__name__)
class EmbeddingGenerator:
"""Generate embeddings using OpenAI."""
def __init__(self, client: Optional[OpenAI] = None):
"""Initialize the embedding generator."""
self.settings = get_settings()
self.client = client or OpenAI(**self.settings.get_openai_client_kwargs())
self.model = self.settings.openai_embedding_model
self._cache: dict = {}
def generate_embedding(self, text: str, use_cache: bool = True) -> List[float]:
"""
Generate embedding for a single text.
Args:
text: Input text to embed
use_cache: Whether to use caching
Returns:
Embedding vector as a list of floats
"""
if use_cache and text in self._cache:
return self._cache[text]
try:
response = self.client.embeddings.create(
model=self.model,
input=text,
)
embedding = response.data[0].embedding
if use_cache:
self._cache[text] = embedding
return embedding
except Exception as e:
logger.error(f"Error generating embedding: {e}")
raise
def generate_embeddings_batch(
self, texts: List[str], use_cache: bool = True
) -> List[List[float]]:
"""
Generate embeddings for multiple texts in batch.
Args:
texts: List of input texts to embed
use_cache: Whether to use caching
Returns:
List of embedding vectors
"""
# Check cache first
cached_embeddings = {}
texts_to_embed = []
indices = []
for i, text in enumerate(texts):
if use_cache and text in self._cache:
cached_embeddings[i] = self._cache[text]
else:
texts_to_embed.append(text)
indices.append(i)
if not texts_to_embed:
# All embeddings were cached
return [cached_embeddings[i] for i in range(len(texts))]
# Generate embeddings for uncached texts
embeddings = []
try:
# OpenAI supports batch processing
response = self.client.embeddings.create(
model=self.model,
input=texts_to_embed,
)
new_embeddings = {indices[i]: item.embedding for i, item in enumerate(response.data)}
# Update cache
if use_cache:
for idx, text in zip(indices, texts_to_embed):
self._cache[text] = new_embeddings[idx]
# Combine cached and new embeddings
for i in range(len(texts)):
if i in cached_embeddings:
embeddings.append(cached_embeddings[i])
else:
embeddings.append(new_embeddings[i])
return embeddings
except Exception as e:
logger.error(f"Error generating batch embeddings: {e}")
raise
def clear_cache(self) -> None:
"""Clear the embedding cache."""
self._cache.clear()
def get_cache_size(self) -> int:
"""Get the number of cached embeddings."""
return len(self._cache)
# Global instance
_embedding_generator: Optional[EmbeddingGenerator] = None
def get_embedding_generator() -> EmbeddingGenerator:
"""Get or create the global embedding generator instance."""
global _embedding_generator
if _embedding_generator is None:
_embedding_generator = EmbeddingGenerator()
return _embedding_generator