diff --git "a/agent/core/agent_loop.py" "b/agent/core/agent_loop.py" new file mode 100644--- /dev/null +++ "b/agent/core/agent_loop.py" @@ -0,0 +1,2645 @@ +"""loop +Main agent implementation with integrated tool system and MCP support +""" + +import asyncio +import json +import logging +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from litellm import ( + ChatCompletionMessageToolCall, + Message, + acompletion, +) +from litellm.exceptions import ContextWindowExceededError + +from agent.config import Config +from agent.core.approval_policy import ( + is_scheduled_operation, + normalize_tool_operation, +) +from agent.core.cost_estimation import CostEstimate, estimate_tool_cost +from agent.messaging.gateway import NotificationGateway +from agent.core import telemetry +from agent.core.doom_loop import check_for_doom_loop +from agent.core.hf_access import ( + HF_BILLING_URL, + HF_PRO_SUBSCRIBE_URL, + is_inference_billing_error, +) +from agent.core.llm_params import _resolve_llm_params +from agent.core.prompt_caching import ( + router_session_id_for, + with_prompt_cache_params, + with_prompt_caching, +) +from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session +from agent.core.tools import ToolRouter +from agent.core.usage_thresholds import ( + USAGE_THRESHOLD_TOOL_NAME, + is_usage_threshold_pending, + next_usage_warning_threshold, +) +from agent.core.yolo_budget import ( + BudgetDecision, + check_session_budget, + is_yolo_budget_pending, + maybe_pause_yolo_after_spend, + release_budget_reservation, + reserve_session_budget, + yolo_budget_can_resume, + yolo_budget_pending_to_tool, +) +from agent.tools.jobs_tool import CPU_FLAVORS +from agent.tools.sandbox_tool import ( + DEFAULT_CPU_SANDBOX_HARDWARE, + start_cpu_sandbox_preload, + teardown_session_sandbox, +) + +logger = logging.getLogger(__name__) + +ToolCall = ChatCompletionMessageToolCall + +_MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '" +_MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments" +_NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT = 2 + + +def _unfinished_plan_items(session: Session) -> list[dict[str, str]]: + plan = getattr(session, "current_plan", None) or [] + unfinished: list[dict[str, str]] = [] + for item in plan: + if not isinstance(item, dict): + continue + status = item.get("status") + if status in {"pending", "in_progress"}: + unfinished.append(item) + return unfinished + + +def _format_plan_items_for_guard(items: list[dict[str, str]], limit: int = 4) -> str: + formatted = [] + for item in items[:limit]: + item_id = item.get("id") or "?" + content = item.get("content") or "(unnamed task)" + status = item.get("status") or "unknown" + formatted.append(f"{item_id}. {content} [{status}]") + if len(items) > limit: + formatted.append(f"... and {len(items) - limit} more") + return "; ".join(formatted) + + +def _no_tool_incomplete_plan_prompt(items: list[dict[str, str]]) -> str: + summary = _format_plan_items_for_guard(items) + return ( + "[SYSTEM: CONTINUATION GUARD] Your previous response ended without any " + "tool calls, but the task is not complete. The current plan still has " + f"unfinished items: {summary}. Do not return control to the user yet. " + "Continue from the next unfinished item and make at least one tool call " + "now. If you genuinely cannot continue, first use tools to inspect the " + "state or verify the blocker." + ) + + +def _malformed_tool_name(message: Message) -> str | None: + """Return the tool name for malformed-json tool-result messages.""" + if getattr(message, "role", None) != "tool": + return None + content = getattr(message, "content", None) + if not isinstance(content, str): + return None + if not content.startswith(_MALFORMED_TOOL_PREFIX): + return None + end = content.find(_MALFORMED_TOOL_SUFFIX, len(_MALFORMED_TOOL_PREFIX)) + if end == -1: + return None + return content[len(_MALFORMED_TOOL_PREFIX) : end] + + +def _detect_repeated_malformed( + items: list[Message], + threshold: int = 2, +) -> str | None: + """Return the repeated malformed tool name if the tail contains a streak. + + Walk backward over the current conversation tail. A streak counts only + consecutive malformed tool-result messages for the same tool; any other + tool result breaks it. + """ + if threshold <= 0: + return None + + streak_tool: str | None = None + streak = 0 + + for item in reversed(items): + if getattr(item, "role", None) != "tool": + continue + + malformed_tool = _malformed_tool_name(item) + if malformed_tool is None: + break + + if streak_tool is None: + streak_tool = malformed_tool + streak = 1 + elif malformed_tool == streak_tool: + streak += 1 + else: + break + + if streak >= threshold: + return streak_tool + + return None + + +def _coerce_float(value: Any) -> float: + if isinstance(value, bool) or value is None: + return 0.0 + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + +def _usage_output_message(pending: dict[str, Any]) -> str: + current = _coerce_float(pending.get("current_spend_usd")) + next_threshold = _coerce_float(pending.get("next_threshold_usd")) + return ( + f"Current-session usage warning acknowledged at ${current:.2f}. " + f"The next warning is at ${next_threshold:.2f}." + ) + + +async def _maybe_pause_for_usage_threshold( + session: Session, + *, + continuation: str, + final_response: str | None = None, +) -> bool: + checker = getattr(session, "usage_threshold_checker", None) + if checker is None or session.pending_approval: + return False + payload: dict[str, Any] = { + "continuation": continuation, + "force_check": continuation == "complete_turn", + "history_size": len(session.context_manager.items), + } + if final_response is not None: + payload["final_response"] = final_response + try: + return bool(await checker(payload)) + except Exception as e: + logger.debug("Usage threshold check failed: %s", e) + return False + + +def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]: + """ + Validate tool arguments structure. + + Returns: + (is_valid, error_message) + """ + args = tool_args.get("args", {}) + # Sometimes LLM passes args as string instead of dict + if isinstance(args, str): + return ( + False, + f"Tool call error: 'args' must be a JSON object, not a string. You passed: {repr(args)}", + ) + if not isinstance(args, dict) and args is not None: + return ( + False, + f"Tool call error: 'args' must be a JSON object. You passed type: {type(args).__name__}", + ) + return True, None + + +_IMMEDIATE_HF_JOB_RUNS = {"run", "uv"} + + +@dataclass(frozen=True) +class ApprovalDecision: + requires_approval: bool + auto_approved: bool = False + auto_approval_blocked: bool = False + block_reason: str | None = None + estimated_cost_usd: float | None = None + remaining_cap_usd: float | None = None + billable: bool = False + + +def _operation(tool_args: dict) -> str: + return normalize_tool_operation(tool_args.get("operation")) + + +def _is_immediate_hf_job_run(tool_name: str, tool_args: dict) -> bool: + return tool_name == "hf_jobs" and _operation(tool_args) in _IMMEDIATE_HF_JOB_RUNS + + +def _is_scheduled_hf_job_run(tool_name: str, tool_args: dict) -> bool: + return tool_name == "hf_jobs" and is_scheduled_operation(_operation(tool_args)) + + +def _is_budgeted_auto_approval_target(tool_name: str, tool_args: dict) -> bool: + return tool_name == "sandbox_create" or _is_immediate_hf_job_run( + tool_name, tool_args + ) + + +def _base_needs_approval( + tool_name: str, tool_args: dict, config: Config | None = None +) -> bool: + """Check if a tool call requires approval before YOLO policy is applied.""" + + # If args are malformed, skip approval (validation error will be shown later) + args_valid, _ = _validate_tool_args(tool_args) + if not args_valid: + return False + + if tool_name == "sandbox_create": + hardware = tool_args.get("hardware") or DEFAULT_CPU_SANDBOX_HARDWARE + return hardware != DEFAULT_CPU_SANDBOX_HARDWARE + + if tool_name == "hf_jobs": + operation = _operation(tool_args) + if is_scheduled_operation(operation): + return True + if operation not in _IMMEDIATE_HF_JOB_RUNS: + return False + + # Check if this is a CPU-only job + # hardware_flavor is at top level of tool_args, not nested in args + hardware_flavor = ( + tool_args.get("hardware_flavor") + or tool_args.get("flavor") + or tool_args.get("hardware") + or "cpu-basic" + ) + is_cpu_job = hardware_flavor in CPU_FLAVORS + + if is_cpu_job: + if config and not config.confirm_cpu_jobs: + return False + return True + + return True + + # Check for file upload operations (hf_private_repos or other tools) + if tool_name == "hf_private_repos": + operation = tool_args.get("operation", "") + if operation == "upload_file": + if config and config.auto_file_upload: + return False + return True + # Other operations (create_repo, etc.) always require approval + if operation in ["create_repo"]: + return True + + # hf_repo_files: upload (can overwrite) and delete require approval + if tool_name == "hf_repo_files": + operation = tool_args.get("operation", "") + if operation in ["upload", "delete"]: + return True + + # hf_repo_git: destructive operations require approval + if tool_name == "hf_repo_git": + operation = tool_args.get("operation", "") + if operation in [ + "delete_branch", + "delete_tag", + "merge_pr", + "create_repo", + "update_repo", + ]: + return True + + return False + + +def _session_auto_approval_enabled(session: Session | None) -> bool: + return bool(session and getattr(session, "auto_approval_enabled", False)) + + +def _effective_yolo_enabled(session: Session | None, config: Config | None) -> bool: + return bool( + (config and config.yolo_mode) or _session_auto_approval_enabled(session) + ) + + +async def _approval_decision( + tool_name: str, + tool_args: dict, + session: Session, + *, + reserved_spend_usd: float = 0.0, +) -> ApprovalDecision: + """Return the approval decision for one parsed tool call.""" + config = session.config + base_requires_approval = _base_needs_approval(tool_name, tool_args, config) + + # Scheduled jobs are recurring/unbounded enough that YOLO never bypasses + # the human confirmation, including legacy config.yolo_mode. + if _is_scheduled_hf_job_run(tool_name, tool_args): + reason = "Scheduled HF jobs always require manual approval." + if _session_auto_approval_enabled(session): + reason = "Scheduled HF jobs require disabling YOLO because their recurring cost is unbounded." + return ApprovalDecision( + requires_approval=True, + auto_approval_blocked=_effective_yolo_enabled(session, config), + block_reason=reason, + ) + + yolo_enabled = _effective_yolo_enabled(session, config) + budgeted_target = _is_budgeted_auto_approval_target(tool_name, tool_args) + + # Cost caps are a session-scoped web policy. Legacy config.yolo_mode + # remains uncapped for CLI/headless, except for scheduled jobs above. + session_yolo_enabled = _session_auto_approval_enabled(session) + if yolo_enabled and budgeted_target and session_yolo_enabled: + estimate = await estimate_tool_cost(tool_name, tool_args, session=session) + budget = check_session_budget( + session, + estimate, + reserved_spend_usd=reserved_spend_usd, + ) + if not budget.allowed: + return ApprovalDecision( + requires_approval=True, + auto_approval_blocked=True, + block_reason=budget.block_reason, + estimated_cost_usd=budget.estimated_cost_usd, + remaining_cap_usd=budget.remaining_cap_usd, + billable=estimate.billable, + ) + if base_requires_approval: + return ApprovalDecision( + requires_approval=False, + auto_approved=True, + estimated_cost_usd=budget.estimated_cost_usd, + remaining_cap_usd=budget.remaining_cap_usd, + billable=estimate.billable, + ) + return ApprovalDecision( + requires_approval=False, + estimated_cost_usd=budget.estimated_cost_usd, + remaining_cap_usd=budget.remaining_cap_usd, + billable=estimate.billable, + ) + + if base_requires_approval and yolo_enabled: + return ApprovalDecision(requires_approval=False, auto_approved=True) + + return ApprovalDecision(requires_approval=base_requires_approval) + + +def _record_estimated_spend( + session: Session, + decision: ApprovalDecision, + *, + reservation_id: str | None = None, +) -> BudgetDecision: + if not decision.billable or decision.estimated_cost_usd is None: + return BudgetDecision(allowed=True, billable=False) + return reserve_session_budget( + session, + CostEstimate( + estimated_cost_usd=decision.estimated_cost_usd, + billable=True, + ), + spend_kind="tool", + reservation_id=reservation_id, + ) + + +async def _record_manual_approved_spend_if_needed( + session: Session, + tool_name: str, + tool_args: dict, + *, + tool_call_id: str | None = None, +) -> BudgetDecision: + if not _session_auto_approval_enabled(session): + return BudgetDecision(allowed=True) + if _is_scheduled_hf_job_run(tool_name, tool_args): + return BudgetDecision( + allowed=False, + billable=True, + block_reason=( + "Scheduled HF jobs require disabling YOLO because their recurring " + "cost is unbounded." + ), + ) + if not _is_budgeted_auto_approval_target(tool_name, tool_args): + return BudgetDecision(allowed=True) + estimate = await estimate_tool_cost(tool_name, tool_args, session=session) + return reserve_session_budget( + session, + estimate, + spend_kind=tool_name, + reservation_id=tool_call_id, + ) + + +async def _check_manual_approved_budget( + session: Session, + tool_name: str, + tool_args: dict, + *, + reserved_spend_usd: float = 0.0, +) -> BudgetDecision: + if not _session_auto_approval_enabled(session): + return BudgetDecision(allowed=True) + if _is_scheduled_hf_job_run(tool_name, tool_args): + return BudgetDecision( + allowed=False, + billable=True, + block_reason=( + "Scheduled HF jobs require disabling YOLO because their recurring " + "cost is unbounded." + ), + ) + if not _is_budgeted_auto_approval_target(tool_name, tool_args): + return BudgetDecision(allowed=True) + estimate = await estimate_tool_cost(tool_name, tool_args, session=session) + return check_session_budget( + session, + estimate, + reserved_spend_usd=reserved_spend_usd, + ) + + +# -- LLM retry constants -------------------------------------------------- +_MAX_LLM_RETRIES = 3 +_LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries +_LLM_RATE_LIMIT_RETRY_DELAYS = [30, 60] + + +def _is_rate_limit_error(error: Exception) -> bool: + """Return True for rate-limit / quota-bucket style provider errors.""" + err_str = str(error).lower() + rate_limit_patterns = [ + "429", + "rate limit", + "rate_limit", + "too many requests", + "too many tokens", + "request limit", + "throttl", + ] + return any(pattern in err_str for pattern in rate_limit_patterns) + + +def _is_context_overflow_error(error: Exception) -> bool: + """Return True when the prompt exceeded the model's context window.""" + if isinstance(error, ContextWindowExceededError): + return True + + err_str = str(error).lower() + overflow_patterns = [ + "context window exceeded", + "maximum context length", + "max context length", + "prompt is too long", + "context length exceeded", + "too many input tokens", + "input is too long", + ] + return any(pattern in err_str for pattern in overflow_patterns) + + +def _retry_delay_for(error: Exception, attempt_index: int) -> int | None: + """Return the delay for this retry attempt, or None if it should not retry.""" + if _is_rate_limit_error(error): + schedule = _LLM_RATE_LIMIT_RETRY_DELAYS + elif _is_transient_error(error): + schedule = _LLM_RETRY_DELAYS + else: + return None + + if attempt_index >= len(schedule): + return None + return schedule[attempt_index] + + +def _is_transient_error(error: Exception) -> bool: + """Return True for errors that are likely transient and worth retrying.""" + err_str = str(error).lower() + transient_patterns = [ + "timeout", + "timed out", + "503", + "service unavailable", + "502", + "bad gateway", + "500", + "internal server error", + "overloaded", + "capacity", + "connection reset", + "connection refused", + "connection error", + "eof", + "broken pipe", + ] + return _is_rate_limit_error(error) or any( + pattern in err_str for pattern in transient_patterns + ) + + +def _is_effort_config_error(error: Exception) -> bool: + """Catch the two 400s the effort probe also handles — thinking + unsupported for this model, or the specific effort level invalid. + + This is our safety net for the case where ``/effort`` was changed + mid-conversation (which clears the probe cache) and the new level + doesn't work for the current model. We heal the cache and retry once. + """ + from agent.core.effort_probe import _is_invalid_effort, _is_thinking_unsupported + + return _is_thinking_unsupported(error) or _is_invalid_effort(error) + + +async def _heal_effort_and_rebuild_params( + session: Session, + error: Exception, + llm_params: dict, +) -> dict: + """Update the session's effort cache based on ``error`` and return new + llm_params. Called only when ``_is_effort_config_error(error)`` is True. + + Two branches: + • thinking-unsupported → cache ``None`` for this model, next call + strips thinking entirely + • invalid-effort → re-run the full cascade probe; the result lands + in the cache + """ + from agent.core.effort_probe import ( + ProbeInconclusive, + _is_thinking_unsupported, + probe_effort, + ) + + model = session.config.model_name + if _is_thinking_unsupported(error): + session.model_effective_effort[model] = None + logger.info("healed: %s doesn't support thinking — stripped", model) + else: + try: + outcome = await probe_effort( + model, + session.config.reasoning_effort, + session.hf_token, + session=session, + ) + session.model_effective_effort[model] = outcome.effective_effort + logger.info( + "healed: %s effort cascade → %s", + model, + outcome.effective_effort, + ) + except ProbeInconclusive: + # Transient during healing — strip thinking for safety, next + # call will either succeed or surface the real error. + session.model_effective_effort[model] = None + logger.info("healed: %s probe inconclusive — stripped", model) + + return _resolve_llm_params( + model, + session.hf_token, + reasoning_effort=session.effective_effort_for(model), + ) + + +def _inference_credit_error_message(user_plan: str | None = None) -> str: + plan = (user_plan or "unknown").lower() + if plan == "pro": + return ( + "Hugging Face Inference Providers credits are exhausted for this " + "account.\n\n" + f"Add credits to continue: {HF_BILLING_URL}" + ) + if plan == "free": + return ( + "Your monthly Hugging Face Inference Providers credits are exhausted.\n\n" + f"Subscribe to HF PRO for more monthly usage: {HF_PRO_SUBSCRIBE_URL}\n" + f"Or add pay-as-you-go credits: {HF_BILLING_URL}" + ) + return ( + "Hugging Face Inference Providers credits appear to be exhausted for " + "this account.\n\n" + f"Add pay-as-you-go credits: {HF_BILLING_URL}\n" + f"If this is a free account, HF PRO adds more monthly usage: {HF_PRO_SUBSCRIBE_URL}" + ) + + +def _friendly_error_message( + error: Exception, + *, + user_plan: str | None = None, +) -> str | None: + """Return a user-friendly message for known error types, or None to fall back to traceback.""" + err_str = str(error).lower() + + if ( + "authentication" in err_str + or "unauthorized" in err_str + or "invalid x-api-key" in err_str + ): + return ( + "Authentication failed - your Hugging Face token is missing or invalid.\n\n" + "To fix this, set HF_TOKEN=hf_... or run `hf auth login`.\n\n" + "You can also add it to a .env file in the project root.\n" + "To switch models, use the /model command." + ) + + if is_inference_billing_error(error): + return _inference_credit_error_message(user_plan) + + if "not supported by provider" in err_str or "no provider supports" in err_str: + return ( + "The model isn't served by the provider you pinned.\n\n" + "Drop the ':' suffix to let the HF router auto-pick a " + "provider, or use '/model' (no arg) to see which providers host " + "which models." + ) + + if "model_not_found" in err_str or ( + "model" in err_str and ("not found" in err_str or "does not exist" in err_str) + ): + return ( + "Model not found. Use '/model' to list suggestions, or paste an " + "HF model id like 'MiniMaxAI/MiniMax-M2.7'. Availability is shown " + "when you switch." + ) + + return None + + +async def _compact_and_notify(session: Session) -> None: + """Run compaction and send event if context was reduced. + + Catches ``CompactionFailedError`` and ends the session cleanly instead + of letting the caller retry. Pre-2026-05-04 the caller looped on + ContextWindowExceededError → compact → re-trigger, burning hosted + inference budget while the session never reached the upload path. + """ + from agent.context_manager.manager import CompactionFailedError + + cm = session.context_manager + old_usage = cm.running_context_usage + logger.debug( + "Compaction check: usage=%d, max=%d, threshold=%d, needs_compact=%s", + old_usage, + cm.model_max_tokens, + cm.compaction_threshold, + cm.needs_compaction, + ) + try: + await cm.compact( + model_name=session.config.model_name, + tool_specs=session.tool_router.get_tool_specs_for_llm(), + hf_token=session.hf_token, + session=session, + ) + except CompactionFailedError as e: + logger.error( + "Compaction failed for session %s: %s — terminating session", + session.session_id, + e, + ) + # Persist the failure event so the dataset has a record of WHY this + # session ended (and the cost it incurred up to that point) even if + # save_and_upload_detached has issues downstream. + await session.send_event( + Event( + event_type="session_terminated", + data={ + "reason": "compaction_failed", + "context_usage": cm.running_context_usage, + "context_threshold": cm.compaction_threshold, + "error": str(e)[:300], + "user_message": ( + "Your conversation has grown too large to continue. " + "The work you've done is saved — start a new session to keep going." + ), + }, + ) + ) + # Stop the agent loop; the finally in _run_session will fire + # cleanup_sandbox + save_trajectory so the dataset captures + # everything that did happen. + session.is_running = False + return + + new_usage = cm.running_context_usage + if new_usage != old_usage: + logger.warning( + "Context compacted: %d -> %d tokens (max=%d, %d messages)", + old_usage, + new_usage, + cm.model_max_tokens, + len(cm.items), + ) + await session.send_event( + Event( + event_type="compacted", + data={"old_tokens": old_usage, "new_tokens": new_usage}, + ) + ) + + +async def _cleanup_on_cancel(session: Session) -> None: + """Kill sandbox processes and cancel HF jobs when the user interrupts.""" + # Kill active sandbox processes + sandbox = getattr(session, "sandbox", None) + if sandbox: + try: + await asyncio.to_thread(sandbox.kill_all) + logger.info("Killed sandbox processes on cancel") + except Exception as e: + logger.warning("Failed to kill sandbox processes: %s", e) + + # Cancel running HF jobs + job_ids = list(session._running_job_ids) + if job_ids: + from huggingface_hub import HfApi + + api = HfApi(token=session.hf_token) + for job_id in job_ids: + try: + await asyncio.to_thread(api.cancel_job, job_id=job_id) + logger.info("Cancelled HF job %s on interrupt", job_id) + except Exception as e: + logger.warning("Failed to cancel HF job %s: %s", job_id, e) + session._running_job_ids.clear() + + +@dataclass +class LLMResult: + """Result from an LLM call (streaming or non-streaming).""" + + content: str | None + tool_calls_acc: dict[int, dict] + token_count: int + finish_reason: str | None + usage: dict = field(default_factory=dict) + + +def _is_invalid_thinking_signature_error(exc: Exception) -> bool: + """Return True when a provider rejected replayed thinking metadata.""" + text = str(exc) + return ( + "Invalid `signature` in `thinking` block" in text + or "Invalid signature in thinking block" in text + ) + + +def _strip_thinking_state_from_messages(messages: list[Any]) -> int: + """Remove replayed thinking metadata from assistant history messages.""" + stripped = 0 + + for message in messages: + role = ( + message.get("role") + if isinstance(message, dict) + else getattr(message, "role", None) + ) + if role != "assistant": + continue + + if isinstance(message, dict): + if message.pop("thinking_blocks", None) is not None: + stripped += 1 + if message.pop("reasoning_content", None) is not None: + stripped += 1 + provider_fields = message.get("provider_specific_fields") + content = message.get("content") + else: + if getattr(message, "thinking_blocks", None) is not None: + message.thinking_blocks = None + stripped += 1 + if getattr(message, "reasoning_content", None) is not None: + message.reasoning_content = None + stripped += 1 + provider_fields = getattr(message, "provider_specific_fields", None) + content = getattr(message, "content", None) + + if isinstance(provider_fields, dict): + cleaned_fields = dict(provider_fields) + if cleaned_fields.pop("thinking_blocks", None) is not None: + stripped += 1 + if cleaned_fields.pop("reasoning_content", None) is not None: + stripped += 1 + if cleaned_fields != provider_fields: + if isinstance(message, dict): + message["provider_specific_fields"] = cleaned_fields + else: + message.provider_specific_fields = cleaned_fields + + if isinstance(content, list): + cleaned_content = [ + block + for block in content + if not ( + isinstance(block, dict) + and block.get("type") in {"thinking", "redacted_thinking"} + ) + ] + if len(cleaned_content) != len(content): + stripped += len(content) - len(cleaned_content) + if isinstance(message, dict): + message["content"] = cleaned_content + else: + message.content = cleaned_content + + return stripped + + +async def _maybe_heal_invalid_thinking_signature( + session: Session, + messages: list[Any], + exc: Exception, + *, + already_healed: bool, +) -> bool: + if already_healed or not _is_invalid_thinking_signature_error(exc): + return False + + stripped = _strip_thinking_state_from_messages(messages) + if not stripped: + return False + + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": ( + "The inference provider rejected stale thinking signatures; retrying " + "without replayed thinking metadata." + ), + }, + ) + ) + return True + + +def _assistant_message_from_result( + llm_result: LLMResult, + *, + tool_calls: list[ToolCall] | None = None, +) -> Message: + """Build an assistant history message for HF Router-compatible replay.""" + kwargs: dict[str, Any] = { + "role": "assistant", + "content": llm_result.content, + } + if tool_calls is not None: + kwargs["tool_calls"] = tool_calls + return Message(**kwargs) + + +async def _call_llm_streaming( + session: Session, messages, tools, llm_params +) -> LLMResult: + """Call the LLM with streaming, emitting assistant_chunk events.""" + response = None + _healed_effort = False # one-shot safety net per call + _healed_thinking_signature = False + t_start = time.monotonic() + for _llm_attempt in range(_MAX_LLM_RETRIES): + try: + request_llm_params = with_prompt_cache_params( + llm_params, + session_id=router_session_id_for(session), + ) + cached_messages, cached_tools = with_prompt_caching( + messages, tools, request_llm_params + ) + response = await acompletion( + messages=cached_messages, + tools=cached_tools, + tool_choice="auto", + stream=True, + stream_options={"include_usage": True}, + timeout=600, + **request_llm_params, + ) + break + except ContextWindowExceededError: + raise + except Exception as e: + if _is_context_overflow_error(e): + raise ContextWindowExceededError(str(e)) from e + if not _healed_effort and _is_effort_config_error(e): + _healed_effort = True + llm_params = await _heal_effort_and_rebuild_params( + session, e, llm_params + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": "Reasoning effort not supported for this model — adjusting and retrying.", + }, + ) + ) + continue + if await _maybe_heal_invalid_thinking_signature( + session, + messages, + e, + already_healed=_healed_thinking_signature, + ): + _healed_thinking_signature = True + continue + _delay = _retry_delay_for(e, _llm_attempt) + if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None: + logger.warning( + "Transient LLM error (attempt %d/%d): %s — retrying in %ds", + _llm_attempt + 1, + _MAX_LLM_RETRIES, + e, + _delay, + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": f"LLM connection error, retrying in {_delay}s...", + }, + ) + ) + await asyncio.sleep(_delay) + continue + raise + + full_content = "" + tool_calls_acc: dict[int, dict] = {} + token_count = 0 + finish_reason = None + final_usage_chunk = None + + async for chunk in response: + if session.is_cancelled: + tool_calls_acc.clear() + break + + choice = chunk.choices[0] if chunk.choices else None + if not choice: + if hasattr(chunk, "usage") and chunk.usage: + token_count = chunk.usage.total_tokens + final_usage_chunk = chunk + continue + + delta = choice.delta + if choice.finish_reason: + finish_reason = choice.finish_reason + + if delta.content: + full_content += delta.content + await session.send_event( + Event(event_type="assistant_chunk", data={"content": delta.content}) + ) + + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in tool_calls_acc: + tool_calls_acc[idx] = { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + if tc_delta.id: + tool_calls_acc[idx]["id"] = tc_delta.id + if tc_delta.function: + if tc_delta.function.name: + tool_calls_acc[idx]["function"]["name"] += ( + tc_delta.function.name + ) + if tc_delta.function.arguments: + tool_calls_acc[idx]["function"]["arguments"] += ( + tc_delta.function.arguments + ) + + if hasattr(chunk, "usage") and chunk.usage: + token_count = chunk.usage.total_tokens + final_usage_chunk = chunk + + usage = await telemetry.record_llm_call( + session, + model=llm_params.get("model", session.config.model_name), + response=final_usage_chunk, + latency_ms=int((time.monotonic() - t_start) * 1000), + finish_reason=finish_reason, + ) + return LLMResult( + content=full_content or None, + tool_calls_acc=tool_calls_acc, + token_count=token_count, + finish_reason=finish_reason, + usage=usage, + ) + + +async def _call_llm_non_streaming( + session: Session, messages, tools, llm_params +) -> LLMResult: + """Call the LLM without streaming, emit assistant_message at the end.""" + response = None + _healed_effort = False + _healed_thinking_signature = False + t_start = time.monotonic() + for _llm_attempt in range(_MAX_LLM_RETRIES): + try: + request_llm_params = with_prompt_cache_params( + llm_params, + session_id=router_session_id_for(session), + ) + cached_messages, cached_tools = with_prompt_caching( + messages, tools, request_llm_params + ) + response = await acompletion( + messages=cached_messages, + tools=cached_tools, + tool_choice="auto", + stream=False, + timeout=600, + **request_llm_params, + ) + break + except ContextWindowExceededError: + raise + except Exception as e: + if _is_context_overflow_error(e): + raise ContextWindowExceededError(str(e)) from e + if not _healed_effort and _is_effort_config_error(e): + _healed_effort = True + llm_params = await _heal_effort_and_rebuild_params( + session, e, llm_params + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": "Reasoning effort not supported for this model — adjusting and retrying.", + }, + ) + ) + continue + if await _maybe_heal_invalid_thinking_signature( + session, + messages, + e, + already_healed=_healed_thinking_signature, + ): + _healed_thinking_signature = True + continue + _delay = _retry_delay_for(e, _llm_attempt) + if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None: + logger.warning( + "Transient LLM error (attempt %d/%d): %s — retrying in %ds", + _llm_attempt + 1, + _MAX_LLM_RETRIES, + e, + _delay, + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": f"LLM connection error, retrying in {_delay}s...", + }, + ) + ) + await asyncio.sleep(_delay) + continue + raise + + choice = response.choices[0] + message = choice.message + content = message.content or None + finish_reason = choice.finish_reason + token_count = response.usage.total_tokens if response.usage else 0 + + # Build tool_calls_acc in the same format as streaming + tool_calls_acc: dict[int, dict] = {} + if message.tool_calls: + for idx, tc in enumerate(message.tool_calls): + tool_calls_acc[idx] = { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + + # Emit the full message as a single event + if content: + await session.send_event( + Event(event_type="assistant_message", data={"content": content}) + ) + + usage = await telemetry.record_llm_call( + session, + model=llm_params.get("model", session.config.model_name), + response=response, + latency_ms=int((time.monotonic() - t_start) * 1000), + finish_reason=finish_reason, + ) + + return LLMResult( + content=content, + tool_calls_acc=tool_calls_acc, + token_count=token_count, + finish_reason=finish_reason, + usage=usage, + ) + + +class Handlers: + """Handler functions for each operation type""" + + @staticmethod + async def _abandon_pending_approval(session: Session) -> None: + """Cancel pending approval tools when the user continues the conversation. + + Injects rejection tool-result messages into the LLM context (so the + history stays valid) and notifies the frontend that those tools were + abandoned. + """ + if is_usage_threshold_pending( + session.pending_approval + ) or is_yolo_budget_pending(session.pending_approval): + pending = session.pending_approval + tool_call_id = str(pending.get("tool_call_id") or "") + tool_name = str(pending.get("kind") or USAGE_THRESHOLD_TOOL_NAME) + session.pending_approval = None + if tool_call_id: + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tool_call_id, + "tool": tool_name, + "state": "abandoned", + }, + ) + ) + if pending.get("continuation") == "complete_turn": + final_response = pending.get("final_response") + await session.send_event( + Event( + event_type="turn_complete", + data={ + "history_size": int( + pending.get("history_size") + or len(session.context_manager.items) + ), + "final_response": final_response + if isinstance(final_response, str) + else None, + }, + ) + ) + session.increment_turn() + await session.auto_save_if_needed() + return + + tool_calls = session.pending_approval.get("tool_calls", []) + for tc in tool_calls: + tool_name = tc.function.name + abandon_msg = ( + "Task abandoned — user continued the conversation without approving." + ) + + # Keep LLM context valid: every tool_call needs a tool result + tool_msg = Message( + role="tool", + content=abandon_msg, + tool_call_id=tc.id, + name=tool_name, + ) + session.context_manager.add_message(tool_msg) + + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "abandoned", + }, + ) + ) + + session.pending_approval = None + logger.info("Abandoned %d pending approval tool(s)", len(tool_calls)) + + @staticmethod + async def run_agent( + session: Session, + text: str, + ) -> str | None: + """ + Handle user input (like user_input_or_turn in codex.rs:1291) + Returns the final assistant response content, if any. + """ + # Clear any stale cancellation flag from a previous run + session.reset_cancel() + + # If there's a pending approval and the user sent a new message, + # abandon the pending tools so the LLM context stays valid. + if text and session.pending_approval: + await Handlers._abandon_pending_approval(session) + + # Add user message to history only if there's actual content + if text: + user_msg = Message(role="user", content=text) + session.context_manager.add_message(user_msg) + + # Send event that we're processing + await session.send_event( + Event(event_type="processing", data={"message": "Processing user input"}) + ) + + # Agentic loop - continue until model doesn't call tools or max iterations is reached + iteration = 0 + final_response = None + errored = False + max_iterations = session.config.max_iterations + no_tool_incomplete_plan_retries = 0 + + while max_iterations == -1 or iteration < max_iterations: + # ── Cancellation check: before LLM call ── + if session.is_cancelled: + break + if session.pending_approval: + return final_response + + # Compact before calling the LLM if context is near the limit. + # When _compact_and_notify catches CompactionFailedError it sets + # session.is_running = False; we MUST exit the loop here, otherwise + # the LLM call below fires with an over-threshold context, hits + # ContextWindowExceededError, and we end up looping again on the + # except path — exactly the bug this PR is supposed to fix. + await _compact_and_notify(session) + if not session.is_running: + break + if session.pending_approval: + return final_response + + if await _maybe_pause_for_usage_threshold( + session, + continuation="continue_agent", + ): + return final_response + + # Doom-loop detection: break out of repeated tool call patterns + doom_prompt = check_for_doom_loop(session.context_manager.items) + if doom_prompt: + session.context_manager.add_message( + Message(role="user", content=doom_prompt) + ) + + malformed_tool = _detect_repeated_malformed(session.context_manager.items) + if malformed_tool: + recovery_prompt = ( + "[SYSTEM: Repeated malformed tool arguments detected for " + f"'{malformed_tool}'. Stop retrying the same tool call shape. " + "Use a different strategy that produces smaller, valid JSON. " + "For large file writes, prefer bash with a heredoc or split the " + "edit into multiple smaller tool calls.]" + ) + session.context_manager.add_message( + Message(role="user", content=recovery_prompt) + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": ( + "Repeated malformed tool arguments detected — " + f"forcing a different strategy for {malformed_tool}" + ), + }, + ) + ) + + messages = session.context_manager.get_messages() + tools = session.tool_router.get_tool_specs_for_llm() + try: + # ── Call the LLM (streaming or non-streaming) ── + # Pull the per-model probed effort from the session cache when + # available; fall back to the raw preference for models we + # haven't probed yet (e.g. research sub-model). + llm_params = _resolve_llm_params( + session.config.model_name, + session.hf_token, + reasoning_effort=session.effective_effort_for( + session.config.model_name + ), + ) + if session.stream: + llm_result = await _call_llm_streaming( + session, messages, tools, llm_params + ) + else: + llm_result = await _call_llm_non_streaming( + session, messages, tools, llm_params + ) + llm_observed_cost_usd = llm_result.usage.get("cost_usd") + + content = llm_result.content + tool_calls_acc = llm_result.tool_calls_acc + token_count = llm_result.token_count + finish_reason = llm_result.finish_reason + + # If output was truncated, all tool call args are garbage. + # Inject a system hint so the LLM retries with smaller content. + if finish_reason == "length" and tool_calls_acc: + dropped_names = [ + tc["function"]["name"] + for tc in tool_calls_acc.values() + if tc["function"]["name"] + ] + logger.warning( + "Output truncated (finish_reason=length) — dropping tool calls: %s", + dropped_names, + ) + tool_calls_acc.clear() + + # Tell the agent what happened so it can retry differently + truncation_hint = ( + "Your previous response was truncated because the output hit the " + "token limit. The following tool calls were lost: " + f"{dropped_names}. " + "IMPORTANT: Do NOT retry with the same large content. Instead:\n" + " • For 'write': use bash with cat<<'HEREDOC' to write the file, " + "or split into several smaller edit calls.\n" + " • For other tools: reduce the size of your arguments or use bash." + ) + if content: + assistant_msg = _assistant_message_from_result(llm_result) + session.context_manager.add_message(assistant_msg, token_count) + session.context_manager.add_message( + Message(role="user", content=f"[SYSTEM: {truncation_hint}]") + ) + if session.stream: + await session.send_event( + Event(event_type="assistant_stream_end", data={}) + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": f"Output truncated — retrying with smaller content ({dropped_names})", + }, + ) + ) + iteration += 1 + continue # retry this iteration + + # Build tool_calls list from accumulated deltas + tool_calls: list[ToolCall] = [] + for idx in sorted(tool_calls_acc.keys()): + tc_data = tool_calls_acc[idx] + tool_calls.append( + ToolCall( + id=tc_data["id"], + type="function", + function={ + "name": tc_data["function"]["name"], + "arguments": tc_data["function"]["arguments"], + }, + ) + ) + + # Signal end of streaming to the frontend + if session.stream: + await session.send_event( + Event(event_type="assistant_stream_end", data={}) + ) + + # If no tool calls, add assistant message and we're done + if not tool_calls: + unfinished_plan = _unfinished_plan_items(session) + if ( + unfinished_plan + and no_tool_incomplete_plan_retries + < _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT + ): + if await maybe_pause_yolo_after_spend( + session, + spend_kind="llm_call", + observed_cost_usd=llm_observed_cost_usd, + ): + return final_response + logger.info( + "No tool calls with unfinished plan; retrying agent turn " + "(attempt %d/%d)", + no_tool_incomplete_plan_retries + 1, + _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT, + ) + if content: + assistant_msg = _assistant_message_from_result(llm_result) + session.context_manager.add_message( + assistant_msg, token_count + ) + session.context_manager.add_message( + Message( + role="user", + content=_no_tool_incomplete_plan_prompt( + unfinished_plan + ), + ) + ) + no_tool_incomplete_plan_retries += 1 + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": ( + "Plan still has unfinished items after a " + "text-only response — retrying instead of " + "returning to the prompt." + ), + }, + ) + ) + iteration += 1 + continue + + logger.debug( + "Agent loop ending: no tool calls. " + "finish_reason=%s, token_count=%d, " + "usage=%d, model_max_tokens=%d, " + "iteration=%d/%d, " + "response_text=%s", + finish_reason, + token_count, + session.context_manager.running_context_usage, + session.context_manager.model_max_tokens, + iteration, + max_iterations, + (content or "")[:500], + ) + if content: + assistant_msg = _assistant_message_from_result(llm_result) + session.context_manager.add_message(assistant_msg, token_count) + final_response = content + if await maybe_pause_yolo_after_spend( + session, + spend_kind="llm_call", + observed_cost_usd=llm_observed_cost_usd, + continuation="complete_turn", + final_response=final_response + if isinstance(final_response, str) + else None, + ): + return final_response + break + + no_tool_incomplete_plan_retries = 0 + + if await maybe_pause_yolo_after_spend( + session, + spend_kind="llm_call", + observed_cost_usd=llm_observed_cost_usd, + ): + return final_response + + # Validate tool call args (one json.loads per call, once) + # and split into good vs bad + good_tools: list[tuple[ToolCall, str, dict]] = [] + bad_tools: list[ToolCall] = [] + for tc in tool_calls: + try: + args = json.loads(tc.function.arguments) + good_tools.append((tc, tc.function.name, args)) + except (json.JSONDecodeError, TypeError, ValueError): + logger.warning( + "Malformed arguments for tool_call %s (%s) — skipping", + tc.id, + tc.function.name, + ) + tc.function.arguments = "{}" + bad_tools.append(tc) + + # Add assistant message with all tool calls to context + assistant_msg = _assistant_message_from_result( + llm_result, + tool_calls=tool_calls, + ) + session.context_manager.add_message(assistant_msg, token_count) + + # Add error results for bad tool calls so the LLM + # knows what happened and can retry differently + for tc in bad_tools: + error_msg = ( + f"ERROR: Tool call to '{tc.function.name}' had malformed JSON " + f"arguments and was NOT executed. Retry with smaller content — " + f"for 'write', split into multiple smaller writes using 'edit'." + ) + session.context_manager.add_message( + Message( + role="tool", + content=error_msg, + tool_call_id=tc.id, + name=tc.function.name, + ) + ) + await session.send_event( + Event( + event_type="tool_call", + data={ + "tool": tc.function.name, + "arguments": {}, + "tool_call_id": tc.id, + }, + ) + ) + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": tc.function.name, + "tool_call_id": tc.id, + "output": error_msg, + "success": False, + }, + ) + ) + + # ── Cancellation check: before tool execution ── + if session.is_cancelled: + break + + # Separate good tools into approval-required vs auto-execute. + # Track reserved spend while classifying a batch so two + # auto-approved jobs in one model response cannot jointly + # exceed the remaining session cap. + approval_required_tools: list[ + tuple[ToolCall, str, dict, ApprovalDecision] + ] = [] + non_approval_tools: list[ + tuple[ToolCall, str, dict, ApprovalDecision] + ] = [] + reserved_auto_spend_usd = 0.0 + for tc, tool_name, tool_args in good_tools: + decision = await _approval_decision( + tool_name, + tool_args, + session, + reserved_spend_usd=reserved_auto_spend_usd, + ) + if decision.requires_approval: + approval_required_tools.append( + (tc, tool_name, tool_args, decision) + ) + else: + non_approval_tools.append((tc, tool_name, tool_args, decision)) + if ( + decision.auto_approved + and decision.billable + and decision.estimated_cost_usd is not None + ): + reserved_auto_spend_usd += decision.estimated_cost_usd + + # Execute non-approval tools (in parallel when possible) + if non_approval_tools: + # 1. Validate args upfront + parsed_tools: list[ + tuple[ToolCall, str, dict, ApprovalDecision, bool, str] + ] = [] + for tc, tool_name, tool_args, decision in non_approval_tools: + args_valid, error_msg = _validate_tool_args(tool_args) + parsed_tools.append( + (tc, tool_name, tool_args, decision, args_valid, error_msg) + ) + + # 2. Send all tool_call events upfront (so frontend shows them all) + for ( + tc, + tool_name, + tool_args, + _decision, + args_valid, + _, + ) in parsed_tools: + if args_valid: + await session.send_event( + Event( + event_type="tool_call", + data={ + "tool": tool_name, + "arguments": tool_args, + "tool_call_id": tc.id, + }, + ) + ) + + # 3. Execute all valid tools in parallel, cancellable + async def _exec_tool( + tc: ToolCall, + name: str, + args: dict, + decision: ApprovalDecision, + valid: bool, + err: str, + ) -> tuple[ToolCall, str, dict, str, bool]: + if not valid: + return (tc, name, args, err, False) + if decision.billable: + budget = _record_estimated_spend( + session, + decision, + reservation_id=tc.id, + ) + if not budget.allowed: + return ( + tc, + name, + args, + budget.block_reason + or "YOLO budget blocked this tool call.", + False, + ) + out, ok = await session.tool_router.call_tool( + name, args, session=session, tool_call_id=tc.id + ) + if not ok and decision.billable: + release_budget_reservation(session, tc.id) + return (tc, name, args, out, ok) + + gather_task = asyncio.ensure_future( + asyncio.gather( + *[ + _exec_tool(tc, name, args, decision, valid, err) + for tc, name, args, decision, valid, err in parsed_tools + ] + ) + ) + cancel_task = asyncio.ensure_future(session._cancelled.wait()) + + done, _ = await asyncio.wait( + [gather_task, cancel_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + if cancel_task in done: + gather_task.cancel() + try: + await gather_task + except asyncio.CancelledError: + pass + # Notify frontend that in-flight tools were cancelled + for tc, name, _args, _decision, valid, _ in parsed_tools: + if valid: + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": name, + "state": "cancelled", + }, + ) + ) + await _cleanup_on_cancel(session) + break + + cancel_task.cancel() + results = gather_task.result() + + # 4. Record results and send outputs (order preserved) + for tc, tool_name, tool_args, output, success in results: + tool_msg = Message( + role="tool", + content=output, + tool_call_id=tc.id, + name=tool_name, + ) + session.context_manager.add_message(tool_msg) + + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": tool_name, + "tool_call_id": tc.id, + "output": output, + "success": success, + }, + ) + ) + + # If there are tools requiring approval, ask for batch approval + if approval_required_tools: + # Prepare batch approval data + tools_data = [] + blocked_payloads = [] + for tc, tool_name, tool_args, decision in approval_required_tools: + # Resolve sandbox file paths for hf_jobs scripts so the + # frontend can display & edit the actual file content. + if tool_name == "hf_jobs" and isinstance( + tool_args.get("script"), str + ): + from agent.tools.sandbox_tool import resolve_sandbox_script + + sandbox = getattr(session, "sandbox", None) + resolved, _ = await resolve_sandbox_script( + sandbox, tool_args["script"] + ) + if resolved: + tool_args = {**tool_args, "script": resolved} + + tool_payload = { + "tool": tool_name, + "arguments": tool_args, + "tool_call_id": tc.id, + } + if decision.auto_approval_blocked: + tool_payload.update( + { + "auto_approval_blocked": True, + "block_reason": decision.block_reason, + "estimated_cost_usd": decision.estimated_cost_usd, + "remaining_cap_usd": decision.remaining_cap_usd, + } + ) + blocked_payloads.append(tool_payload) + tools_data.append(tool_payload) + + event_data = {"tools": tools_data, "count": len(tools_data)} + if blocked_payloads: + first = blocked_payloads[0] + event_data.update( + { + "auto_approval_blocked": True, + "block_reason": first.get("block_reason"), + "estimated_cost_usd": first.get("estimated_cost_usd"), + "remaining_cap_usd": first.get("remaining_cap_usd"), + } + ) + await session.send_event( + Event( + event_type="approval_required", + data=event_data, + ) + ) + + # Store all approval-requiring tools (ToolCall objects for execution) + session.pending_approval = { + "tool_calls": [tc for tc, _, _, _ in approval_required_tools], + } + + # Return early - wait for EXEC_APPROVAL operation + return None + + iteration += 1 + + except ContextWindowExceededError: + # Force compact and retry this iteration. + cm = session.context_manager + logger.warning( + "ContextWindowExceededError at iteration %d — forcing compaction " + "(usage=%d, model_max_tokens=%d, messages=%d)", + iteration, + cm.running_context_usage, + cm.model_max_tokens, + len(cm.items), + ) + cm.running_context_usage = cm.model_max_tokens + 1 + await _compact_and_notify(session) + # Same guard as the top of the loop: if compaction couldn't + # bring us under threshold, _compact_and_notify has already + # emitted session_terminated and set is_running=False. Continue + # would just re-call the LLM with the same too-big context. + if not session.is_running: + break + continue + + except Exception as e: + import traceback + + error_msg = _friendly_error_message( + e, + user_plan=getattr(session, "user_plan", None), + ) + if error_msg is None: + error_msg = str(e) + "\n" + traceback.format_exc() + + await session.send_event( + Event( + event_type="error", + data={"error": error_msg}, + ) + ) + errored = True + break + + if session.is_cancelled: + await _cleanup_on_cancel(session) + await session.send_event(Event(event_type="interrupted")) + elif not errored: + if await _maybe_pause_for_usage_threshold( + session, + continuation="complete_turn", + final_response=final_response + if isinstance(final_response, str) + else None, + ): + return final_response + await session.send_event( + Event( + event_type="turn_complete", + data={ + "history_size": len(session.context_manager.items), + "final_response": final_response + if isinstance(final_response, str) + else None, + }, + ) + ) + + # Increment turn counter and check for auto-save + session.increment_turn() + await session.auto_save_if_needed() + + return final_response + + @staticmethod + async def undo(session: Session) -> None: + """Remove the last complete turn and notify the frontend.""" + removed = session.context_manager.undo_last_turn() + if not removed: + logger.warning("Undo: no user message found to remove") + await session.send_event(Event(event_type="undo_complete")) + + @staticmethod + async def new_conversation(session: Session, *, clear_screen: bool = False) -> None: + """Start a fresh conversation inside the active runtime.""" + try: + result = session.start_new_conversation() + except Exception as e: + await session.send_event( + Event(event_type="error", data={"error": f"New chat failed: {e}"}) + ) + return + result["clear_screen"] = clear_screen + await session.send_event(Event(event_type="new_complete", data=result)) + + @staticmethod + async def resume(session: Session, path: str) -> None: + """Reload context from a saved session log into the active session.""" + from agent.core.session_resume import restore_session_from_log + + try: + result = restore_session_from_log(session, Path(path)) + except Exception as e: + await session.send_event( + Event(event_type="error", data={"error": f"Resume failed: {e}"}) + ) + return + await session.send_event(Event(event_type="resume_complete", data=result)) + + @staticmethod + async def _exec_usage_threshold_approval( + session: Session, approvals: list[dict] + ) -> None: + pending = ( + session.pending_approval + if isinstance(session.pending_approval, dict) + else {} + ) + tool_call_id = str(pending.get("tool_call_id") or "") + approval = next( + (item for item in approvals if item.get("tool_call_id") == tool_call_id), + {"approved": False}, + ) + approved = bool(approval.get("approved")) + + session.pending_approval = None + if not tool_call_id: + await session.send_event( + Event( + event_type="error", + data={"error": "Usage approval is missing its approval id"}, + ) + ) + return + + if not approved: + feedback = str(approval.get("feedback") or "Stopped by user").strip() + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tool_call_id, + "tool": USAGE_THRESHOLD_TOOL_NAME, + "state": "rejected", + }, + ) + ) + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": USAGE_THRESHOLD_TOOL_NAME, + "tool_call_id": tool_call_id, + "output": feedback, + "success": False, + }, + ) + ) + await session.send_event(Event(event_type="interrupted")) + session.increment_turn() + await session.auto_save_if_needed() + return + + current_spend = _coerce_float(pending.get("current_spend_usd")) + acknowledged_threshold = _coerce_float(pending.get("threshold_usd")) + next_threshold = next_usage_warning_threshold( + current_spend, + acknowledged_threshold, + ) + session.usage_warning_next_threshold_usd = next_threshold + pending["next_threshold_usd"] = next_threshold + + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tool_call_id, + "tool": USAGE_THRESHOLD_TOOL_NAME, + "state": "approved", + }, + ) + ) + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": USAGE_THRESHOLD_TOOL_NAME, + "tool_call_id": tool_call_id, + "output": _usage_output_message(pending), + "success": True, + }, + ) + ) + + if pending.get("continuation") == "complete_turn": + final_response = pending.get("final_response") + await session.send_event( + Event( + event_type="turn_complete", + data={ + "history_size": int( + pending.get("history_size") + or len(session.context_manager.items) + ), + "final_response": final_response + if isinstance(final_response, str) + else None, + }, + ) + ) + session.increment_turn() + await session.auto_save_if_needed() + return + + await Handlers.run_agent(session, "") + + @staticmethod + async def _exec_yolo_budget_approval( + session: Session, approvals: list[dict] + ) -> None: + pending = ( + session.pending_approval + if isinstance(session.pending_approval, dict) + else {} + ) + tool_call_id = str(pending.get("tool_call_id") or "") + approval = next( + (item for item in approvals if item.get("tool_call_id") == tool_call_id), + {"approved": False}, + ) + approved = bool(approval.get("approved")) + + if not tool_call_id: + session.pending_approval = None + await session.send_event( + Event( + event_type="error", + data={"error": "YOLO budget approval is missing its approval id"}, + ) + ) + return + + if not approved: + session.pending_approval = None + feedback = str(approval.get("feedback") or "Stopped by user").strip() + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tool_call_id, + "tool": "yolo_budget", + "state": "rejected", + }, + ) + ) + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": "yolo_budget", + "tool_call_id": tool_call_id, + "output": feedback, + "success": False, + }, + ) + ) + await session.send_event(Event(event_type="interrupted")) + session.increment_turn() + await session.auto_save_if_needed() + return + + can_resume, reason = yolo_budget_can_resume(session, pending) + if not can_resume: + pending["reason"] = reason + pending["current_spend_usd"] = round( + float( + getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0 + ), + 6, + ) + pending["remaining_cap_usd"] = ( + None + if getattr(session, "auto_approval_cost_cap_usd", None) is None + else session.auto_approval_remaining_usd + ) + tool = yolo_budget_pending_to_tool(pending) + await session.send_event( + Event( + event_type="approval_required", + data={ + "tools": [tool], + "count": 1, + "yolo_budget": True, + "auto_approval_blocked": True, + "block_reason": reason, + "estimated_cost_usd": pending.get("estimated_next_usd"), + "remaining_cap_usd": pending.get("remaining_cap_usd"), + }, + ) + ) + return + + session.pending_approval = None + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tool_call_id, + "tool": "yolo_budget", + "state": "approved", + }, + ) + ) + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": "yolo_budget", + "tool_call_id": tool_call_id, + "output": "YOLO budget check acknowledged.", + "success": True, + }, + ) + ) + + if pending.get("continuation") == "complete_turn": + final_response = pending.get("final_response") + await session.send_event( + Event( + event_type="turn_complete", + data={ + "history_size": int( + pending.get("history_size") + or len(session.context_manager.items) + ), + "final_response": final_response + if isinstance(final_response, str) + else None, + }, + ) + ) + session.increment_turn() + await session.auto_save_if_needed() + return + + await Handlers.run_agent(session, "") + + @staticmethod + async def exec_approval(session: Session, approvals: list[dict]) -> None: + """Handle batch job execution approval""" + if not session.pending_approval: + await session.send_event( + Event( + event_type="error", + data={"error": "No pending approval to process"}, + ) + ) + return + + if is_usage_threshold_pending(session.pending_approval): + await Handlers._exec_usage_threshold_approval(session, approvals) + return + if is_yolo_budget_pending(session.pending_approval): + await Handlers._exec_yolo_budget_approval(session, approvals) + return + + tool_calls = session.pending_approval.get("tool_calls", []) + if not tool_calls: + await session.send_event( + Event( + event_type="error", + data={"error": "No pending tool calls found"}, + ) + ) + return + + # Create a map of tool_call_id -> approval decision + approval_map = {a["tool_call_id"]: a for a in approvals} + for a in approvals: + if a.get("edited_script"): + logger.info( + f"Received edited script for tool_call {a['tool_call_id']} ({len(a['edited_script'])} chars)" + ) + + # Separate approved and rejected tool calls + approved_tasks = [] + rejected_tasks = [] + + for tc in tool_calls: + tool_name = tc.function.name + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError) as e: + # Malformed arguments — treat as failed, notify agent + logger.warning(f"Malformed tool arguments for {tool_name}: {e}") + tool_msg = Message( + role="tool", + content=f"Malformed arguments: {e}", + tool_call_id=tc.id, + name=tool_name, + ) + session.context_manager.add_message(tool_msg) + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": tool_name, + "tool_call_id": tc.id, + "output": f"Malformed arguments: {e}", + "success": False, + }, + ) + ) + continue + + approval_decision = approval_map.get(tc.id, {"approved": False}) + + if approval_decision.get("approved", False): + edited_script = approval_decision.get("edited_script") + was_edited = False + if edited_script and "script" in tool_args: + tool_args["script"] = edited_script + was_edited = True + logger.info(f"Using user-edited script for {tool_name} ({tc.id})") + selected_namespace = approval_decision.get("namespace") + if selected_namespace and tool_name == "hf_jobs": + tool_args["namespace"] = selected_namespace + approved_tasks.append((tc, tool_name, tool_args, was_edited)) + else: + rejected_tasks.append((tc, tool_name, approval_decision)) + + reserved_manual_spend_usd = 0.0 + blocked_manual_budget: tuple[ToolCall, str, BudgetDecision] | None = None + for tc, tool_name, tool_args, _was_edited in approved_tasks: + budget = await _check_manual_approved_budget( + session, + tool_name, + tool_args, + reserved_spend_usd=reserved_manual_spend_usd, + ) + if not budget.allowed: + blocked_manual_budget = (tc, tool_name, budget) + break + if budget.billable and budget.estimated_cost_usd is not None: + reserved_manual_spend_usd += budget.estimated_cost_usd + + if blocked_manual_budget is not None: + blocked_tc, _blocked_tool, blocked_budget = blocked_manual_budget + tools_data = [] + for tc in tool_calls: + try: + args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, AttributeError, TypeError): + args = {} + payload = { + "tool": getattr(tc.function, "name", None), + "arguments": args, + "tool_call_id": tc.id, + } + if tc.id == blocked_tc.id: + payload.update( + { + "auto_approval_blocked": True, + "block_reason": blocked_budget.block_reason, + "estimated_cost_usd": blocked_budget.estimated_cost_usd, + "remaining_cap_usd": blocked_budget.remaining_cap_usd, + } + ) + tools_data.append(payload) + await session.send_event( + Event( + event_type="approval_required", + data={ + "tools": tools_data, + "count": len(tools_data), + "auto_approval_blocked": True, + "block_reason": blocked_budget.block_reason, + "estimated_cost_usd": blocked_budget.estimated_cost_usd, + "remaining_cap_usd": blocked_budget.remaining_cap_usd, + }, + ) + ) + return + + # Clear pending approval immediately so a page refresh during + # execution won't re-show the approval dialog. + session.pending_approval = None + + # Notify frontend of approval decisions immediately (before execution) + for tc, tool_name, tool_args, _was_edited in approved_tasks: + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "approved", + }, + ) + ) + for tc, tool_name, approval_decision in rejected_tasks: + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "rejected", + }, + ) + ) + + # Execute all approved tools concurrently + async def execute_tool(tc, tool_name, tool_args, was_edited): + """Execute a single tool and return its result. + + The TraceLog already exists on the frontend (created by + approval_required), so we send tool_state_change instead of + tool_call to avoid creating a duplicate. + """ + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "running", + }, + ) + ) + + budget = await _record_manual_approved_spend_if_needed( + session, + tool_name, + tool_args, + tool_call_id=tc.id, + ) + if not budget.allowed: + return ( + tc, + tool_name, + budget.block_reason or "YOLO budget blocked this tool call.", + False, + was_edited, + ) + + output, success = await session.tool_router.call_tool( + tool_name, tool_args, session=session, tool_call_id=tc.id + ) + if not success and budget.reservation: + release_budget_reservation(session, budget.reservation.reservation_id) + + return (tc, tool_name, output, success, was_edited) + + # Execute all approved tools concurrently (cancellable) + if approved_tasks: + gather_task = asyncio.ensure_future( + asyncio.gather( + *[ + execute_tool(tc, tool_name, tool_args, was_edited) + for tc, tool_name, tool_args, was_edited in approved_tasks + ], + return_exceptions=True, + ) + ) + cancel_task = asyncio.ensure_future(session._cancelled.wait()) + + done, _ = await asyncio.wait( + [gather_task, cancel_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + if cancel_task in done: + gather_task.cancel() + try: + await gather_task + except asyncio.CancelledError: + pass + # Notify frontend that approved tools were cancelled + for tc, tool_name, _args, _was_edited in approved_tasks: + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "cancelled", + }, + ) + ) + await _cleanup_on_cancel(session) + await session.send_event(Event(event_type="interrupted")) + session.increment_turn() + await session.auto_save_if_needed() + return + + cancel_task.cancel() + results = gather_task.result() + + # Process results and add to context + for result in results: + if isinstance(result, Exception): + # Handle execution error + logger.error(f"Tool execution error: {result}") + continue + + tc, tool_name, output, success, was_edited = result + + if was_edited: + output = f"[Note: The user edited the script before execution. The output below reflects the user-modified version, not your original script.]\n\n{output}" + + # Add tool result to context + tool_msg = Message( + role="tool", + content=output, + tool_call_id=tc.id, + name=tool_name, + ) + session.context_manager.add_message(tool_msg) + + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": tool_name, + "tool_call_id": tc.id, + "output": output, + "success": success, + }, + ) + ) + + # Process rejected tools + for tc, tool_name, approval_decision in rejected_tasks: + rejection_msg = "Job execution cancelled by user" + user_feedback = approval_decision.get("feedback") + if user_feedback: + # Ensure feedback is a string and sanitize any problematic characters + feedback_str = str(user_feedback).strip() + # Remove any control characters that might break JSON parsing + feedback_str = "".join( + char for char in feedback_str if ord(char) >= 32 or char in "\n\t" + ) + rejection_msg += f". User feedback: {feedback_str}" + + # Ensure rejection_msg is a clean string + rejection_msg = str(rejection_msg).strip() + + tool_msg = Message( + role="tool", + content=rejection_msg, + tool_call_id=tc.id, + name=tool_name, + ) + session.context_manager.add_message(tool_msg) + + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": tool_name, + "tool_call_id": tc.id, + "output": rejection_msg, + "success": False, + }, + ) + ) + + # Continue agent loop with empty input to process the tool results + await Handlers.run_agent(session, "") + + @staticmethod + async def shutdown(session: Session) -> bool: + """Handle shutdown (like shutdown in codex.rs:1329)""" + # Save session trajectory if enabled (fire-and-forget, returns immediately) + if session.config.save_sessions: + logger.info("Saving session...") + repo_id = session.config.session_dataset_repo + _ = session.save_and_upload_detached(repo_id) + + session.is_running = False + if not getattr(session, "local_mode", False): + await teardown_session_sandbox(session) + await session.send_event(Event(event_type="shutdown")) + return True + + +async def process_submission(session: Session, submission) -> bool: + """ + Process a single submission and return whether to continue running. + + Returns: + bool: True to continue, False to shutdown + """ + op = submission.operation + logger.debug("Received operation: %s", op.op_type.value) + + if op.op_type == OpType.USER_INPUT: + text = op.data.get("text", "") if op.data else "" + await Handlers.run_agent(session, text) + return True + + if op.op_type == OpType.COMPACT: + await _compact_and_notify(session) + return True + + if op.op_type == OpType.UNDO: + await Handlers.undo(session) + return True + + if op.op_type == OpType.NEW: + clear_screen = bool((op.data or {}).get("clear_screen")) + await Handlers.new_conversation(session, clear_screen=clear_screen) + return True + + if op.op_type == OpType.RESUME: + path = op.data.get("path") if op.data else None + if path: + await Handlers.resume(session, path) + else: + await session.send_event( + Event(event_type="error", data={"error": "Resume requires a path"}) + ) + return True + + if op.op_type == OpType.EXEC_APPROVAL: + approvals = op.data.get("approvals", []) if op.data else [] + await Handlers.exec_approval(session, approvals) + return True + + if op.op_type == OpType.SHUTDOWN: + return not await Handlers.shutdown(session) + + logger.warning(f"Unknown operation: {op.op_type}") + return True + + +async def submission_loop( + submission_queue: asyncio.Queue, + event_queue: asyncio.Queue, + config: Config, + tool_router: ToolRouter | None = None, + session_holder: list | None = None, + hf_token: str | None = None, + user_id: str | None = None, + local_mode: bool = False, + stream: bool = True, + notification_gateway: NotificationGateway | None = None, + notification_destinations: list[str] | None = None, + defer_turn_complete_notification: bool = False, + user_plan: str | None = None, +) -> None: + """ + Main agent loop - processes submissions and dispatches to handlers. + This is the core of the agent (like submission_loop in codex.rs:1259-1340) + """ + + # Create session with tool router + session = Session( + event_queue, + config=config, + tool_router=tool_router, + hf_token=hf_token, + user_id=user_id, + user_plan=user_plan, + local_mode=local_mode, + stream=stream, + notification_gateway=notification_gateway, + notification_destinations=notification_destinations, + defer_turn_complete_notification=defer_turn_complete_notification, + ) + if session_holder is not None: + session_holder[0] = session + if not local_mode: + start_cpu_sandbox_preload(session) + logger.info("Agent loop started") + + # Retry any failed uploads from previous sessions (fire-and-forget). + # Includes the personal trace repo when enabled so a session that failed + # to publish to the user's HF dataset gets a fresh attempt on next run. + if config and config.save_sessions: + Session.retry_failed_uploads_detached( + directory=str(DEFAULT_SESSION_LOG_DIR), + repo_id=config.session_dataset_repo, + personal_repo_id=session._personal_trace_repo_id(), + ) + + try: + # Main processing loop + async with tool_router: + # Emit ready event after initialization + await session.send_event( + Event( + event_type="ready", + data={ + "message": "Agent initialized", + "tool_count": len(tool_router.tools), + }, + ) + ) + + while session.is_running: + submission = await submission_queue.get() + + try: + should_continue = await process_submission(session, submission) + if not should_continue: + break + except asyncio.CancelledError: + logger.warning("Agent loop cancelled") + break + except Exception as e: + logger.error(f"Error in agent loop: {e}") + await session.send_event( + Event(event_type="error", data={"error": str(e)}) + ) + + logger.info("Agent loop exited") + + finally: + # Emergency save if session saving is enabled and shutdown wasn't called properly + if session.config.save_sessions and session.is_running: + logger.info("Emergency save: preserving session before exit...") + try: + local_path = session.save_and_upload_detached( + session.config.session_dataset_repo + ) + if local_path: + logger.info("Emergency save successful, upload in progress") + except Exception as e: + logger.error(f"Emergency save failed: {e}")