lewtun HF Staff OpenAI Codex commited on
Commit
a8e0e2c
·
2 Parent(s): 1c6871277324b8

Deploy 2026-05-01

Browse files

Co-authored-by: OpenAI Codex <codex@openai.com>

agent/core/agent_loop.py CHANGED
@@ -19,6 +19,11 @@ from litellm import (
19
  from litellm.exceptions import ContextWindowExceededError
20
 
21
  from agent.config import Config
 
 
 
 
 
22
  from agent.messaging.gateway import NotificationGateway
23
  from agent.core import telemetry
24
  from agent.core.doom_loop import check_for_doom_loop
@@ -110,13 +115,39 @@ def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
110
  return True, None
111
 
112
 
113
- def _needs_approval(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  tool_name: str, tool_args: dict, config: Config | None = None
115
  ) -> bool:
116
- """Check if a tool call requires user approval before execution."""
117
- # Yolo mode: skip all approvals
118
- if config and config.yolo_mode:
119
- return False
120
 
121
  # If args are malformed, skip approval (validation error will be shown later)
122
  args_valid, _ = _validate_tool_args(tool_args)
@@ -127,8 +158,10 @@ def _needs_approval(
127
  return True
128
 
129
  if tool_name == "hf_jobs":
130
- operation = tool_args.get("operation", "")
131
- if operation not in ["run", "uv", "scheduled run", "scheduled uv"]:
 
 
132
  return False
133
 
134
  # Check if this is a CPU-only job
@@ -180,6 +213,143 @@ def _needs_approval(
180
  return False
181
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # -- LLM retry constants --------------------------------------------------
184
  _MAX_LLM_RETRIES = 3
185
  _LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries
@@ -1063,29 +1233,49 @@ class Handlers:
1063
  if session.is_cancelled:
1064
  break
1065
 
1066
- # Separate good tools into approval-required vs auto-execute
1067
- approval_required_tools: list[tuple[ToolCall, str, dict]] = []
1068
- non_approval_tools: list[tuple[ToolCall, str, dict]] = []
 
 
 
 
 
 
 
 
1069
  for tc, tool_name, tool_args in good_tools:
1070
- if _needs_approval(tool_name, tool_args, session.config):
1071
- approval_required_tools.append((tc, tool_name, tool_args))
 
 
 
 
 
 
1072
  else:
1073
- non_approval_tools.append((tc, tool_name, tool_args))
 
 
 
 
 
 
1074
 
1075
  # Execute non-approval tools (in parallel when possible)
1076
  if non_approval_tools:
1077
  # 1. Validate args upfront
1078
  parsed_tools: list[
1079
- tuple[ToolCall, str, dict, bool, str]
1080
  ] = []
1081
- for tc, tool_name, tool_args in non_approval_tools:
1082
  args_valid, error_msg = _validate_tool_args(tool_args)
1083
  parsed_tools.append(
1084
- (tc, tool_name, tool_args, args_valid, error_msg)
1085
  )
1086
 
1087
  # 2. Send all tool_call events upfront (so frontend shows them all)
1088
- for tc, tool_name, tool_args, args_valid, _ in parsed_tools:
1089
  if args_valid:
1090
  await session.send_event(
1091
  Event(
@@ -1103,11 +1293,14 @@ class Handlers:
1103
  tc: ToolCall,
1104
  name: str,
1105
  args: dict,
 
1106
  valid: bool,
1107
  err: str,
1108
  ) -> tuple[ToolCall, str, dict, str, bool]:
1109
  if not valid:
1110
  return (tc, name, args, err, False)
 
 
1111
  out, ok = await session.tool_router.call_tool(
1112
  name, args, session=session, tool_call_id=tc.id
1113
  )
@@ -1115,8 +1308,8 @@ class Handlers:
1115
 
1116
  gather_task = asyncio.ensure_future(asyncio.gather(
1117
  *[
1118
- _exec_tool(tc, name, args, valid, err)
1119
- for tc, name, args, valid, err in parsed_tools
1120
  ]
1121
  ))
1122
  cancel_task = asyncio.ensure_future(session._cancelled.wait())
@@ -1133,7 +1326,7 @@ class Handlers:
1133
  except asyncio.CancelledError:
1134
  pass
1135
  # Notify frontend that in-flight tools were cancelled
1136
- for tc, name, _args, valid, _ in parsed_tools:
1137
  if valid:
1138
  await session.send_event(Event(
1139
  event_type="tool_state_change",
@@ -1171,7 +1364,8 @@ class Handlers:
1171
  if approval_required_tools:
1172
  # Prepare batch approval data
1173
  tools_data = []
1174
- for tc, tool_name, tool_args in approval_required_tools:
 
1175
  # Resolve sandbox file paths for hf_jobs scripts so the
1176
  # frontend can display & edit the actual file content.
1177
  if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str):
@@ -1181,20 +1375,42 @@ class Handlers:
1181
  if resolved:
1182
  tool_args = {**tool_args, "script": resolved}
1183
 
1184
- tools_data.append({
1185
  "tool": tool_name,
1186
  "arguments": tool_args,
1187
  "tool_call_id": tc.id,
1188
- })
1189
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1190
  await session.send_event(Event(
1191
  event_type="approval_required",
1192
- data={"tools": tools_data, "count": len(tools_data)},
1193
  ))
1194
 
1195
  # Store all approval-requiring tools (ToolCall objects for execution)
1196
  session.pending_approval = {
1197
- "tool_calls": [tc for tc, _, _ in approval_required_tools],
1198
  }
1199
 
1200
  # Return early - wait for EXEC_APPROVAL operation
@@ -1384,6 +1600,8 @@ class Handlers:
1384
  )
1385
  )
1386
 
 
 
1387
  output, success = await session.tool_router.call_tool(
1388
  tool_name, tool_args, session=session, tool_call_id=tc.id
1389
  )
 
19
  from litellm.exceptions import ContextWindowExceededError
20
 
21
  from agent.config import Config
22
+ from agent.core.approval_policy import (
23
+ is_scheduled_operation,
24
+ normalize_tool_operation,
25
+ )
26
+ from agent.core.cost_estimation import CostEstimate, estimate_tool_cost
27
  from agent.messaging.gateway import NotificationGateway
28
  from agent.core import telemetry
29
  from agent.core.doom_loop import check_for_doom_loop
 
115
  return True, None
116
 
117
 
118
+ _IMMEDIATE_HF_JOB_RUNS = {"run", "uv"}
119
+
120
+ @dataclass(frozen=True)
121
+ class ApprovalDecision:
122
+ requires_approval: bool
123
+ auto_approved: bool = False
124
+ auto_approval_blocked: bool = False
125
+ block_reason: str | None = None
126
+ estimated_cost_usd: float | None = None
127
+ remaining_cap_usd: float | None = None
128
+ billable: bool = False
129
+
130
+
131
+ def _operation(tool_args: dict) -> str:
132
+ return normalize_tool_operation(tool_args.get("operation"))
133
+
134
+
135
+ def _is_immediate_hf_job_run(tool_name: str, tool_args: dict) -> bool:
136
+ return tool_name == "hf_jobs" and _operation(tool_args) in _IMMEDIATE_HF_JOB_RUNS
137
+
138
+
139
+ def _is_scheduled_hf_job_run(tool_name: str, tool_args: dict) -> bool:
140
+ return tool_name == "hf_jobs" and is_scheduled_operation(_operation(tool_args))
141
+
142
+
143
+ def _is_budgeted_auto_approval_target(tool_name: str, tool_args: dict) -> bool:
144
+ return tool_name == "sandbox_create" or _is_immediate_hf_job_run(tool_name, tool_args)
145
+
146
+
147
+ def _base_needs_approval(
148
  tool_name: str, tool_args: dict, config: Config | None = None
149
  ) -> bool:
150
+ """Check if a tool call requires approval before YOLO policy is applied."""
 
 
 
151
 
152
  # If args are malformed, skip approval (validation error will be shown later)
153
  args_valid, _ = _validate_tool_args(tool_args)
 
158
  return True
159
 
160
  if tool_name == "hf_jobs":
161
+ operation = _operation(tool_args)
162
+ if is_scheduled_operation(operation):
163
+ return True
164
+ if operation not in _IMMEDIATE_HF_JOB_RUNS:
165
  return False
166
 
167
  # Check if this is a CPU-only job
 
213
  return False
214
 
215
 
216
+ def _needs_approval(
217
+ tool_name: str, tool_args: dict, config: Config | None = None
218
+ ) -> bool:
219
+ """Legacy sync approval predicate used by tests and CLI display helpers."""
220
+ if _is_scheduled_hf_job_run(tool_name, tool_args):
221
+ return True
222
+ if config and config.yolo_mode:
223
+ return False
224
+ return _base_needs_approval(tool_name, tool_args, config)
225
+
226
+
227
+ def _session_auto_approval_enabled(session: Session | None) -> bool:
228
+ return bool(session and getattr(session, "auto_approval_enabled", False))
229
+
230
+
231
+ def _effective_yolo_enabled(session: Session | None, config: Config | None) -> bool:
232
+ return bool((config and config.yolo_mode) or _session_auto_approval_enabled(session))
233
+
234
+
235
+ def _remaining_budget_after_reservations(
236
+ session: Session | None, reserved_spend_usd: float
237
+ ) -> float | None:
238
+ if not session or getattr(session, "auto_approval_cost_cap_usd", None) is None:
239
+ return None
240
+ cap = float(getattr(session, "auto_approval_cost_cap_usd") or 0.0)
241
+ spent = float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0)
242
+ return round(max(0.0, cap - spent - reserved_spend_usd), 4)
243
+
244
+
245
+ def _budget_block_reason(
246
+ estimate: CostEstimate,
247
+ *,
248
+ remaining_cap_usd: float | None,
249
+ ) -> str | None:
250
+ if estimate.estimated_cost_usd is None:
251
+ return estimate.block_reason or "Could not estimate the cost safely."
252
+ if remaining_cap_usd is not None and estimate.estimated_cost_usd > remaining_cap_usd:
253
+ return (
254
+ f"Estimated cost ${estimate.estimated_cost_usd:.2f} exceeds "
255
+ f"remaining YOLO cap ${remaining_cap_usd:.2f}."
256
+ )
257
+ return None
258
+
259
+
260
+ async def _approval_decision(
261
+ tool_name: str,
262
+ tool_args: dict,
263
+ session: Session,
264
+ *,
265
+ reserved_spend_usd: float = 0.0,
266
+ ) -> ApprovalDecision:
267
+ """Return the approval decision for one parsed tool call."""
268
+ config = session.config
269
+ base_requires_approval = _base_needs_approval(tool_name, tool_args, config)
270
+
271
+ # Scheduled jobs are recurring/unbounded enough that YOLO never bypasses
272
+ # the human confirmation, including legacy config.yolo_mode.
273
+ if _is_scheduled_hf_job_run(tool_name, tool_args):
274
+ return ApprovalDecision(
275
+ requires_approval=True,
276
+ auto_approval_blocked=_effective_yolo_enabled(session, config),
277
+ block_reason="Scheduled HF jobs always require manual approval.",
278
+ )
279
+
280
+ yolo_enabled = _effective_yolo_enabled(session, config)
281
+ budgeted_target = _is_budgeted_auto_approval_target(tool_name, tool_args)
282
+
283
+ # Cost caps are a session-scoped web policy. Legacy config.yolo_mode
284
+ # remains uncapped for CLI/headless, except for scheduled jobs above.
285
+ session_yolo_enabled = _session_auto_approval_enabled(session)
286
+ if yolo_enabled and budgeted_target and session_yolo_enabled:
287
+ estimate = await estimate_tool_cost(tool_name, tool_args, session=session)
288
+ remaining = _remaining_budget_after_reservations(session, reserved_spend_usd)
289
+ reason = _budget_block_reason(estimate, remaining_cap_usd=remaining)
290
+ if reason:
291
+ return ApprovalDecision(
292
+ requires_approval=True,
293
+ auto_approval_blocked=True,
294
+ block_reason=reason,
295
+ estimated_cost_usd=estimate.estimated_cost_usd,
296
+ remaining_cap_usd=remaining,
297
+ billable=estimate.billable,
298
+ )
299
+ if base_requires_approval:
300
+ return ApprovalDecision(
301
+ requires_approval=False,
302
+ auto_approved=True,
303
+ estimated_cost_usd=estimate.estimated_cost_usd,
304
+ remaining_cap_usd=remaining,
305
+ billable=estimate.billable,
306
+ )
307
+ return ApprovalDecision(
308
+ requires_approval=False,
309
+ estimated_cost_usd=estimate.estimated_cost_usd,
310
+ remaining_cap_usd=remaining,
311
+ billable=estimate.billable,
312
+ )
313
+
314
+ if base_requires_approval and yolo_enabled:
315
+ return ApprovalDecision(requires_approval=False, auto_approved=True)
316
+
317
+ return ApprovalDecision(requires_approval=base_requires_approval)
318
+
319
+
320
+ def _record_estimated_spend(session: Session, decision: ApprovalDecision) -> None:
321
+ if not decision.billable or decision.estimated_cost_usd is None:
322
+ return
323
+ if hasattr(session, "add_auto_approval_estimated_spend"):
324
+ session.add_auto_approval_estimated_spend(decision.estimated_cost_usd)
325
+ else:
326
+ session.auto_approval_estimated_spend_usd = round(
327
+ float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0)
328
+ + float(decision.estimated_cost_usd),
329
+ 4,
330
+ )
331
+
332
+
333
+ async def _record_manual_approved_spend_if_needed(
334
+ session: Session,
335
+ tool_name: str,
336
+ tool_args: dict,
337
+ ) -> None:
338
+ if not _session_auto_approval_enabled(session):
339
+ return
340
+ if not _is_budgeted_auto_approval_target(tool_name, tool_args):
341
+ return
342
+ estimate = await estimate_tool_cost(tool_name, tool_args, session=session)
343
+ _record_estimated_spend(
344
+ session,
345
+ ApprovalDecision(
346
+ requires_approval=False,
347
+ billable=estimate.billable,
348
+ estimated_cost_usd=estimate.estimated_cost_usd,
349
+ ),
350
+ )
351
+
352
+
353
  # -- LLM retry constants --------------------------------------------------
354
  _MAX_LLM_RETRIES = 3
355
  _LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries
 
1233
  if session.is_cancelled:
1234
  break
1235
 
1236
+ # Separate good tools into approval-required vs auto-execute.
1237
+ # Track reserved spend while classifying a batch so two
1238
+ # auto-approved jobs in one model response cannot jointly
1239
+ # exceed the remaining session cap.
1240
+ approval_required_tools: list[
1241
+ tuple[ToolCall, str, dict, ApprovalDecision]
1242
+ ] = []
1243
+ non_approval_tools: list[
1244
+ tuple[ToolCall, str, dict, ApprovalDecision]
1245
+ ] = []
1246
+ reserved_auto_spend_usd = 0.0
1247
  for tc, tool_name, tool_args in good_tools:
1248
+ decision = await _approval_decision(
1249
+ tool_name,
1250
+ tool_args,
1251
+ session,
1252
+ reserved_spend_usd=reserved_auto_spend_usd,
1253
+ )
1254
+ if decision.requires_approval:
1255
+ approval_required_tools.append((tc, tool_name, tool_args, decision))
1256
  else:
1257
+ non_approval_tools.append((tc, tool_name, tool_args, decision))
1258
+ if (
1259
+ decision.auto_approved
1260
+ and decision.billable
1261
+ and decision.estimated_cost_usd is not None
1262
+ ):
1263
+ reserved_auto_spend_usd += decision.estimated_cost_usd
1264
 
1265
  # Execute non-approval tools (in parallel when possible)
1266
  if non_approval_tools:
1267
  # 1. Validate args upfront
1268
  parsed_tools: list[
1269
+ tuple[ToolCall, str, dict, ApprovalDecision, bool, str]
1270
  ] = []
1271
+ for tc, tool_name, tool_args, decision in non_approval_tools:
1272
  args_valid, error_msg = _validate_tool_args(tool_args)
1273
  parsed_tools.append(
1274
+ (tc, tool_name, tool_args, decision, args_valid, error_msg)
1275
  )
1276
 
1277
  # 2. Send all tool_call events upfront (so frontend shows them all)
1278
+ for tc, tool_name, tool_args, _decision, args_valid, _ in parsed_tools:
1279
  if args_valid:
1280
  await session.send_event(
1281
  Event(
 
1293
  tc: ToolCall,
1294
  name: str,
1295
  args: dict,
1296
+ decision: ApprovalDecision,
1297
  valid: bool,
1298
  err: str,
1299
  ) -> tuple[ToolCall, str, dict, str, bool]:
1300
  if not valid:
1301
  return (tc, name, args, err, False)
1302
+ if decision.billable:
1303
+ _record_estimated_spend(session, decision)
1304
  out, ok = await session.tool_router.call_tool(
1305
  name, args, session=session, tool_call_id=tc.id
1306
  )
 
1308
 
1309
  gather_task = asyncio.ensure_future(asyncio.gather(
1310
  *[
1311
+ _exec_tool(tc, name, args, decision, valid, err)
1312
+ for tc, name, args, decision, valid, err in parsed_tools
1313
  ]
1314
  ))
1315
  cancel_task = asyncio.ensure_future(session._cancelled.wait())
 
1326
  except asyncio.CancelledError:
1327
  pass
1328
  # Notify frontend that in-flight tools were cancelled
1329
+ for tc, name, _args, _decision, valid, _ in parsed_tools:
1330
  if valid:
1331
  await session.send_event(Event(
1332
  event_type="tool_state_change",
 
1364
  if approval_required_tools:
1365
  # Prepare batch approval data
1366
  tools_data = []
1367
+ blocked_payloads = []
1368
+ for tc, tool_name, tool_args, decision in approval_required_tools:
1369
  # Resolve sandbox file paths for hf_jobs scripts so the
1370
  # frontend can display & edit the actual file content.
1371
  if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str):
 
1375
  if resolved:
1376
  tool_args = {**tool_args, "script": resolved}
1377
 
1378
+ tool_payload = {
1379
  "tool": tool_name,
1380
  "arguments": tool_args,
1381
  "tool_call_id": tc.id,
1382
+ }
1383
+ if decision.auto_approval_blocked:
1384
+ tool_payload.update(
1385
+ {
1386
+ "auto_approval_blocked": True,
1387
+ "block_reason": decision.block_reason,
1388
+ "estimated_cost_usd": decision.estimated_cost_usd,
1389
+ "remaining_cap_usd": decision.remaining_cap_usd,
1390
+ }
1391
+ )
1392
+ blocked_payloads.append(tool_payload)
1393
+ tools_data.append(tool_payload)
1394
+
1395
+ event_data = {"tools": tools_data, "count": len(tools_data)}
1396
+ if blocked_payloads:
1397
+ first = blocked_payloads[0]
1398
+ event_data.update(
1399
+ {
1400
+ "auto_approval_blocked": True,
1401
+ "block_reason": first.get("block_reason"),
1402
+ "estimated_cost_usd": first.get("estimated_cost_usd"),
1403
+ "remaining_cap_usd": first.get("remaining_cap_usd"),
1404
+ }
1405
+ )
1406
  await session.send_event(Event(
1407
  event_type="approval_required",
1408
+ data=event_data,
1409
  ))
1410
 
1411
  # Store all approval-requiring tools (ToolCall objects for execution)
1412
  session.pending_approval = {
1413
+ "tool_calls": [tc for tc, _, _, _ in approval_required_tools],
1414
  }
1415
 
1416
  # Return early - wait for EXEC_APPROVAL operation
 
1600
  )
1601
  )
1602
 
1603
+ await _record_manual_approved_spend_if_needed(session, tool_name, tool_args)
1604
+
1605
  output, success = await session.tool_router.call_tool(
1606
  tool_name, tool_args, session=session, tool_call_id=tc.id
1607
  )
agent/core/approval_policy.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared predicates for approval-gated tool operations."""
2
+
3
+ from typing import Any
4
+
5
+
6
+ def normalize_tool_operation(operation: Any) -> str:
7
+ return str(operation or "").strip().lower()
8
+
9
+
10
+ def is_scheduled_operation(operation: Any) -> bool:
11
+ return normalize_tool_operation(operation).startswith("scheduled ")
agent/core/cost_estimation.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conservative cost estimates for auto-approved infrastructure actions."""
2
+
3
+ import os
4
+ import re
5
+ import time
6
+ from dataclasses import dataclass
7
+ from typing import Any
8
+
9
+ import httpx
10
+
11
+ OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
12
+ JOBS_HARDWARE_URL = f"{OPENID_PROVIDER_URL}/api/jobs/hardware"
13
+ JOBS_PRICE_CACHE_TTL_S = 6 * 60 * 60
14
+
15
+ DEFAULT_JOB_TIMEOUT_HOURS = 0.5
16
+ DEFAULT_SANDBOX_RESERVATION_HOURS = 1.0
17
+
18
+ # Static fallback prices are intentionally conservative enough for a budget
19
+ # guard. The live /api/jobs/hardware catalog wins whenever it is reachable.
20
+ HF_JOBS_PRICE_USD_PER_HOUR: dict[str, float] = {
21
+ "cpu-basic": 0.05,
22
+ "cpu-upgrade": 0.25,
23
+ "cpu-performance": 0.50,
24
+ "cpu-xl": 1.00,
25
+ "t4-small": 0.60,
26
+ "t4-medium": 0.90,
27
+ "l4x1": 1.00,
28
+ "l4x4": 4.00,
29
+ "l40sx1": 2.00,
30
+ "l40sx4": 8.00,
31
+ "l40sx8": 16.00,
32
+ "a10g-small": 1.00,
33
+ "a10g-large": 2.00,
34
+ "a10g-largex2": 4.00,
35
+ "a10g-largex4": 8.00,
36
+ "a100-large": 4.00,
37
+ "a100x4": 16.00,
38
+ "a100x8": 32.00,
39
+ "h200": 10.00,
40
+ "h200x2": 20.00,
41
+ "h200x4": 40.00,
42
+ "h200x8": 80.00,
43
+ "inf2x6": 6.00,
44
+ }
45
+
46
+ SPACE_PRICE_USD_PER_HOUR: dict[str, float] = {
47
+ "cpu-basic": 0.0,
48
+ "cpu-upgrade": 0.05,
49
+ "cpu-performance": 0.50,
50
+ "cpu-xl": 1.00,
51
+ "t4-small": 0.60,
52
+ "t4-medium": 0.90,
53
+ "l4x1": 1.00,
54
+ "l4x4": 4.00,
55
+ "l40sx1": 2.00,
56
+ "l40sx4": 8.00,
57
+ "l40sx8": 16.00,
58
+ "a10g-small": 1.00,
59
+ "a10g-large": 2.00,
60
+ "a10g-largex2": 4.00,
61
+ "a10g-largex4": 8.00,
62
+ "a100-large": 4.00,
63
+ "a100x4": 16.00,
64
+ "a100x8": 32.00,
65
+ "h200": 10.00,
66
+ "h200x2": 20.00,
67
+ "h200x4": 40.00,
68
+ "h200x8": 80.00,
69
+ "inf2x6": 6.00,
70
+ }
71
+
72
+ _DURATION_RE = re.compile(r"^\s*(\d+(?:\.\d+)?)\s*([smhd]?)\s*$", re.IGNORECASE)
73
+ _PRICE_RE = re.compile(r"(\d+(?:\.\d+)?)")
74
+ _jobs_price_cache: tuple[float, dict[str, float]] | None = None
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class CostEstimate:
79
+ """Estimated cost for a tool call.
80
+
81
+ ``estimated_cost_usd=None`` means the call may be billable but we could not
82
+ estimate it safely, so auto-approval should fall back to a human decision.
83
+ """
84
+
85
+ estimated_cost_usd: float | None
86
+ billable: bool
87
+ block_reason: str | None = None
88
+ label: str | None = None
89
+
90
+
91
+ def parse_timeout_hours(value: Any, *, default_hours: float = DEFAULT_JOB_TIMEOUT_HOURS) -> float | None:
92
+ """Parse HF timeout values into hours.
93
+
94
+ Strings accept ``s``, ``m``, ``h``, or ``d`` suffixes. Numeric values are
95
+ treated as seconds, matching the Hub client's typed timeout parameter.
96
+ """
97
+ if value is None or value == "":
98
+ return default_hours
99
+ if isinstance(value, bool):
100
+ return None
101
+ if isinstance(value, int | float):
102
+ seconds = float(value)
103
+ return seconds / 3600 if seconds > 0 else None
104
+ if not isinstance(value, str):
105
+ return None
106
+
107
+ match = _DURATION_RE.match(value)
108
+ if not match:
109
+ return None
110
+ amount = float(match.group(1))
111
+ unit = match.group(2).lower() or "s"
112
+ if amount <= 0:
113
+ return None
114
+ if unit == "s":
115
+ return amount / 3600
116
+ if unit == "m":
117
+ return amount / 60
118
+ if unit == "h":
119
+ return amount
120
+ if unit == "d":
121
+ return amount * 24
122
+ return None
123
+
124
+
125
+ def _extract_flavor(item: dict[str, Any]) -> str | None:
126
+ for key in ("flavor", "name", "id", "value", "hardware", "hardware_flavor"):
127
+ value = item.get(key)
128
+ if isinstance(value, str) and value:
129
+ return value
130
+ return None
131
+
132
+
133
+ def _coerce_price(value: Any) -> float | None:
134
+ if isinstance(value, bool) or value is None:
135
+ return None
136
+ if isinstance(value, int | float):
137
+ return float(value) if value >= 0 else None
138
+ if isinstance(value, str):
139
+ match = _PRICE_RE.search(value.replace(",", ""))
140
+ if match:
141
+ return float(match.group(1))
142
+ return None
143
+
144
+
145
+ def _extract_hourly_price(item: dict[str, Any]) -> float | None:
146
+ for key in (
147
+ "price",
148
+ "price_usd",
149
+ "priceUsd",
150
+ "price_per_hour",
151
+ "pricePerHour",
152
+ "hourly_price",
153
+ "hourlyPrice",
154
+ "usd_per_hour",
155
+ "usdPerHour",
156
+ ):
157
+ price = _coerce_price(item.get(key))
158
+ if price is not None:
159
+ return price
160
+ for key in ("pricing", "billing", "cost"):
161
+ nested = item.get(key)
162
+ if isinstance(nested, dict):
163
+ price = _extract_hourly_price(nested)
164
+ if price is not None:
165
+ return price
166
+ return None
167
+
168
+
169
+ def _iter_hardware_items(payload: Any):
170
+ if isinstance(payload, list):
171
+ for item in payload:
172
+ yield from _iter_hardware_items(item)
173
+ elif isinstance(payload, dict):
174
+ if _extract_flavor(payload):
175
+ yield payload
176
+ for key in ("hardware", "flavors", "items", "data", "jobs"):
177
+ child = payload.get(key)
178
+ if child is not None:
179
+ yield from _iter_hardware_items(child)
180
+
181
+
182
+ def _parse_jobs_price_catalog(payload: Any) -> dict[str, float]:
183
+ prices: dict[str, float] = {}
184
+ for item in _iter_hardware_items(payload):
185
+ flavor = _extract_flavor(item)
186
+ price = _extract_hourly_price(item)
187
+ if flavor and price is not None:
188
+ prices[flavor] = price
189
+ return prices
190
+
191
+
192
+ async def hf_jobs_price_catalog() -> dict[str, float]:
193
+ """Return live HF Jobs hourly prices, falling back to static prices."""
194
+ global _jobs_price_cache
195
+ now = time.monotonic()
196
+ if _jobs_price_cache and now - _jobs_price_cache[0] < JOBS_PRICE_CACHE_TTL_S:
197
+ return dict(_jobs_price_cache[1])
198
+
199
+ prices: dict[str, float] = {}
200
+ try:
201
+ async with httpx.AsyncClient(timeout=3.0) as client:
202
+ response = await client.get(JOBS_HARDWARE_URL)
203
+ if response.status_code == 200:
204
+ prices = _parse_jobs_price_catalog(response.json())
205
+ except (httpx.HTTPError, ValueError):
206
+ prices = {}
207
+
208
+ if not prices:
209
+ prices = dict(HF_JOBS_PRICE_USD_PER_HOUR)
210
+ else:
211
+ prices = {**HF_JOBS_PRICE_USD_PER_HOUR, **prices}
212
+
213
+ _jobs_price_cache = (now, prices)
214
+ return dict(prices)
215
+
216
+
217
+ async def estimate_hf_job_cost(args: dict[str, Any]) -> CostEstimate:
218
+ flavor = str(
219
+ args.get("hardware_flavor")
220
+ or args.get("flavor")
221
+ or args.get("hardware")
222
+ or "cpu-basic"
223
+ )
224
+ timeout_hours = parse_timeout_hours(args.get("timeout"))
225
+ if timeout_hours is None:
226
+ return CostEstimate(
227
+ estimated_cost_usd=None,
228
+ billable=True,
229
+ block_reason=f"Could not parse HF job timeout: {args.get('timeout')!r}.",
230
+ label=flavor,
231
+ )
232
+
233
+ prices = await hf_jobs_price_catalog()
234
+ price = prices.get(flavor)
235
+ if price is None:
236
+ return CostEstimate(
237
+ estimated_cost_usd=None,
238
+ billable=True,
239
+ block_reason=f"No price is available for HF job hardware '{flavor}'.",
240
+ label=flavor,
241
+ )
242
+
243
+ return CostEstimate(
244
+ estimated_cost_usd=round(price * timeout_hours, 4),
245
+ billable=price > 0,
246
+ label=flavor,
247
+ )
248
+
249
+
250
+ async def estimate_sandbox_cost(args: dict[str, Any], *, session: Any = None) -> CostEstimate:
251
+ if session is not None and getattr(session, "sandbox", None):
252
+ return CostEstimate(estimated_cost_usd=0.0, billable=False, label="existing")
253
+
254
+ hardware = str(args.get("hardware") or "cpu-basic")
255
+ price = SPACE_PRICE_USD_PER_HOUR.get(hardware)
256
+ if price is None:
257
+ return CostEstimate(
258
+ estimated_cost_usd=None,
259
+ billable=True,
260
+ block_reason=f"No price is available for sandbox hardware '{hardware}'.",
261
+ label=hardware,
262
+ )
263
+
264
+ return CostEstimate(
265
+ estimated_cost_usd=round(price * DEFAULT_SANDBOX_RESERVATION_HOURS, 4),
266
+ billable=price > 0,
267
+ label=hardware,
268
+ )
269
+
270
+
271
+ async def estimate_tool_cost(
272
+ tool_name: str, args: dict[str, Any], *, session: Any = None
273
+ ) -> CostEstimate:
274
+ if tool_name == "sandbox_create":
275
+ return await estimate_sandbox_cost(args, session=session)
276
+ if tool_name == "hf_jobs":
277
+ return await estimate_hf_job_cost(args)
278
+ return CostEstimate(estimated_cost_usd=0.0, billable=False)
agent/core/session.py CHANGED
@@ -120,6 +120,9 @@ class Session:
120
  self.notification_gateway = notification_gateway
121
  self.notification_destinations = list(notification_destinations or [])
122
  self.defer_turn_complete_notification = defer_turn_complete_notification
 
 
 
123
 
124
  # Session trajectory logging
125
  self.logged_events: list[dict] = []
@@ -313,6 +316,40 @@ class Session:
313
  self.config.model_name = model_name
314
  self.context_manager.model_max_tokens = _get_max_tokens_safe(model_name)
315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  def effective_effort_for(self, model_name: str) -> str | None:
317
  """Resolve the effort level to actually send for ``model_name``.
318
 
 
120
  self.notification_gateway = notification_gateway
121
  self.notification_destinations = list(notification_destinations or [])
122
  self.defer_turn_complete_notification = defer_turn_complete_notification
123
+ self.auto_approval_enabled: bool = False
124
+ self.auto_approval_cost_cap_usd: float | None = None
125
+ self.auto_approval_estimated_spend_usd: float = 0.0
126
 
127
  # Session trajectory logging
128
  self.logged_events: list[dict] = []
 
316
  self.config.model_name = model_name
317
  self.context_manager.model_max_tokens = _get_max_tokens_safe(model_name)
318
 
319
+ def set_auto_approval_policy(
320
+ self, *, enabled: bool, cost_cap_usd: float | None
321
+ ) -> None:
322
+ self.auto_approval_enabled = bool(enabled)
323
+ self.auto_approval_cost_cap_usd = cost_cap_usd
324
+
325
+ def add_auto_approval_estimated_spend(self, amount_usd: float | None) -> None:
326
+ if amount_usd is None or amount_usd <= 0:
327
+ return
328
+ self.auto_approval_estimated_spend_usd = round(
329
+ self.auto_approval_estimated_spend_usd + float(amount_usd), 4
330
+ )
331
+
332
+ @property
333
+ def auto_approval_remaining_usd(self) -> float | None:
334
+ if self.auto_approval_cost_cap_usd is None:
335
+ return None
336
+ return round(
337
+ max(
338
+ 0.0,
339
+ self.auto_approval_cost_cap_usd
340
+ - self.auto_approval_estimated_spend_usd,
341
+ ),
342
+ 4,
343
+ )
344
+
345
+ def auto_approval_policy_summary(self) -> dict[str, Any]:
346
+ return {
347
+ "enabled": self.auto_approval_enabled,
348
+ "cost_cap_usd": self.auto_approval_cost_cap_usd,
349
+ "estimated_spend_usd": round(self.auto_approval_estimated_spend_usd, 4),
350
+ "remaining_usd": self.auto_approval_remaining_usd,
351
+ }
352
+
353
  def effective_effort_for(self, model_name: str) -> str | None:
354
  """Resolve the effort level to actually send for ``model_name``.
355
 
agent/core/session_persistence.py CHANGED
@@ -176,6 +176,9 @@ class MongoSessionStore(NoopSessionStore):
176
  pending_approval: list[dict[str, Any]] | None = None,
177
  claude_counted: bool = False,
178
  notification_destinations: list[str] | None = None,
 
 
 
179
  ) -> None:
180
  if not self._ready():
181
  return
@@ -204,6 +207,9 @@ class MongoSessionStore(NoopSessionStore):
204
  "pending_approval": pending_approval or [],
205
  "claude_counted": claude_counted,
206
  "notification_destinations": notification_destinations or [],
 
 
 
207
  },
208
  },
209
  upsert=True,
@@ -224,6 +230,9 @@ class MongoSessionStore(NoopSessionStore):
224
  claude_counted: bool = False,
225
  created_at: datetime | None = None,
226
  notification_destinations: list[str] | None = None,
 
 
 
227
  ) -> None:
228
  if not self._ready():
229
  return
@@ -241,6 +250,9 @@ class MongoSessionStore(NoopSessionStore):
241
  pending_approval=pending_approval,
242
  claude_counted=claude_counted,
243
  notification_destinations=notification_destinations,
 
 
 
244
  )
245
  ops: list[Any] = []
246
  for idx, raw in enumerate(messages):
 
176
  pending_approval: list[dict[str, Any]] | None = None,
177
  claude_counted: bool = False,
178
  notification_destinations: list[str] | None = None,
179
+ auto_approval_enabled: bool = False,
180
+ auto_approval_cost_cap_usd: float | None = None,
181
+ auto_approval_estimated_spend_usd: float = 0.0,
182
  ) -> None:
183
  if not self._ready():
184
  return
 
207
  "pending_approval": pending_approval or [],
208
  "claude_counted": claude_counted,
209
  "notification_destinations": notification_destinations or [],
210
+ "auto_approval_enabled": auto_approval_enabled,
211
+ "auto_approval_cost_cap_usd": auto_approval_cost_cap_usd,
212
+ "auto_approval_estimated_spend_usd": auto_approval_estimated_spend_usd,
213
  },
214
  },
215
  upsert=True,
 
230
  claude_counted: bool = False,
231
  created_at: datetime | None = None,
232
  notification_destinations: list[str] | None = None,
233
+ auto_approval_enabled: bool = False,
234
+ auto_approval_cost_cap_usd: float | None = None,
235
+ auto_approval_estimated_spend_usd: float = 0.0,
236
  ) -> None:
237
  if not self._ready():
238
  return
 
250
  pending_approval=pending_approval,
251
  claude_counted=claude_counted,
252
  notification_destinations=notification_destinations,
253
+ auto_approval_enabled=auto_approval_enabled,
254
+ auto_approval_cost_cap_usd=auto_approval_cost_cap_usd,
255
+ auto_approval_estimated_spend_usd=auto_approval_estimated_spend_usd,
256
  )
257
  ops: list[Any] = []
258
  for idx, raw in enumerate(messages):
agent/main.py CHANGED
@@ -21,6 +21,7 @@ import litellm
21
  from prompt_toolkit import PromptSession
22
 
23
  from agent.config import load_config
 
24
  from agent.core.agent_loop import submission_loop
25
  from agent.core import model_switcher
26
  from agent.core.hf_tokens import resolve_hf_token
@@ -55,6 +56,20 @@ litellm.suppress_debug_info = True
55
  CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json"
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def _configure_runtime_logging() -> None:
59
  """Keep third-party warning spam from punching through the interactive UI."""
60
  import logging
@@ -375,8 +390,11 @@ async def event_listener(
375
  tools_data = event.data.get("tools", []) if event.data else []
376
  count = event.data.get("count", 0) if event.data else 0
377
 
378
- # If yolo mode is active, auto-approve everything
379
- if config and config.yolo_mode:
 
 
 
380
  approvals = [
381
  {
382
  "tool_call_id": t.get("tool_call_id", ""),
@@ -1293,14 +1311,18 @@ async def headless_main(
1293
  else:
1294
  print_tool_log(tool, log)
1295
  elif event.event_type == "approval_required":
1296
- # Auto-approve everything in headless mode (safety net if yolo_mode
1297
- # didn't prevent the approval event for some reason)
1298
  tools_data = event.data.get("tools", []) if event.data else []
1299
  approvals = [
1300
  {
1301
  "tool_call_id": t.get("tool_call_id", ""),
1302
- "approved": True,
1303
- "feedback": None,
 
 
 
 
1304
  }
1305
  for t in tools_data
1306
  ]
 
21
  from prompt_toolkit import PromptSession
22
 
23
  from agent.config import load_config
24
+ from agent.core.approval_policy import is_scheduled_operation
25
  from agent.core.agent_loop import submission_loop
26
  from agent.core import model_switcher
27
  from agent.core.hf_tokens import resolve_hf_token
 
56
  CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json"
57
 
58
 
59
+ def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
60
+ if tool_info.get("tool") != "hf_jobs":
61
+ return False
62
+ arguments = tool_info.get("arguments") or {}
63
+ if isinstance(arguments, str):
64
+ try:
65
+ arguments = json.loads(arguments)
66
+ except json.JSONDecodeError:
67
+ return False
68
+ if not isinstance(arguments, dict):
69
+ return False
70
+ return is_scheduled_operation(arguments.get("operation"))
71
+
72
+
73
  def _configure_runtime_logging() -> None:
74
  """Keep third-party warning spam from punching through the interactive UI."""
75
  import logging
 
390
  tools_data = event.data.get("tools", []) if event.data else []
391
  count = event.data.get("count", 0) if event.data else 0
392
 
393
+ # If yolo mode is active, auto-approve everything except
394
+ # scheduled HF jobs, whose recurring cost stays manual.
395
+ if config and config.yolo_mode and not any(
396
+ _is_scheduled_hf_job_tool(t) for t in tools_data
397
+ ):
398
  approvals = [
399
  {
400
  "tool_call_id": t.get("tool_call_id", ""),
 
1311
  else:
1312
  print_tool_log(tool, log)
1313
  elif event.event_type == "approval_required":
1314
+ # Auto-approve in headless mode, except scheduled HF jobs. Those
1315
+ # are rejected because their recurring cost needs manual approval.
1316
  tools_data = event.data.get("tools", []) if event.data else []
1317
  approvals = [
1318
  {
1319
  "tool_call_id": t.get("tool_call_id", ""),
1320
+ "approved": not _is_scheduled_hf_job_tool(t),
1321
+ "feedback": (
1322
+ "Scheduled HF jobs require manual approval."
1323
+ if _is_scheduled_hf_job_tool(t)
1324
+ else None
1325
+ ),
1326
  }
1327
  for t in tools_data
1328
  ]
agent/prompts/system_prompt_v3.yaml CHANGED
@@ -42,7 +42,7 @@ system_prompt: |
42
 
43
  SILENT DATASET SUBSTITUTION: When a requested dataset fails to load, you will silently switch to a different one without telling the user. Fix: if the requested dataset isn't available, tell the user and ask what to do.
44
 
45
- HARDCODED UNAVAILABLE PACKAGES: You will forget to install necessary packages like 'flash-attn' for flash_attention_2 or other packages that aren't automatically installed in the job environment. Fix: install necessary packages before running the job.
46
 
47
  SCOPE-CHANGING FIXES: Avoid at all costs! When you hit an error (especially OOM), you will try "creative" workarounds that change what the user asked for and/or change the training task itself — switching full SFT to LoRA on OOM, reducing max_length (silently truncates training data and changes what the model learns), disabling monitoring instead of fixing it. Do not do this. Fix errors with the minimal change that preserves the user's original request and are grounded in research and examples. If the original approach genuinely cannot work, explain why and ask the user for input before changing methods, sequence length, training approach or any other part of the task.
48
 
 
42
 
43
  SILENT DATASET SUBSTITUTION: When a requested dataset fails to load, you will silently switch to a different one without telling the user. Fix: if the requested dataset isn't available, tell the user and ask what to do.
44
 
45
+ PREFER HUB KERNELS OVER COMPILING ATTENTION: Do NOT pip install 'flash-attn' to enable flash_attention_2 building from source can take many minutes to hours and often fails on the job's CUDA/PyTorch combo. Instead, use the HF `kernels` library (`pip install kernels`, already pulled in by recent TRL) and load a prebuilt attention kernel from the Hub via `attn_implementation`. Examples: `AutoModelForCausalLM.from_pretrained(..., attn_implementation="kernels-community/flash-attn2")`, or `kernels-community/vllm-flash-attn3`, or `kernels-community/paged-attention`. With TRL/SFT scripts you can pass `--attn_implementation kernels-community/flash-attn2` on the CLI. Search additional kernels at https://huggingface.co/models?other=kernel. Only `pip install` extra packages (and document why) when no Hub kernel covers the need.
46
 
47
  SCOPE-CHANGING FIXES: Avoid at all costs! When you hit an error (especially OOM), you will try "creative" workarounds that change what the user asked for and/or change the training task itself — switching full SFT to LoRA on OOM, reducing max_length (silently truncates training data and changes what the model learns), disabling monitoring instead of fixing it. Do not do this. Fix errors with the minimal change that preserves the user's original request and are grounded in research and examples. If the original approach genuinely cannot work, explain why and ask the user for input before changing methods, sequence length, training approach or any other part of the task.
48
 
agent/tools/docs_tools.py CHANGED
@@ -932,7 +932,7 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
932
  "• argilla — Data annotation, feedback, and human-in-the-loop workflows.\n"
933
  "• distilabel — Synthetic data generation and distillation pipelines.\n"
934
  "• microsoft-azure — Azure deployment and integration guides.\n"
935
- "• kernels — Lightweight execution environments and notebook-style workflows.\n"
936
  "• google-cloud — GCP deployment and serving workflows.\n"
937
  ),
938
  },
 
932
  "• argilla — Data annotation, feedback, and human-in-the-loop workflows.\n"
933
  "• distilabel — Synthetic data generation and distillation pipelines.\n"
934
  "• microsoft-azure — Azure deployment and integration guides.\n"
935
+ "• kernels — Load prebuilt compute kernels (E.g. flash-attn2) from the Hub via `attn_implementation`; avoids compiling flash-attn from source.\n"
936
  "• google-cloud — GCP deployment and serving workflows.\n"
937
  ),
938
  },
backend/models.py CHANGED
@@ -76,6 +76,15 @@ class PendingApprovalTool(BaseModel):
76
  arguments: dict[str, Any] = {}
77
 
78
 
 
 
 
 
 
 
 
 
 
79
  class SessionInfo(BaseModel):
80
  """Session metadata."""
81
 
@@ -89,6 +98,9 @@ class SessionInfo(BaseModel):
89
  model: str | None = None
90
  title: str | None = None
91
  notification_destinations: list[str] = Field(default_factory=list)
 
 
 
92
 
93
 
94
  class SessionNotificationsRequest(BaseModel):
@@ -97,6 +109,13 @@ class SessionNotificationsRequest(BaseModel):
97
  destinations: list[str]
98
 
99
 
 
 
 
 
 
 
 
100
  class HealthResponse(BaseModel):
101
  """Health check response."""
102
 
 
76
  arguments: dict[str, Any] = {}
77
 
78
 
79
+ class SessionAutoApprovalInfo(BaseModel):
80
+ """Per-session auto-approval budget state."""
81
+
82
+ enabled: bool = False
83
+ cost_cap_usd: float | None = None
84
+ estimated_spend_usd: float = 0.0
85
+ remaining_usd: float | None = None
86
+
87
+
88
  class SessionInfo(BaseModel):
89
  """Session metadata."""
90
 
 
98
  model: str | None = None
99
  title: str | None = None
100
  notification_destinations: list[str] = Field(default_factory=list)
101
+ auto_approval: SessionAutoApprovalInfo = Field(
102
+ default_factory=SessionAutoApprovalInfo
103
+ )
104
 
105
 
106
  class SessionNotificationsRequest(BaseModel):
 
109
  destinations: list[str]
110
 
111
 
112
+ class SessionYoloRequest(BaseModel):
113
+ """Update a session's auto-approval policy."""
114
+
115
+ enabled: bool
116
+ cost_cap_usd: float | None = Field(default=None, ge=0)
117
+
118
+
119
  class HealthResponse(BaseModel):
120
  """Health check response."""
121
 
backend/routes/agent.py CHANGED
@@ -26,6 +26,7 @@ from models import (
26
  SessionInfo,
27
  SessionNotificationsRequest,
28
  SessionResponse,
 
29
  SubmitRequest,
30
  TruncateRequest,
31
  )
@@ -498,6 +499,26 @@ async def set_session_notifications(
498
  }
499
 
500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  @router.get("/user/quota")
502
  async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
503
  """Return the user's plan tier and today's premium-model quota state."""
 
26
  SessionInfo,
27
  SessionNotificationsRequest,
28
  SessionResponse,
29
+ SessionYoloRequest,
30
  SubmitRequest,
31
  TruncateRequest,
32
  )
 
499
  }
500
 
501
 
502
+ @router.patch("/session/{session_id}/yolo")
503
+ async def set_session_yolo(
504
+ session_id: str,
505
+ body: SessionYoloRequest,
506
+ user: dict = Depends(get_current_user),
507
+ ) -> dict:
508
+ """Update the session-scoped auto-approval policy."""
509
+ await _check_session_access(session_id, user)
510
+ try:
511
+ summary = await session_manager.update_session_auto_approval(
512
+ session_id,
513
+ enabled=body.enabled,
514
+ cost_cap_usd=body.cost_cap_usd,
515
+ cap_provided="cost_cap_usd" in body.model_fields_set,
516
+ )
517
+ except ValueError as e:
518
+ raise HTTPException(status_code=400, detail=str(e))
519
+ return {"session_id": session_id, **summary}
520
+
521
+
522
  @router.get("/user/quota")
523
  async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
524
  """Return the user's plan tier and today's premium-model quota state."""
backend/session_manager.py CHANGED
@@ -116,6 +116,7 @@ class SessionCapacityError(Exception):
116
  # and per-request overhead.
117
  MAX_SESSIONS: int = 200
118
  MAX_SESSIONS_PER_USER: int = 10
 
119
 
120
 
121
  class SessionManager:
@@ -297,6 +298,20 @@ class SessionManager:
297
  return "ended"
298
  return "idle"
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  async def _start_agent_session(
301
  self,
302
  *,
@@ -370,6 +385,20 @@ class SessionManager:
370
  notification_destinations=list(
371
  agent_session.session.notification_destinations
372
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  )
374
  except Exception as e:
375
  logger.warning(
@@ -451,6 +480,14 @@ class SessionManager:
451
 
452
  self._restore_pending_approval(session, meta.get("pending_approval") or [])
453
  session.turn_count = int(meta.get("turn_count") or 0)
 
 
 
 
 
 
 
 
454
 
455
  created_at = meta.get("created_at")
456
  if not isinstance(created_at, datetime):
@@ -883,6 +920,43 @@ class SessionManager:
883
  await self.persist_session_snapshot(agent_session, runtime_state="idle")
884
  return True
885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886
  def get_session_owner(self, session_id: str) -> str | None:
887
  """Get the user_id that owns a session, or None if session doesn't exist."""
888
  agent_session = self.sessions.get(session_id)
@@ -925,6 +999,7 @@ class SessionManager:
925
  "notification_destinations": list(
926
  agent_session.session.notification_destinations
927
  ),
 
928
  }
929
 
930
  def set_notification_destinations(
@@ -991,6 +1066,25 @@ class SessionManager:
991
  "model": row.get("model"),
992
  "title": row.get("title"),
993
  "notification_destinations": row.get("notification_destinations") or [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
994
  }
995
  )
996
  return results
 
116
  # and per-request overhead.
117
  MAX_SESSIONS: int = 200
118
  MAX_SESSIONS_PER_USER: int = 10
119
+ DEFAULT_YOLO_COST_CAP_USD: float = 5.0
120
 
121
 
122
  class SessionManager:
 
298
  return "ended"
299
  return "idle"
300
 
301
+ @staticmethod
302
+ def _auto_approval_summary(session: Session) -> dict[str, Any]:
303
+ if hasattr(session, "auto_approval_policy_summary"):
304
+ return session.auto_approval_policy_summary()
305
+ cap = getattr(session, "auto_approval_cost_cap_usd", None)
306
+ estimated = float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0)
307
+ remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4)
308
+ return {
309
+ "enabled": bool(getattr(session, "auto_approval_enabled", False)),
310
+ "cost_cap_usd": cap,
311
+ "estimated_spend_usd": round(estimated, 4),
312
+ "remaining_usd": remaining,
313
+ }
314
+
315
  async def _start_agent_session(
316
  self,
317
  *,
 
385
  notification_destinations=list(
386
  agent_session.session.notification_destinations
387
  ),
388
+ auto_approval_enabled=bool(
389
+ getattr(agent_session.session, "auto_approval_enabled", False)
390
+ ),
391
+ auto_approval_cost_cap_usd=getattr(
392
+ agent_session.session, "auto_approval_cost_cap_usd", None
393
+ ),
394
+ auto_approval_estimated_spend_usd=float(
395
+ getattr(
396
+ agent_session.session,
397
+ "auto_approval_estimated_spend_usd",
398
+ 0.0,
399
+ )
400
+ or 0.0
401
+ ),
402
  )
403
  except Exception as e:
404
  logger.warning(
 
480
 
481
  self._restore_pending_approval(session, meta.get("pending_approval") or [])
482
  session.turn_count = int(meta.get("turn_count") or 0)
483
+ session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False))
484
+ raw_cap = meta.get("auto_approval_cost_cap_usd")
485
+ session.auto_approval_cost_cap_usd = (
486
+ float(raw_cap) if isinstance(raw_cap, int | float) else None
487
+ )
488
+ session.auto_approval_estimated_spend_usd = float(
489
+ meta.get("auto_approval_estimated_spend_usd") or 0.0
490
+ )
491
 
492
  created_at = meta.get("created_at")
493
  if not isinstance(created_at, datetime):
 
920
  await self.persist_session_snapshot(agent_session, runtime_state="idle")
921
  return True
922
 
923
+ async def update_session_auto_approval(
924
+ self,
925
+ session_id: str,
926
+ *,
927
+ enabled: bool,
928
+ cost_cap_usd: float | None,
929
+ cap_provided: bool = False,
930
+ ) -> dict[str, Any]:
931
+ agent_session = self.sessions.get(session_id)
932
+ if not agent_session or not agent_session.is_active:
933
+ raise ValueError("Session not found or inactive")
934
+
935
+ session = agent_session.session
936
+ if enabled:
937
+ if not cap_provided and cost_cap_usd is None:
938
+ cost_cap_usd = getattr(
939
+ session, "auto_approval_cost_cap_usd", None
940
+ )
941
+ if cost_cap_usd is None:
942
+ cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD
943
+ elif cost_cap_usd is None:
944
+ cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD
945
+ else:
946
+ if not cap_provided:
947
+ cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None)
948
+
949
+ if hasattr(session, "set_auto_approval_policy"):
950
+ session.set_auto_approval_policy(
951
+ enabled=enabled,
952
+ cost_cap_usd=cost_cap_usd,
953
+ )
954
+ else:
955
+ session.auto_approval_enabled = bool(enabled)
956
+ session.auto_approval_cost_cap_usd = cost_cap_usd
957
+ await self.persist_session_snapshot(agent_session)
958
+ return self._auto_approval_summary(session)
959
+
960
  def get_session_owner(self, session_id: str) -> str | None:
961
  """Get the user_id that owns a session, or None if session doesn't exist."""
962
  agent_session = self.sessions.get(session_id)
 
999
  "notification_destinations": list(
1000
  agent_session.session.notification_destinations
1001
  ),
1002
+ "auto_approval": self._auto_approval_summary(agent_session.session),
1003
  }
1004
 
1005
  def set_notification_destinations(
 
1066
  "model": row.get("model"),
1067
  "title": row.get("title"),
1068
  "notification_destinations": row.get("notification_destinations") or [],
1069
+ "auto_approval": {
1070
+ "enabled": bool(row.get("auto_approval_enabled", False)),
1071
+ "cost_cap_usd": row.get("auto_approval_cost_cap_usd"),
1072
+ "estimated_spend_usd": float(
1073
+ row.get("auto_approval_estimated_spend_usd") or 0.0
1074
+ ),
1075
+ "remaining_usd": (
1076
+ None
1077
+ if row.get("auto_approval_cost_cap_usd") is None
1078
+ else round(
1079
+ max(
1080
+ 0.0,
1081
+ float(row.get("auto_approval_cost_cap_usd") or 0.0)
1082
+ - float(row.get("auto_approval_estimated_spend_usd") or 0.0),
1083
+ ),
1084
+ 4,
1085
+ )
1086
+ ),
1087
+ },
1088
  }
1089
  )
1090
  return results
frontend/src/components/Chat/ToolCallGroup.tsx CHANGED
@@ -1,5 +1,5 @@
1
  import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
2
- import { Box, Stack, Typography, Chip, Button, TextField, IconButton, Link, CircularProgress } from '@mui/material';
3
  import CheckCircleOutlineIcon from '@mui/icons-material/CheckCircleOutline';
4
  import ErrorOutlineIcon from '@mui/icons-material/ErrorOutline';
5
  import OpenInNewIcon from '@mui/icons-material/OpenInNew';
@@ -502,6 +502,7 @@ function InlineApproval({
502
  }) {
503
  const [feedback, setFeedback] = useState('');
504
  const args = input as Record<string, unknown> | undefined;
 
505
  const { setPanel, getEditedScript } = useAgentStore();
506
  const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore();
507
  const hasEditedScript = !!getEditedScript(toolCallId);
@@ -521,6 +522,24 @@ function InlineApproval({
521
 
522
  return (
523
  <Box sx={{ px: 1.5, py: 1.5, borderTop: '1px solid var(--tool-border)' }}>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  {toolName === 'sandbox_create' && args && (() => {
525
  const hw = String(args.hardware || 'cpu-basic');
526
  const cost = costLabel(hw);
 
1
  import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
2
+ import { Alert, Box, Stack, Typography, Chip, Button, TextField, IconButton, Link, CircularProgress } from '@mui/material';
3
  import CheckCircleOutlineIcon from '@mui/icons-material/CheckCircleOutline';
4
  import ErrorOutlineIcon from '@mui/icons-material/ErrorOutline';
5
  import OpenInNewIcon from '@mui/icons-material/OpenInNew';
 
502
  }) {
503
  const [feedback, setFeedback] = useState('');
504
  const args = input as Record<string, unknown> | undefined;
505
+ const autoApproval = useAgentStore((state) => state.budgetBlocks[toolCallId]);
506
  const { setPanel, getEditedScript } = useAgentStore();
507
  const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore();
508
  const hasEditedScript = !!getEditedScript(toolCallId);
 
522
 
523
  return (
524
  <Box sx={{ px: 1.5, py: 1.5, borderTop: '1px solid var(--tool-border)' }}>
525
+ {autoApproval && (
526
+ <Alert
527
+ severity="warning"
528
+ sx={{
529
+ mb: 1.5,
530
+ py: 0.5,
531
+ bgcolor: 'rgba(245,158,11,0.08)',
532
+ border: '1px solid rgba(245,158,11,0.18)',
533
+ color: 'var(--text)',
534
+ '& .MuiAlert-icon': { color: 'var(--accent-yellow)' },
535
+ }}
536
+ >
537
+ <Typography variant="body2" sx={{ fontSize: '0.72rem' }}>
538
+ YOLO paused: {autoApproval.reason || 'manual approval required.'}
539
+ </Typography>
540
+ </Alert>
541
+ )}
542
+
543
  {toolName === 'sandbox_create' && args && (() => {
544
  const hw = String(args.hardware || 'cpu-basic');
545
  const cost = costLabel(hw);
frontend/src/components/Layout/AppLayout.tsx CHANGED
@@ -24,6 +24,7 @@ import SessionSidebar from '@/components/SessionSidebar/SessionSidebar';
24
  import SessionChat from '@/components/SessionChat';
25
  import CodePanel from '@/components/CodePanel/CodePanel';
26
  import WelcomeScreen from '@/components/WelcomeScreen/WelcomeScreen';
 
27
  import { apiFetch } from '@/utils/api';
28
 
29
  const DRAWER_WIDTH = 260;
@@ -252,6 +253,7 @@ export default function AppLayout() {
252
  </Box>
253
 
254
  <Box sx={{ display: 'flex', alignItems: 'center', gap: 0.5 }}>
 
255
  <IconButton
256
  onClick={toggleTheme}
257
  size="small"
 
24
  import SessionChat from '@/components/SessionChat';
25
  import CodePanel from '@/components/CodePanel/CodePanel';
26
  import WelcomeScreen from '@/components/WelcomeScreen/WelcomeScreen';
27
+ import YoloControl from '@/components/YoloControl';
28
  import { apiFetch } from '@/utils/api';
29
 
30
  const DRAWER_WIDTH = 260;
 
253
  </Box>
254
 
255
  <Box sx={{ display: 'flex', alignItems: 'center', gap: 0.5 }}>
256
+ <YoloControl />
257
  <IconButton
258
  onClick={toggleTheme}
259
  size="small"
frontend/src/components/YoloControl.tsx ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useMemo, useState } from 'react';
2
+ import {
3
+ Button,
4
+ Dialog,
5
+ DialogActions,
6
+ DialogContent,
7
+ DialogTitle,
8
+ TextField,
9
+ Tooltip,
10
+ Typography,
11
+ } from '@mui/material';
12
+ import BoltOutlinedIcon from '@mui/icons-material/BoltOutlined';
13
+ import { useSessionStore } from '@/store/sessionStore';
14
+ import { apiFetch } from '@/utils/api';
15
+
16
+ const DEFAULT_CAP_USD = 5;
17
+
18
+ function money(value: number | null | undefined): string {
19
+ if (value === null || value === undefined) return 'uncapped';
20
+ if (value >= 100) return `$${value.toFixed(0)}`;
21
+ return `$${value.toFixed(2).replace(/\.00$/, '')}`;
22
+ }
23
+
24
+ export default function YoloControl() {
25
+ const { sessions, activeSessionId, updateSessionYolo } = useSessionStore();
26
+ const activeSession = useMemo(
27
+ () => sessions.find((s) => s.id === activeSessionId) || null,
28
+ [sessions, activeSessionId],
29
+ );
30
+ const [dialogOpen, setDialogOpen] = useState(false);
31
+ const [capInput, setCapInput] = useState(String(DEFAULT_CAP_USD));
32
+ const [busy, setBusy] = useState(false);
33
+ const [error, setError] = useState<string | null>(null);
34
+
35
+ const enabled = Boolean(activeSession?.autoApprovalEnabled);
36
+ const disabled = !activeSessionId || activeSession?.expired || busy;
37
+ const remaining = activeSession?.autoApprovalRemainingUsd ?? null;
38
+ const cap = activeSession?.autoApprovalCostCapUsd ?? null;
39
+
40
+ useEffect(() => {
41
+ if (!activeSession) return;
42
+ setCapInput(String(activeSession.autoApprovalCostCapUsd ?? DEFAULT_CAP_USD));
43
+ }, [activeSession?.id, activeSession?.autoApprovalCostCapUsd]); // eslint-disable-line react-hooks/exhaustive-deps
44
+
45
+ async function patchPolicy(nextEnabled: boolean, nextCap?: number) {
46
+ if (!activeSessionId) return null;
47
+ setBusy(true);
48
+ setError(null);
49
+ try {
50
+ const body: Record<string, unknown> = { enabled: nextEnabled };
51
+ if (nextCap !== undefined) body.cost_cap_usd = nextCap;
52
+ const response = await apiFetch(`/api/session/${activeSessionId}/yolo`, {
53
+ method: 'PATCH',
54
+ body: JSON.stringify(body),
55
+ });
56
+ if (!response.ok) {
57
+ throw new Error(await response.text());
58
+ }
59
+ const data = await response.json();
60
+ updateSessionYolo(activeSessionId, data);
61
+ return data;
62
+ } catch {
63
+ setError('Could not update YOLO settings.');
64
+ return null;
65
+ } finally {
66
+ setBusy(false);
67
+ }
68
+ }
69
+
70
+ const handleToggle = async () => {
71
+ if (disabled) return;
72
+ if (enabled) {
73
+ await patchPolicy(false);
74
+ return;
75
+ }
76
+ const nextCap = cap ?? DEFAULT_CAP_USD;
77
+ const updated = await patchPolicy(true, nextCap);
78
+ if (updated) {
79
+ setCapInput(String(updated.cost_cap_usd ?? nextCap));
80
+ setDialogOpen(true);
81
+ }
82
+ };
83
+
84
+ const handleSaveCap = async () => {
85
+ const parsed = Number(capInput);
86
+ if (!Number.isFinite(parsed) || parsed < 0) {
87
+ setError('Enter a non-negative dollar amount.');
88
+ return;
89
+ }
90
+ const updated = await patchPolicy(true, parsed);
91
+ if (updated) setDialogOpen(false);
92
+ };
93
+
94
+ return (
95
+ <>
96
+ <Tooltip title={enabled ? 'Disable session YOLO auto-approval' : 'Enable session YOLO auto-approval'}>
97
+ <span>
98
+ <Button
99
+ size="small"
100
+ variant={enabled ? 'contained' : 'outlined'}
101
+ disabled={disabled}
102
+ onClick={handleToggle}
103
+ startIcon={<BoltOutlinedIcon sx={{ fontSize: 16 }} />}
104
+ sx={{
105
+ minWidth: { xs: 74, md: 116 },
106
+ height: 32,
107
+ px: { xs: 1, md: 1.25 },
108
+ borderRadius: '8px',
109
+ textTransform: 'none',
110
+ fontSize: '0.72rem',
111
+ whiteSpace: 'nowrap',
112
+ bgcolor: enabled ? 'var(--accent-yellow)' : 'transparent',
113
+ color: enabled ? '#111' : 'text.secondary',
114
+ borderColor: enabled ? 'var(--accent-yellow)' : 'divider',
115
+ '&:hover': {
116
+ bgcolor: enabled ? 'var(--accent-yellow)' : 'action.hover',
117
+ borderColor: 'var(--accent-yellow)',
118
+ },
119
+ }}
120
+ >
121
+ {enabled ? `YOLO ${money(remaining)}` : 'YOLO'}
122
+ </Button>
123
+ </span>
124
+ </Tooltip>
125
+
126
+ <Dialog open={dialogOpen} onClose={() => setDialogOpen(false)} maxWidth="xs" fullWidth>
127
+ <DialogTitle sx={{ pb: 1 }}>YOLO Budget</DialogTitle>
128
+ <DialogContent sx={{ display: 'flex', flexDirection: 'column', gap: 1.5, pt: 1 }}>
129
+ <Typography variant="body2" color="text.secondary">
130
+ Auto-approval is active for this session. Scheduled HF jobs still require approval.
131
+ </Typography>
132
+ <TextField
133
+ autoFocus
134
+ label="Session cap (USD)"
135
+ type="number"
136
+ size="small"
137
+ value={capInput}
138
+ onChange={(e) => setCapInput(e.target.value)}
139
+ inputProps={{ min: 0, step: 0.5 }}
140
+ error={Boolean(error)}
141
+ helperText={error || `Estimated spend: ${money(activeSession?.autoApprovalEstimatedSpendUsd ?? 0)} of ${money(cap)}`}
142
+ />
143
+ </DialogContent>
144
+ <DialogActions>
145
+ <Button onClick={() => setDialogOpen(false)} sx={{ textTransform: 'none' }}>
146
+ Close
147
+ </Button>
148
+ <Button onClick={handleSaveCap} disabled={busy} variant="contained" sx={{ textTransform: 'none' }}>
149
+ Save
150
+ </Button>
151
+ </DialogActions>
152
+ </Dialog>
153
+ </>
154
+ );
155
+ }
frontend/src/hooks/useAgentChat.ts CHANGED
@@ -36,7 +36,7 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
36
  const isActiveRef = useRef(isActive);
37
  isActiveRef.current = isActive;
38
 
39
- const { setNeedsAttention } = useSessionStore();
40
 
41
  // Helper: update this session's state (mirrors to globals if active)
42
  const updateSession = useAgentStore.getState().updateSession;
@@ -186,6 +186,20 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
186
  if (!tools.length) return;
187
  setNeedsAttention(sessionId, true);
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  updateSession(sessionId, { activityStatus: { type: 'waiting-approval' } });
190
 
191
  // Build panel data for this session's pending approval
@@ -480,6 +494,9 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
480
  );
481
  if (pendingIds.size > 0) setNeedsAttention(sessionId, true);
482
  }
 
 
 
483
  return { data, pendingIds, info };
484
  }
485
  return { data, pendingIds, info: null };
@@ -562,7 +579,15 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
562
  return true;
563
  } else if (et === 'approval_required') {
564
  sideChannel.onApprovalRequired(
565
- (event.data?.tools || []) as Array<{ tool: string; arguments: Record<string, unknown>; tool_call_id: string }>,
 
 
 
 
 
 
 
 
566
  );
567
  stopReconnect();
568
  const result = await hydrateMessages();
 
36
  const isActiveRef = useRef(isActive);
37
  isActiveRef.current = isActive;
38
 
39
+ const { setNeedsAttention, updateSessionYolo } = useSessionStore();
40
 
41
  // Helper: update this session's state (mirrors to globals if active)
42
  const updateSession = useAgentStore.getState().updateSession;
 
186
  if (!tools.length) return;
187
  setNeedsAttention(sessionId, true);
188
 
189
+ const store = useAgentStore.getState();
190
+ for (const tool of tools) {
191
+ store.setToolBudgetBlock(
192
+ tool.tool_call_id,
193
+ tool.auto_approval_blocked
194
+ ? {
195
+ reason: tool.block_reason ?? null,
196
+ estimatedCostUsd: tool.estimated_cost_usd ?? null,
197
+ remainingCapUsd: tool.remaining_cap_usd ?? null,
198
+ }
199
+ : null,
200
+ );
201
+ }
202
+
203
  updateSession(sessionId, { activityStatus: { type: 'waiting-approval' } });
204
 
205
  // Build panel data for this session's pending approval
 
494
  );
495
  if (pendingIds.size > 0) setNeedsAttention(sessionId, true);
496
  }
497
+ if (info.auto_approval) {
498
+ updateSessionYolo(sessionId, info.auto_approval);
499
+ }
500
  return { data, pendingIds, info };
501
  }
502
  return { data, pendingIds, info: null };
 
579
  return true;
580
  } else if (et === 'approval_required') {
581
  sideChannel.onApprovalRequired(
582
+ (event.data?.tools || []) as Array<{
583
+ tool: string;
584
+ arguments: Record<string, unknown>;
585
+ tool_call_id: string;
586
+ auto_approval_blocked?: boolean;
587
+ block_reason?: string | null;
588
+ estimated_cost_usd?: number | null;
589
+ remaining_cap_usd?: number | null;
590
+ }>,
591
  );
592
  stopReconnect();
593
  const result = await hydrateMessages();
frontend/src/lib/sse-chat-transport.ts CHANGED
@@ -26,7 +26,15 @@ export interface SideChannelCallbacks {
26
  onToolLog: (tool: string, log: string, agentId?: string, label?: string) => void;
27
  onConnectionChange: (connected: boolean) => void;
28
  onSessionDead: (sessionId: string) => void;
29
- onApprovalRequired: (tools: Array<{ tool: string; arguments: Record<string, unknown>; tool_call_id: string }>) => void;
 
 
 
 
 
 
 
 
30
  onToolCallPanel: (tool: string, args: Record<string, unknown>) => void;
31
  onToolOutputPanel: (tool: string, toolCallId: string, output: string, success: boolean) => void;
32
  onStreaming: () => void;
@@ -236,6 +244,10 @@ function createEventToChunkStream(sideChannel: SideChannelCallbacks): TransformS
236
  tool: string;
237
  arguments: Record<string, unknown>;
238
  tool_call_id: string;
 
 
 
 
239
  }>;
240
  if (!tools) break;
241
 
 
26
  onToolLog: (tool: string, log: string, agentId?: string, label?: string) => void;
27
  onConnectionChange: (connected: boolean) => void;
28
  onSessionDead: (sessionId: string) => void;
29
+ onApprovalRequired: (tools: Array<{
30
+ tool: string;
31
+ arguments: Record<string, unknown>;
32
+ tool_call_id: string;
33
+ auto_approval_blocked?: boolean;
34
+ block_reason?: string | null;
35
+ estimated_cost_usd?: number | null;
36
+ remaining_cap_usd?: number | null;
37
+ }>) => void;
38
  onToolCallPanel: (tool: string, args: Record<string, unknown>) => void;
39
  onToolOutputPanel: (tool: string, toolCallId: string, output: string, success: boolean) => void;
40
  onStreaming: () => void;
 
244
  tool: string;
245
  arguments: Record<string, unknown>;
246
  tool_call_id: string;
247
+ auto_approval_blocked?: boolean;
248
+ block_reason?: string | null;
249
+ estimated_cost_usd?: number | null;
250
+ remaining_cap_usd?: number | null;
251
  }>;
252
  if (!tools) break;
253
 
frontend/src/store/agentStore.ts CHANGED
@@ -50,6 +50,12 @@ export interface JobsUpgradeState {
50
  namespace?: string | null;
51
  }
52
 
 
 
 
 
 
 
53
  export type ActivityStatus =
54
  | { type: 'idle' }
55
  | { type: 'thinking' }
@@ -145,6 +151,9 @@ interface AgentStore {
145
  // Tool rejected states (tool_call_id -> true if rejected by user) - persisted across renders
146
  rejectedTools: Record<string, boolean>;
147
 
 
 
 
148
  // ── Per-session actions ─────────────────────────────────────────────
149
 
150
  /** Update a session's state. If it's the active session, also update flat state. */
@@ -196,6 +205,9 @@ interface AgentStore {
196
 
197
  setToolRejected: (toolCallId: string, isRejected: boolean) => void;
198
  getToolRejected: (toolCallId: string) => boolean | undefined;
 
 
 
199
  }
200
 
201
  /**
@@ -300,6 +312,7 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
300
  trackioDashboards: loadTrackioDashboards(),
301
  toolErrors: loadToolErrors(),
302
  rejectedTools: loadRejectedTools(),
 
303
 
304
  // ── Per-session state management ──────────────────────────────────
305
 
@@ -529,4 +542,24 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
529
  },
530
 
531
  getToolRejected: (toolCallId) => get().rejectedTools[toolCallId],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  }));
 
50
  namespace?: string | null;
51
  }
52
 
53
+ export interface ToolBudgetBlockState {
54
+ reason?: string | null;
55
+ estimatedCostUsd?: number | null;
56
+ remainingCapUsd?: number | null;
57
+ }
58
+
59
  export type ActivityStatus =
60
  | { type: 'idle' }
61
  | { type: 'thinking' }
 
151
  // Tool rejected states (tool_call_id -> true if rejected by user) - persisted across renders
152
  rejectedTools: Record<string, boolean>;
153
 
154
+ // Tool budget-block metadata (tool_call_id -> display metadata) - transient UI state
155
+ budgetBlocks: Record<string, ToolBudgetBlockState>;
156
+
157
  // ── Per-session actions ─────────────────────────────────────────────
158
 
159
  /** Update a session's state. If it's the active session, also update flat state. */
 
205
 
206
  setToolRejected: (toolCallId: string, isRejected: boolean) => void;
207
  getToolRejected: (toolCallId: string) => boolean | undefined;
208
+
209
+ setToolBudgetBlock: (toolCallId: string, block: ToolBudgetBlockState | null) => void;
210
+ getToolBudgetBlock: (toolCallId: string) => ToolBudgetBlockState | undefined;
211
  }
212
 
213
  /**
 
312
  trackioDashboards: loadTrackioDashboards(),
313
  toolErrors: loadToolErrors(),
314
  rejectedTools: loadRejectedTools(),
315
+ budgetBlocks: {},
316
 
317
  // ── Per-session state management ──────────────────────────────────
318
 
 
542
  },
543
 
544
  getToolRejected: (toolCallId) => get().rejectedTools[toolCallId],
545
+
546
+ // ── Tool Budget Blocks ───────────────────────────────────────────────
547
+
548
+ setToolBudgetBlock: (toolCallId, block) => {
549
+ set((state) => {
550
+ if (!block) {
551
+ const next = { ...state.budgetBlocks };
552
+ delete next[toolCallId];
553
+ return { budgetBlocks: next };
554
+ }
555
+ return {
556
+ budgetBlocks: {
557
+ ...state.budgetBlocks,
558
+ [toolCallId]: block,
559
+ },
560
+ };
561
+ });
562
+ },
563
+
564
+ getToolBudgetBlock: (toolCallId) => get().budgetBlocks[toolCallId],
565
  }));
frontend/src/store/sessionStore.ts CHANGED
@@ -27,7 +27,19 @@ interface SessionStore {
27
  created_at: string;
28
  is_active?: boolean;
29
  pending_approval?: unknown[] | null;
 
 
 
 
 
 
30
  }>) => void;
 
 
 
 
 
 
31
  /** Atomically swap a session's id in the list + both localStorage caches.
32
  * Used when we rehydrate an expired session into a freshly-created backend
33
  * session — preserves title, timestamps, and messages. */
@@ -47,6 +59,10 @@ export const useSessionStore = create<SessionStore>()(
47
  createdAt: new Date().toISOString(),
48
  isActive: true,
49
  needsAttention: false,
 
 
 
 
50
  };
51
  set((state) => ({
52
  sessions: [...state.sessions, newSession],
@@ -93,12 +109,21 @@ export const useSessionStore = create<SessionStore>()(
93
  if (!id) continue;
94
  const existing = byId.get(id);
95
  if (existing) {
 
96
  const updated = {
97
  ...existing,
98
  title: server.title || existing.title,
99
  isActive: server.is_active ?? existing.isActive,
100
  needsAttention: Boolean(server.pending_approval?.length) || existing.needsAttention,
101
  expired: false,
 
 
 
 
 
 
 
 
102
  };
103
  const idx = merged.findIndex((s) => s.id === id);
104
  if (idx >= 0) merged[idx] = updated;
@@ -112,6 +137,10 @@ export const useSessionStore = create<SessionStore>()(
112
  isActive: server.is_active ?? true,
113
  needsAttention: Boolean(server.pending_approval?.length),
114
  expired: false,
 
 
 
 
115
  };
116
  merged.push(newSession);
117
  byId.set(id, newSession);
@@ -123,6 +152,22 @@ export const useSessionStore = create<SessionStore>()(
123
  });
124
  },
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  renameSession: (oldId: string, newId: string) => {
127
  if (oldId === newId) return;
128
  moveMessages(oldId, newId);
 
27
  created_at: string;
28
  is_active?: boolean;
29
  pending_approval?: unknown[] | null;
30
+ auto_approval?: {
31
+ enabled?: boolean;
32
+ cost_cap_usd?: number | null;
33
+ estimated_spend_usd?: number;
34
+ remaining_usd?: number | null;
35
+ } | null;
36
  }>) => void;
37
+ updateSessionYolo: (id: string, policy: {
38
+ enabled: boolean;
39
+ cost_cap_usd?: number | null;
40
+ estimated_spend_usd?: number;
41
+ remaining_usd?: number | null;
42
+ }) => void;
43
  /** Atomically swap a session's id in the list + both localStorage caches.
44
  * Used when we rehydrate an expired session into a freshly-created backend
45
  * session — preserves title, timestamps, and messages. */
 
59
  createdAt: new Date().toISOString(),
60
  isActive: true,
61
  needsAttention: false,
62
+ autoApprovalEnabled: false,
63
+ autoApprovalCostCapUsd: null,
64
+ autoApprovalEstimatedSpendUsd: 0,
65
+ autoApprovalRemainingUsd: null,
66
  };
67
  set((state) => ({
68
  sessions: [...state.sessions, newSession],
 
109
  if (!id) continue;
110
  const existing = byId.get(id);
111
  if (existing) {
112
+ const auto = server.auto_approval;
113
  const updated = {
114
  ...existing,
115
  title: server.title || existing.title,
116
  isActive: server.is_active ?? existing.isActive,
117
  needsAttention: Boolean(server.pending_approval?.length) || existing.needsAttention,
118
  expired: false,
119
+ ...(auto
120
+ ? {
121
+ autoApprovalEnabled: Boolean(auto.enabled),
122
+ autoApprovalCostCapUsd: auto.cost_cap_usd ?? null,
123
+ autoApprovalEstimatedSpendUsd: auto.estimated_spend_usd ?? 0,
124
+ autoApprovalRemainingUsd: auto.remaining_usd ?? null,
125
+ }
126
+ : {}),
127
  };
128
  const idx = merged.findIndex((s) => s.id === id);
129
  if (idx >= 0) merged[idx] = updated;
 
137
  isActive: server.is_active ?? true,
138
  needsAttention: Boolean(server.pending_approval?.length),
139
  expired: false,
140
+ autoApprovalEnabled: Boolean(server.auto_approval?.enabled),
141
+ autoApprovalCostCapUsd: server.auto_approval?.cost_cap_usd ?? null,
142
+ autoApprovalEstimatedSpendUsd: server.auto_approval?.estimated_spend_usd ?? 0,
143
+ autoApprovalRemainingUsd: server.auto_approval?.remaining_usd ?? null,
144
  };
145
  merged.push(newSession);
146
  byId.set(id, newSession);
 
152
  });
153
  },
154
 
155
+ updateSessionYolo: (id, policy) => {
156
+ set((state) => ({
157
+ sessions: state.sessions.map((s) =>
158
+ s.id === id
159
+ ? {
160
+ ...s,
161
+ autoApprovalEnabled: policy.enabled,
162
+ autoApprovalCostCapUsd: policy.cost_cap_usd ?? null,
163
+ autoApprovalEstimatedSpendUsd: policy.estimated_spend_usd ?? 0,
164
+ autoApprovalRemainingUsd: policy.remaining_usd ?? null,
165
+ }
166
+ : s,
167
+ ),
168
+ }));
169
+ },
170
+
171
  renameSession: (oldId: string, newId: string) => {
172
  if (oldId === newId) return;
173
  moveMessages(oldId, newId);
frontend/src/types/agent.ts CHANGED
@@ -21,6 +21,10 @@ export interface SessionMeta {
21
  * disables input until the user chooses to restore-with-summary or
22
  * start fresh. */
23
  expired?: boolean;
 
 
 
 
24
  }
25
 
26
  export interface ToolApproval {
 
21
  * disables input until the user chooses to restore-with-summary or
22
  * start fresh. */
23
  expired?: boolean;
24
+ autoApprovalEnabled?: boolean;
25
+ autoApprovalCostCapUsd?: number | null;
26
+ autoApprovalEstimatedSpendUsd?: number;
27
+ autoApprovalRemainingUsd?: number | null;
28
  }
29
 
30
  export interface ToolApproval {
frontend/src/types/events.ts CHANGED
@@ -68,6 +68,10 @@ export interface ApprovalToolItem {
68
  tool: string;
69
  arguments: Record<string, unknown>;
70
  tool_call_id: string;
 
 
 
 
71
  }
72
 
73
  export interface TurnCompleteEventData {
 
68
  tool: string;
69
  arguments: Record<string, unknown>;
70
  tool_call_id: string;
71
+ auto_approval_blocked?: boolean;
72
+ block_reason?: string | null;
73
+ estimated_cost_usd?: number | null;
74
+ remaining_cap_usd?: number | null;
75
  }
76
 
77
  export interface TurnCompleteEventData {
tests/unit/test_agent_model_gating.py CHANGED
@@ -127,3 +127,48 @@ async def test_user_quota_response_uses_premium_fields_only(monkeypatch):
127
  "premium_daily_cap": 5,
128
  "premium_remaining": 3,
129
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  "premium_daily_cap": 5,
128
  "premium_remaining": 3,
129
  }
130
+
131
+
132
+ @pytest.mark.asyncio
133
+ async def test_set_session_yolo_calls_manager_with_cap_presence(monkeypatch):
134
+ async def fake_check_session_access(session_id, user, request=None):
135
+ assert session_id == "s1"
136
+ assert user["user_id"] == "u1"
137
+ return object()
138
+
139
+ calls = []
140
+
141
+ async def fake_update_session_auto_approval(session_id, **kwargs):
142
+ calls.append((session_id, kwargs))
143
+ return {
144
+ "enabled": kwargs["enabled"],
145
+ "cost_cap_usd": 7.5,
146
+ "estimated_spend_usd": 0.0,
147
+ "remaining_usd": 7.5,
148
+ }
149
+
150
+ monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
151
+ monkeypatch.setattr(
152
+ agent.session_manager,
153
+ "update_session_auto_approval",
154
+ fake_update_session_auto_approval,
155
+ )
156
+
157
+ response = await agent.set_session_yolo(
158
+ "s1",
159
+ agent.SessionYoloRequest(enabled=True, cost_cap_usd=7.5),
160
+ {"user_id": "u1"},
161
+ )
162
+
163
+ assert response["enabled"] is True
164
+ assert response["remaining_usd"] == 7.5
165
+ assert calls == [
166
+ (
167
+ "s1",
168
+ {
169
+ "enabled": True,
170
+ "cost_cap_usd": 7.5,
171
+ "cap_provided": True,
172
+ },
173
+ )
174
+ ]
tests/unit/test_auto_approval_policy.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import SimpleNamespace
2
+
3
+ import pytest
4
+
5
+ from agent.config import Config
6
+ from agent.core import agent_loop
7
+ from agent.core.cost_estimation import CostEstimate
8
+
9
+
10
+ def _config(**overrides):
11
+ data = {
12
+ "model_name": "moonshotai/Kimi-K2.6",
13
+ "confirm_cpu_jobs": True,
14
+ "auto_file_upload": False,
15
+ "yolo_mode": False,
16
+ **overrides,
17
+ }
18
+ return Config.model_validate(data)
19
+
20
+
21
+ def _session(*, cap=5.0, spent=0.0, enabled=True):
22
+ return SimpleNamespace(
23
+ config=_config(),
24
+ auto_approval_enabled=enabled,
25
+ auto_approval_cost_cap_usd=cap,
26
+ auto_approval_estimated_spend_usd=spent,
27
+ sandbox=None,
28
+ )
29
+
30
+
31
+ @pytest.mark.asyncio
32
+ async def test_session_yolo_auto_approves_non_costed_approval_tool():
33
+ decision = await agent_loop._approval_decision(
34
+ "hf_repo_files",
35
+ {"operation": "upload", "path": "README.md"},
36
+ _session(),
37
+ )
38
+
39
+ assert decision.requires_approval is False
40
+ assert decision.auto_approved is True
41
+
42
+
43
+ @pytest.mark.asyncio
44
+ @pytest.mark.parametrize(
45
+ "operation",
46
+ ["scheduled run", "scheduled uv", "scheduled run"],
47
+ )
48
+ async def test_scheduled_hf_jobs_always_require_manual_approval(operation):
49
+ session = _session()
50
+ session.config.yolo_mode = True
51
+
52
+ decision = await agent_loop._approval_decision(
53
+ "hf_jobs",
54
+ {"operation": operation, "script": "print(1)"},
55
+ session,
56
+ )
57
+
58
+ assert decision.requires_approval is True
59
+ assert decision.auto_approval_blocked is True
60
+ assert "Scheduled HF jobs" in decision.block_reason
61
+ assert agent_loop._needs_approval("hf_jobs", {"operation": operation}, session.config)
62
+
63
+
64
+ @pytest.mark.asyncio
65
+ async def test_immediate_hf_job_under_cap_auto_runs(monkeypatch):
66
+ async def fake_estimate(*args, **kwargs):
67
+ return CostEstimate(estimated_cost_usd=2.0, billable=True)
68
+
69
+ monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
70
+
71
+ decision = await agent_loop._approval_decision(
72
+ "hf_jobs",
73
+ {"operation": "run", "hardware_flavor": "a10g-large", "timeout": "1h"},
74
+ _session(cap=5.0, spent=1.0),
75
+ )
76
+
77
+ assert decision.requires_approval is False
78
+ assert decision.auto_approved is True
79
+ assert decision.estimated_cost_usd == 2.0
80
+
81
+
82
+ @pytest.mark.asyncio
83
+ async def test_immediate_hf_job_over_cap_falls_back_to_approval(monkeypatch):
84
+ async def fake_estimate(*args, **kwargs):
85
+ return CostEstimate(estimated_cost_usd=2.0, billable=True)
86
+
87
+ monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
88
+
89
+ decision = await agent_loop._approval_decision(
90
+ "hf_jobs",
91
+ {"operation": "run", "hardware_flavor": "a10g-large", "timeout": "1h"},
92
+ _session(cap=5.0, spent=4.0),
93
+ )
94
+
95
+ assert decision.requires_approval is True
96
+ assert decision.auto_approval_blocked is True
97
+ assert "exceeds" in decision.block_reason
98
+ assert decision.remaining_cap_usd == 1.0
99
+
100
+
101
+ @pytest.mark.asyncio
102
+ async def test_unknown_cost_falls_back_to_approval(monkeypatch):
103
+ async def fake_estimate(*args, **kwargs):
104
+ return CostEstimate(
105
+ estimated_cost_usd=None,
106
+ billable=True,
107
+ block_reason="No price is available.",
108
+ )
109
+
110
+ monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
111
+
112
+ decision = await agent_loop._approval_decision(
113
+ "sandbox_create",
114
+ {"hardware": "mystery-gpu"},
115
+ _session(),
116
+ )
117
+
118
+ assert decision.requires_approval is True
119
+ assert decision.auto_approval_blocked is True
120
+ assert decision.estimated_cost_usd is None
121
+
122
+
123
+ @pytest.mark.asyncio
124
+ async def test_batch_reservation_blocks_second_over_budget_job(monkeypatch):
125
+ async def fake_estimate(*args, **kwargs):
126
+ return CostEstimate(estimated_cost_usd=3.0, billable=True)
127
+
128
+ monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
129
+ session = _session(cap=5.0, spent=0.0)
130
+
131
+ first = await agent_loop._approval_decision(
132
+ "hf_jobs",
133
+ {"operation": "run", "hardware_flavor": "a10g-large"},
134
+ session,
135
+ reserved_spend_usd=0.0,
136
+ )
137
+ second = await agent_loop._approval_decision(
138
+ "hf_jobs",
139
+ {"operation": "run", "hardware_flavor": "a10g-large"},
140
+ session,
141
+ reserved_spend_usd=first.estimated_cost_usd or 0.0,
142
+ )
143
+
144
+ assert first.requires_approval is False
145
+ assert second.requires_approval is True
146
+ assert second.remaining_cap_usd == 2.0
147
+
148
+
149
+ @pytest.mark.asyncio
150
+ async def test_manual_approval_does_not_record_spend_when_session_yolo_disabled(monkeypatch):
151
+ called = False
152
+
153
+ async def fake_estimate(*args, **kwargs):
154
+ nonlocal called
155
+ called = True
156
+ return CostEstimate(estimated_cost_usd=2.0, billable=True)
157
+
158
+ monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
159
+ session = _session(enabled=False, cap=5.0, spent=0.0)
160
+
161
+ await agent_loop._record_manual_approved_spend_if_needed(
162
+ session,
163
+ "sandbox_create",
164
+ {"hardware": "a10g-large"},
165
+ )
166
+
167
+ assert called is False
168
+ assert session.auto_approval_estimated_spend_usd == 0.0
169
+
170
+
171
+ @pytest.mark.asyncio
172
+ async def test_manual_approval_records_spend_when_session_yolo_enabled(monkeypatch):
173
+ async def fake_estimate(*args, **kwargs):
174
+ return CostEstimate(estimated_cost_usd=1.25, billable=True)
175
+
176
+ monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
177
+ session = _session(enabled=True, cap=5.0, spent=0.5)
178
+
179
+ await agent_loop._record_manual_approved_spend_if_needed(
180
+ session,
181
+ "sandbox_create",
182
+ {"hardware": "a10g-large"},
183
+ )
184
+
185
+ assert session.auto_approval_estimated_spend_usd == 1.75
tests/unit/test_cost_estimation.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import SimpleNamespace
2
+
3
+ import pytest
4
+
5
+ from agent.core import cost_estimation
6
+
7
+
8
+ def test_parse_timeout_hours_common_units():
9
+ assert cost_estimation.parse_timeout_hours(None) == 0.5
10
+ assert cost_estimation.parse_timeout_hours("30m") == 0.5
11
+ assert cost_estimation.parse_timeout_hours("3h") == 3
12
+ assert cost_estimation.parse_timeout_hours(3600) == 1
13
+ assert cost_estimation.parse_timeout_hours("not-a-duration") is None
14
+
15
+
16
+ @pytest.mark.asyncio
17
+ async def test_estimate_hf_job_cost_uses_catalog_price(monkeypatch):
18
+ async def fake_catalog():
19
+ return {"a100-large": 4.0}
20
+
21
+ monkeypatch.setattr(cost_estimation, "hf_jobs_price_catalog", fake_catalog)
22
+
23
+ estimate = await cost_estimation.estimate_hf_job_cost(
24
+ {"hardware_flavor": "a100-large", "timeout": "8h"}
25
+ )
26
+
27
+ assert estimate.estimated_cost_usd == 32.0
28
+ assert estimate.billable is True
29
+
30
+
31
+ @pytest.mark.asyncio
32
+ async def test_estimate_hf_job_cost_blocks_unknown_price(monkeypatch):
33
+ async def fake_catalog():
34
+ return {}
35
+
36
+ monkeypatch.setattr(cost_estimation, "hf_jobs_price_catalog", fake_catalog)
37
+
38
+ estimate = await cost_estimation.estimate_hf_job_cost(
39
+ {"hardware_flavor": "mystery-gpu", "timeout": "30m"}
40
+ )
41
+
42
+ assert estimate.estimated_cost_usd is None
43
+ assert estimate.billable is True
44
+ assert "No price" in estimate.block_reason
45
+
46
+
47
+ @pytest.mark.asyncio
48
+ async def test_estimate_sandbox_cost_is_zero_for_existing_or_cpu_basic():
49
+ existing = await cost_estimation.estimate_sandbox_cost(
50
+ {"hardware": "a100-large"},
51
+ session=SimpleNamespace(sandbox=object()),
52
+ )
53
+ cpu = await cost_estimation.estimate_sandbox_cost({"hardware": "cpu-basic"})
54
+
55
+ assert existing.estimated_cost_usd == 0.0
56
+ assert existing.billable is False
57
+ assert cpu.estimated_cost_usd == 0.0
58
+ assert cpu.billable is False
tests/unit/test_session_manager_persistence.py CHANGED
@@ -27,6 +27,23 @@ class FakeRuntimeSession:
27
  self.turn_count = 0
28
  self.config = SimpleNamespace(model_name=model)
29
  self.notification_destinations = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  class RestoreStore(NoopSessionStore):
@@ -85,6 +102,24 @@ def _runtime_agent_session(
85
  )
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def _install_fake_runtime(manager: SessionManager) -> asyncio.Event:
89
  stop = asyncio.Event()
90
  manager.run_calls = 0 # type: ignore[attr-defined]
@@ -204,6 +239,34 @@ async def test_lazy_restore_preserves_pending_approval_tool_calls():
204
  await _cancel_runtime_tasks(manager)
205
 
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  @pytest.mark.asyncio
208
  async def test_list_sessions_dev_uses_store_dev_visibility():
209
  class ListStore(NoopSessionStore):
@@ -221,6 +284,9 @@ async def test_list_sessions_dev_uses_store_dev_visibility():
221
  "user_id": "alice",
222
  "model": "m",
223
  "created_at": datetime.now(UTC),
 
 
 
224
  },
225
  {
226
  "session_id": "s2",
@@ -238,3 +304,10 @@ async def test_list_sessions_dev_uses_store_dev_visibility():
238
 
239
  assert store.seen_user_id == "dev"
240
  assert {session["session_id"] for session in sessions} == {"s1", "s2"}
 
 
 
 
 
 
 
 
27
  self.turn_count = 0
28
  self.config = SimpleNamespace(model_name=model)
29
  self.notification_destinations = []
30
+ self.auto_approval_enabled = False
31
+ self.auto_approval_cost_cap_usd = None
32
+ self.auto_approval_estimated_spend_usd = 0.0
33
+
34
+ def auto_approval_policy_summary(self):
35
+ cap = self.auto_approval_cost_cap_usd
36
+ remaining = None if cap is None else max(0, cap - self.auto_approval_estimated_spend_usd)
37
+ return {
38
+ "enabled": self.auto_approval_enabled,
39
+ "cost_cap_usd": cap,
40
+ "estimated_spend_usd": self.auto_approval_estimated_spend_usd,
41
+ "remaining_usd": remaining,
42
+ }
43
+
44
+ def set_auto_approval_policy(self, *, enabled, cost_cap_usd):
45
+ self.auto_approval_enabled = enabled
46
+ self.auto_approval_cost_cap_usd = cost_cap_usd
47
 
48
 
49
  class RestoreStore(NoopSessionStore):
 
102
  )
103
 
104
 
105
+ @pytest.mark.asyncio
106
+ async def test_update_session_auto_approval_defaults_to_five_dollars():
107
+ manager = _manager_with_store(NoopSessionStore())
108
+ existing = _runtime_agent_session("s1", user_id="owner")
109
+ manager.sessions["s1"] = existing
110
+
111
+ summary = await manager.update_session_auto_approval(
112
+ "s1",
113
+ enabled=True,
114
+ cost_cap_usd=None,
115
+ cap_provided=False,
116
+ )
117
+
118
+ assert summary["enabled"] is True
119
+ assert summary["cost_cap_usd"] == 5.0
120
+ assert summary["remaining_usd"] == 5.0
121
+
122
+
123
  def _install_fake_runtime(manager: SessionManager) -> asyncio.Event:
124
  stop = asyncio.Event()
125
  manager.run_calls = 0 # type: ignore[attr-defined]
 
239
  await _cancel_runtime_tasks(manager)
240
 
241
 
242
+ @pytest.mark.asyncio
243
+ async def test_lazy_restore_preserves_auto_approval_policy():
244
+ store = RestoreStore(
245
+ metadata={
246
+ "session_id": "yolo-session",
247
+ "user_id": "owner",
248
+ "model": "test-model",
249
+ "auto_approval_enabled": True,
250
+ "auto_approval_cost_cap_usd": 5.0,
251
+ "auto_approval_estimated_spend_usd": 1.25,
252
+ }
253
+ )
254
+ manager = _manager_with_store(store)
255
+ stop = _install_fake_runtime(manager)
256
+
257
+ try:
258
+ restored = await manager.ensure_session_loaded("yolo-session", user_id="owner")
259
+
260
+ assert restored is not None
261
+ assert restored.session.auto_approval_enabled is True
262
+ assert restored.session.auto_approval_cost_cap_usd == 5.0
263
+ assert restored.session.auto_approval_estimated_spend_usd == 1.25
264
+ assert restored.session.auto_approval_policy_summary()["remaining_usd"] == 3.75
265
+ finally:
266
+ stop.set()
267
+ await _cancel_runtime_tasks(manager)
268
+
269
+
270
  @pytest.mark.asyncio
271
  async def test_list_sessions_dev_uses_store_dev_visibility():
272
  class ListStore(NoopSessionStore):
 
284
  "user_id": "alice",
285
  "model": "m",
286
  "created_at": datetime.now(UTC),
287
+ "auto_approval_enabled": True,
288
+ "auto_approval_cost_cap_usd": 5.0,
289
+ "auto_approval_estimated_spend_usd": 2.0,
290
  },
291
  {
292
  "session_id": "s2",
 
304
 
305
  assert store.seen_user_id == "dev"
306
  assert {session["session_id"] for session in sessions} == {"s1", "s2"}
307
+ yolo = next(session for session in sessions if session["session_id"] == "s1")
308
+ assert yolo["auto_approval"] == {
309
+ "enabled": True,
310
+ "cost_cap_usd": 5.0,
311
+ "estimated_spend_usd": 2.0,
312
+ "remaining_usd": 3.0,
313
+ }