Mirrowel commited on
Commit
b1631e5
·
1 Parent(s): b9d3ae7

feat(api): ✨ adopt OpenAI schema and blacklist patterns for model discovery

Browse files

Expose `/v1/models` in the canonical OpenAI list shape while letting admins hide entire model families via wildcard patterns.

- Map provider/model responses to `ModelCard`/`ModelList` DTOs that match upstream
- Parse `IGNORE_MODELS_<provider>` env vars to drop models at runtime (`gpt-3.5*,claude-*`)
- Strip provider prefixes from IDs for a clean, client-friendly catalog
- Remove `grouped` option; the endpoint now always returns the flattened spec

BREAKING CHANGE: Legacy `{provider: {models: [...]}}` envelope and the `grouped` query parameter are gone. Update clients to expect `{"object":"list","data":[...]}` with bare model IDs.

src/proxy_app/main.py CHANGED
@@ -13,8 +13,9 @@ import colorlog
13
  from pathlib import Path
14
  import sys
15
  import json
 
16
  from typing import AsyncGenerator, Any, List, Optional, Union
17
- from pydantic import BaseModel
18
  import argparse
19
  import litellm
20
 
@@ -27,6 +28,18 @@ class EmbeddingRequest(BaseModel):
27
  dimensions: Optional[int] = None
28
  user: Optional[str] = None
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # --- Argument Parsing ---
31
  parser = argparse.ArgumentParser(description="API Key Proxy Server")
32
  parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to.")
@@ -125,12 +138,21 @@ for key, value in os.environ.items():
125
  if not api_keys:
126
  raise ValueError("No provider API keys found in environment variables.")
127
 
 
 
 
 
 
 
 
 
 
128
  # --- Lifespan Management ---
129
  @asynccontextmanager
130
  async def lifespan(app: FastAPI):
131
  """Manage the RotatingClient's lifecycle with the app's lifespan."""
132
  # The client now uses the root logger configuration
133
- client = RotatingClient(api_keys=api_keys, configure_logging=True)
134
  app.state.rotating_client = client
135
  os.environ["LITELLM_LOG"] = "ERROR"
136
  litellm.set_verbose = False
@@ -504,18 +526,18 @@ async def embeddings(
504
  def read_root():
505
  return {"Status": "API Key Proxy is running"}
506
 
507
- @app.get("/v1/models")
508
  async def list_models(
509
- grouped: bool = False,
510
  client: RotatingClient = Depends(get_rotating_client),
511
  _=Depends(verify_api_key)
512
  ):
513
  """
514
- Returns a list of available models from all configured providers.
515
- Optionally returns them as a flat list if grouped=False.
516
  """
517
- models = await client.get_all_available_models(grouped=grouped)
518
- return models
 
 
519
 
520
  @app.get("/v1/providers")
521
  async def list_providers(_=Depends(verify_api_key)):
 
13
  from pathlib import Path
14
  import sys
15
  import json
16
+ import time
17
  from typing import AsyncGenerator, Any, List, Optional, Union
18
+ from pydantic import BaseModel, Field
19
  import argparse
20
  import litellm
21
 
 
28
  dimensions: Optional[int] = None
29
  user: Optional[str] = None
30
 
31
+
32
+ # --- Pydantic Models for Model Endpoints ---
33
+ class ModelCard(BaseModel):
34
+ id: str
35
+ object: str = "model"
36
+ created: int = Field(default_factory=lambda: int(time.time()))
37
+ owned_by: str = "Mirro-Proxy"
38
+
39
+ class ModelList(BaseModel):
40
+ object: str = "list"
41
+ data: List[ModelCard]
42
+
43
  # --- Argument Parsing ---
44
  parser = argparse.ArgumentParser(description="API Key Proxy Server")
45
  parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to.")
 
138
  if not api_keys:
139
  raise ValueError("No provider API keys found in environment variables.")
140
 
141
+ # Load model ignore lists from environment variables
142
+ ignore_models = {}
143
+ for key, value in os.environ.items():
144
+ if key.startswith("IGNORE_MODELS_"):
145
+ provider = key.replace("IGNORE_MODELS_", "").lower()
146
+ models_to_ignore = [model.strip() for model in value.split(',')]
147
+ ignore_models[provider] = models_to_ignore
148
+ logging.debug(f"Loaded ignore list for provider '{provider}': {models_to_ignore}")
149
+
150
  # --- Lifespan Management ---
151
  @asynccontextmanager
152
  async def lifespan(app: FastAPI):
153
  """Manage the RotatingClient's lifecycle with the app's lifespan."""
154
  # The client now uses the root logger configuration
155
+ client = RotatingClient(api_keys=api_keys, configure_logging=True, ignore_models=ignore_models)
156
  app.state.rotating_client = client
157
  os.environ["LITELLM_LOG"] = "ERROR"
158
  litellm.set_verbose = False
 
526
  def read_root():
527
  return {"Status": "API Key Proxy is running"}
528
 
529
+ @app.get("/v1/models", response_model=ModelList)
530
  async def list_models(
 
531
  client: RotatingClient = Depends(get_rotating_client),
532
  _=Depends(verify_api_key)
533
  ):
534
  """
535
+ Returns a list of available models in the OpenAI-compatible format.
 
536
  """
537
+ model_ids = await client.get_all_available_models(grouped=False)
538
+ model_cards = [ModelCard(id=model_id) for model_id in model_ids]
539
+ return ModelList(data=model_cards)
540
+
541
 
542
  @app.get("/v1/providers")
543
  async def list_providers(_=Depends(verify_api_key)):
src/rotator_library/client.py CHANGED
@@ -36,7 +36,16 @@ class RotatingClient:
36
  A client that intelligently rotates and retries API keys using LiteLLM,
37
  with support for both streaming and non-streaming responses.
38
  """
39
- def __init__(self, api_keys: Dict[str, List[str]], max_retries: int = 2, usage_file_path: str = "key_usage.json", configure_logging: bool = True, global_timeout: int = 30, abort_on_callback_error: bool = True):
 
 
 
 
 
 
 
 
 
40
  os.environ["LITELLM_LOG"] = "ERROR"
41
  litellm.set_verbose = False
42
  litellm.drop_params = True
@@ -64,6 +73,27 @@ class RotatingClient:
64
  self.http_client = httpx.AsyncClient()
65
  self.all_providers = AllProviders()
66
  self.cooldown_manager = CooldownManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def _sanitize_litellm_log(self, log_data: dict) -> dict:
69
  """
@@ -800,8 +830,14 @@ class RotatingClient:
800
  lib_logger.debug(f"Attempting to get models for {provider} with key ...{api_key[-4:]}")
801
  models = await provider_instance.get_models(api_key, self.http_client)
802
  lib_logger.info(f"Got {len(models)} models for provider: {provider}")
803
- self._model_list_cache[provider] = models
804
- return models
 
 
 
 
 
 
805
  except Exception as e:
806
  classified_error = classify_error(e)
807
  lib_logger.debug(f"Failed to get models for provider {provider} with key ...{api_key[-4:]}: {classified_error.error_type}. Trying next key.")
@@ -829,7 +865,6 @@ class RotatingClient:
829
  return all_provider_models
830
  else:
831
  flat_models = []
832
- for provider, models in all_provider_models.items():
833
- for model in models:
834
- flat_models.append(f"{provider}/{model}")
835
- return flat_models
 
36
  A client that intelligently rotates and retries API keys using LiteLLM,
37
  with support for both streaming and non-streaming responses.
38
  """
39
+ def __init__(
40
+ self,
41
+ api_keys: Dict[str, List[str]],
42
+ max_retries: int = 2,
43
+ usage_file_path: str = "key_usage.json",
44
+ configure_logging: bool = True,
45
+ global_timeout: int = 30,
46
+ abort_on_callback_error: bool = True,
47
+ ignore_models: Optional[Dict[str, List[str]]] = None
48
+ ):
49
  os.environ["LITELLM_LOG"] = "ERROR"
50
  litellm.set_verbose = False
51
  litellm.drop_params = True
 
73
  self.http_client = httpx.AsyncClient()
74
  self.all_providers = AllProviders()
75
  self.cooldown_manager = CooldownManager()
76
+ self.ignore_models = ignore_models or {}
77
+
78
+ def _is_model_ignored(self, provider: str, model_id: str) -> bool:
79
+ """
80
+ Checks if a model should be ignored based on the ignore list.
81
+ Supports exact and partial matching.
82
+ """
83
+ if provider not in self.ignore_models:
84
+ return False
85
+
86
+ ignore_list = self.ignore_models[provider]
87
+ for ignored_model in ignore_list:
88
+ if ignored_model.endswith('*'):
89
+ # Partial match
90
+ if ignored_model[:-1] in model_id:
91
+ return True
92
+ else:
93
+ # Exact match (ignoring provider prefix)
94
+ if model_id.endswith(ignored_model):
95
+ return True
96
+ return False
97
 
98
  def _sanitize_litellm_log(self, log_data: dict) -> dict:
99
  """
 
830
  lib_logger.debug(f"Attempting to get models for {provider} with key ...{api_key[-4:]}")
831
  models = await provider_instance.get_models(api_key, self.http_client)
832
  lib_logger.info(f"Got {len(models)} models for provider: {provider}")
833
+
834
+ # Filter models based on the ignore list
835
+ filtered_models = [m for m in models if not self._is_model_ignored(provider, m)]
836
+ if len(filtered_models) != len(models):
837
+ lib_logger.info(f"Filtered out {len(models) - len(filtered_models)} models for provider {provider}.")
838
+
839
+ self._model_list_cache[provider] = filtered_models
840
+ return filtered_models
841
  except Exception as e:
842
  classified_error = classify_error(e)
843
  lib_logger.debug(f"Failed to get models for provider {provider} with key ...{api_key[-4:]}: {classified_error.error_type}. Trying next key.")
 
865
  return all_provider_models
866
  else:
867
  flat_models = []
868
+ for models in all_provider_models.values():
869
+ flat_models.extend(models)
870
+ return flat_models