|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
""" |
|
|
Provider layer for multi-backend LLM chat with a production-ready cascade: |
|
|
|
|
|
GROQ → Gemini → Hugging Face Inference Router (Zephyr → Mistral) |
|
|
|
|
|
- Each provider implements a common .chat(...) interface that returns either: |
|
|
* str (non-stream), or |
|
|
* Generator[str, None, None] (streaming text chunks) |
|
|
|
|
|
- MultiProviderChat orchestrates providers in a user-configurable order (Settings.provider_order) |
|
|
and returns the first successful response. |
|
|
|
|
|
- Robustness: |
|
|
* .env + logging are loaded via app.bootstrap import side-effect |
|
|
* Requests session has retries and timeouts |
|
|
* Provider initialization gracefully skips when keys/SDKs are missing |
|
|
* Streaming uses SSE for HF Router; Groq uses SDK streaming; Gemini yields one chunk |
|
|
""" |
|
|
|
|
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Union |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import time |
|
|
|
|
|
|
|
|
import app.bootstrap |
|
|
|
|
|
import requests |
|
|
from requests.adapters import HTTPAdapter |
|
|
from urllib3.util.retry import Retry |
|
|
|
|
|
|
|
|
try: |
|
|
from groq import Groq |
|
|
except Exception: |
|
|
Groq = None |
|
|
|
|
|
try: |
|
|
from google import genai |
|
|
except Exception: |
|
|
genai = None |
|
|
|
|
|
from app.core.config import Settings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
Message = Dict[str, str] |
|
|
|
|
|
|
|
|
|
|
|
class ProviderError(RuntimeError): |
|
|
"""Raised for provider-specific configuration/runtime errors.""" |
|
|
|
|
|
|
|
|
|
|
|
def _ensure_messages(msgs: Iterable[Message]) -> List[Message]: |
|
|
""" |
|
|
Normalize incoming messages to a strict [{"role": str, "content": str}, ...] list. |
|
|
""" |
|
|
out: List[Message] = [] |
|
|
for m in msgs: |
|
|
role = m.get("role", "user") |
|
|
content = m.get("content", "") |
|
|
out.append({"role": role, "content": content}) |
|
|
return out |
|
|
|
|
|
|
|
|
def _requests_session_with_retries( |
|
|
total: int = 3, |
|
|
backoff: float = 0.3, |
|
|
status_forcelist: Optional[List[int]] = None, |
|
|
timeout: float = 60.0, |
|
|
) -> requests.Session: |
|
|
""" |
|
|
Return a requests.Session configured with retries, connection pooling, and default timeouts. |
|
|
""" |
|
|
status_forcelist = status_forcelist or [408, 429, 500, 502, 503, 504] |
|
|
retry = Retry( |
|
|
total=total, |
|
|
read=total, |
|
|
connect=total, |
|
|
backoff_factor=backoff, |
|
|
status_forcelist=status_forcelist, |
|
|
allowed_methods=frozenset(["GET", "POST"]), |
|
|
raise_on_status=False, |
|
|
) |
|
|
adapter = HTTPAdapter(max_retries=retry, pool_connections=10, pool_maxsize=10) |
|
|
session = requests.Session() |
|
|
session.mount("http://", adapter) |
|
|
session.mount("https://", adapter) |
|
|
|
|
|
session.request = _patch_request_with_timeout(session.request, timeout) |
|
|
return session |
|
|
|
|
|
|
|
|
def _patch_request_with_timeout(fn, timeout: float): |
|
|
def wrapper(method, url, **kwargs): |
|
|
if "timeout" not in kwargs: |
|
|
kwargs["timeout"] = timeout |
|
|
return fn(method, url, **kwargs) |
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
class GroqProvider: |
|
|
""" |
|
|
Groq Chat Completions (OpenAI-compatible). |
|
|
Requires: |
|
|
- env: GROQ_API_KEY |
|
|
- package: groq |
|
|
""" |
|
|
name = "groq" |
|
|
|
|
|
def __init__(self, model: str): |
|
|
self.model = model |
|
|
self.api_key = os.getenv("GROQ_API_KEY") |
|
|
if not self.api_key: |
|
|
raise ProviderError("GROQ_API_KEY is not set") |
|
|
if Groq is None: |
|
|
raise ProviderError("groq SDK not installed; add 'groq' to requirements.txt and pip install.") |
|
|
|
|
|
self.client = Groq() |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
messages: Iterable[Message], |
|
|
temperature: float, |
|
|
max_new_tokens: int, |
|
|
stream: bool, |
|
|
) -> Union[str, Generator[str, None, None]]: |
|
|
msgs = _ensure_messages(messages) |
|
|
try: |
|
|
completion = self.client.chat.completions.create( |
|
|
model=self.model, |
|
|
messages=msgs, |
|
|
temperature=float(temperature), |
|
|
max_tokens=int(max_new_tokens), |
|
|
top_p=1, |
|
|
stream=bool(stream), |
|
|
) |
|
|
if stream: |
|
|
def gen(): |
|
|
for chunk in completion: |
|
|
try: |
|
|
delta = chunk.choices[0].delta |
|
|
part = getattr(delta, "content", None) |
|
|
if part: |
|
|
yield part |
|
|
except Exception: |
|
|
continue |
|
|
return gen() |
|
|
else: |
|
|
|
|
|
return completion.choices[0].message.content or "" |
|
|
except Exception as e: |
|
|
raise ProviderError(f"GROQ error: {e}") from e |
|
|
|
|
|
|
|
|
|
|
|
class GeminiProvider: |
|
|
""" |
|
|
Google Gemini via google-genai. |
|
|
Requires: |
|
|
- env: GOOGLE_API_KEY |
|
|
- package: google-genai |
|
|
|
|
|
Role mapping: |
|
|
- system → system_instruction (joined) |
|
|
- user → role 'user' |
|
|
- assistant → role 'model' |
|
|
""" |
|
|
name = "gemini" |
|
|
|
|
|
def __init__(self, model: str): |
|
|
self.model = model |
|
|
self.api_key = os.getenv("GOOGLE_API_KEY") |
|
|
if not self.api_key: |
|
|
raise ProviderError("GOOGLE_API_KEY is not set") |
|
|
if genai is None: |
|
|
raise ProviderError("google-genai SDK not installed; add 'google-genai' to requirements.txt and pip install.") |
|
|
self.client = genai.Client(api_key=self.api_key) |
|
|
|
|
|
@staticmethod |
|
|
def _split_system_and_messages(msgs: List[Message]) -> tuple[str, List[dict]]: |
|
|
system_parts: List[str] = [] |
|
|
contents: List[dict] = [] |
|
|
for m in msgs: |
|
|
role = m.get("role", "user") |
|
|
text = m.get("content", "") |
|
|
if role == "system": |
|
|
system_parts.append(text) |
|
|
else: |
|
|
mapped = "user" if role == "user" else "model" |
|
|
contents.append({"role": mapped, "parts": [{"text": text}]}) |
|
|
return ("\n".join(system_parts).strip(), contents) |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
messages: Iterable[Message], |
|
|
temperature: float, |
|
|
max_new_tokens: int, |
|
|
stream: bool, |
|
|
) -> Union[str, Generator[str, None, None]]: |
|
|
msgs = _ensure_messages(messages) |
|
|
system_instruction, contents = self._split_system_and_messages(msgs) |
|
|
try: |
|
|
|
|
|
kwargs: Dict[str, Any] = { |
|
|
"model": self.model, |
|
|
"contents": contents, |
|
|
"generation_config": { |
|
|
"temperature": float(temperature), |
|
|
"max_output_tokens": int(max_new_tokens), |
|
|
}, |
|
|
} |
|
|
try: |
|
|
resp = self.client.models.generate_content(system_instruction=system_instruction or None, **kwargs) |
|
|
except TypeError: |
|
|
|
|
|
if system_instruction: |
|
|
contents = [{"role": "user", "parts": [{"text": f"System: {system_instruction}"}]}] + contents |
|
|
kwargs["contents"] = contents |
|
|
resp = self.client.models.generate_content(**kwargs) |
|
|
|
|
|
text = getattr(resp, "text", "") or "" |
|
|
|
|
|
if stream: |
|
|
|
|
|
def gen(): |
|
|
yield text |
|
|
return gen() |
|
|
return text |
|
|
except Exception as e: |
|
|
raise ProviderError(f"Gemini error: {e}") from e |
|
|
|
|
|
|
|
|
|
|
|
class HfRouterProvider: |
|
|
""" |
|
|
Hugging Face Inference Router (OpenAI-like /v1/chat/completions). |
|
|
Tries primary -> fallback model (both can include optional provider tag, e.g., "model:featherless-ai"). |
|
|
|
|
|
Requires: |
|
|
- env: HF_TOKEN |
|
|
- package: requests |
|
|
""" |
|
|
name = "router" |
|
|
BASE_URL = "https://router.huggingface.co/v1/chat/completions" |
|
|
|
|
|
def __init__(self, primary_model: str, fallback_model: Optional[str], provider_tag: Optional[str]): |
|
|
self.primary = primary_model |
|
|
self.fallback = fallback_model |
|
|
self.provider_tag = provider_tag |
|
|
self.token = os.getenv("HF_TOKEN") |
|
|
if not self.token: |
|
|
raise ProviderError("HF_TOKEN is not set") |
|
|
self.session = _requests_session_with_retries(total=3, backoff=0.5, timeout=60.0) |
|
|
|
|
|
def _fmt_model(self, model: str) -> str: |
|
|
return model if not self.provider_tag else f"{model}:{self.provider_tag}" |
|
|
|
|
|
def _sse_stream(self, resp: requests.Response) -> Generator[str, None, None]: |
|
|
for raw in resp.iter_lines(decode_unicode=True): |
|
|
if not raw: |
|
|
continue |
|
|
if not raw.startswith("data:"): |
|
|
continue |
|
|
data = raw[5:].strip() |
|
|
if data == "[DONE]": |
|
|
break |
|
|
try: |
|
|
obj = json.loads(data) |
|
|
except Exception: |
|
|
continue |
|
|
try: |
|
|
delta = obj["choices"][0].get("delta", {}) |
|
|
content = delta.get("content") |
|
|
if content: |
|
|
yield content |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
def _call_router( |
|
|
self, |
|
|
model: str, |
|
|
messages: List[Message], |
|
|
temperature: float, |
|
|
max_new_tokens: int, |
|
|
stream: bool, |
|
|
) -> Union[str, Generator[str, None, None]]: |
|
|
headers = { |
|
|
"Authorization": f"Bearer {self.token}", |
|
|
"Content-Type": "application/json", |
|
|
} |
|
|
payload: Dict[str, Any] = { |
|
|
"model": self._fmt_model(model), |
|
|
"messages": messages, |
|
|
"temperature": float(temperature), |
|
|
"max_tokens": int(max_new_tokens), |
|
|
"stream": bool(stream), |
|
|
} |
|
|
if stream: |
|
|
with self.session.post(self.BASE_URL, headers=headers, json=payload, stream=True) as r: |
|
|
if r.status_code >= 400: |
|
|
raise ProviderError(f"HF Router HTTP {r.status_code}: {r.text[:300]}") |
|
|
return self._sse_stream(r) |
|
|
else: |
|
|
r = self.session.post(self.BASE_URL, headers=headers, json=payload) |
|
|
if r.status_code >= 400: |
|
|
raise ProviderError(f"HF Router HTTP {r.status_code}: {r.text[:300]}") |
|
|
obj = r.json() |
|
|
try: |
|
|
return obj["choices"][0]["message"]["content"] |
|
|
except Exception as e: |
|
|
raise ProviderError(f"HF Router response parsing error: {e}") from e |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
messages: Iterable[Message], |
|
|
temperature: float, |
|
|
max_new_tokens: int, |
|
|
stream: bool, |
|
|
) -> Union[str, Generator[str, None, None]]: |
|
|
msgs = _ensure_messages(messages) |
|
|
try: |
|
|
return self._call_router(self.primary, msgs, temperature, max_new_tokens, stream) |
|
|
except Exception as e1: |
|
|
logger.warning("HF primary model failed (%s): %s", self.primary, e1) |
|
|
if self.fallback: |
|
|
return self._call_router(self.fallback, msgs, temperature, max_new_tokens, stream) |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
class MultiProviderChat: |
|
|
""" |
|
|
Tries providers in configured order. First success wins. |
|
|
Skips misconfigured providers (missing key or SDK). |
|
|
""" |
|
|
def __init__(self, settings: Settings): |
|
|
m = settings.model |
|
|
order = [p.strip().lower() for p in settings.provider_order] |
|
|
self.providers: List[Any] = [] |
|
|
|
|
|
for p in order: |
|
|
try: |
|
|
if p == "groq": |
|
|
self.providers.append(GroqProvider(m.groq_model)) |
|
|
elif p == "gemini": |
|
|
self.providers.append(GeminiProvider(m.gemini_model)) |
|
|
elif p == "router": |
|
|
self.providers.append(HfRouterProvider(m.name, m.fallback, m.provider)) |
|
|
else: |
|
|
logger.warning("Unknown provider '%s' in provider_order; skipping.", p) |
|
|
except ProviderError as e: |
|
|
logger.warning("Provider '%s' not available: %s (will skip)", p, e) |
|
|
continue |
|
|
|
|
|
if not self.providers: |
|
|
raise ProviderError("No providers are configured/available") |
|
|
|
|
|
self.temperature = m.temperature |
|
|
self.max_new_tokens = m.max_new_tokens |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
messages: Iterable[Message], |
|
|
temperature: Optional[float] = None, |
|
|
max_new_tokens: Optional[int] = None, |
|
|
stream: bool = True, |
|
|
) -> Union[str, Generator[str, None, None]]: |
|
|
temp = float(self.temperature if temperature is None else temperature) |
|
|
mx = int(self.max_new_tokens if max_new_tokens is None else max_new_tokens) |
|
|
last_err: Optional[Exception] = None |
|
|
|
|
|
for provider in self.providers: |
|
|
pname = getattr(provider, "name", provider.__class__.__name__) |
|
|
t0 = time.time() |
|
|
try: |
|
|
result = provider.chat(messages, temp, mx, stream) |
|
|
logger.info("Provider '%s' succeeded in %.2fs", pname, time.time() - t0) |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.warning("Provider '%s' failed: %s", pname, e) |
|
|
last_err = e |
|
|
continue |
|
|
|
|
|
raise ProviderError(f"All providers failed. Last error: {last_err}") |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"ProviderError", |
|
|
"GroqProvider", |
|
|
"GeminiProvider", |
|
|
"HfRouterProvider", |
|
|
"MultiProviderChat", |
|
|
] |
|
|
|