File size: 2,294 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from typing import Dict, Any
from evoagentx.core.logging import logger
from .base import EmbeddingProvider, BaseEmbeddingWrapper
from .openai_embedding import OpenAIEmbeddingWrapper
from .azure_openai_embedding import AzureOpenAIEmbeddingWrapper
from .huggingface_embedding import HuggingFaceEmbeddingWrapper
from .ollama_embedding import OllamaEmbeddingWrapper
from .voyage import VoyageEmbeddingWrapper
__all__ = [
'OpenAIEmbeddingWrapper',
'AzureOpenAIEmbeddingWrapper',
'HuggingFaceEmbeddingWrapper',
'OllamaEmbeddingWrapper',
'VoyageEmbeddingWrapper',
'EmbeddingFactory',
'BaseEmbedding',
'EmbeddingProvider'
]
class EmbeddingFactory:
"""Factory for creating embedding models based on configuration."""
def create(
self,
provider: EmbeddingProvider,
model_config: Dict[str, Any] = None
) -> BaseEmbeddingWrapper:
"""Create an embedding model based on the provider and configuration.
Args:
provider (EmbeddingProvider): The embedding provider (e.g., OpenAI, HuggingFace, Ollama).
model_config (Dict[str, Any], optional): Configuration for the embedding model.
Returns:
BaseEmbeddingWrapper: A LlamaIndex-compatible embedding model wrapper.
Raises:
ValueError: If the provider or configuration is invalid.
"""
model_config = model_config or {}
model_config.pop("provider") # filter the provider key
if provider == EmbeddingProvider.OPENAI:
wrapper = OpenAIEmbeddingWrapper(**model_config)
elif provider == EmbeddingProvider.AZURE_OPENAI:
wrapper = AzureOpenAIEmbeddingWrapper(**model_config)
elif provider == EmbeddingProvider.HUGGINGFACE:
wrapper = HuggingFaceEmbeddingWrapper(**model_config)
elif provider == EmbeddingProvider.OLLAMA:
wrapper = OllamaEmbeddingWrapper(**model_config)
elif provider == EmbeddingProvider.VOYAGE:
wrapper = VoyageEmbeddingWrapper(**model_config)
else:
raise ValueError(f"Unsupported embedding provider: {provider}")
logger.info(f"Created embedding model for provider: {provider}")
return wrapper |