"""Usage aggregation for app-attributed ML Intern spend.""" import asyncio import logging from datetime import UTC, datetime, timedelta from typing import Any from zoneinfo import ZoneInfo, ZoneInfoNotFoundError import httpx from agent.core.usage_metrics import summarize_sandbox_lifecycle USAGE_EVENT_TYPES = ( "llm_call", "hf_job_complete", "sandbox_create", "sandbox_destroy", ) logger = logging.getLogger(__name__) HF_BILLING_USAGE_V2_URL = "https://huggingface.co/api/settings/billing/usage-v2" HF_BILLING_USAGE_BY_INFERENCE_SESSION_URL = ( "https://huggingface.co/api/settings/billing/usage-by-inference-session" ) HF_BILLING_URL = "https://huggingface.co/settings/billing" HF_INFERENCE_PROVIDERS_PRICING_URL = ( "https://huggingface.co/docs/inference-providers/en/pricing" ) HF_JOBS_PRICING_URL = "https://huggingface.co/docs/hub/jobs-pricing" def _utc(dt: datetime) -> datetime: if dt.tzinfo is None: return dt.replace(tzinfo=UTC) return dt.astimezone(UTC) def _iso(dt: datetime | None) -> str | None: if dt is None: return None return _utc(dt).isoformat().replace("+00:00", "Z") def _coerce_float(value: Any) -> float: if isinstance(value, bool) or value is None: return 0.0 try: return float(value) except (TypeError, ValueError): return 0.0 def _coerce_int(value: Any) -> int: if isinstance(value, bool) or value is None: return 0 try: return int(value) except (TypeError, ValueError): return 0 def _nano_usd_to_usd(value: Any) -> float: return _coerce_float(value) / 1_000_000_000 def _micro_usd_to_usd(value: Any) -> float: return _coerce_float(value) / 1_000_000 def _cents_to_usd(value: Any) -> float: return _coerce_float(value) / 100 def _coerce_timezone(timezone_name: str | None) -> ZoneInfo | None: if not timezone_name: return None try: return ZoneInfo(timezone_name) except (ZoneInfoNotFoundError, ValueError): return None def _normalize_event_timestamp( dt: datetime, *, timezone_name: str | None = None, ) -> datetime: if dt.tzinfo is not None: return _utc(dt) timezone = _coerce_timezone(timezone_name) if timezone is not None: return dt.replace(tzinfo=timezone).astimezone(UTC) return dt.astimezone(UTC) def _parse_timestamp( value: Any, *, timezone_name: str | None = None ) -> datetime | None: if isinstance(value, datetime): return _normalize_event_timestamp(value, timezone_name=timezone_name) if not isinstance(value, str) or not value: return None try: return _normalize_event_timestamp( datetime.fromisoformat(value.replace("Z", "+00:00")), timezone_name=timezone_name, ) except ValueError: return None def event_created_at( event: dict[str, Any], *, timezone_name: str | None = None, ) -> datetime | None: return _parse_timestamp( event.get("created_at") or event.get("timestamp"), timezone_name=timezone_name, ) def resolve_usage_windows( timezone_name: str | None, *, now: datetime | None = None, ) -> dict[str, datetime | str]: """Return UTC month window for a browser timezone.""" try: tz = ZoneInfo(timezone_name or "UTC") except (ZoneInfoNotFoundError, ValueError): tz = ZoneInfo("UTC") now_utc = _utc(now or datetime.now(UTC)) local_now = now_utc.astimezone(tz) month_local = local_now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) return { "timezone": tz.key, "now_utc": now_utc, "month_start_utc": month_local.astimezone(UTC), } def _empty_bucket( *, session_id: str | None = None, ) -> dict[str, Any]: return { "session_id": session_id, "total_usd": 0.0, "inference_usd": 0.0, "hf_jobs_estimated_usd": 0.0, "sandbox_estimated_usd": 0.0, "llm_calls": 0, "hf_jobs_count": 0, "sandbox_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "cache_read_tokens": 0, "cache_creation_tokens": 0, "total_tokens": 0, "hf_jobs_billable_seconds_estimate": 0, "sandbox_billable_seconds_estimate": 0, } def _empty_hf_account_bucket( *, window_start: datetime | None = None, window_end: datetime | None = None, timezone: str | None = None, ) -> dict[str, Any]: return { "window_start": _iso(window_start), "window_end": _iso(window_end), "timezone": timezone, "total_usd": 0.0, "inference_providers_usd": 0.0, "hf_jobs_usd": 0.0, "inference_provider_requests": 0, "hf_jobs_minutes": 0.0, } def aggregate_usage_events( events: list[dict[str, Any]], *, session_id: str | None = None, ) -> dict[str, Any]: bucket = _empty_bucket(session_id=session_id) for event in events: event_type = event.get("event_type") data = event.get("data") or {} if event_type == "llm_call": bucket["llm_calls"] += 1 bucket["inference_usd"] += _coerce_float(data.get("cost_usd")) prompt_tokens = _coerce_int(data.get("prompt_tokens")) completion_tokens = _coerce_int(data.get("completion_tokens")) cache_read_tokens = _coerce_int(data.get("cache_read_tokens")) cache_creation_tokens = _coerce_int(data.get("cache_creation_tokens")) total_tokens = _coerce_int(data.get("total_tokens")) or ( prompt_tokens + completion_tokens + cache_read_tokens + cache_creation_tokens ) bucket["prompt_tokens"] += prompt_tokens bucket["completion_tokens"] += completion_tokens bucket["cache_read_tokens"] += cache_read_tokens bucket["cache_creation_tokens"] += cache_creation_tokens bucket["total_tokens"] += total_tokens elif event_type == "hf_job_complete": bucket["hf_jobs_count"] += 1 bucket["hf_jobs_estimated_usd"] += _coerce_float( data.get("estimated_cost_usd") ) bucket["hf_jobs_billable_seconds_estimate"] += _coerce_int( data.get("billable_seconds_estimate") or data.get("wall_time_s") ) elif event_type == "sandbox_destroy": # Sandbox costs are paired and added after the main pass so the # create event can provide hardware pricing metadata. continue _aggregate_sandbox_usage(events, bucket) bucket["inference_usd"] = round(bucket["inference_usd"], 6) bucket["hf_jobs_estimated_usd"] = round(bucket["hf_jobs_estimated_usd"], 6) bucket["sandbox_estimated_usd"] = round(bucket["sandbox_estimated_usd"], 6) bucket["total_usd"] = round( ( bucket["inference_usd"] + bucket["hf_jobs_estimated_usd"] + bucket["sandbox_estimated_usd"] ), 6, ) return bucket def _aggregate_sandbox_usage( events: list[dict[str, Any]], bucket: dict[str, Any], ) -> None: lifecycle_events = [ (index, event) for index, event in enumerate(events) if event.get("event_type") in {"sandbox_create", "sandbox_destroy"} ] sandbox = summarize_sandbox_lifecycle(lifecycle_events) bucket["sandbox_count"] += sandbox["matched_pairs"] bucket["sandbox_billable_seconds_estimate"] += sandbox["billable_seconds_estimate"] bucket["sandbox_estimated_usd"] += sandbox["estimated_usd"] def _account_bucket_from_billing_usage( payload: dict[str, Any] | None, *, window_start: datetime, window_end: datetime, timezone: str, ) -> dict[str, Any]: bucket = _empty_hf_account_bucket( window_start=window_start, window_end=window_end, timezone=timezone, ) usage = payload.get("usage") if isinstance(payload, dict) else {} if not isinstance(usage, dict): return bucket inference = usage.get("inferenceProviders") if not isinstance(inference, dict): inference = {} jobs = usage.get("jobs") if not isinstance(jobs, dict): jobs = {} bucket["inference_providers_usd"] = round( _nano_usd_to_usd(inference.get("usedNanoUsd")), 6, ) bucket["hf_jobs_usd"] = round(_micro_usd_to_usd(jobs.get("usedMicroUsd")), 6) bucket["inference_provider_requests"] = _coerce_int(inference.get("numRequests")) bucket["hf_jobs_minutes"] = round(_coerce_float(jobs.get("totalMinutes")), 3) bucket["total_usd"] = round( bucket["inference_providers_usd"] + bucket["hf_jobs_usd"], 6, ) return bucket def _session_bucket_from_inference_session_usage( payload: dict[str, Any] | None, *, session_id: str, window_start: datetime, window_end: datetime, timezone: str, ) -> dict[str, Any]: bucket = _empty_hf_account_bucket( window_start=window_start, window_end=window_end, timezone=timezone, ) periods = payload.get("periods") if isinstance(payload, dict) else [] if not isinstance(periods, list): return bucket cost_cents = 0.0 request_count = 0 for period in periods: if not isinstance(period, dict): continue sessions = period.get("sessions") if not isinstance(sessions, list): continue for session in sessions: if not isinstance(session, dict) or session.get("id") != session_id: continue cost_cents += _coerce_float(session.get("costCents")) request_count += _coerce_int(session.get("requestCount")) bucket["inference_providers_usd"] = round(_cents_to_usd(cost_cents), 6) bucket["inference_provider_requests"] = request_count bucket["total_usd"] = bucket["inference_providers_usd"] return bucket def _inference_credits_from_billing_usage( payload: dict[str, Any] | None, ) -> dict[str, Any] | None: usage = payload.get("usage") if isinstance(payload, dict) else {} if not isinstance(usage, dict): return None inference = usage.get("inferenceProviders") if not isinstance(inference, dict): return None included_usd = _nano_usd_to_usd(inference.get("includedNanoUsd")) used_usd = _nano_usd_to_usd(inference.get("usedNanoUsd")) limit_usd = _nano_usd_to_usd(inference.get("limitNanoUsd")) return { "included_usd": round(included_usd, 6), "used_usd": round(used_usd, 6), "remaining_included_usd": round(max(0.0, included_usd - used_usd), 6), "limit_usd": round(limit_usd, 6), "remaining_limit_usd": round(max(0.0, limit_usd - used_usd), 6), "num_requests": _coerce_int(inference.get("numRequests")), "period_start": inference.get("periodStart"), "period_end": inference.get("periodEnd"), } async def _fetch_hf_billing_usage_v2( hf_token: str, *, start: datetime, end: datetime, ) -> dict[str, Any] | None: start_ts = max(1, int(_utc(start).timestamp())) end_ts = max(start_ts + 1, int(_utc(end).timestamp())) try: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.get( HF_BILLING_USAGE_V2_URL, params={"startDate": start_ts, "endDate": end_ts}, headers={"Authorization": f"Bearer {hf_token}"}, ) if response.status_code != 200: logger.debug( "HF billing usage-v2 failed: status=%s body=%s", response.status_code, response.text[:200], ) return None payload = response.json() return payload if isinstance(payload, dict) else None except (httpx.HTTPError, ValueError) as e: logger.debug("HF billing usage-v2 failed: %s", e) return None async def _fetch_hf_inference_session_usage( hf_token: str, *, start: datetime, end: datetime, ) -> dict[str, Any] | None: start_ts = _iso(start) end_ts = _iso(max(_utc(end), _utc(start) + timedelta(seconds=1))) try: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.get( HF_BILLING_USAGE_BY_INFERENCE_SESSION_URL, params={"startDate": start_ts, "endDate": end_ts}, headers={"Authorization": f"Bearer {hf_token}"}, ) if response.status_code != 200: logger.debug( "HF inference session usage failed: status=%s body=%s", response.status_code, response.text[:200], ) return None payload = response.json() return payload if isinstance(payload, dict) else None except (httpx.HTTPError, ValueError) as e: logger.debug("HF inference session usage failed: %s", e) return None def _session_usage_window_started_at( manager: Any, session_id: str | None ) -> datetime | None: if not session_id: return None agent_session = getattr(manager, "sessions", {}).get(session_id) usage_window_started_at = getattr(agent_session, "usage_window_started_at", None) if isinstance(usage_window_started_at, datetime): return _utc(usage_window_started_at) created_at = getattr(agent_session, "created_at", None) if isinstance(created_at, datetime): return _utc(created_at) return None def _session_inference_billing_session_id( manager: Any, session_id: str | None ) -> str | None: if not session_id: return None agent_session = getattr(manager, "sessions", {}).get(session_id) billing_session_id = getattr(agent_session, "inference_billing_session_id", None) if isinstance(billing_session_id, str) and billing_session_id: return billing_session_id runtime_session = getattr(agent_session, "session", None) billing_session_id = getattr(runtime_session, "inference_billing_session_id", None) if isinstance(billing_session_id, str) and billing_session_id: return billing_session_id return None async def _load_persisted_session_usage_window_metadata( manager: Any, session_id: str | None, ) -> tuple[datetime | None, str | None]: if not session_id: return None, None store = manager._store() if not getattr(store, "enabled", False) or not hasattr(store, "load_session"): return None, None loaded = await store.load_session(session_id) metadata = loaded.get("metadata") if isinstance(loaded, dict) else None started_at = None billing_session_id = None if isinstance(metadata, dict): started_at = metadata.get("usage_window_started_at") or metadata.get( "created_at" ) raw_billing_session_id = metadata.get("inference_billing_session_id") if isinstance(raw_billing_session_id, str) and raw_billing_session_id: billing_session_id = raw_billing_session_id if isinstance(started_at, datetime): return _utc(started_at), billing_session_id parsed = _parse_timestamp(started_at) return (_utc(parsed) if parsed is not None else None), billing_session_id async def _build_hf_account_usage( manager: Any, *, hf_token: str | None, session_id: str | None, timezone: str, now_utc: datetime, month_start: datetime, ) -> dict[str, Any]: account_usage: dict[str, Any] = { "source": "hf_billing", "available": False, "current_session": None, "month": None, "inference_providers_credits": None, } if not hf_token: account_usage["error"] = "missing_hf_token" return account_usage session_start = _session_usage_window_started_at(manager, session_id) billing_session_id = _session_inference_billing_session_id(manager, session_id) if session_start is None or billing_session_id is None: ( persisted_start, persisted_billing_session_id, ) = await _load_persisted_session_usage_window_metadata(manager, session_id) if session_start is None: session_start = persisted_start if billing_session_id is None: billing_session_id = persisted_billing_session_id window_tasks: dict[str, tuple[datetime, asyncio.Task[dict[str, Any] | None]]] = { "month": ( month_start, asyncio.create_task( _fetch_hf_billing_usage_v2(hf_token, start=month_start, end=now_utc) ), ), } if billing_session_id is not None and session_start is not None: window_tasks["current_session"] = ( session_start, asyncio.create_task( _fetch_hf_inference_session_usage( hf_token, start=session_start, end=now_utc, ) ), ) payloads: dict[str, dict[str, Any] | None] = {} for name, (_, task) in window_tasks.items(): payloads[name] = await task any_payload = any(isinstance(payload, dict) for payload in payloads.values()) account_usage["available"] = any_payload if not any_payload: account_usage["error"] = "billing_usage_unavailable" return account_usage for name, (start, _) in window_tasks.items(): payload = payloads.get(name) if payload is None: continue if name == "current_session" and billing_session_id is not None: account_usage[name] = _session_bucket_from_inference_session_usage( payload, session_id=billing_session_id, window_start=start, window_end=now_utc, timezone=timezone, ) else: account_usage[name] = _account_bucket_from_billing_usage( payload, window_start=start, window_end=now_utc, timezone=timezone, ) account_usage["inference_providers_credits"] = ( _inference_credits_from_billing_usage(payloads.get("month")) ) return account_usage async def build_hf_billing_snapshot( manager: Any, *, hf_token: str | None, session_id: str | None, timezone_name: str | None = None, now: datetime | None = None, ) -> dict[str, Any]: """Return a dataset-safe HF billing rollup for the session window. This intentionally omits monthly account totals and credit-limit details. The snapshot is an account-window delta, not per-call attribution. """ windows = resolve_usage_windows(timezone_name, now=now) timezone = str(windows["timezone"]) now_utc = windows["now_utc"] snapshot: dict[str, Any] = { "billing_scope": "account_window_delta", "hf_billing": { "source": "hf_billing_usage_v2", "available": False, "error": None, "current_session": None, }, } hf_billing = snapshot["hf_billing"] if not hf_token: hf_billing["error"] = "missing_hf_token" return snapshot if not session_id: hf_billing["error"] = "missing_session_id" return snapshot session_start = _session_usage_window_started_at(manager, session_id) if session_start is None: session_start, _ = await _load_persisted_session_usage_window_metadata( manager, session_id, ) if session_start is None: hf_billing["error"] = "missing_session_window" return snapshot payload = await _fetch_hf_billing_usage_v2( hf_token, start=session_start, end=now_utc, ) if not isinstance(payload, dict): hf_billing["error"] = "billing_usage_unavailable" return snapshot hf_billing["available"] = True hf_billing["current_session"] = _account_bucket_from_billing_usage( payload, window_start=session_start, window_end=now_utc, timezone=timezone, ) return snapshot def _event_in_window( event: dict[str, Any], *, start: datetime | None = None, end: datetime | None = None, timezone_name: str | None = None, ) -> bool: if start is None and end is None: return True created_at = event_created_at(event, timezone_name=timezone_name) if created_at is None: return False if start is not None and created_at < _utc(start): return False if end is not None and created_at >= _utc(end): return False return True def _events_from_runtime_session(agent_session: Any) -> list[dict[str, Any]]: events: list[dict[str, Any]] = [] for raw in getattr(agent_session.session, "logged_events", []) or []: if raw.get("event_type") not in USAGE_EVENT_TYPES: continue events.append( { "session_id": agent_session.session_id, "event_type": raw.get("event_type"), "data": raw.get("data") or {}, "timestamp": raw.get("timestamp"), } ) return events def _runtime_sessions_for_user(manager: Any, user_id: str) -> list[Any]: sessions = list(getattr(manager, "sessions", {}).values()) if user_id == "dev": return sessions return [session for session in sessions if session.user_id == user_id] async def _load_usage_events( manager: Any, *, user_id: str, session_id: str | None = None, start: datetime | None = None, end: datetime | None = None, timezone_name: str | None = None, ) -> list[dict[str, Any]]: store = manager._store() if getattr(store, "enabled", False): return await store.load_usage_events( user_id, session_id=session_id, start=start, end=end, ) events: list[dict[str, Any]] = [] for agent_session in _runtime_sessions_for_user(manager, user_id): if session_id is not None and agent_session.session_id != session_id: continue for event in _events_from_runtime_session(agent_session): if _event_in_window( event, start=start, end=end, timezone_name=timezone_name, ): events.append(event) return events async def build_usage_response( manager: Any, *, user_id: str, hf_token: str | None = None, session_id: str | None = None, timezone_name: str | None = None, now: datetime | None = None, ) -> dict[str, Any]: windows = resolve_usage_windows(timezone_name, now=now) timezone = str(windows["timezone"]) now_utc = windows["now_utc"] month_start = windows["month_start_utc"] session_events: list[dict[str, Any]] = [] if session_id: session_start = _session_usage_window_started_at(manager, session_id) if session_start is None: session_start, _ = await _load_persisted_session_usage_window_metadata( manager, session_id, ) session_events = await _load_usage_events( manager, user_id=user_id, session_id=session_id, start=session_start, ) hf_account = await _build_hf_account_usage( manager, hf_token=hf_token, session_id=session_id, timezone=timezone, now_utc=now_utc, month_start=month_start, ) return { "source": "app_telemetry", "currency": "USD", "generated_at": _iso(now_utc), "timezone": timezone, "session": ( aggregate_usage_events(session_events, session_id=session_id) if session_id else None ), "hf_account": hf_account, "links": { "hf_billing": HF_BILLING_URL, "inference_providers_pricing": HF_INFERENCE_PROVIDERS_PRICING_URL, "jobs_pricing": HF_JOBS_PRICING_URL, }, }