| |
| """ |
| build_pool() β ordered (model_str, api_key) list, primary first then fallbacks |
| call_llm() β tries each slot until one succeeds, raises RuntimeError if all fail |
| """ |
|
|
| import os |
| import time |
| import logging |
|
|
| import litellm |
| from litellm.exceptions import ( |
| RateLimitError, |
| ServiceUnavailableError, |
| APIConnectionError, |
| ) |
|
|
| from models import FALLBACK_CHAIN, MAX_NEW_TOKENS |
|
|
| log = logging.getLogger(__name__) |
|
|
| |
| litellm.suppress_debug_info = True |
|
|
|
|
| def build_pool( |
| primary_model: str, |
| primary_keys: list[str], |
| ) -> list[tuple[str, str]]: |
| """ |
| Returns an ordered list of (model_str, resolved_api_key) pairs. |
| |
| Layout: |
| 1. primary_model Γ each non-empty key in primary_keys |
| 2. every other entry in FALLBACK_CHAIN Γ their non-empty keys |
| (skips primary_model to avoid duplicates) |
| """ |
| pool: list[tuple[str, str]] = [] |
|
|
| |
| for env_var in primary_keys: |
| key = os.environ.get(env_var, "").strip() |
| if key: |
| pool.append((primary_model, key)) |
|
|
| |
| for model_str, key_env_vars in FALLBACK_CHAIN: |
| if model_str == primary_model: |
| continue |
| for env_var in key_env_vars: |
| key = os.environ.get(env_var, "").strip() |
| if key: |
| pool.append((model_str, key)) |
|
|
| if not pool: |
| raise RuntimeError( |
| "No API keys found in environment. " |
| "Set at least one of: GROQ_API_KEY_1, CEREBRAS_API_KEY_1, " |
| "MISTRAL_API_KEY_1, SAMBANOVA_API_KEY_1." |
| ) |
|
|
| return pool |
|
|
|
|
| |
| _RETRIABLE = (RateLimitError, ServiceUnavailableError, APIConnectionError) |
|
|
|
|
| def call_llm( |
| pool: list[tuple[str, str]], |
| messages: list[dict], |
| max_tokens: int = MAX_NEW_TOKENS, |
| ) -> str: |
| last_exc: Exception = RuntimeError("Pool was empty") |
|
|
| for idx, (model_str, api_key) in enumerate(pool): |
| try: |
| resp = litellm.completion( |
| model=model_str, |
| messages=messages, |
| max_tokens=max_tokens, |
| api_key=api_key, |
| num_retries=1, |
| timeout=120, |
| ) |
| log.warning( |
| "β SUCCESS on %s (slot %d/%d)", model_str, idx + 1, len(pool) |
| ) |
| return resp.choices[0].message.content |
|
|
| except _RETRIABLE as exc: |
| log.warning( |
| "Rate-limited on %s (slot %d/%d): %s β trying next slot", |
| model_str, |
| idx + 1, |
| len(pool), |
| exc, |
| ) |
| last_exc = exc |
| if idx < len(pool) - 1: |
| time.sleep(0.5) |
| continue |
|
|
| except Exception as exc: |
| log.warning( |
| "Hard error on %s (slot %d/%d): %s β trying next slot", |
| model_str, |
| idx + 1, |
| len(pool), |
| exc, |
| ) |
| last_exc = exc |
| continue |
|
|
| raise RuntimeError( |
| f"All {len(pool)} provider slots exhausted. Last error: {last_exc}" |
| ) |
|
|