Mirrowel commited on
Commit
79e83ae
·
1 Parent(s): 3962356

feat(core): enable configuration of maximum concurrent requests per key

Browse files

This introduces functionality to allow multiple concurrent requests to utilize the same API key, which is necessary when a provider's capacity allows for parallel usage (e.g., modern OpenAI tiers).

The `UsageManager` is updated to track concurrent request counts per model per key, moving from a simple busy/idle state to a counter.

- New environment variables (`MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER>`) define the maximum concurrency limit for keys of a specific provider.
- The default limit is 1, maintaining the previous behavior (no concurrency).
- Updates provider endpoint resolution to support loading custom API bases via environment variables (e.g., `CUSTOM_API_BASE`) if the provider is not hardcoded.

.env.example CHANGED
@@ -139,6 +139,21 @@ IGNORE_MODELS_OPENAI=""
139
  WHITELIST_MODELS_GEMINI=""
140
  WHITELIST_MODELS_OPENAI=""
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  # ------------------------------------------------------------------------------
144
  # | [ADVANCED] Proxy Configuration |
 
139
  WHITELIST_MODELS_GEMINI=""
140
  WHITELIST_MODELS_OPENAI=""
141
 
142
+ # --- Maximum Concurrent Requests Per Key ---
143
+ # Controls how many concurrent requests for the SAME model can use the SAME key.
144
+ # This is useful for providers that can handle concurrent requests without rate limiting.
145
+ # Default is 1 (no concurrency, current behavior).
146
+ #
147
+ # Format: MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER_NAME>=<number>
148
+ #
149
+ # Example:
150
+ # MAX_CONCURRENT_REQUESTS_PER_KEY_OPENAI=3 # Allow 3 concurrent requests per OpenAI key
151
+ # MAX_CONCURRENT_REQUESTS_PER_KEY_GEMINI=1 # Allow only 1 request per Gemini key (default)
152
+ #
153
+ MAX_CONCURRENT_REQUESTS_PER_KEY_OPENAI=1
154
+ MAX_CONCURRENT_REQUESTS_PER_KEY_GEMINI=1
155
+ MAX_CONCURRENT_REQUESTS_PER_KEY_ANTHROPIC=1
156
+
157
 
158
  # ------------------------------------------------------------------------------
159
  # | [ADVANCED] Proxy Configuration |
src/proxy_app/main.py CHANGED
@@ -163,6 +163,21 @@ for key, value in os.environ.items():
163
  whitelist_models[provider] = models_to_whitelist
164
  logging.debug(f"Loaded whitelist for provider '{provider}': {models_to_whitelist}")
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  # --- Lifespan Management ---
167
  @asynccontextmanager
168
  async def lifespan(app: FastAPI):
@@ -282,7 +297,8 @@ async def lifespan(app: FastAPI):
282
  litellm_provider_params=litellm_provider_params,
283
  ignore_models=ignore_models,
284
  whitelist_models=whitelist_models,
285
- enable_request_logging=ENABLE_REQUEST_LOGGING
 
286
  )
287
  client.background_refresher.start() # Start the background task
288
  app.state.rotating_client = client
 
163
  whitelist_models[provider] = models_to_whitelist
164
  logging.debug(f"Loaded whitelist for provider '{provider}': {models_to_whitelist}")
165
 
166
+ # Load max concurrent requests per key from environment variables
167
+ max_concurrent_requests_per_key = {}
168
+ for key, value in os.environ.items():
169
+ if key.startswith("MAX_CONCURRENT_REQUESTS_PER_KEY_"):
170
+ provider = key.replace("MAX_CONCURRENT_REQUESTS_PER_KEY_", "").lower()
171
+ try:
172
+ max_concurrent = int(value)
173
+ if max_concurrent < 1:
174
+ logging.warning(f"Invalid max_concurrent value for provider '{provider}': {value}. Must be >= 1. Using default (1).")
175
+ max_concurrent = 1
176
+ max_concurrent_requests_per_key[provider] = max_concurrent
177
+ logging.debug(f"Loaded max concurrent requests for provider '{provider}': {max_concurrent}")
178
+ except ValueError:
179
+ logging.warning(f"Invalid max_concurrent value for provider '{provider}': {value}. Using default (1).")
180
+
181
  # --- Lifespan Management ---
182
  @asynccontextmanager
183
  async def lifespan(app: FastAPI):
 
297
  litellm_provider_params=litellm_provider_params,
298
  ignore_models=ignore_models,
299
  whitelist_models=whitelist_models,
300
+ enable_request_logging=ENABLE_REQUEST_LOGGING,
301
+ max_concurrent_requests_per_key=max_concurrent_requests_per_key
302
  )
303
  client.background_refresher.start() # Start the background task
304
  app.state.rotating_client = client
src/proxy_app/provider_urls.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Optional
2
 
3
  # A comprehensive map of provider names to their base URLs.
@@ -31,10 +32,17 @@ PROVIDER_URL_MAP = {
31
  def get_provider_endpoint(provider: str, model_name: str, incoming_path: str) -> Optional[str]:
32
  """
33
  Constructs the full provider endpoint URL based on the provider and incoming request path.
 
34
  """
 
35
  base_url = PROVIDER_URL_MAP.get(provider)
 
 
36
  if not base_url:
37
- return None
 
 
 
38
 
39
  # Determine the specific action from the incoming path (e.g., 'chat/completions')
40
  action = incoming_path.split('/v1/', 1)[-1] if '/v1/' in incoming_path else incoming_path
 
1
+ import os
2
  from typing import Optional
3
 
4
  # A comprehensive map of provider names to their base URLs.
 
32
  def get_provider_endpoint(provider: str, model_name: str, incoming_path: str) -> Optional[str]:
33
  """
34
  Constructs the full provider endpoint URL based on the provider and incoming request path.
35
+ Supports both hardcoded providers and custom OpenAI-compatible providers via environment variables.
36
  """
37
+ # First, check the hardcoded map
38
  base_url = PROVIDER_URL_MAP.get(provider)
39
+
40
+ # If not found, check for custom provider via environment variable
41
  if not base_url:
42
+ api_base_env = f"{provider.upper()}_API_BASE"
43
+ base_url = os.getenv(api_base_env)
44
+ if not base_url:
45
+ return None
46
 
47
  # Determine the specific action from the incoming path (e.g., 'chat/completions')
48
  action = incoming_path.split('/v1/', 1)[-1] if '/v1/' in incoming_path else incoming_path
src/rotator_library/client.py CHANGED
@@ -61,6 +61,7 @@ class RotatingClient:
61
  ignore_models: Optional[Dict[str, List[str]]] = None,
62
  whitelist_models: Optional[Dict[str, List[str]]] = None,
63
  enable_request_logging: bool = False,
 
64
  ):
65
  os.environ["LITELLM_LOG"] = "ERROR"
66
  litellm.set_verbose = False
@@ -118,6 +119,14 @@ class RotatingClient:
118
  self.whitelist_models = whitelist_models or {}
119
  self.enable_request_logging = enable_request_logging
120
 
 
 
 
 
 
 
 
 
121
  def _is_model_ignored(self, provider: str, model_id: str) -> bool:
122
  """
123
  Checks if a model should be ignored based on the ignore list.
@@ -576,8 +585,10 @@ class RotatingClient:
576
  lib_logger.info(
577
  f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{len(credentials_for_provider)}"
578
  )
 
579
  current_cred = await self.usage_manager.acquire_key(
580
- available_keys=creds_to_try, model=model, deadline=deadline
 
581
  )
582
  key_acquired = True
583
  tried_creds.add(current_cred)
@@ -918,8 +929,10 @@ class RotatingClient:
918
  lib_logger.info(
919
  f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{len(credentials_for_provider)}"
920
  )
 
921
  current_cred = await self.usage_manager.acquire_key(
922
- available_keys=creds_to_try, model=model, deadline=deadline
 
923
  )
924
  key_acquired = True
925
  tried_creds.add(current_cred)
 
61
  ignore_models: Optional[Dict[str, List[str]]] = None,
62
  whitelist_models: Optional[Dict[str, List[str]]] = None,
63
  enable_request_logging: bool = False,
64
+ max_concurrent_requests_per_key: Optional[Dict[str, int]] = None,
65
  ):
66
  os.environ["LITELLM_LOG"] = "ERROR"
67
  litellm.set_verbose = False
 
119
  self.whitelist_models = whitelist_models or {}
120
  self.enable_request_logging = enable_request_logging
121
 
122
+ # Store and validate max concurrent requests per key
123
+ self.max_concurrent_requests_per_key = max_concurrent_requests_per_key or {}
124
+ # Validate all values are >= 1
125
+ for provider, max_val in self.max_concurrent_requests_per_key.items():
126
+ if max_val < 1:
127
+ lib_logger.warning(f"Invalid max_concurrent for '{provider}': {max_val}. Setting to 1.")
128
+ self.max_concurrent_requests_per_key[provider] = 1
129
+
130
  def _is_model_ignored(self, provider: str, model_id: str) -> bool:
131
  """
132
  Checks if a model should be ignored based on the ignore list.
 
585
  lib_logger.info(
586
  f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{len(credentials_for_provider)}"
587
  )
588
+ max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1)
589
  current_cred = await self.usage_manager.acquire_key(
590
+ available_keys=creds_to_try, model=model, deadline=deadline,
591
+ max_concurrent=max_concurrent
592
  )
593
  key_acquired = True
594
  tried_creds.add(current_cred)
 
929
  lib_logger.info(
930
  f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{len(credentials_for_provider)}"
931
  )
932
+ max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1)
933
  current_cred = await self.usage_manager.acquire_key(
934
+ available_keys=creds_to_try, model=model, deadline=deadline,
935
+ max_concurrent=max_concurrent
936
  )
937
  key_acquired = True
938
  tried_creds.add(current_cred)
src/rotator_library/usage_manager.py CHANGED
@@ -157,11 +157,12 @@ class UsageManager:
157
  self.key_states[key] = {
158
  "lock": asyncio.Lock(),
159
  "condition": asyncio.Condition(),
160
- "models_in_use": set(),
161
  }
162
 
163
  async def acquire_key(
164
- self, available_keys: List[str], model: str, deadline: float
 
165
  ) -> str:
166
  """
167
  Acquires the best available key using a tiered, model-aware locking strategy,
@@ -198,8 +199,8 @@ class UsageManager:
198
  # Tier 1: Completely idle keys (preferred).
199
  if not key_state["models_in_use"]:
200
  tier1_keys.append((key, usage_count))
201
- # Tier 2: Keys busy with other models, but free for this one.
202
- elif model not in key_state["models_in_use"]:
203
  tier2_keys.append((key, usage_count))
204
 
205
  tier1_keys.sort(key=lambda x: x[1])
@@ -210,7 +211,7 @@ class UsageManager:
210
  state = self.key_states[key]
211
  async with state["lock"]:
212
  if not state["models_in_use"]:
213
- state["models_in_use"].add(model)
214
  lib_logger.info(
215
  f"Acquired Tier 1 key ...{key[-6:]} for model {model}"
216
  )
@@ -220,10 +221,12 @@ class UsageManager:
220
  for key, _ in tier2_keys:
221
  state = self.key_states[key]
222
  async with state["lock"]:
223
- if model not in state["models_in_use"]:
224
- state["models_in_use"].add(model)
 
225
  lib_logger.info(
226
- f"Acquired Tier 2 key ...{key[-6:]} for model {model}"
 
227
  )
228
  return key
229
 
@@ -271,8 +274,14 @@ class UsageManager:
271
  state = self.key_states[key]
272
  async with state["lock"]:
273
  if model in state["models_in_use"]:
274
- state["models_in_use"].remove(model)
275
- lib_logger.info(f"Released credential ...{key[-6:]} from model {model}")
 
 
 
 
 
 
276
  else:
277
  lib_logger.warning(
278
  f"Attempted to release credential ...{key[-6:]} for model {model}, but it was not in use."
 
157
  self.key_states[key] = {
158
  "lock": asyncio.Lock(),
159
  "condition": asyncio.Condition(),
160
+ "models_in_use": {}, # Dict[model_name, concurrent_count]
161
  }
162
 
163
  async def acquire_key(
164
+ self, available_keys: List[str], model: str, deadline: float,
165
+ max_concurrent: int = 1
166
  ) -> str:
167
  """
168
  Acquires the best available key using a tiered, model-aware locking strategy,
 
199
  # Tier 1: Completely idle keys (preferred).
200
  if not key_state["models_in_use"]:
201
  tier1_keys.append((key, usage_count))
202
+ # Tier 2: Keys that can accept more concurrent requests for this model.
203
+ elif key_state["models_in_use"].get(model, 0) < max_concurrent:
204
  tier2_keys.append((key, usage_count))
205
 
206
  tier1_keys.sort(key=lambda x: x[1])
 
211
  state = self.key_states[key]
212
  async with state["lock"]:
213
  if not state["models_in_use"]:
214
+ state["models_in_use"][model] = 1
215
  lib_logger.info(
216
  f"Acquired Tier 1 key ...{key[-6:]} for model {model}"
217
  )
 
221
  for key, _ in tier2_keys:
222
  state = self.key_states[key]
223
  async with state["lock"]:
224
+ current_count = state["models_in_use"].get(model, 0)
225
+ if current_count < max_concurrent:
226
+ state["models_in_use"][model] = current_count + 1
227
  lib_logger.info(
228
+ f"Acquired Tier 2 key ...{key[-6:]} for model {model} "
229
+ f"(concurrent: {state['models_in_use'][model]}/{max_concurrent})"
230
  )
231
  return key
232
 
 
274
  state = self.key_states[key]
275
  async with state["lock"]:
276
  if model in state["models_in_use"]:
277
+ state["models_in_use"][model] -= 1
278
+ remaining = state["models_in_use"][model]
279
+ if remaining <= 0:
280
+ del state["models_in_use"][model] # Clean up when count reaches 0
281
+ lib_logger.info(
282
+ f"Released credential ...{key[-6:]} from model {model} "
283
+ f"(remaining concurrent: {max(0, remaining)})"
284
+ )
285
  else:
286
  lib_logger.warning(
287
  f"Attempted to release credential ...{key[-6:]} for model {model}, but it was not in use."