openhands-backend / llm_router.py
PYAE1994
πŸš€ Phase 3: Supabase + Redis SSE + WebSocket + enhanced smart routing (6 Gemini / 9 SambaNova / 9 GitHub keys)
9b18003
"""
Enhanced LLM Router β€” Multi-provider with smart key rotation, cooldown, and failover.
Providers:
- gemini (Google Generative Language API) β€” 6 keys
- sambanova (SambaNova OpenAI-compatible) β€” 9 keys
- github_gpt (GitHub Models, OpenAI-compatible) β€” 9 keys
Key loading (priority order):
1. Env var GEMINI_KEY / SAMBANOVA_KEY / GITHUB_KEY (comma-separated)
2. Hardcoded fallback pool
Rotation & healing:
- Round-robin across keys per provider
- Per-key failure counter with 5-min cooldown after MAX_FAILURES_BEFORE_COOLDOWN
- Auto-heal: keys automatically re-enter the pool after cooldown expires
- Provider failover: if all keys for a provider exhaust, try next provider
Task-aware routing:
- classify_task() maps prompt β†’ task_type
- provider_order() picks optimal provider order per task type
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional
import httpx
logger = logging.getLogger(__name__)
# ─── Constants ────────────────────────────────────────────────────────────────
MAX_PROVIDER_RETRY = 3
MAX_FAILURES_BEFORE_COOLDOWN = 3
COOLDOWN_SECONDS = 300 # 5 min cooldown
HARD_FAIL_COOLDOWN = 600 # 10 min for auth errors
REQUEST_TIMEOUT_SECONDS = 120.0
STREAM_TIMEOUT_SECONDS = 600.0
# ─── Key fallback pools ───────────────────────────────────────────────────────
_GEMINI_FALLBACK = [
"AIzaSyCyIZthDgVUUtCiMycqi42VrY6PTUNG9HQ",
"AIzaSyAhXY6rLF0GvN4gn6bQSQ9VyGbD4iRX-x4",
"AIzaSyD5TI9VjCL3Mc8OE3qU_sbMA0ZA727fwFc",
"AIzaSyC5X0cgWyb0YLzlnBLo2ESKgQNNPgs_DHs",
"AIzaSyBAmPWoqnOCG740wq1JHKuHP5g-GeQFx24",
"AIzaSyCLn3OBoGwKKBZzP6lcATuF__H2jsv94cg",
]
_SAMBANOVA_FALLBACK = [
"0fea3265-9949-413b-a4d4-5976f18b64e3",
"2c19ca9f-6b6a-4658-a20f-4dcf1b58cdc3",
"2747e1fb-62b9-4a32-a072-c4c1f767584c",
"460f63de-ec38-4c91-99ee-54e6d23de589",
"b99cf0d1-3798-41a4-af48-84f32e73596d",
"a0e9baa0-8759-411a-a6e7-420eb5d9e419",
"2359f623-debd-4ad2-af37-dd232928a04f",
"936e4b3c-6373-4fc4-b6aa-d6571635266a",
"f30a397b-27a9-45bf-94b9-829a0d5c6cf1",
]
_GITHUB_FALLBACK = [
"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",
"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",
"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",
"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",
"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",
"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",
"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",
"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",
"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",
]
def _load_keys(env_var: str, fallback: list[str]) -> list[str]:
raw = os.environ.get(env_var, "").strip()
if raw:
keys = [k.strip() for k in raw.split(",") if k.strip()]
if keys:
logger.info("Loaded %d keys for %s from env", len(keys), env_var)
return keys
logger.info("Using %d fallback keys for %s", len(fallback), env_var)
return list(fallback)
# ─── Provider definitions ─────────────────────────────────────────────────────
@dataclass
class ProviderConfig:
name: str
kind: str # 'gemini' | 'openai'
url: str
key_env: str
model: str
fallback_keys: list[str]
stream_supported: bool = True
PROVIDERS: Dict[str, ProviderConfig] = {
"gemini": ProviderConfig(
name="gemini",
kind="gemini",
url="https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent",
key_env="GEMINI_KEY",
model="gemini-2.0-flash",
fallback_keys=_GEMINI_FALLBACK,
stream_supported=False,
),
"sambanova": ProviderConfig(
name="sambanova",
kind="openai",
url="https://api.sambanova.ai/v1/chat/completions",
key_env="SAMBANOVA_KEY",
model="Meta-Llama-3.3-70B-Instruct",
fallback_keys=_SAMBANOVA_FALLBACK,
stream_supported=True,
),
"github_gpt4o": ProviderConfig(
name="github_gpt4o",
kind="openai",
url="https://models.inference.ai.azure.com/chat/completions",
key_env="GITHUB_KEY",
model="gpt-4o",
fallback_keys=_GITHUB_FALLBACK,
stream_supported=True,
),
}
# Model overrides per provider (selectable by client)
MODEL_MAP: Dict[str, Dict[str, str]] = {
"gemini": {
"default": "gemini-2.0-flash",
"fast": "gemini-1.5-flash",
"pro": "gemini-1.5-pro",
"think": "gemini-2.0-flash-thinking-exp",
},
"sambanova": {
"default": "Meta-Llama-3.3-70B-Instruct",
"large": "Meta-Llama-3.1-405B-Instruct",
"deepseek": "DeepSeek-R1",
"qwen": "Qwen2.5-72B-Instruct",
},
"github_gpt4o": {
"default": "gpt-4o",
"mini": "gpt-4o-mini",
"llama": "Meta-Llama-3.1-70B-Instruct",
"mistral": "Mistral-large-2407",
},
}
# ─── Key pool ─────────────────────────────────────────────────────────────────
@dataclass
class KeyState:
key: str
fail_count: int = 0
cooldown_until: float = 0.0
def is_ready(self) -> bool:
return self.cooldown_until <= time.time()
def to_dict(self) -> dict:
return {
"suffix": f"...{self.key[-6:]}",
"fail_count": self.fail_count,
"available": self.is_ready(),
"cooldown_remaining": max(0.0, self.cooldown_until - time.time()),
}
class KeyPool:
"""Round-robin key pool with failure tracking & auto-heal cooldown."""
def __init__(self, keys: List[str]) -> None:
self._keys: List[KeyState] = [KeyState(k.strip()) for k in keys if k.strip()]
self._cursor = 0
def __bool__(self) -> bool:
return len(self._keys) > 0
def pick(self) -> Optional[KeyState]:
if not self._keys:
return None
n = len(self._keys)
for _ in range(n):
ks = self._keys[self._cursor % n]
self._cursor += 1
if ks.is_ready():
return ks
return None
@staticmethod
def mark_success(ks: KeyState) -> None:
ks.fail_count = 0
ks.cooldown_until = 0.0
@staticmethod
def mark_failure(ks: KeyState, status_code: int = 0) -> None:
ks.fail_count += 1
if status_code in (401, 403):
# Auth error β€” hard cooldown
ks.cooldown_until = time.time() + HARD_FAIL_COOLDOWN
logger.warning("Key auth error (HTTP %d) β€” hard cooldown %ds", status_code, HARD_FAIL_COOLDOWN)
elif ks.fail_count >= MAX_FAILURES_BEFORE_COOLDOWN:
ks.cooldown_until = time.time() + COOLDOWN_SECONDS
logger.warning("Key cooled for %ds (fail_count=%d)", COOLDOWN_SECONDS, ks.fail_count)
def status(self) -> dict:
now = time.time()
return {
"total": len(self._keys),
"available": sum(1 for k in self._keys if k.is_ready()),
"keys": [k.to_dict() for k in self._keys],
}
# Cache pools so cooldown state survives across requests
_POOL_CACHE: Dict[str, KeyPool] = {}
def get_pool(provider: ProviderConfig) -> KeyPool:
if provider.name not in _POOL_CACHE:
keys = _load_keys(provider.key_env, provider.fallback_keys)
_POOL_CACHE[provider.name] = KeyPool(keys)
return _POOL_CACHE[provider.name]
# ─── Task classification β†’ provider order ─────────────────────────────────────
def classify_task(prompt: str) -> str:
p = (prompt or "").lower()
if any(w in p for w in ("workflow", "automation", "pipeline", "orchestrat")):
return "planning"
if any(w in p for w in ("code", "python", "javascript", "typescript", "function",
"api", "build", "debug", "fix", "refactor", "test")):
return "engineering"
if any(w in p for w in ("why", "analyze", "analyse", "explain", "reason",
"think", "evaluate", "compare")):
return "reasoning"
if any(w in p for w in ("translate", "summarize", "summarise", "summary",
"rewrite", "paraphrase")):
return "language"
if any(w in p for w in ("math", "calculate", "solve", "equation", "formula")):
return "math"
return "general"
def provider_order(prompt: str) -> List[str]:
task = classify_task(prompt)
orders = {
"engineering": ["sambanova", "github_gpt4o", "gemini"],
"reasoning": ["sambanova", "github_gpt4o", "gemini"],
"planning": ["github_gpt4o", "sambanova", "gemini"],
"math": ["sambanova", "github_gpt4o", "gemini"],
"language": ["gemini", "sambanova", "github_gpt4o"],
"general": ["gemini", "sambanova", "github_gpt4o"],
}
return orders.get(task, ["gemini", "sambanova", "github_gpt4o"])
# ─── Provider callers ─────────────────────────────────────────────────────────
def _gemini_body(messages: List[Dict[str, str]], model: str, temperature: float, max_tokens: int) -> tuple[str, Dict]:
contents = []
system_parts: List[str] = []
for m in messages:
role = m.get("role")
content = m.get("content", "")
if role == "system":
system_parts.append(content)
continue
gem_role = "user" if role == "user" else "model"
contents.append({"role": gem_role, "parts": [{"text": content}]})
body: Dict[str, Any] = {
"contents": contents,
"generationConfig": {
"temperature": temperature,
"maxOutputTokens": max_tokens,
},
}
if system_parts:
body["systemInstruction"] = {"parts": [{"text": "\n".join(system_parts)}]}
# Build URL with model
url = PROVIDERS["gemini"].url.format(model=model)
return url, body
def _extract_text(provider: ProviderConfig, data: Dict[str, Any]) -> str:
if provider.kind == "gemini":
try:
return data["candidates"][0]["content"]["parts"][0]["text"]
except (KeyError, IndexError, TypeError):
return ""
try:
return data["choices"][0]["message"]["content"] or ""
except (KeyError, IndexError, TypeError):
return ""
async def _call_once(
client: httpx.AsyncClient,
provider: ProviderConfig,
key: str,
messages: List[Dict[str, str]],
temperature: float = 0.4,
max_tokens: int = 2048,
model: Optional[str] = None,
) -> tuple[str, int]:
"""Returns (text, status_code)."""
use_model = model or provider.model
if provider.kind == "gemini":
url, body = _gemini_body(messages, use_model, temperature, max_tokens)
url = f"{url}?key={key}"
r = await client.post(url, json=body, timeout=REQUEST_TIMEOUT_SECONDS)
else:
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
body = {
"model": use_model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
r = await client.post(
provider.url, headers=headers, json=body, timeout=REQUEST_TIMEOUT_SECONDS
)
if r.status_code >= 400:
raise RuntimeError(f"{provider.name} HTTP {r.status_code}: {r.text[:200]}")
return _extract_text(provider, r.json()), r.status_code
async def complete(
messages: List[Dict[str, str]],
*,
temperature: float = 0.4,
max_tokens: int = 2048,
preferred_provider: Optional[str] = None,
model: Optional[str] = None,
) -> Dict[str, Any]:
"""Non-streaming completion with provider/key failover.
Returns: {"content": str, "provider": str, "model": str}
"""
prompt_text = "\n".join(m.get("content", "") for m in messages if m.get("role") == "user")
order = provider_order(prompt_text)
# Honor preferred_provider if specified
if preferred_provider and preferred_provider in PROVIDERS:
order = [preferred_provider] + [p for p in order if p != preferred_provider]
last_err: Optional[str] = None
async with httpx.AsyncClient() as client:
for provider_name in order:
provider = PROVIDERS[provider_name]
pool = get_pool(provider)
if not pool:
continue
for _ in range(MAX_PROVIDER_RETRY):
ks = pool.pick()
if ks is None:
break
try:
text, status = await _call_once(
client, provider, ks.key, messages,
temperature=temperature, max_tokens=max_tokens,
model=model,
)
if not text.strip():
raise RuntimeError("empty completion")
KeyPool.mark_success(ks)
used_model = model or provider.model
return {
"content": text,
"provider": provider.name,
"model": used_model,
}
except Exception as e:
last_err = f"{provider.name}: {e}"
status_code = 0
if "HTTP " in str(e):
try:
status_code = int(str(e).split("HTTP ")[1].split(":")[0])
except Exception:
pass
logger.warning("LLM call failed β†’ %s", last_err)
KeyPool.mark_failure(ks, status_code)
raise RuntimeError(f"ALL_PROVIDERS_FAILED ({last_err})")
# ─── Streaming ────────────────────────────────────────────────────────────────
async def _stream_openai(
client: httpx.AsyncClient,
provider: ProviderConfig,
key: str,
messages: List[Dict[str, str]],
temperature: float,
max_tokens: int,
model: Optional[str] = None,
) -> AsyncIterator[str]:
use_model = model or provider.model
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
"Accept": "text/event-stream",
}
body = {
"model": use_model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
}
async with client.stream(
"POST", provider.url, headers=headers, json=body,
timeout=STREAM_TIMEOUT_SECONDS,
) as r:
if r.status_code >= 400:
err_text = (await r.aread()).decode("utf-8", "ignore")[:200]
raise RuntimeError(f"{provider.name} HTTP {r.status_code}: {err_text}")
async for line in r.aiter_lines():
if not line or not line.startswith("data:"):
continue
payload = line[5:].strip()
if payload == "[DONE]":
break
try:
obj = json.loads(payload)
delta = obj["choices"][0]["delta"].get("content")
if delta:
yield delta
except Exception:
continue
async def stream_complete(
messages: List[Dict[str, str]],
*,
temperature: float = 0.4,
max_tokens: int = 2048,
preferred_provider: Optional[str] = None,
model: Optional[str] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Yield {'type':'delta','content':str,'provider':str} chunks, then {'type':'done', ...}."""
prompt_text = "\n".join(m.get("content", "") for m in messages if m.get("role") == "user")
order = provider_order(prompt_text)
if preferred_provider and preferred_provider in PROVIDERS:
order = [preferred_provider] + [p for p in order if p != preferred_provider]
last_err: Optional[str] = None
async with httpx.AsyncClient() as client:
for provider_name in order:
provider = PROVIDERS[provider_name]
pool = get_pool(provider)
if not pool:
continue
for _ in range(MAX_PROVIDER_RETRY):
ks = pool.pick()
if ks is None:
break
try:
if provider.stream_supported:
got_any = False
async for delta in _stream_openai(
client, provider, ks.key, messages,
temperature, max_tokens, model=model
):
got_any = True
yield {"type": "delta", "content": delta, "provider": provider.name}
if not got_any:
raise RuntimeError("empty stream")
else:
# Gemini fallback: non-streaming
text, _ = await _call_once(
client, provider, ks.key, messages,
temperature=temperature, max_tokens=max_tokens,
model=model,
)
if not text.strip():
raise RuntimeError("empty completion")
yield {"type": "delta", "content": text, "provider": provider.name}
KeyPool.mark_success(ks)
yield {
"type": "done",
"provider": provider.name,
"model": model or provider.model,
}
return
except Exception as e:
last_err = f"{provider.name}: {e}"
logger.warning("LLM stream failed β†’ %s", last_err)
KeyPool.mark_failure(ks)
yield {"type": "error", "error": f"ALL_PROVIDERS_FAILED ({last_err})"}
def pool_status() -> Dict[str, Any]:
"""Diagnostic info about each provider's key pool."""
out: Dict[str, Any] = {}
for name, provider in PROVIDERS.items():
pool = get_pool(provider)
out[name] = {
"model": provider.model,
**pool.status(),
}
return out