import os from typing import Any, Dict, List, Optional from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import ( AzureOpenAIEmbedding as LlamaAzureEmbedding, ) from evoagentx.core.logging import logger from .base import BaseEmbeddingWrapper, EmbeddingProvider, SUPPORTED_MODELS MODEL_DIMENSIONS: Dict[str, int] = { "text-embedding-ada-002": 1536, "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, } SUPPORTED_DIMENSIONS = [ "text-embedding-3-small", "text-embedding-3-large", ] class AzureOpenAIEmbedding(BaseEmbedding): """Azure OpenAI embedding implementation compatible with LlamaIndex.""" api_key: Optional[str] = None model_name: str = "text-embedding-3-small" azure_endpoint: Optional[str] = None api_version: Optional[str] = None deployment_name: Optional[str] = None embed_batch_size: int = 10 dimensions: Optional[int] = None kwargs: Dict[str, Any] = {} def __init__( self, model_name: str = "text-embedding-3-small", api_key: Optional[str] = None, azure_endpoint: Optional[str] = None, api_version: Optional[str] = None, deployment_name: Optional[str] = None, dimensions: Optional[int] = None, embed_batch_size: int = 10, **kwargs: Any, ) -> None: api_key = ( api_key or os.getenv("AZURE_EMBED_API_KEY") or os.getenv("OPENAI_API_KEY") or "" ) super().__init__(api_key=api_key, model_name=model_name, embed_batch_size=embed_batch_size) self.model_name = model_name self.azure_endpoint = azure_endpoint or os.getenv("AZURE_EMBED_ENDPOINT") self.api_version = api_version or os.getenv("AZURE_EMBED_API_VERSION") self.deployment_name = ( deployment_name or os.getenv("AZURE_EMBED_DEPLOYMENT") or model_name ) self.dimensions = dimensions self.embed_batch_size = embed_batch_size self.kwargs = kwargs or {} if not EmbeddingProvider.validate_model(EmbeddingProvider.AZURE_OPENAI, model_name): raise ValueError( "Unsupported Azure OpenAI model: " f"{model_name}. Supported models: {SUPPORTED_MODELS['azure_openai']}" ) if self.dimensions is not None and model_name not in SUPPORTED_DIMENSIONS: logger.warning( "Dimensions parameter is not supported for model %s. Only %s support custom " "dimensions. Ignoring provided value.", model_name, SUPPORTED_DIMENSIONS, ) self.dimensions = None elif self.dimensions is None and model_name in SUPPORTED_DIMENSIONS: self.dimensions = MODEL_DIMENSIONS.get(model_name) try: self._embedding: LlamaAzureEmbedding = LlamaAzureEmbedding( model=self.model_name, azure_endpoint=self.azure_endpoint, azure_deployment=self.deployment_name, deployment_name=self.deployment_name, api_key=self.api_key, api_version=self.api_version, dimensions=self.dimensions, embed_batch_size=self.embed_batch_size, **self.kwargs, ) if self.dimensions is None and hasattr(self._embedding, "dimensions"): self.dimensions = getattr(self._embedding, "dimensions") logger.debug( "Initialized Azure OpenAI embedding: model=%s deployment=%s", self.model_name, self.deployment_name, ) except Exception as exc: logger.error("Failed to initialize Azure OpenAI embedding: %s", exc) raise def _get_query_embedding(self, query: str) -> List[float]: query = query.replace("\n", " ") try: return self._embedding.get_query_embedding(query) except Exception as exc: logger.error("Failed to encode query with Azure OpenAI: %s", exc) raise def _get_text_embedding(self, text: str) -> List[float]: text = text.replace("\n", " ") try: return self._embedding.get_text_embedding(text) except Exception as exc: logger.error("Failed to encode text with Azure OpenAI: %s", exc) raise async def _aget_query_embedding(self, query: str) -> List[float]: return self._get_query_embedding(query) def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: cleaned = [text.replace("\n", " ") for text in texts] try: return self._embedding.get_text_embedding_batch(cleaned) except Exception as exc: logger.error("Failed to encode texts with Azure OpenAI: %s", exc) raise class AzureOpenAIEmbeddingWrapper(BaseEmbeddingWrapper): """Wrapper for Azure OpenAI embedding models.""" def __init__( self, model_name: str = "text-embedding-3-small", api_key: Optional[str] = None, azure_endpoint: Optional[str] = None, api_version: Optional[str] = None, deployment_name: Optional[str] = None, dimensions: Optional[int] = None, embed_batch_size: int = 10, **kwargs: Any, ) -> None: self.model_name = model_name self.api_key = api_key self.azure_endpoint = azure_endpoint self.api_version = api_version self.deployment_name = deployment_name or model_name self.kwargs = kwargs or {} self.embed_batch_size = embed_batch_size self._dimensions = MODEL_DIMENSIONS.get(self.model_name) or dimensions self._embedding_model: Optional[AzureOpenAIEmbedding] = None self._dimensions = self._dimensions or dimensions def get_embedding_model(self) -> BaseEmbedding: if self._embedding_model is None: try: self._embedding_model = AzureOpenAIEmbedding( model_name=self.model_name, api_key=self.api_key, azure_endpoint=self.azure_endpoint, api_version=self.api_version, deployment_name=self.deployment_name, dimensions=self._dimensions, embed_batch_size=self.embed_batch_size, **self.kwargs, ) logger.debug( "Initialized Azure OpenAI embedding wrapper for model %s", self.model_name, ) except Exception as exc: logger.error("Failed to initialize Azure OpenAI embedding wrapper: %s", exc) raise return self._embedding_model @property def dimensions(self) -> Optional[int]: return self._embedding_model.dimensions if self._embedding_model else self._dimensions