File size: 10,386 Bytes
8d60e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
# app/core/inference/client.py
from __future__ import annotations

"""
Unified chat client module.

- Exposes a production-ready MultiProvider cascade client (GROQ → Gemini → HF Router),
  via ChatClient / chat(...).
- Keeps the legacy RouterRequestsClient for direct access to the HF Router compatible
  /v1/chat/completions endpoint, preserving backward compatibility.

This file assumes:
  - app/bootstrap.py exists and loads configs/.env + sets up logging.
  - app/core/config.py provides Settings (with provider_order, etc.).
  - app/core/inference/providers.py implements MultiProviderChat orchestrator.
"""

import os
import json
import time
import logging
from typing import Dict, List, Optional, Iterator, Tuple, Iterable, Union, Generator

# Ensure .env & logging before we load settings/providers
import app.bootstrap  # noqa: F401

import requests

from app.core.config import Settings
from app.core.inference.providers import MultiProviderChat

logger = logging.getLogger(__name__)

# -----------------------------
# Multi-provider cascade client
# -----------------------------

Message = Dict[str, str]

class ChatClient:
    """
    Unified chat client that executes the configured provider cascade.
    Providers are tried in order (settings.provider_order). First success wins.
    """
    def __init__(self, settings: Settings | None = None):
        self._settings = settings or Settings.load()
        self._chain = MultiProviderChat(self._settings)

    def chat(
        self,
        messages: Iterable[Message],
        temperature: Optional[float] = None,
        max_new_tokens: Optional[int] = None,
        stream: Optional[bool] = None,
    ) -> Union[str, Generator[str, None, None]]:
        """
        Execute a chat completion using the provider cascade.

        Args:
            messages: Iterable of {"role": "system|user|assistant", "content": "..."}
            temperature: Optional override for sampling temperature.
            max_new_tokens: Optional override for max tokens.
            stream: If None, uses settings.chat_stream. If True, returns a generator of text chunks.

        Returns:
            str (non-stream) or generator[str] (stream)
        """
        use_stream = self._settings.chat_stream if stream is None else bool(stream)
        return self._chain.chat(
            messages,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            stream=use_stream,
        )

# Backward-compatible helpers
_default_client: ChatClient | None = None

def _get_default() -> ChatClient:
    global _default_client
    if _default_client is None:
        _default_client = ChatClient()
    return _default_client

def chat(
    messages: Iterable[Message],
    temperature: Optional[float] = None,
    max_new_tokens: Optional[int] = None,
    stream: Optional[bool] = None,
) -> Union[str, Generator[str, None, None]]:
    """
    Convenience function using a process-wide default ChatClient.
    """
    return _get_default().chat(messages, temperature=temperature, max_new_tokens=max_new_tokens, stream=stream)

def get_client(settings: Settings | None = None) -> ChatClient:
    """
    Factory for an explicit ChatClient bound to provided settings.
    """
    return ChatClient(settings)


# ------------------------------------------------------
# Legacy HF Router client (kept for backward compatibility)
# ------------------------------------------------------

ROUTER_URL = "https://router.huggingface.co/v1/chat/completions"

def _require_token() -> str:
    tok = os.getenv("HF_TOKEN")
    if not tok:
        raise ValueError("HF_TOKEN is not set. Put it in .env or export it before starting.")
    return tok

def _model_with_provider(model: str, provider: Optional[str]) -> str:
    if provider and ":" not in model:
        return f"{model}:{provider}"
    return model

def _mk_messages(system_prompt: Optional[str], user_text: str) -> List[Dict[str, str]]:
    msgs: List[Dict[str, str]] = []
    if system_prompt:
        msgs.append({"role": "system", "content": system_prompt})
    msgs.append({"role": "user", "content": user_text})
    return msgs

def _timeout_tuple(connect: float = 10.0, read: float = 60.0) -> Tuple[float, float]:
    return (connect, read)

class RouterRequestsClient:
    """
    Simple requests-only client for HF Router Chat Completions.
    Supports non-streaming (returns str) and streaming (yields token strings).

    NOTE: New code should prefer ChatClient above. This class is preserved for any
    legacy call sites that rely on direct HF Router access.
    """
    def __init__(
        self,
        model: str,
        fallback: Optional[str] = None,
        provider: Optional[str] = None,
        max_retries: int = 2,
        connect_timeout: float = 10.0,
        read_timeout: float = 60.0
    ):
        self.model = model
        self.fallback = fallback if fallback != model else None
        self.provider = provider
        self.headers = {"Authorization": f"Bearer {_require_token()}"}
        self.max_retries = max(0, int(max_retries))
        self.timeout = _timeout_tuple(connect_timeout, read_timeout)

    # -------- Non-stream (single text) --------
    def chat_nonstream(
        self,
        system_prompt: Optional[str],
        user_text: str,
        max_tokens: int,
        temperature: float,
        stop: Optional[List[str]] = None,
        frequency_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,
    ) -> str:
        payload = {
            "model": _model_with_provider(self.model, self.provider),
            "messages": _mk_messages(system_prompt, user_text),
            "temperature": float(max(0.0, temperature)),
            "max_tokens": int(max_tokens),
            "stream": False,
        }
        if stop:
            payload["stop"] = stop
        if frequency_penalty is not None:
            payload["frequency_penalty"] = float(frequency_penalty)
        if presence_penalty is not None:
            payload["presence_penalty"] = float(presence_penalty)

        text, ok = self._try_once(payload)
        if ok:
            return text

        # fallback (if configured)
        if self.fallback:
            payload["model"] = _model_with_provider(self.fallback, self.provider)
            text, ok = self._try_once(payload)
            if ok:
                return text

        raise RuntimeError(f"Chat non-stream failed: model={self.model} fallback={self.fallback}")

    def _try_once(self, payload: dict) -> Tuple[str, bool]:
        last_err: Optional[Exception] = None
        for attempt in range(self.max_retries + 1):
            try:
                r = requests.post(ROUTER_URL, headers=self.headers, json=payload, timeout=self.timeout)
                if r.status_code >= 400:
                    logger.error("Router error %s: %s", r.status_code, r.text)
                    last_err = RuntimeError(f"{r.status_code}: {r.text}")
                    time.sleep(min(1.5 * (attempt + 1), 3.0))
                    continue
                data = r.json()
                return data["choices"][0]["message"]["content"], True
            except Exception as e:
                logger.error("Router request failure: %s", e)
                last_err = e
                time.sleep(min(1.5 * (attempt + 1), 3.0))
        if last_err:
            logger.error("Router exhausted retries: %s", last_err)
        return "", False

    # -------- Streaming (yield token deltas) --------
    def chat_stream(
        self,
        system_prompt: Optional[str],
        user_text: str,
        max_tokens: int,
        temperature: float,
        stop: Optional[List[str]] = None,
        frequency_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,
    ) -> Iterator[str]:
        payload = {
            "model": _model_with_provider(self.model, self.provider),
            "messages": _mk_messages(system_prompt, user_text),
            "temperature": float(max(0.0, temperature)),
            "max_tokens": int(max_tokens),
            "stream": True,
        }
        if stop:
            payload["stop"] = stop
        if frequency_penalty is not None:
            payload["frequency_penalty"] = float(frequency_penalty)
        if presence_penalty is not None:
            payload["presence_penalty"] = float(presence_penalty)

        # primary
        ok = False
        for token in self._stream_once(payload):
            ok = True
            yield token
        if ok:
            return
        # fallback stream if primary produced nothing (or died immediately)
        if self.fallback:
            payload["model"] = _model_with_provider(self.fallback, self.provider)
            for token in self._stream_once(payload):
                yield token

    def _stream_once(self, payload: dict) -> Iterator[str]:
        try:
            with requests.post(ROUTER_URL, headers=self.headers, json=payload, stream=True, timeout=self.timeout) as r:
                if r.status_code >= 400:
                    logger.error("Router stream error %s: %s", r.status_code, r.text)
                    return
                for line in r.iter_lines(decode_unicode=True):
                    if not line:
                        continue
                    if not line.startswith("data:"):
                        continue
                    data = line[len("data:"):].strip()
                    if data == "[DONE]":
                        return
                    try:
                        obj = json.loads(data)
                        delta = obj["choices"][0]["delta"].get("content", "")
                        if delta:
                            yield delta
                    except Exception as e:
                        logger.warning("Stream JSON parse issue: %s | line=%r", e, line)
                        continue
        except Exception as e:
            logger.error("Stream request failure: %s", e)
            return

    # -------- Planning (non-stream) --------
    def plan_nonstream(self, system_prompt: str, user_text: str,
                       max_tokens: int, temperature: float) -> str:
        return self.chat_nonstream(system_prompt, user_text, max_tokens, temperature)


__all__ = [
    "ChatClient",
    "chat",
    "get_client",
    "RouterRequestsClient",
]