|
|
import os |
|
|
import warnings |
|
|
from typing import List, Optional, Dict |
|
|
|
|
|
from openai import OpenAI |
|
|
from llama_index.core.embeddings import BaseEmbedding |
|
|
|
|
|
from evoagentx.core.logging import logger |
|
|
from .base import BaseEmbeddingWrapper, EmbeddingProvider, SUPPORTED_MODELS |
|
|
|
|
|
|
|
|
MODEL_DIMENSIONS = { |
|
|
"text-embedding-ada-002": 1536, |
|
|
"text-embedding-3-small": 1536, |
|
|
"text-embedding-3-large": 3072 |
|
|
} |
|
|
|
|
|
SUPPORTED_DIMENSIONS = ["text-embedding-3-small", "text-embedding-3-large",] |
|
|
|
|
|
class OpenAIEmbedding(BaseEmbedding): |
|
|
"""OpenAI embedding model compatible with LlamaIndex BaseEmbedding.""" |
|
|
|
|
|
api_key: str |
|
|
client: OpenAI = None |
|
|
base_url: str = "https://api.openai.com/v1" |
|
|
model_name: str = "text-embedding-3-small" |
|
|
embed_batch_size: int = 10 |
|
|
dimensions: Optional[int] = None |
|
|
kwargs: Optional[Dict] = {} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "text-embedding-3-small", |
|
|
api_key: str = None, |
|
|
dimensions: int = None, |
|
|
base_url: str = None, |
|
|
**kwargs |
|
|
): |
|
|
api_key = api_key or os.getenv("OPENAI_API_KEY") or "" |
|
|
super().__init__(api_key=api_key, model_name=model_name, embed_batch_size=10) |
|
|
base_url = ( |
|
|
base_url |
|
|
or os.getenv("OPENAI_API_BASE") |
|
|
or os.getenv("OPENAI_BASE_URL") |
|
|
or "https://api.openai.com/v1" |
|
|
) |
|
|
if os.environ.get("OPENAI_API_BASE"): |
|
|
warnings.warn( |
|
|
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. " |
|
|
"Please use 'OPENAI_BASE_URL' instead.", |
|
|
DeprecationWarning, |
|
|
) |
|
|
self.base_url = base_url |
|
|
self.dimensions = dimensions |
|
|
self.kwargs = kwargs |
|
|
|
|
|
if not EmbeddingProvider.validate_model(EmbeddingProvider.OPENAI, model_name): |
|
|
raise ValueError(f"Unsupported OpenAI model: {model_name}. Supported models: {SUPPORTED_MODELS['openai']}") |
|
|
|
|
|
if dimensions is not None and model_name not in SUPPORTED_DIMENSIONS: |
|
|
logger.warning( |
|
|
f"Dimensions parameter is not supported for model {model_name}. " |
|
|
f"Only '{SUPPORTED_DIMENSIONS}' support custom dimensions. Ignoring dimensions parameter." |
|
|
) |
|
|
self.dimensions = None |
|
|
elif dimensions is None and model_name in SUPPORTED_DIMENSIONS: |
|
|
self.dimensions = dimensions or MODEL_DIMENSIONS.get(model_name) |
|
|
|
|
|
try: |
|
|
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) |
|
|
logger.debug(f"Initialized OpenAI embedding model: {model_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize OpenAI client: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]: |
|
|
"""Get embedding for a query string.""" |
|
|
try: |
|
|
query = query.replace("\n", " ") |
|
|
response = self.client.embeddings.create( |
|
|
input=[query], |
|
|
model=self.model_name, |
|
|
dimensions=self.dimensions, |
|
|
**self.kwargs |
|
|
) |
|
|
return response.data[0].embedding |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to encode query: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]: |
|
|
"""Get embedding for a text string.""" |
|
|
try: |
|
|
text = text.replace("\n", " ") |
|
|
response = self.client.embeddings.create( |
|
|
input=[text], |
|
|
model=self.model_name, |
|
|
dimensions=self.dimensions, |
|
|
**self.kwargs |
|
|
) |
|
|
return response.data[0].embedding |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to encode text: {str(e)}") |
|
|
raise |
|
|
|
|
|
async def _aget_query_embedding(self, query: str) -> List[float]: |
|
|
"""Asynchronous query embedding.""" |
|
|
return self._get_query_embedding(query) |
|
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: |
|
|
"""Get embeddings for a list of texts synchronously.""" |
|
|
try: |
|
|
texts = [text.replace("\n", " ") for text in texts] |
|
|
response = self.client.embeddings.create( |
|
|
input=texts, |
|
|
model=self.model_name, |
|
|
dimensions=self.dimensions, |
|
|
**self.kwargs |
|
|
) |
|
|
return [item.embedding for item in response.data] |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to encode texts: {str(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
class OpenAIEmbeddingWrapper(BaseEmbeddingWrapper): |
|
|
"""Wrapper for OpenAI embedding models.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "text-embedding-3-small", |
|
|
api_key: str = None, |
|
|
dimensions: int = None, |
|
|
base_url: str = None, |
|
|
**kwargs |
|
|
): |
|
|
self.model_name = model_name |
|
|
self.api_key = api_key |
|
|
self._dimensions = MODEL_DIMENSIONS.get(self.model_name, None) or dimensions |
|
|
self.base_url = base_url |
|
|
self.kwargs = kwargs |
|
|
self._embedding_model = None |
|
|
self._embedding_model = self.get_embedding_model() |
|
|
|
|
|
def get_embedding_model(self) -> BaseEmbedding: |
|
|
"""Return the LlamaIndex-compatible embedding model.""" |
|
|
|
|
|
if getattr(self, "_embedding_model", None) is None: |
|
|
try: |
|
|
self._embedding_model = OpenAIEmbedding( |
|
|
model_name=self.model_name, |
|
|
api_key=self.api_key, |
|
|
dimensions=self._dimensions, |
|
|
base_url=self.base_url, |
|
|
**self.kwargs |
|
|
) |
|
|
logger.debug(f"Initialized OpenAI embedding wrapper for model: {self.model_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize OpenAI embedding wrapper: {str(e)}") |
|
|
raise |
|
|
return self._embedding_model |
|
|
|
|
|
@property |
|
|
def dimensions(self) -> int: |
|
|
"""Return the embedding dimensions.""" |
|
|
|
|
|
return self._dimensions |