Commit
·
8bd8d2f
1
Parent(s):
bca0945
Filter litellm models to chat mode only
Browse files- Add _litellm_chat_prices_cache for filtered models
- Rename get_litellm_prices() to get_litellm_prices_raw()
- Create wrapper get_litellm_prices() that filters mode='chat'
- All model lookups now return only chat models
app.py
CHANGED
|
@@ -50,6 +50,7 @@ def _log_unhandled(exc_type, exc_value, exc_traceback):
|
|
| 50 |
sys.excepthook = _log_unhandled
|
| 51 |
|
| 52 |
_litellm_prices_cache = None
|
|
|
|
| 53 |
_trajectories_cache = {}
|
| 54 |
_calculated_tokens_cache = {}
|
| 55 |
_trajectory_steps_cache = {}
|
|
@@ -391,13 +392,8 @@ def load_all_trajectory_steps(folder: str) -> dict[str, list[dict]]:
|
|
| 391 |
return result
|
| 392 |
|
| 393 |
|
| 394 |
-
def
|
| 395 |
-
"""Get
|
| 396 |
-
prices = get_litellm_prices()
|
| 397 |
-
return sorted(prices.keys())
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
def get_litellm_prices() -> dict:
|
| 401 |
global _litellm_prices_cache
|
| 402 |
if _litellm_prices_cache is not None:
|
| 403 |
return _litellm_prices_cache
|
|
@@ -421,6 +417,26 @@ def get_litellm_prices() -> dict:
|
|
| 421 |
return _litellm_prices_cache
|
| 422 |
|
| 423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
def normalize_model_name(name: str) -> str:
|
| 425 |
"""Normalize model name for comparison: lowercase, remove separators"""
|
| 426 |
return re.sub(r'[-_./]', '', name.lower())
|
|
|
|
| 50 |
sys.excepthook = _log_unhandled
|
| 51 |
|
| 52 |
_litellm_prices_cache = None
|
| 53 |
+
_litellm_chat_prices_cache = None
|
| 54 |
_trajectories_cache = {}
|
| 55 |
_calculated_tokens_cache = {}
|
| 56 |
_trajectory_steps_cache = {}
|
|
|
|
| 392 |
return result
|
| 393 |
|
| 394 |
|
| 395 |
+
def get_litellm_prices_raw() -> dict:
|
| 396 |
+
"""Get raw litellm prices (all modes, unfiltered)"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
global _litellm_prices_cache
|
| 398 |
if _litellm_prices_cache is not None:
|
| 399 |
return _litellm_prices_cache
|
|
|
|
| 417 |
return _litellm_prices_cache
|
| 418 |
|
| 419 |
|
| 420 |
+
def get_litellm_prices() -> dict:
|
| 421 |
+
"""Get litellm prices filtered to chat models only"""
|
| 422 |
+
global _litellm_chat_prices_cache
|
| 423 |
+
if _litellm_chat_prices_cache is not None:
|
| 424 |
+
return _litellm_chat_prices_cache
|
| 425 |
+
|
| 426 |
+
raw_prices = get_litellm_prices_raw()
|
| 427 |
+
_litellm_chat_prices_cache = {
|
| 428 |
+
k: v for k, v in raw_prices.items()
|
| 429 |
+
if isinstance(v, dict) and v.get("mode") == "chat"
|
| 430 |
+
}
|
| 431 |
+
return _litellm_chat_prices_cache
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def get_litellm_model_list() -> list[str]:
|
| 435 |
+
"""Get list of chat model names from litellm prices"""
|
| 436 |
+
prices = get_litellm_prices()
|
| 437 |
+
return sorted(prices.keys())
|
| 438 |
+
|
| 439 |
+
|
| 440 |
def normalize_model_name(name: str) -> str:
|
| 441 |
"""Normalize model name for comparison: lowercase, remove separators"""
|
| 442 |
return re.sub(r'[-_./]', '', name.lower())
|