Spaces:
Running
Running
| """Conservative cost estimates for auto-approved infrastructure actions.""" | |
| import os | |
| import re | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import httpx | |
| OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") | |
| JOBS_HARDWARE_URL = f"{OPENID_PROVIDER_URL}/api/jobs/hardware" | |
| JOBS_PRICE_CACHE_TTL_S = 6 * 60 * 60 | |
| DEFAULT_JOB_TIMEOUT_HOURS = 0.5 | |
| DEFAULT_SANDBOX_RESERVATION_HOURS = 1.0 | |
| # Static fallback prices are intentionally conservative enough for a budget | |
| # guard. The live /api/jobs/hardware catalog wins whenever it is reachable. | |
| HF_JOBS_PRICE_USD_PER_HOUR: dict[str, float] = { | |
| "cpu-basic": 0.05, | |
| "cpu-upgrade": 0.25, | |
| "cpu-performance": 0.50, | |
| "cpu-xl": 1.00, | |
| "t4-small": 0.60, | |
| "t4-medium": 0.90, | |
| "l4x1": 1.00, | |
| "l4x4": 4.00, | |
| "l40sx1": 2.00, | |
| "l40sx4": 8.00, | |
| "l40sx8": 16.00, | |
| "a10g-small": 1.00, | |
| "a10g-large": 2.00, | |
| "a10g-largex2": 4.00, | |
| "a10g-largex4": 8.00, | |
| "a100-large": 4.00, | |
| "a100x4": 16.00, | |
| "a100x8": 32.00, | |
| "h200": 10.00, | |
| "h200x2": 20.00, | |
| "h200x4": 40.00, | |
| "h200x8": 80.00, | |
| "inf2x6": 6.00, | |
| } | |
| SPACE_PRICE_USD_PER_HOUR: dict[str, float] = { | |
| "cpu-basic": 0.0, | |
| "cpu-upgrade": 0.05, | |
| "cpu-performance": 0.50, | |
| "cpu-xl": 1.00, | |
| "t4-small": 0.60, | |
| "t4-medium": 0.90, | |
| "l4x1": 1.00, | |
| "l4x4": 4.00, | |
| "l40sx1": 2.00, | |
| "l40sx4": 8.00, | |
| "l40sx8": 16.00, | |
| "a10g-small": 1.00, | |
| "a10g-large": 2.00, | |
| "a10g-largex2": 4.00, | |
| "a10g-largex4": 8.00, | |
| "a100-large": 4.00, | |
| "a100x4": 16.00, | |
| "a100x8": 32.00, | |
| "h200": 10.00, | |
| "h200x2": 20.00, | |
| "h200x4": 40.00, | |
| "h200x8": 80.00, | |
| "inf2x6": 6.00, | |
| } | |
| _DURATION_RE = re.compile(r"^\s*(\d+(?:\.\d+)?)\s*([smhd]?)\s*$", re.IGNORECASE) | |
| _PRICE_RE = re.compile(r"(\d+(?:\.\d+)?)") | |
| _jobs_price_cache: tuple[float, dict[str, float]] | None = None | |
| class CostEstimate: | |
| """Estimated cost for a tool call. | |
| ``estimated_cost_usd=None`` means the call may be billable but we could not | |
| estimate it safely, so auto-approval should fall back to a human decision. | |
| """ | |
| estimated_cost_usd: float | None | |
| billable: bool | |
| block_reason: str | None = None | |
| label: str | None = None | |
| def parse_timeout_hours( | |
| value: Any, *, default_hours: float = DEFAULT_JOB_TIMEOUT_HOURS | |
| ) -> float | None: | |
| """Parse HF timeout values into hours. | |
| Strings accept ``s``, ``m``, ``h``, or ``d`` suffixes. Numeric values are | |
| treated as seconds, matching the Hub client's typed timeout parameter. | |
| """ | |
| if value is None or value == "": | |
| return default_hours | |
| if isinstance(value, bool): | |
| return None | |
| if isinstance(value, int | float): | |
| seconds = float(value) | |
| return seconds / 3600 if seconds > 0 else None | |
| if not isinstance(value, str): | |
| return None | |
| match = _DURATION_RE.match(value) | |
| if not match: | |
| return None | |
| amount = float(match.group(1)) | |
| unit = match.group(2).lower() or "s" | |
| if amount <= 0: | |
| return None | |
| if unit == "s": | |
| return amount / 3600 | |
| if unit == "m": | |
| return amount / 60 | |
| if unit == "h": | |
| return amount | |
| if unit == "d": | |
| return amount * 24 | |
| return None | |
| def _extract_flavor(item: dict[str, Any]) -> str | None: | |
| for key in ("flavor", "name", "id", "value", "hardware", "hardware_flavor"): | |
| value = item.get(key) | |
| if isinstance(value, str) and value: | |
| return value | |
| return None | |
| def _coerce_price(value: Any) -> float | None: | |
| if isinstance(value, bool) or value is None: | |
| return None | |
| if isinstance(value, int | float): | |
| return float(value) if value >= 0 else None | |
| if isinstance(value, str): | |
| match = _PRICE_RE.search(value.replace(",", "")) | |
| if match: | |
| return float(match.group(1)) | |
| return None | |
| def _extract_hourly_price(item: dict[str, Any]) -> float | None: | |
| for key in ( | |
| "price", | |
| "price_usd", | |
| "priceUsd", | |
| "price_per_hour", | |
| "pricePerHour", | |
| "hourly_price", | |
| "hourlyPrice", | |
| "usd_per_hour", | |
| "usdPerHour", | |
| ): | |
| price = _coerce_price(item.get(key)) | |
| if price is not None: | |
| return price | |
| for key in ("pricing", "billing", "cost"): | |
| nested = item.get(key) | |
| if isinstance(nested, dict): | |
| price = _extract_hourly_price(nested) | |
| if price is not None: | |
| return price | |
| return None | |
| def _iter_hardware_items(payload: Any): | |
| if isinstance(payload, list): | |
| for item in payload: | |
| yield from _iter_hardware_items(item) | |
| elif isinstance(payload, dict): | |
| if _extract_flavor(payload): | |
| yield payload | |
| for key in ("hardware", "flavors", "items", "data", "jobs"): | |
| child = payload.get(key) | |
| if child is not None: | |
| yield from _iter_hardware_items(child) | |
| def _parse_jobs_price_catalog(payload: Any) -> dict[str, float]: | |
| prices: dict[str, float] = {} | |
| for item in _iter_hardware_items(payload): | |
| flavor = _extract_flavor(item) | |
| price = _extract_hourly_price(item) | |
| if flavor and price is not None: | |
| prices[flavor] = price | |
| return prices | |
| async def hf_jobs_price_catalog() -> dict[str, float]: | |
| """Return live HF Jobs hourly prices, falling back to static prices.""" | |
| global _jobs_price_cache | |
| now = time.monotonic() | |
| if _jobs_price_cache and now - _jobs_price_cache[0] < JOBS_PRICE_CACHE_TTL_S: | |
| return dict(_jobs_price_cache[1]) | |
| prices: dict[str, float] = {} | |
| try: | |
| async with httpx.AsyncClient(timeout=3.0) as client: | |
| response = await client.get(JOBS_HARDWARE_URL) | |
| if response.status_code == 200: | |
| prices = _parse_jobs_price_catalog(response.json()) | |
| except (httpx.HTTPError, ValueError): | |
| prices = {} | |
| if not prices: | |
| prices = dict(HF_JOBS_PRICE_USD_PER_HOUR) | |
| else: | |
| prices = {**HF_JOBS_PRICE_USD_PER_HOUR, **prices} | |
| _jobs_price_cache = (now, prices) | |
| return dict(prices) | |
| async def estimate_hf_job_cost(args: dict[str, Any]) -> CostEstimate: | |
| flavor = str( | |
| args.get("hardware_flavor") | |
| or args.get("flavor") | |
| or args.get("hardware") | |
| or "cpu-basic" | |
| ) | |
| timeout_hours = parse_timeout_hours(args.get("timeout")) | |
| if timeout_hours is None: | |
| return CostEstimate( | |
| estimated_cost_usd=None, | |
| billable=True, | |
| block_reason=f"Could not parse HF job timeout: {args.get('timeout')!r}.", | |
| label=flavor, | |
| ) | |
| prices = await hf_jobs_price_catalog() | |
| price = prices.get(flavor) | |
| if price is None: | |
| return CostEstimate( | |
| estimated_cost_usd=None, | |
| billable=True, | |
| block_reason=f"No price is available for HF job hardware '{flavor}'.", | |
| label=flavor, | |
| ) | |
| return CostEstimate( | |
| estimated_cost_usd=round(price * timeout_hours, 4), | |
| billable=price > 0, | |
| label=flavor, | |
| ) | |
| async def estimate_sandbox_cost( | |
| args: dict[str, Any], *, session: Any = None | |
| ) -> CostEstimate: | |
| if session is not None and getattr(session, "sandbox", None): | |
| return CostEstimate(estimated_cost_usd=0.0, billable=False, label="existing") | |
| hardware = str(args.get("hardware") or "cpu-basic") | |
| price = SPACE_PRICE_USD_PER_HOUR.get(hardware) | |
| if price is None: | |
| return CostEstimate( | |
| estimated_cost_usd=None, | |
| billable=True, | |
| block_reason=f"No price is available for sandbox hardware '{hardware}'.", | |
| label=hardware, | |
| ) | |
| return CostEstimate( | |
| estimated_cost_usd=round(price * DEFAULT_SANDBOX_RESERVATION_HOURS, 4), | |
| billable=price > 0, | |
| label=hardware, | |
| ) | |
| async def estimate_tool_cost( | |
| tool_name: str, args: dict[str, Any], *, session: Any = None | |
| ) -> CostEstimate: | |
| if tool_name == "sandbox_create": | |
| return await estimate_sandbox_cost(args, session=session) | |
| if tool_name == "hf_jobs": | |
| return await estimate_hf_job_cost(args) | |
| return CostEstimate(estimated_cost_usd=0.0, billable=False) | |