|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
""" |
|
|
Unified chat client module. |
|
|
|
|
|
- Exposes a production-ready MultiProvider cascade client (GROQ → Gemini → HF Router), |
|
|
via ChatClient / chat(...). |
|
|
- Keeps the legacy RouterRequestsClient for direct access to the HF Router compatible |
|
|
/v1/chat/completions endpoint, preserving backward compatibility. |
|
|
|
|
|
This file assumes: |
|
|
- app/bootstrap.py exists and loads configs/.env + sets up logging. |
|
|
- app/core/config.py provides Settings (with provider_order, etc.). |
|
|
- app/core/inference/providers.py implements MultiProviderChat orchestrator. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import time |
|
|
import logging |
|
|
from typing import Dict, List, Optional, Iterator, Tuple, Iterable, Union, Generator |
|
|
|
|
|
|
|
|
import app.bootstrap |
|
|
|
|
|
import requests |
|
|
|
|
|
from app.core.config import Settings |
|
|
from app.core.inference.providers import MultiProviderChat |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Message = Dict[str, str] |
|
|
|
|
|
class ChatClient: |
|
|
""" |
|
|
Unified chat client that executes the configured provider cascade. |
|
|
Providers are tried in order (settings.provider_order). First success wins. |
|
|
""" |
|
|
def __init__(self, settings: Settings | None = None): |
|
|
self._settings = settings or Settings.load() |
|
|
self._chain = MultiProviderChat(self._settings) |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
messages: Iterable[Message], |
|
|
temperature: Optional[float] = None, |
|
|
max_new_tokens: Optional[int] = None, |
|
|
stream: Optional[bool] = None, |
|
|
) -> Union[str, Generator[str, None, None]]: |
|
|
""" |
|
|
Execute a chat completion using the provider cascade. |
|
|
|
|
|
Args: |
|
|
messages: Iterable of {"role": "system|user|assistant", "content": "..."} |
|
|
temperature: Optional override for sampling temperature. |
|
|
max_new_tokens: Optional override for max tokens. |
|
|
stream: If None, uses settings.chat_stream. If True, returns a generator of text chunks. |
|
|
|
|
|
Returns: |
|
|
str (non-stream) or generator[str] (stream) |
|
|
""" |
|
|
use_stream = self._settings.chat_stream if stream is None else bool(stream) |
|
|
return self._chain.chat( |
|
|
messages, |
|
|
temperature=temperature, |
|
|
max_new_tokens=max_new_tokens, |
|
|
stream=use_stream, |
|
|
) |
|
|
|
|
|
|
|
|
_default_client: ChatClient | None = None |
|
|
|
|
|
def _get_default() -> ChatClient: |
|
|
global _default_client |
|
|
if _default_client is None: |
|
|
_default_client = ChatClient() |
|
|
return _default_client |
|
|
|
|
|
def chat( |
|
|
messages: Iterable[Message], |
|
|
temperature: Optional[float] = None, |
|
|
max_new_tokens: Optional[int] = None, |
|
|
stream: Optional[bool] = None, |
|
|
) -> Union[str, Generator[str, None, None]]: |
|
|
""" |
|
|
Convenience function using a process-wide default ChatClient. |
|
|
""" |
|
|
return _get_default().chat(messages, temperature=temperature, max_new_tokens=max_new_tokens, stream=stream) |
|
|
|
|
|
def get_client(settings: Settings | None = None) -> ChatClient: |
|
|
""" |
|
|
Factory for an explicit ChatClient bound to provided settings. |
|
|
""" |
|
|
return ChatClient(settings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ROUTER_URL = "https://router.huggingface.co/v1/chat/completions" |
|
|
|
|
|
def _require_token() -> str: |
|
|
tok = os.getenv("HF_TOKEN") |
|
|
if not tok: |
|
|
raise ValueError("HF_TOKEN is not set. Put it in .env or export it before starting.") |
|
|
return tok |
|
|
|
|
|
def _model_with_provider(model: str, provider: Optional[str]) -> str: |
|
|
if provider and ":" not in model: |
|
|
return f"{model}:{provider}" |
|
|
return model |
|
|
|
|
|
def _mk_messages(system_prompt: Optional[str], user_text: str) -> List[Dict[str, str]]: |
|
|
msgs: List[Dict[str, str]] = [] |
|
|
if system_prompt: |
|
|
msgs.append({"role": "system", "content": system_prompt}) |
|
|
msgs.append({"role": "user", "content": user_text}) |
|
|
return msgs |
|
|
|
|
|
def _timeout_tuple(connect: float = 10.0, read: float = 60.0) -> Tuple[float, float]: |
|
|
return (connect, read) |
|
|
|
|
|
class RouterRequestsClient: |
|
|
""" |
|
|
Simple requests-only client for HF Router Chat Completions. |
|
|
Supports non-streaming (returns str) and streaming (yields token strings). |
|
|
|
|
|
NOTE: New code should prefer ChatClient above. This class is preserved for any |
|
|
legacy call sites that rely on direct HF Router access. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
model: str, |
|
|
fallback: Optional[str] = None, |
|
|
provider: Optional[str] = None, |
|
|
max_retries: int = 2, |
|
|
connect_timeout: float = 10.0, |
|
|
read_timeout: float = 60.0 |
|
|
): |
|
|
self.model = model |
|
|
self.fallback = fallback if fallback != model else None |
|
|
self.provider = provider |
|
|
self.headers = {"Authorization": f"Bearer {_require_token()}"} |
|
|
self.max_retries = max(0, int(max_retries)) |
|
|
self.timeout = _timeout_tuple(connect_timeout, read_timeout) |
|
|
|
|
|
|
|
|
def chat_nonstream( |
|
|
self, |
|
|
system_prompt: Optional[str], |
|
|
user_text: str, |
|
|
max_tokens: int, |
|
|
temperature: float, |
|
|
stop: Optional[List[str]] = None, |
|
|
frequency_penalty: Optional[float] = None, |
|
|
presence_penalty: Optional[float] = None, |
|
|
) -> str: |
|
|
payload = { |
|
|
"model": _model_with_provider(self.model, self.provider), |
|
|
"messages": _mk_messages(system_prompt, user_text), |
|
|
"temperature": float(max(0.0, temperature)), |
|
|
"max_tokens": int(max_tokens), |
|
|
"stream": False, |
|
|
} |
|
|
if stop: |
|
|
payload["stop"] = stop |
|
|
if frequency_penalty is not None: |
|
|
payload["frequency_penalty"] = float(frequency_penalty) |
|
|
if presence_penalty is not None: |
|
|
payload["presence_penalty"] = float(presence_penalty) |
|
|
|
|
|
text, ok = self._try_once(payload) |
|
|
if ok: |
|
|
return text |
|
|
|
|
|
|
|
|
if self.fallback: |
|
|
payload["model"] = _model_with_provider(self.fallback, self.provider) |
|
|
text, ok = self._try_once(payload) |
|
|
if ok: |
|
|
return text |
|
|
|
|
|
raise RuntimeError(f"Chat non-stream failed: model={self.model} fallback={self.fallback}") |
|
|
|
|
|
def _try_once(self, payload: dict) -> Tuple[str, bool]: |
|
|
last_err: Optional[Exception] = None |
|
|
for attempt in range(self.max_retries + 1): |
|
|
try: |
|
|
r = requests.post(ROUTER_URL, headers=self.headers, json=payload, timeout=self.timeout) |
|
|
if r.status_code >= 400: |
|
|
logger.error("Router error %s: %s", r.status_code, r.text) |
|
|
last_err = RuntimeError(f"{r.status_code}: {r.text}") |
|
|
time.sleep(min(1.5 * (attempt + 1), 3.0)) |
|
|
continue |
|
|
data = r.json() |
|
|
return data["choices"][0]["message"]["content"], True |
|
|
except Exception as e: |
|
|
logger.error("Router request failure: %s", e) |
|
|
last_err = e |
|
|
time.sleep(min(1.5 * (attempt + 1), 3.0)) |
|
|
if last_err: |
|
|
logger.error("Router exhausted retries: %s", last_err) |
|
|
return "", False |
|
|
|
|
|
|
|
|
def chat_stream( |
|
|
self, |
|
|
system_prompt: Optional[str], |
|
|
user_text: str, |
|
|
max_tokens: int, |
|
|
temperature: float, |
|
|
stop: Optional[List[str]] = None, |
|
|
frequency_penalty: Optional[float] = None, |
|
|
presence_penalty: Optional[float] = None, |
|
|
) -> Iterator[str]: |
|
|
payload = { |
|
|
"model": _model_with_provider(self.model, self.provider), |
|
|
"messages": _mk_messages(system_prompt, user_text), |
|
|
"temperature": float(max(0.0, temperature)), |
|
|
"max_tokens": int(max_tokens), |
|
|
"stream": True, |
|
|
} |
|
|
if stop: |
|
|
payload["stop"] = stop |
|
|
if frequency_penalty is not None: |
|
|
payload["frequency_penalty"] = float(frequency_penalty) |
|
|
if presence_penalty is not None: |
|
|
payload["presence_penalty"] = float(presence_penalty) |
|
|
|
|
|
|
|
|
ok = False |
|
|
for token in self._stream_once(payload): |
|
|
ok = True |
|
|
yield token |
|
|
if ok: |
|
|
return |
|
|
|
|
|
if self.fallback: |
|
|
payload["model"] = _model_with_provider(self.fallback, self.provider) |
|
|
for token in self._stream_once(payload): |
|
|
yield token |
|
|
|
|
|
def _stream_once(self, payload: dict) -> Iterator[str]: |
|
|
try: |
|
|
with requests.post(ROUTER_URL, headers=self.headers, json=payload, stream=True, timeout=self.timeout) as r: |
|
|
if r.status_code >= 400: |
|
|
logger.error("Router stream error %s: %s", r.status_code, r.text) |
|
|
return |
|
|
for line in r.iter_lines(decode_unicode=True): |
|
|
if not line: |
|
|
continue |
|
|
if not line.startswith("data:"): |
|
|
continue |
|
|
data = line[len("data:"):].strip() |
|
|
if data == "[DONE]": |
|
|
return |
|
|
try: |
|
|
obj = json.loads(data) |
|
|
delta = obj["choices"][0]["delta"].get("content", "") |
|
|
if delta: |
|
|
yield delta |
|
|
except Exception as e: |
|
|
logger.warning("Stream JSON parse issue: %s | line=%r", e, line) |
|
|
continue |
|
|
except Exception as e: |
|
|
logger.error("Stream request failure: %s", e) |
|
|
return |
|
|
|
|
|
|
|
|
def plan_nonstream(self, system_prompt: str, user_text: str, |
|
|
max_tokens: int, temperature: float) -> str: |
|
|
return self.chat_nonstream(system_prompt, user_text, max_tokens, temperature) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"ChatClient", |
|
|
"chat", |
|
|
"get_client", |
|
|
"RouterRequestsClient", |
|
|
] |
|
|
|