Spaces:
Paused
Paused
Mirrowel commited on
Commit ·
b1631e5
1
Parent(s): b9d3ae7
feat(api): ✨ adopt OpenAI schema and blacklist patterns for model discovery
Browse filesExpose `/v1/models` in the canonical OpenAI list shape while letting admins hide entire model families via wildcard patterns.
- Map provider/model responses to `ModelCard`/`ModelList` DTOs that match upstream
- Parse `IGNORE_MODELS_<provider>` env vars to drop models at runtime (`gpt-3.5*,claude-*`)
- Strip provider prefixes from IDs for a clean, client-friendly catalog
- Remove `grouped` option; the endpoint now always returns the flattened spec
BREAKING CHANGE: Legacy `{provider: {models: [...]}}` envelope and the `grouped` query parameter are gone. Update clients to expect `{"object":"list","data":[...]}` with bare model IDs.
- src/proxy_app/main.py +30 -8
- src/rotator_library/client.py +42 -7
src/proxy_app/main.py
CHANGED
|
@@ -13,8 +13,9 @@ import colorlog
|
|
| 13 |
from pathlib import Path
|
| 14 |
import sys
|
| 15 |
import json
|
|
|
|
| 16 |
from typing import AsyncGenerator, Any, List, Optional, Union
|
| 17 |
-
from pydantic import BaseModel
|
| 18 |
import argparse
|
| 19 |
import litellm
|
| 20 |
|
|
@@ -27,6 +28,18 @@ class EmbeddingRequest(BaseModel):
|
|
| 27 |
dimensions: Optional[int] = None
|
| 28 |
user: Optional[str] = None
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# --- Argument Parsing ---
|
| 31 |
parser = argparse.ArgumentParser(description="API Key Proxy Server")
|
| 32 |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to.")
|
|
@@ -125,12 +138,21 @@ for key, value in os.environ.items():
|
|
| 125 |
if not api_keys:
|
| 126 |
raise ValueError("No provider API keys found in environment variables.")
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
# --- Lifespan Management ---
|
| 129 |
@asynccontextmanager
|
| 130 |
async def lifespan(app: FastAPI):
|
| 131 |
"""Manage the RotatingClient's lifecycle with the app's lifespan."""
|
| 132 |
# The client now uses the root logger configuration
|
| 133 |
-
client = RotatingClient(api_keys=api_keys, configure_logging=True)
|
| 134 |
app.state.rotating_client = client
|
| 135 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 136 |
litellm.set_verbose = False
|
|
@@ -504,18 +526,18 @@ async def embeddings(
|
|
| 504 |
def read_root():
|
| 505 |
return {"Status": "API Key Proxy is running"}
|
| 506 |
|
| 507 |
-
@app.get("/v1/models")
|
| 508 |
async def list_models(
|
| 509 |
-
grouped: bool = False,
|
| 510 |
client: RotatingClient = Depends(get_rotating_client),
|
| 511 |
_=Depends(verify_api_key)
|
| 512 |
):
|
| 513 |
"""
|
| 514 |
-
Returns a list of available models
|
| 515 |
-
Optionally returns them as a flat list if grouped=False.
|
| 516 |
"""
|
| 517 |
-
|
| 518 |
-
|
|
|
|
|
|
|
| 519 |
|
| 520 |
@app.get("/v1/providers")
|
| 521 |
async def list_providers(_=Depends(verify_api_key)):
|
|
|
|
| 13 |
from pathlib import Path
|
| 14 |
import sys
|
| 15 |
import json
|
| 16 |
+
import time
|
| 17 |
from typing import AsyncGenerator, Any, List, Optional, Union
|
| 18 |
+
from pydantic import BaseModel, Field
|
| 19 |
import argparse
|
| 20 |
import litellm
|
| 21 |
|
|
|
|
| 28 |
dimensions: Optional[int] = None
|
| 29 |
user: Optional[str] = None
|
| 30 |
|
| 31 |
+
|
| 32 |
+
# --- Pydantic Models for Model Endpoints ---
|
| 33 |
+
class ModelCard(BaseModel):
|
| 34 |
+
id: str
|
| 35 |
+
object: str = "model"
|
| 36 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 37 |
+
owned_by: str = "Mirro-Proxy"
|
| 38 |
+
|
| 39 |
+
class ModelList(BaseModel):
|
| 40 |
+
object: str = "list"
|
| 41 |
+
data: List[ModelCard]
|
| 42 |
+
|
| 43 |
# --- Argument Parsing ---
|
| 44 |
parser = argparse.ArgumentParser(description="API Key Proxy Server")
|
| 45 |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to.")
|
|
|
|
| 138 |
if not api_keys:
|
| 139 |
raise ValueError("No provider API keys found in environment variables.")
|
| 140 |
|
| 141 |
+
# Load model ignore lists from environment variables
|
| 142 |
+
ignore_models = {}
|
| 143 |
+
for key, value in os.environ.items():
|
| 144 |
+
if key.startswith("IGNORE_MODELS_"):
|
| 145 |
+
provider = key.replace("IGNORE_MODELS_", "").lower()
|
| 146 |
+
models_to_ignore = [model.strip() for model in value.split(',')]
|
| 147 |
+
ignore_models[provider] = models_to_ignore
|
| 148 |
+
logging.debug(f"Loaded ignore list for provider '{provider}': {models_to_ignore}")
|
| 149 |
+
|
| 150 |
# --- Lifespan Management ---
|
| 151 |
@asynccontextmanager
|
| 152 |
async def lifespan(app: FastAPI):
|
| 153 |
"""Manage the RotatingClient's lifecycle with the app's lifespan."""
|
| 154 |
# The client now uses the root logger configuration
|
| 155 |
+
client = RotatingClient(api_keys=api_keys, configure_logging=True, ignore_models=ignore_models)
|
| 156 |
app.state.rotating_client = client
|
| 157 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 158 |
litellm.set_verbose = False
|
|
|
|
| 526 |
def read_root():
|
| 527 |
return {"Status": "API Key Proxy is running"}
|
| 528 |
|
| 529 |
+
@app.get("/v1/models", response_model=ModelList)
|
| 530 |
async def list_models(
|
|
|
|
| 531 |
client: RotatingClient = Depends(get_rotating_client),
|
| 532 |
_=Depends(verify_api_key)
|
| 533 |
):
|
| 534 |
"""
|
| 535 |
+
Returns a list of available models in the OpenAI-compatible format.
|
|
|
|
| 536 |
"""
|
| 537 |
+
model_ids = await client.get_all_available_models(grouped=False)
|
| 538 |
+
model_cards = [ModelCard(id=model_id) for model_id in model_ids]
|
| 539 |
+
return ModelList(data=model_cards)
|
| 540 |
+
|
| 541 |
|
| 542 |
@app.get("/v1/providers")
|
| 543 |
async def list_providers(_=Depends(verify_api_key)):
|
src/rotator_library/client.py
CHANGED
|
@@ -36,7 +36,16 @@ class RotatingClient:
|
|
| 36 |
A client that intelligently rotates and retries API keys using LiteLLM,
|
| 37 |
with support for both streaming and non-streaming responses.
|
| 38 |
"""
|
| 39 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 41 |
litellm.set_verbose = False
|
| 42 |
litellm.drop_params = True
|
|
@@ -64,6 +73,27 @@ class RotatingClient:
|
|
| 64 |
self.http_client = httpx.AsyncClient()
|
| 65 |
self.all_providers = AllProviders()
|
| 66 |
self.cooldown_manager = CooldownManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
def _sanitize_litellm_log(self, log_data: dict) -> dict:
|
| 69 |
"""
|
|
@@ -800,8 +830,14 @@ class RotatingClient:
|
|
| 800 |
lib_logger.debug(f"Attempting to get models for {provider} with key ...{api_key[-4:]}")
|
| 801 |
models = await provider_instance.get_models(api_key, self.http_client)
|
| 802 |
lib_logger.info(f"Got {len(models)} models for provider: {provider}")
|
| 803 |
-
|
| 804 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
except Exception as e:
|
| 806 |
classified_error = classify_error(e)
|
| 807 |
lib_logger.debug(f"Failed to get models for provider {provider} with key ...{api_key[-4:]}: {classified_error.error_type}. Trying next key.")
|
|
@@ -829,7 +865,6 @@ class RotatingClient:
|
|
| 829 |
return all_provider_models
|
| 830 |
else:
|
| 831 |
flat_models = []
|
| 832 |
-
for
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
return flat_models
|
|
|
|
| 36 |
A client that intelligently rotates and retries API keys using LiteLLM,
|
| 37 |
with support for both streaming and non-streaming responses.
|
| 38 |
"""
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
api_keys: Dict[str, List[str]],
|
| 42 |
+
max_retries: int = 2,
|
| 43 |
+
usage_file_path: str = "key_usage.json",
|
| 44 |
+
configure_logging: bool = True,
|
| 45 |
+
global_timeout: int = 30,
|
| 46 |
+
abort_on_callback_error: bool = True,
|
| 47 |
+
ignore_models: Optional[Dict[str, List[str]]] = None
|
| 48 |
+
):
|
| 49 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 50 |
litellm.set_verbose = False
|
| 51 |
litellm.drop_params = True
|
|
|
|
| 73 |
self.http_client = httpx.AsyncClient()
|
| 74 |
self.all_providers = AllProviders()
|
| 75 |
self.cooldown_manager = CooldownManager()
|
| 76 |
+
self.ignore_models = ignore_models or {}
|
| 77 |
+
|
| 78 |
+
def _is_model_ignored(self, provider: str, model_id: str) -> bool:
|
| 79 |
+
"""
|
| 80 |
+
Checks if a model should be ignored based on the ignore list.
|
| 81 |
+
Supports exact and partial matching.
|
| 82 |
+
"""
|
| 83 |
+
if provider not in self.ignore_models:
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
ignore_list = self.ignore_models[provider]
|
| 87 |
+
for ignored_model in ignore_list:
|
| 88 |
+
if ignored_model.endswith('*'):
|
| 89 |
+
# Partial match
|
| 90 |
+
if ignored_model[:-1] in model_id:
|
| 91 |
+
return True
|
| 92 |
+
else:
|
| 93 |
+
# Exact match (ignoring provider prefix)
|
| 94 |
+
if model_id.endswith(ignored_model):
|
| 95 |
+
return True
|
| 96 |
+
return False
|
| 97 |
|
| 98 |
def _sanitize_litellm_log(self, log_data: dict) -> dict:
|
| 99 |
"""
|
|
|
|
| 830 |
lib_logger.debug(f"Attempting to get models for {provider} with key ...{api_key[-4:]}")
|
| 831 |
models = await provider_instance.get_models(api_key, self.http_client)
|
| 832 |
lib_logger.info(f"Got {len(models)} models for provider: {provider}")
|
| 833 |
+
|
| 834 |
+
# Filter models based on the ignore list
|
| 835 |
+
filtered_models = [m for m in models if not self._is_model_ignored(provider, m)]
|
| 836 |
+
if len(filtered_models) != len(models):
|
| 837 |
+
lib_logger.info(f"Filtered out {len(models) - len(filtered_models)} models for provider {provider}.")
|
| 838 |
+
|
| 839 |
+
self._model_list_cache[provider] = filtered_models
|
| 840 |
+
return filtered_models
|
| 841 |
except Exception as e:
|
| 842 |
classified_error = classify_error(e)
|
| 843 |
lib_logger.debug(f"Failed to get models for provider {provider} with key ...{api_key[-4:]}: {classified_error.error_type}. Trying next key.")
|
|
|
|
| 865 |
return all_provider_models
|
| 866 |
else:
|
| 867 |
flat_models = []
|
| 868 |
+
for models in all_provider_models.values():
|
| 869 |
+
flat_models.extend(models)
|
| 870 |
+
return flat_models
|
|
|