"""Provider-agnostic LLM client abstraction (Gemini & Groq via OpenAI API).""" from __future__ import annotations import os from dataclasses import dataclass import time from typing import Any from google import genai from core.config import get_settings from core.logger import get_logger logger = get_logger(__name__) @dataclass(frozen=True) class UsageMetrics: prompt_tokens: int completion_tokens: int total_tokens: int latency_ms: float estimated_cost_usd: float class LLMClient: """Simple LLM client with `chat` and usage logging.""" def __init__( self, provider: str | None = None, model: str | None = None, temperature: float | None = None, ) -> None: settings = get_settings() self.provider = (provider or settings.llm_provider).lower() self.model = model or settings.llm_model self.temperature = ( settings.llm_temperature if temperature is None else float(temperature) ) self.input_cost_per_1m_tokens = settings.input_cost_per_1m_tokens self.output_cost_per_1m_tokens = settings.output_cost_per_1m_tokens if self.provider == "gemini": # Prefer GEMINI_MODEL env var over the generic LLM_MODEL. self.model = model or os.environ.get("GEMINI_MODEL", settings.llm_model) # The SDK reads GEMINI_API_KEY from environment variables. self.client = genai.Client() elif self.provider == "groq": from openai import OpenAI api_key = os.environ.get("GROQ_API_KEY", "") if not api_key: raise ValueError("Missing GROQ_API_KEY environment variable.") self.model = model or os.environ.get( "GROQ_MODEL", "llama-3.3-70b-versatile" ) base_url = os.environ.get( "GROQ_BASE_URL", "https://api.groq.com/openai/v1" ) self.client = OpenAI(api_key=api_key, base_url=base_url) else: raise ValueError( f"Unsupported provider '{self.provider}'. Supported: 'gemini', 'groq'." ) def chat(self, messages: list[dict[str, str]] | str, **kwargs: Any) -> dict[str, Any]: """Send messages to the configured model and return response + metadata.""" if self.provider == "groq": return self._chat_groq(messages, **kwargs) prompt = self._messages_to_prompt(messages) if not prompt: raise ValueError("messages cannot be empty.") started = time.perf_counter() try: config = {"temperature": self.temperature} extra_config = kwargs.pop("config", None) if isinstance(extra_config, dict): config.update(extra_config) config.update(kwargs) response = self.client.models.generate_content( model=self.model, contents=prompt, config=config, ) except Exception as exc: logger.exception( "LLM call failed | provider=%s | model=%s", self.provider, self.model, ) raise RuntimeError("Failed to call the LLM provider.") from exc latency_ms = (time.perf_counter() - started) * 1000 prompt_tokens, completion_tokens, total_tokens = self._extract_usage(response) estimated_cost_usd = self._estimate_cost(prompt_tokens, completion_tokens) metrics = UsageMetrics( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, latency_ms=latency_ms, estimated_cost_usd=estimated_cost_usd, ) self.log_usage(metrics) text = (getattr(response, "text", "") or "").strip() if not text: raise RuntimeError("The LLM returned an empty response.") return { "response": text, "metadata": { "provider": self.provider, "model": self.model, "temperature": self.temperature, "usage": { "prompt_tokens": metrics.prompt_tokens, "completion_tokens": metrics.completion_tokens, "total_tokens": metrics.total_tokens, }, "latency_ms": round(metrics.latency_ms, 2), "estimated_cost_usd": round(metrics.estimated_cost_usd, 8), }, } def log_usage(self, metrics: UsageMetrics) -> None: """Log usage metrics for every LLM call.""" logger.info( ( "llm_call | provider=%s | model=%s | prompt_tokens=%d " "| completion_tokens=%d | total_tokens=%d | latency_ms=%.2f " "| estimated_cost_usd=%.8f" ), self.provider, self.model, metrics.prompt_tokens, metrics.completion_tokens, metrics.total_tokens, metrics.latency_ms, metrics.estimated_cost_usd, ) def _estimate_cost(self, prompt_tokens: int, completion_tokens: int) -> float: input_cost = (prompt_tokens / 1_000_000) * self.input_cost_per_1m_tokens output_cost = (completion_tokens / 1_000_000) * self.output_cost_per_1m_tokens return input_cost + output_cost def _extract_usage(self, response: Any) -> tuple[int, int, int]: usage = getattr(response, "usage", None) if usage is None: usage = getattr(response, "usage_metadata", None) prompt_tokens = self._read_usage_value( usage, "prompt_tokens", "prompt_token_count", "input_tokens", "input_token_count", ) completion_tokens = self._read_usage_value( usage, "completion_tokens", "candidates_token_count", "output_tokens", "output_token_count", ) total_tokens = self._read_usage_value( usage, "total_tokens", "total_token_count", ) if total_tokens == 0: total_tokens = prompt_tokens + completion_tokens return prompt_tokens, completion_tokens, total_tokens def _read_usage_value(self, usage: Any, *fields: str) -> int: if usage is None: return 0 for field_name in fields: value = getattr(usage, field_name, None) if value is None and isinstance(usage, dict): value = usage.get(field_name) if value is None: continue try: return int(value) except (TypeError, ValueError): continue return 0 def _chat_groq(self, messages: list[dict[str, str]] | str, **kwargs: Any) -> dict[str, Any]: """Handle chat via the Groq API (OpenAI-compatible).""" if isinstance(messages, str): groq_messages = [{"role": "user", "content": messages}] elif isinstance(messages, list): groq_messages = [ {"role": m.get("role", "user"), "content": m.get("content", "")} for m in messages ] else: raise TypeError("messages must be either a string or a list of dicts.") config = kwargs.pop("config", None) or {} if isinstance(config, dict): config = dict(config) else: config = {} config.update(kwargs) max_tokens = config.pop("max_output_tokens", config.pop("max_tokens", 1024)) temperature = config.pop("temperature", self.temperature) started = time.perf_counter() try: response = self.client.chat.completions.create( model=self.model, messages=groq_messages, max_tokens=max_tokens, temperature=temperature, ) except Exception as exc: logger.exception( "LLM call failed | provider=%s | model=%s", self.provider, self.model, ) raise RuntimeError("Failed to call the LLM provider.") from exc latency_ms = (time.perf_counter() - started) * 1000 text = (response.choices[0].message.content or "").strip() prompt_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 completion_tokens = getattr(response.usage, "completion_tokens", 0) or 0 total_tokens = prompt_tokens + completion_tokens estimated_cost_usd = self._estimate_cost(prompt_tokens, completion_tokens) metrics = UsageMetrics( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, latency_ms=latency_ms, estimated_cost_usd=estimated_cost_usd, ) self.log_usage(metrics) if not text: raise RuntimeError("The LLM returned an empty response.") return { "response": text, "metadata": { "provider": self.provider, "model": self.model, "temperature": self.temperature, "usage": { "prompt_tokens": metrics.prompt_tokens, "completion_tokens": metrics.completion_tokens, "total_tokens": metrics.total_tokens, }, "latency_ms": round(metrics.latency_ms, 2), "estimated_cost_usd": round(metrics.estimated_cost_usd, 8), }, } def _messages_to_prompt(self, messages: list[dict[str, str]] | str) -> str: if isinstance(messages, str): return messages.strip() if not isinstance(messages, list): raise TypeError("messages must be either a string or a list of dictionaries.") lines: list[str] = [] for item in messages: if not isinstance(item, dict): raise TypeError("Each message must be a dictionary with role/content.") role = str(item.get("role", "user")).strip() or "user" content = str(item.get("content", "")).strip() if content: lines.append(f"{role}: {content}") return "\n".join(lines).strip()