Mirrowel commited on
Commit
d28e7c5
·
1 Parent(s): fc62d82

feat(provider): introduce OAuth credential management and custom provider handling

Browse files

- Implemented a new `CredentialManager` to discover and manage OAuth credential files from standard paths (`~/.gemini`, `~/.qwen`).
- Added a `BackgroundRefresher` to proactively refresh OAuth tokens before they expire, ensuring continuous service.
- Refactored `RotatingClient` to support both API keys and OAuth credentials for provider authentication.
- Integrated `litellm_provider_params` in `RotatingClient` to allow specific LiteLLM configurations per provider (e.g., Google Cloud project ID for Gemini CLI).
- Introduced a `has_custom_logic` flag and `acompletion` method in `ProviderInterface` to enable custom handling for providers like Gemini CLI and Qwen Code, which require specific request formats, authentication, or stream parsing not fully supported by LiteLLM's standard interface.
- Updated `proxy_app/main.py` to utilize the new OAuth credential loading, provider-specific LiteLLM parameters, and the background token refresher.
- Enhanced `error_handler.py` to classify `httpx` exceptions, improving error reporting and retry logic for network and HTTP errors.
- Added `.env.example` entries for configuring Gemini CLI project ID and Qwen/Gemini OAuth credential paths.

BREAKING CHANGE: The constructor for `RotatingClient` has been updated. It now requires an `oauth_credentials` dictionary (can be empty) and accepts an optional `litellm_provider_params` dictionary. Direct instantiations of `RotatingClient` must be updated to include these new arguments.

.env.example CHANGED
@@ -11,3 +11,15 @@ NVIDIA_NIM_API_KEY_2="YOUR_NVIDIA_NIM_API_KEY_2"
11
 
12
  # A secret key for your proxy server to authenticate requests(Can be anything. Used for compatibility)
13
  PROXY_API_KEY="YOUR_PROXY_API_KEY"
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # A secret key for your proxy server to authenticate requests(Can be anything. Used for compatibility)
13
  PROXY_API_KEY="YOUR_PROXY_API_KEY"
14
+
15
+
16
+ # --- OAuth Accounts ---
17
+ # The system will automatically discover standard paths if left blank.
18
+
19
+ # For Gemini CLI (uses a custom API)
20
+ GEMINI_CLI_OAUTH_1=
21
+ # Required for Gemini CLI: Your Google Cloud Project ID
22
+ GEMINI_CLI_PROJECT_ID="gen-lang-client-..."
23
+
24
+ # For Qwen Code (OpenAI Compatible)
25
+ QWEN_CODE_OAUTH_1=
src/proxy_app/main.py CHANGED
@@ -52,6 +52,8 @@ args, _ = parser.parse_known_args()
52
  sys.path.append(str(Path(__file__).resolve().parent.parent))
53
 
54
  from rotator_library import RotatingClient, PROVIDER_PLUGINS
 
 
55
  from proxy_app.request_logger import log_request_to_console
56
  from proxy_app.batch_manager import EmbeddingBatcher
57
  from proxy_app.detailed_logger import DetailedLogger
@@ -125,19 +127,28 @@ PROXY_API_KEY = os.getenv("PROXY_API_KEY")
125
  if not PROXY_API_KEY:
126
  raise ValueError("PROXY_API_KEY environment variable not set.")
127
 
128
- # Load all provider API keys from environment variables
129
  api_keys = {}
 
130
  for key, value in os.environ.items():
131
- # Exclude PROXY_API_KEY from being treated as a provider API key
132
- if (key.endswith("_API_KEY") or "_API_KEY_" in key) and key != "PROXY_API_KEY":
133
- parts = key.split("_API_KEY")
134
- provider = parts[0].lower()
 
 
 
 
 
 
 
 
135
  if provider not in api_keys:
136
  api_keys[provider] = []
137
  api_keys[provider].append(value)
138
 
139
- if not api_keys:
140
- raise ValueError("No provider API keys found in environment variables.")
141
 
142
  # Load model ignore lists from environment variables
143
  ignore_models = {}
@@ -152,8 +163,20 @@ for key, value in os.environ.items():
152
  @asynccontextmanager
153
  async def lifespan(app: FastAPI):
154
  """Manage the RotatingClient's lifecycle with the app's lifespan."""
 
 
 
 
 
155
  # The client now uses the root logger configuration
156
- client = RotatingClient(api_keys=api_keys, configure_logging=True, ignore_models=ignore_models)
 
 
 
 
 
 
 
157
  app.state.rotating_client = client
158
  os.environ["LITELLM_LOG"] = "ERROR"
159
  litellm.set_verbose = False
@@ -168,6 +191,7 @@ async def lifespan(app: FastAPI):
168
 
169
  yield
170
 
 
171
  if app.state.embedding_batcher:
172
  await app.state.embedding_batcher.stop()
173
  await client.close()
@@ -477,20 +501,6 @@ async def embeddings(
477
 
478
  response = await client.aembedding(request=request, **request_data)
479
 
480
- if ENABLE_REQUEST_LOGGING:
481
- response_summary = {
482
- "model": response.model,
483
- "object": response.object,
484
- "usage": response.usage.model_dump(),
485
- "data_count": len(response.data),
486
- "embedding_dimensions": len(response.data[0].embedding) if response.data else 0
487
- }
488
- log_request_response(
489
- request_data=body.model_dump(exclude_none=True),
490
- response_data=response_summary,
491
- is_streaming=False,
492
- log_type="embedding"
493
- )
494
  return response
495
 
496
  except HTTPException as e:
@@ -510,17 +520,6 @@ async def embeddings(
510
  raise HTTPException(status_code=502, detail=f"Bad Gateway: {str(e)}")
511
  except Exception as e:
512
  logging.error(f"Embedding request failed: {e}")
513
- if ENABLE_REQUEST_LOGGING:
514
- try:
515
- request_data = await request.json()
516
- except json.JSONDecodeError:
517
- request_data = {"error": "Could not parse request body"}
518
- log_request_response(
519
- request_data=request_data,
520
- response_data={"error": str(e)},
521
- is_streaming=False,
522
- log_type="embedding"
523
- )
524
  raise HTTPException(status_code=500, detail=str(e))
525
 
526
  @app.get("/")
 
52
  sys.path.append(str(Path(__file__).resolve().parent.parent))
53
 
54
  from rotator_library import RotatingClient, PROVIDER_PLUGINS
55
+ from rotator_library.credential_manager import CredentialManager
56
+ from rotator_library.background_refresher import BackgroundRefresher
57
  from proxy_app.request_logger import log_request_to_console
58
  from proxy_app.batch_manager import EmbeddingBatcher
59
  from proxy_app.detailed_logger import DetailedLogger
 
127
  if not PROXY_API_KEY:
128
  raise ValueError("PROXY_API_KEY environment variable not set.")
129
 
130
+ # Split API keys and OAuth config loading
131
  api_keys = {}
132
+ oauth_credentials = {}
133
  for key, value in os.environ.items():
134
+ if key == "PROXY_API_KEY":
135
+ continue
136
+
137
+ # Handles GEMINI_CLI_OAUTH_1, QWEN_CODE_OAUTH_1, etc.
138
+ if "_OAUTH_" in key:
139
+ provider = key.split("_OAUTH_")[0].lower()
140
+ if provider not in oauth_credentials:
141
+ oauth_credentials[provider] = []
142
+ oauth_credentials[provider].append(value)
143
+ # Handles GEMINI_API_KEY_1, etc.
144
+ elif "_API_KEY" in key:
145
+ provider = key.split("_API_KEY")[0].lower()
146
  if provider not in api_keys:
147
  api_keys[provider] = []
148
  api_keys[provider].append(value)
149
 
150
+ if not api_keys and not oauth_credentials:
151
+ raise ValueError("No provider API keys or OAuth credentials found in environment variables.")
152
 
153
  # Load model ignore lists from environment variables
154
  ignore_models = {}
 
163
  @asynccontextmanager
164
  async def lifespan(app: FastAPI):
165
  """Manage the RotatingClient's lifecycle with the app's lifespan."""
166
+ # [NEW] Load provider-specific params
167
+ litellm_provider_params = {
168
+ "gemini_cli": {"project_id": os.getenv("GEMINI_CLI_PROJECT_ID")}
169
+ }
170
+
171
  # The client now uses the root logger configuration
172
+ client = RotatingClient(
173
+ api_keys=api_keys,
174
+ oauth_credentials=oauth_credentials, # Pass OAuth config
175
+ configure_logging=True,
176
+ litellm_provider_params=litellm_provider_params, # [NEW]
177
+ ignore_models=ignore_models
178
+ )
179
+ client.background_refresher.start() # Start the background task
180
  app.state.rotating_client = client
181
  os.environ["LITELLM_LOG"] = "ERROR"
182
  litellm.set_verbose = False
 
191
 
192
  yield
193
 
194
+ await client.background_refresher.stop() # Stop the background task on shutdown
195
  if app.state.embedding_batcher:
196
  await app.state.embedding_batcher.stop()
197
  await client.close()
 
501
 
502
  response = await client.aembedding(request=request, **request_data)
503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  return response
505
 
506
  except HTTPException as e:
 
520
  raise HTTPException(status_code=502, detail=f"Bad Gateway: {str(e)}")
521
  except Exception as e:
522
  logging.error(f"Embedding request failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
523
  raise HTTPException(status_code=500, detail=str(e))
524
 
525
  @app.get("/")
src/rotator_library/background_refresher.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/rotator_library/background_refresher.py
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from .client import RotatingClient
9
+
10
+ lib_logger = logging.getLogger('rotator_library')
11
+
12
+ class BackgroundRefresher:
13
+ """
14
+ A background task that periodically checks and refreshes OAuth tokens
15
+ to ensure they remain valid.
16
+ """
17
+ def __init__(self, client: 'RotatingClient', interval_seconds: int = 300):
18
+ self._client = client
19
+ self._interval = interval_seconds
20
+ self._task: Optional[asyncio.Task] = None
21
+
22
+ def start(self):
23
+ """Starts the background refresh task."""
24
+ if self._task is None:
25
+ self._task = asyncio.create_task(self._run())
26
+ lib_logger.info(f"Background token refresher started. Check interval: {self._interval} seconds.")
27
+
28
+ async def stop(self):
29
+ """Stops the background refresh task."""
30
+ if self._task:
31
+ self._task.cancel()
32
+ try:
33
+ await self._task
34
+ except asyncio.CancelledError:
35
+ pass
36
+ lib_logger.info("Background token refresher stopped.")
37
+
38
+ async def _run(self):
39
+ """The main loop for the background task."""
40
+ while True:
41
+ try:
42
+ await asyncio.sleep(self._interval)
43
+ lib_logger.info("Running proactive token refresh check...")
44
+
45
+ oauth_configs = self._client.get_oauth_credentials()
46
+ for provider, paths in oauth_configs.items():
47
+ provider_plugin = self._client._get_provider_instance(f"{provider}_oauth")
48
+ if provider_plugin and hasattr(provider_plugin, 'proactively_refresh'):
49
+ for path in paths:
50
+ try:
51
+ await provider_plugin.proactively_refresh(path)
52
+ except Exception as e:
53
+ lib_logger.error(f"Error during proactive refresh for '{path}': {e}")
54
+ except asyncio.CancelledError:
55
+ break
56
+ except Exception as e:
57
+ lib_logger.error(f"Unexpected error in background refresher loop: {e}")
src/rotator_library/client.py CHANGED
@@ -24,6 +24,8 @@ from .error_handler import PreRequestCallbackError, classify_error, AllProviders
24
  from .providers import PROVIDER_PLUGINS
25
  from .request_sanitizer import sanitize_request_payload
26
  from .cooldown_manager import CooldownManager
 
 
27
 
28
  class StreamedAPIError(Exception):
29
  """Custom exception to signal an API error received over a stream."""
@@ -39,11 +41,13 @@ class RotatingClient:
39
  def __init__(
40
  self,
41
  api_keys: Dict[str, List[str]],
 
42
  max_retries: int = 2,
43
  usage_file_path: str = "key_usage.json",
44
  configure_logging: bool = True,
45
  global_timeout: int = 30,
46
  abort_on_callback_error: bool = True,
 
47
  ignore_models: Optional[Dict[str, List[str]]] = None
48
  ):
49
  os.environ["LITELLM_LOG"] = "ERROR"
@@ -63,6 +67,18 @@ class RotatingClient:
63
  if not api_keys:
64
  raise ValueError("API keys dictionary cannot be empty.")
65
  self.api_keys = api_keys
 
 
 
 
 
 
 
 
 
 
 
 
66
  self.max_retries = max_retries
67
  self.global_timeout = global_timeout
68
  self.abort_on_callback_error = abort_on_callback_error
@@ -73,6 +89,7 @@ class RotatingClient:
73
  self.http_client = httpx.AsyncClient()
74
  self.all_providers = AllProviders()
75
  self.cooldown_manager = CooldownManager()
 
76
  self.ignore_models = ignore_models or {}
77
 
78
  def _is_model_ignored(self, provider: str, model_id: str) -> bool:
@@ -191,6 +208,9 @@ class RotatingClient:
191
 
192
  return kwargs
193
 
 
 
 
194
  def _get_provider_instance(self, provider_name: str):
195
  """Lazily initializes and returns a provider instance."""
196
  if provider_name not in self._provider_instances:
@@ -338,8 +358,8 @@ class RotatingClient:
338
  raise ValueError("'model' is a required parameter.")
339
 
340
  provider = model.split('/')[0]
341
- if provider not in self.api_keys:
342
- raise ValueError(f"No API keys configured for provider: {provider}")
343
 
344
  # Establish a global deadline for the entire request lifecycle.
345
  deadline = time.time() + self.global_timeout
@@ -347,16 +367,16 @@ class RotatingClient:
347
  # Create a mutable copy of the keys and shuffle it to ensure
348
  # that the key selection is randomized, which is crucial when
349
  # multiple keys have the same usage stats.
350
- keys_for_provider = list(self.api_keys[provider])
351
- random.shuffle(keys_for_provider)
352
 
353
- tried_keys = set()
354
  last_exception = None
355
  kwargs = self._convert_model_params(**kwargs)
356
-
357
- # The main rotation loop. It continues as long as there are untried keys and the global deadline has not been exceeded.
358
- while len(tried_keys) < len(keys_for_provider) and time.time() < deadline:
359
- current_key = None
360
  key_acquired = False
361
  try:
362
  # Check for a provider-wide cooldown first.
@@ -372,129 +392,167 @@ class RotatingClient:
372
  lib_logger.warning(f"Provider {provider} is in cooldown. Waiting for {remaining_cooldown:.2f} seconds.")
373
  await asyncio.sleep(remaining_cooldown)
374
 
375
- keys_to_try = [k for k in keys_for_provider if k not in tried_keys]
376
- if not keys_to_try:
377
  break
378
 
379
- lib_logger.info(f"Acquiring key for model {model}. Tried keys: {len(tried_keys)}/{len(keys_for_provider)}")
380
- current_key = await self.usage_manager.acquire_key(
381
- available_keys=keys_to_try,
382
  model=model,
383
  deadline=deadline
384
  )
385
  key_acquired = True
386
- tried_keys.add(current_key)
387
 
388
  litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
389
- provider_instance = self._get_provider_instance(provider)
390
- if provider_instance:
391
- if "safety_settings" in litellm_kwargs:
392
- converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
393
- if converted_settings is not None:
394
- litellm_kwargs["safety_settings"] = converted_settings
395
- else:
396
- del litellm_kwargs["safety_settings"]
397
 
398
- if provider == "gemini" and provider_instance:
399
- provider_instance.handle_thinking_parameter(litellm_kwargs, model)
400
-
401
- if "gemma-3" in model and "messages" in litellm_kwargs:
402
- litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
 
403
 
404
- litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
405
-
406
- for attempt in range(self.max_retries):
407
- try:
408
- lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
409
-
410
- if pre_request_callback:
411
- try:
412
- await pre_request_callback(request, litellm_kwargs)
413
- except Exception as e:
414
- if self.abort_on_callback_error:
415
- raise PreRequestCallbackError(f"Pre-request callback failed: {e}") from e
416
- else:
417
- lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
418
-
419
- response = await api_call(
420
- api_key=current_key,
421
- **litellm_kwargs,
422
- logger_fn=self._litellm_logger_callback
423
- )
424
-
425
- await self.usage_manager.record_success(current_key, model, response)
426
- await self.usage_manager.release_key(current_key, model)
427
  key_acquired = False
428
  return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
- except litellm.RateLimitError as e:
431
- last_exception = e
432
- log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
433
- classified_error = classify_error(e)
434
-
435
- # Extract a clean error message for the user-facing log
436
- error_message = str(e).split('\n')[0]
437
- lib_logger.info(f"Key ...{current_key[-4:]} hit rate limit for model {model}. Reason: '{error_message}'. Rotating key.")
438
-
439
- if classified_error.status_code == 429:
440
- cooldown_duration = classified_error.retry_after or 60
441
- await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
442
- lib_logger.warning(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
443
-
444
- await self.usage_manager.record_failure(current_key, model, classified_error)
445
- lib_logger.warning(f"Key ...{current_key[-4:]} encountered a rate limit. Trying next key.")
446
- break # Move to the next key
447
-
448
- except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
449
- last_exception = e
450
- log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
451
- classified_error = classify_error(e)
452
- await self.usage_manager.record_failure(current_key, model, classified_error)
453
-
454
- if attempt >= self.max_retries - 1:
 
 
 
 
 
 
 
 
 
455
  error_message = str(e).split('\n')[0]
456
- lib_logger.warning(f"Key ...{current_key[-4:]} failed after max retries for model {model} due to a server error. Reason: '{error_message}'. Rotating key.")
 
 
 
 
 
 
 
 
457
  break # Move to the next key
458
-
459
- # For temporary errors, wait before retrying with the same key.
460
- wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
461
- remaining_budget = deadline - time.time()
462
-
463
- # If the required wait time exceeds the budget, don't wait; rotate to the next key immediately.
464
- if wait_time > remaining_budget:
465
- lib_logger.warning(f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early.")
466
- break
467
 
468
- error_message = str(e).split('\n')[0]
469
- lib_logger.warning(f"Key ...{current_key[-4:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s.")
470
- await asyncio.sleep(wait_time)
471
- continue # Retry with the same key
 
472
 
473
- except Exception as e:
474
- last_exception = e
475
- log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
476
-
477
- if request and await request.is_disconnected():
478
- lib_logger.warning(f"Client disconnected. Aborting retries for key ...{current_key[-4:]}.")
479
- raise last_exception
480
-
481
- classified_error = classify_error(e)
482
- error_message = str(e).split('\n')[0]
483
- lib_logger.warning(f"Key ...{current_key[-4:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message}. Rotating key.")
484
- if classified_error.status_code == 429:
485
- cooldown_duration = classified_error.retry_after or 60
486
- await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
487
- lib_logger.warning(f"IP-based rate limit detected for {provider} from generic exception. Starting a {cooldown_duration}-second global cooldown.")
488
-
489
- if classified_error.error_type in ['invalid_request', 'context_window_exceeded', 'authentication']:
490
- # For these errors, we should not retry with other keys.
491
- raise last_exception
492
-
493
- await self.usage_manager.record_failure(current_key, model, classified_error)
494
- break # Try next key for other errors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  finally:
496
- if key_acquired and current_key:
497
- await self.usage_manager.release_key(current_key, model)
498
 
499
  if last_exception:
500
  # Log the final error but do not raise it, as per the new requirement.
@@ -510,19 +568,19 @@ class RotatingClient:
510
  provider = model.split('/')[0]
511
 
512
  # Create a mutable copy of the keys and shuffle it.
513
- keys_for_provider = list(self.api_keys[provider])
514
- random.shuffle(keys_for_provider)
515
 
516
  deadline = time.time() + self.global_timeout
517
- tried_keys = set()
518
  last_exception = None
519
  kwargs = self._convert_model_params(**kwargs)
520
 
521
  consecutive_quota_failures = 0
522
 
523
  try:
524
- while len(tried_keys) < len(keys_for_provider) and time.time() < deadline:
525
- current_key = None
526
  key_acquired = False
527
  try:
528
  if await self.cooldown_manager.is_cooling_down(provider):
@@ -534,21 +592,52 @@ class RotatingClient:
534
  lib_logger.warning(f"Provider {provider} is in a global cooldown. All requests to this provider will be paused for {remaining_cooldown:.2f} seconds.")
535
  await asyncio.sleep(remaining_cooldown)
536
 
537
- keys_to_try = [k for k in keys_for_provider if k not in tried_keys]
538
- if not keys_to_try:
539
- lib_logger.warning(f"All keys for provider {provider} have been tried. No more keys to rotate to.")
540
  break
541
 
542
- lib_logger.info(f"Acquiring key for model {model}. Tried keys: {len(tried_keys)}/{len(keys_for_provider)}")
543
- current_key = await self.usage_manager.acquire_key(
544
- available_keys=keys_to_try,
545
  model=model,
546
  deadline=deadline
547
  )
548
  key_acquired = True
549
- tried_keys.add(current_key)
550
 
551
  litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  provider_instance = self._get_provider_instance(provider)
553
  if provider_instance:
554
  if "safety_settings" in litellm_kwargs:
@@ -568,7 +657,7 @@ class RotatingClient:
568
 
569
  for attempt in range(self.max_retries):
570
  try:
571
- lib_logger.info(f"Attempting stream with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
572
 
573
  if pre_request_callback:
574
  try:
@@ -580,15 +669,14 @@ class RotatingClient:
580
  lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
581
 
582
  response = await litellm.acompletion(
583
- api_key=current_key,
584
  **litellm_kwargs,
585
  logger_fn=self._litellm_logger_callback
586
  )
587
 
588
- lib_logger.info(f"Stream connection established for key ...{current_key[-4:]}. Processing response.")
589
 
590
  key_acquired = False
591
- stream_generator = self._safe_streaming_wrapper(response, current_key, model, request)
592
 
593
  async for chunk in stream_generator:
594
  yield chunk
@@ -618,7 +706,7 @@ class RotatingClient:
618
 
619
  # Now, log the failure with the extracted raw response.
620
  log_failure(
621
- api_key=current_key,
622
  model=model,
623
  attempt=attempt + 1,
624
  error=e,
@@ -633,7 +721,7 @@ class RotatingClient:
633
 
634
  if "quota" in error_message_text.lower() or "resource_exhausted" in error_status.lower():
635
  consecutive_quota_failures += 1
636
- lib_logger.warning(f"Key ...{current_key[-4:]} hit a quota limit. This is consecutive failure #{consecutive_quota_failures} for this request.")
637
 
638
  quota_value = "N/A"
639
  quota_id = "N/A"
@@ -648,11 +736,11 @@ class RotatingClient:
648
  if quota_value != "N/A" and quota_id != "N/A":
649
  break
650
 
651
- await self.usage_manager.record_failure(current_key, model, classified_error)
652
 
653
  if consecutive_quota_failures >= 3:
654
  console_log_message = (
655
- f"Terminating stream for key ...{current_key[-4:]} due to 3rd consecutive quota error. "
656
  f"This is now considered a fatal input data error. ID: {quota_id}, Limit: {quota_value}."
657
  )
658
  client_error_message = (
@@ -668,31 +756,31 @@ class RotatingClient:
668
 
669
  else:
670
  # [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
671
- lib_logger.warning(f"Quota error on key ...{current_key[-4:]} (failure {consecutive_quota_failures}/3). Rotating key silently.")
672
  break
673
 
674
  else:
675
  consecutive_quota_failures = 0
676
  # [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
677
- lib_logger.warning(f"Key ...{current_key[-4:]} encountered a recoverable error ({classified_error.error_type}) during stream. Rotating key silently.")
678
 
679
  if classified_error.error_type == 'rate_limit' and classified_error.status_code == 429:
680
  cooldown_duration = classified_error.retry_after or 60
681
  await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
682
  lib_logger.warning(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
683
 
684
- await self.usage_manager.record_failure(current_key, model, classified_error)
685
  break
686
 
687
  except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
688
  consecutive_quota_failures = 0
689
  last_exception = e
690
- log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
691
  classified_error = classify_error(e)
692
- await self.usage_manager.record_failure(current_key, model, classified_error)
693
 
694
  if attempt >= self.max_retries - 1:
695
- lib_logger.warning(f"Key ...{current_key[-4:]} failed after max retries for model {model} due to a server error. Rotating key silently.")
696
  # [MODIFIED] Do not yield to the client here.
697
  break
698
 
@@ -703,17 +791,17 @@ class RotatingClient:
703
  break
704
 
705
  error_message = str(e).split('\n')[0]
706
- lib_logger.warning(f"Key ...{current_key[-4:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s.")
707
  await asyncio.sleep(wait_time)
708
  continue
709
 
710
  except Exception as e:
711
  consecutive_quota_failures = 0
712
  last_exception = e
713
- log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
714
  classified_error = classify_error(e)
715
 
716
- lib_logger.warning(f"Key ...{current_key[-4:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {str(e)}. Rotating key.")
717
 
718
  if classified_error.status_code == 429:
719
  cooldown_duration = classified_error.retry_after or 60
@@ -724,12 +812,12 @@ class RotatingClient:
724
  raise last_exception
725
 
726
  # [MODIFIED] Do not yield to the client here.
727
- await self.usage_manager.record_failure(current_key, model, classified_error)
728
  break
729
 
730
  finally:
731
- if key_acquired and current_key:
732
- await self.usage_manager.release_key(current_key, model)
733
 
734
  final_error_message = "Failed to complete the streaming request: No available API keys after rotation or global timeout exceeded."
735
  if last_exception:
 
24
  from .providers import PROVIDER_PLUGINS
25
  from .request_sanitizer import sanitize_request_payload
26
  from .cooldown_manager import CooldownManager
27
+ from .credential_manager import CredentialManager
28
+ from .background_refresher import BackgroundRefresher
29
 
30
  class StreamedAPIError(Exception):
31
  """Custom exception to signal an API error received over a stream."""
 
41
  def __init__(
42
  self,
43
  api_keys: Dict[str, List[str]],
44
+ oauth_credentials: Dict[str, List[str]],
45
  max_retries: int = 2,
46
  usage_file_path: str = "key_usage.json",
47
  configure_logging: bool = True,
48
  global_timeout: int = 30,
49
  abort_on_callback_error: bool = True,
50
+ litellm_provider_params: Optional[Dict[str, Any]] = None, # [NEW]
51
  ignore_models: Optional[Dict[str, List[str]]] = None
52
  ):
53
  os.environ["LITELLM_LOG"] = "ERROR"
 
67
  if not api_keys:
68
  raise ValueError("API keys dictionary cannot be empty.")
69
  self.api_keys = api_keys
70
+ self.credential_manager = CredentialManager(oauth_credentials)
71
+ self.oauth_credentials = self.credential_manager.discover_and_prepare()
72
+ self.background_refresher = BackgroundRefresher(self)
73
+ self.oauth_providers = set(self.oauth_credentials.keys())
74
+
75
+ all_credentials = {}
76
+ for provider, keys in api_keys.items():
77
+ all_credentials.setdefault(provider, []).extend(keys)
78
+ for provider, paths in self.oauth_credentials.items():
79
+ all_credentials.setdefault(provider, []).extend(paths)
80
+ self.all_credentials = all_credentials
81
+
82
  self.max_retries = max_retries
83
  self.global_timeout = global_timeout
84
  self.abort_on_callback_error = abort_on_callback_error
 
89
  self.http_client = httpx.AsyncClient()
90
  self.all_providers = AllProviders()
91
  self.cooldown_manager = CooldownManager()
92
+ self.litellm_provider_params = litellm_provider_params or {}
93
  self.ignore_models = ignore_models or {}
94
 
95
  def _is_model_ignored(self, provider: str, model_id: str) -> bool:
 
208
 
209
  return kwargs
210
 
211
+ def get_oauth_credentials(self) -> Dict[str, List[str]]:
212
+ return self.oauth_credentials
213
+
214
  def _get_provider_instance(self, provider_name: str):
215
  """Lazily initializes and returns a provider instance."""
216
  if provider_name not in self._provider_instances:
 
358
  raise ValueError("'model' is a required parameter.")
359
 
360
  provider = model.split('/')[0]
361
+ if provider not in self.all_credentials:
362
+ raise ValueError(f"No API keys or OAuth credentials configured for provider: {provider}")
363
 
364
  # Establish a global deadline for the entire request lifecycle.
365
  deadline = time.time() + self.global_timeout
 
367
  # Create a mutable copy of the keys and shuffle it to ensure
368
  # that the key selection is randomized, which is crucial when
369
  # multiple keys have the same usage stats.
370
+ credentials_for_provider = list(self.all_credentials[provider])
371
+ random.shuffle(credentials_for_provider)
372
 
373
+ tried_creds = set()
374
  last_exception = None
375
  kwargs = self._convert_model_params(**kwargs)
376
+
377
+ # The main rotation loop. It continues as long as there are untried credentials and the global deadline has not been exceeded.
378
+ while len(tried_creds) < len(credentials_for_provider) and time.time() < deadline:
379
+ current_cred = None
380
  key_acquired = False
381
  try:
382
  # Check for a provider-wide cooldown first.
 
392
  lib_logger.warning(f"Provider {provider} is in cooldown. Waiting for {remaining_cooldown:.2f} seconds.")
393
  await asyncio.sleep(remaining_cooldown)
394
 
395
+ creds_to_try = [c for c in credentials_for_provider if c not in tried_creds]
396
+ if not creds_to_try:
397
  break
398
 
399
+ lib_logger.info(f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{len(credentials_for_provider)}")
400
+ current_cred = await self.usage_manager.acquire_key(
401
+ available_keys=creds_to_try,
402
  model=model,
403
  deadline=deadline
404
  )
405
  key_acquired = True
406
+ tried_creds.add(current_cred)
407
 
408
  litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
 
 
 
 
 
 
 
 
409
 
410
+ # [NEW] Merge provider-specific params
411
+ if provider in self.litellm_provider_params:
412
+ litellm_kwargs["litellm_params"] = {
413
+ **self.litellm_provider_params[provider],
414
+ **litellm_kwargs.get("litellm_params", {})
415
+ }
416
 
417
+ provider_plugin = self._get_provider_instance(provider)
418
+ if provider_plugin and provider_plugin.has_custom_logic():
419
+ lib_logger.debug(f"Provider '{provider}' has custom logic. Delegating call.")
420
+ litellm_kwargs["credential_identifier"] = current_cred
421
+
422
+ # The plugin handles the entire call, including retries on 401, etc.
423
+ # The main retry loop here is for key rotation on other errors.
424
+ response = await provider_plugin.acompletion(self.http_client, **litellm_kwargs)
425
+
426
+ # For non-streaming, success is immediate
427
+ if not kwargs.get("stream"):
428
+ await self.usage_manager.record_success(current_cred, model, response)
429
+ await self.usage_manager.release_key(current_cred, model)
 
 
 
 
 
 
 
 
 
 
430
  key_acquired = False
431
  return response
432
+ else:
433
+ # For streaming, wrap the response and return
434
+ key_acquired = False
435
+ stream_generator = self._safe_streaming_wrapper(response, current_cred, model, request)
436
+ async for chunk in stream_generator:
437
+ yield chunk
438
+ return
439
+
440
+ else: # This is the standard API Key / litellm-handled provider logic
441
+ is_oauth = provider in self.oauth_providers
442
+ if is_oauth: # Standard OAuth provider (not custom)
443
+ # ... (logic to set headers) ...
444
+ pass
445
+ else: # API Key
446
+ litellm_kwargs["api_key"] = current_cred
447
+
448
+ provider_instance = self._get_provider_instance(provider)
449
+ if provider_instance:
450
+ if "safety_settings" in litellm_kwargs:
451
+ converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
452
+ if converted_settings is not None:
453
+ litellm_kwargs["safety_settings"] = converted_settings
454
+ else:
455
+ del litellm_kwargs["safety_settings"]
456
+
457
+ if provider == "gemini" and provider_instance:
458
+ provider_instance.handle_thinking_parameter(litellm_kwargs, model)
459
 
460
+ if "gemma-3" in model and "messages" in litellm_kwargs:
461
+ litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
462
+
463
+ litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
464
+
465
+ for attempt in range(self.max_retries):
466
+ try:
467
+ lib_logger.info(f"Attempting call with credential ...{current_cred[-6:]} (Attempt {attempt + 1}/{self.max_retries})")
468
+
469
+ if pre_request_callback:
470
+ try:
471
+ await pre_request_callback(request, litellm_kwargs)
472
+ except Exception as e:
473
+ if self.abort_on_callback_error:
474
+ raise PreRequestCallbackError(f"Pre-request callback failed: {e}") from e
475
+ else:
476
+ lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
477
+
478
+ response = await api_call(
479
+ **litellm_kwargs,
480
+ logger_fn=self._litellm_logger_callback
481
+ )
482
+
483
+ await self.usage_manager.record_success(current_cred, model, response)
484
+ await self.usage_manager.release_key(current_cred, model)
485
+ key_acquired = False
486
+ return response
487
+
488
+ except litellm.RateLimitError as e:
489
+ last_exception = e
490
+ log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
491
+ classified_error = classify_error(e)
492
+
493
+ # Extract a clean error message for the user-facing log
494
  error_message = str(e).split('\n')[0]
495
+ lib_logger.info(f"Key ...{current_cred[-6:]} hit rate limit for model {model}. Reason: '{error_message}'. Rotating key.")
496
+
497
+ if classified_error.status_code == 429:
498
+ cooldown_duration = classified_error.retry_after or 60
499
+ await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
500
+ lib_logger.warning(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
501
+
502
+ await self.usage_manager.record_failure(current_cred, model, classified_error)
503
+ lib_logger.warning(f"Key ...{current_cred[-6:]} encountered a rate limit. Trying next key.")
504
  break # Move to the next key
 
 
 
 
 
 
 
 
 
505
 
506
+ except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
507
+ last_exception = e
508
+ log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
509
+ classified_error = classify_error(e)
510
+ await self.usage_manager.record_failure(current_cred, model, classified_error)
511
 
512
+ if attempt >= self.max_retries - 1:
513
+ error_message = str(e).split('\n')[0]
514
+ lib_logger.warning(f"Key ...{current_cred[-6:]} failed after max retries for model {model} due to a server error. Reason: '{error_message}'. Rotating key.")
515
+ break # Move to the next key
516
+
517
+ # For temporary errors, wait before retrying with the same key.
518
+ wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
519
+ remaining_budget = deadline - time.time()
520
+
521
+ # If the required wait time exceeds the budget, don't wait; rotate to the next key immediately.
522
+ if wait_time > remaining_budget:
523
+ lib_logger.warning(f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early.")
524
+ break
525
+
526
+ error_message = str(e).split('\n')[0]
527
+ lib_logger.warning(f"Key ...{current_cred[-6:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s.")
528
+ await asyncio.sleep(wait_time)
529
+ continue # Retry with the same key
530
+
531
+ except Exception as e:
532
+ last_exception = e
533
+ log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
534
+
535
+ if request and await request.is_disconnected():
536
+ lib_logger.warning(f"Client disconnected. Aborting retries for key ...{current_cred[-6:]}.")
537
+ raise last_exception
538
+
539
+ classified_error = classify_error(e)
540
+ error_message = str(e).split('\n')[0]
541
+ lib_logger.warning(f"Key ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message}. Rotating key.")
542
+ if classified_error.status_code == 429:
543
+ cooldown_duration = classified_error.retry_after or 60
544
+ await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
545
+ lib_logger.warning(f"IP-based rate limit detected for {provider} from generic exception. Starting a {cooldown_duration}-second global cooldown.")
546
+
547
+ if classified_error.error_type in ['invalid_request', 'context_window_exceeded', 'authentication']:
548
+ # For these errors, we should not retry with other keys.
549
+ raise last_exception
550
+
551
+ await self.usage_manager.record_failure(current_cred, model, classified_error)
552
+ break # Try next key for other errors
553
  finally:
554
+ if key_acquired and current_cred:
555
+ await self.usage_manager.release_key(current_cred, model)
556
 
557
  if last_exception:
558
  # Log the final error but do not raise it, as per the new requirement.
 
568
  provider = model.split('/')[0]
569
 
570
  # Create a mutable copy of the keys and shuffle it.
571
+ credentials_for_provider = list(self.all_credentials[provider])
572
+ random.shuffle(credentials_for_provider)
573
 
574
  deadline = time.time() + self.global_timeout
575
+ tried_creds = set()
576
  last_exception = None
577
  kwargs = self._convert_model_params(**kwargs)
578
 
579
  consecutive_quota_failures = 0
580
 
581
  try:
582
+ while len(tried_creds) < len(credentials_for_provider) and time.time() < deadline:
583
+ current_cred = None
584
  key_acquired = False
585
  try:
586
  if await self.cooldown_manager.is_cooling_down(provider):
 
592
  lib_logger.warning(f"Provider {provider} is in a global cooldown. All requests to this provider will be paused for {remaining_cooldown:.2f} seconds.")
593
  await asyncio.sleep(remaining_cooldown)
594
 
595
+ creds_to_try = [c for c in credentials_for_provider if c not in tried_creds]
596
+ if not creds_to_try:
597
+ lib_logger.warning(f"All credentials for provider {provider} have been tried. No more credentials to rotate to.")
598
  break
599
 
600
+ lib_logger.info(f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{len(credentials_for_provider)}")
601
+ current_cred = await self.usage_manager.acquire_key(
602
+ available_keys=creds_to_try,
603
  model=model,
604
  deadline=deadline
605
  )
606
  key_acquired = True
607
+ tried_creds.add(current_cred)
608
 
609
  litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
610
+
611
+ # [NEW] Merge provider-specific params
612
+ if provider in self.litellm_provider_params:
613
+ litellm_kwargs["litellm_params"] = {
614
+ **self.litellm_provider_params[provider],
615
+ **litellm_kwargs.get("litellm_params", {})
616
+ }
617
+
618
+ provider_plugin = self._get_provider_instance(provider)
619
+ if provider_plugin and provider_plugin.has_custom_logic():
620
+ lib_logger.debug(f"Provider '{provider}' has custom logic. Delegating call.")
621
+ litellm_kwargs["credential_identifier"] = current_cred
622
+
623
+ # The plugin handles the entire call, including retries on 401, etc.
624
+ # The main retry loop here is for key rotation on other errors.
625
+ response = await provider_plugin.acompletion(self.http_client, **litellm_kwargs)
626
+
627
+ key_acquired = False
628
+ stream_generator = self._safe_streaming_wrapper(response, current_cred, model, request)
629
+ async for chunk in stream_generator:
630
+ yield chunk
631
+ return
632
+
633
+ else: # This is the standard API Key / litellm-handled provider logic
634
+ is_oauth = provider in self.oauth_providers
635
+ if is_oauth: # Standard OAuth provider (not custom)
636
+ # ... (logic to set headers) ...
637
+ pass
638
+ else: # API Key
639
+ litellm_kwargs["api_key"] = current_cred
640
+
641
  provider_instance = self._get_provider_instance(provider)
642
  if provider_instance:
643
  if "safety_settings" in litellm_kwargs:
 
657
 
658
  for attempt in range(self.max_retries):
659
  try:
660
+ lib_logger.info(f"Attempting stream with credential ...{current_cred[-6:]} (Attempt {attempt + 1}/{self.max_retries})")
661
 
662
  if pre_request_callback:
663
  try:
 
669
  lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
670
 
671
  response = await litellm.acompletion(
 
672
  **litellm_kwargs,
673
  logger_fn=self._litellm_logger_callback
674
  )
675
 
676
+ lib_logger.info(f"Stream connection established for credential ...{current_cred[-6:]}. Processing response.")
677
 
678
  key_acquired = False
679
+ stream_generator = self._safe_streaming_wrapper(response, current_cred, model, request)
680
 
681
  async for chunk in stream_generator:
682
  yield chunk
 
706
 
707
  # Now, log the failure with the extracted raw response.
708
  log_failure(
709
+ api_key=current_cred,
710
  model=model,
711
  attempt=attempt + 1,
712
  error=e,
 
721
 
722
  if "quota" in error_message_text.lower() or "resource_exhausted" in error_status.lower():
723
  consecutive_quota_failures += 1
724
+ lib_logger.warning(f"Credential ...{current_cred[-6:]} hit a quota limit. This is consecutive failure #{consecutive_quota_failures} for this request.")
725
 
726
  quota_value = "N/A"
727
  quota_id = "N/A"
 
736
  if quota_value != "N/A" and quota_id != "N/A":
737
  break
738
 
739
+ await self.usage_manager.record_failure(current_cred, model, classified_error)
740
 
741
  if consecutive_quota_failures >= 3:
742
  console_log_message = (
743
+ f"Terminating stream for credential ...{current_cred[-6:]} due to 3rd consecutive quota error. "
744
  f"This is now considered a fatal input data error. ID: {quota_id}, Limit: {quota_value}."
745
  )
746
  client_error_message = (
 
756
 
757
  else:
758
  # [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
759
+ lib_logger.warning(f"Quota error on credential ...{current_cred[-6:]} (failure {consecutive_quota_failures}/3). Rotating key silently.")
760
  break
761
 
762
  else:
763
  consecutive_quota_failures = 0
764
  # [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
765
+ lib_logger.warning(f"Credential ...{current_cred[-6:]} encountered a recoverable error ({classified_error.error_type}) during stream. Rotating key silently.")
766
 
767
  if classified_error.error_type == 'rate_limit' and classified_error.status_code == 429:
768
  cooldown_duration = classified_error.retry_after or 60
769
  await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
770
  lib_logger.warning(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
771
 
772
+ await self.usage_manager.record_failure(current_cred, model, classified_error)
773
  break
774
 
775
  except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
776
  consecutive_quota_failures = 0
777
  last_exception = e
778
+ log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
779
  classified_error = classify_error(e)
780
+ await self.usage_manager.record_failure(current_cred, model, classified_error)
781
 
782
  if attempt >= self.max_retries - 1:
783
+ lib_logger.warning(f"Credential ...{current_cred[-6:]} failed after max retries for model {model} due to a server error. Rotating key silently.")
784
  # [MODIFIED] Do not yield to the client here.
785
  break
786
 
 
791
  break
792
 
793
  error_message = str(e).split('\n')[0]
794
+ lib_logger.warning(f"Credential ...{current_cred[-6:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s.")
795
  await asyncio.sleep(wait_time)
796
  continue
797
 
798
  except Exception as e:
799
  consecutive_quota_failures = 0
800
  last_exception = e
801
+ log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
802
  classified_error = classify_error(e)
803
 
804
+ lib_logger.warning(f"Credential ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {str(e)}. Rotating key.")
805
 
806
  if classified_error.status_code == 429:
807
  cooldown_duration = classified_error.retry_after or 60
 
812
  raise last_exception
813
 
814
  # [MODIFIED] Do not yield to the client here.
815
+ await self.usage_manager.record_failure(current_cred, model, classified_error)
816
  break
817
 
818
  finally:
819
+ if key_acquired and current_cred:
820
+ await self.usage_manager.release_key(current_cred, model)
821
 
822
  final_error_message = "Failed to complete the streaming request: No available API keys after rotation or global timeout exceeded."
823
  if last_exception:
src/rotator_library/credential_manager.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional
6
+
7
+ lib_logger = logging.getLogger('rotator_library')
8
+
9
+ OAUTH_BASE_DIR = Path.cwd() / "oauth_creds"
10
+ OAUTH_BASE_DIR.mkdir(exist_ok=True)
11
+
12
+ # Standard paths where tools like `gemini login` store credentials.
13
+ DEFAULT_OAUTH_PATHS = {
14
+ "gemini": Path.home() / ".gemini" / "oauth_creds.json",
15
+ "qwen": Path.home() / ".qwen" / "oauth_creds.json",
16
+ # Add other providers like 'claude' here if they have a standard CLI path
17
+ }
18
+
19
+ class CredentialManager:
20
+ """
21
+ Discovers OAuth credential files from standard locations, copies them locally,
22
+ and updates the configuration to use the local paths.
23
+ """
24
+ def __init__(self, oauth_config: Dict[str, List[str]]):
25
+ self.oauth_config = oauth_config
26
+
27
+ def discover_and_prepare(self) -> Dict[str, List[str]]:
28
+ """
29
+ Processes the initial OAuth config. If a path is empty, it tries to
30
+ discover the file from a default location. It then copies the file
31
+ locally if it doesn't already exist and returns the updated config
32
+ pointing to the local paths.
33
+ """
34
+ updated_config = {}
35
+ for provider, paths in self.oauth_config.items():
36
+ updated_paths = []
37
+ for i, path_str in enumerate(paths):
38
+ account_id = i + 1
39
+ source_path = self._resolve_source_path(provider, path_str)
40
+
41
+ if not source_path or not source_path.exists():
42
+ lib_logger.warning(f"Could not find OAuth file for {provider} account #{account_id}. Skipping.")
43
+ continue
44
+
45
+ local_filename = f"{provider}_oauth_{account_id}.json"
46
+ local_path = OAUTH_BASE_DIR / local_filename
47
+
48
+ if not local_path.exists():
49
+ try:
50
+ shutil.copy(source_path, local_path)
51
+ lib_logger.info(f"Copied '{source_path}' to local credentials at '{local_path}'.")
52
+ except Exception as e:
53
+ lib_logger.error(f"Failed to copy OAuth file for {provider} account #{account_id}: {e}")
54
+ continue
55
+
56
+ updated_paths.append(str(local_path.resolve()))
57
+
58
+ if updated_paths:
59
+ updated_config[provider] = updated_paths
60
+
61
+ return updated_config
62
+
63
+ def _resolve_source_path(self, provider: str, specified_path: Optional[str]) -> Optional[Path]:
64
+ """Determines the source path for a credential file."""
65
+ if specified_path:
66
+ # If a path is given, use it directly.
67
+ return Path(specified_path).expanduser()
68
+
69
+ # If no path is given, try the default location.
70
+ return DEFAULT_OAUTH_PATHS.get(provider)
src/rotator_library/error_handler.py CHANGED
@@ -1,5 +1,7 @@
1
  import re
 
2
  from typing import Optional, Dict, Any
 
3
 
4
  from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError, Timeout, ContextWindowExceededError
5
 
@@ -22,8 +24,6 @@ class ClassifiedError:
22
  def __str__(self):
23
  return f"ClassifiedError(type={self.error_type}, status={self.status_code}, retry_after={self.retry_after}, original_exc={self.original_exception})"
24
 
25
- import json
26
-
27
  def get_retry_after(error: Exception) -> Optional[int]:
28
  """
29
  Extracts the 'retry-after' duration in seconds from an exception message.
@@ -80,9 +80,24 @@ def get_retry_after(error: Exception) -> Optional[int]:
80
  def classify_error(e: Exception) -> ClassifiedError:
81
  """
82
  Classifies an exception into a structured ClassifiedError object.
 
83
  """
84
  status_code = getattr(e, 'status_code', None)
 
 
 
 
 
 
 
 
 
 
 
85
 
 
 
 
86
  if isinstance(e, PreRequestCallbackError):
87
  return ClassifiedError(
88
  error_type='pre_request_callback_error',
 
1
  import re
2
+ import json
3
  from typing import Optional, Dict, Any
4
+ import httpx
5
 
6
  from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError, Timeout, ContextWindowExceededError
7
 
 
24
  def __str__(self):
25
  return f"ClassifiedError(type={self.error_type}, status={self.status_code}, retry_after={self.retry_after}, original_exc={self.original_exception})"
26
 
 
 
27
  def get_retry_after(error: Exception) -> Optional[int]:
28
  """
29
  Extracts the 'retry-after' duration in seconds from an exception message.
 
80
  def classify_error(e: Exception) -> ClassifiedError:
81
  """
82
  Classifies an exception into a structured ClassifiedError object.
83
+ Now handles both litellm and httpx exceptions.
84
  """
85
  status_code = getattr(e, 'status_code', None)
86
+ if isinstance(e, httpx.HTTPStatusError): # [NEW] Handle httpx errors first
87
+ status_code = e.response.status_code
88
+ if status_code == 401:
89
+ return ClassifiedError(error_type='authentication', original_exception=e, status_code=status_code)
90
+ if status_code == 429:
91
+ retry_after = get_retry_after(e)
92
+ return ClassifiedError(error_type='rate_limit', original_exception=e, status_code=status_code, retry_after=retry_after)
93
+ if 400 <= status_code < 500:
94
+ return ClassifiedError(error_type='invalid_request', original_exception=e, status_code=status_code)
95
+ if 500 <= status_code:
96
+ return ClassifiedError(error_type='server_error', original_exception=e, status_code=status_code)
97
 
98
+ if isinstance(e, (httpx.TimeoutException, httpx.ConnectError, httpx.NetworkError)): # [NEW]
99
+ return ClassifiedError(error_type='api_connection', original_exception=e, status_code=status_code)
100
+
101
  if isinstance(e, PreRequestCallbackError):
102
  return ClassifiedError(
103
  error_type='pre_request_callback_error',
src/rotator_library/providers/__init__.py CHANGED
@@ -26,9 +26,9 @@ def _register_providers():
26
  for attribute_name in dir(module):
27
  attribute = getattr(module, attribute_name)
28
  if isinstance(attribute, type) and issubclass(attribute, ProviderInterface) and attribute is not ProviderInterface:
29
- # The provider name is derived from the module name (e.g., 'openai_provider' -> 'openai')
30
- provider_name = module_name.replace("_provider", "")
31
  # Remap 'nvidia' to 'nvidia_nim' to align with litellm's provider name
 
32
  if provider_name == "nvidia":
33
  provider_name = "nvidia_nim"
34
  PROVIDER_PLUGINS[provider_name] = attribute
 
26
  for attribute_name in dir(module):
27
  attribute = getattr(module, attribute_name)
28
  if isinstance(attribute, type) and issubclass(attribute, ProviderInterface) and attribute is not ProviderInterface:
29
+ # Derives 'gemini_cli' from 'gemini_cli_provider.py'
 
30
  # Remap 'nvidia' to 'nvidia_nim' to align with litellm's provider name
31
+ provider_name = module_name.replace("_provider", "")
32
  if provider_name == "nvidia":
33
  provider_name = "nvidia_nim"
34
  PROVIDER_PLUGINS[provider_name] = attribute
src/rotator_library/providers/gemini_auth_base.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/rotator_library/providers/gemini_auth_base.py
2
+
3
+ import json
4
+ import time
5
+ import asyncio
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Dict, Any
9
+
10
+ import httpx
11
+
12
+ lib_logger = logging.getLogger('rotator_library')
13
+
14
+ CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
15
+ CLIENT_SECRET = "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"
16
+ TOKEN_URI = "https://oauth2.googleapis.com/token"
17
+ REFRESH_EXPIRY_BUFFER_SECONDS = 300
18
+
19
+ class GeminiAuthBase:
20
+ def __init__(self):
21
+ self._credentials_cache: Dict[str, Dict[str, Any]] = {}
22
+ self._refresh_locks: Dict[str, asyncio.Lock] = {}
23
+
24
+ async def _load_credentials(self, path: str) -> Dict[str, Any]:
25
+ if path in self._credentials_cache:
26
+ return self._credentials_cache[path]
27
+
28
+ async with self._get_lock(path):
29
+ if path in self._credentials_cache:
30
+ return self._credentials_cache[path]
31
+ try:
32
+ with open(path, 'r') as f:
33
+ creds = json.load(f)
34
+ # Handle gcloud-style creds file which nest tokens under "credential"
35
+ if "credential" in creds:
36
+ creds = creds["credential"]
37
+ self._credentials_cache[path] = creds
38
+ return creds
39
+ except Exception as e:
40
+ raise IOError(f"Failed to load Gemini OAuth credentials from '{path}': {e}")
41
+
42
+ async def _save_credentials(self, path: str, creds: Dict[str, Any]):
43
+ self._credentials_cache[path] = creds
44
+ try:
45
+ with open(path, 'w') as f:
46
+ json.dump(creds, f, indent=2)
47
+ except Exception as e:
48
+ lib_logger.error(f"Failed to save updated Gemini OAuth credentials to '{path}': {e}")
49
+
50
+ def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
51
+ expiry = creds.get("token_expiry") # gcloud format
52
+ if not expiry: # gemini-cli format
53
+ expiry_timestamp = creds.get("expiry_date", 0) / 1000
54
+ else:
55
+ expiry_timestamp = time.mktime(time.strptime(expiry, "%Y-%m-%dT%H:%M:%SZ"))
56
+
57
+ return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
58
+
59
+ async def _refresh_token(self, path: str, creds: Dict[str, Any]) -> Dict[str, Any]:
60
+ async with self._get_lock(path):
61
+ if not self._is_token_expired(self._credentials_cache.get(path, creds)):
62
+ return self._credentials_cache.get(path, creds)
63
+
64
+ lib_logger.info(f"Refreshing Gemini OAuth token for '{Path(path).name}'...")
65
+ refresh_token = creds.get("refresh_token")
66
+ if not refresh_token:
67
+ raise ValueError("No refresh_token found in credentials file.")
68
+
69
+ async with httpx.AsyncClient() as client:
70
+ response = await client.post(TOKEN_URI, data={
71
+ "client_id": creds.get("client_id", CLIENT_ID),
72
+ "client_secret": creds.get("client_secret", CLIENT_SECRET),
73
+ "refresh_token": refresh_token,
74
+ "grant_type": "refresh_token",
75
+ })
76
+ response.raise_for_status()
77
+ new_token_data = response.json()
78
+
79
+ creds["access_token"] = new_token_data["access_token"]
80
+ expiry_timestamp = time.time() + new_token_data["expires_in"]
81
+ creds["expiry_date"] = expiry_timestamp * 1000 # gemini-cli format
82
+ creds["token_expiry"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(expiry_timestamp)) # gcloud format
83
+
84
+ await self._save_credentials(path, creds)
85
+ lib_logger.info(f"Successfully refreshed Gemini OAuth token for '{Path(path).name}'.")
86
+ return creds
87
+
88
+ async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
89
+ creds = await self._load_credentials(credential_path)
90
+ if self._is_token_expired(creds):
91
+ creds = await self._refresh_token(credential_path, creds)
92
+ return {"Authorization": f"Bearer {creds['access_token']}"}
93
+
94
+ async def proactively_refresh(self, credential_path: str):
95
+ creds = await self._load_credentials(credential_path)
96
+ if self._is_token_expired(creds):
97
+ await self._refresh_token(credential_path, creds)
98
+
99
+ def _get_lock(self, path: str) -> asyncio.Lock:
100
+ if path not in self._refresh_locks:
101
+ self._refresh_locks[path] = asyncio.Lock()
102
+ return self._refresh_locks[path]
src/rotator_library/providers/gemini_cli_provider.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/rotator_library/providers/gemini_cli_provider.py
2
+
3
+ import json
4
+ import httpx
5
+ import logging
6
+ import time
7
+ from typing import List, Dict, Any, AsyncGenerator, Union, Optional
8
+ from .provider_interface import ProviderInterface
9
+ from .gemini_auth_base import GeminiAuthBase
10
+ import litellm
11
+ import os
12
+ from pathlib import Path
13
+
14
+ lib_logger = logging.getLogger('rotator_library')
15
+
16
+ CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal"
17
+
18
+ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.project_id: Optional[str] = None
22
+
23
+ async def _discover_project_id(self, litellm_params: Dict[str, Any]) -> str:
24
+ """Discovers the Google Cloud Project ID."""
25
+ if self.project_id:
26
+ return self.project_id
27
+
28
+ # 1. Prioritize explicitly configured project_id
29
+ if litellm_params.get("project_id"):
30
+ self.project_id = litellm_params["project_id"]
31
+ lib_logger.info(f"Using configured Gemini CLI project ID: {self.project_id}")
32
+ return self.project_id
33
+
34
+ # 2. Fallback: Look for .env file in the standard .gemini directory
35
+ try:
36
+ gemini_env_path = Path.home() / ".gemini" / ".env"
37
+ if gemini_env_path.exists():
38
+ with open(gemini_env_path, 'r') as f:
39
+ for line in f:
40
+ if line.startswith("GOOGLE_CLOUD_PROJECT="):
41
+ self.project_id = line.strip().split("=")[1]
42
+ lib_logger.info(f"Discovered Gemini CLI project ID from ~/.gemini/.env: {self.project_id}")
43
+ return self.project_id
44
+ except Exception as e:
45
+ lib_logger.warning(f"Could not read project ID from ~/.gemini/.env: {e}")
46
+
47
+ raise ValueError(
48
+ "Gemini CLI project ID not found. Please set `GEMINI_CLI_PROJECT_ID` in your main .env file "
49
+ "or ensure it is present in `~/.gemini/.env`."
50
+ )
51
+ def has_custom_logic(self) -> bool:
52
+ return True
53
+
54
+ def _transform_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
55
+ # As seen in Kilo examples, system prompts are injected into the first user message.
56
+ gemini_contents = []
57
+ system_prompt = ""
58
+ if messages and messages[0].get('role') == 'system':
59
+ system_prompt = messages.pop(0).get('content', '')
60
+
61
+ for msg in messages:
62
+ role = "model" if msg.get("role") == "assistant" else "user"
63
+ content = msg.get("content", "")
64
+ if system_prompt and role == "user":
65
+ content = f"{system_prompt}\n\n{content}"
66
+ system_prompt = "" # Inject only once
67
+ gemini_contents.append({"role": role, "parts": [{"text": content}]})
68
+ return gemini_contents
69
+
70
+ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str) -> dict:
71
+ response_data = chunk.get('response', chunk)
72
+ candidate = response_data.get('candidates', [{}])[0]
73
+
74
+ delta = {}
75
+ finish_reason = None
76
+
77
+ # Correctly handle reasoning vs. content based on 'thought' flag from Kilo example
78
+ if 'content' in candidate and 'parts' in candidate['content']:
79
+ part = candidate['content']['parts'][0]
80
+ if part.get('text'):
81
+ if part.get('thought') is True:
82
+ # This is a reasoning/thinking step
83
+ delta['reasoning_content'] = part['text']
84
+ else:
85
+ delta['content'] = part['text']
86
+
87
+ raw_finish_reason = candidate.get('finishReason')
88
+ if raw_finish_reason:
89
+ mapping = {'STOP': 'stop', 'MAX_TOKENS': 'length', 'SAFETY': 'content_filter'}
90
+ finish_reason = mapping.get(raw_finish_reason, 'stop')
91
+
92
+ choice = {"index": 0, "delta": delta, "finish_reason": finish_reason}
93
+
94
+ openai_chunk = {
95
+ "choices": [choice], "model": model_id, "object": "chat.completion.chunk",
96
+ "id": f"chatcmpl-geminicli-{time.time()}", "created": int(time.time())
97
+ }
98
+
99
+ if 'usageMetadata' in response_data:
100
+ usage = response_data['usageMetadata']
101
+ openai_chunk["usage"] = {
102
+ "prompt_tokens": usage.get("promptTokenCount", 0),
103
+ "completion_tokens": usage.get("candidatesTokenCount", 0),
104
+ "total_tokens": usage.get("totalTokenCount", 0),
105
+ }
106
+
107
+ return openai_chunk
108
+
109
+ async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
110
+ model = kwargs["model"]
111
+ credential_path = kwargs.pop("credential_identifier")
112
+ auth_header = await self.get_auth_header(credential_path)
113
+
114
+ project_id = await self._discover_project_id(kwargs.get("litellm_params", {}))
115
+
116
+ # Handle :thinking suffix from Kilo example
117
+ model_name = model.split('/')[-1]
118
+ enable_thinking = model_name.endswith(':thinking')
119
+ if enable_thinking:
120
+ model_name = model_name.replace(':thinking', '')
121
+
122
+ gen_config = {
123
+ "temperature": kwargs.get("temperature", 0.7),
124
+ "maxOutputTokens": kwargs.get("max_tokens", 8192),
125
+ }
126
+ if enable_thinking:
127
+ gen_config["thinkingConfig"] = {"thinkingBudget": -1}
128
+
129
+ request_payload = {
130
+ "model": model_name,
131
+ "project": project_id,
132
+ "request": {
133
+ "contents": self._transform_messages(kwargs.get("messages", [])),
134
+ "generationConfig": gen_config,
135
+ },
136
+ }
137
+
138
+ url = f"{CODE_ASSIST_ENDPOINT}:streamGenerateContent"
139
+
140
+ async def stream_handler():
141
+ async with client.stream("POST", url, headers=auth_header, json=request_payload, params={"alt": "sse"}, timeout=600) as response:
142
+ response.raise_for_status()
143
+ async for line in response.aiter_lines():
144
+ if line.startswith('data: '):
145
+ data_str = line[6:]
146
+ if data_str == "[DONE]": break
147
+ try:
148
+ chunk = json.loads(data_str)
149
+ openai_chunk = self._convert_chunk_to_openai(chunk, model)
150
+ yield litellm.ModelResponse(**openai_chunk)
151
+ except json.JSONDecodeError:
152
+ lib_logger.warning(f"Could not decode JSON from Gemini CLI: {line}")
153
+
154
+ if kwargs.get("stream", False):
155
+ return stream_handler()
156
+ else:
157
+ # Accumulate stream for non-streaming response
158
+ chunks = [chunk async for chunk in stream_handler()]
159
+ return litellm.utils.stream_to_completion_response(chunks)
160
+
161
+ # [NEW] Hardcoded model list based on Kilo example
162
+ HARDCODED_MODELS = [
163
+ "gemini-2.5-pro",
164
+ "gemini-2.5-flash",
165
+ "gemini-2.5-flash-lite"
166
+ ]
167
+ # Use the shared GeminiAuthBase for auth logic
168
+ # get_models is not applicable for this custom provider
169
+ async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
170
+ """Returns a hardcoded list of known compatible Gemini CLI models."""
171
+ return [f"gemini_cli/{model_id}" for model_id in HARDCODED_MODELS]
src/rotator_library/providers/provider_interface.py CHANGED
@@ -1,13 +1,14 @@
1
  from abc import ABC, abstractmethod
2
- from typing import List, Dict, Any
3
  import httpx
 
4
 
5
  class ProviderInterface(ABC):
6
  """
7
- An interface for API provider-specific functionality, primarily for discovering
8
- available models.
9
  """
10
-
11
  @abstractmethod
12
  async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
13
  """
@@ -22,7 +23,25 @@ class ProviderInterface(ABC):
22
  """
23
  pass
24
 
25
- def convert_safety_settings(self, settings: Dict[str, str]) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  """
27
  Converts a generic safety settings dictionary to the provider-specific format.
28
 
@@ -33,3 +52,17 @@ class ProviderInterface(ABC):
33
  A list of provider-specific safety setting objects or None.
34
  """
35
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from abc import ABC, abstractmethod
2
+ from typing import List, Dict, Any, Optional, AsyncGenerator, Union
3
  import httpx
4
+ import litellm
5
 
6
  class ProviderInterface(ABC):
7
  """
8
+ An interface for API provider-specific functionality, including model
9
+ discovery and custom API call handling for non-standard providers.
10
  """
11
+
12
  @abstractmethod
13
  async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
14
  """
 
23
  """
24
  pass
25
 
26
+ # [NEW] Add methods for providers that need to bypass litellm
27
+ def has_custom_logic(self) -> bool:
28
+ """
29
+ Returns True if the provider implements its own acompletion/aembedding logic,
30
+ bypassing the standard litellm call.
31
+ """
32
+ return False
33
+
34
+ async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
35
+ """
36
+ Handles the entire completion call for non-standard providers.
37
+ """
38
+ raise NotImplementedError(f"{self.__class__.__name__} does not implement custom acompletion.")
39
+
40
+ async def aembedding(self, client: httpx.AsyncClient, **kwargs) -> litellm.EmbeddingResponse:
41
+ """Handles the entire embedding call for non-standard providers."""
42
+ raise NotImplementedError(f"{self.__class__.__name__} does not implement custom aembedding.")
43
+
44
+ def convert_safety_settings(self, settings: Dict[str, str]) -> Optional[List[Dict[str, Any]]]:
45
  """
46
  Converts a generic safety settings dictionary to the provider-specific format.
47
 
 
52
  A list of provider-specific safety setting objects or None.
53
  """
54
  return None
55
+
56
+ # [NEW] Add new methods for OAuth providers
57
+ async def get_auth_header(self, credential_identifier: str) -> Dict[str, str]:
58
+ """
59
+ For OAuth providers, this method returns the Authorization header.
60
+ For API key providers, this can be a no-op or raise NotImplementedError.
61
+ """
62
+ raise NotImplementedError("This provider does not support OAuth.")
63
+
64
+ async def proactively_refresh(self, credential_path: str):
65
+ """
66
+ Proactively refreshes a token if it's nearing expiry.
67
+ """
68
+ pass
src/rotator_library/providers/qwen_auth_base.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/rotator_library/providers/qwen_auth_base.py
2
+
3
+ import json
4
+ import time
5
+ import asyncio
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Dict, Any, Tuple
9
+
10
+ import httpx
11
+
12
+ lib_logger = logging.getLogger('rotator_library')
13
+
14
+ CLIENT_ID = "f0304373b74a44d2b584a3fb70ca9e56"
15
+ TOKEN_ENDPOINT = "https://chat.qwen.ai/api/v1/oauth2/token"
16
+ REFRESH_EXPIRY_BUFFER_SECONDS = 300
17
+
18
+ class QwenAuthBase:
19
+ def __init__(self):
20
+ self._credentials_cache: Dict[str, Dict[str, Any]] = {}
21
+ self._refresh_locks: Dict[str, asyncio.Lock] = {}
22
+
23
+ async def _load_credentials(self, path: str) -> Dict[str, Any]:
24
+ if path in self._credentials_cache:
25
+ return self._credentials_cache[path]
26
+
27
+ async with self._get_lock(path):
28
+ if path in self._credentials_cache:
29
+ return self._credentials_cache[path]
30
+ try:
31
+ with open(path, 'r') as f:
32
+ creds = json.load(f)
33
+ self._credentials_cache[path] = creds
34
+ return creds
35
+ except Exception as e:
36
+ raise IOError(f"Failed to load Qwen OAuth credentials from '{path}': {e}")
37
+
38
+ async def _save_credentials(self, path: str, creds: Dict[str, Any]):
39
+ self._credentials_cache[path] = creds
40
+ try:
41
+ with open(path, 'w') as f:
42
+ json.dump(creds, f, indent=2)
43
+ except Exception as e:
44
+ lib_logger.error(f"Failed to save updated Qwen OAuth credentials to '{path}': {e}")
45
+
46
+ def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
47
+ expiry_timestamp = creds.get("expiry_date", 0) / 1000
48
+ return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
49
+
50
+ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any]:
51
+ async with self._get_lock(path):
52
+ cached_creds = self._credentials_cache.get(path)
53
+ if not force and cached_creds and not self._is_token_expired(cached_creds):
54
+ return cached_creds
55
+
56
+ creds_from_file = await self._load_credentials(path)
57
+
58
+ lib_logger.info(f"Refreshing Qwen OAuth token for '{Path(path).name}'...")
59
+ refresh_token = creds_from_file.get("refresh_token")
60
+ if not refresh_token:
61
+ raise ValueError("No refresh_token found in Qwen credentials file.")
62
+
63
+ async with httpx.AsyncClient() as client:
64
+ response = await client.post(TOKEN_ENDPOINT, data={
65
+ "grant_type": "refresh_token",
66
+ "refresh_token": refresh_token,
67
+ "client_id": CLIENT_ID,
68
+ })
69
+ response.raise_for_status()
70
+ new_token_data = response.json()
71
+
72
+ creds_from_file["access_token"] = new_token_data["access_token"]
73
+ creds_from_file["refresh_token"] = new_token_data.get("refresh_token", creds_from_file["refresh_token"])
74
+ creds_from_file["expiry_date"] = (time.time() + new_token_data["expires_in"]) * 1000
75
+
76
+ await self._save_credentials(path, creds_from_file)
77
+ lib_logger.info(f"Successfully refreshed Qwen OAuth token for '{Path(path).name}'.")
78
+ return creds_from_file
79
+
80
+ async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
81
+ creds = await self._load_credentials(credential_path)
82
+ if self._is_token_expired(creds):
83
+ creds = await self._refresh_token(credential_path)
84
+ return {"Authorization": f"Bearer {creds['access_token']}"}
85
+
86
+ def get_api_details(self, credential_path: str) -> Tuple[str, str]:
87
+ creds = self._credentials_cache[credential_path]
88
+ base_url = creds.get("resource_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
89
+ if not base_url.startswith("http"):
90
+ base_url = f"https://{base_url}"
91
+ return base_url, creds["access_token"]
92
+
93
+ async def proactively_refresh(self, credential_path: str):
94
+ creds = await self._load_credentials(credential_path)
95
+ if self._is_token_expired(creds):
96
+ await self._refresh_token(credential_path)
97
+
98
+ def _get_lock(self, path: str) -> asyncio.Lock:
99
+ if path not in self._refresh_locks:
100
+ self._refresh_locks[path] = asyncio.Lock()
101
+ return self._refresh_locks[path]
src/rotator_library/providers/qwen_code_provider.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/rotator_library/providers/qwen_code_provider.py
2
+
3
+ import httpx
4
+ import logging
5
+ from typing import Union, AsyncGenerator
6
+ from .provider_interface import ProviderInterface
7
+ from .qwen_auth_base import QwenAuthBase
8
+ import litellm
9
+
10
+ lib_logger = logging.getLogger('rotator_library')
11
+
12
+ # [NEW] Hardcoded model list based on Kilo example
13
+ HARDCODED_MODELS = [
14
+ "qwen3-coder-plus",
15
+ "qwen3-coder-flash"
16
+ ]
17
+
18
+ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
19
+ def has_custom_logic(self) -> bool:
20
+ return True # We use custom logic to handle 401 retries and stream parsing
21
+
22
+ # [NEW] get_models implementation
23
+ async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
24
+ """Returns a hardcoded list of known compatible Qwen models for the OpenAI-compatible API."""
25
+ return [f"qwen_code/{model_id}" for model_id in HARDCODED_MODELS]
26
+
27
+ async def _stream_parser(self, stream: AsyncGenerator, model_id: str) -> AsyncGenerator:
28
+ """Parses the stream from litellm to handle Qwen's <think> tags."""
29
+ async for chunk in stream:
30
+ content = chunk.choices[0].delta.content
31
+ if content and ("<think>" in content or "</think>" in content):
32
+ parts = content.replace("<think>", "||THINK||").replace("</think>", "||/THINK||").split("||")
33
+ for part in parts:
34
+ if not part: continue
35
+ new_chunk = chunk.copy()
36
+ if part.startswith("THINK||"):
37
+ new_chunk.choices[0].delta.reasoning_content = part.replace("THINK||", "")
38
+ new_chunk.choices[0].delta.content = None
39
+ elif part.startswith("/THINK||"):
40
+ continue # Ignore closing tag
41
+ else:
42
+ new_chunk.choices[0].delta.content = part
43
+ new_chunk.choices[0].delta.reasoning_content = None
44
+ yield new_chunk
45
+ else:
46
+ yield chunk
47
+
48
+ async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
49
+ credential_path = kwargs.pop("credential_identifier")
50
+ model = kwargs["model"]
51
+
52
+ async def do_call():
53
+ api_base, access_token = self.get_api_details(credential_path)
54
+ response = await litellm.acompletion(
55
+ **kwargs, api_key=access_token, api_base=api_base
56
+ )
57
+ return response
58
+
59
+ try:
60
+ response = await do_call()
61
+ except litellm.AuthenticationError as e:
62
+ if "401" in str(e):
63
+ lib_logger.warning("Qwen Code returned 401. Forcing token refresh and retrying once.")
64
+ await self._refresh_token(credential_path, force=True)
65
+ response = await do_call()
66
+ else:
67
+ raise e
68
+
69
+ if kwargs.get("stream"):
70
+ return self._stream_parser(response, model)
71
+ return response