Spaces:
Paused
Paused
Mirrowel commited on
Commit ·
3b4c51c
1
Parent(s): 2ed5bb6
refactor(auth): generalize credential handling and improve model matching
Browse filesRefactors the RotatingClient and provider interfaces to use the generic 'credential' term instead of 'api_key'. This ensures consistency when handling non-key credentials, such as OAuth file paths.
Specific changes include:
- Renaming internal property references from 'api_keys' to 'all_credentials'.
- Updating provider interface method signatures (`get_models`) from `api_key` to `credential`.
- Enhancing model matching (whitelist/blacklist) to correctly match patterns against both the full proxy ID and the provider's native model name (e.g., allowing wildcards to match 'gemma-7b' when the full ID is 'google/gemma-7b').
src/rotator_library/client.py
CHANGED
|
@@ -97,42 +97,62 @@ class RotatingClient:
|
|
| 97 |
def _is_model_ignored(self, provider: str, model_id: str) -> bool:
|
| 98 |
"""
|
| 99 |
Checks if a model should be ignored based on the ignore list.
|
| 100 |
-
Supports exact and partial matching.
|
| 101 |
"""
|
| 102 |
-
|
|
|
|
| 103 |
return False
|
| 104 |
|
| 105 |
-
ignore_list = self.ignore_models[
|
| 106 |
if ignore_list == ['*']:
|
| 107 |
return True
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
return True
|
| 113 |
else:
|
| 114 |
-
# Exact match
|
| 115 |
-
if model_id
|
| 116 |
return True
|
| 117 |
return False
|
| 118 |
|
| 119 |
def _is_model_whitelisted(self, provider: str, model_id: str) -> bool:
|
| 120 |
"""
|
| 121 |
Checks if a model is explicitly whitelisted.
|
| 122 |
-
Supports exact and partial matching.
|
| 123 |
"""
|
| 124 |
-
|
|
|
|
| 125 |
return False
|
| 126 |
|
| 127 |
-
whitelist = self.whitelist_models[
|
| 128 |
-
for
|
| 129 |
-
if
|
| 130 |
return True
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
return True
|
| 134 |
else:
|
| 135 |
-
|
|
|
|
| 136 |
return True
|
| 137 |
return False
|
| 138 |
|
|
@@ -918,21 +938,27 @@ class RotatingClient:
|
|
| 918 |
lib_logger.debug(f"Returning cached models for provider: {provider}")
|
| 919 |
return self._model_list_cache[provider]
|
| 920 |
|
| 921 |
-
|
| 922 |
-
if not
|
| 923 |
-
lib_logger.warning(f"No
|
| 924 |
return []
|
| 925 |
|
| 926 |
-
# Create a copy and shuffle it to randomize the starting
|
| 927 |
-
|
| 928 |
-
random.shuffle(
|
| 929 |
|
| 930 |
provider_instance = self._get_provider_instance(provider)
|
| 931 |
if provider_instance:
|
| 932 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 933 |
try:
|
| 934 |
-
|
| 935 |
-
|
|
|
|
|
|
|
| 936 |
lib_logger.info(f"Got {len(models)} models for provider: {provider}")
|
| 937 |
|
| 938 |
# Whitelist and blacklist logic
|
|
@@ -955,7 +981,8 @@ class RotatingClient:
|
|
| 955 |
return final_models
|
| 956 |
except Exception as e:
|
| 957 |
classified_error = classify_error(e)
|
| 958 |
-
|
|
|
|
| 959 |
continue # Try the next credential
|
| 960 |
|
| 961 |
lib_logger.error(f"Failed to get models for provider {provider} after trying all credentials.")
|
|
@@ -964,11 +991,13 @@ class RotatingClient:
|
|
| 964 |
async def get_all_available_models(self, grouped: bool = True) -> Union[Dict[str, List[str]], List[str]]:
|
| 965 |
"""Returns a list of all available models, either grouped by provider or as a flat list."""
|
| 966 |
lib_logger.info("Getting all available models...")
|
| 967 |
-
|
|
|
|
|
|
|
| 968 |
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 969 |
|
| 970 |
all_provider_models = {}
|
| 971 |
-
for provider, result in zip(
|
| 972 |
if isinstance(result, Exception):
|
| 973 |
lib_logger.error(f"Failed to get models for provider {provider}: {result}")
|
| 974 |
all_provider_models[provider] = []
|
|
|
|
| 97 |
def _is_model_ignored(self, provider: str, model_id: str) -> bool:
|
| 98 |
"""
|
| 99 |
Checks if a model should be ignored based on the ignore list.
|
| 100 |
+
Supports exact and partial matching for both full model IDs and model names.
|
| 101 |
"""
|
| 102 |
+
model_provider = model_id.split('/')[0]
|
| 103 |
+
if model_provider not in self.ignore_models:
|
| 104 |
return False
|
| 105 |
|
| 106 |
+
ignore_list = self.ignore_models[model_provider]
|
| 107 |
if ignore_list == ['*']:
|
| 108 |
return True
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
# This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b")
|
| 112 |
+
provider_model_name = model_id.split('/', 1)[1]
|
| 113 |
+
except IndexError:
|
| 114 |
+
provider_model_name = model_id
|
| 115 |
+
|
| 116 |
+
for ignored_pattern in ignore_list:
|
| 117 |
+
if ignored_pattern.endswith('*'):
|
| 118 |
+
match_pattern = ignored_pattern[:-1]
|
| 119 |
+
# Match wildcard against the provider's model name
|
| 120 |
+
if provider_model_name.startswith(match_pattern):
|
| 121 |
return True
|
| 122 |
else:
|
| 123 |
+
# Exact match against the full proxy ID OR the provider's model name
|
| 124 |
+
if model_id == ignored_pattern or provider_model_name == ignored_pattern:
|
| 125 |
return True
|
| 126 |
return False
|
| 127 |
|
| 128 |
def _is_model_whitelisted(self, provider: str, model_id: str) -> bool:
|
| 129 |
"""
|
| 130 |
Checks if a model is explicitly whitelisted.
|
| 131 |
+
Supports exact and partial matching for both full model IDs and model names.
|
| 132 |
"""
|
| 133 |
+
model_provider = model_id.split('/')[0]
|
| 134 |
+
if model_provider not in self.whitelist_models:
|
| 135 |
return False
|
| 136 |
|
| 137 |
+
whitelist = self.whitelist_models[model_provider]
|
| 138 |
+
for whitelisted_pattern in whitelist:
|
| 139 |
+
if whitelisted_pattern == '*':
|
| 140 |
return True
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
# This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b")
|
| 144 |
+
provider_model_name = model_id.split('/', 1)[1]
|
| 145 |
+
except IndexError:
|
| 146 |
+
provider_model_name = model_id
|
| 147 |
+
|
| 148 |
+
if whitelisted_pattern.endswith('*'):
|
| 149 |
+
match_pattern = whitelisted_pattern[:-1]
|
| 150 |
+
# Match wildcard against the provider's model name
|
| 151 |
+
if provider_model_name.startswith(match_pattern):
|
| 152 |
return True
|
| 153 |
else:
|
| 154 |
+
# Exact match against the full proxy ID OR the provider's model name
|
| 155 |
+
if model_id == whitelisted_pattern or provider_model_name == whitelisted_pattern:
|
| 156 |
return True
|
| 157 |
return False
|
| 158 |
|
|
|
|
| 938 |
lib_logger.debug(f"Returning cached models for provider: {provider}")
|
| 939 |
return self._model_list_cache[provider]
|
| 940 |
|
| 941 |
+
credentials_for_provider = self.all_credentials.get(provider)
|
| 942 |
+
if not credentials_for_provider:
|
| 943 |
+
lib_logger.warning(f"No credentials for provider: {provider}")
|
| 944 |
return []
|
| 945 |
|
| 946 |
+
# Create a copy and shuffle it to randomize the starting credential
|
| 947 |
+
shuffled_credentials = list(credentials_for_provider)
|
| 948 |
+
random.shuffle(shuffled_credentials)
|
| 949 |
|
| 950 |
provider_instance = self._get_provider_instance(provider)
|
| 951 |
if provider_instance:
|
| 952 |
+
# For providers with hardcoded models (like gemini_cli), we only need to call once.
|
| 953 |
+
# For others, we might need to try multiple keys if one is invalid.
|
| 954 |
+
# The current logic of iterating works for both, as the credential is not
|
| 955 |
+
# always used in get_models.
|
| 956 |
+
for credential in shuffled_credentials:
|
| 957 |
try:
|
| 958 |
+
# Display last 6 chars for API keys, or the filename for OAuth paths
|
| 959 |
+
cred_display = credential[-6:] if not os.path.isfile(credential) else os.path.basename(credential)
|
| 960 |
+
lib_logger.debug(f"Attempting to get models for {provider} with credential ...{cred_display}")
|
| 961 |
+
models = await provider_instance.get_models(credential, self.http_client)
|
| 962 |
lib_logger.info(f"Got {len(models)} models for provider: {provider}")
|
| 963 |
|
| 964 |
# Whitelist and blacklist logic
|
|
|
|
| 981 |
return final_models
|
| 982 |
except Exception as e:
|
| 983 |
classified_error = classify_error(e)
|
| 984 |
+
cred_display = credential[-6:] if not os.path.isfile(credential) else os.path.basename(credential)
|
| 985 |
+
lib_logger.debug(f"Failed to get models for provider {provider} with credential ...{cred_display}: {classified_error.error_type}. Trying next credential.")
|
| 986 |
continue # Try the next credential
|
| 987 |
|
| 988 |
lib_logger.error(f"Failed to get models for provider {provider} after trying all credentials.")
|
|
|
|
| 991 |
async def get_all_available_models(self, grouped: bool = True) -> Union[Dict[str, List[str]], List[str]]:
|
| 992 |
"""Returns a list of all available models, either grouped by provider or as a flat list."""
|
| 993 |
lib_logger.info("Getting all available models...")
|
| 994 |
+
|
| 995 |
+
all_providers = list(self.all_credentials.keys())
|
| 996 |
+
tasks = [self.get_available_models(provider) for provider in all_providers]
|
| 997 |
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 998 |
|
| 999 |
all_provider_models = {}
|
| 1000 |
+
for provider, result in zip(all_providers, results):
|
| 1001 |
if isinstance(result, Exception):
|
| 1002 |
lib_logger.error(f"Failed to get models for provider {provider}: {result}")
|
| 1003 |
all_provider_models[provider] = []
|
src/rotator_library/providers/gemini_cli_provider.py
CHANGED
|
@@ -15,6 +15,13 @@ lib_logger = logging.getLogger('rotator_library')
|
|
| 15 |
|
| 16 |
CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal"
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
| 19 |
def __init__(self):
|
| 20 |
super().__init__()
|
|
@@ -197,14 +204,8 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 197 |
chunks = [chunk async for chunk in response_gen]
|
| 198 |
return litellm.utils.stream_to_completion_response(chunks)
|
| 199 |
|
| 200 |
-
# [NEW] Hardcoded model list based on Kilo example
|
| 201 |
-
HARDCODED_MODELS = [
|
| 202 |
-
"gemini-2.5-pro",
|
| 203 |
-
"gemini-2.5-flash",
|
| 204 |
-
"gemini-2.5-flash-lite"
|
| 205 |
-
]
|
| 206 |
# Use the shared GeminiAuthBase for auth logic
|
| 207 |
# get_models is not applicable for this custom provider
|
| 208 |
-
async def get_models(self,
|
| 209 |
"""Returns a hardcoded list of known compatible Gemini CLI models."""
|
| 210 |
return [f"gemini_cli/{model_id}" for model_id in HARDCODED_MODELS]
|
|
|
|
| 15 |
|
| 16 |
CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal"
|
| 17 |
|
| 18 |
+
# [NEW] Hardcoded model list based on Kilo example
|
| 19 |
+
HARDCODED_MODELS = [
|
| 20 |
+
"gemini-2.5-pro",
|
| 21 |
+
"gemini-2.5-flash",
|
| 22 |
+
"gemini-2.5-flash-lite"
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
| 26 |
def __init__(self):
|
| 27 |
super().__init__()
|
|
|
|
| 204 |
chunks = [chunk async for chunk in response_gen]
|
| 205 |
return litellm.utils.stream_to_completion_response(chunks)
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
# Use the shared GeminiAuthBase for auth logic
|
| 208 |
# get_models is not applicable for this custom provider
|
| 209 |
+
async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
|
| 210 |
"""Returns a hardcoded list of known compatible Gemini CLI models."""
|
| 211 |
return [f"gemini_cli/{model_id}" for model_id in HARDCODED_MODELS]
|
src/rotator_library/providers/qwen_code_provider.py
CHANGED
|
@@ -21,7 +21,7 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 21 |
return True # We use custom logic to handle 401 retries and stream parsing
|
| 22 |
|
| 23 |
# [NEW] get_models implementation
|
| 24 |
-
async def get_models(self,
|
| 25 |
"""Returns a hardcoded list of known compatible Qwen models for the OpenAI-compatible API."""
|
| 26 |
return [f"qwen_code/{model_id}" for model_id in HARDCODED_MODELS]
|
| 27 |
|
|
|
|
| 21 |
return True # We use custom logic to handle 401 retries and stream parsing
|
| 22 |
|
| 23 |
# [NEW] get_models implementation
|
| 24 |
+
async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
|
| 25 |
"""Returns a hardcoded list of known compatible Qwen models for the OpenAI-compatible API."""
|
| 26 |
return [f"qwen_code/{model_id}" for model_id in HARDCODED_MODELS]
|
| 27 |
|