from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timezone from decimal import Decimal from typing import Any, Dict, Literal, Optional from agent.model_metadata import fetch_endpoint_model_metadata, fetch_model_metadata DEFAULT_PRICING = {"input": 0.0, "output": 0.0} _ZERO = Decimal("0") _ONE_MILLION = Decimal("1000000") CostStatus = Literal["actual", "estimated", "included", "unknown"] CostSource = Literal[ "provider_cost_api", "provider_generation_api", "provider_models_api", "official_docs_snapshot", "user_override", "custom_contract", "none", ] @dataclass(frozen=True) class CanonicalUsage: input_tokens: int = 0 output_tokens: int = 0 cache_read_tokens: int = 0 cache_write_tokens: int = 0 reasoning_tokens: int = 0 request_count: int = 1 raw_usage: Optional[dict[str, Any]] = None @property def prompt_tokens(self) -> int: return self.input_tokens + self.cache_read_tokens + self.cache_write_tokens @property def total_tokens(self) -> int: return self.prompt_tokens + self.output_tokens @dataclass(frozen=True) class BillingRoute: provider: str model: str base_url: str = "" billing_mode: str = "unknown" @dataclass(frozen=True) class PricingEntry: input_cost_per_million: Optional[Decimal] = None output_cost_per_million: Optional[Decimal] = None cache_read_cost_per_million: Optional[Decimal] = None cache_write_cost_per_million: Optional[Decimal] = None request_cost: Optional[Decimal] = None source: CostSource = "none" source_url: Optional[str] = None pricing_version: Optional[str] = None fetched_at: Optional[datetime] = None @dataclass(frozen=True) class CostResult: amount_usd: Optional[Decimal] status: CostStatus source: CostSource label: str fetched_at: Optional[datetime] = None pricing_version: Optional[str] = None notes: tuple[str, ...] = () _UTC_NOW = lambda: datetime.now(timezone.utc) # Official docs snapshot entries. Models whose published pricing and cache # semantics are stable enough to encode exactly. _OFFICIAL_DOCS_PRICING: Dict[tuple[str, str], PricingEntry] = { ( "anthropic", "claude-opus-4-20250514", ): PricingEntry( input_cost_per_million=Decimal("15.00"), output_cost_per_million=Decimal("75.00"), cache_read_cost_per_million=Decimal("1.50"), cache_write_cost_per_million=Decimal("18.75"), source="official_docs_snapshot", source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", pricing_version="anthropic-prompt-caching-2026-03-16", ), ( "anthropic", "claude-sonnet-4-20250514", ): PricingEntry( input_cost_per_million=Decimal("3.00"), output_cost_per_million=Decimal("15.00"), cache_read_cost_per_million=Decimal("0.30"), cache_write_cost_per_million=Decimal("3.75"), source="official_docs_snapshot", source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", pricing_version="anthropic-prompt-caching-2026-03-16", ), # OpenAI ( "openai", "gpt-4o", ): PricingEntry( input_cost_per_million=Decimal("2.50"), output_cost_per_million=Decimal("10.00"), cache_read_cost_per_million=Decimal("1.25"), source="official_docs_snapshot", source_url="https://openai.com/api/pricing/", pricing_version="openai-pricing-2026-03-16", ), ( "openai", "gpt-4o-mini", ): PricingEntry( input_cost_per_million=Decimal("0.15"), output_cost_per_million=Decimal("0.60"), cache_read_cost_per_million=Decimal("0.075"), source="official_docs_snapshot", source_url="https://openai.com/api/pricing/", pricing_version="openai-pricing-2026-03-16", ), ( "openai", "gpt-4.1", ): PricingEntry( input_cost_per_million=Decimal("2.00"), output_cost_per_million=Decimal("8.00"), cache_read_cost_per_million=Decimal("0.50"), source="official_docs_snapshot", source_url="https://openai.com/api/pricing/", pricing_version="openai-pricing-2026-03-16", ), ( "openai", "gpt-4.1-mini", ): PricingEntry( input_cost_per_million=Decimal("0.40"), output_cost_per_million=Decimal("1.60"), cache_read_cost_per_million=Decimal("0.10"), source="official_docs_snapshot", source_url="https://openai.com/api/pricing/", pricing_version="openai-pricing-2026-03-16", ), ( "openai", "gpt-4.1-nano", ): PricingEntry( input_cost_per_million=Decimal("0.10"), output_cost_per_million=Decimal("0.40"), cache_read_cost_per_million=Decimal("0.025"), source="official_docs_snapshot", source_url="https://openai.com/api/pricing/", pricing_version="openai-pricing-2026-03-16", ), ( "openai", "o3", ): PricingEntry( input_cost_per_million=Decimal("10.00"), output_cost_per_million=Decimal("40.00"), cache_read_cost_per_million=Decimal("2.50"), source="official_docs_snapshot", source_url="https://openai.com/api/pricing/", pricing_version="openai-pricing-2026-03-16", ), ( "openai", "o3-mini", ): PricingEntry( input_cost_per_million=Decimal("1.10"), output_cost_per_million=Decimal("4.40"), cache_read_cost_per_million=Decimal("0.55"), source="official_docs_snapshot", source_url="https://openai.com/api/pricing/", pricing_version="openai-pricing-2026-03-16", ), # Anthropic older models (pre-4.6 generation) ( "anthropic", "claude-3-5-sonnet-20241022", ): PricingEntry( input_cost_per_million=Decimal("3.00"), output_cost_per_million=Decimal("15.00"), cache_read_cost_per_million=Decimal("0.30"), cache_write_cost_per_million=Decimal("3.75"), source="official_docs_snapshot", source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", pricing_version="anthropic-pricing-2026-03-16", ), ( "anthropic", "claude-3-5-haiku-20241022", ): PricingEntry( input_cost_per_million=Decimal("0.80"), output_cost_per_million=Decimal("4.00"), cache_read_cost_per_million=Decimal("0.08"), cache_write_cost_per_million=Decimal("1.00"), source="official_docs_snapshot", source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", pricing_version="anthropic-pricing-2026-03-16", ), ( "anthropic", "claude-3-opus-20240229", ): PricingEntry( input_cost_per_million=Decimal("15.00"), output_cost_per_million=Decimal("75.00"), cache_read_cost_per_million=Decimal("1.50"), cache_write_cost_per_million=Decimal("18.75"), source="official_docs_snapshot", source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", pricing_version="anthropic-pricing-2026-03-16", ), ( "anthropic", "claude-3-haiku-20240307", ): PricingEntry( input_cost_per_million=Decimal("0.25"), output_cost_per_million=Decimal("1.25"), cache_read_cost_per_million=Decimal("0.03"), cache_write_cost_per_million=Decimal("0.30"), source="official_docs_snapshot", source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", pricing_version="anthropic-pricing-2026-03-16", ), # DeepSeek ( "deepseek", "deepseek-chat", ): PricingEntry( input_cost_per_million=Decimal("0.14"), output_cost_per_million=Decimal("0.28"), source="official_docs_snapshot", source_url="https://api-docs.deepseek.com/quick_start/pricing", pricing_version="deepseek-pricing-2026-03-16", ), ( "deepseek", "deepseek-reasoner", ): PricingEntry( input_cost_per_million=Decimal("0.55"), output_cost_per_million=Decimal("2.19"), source="official_docs_snapshot", source_url="https://api-docs.deepseek.com/quick_start/pricing", pricing_version="deepseek-pricing-2026-03-16", ), # Google Gemini ( "google", "gemini-2.5-pro", ): PricingEntry( input_cost_per_million=Decimal("1.25"), output_cost_per_million=Decimal("10.00"), source="official_docs_snapshot", source_url="https://ai.google.dev/pricing", pricing_version="google-pricing-2026-03-16", ), ( "google", "gemini-2.5-flash", ): PricingEntry( input_cost_per_million=Decimal("0.15"), output_cost_per_million=Decimal("0.60"), source="official_docs_snapshot", source_url="https://ai.google.dev/pricing", pricing_version="google-pricing-2026-03-16", ), ( "google", "gemini-2.0-flash", ): PricingEntry( input_cost_per_million=Decimal("0.10"), output_cost_per_million=Decimal("0.40"), source="official_docs_snapshot", source_url="https://ai.google.dev/pricing", pricing_version="google-pricing-2026-03-16", ), } def _to_decimal(value: Any) -> Optional[Decimal]: if value is None: return None try: return Decimal(str(value)) except Exception: return None def _to_int(value: Any) -> int: try: return int(value or 0) except Exception: return 0 def resolve_billing_route( model_name: str, provider: Optional[str] = None, base_url: Optional[str] = None, ) -> BillingRoute: provider_name = (provider or "").strip().lower() base = (base_url or "").strip().lower() model = (model_name or "").strip() if not provider_name and "/" in model: inferred_provider, bare_model = model.split("/", 1) if inferred_provider in {"anthropic", "openai", "google"}: provider_name = inferred_provider model = bare_model if provider_name == "openai-codex": return BillingRoute(provider="openai-codex", model=model, base_url=base_url or "", billing_mode="subscription_included") if provider_name == "openrouter" or "openrouter.ai" in base: return BillingRoute(provider="openrouter", model=model, base_url=base_url or "", billing_mode="official_models_api") if provider_name == "anthropic": return BillingRoute(provider="anthropic", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot") if provider_name == "openai": return BillingRoute(provider="openai", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot") if provider_name in {"custom", "local"} or (base and "localhost" in base): return BillingRoute(provider=provider_name or "custom", model=model, base_url=base_url or "", billing_mode="unknown") return BillingRoute(provider=provider_name or "unknown", model=model.split("/")[-1] if model else "", base_url=base_url or "", billing_mode="unknown") def _lookup_official_docs_pricing(route: BillingRoute) -> Optional[PricingEntry]: return _OFFICIAL_DOCS_PRICING.get((route.provider, route.model.lower())) def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]: return _pricing_entry_from_metadata( fetch_model_metadata(), route.model, source_url="https://openrouter.ai/docs/api/api-reference/models/get-models", pricing_version="openrouter-models-api", ) def _pricing_entry_from_metadata( metadata: Dict[str, Dict[str, Any]], model_id: str, *, source_url: str, pricing_version: str, ) -> Optional[PricingEntry]: if model_id not in metadata: return None pricing = metadata[model_id].get("pricing") or {} prompt = _to_decimal(pricing.get("prompt")) completion = _to_decimal(pricing.get("completion")) request = _to_decimal(pricing.get("request")) cache_read = _to_decimal( pricing.get("cache_read") or pricing.get("cached_prompt") or pricing.get("input_cache_read") ) cache_write = _to_decimal( pricing.get("cache_write") or pricing.get("cache_creation") or pricing.get("input_cache_write") ) if prompt is None and completion is None and request is None: return None def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]: if value is None: return None return value * _ONE_MILLION return PricingEntry( input_cost_per_million=_per_token_to_per_million(prompt), output_cost_per_million=_per_token_to_per_million(completion), cache_read_cost_per_million=_per_token_to_per_million(cache_read), cache_write_cost_per_million=_per_token_to_per_million(cache_write), request_cost=request, source="provider_models_api", source_url=source_url, pricing_version=pricing_version, fetched_at=_UTC_NOW(), ) def get_pricing_entry( model_name: str, provider: Optional[str] = None, base_url: Optional[str] = None, api_key: Optional[str] = None, ) -> Optional[PricingEntry]: route = resolve_billing_route(model_name, provider=provider, base_url=base_url) if route.billing_mode == "subscription_included": return PricingEntry( input_cost_per_million=_ZERO, output_cost_per_million=_ZERO, cache_read_cost_per_million=_ZERO, cache_write_cost_per_million=_ZERO, source="none", pricing_version="included-route", ) if route.provider == "openrouter": return _openrouter_pricing_entry(route) if route.base_url: entry = _pricing_entry_from_metadata( fetch_endpoint_model_metadata(route.base_url, api_key=api_key or ""), route.model, source_url=f"{route.base_url.rstrip('/')}/models", pricing_version="openai-compatible-models-api", ) if entry: return entry return _lookup_official_docs_pricing(route) def normalize_usage( response_usage: Any, *, provider: Optional[str] = None, api_mode: Optional[str] = None, ) -> CanonicalUsage: """Normalize raw API response usage into canonical token buckets. Handles three API shapes: - Anthropic: input_tokens/output_tokens/cache_read_input_tokens/cache_creation_input_tokens - Codex Responses: input_tokens includes cache tokens; input_tokens_details.cached_tokens separates them - OpenAI Chat Completions: prompt_tokens includes cache tokens; prompt_tokens_details.cached_tokens separates them In both Codex and OpenAI modes, input_tokens is derived by subtracting cache tokens from the total — the API contract is that input/prompt totals include cached tokens and the details object breaks them out. """ if not response_usage: return CanonicalUsage() provider_name = (provider or "").strip().lower() mode = (api_mode or "").strip().lower() if mode == "anthropic_messages" or provider_name == "anthropic": input_tokens = _to_int(getattr(response_usage, "input_tokens", 0)) output_tokens = _to_int(getattr(response_usage, "output_tokens", 0)) cache_read_tokens = _to_int(getattr(response_usage, "cache_read_input_tokens", 0)) cache_write_tokens = _to_int(getattr(response_usage, "cache_creation_input_tokens", 0)) elif mode == "codex_responses": input_total = _to_int(getattr(response_usage, "input_tokens", 0)) output_tokens = _to_int(getattr(response_usage, "output_tokens", 0)) details = getattr(response_usage, "input_tokens_details", None) cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0) cache_write_tokens = _to_int( getattr(details, "cache_creation_tokens", 0) if details else 0 ) input_tokens = max(0, input_total - cache_read_tokens - cache_write_tokens) else: prompt_total = _to_int(getattr(response_usage, "prompt_tokens", 0)) output_tokens = _to_int(getattr(response_usage, "completion_tokens", 0)) details = getattr(response_usage, "prompt_tokens_details", None) cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0) cache_write_tokens = _to_int( getattr(details, "cache_write_tokens", 0) if details else 0 ) input_tokens = max(0, prompt_total - cache_read_tokens - cache_write_tokens) reasoning_tokens = 0 output_details = getattr(response_usage, "output_tokens_details", None) if output_details: reasoning_tokens = _to_int(getattr(output_details, "reasoning_tokens", 0)) return CanonicalUsage( input_tokens=input_tokens, output_tokens=output_tokens, cache_read_tokens=cache_read_tokens, cache_write_tokens=cache_write_tokens, reasoning_tokens=reasoning_tokens, ) def estimate_usage_cost( model_name: str, usage: CanonicalUsage, *, provider: Optional[str] = None, base_url: Optional[str] = None, api_key: Optional[str] = None, ) -> CostResult: route = resolve_billing_route(model_name, provider=provider, base_url=base_url) if route.billing_mode == "subscription_included": return CostResult( amount_usd=_ZERO, status="included", source="none", label="included", pricing_version="included-route", ) entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key) if not entry: return CostResult(amount_usd=None, status="unknown", source="none", label="n/a") notes: list[str] = [] amount = _ZERO if usage.input_tokens and entry.input_cost_per_million is None: return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a") if usage.output_tokens and entry.output_cost_per_million is None: return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a") if usage.cache_read_tokens: if entry.cache_read_cost_per_million is None: return CostResult( amount_usd=None, status="unknown", source=entry.source, label="n/a", notes=("cache-read pricing unavailable for route",), ) if usage.cache_write_tokens: if entry.cache_write_cost_per_million is None: return CostResult( amount_usd=None, status="unknown", source=entry.source, label="n/a", notes=("cache-write pricing unavailable for route",), ) if entry.input_cost_per_million is not None: amount += Decimal(usage.input_tokens) * entry.input_cost_per_million / _ONE_MILLION if entry.output_cost_per_million is not None: amount += Decimal(usage.output_tokens) * entry.output_cost_per_million / _ONE_MILLION if entry.cache_read_cost_per_million is not None: amount += Decimal(usage.cache_read_tokens) * entry.cache_read_cost_per_million / _ONE_MILLION if entry.cache_write_cost_per_million is not None: amount += Decimal(usage.cache_write_tokens) * entry.cache_write_cost_per_million / _ONE_MILLION if entry.request_cost is not None and usage.request_count: amount += Decimal(usage.request_count) * entry.request_cost status: CostStatus = "estimated" label = f"~${amount:.2f}" if entry.source == "none" and amount == _ZERO: status = "included" label = "included" if route.provider == "openrouter": notes.append("OpenRouter cost is estimated from the models API until reconciled.") return CostResult( amount_usd=amount, status=status, source=entry.source, label=label, fetched_at=entry.fetched_at, pricing_version=entry.pricing_version, notes=tuple(notes), ) def has_known_pricing( model_name: str, provider: Optional[str] = None, base_url: Optional[str] = None, api_key: Optional[str] = None, ) -> bool: """Check whether we have pricing data for this model+route. Uses direct lookup instead of routing through the full estimation pipeline — avoids creating dummy usage objects just to check status. """ route = resolve_billing_route(model_name, provider=provider, base_url=base_url) if route.billing_mode == "subscription_included": return True entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key) return entry is not None def get_pricing( model_name: str, provider: Optional[str] = None, base_url: Optional[str] = None, api_key: Optional[str] = None, ) -> Dict[str, float]: """Backward-compatible thin wrapper for legacy callers. Returns only non-cache input/output fields when a pricing entry exists. Unknown routes return zeroes. """ entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key) if not entry: return {"input": 0.0, "output": 0.0} return { "input": float(entry.input_cost_per_million or _ZERO), "output": float(entry.output_cost_per_million or _ZERO), } def estimate_cost_usd( model: str, input_tokens: int, output_tokens: int, *, provider: Optional[str] = None, base_url: Optional[str] = None, api_key: Optional[str] = None, ) -> float: """Backward-compatible helper for legacy callers. This uses non-cached input/output only. New code should call `estimate_usage_cost()` with canonical usage buckets. """ result = estimate_usage_cost( model, CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens), provider=provider, base_url=base_url, api_key=api_key, ) return float(result.amount_usd or _ZERO) def format_duration_compact(seconds: float) -> str: if seconds < 60: return f"{seconds:.0f}s" minutes = seconds / 60 if minutes < 60: return f"{minutes:.0f}m" hours = minutes / 60 if hours < 24: remaining_min = int(minutes % 60) return f"{int(hours)}h {remaining_min}m" if remaining_min else f"{int(hours)}h" days = hours / 24 return f"{days:.1f}d" def format_token_count_compact(value: int) -> str: abs_value = abs(int(value)) if abs_value < 1_000: return str(int(value)) sign = "-" if value < 0 else "" units = ((1_000_000_000, "B"), (1_000_000, "M"), (1_000, "K")) for threshold, suffix in units: if abs_value >= threshold: scaled = abs_value / threshold if scaled < 10: text = f"{scaled:.2f}" elif scaled < 100: text = f"{scaled:.1f}" else: text = f"{scaled:.0f}" text = text.rstrip("0").rstrip(".") return f"{sign}{text}{suffix}" return f"{value:,}"