File size: 11,029 Bytes
3193174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
"""
Management of token, request, and time budgets.

Provides cost control at the graph level and at individual node level.
"""

from collections.abc import Callable
from datetime import UTC, datetime
from typing import Any

from pydantic import BaseModel, ConfigDict

__all__ = [
    "Budget",
    "BudgetConfig",
    "BudgetTracker",
    "NodeBudget",
]


class Budget(BaseModel):
    """Tracks a resource limit accounting for used and reserved amounts."""

    limit: float
    used: float = 0.0
    reserved: float = 0.0

    @property
    def available(self) -> float:
        """Remaining available resource (excluding reserved amount)."""
        return max(0.0, self.limit - self.used - self.reserved)

    @property
    def remaining(self) -> float:
        """Remaining resource ignoring the reserve (limit - used)."""
        return max(0.0, self.limit - self.used)

    @property
    def usage_ratio(self) -> float:
        """Fraction of the resource limit that has been consumed."""
        if self.limit <= 0:
            return 0.0
        return self.used / self.limit

    @property
    def is_exhausted(self) -> bool:
        """True if no available resource remains."""
        return self.available <= 0

    def can_spend(self, amount: float) -> bool:
        """Check whether the available resource is sufficient for the given amount."""
        return self.available >= amount

    def spend(self, amount: float) -> bool:
        """Consume the resource if available; return True on success."""
        if not self.can_spend(amount):
            return False
        self.used += amount
        return True

    def reserve(self, amount: float) -> bool:
        """Reserve resource for a future operation."""
        if self.available < amount:
            return False
        self.reserved += amount
        return True

    def release_reservation(self, amount: float) -> None:
        """Release a portion of the reservation."""
        self.reserved = max(0.0, self.reserved - amount)

    def commit_reservation(self, amount: float) -> None:
        """Move up to the given amount from reservation into usage."""
        actual = min(amount, self.reserved)
        self.reserved -= actual
        self.used += actual

    def reset(self) -> None:
        """Reset used and reserved amounts to zero."""
        self.used = 0.0
        self.reserved = 0.0

    def to_dict(self) -> dict[str, Any]:
        """Serialize the budget to a dictionary."""
        return {
            "limit": self.limit,
            "used": self.used,
            "reserved": self.reserved,
            "available": self.available,
            "usage_ratio": self.usage_ratio,
        }


class NodeBudget(BaseModel):
    """Per-node limits for tokens, requests, time, and message lengths."""

    node_id: str
    tokens: Budget | None = None
    requests: Budget | None = None
    time_seconds: Budget | None = None
    max_prompt_length: int | None = None
    max_response_length: int | None = None

    def can_execute(self, estimated_tokens: int = 0) -> tuple[bool, str | None]:
        """Check whether a step can be executed given the estimated token count."""
        if self.tokens and not self.tokens.can_spend(estimated_tokens):
            return False, f"Token budget exhausted for node {self.node_id}"

        if self.requests and not self.requests.can_spend(1):
            return False, f"Request budget exhausted for node {self.node_id}"

        return True, None

    def record_usage(
        self,
        tokens: int = 0,
        time_seconds: float = 0.0,
    ) -> None:
        """Record actual resource consumption for the node."""
        if self.tokens:
            self.tokens.spend(tokens)
        if self.requests:
            self.requests.spend(1)
        if self.time_seconds:
            self.time_seconds.spend(time_seconds)

    def to_dict(self) -> dict[str, Any]:
        """Serialize the node budget to a dictionary."""
        return {
            "node_id": self.node_id,
            "tokens": self.tokens.to_dict() if self.tokens else None,
            "requests": self.requests.to_dict() if self.requests else None,
            "time_seconds": self.time_seconds.to_dict() if self.time_seconds else None,
            "limits": {
                "max_prompt_length": self.max_prompt_length,
                "max_response_length": self.max_response_length,
            },
        }


class BudgetConfig(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    """Configuration for global and per-component execution limits."""

    total_token_limit: int | None = None
    total_request_limit: int | None = None
    total_time_limit_seconds: float | None = None

    node_token_limit: int | None = None
    node_request_limit: int | None = None
    node_time_limit_seconds: float | None = None

    max_prompt_length: int | None = None
    max_response_length: int | None = None

    warn_at_usage_ratio: float = 0.8

    on_budget_warning: Callable[[str, Budget], None] | None = None
    on_budget_exceeded: Callable[[str, Budget], None] | None = None


class BudgetTracker:
    """Tracks global and per-node budgets and issues warnings when thresholds are approached."""

    def __init__(self, config: BudgetConfig | None = None):
        self.config = config or BudgetConfig()

        self._global_tokens = Budget(limit=float(self.config.total_token_limit or float("inf")))
        self._global_requests = Budget(limit=float(self.config.total_request_limit or float("inf")))
        self._global_time = Budget(limit=self.config.total_time_limit_seconds or float("inf"))

        self._node_budgets: dict[str, NodeBudget] = {}
        self._start_time: datetime | None = None

    def start(self) -> None:
        """Record the start time for time-budget tracking."""
        self._start_time = datetime.now(UTC)

    def get_elapsed_seconds(self) -> float:
        """Return elapsed seconds since start() was called."""
        if self._start_time is None:
            return 0.0
        return (datetime.now(UTC) - self._start_time).total_seconds()

    def get_node_budget(self, node_id: str) -> NodeBudget:
        """Return (or create) the budget for the given node."""
        if node_id not in self._node_budgets:
            self._node_budgets[node_id] = NodeBudget(
                node_id=node_id,
                tokens=Budget(limit=float(self.config.node_token_limit or float("inf")))
                if self.config.node_token_limit
                else None,
                requests=Budget(limit=float(self.config.node_request_limit or float("inf")))
                if self.config.node_request_limit
                else None,
                time_seconds=Budget(limit=self.config.node_time_limit_seconds or float("inf"))
                if self.config.node_time_limit_seconds
                else None,
                max_prompt_length=self.config.max_prompt_length,
                max_response_length=self.config.max_response_length,
            )
        return self._node_budgets[node_id]

    def can_execute(
        self,
        node_id: str,
        estimated_tokens: int = 0,
    ) -> tuple[bool, str | None]:
        """Check whether a step can be executed considering both global and node-level limits."""
        if self._global_time.is_exhausted:
            elapsed = self.get_elapsed_seconds()
            time_limit = self.config.total_time_limit_seconds
            if time_limit is not None and elapsed >= time_limit:
                return False, f"Time budget exhausted: {elapsed:.1f}s"

        if not self._global_tokens.can_spend(estimated_tokens):
            return (
                False,
                f"Global token budget exhausted: {self._global_tokens.used}/{self._global_tokens.limit}",
            )

        if not self._global_requests.can_spend(1):
            return (
                False,
                f"Global request budget exhausted: {self._global_requests.used}/{self._global_requests.limit}",
            )

        node_budget = self.get_node_budget(node_id)
        can, reason = node_budget.can_execute(estimated_tokens)
        if not can:
            return False, reason

        return True, None

    def record_usage(
        self,
        node_id: str,
        prompt_tokens: int = 0,
        completion_tokens: int = 0,
        latency_seconds: float = 0.0,
    ) -> None:
        """Record actual consumption for a node and update global counters."""
        total_tokens = prompt_tokens + completion_tokens

        self._global_tokens.spend(total_tokens)
        self._global_requests.spend(1)

        node_budget = self.get_node_budget(node_id)
        node_budget.record_usage(tokens=total_tokens, time_seconds=latency_seconds)

        self._check_warnings()

    def truncate_prompt(self, prompt: str) -> str:
        """Truncate prompt to the configured limit and append a truncation marker."""
        if self.config.max_prompt_length and len(prompt) > self.config.max_prompt_length:
            return prompt[: self.config.max_prompt_length] + "\n[TRUNCATED]"
        return prompt

    def truncate_response(self, response: str) -> str:
        """Truncate response to the configured limit and append a truncation marker."""
        if self.config.max_response_length and len(response) > self.config.max_response_length:
            return response[: self.config.max_response_length] + "\n[TRUNCATED]"
        return response

    def _check_warnings(self) -> None:
        """Invoke warning callbacks if the warn_at_usage_ratio threshold has been reached."""
        if self.config.on_budget_warning:
            if self._global_tokens.usage_ratio >= self.config.warn_at_usage_ratio:
                self.config.on_budget_warning("tokens", self._global_tokens)
            if self._global_requests.usage_ratio >= self.config.warn_at_usage_ratio:
                self.config.on_budget_warning("requests", self._global_requests)

    @property
    def global_tokens(self) -> Budget:
        return self._global_tokens

    @property
    def global_requests(self) -> Budget:
        return self._global_requests

    @property
    def global_time(self) -> Budget:
        return self._global_time

    def get_summary(self) -> dict[str, Any]:
        """Return a summary of global and per-node budget usage."""
        return {
            "global": {
                "tokens": self._global_tokens.to_dict(),
                "requests": self._global_requests.to_dict(),
                "time": self._global_time.to_dict(),
                "elapsed_seconds": self.get_elapsed_seconds(),
            },
            "nodes": {node_id: budget.to_dict() for node_id, budget in self._node_budgets.items()},
        }

    def reset(self) -> None:
        """Reset all budgets and the start time."""
        self._global_tokens.reset()
        self._global_requests.reset()
        self._global_time.reset()
        self._node_budgets.clear()
        self._start_time = None