IgorSlinko commited on
Commit
8bd8d2f
·
1 Parent(s): bca0945

Filter litellm models to chat mode only

Browse files

- Add _litellm_chat_prices_cache for filtered models
- Rename get_litellm_prices() to get_litellm_prices_raw()
- Create wrapper get_litellm_prices() that filters mode='chat'
- All model lookups now return only chat models

Files changed (1) hide show
  1. app.py +23 -7
app.py CHANGED
@@ -50,6 +50,7 @@ def _log_unhandled(exc_type, exc_value, exc_traceback):
50
  sys.excepthook = _log_unhandled
51
 
52
  _litellm_prices_cache = None
 
53
  _trajectories_cache = {}
54
  _calculated_tokens_cache = {}
55
  _trajectory_steps_cache = {}
@@ -391,13 +392,8 @@ def load_all_trajectory_steps(folder: str) -> dict[str, list[dict]]:
391
  return result
392
 
393
 
394
- def get_litellm_model_list() -> list[str]:
395
- """Get list of model names from litellm prices"""
396
- prices = get_litellm_prices()
397
- return sorted(prices.keys())
398
-
399
-
400
- def get_litellm_prices() -> dict:
401
  global _litellm_prices_cache
402
  if _litellm_prices_cache is not None:
403
  return _litellm_prices_cache
@@ -421,6 +417,26 @@ def get_litellm_prices() -> dict:
421
  return _litellm_prices_cache
422
 
423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  def normalize_model_name(name: str) -> str:
425
  """Normalize model name for comparison: lowercase, remove separators"""
426
  return re.sub(r'[-_./]', '', name.lower())
 
50
  sys.excepthook = _log_unhandled
51
 
52
  _litellm_prices_cache = None
53
+ _litellm_chat_prices_cache = None
54
  _trajectories_cache = {}
55
  _calculated_tokens_cache = {}
56
  _trajectory_steps_cache = {}
 
392
  return result
393
 
394
 
395
+ def get_litellm_prices_raw() -> dict:
396
+ """Get raw litellm prices (all modes, unfiltered)"""
 
 
 
 
 
397
  global _litellm_prices_cache
398
  if _litellm_prices_cache is not None:
399
  return _litellm_prices_cache
 
417
  return _litellm_prices_cache
418
 
419
 
420
+ def get_litellm_prices() -> dict:
421
+ """Get litellm prices filtered to chat models only"""
422
+ global _litellm_chat_prices_cache
423
+ if _litellm_chat_prices_cache is not None:
424
+ return _litellm_chat_prices_cache
425
+
426
+ raw_prices = get_litellm_prices_raw()
427
+ _litellm_chat_prices_cache = {
428
+ k: v for k, v in raw_prices.items()
429
+ if isinstance(v, dict) and v.get("mode") == "chat"
430
+ }
431
+ return _litellm_chat_prices_cache
432
+
433
+
434
+ def get_litellm_model_list() -> list[str]:
435
+ """Get list of chat model names from litellm prices"""
436
+ prices = get_litellm_prices()
437
+ return sorted(prices.keys())
438
+
439
+
440
  def normalize_model_name(name: str) -> str:
441
  """Normalize model name for comparison: lowercase, remove separators"""
442
  return re.sub(r'[-_./]', '', name.lower())