Spaces:
Paused
feat(provider): introduce OAuth credential management and custom provider handling
Browse files- Implemented a new `CredentialManager` to discover and manage OAuth credential files from standard paths (`~/.gemini`, `~/.qwen`).
- Added a `BackgroundRefresher` to proactively refresh OAuth tokens before they expire, ensuring continuous service.
- Refactored `RotatingClient` to support both API keys and OAuth credentials for provider authentication.
- Integrated `litellm_provider_params` in `RotatingClient` to allow specific LiteLLM configurations per provider (e.g., Google Cloud project ID for Gemini CLI).
- Introduced a `has_custom_logic` flag and `acompletion` method in `ProviderInterface` to enable custom handling for providers like Gemini CLI and Qwen Code, which require specific request formats, authentication, or stream parsing not fully supported by LiteLLM's standard interface.
- Updated `proxy_app/main.py` to utilize the new OAuth credential loading, provider-specific LiteLLM parameters, and the background token refresher.
- Enhanced `error_handler.py` to classify `httpx` exceptions, improving error reporting and retry logic for network and HTTP errors.
- Added `.env.example` entries for configuring Gemini CLI project ID and Qwen/Gemini OAuth credential paths.
BREAKING CHANGE: The constructor for `RotatingClient` has been updated. It now requires an `oauth_credentials` dictionary (can be empty) and accepts an optional `litellm_provider_params` dictionary. Direct instantiations of `RotatingClient` must be updated to include these new arguments.
- .env.example +12 -0
- src/proxy_app/main.py +32 -33
- src/rotator_library/background_refresher.py +57 -0
- src/rotator_library/client.py +234 -146
- src/rotator_library/credential_manager.py +70 -0
- src/rotator_library/error_handler.py +17 -2
- src/rotator_library/providers/__init__.py +2 -2
- src/rotator_library/providers/gemini_auth_base.py +102 -0
- src/rotator_library/providers/gemini_cli_provider.py +171 -0
- src/rotator_library/providers/provider_interface.py +38 -5
- src/rotator_library/providers/qwen_auth_base.py +101 -0
- src/rotator_library/providers/qwen_code_provider.py +71 -0
|
@@ -11,3 +11,15 @@ NVIDIA_NIM_API_KEY_2="YOUR_NVIDIA_NIM_API_KEY_2"
|
|
| 11 |
|
| 12 |
# A secret key for your proxy server to authenticate requests(Can be anything. Used for compatibility)
|
| 13 |
PROXY_API_KEY="YOUR_PROXY_API_KEY"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# A secret key for your proxy server to authenticate requests(Can be anything. Used for compatibility)
|
| 13 |
PROXY_API_KEY="YOUR_PROXY_API_KEY"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# --- OAuth Accounts ---
|
| 17 |
+
# The system will automatically discover standard paths if left blank.
|
| 18 |
+
|
| 19 |
+
# For Gemini CLI (uses a custom API)
|
| 20 |
+
GEMINI_CLI_OAUTH_1=
|
| 21 |
+
# Required for Gemini CLI: Your Google Cloud Project ID
|
| 22 |
+
GEMINI_CLI_PROJECT_ID="gen-lang-client-..."
|
| 23 |
+
|
| 24 |
+
# For Qwen Code (OpenAI Compatible)
|
| 25 |
+
QWEN_CODE_OAUTH_1=
|
|
@@ -52,6 +52,8 @@ args, _ = parser.parse_known_args()
|
|
| 52 |
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 53 |
|
| 54 |
from rotator_library import RotatingClient, PROVIDER_PLUGINS
|
|
|
|
|
|
|
| 55 |
from proxy_app.request_logger import log_request_to_console
|
| 56 |
from proxy_app.batch_manager import EmbeddingBatcher
|
| 57 |
from proxy_app.detailed_logger import DetailedLogger
|
|
@@ -125,19 +127,28 @@ PROXY_API_KEY = os.getenv("PROXY_API_KEY")
|
|
| 125 |
if not PROXY_API_KEY:
|
| 126 |
raise ValueError("PROXY_API_KEY environment variable not set.")
|
| 127 |
|
| 128 |
-
#
|
| 129 |
api_keys = {}
|
|
|
|
| 130 |
for key, value in os.environ.items():
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
if provider not in api_keys:
|
| 136 |
api_keys[provider] = []
|
| 137 |
api_keys[provider].append(value)
|
| 138 |
|
| 139 |
-
if not api_keys:
|
| 140 |
-
raise ValueError("No provider API keys found in environment variables.")
|
| 141 |
|
| 142 |
# Load model ignore lists from environment variables
|
| 143 |
ignore_models = {}
|
|
@@ -152,8 +163,20 @@ for key, value in os.environ.items():
|
|
| 152 |
@asynccontextmanager
|
| 153 |
async def lifespan(app: FastAPI):
|
| 154 |
"""Manage the RotatingClient's lifecycle with the app's lifespan."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# The client now uses the root logger configuration
|
| 156 |
-
client = RotatingClient(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
app.state.rotating_client = client
|
| 158 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 159 |
litellm.set_verbose = False
|
|
@@ -168,6 +191,7 @@ async def lifespan(app: FastAPI):
|
|
| 168 |
|
| 169 |
yield
|
| 170 |
|
|
|
|
| 171 |
if app.state.embedding_batcher:
|
| 172 |
await app.state.embedding_batcher.stop()
|
| 173 |
await client.close()
|
|
@@ -477,20 +501,6 @@ async def embeddings(
|
|
| 477 |
|
| 478 |
response = await client.aembedding(request=request, **request_data)
|
| 479 |
|
| 480 |
-
if ENABLE_REQUEST_LOGGING:
|
| 481 |
-
response_summary = {
|
| 482 |
-
"model": response.model,
|
| 483 |
-
"object": response.object,
|
| 484 |
-
"usage": response.usage.model_dump(),
|
| 485 |
-
"data_count": len(response.data),
|
| 486 |
-
"embedding_dimensions": len(response.data[0].embedding) if response.data else 0
|
| 487 |
-
}
|
| 488 |
-
log_request_response(
|
| 489 |
-
request_data=body.model_dump(exclude_none=True),
|
| 490 |
-
response_data=response_summary,
|
| 491 |
-
is_streaming=False,
|
| 492 |
-
log_type="embedding"
|
| 493 |
-
)
|
| 494 |
return response
|
| 495 |
|
| 496 |
except HTTPException as e:
|
|
@@ -510,17 +520,6 @@ async def embeddings(
|
|
| 510 |
raise HTTPException(status_code=502, detail=f"Bad Gateway: {str(e)}")
|
| 511 |
except Exception as e:
|
| 512 |
logging.error(f"Embedding request failed: {e}")
|
| 513 |
-
if ENABLE_REQUEST_LOGGING:
|
| 514 |
-
try:
|
| 515 |
-
request_data = await request.json()
|
| 516 |
-
except json.JSONDecodeError:
|
| 517 |
-
request_data = {"error": "Could not parse request body"}
|
| 518 |
-
log_request_response(
|
| 519 |
-
request_data=request_data,
|
| 520 |
-
response_data={"error": str(e)},
|
| 521 |
-
is_streaming=False,
|
| 522 |
-
log_type="embedding"
|
| 523 |
-
)
|
| 524 |
raise HTTPException(status_code=500, detail=str(e))
|
| 525 |
|
| 526 |
@app.get("/")
|
|
|
|
| 52 |
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 53 |
|
| 54 |
from rotator_library import RotatingClient, PROVIDER_PLUGINS
|
| 55 |
+
from rotator_library.credential_manager import CredentialManager
|
| 56 |
+
from rotator_library.background_refresher import BackgroundRefresher
|
| 57 |
from proxy_app.request_logger import log_request_to_console
|
| 58 |
from proxy_app.batch_manager import EmbeddingBatcher
|
| 59 |
from proxy_app.detailed_logger import DetailedLogger
|
|
|
|
| 127 |
if not PROXY_API_KEY:
|
| 128 |
raise ValueError("PROXY_API_KEY environment variable not set.")
|
| 129 |
|
| 130 |
+
# Split API keys and OAuth config loading
|
| 131 |
api_keys = {}
|
| 132 |
+
oauth_credentials = {}
|
| 133 |
for key, value in os.environ.items():
|
| 134 |
+
if key == "PROXY_API_KEY":
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
# Handles GEMINI_CLI_OAUTH_1, QWEN_CODE_OAUTH_1, etc.
|
| 138 |
+
if "_OAUTH_" in key:
|
| 139 |
+
provider = key.split("_OAUTH_")[0].lower()
|
| 140 |
+
if provider not in oauth_credentials:
|
| 141 |
+
oauth_credentials[provider] = []
|
| 142 |
+
oauth_credentials[provider].append(value)
|
| 143 |
+
# Handles GEMINI_API_KEY_1, etc.
|
| 144 |
+
elif "_API_KEY" in key:
|
| 145 |
+
provider = key.split("_API_KEY")[0].lower()
|
| 146 |
if provider not in api_keys:
|
| 147 |
api_keys[provider] = []
|
| 148 |
api_keys[provider].append(value)
|
| 149 |
|
| 150 |
+
if not api_keys and not oauth_credentials:
|
| 151 |
+
raise ValueError("No provider API keys or OAuth credentials found in environment variables.")
|
| 152 |
|
| 153 |
# Load model ignore lists from environment variables
|
| 154 |
ignore_models = {}
|
|
|
|
| 163 |
@asynccontextmanager
|
| 164 |
async def lifespan(app: FastAPI):
|
| 165 |
"""Manage the RotatingClient's lifecycle with the app's lifespan."""
|
| 166 |
+
# [NEW] Load provider-specific params
|
| 167 |
+
litellm_provider_params = {
|
| 168 |
+
"gemini_cli": {"project_id": os.getenv("GEMINI_CLI_PROJECT_ID")}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
# The client now uses the root logger configuration
|
| 172 |
+
client = RotatingClient(
|
| 173 |
+
api_keys=api_keys,
|
| 174 |
+
oauth_credentials=oauth_credentials, # Pass OAuth config
|
| 175 |
+
configure_logging=True,
|
| 176 |
+
litellm_provider_params=litellm_provider_params, # [NEW]
|
| 177 |
+
ignore_models=ignore_models
|
| 178 |
+
)
|
| 179 |
+
client.background_refresher.start() # Start the background task
|
| 180 |
app.state.rotating_client = client
|
| 181 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 182 |
litellm.set_verbose = False
|
|
|
|
| 191 |
|
| 192 |
yield
|
| 193 |
|
| 194 |
+
await client.background_refresher.stop() # Stop the background task on shutdown
|
| 195 |
if app.state.embedding_batcher:
|
| 196 |
await app.state.embedding_batcher.stop()
|
| 197 |
await client.close()
|
|
|
|
| 501 |
|
| 502 |
response = await client.aembedding(request=request, **request_data)
|
| 503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
return response
|
| 505 |
|
| 506 |
except HTTPException as e:
|
|
|
|
| 520 |
raise HTTPException(status_code=502, detail=f"Bad Gateway: {str(e)}")
|
| 521 |
except Exception as e:
|
| 522 |
logging.error(f"Embedding request failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
raise HTTPException(status_code=500, detail=str(e))
|
| 524 |
|
| 525 |
@app.get("/")
|
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/rotator_library/background_refresher.py
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import logging
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from .client import RotatingClient
|
| 9 |
+
|
| 10 |
+
lib_logger = logging.getLogger('rotator_library')
|
| 11 |
+
|
| 12 |
+
class BackgroundRefresher:
|
| 13 |
+
"""
|
| 14 |
+
A background task that periodically checks and refreshes OAuth tokens
|
| 15 |
+
to ensure they remain valid.
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, client: 'RotatingClient', interval_seconds: int = 300):
|
| 18 |
+
self._client = client
|
| 19 |
+
self._interval = interval_seconds
|
| 20 |
+
self._task: Optional[asyncio.Task] = None
|
| 21 |
+
|
| 22 |
+
def start(self):
|
| 23 |
+
"""Starts the background refresh task."""
|
| 24 |
+
if self._task is None:
|
| 25 |
+
self._task = asyncio.create_task(self._run())
|
| 26 |
+
lib_logger.info(f"Background token refresher started. Check interval: {self._interval} seconds.")
|
| 27 |
+
|
| 28 |
+
async def stop(self):
|
| 29 |
+
"""Stops the background refresh task."""
|
| 30 |
+
if self._task:
|
| 31 |
+
self._task.cancel()
|
| 32 |
+
try:
|
| 33 |
+
await self._task
|
| 34 |
+
except asyncio.CancelledError:
|
| 35 |
+
pass
|
| 36 |
+
lib_logger.info("Background token refresher stopped.")
|
| 37 |
+
|
| 38 |
+
async def _run(self):
|
| 39 |
+
"""The main loop for the background task."""
|
| 40 |
+
while True:
|
| 41 |
+
try:
|
| 42 |
+
await asyncio.sleep(self._interval)
|
| 43 |
+
lib_logger.info("Running proactive token refresh check...")
|
| 44 |
+
|
| 45 |
+
oauth_configs = self._client.get_oauth_credentials()
|
| 46 |
+
for provider, paths in oauth_configs.items():
|
| 47 |
+
provider_plugin = self._client._get_provider_instance(f"{provider}_oauth")
|
| 48 |
+
if provider_plugin and hasattr(provider_plugin, 'proactively_refresh'):
|
| 49 |
+
for path in paths:
|
| 50 |
+
try:
|
| 51 |
+
await provider_plugin.proactively_refresh(path)
|
| 52 |
+
except Exception as e:
|
| 53 |
+
lib_logger.error(f"Error during proactive refresh for '{path}': {e}")
|
| 54 |
+
except asyncio.CancelledError:
|
| 55 |
+
break
|
| 56 |
+
except Exception as e:
|
| 57 |
+
lib_logger.error(f"Unexpected error in background refresher loop: {e}")
|
|
@@ -24,6 +24,8 @@ from .error_handler import PreRequestCallbackError, classify_error, AllProviders
|
|
| 24 |
from .providers import PROVIDER_PLUGINS
|
| 25 |
from .request_sanitizer import sanitize_request_payload
|
| 26 |
from .cooldown_manager import CooldownManager
|
|
|
|
|
|
|
| 27 |
|
| 28 |
class StreamedAPIError(Exception):
|
| 29 |
"""Custom exception to signal an API error received over a stream."""
|
|
@@ -39,11 +41,13 @@ class RotatingClient:
|
|
| 39 |
def __init__(
|
| 40 |
self,
|
| 41 |
api_keys: Dict[str, List[str]],
|
|
|
|
| 42 |
max_retries: int = 2,
|
| 43 |
usage_file_path: str = "key_usage.json",
|
| 44 |
configure_logging: bool = True,
|
| 45 |
global_timeout: int = 30,
|
| 46 |
abort_on_callback_error: bool = True,
|
|
|
|
| 47 |
ignore_models: Optional[Dict[str, List[str]]] = None
|
| 48 |
):
|
| 49 |
os.environ["LITELLM_LOG"] = "ERROR"
|
|
@@ -63,6 +67,18 @@ class RotatingClient:
|
|
| 63 |
if not api_keys:
|
| 64 |
raise ValueError("API keys dictionary cannot be empty.")
|
| 65 |
self.api_keys = api_keys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
self.max_retries = max_retries
|
| 67 |
self.global_timeout = global_timeout
|
| 68 |
self.abort_on_callback_error = abort_on_callback_error
|
|
@@ -73,6 +89,7 @@ class RotatingClient:
|
|
| 73 |
self.http_client = httpx.AsyncClient()
|
| 74 |
self.all_providers = AllProviders()
|
| 75 |
self.cooldown_manager = CooldownManager()
|
|
|
|
| 76 |
self.ignore_models = ignore_models or {}
|
| 77 |
|
| 78 |
def _is_model_ignored(self, provider: str, model_id: str) -> bool:
|
|
@@ -191,6 +208,9 @@ class RotatingClient:
|
|
| 191 |
|
| 192 |
return kwargs
|
| 193 |
|
|
|
|
|
|
|
|
|
|
| 194 |
def _get_provider_instance(self, provider_name: str):
|
| 195 |
"""Lazily initializes and returns a provider instance."""
|
| 196 |
if provider_name not in self._provider_instances:
|
|
@@ -338,8 +358,8 @@ class RotatingClient:
|
|
| 338 |
raise ValueError("'model' is a required parameter.")
|
| 339 |
|
| 340 |
provider = model.split('/')[0]
|
| 341 |
-
if provider not in self.
|
| 342 |
-
raise ValueError(f"No API keys configured for provider: {provider}")
|
| 343 |
|
| 344 |
# Establish a global deadline for the entire request lifecycle.
|
| 345 |
deadline = time.time() + self.global_timeout
|
|
@@ -347,16 +367,16 @@ class RotatingClient:
|
|
| 347 |
# Create a mutable copy of the keys and shuffle it to ensure
|
| 348 |
# that the key selection is randomized, which is crucial when
|
| 349 |
# multiple keys have the same usage stats.
|
| 350 |
-
|
| 351 |
-
random.shuffle(
|
| 352 |
|
| 353 |
-
|
| 354 |
last_exception = None
|
| 355 |
kwargs = self._convert_model_params(**kwargs)
|
| 356 |
-
|
| 357 |
-
# The main rotation loop. It continues as long as there are untried
|
| 358 |
-
while len(
|
| 359 |
-
|
| 360 |
key_acquired = False
|
| 361 |
try:
|
| 362 |
# Check for a provider-wide cooldown first.
|
|
@@ -372,129 +392,167 @@ class RotatingClient:
|
|
| 372 |
lib_logger.warning(f"Provider {provider} is in cooldown. Waiting for {remaining_cooldown:.2f} seconds.")
|
| 373 |
await asyncio.sleep(remaining_cooldown)
|
| 374 |
|
| 375 |
-
|
| 376 |
-
if not
|
| 377 |
break
|
| 378 |
|
| 379 |
-
lib_logger.info(f"Acquiring key for model {model}. Tried keys: {len(
|
| 380 |
-
|
| 381 |
-
available_keys=
|
| 382 |
model=model,
|
| 383 |
deadline=deadline
|
| 384 |
)
|
| 385 |
key_acquired = True
|
| 386 |
-
|
| 387 |
|
| 388 |
litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
|
| 389 |
-
provider_instance = self._get_provider_instance(provider)
|
| 390 |
-
if provider_instance:
|
| 391 |
-
if "safety_settings" in litellm_kwargs:
|
| 392 |
-
converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
|
| 393 |
-
if converted_settings is not None:
|
| 394 |
-
litellm_kwargs["safety_settings"] = converted_settings
|
| 395 |
-
else:
|
| 396 |
-
del litellm_kwargs["safety_settings"]
|
| 397 |
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
| 403 |
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
|
| 418 |
-
|
| 419 |
-
response = await api_call(
|
| 420 |
-
api_key=current_key,
|
| 421 |
-
**litellm_kwargs,
|
| 422 |
-
logger_fn=self._litellm_logger_callback
|
| 423 |
-
)
|
| 424 |
-
|
| 425 |
-
await self.usage_manager.record_success(current_key, model, response)
|
| 426 |
-
await self.usage_manager.release_key(current_key, model)
|
| 427 |
key_acquired = False
|
| 428 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
error_message = str(e).split('\n')[0]
|
| 456 |
-
lib_logger.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
break # Move to the next key
|
| 458 |
-
|
| 459 |
-
# For temporary errors, wait before retrying with the same key.
|
| 460 |
-
wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
|
| 461 |
-
remaining_budget = deadline - time.time()
|
| 462 |
-
|
| 463 |
-
# If the required wait time exceeds the budget, don't wait; rotate to the next key immediately.
|
| 464 |
-
if wait_time > remaining_budget:
|
| 465 |
-
lib_logger.warning(f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early.")
|
| 466 |
-
break
|
| 467 |
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
|
|
|
| 472 |
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
#
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
finally:
|
| 496 |
-
if key_acquired and
|
| 497 |
-
await self.usage_manager.release_key(
|
| 498 |
|
| 499 |
if last_exception:
|
| 500 |
# Log the final error but do not raise it, as per the new requirement.
|
|
@@ -510,19 +568,19 @@ class RotatingClient:
|
|
| 510 |
provider = model.split('/')[0]
|
| 511 |
|
| 512 |
# Create a mutable copy of the keys and shuffle it.
|
| 513 |
-
|
| 514 |
-
random.shuffle(
|
| 515 |
|
| 516 |
deadline = time.time() + self.global_timeout
|
| 517 |
-
|
| 518 |
last_exception = None
|
| 519 |
kwargs = self._convert_model_params(**kwargs)
|
| 520 |
|
| 521 |
consecutive_quota_failures = 0
|
| 522 |
|
| 523 |
try:
|
| 524 |
-
while len(
|
| 525 |
-
|
| 526 |
key_acquired = False
|
| 527 |
try:
|
| 528 |
if await self.cooldown_manager.is_cooling_down(provider):
|
|
@@ -534,21 +592,52 @@ class RotatingClient:
|
|
| 534 |
lib_logger.warning(f"Provider {provider} is in a global cooldown. All requests to this provider will be paused for {remaining_cooldown:.2f} seconds.")
|
| 535 |
await asyncio.sleep(remaining_cooldown)
|
| 536 |
|
| 537 |
-
|
| 538 |
-
if not
|
| 539 |
-
lib_logger.warning(f"All
|
| 540 |
break
|
| 541 |
|
| 542 |
-
lib_logger.info(f"Acquiring
|
| 543 |
-
|
| 544 |
-
available_keys=
|
| 545 |
model=model,
|
| 546 |
deadline=deadline
|
| 547 |
)
|
| 548 |
key_acquired = True
|
| 549 |
-
|
| 550 |
|
| 551 |
litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
provider_instance = self._get_provider_instance(provider)
|
| 553 |
if provider_instance:
|
| 554 |
if "safety_settings" in litellm_kwargs:
|
|
@@ -568,7 +657,7 @@ class RotatingClient:
|
|
| 568 |
|
| 569 |
for attempt in range(self.max_retries):
|
| 570 |
try:
|
| 571 |
-
lib_logger.info(f"Attempting stream with
|
| 572 |
|
| 573 |
if pre_request_callback:
|
| 574 |
try:
|
|
@@ -580,15 +669,14 @@ class RotatingClient:
|
|
| 580 |
lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
|
| 581 |
|
| 582 |
response = await litellm.acompletion(
|
| 583 |
-
api_key=current_key,
|
| 584 |
**litellm_kwargs,
|
| 585 |
logger_fn=self._litellm_logger_callback
|
| 586 |
)
|
| 587 |
|
| 588 |
-
lib_logger.info(f"Stream connection established for
|
| 589 |
|
| 590 |
key_acquired = False
|
| 591 |
-
stream_generator = self._safe_streaming_wrapper(response,
|
| 592 |
|
| 593 |
async for chunk in stream_generator:
|
| 594 |
yield chunk
|
|
@@ -618,7 +706,7 @@ class RotatingClient:
|
|
| 618 |
|
| 619 |
# Now, log the failure with the extracted raw response.
|
| 620 |
log_failure(
|
| 621 |
-
api_key=
|
| 622 |
model=model,
|
| 623 |
attempt=attempt + 1,
|
| 624 |
error=e,
|
|
@@ -633,7 +721,7 @@ class RotatingClient:
|
|
| 633 |
|
| 634 |
if "quota" in error_message_text.lower() or "resource_exhausted" in error_status.lower():
|
| 635 |
consecutive_quota_failures += 1
|
| 636 |
-
lib_logger.warning(f"
|
| 637 |
|
| 638 |
quota_value = "N/A"
|
| 639 |
quota_id = "N/A"
|
|
@@ -648,11 +736,11 @@ class RotatingClient:
|
|
| 648 |
if quota_value != "N/A" and quota_id != "N/A":
|
| 649 |
break
|
| 650 |
|
| 651 |
-
await self.usage_manager.record_failure(
|
| 652 |
|
| 653 |
if consecutive_quota_failures >= 3:
|
| 654 |
console_log_message = (
|
| 655 |
-
f"Terminating stream for
|
| 656 |
f"This is now considered a fatal input data error. ID: {quota_id}, Limit: {quota_value}."
|
| 657 |
)
|
| 658 |
client_error_message = (
|
|
@@ -668,31 +756,31 @@ class RotatingClient:
|
|
| 668 |
|
| 669 |
else:
|
| 670 |
# [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
|
| 671 |
-
lib_logger.warning(f"Quota error on
|
| 672 |
break
|
| 673 |
|
| 674 |
else:
|
| 675 |
consecutive_quota_failures = 0
|
| 676 |
# [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
|
| 677 |
-
lib_logger.warning(f"
|
| 678 |
|
| 679 |
if classified_error.error_type == 'rate_limit' and classified_error.status_code == 429:
|
| 680 |
cooldown_duration = classified_error.retry_after or 60
|
| 681 |
await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
|
| 682 |
lib_logger.warning(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
|
| 683 |
|
| 684 |
-
await self.usage_manager.record_failure(
|
| 685 |
break
|
| 686 |
|
| 687 |
except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
|
| 688 |
consecutive_quota_failures = 0
|
| 689 |
last_exception = e
|
| 690 |
-
log_failure(api_key=
|
| 691 |
classified_error = classify_error(e)
|
| 692 |
-
await self.usage_manager.record_failure(
|
| 693 |
|
| 694 |
if attempt >= self.max_retries - 1:
|
| 695 |
-
lib_logger.warning(f"
|
| 696 |
# [MODIFIED] Do not yield to the client here.
|
| 697 |
break
|
| 698 |
|
|
@@ -703,17 +791,17 @@ class RotatingClient:
|
|
| 703 |
break
|
| 704 |
|
| 705 |
error_message = str(e).split('\n')[0]
|
| 706 |
-
lib_logger.warning(f"
|
| 707 |
await asyncio.sleep(wait_time)
|
| 708 |
continue
|
| 709 |
|
| 710 |
except Exception as e:
|
| 711 |
consecutive_quota_failures = 0
|
| 712 |
last_exception = e
|
| 713 |
-
log_failure(api_key=
|
| 714 |
classified_error = classify_error(e)
|
| 715 |
|
| 716 |
-
lib_logger.warning(f"
|
| 717 |
|
| 718 |
if classified_error.status_code == 429:
|
| 719 |
cooldown_duration = classified_error.retry_after or 60
|
|
@@ -724,12 +812,12 @@ class RotatingClient:
|
|
| 724 |
raise last_exception
|
| 725 |
|
| 726 |
# [MODIFIED] Do not yield to the client here.
|
| 727 |
-
await self.usage_manager.record_failure(
|
| 728 |
break
|
| 729 |
|
| 730 |
finally:
|
| 731 |
-
if key_acquired and
|
| 732 |
-
await self.usage_manager.release_key(
|
| 733 |
|
| 734 |
final_error_message = "Failed to complete the streaming request: No available API keys after rotation or global timeout exceeded."
|
| 735 |
if last_exception:
|
|
|
|
| 24 |
from .providers import PROVIDER_PLUGINS
|
| 25 |
from .request_sanitizer import sanitize_request_payload
|
| 26 |
from .cooldown_manager import CooldownManager
|
| 27 |
+
from .credential_manager import CredentialManager
|
| 28 |
+
from .background_refresher import BackgroundRefresher
|
| 29 |
|
| 30 |
class StreamedAPIError(Exception):
|
| 31 |
"""Custom exception to signal an API error received over a stream."""
|
|
|
|
| 41 |
def __init__(
|
| 42 |
self,
|
| 43 |
api_keys: Dict[str, List[str]],
|
| 44 |
+
oauth_credentials: Dict[str, List[str]],
|
| 45 |
max_retries: int = 2,
|
| 46 |
usage_file_path: str = "key_usage.json",
|
| 47 |
configure_logging: bool = True,
|
| 48 |
global_timeout: int = 30,
|
| 49 |
abort_on_callback_error: bool = True,
|
| 50 |
+
litellm_provider_params: Optional[Dict[str, Any]] = None, # [NEW]
|
| 51 |
ignore_models: Optional[Dict[str, List[str]]] = None
|
| 52 |
):
|
| 53 |
os.environ["LITELLM_LOG"] = "ERROR"
|
|
|
|
| 67 |
if not api_keys:
|
| 68 |
raise ValueError("API keys dictionary cannot be empty.")
|
| 69 |
self.api_keys = api_keys
|
| 70 |
+
self.credential_manager = CredentialManager(oauth_credentials)
|
| 71 |
+
self.oauth_credentials = self.credential_manager.discover_and_prepare()
|
| 72 |
+
self.background_refresher = BackgroundRefresher(self)
|
| 73 |
+
self.oauth_providers = set(self.oauth_credentials.keys())
|
| 74 |
+
|
| 75 |
+
all_credentials = {}
|
| 76 |
+
for provider, keys in api_keys.items():
|
| 77 |
+
all_credentials.setdefault(provider, []).extend(keys)
|
| 78 |
+
for provider, paths in self.oauth_credentials.items():
|
| 79 |
+
all_credentials.setdefault(provider, []).extend(paths)
|
| 80 |
+
self.all_credentials = all_credentials
|
| 81 |
+
|
| 82 |
self.max_retries = max_retries
|
| 83 |
self.global_timeout = global_timeout
|
| 84 |
self.abort_on_callback_error = abort_on_callback_error
|
|
|
|
| 89 |
self.http_client = httpx.AsyncClient()
|
| 90 |
self.all_providers = AllProviders()
|
| 91 |
self.cooldown_manager = CooldownManager()
|
| 92 |
+
self.litellm_provider_params = litellm_provider_params or {}
|
| 93 |
self.ignore_models = ignore_models or {}
|
| 94 |
|
| 95 |
def _is_model_ignored(self, provider: str, model_id: str) -> bool:
|
|
|
|
| 208 |
|
| 209 |
return kwargs
|
| 210 |
|
| 211 |
+
def get_oauth_credentials(self) -> Dict[str, List[str]]:
|
| 212 |
+
return self.oauth_credentials
|
| 213 |
+
|
| 214 |
def _get_provider_instance(self, provider_name: str):
|
| 215 |
"""Lazily initializes and returns a provider instance."""
|
| 216 |
if provider_name not in self._provider_instances:
|
|
|
|
| 358 |
raise ValueError("'model' is a required parameter.")
|
| 359 |
|
| 360 |
provider = model.split('/')[0]
|
| 361 |
+
if provider not in self.all_credentials:
|
| 362 |
+
raise ValueError(f"No API keys or OAuth credentials configured for provider: {provider}")
|
| 363 |
|
| 364 |
# Establish a global deadline for the entire request lifecycle.
|
| 365 |
deadline = time.time() + self.global_timeout
|
|
|
|
| 367 |
# Create a mutable copy of the keys and shuffle it to ensure
|
| 368 |
# that the key selection is randomized, which is crucial when
|
| 369 |
# multiple keys have the same usage stats.
|
| 370 |
+
credentials_for_provider = list(self.all_credentials[provider])
|
| 371 |
+
random.shuffle(credentials_for_provider)
|
| 372 |
|
| 373 |
+
tried_creds = set()
|
| 374 |
last_exception = None
|
| 375 |
kwargs = self._convert_model_params(**kwargs)
|
| 376 |
+
|
| 377 |
+
# The main rotation loop. It continues as long as there are untried credentials and the global deadline has not been exceeded.
|
| 378 |
+
while len(tried_creds) < len(credentials_for_provider) and time.time() < deadline:
|
| 379 |
+
current_cred = None
|
| 380 |
key_acquired = False
|
| 381 |
try:
|
| 382 |
# Check for a provider-wide cooldown first.
|
|
|
|
| 392 |
lib_logger.warning(f"Provider {provider} is in cooldown. Waiting for {remaining_cooldown:.2f} seconds.")
|
| 393 |
await asyncio.sleep(remaining_cooldown)
|
| 394 |
|
| 395 |
+
creds_to_try = [c for c in credentials_for_provider if c not in tried_creds]
|
| 396 |
+
if not creds_to_try:
|
| 397 |
break
|
| 398 |
|
| 399 |
+
lib_logger.info(f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{len(credentials_for_provider)}")
|
| 400 |
+
current_cred = await self.usage_manager.acquire_key(
|
| 401 |
+
available_keys=creds_to_try,
|
| 402 |
model=model,
|
| 403 |
deadline=deadline
|
| 404 |
)
|
| 405 |
key_acquired = True
|
| 406 |
+
tried_creds.add(current_cred)
|
| 407 |
|
| 408 |
litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
+
# [NEW] Merge provider-specific params
|
| 411 |
+
if provider in self.litellm_provider_params:
|
| 412 |
+
litellm_kwargs["litellm_params"] = {
|
| 413 |
+
**self.litellm_provider_params[provider],
|
| 414 |
+
**litellm_kwargs.get("litellm_params", {})
|
| 415 |
+
}
|
| 416 |
|
| 417 |
+
provider_plugin = self._get_provider_instance(provider)
|
| 418 |
+
if provider_plugin and provider_plugin.has_custom_logic():
|
| 419 |
+
lib_logger.debug(f"Provider '{provider}' has custom logic. Delegating call.")
|
| 420 |
+
litellm_kwargs["credential_identifier"] = current_cred
|
| 421 |
+
|
| 422 |
+
# The plugin handles the entire call, including retries on 401, etc.
|
| 423 |
+
# The main retry loop here is for key rotation on other errors.
|
| 424 |
+
response = await provider_plugin.acompletion(self.http_client, **litellm_kwargs)
|
| 425 |
+
|
| 426 |
+
# For non-streaming, success is immediate
|
| 427 |
+
if not kwargs.get("stream"):
|
| 428 |
+
await self.usage_manager.record_success(current_cred, model, response)
|
| 429 |
+
await self.usage_manager.release_key(current_cred, model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
key_acquired = False
|
| 431 |
return response
|
| 432 |
+
else:
|
| 433 |
+
# For streaming, wrap the response and return
|
| 434 |
+
key_acquired = False
|
| 435 |
+
stream_generator = self._safe_streaming_wrapper(response, current_cred, model, request)
|
| 436 |
+
async for chunk in stream_generator:
|
| 437 |
+
yield chunk
|
| 438 |
+
return
|
| 439 |
+
|
| 440 |
+
else: # This is the standard API Key / litellm-handled provider logic
|
| 441 |
+
is_oauth = provider in self.oauth_providers
|
| 442 |
+
if is_oauth: # Standard OAuth provider (not custom)
|
| 443 |
+
# ... (logic to set headers) ...
|
| 444 |
+
pass
|
| 445 |
+
else: # API Key
|
| 446 |
+
litellm_kwargs["api_key"] = current_cred
|
| 447 |
+
|
| 448 |
+
provider_instance = self._get_provider_instance(provider)
|
| 449 |
+
if provider_instance:
|
| 450 |
+
if "safety_settings" in litellm_kwargs:
|
| 451 |
+
converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
|
| 452 |
+
if converted_settings is not None:
|
| 453 |
+
litellm_kwargs["safety_settings"] = converted_settings
|
| 454 |
+
else:
|
| 455 |
+
del litellm_kwargs["safety_settings"]
|
| 456 |
+
|
| 457 |
+
if provider == "gemini" and provider_instance:
|
| 458 |
+
provider_instance.handle_thinking_parameter(litellm_kwargs, model)
|
| 459 |
|
| 460 |
+
if "gemma-3" in model and "messages" in litellm_kwargs:
|
| 461 |
+
litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
|
| 462 |
+
|
| 463 |
+
litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
|
| 464 |
+
|
| 465 |
+
for attempt in range(self.max_retries):
|
| 466 |
+
try:
|
| 467 |
+
lib_logger.info(f"Attempting call with credential ...{current_cred[-6:]} (Attempt {attempt + 1}/{self.max_retries})")
|
| 468 |
+
|
| 469 |
+
if pre_request_callback:
|
| 470 |
+
try:
|
| 471 |
+
await pre_request_callback(request, litellm_kwargs)
|
| 472 |
+
except Exception as e:
|
| 473 |
+
if self.abort_on_callback_error:
|
| 474 |
+
raise PreRequestCallbackError(f"Pre-request callback failed: {e}") from e
|
| 475 |
+
else:
|
| 476 |
+
lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
|
| 477 |
+
|
| 478 |
+
response = await api_call(
|
| 479 |
+
**litellm_kwargs,
|
| 480 |
+
logger_fn=self._litellm_logger_callback
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
await self.usage_manager.record_success(current_cred, model, response)
|
| 484 |
+
await self.usage_manager.release_key(current_cred, model)
|
| 485 |
+
key_acquired = False
|
| 486 |
+
return response
|
| 487 |
+
|
| 488 |
+
except litellm.RateLimitError as e:
|
| 489 |
+
last_exception = e
|
| 490 |
+
log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
|
| 491 |
+
classified_error = classify_error(e)
|
| 492 |
+
|
| 493 |
+
# Extract a clean error message for the user-facing log
|
| 494 |
error_message = str(e).split('\n')[0]
|
| 495 |
+
lib_logger.info(f"Key ...{current_cred[-6:]} hit rate limit for model {model}. Reason: '{error_message}'. Rotating key.")
|
| 496 |
+
|
| 497 |
+
if classified_error.status_code == 429:
|
| 498 |
+
cooldown_duration = classified_error.retry_after or 60
|
| 499 |
+
await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
|
| 500 |
+
lib_logger.warning(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
|
| 501 |
+
|
| 502 |
+
await self.usage_manager.record_failure(current_cred, model, classified_error)
|
| 503 |
+
lib_logger.warning(f"Key ...{current_cred[-6:]} encountered a rate limit. Trying next key.")
|
| 504 |
break # Move to the next key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
+
except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
|
| 507 |
+
last_exception = e
|
| 508 |
+
log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
|
| 509 |
+
classified_error = classify_error(e)
|
| 510 |
+
await self.usage_manager.record_failure(current_cred, model, classified_error)
|
| 511 |
|
| 512 |
+
if attempt >= self.max_retries - 1:
|
| 513 |
+
error_message = str(e).split('\n')[0]
|
| 514 |
+
lib_logger.warning(f"Key ...{current_cred[-6:]} failed after max retries for model {model} due to a server error. Reason: '{error_message}'. Rotating key.")
|
| 515 |
+
break # Move to the next key
|
| 516 |
+
|
| 517 |
+
# For temporary errors, wait before retrying with the same key.
|
| 518 |
+
wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
|
| 519 |
+
remaining_budget = deadline - time.time()
|
| 520 |
+
|
| 521 |
+
# If the required wait time exceeds the budget, don't wait; rotate to the next key immediately.
|
| 522 |
+
if wait_time > remaining_budget:
|
| 523 |
+
lib_logger.warning(f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early.")
|
| 524 |
+
break
|
| 525 |
+
|
| 526 |
+
error_message = str(e).split('\n')[0]
|
| 527 |
+
lib_logger.warning(f"Key ...{current_cred[-6:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s.")
|
| 528 |
+
await asyncio.sleep(wait_time)
|
| 529 |
+
continue # Retry with the same key
|
| 530 |
+
|
| 531 |
+
except Exception as e:
|
| 532 |
+
last_exception = e
|
| 533 |
+
log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
|
| 534 |
+
|
| 535 |
+
if request and await request.is_disconnected():
|
| 536 |
+
lib_logger.warning(f"Client disconnected. Aborting retries for key ...{current_cred[-6:]}.")
|
| 537 |
+
raise last_exception
|
| 538 |
+
|
| 539 |
+
classified_error = classify_error(e)
|
| 540 |
+
error_message = str(e).split('\n')[0]
|
| 541 |
+
lib_logger.warning(f"Key ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message}. Rotating key.")
|
| 542 |
+
if classified_error.status_code == 429:
|
| 543 |
+
cooldown_duration = classified_error.retry_after or 60
|
| 544 |
+
await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
|
| 545 |
+
lib_logger.warning(f"IP-based rate limit detected for {provider} from generic exception. Starting a {cooldown_duration}-second global cooldown.")
|
| 546 |
+
|
| 547 |
+
if classified_error.error_type in ['invalid_request', 'context_window_exceeded', 'authentication']:
|
| 548 |
+
# For these errors, we should not retry with other keys.
|
| 549 |
+
raise last_exception
|
| 550 |
+
|
| 551 |
+
await self.usage_manager.record_failure(current_cred, model, classified_error)
|
| 552 |
+
break # Try next key for other errors
|
| 553 |
finally:
|
| 554 |
+
if key_acquired and current_cred:
|
| 555 |
+
await self.usage_manager.release_key(current_cred, model)
|
| 556 |
|
| 557 |
if last_exception:
|
| 558 |
# Log the final error but do not raise it, as per the new requirement.
|
|
|
|
| 568 |
provider = model.split('/')[0]
|
| 569 |
|
| 570 |
# Create a mutable copy of the keys and shuffle it.
|
| 571 |
+
credentials_for_provider = list(self.all_credentials[provider])
|
| 572 |
+
random.shuffle(credentials_for_provider)
|
| 573 |
|
| 574 |
deadline = time.time() + self.global_timeout
|
| 575 |
+
tried_creds = set()
|
| 576 |
last_exception = None
|
| 577 |
kwargs = self._convert_model_params(**kwargs)
|
| 578 |
|
| 579 |
consecutive_quota_failures = 0
|
| 580 |
|
| 581 |
try:
|
| 582 |
+
while len(tried_creds) < len(credentials_for_provider) and time.time() < deadline:
|
| 583 |
+
current_cred = None
|
| 584 |
key_acquired = False
|
| 585 |
try:
|
| 586 |
if await self.cooldown_manager.is_cooling_down(provider):
|
|
|
|
| 592 |
lib_logger.warning(f"Provider {provider} is in a global cooldown. All requests to this provider will be paused for {remaining_cooldown:.2f} seconds.")
|
| 593 |
await asyncio.sleep(remaining_cooldown)
|
| 594 |
|
| 595 |
+
creds_to_try = [c for c in credentials_for_provider if c not in tried_creds]
|
| 596 |
+
if not creds_to_try:
|
| 597 |
+
lib_logger.warning(f"All credentials for provider {provider} have been tried. No more credentials to rotate to.")
|
| 598 |
break
|
| 599 |
|
| 600 |
+
lib_logger.info(f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{len(credentials_for_provider)}")
|
| 601 |
+
current_cred = await self.usage_manager.acquire_key(
|
| 602 |
+
available_keys=creds_to_try,
|
| 603 |
model=model,
|
| 604 |
deadline=deadline
|
| 605 |
)
|
| 606 |
key_acquired = True
|
| 607 |
+
tried_creds.add(current_cred)
|
| 608 |
|
| 609 |
litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
|
| 610 |
+
|
| 611 |
+
# [NEW] Merge provider-specific params
|
| 612 |
+
if provider in self.litellm_provider_params:
|
| 613 |
+
litellm_kwargs["litellm_params"] = {
|
| 614 |
+
**self.litellm_provider_params[provider],
|
| 615 |
+
**litellm_kwargs.get("litellm_params", {})
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
provider_plugin = self._get_provider_instance(provider)
|
| 619 |
+
if provider_plugin and provider_plugin.has_custom_logic():
|
| 620 |
+
lib_logger.debug(f"Provider '{provider}' has custom logic. Delegating call.")
|
| 621 |
+
litellm_kwargs["credential_identifier"] = current_cred
|
| 622 |
+
|
| 623 |
+
# The plugin handles the entire call, including retries on 401, etc.
|
| 624 |
+
# The main retry loop here is for key rotation on other errors.
|
| 625 |
+
response = await provider_plugin.acompletion(self.http_client, **litellm_kwargs)
|
| 626 |
+
|
| 627 |
+
key_acquired = False
|
| 628 |
+
stream_generator = self._safe_streaming_wrapper(response, current_cred, model, request)
|
| 629 |
+
async for chunk in stream_generator:
|
| 630 |
+
yield chunk
|
| 631 |
+
return
|
| 632 |
+
|
| 633 |
+
else: # This is the standard API Key / litellm-handled provider logic
|
| 634 |
+
is_oauth = provider in self.oauth_providers
|
| 635 |
+
if is_oauth: # Standard OAuth provider (not custom)
|
| 636 |
+
# ... (logic to set headers) ...
|
| 637 |
+
pass
|
| 638 |
+
else: # API Key
|
| 639 |
+
litellm_kwargs["api_key"] = current_cred
|
| 640 |
+
|
| 641 |
provider_instance = self._get_provider_instance(provider)
|
| 642 |
if provider_instance:
|
| 643 |
if "safety_settings" in litellm_kwargs:
|
|
|
|
| 657 |
|
| 658 |
for attempt in range(self.max_retries):
|
| 659 |
try:
|
| 660 |
+
lib_logger.info(f"Attempting stream with credential ...{current_cred[-6:]} (Attempt {attempt + 1}/{self.max_retries})")
|
| 661 |
|
| 662 |
if pre_request_callback:
|
| 663 |
try:
|
|
|
|
| 669 |
lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
|
| 670 |
|
| 671 |
response = await litellm.acompletion(
|
|
|
|
| 672 |
**litellm_kwargs,
|
| 673 |
logger_fn=self._litellm_logger_callback
|
| 674 |
)
|
| 675 |
|
| 676 |
+
lib_logger.info(f"Stream connection established for credential ...{current_cred[-6:]}. Processing response.")
|
| 677 |
|
| 678 |
key_acquired = False
|
| 679 |
+
stream_generator = self._safe_streaming_wrapper(response, current_cred, model, request)
|
| 680 |
|
| 681 |
async for chunk in stream_generator:
|
| 682 |
yield chunk
|
|
|
|
| 706 |
|
| 707 |
# Now, log the failure with the extracted raw response.
|
| 708 |
log_failure(
|
| 709 |
+
api_key=current_cred,
|
| 710 |
model=model,
|
| 711 |
attempt=attempt + 1,
|
| 712 |
error=e,
|
|
|
|
| 721 |
|
| 722 |
if "quota" in error_message_text.lower() or "resource_exhausted" in error_status.lower():
|
| 723 |
consecutive_quota_failures += 1
|
| 724 |
+
lib_logger.warning(f"Credential ...{current_cred[-6:]} hit a quota limit. This is consecutive failure #{consecutive_quota_failures} for this request.")
|
| 725 |
|
| 726 |
quota_value = "N/A"
|
| 727 |
quota_id = "N/A"
|
|
|
|
| 736 |
if quota_value != "N/A" and quota_id != "N/A":
|
| 737 |
break
|
| 738 |
|
| 739 |
+
await self.usage_manager.record_failure(current_cred, model, classified_error)
|
| 740 |
|
| 741 |
if consecutive_quota_failures >= 3:
|
| 742 |
console_log_message = (
|
| 743 |
+
f"Terminating stream for credential ...{current_cred[-6:]} due to 3rd consecutive quota error. "
|
| 744 |
f"This is now considered a fatal input data error. ID: {quota_id}, Limit: {quota_value}."
|
| 745 |
)
|
| 746 |
client_error_message = (
|
|
|
|
| 756 |
|
| 757 |
else:
|
| 758 |
# [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
|
| 759 |
+
lib_logger.warning(f"Quota error on credential ...{current_cred[-6:]} (failure {consecutive_quota_failures}/3). Rotating key silently.")
|
| 760 |
break
|
| 761 |
|
| 762 |
else:
|
| 763 |
consecutive_quota_failures = 0
|
| 764 |
# [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
|
| 765 |
+
lib_logger.warning(f"Credential ...{current_cred[-6:]} encountered a recoverable error ({classified_error.error_type}) during stream. Rotating key silently.")
|
| 766 |
|
| 767 |
if classified_error.error_type == 'rate_limit' and classified_error.status_code == 429:
|
| 768 |
cooldown_duration = classified_error.retry_after or 60
|
| 769 |
await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
|
| 770 |
lib_logger.warning(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
|
| 771 |
|
| 772 |
+
await self.usage_manager.record_failure(current_cred, model, classified_error)
|
| 773 |
break
|
| 774 |
|
| 775 |
except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
|
| 776 |
consecutive_quota_failures = 0
|
| 777 |
last_exception = e
|
| 778 |
+
log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
|
| 779 |
classified_error = classify_error(e)
|
| 780 |
+
await self.usage_manager.record_failure(current_cred, model, classified_error)
|
| 781 |
|
| 782 |
if attempt >= self.max_retries - 1:
|
| 783 |
+
lib_logger.warning(f"Credential ...{current_cred[-6:]} failed after max retries for model {model} due to a server error. Rotating key silently.")
|
| 784 |
# [MODIFIED] Do not yield to the client here.
|
| 785 |
break
|
| 786 |
|
|
|
|
| 791 |
break
|
| 792 |
|
| 793 |
error_message = str(e).split('\n')[0]
|
| 794 |
+
lib_logger.warning(f"Credential ...{current_cred[-6:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s.")
|
| 795 |
await asyncio.sleep(wait_time)
|
| 796 |
continue
|
| 797 |
|
| 798 |
except Exception as e:
|
| 799 |
consecutive_quota_failures = 0
|
| 800 |
last_exception = e
|
| 801 |
+
log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
|
| 802 |
classified_error = classify_error(e)
|
| 803 |
|
| 804 |
+
lib_logger.warning(f"Credential ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {str(e)}. Rotating key.")
|
| 805 |
|
| 806 |
if classified_error.status_code == 429:
|
| 807 |
cooldown_duration = classified_error.retry_after or 60
|
|
|
|
| 812 |
raise last_exception
|
| 813 |
|
| 814 |
# [MODIFIED] Do not yield to the client here.
|
| 815 |
+
await self.usage_manager.record_failure(current_cred, model, classified_error)
|
| 816 |
break
|
| 817 |
|
| 818 |
finally:
|
| 819 |
+
if key_acquired and current_cred:
|
| 820 |
+
await self.usage_manager.release_key(current_cred, model)
|
| 821 |
|
| 822 |
final_error_message = "Failed to complete the streaming request: No available API keys after rotation or global timeout exceeded."
|
| 823 |
if last_exception:
|
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
lib_logger = logging.getLogger('rotator_library')
|
| 8 |
+
|
| 9 |
+
OAUTH_BASE_DIR = Path.cwd() / "oauth_creds"
|
| 10 |
+
OAUTH_BASE_DIR.mkdir(exist_ok=True)
|
| 11 |
+
|
| 12 |
+
# Standard paths where tools like `gemini login` store credentials.
|
| 13 |
+
DEFAULT_OAUTH_PATHS = {
|
| 14 |
+
"gemini": Path.home() / ".gemini" / "oauth_creds.json",
|
| 15 |
+
"qwen": Path.home() / ".qwen" / "oauth_creds.json",
|
| 16 |
+
# Add other providers like 'claude' here if they have a standard CLI path
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
class CredentialManager:
|
| 20 |
+
"""
|
| 21 |
+
Discovers OAuth credential files from standard locations, copies them locally,
|
| 22 |
+
and updates the configuration to use the local paths.
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, oauth_config: Dict[str, List[str]]):
|
| 25 |
+
self.oauth_config = oauth_config
|
| 26 |
+
|
| 27 |
+
def discover_and_prepare(self) -> Dict[str, List[str]]:
|
| 28 |
+
"""
|
| 29 |
+
Processes the initial OAuth config. If a path is empty, it tries to
|
| 30 |
+
discover the file from a default location. It then copies the file
|
| 31 |
+
locally if it doesn't already exist and returns the updated config
|
| 32 |
+
pointing to the local paths.
|
| 33 |
+
"""
|
| 34 |
+
updated_config = {}
|
| 35 |
+
for provider, paths in self.oauth_config.items():
|
| 36 |
+
updated_paths = []
|
| 37 |
+
for i, path_str in enumerate(paths):
|
| 38 |
+
account_id = i + 1
|
| 39 |
+
source_path = self._resolve_source_path(provider, path_str)
|
| 40 |
+
|
| 41 |
+
if not source_path or not source_path.exists():
|
| 42 |
+
lib_logger.warning(f"Could not find OAuth file for {provider} account #{account_id}. Skipping.")
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
local_filename = f"{provider}_oauth_{account_id}.json"
|
| 46 |
+
local_path = OAUTH_BASE_DIR / local_filename
|
| 47 |
+
|
| 48 |
+
if not local_path.exists():
|
| 49 |
+
try:
|
| 50 |
+
shutil.copy(source_path, local_path)
|
| 51 |
+
lib_logger.info(f"Copied '{source_path}' to local credentials at '{local_path}'.")
|
| 52 |
+
except Exception as e:
|
| 53 |
+
lib_logger.error(f"Failed to copy OAuth file for {provider} account #{account_id}: {e}")
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
updated_paths.append(str(local_path.resolve()))
|
| 57 |
+
|
| 58 |
+
if updated_paths:
|
| 59 |
+
updated_config[provider] = updated_paths
|
| 60 |
+
|
| 61 |
+
return updated_config
|
| 62 |
+
|
| 63 |
+
def _resolve_source_path(self, provider: str, specified_path: Optional[str]) -> Optional[Path]:
|
| 64 |
+
"""Determines the source path for a credential file."""
|
| 65 |
+
if specified_path:
|
| 66 |
+
# If a path is given, use it directly.
|
| 67 |
+
return Path(specified_path).expanduser()
|
| 68 |
+
|
| 69 |
+
# If no path is given, try the default location.
|
| 70 |
+
return DEFAULT_OAUTH_PATHS.get(provider)
|
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import re
|
|
|
|
| 2 |
from typing import Optional, Dict, Any
|
|
|
|
| 3 |
|
| 4 |
from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError, Timeout, ContextWindowExceededError
|
| 5 |
|
|
@@ -22,8 +24,6 @@ class ClassifiedError:
|
|
| 22 |
def __str__(self):
|
| 23 |
return f"ClassifiedError(type={self.error_type}, status={self.status_code}, retry_after={self.retry_after}, original_exc={self.original_exception})"
|
| 24 |
|
| 25 |
-
import json
|
| 26 |
-
|
| 27 |
def get_retry_after(error: Exception) -> Optional[int]:
|
| 28 |
"""
|
| 29 |
Extracts the 'retry-after' duration in seconds from an exception message.
|
|
@@ -80,9 +80,24 @@ def get_retry_after(error: Exception) -> Optional[int]:
|
|
| 80 |
def classify_error(e: Exception) -> ClassifiedError:
|
| 81 |
"""
|
| 82 |
Classifies an exception into a structured ClassifiedError object.
|
|
|
|
| 83 |
"""
|
| 84 |
status_code = getattr(e, 'status_code', None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
|
|
|
|
|
|
|
|
|
| 86 |
if isinstance(e, PreRequestCallbackError):
|
| 87 |
return ClassifiedError(
|
| 88 |
error_type='pre_request_callback_error',
|
|
|
|
| 1 |
import re
|
| 2 |
+
import json
|
| 3 |
from typing import Optional, Dict, Any
|
| 4 |
+
import httpx
|
| 5 |
|
| 6 |
from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError, Timeout, ContextWindowExceededError
|
| 7 |
|
|
|
|
| 24 |
def __str__(self):
|
| 25 |
return f"ClassifiedError(type={self.error_type}, status={self.status_code}, retry_after={self.retry_after}, original_exc={self.original_exception})"
|
| 26 |
|
|
|
|
|
|
|
| 27 |
def get_retry_after(error: Exception) -> Optional[int]:
|
| 28 |
"""
|
| 29 |
Extracts the 'retry-after' duration in seconds from an exception message.
|
|
|
|
| 80 |
def classify_error(e: Exception) -> ClassifiedError:
|
| 81 |
"""
|
| 82 |
Classifies an exception into a structured ClassifiedError object.
|
| 83 |
+
Now handles both litellm and httpx exceptions.
|
| 84 |
"""
|
| 85 |
status_code = getattr(e, 'status_code', None)
|
| 86 |
+
if isinstance(e, httpx.HTTPStatusError): # [NEW] Handle httpx errors first
|
| 87 |
+
status_code = e.response.status_code
|
| 88 |
+
if status_code == 401:
|
| 89 |
+
return ClassifiedError(error_type='authentication', original_exception=e, status_code=status_code)
|
| 90 |
+
if status_code == 429:
|
| 91 |
+
retry_after = get_retry_after(e)
|
| 92 |
+
return ClassifiedError(error_type='rate_limit', original_exception=e, status_code=status_code, retry_after=retry_after)
|
| 93 |
+
if 400 <= status_code < 500:
|
| 94 |
+
return ClassifiedError(error_type='invalid_request', original_exception=e, status_code=status_code)
|
| 95 |
+
if 500 <= status_code:
|
| 96 |
+
return ClassifiedError(error_type='server_error', original_exception=e, status_code=status_code)
|
| 97 |
|
| 98 |
+
if isinstance(e, (httpx.TimeoutException, httpx.ConnectError, httpx.NetworkError)): # [NEW]
|
| 99 |
+
return ClassifiedError(error_type='api_connection', original_exception=e, status_code=status_code)
|
| 100 |
+
|
| 101 |
if isinstance(e, PreRequestCallbackError):
|
| 102 |
return ClassifiedError(
|
| 103 |
error_type='pre_request_callback_error',
|
|
@@ -26,9 +26,9 @@ def _register_providers():
|
|
| 26 |
for attribute_name in dir(module):
|
| 27 |
attribute = getattr(module, attribute_name)
|
| 28 |
if isinstance(attribute, type) and issubclass(attribute, ProviderInterface) and attribute is not ProviderInterface:
|
| 29 |
-
#
|
| 30 |
-
provider_name = module_name.replace("_provider", "")
|
| 31 |
# Remap 'nvidia' to 'nvidia_nim' to align with litellm's provider name
|
|
|
|
| 32 |
if provider_name == "nvidia":
|
| 33 |
provider_name = "nvidia_nim"
|
| 34 |
PROVIDER_PLUGINS[provider_name] = attribute
|
|
|
|
| 26 |
for attribute_name in dir(module):
|
| 27 |
attribute = getattr(module, attribute_name)
|
| 28 |
if isinstance(attribute, type) and issubclass(attribute, ProviderInterface) and attribute is not ProviderInterface:
|
| 29 |
+
# Derives 'gemini_cli' from 'gemini_cli_provider.py'
|
|
|
|
| 30 |
# Remap 'nvidia' to 'nvidia_nim' to align with litellm's provider name
|
| 31 |
+
provider_name = module_name.replace("_provider", "")
|
| 32 |
if provider_name == "nvidia":
|
| 33 |
provider_name = "nvidia_nim"
|
| 34 |
PROVIDER_PLUGINS[provider_name] = attribute
|
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/rotator_library/providers/gemini_auth_base.py
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import asyncio
|
| 6 |
+
import logging
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, Any
|
| 9 |
+
|
| 10 |
+
import httpx
|
| 11 |
+
|
| 12 |
+
lib_logger = logging.getLogger('rotator_library')
|
| 13 |
+
|
| 14 |
+
CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
| 15 |
+
CLIENT_SECRET = "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"
|
| 16 |
+
TOKEN_URI = "https://oauth2.googleapis.com/token"
|
| 17 |
+
REFRESH_EXPIRY_BUFFER_SECONDS = 300
|
| 18 |
+
|
| 19 |
+
class GeminiAuthBase:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self._credentials_cache: Dict[str, Dict[str, Any]] = {}
|
| 22 |
+
self._refresh_locks: Dict[str, asyncio.Lock] = {}
|
| 23 |
+
|
| 24 |
+
async def _load_credentials(self, path: str) -> Dict[str, Any]:
|
| 25 |
+
if path in self._credentials_cache:
|
| 26 |
+
return self._credentials_cache[path]
|
| 27 |
+
|
| 28 |
+
async with self._get_lock(path):
|
| 29 |
+
if path in self._credentials_cache:
|
| 30 |
+
return self._credentials_cache[path]
|
| 31 |
+
try:
|
| 32 |
+
with open(path, 'r') as f:
|
| 33 |
+
creds = json.load(f)
|
| 34 |
+
# Handle gcloud-style creds file which nest tokens under "credential"
|
| 35 |
+
if "credential" in creds:
|
| 36 |
+
creds = creds["credential"]
|
| 37 |
+
self._credentials_cache[path] = creds
|
| 38 |
+
return creds
|
| 39 |
+
except Exception as e:
|
| 40 |
+
raise IOError(f"Failed to load Gemini OAuth credentials from '{path}': {e}")
|
| 41 |
+
|
| 42 |
+
async def _save_credentials(self, path: str, creds: Dict[str, Any]):
|
| 43 |
+
self._credentials_cache[path] = creds
|
| 44 |
+
try:
|
| 45 |
+
with open(path, 'w') as f:
|
| 46 |
+
json.dump(creds, f, indent=2)
|
| 47 |
+
except Exception as e:
|
| 48 |
+
lib_logger.error(f"Failed to save updated Gemini OAuth credentials to '{path}': {e}")
|
| 49 |
+
|
| 50 |
+
def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
|
| 51 |
+
expiry = creds.get("token_expiry") # gcloud format
|
| 52 |
+
if not expiry: # gemini-cli format
|
| 53 |
+
expiry_timestamp = creds.get("expiry_date", 0) / 1000
|
| 54 |
+
else:
|
| 55 |
+
expiry_timestamp = time.mktime(time.strptime(expiry, "%Y-%m-%dT%H:%M:%SZ"))
|
| 56 |
+
|
| 57 |
+
return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
|
| 58 |
+
|
| 59 |
+
async def _refresh_token(self, path: str, creds: Dict[str, Any]) -> Dict[str, Any]:
|
| 60 |
+
async with self._get_lock(path):
|
| 61 |
+
if not self._is_token_expired(self._credentials_cache.get(path, creds)):
|
| 62 |
+
return self._credentials_cache.get(path, creds)
|
| 63 |
+
|
| 64 |
+
lib_logger.info(f"Refreshing Gemini OAuth token for '{Path(path).name}'...")
|
| 65 |
+
refresh_token = creds.get("refresh_token")
|
| 66 |
+
if not refresh_token:
|
| 67 |
+
raise ValueError("No refresh_token found in credentials file.")
|
| 68 |
+
|
| 69 |
+
async with httpx.AsyncClient() as client:
|
| 70 |
+
response = await client.post(TOKEN_URI, data={
|
| 71 |
+
"client_id": creds.get("client_id", CLIENT_ID),
|
| 72 |
+
"client_secret": creds.get("client_secret", CLIENT_SECRET),
|
| 73 |
+
"refresh_token": refresh_token,
|
| 74 |
+
"grant_type": "refresh_token",
|
| 75 |
+
})
|
| 76 |
+
response.raise_for_status()
|
| 77 |
+
new_token_data = response.json()
|
| 78 |
+
|
| 79 |
+
creds["access_token"] = new_token_data["access_token"]
|
| 80 |
+
expiry_timestamp = time.time() + new_token_data["expires_in"]
|
| 81 |
+
creds["expiry_date"] = expiry_timestamp * 1000 # gemini-cli format
|
| 82 |
+
creds["token_expiry"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(expiry_timestamp)) # gcloud format
|
| 83 |
+
|
| 84 |
+
await self._save_credentials(path, creds)
|
| 85 |
+
lib_logger.info(f"Successfully refreshed Gemini OAuth token for '{Path(path).name}'.")
|
| 86 |
+
return creds
|
| 87 |
+
|
| 88 |
+
async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
|
| 89 |
+
creds = await self._load_credentials(credential_path)
|
| 90 |
+
if self._is_token_expired(creds):
|
| 91 |
+
creds = await self._refresh_token(credential_path, creds)
|
| 92 |
+
return {"Authorization": f"Bearer {creds['access_token']}"}
|
| 93 |
+
|
| 94 |
+
async def proactively_refresh(self, credential_path: str):
|
| 95 |
+
creds = await self._load_credentials(credential_path)
|
| 96 |
+
if self._is_token_expired(creds):
|
| 97 |
+
await self._refresh_token(credential_path, creds)
|
| 98 |
+
|
| 99 |
+
def _get_lock(self, path: str) -> asyncio.Lock:
|
| 100 |
+
if path not in self._refresh_locks:
|
| 101 |
+
self._refresh_locks[path] = asyncio.Lock()
|
| 102 |
+
return self._refresh_locks[path]
|
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/rotator_library/providers/gemini_cli_provider.py
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import httpx
|
| 5 |
+
import logging
|
| 6 |
+
import time
|
| 7 |
+
from typing import List, Dict, Any, AsyncGenerator, Union, Optional
|
| 8 |
+
from .provider_interface import ProviderInterface
|
| 9 |
+
from .gemini_auth_base import GeminiAuthBase
|
| 10 |
+
import litellm
|
| 11 |
+
import os
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
lib_logger = logging.getLogger('rotator_library')
|
| 15 |
+
|
| 16 |
+
CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal"
|
| 17 |
+
|
| 18 |
+
class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.project_id: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
async def _discover_project_id(self, litellm_params: Dict[str, Any]) -> str:
|
| 24 |
+
"""Discovers the Google Cloud Project ID."""
|
| 25 |
+
if self.project_id:
|
| 26 |
+
return self.project_id
|
| 27 |
+
|
| 28 |
+
# 1. Prioritize explicitly configured project_id
|
| 29 |
+
if litellm_params.get("project_id"):
|
| 30 |
+
self.project_id = litellm_params["project_id"]
|
| 31 |
+
lib_logger.info(f"Using configured Gemini CLI project ID: {self.project_id}")
|
| 32 |
+
return self.project_id
|
| 33 |
+
|
| 34 |
+
# 2. Fallback: Look for .env file in the standard .gemini directory
|
| 35 |
+
try:
|
| 36 |
+
gemini_env_path = Path.home() / ".gemini" / ".env"
|
| 37 |
+
if gemini_env_path.exists():
|
| 38 |
+
with open(gemini_env_path, 'r') as f:
|
| 39 |
+
for line in f:
|
| 40 |
+
if line.startswith("GOOGLE_CLOUD_PROJECT="):
|
| 41 |
+
self.project_id = line.strip().split("=")[1]
|
| 42 |
+
lib_logger.info(f"Discovered Gemini CLI project ID from ~/.gemini/.env: {self.project_id}")
|
| 43 |
+
return self.project_id
|
| 44 |
+
except Exception as e:
|
| 45 |
+
lib_logger.warning(f"Could not read project ID from ~/.gemini/.env: {e}")
|
| 46 |
+
|
| 47 |
+
raise ValueError(
|
| 48 |
+
"Gemini CLI project ID not found. Please set `GEMINI_CLI_PROJECT_ID` in your main .env file "
|
| 49 |
+
"or ensure it is present in `~/.gemini/.env`."
|
| 50 |
+
)
|
| 51 |
+
def has_custom_logic(self) -> bool:
|
| 52 |
+
return True
|
| 53 |
+
|
| 54 |
+
def _transform_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 55 |
+
# As seen in Kilo examples, system prompts are injected into the first user message.
|
| 56 |
+
gemini_contents = []
|
| 57 |
+
system_prompt = ""
|
| 58 |
+
if messages and messages[0].get('role') == 'system':
|
| 59 |
+
system_prompt = messages.pop(0).get('content', '')
|
| 60 |
+
|
| 61 |
+
for msg in messages:
|
| 62 |
+
role = "model" if msg.get("role") == "assistant" else "user"
|
| 63 |
+
content = msg.get("content", "")
|
| 64 |
+
if system_prompt and role == "user":
|
| 65 |
+
content = f"{system_prompt}\n\n{content}"
|
| 66 |
+
system_prompt = "" # Inject only once
|
| 67 |
+
gemini_contents.append({"role": role, "parts": [{"text": content}]})
|
| 68 |
+
return gemini_contents
|
| 69 |
+
|
| 70 |
+
def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str) -> dict:
|
| 71 |
+
response_data = chunk.get('response', chunk)
|
| 72 |
+
candidate = response_data.get('candidates', [{}])[0]
|
| 73 |
+
|
| 74 |
+
delta = {}
|
| 75 |
+
finish_reason = None
|
| 76 |
+
|
| 77 |
+
# Correctly handle reasoning vs. content based on 'thought' flag from Kilo example
|
| 78 |
+
if 'content' in candidate and 'parts' in candidate['content']:
|
| 79 |
+
part = candidate['content']['parts'][0]
|
| 80 |
+
if part.get('text'):
|
| 81 |
+
if part.get('thought') is True:
|
| 82 |
+
# This is a reasoning/thinking step
|
| 83 |
+
delta['reasoning_content'] = part['text']
|
| 84 |
+
else:
|
| 85 |
+
delta['content'] = part['text']
|
| 86 |
+
|
| 87 |
+
raw_finish_reason = candidate.get('finishReason')
|
| 88 |
+
if raw_finish_reason:
|
| 89 |
+
mapping = {'STOP': 'stop', 'MAX_TOKENS': 'length', 'SAFETY': 'content_filter'}
|
| 90 |
+
finish_reason = mapping.get(raw_finish_reason, 'stop')
|
| 91 |
+
|
| 92 |
+
choice = {"index": 0, "delta": delta, "finish_reason": finish_reason}
|
| 93 |
+
|
| 94 |
+
openai_chunk = {
|
| 95 |
+
"choices": [choice], "model": model_id, "object": "chat.completion.chunk",
|
| 96 |
+
"id": f"chatcmpl-geminicli-{time.time()}", "created": int(time.time())
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
if 'usageMetadata' in response_data:
|
| 100 |
+
usage = response_data['usageMetadata']
|
| 101 |
+
openai_chunk["usage"] = {
|
| 102 |
+
"prompt_tokens": usage.get("promptTokenCount", 0),
|
| 103 |
+
"completion_tokens": usage.get("candidatesTokenCount", 0),
|
| 104 |
+
"total_tokens": usage.get("totalTokenCount", 0),
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
return openai_chunk
|
| 108 |
+
|
| 109 |
+
async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
|
| 110 |
+
model = kwargs["model"]
|
| 111 |
+
credential_path = kwargs.pop("credential_identifier")
|
| 112 |
+
auth_header = await self.get_auth_header(credential_path)
|
| 113 |
+
|
| 114 |
+
project_id = await self._discover_project_id(kwargs.get("litellm_params", {}))
|
| 115 |
+
|
| 116 |
+
# Handle :thinking suffix from Kilo example
|
| 117 |
+
model_name = model.split('/')[-1]
|
| 118 |
+
enable_thinking = model_name.endswith(':thinking')
|
| 119 |
+
if enable_thinking:
|
| 120 |
+
model_name = model_name.replace(':thinking', '')
|
| 121 |
+
|
| 122 |
+
gen_config = {
|
| 123 |
+
"temperature": kwargs.get("temperature", 0.7),
|
| 124 |
+
"maxOutputTokens": kwargs.get("max_tokens", 8192),
|
| 125 |
+
}
|
| 126 |
+
if enable_thinking:
|
| 127 |
+
gen_config["thinkingConfig"] = {"thinkingBudget": -1}
|
| 128 |
+
|
| 129 |
+
request_payload = {
|
| 130 |
+
"model": model_name,
|
| 131 |
+
"project": project_id,
|
| 132 |
+
"request": {
|
| 133 |
+
"contents": self._transform_messages(kwargs.get("messages", [])),
|
| 134 |
+
"generationConfig": gen_config,
|
| 135 |
+
},
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
url = f"{CODE_ASSIST_ENDPOINT}:streamGenerateContent"
|
| 139 |
+
|
| 140 |
+
async def stream_handler():
|
| 141 |
+
async with client.stream("POST", url, headers=auth_header, json=request_payload, params={"alt": "sse"}, timeout=600) as response:
|
| 142 |
+
response.raise_for_status()
|
| 143 |
+
async for line in response.aiter_lines():
|
| 144 |
+
if line.startswith('data: '):
|
| 145 |
+
data_str = line[6:]
|
| 146 |
+
if data_str == "[DONE]": break
|
| 147 |
+
try:
|
| 148 |
+
chunk = json.loads(data_str)
|
| 149 |
+
openai_chunk = self._convert_chunk_to_openai(chunk, model)
|
| 150 |
+
yield litellm.ModelResponse(**openai_chunk)
|
| 151 |
+
except json.JSONDecodeError:
|
| 152 |
+
lib_logger.warning(f"Could not decode JSON from Gemini CLI: {line}")
|
| 153 |
+
|
| 154 |
+
if kwargs.get("stream", False):
|
| 155 |
+
return stream_handler()
|
| 156 |
+
else:
|
| 157 |
+
# Accumulate stream for non-streaming response
|
| 158 |
+
chunks = [chunk async for chunk in stream_handler()]
|
| 159 |
+
return litellm.utils.stream_to_completion_response(chunks)
|
| 160 |
+
|
| 161 |
+
# [NEW] Hardcoded model list based on Kilo example
|
| 162 |
+
HARDCODED_MODELS = [
|
| 163 |
+
"gemini-2.5-pro",
|
| 164 |
+
"gemini-2.5-flash",
|
| 165 |
+
"gemini-2.5-flash-lite"
|
| 166 |
+
]
|
| 167 |
+
# Use the shared GeminiAuthBase for auth logic
|
| 168 |
+
# get_models is not applicable for this custom provider
|
| 169 |
+
async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
|
| 170 |
+
"""Returns a hardcoded list of known compatible Gemini CLI models."""
|
| 171 |
+
return [f"gemini_cli/{model_id}" for model_id in HARDCODED_MODELS]
|
|
@@ -1,13 +1,14 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
-
from typing import List, Dict, Any
|
| 3 |
import httpx
|
|
|
|
| 4 |
|
| 5 |
class ProviderInterface(ABC):
|
| 6 |
"""
|
| 7 |
-
An interface for API provider-specific functionality,
|
| 8 |
-
|
| 9 |
"""
|
| 10 |
-
|
| 11 |
@abstractmethod
|
| 12 |
async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
|
| 13 |
"""
|
|
@@ -22,7 +23,25 @@ class ProviderInterface(ABC):
|
|
| 22 |
"""
|
| 23 |
pass
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
"""
|
| 27 |
Converts a generic safety settings dictionary to the provider-specific format.
|
| 28 |
|
|
@@ -33,3 +52,17 @@ class ProviderInterface(ABC):
|
|
| 33 |
A list of provider-specific safety setting objects or None.
|
| 34 |
"""
|
| 35 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
| 3 |
import httpx
|
| 4 |
+
import litellm
|
| 5 |
|
| 6 |
class ProviderInterface(ABC):
|
| 7 |
"""
|
| 8 |
+
An interface for API provider-specific functionality, including model
|
| 9 |
+
discovery and custom API call handling for non-standard providers.
|
| 10 |
"""
|
| 11 |
+
|
| 12 |
@abstractmethod
|
| 13 |
async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
|
| 14 |
"""
|
|
|
|
| 23 |
"""
|
| 24 |
pass
|
| 25 |
|
| 26 |
+
# [NEW] Add methods for providers that need to bypass litellm
|
| 27 |
+
def has_custom_logic(self) -> bool:
|
| 28 |
+
"""
|
| 29 |
+
Returns True if the provider implements its own acompletion/aembedding logic,
|
| 30 |
+
bypassing the standard litellm call.
|
| 31 |
+
"""
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
|
| 35 |
+
"""
|
| 36 |
+
Handles the entire completion call for non-standard providers.
|
| 37 |
+
"""
|
| 38 |
+
raise NotImplementedError(f"{self.__class__.__name__} does not implement custom acompletion.")
|
| 39 |
+
|
| 40 |
+
async def aembedding(self, client: httpx.AsyncClient, **kwargs) -> litellm.EmbeddingResponse:
|
| 41 |
+
"""Handles the entire embedding call for non-standard providers."""
|
| 42 |
+
raise NotImplementedError(f"{self.__class__.__name__} does not implement custom aembedding.")
|
| 43 |
+
|
| 44 |
+
def convert_safety_settings(self, settings: Dict[str, str]) -> Optional[List[Dict[str, Any]]]:
|
| 45 |
"""
|
| 46 |
Converts a generic safety settings dictionary to the provider-specific format.
|
| 47 |
|
|
|
|
| 52 |
A list of provider-specific safety setting objects or None.
|
| 53 |
"""
|
| 54 |
return None
|
| 55 |
+
|
| 56 |
+
# [NEW] Add new methods for OAuth providers
|
| 57 |
+
async def get_auth_header(self, credential_identifier: str) -> Dict[str, str]:
|
| 58 |
+
"""
|
| 59 |
+
For OAuth providers, this method returns the Authorization header.
|
| 60 |
+
For API key providers, this can be a no-op or raise NotImplementedError.
|
| 61 |
+
"""
|
| 62 |
+
raise NotImplementedError("This provider does not support OAuth.")
|
| 63 |
+
|
| 64 |
+
async def proactively_refresh(self, credential_path: str):
|
| 65 |
+
"""
|
| 66 |
+
Proactively refreshes a token if it's nearing expiry.
|
| 67 |
+
"""
|
| 68 |
+
pass
|
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/rotator_library/providers/qwen_auth_base.py
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import asyncio
|
| 6 |
+
import logging
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, Any, Tuple
|
| 9 |
+
|
| 10 |
+
import httpx
|
| 11 |
+
|
| 12 |
+
lib_logger = logging.getLogger('rotator_library')
|
| 13 |
+
|
| 14 |
+
CLIENT_ID = "f0304373b74a44d2b584a3fb70ca9e56"
|
| 15 |
+
TOKEN_ENDPOINT = "https://chat.qwen.ai/api/v1/oauth2/token"
|
| 16 |
+
REFRESH_EXPIRY_BUFFER_SECONDS = 300
|
| 17 |
+
|
| 18 |
+
class QwenAuthBase:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self._credentials_cache: Dict[str, Dict[str, Any]] = {}
|
| 21 |
+
self._refresh_locks: Dict[str, asyncio.Lock] = {}
|
| 22 |
+
|
| 23 |
+
async def _load_credentials(self, path: str) -> Dict[str, Any]:
|
| 24 |
+
if path in self._credentials_cache:
|
| 25 |
+
return self._credentials_cache[path]
|
| 26 |
+
|
| 27 |
+
async with self._get_lock(path):
|
| 28 |
+
if path in self._credentials_cache:
|
| 29 |
+
return self._credentials_cache[path]
|
| 30 |
+
try:
|
| 31 |
+
with open(path, 'r') as f:
|
| 32 |
+
creds = json.load(f)
|
| 33 |
+
self._credentials_cache[path] = creds
|
| 34 |
+
return creds
|
| 35 |
+
except Exception as e:
|
| 36 |
+
raise IOError(f"Failed to load Qwen OAuth credentials from '{path}': {e}")
|
| 37 |
+
|
| 38 |
+
async def _save_credentials(self, path: str, creds: Dict[str, Any]):
|
| 39 |
+
self._credentials_cache[path] = creds
|
| 40 |
+
try:
|
| 41 |
+
with open(path, 'w') as f:
|
| 42 |
+
json.dump(creds, f, indent=2)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
lib_logger.error(f"Failed to save updated Qwen OAuth credentials to '{path}': {e}")
|
| 45 |
+
|
| 46 |
+
def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
|
| 47 |
+
expiry_timestamp = creds.get("expiry_date", 0) / 1000
|
| 48 |
+
return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
|
| 49 |
+
|
| 50 |
+
async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any]:
|
| 51 |
+
async with self._get_lock(path):
|
| 52 |
+
cached_creds = self._credentials_cache.get(path)
|
| 53 |
+
if not force and cached_creds and not self._is_token_expired(cached_creds):
|
| 54 |
+
return cached_creds
|
| 55 |
+
|
| 56 |
+
creds_from_file = await self._load_credentials(path)
|
| 57 |
+
|
| 58 |
+
lib_logger.info(f"Refreshing Qwen OAuth token for '{Path(path).name}'...")
|
| 59 |
+
refresh_token = creds_from_file.get("refresh_token")
|
| 60 |
+
if not refresh_token:
|
| 61 |
+
raise ValueError("No refresh_token found in Qwen credentials file.")
|
| 62 |
+
|
| 63 |
+
async with httpx.AsyncClient() as client:
|
| 64 |
+
response = await client.post(TOKEN_ENDPOINT, data={
|
| 65 |
+
"grant_type": "refresh_token",
|
| 66 |
+
"refresh_token": refresh_token,
|
| 67 |
+
"client_id": CLIENT_ID,
|
| 68 |
+
})
|
| 69 |
+
response.raise_for_status()
|
| 70 |
+
new_token_data = response.json()
|
| 71 |
+
|
| 72 |
+
creds_from_file["access_token"] = new_token_data["access_token"]
|
| 73 |
+
creds_from_file["refresh_token"] = new_token_data.get("refresh_token", creds_from_file["refresh_token"])
|
| 74 |
+
creds_from_file["expiry_date"] = (time.time() + new_token_data["expires_in"]) * 1000
|
| 75 |
+
|
| 76 |
+
await self._save_credentials(path, creds_from_file)
|
| 77 |
+
lib_logger.info(f"Successfully refreshed Qwen OAuth token for '{Path(path).name}'.")
|
| 78 |
+
return creds_from_file
|
| 79 |
+
|
| 80 |
+
async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
|
| 81 |
+
creds = await self._load_credentials(credential_path)
|
| 82 |
+
if self._is_token_expired(creds):
|
| 83 |
+
creds = await self._refresh_token(credential_path)
|
| 84 |
+
return {"Authorization": f"Bearer {creds['access_token']}"}
|
| 85 |
+
|
| 86 |
+
def get_api_details(self, credential_path: str) -> Tuple[str, str]:
|
| 87 |
+
creds = self._credentials_cache[credential_path]
|
| 88 |
+
base_url = creds.get("resource_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
| 89 |
+
if not base_url.startswith("http"):
|
| 90 |
+
base_url = f"https://{base_url}"
|
| 91 |
+
return base_url, creds["access_token"]
|
| 92 |
+
|
| 93 |
+
async def proactively_refresh(self, credential_path: str):
|
| 94 |
+
creds = await self._load_credentials(credential_path)
|
| 95 |
+
if self._is_token_expired(creds):
|
| 96 |
+
await self._refresh_token(credential_path)
|
| 97 |
+
|
| 98 |
+
def _get_lock(self, path: str) -> asyncio.Lock:
|
| 99 |
+
if path not in self._refresh_locks:
|
| 100 |
+
self._refresh_locks[path] = asyncio.Lock()
|
| 101 |
+
return self._refresh_locks[path]
|
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/rotator_library/providers/qwen_code_provider.py
|
| 2 |
+
|
| 3 |
+
import httpx
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Union, AsyncGenerator
|
| 6 |
+
from .provider_interface import ProviderInterface
|
| 7 |
+
from .qwen_auth_base import QwenAuthBase
|
| 8 |
+
import litellm
|
| 9 |
+
|
| 10 |
+
lib_logger = logging.getLogger('rotator_library')
|
| 11 |
+
|
| 12 |
+
# [NEW] Hardcoded model list based on Kilo example
|
| 13 |
+
HARDCODED_MODELS = [
|
| 14 |
+
"qwen3-coder-plus",
|
| 15 |
+
"qwen3-coder-flash"
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
| 19 |
+
def has_custom_logic(self) -> bool:
|
| 20 |
+
return True # We use custom logic to handle 401 retries and stream parsing
|
| 21 |
+
|
| 22 |
+
# [NEW] get_models implementation
|
| 23 |
+
async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
|
| 24 |
+
"""Returns a hardcoded list of known compatible Qwen models for the OpenAI-compatible API."""
|
| 25 |
+
return [f"qwen_code/{model_id}" for model_id in HARDCODED_MODELS]
|
| 26 |
+
|
| 27 |
+
async def _stream_parser(self, stream: AsyncGenerator, model_id: str) -> AsyncGenerator:
|
| 28 |
+
"""Parses the stream from litellm to handle Qwen's <think> tags."""
|
| 29 |
+
async for chunk in stream:
|
| 30 |
+
content = chunk.choices[0].delta.content
|
| 31 |
+
if content and ("<think>" in content or "</think>" in content):
|
| 32 |
+
parts = content.replace("<think>", "||THINK||").replace("</think>", "||/THINK||").split("||")
|
| 33 |
+
for part in parts:
|
| 34 |
+
if not part: continue
|
| 35 |
+
new_chunk = chunk.copy()
|
| 36 |
+
if part.startswith("THINK||"):
|
| 37 |
+
new_chunk.choices[0].delta.reasoning_content = part.replace("THINK||", "")
|
| 38 |
+
new_chunk.choices[0].delta.content = None
|
| 39 |
+
elif part.startswith("/THINK||"):
|
| 40 |
+
continue # Ignore closing tag
|
| 41 |
+
else:
|
| 42 |
+
new_chunk.choices[0].delta.content = part
|
| 43 |
+
new_chunk.choices[0].delta.reasoning_content = None
|
| 44 |
+
yield new_chunk
|
| 45 |
+
else:
|
| 46 |
+
yield chunk
|
| 47 |
+
|
| 48 |
+
async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
|
| 49 |
+
credential_path = kwargs.pop("credential_identifier")
|
| 50 |
+
model = kwargs["model"]
|
| 51 |
+
|
| 52 |
+
async def do_call():
|
| 53 |
+
api_base, access_token = self.get_api_details(credential_path)
|
| 54 |
+
response = await litellm.acompletion(
|
| 55 |
+
**kwargs, api_key=access_token, api_base=api_base
|
| 56 |
+
)
|
| 57 |
+
return response
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
response = await do_call()
|
| 61 |
+
except litellm.AuthenticationError as e:
|
| 62 |
+
if "401" in str(e):
|
| 63 |
+
lib_logger.warning("Qwen Code returned 401. Forcing token refresh and retrying once.")
|
| 64 |
+
await self._refresh_token(credential_path, force=True)
|
| 65 |
+
response = await do_call()
|
| 66 |
+
else:
|
| 67 |
+
raise e
|
| 68 |
+
|
| 69 |
+
if kwargs.get("stream"):
|
| 70 |
+
return self._stream_parser(response, model)
|
| 71 |
+
return response
|