File size: 10,447 Bytes
d0d2f42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
"""Provider-agnostic LLM client abstraction (Gemini & Groq via OpenAI API)."""

from __future__ import annotations

import os
from dataclasses import dataclass
import time
from typing import Any

from google import genai

from core.config import get_settings
from core.logger import get_logger


logger = get_logger(__name__)


@dataclass(frozen=True)
class UsageMetrics:
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int
    latency_ms: float
    estimated_cost_usd: float


class LLMClient:
    """Simple LLM client with `chat` and usage logging."""

    def __init__(
        self,
        provider: str | None = None,
        model: str | None = None,
        temperature: float | None = None,
    ) -> None:
        settings = get_settings()
        self.provider = (provider or settings.llm_provider).lower()
        self.model = model or settings.llm_model
        self.temperature = (
            settings.llm_temperature if temperature is None else float(temperature)
        )
        self.input_cost_per_1m_tokens = settings.input_cost_per_1m_tokens
        self.output_cost_per_1m_tokens = settings.output_cost_per_1m_tokens

        if self.provider == "gemini":
            # Prefer GEMINI_MODEL env var over the generic LLM_MODEL.
            self.model = model or os.environ.get("GEMINI_MODEL", settings.llm_model)
            # The SDK reads GEMINI_API_KEY from environment variables.
            self.client = genai.Client()
        elif self.provider == "groq":
            from openai import OpenAI

            api_key = os.environ.get("GROQ_API_KEY", "")
            if not api_key:
                raise ValueError("Missing GROQ_API_KEY environment variable.")
            self.model = model or os.environ.get(
                "GROQ_MODEL", "llama-3.3-70b-versatile"
            )
            base_url = os.environ.get(
                "GROQ_BASE_URL", "https://api.groq.com/openai/v1"
            )
            self.client = OpenAI(api_key=api_key, base_url=base_url)
        else:
            raise ValueError(
                f"Unsupported provider '{self.provider}'. Supported: 'gemini', 'groq'."
            )

    def chat(self, messages: list[dict[str, str]] | str, **kwargs: Any) -> dict[str, Any]:
        """Send messages to the configured model and return response + metadata."""
        if self.provider == "groq":
            return self._chat_groq(messages, **kwargs)

        prompt = self._messages_to_prompt(messages)
        if not prompt:
            raise ValueError("messages cannot be empty.")

        started = time.perf_counter()
        try:
            config = {"temperature": self.temperature}
            extra_config = kwargs.pop("config", None)
            if isinstance(extra_config, dict):
                config.update(extra_config)
            config.update(kwargs)
            response = self.client.models.generate_content(
                model=self.model,
                contents=prompt,
                config=config,
            )
        except Exception as exc:
            logger.exception(
                "LLM call failed | provider=%s | model=%s",
                self.provider,
                self.model,
            )
            raise RuntimeError("Failed to call the LLM provider.") from exc

        latency_ms = (time.perf_counter() - started) * 1000
        prompt_tokens, completion_tokens, total_tokens = self._extract_usage(response)
        estimated_cost_usd = self._estimate_cost(prompt_tokens, completion_tokens)

        metrics = UsageMetrics(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=total_tokens,
            latency_ms=latency_ms,
            estimated_cost_usd=estimated_cost_usd,
        )
        self.log_usage(metrics)

        text = (getattr(response, "text", "") or "").strip()
        if not text:
            raise RuntimeError("The LLM returned an empty response.")

        return {
            "response": text,
            "metadata": {
                "provider": self.provider,
                "model": self.model,
                "temperature": self.temperature,
                "usage": {
                    "prompt_tokens": metrics.prompt_tokens,
                    "completion_tokens": metrics.completion_tokens,
                    "total_tokens": metrics.total_tokens,
                },
                "latency_ms": round(metrics.latency_ms, 2),
                "estimated_cost_usd": round(metrics.estimated_cost_usd, 8),
            },
        }

    def log_usage(self, metrics: UsageMetrics) -> None:
        """Log usage metrics for every LLM call."""
        logger.info(
            (
                "llm_call | provider=%s | model=%s | prompt_tokens=%d "
                "| completion_tokens=%d | total_tokens=%d | latency_ms=%.2f "
                "| estimated_cost_usd=%.8f"
            ),
            self.provider,
            self.model,
            metrics.prompt_tokens,
            metrics.completion_tokens,
            metrics.total_tokens,
            metrics.latency_ms,
            metrics.estimated_cost_usd,
        )

    def _estimate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
        input_cost = (prompt_tokens / 1_000_000) * self.input_cost_per_1m_tokens
        output_cost = (completion_tokens / 1_000_000) * self.output_cost_per_1m_tokens
        return input_cost + output_cost

    def _extract_usage(self, response: Any) -> tuple[int, int, int]:
        usage = getattr(response, "usage", None)
        if usage is None:
            usage = getattr(response, "usage_metadata", None)

        prompt_tokens = self._read_usage_value(
            usage,
            "prompt_tokens",
            "prompt_token_count",
            "input_tokens",
            "input_token_count",
        )
        completion_tokens = self._read_usage_value(
            usage,
            "completion_tokens",
            "candidates_token_count",
            "output_tokens",
            "output_token_count",
        )
        total_tokens = self._read_usage_value(
            usage,
            "total_tokens",
            "total_token_count",
        )

        if total_tokens == 0:
            total_tokens = prompt_tokens + completion_tokens

        return prompt_tokens, completion_tokens, total_tokens

    def _read_usage_value(self, usage: Any, *fields: str) -> int:
        if usage is None:
            return 0

        for field_name in fields:
            value = getattr(usage, field_name, None)
            if value is None and isinstance(usage, dict):
                value = usage.get(field_name)
            if value is None:
                continue
            try:
                return int(value)
            except (TypeError, ValueError):
                continue
        return 0

    def _chat_groq(self, messages: list[dict[str, str]] | str, **kwargs: Any) -> dict[str, Any]:
        """Handle chat via the Groq API (OpenAI-compatible)."""
        if isinstance(messages, str):
            groq_messages = [{"role": "user", "content": messages}]
        elif isinstance(messages, list):
            groq_messages = [
                {"role": m.get("role", "user"), "content": m.get("content", "")}
                for m in messages
            ]
        else:
            raise TypeError("messages must be either a string or a list of dicts.")

        config = kwargs.pop("config", None) or {}
        if isinstance(config, dict):
            config = dict(config)
        else:
            config = {}
        config.update(kwargs)

        max_tokens = config.pop("max_output_tokens", config.pop("max_tokens", 1024))
        temperature = config.pop("temperature", self.temperature)

        started = time.perf_counter()
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=groq_messages,
                max_tokens=max_tokens,
                temperature=temperature,
            )
        except Exception as exc:
            logger.exception(
                "LLM call failed | provider=%s | model=%s",
                self.provider,
                self.model,
            )
            raise RuntimeError("Failed to call the LLM provider.") from exc

        latency_ms = (time.perf_counter() - started) * 1000

        text = (response.choices[0].message.content or "").strip()

        prompt_tokens = getattr(response.usage, "prompt_tokens", 0) or 0
        completion_tokens = getattr(response.usage, "completion_tokens", 0) or 0
        total_tokens = prompt_tokens + completion_tokens
        estimated_cost_usd = self._estimate_cost(prompt_tokens, completion_tokens)

        metrics = UsageMetrics(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=total_tokens,
            latency_ms=latency_ms,
            estimated_cost_usd=estimated_cost_usd,
        )
        self.log_usage(metrics)

        if not text:
            raise RuntimeError("The LLM returned an empty response.")

        return {
            "response": text,
            "metadata": {
                "provider": self.provider,
                "model": self.model,
                "temperature": self.temperature,
                "usage": {
                    "prompt_tokens": metrics.prompt_tokens,
                    "completion_tokens": metrics.completion_tokens,
                    "total_tokens": metrics.total_tokens,
                },
                "latency_ms": round(metrics.latency_ms, 2),
                "estimated_cost_usd": round(metrics.estimated_cost_usd, 8),
            },
        }

    def _messages_to_prompt(self, messages: list[dict[str, str]] | str) -> str:
        if isinstance(messages, str):
            return messages.strip()

        if not isinstance(messages, list):
            raise TypeError("messages must be either a string or a list of dictionaries.")

        lines: list[str] = []
        for item in messages:
            if not isinstance(item, dict):
                raise TypeError("Each message must be a dictionary with role/content.")
            role = str(item.get("role", "user")).strip() or "user"
            content = str(item.get("content", "")).strip()
            if content:
                lines.append(f"{role}: {content}")

        return "\n".join(lines).strip()