Mirrowel commited on
Commit
e2c300f
·
1 Parent(s): dea215f

refactor(providers): enhance model discovery and deduplication logic

Browse files

Restructures model loading logic across `gemini_cli`, `iflow`, and `qwen_code` providers to handle prioritized model sources more reliably.

- Introduced `env_var_ids` tracking to record model IDs defined via environment variables.
- Environment variable models are now guaranteed to be included without internal deduplication.
- Dynamic and hardcoded models are only added if their base model ID does not conflict with an ID already defined via environment variables.
- Added a helper function `extract_model_id` to standardize model ID parsing across different API response formats.
- Corrects the Qwen Code provider to strip the internal provider prefix before sending the model name in the chat completion API payload.

src/rotator_library/providers/gemini_cli_provider.py CHANGED
@@ -722,25 +722,53 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
722
  async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
723
  """
724
  Returns a merged list of Gemini CLI models from three sources:
725
- 1. Environment variable models (via GEMINI_CLI_MODELS)
726
- 2. Hardcoded models (fallback list)
727
- 3. Dynamic discovery from Gemini API (if supported)
 
 
 
728
  """
729
  models = []
730
-
731
- # Source 1: Load environment variable models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
  static_models = self.model_definitions.get_all_provider_models("gemini_cli")
733
  if static_models:
734
- models.extend(static_models)
 
 
 
 
 
 
 
 
 
 
735
  lib_logger.info(f"Loaded {len(static_models)} static models for gemini_cli from environment variables")
736
 
737
- # Source 2: Add hardcoded models (avoiding duplicates)
738
- existing_ids = [m.split("/")[-1] for m in models]
739
  for model_id in HARDCODED_MODELS:
740
- if model_id not in existing_ids:
741
  models.append(f"gemini_cli/{model_id}")
 
742
 
743
- # Source 3: Try dynamic discovery from Gemini API
744
  try:
745
  # Get access token for API calls
746
  access_token = await self.get_access_token(credential)
@@ -756,27 +784,17 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
756
  )
757
  response.raise_for_status()
758
 
759
- # Parse dynamic models and avoid duplicates
760
- existing_ids = [m.split("/")[-1] for m in models]
761
  dynamic_data = response.json()
762
-
763
  # Handle various response formats
764
  model_list = dynamic_data.get("models", dynamic_data.get("data", []))
765
 
766
  dynamic_count = 0
767
  for model in model_list:
768
- # Extract model ID (may be in 'name' or 'id' field)
769
- model_id = None
770
- if isinstance(model, dict):
771
- model_id = model.get("name", model.get("id"))
772
- # Gemini models often have format "models/gemini-pro", extract just the model name
773
- if model_id and "/" in model_id:
774
- model_id = model_id.split("/")[-1]
775
- else:
776
- model_id = model
777
-
778
- if model_id and model_id not in existing_ids and model_id.startswith("gemini"):
779
  models.append(f"gemini_cli/{model_id}")
 
780
  dynamic_count += 1
781
 
782
  if dynamic_count > 0:
 
722
  async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
723
  """
724
  Returns a merged list of Gemini CLI models from three sources:
725
+ 1. Environment variable models (via GEMINI_CLI_MODELS) - ALWAYS included, take priority
726
+ 2. Hardcoded models (fallback list) - added only if ID not in env vars
727
+ 3. Dynamic discovery from Gemini API (if supported) - added only if ID not in env vars
728
+
729
+ Environment variable models always win and are never deduplicated, even if they
730
+ share the same ID (to support different configs like temperature, etc.)
731
  """
732
  models = []
733
+ env_var_ids = set() # Track IDs from env vars to prevent hardcoded/dynamic duplicates
734
+
735
+ def extract_model_id(item) -> str:
736
+ """Extract model ID from various formats (dict, string with/without provider prefix)."""
737
+ if isinstance(item, dict):
738
+ # Dict format: extract 'name' or 'id' field
739
+ model_id = item.get("name") or item.get("id", "")
740
+ # Gemini models often have format "models/gemini-pro", extract just the model name
741
+ if model_id and "/" in model_id:
742
+ model_id = model_id.split("/")[-1]
743
+ return model_id
744
+ elif isinstance(item, str):
745
+ # String format: extract ID from "provider/id" or "models/id" or just "id"
746
+ return item.split("/")[-1] if "/" in item else item
747
+ return str(item)
748
+
749
+ # Source 1: Load environment variable models (ALWAYS include ALL of them)
750
  static_models = self.model_definitions.get_all_provider_models("gemini_cli")
751
  if static_models:
752
+ for model in static_models:
753
+ # Extract model name from "gemini_cli/ModelName" format
754
+ model_name = model.split("/")[-1] if "/" in model else model
755
+ # Get the actual model ID from definitions (which may differ from the name)
756
+ model_id = self.model_definitions.get_model_id("gemini_cli", model_name)
757
+
758
+ # ALWAYS add env var models (no deduplication)
759
+ models.append(model)
760
+ # Track the ID to prevent hardcoded/dynamic duplicates
761
+ if model_id:
762
+ env_var_ids.add(model_id)
763
  lib_logger.info(f"Loaded {len(static_models)} static models for gemini_cli from environment variables")
764
 
765
+ # Source 2: Add hardcoded models (only if ID not already in env vars)
 
766
  for model_id in HARDCODED_MODELS:
767
+ if model_id not in env_var_ids:
768
  models.append(f"gemini_cli/{model_id}")
769
+ env_var_ids.add(model_id)
770
 
771
+ # Source 3: Try dynamic discovery from Gemini API (only if ID not already in env vars)
772
  try:
773
  # Get access token for API calls
774
  access_token = await self.get_access_token(credential)
 
784
  )
785
  response.raise_for_status()
786
 
 
 
787
  dynamic_data = response.json()
 
788
  # Handle various response formats
789
  model_list = dynamic_data.get("models", dynamic_data.get("data", []))
790
 
791
  dynamic_count = 0
792
  for model in model_list:
793
+ model_id = extract_model_id(model)
794
+ # Only include Gemini models that aren't already in env vars
795
+ if model_id and model_id not in env_var_ids and model_id.startswith("gemini"):
 
 
 
 
 
 
 
 
796
  models.append(f"gemini_cli/{model_id}")
797
+ env_var_ids.add(model_id)
798
  dynamic_count += 1
799
 
800
  if dynamic_count > 0:
src/rotator_library/providers/iflow_provider.py CHANGED
@@ -57,27 +57,51 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
57
  async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
58
  """
59
  Returns a merged list of iFlow models from three sources:
60
- 1. Environment variable models (via IFLOW_MODELS)
61
- 2. Hardcoded models (fallback list)
62
- 3. Dynamic discovery from iFlow API (if supported)
 
 
 
63
 
64
  Validates OAuth credentials if applicable.
65
  """
66
  models = []
67
-
68
- # Source 1: Load environment variable models
 
 
 
 
 
 
 
 
 
 
 
69
  static_models = self.model_definitions.get_all_provider_models("iflow")
70
  if static_models:
71
- models.extend(static_models)
 
 
 
 
 
 
 
 
 
 
72
  lib_logger.info(f"Loaded {len(static_models)} static models for iflow from environment variables")
73
 
74
- # Source 2: Add hardcoded models (avoiding duplicates)
75
- existing_ids = [m.split("/")[-1] for m in models]
76
  for model_id in HARDCODED_MODELS:
77
- if model_id not in existing_ids:
78
  models.append(f"iflow/{model_id}")
 
79
 
80
- # Source 3: Try dynamic discovery from iFlow API
81
  try:
82
  # Validate OAuth credentials and get API details
83
  if os.path.isfile(credential):
@@ -92,18 +116,16 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
92
  )
93
  response.raise_for_status()
94
 
95
- # Parse dynamic models and avoid duplicates
96
- existing_ids = [m.split("/")[-1] for m in models]
97
  dynamic_data = response.json()
98
-
99
  # Handle both {data: [...]} and direct [...] formats
100
  model_list = dynamic_data.get("data", dynamic_data) if isinstance(dynamic_data, dict) else dynamic_data
101
 
102
  dynamic_count = 0
103
  for model in model_list:
104
- model_id = model.get("id") if isinstance(model, dict) else model
105
- if model_id and model_id not in existing_ids:
106
  models.append(f"iflow/{model_id}")
 
107
  dynamic_count += 1
108
 
109
  if dynamic_count > 0:
 
57
  async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
58
  """
59
  Returns a merged list of iFlow models from three sources:
60
+ 1. Environment variable models (via IFLOW_MODELS) - ALWAYS included, take priority
61
+ 2. Hardcoded models (fallback list) - added only if ID not in env vars
62
+ 3. Dynamic discovery from iFlow API (if supported) - added only if ID not in env vars
63
+
64
+ Environment variable models always win and are never deduplicated, even if they
65
+ share the same ID (to support different configs like temperature, etc.)
66
 
67
  Validates OAuth credentials if applicable.
68
  """
69
  models = []
70
+ env_var_ids = set() # Track IDs from env vars to prevent hardcoded/dynamic duplicates
71
+
72
+ def extract_model_id(item) -> str:
73
+ """Extract model ID from various formats (dict, string with/without provider prefix)."""
74
+ if isinstance(item, dict):
75
+ # Dict format: extract 'id' or 'name' field
76
+ return item.get("id") or item.get("name", "")
77
+ elif isinstance(item, str):
78
+ # String format: extract ID from "provider/id" or just "id"
79
+ return item.split("/")[-1] if "/" in item else item
80
+ return str(item)
81
+
82
+ # Source 1: Load environment variable models (ALWAYS include ALL of them)
83
  static_models = self.model_definitions.get_all_provider_models("iflow")
84
  if static_models:
85
+ for model in static_models:
86
+ # Extract model name from "iflow/ModelName" format
87
+ model_name = model.split("/")[-1] if "/" in model else model
88
+ # Get the actual model ID from definitions (which may differ from the name)
89
+ model_id = self.model_definitions.get_model_id("iflow", model_name)
90
+
91
+ # ALWAYS add env var models (no deduplication)
92
+ models.append(model)
93
+ # Track the ID to prevent hardcoded/dynamic duplicates
94
+ if model_id:
95
+ env_var_ids.add(model_id)
96
  lib_logger.info(f"Loaded {len(static_models)} static models for iflow from environment variables")
97
 
98
+ # Source 2: Add hardcoded models (only if ID not already in env vars)
 
99
  for model_id in HARDCODED_MODELS:
100
+ if model_id not in env_var_ids:
101
  models.append(f"iflow/{model_id}")
102
+ env_var_ids.add(model_id)
103
 
104
+ # Source 3: Try dynamic discovery from iFlow API (only if ID not already in env vars)
105
  try:
106
  # Validate OAuth credentials and get API details
107
  if os.path.isfile(credential):
 
116
  )
117
  response.raise_for_status()
118
 
 
 
119
  dynamic_data = response.json()
 
120
  # Handle both {data: [...]} and direct [...] formats
121
  model_list = dynamic_data.get("data", dynamic_data) if isinstance(dynamic_data, dict) else dynamic_data
122
 
123
  dynamic_count = 0
124
  for model in model_list:
125
+ model_id = extract_model_id(model)
126
+ if model_id and model_id not in env_var_ids:
127
  models.append(f"iflow/{model_id}")
128
+ env_var_ids.add(model_id)
129
  dynamic_count += 1
130
 
131
  if dynamic_count > 0:
src/rotator_library/providers/qwen_code_provider.py CHANGED
@@ -40,27 +40,51 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
40
  async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
41
  """
42
  Returns a merged list of Qwen Code models from three sources:
43
- 1. Environment variable models (via QWEN_CODE_MODELS)
44
- 2. Hardcoded models (fallback list)
45
- 3. Dynamic discovery from Qwen API (if supported)
 
 
 
46
 
47
  Validates OAuth credentials if applicable.
48
  """
49
  models = []
50
-
51
- # Source 1: Load environment variable models
 
 
 
 
 
 
 
 
 
 
 
52
  static_models = self.model_definitions.get_all_provider_models("qwen_code")
53
  if static_models:
54
- models.extend(static_models)
 
 
 
 
 
 
 
 
 
 
55
  lib_logger.info(f"Loaded {len(static_models)} static models for qwen_code from environment variables")
56
 
57
- # Source 2: Add hardcoded models (avoiding duplicates)
58
- existing_ids = [m.split("/")[-1] for m in models]
59
  for model_id in HARDCODED_MODELS:
60
- if model_id not in existing_ids:
61
  models.append(f"qwen_code/{model_id}")
 
62
 
63
- # Source 3: Try dynamic discovery from Qwen Code API
64
  try:
65
  # Validate OAuth credentials and get API details
66
  if os.path.isfile(credential):
@@ -75,18 +99,16 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
75
  )
76
  response.raise_for_status()
77
 
78
- # Parse dynamic models and avoid duplicates
79
- existing_ids = [m.split("/")[-1] for m in models]
80
  dynamic_data = response.json()
81
-
82
  # Handle both {data: [...]} and direct [...] formats
83
  model_list = dynamic_data.get("data", dynamic_data) if isinstance(dynamic_data, dict) else dynamic_data
84
 
85
  dynamic_count = 0
86
  for model in model_list:
87
- model_id = model.get("id") if isinstance(model, dict) else model
88
- if model_id and model_id not in existing_ids:
89
  models.append(f"qwen_code/{model_id}")
 
90
  dynamic_count += 1
91
 
92
  if dynamic_count > 0:
@@ -342,8 +364,12 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
342
  """Prepares and makes the actual API call."""
343
  api_base, access_token = await self.get_api_details(credential_path)
344
 
 
 
 
 
345
  # Build clean payload with only supported parameters
346
- payload = self._build_request_payload(**kwargs)
347
 
348
  headers = {
349
  "Authorization": f"Bearer {access_token}",
 
40
  async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
41
  """
42
  Returns a merged list of Qwen Code models from three sources:
43
+ 1. Environment variable models (via QWEN_CODE_MODELS) - ALWAYS included, take priority
44
+ 2. Hardcoded models (fallback list) - added only if ID not in env vars
45
+ 3. Dynamic discovery from Qwen API (if supported) - added only if ID not in env vars
46
+
47
+ Environment variable models always win and are never deduplicated, even if they
48
+ share the same ID (to support different configs like temperature, etc.)
49
 
50
  Validates OAuth credentials if applicable.
51
  """
52
  models = []
53
+ env_var_ids = set() # Track IDs from env vars to prevent hardcoded/dynamic duplicates
54
+
55
+ def extract_model_id(item) -> str:
56
+ """Extract model ID from various formats (dict, string with/without provider prefix)."""
57
+ if isinstance(item, dict):
58
+ # Dict format: extract 'id' or 'name' field
59
+ return item.get("id") or item.get("name", "")
60
+ elif isinstance(item, str):
61
+ # String format: extract ID from "provider/id" or just "id"
62
+ return item.split("/")[-1] if "/" in item else item
63
+ return str(item)
64
+
65
+ # Source 1: Load environment variable models (ALWAYS include ALL of them)
66
  static_models = self.model_definitions.get_all_provider_models("qwen_code")
67
  if static_models:
68
+ for model in static_models:
69
+ # Extract model name from "qwen_code/ModelName" format
70
+ model_name = model.split("/")[-1] if "/" in model else model
71
+ # Get the actual model ID from definitions (which may differ from the name)
72
+ model_id = self.model_definitions.get_model_id("qwen_code", model_name)
73
+
74
+ # ALWAYS add env var models (no deduplication)
75
+ models.append(model)
76
+ # Track the ID to prevent hardcoded/dynamic duplicates
77
+ if model_id:
78
+ env_var_ids.add(model_id)
79
  lib_logger.info(f"Loaded {len(static_models)} static models for qwen_code from environment variables")
80
 
81
+ # Source 2: Add hardcoded models (only if ID not already in env vars)
 
82
  for model_id in HARDCODED_MODELS:
83
+ if model_id not in env_var_ids:
84
  models.append(f"qwen_code/{model_id}")
85
+ env_var_ids.add(model_id)
86
 
87
+ # Source 3: Try dynamic discovery from Qwen Code API (only if ID not already in env vars)
88
  try:
89
  # Validate OAuth credentials and get API details
90
  if os.path.isfile(credential):
 
99
  )
100
  response.raise_for_status()
101
 
 
 
102
  dynamic_data = response.json()
 
103
  # Handle both {data: [...]} and direct [...] formats
104
  model_list = dynamic_data.get("data", dynamic_data) if isinstance(dynamic_data, dict) else dynamic_data
105
 
106
  dynamic_count = 0
107
  for model in model_list:
108
+ model_id = extract_model_id(model)
109
+ if model_id and model_id not in env_var_ids:
110
  models.append(f"qwen_code/{model_id}")
111
+ env_var_ids.add(model_id)
112
  dynamic_count += 1
113
 
114
  if dynamic_count > 0:
 
364
  """Prepares and makes the actual API call."""
365
  api_base, access_token = await self.get_api_details(credential_path)
366
 
367
+ # Strip provider prefix from model name (e.g., "qwen_code/qwen3-coder-plus" -> "qwen3-coder-plus")
368
+ model_name = model.split('/')[-1]
369
+ kwargs_with_stripped_model = {**kwargs, 'model': model_name}
370
+
371
  # Build clean payload with only supported parameters
372
+ payload = self._build_request_payload(**kwargs_with_stripped_model)
373
 
374
  headers = {
375
  "Authorization": f"Bearer {access_token}",