Mirrowel commited on
Commit
26c6a6e
·
1 Parent(s): fd71e0a

feat: Exclude PROXY_API_KEY from provider API key loading and enhance model retrieval options

Browse files
src/proxy_app/main.py CHANGED
@@ -26,7 +26,8 @@ if not PROXY_API_KEY:
26
  # Load all provider API keys from environment variables
27
  api_keys = {}
28
  for key, value in os.environ.items():
29
- if key.endswith("_API_KEY") or "_API_KEY_" in key:
 
30
  parts = key.split("_API_KEY")
31
  provider = parts[0].lower()
32
  if provider not in api_keys:
@@ -75,19 +76,20 @@ def read_root():
75
  return {"Status": "API Key Proxy is running"}
76
 
77
  @app.get("/v1/models")
78
- async def list_models(_=Depends(verify_api_key)):
79
  """
80
  Returns a list of available models from all configured providers.
 
81
  """
82
- models = await rotating_client.get_all_available_models()
83
- return {"data": models}
84
 
85
  @app.get("/v1/providers")
86
  async def list_providers(_=Depends(verify_api_key)):
87
  """
88
  Returns a list of all available providers.
89
  """
90
- return {"data": list(PROVIDER_PLUGINS.keys())}
91
 
92
  @app.post("/v1/token-count")
93
  async def token_count(request: Request, _=Depends(verify_api_key)):
 
26
  # Load all provider API keys from environment variables
27
  api_keys = {}
28
  for key, value in os.environ.items():
29
+ # Exclude PROXY_API_KEY from being treated as a provider API key
30
+ if (key.endswith("_API_KEY") or "_API_KEY_" in key) and key != "PROXY_API_KEY":
31
  parts = key.split("_API_KEY")
32
  provider = parts[0].lower()
33
  if provider not in api_keys:
 
76
  return {"Status": "API Key Proxy is running"}
77
 
78
  @app.get("/v1/models")
79
+ async def list_models(grouped: bool = False, _=Depends(verify_api_key)):
80
  """
81
  Returns a list of available models from all configured providers.
82
+ Optionally returns them as a flat list if grouped=False.
83
  """
84
+ models = await rotating_client.get_all_available_models(grouped=grouped)
85
+ return models
86
 
87
  @app.get("/v1/providers")
88
  async def list_providers(_=Depends(verify_api_key)):
89
  """
90
  Returns a list of all available providers.
91
  """
92
+ return list(PROVIDER_PLUGINS.keys())
93
 
94
  @app.post("/v1/token-count")
95
  async def token_count(request: Request, _=Depends(verify_api_key)):
src/rotator_library/client.py CHANGED
@@ -145,11 +145,19 @@ class RotatingClient:
145
  logging.warning(f"Model list fetching not implemented for provider: {provider}")
146
  return []
147
 
148
- async def get_all_available_models(self) -> Dict[str, List[str]]:
149
  """
150
- Returns a dictionary of all available models, grouped by provider.
151
  """
152
  all_provider_models = {}
153
  for provider in self.api_keys.keys():
154
  all_provider_models[provider] = await self.get_available_models(provider)
155
- return all_provider_models
 
 
 
 
 
 
 
 
 
145
  logging.warning(f"Model list fetching not implemented for provider: {provider}")
146
  return []
147
 
148
+ async def get_all_available_models(self, grouped: bool = True) -> Any:
149
  """
150
+ Returns a list of all available models, either grouped by provider or as a flat list.
151
  """
152
  all_provider_models = {}
153
  for provider in self.api_keys.keys():
154
  all_provider_models[provider] = await self.get_available_models(provider)
155
+
156
+ if grouped:
157
+ return all_provider_models
158
+ else:
159
+ flat_models = []
160
+ for provider, models in all_provider_models.items():
161
+ for model in models:
162
+ flat_models.append(f"{model}")
163
+ return flat_models