File size: 10,966 Bytes
bbe01fe
3d134a6
 
bbe01fe
 
 
 
 
 
 
 
 
3d134a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbe01fe
 
 
 
 
 
 
 
 
 
 
 
3d134a6
bbe01fe
 
 
 
 
 
 
 
3d134a6
 
bbe01fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d134a6
 
 
bbe01fe
 
3d134a6
 
 
 
 
 
 
 
 
 
 
bbe01fe
 
 
 
 
 
 
3d134a6
bbe01fe
 
 
 
 
 
3d134a6
 
 
bbe01fe
 
3d134a6
 
 
 
bbe01fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d134a6
bbe01fe
 
 
 
 
 
 
 
 
 
 
 
3d134a6
 
bbe01fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import json
import time
from typing import AsyncIterator, Literal, Optional, Protocol

import httpx
from groq import AsyncGroq
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type

from app.core.config import Settings
from app.core.exceptions import GenerationError


class TpmBucket:
    """
    Sliding 60-second token-consumption tracker shared across all Groq calls.

    Issue 7: When the bucket exceeds 12,000 estimated tokens in the current
    minute window, complete_with_complexity() downgrades 70B calls to 8B
    automatically.  This leaves 2,400 TPM headroom and prevents hard failures
    (HTTP 429) from degrading the service under load.

    Token estimates are rough (prompt_chars / 4) but accurate enough for this
    protective purpose — the goal is load shedding, not exact accounting.
    """

    _WINDOW_SECONDS: int = 60
    _DOWNGRADE_THRESHOLD: int = 12_000

    def __init__(self) -> None:
        self._count: int = 0
        self._window_start: float = time.monotonic()

    def add(self, estimated_tokens: int) -> None:
        now = time.monotonic()
        if now - self._window_start >= self._WINDOW_SECONDS:
            self._count = 0
            self._window_start = now
        self._count += estimated_tokens

    @property
    def should_downgrade(self) -> bool:
        now = time.monotonic()
        if now - self._window_start >= self._WINDOW_SECONDS:
            return False
        return self._count > self._DOWNGRADE_THRESHOLD


class LLMClient(Protocol):
    async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]:
        ...

    async def classify_complexity(self, query: str) -> Literal["simple", "complex"]:
        ...

    async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]:
        ...


class GroqClient:
    def __init__(self, api_key: str, model_default: str, model_large: str, tpm_bucket: Optional[TpmBucket] = None):
        if not api_key or api_key == "gsk_placeholder":
             # We might be initialized in a test context without a real key
             self.client = None
        else:
            self.client = AsyncGroq(api_key=api_key)
            
        self.model_default = model_default
        self.model_large = model_large
        # Shared TPM bucket — injected at startup, None in test contexts.
        self._tpm_bucket = tpm_bucket

    @retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException)))
    async def classify_complexity(self, query: str) -> Literal["simple", "complex"]:
        if not self.client:
             raise GenerationError("GroqClient not configured with an API Key.")
             
        system = "You are a classifier. Read the user query. Output ONLY the word 'simple' or 'complex'. Do not explain."
        
        try:
            response = await self.client.chat.completions.create(
                messages=[
                    {"role": "system", "content": system},
                    {"role": "user", "content": query}
                ],
                model=self.model_default,
                temperature=0.0,
                max_tokens=10,
                timeout=3.0,
            )
            
            result = response.choices[0].message.content.strip().lower()
            if "complex" in result:
                return "complex"
            return "simple"
            
        except Exception as e:
            # Fallback to complex just to be safe if classification fails on parsing
            return "complex"

    @retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException)))
    async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]:
        if not self.client:
             raise GenerationError("GroqClient not configured with an API Key.")
             
        model = self.model_default
        
        try:
             stream_response = await self.client.chat.completions.create(
                 messages=[
                     {"role": "system", "content": system},
                     {"role": "user", "content": prompt}
                 ],
                 model=model,
                 stream=True
             )
             
             async for chunk in stream_response:
                 content = chunk.choices[0].delta.content
                 if content:
                     yield content
                     
        except Exception as e:
            raise GenerationError("Groq completion failed", context={"error": str(e)}) from e

            
    async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]:
         # Helper to allow pipeline nodes to pass the pre-classified complexity.
         # Issue 7: if the shared TPM bucket is above 12,000 tokens in the current
         # minute window, downgrade 70B to 8B to prevent hard rate-limit failures.
         if not self.client:
             raise GenerationError("GroqClient not configured with an API Key.")

         if complexity == "complex" and self._tpm_bucket is not None and self._tpm_bucket.should_downgrade:
             model = self.model_default
         else:
             model = self.model_large if complexity == "complex" else self.model_default

         # Estimate input tokens before the call so the bucket reflects the full
         # cost even when the response is long.  4 chars ≈ 1 token (rough heuristic).
         if self._tpm_bucket is not None:
             self._tpm_bucket.add((len(prompt) + len(system)) // 4)

         try:
             stream_response = await self.client.chat.completions.create(
                 messages=[
                     {"role": "system", "content": system},
                     {"role": "user", "content": prompt}
                 ],
                 model=model,
                 stream=stream  # Instruct strictly said stream=True yields token chunks.
             )
             
             if stream:
                 async for chunk in stream_response:
                     content = chunk.choices[0].delta.content
                     if content:
                         # Accumulate estimated response tokens in the bucket.
                         if self._tpm_bucket is not None:
                             self._tpm_bucket.add(len(content) // 4 or 1)
                         yield content
             else:
                 full = stream_response.choices[0].message.content
                 if self._tpm_bucket is not None and full:
                     self._tpm_bucket.add(len(full) // 4)
                 yield full
                     
         except Exception as e:
            raise GenerationError("Groq completion failed", context={"error": str(e)}) from e


class OllamaClient:
    def __init__(self, base_url: str, model: str):
        self.base_url = base_url.rstrip("/")
        self.model = model

    @retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException)))
    async def classify_complexity(self, query: str) -> Literal["simple", "complex"]:
        system = "You are a classifier. Read the user query. Output ONLY the word 'simple' or 'complex'. Do not explain."
        
        try:
            async with httpx.AsyncClient() as client:
                response = await client.post(
                    f"{self.base_url}/api/chat",
                    json={
                        "model": self.model,
                        "messages": [
                            {"role": "system", "content": system},
                            {"role": "user", "content": query}
                        ],
                        "stream": False,
                        "options": {
                             "temperature": 0.0,
                             "num_predict": 10
                         }
                    },
                    timeout=3.0
                )
                response.raise_for_status()
                data = response.json()
                result = data.get("message", {}).get("content", "").strip().lower()
                
                if "complex" in result:
                    return "complex"
                return "simple"
        except Exception:
            return "complex"

    @retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException)))
    async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]:
        async with httpx.AsyncClient() as client:
            try:
                async with client.stream(
                    "POST",
                    f"{self.base_url}/api/chat",
                    json={
                        "model": self.model,
                        "messages": [
                            {"role": "system", "content": system},
                            {"role": "user", "content": prompt}
                        ],
                        "stream": True # Force true per instruction
                    }
                ) as response:
                    response.raise_for_status()
                    async for line in response.aiter_lines():
                        if line:
                            try:
                                data = json.loads(line)
                                if "message" in data and "content" in data["message"]:
                                    yield data["message"]["content"]
                            except json.JSONDecodeError:
                                pass
                                
            except Exception as e:
                raise GenerationError("Ollama completion failed", context={"error": str(e)}) from e
                
    async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]:
         # Ollama just uses one model in this implementation
         async for token in self.complete(prompt, system, stream):
             yield token


def get_llm_client(settings: Settings, tpm_bucket: Optional[TpmBucket] = None) -> LLMClient:
    if settings.LLM_PROVIDER == "ollama":
        if not settings.OLLAMA_BASE_URL or not settings.OLLAMA_MODEL:
             raise ValueError("OLLAMA_BASE_URL and OLLAMA_MODEL must be explicitly set when LLM_PROVIDER is 'ollama'")
        return OllamaClient(
            base_url=settings.OLLAMA_BASE_URL,
            model=settings.OLLAMA_MODEL
        )
    else:
        # Defaults to Groq
        return GroqClient(
            api_key=settings.GROQ_API_KEY or "",
            model_default=settings.GROQ_MODEL_DEFAULT,
            model_large=settings.GROQ_MODEL_LARGE,
            tpm_bucket=tpm_bucket,
        )