Spaces:
Sleeping
Sleeping
PYAE1994
π Phase 3: Supabase + Redis SSE + WebSocket + enhanced smart routing (6 Gemini / 9 SambaNova / 9 GitHub keys)
9b18003 | """ | |
| Enhanced LLM Router β Multi-provider with smart key rotation, cooldown, and failover. | |
| Providers: | |
| - gemini (Google Generative Language API) β 6 keys | |
| - sambanova (SambaNova OpenAI-compatible) β 9 keys | |
| - github_gpt (GitHub Models, OpenAI-compatible) β 9 keys | |
| Key loading (priority order): | |
| 1. Env var GEMINI_KEY / SAMBANOVA_KEY / GITHUB_KEY (comma-separated) | |
| 2. Hardcoded fallback pool | |
| Rotation & healing: | |
| - Round-robin across keys per provider | |
| - Per-key failure counter with 5-min cooldown after MAX_FAILURES_BEFORE_COOLDOWN | |
| - Auto-heal: keys automatically re-enter the pool after cooldown expires | |
| - Provider failover: if all keys for a provider exhaust, try next provider | |
| Task-aware routing: | |
| - classify_task() maps prompt β task_type | |
| - provider_order() picks optimal provider order per task type | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Any, AsyncIterator, Dict, List, Optional | |
| import httpx | |
| logger = logging.getLogger(__name__) | |
| # βββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MAX_PROVIDER_RETRY = 3 | |
| MAX_FAILURES_BEFORE_COOLDOWN = 3 | |
| COOLDOWN_SECONDS = 300 # 5 min cooldown | |
| HARD_FAIL_COOLDOWN = 600 # 10 min for auth errors | |
| REQUEST_TIMEOUT_SECONDS = 120.0 | |
| STREAM_TIMEOUT_SECONDS = 600.0 | |
| # βββ Key fallback pools βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _GEMINI_FALLBACK = [ | |
| "AIzaSyCyIZthDgVUUtCiMycqi42VrY6PTUNG9HQ", | |
| "AIzaSyAhXY6rLF0GvN4gn6bQSQ9VyGbD4iRX-x4", | |
| "AIzaSyD5TI9VjCL3Mc8OE3qU_sbMA0ZA727fwFc", | |
| "AIzaSyC5X0cgWyb0YLzlnBLo2ESKgQNNPgs_DHs", | |
| "AIzaSyBAmPWoqnOCG740wq1JHKuHP5g-GeQFx24", | |
| "AIzaSyCLn3OBoGwKKBZzP6lcATuF__H2jsv94cg", | |
| ] | |
| _SAMBANOVA_FALLBACK = [ | |
| "0fea3265-9949-413b-a4d4-5976f18b64e3", | |
| "2c19ca9f-6b6a-4658-a20f-4dcf1b58cdc3", | |
| "2747e1fb-62b9-4a32-a072-c4c1f767584c", | |
| "460f63de-ec38-4c91-99ee-54e6d23de589", | |
| "b99cf0d1-3798-41a4-af48-84f32e73596d", | |
| "a0e9baa0-8759-411a-a6e7-420eb5d9e419", | |
| "2359f623-debd-4ad2-af37-dd232928a04f", | |
| "936e4b3c-6373-4fc4-b6aa-d6571635266a", | |
| "f30a397b-27a9-45bf-94b9-829a0d5c6cf1", | |
| ] | |
| _GITHUB_FALLBACK = [ | |
| "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", | |
| "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", | |
| "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", | |
| "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", | |
| "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", | |
| "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", | |
| "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", | |
| "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", | |
| "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", | |
| ] | |
| def _load_keys(env_var: str, fallback: list[str]) -> list[str]: | |
| raw = os.environ.get(env_var, "").strip() | |
| if raw: | |
| keys = [k.strip() for k in raw.split(",") if k.strip()] | |
| if keys: | |
| logger.info("Loaded %d keys for %s from env", len(keys), env_var) | |
| return keys | |
| logger.info("Using %d fallback keys for %s", len(fallback), env_var) | |
| return list(fallback) | |
| # βββ Provider definitions βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ProviderConfig: | |
| name: str | |
| kind: str # 'gemini' | 'openai' | |
| url: str | |
| key_env: str | |
| model: str | |
| fallback_keys: list[str] | |
| stream_supported: bool = True | |
| PROVIDERS: Dict[str, ProviderConfig] = { | |
| "gemini": ProviderConfig( | |
| name="gemini", | |
| kind="gemini", | |
| url="https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent", | |
| key_env="GEMINI_KEY", | |
| model="gemini-2.0-flash", | |
| fallback_keys=_GEMINI_FALLBACK, | |
| stream_supported=False, | |
| ), | |
| "sambanova": ProviderConfig( | |
| name="sambanova", | |
| kind="openai", | |
| url="https://api.sambanova.ai/v1/chat/completions", | |
| key_env="SAMBANOVA_KEY", | |
| model="Meta-Llama-3.3-70B-Instruct", | |
| fallback_keys=_SAMBANOVA_FALLBACK, | |
| stream_supported=True, | |
| ), | |
| "github_gpt4o": ProviderConfig( | |
| name="github_gpt4o", | |
| kind="openai", | |
| url="https://models.inference.ai.azure.com/chat/completions", | |
| key_env="GITHUB_KEY", | |
| model="gpt-4o", | |
| fallback_keys=_GITHUB_FALLBACK, | |
| stream_supported=True, | |
| ), | |
| } | |
| # Model overrides per provider (selectable by client) | |
| MODEL_MAP: Dict[str, Dict[str, str]] = { | |
| "gemini": { | |
| "default": "gemini-2.0-flash", | |
| "fast": "gemini-1.5-flash", | |
| "pro": "gemini-1.5-pro", | |
| "think": "gemini-2.0-flash-thinking-exp", | |
| }, | |
| "sambanova": { | |
| "default": "Meta-Llama-3.3-70B-Instruct", | |
| "large": "Meta-Llama-3.1-405B-Instruct", | |
| "deepseek": "DeepSeek-R1", | |
| "qwen": "Qwen2.5-72B-Instruct", | |
| }, | |
| "github_gpt4o": { | |
| "default": "gpt-4o", | |
| "mini": "gpt-4o-mini", | |
| "llama": "Meta-Llama-3.1-70B-Instruct", | |
| "mistral": "Mistral-large-2407", | |
| }, | |
| } | |
| # βββ Key pool βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class KeyState: | |
| key: str | |
| fail_count: int = 0 | |
| cooldown_until: float = 0.0 | |
| def is_ready(self) -> bool: | |
| return self.cooldown_until <= time.time() | |
| def to_dict(self) -> dict: | |
| return { | |
| "suffix": f"...{self.key[-6:]}", | |
| "fail_count": self.fail_count, | |
| "available": self.is_ready(), | |
| "cooldown_remaining": max(0.0, self.cooldown_until - time.time()), | |
| } | |
| class KeyPool: | |
| """Round-robin key pool with failure tracking & auto-heal cooldown.""" | |
| def __init__(self, keys: List[str]) -> None: | |
| self._keys: List[KeyState] = [KeyState(k.strip()) for k in keys if k.strip()] | |
| self._cursor = 0 | |
| def __bool__(self) -> bool: | |
| return len(self._keys) > 0 | |
| def pick(self) -> Optional[KeyState]: | |
| if not self._keys: | |
| return None | |
| n = len(self._keys) | |
| for _ in range(n): | |
| ks = self._keys[self._cursor % n] | |
| self._cursor += 1 | |
| if ks.is_ready(): | |
| return ks | |
| return None | |
| def mark_success(ks: KeyState) -> None: | |
| ks.fail_count = 0 | |
| ks.cooldown_until = 0.0 | |
| def mark_failure(ks: KeyState, status_code: int = 0) -> None: | |
| ks.fail_count += 1 | |
| if status_code in (401, 403): | |
| # Auth error β hard cooldown | |
| ks.cooldown_until = time.time() + HARD_FAIL_COOLDOWN | |
| logger.warning("Key auth error (HTTP %d) β hard cooldown %ds", status_code, HARD_FAIL_COOLDOWN) | |
| elif ks.fail_count >= MAX_FAILURES_BEFORE_COOLDOWN: | |
| ks.cooldown_until = time.time() + COOLDOWN_SECONDS | |
| logger.warning("Key cooled for %ds (fail_count=%d)", COOLDOWN_SECONDS, ks.fail_count) | |
| def status(self) -> dict: | |
| now = time.time() | |
| return { | |
| "total": len(self._keys), | |
| "available": sum(1 for k in self._keys if k.is_ready()), | |
| "keys": [k.to_dict() for k in self._keys], | |
| } | |
| # Cache pools so cooldown state survives across requests | |
| _POOL_CACHE: Dict[str, KeyPool] = {} | |
| def get_pool(provider: ProviderConfig) -> KeyPool: | |
| if provider.name not in _POOL_CACHE: | |
| keys = _load_keys(provider.key_env, provider.fallback_keys) | |
| _POOL_CACHE[provider.name] = KeyPool(keys) | |
| return _POOL_CACHE[provider.name] | |
| # βββ Task classification β provider order βββββββββββββββββββββββββββββββββββββ | |
| def classify_task(prompt: str) -> str: | |
| p = (prompt or "").lower() | |
| if any(w in p for w in ("workflow", "automation", "pipeline", "orchestrat")): | |
| return "planning" | |
| if any(w in p for w in ("code", "python", "javascript", "typescript", "function", | |
| "api", "build", "debug", "fix", "refactor", "test")): | |
| return "engineering" | |
| if any(w in p for w in ("why", "analyze", "analyse", "explain", "reason", | |
| "think", "evaluate", "compare")): | |
| return "reasoning" | |
| if any(w in p for w in ("translate", "summarize", "summarise", "summary", | |
| "rewrite", "paraphrase")): | |
| return "language" | |
| if any(w in p for w in ("math", "calculate", "solve", "equation", "formula")): | |
| return "math" | |
| return "general" | |
| def provider_order(prompt: str) -> List[str]: | |
| task = classify_task(prompt) | |
| orders = { | |
| "engineering": ["sambanova", "github_gpt4o", "gemini"], | |
| "reasoning": ["sambanova", "github_gpt4o", "gemini"], | |
| "planning": ["github_gpt4o", "sambanova", "gemini"], | |
| "math": ["sambanova", "github_gpt4o", "gemini"], | |
| "language": ["gemini", "sambanova", "github_gpt4o"], | |
| "general": ["gemini", "sambanova", "github_gpt4o"], | |
| } | |
| return orders.get(task, ["gemini", "sambanova", "github_gpt4o"]) | |
| # βββ Provider callers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _gemini_body(messages: List[Dict[str, str]], model: str, temperature: float, max_tokens: int) -> tuple[str, Dict]: | |
| contents = [] | |
| system_parts: List[str] = [] | |
| for m in messages: | |
| role = m.get("role") | |
| content = m.get("content", "") | |
| if role == "system": | |
| system_parts.append(content) | |
| continue | |
| gem_role = "user" if role == "user" else "model" | |
| contents.append({"role": gem_role, "parts": [{"text": content}]}) | |
| body: Dict[str, Any] = { | |
| "contents": contents, | |
| "generationConfig": { | |
| "temperature": temperature, | |
| "maxOutputTokens": max_tokens, | |
| }, | |
| } | |
| if system_parts: | |
| body["systemInstruction"] = {"parts": [{"text": "\n".join(system_parts)}]} | |
| # Build URL with model | |
| url = PROVIDERS["gemini"].url.format(model=model) | |
| return url, body | |
| def _extract_text(provider: ProviderConfig, data: Dict[str, Any]) -> str: | |
| if provider.kind == "gemini": | |
| try: | |
| return data["candidates"][0]["content"]["parts"][0]["text"] | |
| except (KeyError, IndexError, TypeError): | |
| return "" | |
| try: | |
| return data["choices"][0]["message"]["content"] or "" | |
| except (KeyError, IndexError, TypeError): | |
| return "" | |
| async def _call_once( | |
| client: httpx.AsyncClient, | |
| provider: ProviderConfig, | |
| key: str, | |
| messages: List[Dict[str, str]], | |
| temperature: float = 0.4, | |
| max_tokens: int = 2048, | |
| model: Optional[str] = None, | |
| ) -> tuple[str, int]: | |
| """Returns (text, status_code).""" | |
| use_model = model or provider.model | |
| if provider.kind == "gemini": | |
| url, body = _gemini_body(messages, use_model, temperature, max_tokens) | |
| url = f"{url}?key={key}" | |
| r = await client.post(url, json=body, timeout=REQUEST_TIMEOUT_SECONDS) | |
| else: | |
| headers = { | |
| "Authorization": f"Bearer {key}", | |
| "Content-Type": "application/json", | |
| } | |
| body = { | |
| "model": use_model, | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| } | |
| r = await client.post( | |
| provider.url, headers=headers, json=body, timeout=REQUEST_TIMEOUT_SECONDS | |
| ) | |
| if r.status_code >= 400: | |
| raise RuntimeError(f"{provider.name} HTTP {r.status_code}: {r.text[:200]}") | |
| return _extract_text(provider, r.json()), r.status_code | |
| async def complete( | |
| messages: List[Dict[str, str]], | |
| *, | |
| temperature: float = 0.4, | |
| max_tokens: int = 2048, | |
| preferred_provider: Optional[str] = None, | |
| model: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """Non-streaming completion with provider/key failover. | |
| Returns: {"content": str, "provider": str, "model": str} | |
| """ | |
| prompt_text = "\n".join(m.get("content", "") for m in messages if m.get("role") == "user") | |
| order = provider_order(prompt_text) | |
| # Honor preferred_provider if specified | |
| if preferred_provider and preferred_provider in PROVIDERS: | |
| order = [preferred_provider] + [p for p in order if p != preferred_provider] | |
| last_err: Optional[str] = None | |
| async with httpx.AsyncClient() as client: | |
| for provider_name in order: | |
| provider = PROVIDERS[provider_name] | |
| pool = get_pool(provider) | |
| if not pool: | |
| continue | |
| for _ in range(MAX_PROVIDER_RETRY): | |
| ks = pool.pick() | |
| if ks is None: | |
| break | |
| try: | |
| text, status = await _call_once( | |
| client, provider, ks.key, messages, | |
| temperature=temperature, max_tokens=max_tokens, | |
| model=model, | |
| ) | |
| if not text.strip(): | |
| raise RuntimeError("empty completion") | |
| KeyPool.mark_success(ks) | |
| used_model = model or provider.model | |
| return { | |
| "content": text, | |
| "provider": provider.name, | |
| "model": used_model, | |
| } | |
| except Exception as e: | |
| last_err = f"{provider.name}: {e}" | |
| status_code = 0 | |
| if "HTTP " in str(e): | |
| try: | |
| status_code = int(str(e).split("HTTP ")[1].split(":")[0]) | |
| except Exception: | |
| pass | |
| logger.warning("LLM call failed β %s", last_err) | |
| KeyPool.mark_failure(ks, status_code) | |
| raise RuntimeError(f"ALL_PROVIDERS_FAILED ({last_err})") | |
| # βββ Streaming ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def _stream_openai( | |
| client: httpx.AsyncClient, | |
| provider: ProviderConfig, | |
| key: str, | |
| messages: List[Dict[str, str]], | |
| temperature: float, | |
| max_tokens: int, | |
| model: Optional[str] = None, | |
| ) -> AsyncIterator[str]: | |
| use_model = model or provider.model | |
| headers = { | |
| "Authorization": f"Bearer {key}", | |
| "Content-Type": "application/json", | |
| "Accept": "text/event-stream", | |
| } | |
| body = { | |
| "model": use_model, | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "stream": True, | |
| } | |
| async with client.stream( | |
| "POST", provider.url, headers=headers, json=body, | |
| timeout=STREAM_TIMEOUT_SECONDS, | |
| ) as r: | |
| if r.status_code >= 400: | |
| err_text = (await r.aread()).decode("utf-8", "ignore")[:200] | |
| raise RuntimeError(f"{provider.name} HTTP {r.status_code}: {err_text}") | |
| async for line in r.aiter_lines(): | |
| if not line or not line.startswith("data:"): | |
| continue | |
| payload = line[5:].strip() | |
| if payload == "[DONE]": | |
| break | |
| try: | |
| obj = json.loads(payload) | |
| delta = obj["choices"][0]["delta"].get("content") | |
| if delta: | |
| yield delta | |
| except Exception: | |
| continue | |
| async def stream_complete( | |
| messages: List[Dict[str, str]], | |
| *, | |
| temperature: float = 0.4, | |
| max_tokens: int = 2048, | |
| preferred_provider: Optional[str] = None, | |
| model: Optional[str] = None, | |
| ) -> AsyncIterator[Dict[str, Any]]: | |
| """Yield {'type':'delta','content':str,'provider':str} chunks, then {'type':'done', ...}.""" | |
| prompt_text = "\n".join(m.get("content", "") for m in messages if m.get("role") == "user") | |
| order = provider_order(prompt_text) | |
| if preferred_provider and preferred_provider in PROVIDERS: | |
| order = [preferred_provider] + [p for p in order if p != preferred_provider] | |
| last_err: Optional[str] = None | |
| async with httpx.AsyncClient() as client: | |
| for provider_name in order: | |
| provider = PROVIDERS[provider_name] | |
| pool = get_pool(provider) | |
| if not pool: | |
| continue | |
| for _ in range(MAX_PROVIDER_RETRY): | |
| ks = pool.pick() | |
| if ks is None: | |
| break | |
| try: | |
| if provider.stream_supported: | |
| got_any = False | |
| async for delta in _stream_openai( | |
| client, provider, ks.key, messages, | |
| temperature, max_tokens, model=model | |
| ): | |
| got_any = True | |
| yield {"type": "delta", "content": delta, "provider": provider.name} | |
| if not got_any: | |
| raise RuntimeError("empty stream") | |
| else: | |
| # Gemini fallback: non-streaming | |
| text, _ = await _call_once( | |
| client, provider, ks.key, messages, | |
| temperature=temperature, max_tokens=max_tokens, | |
| model=model, | |
| ) | |
| if not text.strip(): | |
| raise RuntimeError("empty completion") | |
| yield {"type": "delta", "content": text, "provider": provider.name} | |
| KeyPool.mark_success(ks) | |
| yield { | |
| "type": "done", | |
| "provider": provider.name, | |
| "model": model or provider.model, | |
| } | |
| return | |
| except Exception as e: | |
| last_err = f"{provider.name}: {e}" | |
| logger.warning("LLM stream failed β %s", last_err) | |
| KeyPool.mark_failure(ks) | |
| yield {"type": "error", "error": f"ALL_PROVIDERS_FAILED ({last_err})"} | |
| def pool_status() -> Dict[str, Any]: | |
| """Diagnostic info about each provider's key pool.""" | |
| out: Dict[str, Any] = {} | |
| for name, provider in PROVIDERS.items(): | |
| pool = get_pool(provider) | |
| out[name] = { | |
| "model": provider.model, | |
| **pool.status(), | |
| } | |
| return out | |