Spaces:
Paused
feat(multi-provider): Implement dynamic API key loading and new endpoints
Browse filesRefactor API key loading in `main.py` to dynamically support multiple providers from environment variables (e.g., `OPENAI_API_KEY`, `ANTHROPIC_API_KEY_1`).
Introduce a pluggable provider system in `src/rotator_library/providers` with an abstract `ProviderInterface` and concrete implementations for:
- Anthropic
- AWS Bedrock
- Cohere
- Google Gemini
- Groq
- Mistral
- OpenAI
Enhance `RotatingClient` to:
- Accept a dictionary of API keys, grouped by provider.
- Dynamically fetch available models from integrated providers with caching.
- Calculate token counts using `litellm.token_counter`.
Add new API endpoints to `main.py`:
- `GET /v1/models`: Lists all available models across configured providers.
- `GET /v1/providers`: Lists all integrated providers.
- `POST /v1/token-count`: Calculates token usage for given messages or text.
Update `UsageManager` to:
- Record approximate costs for completions.
- Utilize `litellm.ModelResponse` for more comprehensive usage tracking.
- Streamline `rotator_library/__init__.py` exports.
- src/proxy_app/main.py +50 -21
- src/rotator_library/__init__.py +2 -15
- src/rotator_library/client.py +56 -11
- src/rotator_library/providers/__init__.py +35 -0
- src/rotator_library/providers/anthropic_provider.py +26 -0
- src/rotator_library/providers/bedrock_provider.py +23 -0
- src/rotator_library/providers/cohere_provider.py +23 -0
- src/rotator_library/providers/gemini_provider.py +23 -0
- src/rotator_library/providers/groq_provider.py +23 -0
- src/rotator_library/providers/mistral_provider.py +23 -0
- src/rotator_library/providers/openai_provider.py +23 -0
- src/rotator_library/providers/provider_interface.py +21 -0
- src/rotator_library/usage_manager.py +17 -5
|
@@ -10,10 +10,10 @@ import sys
|
|
| 10 |
# Add the 'src' directory to the Python path to allow importing 'rotating_api_key_client'
|
| 11 |
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 12 |
|
| 13 |
-
from rotator_library import RotatingClient
|
| 14 |
|
| 15 |
# Configure logging
|
| 16 |
-
logging.basicConfig(level=logging.INFO)
|
| 17 |
|
| 18 |
# Load environment variables from .env file
|
| 19 |
load_dotenv()
|
|
@@ -23,27 +23,21 @@ PROXY_API_KEY = os.getenv("PROXY_API_KEY")
|
|
| 23 |
if not PROXY_API_KEY:
|
| 24 |
raise ValueError("PROXY_API_KEY environment variable not set.")
|
| 25 |
|
| 26 |
-
# Load all
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
i += 1
|
| 39 |
-
else:
|
| 40 |
-
break
|
| 41 |
-
|
| 42 |
-
if not gemini_keys:
|
| 43 |
-
raise ValueError("No GEMINI_API_KEY or GEMINI_API_KEY_n environment variables found.")
|
| 44 |
|
| 45 |
# Initialize the rotating client
|
| 46 |
-
rotating_client = RotatingClient(api_keys=
|
| 47 |
|
| 48 |
# --- FastAPI App Setup ---
|
| 49 |
app = FastAPI()
|
|
@@ -79,3 +73,38 @@ async def chat_completions(request: Request, _=Depends(verify_api_key)):
|
|
| 79 |
@app.get("/")
|
| 80 |
def read_root():
|
| 81 |
return {"Status": "API Key Proxy is running"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
# Add the 'src' directory to the Python path to allow importing 'rotating_api_key_client'
|
| 11 |
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 12 |
|
| 13 |
+
from rotator_library import RotatingClient, PROVIDER_PLUGINS
|
| 14 |
|
| 15 |
# Configure logging
|
| 16 |
+
logging.basicConfig(level=logging.INFO) #-> moved to the rotator_library
|
| 17 |
|
| 18 |
# Load environment variables from .env file
|
| 19 |
load_dotenv()
|
|
|
|
| 23 |
if not PROXY_API_KEY:
|
| 24 |
raise ValueError("PROXY_API_KEY environment variable not set.")
|
| 25 |
|
| 26 |
+
# Load all provider API keys from environment variables
|
| 27 |
+
api_keys = {}
|
| 28 |
+
for key, value in os.environ.items():
|
| 29 |
+
if key.endswith("_API_KEY") or "_API_KEY_" in key:
|
| 30 |
+
parts = key.split("_API_KEY")
|
| 31 |
+
provider = parts[0].lower()
|
| 32 |
+
if provider not in api_keys:
|
| 33 |
+
api_keys[provider] = []
|
| 34 |
+
api_keys[provider].append(value)
|
| 35 |
+
|
| 36 |
+
if not api_keys:
|
| 37 |
+
raise ValueError("No provider API keys found in environment variables.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Initialize the rotating client
|
| 40 |
+
rotating_client = RotatingClient(api_keys=api_keys)
|
| 41 |
|
| 42 |
# --- FastAPI App Setup ---
|
| 43 |
app = FastAPI()
|
|
|
|
| 73 |
@app.get("/")
|
| 74 |
def read_root():
|
| 75 |
return {"Status": "API Key Proxy is running"}
|
| 76 |
+
|
| 77 |
+
@app.get("/v1/models")
|
| 78 |
+
async def list_models(_=Depends(verify_api_key)):
|
| 79 |
+
"""
|
| 80 |
+
Returns a list of available models from all configured providers.
|
| 81 |
+
"""
|
| 82 |
+
models = await rotating_client.get_all_available_models()
|
| 83 |
+
return {"data": models}
|
| 84 |
+
|
| 85 |
+
@app.get("/v1/providers")
|
| 86 |
+
async def list_providers(_=Depends(verify_api_key)):
|
| 87 |
+
"""
|
| 88 |
+
Returns a list of all available providers.
|
| 89 |
+
"""
|
| 90 |
+
return {"data": list(PROVIDER_PLUGINS.keys())}
|
| 91 |
+
|
| 92 |
+
@app.post("/v1/token-count")
|
| 93 |
+
async def token_count(request: Request, _=Depends(verify_api_key)):
|
| 94 |
+
"""
|
| 95 |
+
Calculates the token count for a given list of messages and a model.
|
| 96 |
+
"""
|
| 97 |
+
try:
|
| 98 |
+
data = await request.json()
|
| 99 |
+
model = data.get("model")
|
| 100 |
+
messages = data.get("messages")
|
| 101 |
+
|
| 102 |
+
if not model or not messages:
|
| 103 |
+
raise HTTPException(status_code=400, detail="'model' and 'messages' are required.")
|
| 104 |
+
|
| 105 |
+
count = rotating_client.token_count(model=model, messages=messages)
|
| 106 |
+
return {"token_count": count}
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logging.error(f"Token count failed: {e}")
|
| 110 |
+
raise HTTPException(status_code=500, detail=str(e))
|
|
@@ -1,17 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Rotating API Key Client
|
| 3 |
-
"""
|
| 4 |
from .client import RotatingClient
|
| 5 |
-
from .
|
| 6 |
-
from .error_handler import is_authentication_error, is_rate_limit_error, is_server_error, is_unrecoverable_error
|
| 7 |
-
from .failure_logger import log_failure
|
| 8 |
|
| 9 |
-
__all__ = [
|
| 10 |
-
"RotatingClient",
|
| 11 |
-
"UsageManager",
|
| 12 |
-
"is_authentication_error",
|
| 13 |
-
"is_rate_limit_error",
|
| 14 |
-
"is_server_error",
|
| 15 |
-
"is_unrecoverable_error",
|
| 16 |
-
"log_failure",
|
| 17 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from .client import RotatingClient
|
| 2 |
+
from .providers import PROVIDER_PLUGINS
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
__all__ = ["RotatingClient", "PROVIDER_PLUGINS"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,29 +1,31 @@
|
|
| 1 |
import asyncio
|
| 2 |
import json
|
| 3 |
import litellm
|
|
|
|
| 4 |
import logging
|
| 5 |
from typing import List, Dict, Any, AsyncGenerator
|
| 6 |
|
| 7 |
from src.rotator_library.usage_manager import UsageManager
|
| 8 |
from src.rotator_library.failure_logger import log_failure
|
| 9 |
-
from src.rotator_library.error_handler import
|
| 10 |
-
|
| 11 |
-
is_rate_limit_error,
|
| 12 |
-
is_server_error,
|
| 13 |
-
is_unrecoverable_error,
|
| 14 |
-
)
|
| 15 |
|
| 16 |
class RotatingClient:
|
| 17 |
"""
|
| 18 |
A client that intelligently rotates and retries API keys using LiteLLM,
|
| 19 |
with support for both streaming and non-streaming responses.
|
| 20 |
"""
|
| 21 |
-
def __init__(self, api_keys: List[str], max_retries: int = 2, usage_file_path: str = "key_usage.json"):
|
|
|
|
| 22 |
if not api_keys:
|
| 23 |
-
raise ValueError("API keys
|
| 24 |
self.api_keys = api_keys
|
| 25 |
self.max_retries = max_retries
|
| 26 |
self.usage_manager = UsageManager(file_path=usage_file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
async def _streaming_wrapper(self, stream: Any, key: str, model: str) -> AsyncGenerator[Any, None]:
|
| 29 |
"""
|
|
@@ -42,7 +44,7 @@ class RotatingClient:
|
|
| 42 |
# Safely check for usage data in the chunk
|
| 43 |
if hasattr(chunk, 'usage') and chunk.usage:
|
| 44 |
logging.info(f"Usage found in chunk for key ...{key[-4:]}: {chunk.usage}")
|
| 45 |
-
self.usage_manager.record_success(key, model, chunk
|
| 46 |
|
| 47 |
finally:
|
| 48 |
# Signal the end of the stream
|
|
@@ -61,9 +63,13 @@ class RotatingClient:
|
|
| 61 |
if not model:
|
| 62 |
raise ValueError("'model' is a required parameter.")
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
while True: # Loop until a key succeeds or we decide to give up
|
| 65 |
current_key = self.usage_manager.get_next_smart_key(
|
| 66 |
-
available_keys=self.api_keys,
|
| 67 |
model=model
|
| 68 |
)
|
| 69 |
|
|
@@ -82,7 +88,7 @@ class RotatingClient:
|
|
| 82 |
return self._streaming_wrapper(response, current_key, model)
|
| 83 |
else:
|
| 84 |
# For non-streams, we can log usage immediately.
|
| 85 |
-
self.usage_manager.record_success(current_key, model, response
|
| 86 |
return response
|
| 87 |
|
| 88 |
except Exception as e:
|
|
@@ -108,3 +114,42 @@ class RotatingClient:
|
|
| 108 |
print(f"Key ...{current_key[-4:]} failed permanently. Rotating...")
|
| 109 |
self.usage_manager.record_rotation_error(current_key, model, e)
|
| 110 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import json
|
| 3 |
import litellm
|
| 4 |
+
from litellm.litellm_core_utils.token_counter import token_counter
|
| 5 |
import logging
|
| 6 |
from typing import List, Dict, Any, AsyncGenerator
|
| 7 |
|
| 8 |
from src.rotator_library.usage_manager import UsageManager
|
| 9 |
from src.rotator_library.failure_logger import log_failure
|
| 10 |
+
from src.rotator_library.error_handler import is_server_error, is_unrecoverable_error
|
| 11 |
+
from src.rotator_library.providers import PROVIDER_PLUGINS
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class RotatingClient:
|
| 14 |
"""
|
| 15 |
A client that intelligently rotates and retries API keys using LiteLLM,
|
| 16 |
with support for both streaming and non-streaming responses.
|
| 17 |
"""
|
| 18 |
+
def __init__(self, api_keys: Dict[str, List[str]], max_retries: int = 2, usage_file_path: str = "key_usage.json"):
|
| 19 |
+
litellm.set_verbose = False
|
| 20 |
if not api_keys:
|
| 21 |
+
raise ValueError("API keys dictionary cannot be empty.")
|
| 22 |
self.api_keys = api_keys
|
| 23 |
self.max_retries = max_retries
|
| 24 |
self.usage_manager = UsageManager(file_path=usage_file_path)
|
| 25 |
+
self._model_list_cache = {}
|
| 26 |
+
self._provider_instances = {
|
| 27 |
+
name: plugin() for name, plugin in PROVIDER_PLUGINS.items()
|
| 28 |
+
}
|
| 29 |
|
| 30 |
async def _streaming_wrapper(self, stream: Any, key: str, model: str) -> AsyncGenerator[Any, None]:
|
| 31 |
"""
|
|
|
|
| 44 |
# Safely check for usage data in the chunk
|
| 45 |
if hasattr(chunk, 'usage') and chunk.usage:
|
| 46 |
logging.info(f"Usage found in chunk for key ...{key[-4:]}: {chunk.usage}")
|
| 47 |
+
self.usage_manager.record_success(key, model, chunk)
|
| 48 |
|
| 49 |
finally:
|
| 50 |
# Signal the end of the stream
|
|
|
|
| 63 |
if not model:
|
| 64 |
raise ValueError("'model' is a required parameter.")
|
| 65 |
|
| 66 |
+
provider = model.split('/')[0]
|
| 67 |
+
if provider not in self.api_keys:
|
| 68 |
+
raise ValueError(f"No API keys configured for provider: {provider}")
|
| 69 |
+
|
| 70 |
while True: # Loop until a key succeeds or we decide to give up
|
| 71 |
current_key = self.usage_manager.get_next_smart_key(
|
| 72 |
+
available_keys=self.api_keys[provider],
|
| 73 |
model=model
|
| 74 |
)
|
| 75 |
|
|
|
|
| 88 |
return self._streaming_wrapper(response, current_key, model)
|
| 89 |
else:
|
| 90 |
# For non-streams, we can log usage immediately.
|
| 91 |
+
self.usage_manager.record_success(current_key, model, response)
|
| 92 |
return response
|
| 93 |
|
| 94 |
except Exception as e:
|
|
|
|
| 114 |
print(f"Key ...{current_key[-4:]} failed permanently. Rotating...")
|
| 115 |
self.usage_manager.record_rotation_error(current_key, model, e)
|
| 116 |
break
|
| 117 |
+
|
| 118 |
+
def token_count(self, model: str, text: str = None, messages: List[Dict[str, str]] = None) -> int:
|
| 119 |
+
"""
|
| 120 |
+
Calculates the number of tokens for a given text or list of messages.
|
| 121 |
+
"""
|
| 122 |
+
if messages:
|
| 123 |
+
return token_counter(model=model, messages=messages)
|
| 124 |
+
elif text:
|
| 125 |
+
return token_counter(model=model, text=text)
|
| 126 |
+
else:
|
| 127 |
+
raise ValueError("Either 'text' or 'messages' must be provided.")
|
| 128 |
+
|
| 129 |
+
async def get_available_models(self, provider: str) -> List[str]:
|
| 130 |
+
"""
|
| 131 |
+
Returns a list of available models for a specific provider, with caching.
|
| 132 |
+
"""
|
| 133 |
+
if provider in self._model_list_cache:
|
| 134 |
+
return self._model_list_cache[provider]
|
| 135 |
+
|
| 136 |
+
api_key = self.api_keys.get(provider, [None])[0]
|
| 137 |
+
if not api_key:
|
| 138 |
+
return []
|
| 139 |
+
|
| 140 |
+
if provider in self._provider_instances:
|
| 141 |
+
models = await self._provider_instances[provider].get_models(api_key)
|
| 142 |
+
self._model_list_cache[provider] = models
|
| 143 |
+
return models
|
| 144 |
+
else:
|
| 145 |
+
logging.warning(f"Model list fetching not implemented for provider: {provider}")
|
| 146 |
+
return []
|
| 147 |
+
|
| 148 |
+
async def get_all_available_models(self) -> Dict[str, List[str]]:
|
| 149 |
+
"""
|
| 150 |
+
Returns a dictionary of all available models, grouped by provider.
|
| 151 |
+
"""
|
| 152 |
+
all_provider_models = {}
|
| 153 |
+
for provider in self.api_keys.keys():
|
| 154 |
+
all_provider_models[provider] = await self.get_available_models(provider)
|
| 155 |
+
return all_provider_models
|
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import pkgutil
|
| 3 |
+
from typing import Dict, Type
|
| 4 |
+
from .provider_interface import ProviderInterface
|
| 5 |
+
|
| 6 |
+
# --- Provider Plugin System ---
|
| 7 |
+
|
| 8 |
+
# Dictionary to hold discovered provider classes, mapping provider name to class
|
| 9 |
+
PROVIDER_PLUGINS: Dict[str, Type[ProviderInterface]] = {}
|
| 10 |
+
|
| 11 |
+
def _register_providers():
|
| 12 |
+
"""
|
| 13 |
+
Dynamically discovers and imports provider plugins from this directory.
|
| 14 |
+
"""
|
| 15 |
+
package_path = __path__
|
| 16 |
+
package_name = __name__
|
| 17 |
+
|
| 18 |
+
for _, module_name, _ in pkgutil.iter_modules(package_path):
|
| 19 |
+
# Construct the full module path
|
| 20 |
+
full_module_path = f"{package_name}.{module_name}"
|
| 21 |
+
|
| 22 |
+
# Import the module
|
| 23 |
+
module = importlib.import_module(full_module_path)
|
| 24 |
+
|
| 25 |
+
# Look for a class that inherits from ProviderInterface
|
| 26 |
+
for attribute_name in dir(module):
|
| 27 |
+
attribute = getattr(module, attribute_name)
|
| 28 |
+
if isinstance(attribute, type) and issubclass(attribute, ProviderInterface) and attribute is not ProviderInterface:
|
| 29 |
+
# The provider name is derived from the module name (e.g., 'openai_provider' -> 'openai')
|
| 30 |
+
provider_name = module_name.replace("_provider", "")
|
| 31 |
+
PROVIDER_PLUGINS[provider_name] = attribute
|
| 32 |
+
print(f"Registered provider: {provider_name}")
|
| 33 |
+
|
| 34 |
+
# Discover and register providers when the package is imported
|
| 35 |
+
_register_providers()
|
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List
|
| 4 |
+
from .provider_interface import ProviderInterface
|
| 5 |
+
|
| 6 |
+
class AnthropicProvider(ProviderInterface):
|
| 7 |
+
"""
|
| 8 |
+
Provider implementation for the Anthropic API.
|
| 9 |
+
"""
|
| 10 |
+
async def get_models(self, api_key: str) -> List[str]:
|
| 11 |
+
"""
|
| 12 |
+
Fetches the list of available models from the Anthropic API.
|
| 13 |
+
"""
|
| 14 |
+
try:
|
| 15 |
+
response = requests.get(
|
| 16 |
+
"https://api.anthropic.com/v1/models",
|
| 17 |
+
headers={
|
| 18 |
+
"x-api-key": api_key,
|
| 19 |
+
"anthropic-version": "2023-06-01"
|
| 20 |
+
}
|
| 21 |
+
)
|
| 22 |
+
response.raise_for_status()
|
| 23 |
+
return [f"anthropic/{model['id']}" for model in response.json().get("data", [])]
|
| 24 |
+
except requests.RequestException as e:
|
| 25 |
+
logging.error(f"Failed to fetch Anthropic models: {e}")
|
| 26 |
+
return []
|
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List
|
| 3 |
+
from .provider_interface import ProviderInterface
|
| 4 |
+
|
| 5 |
+
class BedrockProvider(ProviderInterface):
|
| 6 |
+
"""
|
| 7 |
+
Provider implementation for AWS Bedrock.
|
| 8 |
+
"""
|
| 9 |
+
async def get_models(self, api_key: str) -> List[str]:
|
| 10 |
+
"""
|
| 11 |
+
Returns a hardcoded list of common Bedrock models, as there is no
|
| 12 |
+
simple, unauthenticated API endpoint to list them.
|
| 13 |
+
"""
|
| 14 |
+
# Note: Listing Bedrock models typically requires AWS credentials and boto3.
|
| 15 |
+
# For a simple, key-based proxy, we'll list common models.
|
| 16 |
+
# This can be expanded with full AWS authentication if needed.
|
| 17 |
+
logging.info("Returning hardcoded list for Bedrock. Full discovery requires AWS auth.")
|
| 18 |
+
return [
|
| 19 |
+
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
| 20 |
+
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
| 21 |
+
"bedrock/cohere.command-r-plus-v1:0",
|
| 22 |
+
"bedrock/mistral.mistral-large-2402-v1:0",
|
| 23 |
+
]
|
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List
|
| 4 |
+
from .provider_interface import ProviderInterface
|
| 5 |
+
|
| 6 |
+
class CohereProvider(ProviderInterface):
|
| 7 |
+
"""
|
| 8 |
+
Provider implementation for the Cohere API.
|
| 9 |
+
"""
|
| 10 |
+
async def get_models(self, api_key: str) -> List[str]:
|
| 11 |
+
"""
|
| 12 |
+
Fetches the list of available models from the Cohere API.
|
| 13 |
+
"""
|
| 14 |
+
try:
|
| 15 |
+
response = requests.get(
|
| 16 |
+
"https://api.cohere.ai/v1/models",
|
| 17 |
+
headers={"Authorization": f"Bearer {api_key}"}
|
| 18 |
+
)
|
| 19 |
+
response.raise_for_status()
|
| 20 |
+
return [f"cohere/{model['name']}" for model in response.json().get("models", [])]
|
| 21 |
+
except requests.RequestException as e:
|
| 22 |
+
logging.error(f"Failed to fetch Cohere models: {e}")
|
| 23 |
+
return []
|
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List
|
| 4 |
+
from .provider_interface import ProviderInterface
|
| 5 |
+
|
| 6 |
+
class GeminiProvider(ProviderInterface):
|
| 7 |
+
"""
|
| 8 |
+
Provider implementation for the Google Gemini API.
|
| 9 |
+
"""
|
| 10 |
+
async def get_models(self, api_key: str) -> List[str]:
|
| 11 |
+
"""
|
| 12 |
+
Fetches the list of available models from the Google Gemini API.
|
| 13 |
+
"""
|
| 14 |
+
try:
|
| 15 |
+
response = requests.get(
|
| 16 |
+
"https://generativelanguage.googleapis.com/v1beta/models",
|
| 17 |
+
headers={"x-goog-api-key": api_key}
|
| 18 |
+
)
|
| 19 |
+
response.raise_for_status()
|
| 20 |
+
return [f"gemini/{model['name'].replace('models/', '')}" for model in response.json().get("models", [])]
|
| 21 |
+
except requests.RequestException as e:
|
| 22 |
+
logging.error(f"Failed to fetch Gemini models: {e}")
|
| 23 |
+
return []
|
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List
|
| 4 |
+
from .provider_interface import ProviderInterface
|
| 5 |
+
|
| 6 |
+
class GroqProvider(ProviderInterface):
|
| 7 |
+
"""
|
| 8 |
+
Provider implementation for the Groq API.
|
| 9 |
+
"""
|
| 10 |
+
async def get_models(self, api_key: str) -> List[str]:
|
| 11 |
+
"""
|
| 12 |
+
Fetches the list of available models from the Groq API.
|
| 13 |
+
"""
|
| 14 |
+
try:
|
| 15 |
+
response = requests.get(
|
| 16 |
+
"https://api.groq.com/openai/v1/models",
|
| 17 |
+
headers={"Authorization": f"Bearer {api_key}"}
|
| 18 |
+
)
|
| 19 |
+
response.raise_for_status()
|
| 20 |
+
return [f"groq/{model['id']}" for model in response.json().get("data", [])]
|
| 21 |
+
except requests.RequestException as e:
|
| 22 |
+
logging.error(f"Failed to fetch Groq models: {e}")
|
| 23 |
+
return []
|
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List
|
| 4 |
+
from .provider_interface import ProviderInterface
|
| 5 |
+
|
| 6 |
+
class MistralProvider(ProviderInterface):
|
| 7 |
+
"""
|
| 8 |
+
Provider implementation for the Mistral API.
|
| 9 |
+
"""
|
| 10 |
+
async def get_models(self, api_key: str) -> List[str]:
|
| 11 |
+
"""
|
| 12 |
+
Fetches the list of available models from the Mistral API.
|
| 13 |
+
"""
|
| 14 |
+
try:
|
| 15 |
+
response = requests.get(
|
| 16 |
+
"https://api.mistral.ai/v1/models",
|
| 17 |
+
headers={"Authorization": f"Bearer {api_key}"}
|
| 18 |
+
)
|
| 19 |
+
response.raise_for_status()
|
| 20 |
+
return [f"mistral/{model['id']}" for model in response.json().get("data", [])]
|
| 21 |
+
except requests.RequestException as e:
|
| 22 |
+
logging.error(f"Failed to fetch Mistral models: {e}")
|
| 23 |
+
return []
|
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List
|
| 4 |
+
from .provider_interface import ProviderInterface
|
| 5 |
+
|
| 6 |
+
class OpenAIProvider(ProviderInterface):
|
| 7 |
+
"""
|
| 8 |
+
Provider implementation for the OpenAI API.
|
| 9 |
+
"""
|
| 10 |
+
async def get_models(self, api_key: str) -> List[str]:
|
| 11 |
+
"""
|
| 12 |
+
Fetches the list of available models from the OpenAI API.
|
| 13 |
+
"""
|
| 14 |
+
try:
|
| 15 |
+
response = requests.get(
|
| 16 |
+
"https://api.openai.com/v1/models",
|
| 17 |
+
headers={"Authorization": f"Bearer {api_key}"}
|
| 18 |
+
)
|
| 19 |
+
response.raise_for_status()
|
| 20 |
+
return [f"openai/{model['id']}" for model in response.json().get("data", [])]
|
| 21 |
+
except requests.RequestException as e:
|
| 22 |
+
logging.error(f"Failed to fetch OpenAI models: {e}")
|
| 23 |
+
return []
|
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Any
|
| 3 |
+
|
| 4 |
+
class ProviderInterface(ABC):
|
| 5 |
+
"""
|
| 6 |
+
An interface for API provider-specific functionality, primarily for discovering
|
| 7 |
+
available models.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
@abstractmethod
|
| 11 |
+
async def get_models(self, api_key: str) -> List[str]:
|
| 12 |
+
"""
|
| 13 |
+
Fetches the list of available model names from the provider's API.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
api_key: The API key required for authentication.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
A list of model name strings.
|
| 20 |
+
"""
|
| 21 |
+
pass
|
|
@@ -4,6 +4,7 @@ import time
|
|
| 4 |
from datetime import date, datetime
|
| 5 |
from typing import Dict, List, Optional, Any
|
| 6 |
from filelock import FileLock
|
|
|
|
| 7 |
|
| 8 |
class UsageManager:
|
| 9 |
"""
|
|
@@ -42,10 +43,11 @@ class UsageManager:
|
|
| 42 |
# Add yesterday's daily stats to global stats
|
| 43 |
global_data = data.setdefault("global", {"models": {}})
|
| 44 |
for model, stats in daily_data.get("models", {}).items():
|
| 45 |
-
global_model_stats = global_data["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0})
|
| 46 |
global_model_stats["success_count"] += stats.get("success_count", 0)
|
| 47 |
global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0)
|
| 48 |
global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0)
|
|
|
|
| 49 |
|
| 50 |
# Reset daily stats
|
| 51 |
data["daily"] = {"date": today_str, "models": {}}
|
|
@@ -82,7 +84,7 @@ class UsageManager:
|
|
| 82 |
|
| 83 |
return best_key if best_key else active_keys[0]
|
| 84 |
|
| 85 |
-
def record_success(self, key: str, model: str,
|
| 86 |
key_data = self.usage_data.setdefault(key, {"daily": {"date": date.today().isoformat(), "models": {}}, "global": {"models": {}}, "cooldown_until": None})
|
| 87 |
|
| 88 |
# Ensure daily stats are for today
|
|
@@ -90,12 +92,22 @@ class UsageManager:
|
|
| 90 |
self._reset_daily_stats_if_needed() # Should be rare, but as a safeguard
|
| 91 |
key_data = self.usage_data[key]
|
| 92 |
|
| 93 |
-
daily_model_data = key_data["daily"]["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0})
|
| 94 |
|
|
|
|
| 95 |
daily_model_data["success_count"] += 1
|
| 96 |
-
daily_model_data["prompt_tokens"] += usage.
|
| 97 |
-
daily_model_data["completion_tokens"] += usage.
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
key_data["last_used_ts"] = time.time()
|
| 100 |
self._save_usage()
|
| 101 |
|
|
|
|
| 4 |
from datetime import date, datetime
|
| 5 |
from typing import Dict, List, Optional, Any
|
| 6 |
from filelock import FileLock
|
| 7 |
+
import litellm
|
| 8 |
|
| 9 |
class UsageManager:
|
| 10 |
"""
|
|
|
|
| 43 |
# Add yesterday's daily stats to global stats
|
| 44 |
global_data = data.setdefault("global", {"models": {}})
|
| 45 |
for model, stats in daily_data.get("models", {}).items():
|
| 46 |
+
global_model_stats = global_data["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0})
|
| 47 |
global_model_stats["success_count"] += stats.get("success_count", 0)
|
| 48 |
global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0)
|
| 49 |
global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0)
|
| 50 |
+
global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0)
|
| 51 |
|
| 52 |
# Reset daily stats
|
| 53 |
data["daily"] = {"date": today_str, "models": {}}
|
|
|
|
| 84 |
|
| 85 |
return best_key if best_key else active_keys[0]
|
| 86 |
|
| 87 |
+
def record_success(self, key: str, model: str, completion_response: litellm.ModelResponse):
|
| 88 |
key_data = self.usage_data.setdefault(key, {"daily": {"date": date.today().isoformat(), "models": {}}, "global": {"models": {}}, "cooldown_until": None})
|
| 89 |
|
| 90 |
# Ensure daily stats are for today
|
|
|
|
| 92 |
self._reset_daily_stats_if_needed() # Should be rare, but as a safeguard
|
| 93 |
key_data = self.usage_data[key]
|
| 94 |
|
| 95 |
+
daily_model_data = key_data["daily"]["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0})
|
| 96 |
|
| 97 |
+
usage = completion_response.usage
|
| 98 |
daily_model_data["success_count"] += 1
|
| 99 |
+
daily_model_data["prompt_tokens"] += usage.prompt_tokens
|
| 100 |
+
daily_model_data["completion_tokens"] += usage.completion_tokens
|
| 101 |
|
| 102 |
+
# Calculate approximate cost using LiteLLM
|
| 103 |
+
try:
|
| 104 |
+
cost = litellm.completion_cost(
|
| 105 |
+
completion_response=completion_response
|
| 106 |
+
)
|
| 107 |
+
daily_model_data["approx_cost"] += cost
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f"Warning: Could not calculate cost for model {model}: {e}")
|
| 110 |
+
|
| 111 |
key_data["last_used_ts"] = time.time()
|
| 112 |
self._save_usage()
|
| 113 |
|