VibecoderMcSwaggins commited on
Commit
32fc7aa
·
unverified ·
1 Parent(s): f295ef3

refactor(factory): Implement Provider Registry Pattern (Priority 6) (#128)

Browse files

Priority 6: Provider Registry Pattern

- Implemented Strategy Pattern in src/clients/registry.py
- Decoupled client creation from factory
- Added HuggingFaceProvider and OpenAIProvider

✅ All checks passed

src/clients/factory.py CHANGED
@@ -4,13 +4,19 @@ from typing import Any
4
 
5
  import structlog
6
  from agent_framework import BaseChatClient
7
- from agent_framework.openai import OpenAIChatClient
8
 
9
- from src.clients.huggingface import HuggingFaceChatClient
 
10
  from src.utils.config import settings
11
 
12
  logger = structlog.get_logger()
13
 
 
 
 
 
 
 
14
 
15
  def get_chat_client(
16
  provider: str | None = None,
@@ -21,56 +27,30 @@ def get_chat_client(
21
  """
22
  Factory for creating chat clients.
23
 
24
- Auto-detection priority:
 
 
25
  1. Explicit provider parameter
26
- 2. API key prefix detection (sk- → OpenAI, sk-ant- → Anthropic)
27
- 3. OpenAI key from env (Best Function Calling)
28
- 4. Gemini key from env (Best Context/Cost)
29
- 5. HuggingFace (Free Fallback)
30
 
31
  Args:
32
- provider: Force specific provider ("openai", "gemini", "huggingface")
33
- api_key: Override API key for the provider (auto-detects provider from prefix)
34
  model_id: Override default model ID
35
  **kwargs: Additional arguments for the client
36
 
37
  Returns:
38
- Configured BaseChatClient instance (Namespace Neutral)
39
 
40
  Raises:
41
- ValueError: If an unsupported provider is explicitly requested
42
- NotImplementedError: If Gemini or Anthropic is requested (not yet implemented)
43
  """
44
- # Normalize provider to lowercase for case-insensitive matching
45
- normalized = provider.lower() if provider is not None else None
46
-
47
- # FIX: Auto-detect provider from API key prefix when not explicitly set
48
- # This enables BYOK (Bring Your Own Key) from Gradio without explicit provider
49
- # Order matters: "sk-ant-" must be checked before "sk-" (both start with "sk-")
50
- if normalized is None and api_key:
51
- if api_key.startswith("sk-"):
52
- normalized = "openai"
53
- # HF tokens start with "hf_" - no auto-detection needed (falls through to default)
54
-
55
- # Validate explicit provider requests early
56
- valid_providers = (None, "openai", "huggingface")
57
- if normalized not in valid_providers:
58
- raise ValueError(f"Unsupported provider: {provider!r}")
59
-
60
- # 1. OpenAI (Standard / Paid Tier)
61
- if normalized == "openai" or (normalized is None and settings.has_openai_key):
62
- logger.info("Using OpenAI Chat Client")
63
- return OpenAIChatClient(
64
- model_id=model_id or settings.openai_model,
65
- api_key=api_key or settings.openai_api_key,
66
- **kwargs,
67
- )
68
-
69
- # 4. HuggingFace (Free Fallback)
70
- # This is the default if no other keys are present
71
- logger.info("Using HuggingFace Chat Client (Free Tier)")
72
- return HuggingFaceChatClient(
73
- model_id=model_id or settings.huggingface_model,
74
- api_key=api_key or settings.hf_token,
75
  **kwargs,
76
  )
 
4
 
5
  import structlog
6
  from agent_framework import BaseChatClient
 
7
 
8
+ from src.clients.providers import HuggingFaceProvider, OpenAIProvider
9
+ from src.clients.registry import ProviderRegistry
10
  from src.utils.config import settings
11
 
12
  logger = structlog.get_logger()
13
 
14
+ # Register strategies in order of priority
15
+ # 1. OpenAI (Specific key or Env)
16
+ ProviderRegistry.register(OpenAIProvider())
17
+ # 2. HuggingFace (Free Fallback)
18
+ ProviderRegistry.register(HuggingFaceProvider())
19
+
20
 
21
  def get_chat_client(
22
  provider: str | None = None,
 
27
  """
28
  Factory for creating chat clients.
29
 
30
+ Delegates to ProviderRegistry for strategy selection.
31
+
32
+ Auto-detection priority (via Registry):
33
  1. Explicit provider parameter
34
+ 2. API key prefix detection (sk- → OpenAI)
35
+ 3. OpenAI key from env
36
+ 4. HuggingFace (Free Fallback)
 
37
 
38
  Args:
39
+ provider: Force specific provider ("openai", "huggingface")
40
+ api_key: Override API key
41
  model_id: Override default model ID
42
  **kwargs: Additional arguments for the client
43
 
44
  Returns:
45
+ Configured BaseChatClient instance
46
 
47
  Raises:
48
+ ValueError: If an unsupported provider is requested
 
49
  """
50
+ return ProviderRegistry.get_client(
51
+ settings=settings,
52
+ provider=provider,
53
+ api_key=api_key,
54
+ model_id=model_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  **kwargs,
56
  )
src/clients/providers.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Client Provider Strategies.
3
+ """
4
+
5
+ from typing import Any
6
+
7
+ from agent_framework import BaseChatClient
8
+ from agent_framework.openai import OpenAIChatClient
9
+
10
+ from src.clients.huggingface import HuggingFaceChatClient
11
+ from src.clients.registry import ClientProvider
12
+ from src.utils.config import Settings
13
+
14
+
15
+ class OpenAIProvider(ClientProvider):
16
+ """Strategy for OpenAI client creation."""
17
+
18
+ @property
19
+ def name(self) -> str:
20
+ return "OpenAI"
21
+
22
+ def can_handle(
23
+ self, provider_name: str | None, api_key: str | None, settings: Settings
24
+ ) -> bool:
25
+ # 1. Explicit request
26
+ if provider_name == "openai":
27
+ return True
28
+
29
+ # 2. BYOK Detection (sk-...)
30
+ if provider_name is None and api_key and api_key.startswith("sk-"):
31
+ return True
32
+
33
+ # 3. Env Fallback (if no explicit provider)
34
+ if provider_name is None and settings.has_openai_key:
35
+ return True
36
+
37
+ return False
38
+
39
+ def create(
40
+ self,
41
+ settings: Settings,
42
+ api_key: str | None = None,
43
+ model_id: str | None = None,
44
+ **kwargs: Any,
45
+ ) -> BaseChatClient:
46
+ return OpenAIChatClient(
47
+ model_id=model_id or settings.openai_model,
48
+ api_key=api_key or settings.openai_api_key,
49
+ **kwargs,
50
+ )
51
+
52
+
53
+ class HuggingFaceProvider(ClientProvider):
54
+ """Strategy for HuggingFace client creation (Free Tier Fallback)."""
55
+
56
+ @property
57
+ def name(self) -> str:
58
+ return "HuggingFace"
59
+
60
+ def can_handle(
61
+ self, provider_name: str | None, api_key: str | None, settings: Settings
62
+ ) -> bool:
63
+ # 1. Explicit request
64
+ if provider_name == "huggingface":
65
+ return True
66
+
67
+ # 2. Fallback (Default) - only if NO specific provider requested
68
+ if provider_name is None:
69
+ return True
70
+
71
+ return False
72
+
73
+ def create(
74
+ self,
75
+ settings: Settings,
76
+ api_key: str | None = None,
77
+ model_id: str | None = None,
78
+ **kwargs: Any,
79
+ ) -> BaseChatClient:
80
+ return HuggingFaceChatClient(
81
+ model_id=model_id or settings.huggingface_model,
82
+ api_key=api_key or settings.hf_token,
83
+ **kwargs,
84
+ )
src/clients/registry.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Client Provider Registry for unified provider selection.
3
+
4
+ Implements the Strategy Pattern to decouple client creation from the factory.
5
+ """
6
+
7
+ from typing import Any, ClassVar, Protocol
8
+
9
+ import structlog
10
+ from agent_framework import BaseChatClient
11
+
12
+ from src.utils.config import Settings
13
+
14
+ logger = structlog.get_logger()
15
+
16
+
17
+ class ClientProvider(Protocol):
18
+ """Protocol for LLM client providers."""
19
+
20
+ @property
21
+ def name(self) -> str:
22
+ """Provider name (e.g., 'openai', 'huggingface')."""
23
+ ...
24
+
25
+ def can_handle(
26
+ self, provider_name: str | None, api_key: str | None, settings: Settings
27
+ ) -> bool:
28
+ """Determine if this provider should handle the request."""
29
+ ...
30
+
31
+ def create(
32
+ self,
33
+ settings: Settings,
34
+ api_key: str | None = None,
35
+ model_id: str | None = None,
36
+ **kwargs: Any,
37
+ ) -> BaseChatClient:
38
+ """Create the client instance."""
39
+ ...
40
+
41
+
42
+ class ProviderRegistry:
43
+ """Registry for managing available LLM providers."""
44
+
45
+ _providers: ClassVar[list[ClientProvider]] = []
46
+
47
+ @classmethod
48
+ def register(cls, provider: ClientProvider) -> None:
49
+ """Register a new provider strategy."""
50
+ cls._providers.append(provider)
51
+
52
+ @classmethod
53
+ def clear(cls) -> None:
54
+ """Clear all registered providers (useful for testing)."""
55
+ cls._providers.clear()
56
+
57
+ @classmethod
58
+ def get_client(
59
+ cls,
60
+ settings: Settings,
61
+ provider: str | None = None,
62
+ api_key: str | None = None,
63
+ model_id: str | None = None,
64
+ **kwargs: Any,
65
+ ) -> BaseChatClient:
66
+ """
67
+ Find and execute the appropriate provider strategy.
68
+
69
+ Args:
70
+ settings: Application settings
71
+ provider: Explicit provider name
72
+ api_key: Optional API key
73
+ model_id: Optional model ID
74
+ **kwargs: Additional arguments for the client
75
+
76
+ Returns:
77
+ Configured BaseChatClient
78
+
79
+ Raises:
80
+ ValueError: If no provider can handle the request
81
+ """
82
+ # Normalize provider name
83
+ normalized_provider = provider.lower() if provider else None
84
+
85
+ for p in cls._providers:
86
+ if p.can_handle(normalized_provider, api_key, settings):
87
+ logger.info(f"Using {p.name} Chat Client")
88
+ return p.create(settings, api_key, model_id, **kwargs)
89
+
90
+ raise ValueError(f"No suitable provider found for provider={provider}")
tests/unit/clients/test_chat_client_factory.py CHANGED
@@ -79,7 +79,7 @@ class TestChatClientFactory:
79
 
80
  from src.clients.factory import get_chat_client
81
 
82
- with pytest.raises(ValueError, match="Unsupported provider"):
83
  get_chat_client(provider="invalid_provider")
84
 
85
  def test_byok_auto_detects_openai_from_key_prefix(self) -> None:
 
79
 
80
  from src.clients.factory import get_chat_client
81
 
82
+ with pytest.raises(ValueError, match="No suitable provider found"):
83
  get_chat_client(provider="invalid_provider")
84
 
85
  def test_byok_auto_detects_openai_from_key_prefix(self) -> None: