# app/core/inference/providers.py from __future__ import annotations """ Provider layer for multi-backend LLM chat with a production-ready cascade: GROQ → Gemini → Hugging Face Inference Router (Zephyr → Mistral) - Each provider implements a common .chat(...) interface that returns either: * str (non-stream), or * Generator[str, None, None] (streaming text chunks) - MultiProviderChat orchestrates providers in a user-configurable order (Settings.provider_order) and returns the first successful response. - Robustness: * .env + logging are loaded via app.bootstrap import side-effect * Requests session has retries and timeouts * Provider initialization gracefully skips when keys/SDKs are missing * Streaming uses SSE for HF Router; Groq uses SDK streaming; Gemini yields one chunk """ from typing import Any, Dict, Generator, Iterable, List, Optional, Union import json import logging import os import time # Ensure .env + logging configured even if imported directly import app.bootstrap # noqa: F401 import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry # Optional SDKs; handled gracefully if absent try: from groq import Groq except Exception: # pragma: no cover Groq = None # type: ignore try: from google import genai except Exception: # pragma: no cover genai = None # type: ignore from app.core.config import Settings logger = logging.getLogger(__name__) Message = Dict[str, str] # {"role": "system|user|assistant", "content": "..."} # ---------- Errors ---------- class ProviderError(RuntimeError): """Raised for provider-specific configuration/runtime errors.""" # ---------- Helpers ---------- def _ensure_messages(msgs: Iterable[Message]) -> List[Message]: """ Normalize incoming messages to a strict [{"role": str, "content": str}, ...] list. """ out: List[Message] = [] for m in msgs: role = m.get("role", "user") content = m.get("content", "") out.append({"role": role, "content": content}) return out def _requests_session_with_retries( total: int = 3, backoff: float = 0.3, status_forcelist: Optional[List[int]] = None, timeout: float = 60.0, ) -> requests.Session: """ Return a requests.Session configured with retries, connection pooling, and default timeouts. """ status_forcelist = status_forcelist or [408, 429, 500, 502, 503, 504] retry = Retry( total=total, read=total, connect=total, backoff_factor=backoff, status_forcelist=status_forcelist, allowed_methods=frozenset(["GET", "POST"]), raise_on_status=False, ) adapter = HTTPAdapter(max_retries=retry, pool_connections=10, pool_maxsize=10) session = requests.Session() session.mount("http://", adapter) session.mount("https://", adapter) # Store default timeout on session via a patched request method session.request = _patch_request_with_timeout(session.request, timeout) # type: ignore return session def _patch_request_with_timeout(fn, timeout: float): def wrapper(method, url, **kwargs): if "timeout" not in kwargs: kwargs["timeout"] = timeout return fn(method, url, **kwargs) return wrapper # ---------- GROQ ---------- class GroqProvider: """ Groq Chat Completions (OpenAI-compatible). Requires: - env: GROQ_API_KEY - package: groq """ name = "groq" def __init__(self, model: str): self.model = model self.api_key = os.getenv("GROQ_API_KEY") if not self.api_key: raise ProviderError("GROQ_API_KEY is not set") if Groq is None: raise ProviderError("groq SDK not installed; add 'groq' to requirements.txt and pip install.") # SDK reads key from env self.client = Groq() def chat( self, messages: Iterable[Message], temperature: float, max_new_tokens: int, stream: bool, ) -> Union[str, Generator[str, None, None]]: msgs = _ensure_messages(messages) try: completion = self.client.chat.completions.create( model=self.model, messages=msgs, temperature=float(temperature), max_tokens=int(max_new_tokens), top_p=1, stream=bool(stream), ) if stream: def gen(): for chunk in completion: try: delta = chunk.choices[0].delta part = getattr(delta, "content", None) if part: yield part except Exception: continue return gen() else: # Non-streaming: return final message content return completion.choices[0].message.content or "" except Exception as e: raise ProviderError(f"GROQ error: {e}") from e # ---------- GEMINI ---------- class GeminiProvider: """ Google Gemini via google-genai. Requires: - env: GOOGLE_API_KEY - package: google-genai Role mapping: - system → system_instruction (joined) - user → role 'user' - assistant → role 'model' """ name = "gemini" def __init__(self, model: str): self.model = model self.api_key = os.getenv("GOOGLE_API_KEY") if not self.api_key: raise ProviderError("GOOGLE_API_KEY is not set") if genai is None: raise ProviderError("google-genai SDK not installed; add 'google-genai' to requirements.txt and pip install.") self.client = genai.Client(api_key=self.api_key) @staticmethod def _split_system_and_messages(msgs: List[Message]) -> tuple[str, List[dict]]: system_parts: List[str] = [] contents: List[dict] = [] for m in msgs: role = m.get("role", "user") text = m.get("content", "") if role == "system": system_parts.append(text) else: mapped = "user" if role == "user" else "model" contents.append({"role": mapped, "parts": [{"text": text}]}) return ("\n".join(system_parts).strip(), contents) def chat( self, messages: Iterable[Message], temperature: float, max_new_tokens: int, stream: bool, ) -> Union[str, Generator[str, None, None]]: msgs = _ensure_messages(messages) system_instruction, contents = self._split_system_and_messages(msgs) try: # Some versions of google-genai expose system_instruction; if not, we prepend. kwargs: Dict[str, Any] = { "model": self.model, "contents": contents, "generation_config": { "temperature": float(temperature), "max_output_tokens": int(max_new_tokens), }, } try: resp = self.client.models.generate_content(system_instruction=system_instruction or None, **kwargs) except TypeError: # Fallback for older SDKs: inject system as first user turn if system_instruction: contents = [{"role": "user", "parts": [{"text": f"System: {system_instruction}"}]}] + contents kwargs["contents"] = contents resp = self.client.models.generate_content(**kwargs) text = getattr(resp, "text", "") or "" if stream: # Fake streaming for API parity: one chunk def gen(): yield text return gen() return text except Exception as e: raise ProviderError(f"Gemini error: {e}") from e # ---------- HF Inference Router ---------- class HfRouterProvider: """ Hugging Face Inference Router (OpenAI-like /v1/chat/completions). Tries primary -> fallback model (both can include optional provider tag, e.g., "model:featherless-ai"). Requires: - env: HF_TOKEN - package: requests """ name = "router" BASE_URL = "https://router.huggingface.co/v1/chat/completions" def __init__(self, primary_model: str, fallback_model: Optional[str], provider_tag: Optional[str]): self.primary = primary_model self.fallback = fallback_model self.provider_tag = provider_tag self.token = os.getenv("HF_TOKEN") if not self.token: raise ProviderError("HF_TOKEN is not set") self.session = _requests_session_with_retries(total=3, backoff=0.5, timeout=60.0) def _fmt_model(self, model: str) -> str: return model if not self.provider_tag else f"{model}:{self.provider_tag}" def _sse_stream(self, resp: requests.Response) -> Generator[str, None, None]: for raw in resp.iter_lines(decode_unicode=True): if not raw: continue if not raw.startswith("data:"): continue data = raw[5:].strip() if data == "[DONE]": break try: obj = json.loads(data) except Exception: continue try: delta = obj["choices"][0].get("delta", {}) content = delta.get("content") if content: yield content except Exception: continue def _call_router( self, model: str, messages: List[Message], temperature: float, max_new_tokens: int, stream: bool, ) -> Union[str, Generator[str, None, None]]: headers = { "Authorization": f"Bearer {self.token}", "Content-Type": "application/json", } payload: Dict[str, Any] = { "model": self._fmt_model(model), "messages": messages, "temperature": float(temperature), "max_tokens": int(max_new_tokens), "stream": bool(stream), } if stream: with self.session.post(self.BASE_URL, headers=headers, json=payload, stream=True) as r: if r.status_code >= 400: raise ProviderError(f"HF Router HTTP {r.status_code}: {r.text[:300]}") return self._sse_stream(r) else: r = self.session.post(self.BASE_URL, headers=headers, json=payload) if r.status_code >= 400: raise ProviderError(f"HF Router HTTP {r.status_code}: {r.text[:300]}") obj = r.json() try: return obj["choices"][0]["message"]["content"] except Exception as e: raise ProviderError(f"HF Router response parsing error: {e}") from e def chat( self, messages: Iterable[Message], temperature: float, max_new_tokens: int, stream: bool, ) -> Union[str, Generator[str, None, None]]: msgs = _ensure_messages(messages) try: return self._call_router(self.primary, msgs, temperature, max_new_tokens, stream) except Exception as e1: logger.warning("HF primary model failed (%s): %s", self.primary, e1) if self.fallback: return self._call_router(self.fallback, msgs, temperature, max_new_tokens, stream) raise # ---------- Orchestrator ---------- class MultiProviderChat: """ Tries providers in configured order. First success wins. Skips misconfigured providers (missing key or SDK). """ def __init__(self, settings: Settings): m = settings.model order = [p.strip().lower() for p in settings.provider_order] self.providers: List[Any] = [] for p in order: try: if p == "groq": self.providers.append(GroqProvider(m.groq_model)) elif p == "gemini": self.providers.append(GeminiProvider(m.gemini_model)) elif p == "router": self.providers.append(HfRouterProvider(m.name, m.fallback, m.provider)) else: logger.warning("Unknown provider '%s' in provider_order; skipping.", p) except ProviderError as e: logger.warning("Provider '%s' not available: %s (will skip)", p, e) continue if not self.providers: raise ProviderError("No providers are configured/available") self.temperature = m.temperature self.max_new_tokens = m.max_new_tokens def chat( self, messages: Iterable[Message], temperature: Optional[float] = None, max_new_tokens: Optional[int] = None, stream: bool = True, ) -> Union[str, Generator[str, None, None]]: temp = float(self.temperature if temperature is None else temperature) mx = int(self.max_new_tokens if max_new_tokens is None else max_new_tokens) last_err: Optional[Exception] = None for provider in self.providers: pname = getattr(provider, "name", provider.__class__.__name__) t0 = time.time() try: result = provider.chat(messages, temp, mx, stream) logger.info("Provider '%s' succeeded in %.2fs", pname, time.time() - t0) return result except Exception as e: logger.warning("Provider '%s' failed: %s", pname, e) last_err = e continue raise ProviderError(f"All providers failed. Last error: {last_err}") __all__ = [ "ProviderError", "GroqProvider", "GeminiProvider", "HfRouterProvider", "MultiProviderChat", ]