Mirrowel commited on
Commit
036cadd
·
1 Parent(s): e2c300f

feat(models): support array format for simplified model definitions

Browse files

Enhance flexibility for defining models loaded via environment variables (e.g., `PROVIDER_MODELS`).

- Allow simple array configuration, converting the list of names into internal model definitions using the name as the key and ID.
- Improve `get_model_id` logic to fall back to the model name if the explicit `id` field is not present in the definition dictionary.
- Update docstrings to clearly document both the simple array and advanced dictionary formats.

src/rotator_library/model_definitions.py CHANGED
@@ -12,7 +12,18 @@ if not lib_logger.handlers:
12
  class ModelDefinitions:
13
  """
14
  Simple model definitions loader from environment variables.
15
- Format: PROVIDER_MODELS={"model1": {"id": "id1"}, "model2": {"id": "id2", "options": {"reasoning_effort": "high"}}}
 
 
 
 
 
 
 
 
 
 
 
16
  """
17
 
18
  def __init__(self, config_path: Optional[str] = None):
@@ -28,11 +39,25 @@ class ModelDefinitions:
28
  provider_name = env_var[:-7].lower() # Remove "_MODELS" (7 characters)
29
  try:
30
  models_json = json.loads(env_value)
 
 
31
  if isinstance(models_json, dict):
32
  self.definitions[provider_name] = models_json
33
  lib_logger.info(
34
  f"Loaded {len(models_json)} models for provider: {provider_name}"
35
  )
 
 
 
 
 
 
 
 
 
 
 
 
36
  except (json.JSONDecodeError, TypeError) as e:
37
  lib_logger.warning(f"Invalid JSON in {env_var}: {e}")
38
 
@@ -53,9 +78,12 @@ class ModelDefinitions:
53
  return model_def.get("options", {}) if model_def else {}
54
 
55
  def get_model_id(self, provider_name: str, model_name: str) -> Optional[str]:
56
- """Get model ID for a specific model."""
57
  model_def = self.get_model_definition(provider_name, model_name)
58
- return model_def.get("id") if model_def else None
 
 
 
59
 
60
  def get_all_provider_models(self, provider_name: str) -> list:
61
  """Get all model names with provider prefix."""
 
12
  class ModelDefinitions:
13
  """
14
  Simple model definitions loader from environment variables.
15
+
16
+ Supports two formats:
17
+ 1. Array format (simple): PROVIDER_MODELS=["model-1", "model-2", "model-3"]
18
+ - Each model name is used as both name and ID
19
+ 2. Dict format (advanced): PROVIDER_MODELS={"model-name": {"id": "model-id", "options": {...}}}
20
+ - The 'id' field is optional - if not provided, the model name (key) is used as the ID
21
+
22
+ Examples:
23
+ - IFLOW_MODELS='["glm-4.6", "qwen3-max"]' - simple array format
24
+ - IFLOW_MODELS='{"glm-4.6": {}}' - dict format, uses "glm-4.6" as both name and ID
25
+ - IFLOW_MODELS='{"custom-name": {"id": "actual-id"}}' - dict format with custom ID
26
+ - IFLOW_MODELS='{"model": {"id": "id", "options": {"temperature": 0.7}}}' - with options
27
  """
28
 
29
  def __init__(self, config_path: Optional[str] = None):
 
39
  provider_name = env_var[:-7].lower() # Remove "_MODELS" (7 characters)
40
  try:
41
  models_json = json.loads(env_value)
42
+
43
+ # Handle dict format: {"model-name": {"id": "...", "options": {...}}}
44
  if isinstance(models_json, dict):
45
  self.definitions[provider_name] = models_json
46
  lib_logger.info(
47
  f"Loaded {len(models_json)} models for provider: {provider_name}"
48
  )
49
+ # Handle array format: ["model-1", "model-2", "model-3"]
50
+ elif isinstance(models_json, list):
51
+ # Convert array to dict format with empty definitions
52
+ models_dict = {model_name: {} for model_name in models_json if isinstance(model_name, str)}
53
+ self.definitions[provider_name] = models_dict
54
+ lib_logger.info(
55
+ f"Loaded {len(models_dict)} models for provider: {provider_name} (array format)"
56
+ )
57
+ else:
58
+ lib_logger.warning(
59
+ f"{env_var} must be a JSON object or array, got {type(models_json).__name__}"
60
+ )
61
  except (json.JSONDecodeError, TypeError) as e:
62
  lib_logger.warning(f"Invalid JSON in {env_var}: {e}")
63
 
 
78
  return model_def.get("options", {}) if model_def else {}
79
 
80
  def get_model_id(self, provider_name: str, model_name: str) -> Optional[str]:
81
+ """Get model ID for a specific model. Falls back to model_name if 'id' is not specified."""
82
  model_def = self.get_model_definition(provider_name, model_name)
83
+ if not model_def:
84
+ return None
85
+ # Use 'id' if provided, otherwise use the model_name as the ID
86
+ return model_def.get("id", model_name)
87
 
88
  def get_all_provider_models(self, provider_name: str) -> list:
89
  """Get all model names with provider prefix."""