Mirrowel commited on
Commit
21dcb11
·
1 Parent(s): 7a5872b

feat(multi-provider): Implement dynamic API key loading and new endpoints

Browse files

Refactor API key loading in `main.py` to dynamically support multiple providers from environment variables (e.g., `OPENAI_API_KEY`, `ANTHROPIC_API_KEY_1`).

Introduce a pluggable provider system in `src/rotator_library/providers` with an abstract `ProviderInterface` and concrete implementations for:
- Anthropic
- AWS Bedrock
- Cohere
- Google Gemini
- Groq
- Mistral
- OpenAI

Enhance `RotatingClient` to:
- Accept a dictionary of API keys, grouped by provider.
- Dynamically fetch available models from integrated providers with caching.
- Calculate token counts using `litellm.token_counter`.

Add new API endpoints to `main.py`:
- `GET /v1/models`: Lists all available models across configured providers.
- `GET /v1/providers`: Lists all integrated providers.
- `POST /v1/token-count`: Calculates token usage for given messages or text.

Update `UsageManager` to:
- Record approximate costs for completions.
- Utilize `litellm.ModelResponse` for more comprehensive usage tracking.
- Streamline `rotator_library/__init__.py` exports.

src/proxy_app/main.py CHANGED
@@ -10,10 +10,10 @@ import sys
10
  # Add the 'src' directory to the Python path to allow importing 'rotating_api_key_client'
11
  sys.path.append(str(Path(__file__).resolve().parent.parent))
12
 
13
- from rotator_library import RotatingClient
14
 
15
  # Configure logging
16
- logging.basicConfig(level=logging.INFO)
17
 
18
  # Load environment variables from .env file
19
  load_dotenv()
@@ -23,27 +23,21 @@ PROXY_API_KEY = os.getenv("PROXY_API_KEY")
23
  if not PROXY_API_KEY:
24
  raise ValueError("PROXY_API_KEY environment variable not set.")
25
 
26
- # Load all Gemini keys from environment variables
27
- gemini_keys = []
28
- i = 1
29
- while True:
30
- # Start with GEMINI_API_KEY_1, then GEMINI_API_KEY_2, etc.
31
- key = os.getenv(f"GEMINI_API_KEY_{i}")
32
- if not key and i == 1:
33
- # Fallback for a single key named just GEMINI_API_KEY
34
- key = os.getenv("GEMINI_API_KEY")
35
-
36
- if key:
37
- gemini_keys.append(key)
38
- i += 1
39
- else:
40
- break
41
-
42
- if not gemini_keys:
43
- raise ValueError("No GEMINI_API_KEY or GEMINI_API_KEY_n environment variables found.")
44
 
45
  # Initialize the rotating client
46
- rotating_client = RotatingClient(api_keys=gemini_keys)
47
 
48
  # --- FastAPI App Setup ---
49
  app = FastAPI()
@@ -79,3 +73,38 @@ async def chat_completions(request: Request, _=Depends(verify_api_key)):
79
  @app.get("/")
80
  def read_root():
81
  return {"Status": "API Key Proxy is running"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Add the 'src' directory to the Python path to allow importing 'rotating_api_key_client'
11
  sys.path.append(str(Path(__file__).resolve().parent.parent))
12
 
13
+ from rotator_library import RotatingClient, PROVIDER_PLUGINS
14
 
15
  # Configure logging
16
+ logging.basicConfig(level=logging.INFO) #-> moved to the rotator_library
17
 
18
  # Load environment variables from .env file
19
  load_dotenv()
 
23
  if not PROXY_API_KEY:
24
  raise ValueError("PROXY_API_KEY environment variable not set.")
25
 
26
+ # Load all provider API keys from environment variables
27
+ api_keys = {}
28
+ for key, value in os.environ.items():
29
+ if key.endswith("_API_KEY") or "_API_KEY_" in key:
30
+ parts = key.split("_API_KEY")
31
+ provider = parts[0].lower()
32
+ if provider not in api_keys:
33
+ api_keys[provider] = []
34
+ api_keys[provider].append(value)
35
+
36
+ if not api_keys:
37
+ raise ValueError("No provider API keys found in environment variables.")
 
 
 
 
 
 
38
 
39
  # Initialize the rotating client
40
+ rotating_client = RotatingClient(api_keys=api_keys)
41
 
42
  # --- FastAPI App Setup ---
43
  app = FastAPI()
 
73
  @app.get("/")
74
  def read_root():
75
  return {"Status": "API Key Proxy is running"}
76
+
77
+ @app.get("/v1/models")
78
+ async def list_models(_=Depends(verify_api_key)):
79
+ """
80
+ Returns a list of available models from all configured providers.
81
+ """
82
+ models = await rotating_client.get_all_available_models()
83
+ return {"data": models}
84
+
85
+ @app.get("/v1/providers")
86
+ async def list_providers(_=Depends(verify_api_key)):
87
+ """
88
+ Returns a list of all available providers.
89
+ """
90
+ return {"data": list(PROVIDER_PLUGINS.keys())}
91
+
92
+ @app.post("/v1/token-count")
93
+ async def token_count(request: Request, _=Depends(verify_api_key)):
94
+ """
95
+ Calculates the token count for a given list of messages and a model.
96
+ """
97
+ try:
98
+ data = await request.json()
99
+ model = data.get("model")
100
+ messages = data.get("messages")
101
+
102
+ if not model or not messages:
103
+ raise HTTPException(status_code=400, detail="'model' and 'messages' are required.")
104
+
105
+ count = rotating_client.token_count(model=model, messages=messages)
106
+ return {"token_count": count}
107
+
108
+ except Exception as e:
109
+ logging.error(f"Token count failed: {e}")
110
+ raise HTTPException(status_code=500, detail=str(e))
src/rotator_library/__init__.py CHANGED
@@ -1,17 +1,4 @@
1
- """
2
- Rotating API Key Client
3
- """
4
  from .client import RotatingClient
5
- from .usage_manager import UsageManager
6
- from .error_handler import is_authentication_error, is_rate_limit_error, is_server_error, is_unrecoverable_error
7
- from .failure_logger import log_failure
8
 
9
- __all__ = [
10
- "RotatingClient",
11
- "UsageManager",
12
- "is_authentication_error",
13
- "is_rate_limit_error",
14
- "is_server_error",
15
- "is_unrecoverable_error",
16
- "log_failure",
17
- ]
 
 
 
 
1
  from .client import RotatingClient
2
+ from .providers import PROVIDER_PLUGINS
 
 
3
 
4
+ __all__ = ["RotatingClient", "PROVIDER_PLUGINS"]
 
 
 
 
 
 
 
 
src/rotator_library/client.py CHANGED
@@ -1,29 +1,31 @@
1
  import asyncio
2
  import json
3
  import litellm
 
4
  import logging
5
  from typing import List, Dict, Any, AsyncGenerator
6
 
7
  from src.rotator_library.usage_manager import UsageManager
8
  from src.rotator_library.failure_logger import log_failure
9
- from src.rotator_library.error_handler import (
10
- is_authentication_error,
11
- is_rate_limit_error,
12
- is_server_error,
13
- is_unrecoverable_error,
14
- )
15
 
16
  class RotatingClient:
17
  """
18
  A client that intelligently rotates and retries API keys using LiteLLM,
19
  with support for both streaming and non-streaming responses.
20
  """
21
- def __init__(self, api_keys: List[str], max_retries: int = 2, usage_file_path: str = "key_usage.json"):
 
22
  if not api_keys:
23
- raise ValueError("API keys list cannot be empty.")
24
  self.api_keys = api_keys
25
  self.max_retries = max_retries
26
  self.usage_manager = UsageManager(file_path=usage_file_path)
 
 
 
 
27
 
28
  async def _streaming_wrapper(self, stream: Any, key: str, model: str) -> AsyncGenerator[Any, None]:
29
  """
@@ -42,7 +44,7 @@ class RotatingClient:
42
  # Safely check for usage data in the chunk
43
  if hasattr(chunk, 'usage') and chunk.usage:
44
  logging.info(f"Usage found in chunk for key ...{key[-4:]}: {chunk.usage}")
45
- self.usage_manager.record_success(key, model, chunk.usage)
46
 
47
  finally:
48
  # Signal the end of the stream
@@ -61,9 +63,13 @@ class RotatingClient:
61
  if not model:
62
  raise ValueError("'model' is a required parameter.")
63
 
 
 
 
 
64
  while True: # Loop until a key succeeds or we decide to give up
65
  current_key = self.usage_manager.get_next_smart_key(
66
- available_keys=self.api_keys,
67
  model=model
68
  )
69
 
@@ -82,7 +88,7 @@ class RotatingClient:
82
  return self._streaming_wrapper(response, current_key, model)
83
  else:
84
  # For non-streams, we can log usage immediately.
85
- self.usage_manager.record_success(current_key, model, response.usage)
86
  return response
87
 
88
  except Exception as e:
@@ -108,3 +114,42 @@ class RotatingClient:
108
  print(f"Key ...{current_key[-4:]} failed permanently. Rotating...")
109
  self.usage_manager.record_rotation_error(current_key, model, e)
110
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
  import json
3
  import litellm
4
+ from litellm.litellm_core_utils.token_counter import token_counter
5
  import logging
6
  from typing import List, Dict, Any, AsyncGenerator
7
 
8
  from src.rotator_library.usage_manager import UsageManager
9
  from src.rotator_library.failure_logger import log_failure
10
+ from src.rotator_library.error_handler import is_server_error, is_unrecoverable_error
11
+ from src.rotator_library.providers import PROVIDER_PLUGINS
 
 
 
 
12
 
13
  class RotatingClient:
14
  """
15
  A client that intelligently rotates and retries API keys using LiteLLM,
16
  with support for both streaming and non-streaming responses.
17
  """
18
+ def __init__(self, api_keys: Dict[str, List[str]], max_retries: int = 2, usage_file_path: str = "key_usage.json"):
19
+ litellm.set_verbose = False
20
  if not api_keys:
21
+ raise ValueError("API keys dictionary cannot be empty.")
22
  self.api_keys = api_keys
23
  self.max_retries = max_retries
24
  self.usage_manager = UsageManager(file_path=usage_file_path)
25
+ self._model_list_cache = {}
26
+ self._provider_instances = {
27
+ name: plugin() for name, plugin in PROVIDER_PLUGINS.items()
28
+ }
29
 
30
  async def _streaming_wrapper(self, stream: Any, key: str, model: str) -> AsyncGenerator[Any, None]:
31
  """
 
44
  # Safely check for usage data in the chunk
45
  if hasattr(chunk, 'usage') and chunk.usage:
46
  logging.info(f"Usage found in chunk for key ...{key[-4:]}: {chunk.usage}")
47
+ self.usage_manager.record_success(key, model, chunk)
48
 
49
  finally:
50
  # Signal the end of the stream
 
63
  if not model:
64
  raise ValueError("'model' is a required parameter.")
65
 
66
+ provider = model.split('/')[0]
67
+ if provider not in self.api_keys:
68
+ raise ValueError(f"No API keys configured for provider: {provider}")
69
+
70
  while True: # Loop until a key succeeds or we decide to give up
71
  current_key = self.usage_manager.get_next_smart_key(
72
+ available_keys=self.api_keys[provider],
73
  model=model
74
  )
75
 
 
88
  return self._streaming_wrapper(response, current_key, model)
89
  else:
90
  # For non-streams, we can log usage immediately.
91
+ self.usage_manager.record_success(current_key, model, response)
92
  return response
93
 
94
  except Exception as e:
 
114
  print(f"Key ...{current_key[-4:]} failed permanently. Rotating...")
115
  self.usage_manager.record_rotation_error(current_key, model, e)
116
  break
117
+
118
+ def token_count(self, model: str, text: str = None, messages: List[Dict[str, str]] = None) -> int:
119
+ """
120
+ Calculates the number of tokens for a given text or list of messages.
121
+ """
122
+ if messages:
123
+ return token_counter(model=model, messages=messages)
124
+ elif text:
125
+ return token_counter(model=model, text=text)
126
+ else:
127
+ raise ValueError("Either 'text' or 'messages' must be provided.")
128
+
129
+ async def get_available_models(self, provider: str) -> List[str]:
130
+ """
131
+ Returns a list of available models for a specific provider, with caching.
132
+ """
133
+ if provider in self._model_list_cache:
134
+ return self._model_list_cache[provider]
135
+
136
+ api_key = self.api_keys.get(provider, [None])[0]
137
+ if not api_key:
138
+ return []
139
+
140
+ if provider in self._provider_instances:
141
+ models = await self._provider_instances[provider].get_models(api_key)
142
+ self._model_list_cache[provider] = models
143
+ return models
144
+ else:
145
+ logging.warning(f"Model list fetching not implemented for provider: {provider}")
146
+ return []
147
+
148
+ async def get_all_available_models(self) -> Dict[str, List[str]]:
149
+ """
150
+ Returns a dictionary of all available models, grouped by provider.
151
+ """
152
+ all_provider_models = {}
153
+ for provider in self.api_keys.keys():
154
+ all_provider_models[provider] = await self.get_available_models(provider)
155
+ return all_provider_models
src/rotator_library/providers/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import pkgutil
3
+ from typing import Dict, Type
4
+ from .provider_interface import ProviderInterface
5
+
6
+ # --- Provider Plugin System ---
7
+
8
+ # Dictionary to hold discovered provider classes, mapping provider name to class
9
+ PROVIDER_PLUGINS: Dict[str, Type[ProviderInterface]] = {}
10
+
11
+ def _register_providers():
12
+ """
13
+ Dynamically discovers and imports provider plugins from this directory.
14
+ """
15
+ package_path = __path__
16
+ package_name = __name__
17
+
18
+ for _, module_name, _ in pkgutil.iter_modules(package_path):
19
+ # Construct the full module path
20
+ full_module_path = f"{package_name}.{module_name}"
21
+
22
+ # Import the module
23
+ module = importlib.import_module(full_module_path)
24
+
25
+ # Look for a class that inherits from ProviderInterface
26
+ for attribute_name in dir(module):
27
+ attribute = getattr(module, attribute_name)
28
+ if isinstance(attribute, type) and issubclass(attribute, ProviderInterface) and attribute is not ProviderInterface:
29
+ # The provider name is derived from the module name (e.g., 'openai_provider' -> 'openai')
30
+ provider_name = module_name.replace("_provider", "")
31
+ PROVIDER_PLUGINS[provider_name] = attribute
32
+ print(f"Registered provider: {provider_name}")
33
+
34
+ # Discover and register providers when the package is imported
35
+ _register_providers()
src/rotator_library/providers/anthropic_provider.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import logging
3
+ from typing import List
4
+ from .provider_interface import ProviderInterface
5
+
6
+ class AnthropicProvider(ProviderInterface):
7
+ """
8
+ Provider implementation for the Anthropic API.
9
+ """
10
+ async def get_models(self, api_key: str) -> List[str]:
11
+ """
12
+ Fetches the list of available models from the Anthropic API.
13
+ """
14
+ try:
15
+ response = requests.get(
16
+ "https://api.anthropic.com/v1/models",
17
+ headers={
18
+ "x-api-key": api_key,
19
+ "anthropic-version": "2023-06-01"
20
+ }
21
+ )
22
+ response.raise_for_status()
23
+ return [f"anthropic/{model['id']}" for model in response.json().get("data", [])]
24
+ except requests.RequestException as e:
25
+ logging.error(f"Failed to fetch Anthropic models: {e}")
26
+ return []
src/rotator_library/providers/bedrock_provider.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+ from .provider_interface import ProviderInterface
4
+
5
+ class BedrockProvider(ProviderInterface):
6
+ """
7
+ Provider implementation for AWS Bedrock.
8
+ """
9
+ async def get_models(self, api_key: str) -> List[str]:
10
+ """
11
+ Returns a hardcoded list of common Bedrock models, as there is no
12
+ simple, unauthenticated API endpoint to list them.
13
+ """
14
+ # Note: Listing Bedrock models typically requires AWS credentials and boto3.
15
+ # For a simple, key-based proxy, we'll list common models.
16
+ # This can be expanded with full AWS authentication if needed.
17
+ logging.info("Returning hardcoded list for Bedrock. Full discovery requires AWS auth.")
18
+ return [
19
+ "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
20
+ "bedrock/anthropic.claude-3-haiku-20240307-v1:0",
21
+ "bedrock/cohere.command-r-plus-v1:0",
22
+ "bedrock/mistral.mistral-large-2402-v1:0",
23
+ ]
src/rotator_library/providers/cohere_provider.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import logging
3
+ from typing import List
4
+ from .provider_interface import ProviderInterface
5
+
6
+ class CohereProvider(ProviderInterface):
7
+ """
8
+ Provider implementation for the Cohere API.
9
+ """
10
+ async def get_models(self, api_key: str) -> List[str]:
11
+ """
12
+ Fetches the list of available models from the Cohere API.
13
+ """
14
+ try:
15
+ response = requests.get(
16
+ "https://api.cohere.ai/v1/models",
17
+ headers={"Authorization": f"Bearer {api_key}"}
18
+ )
19
+ response.raise_for_status()
20
+ return [f"cohere/{model['name']}" for model in response.json().get("models", [])]
21
+ except requests.RequestException as e:
22
+ logging.error(f"Failed to fetch Cohere models: {e}")
23
+ return []
src/rotator_library/providers/gemini_provider.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import logging
3
+ from typing import List
4
+ from .provider_interface import ProviderInterface
5
+
6
+ class GeminiProvider(ProviderInterface):
7
+ """
8
+ Provider implementation for the Google Gemini API.
9
+ """
10
+ async def get_models(self, api_key: str) -> List[str]:
11
+ """
12
+ Fetches the list of available models from the Google Gemini API.
13
+ """
14
+ try:
15
+ response = requests.get(
16
+ "https://generativelanguage.googleapis.com/v1beta/models",
17
+ headers={"x-goog-api-key": api_key}
18
+ )
19
+ response.raise_for_status()
20
+ return [f"gemini/{model['name'].replace('models/', '')}" for model in response.json().get("models", [])]
21
+ except requests.RequestException as e:
22
+ logging.error(f"Failed to fetch Gemini models: {e}")
23
+ return []
src/rotator_library/providers/groq_provider.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import logging
3
+ from typing import List
4
+ from .provider_interface import ProviderInterface
5
+
6
+ class GroqProvider(ProviderInterface):
7
+ """
8
+ Provider implementation for the Groq API.
9
+ """
10
+ async def get_models(self, api_key: str) -> List[str]:
11
+ """
12
+ Fetches the list of available models from the Groq API.
13
+ """
14
+ try:
15
+ response = requests.get(
16
+ "https://api.groq.com/openai/v1/models",
17
+ headers={"Authorization": f"Bearer {api_key}"}
18
+ )
19
+ response.raise_for_status()
20
+ return [f"groq/{model['id']}" for model in response.json().get("data", [])]
21
+ except requests.RequestException as e:
22
+ logging.error(f"Failed to fetch Groq models: {e}")
23
+ return []
src/rotator_library/providers/mistral_provider.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import logging
3
+ from typing import List
4
+ from .provider_interface import ProviderInterface
5
+
6
+ class MistralProvider(ProviderInterface):
7
+ """
8
+ Provider implementation for the Mistral API.
9
+ """
10
+ async def get_models(self, api_key: str) -> List[str]:
11
+ """
12
+ Fetches the list of available models from the Mistral API.
13
+ """
14
+ try:
15
+ response = requests.get(
16
+ "https://api.mistral.ai/v1/models",
17
+ headers={"Authorization": f"Bearer {api_key}"}
18
+ )
19
+ response.raise_for_status()
20
+ return [f"mistral/{model['id']}" for model in response.json().get("data", [])]
21
+ except requests.RequestException as e:
22
+ logging.error(f"Failed to fetch Mistral models: {e}")
23
+ return []
src/rotator_library/providers/openai_provider.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import logging
3
+ from typing import List
4
+ from .provider_interface import ProviderInterface
5
+
6
+ class OpenAIProvider(ProviderInterface):
7
+ """
8
+ Provider implementation for the OpenAI API.
9
+ """
10
+ async def get_models(self, api_key: str) -> List[str]:
11
+ """
12
+ Fetches the list of available models from the OpenAI API.
13
+ """
14
+ try:
15
+ response = requests.get(
16
+ "https://api.openai.com/v1/models",
17
+ headers={"Authorization": f"Bearer {api_key}"}
18
+ )
19
+ response.raise_for_status()
20
+ return [f"openai/{model['id']}" for model in response.json().get("data", [])]
21
+ except requests.RequestException as e:
22
+ logging.error(f"Failed to fetch OpenAI models: {e}")
23
+ return []
src/rotator_library/providers/provider_interface.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Any
3
+
4
+ class ProviderInterface(ABC):
5
+ """
6
+ An interface for API provider-specific functionality, primarily for discovering
7
+ available models.
8
+ """
9
+
10
+ @abstractmethod
11
+ async def get_models(self, api_key: str) -> List[str]:
12
+ """
13
+ Fetches the list of available model names from the provider's API.
14
+
15
+ Args:
16
+ api_key: The API key required for authentication.
17
+
18
+ Returns:
19
+ A list of model name strings.
20
+ """
21
+ pass
src/rotator_library/usage_manager.py CHANGED
@@ -4,6 +4,7 @@ import time
4
  from datetime import date, datetime
5
  from typing import Dict, List, Optional, Any
6
  from filelock import FileLock
 
7
 
8
  class UsageManager:
9
  """
@@ -42,10 +43,11 @@ class UsageManager:
42
  # Add yesterday's daily stats to global stats
43
  global_data = data.setdefault("global", {"models": {}})
44
  for model, stats in daily_data.get("models", {}).items():
45
- global_model_stats = global_data["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0})
46
  global_model_stats["success_count"] += stats.get("success_count", 0)
47
  global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0)
48
  global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0)
 
49
 
50
  # Reset daily stats
51
  data["daily"] = {"date": today_str, "models": {}}
@@ -82,7 +84,7 @@ class UsageManager:
82
 
83
  return best_key if best_key else active_keys[0]
84
 
85
- def record_success(self, key: str, model: str, usage: Dict):
86
  key_data = self.usage_data.setdefault(key, {"daily": {"date": date.today().isoformat(), "models": {}}, "global": {"models": {}}, "cooldown_until": None})
87
 
88
  # Ensure daily stats are for today
@@ -90,12 +92,22 @@ class UsageManager:
90
  self._reset_daily_stats_if_needed() # Should be rare, but as a safeguard
91
  key_data = self.usage_data[key]
92
 
93
- daily_model_data = key_data["daily"]["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0})
94
 
 
95
  daily_model_data["success_count"] += 1
96
- daily_model_data["prompt_tokens"] += usage.get("prompt_tokens", 0)
97
- daily_model_data["completion_tokens"] += usage.get("completion_tokens", 0)
98
 
 
 
 
 
 
 
 
 
 
99
  key_data["last_used_ts"] = time.time()
100
  self._save_usage()
101
 
 
4
  from datetime import date, datetime
5
  from typing import Dict, List, Optional, Any
6
  from filelock import FileLock
7
+ import litellm
8
 
9
  class UsageManager:
10
  """
 
43
  # Add yesterday's daily stats to global stats
44
  global_data = data.setdefault("global", {"models": {}})
45
  for model, stats in daily_data.get("models", {}).items():
46
+ global_model_stats = global_data["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0})
47
  global_model_stats["success_count"] += stats.get("success_count", 0)
48
  global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0)
49
  global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0)
50
+ global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0)
51
 
52
  # Reset daily stats
53
  data["daily"] = {"date": today_str, "models": {}}
 
84
 
85
  return best_key if best_key else active_keys[0]
86
 
87
+ def record_success(self, key: str, model: str, completion_response: litellm.ModelResponse):
88
  key_data = self.usage_data.setdefault(key, {"daily": {"date": date.today().isoformat(), "models": {}}, "global": {"models": {}}, "cooldown_until": None})
89
 
90
  # Ensure daily stats are for today
 
92
  self._reset_daily_stats_if_needed() # Should be rare, but as a safeguard
93
  key_data = self.usage_data[key]
94
 
95
+ daily_model_data = key_data["daily"]["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0})
96
 
97
+ usage = completion_response.usage
98
  daily_model_data["success_count"] += 1
99
+ daily_model_data["prompt_tokens"] += usage.prompt_tokens
100
+ daily_model_data["completion_tokens"] += usage.completion_tokens
101
 
102
+ # Calculate approximate cost using LiteLLM
103
+ try:
104
+ cost = litellm.completion_cost(
105
+ completion_response=completion_response
106
+ )
107
+ daily_model_data["approx_cost"] += cost
108
+ except Exception as e:
109
+ print(f"Warning: Could not calculate cost for model {model}: {e}")
110
+
111
  key_data["last_used_ts"] = time.time()
112
  self._save_usage()
113