Spaces:
Paused
Paused
Mirrowel commited on
Commit ·
036cadd
1
Parent(s): e2c300f
feat(models): support array format for simplified model definitions
Browse filesEnhance flexibility for defining models loaded via environment variables (e.g., `PROVIDER_MODELS`).
- Allow simple array configuration, converting the list of names into internal model definitions using the name as the key and ID.
- Improve `get_model_id` logic to fall back to the model name if the explicit `id` field is not present in the definition dictionary.
- Update docstrings to clearly document both the simple array and advanced dictionary formats.
src/rotator_library/model_definitions.py
CHANGED
|
@@ -12,7 +12,18 @@ if not lib_logger.handlers:
|
|
| 12 |
class ModelDefinitions:
|
| 13 |
"""
|
| 14 |
Simple model definitions loader from environment variables.
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"""
|
| 17 |
|
| 18 |
def __init__(self, config_path: Optional[str] = None):
|
|
@@ -28,11 +39,25 @@ class ModelDefinitions:
|
|
| 28 |
provider_name = env_var[:-7].lower() # Remove "_MODELS" (7 characters)
|
| 29 |
try:
|
| 30 |
models_json = json.loads(env_value)
|
|
|
|
|
|
|
| 31 |
if isinstance(models_json, dict):
|
| 32 |
self.definitions[provider_name] = models_json
|
| 33 |
lib_logger.info(
|
| 34 |
f"Loaded {len(models_json)} models for provider: {provider_name}"
|
| 35 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
except (json.JSONDecodeError, TypeError) as e:
|
| 37 |
lib_logger.warning(f"Invalid JSON in {env_var}: {e}")
|
| 38 |
|
|
@@ -53,9 +78,12 @@ class ModelDefinitions:
|
|
| 53 |
return model_def.get("options", {}) if model_def else {}
|
| 54 |
|
| 55 |
def get_model_id(self, provider_name: str, model_name: str) -> Optional[str]:
|
| 56 |
-
"""Get model ID for a specific model."""
|
| 57 |
model_def = self.get_model_definition(provider_name, model_name)
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
def get_all_provider_models(self, provider_name: str) -> list:
|
| 61 |
"""Get all model names with provider prefix."""
|
|
|
|
| 12 |
class ModelDefinitions:
|
| 13 |
"""
|
| 14 |
Simple model definitions loader from environment variables.
|
| 15 |
+
|
| 16 |
+
Supports two formats:
|
| 17 |
+
1. Array format (simple): PROVIDER_MODELS=["model-1", "model-2", "model-3"]
|
| 18 |
+
- Each model name is used as both name and ID
|
| 19 |
+
2. Dict format (advanced): PROVIDER_MODELS={"model-name": {"id": "model-id", "options": {...}}}
|
| 20 |
+
- The 'id' field is optional - if not provided, the model name (key) is used as the ID
|
| 21 |
+
|
| 22 |
+
Examples:
|
| 23 |
+
- IFLOW_MODELS='["glm-4.6", "qwen3-max"]' - simple array format
|
| 24 |
+
- IFLOW_MODELS='{"glm-4.6": {}}' - dict format, uses "glm-4.6" as both name and ID
|
| 25 |
+
- IFLOW_MODELS='{"custom-name": {"id": "actual-id"}}' - dict format with custom ID
|
| 26 |
+
- IFLOW_MODELS='{"model": {"id": "id", "options": {"temperature": 0.7}}}' - with options
|
| 27 |
"""
|
| 28 |
|
| 29 |
def __init__(self, config_path: Optional[str] = None):
|
|
|
|
| 39 |
provider_name = env_var[:-7].lower() # Remove "_MODELS" (7 characters)
|
| 40 |
try:
|
| 41 |
models_json = json.loads(env_value)
|
| 42 |
+
|
| 43 |
+
# Handle dict format: {"model-name": {"id": "...", "options": {...}}}
|
| 44 |
if isinstance(models_json, dict):
|
| 45 |
self.definitions[provider_name] = models_json
|
| 46 |
lib_logger.info(
|
| 47 |
f"Loaded {len(models_json)} models for provider: {provider_name}"
|
| 48 |
)
|
| 49 |
+
# Handle array format: ["model-1", "model-2", "model-3"]
|
| 50 |
+
elif isinstance(models_json, list):
|
| 51 |
+
# Convert array to dict format with empty definitions
|
| 52 |
+
models_dict = {model_name: {} for model_name in models_json if isinstance(model_name, str)}
|
| 53 |
+
self.definitions[provider_name] = models_dict
|
| 54 |
+
lib_logger.info(
|
| 55 |
+
f"Loaded {len(models_dict)} models for provider: {provider_name} (array format)"
|
| 56 |
+
)
|
| 57 |
+
else:
|
| 58 |
+
lib_logger.warning(
|
| 59 |
+
f"{env_var} must be a JSON object or array, got {type(models_json).__name__}"
|
| 60 |
+
)
|
| 61 |
except (json.JSONDecodeError, TypeError) as e:
|
| 62 |
lib_logger.warning(f"Invalid JSON in {env_var}: {e}")
|
| 63 |
|
|
|
|
| 78 |
return model_def.get("options", {}) if model_def else {}
|
| 79 |
|
| 80 |
def get_model_id(self, provider_name: str, model_name: str) -> Optional[str]:
|
| 81 |
+
"""Get model ID for a specific model. Falls back to model_name if 'id' is not specified."""
|
| 82 |
model_def = self.get_model_definition(provider_name, model_name)
|
| 83 |
+
if not model_def:
|
| 84 |
+
return None
|
| 85 |
+
# Use 'id' if provided, otherwise use the model_name as the ID
|
| 86 |
+
return model_def.get("id", model_name)
|
| 87 |
|
| 88 |
def get_all_provider_models(self, provider_name: str) -> list:
|
| 89 |
"""Get all model names with provider prefix."""
|