Spaces:
Running
Running
| import json | |
| import time | |
| from typing import AsyncIterator, Literal, Optional, Protocol | |
| import httpx | |
| from groq import AsyncGroq | |
| from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type | |
| from app.core.config import Settings | |
| from app.core.exceptions import GenerationError | |
| class TpmBucket: | |
| """ | |
| Sliding 60-second token-consumption tracker shared across all Groq calls. | |
| Issue 7: When the bucket exceeds 12,000 estimated tokens in the current | |
| minute window, complete_with_complexity() downgrades 70B calls to 8B | |
| automatically. This leaves 2,400 TPM headroom and prevents hard failures | |
| (HTTP 429) from degrading the service under load. | |
| Token estimates are rough (prompt_chars / 4) but accurate enough for this | |
| protective purpose — the goal is load shedding, not exact accounting. | |
| """ | |
| _WINDOW_SECONDS: int = 60 | |
| _DOWNGRADE_THRESHOLD: int = 12_000 | |
| def __init__(self) -> None: | |
| self._count: int = 0 | |
| self._window_start: float = time.monotonic() | |
| def add(self, estimated_tokens: int) -> None: | |
| now = time.monotonic() | |
| if now - self._window_start >= self._WINDOW_SECONDS: | |
| self._count = 0 | |
| self._window_start = now | |
| self._count += estimated_tokens | |
| def should_downgrade(self) -> bool: | |
| now = time.monotonic() | |
| if now - self._window_start >= self._WINDOW_SECONDS: | |
| return False | |
| return self._count > self._DOWNGRADE_THRESHOLD | |
| class LLMClient(Protocol): | |
| async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]: | |
| ... | |
| async def classify_complexity(self, query: str) -> Literal["simple", "complex"]: | |
| ... | |
| async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]: | |
| ... | |
| class GroqClient: | |
| def __init__(self, api_key: str, model_default: str, model_large: str, tpm_bucket: Optional[TpmBucket] = None): | |
| if not api_key or api_key == "gsk_placeholder": | |
| # We might be initialized in a test context without a real key | |
| self.client = None | |
| else: | |
| self.client = AsyncGroq(api_key=api_key) | |
| self.model_default = model_default | |
| self.model_large = model_large | |
| # Shared TPM bucket — injected at startup, None in test contexts. | |
| self._tpm_bucket = tpm_bucket | |
| async def classify_complexity(self, query: str) -> Literal["simple", "complex"]: | |
| if not self.client: | |
| raise GenerationError("GroqClient not configured with an API Key.") | |
| system = "You are a classifier. Read the user query. Output ONLY the word 'simple' or 'complex'. Do not explain." | |
| try: | |
| response = await self.client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": query} | |
| ], | |
| model=self.model_default, | |
| temperature=0.0, | |
| max_tokens=10, | |
| timeout=3.0, | |
| ) | |
| result = response.choices[0].message.content.strip().lower() | |
| if "complex" in result: | |
| return "complex" | |
| return "simple" | |
| except Exception as e: | |
| # Fallback to complex just to be safe if classification fails on parsing | |
| return "complex" | |
| async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]: | |
| if not self.client: | |
| raise GenerationError("GroqClient not configured with an API Key.") | |
| model = self.model_default | |
| try: | |
| stream_response = await self.client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| model=model, | |
| stream=True | |
| ) | |
| async for chunk in stream_response: | |
| content = chunk.choices[0].delta.content | |
| if content: | |
| yield content | |
| except Exception as e: | |
| raise GenerationError("Groq completion failed", context={"error": str(e)}) from e | |
| async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]: | |
| # Helper to allow pipeline nodes to pass the pre-classified complexity. | |
| # Issue 7: if the shared TPM bucket is above 12,000 tokens in the current | |
| # minute window, downgrade 70B to 8B to prevent hard rate-limit failures. | |
| if not self.client: | |
| raise GenerationError("GroqClient not configured with an API Key.") | |
| if complexity == "complex" and self._tpm_bucket is not None and self._tpm_bucket.should_downgrade: | |
| model = self.model_default | |
| else: | |
| model = self.model_large if complexity == "complex" else self.model_default | |
| # Estimate input tokens before the call so the bucket reflects the full | |
| # cost even when the response is long. 4 chars ≈ 1 token (rough heuristic). | |
| if self._tpm_bucket is not None: | |
| self._tpm_bucket.add((len(prompt) + len(system)) // 4) | |
| try: | |
| stream_response = await self.client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| model=model, | |
| stream=stream # Instruct strictly said stream=True yields token chunks. | |
| ) | |
| if stream: | |
| async for chunk in stream_response: | |
| content = chunk.choices[0].delta.content | |
| if content: | |
| # Accumulate estimated response tokens in the bucket. | |
| if self._tpm_bucket is not None: | |
| self._tpm_bucket.add(len(content) // 4 or 1) | |
| yield content | |
| else: | |
| full = stream_response.choices[0].message.content | |
| if self._tpm_bucket is not None and full: | |
| self._tpm_bucket.add(len(full) // 4) | |
| yield full | |
| except Exception as e: | |
| raise GenerationError("Groq completion failed", context={"error": str(e)}) from e | |
| class OllamaClient: | |
| def __init__(self, base_url: str, model: str): | |
| self.base_url = base_url.rstrip("/") | |
| self.model = model | |
| async def classify_complexity(self, query: str) -> Literal["simple", "complex"]: | |
| system = "You are a classifier. Read the user query. Output ONLY the word 'simple' or 'complex'. Do not explain." | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post( | |
| f"{self.base_url}/api/chat", | |
| json={ | |
| "model": self.model, | |
| "messages": [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": query} | |
| ], | |
| "stream": False, | |
| "options": { | |
| "temperature": 0.0, | |
| "num_predict": 10 | |
| } | |
| }, | |
| timeout=3.0 | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| result = data.get("message", {}).get("content", "").strip().lower() | |
| if "complex" in result: | |
| return "complex" | |
| return "simple" | |
| except Exception: | |
| return "complex" | |
| async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]: | |
| async with httpx.AsyncClient() as client: | |
| try: | |
| async with client.stream( | |
| "POST", | |
| f"{self.base_url}/api/chat", | |
| json={ | |
| "model": self.model, | |
| "messages": [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| "stream": True # Force true per instruction | |
| } | |
| ) as response: | |
| response.raise_for_status() | |
| async for line in response.aiter_lines(): | |
| if line: | |
| try: | |
| data = json.loads(line) | |
| if "message" in data and "content" in data["message"]: | |
| yield data["message"]["content"] | |
| except json.JSONDecodeError: | |
| pass | |
| except Exception as e: | |
| raise GenerationError("Ollama completion failed", context={"error": str(e)}) from e | |
| async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]: | |
| # Ollama just uses one model in this implementation | |
| async for token in self.complete(prompt, system, stream): | |
| yield token | |
| def get_llm_client(settings: Settings, tpm_bucket: Optional[TpmBucket] = None) -> LLMClient: | |
| if settings.LLM_PROVIDER == "ollama": | |
| if not settings.OLLAMA_BASE_URL or not settings.OLLAMA_MODEL: | |
| raise ValueError("OLLAMA_BASE_URL and OLLAMA_MODEL must be explicitly set when LLM_PROVIDER is 'ollama'") | |
| return OllamaClient( | |
| base_url=settings.OLLAMA_BASE_URL, | |
| model=settings.OLLAMA_MODEL | |
| ) | |
| else: | |
| # Defaults to Groq | |
| return GroqClient( | |
| api_key=settings.GROQ_API_KEY or "", | |
| model_default=settings.GROQ_MODEL_DEFAULT, | |
| model_large=settings.GROQ_MODEL_LARGE, | |
| tpm_bucket=tpm_bucket, | |
| ) | |