Spaces:
Paused
Paused
Mirrowel commited on
Commit ·
79e83ae
1
Parent(s): 3962356
feat(core): enable configuration of maximum concurrent requests per key
Browse filesThis introduces functionality to allow multiple concurrent requests to utilize the same API key, which is necessary when a provider's capacity allows for parallel usage (e.g., modern OpenAI tiers).
The `UsageManager` is updated to track concurrent request counts per model per key, moving from a simple busy/idle state to a counter.
- New environment variables (`MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER>`) define the maximum concurrency limit for keys of a specific provider.
- The default limit is 1, maintaining the previous behavior (no concurrency).
- Updates provider endpoint resolution to support loading custom API bases via environment variables (e.g., `CUSTOM_API_BASE`) if the provider is not hardcoded.
- .env.example +15 -0
- src/proxy_app/main.py +17 -1
- src/proxy_app/provider_urls.py +9 -1
- src/rotator_library/client.py +15 -2
- src/rotator_library/usage_manager.py +19 -10
.env.example
CHANGED
|
@@ -139,6 +139,21 @@ IGNORE_MODELS_OPENAI=""
|
|
| 139 |
WHITELIST_MODELS_GEMINI=""
|
| 140 |
WHITELIST_MODELS_OPENAI=""
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
# ------------------------------------------------------------------------------
|
| 144 |
# | [ADVANCED] Proxy Configuration |
|
|
|
|
| 139 |
WHITELIST_MODELS_GEMINI=""
|
| 140 |
WHITELIST_MODELS_OPENAI=""
|
| 141 |
|
| 142 |
+
# --- Maximum Concurrent Requests Per Key ---
|
| 143 |
+
# Controls how many concurrent requests for the SAME model can use the SAME key.
|
| 144 |
+
# This is useful for providers that can handle concurrent requests without rate limiting.
|
| 145 |
+
# Default is 1 (no concurrency, current behavior).
|
| 146 |
+
#
|
| 147 |
+
# Format: MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER_NAME>=<number>
|
| 148 |
+
#
|
| 149 |
+
# Example:
|
| 150 |
+
# MAX_CONCURRENT_REQUESTS_PER_KEY_OPENAI=3 # Allow 3 concurrent requests per OpenAI key
|
| 151 |
+
# MAX_CONCURRENT_REQUESTS_PER_KEY_GEMINI=1 # Allow only 1 request per Gemini key (default)
|
| 152 |
+
#
|
| 153 |
+
MAX_CONCURRENT_REQUESTS_PER_KEY_OPENAI=1
|
| 154 |
+
MAX_CONCURRENT_REQUESTS_PER_KEY_GEMINI=1
|
| 155 |
+
MAX_CONCURRENT_REQUESTS_PER_KEY_ANTHROPIC=1
|
| 156 |
+
|
| 157 |
|
| 158 |
# ------------------------------------------------------------------------------
|
| 159 |
# | [ADVANCED] Proxy Configuration |
|
src/proxy_app/main.py
CHANGED
|
@@ -163,6 +163,21 @@ for key, value in os.environ.items():
|
|
| 163 |
whitelist_models[provider] = models_to_whitelist
|
| 164 |
logging.debug(f"Loaded whitelist for provider '{provider}': {models_to_whitelist}")
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
# --- Lifespan Management ---
|
| 167 |
@asynccontextmanager
|
| 168 |
async def lifespan(app: FastAPI):
|
|
@@ -282,7 +297,8 @@ async def lifespan(app: FastAPI):
|
|
| 282 |
litellm_provider_params=litellm_provider_params,
|
| 283 |
ignore_models=ignore_models,
|
| 284 |
whitelist_models=whitelist_models,
|
| 285 |
-
enable_request_logging=ENABLE_REQUEST_LOGGING
|
|
|
|
| 286 |
)
|
| 287 |
client.background_refresher.start() # Start the background task
|
| 288 |
app.state.rotating_client = client
|
|
|
|
| 163 |
whitelist_models[provider] = models_to_whitelist
|
| 164 |
logging.debug(f"Loaded whitelist for provider '{provider}': {models_to_whitelist}")
|
| 165 |
|
| 166 |
+
# Load max concurrent requests per key from environment variables
|
| 167 |
+
max_concurrent_requests_per_key = {}
|
| 168 |
+
for key, value in os.environ.items():
|
| 169 |
+
if key.startswith("MAX_CONCURRENT_REQUESTS_PER_KEY_"):
|
| 170 |
+
provider = key.replace("MAX_CONCURRENT_REQUESTS_PER_KEY_", "").lower()
|
| 171 |
+
try:
|
| 172 |
+
max_concurrent = int(value)
|
| 173 |
+
if max_concurrent < 1:
|
| 174 |
+
logging.warning(f"Invalid max_concurrent value for provider '{provider}': {value}. Must be >= 1. Using default (1).")
|
| 175 |
+
max_concurrent = 1
|
| 176 |
+
max_concurrent_requests_per_key[provider] = max_concurrent
|
| 177 |
+
logging.debug(f"Loaded max concurrent requests for provider '{provider}': {max_concurrent}")
|
| 178 |
+
except ValueError:
|
| 179 |
+
logging.warning(f"Invalid max_concurrent value for provider '{provider}': {value}. Using default (1).")
|
| 180 |
+
|
| 181 |
# --- Lifespan Management ---
|
| 182 |
@asynccontextmanager
|
| 183 |
async def lifespan(app: FastAPI):
|
|
|
|
| 297 |
litellm_provider_params=litellm_provider_params,
|
| 298 |
ignore_models=ignore_models,
|
| 299 |
whitelist_models=whitelist_models,
|
| 300 |
+
enable_request_logging=ENABLE_REQUEST_LOGGING,
|
| 301 |
+
max_concurrent_requests_per_key=max_concurrent_requests_per_key
|
| 302 |
)
|
| 303 |
client.background_refresher.start() # Start the background task
|
| 304 |
app.state.rotating_client = client
|
src/proxy_app/provider_urls.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from typing import Optional
|
| 2 |
|
| 3 |
# A comprehensive map of provider names to their base URLs.
|
|
@@ -31,10 +32,17 @@ PROVIDER_URL_MAP = {
|
|
| 31 |
def get_provider_endpoint(provider: str, model_name: str, incoming_path: str) -> Optional[str]:
|
| 32 |
"""
|
| 33 |
Constructs the full provider endpoint URL based on the provider and incoming request path.
|
|
|
|
| 34 |
"""
|
|
|
|
| 35 |
base_url = PROVIDER_URL_MAP.get(provider)
|
|
|
|
|
|
|
| 36 |
if not base_url:
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Determine the specific action from the incoming path (e.g., 'chat/completions')
|
| 40 |
action = incoming_path.split('/v1/', 1)[-1] if '/v1/' in incoming_path else incoming_path
|
|
|
|
| 1 |
+
import os
|
| 2 |
from typing import Optional
|
| 3 |
|
| 4 |
# A comprehensive map of provider names to their base URLs.
|
|
|
|
| 32 |
def get_provider_endpoint(provider: str, model_name: str, incoming_path: str) -> Optional[str]:
|
| 33 |
"""
|
| 34 |
Constructs the full provider endpoint URL based on the provider and incoming request path.
|
| 35 |
+
Supports both hardcoded providers and custom OpenAI-compatible providers via environment variables.
|
| 36 |
"""
|
| 37 |
+
# First, check the hardcoded map
|
| 38 |
base_url = PROVIDER_URL_MAP.get(provider)
|
| 39 |
+
|
| 40 |
+
# If not found, check for custom provider via environment variable
|
| 41 |
if not base_url:
|
| 42 |
+
api_base_env = f"{provider.upper()}_API_BASE"
|
| 43 |
+
base_url = os.getenv(api_base_env)
|
| 44 |
+
if not base_url:
|
| 45 |
+
return None
|
| 46 |
|
| 47 |
# Determine the specific action from the incoming path (e.g., 'chat/completions')
|
| 48 |
action = incoming_path.split('/v1/', 1)[-1] if '/v1/' in incoming_path else incoming_path
|
src/rotator_library/client.py
CHANGED
|
@@ -61,6 +61,7 @@ class RotatingClient:
|
|
| 61 |
ignore_models: Optional[Dict[str, List[str]]] = None,
|
| 62 |
whitelist_models: Optional[Dict[str, List[str]]] = None,
|
| 63 |
enable_request_logging: bool = False,
|
|
|
|
| 64 |
):
|
| 65 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 66 |
litellm.set_verbose = False
|
|
@@ -118,6 +119,14 @@ class RotatingClient:
|
|
| 118 |
self.whitelist_models = whitelist_models or {}
|
| 119 |
self.enable_request_logging = enable_request_logging
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
def _is_model_ignored(self, provider: str, model_id: str) -> bool:
|
| 122 |
"""
|
| 123 |
Checks if a model should be ignored based on the ignore list.
|
|
@@ -576,8 +585,10 @@ class RotatingClient:
|
|
| 576 |
lib_logger.info(
|
| 577 |
f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{len(credentials_for_provider)}"
|
| 578 |
)
|
|
|
|
| 579 |
current_cred = await self.usage_manager.acquire_key(
|
| 580 |
-
available_keys=creds_to_try, model=model, deadline=deadline
|
|
|
|
| 581 |
)
|
| 582 |
key_acquired = True
|
| 583 |
tried_creds.add(current_cred)
|
|
@@ -918,8 +929,10 @@ class RotatingClient:
|
|
| 918 |
lib_logger.info(
|
| 919 |
f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{len(credentials_for_provider)}"
|
| 920 |
)
|
|
|
|
| 921 |
current_cred = await self.usage_manager.acquire_key(
|
| 922 |
-
available_keys=creds_to_try, model=model, deadline=deadline
|
|
|
|
| 923 |
)
|
| 924 |
key_acquired = True
|
| 925 |
tried_creds.add(current_cred)
|
|
|
|
| 61 |
ignore_models: Optional[Dict[str, List[str]]] = None,
|
| 62 |
whitelist_models: Optional[Dict[str, List[str]]] = None,
|
| 63 |
enable_request_logging: bool = False,
|
| 64 |
+
max_concurrent_requests_per_key: Optional[Dict[str, int]] = None,
|
| 65 |
):
|
| 66 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 67 |
litellm.set_verbose = False
|
|
|
|
| 119 |
self.whitelist_models = whitelist_models or {}
|
| 120 |
self.enable_request_logging = enable_request_logging
|
| 121 |
|
| 122 |
+
# Store and validate max concurrent requests per key
|
| 123 |
+
self.max_concurrent_requests_per_key = max_concurrent_requests_per_key or {}
|
| 124 |
+
# Validate all values are >= 1
|
| 125 |
+
for provider, max_val in self.max_concurrent_requests_per_key.items():
|
| 126 |
+
if max_val < 1:
|
| 127 |
+
lib_logger.warning(f"Invalid max_concurrent for '{provider}': {max_val}. Setting to 1.")
|
| 128 |
+
self.max_concurrent_requests_per_key[provider] = 1
|
| 129 |
+
|
| 130 |
def _is_model_ignored(self, provider: str, model_id: str) -> bool:
|
| 131 |
"""
|
| 132 |
Checks if a model should be ignored based on the ignore list.
|
|
|
|
| 585 |
lib_logger.info(
|
| 586 |
f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{len(credentials_for_provider)}"
|
| 587 |
)
|
| 588 |
+
max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1)
|
| 589 |
current_cred = await self.usage_manager.acquire_key(
|
| 590 |
+
available_keys=creds_to_try, model=model, deadline=deadline,
|
| 591 |
+
max_concurrent=max_concurrent
|
| 592 |
)
|
| 593 |
key_acquired = True
|
| 594 |
tried_creds.add(current_cred)
|
|
|
|
| 929 |
lib_logger.info(
|
| 930 |
f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{len(credentials_for_provider)}"
|
| 931 |
)
|
| 932 |
+
max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1)
|
| 933 |
current_cred = await self.usage_manager.acquire_key(
|
| 934 |
+
available_keys=creds_to_try, model=model, deadline=deadline,
|
| 935 |
+
max_concurrent=max_concurrent
|
| 936 |
)
|
| 937 |
key_acquired = True
|
| 938 |
tried_creds.add(current_cred)
|
src/rotator_library/usage_manager.py
CHANGED
|
@@ -157,11 +157,12 @@ class UsageManager:
|
|
| 157 |
self.key_states[key] = {
|
| 158 |
"lock": asyncio.Lock(),
|
| 159 |
"condition": asyncio.Condition(),
|
| 160 |
-
"models_in_use":
|
| 161 |
}
|
| 162 |
|
| 163 |
async def acquire_key(
|
| 164 |
-
self, available_keys: List[str], model: str, deadline: float
|
|
|
|
| 165 |
) -> str:
|
| 166 |
"""
|
| 167 |
Acquires the best available key using a tiered, model-aware locking strategy,
|
|
@@ -198,8 +199,8 @@ class UsageManager:
|
|
| 198 |
# Tier 1: Completely idle keys (preferred).
|
| 199 |
if not key_state["models_in_use"]:
|
| 200 |
tier1_keys.append((key, usage_count))
|
| 201 |
-
# Tier 2: Keys
|
| 202 |
-
elif
|
| 203 |
tier2_keys.append((key, usage_count))
|
| 204 |
|
| 205 |
tier1_keys.sort(key=lambda x: x[1])
|
|
@@ -210,7 +211,7 @@ class UsageManager:
|
|
| 210 |
state = self.key_states[key]
|
| 211 |
async with state["lock"]:
|
| 212 |
if not state["models_in_use"]:
|
| 213 |
-
state["models_in_use"]
|
| 214 |
lib_logger.info(
|
| 215 |
f"Acquired Tier 1 key ...{key[-6:]} for model {model}"
|
| 216 |
)
|
|
@@ -220,10 +221,12 @@ class UsageManager:
|
|
| 220 |
for key, _ in tier2_keys:
|
| 221 |
state = self.key_states[key]
|
| 222 |
async with state["lock"]:
|
| 223 |
-
|
| 224 |
-
|
|
|
|
| 225 |
lib_logger.info(
|
| 226 |
-
f"Acquired Tier 2 key ...{key[-6:]} for model {model}"
|
|
|
|
| 227 |
)
|
| 228 |
return key
|
| 229 |
|
|
@@ -271,8 +274,14 @@ class UsageManager:
|
|
| 271 |
state = self.key_states[key]
|
| 272 |
async with state["lock"]:
|
| 273 |
if model in state["models_in_use"]:
|
| 274 |
-
state["models_in_use"]
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
else:
|
| 277 |
lib_logger.warning(
|
| 278 |
f"Attempted to release credential ...{key[-6:]} for model {model}, but it was not in use."
|
|
|
|
| 157 |
self.key_states[key] = {
|
| 158 |
"lock": asyncio.Lock(),
|
| 159 |
"condition": asyncio.Condition(),
|
| 160 |
+
"models_in_use": {}, # Dict[model_name, concurrent_count]
|
| 161 |
}
|
| 162 |
|
| 163 |
async def acquire_key(
|
| 164 |
+
self, available_keys: List[str], model: str, deadline: float,
|
| 165 |
+
max_concurrent: int = 1
|
| 166 |
) -> str:
|
| 167 |
"""
|
| 168 |
Acquires the best available key using a tiered, model-aware locking strategy,
|
|
|
|
| 199 |
# Tier 1: Completely idle keys (preferred).
|
| 200 |
if not key_state["models_in_use"]:
|
| 201 |
tier1_keys.append((key, usage_count))
|
| 202 |
+
# Tier 2: Keys that can accept more concurrent requests for this model.
|
| 203 |
+
elif key_state["models_in_use"].get(model, 0) < max_concurrent:
|
| 204 |
tier2_keys.append((key, usage_count))
|
| 205 |
|
| 206 |
tier1_keys.sort(key=lambda x: x[1])
|
|
|
|
| 211 |
state = self.key_states[key]
|
| 212 |
async with state["lock"]:
|
| 213 |
if not state["models_in_use"]:
|
| 214 |
+
state["models_in_use"][model] = 1
|
| 215 |
lib_logger.info(
|
| 216 |
f"Acquired Tier 1 key ...{key[-6:]} for model {model}"
|
| 217 |
)
|
|
|
|
| 221 |
for key, _ in tier2_keys:
|
| 222 |
state = self.key_states[key]
|
| 223 |
async with state["lock"]:
|
| 224 |
+
current_count = state["models_in_use"].get(model, 0)
|
| 225 |
+
if current_count < max_concurrent:
|
| 226 |
+
state["models_in_use"][model] = current_count + 1
|
| 227 |
lib_logger.info(
|
| 228 |
+
f"Acquired Tier 2 key ...{key[-6:]} for model {model} "
|
| 229 |
+
f"(concurrent: {state['models_in_use'][model]}/{max_concurrent})"
|
| 230 |
)
|
| 231 |
return key
|
| 232 |
|
|
|
|
| 274 |
state = self.key_states[key]
|
| 275 |
async with state["lock"]:
|
| 276 |
if model in state["models_in_use"]:
|
| 277 |
+
state["models_in_use"][model] -= 1
|
| 278 |
+
remaining = state["models_in_use"][model]
|
| 279 |
+
if remaining <= 0:
|
| 280 |
+
del state["models_in_use"][model] # Clean up when count reaches 0
|
| 281 |
+
lib_logger.info(
|
| 282 |
+
f"Released credential ...{key[-6:]} from model {model} "
|
| 283 |
+
f"(remaining concurrent: {max(0, remaining)})"
|
| 284 |
+
)
|
| 285 |
else:
|
| 286 |
lib_logger.warning(
|
| 287 |
f"Attempted to release credential ...{key[-6:]} for model {model}, but it was not in use."
|