linuztx commited on
Commit
2cf6dd3
·
1 Parent(s): 7ca5e78

refactor: Standardize model provider configuration

Browse files
conf/model_providers.yaml CHANGED
@@ -1,55 +1,90 @@
1
  # Supported model providers for Agent Zero
2
  # ---------------------------------------
3
- # Each entry must contain:
4
- # id – identifier used in settings (lower-case, no spaces)
5
- # name – human readable name
6
- # Optional extra parameters are accepted (api_base, kwargs …)
7
- # Chat-capable and embedding-capable providers are listed separately as not every
8
- # provider exposes both kinds of models.
 
 
 
 
 
 
 
 
 
9
 
10
  chat:
11
- - id: anthropic
12
  name: Anthropic
13
- - id: deepseek
 
14
  name: DeepSeek
15
- - id: gemini
 
16
  name: Google
17
- - id: groq
 
18
  name: Groq
19
- - id: huggingface
 
20
  name: HuggingFace
21
- - id: lm_studio
 
22
  name: LM Studio
23
- - id: mistral
 
24
  name: Mistral AI
25
- - id: ollama
 
26
  name: Ollama
27
- - id: openai
 
28
  name: OpenAI
29
- - id: azure
 
30
  name: OpenAI Azure
31
- - id: openrouter
 
32
  name: OpenRouter
33
- - id: sambanova
 
 
 
 
 
34
  name: Sambanova
35
- - id: venice
 
36
  name: Venice
37
- api_base: https://api.venice.ai/api/v1
38
- - id: other
 
 
39
  name: Other OpenAI compatible
 
40
 
41
  embedding:
42
- - id: huggingface
43
  name: HuggingFace
44
- - id: mistral
45
- name: Mistral AI
46
- - id: lm_studio
47
  name: LM Studio
48
- - id: ollama
 
 
 
 
49
  name: Ollama
50
- - id: openai
 
51
  name: OpenAI
52
- - id: azure
 
53
  name: OpenAI Azure
54
- - id: other
55
- name: Other OpenAI compatible
 
 
 
1
  # Supported model providers for Agent Zero
2
  # ---------------------------------------
3
+ #
4
+ # Each provider type ("chat", "embedding") contains a mapping of provider IDs
5
+ # to their configurations.
6
+ #
7
+ # The provider ID (e.g., "anthropic") is used:
8
+ # - in the settings UI dropdowns.
9
+ # - to construct the environment variable for the API key (e.g., ANTHROPIC_API_KEY).
10
+ #
11
+ # Each provider configuration requires:
12
+ # name: Human-readable name for the UI.
13
+ # litellm_provider: The corresponding provider name in LiteLLM.
14
+ #
15
+ # Optional fields:
16
+ # kwargs: A dictionary of extra parameters to pass to LiteLLM.
17
+ # This is useful for `api_base`, `extra_headers`, etc.
18
 
19
  chat:
20
+ anthropic:
21
  name: Anthropic
22
+ litellm_provider: anthropic
23
+ deepseek:
24
  name: DeepSeek
25
+ litellm_provider: deepseek
26
+ gemini:
27
  name: Google
28
+ litellm_provider: gemini
29
+ groq:
30
  name: Groq
31
+ litellm_provider: groq
32
+ huggingface:
33
  name: HuggingFace
34
+ litellm_provider: huggingface
35
+ lm_studio:
36
  name: LM Studio
37
+ litellm_provider: lm_studio
38
+ mistral:
39
  name: Mistral AI
40
+ litellm_provider: mistral
41
+ ollama:
42
  name: Ollama
43
+ litellm_provider: ollama
44
+ openai:
45
  name: OpenAI
46
+ litellm_provider: openai
47
+ azure:
48
  name: OpenAI Azure
49
+ litellm_provider: azure
50
+ openrouter:
51
  name: OpenRouter
52
+ litellm_provider: openrouter
53
+ kwargs:
54
+ extra_headers:
55
+ "HTTP-Referer": "https://agent-zero.ai/"
56
+ "X-Title": "Agent Zero"
57
+ sambanova:
58
  name: Sambanova
59
+ litellm_provider: sambanova
60
+ venice:
61
  name: Venice
62
+ litellm_provider: openai
63
+ kwargs:
64
+ api_base: https://api.venice.ai/api/v1
65
+ other:
66
  name: Other OpenAI compatible
67
+ litellm_provider: openai
68
 
69
  embedding:
70
+ huggingface:
71
  name: HuggingFace
72
+ litellm_provider: huggingface
73
+ lm_studio:
 
74
  name: LM Studio
75
+ litellm_provider: lm_studio
76
+ mistral:
77
+ name: Mistral AI
78
+ litellm_provider: mistral
79
+ ollama:
80
  name: Ollama
81
+ litellm_provider: ollama
82
+ openai:
83
  name: OpenAI
84
+ litellm_provider: openai
85
+ azure:
86
  name: OpenAI Azure
87
+ litellm_provider: azure
88
+ other:
89
+ name: Other OpenAI compatible
90
+ litellm_provider: openai
models.py CHANGED
@@ -351,6 +351,9 @@ class LocalSentenceTransformerWrapper(Embeddings):
351
  """Local wrapper for sentence-transformers models to avoid HuggingFace API calls"""
352
 
353
  def __init__(self, provider: str, model: str, **kwargs: Any):
 
 
 
354
  # Remove the "sentence-transformers/" prefix if present
355
  if model.startswith("sentence-transformers/"):
356
  model = model[len("sentence-transformers/") :]
@@ -449,20 +452,37 @@ def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict):
449
  if provider_name == "other":
450
  provider_name = "openai"
451
 
452
- # Treat unknown providers that expose a custom OpenAI-compatible endpoint
453
- # (i.e. they pass an `api_base` URL) as generic OpenAI providers so that
454
- # LiteLLM can route the call correctly. This keeps dedicated providers
455
- # such as Azure and OpenRouter unchanged.
456
- if kwargs.get("api_base") and provider_name not in (
457
- "openai",
458
- "azure",
459
- "openrouter",
460
- ):
461
- provider_name = "openai"
462
-
463
  return provider_name, model_name, kwargs
464
 
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  def get_model(type: ModelType, provider: str, name: str, **kwargs: Any):
467
  provider_name = provider.lower()
468
  if type == ModelType.CHAT:
@@ -476,46 +496,22 @@ def get_model(type: ModelType, provider: str, name: str, **kwargs: Any):
476
  def get_chat_model(
477
  provider: str, name: str, **kwargs: Any
478
  ) -> LiteLLMChatWrapper:
479
- provider_name = provider.lower()
480
-
481
- # Merge provider-specific defaults from configuration file
482
- cfg = get_provider_config("chat", provider_name)
483
- if cfg:
484
- extra = {k: v for k, v in cfg.items() if k not in ("id", "name", "value")}
485
- for k, v in extra.items():
486
- kwargs.setdefault(k, v)
487
-
488
- model = _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, **kwargs)
489
- return model
490
 
491
 
492
  def get_browser_model(
493
  provider: str, name: str, **kwargs: Any
494
  ) -> BrowserCompatibleChatWrapper:
495
- provider_name = provider.lower()
496
-
497
- cfg = get_provider_config("chat", provider_name)
498
- if cfg:
499
- extra = {k: v for k, v in cfg.items() if k not in ("id", "name", "value")}
500
- for k, v in extra.items():
501
- kwargs.setdefault(k, v)
502
-
503
- model = _get_litellm_chat(
504
- BrowserCompatibleChatWrapper, name, provider_name, **kwargs
505
- )
506
- return model
507
 
508
 
509
  def get_embedding_model(
510
  provider: str, name: str, **kwargs: Any
511
  ) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper:
512
- provider_name = provider.lower()
513
-
514
- cfg = get_provider_config("embedding", provider_name)
515
- if cfg:
516
- extra = {k: v for k, v in cfg.items() if k not in ("id", "name", "value")}
517
- for k, v in extra.items():
518
- kwargs.setdefault(k, v)
519
-
520
- model = _get_litellm_embedding(name, provider_name, **kwargs)
521
- return model
 
351
  """Local wrapper for sentence-transformers models to avoid HuggingFace API calls"""
352
 
353
  def __init__(self, provider: str, model: str, **kwargs: Any):
354
+ # Clean common user-input mistakes
355
+ model = model.strip().strip('"').strip("'")
356
+
357
  # Remove the "sentence-transformers/" prefix if present
358
  if model.startswith("sentence-transformers/"):
359
  model = model[len("sentence-transformers/") :]
 
452
  if provider_name == "other":
453
  provider_name = "openai"
454
 
 
 
 
 
 
 
 
 
 
 
 
455
  return provider_name, model_name, kwargs
456
 
457
 
458
+ def _merge_provider_defaults(
459
+ provider_type: str, original_provider: str, kwargs: dict
460
+ ) -> tuple[str, dict]:
461
+ provider_name = original_provider # default: unchanged
462
+ cfg = get_provider_config(provider_type, original_provider)
463
+ if cfg:
464
+ provider_name = cfg.get("litellm_provider", original_provider).lower()
465
+
466
+ # Extra arguments nested under `kwargs` for readability
467
+ extra_kwargs = cfg.get("kwargs") if isinstance(cfg, dict) else None # type: ignore[arg-type]
468
+ if isinstance(extra_kwargs, dict):
469
+ for k, v in extra_kwargs.items():
470
+ kwargs.setdefault(k, v)
471
+
472
+ # Copy any additional top-level fields except metadata keys
473
+ for k, v in cfg.items():
474
+ if k not in ("id", "name", "value", "litellm_provider", "kwargs"):
475
+ kwargs.setdefault(k, v)
476
+
477
+ # Inject API key based on the *original* provider id if still missing
478
+ if "api_key" not in kwargs:
479
+ key = get_api_key(original_provider)
480
+ if key and key not in ("None", "NA"):
481
+ kwargs["api_key"] = key
482
+
483
+ return provider_name, kwargs
484
+
485
+
486
  def get_model(type: ModelType, provider: str, name: str, **kwargs: Any):
487
  provider_name = provider.lower()
488
  if type == ModelType.CHAT:
 
496
  def get_chat_model(
497
  provider: str, name: str, **kwargs: Any
498
  ) -> LiteLLMChatWrapper:
499
+ orig = provider.lower()
500
+ provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs)
501
+ return _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, **kwargs)
 
 
 
 
 
 
 
 
502
 
503
 
504
  def get_browser_model(
505
  provider: str, name: str, **kwargs: Any
506
  ) -> BrowserCompatibleChatWrapper:
507
+ orig = provider.lower()
508
+ provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs)
509
+ return _get_litellm_chat(BrowserCompatibleChatWrapper, name, provider_name, **kwargs)
 
 
 
 
 
 
 
 
 
510
 
511
 
512
  def get_embedding_model(
513
  provider: str, name: str, **kwargs: Any
514
  ) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper:
515
+ orig = provider.lower()
516
+ provider_name, kwargs = _merge_provider_defaults("embedding", orig, kwargs)
517
+ return _get_litellm_embedding(name, provider_name, **kwargs)
 
 
 
 
 
 
 
preload.py CHANGED
@@ -21,11 +21,11 @@ async def preload():
21
 
22
  # preload embedding model
23
  async def preload_embedding():
24
- if set["embed_model_provider"] == "HuggingFace":
25
  try:
26
  # Use the new LiteLLM-based model system
27
  emb_mod = models.get_embedding_model(
28
- "HuggingFace", set["embed_model_name"]
29
  )
30
  emb_txt = await emb_mod.aembed_query("test")
31
  return emb_txt
 
21
 
22
  # preload embedding model
23
  async def preload_embedding():
24
+ if set["embed_model_provider"].lower() == "huggingface":
25
  try:
26
  # Use the new LiteLLM-based model system
27
  emb_mod = models.get_embedding_model(
28
+ "huggingface", set["embed_model_name"]
29
  )
30
  emb_txt = await emb_mod.aembed_query("test")
31
  return emb_txt
python/helpers/providers.py CHANGED
@@ -24,20 +24,45 @@ class ProviderManager:
24
  self._load_providers()
25
 
26
  def _load_providers(self):
27
- """Loads provider configurations from the YAML file."""
28
  try:
29
  config_path = files.get_abs_path("conf/model_providers.yaml")
30
  with open(config_path, "r", encoding="utf-8") as f:
31
- self._raw = yaml.safe_load(f) or {}
32
  except (FileNotFoundError, yaml.YAMLError):
33
- self._raw = {}
34
-
35
- # Build UI option lists (value / label) from raw data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  self._options = {}
37
- for p_type, providers in (self._raw or {}).items():
38
  opts: List[FieldOption] = []
39
- for p in providers or []:
40
- pid = (p.get("id") or p.get("value") or "").upper()
41
  name = p.get("name") or p.get("label") or pid
42
  if pid:
43
  opts.append({"value": pid, "label": name})
 
24
  self._load_providers()
25
 
26
  def _load_providers(self):
27
+ """Loads provider configurations from the YAML file and normalises them."""
28
  try:
29
  config_path = files.get_abs_path("conf/model_providers.yaml")
30
  with open(config_path, "r", encoding="utf-8") as f:
31
+ raw_yaml = yaml.safe_load(f) or {}
32
  except (FileNotFoundError, yaml.YAMLError):
33
+ raw_yaml = {}
34
+
35
+ # ------------------------------------------------------------
36
+ # Normalise the YAML so that internally we always work with a
37
+ # list-of-dicts [{id, name, ...}] for each provider type. This
38
+ # keeps existing callers unchanged while allowing the new nested
39
+ # mapping format in the YAML (id -> { ... }).
40
+ # ------------------------------------------------------------
41
+ normalised: Dict[str, List[Dict[str, str]]] = {}
42
+
43
+ for p_type, providers in (raw_yaml or {}).items():
44
+ items: List[Dict[str, str]] = []
45
+
46
+ if isinstance(providers, dict):
47
+ # New format: mapping of id -> config
48
+ for pid, cfg in providers.items():
49
+ entry = {"id": pid, **(cfg or {})}
50
+ items.append(entry)
51
+ elif isinstance(providers, list):
52
+ # Legacy list format – use as-is
53
+ items.extend(providers or [])
54
+
55
+ normalised[p_type] = items
56
+
57
+ # Save raw
58
+ self._raw = normalised
59
+
60
+ # Build UI-friendly option list (value / label)
61
  self._options = {}
62
+ for p_type, providers in normalised.items():
63
  opts: List[FieldOption] = []
64
+ for p in providers:
65
+ pid = (p.get("id") or p.get("value") or "").lower()
66
  name = p.get("name") or p.get("label") or pid
67
  if pid:
68
  opts.append({"value": pid, "label": name})
python/helpers/settings.py CHANGED
@@ -500,8 +500,17 @@ def convert_out(settings: Settings) -> SettingsOutput:
500
  # api keys model section
501
  api_keys_fields: list[SettingsField] = []
502
 
503
- for provider in get_providers("chat"):
504
- api_keys_fields.append(_get_api_key_field(settings, provider["value"].lower(), provider["label"]))
 
 
 
 
 
 
 
 
 
505
 
506
  api_keys_section: SettingsSection = {
507
  "id": "api_keys",
@@ -993,7 +1002,7 @@ def _write_sensitive_settings(settings: Settings):
993
  def get_default_settings() -> Settings:
994
  return Settings(
995
  version=_get_version(),
996
- chat_model_provider="OPENROUTER",
997
  chat_model_name="openai/gpt-4.1",
998
  chat_model_api_base="",
999
  chat_model_kwargs={"temperature": "0"},
@@ -1003,7 +1012,7 @@ def get_default_settings() -> Settings:
1003
  chat_model_rl_requests=0,
1004
  chat_model_rl_input=0,
1005
  chat_model_rl_output=0,
1006
- util_model_provider="OPENROUTER",
1007
  util_model_name="openai/gpt-4.1-nano",
1008
  util_model_api_base="",
1009
  util_model_ctx_length=100000,
@@ -1012,13 +1021,13 @@ def get_default_settings() -> Settings:
1012
  util_model_rl_requests=0,
1013
  util_model_rl_input=0,
1014
  util_model_rl_output=0,
1015
- embed_model_provider="HUGGINGFACE",
1016
  embed_model_name="sentence-transformers/all-MiniLM-L6-v2",
1017
  embed_model_api_base="",
1018
  embed_model_kwargs={},
1019
  embed_model_rl_requests=0,
1020
  embed_model_rl_input=0,
1021
- browser_model_provider="OPENROUTER",
1022
  browser_model_name="openai/gpt-4.1",
1023
  browser_model_api_base="",
1024
  browser_model_vision=True,
 
500
  # api keys model section
501
  api_keys_fields: list[SettingsField] = []
502
 
503
+ # Collect unique providers from both chat and embedding sections
504
+ providers_seen: set[str] = set()
505
+ for p_type in ("chat", "embedding"):
506
+ for provider in get_providers(p_type):
507
+ pid_lower = provider["value"].lower()
508
+ if pid_lower in providers_seen:
509
+ continue
510
+ providers_seen.add(pid_lower)
511
+ api_keys_fields.append(
512
+ _get_api_key_field(settings, pid_lower, provider["label"])
513
+ )
514
 
515
  api_keys_section: SettingsSection = {
516
  "id": "api_keys",
 
1002
  def get_default_settings() -> Settings:
1003
  return Settings(
1004
  version=_get_version(),
1005
+ chat_model_provider="openrouter",
1006
  chat_model_name="openai/gpt-4.1",
1007
  chat_model_api_base="",
1008
  chat_model_kwargs={"temperature": "0"},
 
1012
  chat_model_rl_requests=0,
1013
  chat_model_rl_input=0,
1014
  chat_model_rl_output=0,
1015
+ util_model_provider="openrouter",
1016
  util_model_name="openai/gpt-4.1-nano",
1017
  util_model_api_base="",
1018
  util_model_ctx_length=100000,
 
1021
  util_model_rl_requests=0,
1022
  util_model_rl_input=0,
1023
  util_model_rl_output=0,
1024
+ embed_model_provider="huggingface",
1025
  embed_model_name="sentence-transformers/all-MiniLM-L6-v2",
1026
  embed_model_api_base="",
1027
  embed_model_kwargs={},
1028
  embed_model_rl_requests=0,
1029
  embed_model_rl_input=0,
1030
+ browser_model_provider="openrouter",
1031
  browser_model_name="openai/gpt-4.1",
1032
  browser_model_api_base="",
1033
  browser_model_vision=True,