Spaces:
Sleeping
Sleeping
| """Token counting, cost estimation, latency measurement, and budget checking. | |
| Reusable across classes 3, 5, 14-16 of the AI Engineer Bootcamp. | |
| Supports Gemini and Groq providers. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from dataclasses import dataclass, field | |
| from time import perf_counter | |
| from typing import Any | |
| from core.config import get_settings | |
| from core.logger import get_logger | |
| logger = get_logger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Data types | |
| # --------------------------------------------------------------------------- | |
| class Pricing: | |
| input_per_1k: float | |
| output_per_1k: float | |
| currency: str = "USD" | |
| class LatencyResult: | |
| ttft_s: float | None | |
| total_s: float | |
| tps: float | None | |
| input_tokens: int | None | |
| output_tokens: int | None | |
| output_text: str | |
| meta: dict = field(default_factory=dict) | |
| class BudgetExceededError(RuntimeError): | |
| """Raised when a request exceeds the configured budget.""" | |
| # --------------------------------------------------------------------------- | |
| # Provider helpers | |
| # --------------------------------------------------------------------------- | |
| def _get_gemini_client(): | |
| """Return (genai.Client, model_name) for Gemini.""" | |
| from google import genai | |
| client = genai.Client() | |
| model = get_settings().llm_model | |
| return client, model | |
| def _get_groq_client(): | |
| """Return (Groq client, model_name) for Groq.""" | |
| from groq import Groq | |
| api_key = os.environ.get("GROQ_API_KEY", "") | |
| if not api_key: | |
| raise ValueError("Missing GROQ_API_KEY environment variable.") | |
| client = Groq(api_key=api_key) | |
| model = os.environ.get("GROQ_MODEL", "llama-3.3-70b-versatile") | |
| return client, model | |
| def _resolve_provider(provider: str | None) -> str: | |
| return (provider or get_settings().llm_provider).lower() | |
| # --------------------------------------------------------------------------- | |
| # count_tokens | |
| # --------------------------------------------------------------------------- | |
| def count_tokens(text: str, provider: str | None = None) -> int: | |
| """Count tokens in *text*. Gemini uses the native API; Groq approximates.""" | |
| provider = _resolve_provider(provider) | |
| if provider == "gemini": | |
| client, model = _get_gemini_client() | |
| result = client.models.count_tokens(model=model, contents=text) | |
| if hasattr(result, "total_tokens"): | |
| return int(result.total_tokens) | |
| try: | |
| return int(result) | |
| except (TypeError, ValueError): | |
| raise RuntimeError(f"Cannot extract token count from {result!r}") | |
| if provider == "groq": | |
| # Groq has no count_tokens endpoint. | |
| # Use tiktoken if available; otherwise ≈ 4 chars/token. | |
| try: | |
| import tiktoken | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| return len(enc.encode(text)) | |
| except ImportError: | |
| return max(1, len(text) // 4) | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| # --------------------------------------------------------------------------- | |
| # estimate_cost | |
| # --------------------------------------------------------------------------- | |
| def estimate_cost( | |
| input_tokens: int, | |
| output_tokens: int, | |
| pricing: Pricing | None = None, | |
| ) -> float: | |
| """Estimate cost from token counts and pricing. Returns 0.0 when pricing is None.""" | |
| if pricing is None: | |
| return 0.0 | |
| return ( | |
| (input_tokens / 1000) * pricing.input_per_1k | |
| + (output_tokens / 1000) * pricing.output_per_1k | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # measure_latency | |
| # --------------------------------------------------------------------------- | |
| def measure_latency( | |
| prompt: str, | |
| *, | |
| stream: bool = False, | |
| generation_config: dict | None = None, | |
| pricing: Pricing | None = None, | |
| provider: str | None = None, | |
| ) -> LatencyResult: | |
| """Measure latency (and TTFT when *stream=True*) of an LLM call.""" | |
| provider = _resolve_provider(provider) | |
| if provider == "gemini": | |
| return _measure_gemini(prompt, stream=stream, | |
| generation_config=generation_config, pricing=pricing) | |
| if provider == "groq": | |
| return _measure_groq(prompt, stream=stream, | |
| generation_config=generation_config, pricing=pricing) | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| # -- Gemini ---------------------------------------------------------------- | |
| def _measure_gemini( | |
| prompt: str, | |
| *, | |
| stream: bool, | |
| generation_config: dict | None, | |
| pricing: Pricing | None, | |
| ) -> LatencyResult: | |
| client, model = _get_gemini_client() | |
| config = dict(generation_config or {}) | |
| if not stream: | |
| t0 = perf_counter() | |
| response = client.models.generate_content( | |
| model=model, contents=prompt, config=config, | |
| ) | |
| total_s = perf_counter() - t0 | |
| output_text = (getattr(response, "text", "") or "").strip() | |
| input_tokens, output_tokens = _gemini_tokens(response) | |
| if not input_tokens: | |
| input_tokens = count_tokens(prompt, provider="gemini") | |
| if not output_tokens: | |
| output_tokens = count_tokens(output_text, provider="gemini") if output_text else 0 | |
| tps = output_tokens / total_s if total_s > 0 else None | |
| cost = estimate_cost(input_tokens, output_tokens, pricing) | |
| return LatencyResult( | |
| ttft_s=None, total_s=total_s, tps=tps, | |
| input_tokens=input_tokens, output_tokens=output_tokens, | |
| output_text=output_text, | |
| meta={"cost": cost, "provider": "gemini"}, | |
| ) | |
| # --- streaming --- | |
| t0 = perf_counter() | |
| t_first: float | None = None | |
| output_text = "" | |
| try: | |
| stream_iter = client.models.generate_content_stream( | |
| model=model, contents=prompt, config=config, | |
| ) | |
| except AttributeError: | |
| raise RuntimeError( | |
| "Streaming no soportado por el SDK actual de Gemini." | |
| ) | |
| usage_meta = None | |
| for chunk in stream_iter: | |
| chunk_text = getattr(chunk, "text", "") or "" | |
| if chunk_text and t_first is None: | |
| t_first = perf_counter() | |
| output_text += chunk_text | |
| um = getattr(chunk, "usage_metadata", None) | |
| if um: | |
| usage_meta = um | |
| t_end = perf_counter() | |
| total_s = t_end - t0 | |
| ttft_s = (t_first - t0) if t_first is not None else None | |
| input_tokens = _safe_int(getattr(usage_meta, "prompt_token_count", None)) if usage_meta else None | |
| output_tokens = _safe_int(getattr(usage_meta, "candidates_token_count", None)) if usage_meta else None | |
| if not input_tokens: | |
| input_tokens = count_tokens(prompt, provider="gemini") | |
| if not output_tokens: | |
| output_text_stripped = output_text.strip() | |
| output_tokens = count_tokens(output_text_stripped, provider="gemini") if output_text_stripped else 0 | |
| tps = None | |
| if t_first is not None and (t_end - t_first) > 0: | |
| tps = output_tokens / (t_end - t_first) | |
| cost = estimate_cost(input_tokens, output_tokens, pricing) | |
| return LatencyResult( | |
| ttft_s=ttft_s, total_s=total_s, tps=tps, | |
| input_tokens=input_tokens, output_tokens=output_tokens, | |
| output_text=output_text.strip(), | |
| meta={"cost": cost, "provider": "gemini"}, | |
| ) | |
| def _gemini_tokens(response: Any) -> tuple[int | None, int | None]: | |
| """Extract token counts from a Gemini response.""" | |
| usage = getattr(response, "usage_metadata", None) | |
| if usage is None: | |
| return None, None | |
| return ( | |
| _safe_int(getattr(usage, "prompt_token_count", None)), | |
| _safe_int(getattr(usage, "candidates_token_count", None)), | |
| ) | |
| # -- Groq ----------------------------------------------------------------- | |
| def _measure_groq( | |
| prompt: str, | |
| *, | |
| stream: bool, | |
| generation_config: dict | None, | |
| pricing: Pricing | None, | |
| ) -> LatencyResult: | |
| client, model = _get_groq_client() | |
| gen = dict(generation_config or {}) | |
| # Map Gemini-style keys to Groq/OpenAI keys | |
| temperature = gen.pop("temperature", None) | |
| max_tokens = gen.pop("max_output_tokens", gen.pop("max_tokens", 1024)) | |
| messages = [{"role": "user", "content": prompt}] | |
| kwargs: dict[str, Any] = {"model": model, "messages": messages, "max_tokens": max_tokens} | |
| if temperature is not None: | |
| kwargs["temperature"] = temperature | |
| if not stream: | |
| t0 = perf_counter() | |
| response = client.chat.completions.create(**kwargs) | |
| total_s = perf_counter() - t0 | |
| output_text = (response.choices[0].message.content or "").strip() | |
| input_tokens = getattr(response.usage, "prompt_tokens", None) if response.usage else None | |
| output_tokens = getattr(response.usage, "completion_tokens", None) if response.usage else None | |
| if not input_tokens: | |
| input_tokens = count_tokens(prompt, provider="groq") | |
| if not output_tokens: | |
| output_tokens = count_tokens(output_text, provider="groq") if output_text else 0 | |
| tps = output_tokens / total_s if total_s > 0 else None | |
| cost = estimate_cost(input_tokens, output_tokens, pricing) | |
| return LatencyResult( | |
| ttft_s=None, total_s=total_s, tps=tps, | |
| input_tokens=input_tokens, output_tokens=output_tokens, | |
| output_text=output_text, | |
| meta={"cost": cost, "provider": "groq"}, | |
| ) | |
| # --- streaming --- | |
| kwargs["stream"] = True | |
| t0 = perf_counter() | |
| t_first: float | None = None | |
| output_text = "" | |
| usage_data = None | |
| stream_resp = client.chat.completions.create(**kwargs) | |
| for chunk in stream_resp: | |
| if not chunk.choices: | |
| # Final chunk may carry usage only | |
| if hasattr(chunk, "usage") and chunk.usage: | |
| usage_data = chunk.usage | |
| continue | |
| delta = chunk.choices[0].delta | |
| chunk_text = getattr(delta, "content", "") or "" | |
| if chunk_text and t_first is None: | |
| t_first = perf_counter() | |
| output_text += chunk_text | |
| if hasattr(chunk, "usage") and chunk.usage: | |
| usage_data = chunk.usage | |
| t_end = perf_counter() | |
| total_s = t_end - t0 | |
| ttft_s = (t_first - t0) if t_first is not None else None | |
| input_tokens = getattr(usage_data, "prompt_tokens", None) if usage_data else None | |
| output_tokens = getattr(usage_data, "completion_tokens", None) if usage_data else None | |
| if not input_tokens: | |
| input_tokens = count_tokens(prompt, provider="groq") | |
| if not output_tokens: | |
| stripped = output_text.strip() | |
| output_tokens = count_tokens(stripped, provider="groq") if stripped else 0 | |
| tps = None | |
| if t_first is not None and (t_end - t_first) > 0: | |
| tps = output_tokens / (t_end - t_first) | |
| cost = estimate_cost(input_tokens, output_tokens, pricing) | |
| return LatencyResult( | |
| ttft_s=ttft_s, total_s=total_s, tps=tps, | |
| input_tokens=input_tokens, output_tokens=output_tokens, | |
| output_text=output_text.strip(), | |
| meta={"cost": cost, "provider": "groq"}, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # BudgetChecker | |
| # --------------------------------------------------------------------------- | |
| class BudgetChecker: | |
| """Validates that a request stays within a cost budget.""" | |
| def __init__( | |
| self, | |
| max_cost_usd: float, | |
| pricing: Pricing, | |
| *, | |
| strict: bool = True, | |
| ) -> None: | |
| self.max_cost_usd = max_cost_usd | |
| self.pricing = pricing | |
| self.strict = strict | |
| def check(self, input_tokens: int, output_tokens: int) -> dict: | |
| """Return status dict; raise *BudgetExceededError* in strict mode.""" | |
| cost = estimate_cost(input_tokens, output_tokens, self.pricing) | |
| if cost > self.max_cost_usd: | |
| if self.strict: | |
| raise BudgetExceededError( | |
| f"Presupuesto excedido: ${cost:.6f} > límite ${self.max_cost_usd:.6f}" | |
| ) | |
| return {"ok": False, "cost": cost, "max": self.max_cost_usd} | |
| return {"ok": True, "cost": cost, "max": self.max_cost_usd} | |
| # --------------------------------------------------------------------------- | |
| # Streaming generators (for UI real-time display) | |
| # --------------------------------------------------------------------------- | |
| def stream_chunks( | |
| prompt: str, | |
| *, | |
| generation_config: dict | None = None, | |
| provider: str | None = None, | |
| _metrics_out: dict | None = None, | |
| ): | |
| """Yield text chunks from a streaming LLM call. | |
| If *_metrics_out* is a dict it will be filled **in-place** with timing | |
| and token information once the stream finishes. This lets callers | |
| (e.g. Streamlit) display chunks in real-time while still capturing | |
| TTFT / total_s / tps afterwards. | |
| """ | |
| provider = _resolve_provider(provider) | |
| if provider == "gemini": | |
| yield from _stream_chunks_gemini(prompt, generation_config, _metrics_out) | |
| elif provider == "groq": | |
| yield from _stream_chunks_groq(prompt, generation_config, _metrics_out) | |
| else: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| def _stream_chunks_gemini(prompt, generation_config, metrics_out): | |
| client, model = _get_gemini_client() | |
| config = dict(generation_config or {}) | |
| t0 = perf_counter() | |
| t_first = None | |
| output_text = "" | |
| usage_meta = None | |
| stream_iter = client.models.generate_content_stream( | |
| model=model, contents=prompt, config=config, | |
| ) | |
| for chunk in stream_iter: | |
| chunk_text = getattr(chunk, "text", "") or "" | |
| if chunk_text and t_first is None: | |
| t_first = perf_counter() | |
| if chunk_text: | |
| output_text += chunk_text | |
| yield chunk_text | |
| um = getattr(chunk, "usage_metadata", None) | |
| if um: | |
| usage_meta = um | |
| t_end = perf_counter() | |
| if metrics_out is not None: | |
| in_tok = _safe_int(getattr(usage_meta, "prompt_token_count", None)) if usage_meta else None | |
| out_tok = _safe_int(getattr(usage_meta, "candidates_token_count", None)) if usage_meta else None | |
| if not in_tok: | |
| in_tok = count_tokens(prompt, provider="gemini") | |
| if not out_tok: | |
| out_tok = count_tokens(output_text.strip(), provider="gemini") if output_text.strip() else 0 | |
| tps = None | |
| if t_first is not None and (t_end - t_first) > 0: | |
| tps = out_tok / (t_end - t_first) | |
| metrics_out.update( | |
| ttft_s=(t_first - t0) if t_first else None, | |
| total_s=t_end - t0, | |
| tps=tps, | |
| input_tokens=in_tok, | |
| output_tokens=out_tok, | |
| ) | |
| def _stream_chunks_groq(prompt, generation_config, metrics_out): | |
| client, model = _get_groq_client() | |
| gen = dict(generation_config or {}) | |
| temperature = gen.pop("temperature", None) | |
| max_tokens = gen.pop("max_output_tokens", gen.pop("max_tokens", 1024)) | |
| messages = [{"role": "user", "content": prompt}] | |
| kwargs: dict[str, Any] = { | |
| "model": model, "messages": messages, | |
| "max_tokens": max_tokens, "stream": True, | |
| } | |
| if temperature is not None: | |
| kwargs["temperature"] = temperature | |
| t0 = perf_counter() | |
| t_first = None | |
| output_text = "" | |
| usage_data = None | |
| stream_resp = client.chat.completions.create(**kwargs) | |
| for chunk in stream_resp: | |
| if not chunk.choices: | |
| if hasattr(chunk, "usage") and chunk.usage: | |
| usage_data = chunk.usage | |
| continue | |
| delta = chunk.choices[0].delta | |
| chunk_text = getattr(delta, "content", "") or "" | |
| if chunk_text and t_first is None: | |
| t_first = perf_counter() | |
| if chunk_text: | |
| output_text += chunk_text | |
| yield chunk_text | |
| if hasattr(chunk, "usage") and chunk.usage: | |
| usage_data = chunk.usage | |
| t_end = perf_counter() | |
| if metrics_out is not None: | |
| in_tok = getattr(usage_data, "prompt_tokens", None) if usage_data else None | |
| out_tok = getattr(usage_data, "completion_tokens", None) if usage_data else None | |
| if not in_tok: | |
| in_tok = count_tokens(prompt, provider="groq") | |
| if not out_tok: | |
| out_tok = count_tokens(output_text.strip(), provider="groq") if output_text.strip() else 0 | |
| tps = None | |
| if t_first is not None and (t_end - t_first) > 0: | |
| tps = out_tok / (t_end - t_first) | |
| metrics_out.update( | |
| ttft_s=(t_first - t0) if t_first else None, | |
| total_s=t_end - t0, | |
| tps=tps, | |
| input_tokens=in_tok, | |
| output_tokens=out_tok, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _safe_int(value: Any) -> int | None: | |
| if value is None: | |
| return None | |
| try: | |
| return int(value) | |
| except (TypeError, ValueError): | |
| return None | |