File size: 2,294 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import Dict, Any

from evoagentx.core.logging import logger
from .base import EmbeddingProvider, BaseEmbeddingWrapper
from .openai_embedding import OpenAIEmbeddingWrapper
from .azure_openai_embedding import AzureOpenAIEmbeddingWrapper
from .huggingface_embedding import HuggingFaceEmbeddingWrapper
from .ollama_embedding import OllamaEmbeddingWrapper
from .voyage import VoyageEmbeddingWrapper

__all__ = [
    'OpenAIEmbeddingWrapper',
    'AzureOpenAIEmbeddingWrapper',
    'HuggingFaceEmbeddingWrapper',
    'OllamaEmbeddingWrapper',
    'VoyageEmbeddingWrapper',
    'EmbeddingFactory',
    'BaseEmbedding',
    'EmbeddingProvider'
]

class EmbeddingFactory:
    """Factory for creating embedding models based on configuration."""
    
    def create(
        self,
        provider: EmbeddingProvider,
        model_config: Dict[str, Any] = None
    ) -> BaseEmbeddingWrapper:
        """Create an embedding model based on the provider and configuration.
        
        Args:
            provider (EmbeddingProvider): The embedding provider (e.g., OpenAI, HuggingFace, Ollama).
            model_config (Dict[str, Any], optional): Configuration for the embedding model.
            
        Returns:
            BaseEmbeddingWrapper: A LlamaIndex-compatible embedding model wrapper.
            
        Raises:
            ValueError: If the provider or configuration is invalid.
        """
        model_config = model_config or {}
        model_config.pop("provider")    # filter the provider key

        if provider == EmbeddingProvider.OPENAI:
            wrapper = OpenAIEmbeddingWrapper(**model_config)
        elif provider == EmbeddingProvider.AZURE_OPENAI:
            wrapper = AzureOpenAIEmbeddingWrapper(**model_config)
        elif provider == EmbeddingProvider.HUGGINGFACE:
            wrapper = HuggingFaceEmbeddingWrapper(**model_config)
        elif provider == EmbeddingProvider.OLLAMA:
            wrapper = OllamaEmbeddingWrapper(**model_config)
        elif provider == EmbeddingProvider.VOYAGE:
            wrapper = VoyageEmbeddingWrapper(**model_config)
        else:
            raise ValueError(f"Unsupported embedding provider: {provider}")
        
        logger.info(f"Created embedding model for provider: {provider}")
        return wrapper