File size: 2,205 Bytes
23a5cce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Embedding Provider Factory - Dynamic Provider Creation
"""

from typing import Dict, Any, Optional
from warbler_cda.embeddings.base_provider import EmbeddingProvider
from warbler_cda.embeddings.local_provider import LocalEmbeddingProvider
from warbler_cda.embeddings.openai_provider import OpenAIEmbeddingProvider
from warbler_cda.embeddings.sentence_transformer_provider import SentenceTransformerEmbeddingProvider


class EmbeddingProviderFactory:
    """Factory for creating embedding providers."""
    
    PROVIDERS = {
        "local": LocalEmbeddingProvider,
        "openai": OpenAIEmbeddingProvider,
        "sentence_transformer": SentenceTransformerEmbeddingProvider,
    }
    
    @classmethod
    def create_provider(cls, provider_type: str, config: Optional[Dict[str, Any]] = None) -> EmbeddingProvider:
        """Create an embedding provider of the specified type."""
        if provider_type not in cls.PROVIDERS:
            available = list(cls.PROVIDERS.keys())
            raise ValueError(f"Unknown provider type '{provider_type}'. Available: {available}")
            
        provider_class = cls.PROVIDERS[provider_type]
        return provider_class(config)
        
    @classmethod 
    def get_default_provider(cls, config: Optional[Dict[str, Any]] = None) -> EmbeddingProvider:
        """Get the default embedding provider (SentenceTransformer with fallback to local)."""
        try:
            return cls.create_provider("sentence_transformer", config)
        except ImportError:
            print("Warning: SentenceTransformer not available, falling back to LocalEmbeddingProvider")
            return cls.create_provider("local", config)
        
    @classmethod
    def list_available_providers(cls) -> list[str]:
        """List all available provider types."""
        return list(cls.PROVIDERS.keys())
        
    @classmethod
    def create_from_config(cls, full_config: Dict[str, Any]) -> EmbeddingProvider:
        """Create provider from configuration dict."""
        provider_type = full_config.get("provider", "local")
        provider_config = full_config.get("config", {})
        
        return cls.create_provider(provider_type, provider_config)