Mirrowel commited on
Commit
3b4c51c
·
1 Parent(s): 2ed5bb6

refactor(auth): generalize credential handling and improve model matching

Browse files

Refactors the RotatingClient and provider interfaces to use the generic 'credential' term instead of 'api_key'. This ensures consistency when handling non-key credentials, such as OAuth file paths.

Specific changes include:
- Renaming internal property references from 'api_keys' to 'all_credentials'.
- Updating provider interface method signatures (`get_models`) from `api_key` to `credential`.
- Enhancing model matching (whitelist/blacklist) to correctly match patterns against both the full proxy ID and the provider's native model name (e.g., allowing wildcards to match 'gemma-7b' when the full ID is 'google/gemma-7b').

src/rotator_library/client.py CHANGED
@@ -97,42 +97,62 @@ class RotatingClient:
97
  def _is_model_ignored(self, provider: str, model_id: str) -> bool:
98
  """
99
  Checks if a model should be ignored based on the ignore list.
100
- Supports exact and partial matching.
101
  """
102
- if provider not in self.ignore_models:
 
103
  return False
104
 
105
- ignore_list = self.ignore_models[provider]
106
  if ignore_list == ['*']:
107
  return True
108
- for ignored_model in ignore_list:
109
- if ignored_model.endswith('*'):
110
- # Partial match
111
- if ignored_model[:-1] in model_id:
 
 
 
 
 
 
 
 
112
  return True
113
  else:
114
- # Exact match (ignoring provider prefix)
115
- if model_id.endswith(ignored_model):
116
  return True
117
  return False
118
 
119
  def _is_model_whitelisted(self, provider: str, model_id: str) -> bool:
120
  """
121
  Checks if a model is explicitly whitelisted.
122
- Supports exact and partial matching.
123
  """
124
- if provider not in self.whitelist_models:
 
125
  return False
126
 
127
- whitelist = self.whitelist_models[provider]
128
- for whitelisted_model in whitelist:
129
- if whitelisted_model == '*':
130
  return True
131
- if whitelisted_model.endswith('*'):
132
- if whitelisted_model[:-1] in model_id:
 
 
 
 
 
 
 
 
 
133
  return True
134
  else:
135
- if model_id.endswith(whitelisted_model):
 
136
  return True
137
  return False
138
 
@@ -918,21 +938,27 @@ class RotatingClient:
918
  lib_logger.debug(f"Returning cached models for provider: {provider}")
919
  return self._model_list_cache[provider]
920
 
921
- keys_for_provider = self.api_keys.get(provider)
922
- if not keys_for_provider:
923
- lib_logger.warning(f"No API key for provider: {provider}")
924
  return []
925
 
926
- # Create a copy and shuffle it to randomize the starting key
927
- shuffled_keys = list(keys_for_provider)
928
- random.shuffle(shuffled_keys)
929
 
930
  provider_instance = self._get_provider_instance(provider)
931
  if provider_instance:
932
- for api_key in shuffled_keys:
 
 
 
 
933
  try:
934
- lib_logger.debug(f"Attempting to get models for {provider} with credential ...{api_key[-6:]}")
935
- models = await provider_instance.get_models(api_key, self.http_client)
 
 
936
  lib_logger.info(f"Got {len(models)} models for provider: {provider}")
937
 
938
  # Whitelist and blacklist logic
@@ -955,7 +981,8 @@ class RotatingClient:
955
  return final_models
956
  except Exception as e:
957
  classified_error = classify_error(e)
958
- lib_logger.debug(f"Failed to get models for provider {provider} with credential ...{api_key[-6:]}: {classified_error.error_type}. Trying next credential.")
 
959
  continue # Try the next credential
960
 
961
  lib_logger.error(f"Failed to get models for provider {provider} after trying all credentials.")
@@ -964,11 +991,13 @@ class RotatingClient:
964
  async def get_all_available_models(self, grouped: bool = True) -> Union[Dict[str, List[str]], List[str]]:
965
  """Returns a list of all available models, either grouped by provider or as a flat list."""
966
  lib_logger.info("Getting all available models...")
967
- tasks = [self.get_available_models(provider) for provider in self.api_keys.keys()]
 
 
968
  results = await asyncio.gather(*tasks, return_exceptions=True)
969
 
970
  all_provider_models = {}
971
- for provider, result in zip(self.api_keys.keys(), results):
972
  if isinstance(result, Exception):
973
  lib_logger.error(f"Failed to get models for provider {provider}: {result}")
974
  all_provider_models[provider] = []
 
97
  def _is_model_ignored(self, provider: str, model_id: str) -> bool:
98
  """
99
  Checks if a model should be ignored based on the ignore list.
100
+ Supports exact and partial matching for both full model IDs and model names.
101
  """
102
+ model_provider = model_id.split('/')[0]
103
+ if model_provider not in self.ignore_models:
104
  return False
105
 
106
+ ignore_list = self.ignore_models[model_provider]
107
  if ignore_list == ['*']:
108
  return True
109
+
110
+ try:
111
+ # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b")
112
+ provider_model_name = model_id.split('/', 1)[1]
113
+ except IndexError:
114
+ provider_model_name = model_id
115
+
116
+ for ignored_pattern in ignore_list:
117
+ if ignored_pattern.endswith('*'):
118
+ match_pattern = ignored_pattern[:-1]
119
+ # Match wildcard against the provider's model name
120
+ if provider_model_name.startswith(match_pattern):
121
  return True
122
  else:
123
+ # Exact match against the full proxy ID OR the provider's model name
124
+ if model_id == ignored_pattern or provider_model_name == ignored_pattern:
125
  return True
126
  return False
127
 
128
  def _is_model_whitelisted(self, provider: str, model_id: str) -> bool:
129
  """
130
  Checks if a model is explicitly whitelisted.
131
+ Supports exact and partial matching for both full model IDs and model names.
132
  """
133
+ model_provider = model_id.split('/')[0]
134
+ if model_provider not in self.whitelist_models:
135
  return False
136
 
137
+ whitelist = self.whitelist_models[model_provider]
138
+ for whitelisted_pattern in whitelist:
139
+ if whitelisted_pattern == '*':
140
  return True
141
+
142
+ try:
143
+ # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b")
144
+ provider_model_name = model_id.split('/', 1)[1]
145
+ except IndexError:
146
+ provider_model_name = model_id
147
+
148
+ if whitelisted_pattern.endswith('*'):
149
+ match_pattern = whitelisted_pattern[:-1]
150
+ # Match wildcard against the provider's model name
151
+ if provider_model_name.startswith(match_pattern):
152
  return True
153
  else:
154
+ # Exact match against the full proxy ID OR the provider's model name
155
+ if model_id == whitelisted_pattern or provider_model_name == whitelisted_pattern:
156
  return True
157
  return False
158
 
 
938
  lib_logger.debug(f"Returning cached models for provider: {provider}")
939
  return self._model_list_cache[provider]
940
 
941
+ credentials_for_provider = self.all_credentials.get(provider)
942
+ if not credentials_for_provider:
943
+ lib_logger.warning(f"No credentials for provider: {provider}")
944
  return []
945
 
946
+ # Create a copy and shuffle it to randomize the starting credential
947
+ shuffled_credentials = list(credentials_for_provider)
948
+ random.shuffle(shuffled_credentials)
949
 
950
  provider_instance = self._get_provider_instance(provider)
951
  if provider_instance:
952
+ # For providers with hardcoded models (like gemini_cli), we only need to call once.
953
+ # For others, we might need to try multiple keys if one is invalid.
954
+ # The current logic of iterating works for both, as the credential is not
955
+ # always used in get_models.
956
+ for credential in shuffled_credentials:
957
  try:
958
+ # Display last 6 chars for API keys, or the filename for OAuth paths
959
+ cred_display = credential[-6:] if not os.path.isfile(credential) else os.path.basename(credential)
960
+ lib_logger.debug(f"Attempting to get models for {provider} with credential ...{cred_display}")
961
+ models = await provider_instance.get_models(credential, self.http_client)
962
  lib_logger.info(f"Got {len(models)} models for provider: {provider}")
963
 
964
  # Whitelist and blacklist logic
 
981
  return final_models
982
  except Exception as e:
983
  classified_error = classify_error(e)
984
+ cred_display = credential[-6:] if not os.path.isfile(credential) else os.path.basename(credential)
985
+ lib_logger.debug(f"Failed to get models for provider {provider} with credential ...{cred_display}: {classified_error.error_type}. Trying next credential.")
986
  continue # Try the next credential
987
 
988
  lib_logger.error(f"Failed to get models for provider {provider} after trying all credentials.")
 
991
  async def get_all_available_models(self, grouped: bool = True) -> Union[Dict[str, List[str]], List[str]]:
992
  """Returns a list of all available models, either grouped by provider or as a flat list."""
993
  lib_logger.info("Getting all available models...")
994
+
995
+ all_providers = list(self.all_credentials.keys())
996
+ tasks = [self.get_available_models(provider) for provider in all_providers]
997
  results = await asyncio.gather(*tasks, return_exceptions=True)
998
 
999
  all_provider_models = {}
1000
+ for provider, result in zip(all_providers, results):
1001
  if isinstance(result, Exception):
1002
  lib_logger.error(f"Failed to get models for provider {provider}: {result}")
1003
  all_provider_models[provider] = []
src/rotator_library/providers/gemini_cli_provider.py CHANGED
@@ -15,6 +15,13 @@ lib_logger = logging.getLogger('rotator_library')
15
 
16
  CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal"
17
 
 
 
 
 
 
 
 
18
  class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
19
  def __init__(self):
20
  super().__init__()
@@ -197,14 +204,8 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
197
  chunks = [chunk async for chunk in response_gen]
198
  return litellm.utils.stream_to_completion_response(chunks)
199
 
200
- # [NEW] Hardcoded model list based on Kilo example
201
- HARDCODED_MODELS = [
202
- "gemini-2.5-pro",
203
- "gemini-2.5-flash",
204
- "gemini-2.5-flash-lite"
205
- ]
206
  # Use the shared GeminiAuthBase for auth logic
207
  # get_models is not applicable for this custom provider
208
- async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
209
  """Returns a hardcoded list of known compatible Gemini CLI models."""
210
  return [f"gemini_cli/{model_id}" for model_id in HARDCODED_MODELS]
 
15
 
16
  CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal"
17
 
18
+ # [NEW] Hardcoded model list based on Kilo example
19
+ HARDCODED_MODELS = [
20
+ "gemini-2.5-pro",
21
+ "gemini-2.5-flash",
22
+ "gemini-2.5-flash-lite"
23
+ ]
24
+
25
  class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
26
  def __init__(self):
27
  super().__init__()
 
204
  chunks = [chunk async for chunk in response_gen]
205
  return litellm.utils.stream_to_completion_response(chunks)
206
 
 
 
 
 
 
 
207
  # Use the shared GeminiAuthBase for auth logic
208
  # get_models is not applicable for this custom provider
209
+ async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
210
  """Returns a hardcoded list of known compatible Gemini CLI models."""
211
  return [f"gemini_cli/{model_id}" for model_id in HARDCODED_MODELS]
src/rotator_library/providers/qwen_code_provider.py CHANGED
@@ -21,7 +21,7 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
21
  return True # We use custom logic to handle 401 retries and stream parsing
22
 
23
  # [NEW] get_models implementation
24
- async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
25
  """Returns a hardcoded list of known compatible Qwen models for the OpenAI-compatible API."""
26
  return [f"qwen_code/{model_id}" for model_id in HARDCODED_MODELS]
27
 
 
21
  return True # We use custom logic to handle 401 retries and stream parsing
22
 
23
  # [NEW] get_models implementation
24
+ async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
25
  """Returns a hardcoded list of known compatible Qwen models for the OpenAI-compatible API."""
26
  return [f"qwen_code/{model_id}" for model_id in HARDCODED_MODELS]
27