personabot-api / app /services /llm_client.py
GitHub Actions
Deploy 85f07db
3d134a6
import json
import time
from typing import AsyncIterator, Literal, Optional, Protocol
import httpx
from groq import AsyncGroq
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
from app.core.config import Settings
from app.core.exceptions import GenerationError
class TpmBucket:
"""
Sliding 60-second token-consumption tracker shared across all Groq calls.
Issue 7: When the bucket exceeds 12,000 estimated tokens in the current
minute window, complete_with_complexity() downgrades 70B calls to 8B
automatically. This leaves 2,400 TPM headroom and prevents hard failures
(HTTP 429) from degrading the service under load.
Token estimates are rough (prompt_chars / 4) but accurate enough for this
protective purpose — the goal is load shedding, not exact accounting.
"""
_WINDOW_SECONDS: int = 60
_DOWNGRADE_THRESHOLD: int = 12_000
def __init__(self) -> None:
self._count: int = 0
self._window_start: float = time.monotonic()
def add(self, estimated_tokens: int) -> None:
now = time.monotonic()
if now - self._window_start >= self._WINDOW_SECONDS:
self._count = 0
self._window_start = now
self._count += estimated_tokens
@property
def should_downgrade(self) -> bool:
now = time.monotonic()
if now - self._window_start >= self._WINDOW_SECONDS:
return False
return self._count > self._DOWNGRADE_THRESHOLD
class LLMClient(Protocol):
async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]:
...
async def classify_complexity(self, query: str) -> Literal["simple", "complex"]:
...
async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]:
...
class GroqClient:
def __init__(self, api_key: str, model_default: str, model_large: str, tpm_bucket: Optional[TpmBucket] = None):
if not api_key or api_key == "gsk_placeholder":
# We might be initialized in a test context without a real key
self.client = None
else:
self.client = AsyncGroq(api_key=api_key)
self.model_default = model_default
self.model_large = model_large
# Shared TPM bucket — injected at startup, None in test contexts.
self._tpm_bucket = tpm_bucket
@retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException)))
async def classify_complexity(self, query: str) -> Literal["simple", "complex"]:
if not self.client:
raise GenerationError("GroqClient not configured with an API Key.")
system = "You are a classifier. Read the user query. Output ONLY the word 'simple' or 'complex'. Do not explain."
try:
response = await self.client.chat.completions.create(
messages=[
{"role": "system", "content": system},
{"role": "user", "content": query}
],
model=self.model_default,
temperature=0.0,
max_tokens=10,
timeout=3.0,
)
result = response.choices[0].message.content.strip().lower()
if "complex" in result:
return "complex"
return "simple"
except Exception as e:
# Fallback to complex just to be safe if classification fails on parsing
return "complex"
@retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException)))
async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]:
if not self.client:
raise GenerationError("GroqClient not configured with an API Key.")
model = self.model_default
try:
stream_response = await self.client.chat.completions.create(
messages=[
{"role": "system", "content": system},
{"role": "user", "content": prompt}
],
model=model,
stream=True
)
async for chunk in stream_response:
content = chunk.choices[0].delta.content
if content:
yield content
except Exception as e:
raise GenerationError("Groq completion failed", context={"error": str(e)}) from e
async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]:
# Helper to allow pipeline nodes to pass the pre-classified complexity.
# Issue 7: if the shared TPM bucket is above 12,000 tokens in the current
# minute window, downgrade 70B to 8B to prevent hard rate-limit failures.
if not self.client:
raise GenerationError("GroqClient not configured with an API Key.")
if complexity == "complex" and self._tpm_bucket is not None and self._tpm_bucket.should_downgrade:
model = self.model_default
else:
model = self.model_large if complexity == "complex" else self.model_default
# Estimate input tokens before the call so the bucket reflects the full
# cost even when the response is long. 4 chars ≈ 1 token (rough heuristic).
if self._tpm_bucket is not None:
self._tpm_bucket.add((len(prompt) + len(system)) // 4)
try:
stream_response = await self.client.chat.completions.create(
messages=[
{"role": "system", "content": system},
{"role": "user", "content": prompt}
],
model=model,
stream=stream # Instruct strictly said stream=True yields token chunks.
)
if stream:
async for chunk in stream_response:
content = chunk.choices[0].delta.content
if content:
# Accumulate estimated response tokens in the bucket.
if self._tpm_bucket is not None:
self._tpm_bucket.add(len(content) // 4 or 1)
yield content
else:
full = stream_response.choices[0].message.content
if self._tpm_bucket is not None and full:
self._tpm_bucket.add(len(full) // 4)
yield full
except Exception as e:
raise GenerationError("Groq completion failed", context={"error": str(e)}) from e
class OllamaClient:
def __init__(self, base_url: str, model: str):
self.base_url = base_url.rstrip("/")
self.model = model
@retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException)))
async def classify_complexity(self, query: str) -> Literal["simple", "complex"]:
system = "You are a classifier. Read the user query. Output ONLY the word 'simple' or 'complex'. Do not explain."
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/api/chat",
json={
"model": self.model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": query}
],
"stream": False,
"options": {
"temperature": 0.0,
"num_predict": 10
}
},
timeout=3.0
)
response.raise_for_status()
data = response.json()
result = data.get("message", {}).get("content", "").strip().lower()
if "complex" in result:
return "complex"
return "simple"
except Exception:
return "complex"
@retry(stop=stop_after_attempt(2), wait=wait_fixed(1.0), retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException)))
async def complete(self, prompt: str, system: str, stream: bool) -> AsyncIterator[str]:
async with httpx.AsyncClient() as client:
try:
async with client.stream(
"POST",
f"{self.base_url}/api/chat",
json={
"model": self.model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": prompt}
],
"stream": True # Force true per instruction
}
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line:
try:
data = json.loads(line)
if "message" in data and "content" in data["message"]:
yield data["message"]["content"]
except json.JSONDecodeError:
pass
except Exception as e:
raise GenerationError("Ollama completion failed", context={"error": str(e)}) from e
async def complete_with_complexity(self, prompt: str, system: str, stream: bool, complexity: str) -> AsyncIterator[str]:
# Ollama just uses one model in this implementation
async for token in self.complete(prompt, system, stream):
yield token
def get_llm_client(settings: Settings, tpm_bucket: Optional[TpmBucket] = None) -> LLMClient:
if settings.LLM_PROVIDER == "ollama":
if not settings.OLLAMA_BASE_URL or not settings.OLLAMA_MODEL:
raise ValueError("OLLAMA_BASE_URL and OLLAMA_MODEL must be explicitly set when LLM_PROVIDER is 'ollama'")
return OllamaClient(
base_url=settings.OLLAMA_BASE_URL,
model=settings.OLLAMA_MODEL
)
else:
# Defaults to Groq
return GroqClient(
api_key=settings.GROQ_API_KEY or "",
model_default=settings.GROQ_MODEL_DEFAULT,
model_large=settings.GROQ_MODEL_LARGE,
tpm_bucket=tpm_bucket,
)