File size: 2,561 Bytes
32fc7aa e85ccf5 32fc7aa e85ccf5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
"""
Client Provider Registry for unified provider selection.
Implements the Strategy Pattern to decouple client creation from the factory.
"""
from typing import Any, ClassVar, Protocol
import structlog
from agent_framework import BaseChatClient
from src.utils.config import Settings
from src.utils.exceptions import ConfigurationError
logger = structlog.get_logger()
class ClientProvider(Protocol):
"""Protocol for LLM client providers."""
@property
def name(self) -> str:
"""Provider name (e.g., 'openai', 'huggingface')."""
...
def can_handle(
self, provider_name: str | None, api_key: str | None, settings: Settings
) -> bool:
"""Determine if this provider should handle the request."""
...
def create(
self,
settings: Settings,
api_key: str | None = None,
model_id: str | None = None,
**kwargs: Any,
) -> BaseChatClient:
"""Create the client instance."""
...
class ProviderRegistry:
"""Registry for managing available LLM providers."""
_providers: ClassVar[list[ClientProvider]] = []
@classmethod
def register(cls, provider: ClientProvider) -> None:
"""Register a new provider strategy."""
cls._providers.append(provider)
@classmethod
def clear(cls) -> None:
"""Clear all registered providers (useful for testing)."""
cls._providers.clear()
@classmethod
def get_client(
cls,
settings: Settings,
provider: str | None = None,
api_key: str | None = None,
model_id: str | None = None,
**kwargs: Any,
) -> BaseChatClient:
"""
Find and execute the appropriate provider strategy.
Args:
settings: Application settings
provider: Explicit provider name
api_key: Optional API key
model_id: Optional model ID
**kwargs: Additional arguments for the client
Returns:
Configured BaseChatClient
Raises:
ValueError: If no provider can handle the request
"""
# Normalize provider name
normalized_provider = provider.lower() if provider else None
for p in cls._providers:
if p.can_handle(normalized_provider, api_key, settings):
logger.info(f"Using {p.name} Chat Client")
return p.create(settings, api_key, model_id, **kwargs)
raise ConfigurationError(f"No suitable provider found for provider={provider}")
|