import json import time from typing import AsyncIterator, Literal, Optional, Protocol import httpx from groq import AsyncGroq from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type from app.core.config import Settings from app.core.exceptions import GenerationError class TpmBucket: """ Sliding 60-second token-consumption tracker shared across all Groq calls. Issue 7: When the bucket exceeds 12,000 estimated tokens in the current minute window, complete_with_complexity() downgrades 70B calls to 8B automatically. This leaves 2,400 TPM headroom and prevents hard failures (HTTP 429) from degrading the service under load. Token estimates are rough (prompt_chars / 4) but accurate enough for this protective purpose — the goal is load shedding, not exact accounting. """ _WINDOW_SECONDS: int = 60 _DOWNGRADE_THRESHOLD: int = 12_000 def __init__(self) -> None: self._count: int = 0 self._window_start: float = time.monotonic() def add(self, estimated_tokens: int) -> None: now = time.monotonic() if now - self._window_start >= self._WINDOW_SECONDS: self._count = 0 self._window_start = now self._count += estimated_tokens @property def should_downgrade(self) -> bool: now = time.monotonic() if now - self._window_start >= self._WINDOW_SECONDS: return False return self._count > self._DOWNGRADE_THRESHOLD class LLMClient(Protocol): async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]: ... async def classify_complexity(self, query: str) -> Literal["simple", "complex"]: ... async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]: ... class GroqClient: def __init__(self, api_key: str, model_default: str, model_large: str, tpm_bucket: Optional[TpmBucket] = None): if not api_key or api_key == "gsk_placeholder": # We might be initialized in a test context without a real key self.client = None else: self.client = AsyncGroq(api_key=api_key) self.model_default = model_default self.model_large = model_large # Shared TPM bucket — injected at startup, None in test contexts. self._tpm_bucket = tpm_bucket @retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException))) async def classify_complexity(self, query: str) -> Literal["simple", "complex"]: if not self.client: raise GenerationError("GroqClient not configured with an API Key.") system = "You are a classifier. Read the user query. Output ONLY the word 'simple' or 'complex'. Do not explain." try: response = await self.client.chat.completions.create( messages=[ {"role": "system", "content": system}, {"role": "user", "content": query} ], model=self.model_default, temperature=0.0, max_tokens=10, timeout=3.0, ) result = response.choices[0].message.content.strip().lower() if "complex" in result: return "complex" return "simple" except Exception as e: # Fallback to complex just to be safe if classification fails on parsing return "complex" @retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException))) async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]: if not self.client: raise GenerationError("GroqClient not configured with an API Key.") model = self.model_default try: stream_response = await self.client.chat.completions.create( messages=[ {"role": "system", "content": system}, {"role": "user", "content": prompt} ], model=model, stream=True ) async for chunk in stream_response: content = chunk.choices[0].delta.content if content: yield content except Exception as e: raise GenerationError("Groq completion failed", context={"error": str(e)}) from e async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]: # Helper to allow pipeline nodes to pass the pre-classified complexity. # Issue 7: if the shared TPM bucket is above 12,000 tokens in the current # minute window, downgrade 70B to 8B to prevent hard rate-limit failures. if not self.client: raise GenerationError("GroqClient not configured with an API Key.") if complexity == "complex" and self._tpm_bucket is not None and self._tpm_bucket.should_downgrade: model = self.model_default else: model = self.model_large if complexity == "complex" else self.model_default # Estimate input tokens before the call so the bucket reflects the full # cost even when the response is long. 4 chars ≈ 1 token (rough heuristic). if self._tpm_bucket is not None: self._tpm_bucket.add((len(prompt) + len(system)) // 4) try: stream_response = await self.client.chat.completions.create( messages=[ {"role": "system", "content": system}, {"role": "user", "content": prompt} ], model=model, stream=stream # Instruct strictly said stream=True yields token chunks. ) if stream: async for chunk in stream_response: content = chunk.choices[0].delta.content if content: # Accumulate estimated response tokens in the bucket. if self._tpm_bucket is not None: self._tpm_bucket.add(len(content) // 4 or 1) yield content else: full = stream_response.choices[0].message.content if self._tpm_bucket is not None and full: self._tpm_bucket.add(len(full) // 4) yield full except Exception as e: raise GenerationError("Groq completion failed", context={"error": str(e)}) from e class OllamaClient: def __init__(self, base_url: str, model: str): self.base_url = base_url.rstrip("/") self.model = model @retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException))) async def classify_complexity(self, query: str) -> Literal["simple", "complex"]: system = "You are a classifier. Read the user query. Output ONLY the word 'simple' or 'complex'. Do not explain." try: async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/api/chat", json={ "model": self.model, "messages": [ {"role": "system", "content": system}, {"role": "user", "content": query} ], "stream": False, "options": { "temperature": 0.0, "num_predict": 10 } }, timeout=3.0 ) response.raise_for_status() data = response.json() result = data.get("message", {}).get("content", "").strip().lower() if "complex" in result: return "complex" return "simple" except Exception: return "complex" @retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException))) async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]: async with httpx.AsyncClient() as client: try: async with client.stream( "POST", f"{self.base_url}/api/chat", json={ "model": self.model, "messages": [ {"role": "system", "content": system}, {"role": "user", "content": prompt} ], "stream": True # Force true per instruction } ) as response: response.raise_for_status() async for line in response.aiter_lines(): if line: try: data = json.loads(line) if "message" in data and "content" in data["message"]: yield data["message"]["content"] except json.JSONDecodeError: pass except Exception as e: raise GenerationError("Ollama completion failed", context={"error": str(e)}) from e async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]: # Ollama just uses one model in this implementation async for token in self.complete(prompt, system, stream): yield token def get_llm_client(settings: Settings, tpm_bucket: Optional[TpmBucket] = None) -> LLMClient: if settings.LLM_PROVIDER == "ollama": if not settings.OLLAMA_BASE_URL or not settings.OLLAMA_MODEL: raise ValueError("OLLAMA_BASE_URL and OLLAMA_MODEL must be explicitly set when LLM_PROVIDER is 'ollama'") return OllamaClient( base_url=settings.OLLAMA_BASE_URL, model=settings.OLLAMA_MODEL ) else: # Defaults to Groq return GroqClient( api_key=settings.GROQ_API_KEY or "", model_default=settings.GROQ_MODEL_DEFAULT, model_large=settings.GROQ_MODEL_LARGE, tpm_bucket=tpm_bucket, )