Spaces:
Running
Running
| """ | |
| Management of token, request, and time budgets. | |
| Provides cost control at the graph level and at individual node level. | |
| """ | |
| from collections.abc import Callable | |
| from datetime import UTC, datetime | |
| from typing import Any | |
| from pydantic import BaseModel, ConfigDict | |
| __all__ = [ | |
| "Budget", | |
| "BudgetConfig", | |
| "BudgetTracker", | |
| "NodeBudget", | |
| ] | |
| class Budget(BaseModel): | |
| """Tracks a resource limit accounting for used and reserved amounts.""" | |
| limit: float | |
| used: float = 0.0 | |
| reserved: float = 0.0 | |
| def available(self) -> float: | |
| """Remaining available resource (excluding reserved amount).""" | |
| return max(0.0, self.limit - self.used - self.reserved) | |
| def remaining(self) -> float: | |
| """Remaining resource ignoring the reserve (limit - used).""" | |
| return max(0.0, self.limit - self.used) | |
| def usage_ratio(self) -> float: | |
| """Fraction of the resource limit that has been consumed.""" | |
| if self.limit <= 0: | |
| return 0.0 | |
| return self.used / self.limit | |
| def is_exhausted(self) -> bool: | |
| """True if no available resource remains.""" | |
| return self.available <= 0 | |
| def can_spend(self, amount: float) -> bool: | |
| """Check whether the available resource is sufficient for the given amount.""" | |
| return self.available >= amount | |
| def spend(self, amount: float) -> bool: | |
| """Consume the resource if available; return True on success.""" | |
| if not self.can_spend(amount): | |
| return False | |
| self.used += amount | |
| return True | |
| def reserve(self, amount: float) -> bool: | |
| """Reserve resource for a future operation.""" | |
| if self.available < amount: | |
| return False | |
| self.reserved += amount | |
| return True | |
| def release_reservation(self, amount: float) -> None: | |
| """Release a portion of the reservation.""" | |
| self.reserved = max(0.0, self.reserved - amount) | |
| def commit_reservation(self, amount: float) -> None: | |
| """Move up to the given amount from reservation into usage.""" | |
| actual = min(amount, self.reserved) | |
| self.reserved -= actual | |
| self.used += actual | |
| def reset(self) -> None: | |
| """Reset used and reserved amounts to zero.""" | |
| self.used = 0.0 | |
| self.reserved = 0.0 | |
| def to_dict(self) -> dict[str, Any]: | |
| """Serialize the budget to a dictionary.""" | |
| return { | |
| "limit": self.limit, | |
| "used": self.used, | |
| "reserved": self.reserved, | |
| "available": self.available, | |
| "usage_ratio": self.usage_ratio, | |
| } | |
| class NodeBudget(BaseModel): | |
| """Per-node limits for tokens, requests, time, and message lengths.""" | |
| node_id: str | |
| tokens: Budget | None = None | |
| requests: Budget | None = None | |
| time_seconds: Budget | None = None | |
| max_prompt_length: int | None = None | |
| max_response_length: int | None = None | |
| def can_execute(self, estimated_tokens: int = 0) -> tuple[bool, str | None]: | |
| """Check whether a step can be executed given the estimated token count.""" | |
| if self.tokens and not self.tokens.can_spend(estimated_tokens): | |
| return False, f"Token budget exhausted for node {self.node_id}" | |
| if self.requests and not self.requests.can_spend(1): | |
| return False, f"Request budget exhausted for node {self.node_id}" | |
| return True, None | |
| def record_usage( | |
| self, | |
| tokens: int = 0, | |
| time_seconds: float = 0.0, | |
| ) -> None: | |
| """Record actual resource consumption for the node.""" | |
| if self.tokens: | |
| self.tokens.spend(tokens) | |
| if self.requests: | |
| self.requests.spend(1) | |
| if self.time_seconds: | |
| self.time_seconds.spend(time_seconds) | |
| def to_dict(self) -> dict[str, Any]: | |
| """Serialize the node budget to a dictionary.""" | |
| return { | |
| "node_id": self.node_id, | |
| "tokens": self.tokens.to_dict() if self.tokens else None, | |
| "requests": self.requests.to_dict() if self.requests else None, | |
| "time_seconds": self.time_seconds.to_dict() if self.time_seconds else None, | |
| "limits": { | |
| "max_prompt_length": self.max_prompt_length, | |
| "max_response_length": self.max_response_length, | |
| }, | |
| } | |
| class BudgetConfig(BaseModel): | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| """Configuration for global and per-component execution limits.""" | |
| total_token_limit: int | None = None | |
| total_request_limit: int | None = None | |
| total_time_limit_seconds: float | None = None | |
| node_token_limit: int | None = None | |
| node_request_limit: int | None = None | |
| node_time_limit_seconds: float | None = None | |
| max_prompt_length: int | None = None | |
| max_response_length: int | None = None | |
| warn_at_usage_ratio: float = 0.8 | |
| on_budget_warning: Callable[[str, Budget], None] | None = None | |
| on_budget_exceeded: Callable[[str, Budget], None] | None = None | |
| class BudgetTracker: | |
| """Tracks global and per-node budgets and issues warnings when thresholds are approached.""" | |
| def __init__(self, config: BudgetConfig | None = None): | |
| self.config = config or BudgetConfig() | |
| self._global_tokens = Budget(limit=float(self.config.total_token_limit or float("inf"))) | |
| self._global_requests = Budget(limit=float(self.config.total_request_limit or float("inf"))) | |
| self._global_time = Budget(limit=self.config.total_time_limit_seconds or float("inf")) | |
| self._node_budgets: dict[str, NodeBudget] = {} | |
| self._start_time: datetime | None = None | |
| def start(self) -> None: | |
| """Record the start time for time-budget tracking.""" | |
| self._start_time = datetime.now(UTC) | |
| def get_elapsed_seconds(self) -> float: | |
| """Return elapsed seconds since start() was called.""" | |
| if self._start_time is None: | |
| return 0.0 | |
| return (datetime.now(UTC) - self._start_time).total_seconds() | |
| def get_node_budget(self, node_id: str) -> NodeBudget: | |
| """Return (or create) the budget for the given node.""" | |
| if node_id not in self._node_budgets: | |
| self._node_budgets[node_id] = NodeBudget( | |
| node_id=node_id, | |
| tokens=Budget(limit=float(self.config.node_token_limit or float("inf"))) | |
| if self.config.node_token_limit | |
| else None, | |
| requests=Budget(limit=float(self.config.node_request_limit or float("inf"))) | |
| if self.config.node_request_limit | |
| else None, | |
| time_seconds=Budget(limit=self.config.node_time_limit_seconds or float("inf")) | |
| if self.config.node_time_limit_seconds | |
| else None, | |
| max_prompt_length=self.config.max_prompt_length, | |
| max_response_length=self.config.max_response_length, | |
| ) | |
| return self._node_budgets[node_id] | |
| def can_execute( | |
| self, | |
| node_id: str, | |
| estimated_tokens: int = 0, | |
| ) -> tuple[bool, str | None]: | |
| """Check whether a step can be executed considering both global and node-level limits.""" | |
| if self._global_time.is_exhausted: | |
| elapsed = self.get_elapsed_seconds() | |
| time_limit = self.config.total_time_limit_seconds | |
| if time_limit is not None and elapsed >= time_limit: | |
| return False, f"Time budget exhausted: {elapsed:.1f}s" | |
| if not self._global_tokens.can_spend(estimated_tokens): | |
| return ( | |
| False, | |
| f"Global token budget exhausted: {self._global_tokens.used}/{self._global_tokens.limit}", | |
| ) | |
| if not self._global_requests.can_spend(1): | |
| return ( | |
| False, | |
| f"Global request budget exhausted: {self._global_requests.used}/{self._global_requests.limit}", | |
| ) | |
| node_budget = self.get_node_budget(node_id) | |
| can, reason = node_budget.can_execute(estimated_tokens) | |
| if not can: | |
| return False, reason | |
| return True, None | |
| def record_usage( | |
| self, | |
| node_id: str, | |
| prompt_tokens: int = 0, | |
| completion_tokens: int = 0, | |
| latency_seconds: float = 0.0, | |
| ) -> None: | |
| """Record actual consumption for a node and update global counters.""" | |
| total_tokens = prompt_tokens + completion_tokens | |
| self._global_tokens.spend(total_tokens) | |
| self._global_requests.spend(1) | |
| node_budget = self.get_node_budget(node_id) | |
| node_budget.record_usage(tokens=total_tokens, time_seconds=latency_seconds) | |
| self._check_warnings() | |
| def truncate_prompt(self, prompt: str) -> str: | |
| """Truncate prompt to the configured limit and append a truncation marker.""" | |
| if self.config.max_prompt_length and len(prompt) > self.config.max_prompt_length: | |
| return prompt[: self.config.max_prompt_length] + "\n[TRUNCATED]" | |
| return prompt | |
| def truncate_response(self, response: str) -> str: | |
| """Truncate response to the configured limit and append a truncation marker.""" | |
| if self.config.max_response_length and len(response) > self.config.max_response_length: | |
| return response[: self.config.max_response_length] + "\n[TRUNCATED]" | |
| return response | |
| def _check_warnings(self) -> None: | |
| """Invoke warning callbacks if the warn_at_usage_ratio threshold has been reached.""" | |
| if self.config.on_budget_warning: | |
| if self._global_tokens.usage_ratio >= self.config.warn_at_usage_ratio: | |
| self.config.on_budget_warning("tokens", self._global_tokens) | |
| if self._global_requests.usage_ratio >= self.config.warn_at_usage_ratio: | |
| self.config.on_budget_warning("requests", self._global_requests) | |
| def global_tokens(self) -> Budget: | |
| return self._global_tokens | |
| def global_requests(self) -> Budget: | |
| return self._global_requests | |
| def global_time(self) -> Budget: | |
| return self._global_time | |
| def get_summary(self) -> dict[str, Any]: | |
| """Return a summary of global and per-node budget usage.""" | |
| return { | |
| "global": { | |
| "tokens": self._global_tokens.to_dict(), | |
| "requests": self._global_requests.to_dict(), | |
| "time": self._global_time.to_dict(), | |
| "elapsed_seconds": self.get_elapsed_seconds(), | |
| }, | |
| "nodes": {node_id: budget.to_dict() for node_id, budget in self._node_budgets.items()}, | |
| } | |
| def reset(self) -> None: | |
| """Reset all budgets and the start time.""" | |
| self._global_tokens.reset() | |
| self._global_requests.reset() | |
| self._global_time.reset() | |
| self._node_budgets.clear() | |
| self._start_time = None | |