Spaces:
Paused
Paused
Mirrowel commited on
Commit ·
e2c300f
1
Parent(s): dea215f
refactor(providers): enhance model discovery and deduplication logic
Browse filesRestructures model loading logic across `gemini_cli`, `iflow`, and `qwen_code` providers to handle prioritized model sources more reliably.
- Introduced `env_var_ids` tracking to record model IDs defined via environment variables.
- Environment variable models are now guaranteed to be included without internal deduplication.
- Dynamic and hardcoded models are only added if their base model ID does not conflict with an ID already defined via environment variables.
- Added a helper function `extract_model_id` to standardize model ID parsing across different API response formats.
- Corrects the Qwen Code provider to strip the internal provider prefix before sending the model name in the chat completion API payload.
src/rotator_library/providers/gemini_cli_provider.py
CHANGED
|
@@ -722,25 +722,53 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 722 |
async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
|
| 723 |
"""
|
| 724 |
Returns a merged list of Gemini CLI models from three sources:
|
| 725 |
-
1. Environment variable models (via GEMINI_CLI_MODELS)
|
| 726 |
-
2. Hardcoded models (fallback list)
|
| 727 |
-
3. Dynamic discovery from Gemini API (if supported)
|
|
|
|
|
|
|
|
|
|
| 728 |
"""
|
| 729 |
models = []
|
| 730 |
-
|
| 731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
static_models = self.model_definitions.get_all_provider_models("gemini_cli")
|
| 733 |
if static_models:
|
| 734 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
lib_logger.info(f"Loaded {len(static_models)} static models for gemini_cli from environment variables")
|
| 736 |
|
| 737 |
-
# Source 2: Add hardcoded models (
|
| 738 |
-
existing_ids = [m.split("/")[-1] for m in models]
|
| 739 |
for model_id in HARDCODED_MODELS:
|
| 740 |
-
if model_id not in
|
| 741 |
models.append(f"gemini_cli/{model_id}")
|
|
|
|
| 742 |
|
| 743 |
-
# Source 3: Try dynamic discovery from Gemini API
|
| 744 |
try:
|
| 745 |
# Get access token for API calls
|
| 746 |
access_token = await self.get_access_token(credential)
|
|
@@ -756,27 +784,17 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 756 |
)
|
| 757 |
response.raise_for_status()
|
| 758 |
|
| 759 |
-
# Parse dynamic models and avoid duplicates
|
| 760 |
-
existing_ids = [m.split("/")[-1] for m in models]
|
| 761 |
dynamic_data = response.json()
|
| 762 |
-
|
| 763 |
# Handle various response formats
|
| 764 |
model_list = dynamic_data.get("models", dynamic_data.get("data", []))
|
| 765 |
|
| 766 |
dynamic_count = 0
|
| 767 |
for model in model_list:
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
if
|
| 771 |
-
model_id = model.get("name", model.get("id"))
|
| 772 |
-
# Gemini models often have format "models/gemini-pro", extract just the model name
|
| 773 |
-
if model_id and "/" in model_id:
|
| 774 |
-
model_id = model_id.split("/")[-1]
|
| 775 |
-
else:
|
| 776 |
-
model_id = model
|
| 777 |
-
|
| 778 |
-
if model_id and model_id not in existing_ids and model_id.startswith("gemini"):
|
| 779 |
models.append(f"gemini_cli/{model_id}")
|
|
|
|
| 780 |
dynamic_count += 1
|
| 781 |
|
| 782 |
if dynamic_count > 0:
|
|
|
|
| 722 |
async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
|
| 723 |
"""
|
| 724 |
Returns a merged list of Gemini CLI models from three sources:
|
| 725 |
+
1. Environment variable models (via GEMINI_CLI_MODELS) - ALWAYS included, take priority
|
| 726 |
+
2. Hardcoded models (fallback list) - added only if ID not in env vars
|
| 727 |
+
3. Dynamic discovery from Gemini API (if supported) - added only if ID not in env vars
|
| 728 |
+
|
| 729 |
+
Environment variable models always win and are never deduplicated, even if they
|
| 730 |
+
share the same ID (to support different configs like temperature, etc.)
|
| 731 |
"""
|
| 732 |
models = []
|
| 733 |
+
env_var_ids = set() # Track IDs from env vars to prevent hardcoded/dynamic duplicates
|
| 734 |
+
|
| 735 |
+
def extract_model_id(item) -> str:
|
| 736 |
+
"""Extract model ID from various formats (dict, string with/without provider prefix)."""
|
| 737 |
+
if isinstance(item, dict):
|
| 738 |
+
# Dict format: extract 'name' or 'id' field
|
| 739 |
+
model_id = item.get("name") or item.get("id", "")
|
| 740 |
+
# Gemini models often have format "models/gemini-pro", extract just the model name
|
| 741 |
+
if model_id and "/" in model_id:
|
| 742 |
+
model_id = model_id.split("/")[-1]
|
| 743 |
+
return model_id
|
| 744 |
+
elif isinstance(item, str):
|
| 745 |
+
# String format: extract ID from "provider/id" or "models/id" or just "id"
|
| 746 |
+
return item.split("/")[-1] if "/" in item else item
|
| 747 |
+
return str(item)
|
| 748 |
+
|
| 749 |
+
# Source 1: Load environment variable models (ALWAYS include ALL of them)
|
| 750 |
static_models = self.model_definitions.get_all_provider_models("gemini_cli")
|
| 751 |
if static_models:
|
| 752 |
+
for model in static_models:
|
| 753 |
+
# Extract model name from "gemini_cli/ModelName" format
|
| 754 |
+
model_name = model.split("/")[-1] if "/" in model else model
|
| 755 |
+
# Get the actual model ID from definitions (which may differ from the name)
|
| 756 |
+
model_id = self.model_definitions.get_model_id("gemini_cli", model_name)
|
| 757 |
+
|
| 758 |
+
# ALWAYS add env var models (no deduplication)
|
| 759 |
+
models.append(model)
|
| 760 |
+
# Track the ID to prevent hardcoded/dynamic duplicates
|
| 761 |
+
if model_id:
|
| 762 |
+
env_var_ids.add(model_id)
|
| 763 |
lib_logger.info(f"Loaded {len(static_models)} static models for gemini_cli from environment variables")
|
| 764 |
|
| 765 |
+
# Source 2: Add hardcoded models (only if ID not already in env vars)
|
|
|
|
| 766 |
for model_id in HARDCODED_MODELS:
|
| 767 |
+
if model_id not in env_var_ids:
|
| 768 |
models.append(f"gemini_cli/{model_id}")
|
| 769 |
+
env_var_ids.add(model_id)
|
| 770 |
|
| 771 |
+
# Source 3: Try dynamic discovery from Gemini API (only if ID not already in env vars)
|
| 772 |
try:
|
| 773 |
# Get access token for API calls
|
| 774 |
access_token = await self.get_access_token(credential)
|
|
|
|
| 784 |
)
|
| 785 |
response.raise_for_status()
|
| 786 |
|
|
|
|
|
|
|
| 787 |
dynamic_data = response.json()
|
|
|
|
| 788 |
# Handle various response formats
|
| 789 |
model_list = dynamic_data.get("models", dynamic_data.get("data", []))
|
| 790 |
|
| 791 |
dynamic_count = 0
|
| 792 |
for model in model_list:
|
| 793 |
+
model_id = extract_model_id(model)
|
| 794 |
+
# Only include Gemini models that aren't already in env vars
|
| 795 |
+
if model_id and model_id not in env_var_ids and model_id.startswith("gemini"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
models.append(f"gemini_cli/{model_id}")
|
| 797 |
+
env_var_ids.add(model_id)
|
| 798 |
dynamic_count += 1
|
| 799 |
|
| 800 |
if dynamic_count > 0:
|
src/rotator_library/providers/iflow_provider.py
CHANGED
|
@@ -57,27 +57,51 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 57 |
async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
|
| 58 |
"""
|
| 59 |
Returns a merged list of iFlow models from three sources:
|
| 60 |
-
1. Environment variable models (via IFLOW_MODELS)
|
| 61 |
-
2. Hardcoded models (fallback list)
|
| 62 |
-
3. Dynamic discovery from iFlow API (if supported)
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
Validates OAuth credentials if applicable.
|
| 65 |
"""
|
| 66 |
models = []
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
static_models = self.model_definitions.get_all_provider_models("iflow")
|
| 70 |
if static_models:
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
lib_logger.info(f"Loaded {len(static_models)} static models for iflow from environment variables")
|
| 73 |
|
| 74 |
-
# Source 2: Add hardcoded models (
|
| 75 |
-
existing_ids = [m.split("/")[-1] for m in models]
|
| 76 |
for model_id in HARDCODED_MODELS:
|
| 77 |
-
if model_id not in
|
| 78 |
models.append(f"iflow/{model_id}")
|
|
|
|
| 79 |
|
| 80 |
-
# Source 3: Try dynamic discovery from iFlow API
|
| 81 |
try:
|
| 82 |
# Validate OAuth credentials and get API details
|
| 83 |
if os.path.isfile(credential):
|
|
@@ -92,18 +116,16 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 92 |
)
|
| 93 |
response.raise_for_status()
|
| 94 |
|
| 95 |
-
# Parse dynamic models and avoid duplicates
|
| 96 |
-
existing_ids = [m.split("/")[-1] for m in models]
|
| 97 |
dynamic_data = response.json()
|
| 98 |
-
|
| 99 |
# Handle both {data: [...]} and direct [...] formats
|
| 100 |
model_list = dynamic_data.get("data", dynamic_data) if isinstance(dynamic_data, dict) else dynamic_data
|
| 101 |
|
| 102 |
dynamic_count = 0
|
| 103 |
for model in model_list:
|
| 104 |
-
model_id =
|
| 105 |
-
if model_id and model_id not in
|
| 106 |
models.append(f"iflow/{model_id}")
|
|
|
|
| 107 |
dynamic_count += 1
|
| 108 |
|
| 109 |
if dynamic_count > 0:
|
|
|
|
| 57 |
async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
|
| 58 |
"""
|
| 59 |
Returns a merged list of iFlow models from three sources:
|
| 60 |
+
1. Environment variable models (via IFLOW_MODELS) - ALWAYS included, take priority
|
| 61 |
+
2. Hardcoded models (fallback list) - added only if ID not in env vars
|
| 62 |
+
3. Dynamic discovery from iFlow API (if supported) - added only if ID not in env vars
|
| 63 |
+
|
| 64 |
+
Environment variable models always win and are never deduplicated, even if they
|
| 65 |
+
share the same ID (to support different configs like temperature, etc.)
|
| 66 |
|
| 67 |
Validates OAuth credentials if applicable.
|
| 68 |
"""
|
| 69 |
models = []
|
| 70 |
+
env_var_ids = set() # Track IDs from env vars to prevent hardcoded/dynamic duplicates
|
| 71 |
+
|
| 72 |
+
def extract_model_id(item) -> str:
|
| 73 |
+
"""Extract model ID from various formats (dict, string with/without provider prefix)."""
|
| 74 |
+
if isinstance(item, dict):
|
| 75 |
+
# Dict format: extract 'id' or 'name' field
|
| 76 |
+
return item.get("id") or item.get("name", "")
|
| 77 |
+
elif isinstance(item, str):
|
| 78 |
+
# String format: extract ID from "provider/id" or just "id"
|
| 79 |
+
return item.split("/")[-1] if "/" in item else item
|
| 80 |
+
return str(item)
|
| 81 |
+
|
| 82 |
+
# Source 1: Load environment variable models (ALWAYS include ALL of them)
|
| 83 |
static_models = self.model_definitions.get_all_provider_models("iflow")
|
| 84 |
if static_models:
|
| 85 |
+
for model in static_models:
|
| 86 |
+
# Extract model name from "iflow/ModelName" format
|
| 87 |
+
model_name = model.split("/")[-1] if "/" in model else model
|
| 88 |
+
# Get the actual model ID from definitions (which may differ from the name)
|
| 89 |
+
model_id = self.model_definitions.get_model_id("iflow", model_name)
|
| 90 |
+
|
| 91 |
+
# ALWAYS add env var models (no deduplication)
|
| 92 |
+
models.append(model)
|
| 93 |
+
# Track the ID to prevent hardcoded/dynamic duplicates
|
| 94 |
+
if model_id:
|
| 95 |
+
env_var_ids.add(model_id)
|
| 96 |
lib_logger.info(f"Loaded {len(static_models)} static models for iflow from environment variables")
|
| 97 |
|
| 98 |
+
# Source 2: Add hardcoded models (only if ID not already in env vars)
|
|
|
|
| 99 |
for model_id in HARDCODED_MODELS:
|
| 100 |
+
if model_id not in env_var_ids:
|
| 101 |
models.append(f"iflow/{model_id}")
|
| 102 |
+
env_var_ids.add(model_id)
|
| 103 |
|
| 104 |
+
# Source 3: Try dynamic discovery from iFlow API (only if ID not already in env vars)
|
| 105 |
try:
|
| 106 |
# Validate OAuth credentials and get API details
|
| 107 |
if os.path.isfile(credential):
|
|
|
|
| 116 |
)
|
| 117 |
response.raise_for_status()
|
| 118 |
|
|
|
|
|
|
|
| 119 |
dynamic_data = response.json()
|
|
|
|
| 120 |
# Handle both {data: [...]} and direct [...] formats
|
| 121 |
model_list = dynamic_data.get("data", dynamic_data) if isinstance(dynamic_data, dict) else dynamic_data
|
| 122 |
|
| 123 |
dynamic_count = 0
|
| 124 |
for model in model_list:
|
| 125 |
+
model_id = extract_model_id(model)
|
| 126 |
+
if model_id and model_id not in env_var_ids:
|
| 127 |
models.append(f"iflow/{model_id}")
|
| 128 |
+
env_var_ids.add(model_id)
|
| 129 |
dynamic_count += 1
|
| 130 |
|
| 131 |
if dynamic_count > 0:
|
src/rotator_library/providers/qwen_code_provider.py
CHANGED
|
@@ -40,27 +40,51 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 40 |
async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
|
| 41 |
"""
|
| 42 |
Returns a merged list of Qwen Code models from three sources:
|
| 43 |
-
1. Environment variable models (via QWEN_CODE_MODELS)
|
| 44 |
-
2. Hardcoded models (fallback list)
|
| 45 |
-
3. Dynamic discovery from Qwen API (if supported)
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
Validates OAuth credentials if applicable.
|
| 48 |
"""
|
| 49 |
models = []
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
static_models = self.model_definitions.get_all_provider_models("qwen_code")
|
| 53 |
if static_models:
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
lib_logger.info(f"Loaded {len(static_models)} static models for qwen_code from environment variables")
|
| 56 |
|
| 57 |
-
# Source 2: Add hardcoded models (
|
| 58 |
-
existing_ids = [m.split("/")[-1] for m in models]
|
| 59 |
for model_id in HARDCODED_MODELS:
|
| 60 |
-
if model_id not in
|
| 61 |
models.append(f"qwen_code/{model_id}")
|
|
|
|
| 62 |
|
| 63 |
-
# Source 3: Try dynamic discovery from Qwen Code API
|
| 64 |
try:
|
| 65 |
# Validate OAuth credentials and get API details
|
| 66 |
if os.path.isfile(credential):
|
|
@@ -75,18 +99,16 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 75 |
)
|
| 76 |
response.raise_for_status()
|
| 77 |
|
| 78 |
-
# Parse dynamic models and avoid duplicates
|
| 79 |
-
existing_ids = [m.split("/")[-1] for m in models]
|
| 80 |
dynamic_data = response.json()
|
| 81 |
-
|
| 82 |
# Handle both {data: [...]} and direct [...] formats
|
| 83 |
model_list = dynamic_data.get("data", dynamic_data) if isinstance(dynamic_data, dict) else dynamic_data
|
| 84 |
|
| 85 |
dynamic_count = 0
|
| 86 |
for model in model_list:
|
| 87 |
-
model_id =
|
| 88 |
-
if model_id and model_id not in
|
| 89 |
models.append(f"qwen_code/{model_id}")
|
|
|
|
| 90 |
dynamic_count += 1
|
| 91 |
|
| 92 |
if dynamic_count > 0:
|
|
@@ -342,8 +364,12 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 342 |
"""Prepares and makes the actual API call."""
|
| 343 |
api_base, access_token = await self.get_api_details(credential_path)
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
# Build clean payload with only supported parameters
|
| 346 |
-
payload = self._build_request_payload(**
|
| 347 |
|
| 348 |
headers = {
|
| 349 |
"Authorization": f"Bearer {access_token}",
|
|
|
|
| 40 |
async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
|
| 41 |
"""
|
| 42 |
Returns a merged list of Qwen Code models from three sources:
|
| 43 |
+
1. Environment variable models (via QWEN_CODE_MODELS) - ALWAYS included, take priority
|
| 44 |
+
2. Hardcoded models (fallback list) - added only if ID not in env vars
|
| 45 |
+
3. Dynamic discovery from Qwen API (if supported) - added only if ID not in env vars
|
| 46 |
+
|
| 47 |
+
Environment variable models always win and are never deduplicated, even if they
|
| 48 |
+
share the same ID (to support different configs like temperature, etc.)
|
| 49 |
|
| 50 |
Validates OAuth credentials if applicable.
|
| 51 |
"""
|
| 52 |
models = []
|
| 53 |
+
env_var_ids = set() # Track IDs from env vars to prevent hardcoded/dynamic duplicates
|
| 54 |
+
|
| 55 |
+
def extract_model_id(item) -> str:
|
| 56 |
+
"""Extract model ID from various formats (dict, string with/without provider prefix)."""
|
| 57 |
+
if isinstance(item, dict):
|
| 58 |
+
# Dict format: extract 'id' or 'name' field
|
| 59 |
+
return item.get("id") or item.get("name", "")
|
| 60 |
+
elif isinstance(item, str):
|
| 61 |
+
# String format: extract ID from "provider/id" or just "id"
|
| 62 |
+
return item.split("/")[-1] if "/" in item else item
|
| 63 |
+
return str(item)
|
| 64 |
+
|
| 65 |
+
# Source 1: Load environment variable models (ALWAYS include ALL of them)
|
| 66 |
static_models = self.model_definitions.get_all_provider_models("qwen_code")
|
| 67 |
if static_models:
|
| 68 |
+
for model in static_models:
|
| 69 |
+
# Extract model name from "qwen_code/ModelName" format
|
| 70 |
+
model_name = model.split("/")[-1] if "/" in model else model
|
| 71 |
+
# Get the actual model ID from definitions (which may differ from the name)
|
| 72 |
+
model_id = self.model_definitions.get_model_id("qwen_code", model_name)
|
| 73 |
+
|
| 74 |
+
# ALWAYS add env var models (no deduplication)
|
| 75 |
+
models.append(model)
|
| 76 |
+
# Track the ID to prevent hardcoded/dynamic duplicates
|
| 77 |
+
if model_id:
|
| 78 |
+
env_var_ids.add(model_id)
|
| 79 |
lib_logger.info(f"Loaded {len(static_models)} static models for qwen_code from environment variables")
|
| 80 |
|
| 81 |
+
# Source 2: Add hardcoded models (only if ID not already in env vars)
|
|
|
|
| 82 |
for model_id in HARDCODED_MODELS:
|
| 83 |
+
if model_id not in env_var_ids:
|
| 84 |
models.append(f"qwen_code/{model_id}")
|
| 85 |
+
env_var_ids.add(model_id)
|
| 86 |
|
| 87 |
+
# Source 3: Try dynamic discovery from Qwen Code API (only if ID not already in env vars)
|
| 88 |
try:
|
| 89 |
# Validate OAuth credentials and get API details
|
| 90 |
if os.path.isfile(credential):
|
|
|
|
| 99 |
)
|
| 100 |
response.raise_for_status()
|
| 101 |
|
|
|
|
|
|
|
| 102 |
dynamic_data = response.json()
|
|
|
|
| 103 |
# Handle both {data: [...]} and direct [...] formats
|
| 104 |
model_list = dynamic_data.get("data", dynamic_data) if isinstance(dynamic_data, dict) else dynamic_data
|
| 105 |
|
| 106 |
dynamic_count = 0
|
| 107 |
for model in model_list:
|
| 108 |
+
model_id = extract_model_id(model)
|
| 109 |
+
if model_id and model_id not in env_var_ids:
|
| 110 |
models.append(f"qwen_code/{model_id}")
|
| 111 |
+
env_var_ids.add(model_id)
|
| 112 |
dynamic_count += 1
|
| 113 |
|
| 114 |
if dynamic_count > 0:
|
|
|
|
| 364 |
"""Prepares and makes the actual API call."""
|
| 365 |
api_base, access_token = await self.get_api_details(credential_path)
|
| 366 |
|
| 367 |
+
# Strip provider prefix from model name (e.g., "qwen_code/qwen3-coder-plus" -> "qwen3-coder-plus")
|
| 368 |
+
model_name = model.split('/')[-1]
|
| 369 |
+
kwargs_with_stripped_model = {**kwargs, 'model': model_name}
|
| 370 |
+
|
| 371 |
# Build clean payload with only supported parameters
|
| 372 |
+
payload = self._build_request_payload(**kwargs_with_stripped_model)
|
| 373 |
|
| 374 |
headers = {
|
| 375 |
"Authorization": f"Bearer {access_token}",
|