Deploy 2026-05-01
Browse filesCo-authored-by: OpenAI Codex <codex@openai.com>
- agent/core/agent_loop.py +244 -26
- agent/core/approval_policy.py +11 -0
- agent/core/cost_estimation.py +278 -0
- agent/core/session.py +37 -0
- agent/core/session_persistence.py +12 -0
- agent/main.py +28 -6
- agent/prompts/system_prompt_v3.yaml +1 -1
- agent/tools/docs_tools.py +1 -1
- backend/models.py +19 -0
- backend/routes/agent.py +21 -0
- backend/session_manager.py +94 -0
- frontend/src/components/Chat/ToolCallGroup.tsx +20 -1
- frontend/src/components/Layout/AppLayout.tsx +2 -0
- frontend/src/components/YoloControl.tsx +155 -0
- frontend/src/hooks/useAgentChat.ts +27 -2
- frontend/src/lib/sse-chat-transport.ts +13 -1
- frontend/src/store/agentStore.ts +33 -0
- frontend/src/store/sessionStore.ts +45 -0
- frontend/src/types/agent.ts +4 -0
- frontend/src/types/events.ts +4 -0
- tests/unit/test_agent_model_gating.py +45 -0
- tests/unit/test_auto_approval_policy.py +185 -0
- tests/unit/test_cost_estimation.py +58 -0
- tests/unit/test_session_manager_persistence.py +73 -0
agent/core/agent_loop.py
CHANGED
|
@@ -19,6 +19,11 @@ from litellm import (
|
|
| 19 |
from litellm.exceptions import ContextWindowExceededError
|
| 20 |
|
| 21 |
from agent.config import Config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
from agent.messaging.gateway import NotificationGateway
|
| 23 |
from agent.core import telemetry
|
| 24 |
from agent.core.doom_loop import check_for_doom_loop
|
|
@@ -110,13 +115,39 @@ def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
|
|
| 110 |
return True, None
|
| 111 |
|
| 112 |
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
tool_name: str, tool_args: dict, config: Config | None = None
|
| 115 |
) -> bool:
|
| 116 |
-
"""Check if a tool call requires
|
| 117 |
-
# Yolo mode: skip all approvals
|
| 118 |
-
if config and config.yolo_mode:
|
| 119 |
-
return False
|
| 120 |
|
| 121 |
# If args are malformed, skip approval (validation error will be shown later)
|
| 122 |
args_valid, _ = _validate_tool_args(tool_args)
|
|
@@ -127,8 +158,10 @@ def _needs_approval(
|
|
| 127 |
return True
|
| 128 |
|
| 129 |
if tool_name == "hf_jobs":
|
| 130 |
-
operation =
|
| 131 |
-
if operation
|
|
|
|
|
|
|
| 132 |
return False
|
| 133 |
|
| 134 |
# Check if this is a CPU-only job
|
|
@@ -180,6 +213,143 @@ def _needs_approval(
|
|
| 180 |
return False
|
| 181 |
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
# -- LLM retry constants --------------------------------------------------
|
| 184 |
_MAX_LLM_RETRIES = 3
|
| 185 |
_LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries
|
|
@@ -1063,29 +1233,49 @@ class Handlers:
|
|
| 1063 |
if session.is_cancelled:
|
| 1064 |
break
|
| 1065 |
|
| 1066 |
-
# Separate good tools into approval-required vs auto-execute
|
| 1067 |
-
|
| 1068 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1069 |
for tc, tool_name, tool_args in good_tools:
|
| 1070 |
-
|
| 1071 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
else:
|
| 1073 |
-
non_approval_tools.append((tc, tool_name, tool_args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1074 |
|
| 1075 |
# Execute non-approval tools (in parallel when possible)
|
| 1076 |
if non_approval_tools:
|
| 1077 |
# 1. Validate args upfront
|
| 1078 |
parsed_tools: list[
|
| 1079 |
-
tuple[ToolCall, str, dict, bool, str]
|
| 1080 |
] = []
|
| 1081 |
-
for tc, tool_name, tool_args in non_approval_tools:
|
| 1082 |
args_valid, error_msg = _validate_tool_args(tool_args)
|
| 1083 |
parsed_tools.append(
|
| 1084 |
-
(tc, tool_name, tool_args, args_valid, error_msg)
|
| 1085 |
)
|
| 1086 |
|
| 1087 |
# 2. Send all tool_call events upfront (so frontend shows them all)
|
| 1088 |
-
for tc, tool_name, tool_args, args_valid, _ in parsed_tools:
|
| 1089 |
if args_valid:
|
| 1090 |
await session.send_event(
|
| 1091 |
Event(
|
|
@@ -1103,11 +1293,14 @@ class Handlers:
|
|
| 1103 |
tc: ToolCall,
|
| 1104 |
name: str,
|
| 1105 |
args: dict,
|
|
|
|
| 1106 |
valid: bool,
|
| 1107 |
err: str,
|
| 1108 |
) -> tuple[ToolCall, str, dict, str, bool]:
|
| 1109 |
if not valid:
|
| 1110 |
return (tc, name, args, err, False)
|
|
|
|
|
|
|
| 1111 |
out, ok = await session.tool_router.call_tool(
|
| 1112 |
name, args, session=session, tool_call_id=tc.id
|
| 1113 |
)
|
|
@@ -1115,8 +1308,8 @@ class Handlers:
|
|
| 1115 |
|
| 1116 |
gather_task = asyncio.ensure_future(asyncio.gather(
|
| 1117 |
*[
|
| 1118 |
-
_exec_tool(tc, name, args, valid, err)
|
| 1119 |
-
for tc, name, args, valid, err in parsed_tools
|
| 1120 |
]
|
| 1121 |
))
|
| 1122 |
cancel_task = asyncio.ensure_future(session._cancelled.wait())
|
|
@@ -1133,7 +1326,7 @@ class Handlers:
|
|
| 1133 |
except asyncio.CancelledError:
|
| 1134 |
pass
|
| 1135 |
# Notify frontend that in-flight tools were cancelled
|
| 1136 |
-
for tc, name, _args, valid, _ in parsed_tools:
|
| 1137 |
if valid:
|
| 1138 |
await session.send_event(Event(
|
| 1139 |
event_type="tool_state_change",
|
|
@@ -1171,7 +1364,8 @@ class Handlers:
|
|
| 1171 |
if approval_required_tools:
|
| 1172 |
# Prepare batch approval data
|
| 1173 |
tools_data = []
|
| 1174 |
-
|
|
|
|
| 1175 |
# Resolve sandbox file paths for hf_jobs scripts so the
|
| 1176 |
# frontend can display & edit the actual file content.
|
| 1177 |
if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str):
|
|
@@ -1181,20 +1375,42 @@ class Handlers:
|
|
| 1181 |
if resolved:
|
| 1182 |
tool_args = {**tool_args, "script": resolved}
|
| 1183 |
|
| 1184 |
-
|
| 1185 |
"tool": tool_name,
|
| 1186 |
"arguments": tool_args,
|
| 1187 |
"tool_call_id": tc.id,
|
| 1188 |
-
}
|
| 1189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1190 |
await session.send_event(Event(
|
| 1191 |
event_type="approval_required",
|
| 1192 |
-
data=
|
| 1193 |
))
|
| 1194 |
|
| 1195 |
# Store all approval-requiring tools (ToolCall objects for execution)
|
| 1196 |
session.pending_approval = {
|
| 1197 |
-
"tool_calls": [tc for tc, _, _ in approval_required_tools],
|
| 1198 |
}
|
| 1199 |
|
| 1200 |
# Return early - wait for EXEC_APPROVAL operation
|
|
@@ -1384,6 +1600,8 @@ class Handlers:
|
|
| 1384 |
)
|
| 1385 |
)
|
| 1386 |
|
|
|
|
|
|
|
| 1387 |
output, success = await session.tool_router.call_tool(
|
| 1388 |
tool_name, tool_args, session=session, tool_call_id=tc.id
|
| 1389 |
)
|
|
|
|
| 19 |
from litellm.exceptions import ContextWindowExceededError
|
| 20 |
|
| 21 |
from agent.config import Config
|
| 22 |
+
from agent.core.approval_policy import (
|
| 23 |
+
is_scheduled_operation,
|
| 24 |
+
normalize_tool_operation,
|
| 25 |
+
)
|
| 26 |
+
from agent.core.cost_estimation import CostEstimate, estimate_tool_cost
|
| 27 |
from agent.messaging.gateway import NotificationGateway
|
| 28 |
from agent.core import telemetry
|
| 29 |
from agent.core.doom_loop import check_for_doom_loop
|
|
|
|
| 115 |
return True, None
|
| 116 |
|
| 117 |
|
| 118 |
+
_IMMEDIATE_HF_JOB_RUNS = {"run", "uv"}
|
| 119 |
+
|
| 120 |
+
@dataclass(frozen=True)
|
| 121 |
+
class ApprovalDecision:
|
| 122 |
+
requires_approval: bool
|
| 123 |
+
auto_approved: bool = False
|
| 124 |
+
auto_approval_blocked: bool = False
|
| 125 |
+
block_reason: str | None = None
|
| 126 |
+
estimated_cost_usd: float | None = None
|
| 127 |
+
remaining_cap_usd: float | None = None
|
| 128 |
+
billable: bool = False
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _operation(tool_args: dict) -> str:
|
| 132 |
+
return normalize_tool_operation(tool_args.get("operation"))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _is_immediate_hf_job_run(tool_name: str, tool_args: dict) -> bool:
|
| 136 |
+
return tool_name == "hf_jobs" and _operation(tool_args) in _IMMEDIATE_HF_JOB_RUNS
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _is_scheduled_hf_job_run(tool_name: str, tool_args: dict) -> bool:
|
| 140 |
+
return tool_name == "hf_jobs" and is_scheduled_operation(_operation(tool_args))
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _is_budgeted_auto_approval_target(tool_name: str, tool_args: dict) -> bool:
|
| 144 |
+
return tool_name == "sandbox_create" or _is_immediate_hf_job_run(tool_name, tool_args)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _base_needs_approval(
|
| 148 |
tool_name: str, tool_args: dict, config: Config | None = None
|
| 149 |
) -> bool:
|
| 150 |
+
"""Check if a tool call requires approval before YOLO policy is applied."""
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
# If args are malformed, skip approval (validation error will be shown later)
|
| 153 |
args_valid, _ = _validate_tool_args(tool_args)
|
|
|
|
| 158 |
return True
|
| 159 |
|
| 160 |
if tool_name == "hf_jobs":
|
| 161 |
+
operation = _operation(tool_args)
|
| 162 |
+
if is_scheduled_operation(operation):
|
| 163 |
+
return True
|
| 164 |
+
if operation not in _IMMEDIATE_HF_JOB_RUNS:
|
| 165 |
return False
|
| 166 |
|
| 167 |
# Check if this is a CPU-only job
|
|
|
|
| 213 |
return False
|
| 214 |
|
| 215 |
|
| 216 |
+
def _needs_approval(
|
| 217 |
+
tool_name: str, tool_args: dict, config: Config | None = None
|
| 218 |
+
) -> bool:
|
| 219 |
+
"""Legacy sync approval predicate used by tests and CLI display helpers."""
|
| 220 |
+
if _is_scheduled_hf_job_run(tool_name, tool_args):
|
| 221 |
+
return True
|
| 222 |
+
if config and config.yolo_mode:
|
| 223 |
+
return False
|
| 224 |
+
return _base_needs_approval(tool_name, tool_args, config)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _session_auto_approval_enabled(session: Session | None) -> bool:
|
| 228 |
+
return bool(session and getattr(session, "auto_approval_enabled", False))
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _effective_yolo_enabled(session: Session | None, config: Config | None) -> bool:
|
| 232 |
+
return bool((config and config.yolo_mode) or _session_auto_approval_enabled(session))
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _remaining_budget_after_reservations(
|
| 236 |
+
session: Session | None, reserved_spend_usd: float
|
| 237 |
+
) -> float | None:
|
| 238 |
+
if not session or getattr(session, "auto_approval_cost_cap_usd", None) is None:
|
| 239 |
+
return None
|
| 240 |
+
cap = float(getattr(session, "auto_approval_cost_cap_usd") or 0.0)
|
| 241 |
+
spent = float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0)
|
| 242 |
+
return round(max(0.0, cap - spent - reserved_spend_usd), 4)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def _budget_block_reason(
|
| 246 |
+
estimate: CostEstimate,
|
| 247 |
+
*,
|
| 248 |
+
remaining_cap_usd: float | None,
|
| 249 |
+
) -> str | None:
|
| 250 |
+
if estimate.estimated_cost_usd is None:
|
| 251 |
+
return estimate.block_reason or "Could not estimate the cost safely."
|
| 252 |
+
if remaining_cap_usd is not None and estimate.estimated_cost_usd > remaining_cap_usd:
|
| 253 |
+
return (
|
| 254 |
+
f"Estimated cost ${estimate.estimated_cost_usd:.2f} exceeds "
|
| 255 |
+
f"remaining YOLO cap ${remaining_cap_usd:.2f}."
|
| 256 |
+
)
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
async def _approval_decision(
|
| 261 |
+
tool_name: str,
|
| 262 |
+
tool_args: dict,
|
| 263 |
+
session: Session,
|
| 264 |
+
*,
|
| 265 |
+
reserved_spend_usd: float = 0.0,
|
| 266 |
+
) -> ApprovalDecision:
|
| 267 |
+
"""Return the approval decision for one parsed tool call."""
|
| 268 |
+
config = session.config
|
| 269 |
+
base_requires_approval = _base_needs_approval(tool_name, tool_args, config)
|
| 270 |
+
|
| 271 |
+
# Scheduled jobs are recurring/unbounded enough that YOLO never bypasses
|
| 272 |
+
# the human confirmation, including legacy config.yolo_mode.
|
| 273 |
+
if _is_scheduled_hf_job_run(tool_name, tool_args):
|
| 274 |
+
return ApprovalDecision(
|
| 275 |
+
requires_approval=True,
|
| 276 |
+
auto_approval_blocked=_effective_yolo_enabled(session, config),
|
| 277 |
+
block_reason="Scheduled HF jobs always require manual approval.",
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
yolo_enabled = _effective_yolo_enabled(session, config)
|
| 281 |
+
budgeted_target = _is_budgeted_auto_approval_target(tool_name, tool_args)
|
| 282 |
+
|
| 283 |
+
# Cost caps are a session-scoped web policy. Legacy config.yolo_mode
|
| 284 |
+
# remains uncapped for CLI/headless, except for scheduled jobs above.
|
| 285 |
+
session_yolo_enabled = _session_auto_approval_enabled(session)
|
| 286 |
+
if yolo_enabled and budgeted_target and session_yolo_enabled:
|
| 287 |
+
estimate = await estimate_tool_cost(tool_name, tool_args, session=session)
|
| 288 |
+
remaining = _remaining_budget_after_reservations(session, reserved_spend_usd)
|
| 289 |
+
reason = _budget_block_reason(estimate, remaining_cap_usd=remaining)
|
| 290 |
+
if reason:
|
| 291 |
+
return ApprovalDecision(
|
| 292 |
+
requires_approval=True,
|
| 293 |
+
auto_approval_blocked=True,
|
| 294 |
+
block_reason=reason,
|
| 295 |
+
estimated_cost_usd=estimate.estimated_cost_usd,
|
| 296 |
+
remaining_cap_usd=remaining,
|
| 297 |
+
billable=estimate.billable,
|
| 298 |
+
)
|
| 299 |
+
if base_requires_approval:
|
| 300 |
+
return ApprovalDecision(
|
| 301 |
+
requires_approval=False,
|
| 302 |
+
auto_approved=True,
|
| 303 |
+
estimated_cost_usd=estimate.estimated_cost_usd,
|
| 304 |
+
remaining_cap_usd=remaining,
|
| 305 |
+
billable=estimate.billable,
|
| 306 |
+
)
|
| 307 |
+
return ApprovalDecision(
|
| 308 |
+
requires_approval=False,
|
| 309 |
+
estimated_cost_usd=estimate.estimated_cost_usd,
|
| 310 |
+
remaining_cap_usd=remaining,
|
| 311 |
+
billable=estimate.billable,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if base_requires_approval and yolo_enabled:
|
| 315 |
+
return ApprovalDecision(requires_approval=False, auto_approved=True)
|
| 316 |
+
|
| 317 |
+
return ApprovalDecision(requires_approval=base_requires_approval)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def _record_estimated_spend(session: Session, decision: ApprovalDecision) -> None:
|
| 321 |
+
if not decision.billable or decision.estimated_cost_usd is None:
|
| 322 |
+
return
|
| 323 |
+
if hasattr(session, "add_auto_approval_estimated_spend"):
|
| 324 |
+
session.add_auto_approval_estimated_spend(decision.estimated_cost_usd)
|
| 325 |
+
else:
|
| 326 |
+
session.auto_approval_estimated_spend_usd = round(
|
| 327 |
+
float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0)
|
| 328 |
+
+ float(decision.estimated_cost_usd),
|
| 329 |
+
4,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
async def _record_manual_approved_spend_if_needed(
|
| 334 |
+
session: Session,
|
| 335 |
+
tool_name: str,
|
| 336 |
+
tool_args: dict,
|
| 337 |
+
) -> None:
|
| 338 |
+
if not _session_auto_approval_enabled(session):
|
| 339 |
+
return
|
| 340 |
+
if not _is_budgeted_auto_approval_target(tool_name, tool_args):
|
| 341 |
+
return
|
| 342 |
+
estimate = await estimate_tool_cost(tool_name, tool_args, session=session)
|
| 343 |
+
_record_estimated_spend(
|
| 344 |
+
session,
|
| 345 |
+
ApprovalDecision(
|
| 346 |
+
requires_approval=False,
|
| 347 |
+
billable=estimate.billable,
|
| 348 |
+
estimated_cost_usd=estimate.estimated_cost_usd,
|
| 349 |
+
),
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
# -- LLM retry constants --------------------------------------------------
|
| 354 |
_MAX_LLM_RETRIES = 3
|
| 355 |
_LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries
|
|
|
|
| 1233 |
if session.is_cancelled:
|
| 1234 |
break
|
| 1235 |
|
| 1236 |
+
# Separate good tools into approval-required vs auto-execute.
|
| 1237 |
+
# Track reserved spend while classifying a batch so two
|
| 1238 |
+
# auto-approved jobs in one model response cannot jointly
|
| 1239 |
+
# exceed the remaining session cap.
|
| 1240 |
+
approval_required_tools: list[
|
| 1241 |
+
tuple[ToolCall, str, dict, ApprovalDecision]
|
| 1242 |
+
] = []
|
| 1243 |
+
non_approval_tools: list[
|
| 1244 |
+
tuple[ToolCall, str, dict, ApprovalDecision]
|
| 1245 |
+
] = []
|
| 1246 |
+
reserved_auto_spend_usd = 0.0
|
| 1247 |
for tc, tool_name, tool_args in good_tools:
|
| 1248 |
+
decision = await _approval_decision(
|
| 1249 |
+
tool_name,
|
| 1250 |
+
tool_args,
|
| 1251 |
+
session,
|
| 1252 |
+
reserved_spend_usd=reserved_auto_spend_usd,
|
| 1253 |
+
)
|
| 1254 |
+
if decision.requires_approval:
|
| 1255 |
+
approval_required_tools.append((tc, tool_name, tool_args, decision))
|
| 1256 |
else:
|
| 1257 |
+
non_approval_tools.append((tc, tool_name, tool_args, decision))
|
| 1258 |
+
if (
|
| 1259 |
+
decision.auto_approved
|
| 1260 |
+
and decision.billable
|
| 1261 |
+
and decision.estimated_cost_usd is not None
|
| 1262 |
+
):
|
| 1263 |
+
reserved_auto_spend_usd += decision.estimated_cost_usd
|
| 1264 |
|
| 1265 |
# Execute non-approval tools (in parallel when possible)
|
| 1266 |
if non_approval_tools:
|
| 1267 |
# 1. Validate args upfront
|
| 1268 |
parsed_tools: list[
|
| 1269 |
+
tuple[ToolCall, str, dict, ApprovalDecision, bool, str]
|
| 1270 |
] = []
|
| 1271 |
+
for tc, tool_name, tool_args, decision in non_approval_tools:
|
| 1272 |
args_valid, error_msg = _validate_tool_args(tool_args)
|
| 1273 |
parsed_tools.append(
|
| 1274 |
+
(tc, tool_name, tool_args, decision, args_valid, error_msg)
|
| 1275 |
)
|
| 1276 |
|
| 1277 |
# 2. Send all tool_call events upfront (so frontend shows them all)
|
| 1278 |
+
for tc, tool_name, tool_args, _decision, args_valid, _ in parsed_tools:
|
| 1279 |
if args_valid:
|
| 1280 |
await session.send_event(
|
| 1281 |
Event(
|
|
|
|
| 1293 |
tc: ToolCall,
|
| 1294 |
name: str,
|
| 1295 |
args: dict,
|
| 1296 |
+
decision: ApprovalDecision,
|
| 1297 |
valid: bool,
|
| 1298 |
err: str,
|
| 1299 |
) -> tuple[ToolCall, str, dict, str, bool]:
|
| 1300 |
if not valid:
|
| 1301 |
return (tc, name, args, err, False)
|
| 1302 |
+
if decision.billable:
|
| 1303 |
+
_record_estimated_spend(session, decision)
|
| 1304 |
out, ok = await session.tool_router.call_tool(
|
| 1305 |
name, args, session=session, tool_call_id=tc.id
|
| 1306 |
)
|
|
|
|
| 1308 |
|
| 1309 |
gather_task = asyncio.ensure_future(asyncio.gather(
|
| 1310 |
*[
|
| 1311 |
+
_exec_tool(tc, name, args, decision, valid, err)
|
| 1312 |
+
for tc, name, args, decision, valid, err in parsed_tools
|
| 1313 |
]
|
| 1314 |
))
|
| 1315 |
cancel_task = asyncio.ensure_future(session._cancelled.wait())
|
|
|
|
| 1326 |
except asyncio.CancelledError:
|
| 1327 |
pass
|
| 1328 |
# Notify frontend that in-flight tools were cancelled
|
| 1329 |
+
for tc, name, _args, _decision, valid, _ in parsed_tools:
|
| 1330 |
if valid:
|
| 1331 |
await session.send_event(Event(
|
| 1332 |
event_type="tool_state_change",
|
|
|
|
| 1364 |
if approval_required_tools:
|
| 1365 |
# Prepare batch approval data
|
| 1366 |
tools_data = []
|
| 1367 |
+
blocked_payloads = []
|
| 1368 |
+
for tc, tool_name, tool_args, decision in approval_required_tools:
|
| 1369 |
# Resolve sandbox file paths for hf_jobs scripts so the
|
| 1370 |
# frontend can display & edit the actual file content.
|
| 1371 |
if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str):
|
|
|
|
| 1375 |
if resolved:
|
| 1376 |
tool_args = {**tool_args, "script": resolved}
|
| 1377 |
|
| 1378 |
+
tool_payload = {
|
| 1379 |
"tool": tool_name,
|
| 1380 |
"arguments": tool_args,
|
| 1381 |
"tool_call_id": tc.id,
|
| 1382 |
+
}
|
| 1383 |
+
if decision.auto_approval_blocked:
|
| 1384 |
+
tool_payload.update(
|
| 1385 |
+
{
|
| 1386 |
+
"auto_approval_blocked": True,
|
| 1387 |
+
"block_reason": decision.block_reason,
|
| 1388 |
+
"estimated_cost_usd": decision.estimated_cost_usd,
|
| 1389 |
+
"remaining_cap_usd": decision.remaining_cap_usd,
|
| 1390 |
+
}
|
| 1391 |
+
)
|
| 1392 |
+
blocked_payloads.append(tool_payload)
|
| 1393 |
+
tools_data.append(tool_payload)
|
| 1394 |
+
|
| 1395 |
+
event_data = {"tools": tools_data, "count": len(tools_data)}
|
| 1396 |
+
if blocked_payloads:
|
| 1397 |
+
first = blocked_payloads[0]
|
| 1398 |
+
event_data.update(
|
| 1399 |
+
{
|
| 1400 |
+
"auto_approval_blocked": True,
|
| 1401 |
+
"block_reason": first.get("block_reason"),
|
| 1402 |
+
"estimated_cost_usd": first.get("estimated_cost_usd"),
|
| 1403 |
+
"remaining_cap_usd": first.get("remaining_cap_usd"),
|
| 1404 |
+
}
|
| 1405 |
+
)
|
| 1406 |
await session.send_event(Event(
|
| 1407 |
event_type="approval_required",
|
| 1408 |
+
data=event_data,
|
| 1409 |
))
|
| 1410 |
|
| 1411 |
# Store all approval-requiring tools (ToolCall objects for execution)
|
| 1412 |
session.pending_approval = {
|
| 1413 |
+
"tool_calls": [tc for tc, _, _, _ in approval_required_tools],
|
| 1414 |
}
|
| 1415 |
|
| 1416 |
# Return early - wait for EXEC_APPROVAL operation
|
|
|
|
| 1600 |
)
|
| 1601 |
)
|
| 1602 |
|
| 1603 |
+
await _record_manual_approved_spend_if_needed(session, tool_name, tool_args)
|
| 1604 |
+
|
| 1605 |
output, success = await session.tool_router.call_tool(
|
| 1606 |
tool_name, tool_args, session=session, tool_call_id=tc.id
|
| 1607 |
)
|
agent/core/approval_policy.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared predicates for approval-gated tool operations."""
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def normalize_tool_operation(operation: Any) -> str:
|
| 7 |
+
return str(operation or "").strip().lower()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def is_scheduled_operation(operation: Any) -> bool:
|
| 11 |
+
return normalize_tool_operation(operation).startswith("scheduled ")
|
agent/core/cost_estimation.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conservative cost estimates for auto-approved infrastructure actions."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import httpx
|
| 10 |
+
|
| 11 |
+
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
|
| 12 |
+
JOBS_HARDWARE_URL = f"{OPENID_PROVIDER_URL}/api/jobs/hardware"
|
| 13 |
+
JOBS_PRICE_CACHE_TTL_S = 6 * 60 * 60
|
| 14 |
+
|
| 15 |
+
DEFAULT_JOB_TIMEOUT_HOURS = 0.5
|
| 16 |
+
DEFAULT_SANDBOX_RESERVATION_HOURS = 1.0
|
| 17 |
+
|
| 18 |
+
# Static fallback prices are intentionally conservative enough for a budget
|
| 19 |
+
# guard. The live /api/jobs/hardware catalog wins whenever it is reachable.
|
| 20 |
+
HF_JOBS_PRICE_USD_PER_HOUR: dict[str, float] = {
|
| 21 |
+
"cpu-basic": 0.05,
|
| 22 |
+
"cpu-upgrade": 0.25,
|
| 23 |
+
"cpu-performance": 0.50,
|
| 24 |
+
"cpu-xl": 1.00,
|
| 25 |
+
"t4-small": 0.60,
|
| 26 |
+
"t4-medium": 0.90,
|
| 27 |
+
"l4x1": 1.00,
|
| 28 |
+
"l4x4": 4.00,
|
| 29 |
+
"l40sx1": 2.00,
|
| 30 |
+
"l40sx4": 8.00,
|
| 31 |
+
"l40sx8": 16.00,
|
| 32 |
+
"a10g-small": 1.00,
|
| 33 |
+
"a10g-large": 2.00,
|
| 34 |
+
"a10g-largex2": 4.00,
|
| 35 |
+
"a10g-largex4": 8.00,
|
| 36 |
+
"a100-large": 4.00,
|
| 37 |
+
"a100x4": 16.00,
|
| 38 |
+
"a100x8": 32.00,
|
| 39 |
+
"h200": 10.00,
|
| 40 |
+
"h200x2": 20.00,
|
| 41 |
+
"h200x4": 40.00,
|
| 42 |
+
"h200x8": 80.00,
|
| 43 |
+
"inf2x6": 6.00,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
SPACE_PRICE_USD_PER_HOUR: dict[str, float] = {
|
| 47 |
+
"cpu-basic": 0.0,
|
| 48 |
+
"cpu-upgrade": 0.05,
|
| 49 |
+
"cpu-performance": 0.50,
|
| 50 |
+
"cpu-xl": 1.00,
|
| 51 |
+
"t4-small": 0.60,
|
| 52 |
+
"t4-medium": 0.90,
|
| 53 |
+
"l4x1": 1.00,
|
| 54 |
+
"l4x4": 4.00,
|
| 55 |
+
"l40sx1": 2.00,
|
| 56 |
+
"l40sx4": 8.00,
|
| 57 |
+
"l40sx8": 16.00,
|
| 58 |
+
"a10g-small": 1.00,
|
| 59 |
+
"a10g-large": 2.00,
|
| 60 |
+
"a10g-largex2": 4.00,
|
| 61 |
+
"a10g-largex4": 8.00,
|
| 62 |
+
"a100-large": 4.00,
|
| 63 |
+
"a100x4": 16.00,
|
| 64 |
+
"a100x8": 32.00,
|
| 65 |
+
"h200": 10.00,
|
| 66 |
+
"h200x2": 20.00,
|
| 67 |
+
"h200x4": 40.00,
|
| 68 |
+
"h200x8": 80.00,
|
| 69 |
+
"inf2x6": 6.00,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
_DURATION_RE = re.compile(r"^\s*(\d+(?:\.\d+)?)\s*([smhd]?)\s*$", re.IGNORECASE)
|
| 73 |
+
_PRICE_RE = re.compile(r"(\d+(?:\.\d+)?)")
|
| 74 |
+
_jobs_price_cache: tuple[float, dict[str, float]] | None = None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass(frozen=True)
|
| 78 |
+
class CostEstimate:
|
| 79 |
+
"""Estimated cost for a tool call.
|
| 80 |
+
|
| 81 |
+
``estimated_cost_usd=None`` means the call may be billable but we could not
|
| 82 |
+
estimate it safely, so auto-approval should fall back to a human decision.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
estimated_cost_usd: float | None
|
| 86 |
+
billable: bool
|
| 87 |
+
block_reason: str | None = None
|
| 88 |
+
label: str | None = None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def parse_timeout_hours(value: Any, *, default_hours: float = DEFAULT_JOB_TIMEOUT_HOURS) -> float | None:
|
| 92 |
+
"""Parse HF timeout values into hours.
|
| 93 |
+
|
| 94 |
+
Strings accept ``s``, ``m``, ``h``, or ``d`` suffixes. Numeric values are
|
| 95 |
+
treated as seconds, matching the Hub client's typed timeout parameter.
|
| 96 |
+
"""
|
| 97 |
+
if value is None or value == "":
|
| 98 |
+
return default_hours
|
| 99 |
+
if isinstance(value, bool):
|
| 100 |
+
return None
|
| 101 |
+
if isinstance(value, int | float):
|
| 102 |
+
seconds = float(value)
|
| 103 |
+
return seconds / 3600 if seconds > 0 else None
|
| 104 |
+
if not isinstance(value, str):
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
match = _DURATION_RE.match(value)
|
| 108 |
+
if not match:
|
| 109 |
+
return None
|
| 110 |
+
amount = float(match.group(1))
|
| 111 |
+
unit = match.group(2).lower() or "s"
|
| 112 |
+
if amount <= 0:
|
| 113 |
+
return None
|
| 114 |
+
if unit == "s":
|
| 115 |
+
return amount / 3600
|
| 116 |
+
if unit == "m":
|
| 117 |
+
return amount / 60
|
| 118 |
+
if unit == "h":
|
| 119 |
+
return amount
|
| 120 |
+
if unit == "d":
|
| 121 |
+
return amount * 24
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _extract_flavor(item: dict[str, Any]) -> str | None:
|
| 126 |
+
for key in ("flavor", "name", "id", "value", "hardware", "hardware_flavor"):
|
| 127 |
+
value = item.get(key)
|
| 128 |
+
if isinstance(value, str) and value:
|
| 129 |
+
return value
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _coerce_price(value: Any) -> float | None:
|
| 134 |
+
if isinstance(value, bool) or value is None:
|
| 135 |
+
return None
|
| 136 |
+
if isinstance(value, int | float):
|
| 137 |
+
return float(value) if value >= 0 else None
|
| 138 |
+
if isinstance(value, str):
|
| 139 |
+
match = _PRICE_RE.search(value.replace(",", ""))
|
| 140 |
+
if match:
|
| 141 |
+
return float(match.group(1))
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _extract_hourly_price(item: dict[str, Any]) -> float | None:
|
| 146 |
+
for key in (
|
| 147 |
+
"price",
|
| 148 |
+
"price_usd",
|
| 149 |
+
"priceUsd",
|
| 150 |
+
"price_per_hour",
|
| 151 |
+
"pricePerHour",
|
| 152 |
+
"hourly_price",
|
| 153 |
+
"hourlyPrice",
|
| 154 |
+
"usd_per_hour",
|
| 155 |
+
"usdPerHour",
|
| 156 |
+
):
|
| 157 |
+
price = _coerce_price(item.get(key))
|
| 158 |
+
if price is not None:
|
| 159 |
+
return price
|
| 160 |
+
for key in ("pricing", "billing", "cost"):
|
| 161 |
+
nested = item.get(key)
|
| 162 |
+
if isinstance(nested, dict):
|
| 163 |
+
price = _extract_hourly_price(nested)
|
| 164 |
+
if price is not None:
|
| 165 |
+
return price
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _iter_hardware_items(payload: Any):
|
| 170 |
+
if isinstance(payload, list):
|
| 171 |
+
for item in payload:
|
| 172 |
+
yield from _iter_hardware_items(item)
|
| 173 |
+
elif isinstance(payload, dict):
|
| 174 |
+
if _extract_flavor(payload):
|
| 175 |
+
yield payload
|
| 176 |
+
for key in ("hardware", "flavors", "items", "data", "jobs"):
|
| 177 |
+
child = payload.get(key)
|
| 178 |
+
if child is not None:
|
| 179 |
+
yield from _iter_hardware_items(child)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _parse_jobs_price_catalog(payload: Any) -> dict[str, float]:
|
| 183 |
+
prices: dict[str, float] = {}
|
| 184 |
+
for item in _iter_hardware_items(payload):
|
| 185 |
+
flavor = _extract_flavor(item)
|
| 186 |
+
price = _extract_hourly_price(item)
|
| 187 |
+
if flavor and price is not None:
|
| 188 |
+
prices[flavor] = price
|
| 189 |
+
return prices
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
async def hf_jobs_price_catalog() -> dict[str, float]:
|
| 193 |
+
"""Return live HF Jobs hourly prices, falling back to static prices."""
|
| 194 |
+
global _jobs_price_cache
|
| 195 |
+
now = time.monotonic()
|
| 196 |
+
if _jobs_price_cache and now - _jobs_price_cache[0] < JOBS_PRICE_CACHE_TTL_S:
|
| 197 |
+
return dict(_jobs_price_cache[1])
|
| 198 |
+
|
| 199 |
+
prices: dict[str, float] = {}
|
| 200 |
+
try:
|
| 201 |
+
async with httpx.AsyncClient(timeout=3.0) as client:
|
| 202 |
+
response = await client.get(JOBS_HARDWARE_URL)
|
| 203 |
+
if response.status_code == 200:
|
| 204 |
+
prices = _parse_jobs_price_catalog(response.json())
|
| 205 |
+
except (httpx.HTTPError, ValueError):
|
| 206 |
+
prices = {}
|
| 207 |
+
|
| 208 |
+
if not prices:
|
| 209 |
+
prices = dict(HF_JOBS_PRICE_USD_PER_HOUR)
|
| 210 |
+
else:
|
| 211 |
+
prices = {**HF_JOBS_PRICE_USD_PER_HOUR, **prices}
|
| 212 |
+
|
| 213 |
+
_jobs_price_cache = (now, prices)
|
| 214 |
+
return dict(prices)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
async def estimate_hf_job_cost(args: dict[str, Any]) -> CostEstimate:
|
| 218 |
+
flavor = str(
|
| 219 |
+
args.get("hardware_flavor")
|
| 220 |
+
or args.get("flavor")
|
| 221 |
+
or args.get("hardware")
|
| 222 |
+
or "cpu-basic"
|
| 223 |
+
)
|
| 224 |
+
timeout_hours = parse_timeout_hours(args.get("timeout"))
|
| 225 |
+
if timeout_hours is None:
|
| 226 |
+
return CostEstimate(
|
| 227 |
+
estimated_cost_usd=None,
|
| 228 |
+
billable=True,
|
| 229 |
+
block_reason=f"Could not parse HF job timeout: {args.get('timeout')!r}.",
|
| 230 |
+
label=flavor,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
prices = await hf_jobs_price_catalog()
|
| 234 |
+
price = prices.get(flavor)
|
| 235 |
+
if price is None:
|
| 236 |
+
return CostEstimate(
|
| 237 |
+
estimated_cost_usd=None,
|
| 238 |
+
billable=True,
|
| 239 |
+
block_reason=f"No price is available for HF job hardware '{flavor}'.",
|
| 240 |
+
label=flavor,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return CostEstimate(
|
| 244 |
+
estimated_cost_usd=round(price * timeout_hours, 4),
|
| 245 |
+
billable=price > 0,
|
| 246 |
+
label=flavor,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
async def estimate_sandbox_cost(args: dict[str, Any], *, session: Any = None) -> CostEstimate:
|
| 251 |
+
if session is not None and getattr(session, "sandbox", None):
|
| 252 |
+
return CostEstimate(estimated_cost_usd=0.0, billable=False, label="existing")
|
| 253 |
+
|
| 254 |
+
hardware = str(args.get("hardware") or "cpu-basic")
|
| 255 |
+
price = SPACE_PRICE_USD_PER_HOUR.get(hardware)
|
| 256 |
+
if price is None:
|
| 257 |
+
return CostEstimate(
|
| 258 |
+
estimated_cost_usd=None,
|
| 259 |
+
billable=True,
|
| 260 |
+
block_reason=f"No price is available for sandbox hardware '{hardware}'.",
|
| 261 |
+
label=hardware,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
return CostEstimate(
|
| 265 |
+
estimated_cost_usd=round(price * DEFAULT_SANDBOX_RESERVATION_HOURS, 4),
|
| 266 |
+
billable=price > 0,
|
| 267 |
+
label=hardware,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
async def estimate_tool_cost(
|
| 272 |
+
tool_name: str, args: dict[str, Any], *, session: Any = None
|
| 273 |
+
) -> CostEstimate:
|
| 274 |
+
if tool_name == "sandbox_create":
|
| 275 |
+
return await estimate_sandbox_cost(args, session=session)
|
| 276 |
+
if tool_name == "hf_jobs":
|
| 277 |
+
return await estimate_hf_job_cost(args)
|
| 278 |
+
return CostEstimate(estimated_cost_usd=0.0, billable=False)
|
agent/core/session.py
CHANGED
|
@@ -120,6 +120,9 @@ class Session:
|
|
| 120 |
self.notification_gateway = notification_gateway
|
| 121 |
self.notification_destinations = list(notification_destinations or [])
|
| 122 |
self.defer_turn_complete_notification = defer_turn_complete_notification
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
# Session trajectory logging
|
| 125 |
self.logged_events: list[dict] = []
|
|
@@ -313,6 +316,40 @@ class Session:
|
|
| 313 |
self.config.model_name = model_name
|
| 314 |
self.context_manager.model_max_tokens = _get_max_tokens_safe(model_name)
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
def effective_effort_for(self, model_name: str) -> str | None:
|
| 317 |
"""Resolve the effort level to actually send for ``model_name``.
|
| 318 |
|
|
|
|
| 120 |
self.notification_gateway = notification_gateway
|
| 121 |
self.notification_destinations = list(notification_destinations or [])
|
| 122 |
self.defer_turn_complete_notification = defer_turn_complete_notification
|
| 123 |
+
self.auto_approval_enabled: bool = False
|
| 124 |
+
self.auto_approval_cost_cap_usd: float | None = None
|
| 125 |
+
self.auto_approval_estimated_spend_usd: float = 0.0
|
| 126 |
|
| 127 |
# Session trajectory logging
|
| 128 |
self.logged_events: list[dict] = []
|
|
|
|
| 316 |
self.config.model_name = model_name
|
| 317 |
self.context_manager.model_max_tokens = _get_max_tokens_safe(model_name)
|
| 318 |
|
| 319 |
+
def set_auto_approval_policy(
|
| 320 |
+
self, *, enabled: bool, cost_cap_usd: float | None
|
| 321 |
+
) -> None:
|
| 322 |
+
self.auto_approval_enabled = bool(enabled)
|
| 323 |
+
self.auto_approval_cost_cap_usd = cost_cap_usd
|
| 324 |
+
|
| 325 |
+
def add_auto_approval_estimated_spend(self, amount_usd: float | None) -> None:
|
| 326 |
+
if amount_usd is None or amount_usd <= 0:
|
| 327 |
+
return
|
| 328 |
+
self.auto_approval_estimated_spend_usd = round(
|
| 329 |
+
self.auto_approval_estimated_spend_usd + float(amount_usd), 4
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
@property
|
| 333 |
+
def auto_approval_remaining_usd(self) -> float | None:
|
| 334 |
+
if self.auto_approval_cost_cap_usd is None:
|
| 335 |
+
return None
|
| 336 |
+
return round(
|
| 337 |
+
max(
|
| 338 |
+
0.0,
|
| 339 |
+
self.auto_approval_cost_cap_usd
|
| 340 |
+
- self.auto_approval_estimated_spend_usd,
|
| 341 |
+
),
|
| 342 |
+
4,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def auto_approval_policy_summary(self) -> dict[str, Any]:
|
| 346 |
+
return {
|
| 347 |
+
"enabled": self.auto_approval_enabled,
|
| 348 |
+
"cost_cap_usd": self.auto_approval_cost_cap_usd,
|
| 349 |
+
"estimated_spend_usd": round(self.auto_approval_estimated_spend_usd, 4),
|
| 350 |
+
"remaining_usd": self.auto_approval_remaining_usd,
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
def effective_effort_for(self, model_name: str) -> str | None:
|
| 354 |
"""Resolve the effort level to actually send for ``model_name``.
|
| 355 |
|
agent/core/session_persistence.py
CHANGED
|
@@ -176,6 +176,9 @@ class MongoSessionStore(NoopSessionStore):
|
|
| 176 |
pending_approval: list[dict[str, Any]] | None = None,
|
| 177 |
claude_counted: bool = False,
|
| 178 |
notification_destinations: list[str] | None = None,
|
|
|
|
|
|
|
|
|
|
| 179 |
) -> None:
|
| 180 |
if not self._ready():
|
| 181 |
return
|
|
@@ -204,6 +207,9 @@ class MongoSessionStore(NoopSessionStore):
|
|
| 204 |
"pending_approval": pending_approval or [],
|
| 205 |
"claude_counted": claude_counted,
|
| 206 |
"notification_destinations": notification_destinations or [],
|
|
|
|
|
|
|
|
|
|
| 207 |
},
|
| 208 |
},
|
| 209 |
upsert=True,
|
|
@@ -224,6 +230,9 @@ class MongoSessionStore(NoopSessionStore):
|
|
| 224 |
claude_counted: bool = False,
|
| 225 |
created_at: datetime | None = None,
|
| 226 |
notification_destinations: list[str] | None = None,
|
|
|
|
|
|
|
|
|
|
| 227 |
) -> None:
|
| 228 |
if not self._ready():
|
| 229 |
return
|
|
@@ -241,6 +250,9 @@ class MongoSessionStore(NoopSessionStore):
|
|
| 241 |
pending_approval=pending_approval,
|
| 242 |
claude_counted=claude_counted,
|
| 243 |
notification_destinations=notification_destinations,
|
|
|
|
|
|
|
|
|
|
| 244 |
)
|
| 245 |
ops: list[Any] = []
|
| 246 |
for idx, raw in enumerate(messages):
|
|
|
|
| 176 |
pending_approval: list[dict[str, Any]] | None = None,
|
| 177 |
claude_counted: bool = False,
|
| 178 |
notification_destinations: list[str] | None = None,
|
| 179 |
+
auto_approval_enabled: bool = False,
|
| 180 |
+
auto_approval_cost_cap_usd: float | None = None,
|
| 181 |
+
auto_approval_estimated_spend_usd: float = 0.0,
|
| 182 |
) -> None:
|
| 183 |
if not self._ready():
|
| 184 |
return
|
|
|
|
| 207 |
"pending_approval": pending_approval or [],
|
| 208 |
"claude_counted": claude_counted,
|
| 209 |
"notification_destinations": notification_destinations or [],
|
| 210 |
+
"auto_approval_enabled": auto_approval_enabled,
|
| 211 |
+
"auto_approval_cost_cap_usd": auto_approval_cost_cap_usd,
|
| 212 |
+
"auto_approval_estimated_spend_usd": auto_approval_estimated_spend_usd,
|
| 213 |
},
|
| 214 |
},
|
| 215 |
upsert=True,
|
|
|
|
| 230 |
claude_counted: bool = False,
|
| 231 |
created_at: datetime | None = None,
|
| 232 |
notification_destinations: list[str] | None = None,
|
| 233 |
+
auto_approval_enabled: bool = False,
|
| 234 |
+
auto_approval_cost_cap_usd: float | None = None,
|
| 235 |
+
auto_approval_estimated_spend_usd: float = 0.0,
|
| 236 |
) -> None:
|
| 237 |
if not self._ready():
|
| 238 |
return
|
|
|
|
| 250 |
pending_approval=pending_approval,
|
| 251 |
claude_counted=claude_counted,
|
| 252 |
notification_destinations=notification_destinations,
|
| 253 |
+
auto_approval_enabled=auto_approval_enabled,
|
| 254 |
+
auto_approval_cost_cap_usd=auto_approval_cost_cap_usd,
|
| 255 |
+
auto_approval_estimated_spend_usd=auto_approval_estimated_spend_usd,
|
| 256 |
)
|
| 257 |
ops: list[Any] = []
|
| 258 |
for idx, raw in enumerate(messages):
|
agent/main.py
CHANGED
|
@@ -21,6 +21,7 @@ import litellm
|
|
| 21 |
from prompt_toolkit import PromptSession
|
| 22 |
|
| 23 |
from agent.config import load_config
|
|
|
|
| 24 |
from agent.core.agent_loop import submission_loop
|
| 25 |
from agent.core import model_switcher
|
| 26 |
from agent.core.hf_tokens import resolve_hf_token
|
|
@@ -55,6 +56,20 @@ litellm.suppress_debug_info = True
|
|
| 55 |
CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json"
|
| 56 |
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
def _configure_runtime_logging() -> None:
|
| 59 |
"""Keep third-party warning spam from punching through the interactive UI."""
|
| 60 |
import logging
|
|
@@ -375,8 +390,11 @@ async def event_listener(
|
|
| 375 |
tools_data = event.data.get("tools", []) if event.data else []
|
| 376 |
count = event.data.get("count", 0) if event.data else 0
|
| 377 |
|
| 378 |
-
# If yolo mode is active, auto-approve everything
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
| 380 |
approvals = [
|
| 381 |
{
|
| 382 |
"tool_call_id": t.get("tool_call_id", ""),
|
|
@@ -1293,14 +1311,18 @@ async def headless_main(
|
|
| 1293 |
else:
|
| 1294 |
print_tool_log(tool, log)
|
| 1295 |
elif event.event_type == "approval_required":
|
| 1296 |
-
# Auto-approve
|
| 1297 |
-
#
|
| 1298 |
tools_data = event.data.get("tools", []) if event.data else []
|
| 1299 |
approvals = [
|
| 1300 |
{
|
| 1301 |
"tool_call_id": t.get("tool_call_id", ""),
|
| 1302 |
-
"approved":
|
| 1303 |
-
"feedback":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1304 |
}
|
| 1305 |
for t in tools_data
|
| 1306 |
]
|
|
|
|
| 21 |
from prompt_toolkit import PromptSession
|
| 22 |
|
| 23 |
from agent.config import load_config
|
| 24 |
+
from agent.core.approval_policy import is_scheduled_operation
|
| 25 |
from agent.core.agent_loop import submission_loop
|
| 26 |
from agent.core import model_switcher
|
| 27 |
from agent.core.hf_tokens import resolve_hf_token
|
|
|
|
| 56 |
CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json"
|
| 57 |
|
| 58 |
|
| 59 |
+
def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
|
| 60 |
+
if tool_info.get("tool") != "hf_jobs":
|
| 61 |
+
return False
|
| 62 |
+
arguments = tool_info.get("arguments") or {}
|
| 63 |
+
if isinstance(arguments, str):
|
| 64 |
+
try:
|
| 65 |
+
arguments = json.loads(arguments)
|
| 66 |
+
except json.JSONDecodeError:
|
| 67 |
+
return False
|
| 68 |
+
if not isinstance(arguments, dict):
|
| 69 |
+
return False
|
| 70 |
+
return is_scheduled_operation(arguments.get("operation"))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
def _configure_runtime_logging() -> None:
|
| 74 |
"""Keep third-party warning spam from punching through the interactive UI."""
|
| 75 |
import logging
|
|
|
|
| 390 |
tools_data = event.data.get("tools", []) if event.data else []
|
| 391 |
count = event.data.get("count", 0) if event.data else 0
|
| 392 |
|
| 393 |
+
# If yolo mode is active, auto-approve everything except
|
| 394 |
+
# scheduled HF jobs, whose recurring cost stays manual.
|
| 395 |
+
if config and config.yolo_mode and not any(
|
| 396 |
+
_is_scheduled_hf_job_tool(t) for t in tools_data
|
| 397 |
+
):
|
| 398 |
approvals = [
|
| 399 |
{
|
| 400 |
"tool_call_id": t.get("tool_call_id", ""),
|
|
|
|
| 1311 |
else:
|
| 1312 |
print_tool_log(tool, log)
|
| 1313 |
elif event.event_type == "approval_required":
|
| 1314 |
+
# Auto-approve in headless mode, except scheduled HF jobs. Those
|
| 1315 |
+
# are rejected because their recurring cost needs manual approval.
|
| 1316 |
tools_data = event.data.get("tools", []) if event.data else []
|
| 1317 |
approvals = [
|
| 1318 |
{
|
| 1319 |
"tool_call_id": t.get("tool_call_id", ""),
|
| 1320 |
+
"approved": not _is_scheduled_hf_job_tool(t),
|
| 1321 |
+
"feedback": (
|
| 1322 |
+
"Scheduled HF jobs require manual approval."
|
| 1323 |
+
if _is_scheduled_hf_job_tool(t)
|
| 1324 |
+
else None
|
| 1325 |
+
),
|
| 1326 |
}
|
| 1327 |
for t in tools_data
|
| 1328 |
]
|
agent/prompts/system_prompt_v3.yaml
CHANGED
|
@@ -42,7 +42,7 @@ system_prompt: |
|
|
| 42 |
|
| 43 |
SILENT DATASET SUBSTITUTION: When a requested dataset fails to load, you will silently switch to a different one without telling the user. Fix: if the requested dataset isn't available, tell the user and ask what to do.
|
| 44 |
|
| 45 |
-
|
| 46 |
|
| 47 |
SCOPE-CHANGING FIXES: Avoid at all costs! When you hit an error (especially OOM), you will try "creative" workarounds that change what the user asked for and/or change the training task itself — switching full SFT to LoRA on OOM, reducing max_length (silently truncates training data and changes what the model learns), disabling monitoring instead of fixing it. Do not do this. Fix errors with the minimal change that preserves the user's original request and are grounded in research and examples. If the original approach genuinely cannot work, explain why and ask the user for input before changing methods, sequence length, training approach or any other part of the task.
|
| 48 |
|
|
|
|
| 42 |
|
| 43 |
SILENT DATASET SUBSTITUTION: When a requested dataset fails to load, you will silently switch to a different one without telling the user. Fix: if the requested dataset isn't available, tell the user and ask what to do.
|
| 44 |
|
| 45 |
+
PREFER HUB KERNELS OVER COMPILING ATTENTION: Do NOT pip install 'flash-attn' to enable flash_attention_2 building from source can take many minutes to hours and often fails on the job's CUDA/PyTorch combo. Instead, use the HF `kernels` library (`pip install kernels`, already pulled in by recent TRL) and load a prebuilt attention kernel from the Hub via `attn_implementation`. Examples: `AutoModelForCausalLM.from_pretrained(..., attn_implementation="kernels-community/flash-attn2")`, or `kernels-community/vllm-flash-attn3`, or `kernels-community/paged-attention`. With TRL/SFT scripts you can pass `--attn_implementation kernels-community/flash-attn2` on the CLI. Search additional kernels at https://huggingface.co/models?other=kernel. Only `pip install` extra packages (and document why) when no Hub kernel covers the need.
|
| 46 |
|
| 47 |
SCOPE-CHANGING FIXES: Avoid at all costs! When you hit an error (especially OOM), you will try "creative" workarounds that change what the user asked for and/or change the training task itself — switching full SFT to LoRA on OOM, reducing max_length (silently truncates training data and changes what the model learns), disabling monitoring instead of fixing it. Do not do this. Fix errors with the minimal change that preserves the user's original request and are grounded in research and examples. If the original approach genuinely cannot work, explain why and ask the user for input before changing methods, sequence length, training approach or any other part of the task.
|
| 48 |
|
agent/tools/docs_tools.py
CHANGED
|
@@ -932,7 +932,7 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
|
|
| 932 |
"• argilla — Data annotation, feedback, and human-in-the-loop workflows.\n"
|
| 933 |
"• distilabel — Synthetic data generation and distillation pipelines.\n"
|
| 934 |
"• microsoft-azure — Azure deployment and integration guides.\n"
|
| 935 |
-
"• kernels —
|
| 936 |
"• google-cloud — GCP deployment and serving workflows.\n"
|
| 937 |
),
|
| 938 |
},
|
|
|
|
| 932 |
"• argilla — Data annotation, feedback, and human-in-the-loop workflows.\n"
|
| 933 |
"• distilabel — Synthetic data generation and distillation pipelines.\n"
|
| 934 |
"• microsoft-azure — Azure deployment and integration guides.\n"
|
| 935 |
+
"• kernels — Load prebuilt compute kernels (E.g. flash-attn2) from the Hub via `attn_implementation`; avoids compiling flash-attn from source.\n"
|
| 936 |
"• google-cloud — GCP deployment and serving workflows.\n"
|
| 937 |
),
|
| 938 |
},
|
backend/models.py
CHANGED
|
@@ -76,6 +76,15 @@ class PendingApprovalTool(BaseModel):
|
|
| 76 |
arguments: dict[str, Any] = {}
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
class SessionInfo(BaseModel):
|
| 80 |
"""Session metadata."""
|
| 81 |
|
|
@@ -89,6 +98,9 @@ class SessionInfo(BaseModel):
|
|
| 89 |
model: str | None = None
|
| 90 |
title: str | None = None
|
| 91 |
notification_destinations: list[str] = Field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
class SessionNotificationsRequest(BaseModel):
|
|
@@ -97,6 +109,13 @@ class SessionNotificationsRequest(BaseModel):
|
|
| 97 |
destinations: list[str]
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
class HealthResponse(BaseModel):
|
| 101 |
"""Health check response."""
|
| 102 |
|
|
|
|
| 76 |
arguments: dict[str, Any] = {}
|
| 77 |
|
| 78 |
|
| 79 |
+
class SessionAutoApprovalInfo(BaseModel):
|
| 80 |
+
"""Per-session auto-approval budget state."""
|
| 81 |
+
|
| 82 |
+
enabled: bool = False
|
| 83 |
+
cost_cap_usd: float | None = None
|
| 84 |
+
estimated_spend_usd: float = 0.0
|
| 85 |
+
remaining_usd: float | None = None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
class SessionInfo(BaseModel):
|
| 89 |
"""Session metadata."""
|
| 90 |
|
|
|
|
| 98 |
model: str | None = None
|
| 99 |
title: str | None = None
|
| 100 |
notification_destinations: list[str] = Field(default_factory=list)
|
| 101 |
+
auto_approval: SessionAutoApprovalInfo = Field(
|
| 102 |
+
default_factory=SessionAutoApprovalInfo
|
| 103 |
+
)
|
| 104 |
|
| 105 |
|
| 106 |
class SessionNotificationsRequest(BaseModel):
|
|
|
|
| 109 |
destinations: list[str]
|
| 110 |
|
| 111 |
|
| 112 |
+
class SessionYoloRequest(BaseModel):
|
| 113 |
+
"""Update a session's auto-approval policy."""
|
| 114 |
+
|
| 115 |
+
enabled: bool
|
| 116 |
+
cost_cap_usd: float | None = Field(default=None, ge=0)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
class HealthResponse(BaseModel):
|
| 120 |
"""Health check response."""
|
| 121 |
|
backend/routes/agent.py
CHANGED
|
@@ -26,6 +26,7 @@ from models import (
|
|
| 26 |
SessionInfo,
|
| 27 |
SessionNotificationsRequest,
|
| 28 |
SessionResponse,
|
|
|
|
| 29 |
SubmitRequest,
|
| 30 |
TruncateRequest,
|
| 31 |
)
|
|
@@ -498,6 +499,26 @@ async def set_session_notifications(
|
|
| 498 |
}
|
| 499 |
|
| 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
@router.get("/user/quota")
|
| 502 |
async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
|
| 503 |
"""Return the user's plan tier and today's premium-model quota state."""
|
|
|
|
| 26 |
SessionInfo,
|
| 27 |
SessionNotificationsRequest,
|
| 28 |
SessionResponse,
|
| 29 |
+
SessionYoloRequest,
|
| 30 |
SubmitRequest,
|
| 31 |
TruncateRequest,
|
| 32 |
)
|
|
|
|
| 499 |
}
|
| 500 |
|
| 501 |
|
| 502 |
+
@router.patch("/session/{session_id}/yolo")
|
| 503 |
+
async def set_session_yolo(
|
| 504 |
+
session_id: str,
|
| 505 |
+
body: SessionYoloRequest,
|
| 506 |
+
user: dict = Depends(get_current_user),
|
| 507 |
+
) -> dict:
|
| 508 |
+
"""Update the session-scoped auto-approval policy."""
|
| 509 |
+
await _check_session_access(session_id, user)
|
| 510 |
+
try:
|
| 511 |
+
summary = await session_manager.update_session_auto_approval(
|
| 512 |
+
session_id,
|
| 513 |
+
enabled=body.enabled,
|
| 514 |
+
cost_cap_usd=body.cost_cap_usd,
|
| 515 |
+
cap_provided="cost_cap_usd" in body.model_fields_set,
|
| 516 |
+
)
|
| 517 |
+
except ValueError as e:
|
| 518 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 519 |
+
return {"session_id": session_id, **summary}
|
| 520 |
+
|
| 521 |
+
|
| 522 |
@router.get("/user/quota")
|
| 523 |
async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
|
| 524 |
"""Return the user's plan tier and today's premium-model quota state."""
|
backend/session_manager.py
CHANGED
|
@@ -116,6 +116,7 @@ class SessionCapacityError(Exception):
|
|
| 116 |
# and per-request overhead.
|
| 117 |
MAX_SESSIONS: int = 200
|
| 118 |
MAX_SESSIONS_PER_USER: int = 10
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
class SessionManager:
|
|
@@ -297,6 +298,20 @@ class SessionManager:
|
|
| 297 |
return "ended"
|
| 298 |
return "idle"
|
| 299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
async def _start_agent_session(
|
| 301 |
self,
|
| 302 |
*,
|
|
@@ -370,6 +385,20 @@ class SessionManager:
|
|
| 370 |
notification_destinations=list(
|
| 371 |
agent_session.session.notification_destinations
|
| 372 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
)
|
| 374 |
except Exception as e:
|
| 375 |
logger.warning(
|
|
@@ -451,6 +480,14 @@ class SessionManager:
|
|
| 451 |
|
| 452 |
self._restore_pending_approval(session, meta.get("pending_approval") or [])
|
| 453 |
session.turn_count = int(meta.get("turn_count") or 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
created_at = meta.get("created_at")
|
| 456 |
if not isinstance(created_at, datetime):
|
|
@@ -883,6 +920,43 @@ class SessionManager:
|
|
| 883 |
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 884 |
return True
|
| 885 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 886 |
def get_session_owner(self, session_id: str) -> str | None:
|
| 887 |
"""Get the user_id that owns a session, or None if session doesn't exist."""
|
| 888 |
agent_session = self.sessions.get(session_id)
|
|
@@ -925,6 +999,7 @@ class SessionManager:
|
|
| 925 |
"notification_destinations": list(
|
| 926 |
agent_session.session.notification_destinations
|
| 927 |
),
|
|
|
|
| 928 |
}
|
| 929 |
|
| 930 |
def set_notification_destinations(
|
|
@@ -991,6 +1066,25 @@ class SessionManager:
|
|
| 991 |
"model": row.get("model"),
|
| 992 |
"title": row.get("title"),
|
| 993 |
"notification_destinations": row.get("notification_destinations") or [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 994 |
}
|
| 995 |
)
|
| 996 |
return results
|
|
|
|
| 116 |
# and per-request overhead.
|
| 117 |
MAX_SESSIONS: int = 200
|
| 118 |
MAX_SESSIONS_PER_USER: int = 10
|
| 119 |
+
DEFAULT_YOLO_COST_CAP_USD: float = 5.0
|
| 120 |
|
| 121 |
|
| 122 |
class SessionManager:
|
|
|
|
| 298 |
return "ended"
|
| 299 |
return "idle"
|
| 300 |
|
| 301 |
+
@staticmethod
|
| 302 |
+
def _auto_approval_summary(session: Session) -> dict[str, Any]:
|
| 303 |
+
if hasattr(session, "auto_approval_policy_summary"):
|
| 304 |
+
return session.auto_approval_policy_summary()
|
| 305 |
+
cap = getattr(session, "auto_approval_cost_cap_usd", None)
|
| 306 |
+
estimated = float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0)
|
| 307 |
+
remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4)
|
| 308 |
+
return {
|
| 309 |
+
"enabled": bool(getattr(session, "auto_approval_enabled", False)),
|
| 310 |
+
"cost_cap_usd": cap,
|
| 311 |
+
"estimated_spend_usd": round(estimated, 4),
|
| 312 |
+
"remaining_usd": remaining,
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
async def _start_agent_session(
|
| 316 |
self,
|
| 317 |
*,
|
|
|
|
| 385 |
notification_destinations=list(
|
| 386 |
agent_session.session.notification_destinations
|
| 387 |
),
|
| 388 |
+
auto_approval_enabled=bool(
|
| 389 |
+
getattr(agent_session.session, "auto_approval_enabled", False)
|
| 390 |
+
),
|
| 391 |
+
auto_approval_cost_cap_usd=getattr(
|
| 392 |
+
agent_session.session, "auto_approval_cost_cap_usd", None
|
| 393 |
+
),
|
| 394 |
+
auto_approval_estimated_spend_usd=float(
|
| 395 |
+
getattr(
|
| 396 |
+
agent_session.session,
|
| 397 |
+
"auto_approval_estimated_spend_usd",
|
| 398 |
+
0.0,
|
| 399 |
+
)
|
| 400 |
+
or 0.0
|
| 401 |
+
),
|
| 402 |
)
|
| 403 |
except Exception as e:
|
| 404 |
logger.warning(
|
|
|
|
| 480 |
|
| 481 |
self._restore_pending_approval(session, meta.get("pending_approval") or [])
|
| 482 |
session.turn_count = int(meta.get("turn_count") or 0)
|
| 483 |
+
session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False))
|
| 484 |
+
raw_cap = meta.get("auto_approval_cost_cap_usd")
|
| 485 |
+
session.auto_approval_cost_cap_usd = (
|
| 486 |
+
float(raw_cap) if isinstance(raw_cap, int | float) else None
|
| 487 |
+
)
|
| 488 |
+
session.auto_approval_estimated_spend_usd = float(
|
| 489 |
+
meta.get("auto_approval_estimated_spend_usd") or 0.0
|
| 490 |
+
)
|
| 491 |
|
| 492 |
created_at = meta.get("created_at")
|
| 493 |
if not isinstance(created_at, datetime):
|
|
|
|
| 920 |
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 921 |
return True
|
| 922 |
|
| 923 |
+
async def update_session_auto_approval(
|
| 924 |
+
self,
|
| 925 |
+
session_id: str,
|
| 926 |
+
*,
|
| 927 |
+
enabled: bool,
|
| 928 |
+
cost_cap_usd: float | None,
|
| 929 |
+
cap_provided: bool = False,
|
| 930 |
+
) -> dict[str, Any]:
|
| 931 |
+
agent_session = self.sessions.get(session_id)
|
| 932 |
+
if not agent_session or not agent_session.is_active:
|
| 933 |
+
raise ValueError("Session not found or inactive")
|
| 934 |
+
|
| 935 |
+
session = agent_session.session
|
| 936 |
+
if enabled:
|
| 937 |
+
if not cap_provided and cost_cap_usd is None:
|
| 938 |
+
cost_cap_usd = getattr(
|
| 939 |
+
session, "auto_approval_cost_cap_usd", None
|
| 940 |
+
)
|
| 941 |
+
if cost_cap_usd is None:
|
| 942 |
+
cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD
|
| 943 |
+
elif cost_cap_usd is None:
|
| 944 |
+
cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD
|
| 945 |
+
else:
|
| 946 |
+
if not cap_provided:
|
| 947 |
+
cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None)
|
| 948 |
+
|
| 949 |
+
if hasattr(session, "set_auto_approval_policy"):
|
| 950 |
+
session.set_auto_approval_policy(
|
| 951 |
+
enabled=enabled,
|
| 952 |
+
cost_cap_usd=cost_cap_usd,
|
| 953 |
+
)
|
| 954 |
+
else:
|
| 955 |
+
session.auto_approval_enabled = bool(enabled)
|
| 956 |
+
session.auto_approval_cost_cap_usd = cost_cap_usd
|
| 957 |
+
await self.persist_session_snapshot(agent_session)
|
| 958 |
+
return self._auto_approval_summary(session)
|
| 959 |
+
|
| 960 |
def get_session_owner(self, session_id: str) -> str | None:
|
| 961 |
"""Get the user_id that owns a session, or None if session doesn't exist."""
|
| 962 |
agent_session = self.sessions.get(session_id)
|
|
|
|
| 999 |
"notification_destinations": list(
|
| 1000 |
agent_session.session.notification_destinations
|
| 1001 |
),
|
| 1002 |
+
"auto_approval": self._auto_approval_summary(agent_session.session),
|
| 1003 |
}
|
| 1004 |
|
| 1005 |
def set_notification_destinations(
|
|
|
|
| 1066 |
"model": row.get("model"),
|
| 1067 |
"title": row.get("title"),
|
| 1068 |
"notification_destinations": row.get("notification_destinations") or [],
|
| 1069 |
+
"auto_approval": {
|
| 1070 |
+
"enabled": bool(row.get("auto_approval_enabled", False)),
|
| 1071 |
+
"cost_cap_usd": row.get("auto_approval_cost_cap_usd"),
|
| 1072 |
+
"estimated_spend_usd": float(
|
| 1073 |
+
row.get("auto_approval_estimated_spend_usd") or 0.0
|
| 1074 |
+
),
|
| 1075 |
+
"remaining_usd": (
|
| 1076 |
+
None
|
| 1077 |
+
if row.get("auto_approval_cost_cap_usd") is None
|
| 1078 |
+
else round(
|
| 1079 |
+
max(
|
| 1080 |
+
0.0,
|
| 1081 |
+
float(row.get("auto_approval_cost_cap_usd") or 0.0)
|
| 1082 |
+
- float(row.get("auto_approval_estimated_spend_usd") or 0.0),
|
| 1083 |
+
),
|
| 1084 |
+
4,
|
| 1085 |
+
)
|
| 1086 |
+
),
|
| 1087 |
+
},
|
| 1088 |
}
|
| 1089 |
)
|
| 1090 |
return results
|
frontend/src/components/Chat/ToolCallGroup.tsx
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
| 2 |
-
import { Box, Stack, Typography, Chip, Button, TextField, IconButton, Link, CircularProgress } from '@mui/material';
|
| 3 |
import CheckCircleOutlineIcon from '@mui/icons-material/CheckCircleOutline';
|
| 4 |
import ErrorOutlineIcon from '@mui/icons-material/ErrorOutline';
|
| 5 |
import OpenInNewIcon from '@mui/icons-material/OpenInNew';
|
|
@@ -502,6 +502,7 @@ function InlineApproval({
|
|
| 502 |
}) {
|
| 503 |
const [feedback, setFeedback] = useState('');
|
| 504 |
const args = input as Record<string, unknown> | undefined;
|
|
|
|
| 505 |
const { setPanel, getEditedScript } = useAgentStore();
|
| 506 |
const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore();
|
| 507 |
const hasEditedScript = !!getEditedScript(toolCallId);
|
|
@@ -521,6 +522,24 @@ function InlineApproval({
|
|
| 521 |
|
| 522 |
return (
|
| 523 |
<Box sx={{ px: 1.5, py: 1.5, borderTop: '1px solid var(--tool-border)' }}>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
{toolName === 'sandbox_create' && args && (() => {
|
| 525 |
const hw = String(args.hardware || 'cpu-basic');
|
| 526 |
const cost = costLabel(hw);
|
|
|
|
| 1 |
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
| 2 |
+
import { Alert, Box, Stack, Typography, Chip, Button, TextField, IconButton, Link, CircularProgress } from '@mui/material';
|
| 3 |
import CheckCircleOutlineIcon from '@mui/icons-material/CheckCircleOutline';
|
| 4 |
import ErrorOutlineIcon from '@mui/icons-material/ErrorOutline';
|
| 5 |
import OpenInNewIcon from '@mui/icons-material/OpenInNew';
|
|
|
|
| 502 |
}) {
|
| 503 |
const [feedback, setFeedback] = useState('');
|
| 504 |
const args = input as Record<string, unknown> | undefined;
|
| 505 |
+
const autoApproval = useAgentStore((state) => state.budgetBlocks[toolCallId]);
|
| 506 |
const { setPanel, getEditedScript } = useAgentStore();
|
| 507 |
const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore();
|
| 508 |
const hasEditedScript = !!getEditedScript(toolCallId);
|
|
|
|
| 522 |
|
| 523 |
return (
|
| 524 |
<Box sx={{ px: 1.5, py: 1.5, borderTop: '1px solid var(--tool-border)' }}>
|
| 525 |
+
{autoApproval && (
|
| 526 |
+
<Alert
|
| 527 |
+
severity="warning"
|
| 528 |
+
sx={{
|
| 529 |
+
mb: 1.5,
|
| 530 |
+
py: 0.5,
|
| 531 |
+
bgcolor: 'rgba(245,158,11,0.08)',
|
| 532 |
+
border: '1px solid rgba(245,158,11,0.18)',
|
| 533 |
+
color: 'var(--text)',
|
| 534 |
+
'& .MuiAlert-icon': { color: 'var(--accent-yellow)' },
|
| 535 |
+
}}
|
| 536 |
+
>
|
| 537 |
+
<Typography variant="body2" sx={{ fontSize: '0.72rem' }}>
|
| 538 |
+
YOLO paused: {autoApproval.reason || 'manual approval required.'}
|
| 539 |
+
</Typography>
|
| 540 |
+
</Alert>
|
| 541 |
+
)}
|
| 542 |
+
|
| 543 |
{toolName === 'sandbox_create' && args && (() => {
|
| 544 |
const hw = String(args.hardware || 'cpu-basic');
|
| 545 |
const cost = costLabel(hw);
|
frontend/src/components/Layout/AppLayout.tsx
CHANGED
|
@@ -24,6 +24,7 @@ import SessionSidebar from '@/components/SessionSidebar/SessionSidebar';
|
|
| 24 |
import SessionChat from '@/components/SessionChat';
|
| 25 |
import CodePanel from '@/components/CodePanel/CodePanel';
|
| 26 |
import WelcomeScreen from '@/components/WelcomeScreen/WelcomeScreen';
|
|
|
|
| 27 |
import { apiFetch } from '@/utils/api';
|
| 28 |
|
| 29 |
const DRAWER_WIDTH = 260;
|
|
@@ -252,6 +253,7 @@ export default function AppLayout() {
|
|
| 252 |
</Box>
|
| 253 |
|
| 254 |
<Box sx={{ display: 'flex', alignItems: 'center', gap: 0.5 }}>
|
|
|
|
| 255 |
<IconButton
|
| 256 |
onClick={toggleTheme}
|
| 257 |
size="small"
|
|
|
|
| 24 |
import SessionChat from '@/components/SessionChat';
|
| 25 |
import CodePanel from '@/components/CodePanel/CodePanel';
|
| 26 |
import WelcomeScreen from '@/components/WelcomeScreen/WelcomeScreen';
|
| 27 |
+
import YoloControl from '@/components/YoloControl';
|
| 28 |
import { apiFetch } from '@/utils/api';
|
| 29 |
|
| 30 |
const DRAWER_WIDTH = 260;
|
|
|
|
| 253 |
</Box>
|
| 254 |
|
| 255 |
<Box sx={{ display: 'flex', alignItems: 'center', gap: 0.5 }}>
|
| 256 |
+
<YoloControl />
|
| 257 |
<IconButton
|
| 258 |
onClick={toggleTheme}
|
| 259 |
size="small"
|
frontend/src/components/YoloControl.tsx
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useMemo, useState } from 'react';
|
| 2 |
+
import {
|
| 3 |
+
Button,
|
| 4 |
+
Dialog,
|
| 5 |
+
DialogActions,
|
| 6 |
+
DialogContent,
|
| 7 |
+
DialogTitle,
|
| 8 |
+
TextField,
|
| 9 |
+
Tooltip,
|
| 10 |
+
Typography,
|
| 11 |
+
} from '@mui/material';
|
| 12 |
+
import BoltOutlinedIcon from '@mui/icons-material/BoltOutlined';
|
| 13 |
+
import { useSessionStore } from '@/store/sessionStore';
|
| 14 |
+
import { apiFetch } from '@/utils/api';
|
| 15 |
+
|
| 16 |
+
const DEFAULT_CAP_USD = 5;
|
| 17 |
+
|
| 18 |
+
function money(value: number | null | undefined): string {
|
| 19 |
+
if (value === null || value === undefined) return 'uncapped';
|
| 20 |
+
if (value >= 100) return `$${value.toFixed(0)}`;
|
| 21 |
+
return `$${value.toFixed(2).replace(/\.00$/, '')}`;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
export default function YoloControl() {
|
| 25 |
+
const { sessions, activeSessionId, updateSessionYolo } = useSessionStore();
|
| 26 |
+
const activeSession = useMemo(
|
| 27 |
+
() => sessions.find((s) => s.id === activeSessionId) || null,
|
| 28 |
+
[sessions, activeSessionId],
|
| 29 |
+
);
|
| 30 |
+
const [dialogOpen, setDialogOpen] = useState(false);
|
| 31 |
+
const [capInput, setCapInput] = useState(String(DEFAULT_CAP_USD));
|
| 32 |
+
const [busy, setBusy] = useState(false);
|
| 33 |
+
const [error, setError] = useState<string | null>(null);
|
| 34 |
+
|
| 35 |
+
const enabled = Boolean(activeSession?.autoApprovalEnabled);
|
| 36 |
+
const disabled = !activeSessionId || activeSession?.expired || busy;
|
| 37 |
+
const remaining = activeSession?.autoApprovalRemainingUsd ?? null;
|
| 38 |
+
const cap = activeSession?.autoApprovalCostCapUsd ?? null;
|
| 39 |
+
|
| 40 |
+
useEffect(() => {
|
| 41 |
+
if (!activeSession) return;
|
| 42 |
+
setCapInput(String(activeSession.autoApprovalCostCapUsd ?? DEFAULT_CAP_USD));
|
| 43 |
+
}, [activeSession?.id, activeSession?.autoApprovalCostCapUsd]); // eslint-disable-line react-hooks/exhaustive-deps
|
| 44 |
+
|
| 45 |
+
async function patchPolicy(nextEnabled: boolean, nextCap?: number) {
|
| 46 |
+
if (!activeSessionId) return null;
|
| 47 |
+
setBusy(true);
|
| 48 |
+
setError(null);
|
| 49 |
+
try {
|
| 50 |
+
const body: Record<string, unknown> = { enabled: nextEnabled };
|
| 51 |
+
if (nextCap !== undefined) body.cost_cap_usd = nextCap;
|
| 52 |
+
const response = await apiFetch(`/api/session/${activeSessionId}/yolo`, {
|
| 53 |
+
method: 'PATCH',
|
| 54 |
+
body: JSON.stringify(body),
|
| 55 |
+
});
|
| 56 |
+
if (!response.ok) {
|
| 57 |
+
throw new Error(await response.text());
|
| 58 |
+
}
|
| 59 |
+
const data = await response.json();
|
| 60 |
+
updateSessionYolo(activeSessionId, data);
|
| 61 |
+
return data;
|
| 62 |
+
} catch {
|
| 63 |
+
setError('Could not update YOLO settings.');
|
| 64 |
+
return null;
|
| 65 |
+
} finally {
|
| 66 |
+
setBusy(false);
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
const handleToggle = async () => {
|
| 71 |
+
if (disabled) return;
|
| 72 |
+
if (enabled) {
|
| 73 |
+
await patchPolicy(false);
|
| 74 |
+
return;
|
| 75 |
+
}
|
| 76 |
+
const nextCap = cap ?? DEFAULT_CAP_USD;
|
| 77 |
+
const updated = await patchPolicy(true, nextCap);
|
| 78 |
+
if (updated) {
|
| 79 |
+
setCapInput(String(updated.cost_cap_usd ?? nextCap));
|
| 80 |
+
setDialogOpen(true);
|
| 81 |
+
}
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
const handleSaveCap = async () => {
|
| 85 |
+
const parsed = Number(capInput);
|
| 86 |
+
if (!Number.isFinite(parsed) || parsed < 0) {
|
| 87 |
+
setError('Enter a non-negative dollar amount.');
|
| 88 |
+
return;
|
| 89 |
+
}
|
| 90 |
+
const updated = await patchPolicy(true, parsed);
|
| 91 |
+
if (updated) setDialogOpen(false);
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
return (
|
| 95 |
+
<>
|
| 96 |
+
<Tooltip title={enabled ? 'Disable session YOLO auto-approval' : 'Enable session YOLO auto-approval'}>
|
| 97 |
+
<span>
|
| 98 |
+
<Button
|
| 99 |
+
size="small"
|
| 100 |
+
variant={enabled ? 'contained' : 'outlined'}
|
| 101 |
+
disabled={disabled}
|
| 102 |
+
onClick={handleToggle}
|
| 103 |
+
startIcon={<BoltOutlinedIcon sx={{ fontSize: 16 }} />}
|
| 104 |
+
sx={{
|
| 105 |
+
minWidth: { xs: 74, md: 116 },
|
| 106 |
+
height: 32,
|
| 107 |
+
px: { xs: 1, md: 1.25 },
|
| 108 |
+
borderRadius: '8px',
|
| 109 |
+
textTransform: 'none',
|
| 110 |
+
fontSize: '0.72rem',
|
| 111 |
+
whiteSpace: 'nowrap',
|
| 112 |
+
bgcolor: enabled ? 'var(--accent-yellow)' : 'transparent',
|
| 113 |
+
color: enabled ? '#111' : 'text.secondary',
|
| 114 |
+
borderColor: enabled ? 'var(--accent-yellow)' : 'divider',
|
| 115 |
+
'&:hover': {
|
| 116 |
+
bgcolor: enabled ? 'var(--accent-yellow)' : 'action.hover',
|
| 117 |
+
borderColor: 'var(--accent-yellow)',
|
| 118 |
+
},
|
| 119 |
+
}}
|
| 120 |
+
>
|
| 121 |
+
{enabled ? `YOLO ${money(remaining)}` : 'YOLO'}
|
| 122 |
+
</Button>
|
| 123 |
+
</span>
|
| 124 |
+
</Tooltip>
|
| 125 |
+
|
| 126 |
+
<Dialog open={dialogOpen} onClose={() => setDialogOpen(false)} maxWidth="xs" fullWidth>
|
| 127 |
+
<DialogTitle sx={{ pb: 1 }}>YOLO Budget</DialogTitle>
|
| 128 |
+
<DialogContent sx={{ display: 'flex', flexDirection: 'column', gap: 1.5, pt: 1 }}>
|
| 129 |
+
<Typography variant="body2" color="text.secondary">
|
| 130 |
+
Auto-approval is active for this session. Scheduled HF jobs still require approval.
|
| 131 |
+
</Typography>
|
| 132 |
+
<TextField
|
| 133 |
+
autoFocus
|
| 134 |
+
label="Session cap (USD)"
|
| 135 |
+
type="number"
|
| 136 |
+
size="small"
|
| 137 |
+
value={capInput}
|
| 138 |
+
onChange={(e) => setCapInput(e.target.value)}
|
| 139 |
+
inputProps={{ min: 0, step: 0.5 }}
|
| 140 |
+
error={Boolean(error)}
|
| 141 |
+
helperText={error || `Estimated spend: ${money(activeSession?.autoApprovalEstimatedSpendUsd ?? 0)} of ${money(cap)}`}
|
| 142 |
+
/>
|
| 143 |
+
</DialogContent>
|
| 144 |
+
<DialogActions>
|
| 145 |
+
<Button onClick={() => setDialogOpen(false)} sx={{ textTransform: 'none' }}>
|
| 146 |
+
Close
|
| 147 |
+
</Button>
|
| 148 |
+
<Button onClick={handleSaveCap} disabled={busy} variant="contained" sx={{ textTransform: 'none' }}>
|
| 149 |
+
Save
|
| 150 |
+
</Button>
|
| 151 |
+
</DialogActions>
|
| 152 |
+
</Dialog>
|
| 153 |
+
</>
|
| 154 |
+
);
|
| 155 |
+
}
|
frontend/src/hooks/useAgentChat.ts
CHANGED
|
@@ -36,7 +36,7 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 36 |
const isActiveRef = useRef(isActive);
|
| 37 |
isActiveRef.current = isActive;
|
| 38 |
|
| 39 |
-
const { setNeedsAttention } = useSessionStore();
|
| 40 |
|
| 41 |
// Helper: update this session's state (mirrors to globals if active)
|
| 42 |
const updateSession = useAgentStore.getState().updateSession;
|
|
@@ -186,6 +186,20 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 186 |
if (!tools.length) return;
|
| 187 |
setNeedsAttention(sessionId, true);
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
updateSession(sessionId, { activityStatus: { type: 'waiting-approval' } });
|
| 190 |
|
| 191 |
// Build panel data for this session's pending approval
|
|
@@ -480,6 +494,9 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 480 |
);
|
| 481 |
if (pendingIds.size > 0) setNeedsAttention(sessionId, true);
|
| 482 |
}
|
|
|
|
|
|
|
|
|
|
| 483 |
return { data, pendingIds, info };
|
| 484 |
}
|
| 485 |
return { data, pendingIds, info: null };
|
|
@@ -562,7 +579,15 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 562 |
return true;
|
| 563 |
} else if (et === 'approval_required') {
|
| 564 |
sideChannel.onApprovalRequired(
|
| 565 |
-
(event.data?.tools || []) as Array<{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
);
|
| 567 |
stopReconnect();
|
| 568 |
const result = await hydrateMessages();
|
|
|
|
| 36 |
const isActiveRef = useRef(isActive);
|
| 37 |
isActiveRef.current = isActive;
|
| 38 |
|
| 39 |
+
const { setNeedsAttention, updateSessionYolo } = useSessionStore();
|
| 40 |
|
| 41 |
// Helper: update this session's state (mirrors to globals if active)
|
| 42 |
const updateSession = useAgentStore.getState().updateSession;
|
|
|
|
| 186 |
if (!tools.length) return;
|
| 187 |
setNeedsAttention(sessionId, true);
|
| 188 |
|
| 189 |
+
const store = useAgentStore.getState();
|
| 190 |
+
for (const tool of tools) {
|
| 191 |
+
store.setToolBudgetBlock(
|
| 192 |
+
tool.tool_call_id,
|
| 193 |
+
tool.auto_approval_blocked
|
| 194 |
+
? {
|
| 195 |
+
reason: tool.block_reason ?? null,
|
| 196 |
+
estimatedCostUsd: tool.estimated_cost_usd ?? null,
|
| 197 |
+
remainingCapUsd: tool.remaining_cap_usd ?? null,
|
| 198 |
+
}
|
| 199 |
+
: null,
|
| 200 |
+
);
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
updateSession(sessionId, { activityStatus: { type: 'waiting-approval' } });
|
| 204 |
|
| 205 |
// Build panel data for this session's pending approval
|
|
|
|
| 494 |
);
|
| 495 |
if (pendingIds.size > 0) setNeedsAttention(sessionId, true);
|
| 496 |
}
|
| 497 |
+
if (info.auto_approval) {
|
| 498 |
+
updateSessionYolo(sessionId, info.auto_approval);
|
| 499 |
+
}
|
| 500 |
return { data, pendingIds, info };
|
| 501 |
}
|
| 502 |
return { data, pendingIds, info: null };
|
|
|
|
| 579 |
return true;
|
| 580 |
} else if (et === 'approval_required') {
|
| 581 |
sideChannel.onApprovalRequired(
|
| 582 |
+
(event.data?.tools || []) as Array<{
|
| 583 |
+
tool: string;
|
| 584 |
+
arguments: Record<string, unknown>;
|
| 585 |
+
tool_call_id: string;
|
| 586 |
+
auto_approval_blocked?: boolean;
|
| 587 |
+
block_reason?: string | null;
|
| 588 |
+
estimated_cost_usd?: number | null;
|
| 589 |
+
remaining_cap_usd?: number | null;
|
| 590 |
+
}>,
|
| 591 |
);
|
| 592 |
stopReconnect();
|
| 593 |
const result = await hydrateMessages();
|
frontend/src/lib/sse-chat-transport.ts
CHANGED
|
@@ -26,7 +26,15 @@ export interface SideChannelCallbacks {
|
|
| 26 |
onToolLog: (tool: string, log: string, agentId?: string, label?: string) => void;
|
| 27 |
onConnectionChange: (connected: boolean) => void;
|
| 28 |
onSessionDead: (sessionId: string) => void;
|
| 29 |
-
onApprovalRequired: (tools: Array<{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
onToolCallPanel: (tool: string, args: Record<string, unknown>) => void;
|
| 31 |
onToolOutputPanel: (tool: string, toolCallId: string, output: string, success: boolean) => void;
|
| 32 |
onStreaming: () => void;
|
|
@@ -236,6 +244,10 @@ function createEventToChunkStream(sideChannel: SideChannelCallbacks): TransformS
|
|
| 236 |
tool: string;
|
| 237 |
arguments: Record<string, unknown>;
|
| 238 |
tool_call_id: string;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
}>;
|
| 240 |
if (!tools) break;
|
| 241 |
|
|
|
|
| 26 |
onToolLog: (tool: string, log: string, agentId?: string, label?: string) => void;
|
| 27 |
onConnectionChange: (connected: boolean) => void;
|
| 28 |
onSessionDead: (sessionId: string) => void;
|
| 29 |
+
onApprovalRequired: (tools: Array<{
|
| 30 |
+
tool: string;
|
| 31 |
+
arguments: Record<string, unknown>;
|
| 32 |
+
tool_call_id: string;
|
| 33 |
+
auto_approval_blocked?: boolean;
|
| 34 |
+
block_reason?: string | null;
|
| 35 |
+
estimated_cost_usd?: number | null;
|
| 36 |
+
remaining_cap_usd?: number | null;
|
| 37 |
+
}>) => void;
|
| 38 |
onToolCallPanel: (tool: string, args: Record<string, unknown>) => void;
|
| 39 |
onToolOutputPanel: (tool: string, toolCallId: string, output: string, success: boolean) => void;
|
| 40 |
onStreaming: () => void;
|
|
|
|
| 244 |
tool: string;
|
| 245 |
arguments: Record<string, unknown>;
|
| 246 |
tool_call_id: string;
|
| 247 |
+
auto_approval_blocked?: boolean;
|
| 248 |
+
block_reason?: string | null;
|
| 249 |
+
estimated_cost_usd?: number | null;
|
| 250 |
+
remaining_cap_usd?: number | null;
|
| 251 |
}>;
|
| 252 |
if (!tools) break;
|
| 253 |
|
frontend/src/store/agentStore.ts
CHANGED
|
@@ -50,6 +50,12 @@ export interface JobsUpgradeState {
|
|
| 50 |
namespace?: string | null;
|
| 51 |
}
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
export type ActivityStatus =
|
| 54 |
| { type: 'idle' }
|
| 55 |
| { type: 'thinking' }
|
|
@@ -145,6 +151,9 @@ interface AgentStore {
|
|
| 145 |
// Tool rejected states (tool_call_id -> true if rejected by user) - persisted across renders
|
| 146 |
rejectedTools: Record<string, boolean>;
|
| 147 |
|
|
|
|
|
|
|
|
|
|
| 148 |
// ── Per-session actions ─────────────────────────────────────────────
|
| 149 |
|
| 150 |
/** Update a session's state. If it's the active session, also update flat state. */
|
|
@@ -196,6 +205,9 @@ interface AgentStore {
|
|
| 196 |
|
| 197 |
setToolRejected: (toolCallId: string, isRejected: boolean) => void;
|
| 198 |
getToolRejected: (toolCallId: string) => boolean | undefined;
|
|
|
|
|
|
|
|
|
|
| 199 |
}
|
| 200 |
|
| 201 |
/**
|
|
@@ -300,6 +312,7 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
|
|
| 300 |
trackioDashboards: loadTrackioDashboards(),
|
| 301 |
toolErrors: loadToolErrors(),
|
| 302 |
rejectedTools: loadRejectedTools(),
|
|
|
|
| 303 |
|
| 304 |
// ── Per-session state management ──────────────────────────────────
|
| 305 |
|
|
@@ -529,4 +542,24 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
|
|
| 529 |
},
|
| 530 |
|
| 531 |
getToolRejected: (toolCallId) => get().rejectedTools[toolCallId],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
}));
|
|
|
|
| 50 |
namespace?: string | null;
|
| 51 |
}
|
| 52 |
|
| 53 |
+
export interface ToolBudgetBlockState {
|
| 54 |
+
reason?: string | null;
|
| 55 |
+
estimatedCostUsd?: number | null;
|
| 56 |
+
remainingCapUsd?: number | null;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
export type ActivityStatus =
|
| 60 |
| { type: 'idle' }
|
| 61 |
| { type: 'thinking' }
|
|
|
|
| 151 |
// Tool rejected states (tool_call_id -> true if rejected by user) - persisted across renders
|
| 152 |
rejectedTools: Record<string, boolean>;
|
| 153 |
|
| 154 |
+
// Tool budget-block metadata (tool_call_id -> display metadata) - transient UI state
|
| 155 |
+
budgetBlocks: Record<string, ToolBudgetBlockState>;
|
| 156 |
+
|
| 157 |
// ── Per-session actions ─────────────────────────────────────────────
|
| 158 |
|
| 159 |
/** Update a session's state. If it's the active session, also update flat state. */
|
|
|
|
| 205 |
|
| 206 |
setToolRejected: (toolCallId: string, isRejected: boolean) => void;
|
| 207 |
getToolRejected: (toolCallId: string) => boolean | undefined;
|
| 208 |
+
|
| 209 |
+
setToolBudgetBlock: (toolCallId: string, block: ToolBudgetBlockState | null) => void;
|
| 210 |
+
getToolBudgetBlock: (toolCallId: string) => ToolBudgetBlockState | undefined;
|
| 211 |
}
|
| 212 |
|
| 213 |
/**
|
|
|
|
| 312 |
trackioDashboards: loadTrackioDashboards(),
|
| 313 |
toolErrors: loadToolErrors(),
|
| 314 |
rejectedTools: loadRejectedTools(),
|
| 315 |
+
budgetBlocks: {},
|
| 316 |
|
| 317 |
// ── Per-session state management ──────────────────────────────────
|
| 318 |
|
|
|
|
| 542 |
},
|
| 543 |
|
| 544 |
getToolRejected: (toolCallId) => get().rejectedTools[toolCallId],
|
| 545 |
+
|
| 546 |
+
// ── Tool Budget Blocks ───────────────────────────────────────────────
|
| 547 |
+
|
| 548 |
+
setToolBudgetBlock: (toolCallId, block) => {
|
| 549 |
+
set((state) => {
|
| 550 |
+
if (!block) {
|
| 551 |
+
const next = { ...state.budgetBlocks };
|
| 552 |
+
delete next[toolCallId];
|
| 553 |
+
return { budgetBlocks: next };
|
| 554 |
+
}
|
| 555 |
+
return {
|
| 556 |
+
budgetBlocks: {
|
| 557 |
+
...state.budgetBlocks,
|
| 558 |
+
[toolCallId]: block,
|
| 559 |
+
},
|
| 560 |
+
};
|
| 561 |
+
});
|
| 562 |
+
},
|
| 563 |
+
|
| 564 |
+
getToolBudgetBlock: (toolCallId) => get().budgetBlocks[toolCallId],
|
| 565 |
}));
|
frontend/src/store/sessionStore.ts
CHANGED
|
@@ -27,7 +27,19 @@ interface SessionStore {
|
|
| 27 |
created_at: string;
|
| 28 |
is_active?: boolean;
|
| 29 |
pending_approval?: unknown[] | null;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
}>) => void;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
/** Atomically swap a session's id in the list + both localStorage caches.
|
| 32 |
* Used when we rehydrate an expired session into a freshly-created backend
|
| 33 |
* session — preserves title, timestamps, and messages. */
|
|
@@ -47,6 +59,10 @@ export const useSessionStore = create<SessionStore>()(
|
|
| 47 |
createdAt: new Date().toISOString(),
|
| 48 |
isActive: true,
|
| 49 |
needsAttention: false,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
};
|
| 51 |
set((state) => ({
|
| 52 |
sessions: [...state.sessions, newSession],
|
|
@@ -93,12 +109,21 @@ export const useSessionStore = create<SessionStore>()(
|
|
| 93 |
if (!id) continue;
|
| 94 |
const existing = byId.get(id);
|
| 95 |
if (existing) {
|
|
|
|
| 96 |
const updated = {
|
| 97 |
...existing,
|
| 98 |
title: server.title || existing.title,
|
| 99 |
isActive: server.is_active ?? existing.isActive,
|
| 100 |
needsAttention: Boolean(server.pending_approval?.length) || existing.needsAttention,
|
| 101 |
expired: false,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
};
|
| 103 |
const idx = merged.findIndex((s) => s.id === id);
|
| 104 |
if (idx >= 0) merged[idx] = updated;
|
|
@@ -112,6 +137,10 @@ export const useSessionStore = create<SessionStore>()(
|
|
| 112 |
isActive: server.is_active ?? true,
|
| 113 |
needsAttention: Boolean(server.pending_approval?.length),
|
| 114 |
expired: false,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
};
|
| 116 |
merged.push(newSession);
|
| 117 |
byId.set(id, newSession);
|
|
@@ -123,6 +152,22 @@ export const useSessionStore = create<SessionStore>()(
|
|
| 123 |
});
|
| 124 |
},
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
renameSession: (oldId: string, newId: string) => {
|
| 127 |
if (oldId === newId) return;
|
| 128 |
moveMessages(oldId, newId);
|
|
|
|
| 27 |
created_at: string;
|
| 28 |
is_active?: boolean;
|
| 29 |
pending_approval?: unknown[] | null;
|
| 30 |
+
auto_approval?: {
|
| 31 |
+
enabled?: boolean;
|
| 32 |
+
cost_cap_usd?: number | null;
|
| 33 |
+
estimated_spend_usd?: number;
|
| 34 |
+
remaining_usd?: number | null;
|
| 35 |
+
} | null;
|
| 36 |
}>) => void;
|
| 37 |
+
updateSessionYolo: (id: string, policy: {
|
| 38 |
+
enabled: boolean;
|
| 39 |
+
cost_cap_usd?: number | null;
|
| 40 |
+
estimated_spend_usd?: number;
|
| 41 |
+
remaining_usd?: number | null;
|
| 42 |
+
}) => void;
|
| 43 |
/** Atomically swap a session's id in the list + both localStorage caches.
|
| 44 |
* Used when we rehydrate an expired session into a freshly-created backend
|
| 45 |
* session — preserves title, timestamps, and messages. */
|
|
|
|
| 59 |
createdAt: new Date().toISOString(),
|
| 60 |
isActive: true,
|
| 61 |
needsAttention: false,
|
| 62 |
+
autoApprovalEnabled: false,
|
| 63 |
+
autoApprovalCostCapUsd: null,
|
| 64 |
+
autoApprovalEstimatedSpendUsd: 0,
|
| 65 |
+
autoApprovalRemainingUsd: null,
|
| 66 |
};
|
| 67 |
set((state) => ({
|
| 68 |
sessions: [...state.sessions, newSession],
|
|
|
|
| 109 |
if (!id) continue;
|
| 110 |
const existing = byId.get(id);
|
| 111 |
if (existing) {
|
| 112 |
+
const auto = server.auto_approval;
|
| 113 |
const updated = {
|
| 114 |
...existing,
|
| 115 |
title: server.title || existing.title,
|
| 116 |
isActive: server.is_active ?? existing.isActive,
|
| 117 |
needsAttention: Boolean(server.pending_approval?.length) || existing.needsAttention,
|
| 118 |
expired: false,
|
| 119 |
+
...(auto
|
| 120 |
+
? {
|
| 121 |
+
autoApprovalEnabled: Boolean(auto.enabled),
|
| 122 |
+
autoApprovalCostCapUsd: auto.cost_cap_usd ?? null,
|
| 123 |
+
autoApprovalEstimatedSpendUsd: auto.estimated_spend_usd ?? 0,
|
| 124 |
+
autoApprovalRemainingUsd: auto.remaining_usd ?? null,
|
| 125 |
+
}
|
| 126 |
+
: {}),
|
| 127 |
};
|
| 128 |
const idx = merged.findIndex((s) => s.id === id);
|
| 129 |
if (idx >= 0) merged[idx] = updated;
|
|
|
|
| 137 |
isActive: server.is_active ?? true,
|
| 138 |
needsAttention: Boolean(server.pending_approval?.length),
|
| 139 |
expired: false,
|
| 140 |
+
autoApprovalEnabled: Boolean(server.auto_approval?.enabled),
|
| 141 |
+
autoApprovalCostCapUsd: server.auto_approval?.cost_cap_usd ?? null,
|
| 142 |
+
autoApprovalEstimatedSpendUsd: server.auto_approval?.estimated_spend_usd ?? 0,
|
| 143 |
+
autoApprovalRemainingUsd: server.auto_approval?.remaining_usd ?? null,
|
| 144 |
};
|
| 145 |
merged.push(newSession);
|
| 146 |
byId.set(id, newSession);
|
|
|
|
| 152 |
});
|
| 153 |
},
|
| 154 |
|
| 155 |
+
updateSessionYolo: (id, policy) => {
|
| 156 |
+
set((state) => ({
|
| 157 |
+
sessions: state.sessions.map((s) =>
|
| 158 |
+
s.id === id
|
| 159 |
+
? {
|
| 160 |
+
...s,
|
| 161 |
+
autoApprovalEnabled: policy.enabled,
|
| 162 |
+
autoApprovalCostCapUsd: policy.cost_cap_usd ?? null,
|
| 163 |
+
autoApprovalEstimatedSpendUsd: policy.estimated_spend_usd ?? 0,
|
| 164 |
+
autoApprovalRemainingUsd: policy.remaining_usd ?? null,
|
| 165 |
+
}
|
| 166 |
+
: s,
|
| 167 |
+
),
|
| 168 |
+
}));
|
| 169 |
+
},
|
| 170 |
+
|
| 171 |
renameSession: (oldId: string, newId: string) => {
|
| 172 |
if (oldId === newId) return;
|
| 173 |
moveMessages(oldId, newId);
|
frontend/src/types/agent.ts
CHANGED
|
@@ -21,6 +21,10 @@ export interface SessionMeta {
|
|
| 21 |
* disables input until the user chooses to restore-with-summary or
|
| 22 |
* start fresh. */
|
| 23 |
expired?: boolean;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
}
|
| 25 |
|
| 26 |
export interface ToolApproval {
|
|
|
|
| 21 |
* disables input until the user chooses to restore-with-summary or
|
| 22 |
* start fresh. */
|
| 23 |
expired?: boolean;
|
| 24 |
+
autoApprovalEnabled?: boolean;
|
| 25 |
+
autoApprovalCostCapUsd?: number | null;
|
| 26 |
+
autoApprovalEstimatedSpendUsd?: number;
|
| 27 |
+
autoApprovalRemainingUsd?: number | null;
|
| 28 |
}
|
| 29 |
|
| 30 |
export interface ToolApproval {
|
frontend/src/types/events.ts
CHANGED
|
@@ -68,6 +68,10 @@ export interface ApprovalToolItem {
|
|
| 68 |
tool: string;
|
| 69 |
arguments: Record<string, unknown>;
|
| 70 |
tool_call_id: string;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
}
|
| 72 |
|
| 73 |
export interface TurnCompleteEventData {
|
|
|
|
| 68 |
tool: string;
|
| 69 |
arguments: Record<string, unknown>;
|
| 70 |
tool_call_id: string;
|
| 71 |
+
auto_approval_blocked?: boolean;
|
| 72 |
+
block_reason?: string | null;
|
| 73 |
+
estimated_cost_usd?: number | null;
|
| 74 |
+
remaining_cap_usd?: number | null;
|
| 75 |
}
|
| 76 |
|
| 77 |
export interface TurnCompleteEventData {
|
tests/unit/test_agent_model_gating.py
CHANGED
|
@@ -127,3 +127,48 @@ async def test_user_quota_response_uses_premium_fields_only(monkeypatch):
|
|
| 127 |
"premium_daily_cap": 5,
|
| 128 |
"premium_remaining": 3,
|
| 129 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
"premium_daily_cap": 5,
|
| 128 |
"premium_remaining": 3,
|
| 129 |
}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@pytest.mark.asyncio
|
| 133 |
+
async def test_set_session_yolo_calls_manager_with_cap_presence(monkeypatch):
|
| 134 |
+
async def fake_check_session_access(session_id, user, request=None):
|
| 135 |
+
assert session_id == "s1"
|
| 136 |
+
assert user["user_id"] == "u1"
|
| 137 |
+
return object()
|
| 138 |
+
|
| 139 |
+
calls = []
|
| 140 |
+
|
| 141 |
+
async def fake_update_session_auto_approval(session_id, **kwargs):
|
| 142 |
+
calls.append((session_id, kwargs))
|
| 143 |
+
return {
|
| 144 |
+
"enabled": kwargs["enabled"],
|
| 145 |
+
"cost_cap_usd": 7.5,
|
| 146 |
+
"estimated_spend_usd": 0.0,
|
| 147 |
+
"remaining_usd": 7.5,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 151 |
+
monkeypatch.setattr(
|
| 152 |
+
agent.session_manager,
|
| 153 |
+
"update_session_auto_approval",
|
| 154 |
+
fake_update_session_auto_approval,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
response = await agent.set_session_yolo(
|
| 158 |
+
"s1",
|
| 159 |
+
agent.SessionYoloRequest(enabled=True, cost_cap_usd=7.5),
|
| 160 |
+
{"user_id": "u1"},
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
assert response["enabled"] is True
|
| 164 |
+
assert response["remaining_usd"] == 7.5
|
| 165 |
+
assert calls == [
|
| 166 |
+
(
|
| 167 |
+
"s1",
|
| 168 |
+
{
|
| 169 |
+
"enabled": True,
|
| 170 |
+
"cost_cap_usd": 7.5,
|
| 171 |
+
"cap_provided": True,
|
| 172 |
+
},
|
| 173 |
+
)
|
| 174 |
+
]
|
tests/unit/test_auto_approval_policy.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from types import SimpleNamespace
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from agent.config import Config
|
| 6 |
+
from agent.core import agent_loop
|
| 7 |
+
from agent.core.cost_estimation import CostEstimate
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _config(**overrides):
|
| 11 |
+
data = {
|
| 12 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 13 |
+
"confirm_cpu_jobs": True,
|
| 14 |
+
"auto_file_upload": False,
|
| 15 |
+
"yolo_mode": False,
|
| 16 |
+
**overrides,
|
| 17 |
+
}
|
| 18 |
+
return Config.model_validate(data)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _session(*, cap=5.0, spent=0.0, enabled=True):
|
| 22 |
+
return SimpleNamespace(
|
| 23 |
+
config=_config(),
|
| 24 |
+
auto_approval_enabled=enabled,
|
| 25 |
+
auto_approval_cost_cap_usd=cap,
|
| 26 |
+
auto_approval_estimated_spend_usd=spent,
|
| 27 |
+
sandbox=None,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@pytest.mark.asyncio
|
| 32 |
+
async def test_session_yolo_auto_approves_non_costed_approval_tool():
|
| 33 |
+
decision = await agent_loop._approval_decision(
|
| 34 |
+
"hf_repo_files",
|
| 35 |
+
{"operation": "upload", "path": "README.md"},
|
| 36 |
+
_session(),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
assert decision.requires_approval is False
|
| 40 |
+
assert decision.auto_approved is True
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@pytest.mark.asyncio
|
| 44 |
+
@pytest.mark.parametrize(
|
| 45 |
+
"operation",
|
| 46 |
+
["scheduled run", "scheduled uv", "scheduled run"],
|
| 47 |
+
)
|
| 48 |
+
async def test_scheduled_hf_jobs_always_require_manual_approval(operation):
|
| 49 |
+
session = _session()
|
| 50 |
+
session.config.yolo_mode = True
|
| 51 |
+
|
| 52 |
+
decision = await agent_loop._approval_decision(
|
| 53 |
+
"hf_jobs",
|
| 54 |
+
{"operation": operation, "script": "print(1)"},
|
| 55 |
+
session,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
assert decision.requires_approval is True
|
| 59 |
+
assert decision.auto_approval_blocked is True
|
| 60 |
+
assert "Scheduled HF jobs" in decision.block_reason
|
| 61 |
+
assert agent_loop._needs_approval("hf_jobs", {"operation": operation}, session.config)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@pytest.mark.asyncio
|
| 65 |
+
async def test_immediate_hf_job_under_cap_auto_runs(monkeypatch):
|
| 66 |
+
async def fake_estimate(*args, **kwargs):
|
| 67 |
+
return CostEstimate(estimated_cost_usd=2.0, billable=True)
|
| 68 |
+
|
| 69 |
+
monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
|
| 70 |
+
|
| 71 |
+
decision = await agent_loop._approval_decision(
|
| 72 |
+
"hf_jobs",
|
| 73 |
+
{"operation": "run", "hardware_flavor": "a10g-large", "timeout": "1h"},
|
| 74 |
+
_session(cap=5.0, spent=1.0),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
assert decision.requires_approval is False
|
| 78 |
+
assert decision.auto_approved is True
|
| 79 |
+
assert decision.estimated_cost_usd == 2.0
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@pytest.mark.asyncio
|
| 83 |
+
async def test_immediate_hf_job_over_cap_falls_back_to_approval(monkeypatch):
|
| 84 |
+
async def fake_estimate(*args, **kwargs):
|
| 85 |
+
return CostEstimate(estimated_cost_usd=2.0, billable=True)
|
| 86 |
+
|
| 87 |
+
monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
|
| 88 |
+
|
| 89 |
+
decision = await agent_loop._approval_decision(
|
| 90 |
+
"hf_jobs",
|
| 91 |
+
{"operation": "run", "hardware_flavor": "a10g-large", "timeout": "1h"},
|
| 92 |
+
_session(cap=5.0, spent=4.0),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
assert decision.requires_approval is True
|
| 96 |
+
assert decision.auto_approval_blocked is True
|
| 97 |
+
assert "exceeds" in decision.block_reason
|
| 98 |
+
assert decision.remaining_cap_usd == 1.0
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@pytest.mark.asyncio
|
| 102 |
+
async def test_unknown_cost_falls_back_to_approval(monkeypatch):
|
| 103 |
+
async def fake_estimate(*args, **kwargs):
|
| 104 |
+
return CostEstimate(
|
| 105 |
+
estimated_cost_usd=None,
|
| 106 |
+
billable=True,
|
| 107 |
+
block_reason="No price is available.",
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
|
| 111 |
+
|
| 112 |
+
decision = await agent_loop._approval_decision(
|
| 113 |
+
"sandbox_create",
|
| 114 |
+
{"hardware": "mystery-gpu"},
|
| 115 |
+
_session(),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
assert decision.requires_approval is True
|
| 119 |
+
assert decision.auto_approval_blocked is True
|
| 120 |
+
assert decision.estimated_cost_usd is None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@pytest.mark.asyncio
|
| 124 |
+
async def test_batch_reservation_blocks_second_over_budget_job(monkeypatch):
|
| 125 |
+
async def fake_estimate(*args, **kwargs):
|
| 126 |
+
return CostEstimate(estimated_cost_usd=3.0, billable=True)
|
| 127 |
+
|
| 128 |
+
monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
|
| 129 |
+
session = _session(cap=5.0, spent=0.0)
|
| 130 |
+
|
| 131 |
+
first = await agent_loop._approval_decision(
|
| 132 |
+
"hf_jobs",
|
| 133 |
+
{"operation": "run", "hardware_flavor": "a10g-large"},
|
| 134 |
+
session,
|
| 135 |
+
reserved_spend_usd=0.0,
|
| 136 |
+
)
|
| 137 |
+
second = await agent_loop._approval_decision(
|
| 138 |
+
"hf_jobs",
|
| 139 |
+
{"operation": "run", "hardware_flavor": "a10g-large"},
|
| 140 |
+
session,
|
| 141 |
+
reserved_spend_usd=first.estimated_cost_usd or 0.0,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
assert first.requires_approval is False
|
| 145 |
+
assert second.requires_approval is True
|
| 146 |
+
assert second.remaining_cap_usd == 2.0
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@pytest.mark.asyncio
|
| 150 |
+
async def test_manual_approval_does_not_record_spend_when_session_yolo_disabled(monkeypatch):
|
| 151 |
+
called = False
|
| 152 |
+
|
| 153 |
+
async def fake_estimate(*args, **kwargs):
|
| 154 |
+
nonlocal called
|
| 155 |
+
called = True
|
| 156 |
+
return CostEstimate(estimated_cost_usd=2.0, billable=True)
|
| 157 |
+
|
| 158 |
+
monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
|
| 159 |
+
session = _session(enabled=False, cap=5.0, spent=0.0)
|
| 160 |
+
|
| 161 |
+
await agent_loop._record_manual_approved_spend_if_needed(
|
| 162 |
+
session,
|
| 163 |
+
"sandbox_create",
|
| 164 |
+
{"hardware": "a10g-large"},
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
assert called is False
|
| 168 |
+
assert session.auto_approval_estimated_spend_usd == 0.0
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@pytest.mark.asyncio
|
| 172 |
+
async def test_manual_approval_records_spend_when_session_yolo_enabled(monkeypatch):
|
| 173 |
+
async def fake_estimate(*args, **kwargs):
|
| 174 |
+
return CostEstimate(estimated_cost_usd=1.25, billable=True)
|
| 175 |
+
|
| 176 |
+
monkeypatch.setattr(agent_loop, "estimate_tool_cost", fake_estimate)
|
| 177 |
+
session = _session(enabled=True, cap=5.0, spent=0.5)
|
| 178 |
+
|
| 179 |
+
await agent_loop._record_manual_approved_spend_if_needed(
|
| 180 |
+
session,
|
| 181 |
+
"sandbox_create",
|
| 182 |
+
{"hardware": "a10g-large"},
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
assert session.auto_approval_estimated_spend_usd == 1.75
|
tests/unit/test_cost_estimation.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from types import SimpleNamespace
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from agent.core import cost_estimation
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_parse_timeout_hours_common_units():
|
| 9 |
+
assert cost_estimation.parse_timeout_hours(None) == 0.5
|
| 10 |
+
assert cost_estimation.parse_timeout_hours("30m") == 0.5
|
| 11 |
+
assert cost_estimation.parse_timeout_hours("3h") == 3
|
| 12 |
+
assert cost_estimation.parse_timeout_hours(3600) == 1
|
| 13 |
+
assert cost_estimation.parse_timeout_hours("not-a-duration") is None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@pytest.mark.asyncio
|
| 17 |
+
async def test_estimate_hf_job_cost_uses_catalog_price(monkeypatch):
|
| 18 |
+
async def fake_catalog():
|
| 19 |
+
return {"a100-large": 4.0}
|
| 20 |
+
|
| 21 |
+
monkeypatch.setattr(cost_estimation, "hf_jobs_price_catalog", fake_catalog)
|
| 22 |
+
|
| 23 |
+
estimate = await cost_estimation.estimate_hf_job_cost(
|
| 24 |
+
{"hardware_flavor": "a100-large", "timeout": "8h"}
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
assert estimate.estimated_cost_usd == 32.0
|
| 28 |
+
assert estimate.billable is True
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@pytest.mark.asyncio
|
| 32 |
+
async def test_estimate_hf_job_cost_blocks_unknown_price(monkeypatch):
|
| 33 |
+
async def fake_catalog():
|
| 34 |
+
return {}
|
| 35 |
+
|
| 36 |
+
monkeypatch.setattr(cost_estimation, "hf_jobs_price_catalog", fake_catalog)
|
| 37 |
+
|
| 38 |
+
estimate = await cost_estimation.estimate_hf_job_cost(
|
| 39 |
+
{"hardware_flavor": "mystery-gpu", "timeout": "30m"}
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
assert estimate.estimated_cost_usd is None
|
| 43 |
+
assert estimate.billable is True
|
| 44 |
+
assert "No price" in estimate.block_reason
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@pytest.mark.asyncio
|
| 48 |
+
async def test_estimate_sandbox_cost_is_zero_for_existing_or_cpu_basic():
|
| 49 |
+
existing = await cost_estimation.estimate_sandbox_cost(
|
| 50 |
+
{"hardware": "a100-large"},
|
| 51 |
+
session=SimpleNamespace(sandbox=object()),
|
| 52 |
+
)
|
| 53 |
+
cpu = await cost_estimation.estimate_sandbox_cost({"hardware": "cpu-basic"})
|
| 54 |
+
|
| 55 |
+
assert existing.estimated_cost_usd == 0.0
|
| 56 |
+
assert existing.billable is False
|
| 57 |
+
assert cpu.estimated_cost_usd == 0.0
|
| 58 |
+
assert cpu.billable is False
|
tests/unit/test_session_manager_persistence.py
CHANGED
|
@@ -27,6 +27,23 @@ class FakeRuntimeSession:
|
|
| 27 |
self.turn_count = 0
|
| 28 |
self.config = SimpleNamespace(model_name=model)
|
| 29 |
self.notification_destinations = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
class RestoreStore(NoopSessionStore):
|
|
@@ -85,6 +102,24 @@ def _runtime_agent_session(
|
|
| 85 |
)
|
| 86 |
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
def _install_fake_runtime(manager: SessionManager) -> asyncio.Event:
|
| 89 |
stop = asyncio.Event()
|
| 90 |
manager.run_calls = 0 # type: ignore[attr-defined]
|
|
@@ -204,6 +239,34 @@ async def test_lazy_restore_preserves_pending_approval_tool_calls():
|
|
| 204 |
await _cancel_runtime_tasks(manager)
|
| 205 |
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
@pytest.mark.asyncio
|
| 208 |
async def test_list_sessions_dev_uses_store_dev_visibility():
|
| 209 |
class ListStore(NoopSessionStore):
|
|
@@ -221,6 +284,9 @@ async def test_list_sessions_dev_uses_store_dev_visibility():
|
|
| 221 |
"user_id": "alice",
|
| 222 |
"model": "m",
|
| 223 |
"created_at": datetime.now(UTC),
|
|
|
|
|
|
|
|
|
|
| 224 |
},
|
| 225 |
{
|
| 226 |
"session_id": "s2",
|
|
@@ -238,3 +304,10 @@ async def test_list_sessions_dev_uses_store_dev_visibility():
|
|
| 238 |
|
| 239 |
assert store.seen_user_id == "dev"
|
| 240 |
assert {session["session_id"] for session in sessions} == {"s1", "s2"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
self.turn_count = 0
|
| 28 |
self.config = SimpleNamespace(model_name=model)
|
| 29 |
self.notification_destinations = []
|
| 30 |
+
self.auto_approval_enabled = False
|
| 31 |
+
self.auto_approval_cost_cap_usd = None
|
| 32 |
+
self.auto_approval_estimated_spend_usd = 0.0
|
| 33 |
+
|
| 34 |
+
def auto_approval_policy_summary(self):
|
| 35 |
+
cap = self.auto_approval_cost_cap_usd
|
| 36 |
+
remaining = None if cap is None else max(0, cap - self.auto_approval_estimated_spend_usd)
|
| 37 |
+
return {
|
| 38 |
+
"enabled": self.auto_approval_enabled,
|
| 39 |
+
"cost_cap_usd": cap,
|
| 40 |
+
"estimated_spend_usd": self.auto_approval_estimated_spend_usd,
|
| 41 |
+
"remaining_usd": remaining,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def set_auto_approval_policy(self, *, enabled, cost_cap_usd):
|
| 45 |
+
self.auto_approval_enabled = enabled
|
| 46 |
+
self.auto_approval_cost_cap_usd = cost_cap_usd
|
| 47 |
|
| 48 |
|
| 49 |
class RestoreStore(NoopSessionStore):
|
|
|
|
| 102 |
)
|
| 103 |
|
| 104 |
|
| 105 |
+
@pytest.mark.asyncio
|
| 106 |
+
async def test_update_session_auto_approval_defaults_to_five_dollars():
|
| 107 |
+
manager = _manager_with_store(NoopSessionStore())
|
| 108 |
+
existing = _runtime_agent_session("s1", user_id="owner")
|
| 109 |
+
manager.sessions["s1"] = existing
|
| 110 |
+
|
| 111 |
+
summary = await manager.update_session_auto_approval(
|
| 112 |
+
"s1",
|
| 113 |
+
enabled=True,
|
| 114 |
+
cost_cap_usd=None,
|
| 115 |
+
cap_provided=False,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
assert summary["enabled"] is True
|
| 119 |
+
assert summary["cost_cap_usd"] == 5.0
|
| 120 |
+
assert summary["remaining_usd"] == 5.0
|
| 121 |
+
|
| 122 |
+
|
| 123 |
def _install_fake_runtime(manager: SessionManager) -> asyncio.Event:
|
| 124 |
stop = asyncio.Event()
|
| 125 |
manager.run_calls = 0 # type: ignore[attr-defined]
|
|
|
|
| 239 |
await _cancel_runtime_tasks(manager)
|
| 240 |
|
| 241 |
|
| 242 |
+
@pytest.mark.asyncio
|
| 243 |
+
async def test_lazy_restore_preserves_auto_approval_policy():
|
| 244 |
+
store = RestoreStore(
|
| 245 |
+
metadata={
|
| 246 |
+
"session_id": "yolo-session",
|
| 247 |
+
"user_id": "owner",
|
| 248 |
+
"model": "test-model",
|
| 249 |
+
"auto_approval_enabled": True,
|
| 250 |
+
"auto_approval_cost_cap_usd": 5.0,
|
| 251 |
+
"auto_approval_estimated_spend_usd": 1.25,
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
+
manager = _manager_with_store(store)
|
| 255 |
+
stop = _install_fake_runtime(manager)
|
| 256 |
+
|
| 257 |
+
try:
|
| 258 |
+
restored = await manager.ensure_session_loaded("yolo-session", user_id="owner")
|
| 259 |
+
|
| 260 |
+
assert restored is not None
|
| 261 |
+
assert restored.session.auto_approval_enabled is True
|
| 262 |
+
assert restored.session.auto_approval_cost_cap_usd == 5.0
|
| 263 |
+
assert restored.session.auto_approval_estimated_spend_usd == 1.25
|
| 264 |
+
assert restored.session.auto_approval_policy_summary()["remaining_usd"] == 3.75
|
| 265 |
+
finally:
|
| 266 |
+
stop.set()
|
| 267 |
+
await _cancel_runtime_tasks(manager)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
@pytest.mark.asyncio
|
| 271 |
async def test_list_sessions_dev_uses_store_dev_visibility():
|
| 272 |
class ListStore(NoopSessionStore):
|
|
|
|
| 284 |
"user_id": "alice",
|
| 285 |
"model": "m",
|
| 286 |
"created_at": datetime.now(UTC),
|
| 287 |
+
"auto_approval_enabled": True,
|
| 288 |
+
"auto_approval_cost_cap_usd": 5.0,
|
| 289 |
+
"auto_approval_estimated_spend_usd": 2.0,
|
| 290 |
},
|
| 291 |
{
|
| 292 |
"session_id": "s2",
|
|
|
|
| 304 |
|
| 305 |
assert store.seen_user_id == "dev"
|
| 306 |
assert {session["session_id"] for session in sessions} == {"s1", "s2"}
|
| 307 |
+
yolo = next(session for session in sessions if session["session_id"] == "s1")
|
| 308 |
+
assert yolo["auto_approval"] == {
|
| 309 |
+
"enabled": True,
|
| 310 |
+
"cost_cap_usd": 5.0,
|
| 311 |
+
"estimated_spend_usd": 2.0,
|
| 312 |
+
"remaining_usd": 3.0,
|
| 313 |
+
}
|