""" Enhanced LLM Router — Multi-provider with smart key rotation, cooldown, and failover. Providers: - gemini (Google Generative Language API) — 6 keys - sambanova (SambaNova OpenAI-compatible) — 9 keys - github_gpt (GitHub Models, OpenAI-compatible) — 9 keys Key loading (priority order): 1. Env var GEMINI_KEY / SAMBANOVA_KEY / GITHUB_KEY (comma-separated) 2. Hardcoded fallback pool Rotation & healing: - Round-robin across keys per provider - Per-key failure counter with 5-min cooldown after MAX_FAILURES_BEFORE_COOLDOWN - Auto-heal: keys automatically re-enter the pool after cooldown expires - Provider failover: if all keys for a provider exhaust, try next provider Task-aware routing: - classify_task() maps prompt → task_type - provider_order() picks optimal provider order per task type """ from __future__ import annotations import asyncio import json import logging import os import time from dataclasses import dataclass, field from typing import Any, AsyncIterator, Dict, List, Optional import httpx logger = logging.getLogger(__name__) # ─── Constants ──────────────────────────────────────────────────────────────── MAX_PROVIDER_RETRY = 3 MAX_FAILURES_BEFORE_COOLDOWN = 3 COOLDOWN_SECONDS = 300 # 5 min cooldown HARD_FAIL_COOLDOWN = 600 # 10 min for auth errors REQUEST_TIMEOUT_SECONDS = 120.0 STREAM_TIMEOUT_SECONDS = 600.0 # ─── Key fallback pools ─────────────────────────────────────────────────────── _GEMINI_FALLBACK = [ "AIzaSyCyIZthDgVUUtCiMycqi42VrY6PTUNG9HQ", "AIzaSyAhXY6rLF0GvN4gn6bQSQ9VyGbD4iRX-x4", "AIzaSyD5TI9VjCL3Mc8OE3qU_sbMA0ZA727fwFc", "AIzaSyC5X0cgWyb0YLzlnBLo2ESKgQNNPgs_DHs", "AIzaSyBAmPWoqnOCG740wq1JHKuHP5g-GeQFx24", "AIzaSyCLn3OBoGwKKBZzP6lcATuF__H2jsv94cg", ] _SAMBANOVA_FALLBACK = [ "0fea3265-9949-413b-a4d4-5976f18b64e3", "2c19ca9f-6b6a-4658-a20f-4dcf1b58cdc3", "2747e1fb-62b9-4a32-a072-c4c1f767584c", "460f63de-ec38-4c91-99ee-54e6d23de589", "b99cf0d1-3798-41a4-af48-84f32e73596d", "a0e9baa0-8759-411a-a6e7-420eb5d9e419", "2359f623-debd-4ad2-af37-dd232928a04f", "936e4b3c-6373-4fc4-b6aa-d6571635266a", "f30a397b-27a9-45bf-94b9-829a0d5c6cf1", ] _GITHUB_FALLBACK = [ "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", "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", "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", "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", "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", "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", "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", "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", "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", ] def _load_keys(env_var: str, fallback: list[str]) -> list[str]: raw = os.environ.get(env_var, "").strip() if raw: keys = [k.strip() for k in raw.split(",") if k.strip()] if keys: logger.info("Loaded %d keys for %s from env", len(keys), env_var) return keys logger.info("Using %d fallback keys for %s", len(fallback), env_var) return list(fallback) # ─── Provider definitions ───────────────────────────────────────────────────── @dataclass class ProviderConfig: name: str kind: str # 'gemini' | 'openai' url: str key_env: str model: str fallback_keys: list[str] stream_supported: bool = True PROVIDERS: Dict[str, ProviderConfig] = { "gemini": ProviderConfig( name="gemini", kind="gemini", url="https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent", key_env="GEMINI_KEY", model="gemini-2.0-flash", fallback_keys=_GEMINI_FALLBACK, stream_supported=False, ), "sambanova": ProviderConfig( name="sambanova", kind="openai", url="https://api.sambanova.ai/v1/chat/completions", key_env="SAMBANOVA_KEY", model="Meta-Llama-3.3-70B-Instruct", fallback_keys=_SAMBANOVA_FALLBACK, stream_supported=True, ), "github_gpt4o": ProviderConfig( name="github_gpt4o", kind="openai", url="https://models.inference.ai.azure.com/chat/completions", key_env="GITHUB_KEY", model="gpt-4o", fallback_keys=_GITHUB_FALLBACK, stream_supported=True, ), } # Model overrides per provider (selectable by client) MODEL_MAP: Dict[str, Dict[str, str]] = { "gemini": { "default": "gemini-2.0-flash", "fast": "gemini-1.5-flash", "pro": "gemini-1.5-pro", "think": "gemini-2.0-flash-thinking-exp", }, "sambanova": { "default": "Meta-Llama-3.3-70B-Instruct", "large": "Meta-Llama-3.1-405B-Instruct", "deepseek": "DeepSeek-R1", "qwen": "Qwen2.5-72B-Instruct", }, "github_gpt4o": { "default": "gpt-4o", "mini": "gpt-4o-mini", "llama": "Meta-Llama-3.1-70B-Instruct", "mistral": "Mistral-large-2407", }, } # ─── Key pool ───────────────────────────────────────────────────────────────── @dataclass class KeyState: key: str fail_count: int = 0 cooldown_until: float = 0.0 def is_ready(self) -> bool: return self.cooldown_until <= time.time() def to_dict(self) -> dict: return { "suffix": f"...{self.key[-6:]}", "fail_count": self.fail_count, "available": self.is_ready(), "cooldown_remaining": max(0.0, self.cooldown_until - time.time()), } class KeyPool: """Round-robin key pool with failure tracking & auto-heal cooldown.""" def __init__(self, keys: List[str]) -> None: self._keys: List[KeyState] = [KeyState(k.strip()) for k in keys if k.strip()] self._cursor = 0 def __bool__(self) -> bool: return len(self._keys) > 0 def pick(self) -> Optional[KeyState]: if not self._keys: return None n = len(self._keys) for _ in range(n): ks = self._keys[self._cursor % n] self._cursor += 1 if ks.is_ready(): return ks return None @staticmethod def mark_success(ks: KeyState) -> None: ks.fail_count = 0 ks.cooldown_until = 0.0 @staticmethod def mark_failure(ks: KeyState, status_code: int = 0) -> None: ks.fail_count += 1 if status_code in (401, 403): # Auth error — hard cooldown ks.cooldown_until = time.time() + HARD_FAIL_COOLDOWN logger.warning("Key auth error (HTTP %d) — hard cooldown %ds", status_code, HARD_FAIL_COOLDOWN) elif ks.fail_count >= MAX_FAILURES_BEFORE_COOLDOWN: ks.cooldown_until = time.time() + COOLDOWN_SECONDS logger.warning("Key cooled for %ds (fail_count=%d)", COOLDOWN_SECONDS, ks.fail_count) def status(self) -> dict: now = time.time() return { "total": len(self._keys), "available": sum(1 for k in self._keys if k.is_ready()), "keys": [k.to_dict() for k in self._keys], } # Cache pools so cooldown state survives across requests _POOL_CACHE: Dict[str, KeyPool] = {} def get_pool(provider: ProviderConfig) -> KeyPool: if provider.name not in _POOL_CACHE: keys = _load_keys(provider.key_env, provider.fallback_keys) _POOL_CACHE[provider.name] = KeyPool(keys) return _POOL_CACHE[provider.name] # ─── Task classification → provider order ───────────────────────────────────── def classify_task(prompt: str) -> str: p = (prompt or "").lower() if any(w in p for w in ("workflow", "automation", "pipeline", "orchestrat")): return "planning" if any(w in p for w in ("code", "python", "javascript", "typescript", "function", "api", "build", "debug", "fix", "refactor", "test")): return "engineering" if any(w in p for w in ("why", "analyze", "analyse", "explain", "reason", "think", "evaluate", "compare")): return "reasoning" if any(w in p for w in ("translate", "summarize", "summarise", "summary", "rewrite", "paraphrase")): return "language" if any(w in p for w in ("math", "calculate", "solve", "equation", "formula")): return "math" return "general" def provider_order(prompt: str) -> List[str]: task = classify_task(prompt) orders = { "engineering": ["sambanova", "github_gpt4o", "gemini"], "reasoning": ["sambanova", "github_gpt4o", "gemini"], "planning": ["github_gpt4o", "sambanova", "gemini"], "math": ["sambanova", "github_gpt4o", "gemini"], "language": ["gemini", "sambanova", "github_gpt4o"], "general": ["gemini", "sambanova", "github_gpt4o"], } return orders.get(task, ["gemini", "sambanova", "github_gpt4o"]) # ─── Provider callers ───────────────────────────────────────────────────────── def _gemini_body(messages: List[Dict[str, str]], model: str, temperature: float, max_tokens: int) -> tuple[str, Dict]: contents = [] system_parts: List[str] = [] for m in messages: role = m.get("role") content = m.get("content", "") if role == "system": system_parts.append(content) continue gem_role = "user" if role == "user" else "model" contents.append({"role": gem_role, "parts": [{"text": content}]}) body: Dict[str, Any] = { "contents": contents, "generationConfig": { "temperature": temperature, "maxOutputTokens": max_tokens, }, } if system_parts: body["systemInstruction"] = {"parts": [{"text": "\n".join(system_parts)}]} # Build URL with model url = PROVIDERS["gemini"].url.format(model=model) return url, body def _extract_text(provider: ProviderConfig, data: Dict[str, Any]) -> str: if provider.kind == "gemini": try: return data["candidates"][0]["content"]["parts"][0]["text"] except (KeyError, IndexError, TypeError): return "" try: return data["choices"][0]["message"]["content"] or "" except (KeyError, IndexError, TypeError): return "" async def _call_once( client: httpx.AsyncClient, provider: ProviderConfig, key: str, messages: List[Dict[str, str]], temperature: float = 0.4, max_tokens: int = 2048, model: Optional[str] = None, ) -> tuple[str, int]: """Returns (text, status_code).""" use_model = model or provider.model if provider.kind == "gemini": url, body = _gemini_body(messages, use_model, temperature, max_tokens) url = f"{url}?key={key}" r = await client.post(url, json=body, timeout=REQUEST_TIMEOUT_SECONDS) else: headers = { "Authorization": f"Bearer {key}", "Content-Type": "application/json", } body = { "model": use_model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, } r = await client.post( provider.url, headers=headers, json=body, timeout=REQUEST_TIMEOUT_SECONDS ) if r.status_code >= 400: raise RuntimeError(f"{provider.name} HTTP {r.status_code}: {r.text[:200]}") return _extract_text(provider, r.json()), r.status_code async def complete( messages: List[Dict[str, str]], *, temperature: float = 0.4, max_tokens: int = 2048, preferred_provider: Optional[str] = None, model: Optional[str] = None, ) -> Dict[str, Any]: """Non-streaming completion with provider/key failover. Returns: {"content": str, "provider": str, "model": str} """ prompt_text = "\n".join(m.get("content", "") for m in messages if m.get("role") == "user") order = provider_order(prompt_text) # Honor preferred_provider if specified if preferred_provider and preferred_provider in PROVIDERS: order = [preferred_provider] + [p for p in order if p != preferred_provider] last_err: Optional[str] = None async with httpx.AsyncClient() as client: for provider_name in order: provider = PROVIDERS[provider_name] pool = get_pool(provider) if not pool: continue for _ in range(MAX_PROVIDER_RETRY): ks = pool.pick() if ks is None: break try: text, status = await _call_once( client, provider, ks.key, messages, temperature=temperature, max_tokens=max_tokens, model=model, ) if not text.strip(): raise RuntimeError("empty completion") KeyPool.mark_success(ks) used_model = model or provider.model return { "content": text, "provider": provider.name, "model": used_model, } except Exception as e: last_err = f"{provider.name}: {e}" status_code = 0 if "HTTP " in str(e): try: status_code = int(str(e).split("HTTP ")[1].split(":")[0]) except Exception: pass logger.warning("LLM call failed → %s", last_err) KeyPool.mark_failure(ks, status_code) raise RuntimeError(f"ALL_PROVIDERS_FAILED ({last_err})") # ─── Streaming ──────────────────────────────────────────────────────────────── async def _stream_openai( client: httpx.AsyncClient, provider: ProviderConfig, key: str, messages: List[Dict[str, str]], temperature: float, max_tokens: int, model: Optional[str] = None, ) -> AsyncIterator[str]: use_model = model or provider.model headers = { "Authorization": f"Bearer {key}", "Content-Type": "application/json", "Accept": "text/event-stream", } body = { "model": use_model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True, } async with client.stream( "POST", provider.url, headers=headers, json=body, timeout=STREAM_TIMEOUT_SECONDS, ) as r: if r.status_code >= 400: err_text = (await r.aread()).decode("utf-8", "ignore")[:200] raise RuntimeError(f"{provider.name} HTTP {r.status_code}: {err_text}") async for line in r.aiter_lines(): if not line or not line.startswith("data:"): continue payload = line[5:].strip() if payload == "[DONE]": break try: obj = json.loads(payload) delta = obj["choices"][0]["delta"].get("content") if delta: yield delta except Exception: continue async def stream_complete( messages: List[Dict[str, str]], *, temperature: float = 0.4, max_tokens: int = 2048, preferred_provider: Optional[str] = None, model: Optional[str] = None, ) -> AsyncIterator[Dict[str, Any]]: """Yield {'type':'delta','content':str,'provider':str} chunks, then {'type':'done', ...}.""" prompt_text = "\n".join(m.get("content", "") for m in messages if m.get("role") == "user") order = provider_order(prompt_text) if preferred_provider and preferred_provider in PROVIDERS: order = [preferred_provider] + [p for p in order if p != preferred_provider] last_err: Optional[str] = None async with httpx.AsyncClient() as client: for provider_name in order: provider = PROVIDERS[provider_name] pool = get_pool(provider) if not pool: continue for _ in range(MAX_PROVIDER_RETRY): ks = pool.pick() if ks is None: break try: if provider.stream_supported: got_any = False async for delta in _stream_openai( client, provider, ks.key, messages, temperature, max_tokens, model=model ): got_any = True yield {"type": "delta", "content": delta, "provider": provider.name} if not got_any: raise RuntimeError("empty stream") else: # Gemini fallback: non-streaming text, _ = await _call_once( client, provider, ks.key, messages, temperature=temperature, max_tokens=max_tokens, model=model, ) if not text.strip(): raise RuntimeError("empty completion") yield {"type": "delta", "content": text, "provider": provider.name} KeyPool.mark_success(ks) yield { "type": "done", "provider": provider.name, "model": model or provider.model, } return except Exception as e: last_err = f"{provider.name}: {e}" logger.warning("LLM stream failed → %s", last_err) KeyPool.mark_failure(ks) yield {"type": "error", "error": f"ALL_PROVIDERS_FAILED ({last_err})"} def pool_status() -> Dict[str, Any]: """Diagnostic info about each provider's key pool.""" out: Dict[str, Any] = {} for name, provider in PROVIDERS.items(): pool = get_pool(provider) out[name] = { "model": provider.model, **pool.status(), } return out