""" Management of token, request, and time budgets. Provides cost control at the graph level and at individual node level. """ from collections.abc import Callable from datetime import UTC, datetime from typing import Any from pydantic import BaseModel, ConfigDict __all__ = [ "Budget", "BudgetConfig", "BudgetTracker", "NodeBudget", ] class Budget(BaseModel): """Tracks a resource limit accounting for used and reserved amounts.""" limit: float used: float = 0.0 reserved: float = 0.0 @property def available(self) -> float: """Remaining available resource (excluding reserved amount).""" return max(0.0, self.limit - self.used - self.reserved) @property def remaining(self) -> float: """Remaining resource ignoring the reserve (limit - used).""" return max(0.0, self.limit - self.used) @property def usage_ratio(self) -> float: """Fraction of the resource limit that has been consumed.""" if self.limit <= 0: return 0.0 return self.used / self.limit @property def is_exhausted(self) -> bool: """True if no available resource remains.""" return self.available <= 0 def can_spend(self, amount: float) -> bool: """Check whether the available resource is sufficient for the given amount.""" return self.available >= amount def spend(self, amount: float) -> bool: """Consume the resource if available; return True on success.""" if not self.can_spend(amount): return False self.used += amount return True def reserve(self, amount: float) -> bool: """Reserve resource for a future operation.""" if self.available < amount: return False self.reserved += amount return True def release_reservation(self, amount: float) -> None: """Release a portion of the reservation.""" self.reserved = max(0.0, self.reserved - amount) def commit_reservation(self, amount: float) -> None: """Move up to the given amount from reservation into usage.""" actual = min(amount, self.reserved) self.reserved -= actual self.used += actual def reset(self) -> None: """Reset used and reserved amounts to zero.""" self.used = 0.0 self.reserved = 0.0 def to_dict(self) -> dict[str, Any]: """Serialize the budget to a dictionary.""" return { "limit": self.limit, "used": self.used, "reserved": self.reserved, "available": self.available, "usage_ratio": self.usage_ratio, } class NodeBudget(BaseModel): """Per-node limits for tokens, requests, time, and message lengths.""" node_id: str tokens: Budget | None = None requests: Budget | None = None time_seconds: Budget | None = None max_prompt_length: int | None = None max_response_length: int | None = None def can_execute(self, estimated_tokens: int = 0) -> tuple[bool, str | None]: """Check whether a step can be executed given the estimated token count.""" if self.tokens and not self.tokens.can_spend(estimated_tokens): return False, f"Token budget exhausted for node {self.node_id}" if self.requests and not self.requests.can_spend(1): return False, f"Request budget exhausted for node {self.node_id}" return True, None def record_usage( self, tokens: int = 0, time_seconds: float = 0.0, ) -> None: """Record actual resource consumption for the node.""" if self.tokens: self.tokens.spend(tokens) if self.requests: self.requests.spend(1) if self.time_seconds: self.time_seconds.spend(time_seconds) def to_dict(self) -> dict[str, Any]: """Serialize the node budget to a dictionary.""" return { "node_id": self.node_id, "tokens": self.tokens.to_dict() if self.tokens else None, "requests": self.requests.to_dict() if self.requests else None, "time_seconds": self.time_seconds.to_dict() if self.time_seconds else None, "limits": { "max_prompt_length": self.max_prompt_length, "max_response_length": self.max_response_length, }, } class BudgetConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) """Configuration for global and per-component execution limits.""" total_token_limit: int | None = None total_request_limit: int | None = None total_time_limit_seconds: float | None = None node_token_limit: int | None = None node_request_limit: int | None = None node_time_limit_seconds: float | None = None max_prompt_length: int | None = None max_response_length: int | None = None warn_at_usage_ratio: float = 0.8 on_budget_warning: Callable[[str, Budget], None] | None = None on_budget_exceeded: Callable[[str, Budget], None] | None = None class BudgetTracker: """Tracks global and per-node budgets and issues warnings when thresholds are approached.""" def __init__(self, config: BudgetConfig | None = None): self.config = config or BudgetConfig() self._global_tokens = Budget(limit=float(self.config.total_token_limit or float("inf"))) self._global_requests = Budget(limit=float(self.config.total_request_limit or float("inf"))) self._global_time = Budget(limit=self.config.total_time_limit_seconds or float("inf")) self._node_budgets: dict[str, NodeBudget] = {} self._start_time: datetime | None = None def start(self) -> None: """Record the start time for time-budget tracking.""" self._start_time = datetime.now(UTC) def get_elapsed_seconds(self) -> float: """Return elapsed seconds since start() was called.""" if self._start_time is None: return 0.0 return (datetime.now(UTC) - self._start_time).total_seconds() def get_node_budget(self, node_id: str) -> NodeBudget: """Return (or create) the budget for the given node.""" if node_id not in self._node_budgets: self._node_budgets[node_id] = NodeBudget( node_id=node_id, tokens=Budget(limit=float(self.config.node_token_limit or float("inf"))) if self.config.node_token_limit else None, requests=Budget(limit=float(self.config.node_request_limit or float("inf"))) if self.config.node_request_limit else None, time_seconds=Budget(limit=self.config.node_time_limit_seconds or float("inf")) if self.config.node_time_limit_seconds else None, max_prompt_length=self.config.max_prompt_length, max_response_length=self.config.max_response_length, ) return self._node_budgets[node_id] def can_execute( self, node_id: str, estimated_tokens: int = 0, ) -> tuple[bool, str | None]: """Check whether a step can be executed considering both global and node-level limits.""" if self._global_time.is_exhausted: elapsed = self.get_elapsed_seconds() time_limit = self.config.total_time_limit_seconds if time_limit is not None and elapsed >= time_limit: return False, f"Time budget exhausted: {elapsed:.1f}s" if not self._global_tokens.can_spend(estimated_tokens): return ( False, f"Global token budget exhausted: {self._global_tokens.used}/{self._global_tokens.limit}", ) if not self._global_requests.can_spend(1): return ( False, f"Global request budget exhausted: {self._global_requests.used}/{self._global_requests.limit}", ) node_budget = self.get_node_budget(node_id) can, reason = node_budget.can_execute(estimated_tokens) if not can: return False, reason return True, None def record_usage( self, node_id: str, prompt_tokens: int = 0, completion_tokens: int = 0, latency_seconds: float = 0.0, ) -> None: """Record actual consumption for a node and update global counters.""" total_tokens = prompt_tokens + completion_tokens self._global_tokens.spend(total_tokens) self._global_requests.spend(1) node_budget = self.get_node_budget(node_id) node_budget.record_usage(tokens=total_tokens, time_seconds=latency_seconds) self._check_warnings() def truncate_prompt(self, prompt: str) -> str: """Truncate prompt to the configured limit and append a truncation marker.""" if self.config.max_prompt_length and len(prompt) > self.config.max_prompt_length: return prompt[: self.config.max_prompt_length] + "\n[TRUNCATED]" return prompt def truncate_response(self, response: str) -> str: """Truncate response to the configured limit and append a truncation marker.""" if self.config.max_response_length and len(response) > self.config.max_response_length: return response[: self.config.max_response_length] + "\n[TRUNCATED]" return response def _check_warnings(self) -> None: """Invoke warning callbacks if the warn_at_usage_ratio threshold has been reached.""" if self.config.on_budget_warning: if self._global_tokens.usage_ratio >= self.config.warn_at_usage_ratio: self.config.on_budget_warning("tokens", self._global_tokens) if self._global_requests.usage_ratio >= self.config.warn_at_usage_ratio: self.config.on_budget_warning("requests", self._global_requests) @property def global_tokens(self) -> Budget: return self._global_tokens @property def global_requests(self) -> Budget: return self._global_requests @property def global_time(self) -> Budget: return self._global_time def get_summary(self) -> dict[str, Any]: """Return a summary of global and per-node budget usage.""" return { "global": { "tokens": self._global_tokens.to_dict(), "requests": self._global_requests.to_dict(), "time": self._global_time.to_dict(), "elapsed_seconds": self.get_elapsed_seconds(), }, "nodes": {node_id: budget.to_dict() for node_id, budget in self._node_budgets.items()}, } def reset(self) -> None: """Reset all budgets and the start time.""" self._global_tokens.reset() self._global_requests.reset() self._global_time.reset() self._node_budgets.clear() self._start_time = None