linuztx commited on
Commit
7ca5e78
·
1 Parent(s): 8a7cc1d

refactor: Extract model provider config to YAML file

Browse files
conf/model_providers.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Supported model providers for Agent Zero
2
+ # ---------------------------------------
3
+ # Each entry must contain:
4
+ # id – identifier used in settings (lower-case, no spaces)
5
+ # name – human readable name
6
+ # Optional extra parameters are accepted (api_base, kwargs …)
7
+ # Chat-capable and embedding-capable providers are listed separately as not every
8
+ # provider exposes both kinds of models.
9
+
10
+ chat:
11
+ - id: anthropic
12
+ name: Anthropic
13
+ - id: deepseek
14
+ name: DeepSeek
15
+ - id: gemini
16
+ name: Google
17
+ - id: groq
18
+ name: Groq
19
+ - id: huggingface
20
+ name: HuggingFace
21
+ - id: lm_studio
22
+ name: LM Studio
23
+ - id: mistral
24
+ name: Mistral AI
25
+ - id: ollama
26
+ name: Ollama
27
+ - id: openai
28
+ name: OpenAI
29
+ - id: azure
30
+ name: OpenAI Azure
31
+ - id: openrouter
32
+ name: OpenRouter
33
+ - id: sambanova
34
+ name: Sambanova
35
+ - id: venice
36
+ name: Venice
37
+ api_base: https://api.venice.ai/api/v1
38
+ - id: other
39
+ name: Other OpenAI compatible
40
+
41
+ embedding:
42
+ - id: huggingface
43
+ name: HuggingFace
44
+ - id: mistral
45
+ name: Mistral AI
46
+ - id: lm_studio
47
+ name: LM Studio
48
+ - id: ollama
49
+ name: Ollama
50
+ - id: openai
51
+ name: OpenAI
52
+ - id: azure
53
+ name: OpenAI Azure
54
+ - id: other
55
+ name: Other OpenAI compatible
models.py CHANGED
@@ -19,6 +19,7 @@ import litellm
19
 
20
  from python.helpers import dotenv
21
  from python.helpers.dotenv import load_dotenv
 
22
  from python.helpers.rate_limiter import RateLimiter
23
  from python.helpers.tokens import approximate_tokens
24
 
@@ -448,6 +449,17 @@ def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict):
448
  if provider_name == "other":
449
  provider_name = "openai"
450
 
 
 
 
 
 
 
 
 
 
 
 
451
  return provider_name, model_name, kwargs
452
 
453
 
@@ -465,6 +477,14 @@ def get_chat_model(
465
  provider: str, name: str, **kwargs: Any
466
  ) -> LiteLLMChatWrapper:
467
  provider_name = provider.lower()
 
 
 
 
 
 
 
 
468
  model = _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, **kwargs)
469
  return model
470
 
@@ -473,6 +493,13 @@ def get_browser_model(
473
  provider: str, name: str, **kwargs: Any
474
  ) -> BrowserCompatibleChatWrapper:
475
  provider_name = provider.lower()
 
 
 
 
 
 
 
476
  model = _get_litellm_chat(
477
  BrowserCompatibleChatWrapper, name, provider_name, **kwargs
478
  )
@@ -483,5 +510,12 @@ def get_embedding_model(
483
  provider: str, name: str, **kwargs: Any
484
  ) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper:
485
  provider_name = provider.lower()
 
 
 
 
 
 
 
486
  model = _get_litellm_embedding(name, provider_name, **kwargs)
487
  return model
 
19
 
20
  from python.helpers import dotenv
21
  from python.helpers.dotenv import load_dotenv
22
+ from python.helpers.providers import get_provider_config
23
  from python.helpers.rate_limiter import RateLimiter
24
  from python.helpers.tokens import approximate_tokens
25
 
 
449
  if provider_name == "other":
450
  provider_name = "openai"
451
 
452
+ # Treat unknown providers that expose a custom OpenAI-compatible endpoint
453
+ # (i.e. they pass an `api_base` URL) as generic OpenAI providers so that
454
+ # LiteLLM can route the call correctly. This keeps dedicated providers
455
+ # such as Azure and OpenRouter unchanged.
456
+ if kwargs.get("api_base") and provider_name not in (
457
+ "openai",
458
+ "azure",
459
+ "openrouter",
460
+ ):
461
+ provider_name = "openai"
462
+
463
  return provider_name, model_name, kwargs
464
 
465
 
 
477
  provider: str, name: str, **kwargs: Any
478
  ) -> LiteLLMChatWrapper:
479
  provider_name = provider.lower()
480
+
481
+ # Merge provider-specific defaults from configuration file
482
+ cfg = get_provider_config("chat", provider_name)
483
+ if cfg:
484
+ extra = {k: v for k, v in cfg.items() if k not in ("id", "name", "value")}
485
+ for k, v in extra.items():
486
+ kwargs.setdefault(k, v)
487
+
488
  model = _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, **kwargs)
489
  return model
490
 
 
493
  provider: str, name: str, **kwargs: Any
494
  ) -> BrowserCompatibleChatWrapper:
495
  provider_name = provider.lower()
496
+
497
+ cfg = get_provider_config("chat", provider_name)
498
+ if cfg:
499
+ extra = {k: v for k, v in cfg.items() if k not in ("id", "name", "value")}
500
+ for k, v in extra.items():
501
+ kwargs.setdefault(k, v)
502
+
503
  model = _get_litellm_chat(
504
  BrowserCompatibleChatWrapper, name, provider_name, **kwargs
505
  )
 
510
  provider: str, name: str, **kwargs: Any
511
  ) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper:
512
  provider_name = provider.lower()
513
+
514
+ cfg = get_provider_config("embedding", provider_name)
515
+ if cfg:
516
+ extra = {k: v for k, v in cfg.items() if k not in ("id", "name", "value")}
517
+ for k, v in extra.items():
518
+ kwargs.setdefault(k, v)
519
+
520
  model = _get_litellm_embedding(name, provider_name, **kwargs)
521
  return model
python/helpers/providers.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from python.helpers import files
3
+ from typing import List, Dict, Optional, TypedDict
4
+
5
+
6
+ # Type alias for UI option items
7
+ class FieldOption(TypedDict):
8
+ value: str
9
+ label: str
10
+
11
+ class ProviderManager:
12
+ _instance = None
13
+ _raw: Optional[Dict[str, List[Dict[str, str]]]] = None # full provider data
14
+ _options: Optional[Dict[str, List[FieldOption]]] = None # UI-friendly list
15
+
16
+ @classmethod
17
+ def get_instance(cls):
18
+ if cls._instance is None:
19
+ cls._instance = cls()
20
+ return cls._instance
21
+
22
+ def __init__(self):
23
+ if self._raw is None or self._options is None:
24
+ self._load_providers()
25
+
26
+ def _load_providers(self):
27
+ """Loads provider configurations from the YAML file."""
28
+ try:
29
+ config_path = files.get_abs_path("conf/model_providers.yaml")
30
+ with open(config_path, "r", encoding="utf-8") as f:
31
+ self._raw = yaml.safe_load(f) or {}
32
+ except (FileNotFoundError, yaml.YAMLError):
33
+ self._raw = {}
34
+
35
+ # Build UI option lists (value / label) from raw data
36
+ self._options = {}
37
+ for p_type, providers in (self._raw or {}).items():
38
+ opts: List[FieldOption] = []
39
+ for p in providers or []:
40
+ pid = (p.get("id") or p.get("value") or "").upper()
41
+ name = p.get("name") or p.get("label") or pid
42
+ if pid:
43
+ opts.append({"value": pid, "label": name})
44
+ self._options[p_type] = opts
45
+
46
+ def get_providers(self, provider_type: str) -> List[FieldOption]:
47
+ """Returns a list of providers for a given type (e.g., 'chat', 'embedding')."""
48
+ return self._options.get(provider_type, []) if self._options else []
49
+
50
+
51
+ def get_raw_providers(self, provider_type: str) -> List[Dict[str, str]]:
52
+ """Return raw provider dictionaries for advanced use-cases."""
53
+ return self._raw.get(provider_type, []) if self._raw else []
54
+
55
+ def get_provider_config(self, provider_type: str, provider_id: str) -> Optional[Dict[str, str]]:
56
+ """Return the metadata dict for a single provider id (case-insensitive)."""
57
+ provider_id_low = provider_id.lower()
58
+ for p in self.get_raw_providers(provider_type):
59
+ if (p.get("id") or p.get("value", "")).lower() == provider_id_low:
60
+ return p
61
+ return None
62
+
63
+
64
+ def get_providers(provider_type: str) -> List[FieldOption]:
65
+ """Convenience function to get providers of a specific type."""
66
+ return ProviderManager.get_instance().get_providers(provider_type)
67
+
68
+
69
+ def get_raw_providers(provider_type: str) -> List[Dict[str, str]]:
70
+ """Return full metadata for providers of a given type."""
71
+ return ProviderManager.get_instance().get_raw_providers(provider_type)
72
+
73
+
74
+ def get_provider_config(provider_type: str, provider_id: str) -> Optional[Dict[str, str]]:
75
+ """Return metadata for a single provider (None if not found)."""
76
+ return ProviderManager.get_instance().get_provider_config(provider_type, provider_id)
python/helpers/settings.py CHANGED
@@ -4,12 +4,13 @@ import json
4
  import os
5
  import re
6
  import subprocess
7
- from typing import Any, Literal, TypedDict
8
 
9
  import models
10
  from python.helpers import runtime, whisper, defer, git
11
  from . import files, dotenv
12
  from python.helpers.print_style import PrintStyle
 
13
 
14
 
15
  class Settings(TypedDict):
@@ -121,22 +122,6 @@ PASSWORD_PLACEHOLDER = "****PSWD****"
121
  SETTINGS_FILE = files.get_abs_path("tmp/settings.json")
122
  _settings: Settings | None = None
123
 
124
- # TODO: this is temporary, will be replaced by a proper solution
125
- PROVIDERS: list[FieldOption] = [
126
- {"value": "ANTHROPIC", "label": "Anthropic"},
127
- {"value": "DEEPSEEK", "label": "DeepSeek"},
128
- {"value": "GEMINI", "label": "Google"},
129
- {"value": "GROQ", "label": "Groq"},
130
- {"value": "HUGGINGFACE", "label": "HuggingFace"},
131
- {"value": "LM_STUDIO", "label": "LM Studio"},
132
- {"value": "MISTRAL", "label": "Mistral AI"},
133
- {"value": "OLLAMA", "label": "Ollama"},
134
- {"value": "OPENAI", "label": "OpenAI"},
135
- {"value": "AZURE", "label": "OpenAI Azure"},
136
- {"value": "OPENROUTER", "label": "OpenRouter"},
137
- {"value": "SAMBANOVA", "label": "Sambanova"},
138
- {"value": "OTHER", "label": "Other OpenAI compatible"},
139
- ]
140
 
141
 
142
  def convert_out(settings: Settings) -> SettingsOutput:
@@ -151,7 +136,7 @@ def convert_out(settings: Settings) -> SettingsOutput:
151
  "description": "Select provider for main chat model used by Agent Zero",
152
  "type": "select",
153
  "value": settings["chat_model_provider"],
154
- "options": PROVIDERS,
155
  }
156
  )
157
  chat_model_fields.append(
@@ -264,7 +249,7 @@ def convert_out(settings: Settings) -> SettingsOutput:
264
  "description": "Select provider for utility model used by the framework",
265
  "type": "select",
266
  "value": settings["util_model_provider"],
267
- "options": PROVIDERS,
268
  }
269
  )
270
  util_model_fields.append(
@@ -344,7 +329,7 @@ def convert_out(settings: Settings) -> SettingsOutput:
344
  "description": "Select provider for embedding model used by the framework",
345
  "type": "select",
346
  "value": settings["embed_model_provider"],
347
- "options": PROVIDERS,
348
  }
349
  )
350
  embed_model_fields.append(
@@ -414,7 +399,7 @@ def convert_out(settings: Settings) -> SettingsOutput:
414
  "description": "Select provider for web browser model used by <a href='https://github.com/browser-use/browser-use' target='_blank'>browser-use</a> framework",
415
  "type": "select",
416
  "value": settings["browser_model_provider"],
417
- "options": PROVIDERS,
418
  }
419
  )
420
  browser_model_fields.append(
@@ -515,7 +500,7 @@ def convert_out(settings: Settings) -> SettingsOutput:
515
  # api keys model section
516
  api_keys_fields: list[SettingsField] = []
517
 
518
- for provider in PROVIDERS:
519
  api_keys_fields.append(_get_api_key_field(settings, provider["value"].lower(), provider["label"]))
520
 
521
  api_keys_section: SettingsSection = {
 
4
  import os
5
  import re
6
  import subprocess
7
+ from typing import Any, Literal, TypedDict, cast
8
 
9
  import models
10
  from python.helpers import runtime, whisper, defer, git
11
  from . import files, dotenv
12
  from python.helpers.print_style import PrintStyle
13
+ from python.helpers.providers import get_providers
14
 
15
 
16
  class Settings(TypedDict):
 
122
  SETTINGS_FILE = files.get_abs_path("tmp/settings.json")
123
  _settings: Settings | None = None
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
 
127
  def convert_out(settings: Settings) -> SettingsOutput:
 
136
  "description": "Select provider for main chat model used by Agent Zero",
137
  "type": "select",
138
  "value": settings["chat_model_provider"],
139
+ "options": cast(list[FieldOption], get_providers("chat")),
140
  }
141
  )
142
  chat_model_fields.append(
 
249
  "description": "Select provider for utility model used by the framework",
250
  "type": "select",
251
  "value": settings["util_model_provider"],
252
+ "options": cast(list[FieldOption], get_providers("chat")),
253
  }
254
  )
255
  util_model_fields.append(
 
329
  "description": "Select provider for embedding model used by the framework",
330
  "type": "select",
331
  "value": settings["embed_model_provider"],
332
+ "options": cast(list[FieldOption], get_providers("embedding")),
333
  }
334
  )
335
  embed_model_fields.append(
 
399
  "description": "Select provider for web browser model used by <a href='https://github.com/browser-use/browser-use' target='_blank'>browser-use</a> framework",
400
  "type": "select",
401
  "value": settings["browser_model_provider"],
402
+ "options": cast(list[FieldOption], get_providers("chat")),
403
  }
404
  )
405
  browser_model_fields.append(
 
500
  # api keys model section
501
  api_keys_fields: list[SettingsField] = []
502
 
503
+ for provider in get_providers("chat"):
504
  api_keys_fields.append(_get_api_key_field(settings, provider["value"].lower(), provider["label"]))
505
 
506
  api_keys_section: SettingsSection = {