from __future__ import annotations import asyncio import json import time import uuid from typing import Any, AsyncIterator from .config import settings from .providers import CompletionAccumulator, ProviderError, provider_from_overrides from .security import redact_text from .sessions import sessions from .shell import shell_exec_stream SYSTEM_PROMPT = """You are an autonomous shell-based AI agent. You are working inside the root directory of the current session. The visible `/` is your session root. Uploaded files are available in `/uploads`. Put files you want to return to the owner in `/outputs`. Use the `shell_exec` tool to inspect files, create files, modify files, install dependencies, execute commands, run tests, debug problems, and complete the owner’s requests. Choose the appropriate language, tools, commands, libraries, and workflow yourself. You have broad freedom in this environment. Do not invent command results. After using tools, give the owner a clear final summary and mention any output files you created.""" MAX_AGENT_ITERATIONS = 24 ACTIVE_CANCEL_EVENTS: dict[str, asyncio.Event] = {} ACTIVE_LOCKS: dict[str, asyncio.Lock] = {} def session_is_active(session_id: str) -> bool: event = ACTIVE_CANCEL_EVENTS.get(session_id) return bool(event and not event.is_set()) def request_generation_stop(session_id: str) -> bool: event = ACTIVE_CANCEL_EVENTS.get(session_id) if not event: return False event.set() sessions.log(session_id, "generation_stop_requested", {}) return True def _lock_for(session_id: str) -> asyncio.Lock: lock = ACTIVE_LOCKS.get(session_id) if lock is None: lock = asyncio.Lock() ACTIVE_LOCKS[session_id] = lock return lock def _assistant_message_from_acc(acc: CompletionAccumulator) -> dict[str, Any]: message: dict[str, Any] = {"role": "assistant", "content": acc.content or ""} if acc.reasoning: message["reasoning"] = acc.reasoning if acc.cancelled: message["cancelled"] = True calls = acc.tool_calls if calls and not acc.cancelled: message["tool_calls"] = calls return message def _parse_tool_arguments(call: dict[str, Any]) -> tuple[str | None, int | None, str | None]: fn = call.get("function") or {} if fn.get("name") != "shell_exec": return None, None, f"Unsupported tool requested: {fn.get('name')}. The only available tool is shell_exec." raw_args = fn.get("arguments") or "{}" try: args = json.loads(raw_args) except json.JSONDecodeError as exc: return None, None, f"Malformed tool arguments: {exc}. Raw arguments: {raw_args[:1000]}" command = args.get("command") if not isinstance(command, str) or not command.strip(): return None, None, "shell_exec requires a non-empty string command." timeout = args.get("timeout_seconds") if timeout is not None: try: timeout = int(timeout) except Exception: timeout = None return command, timeout, None async def _drain_queued_messages(session_id: str, history: list[dict[str, Any]]) -> list[dict[str, Any]]: queued = sessions.drain_queued_messages(session_id) for item in queued: text = str(item.get("content") or "").strip() if not text: continue history.append({"role": "user", "content": text, "queued_id": item.get("id")}) sessions.log(session_id, "queued_message_consumed", {"id": item.get("id"), "text": text}) if queued: sessions.save_history(session_id, history) return queued async def _yield_drain_queued(session_id: str, history: list[dict[str, Any]]) -> AsyncIterator[dict[str, Any]]: queued = await _drain_queued_messages(session_id, history) for item in queued: yield {"type": "queued_consumed", "message": item} async def _execute_tool_calls( session_id: str, history: list[dict[str, Any]], tool_calls: list[dict[str, Any]], cancel_event: asyncio.Event, ) -> AsyncIterator[dict[str, Any]]: for call in tool_calls: if cancel_event.is_set(): yield {"type": "generation_cancelled", "message": "Stopped before the next tool call."} return command, timeout, error = _parse_tool_arguments(call) tool_call_id = call.get("id") or f"call_{uuid.uuid4().hex[:12]}" if error: result = { "command": command or "", "stdout": "", "stderr": error, "exit_code": -1, "duration_seconds": 0, "timed_out": False, "cancelled": False, "session_id": session_id, "session_name": sessions.get_metadata(session_id).get("name", session_id), "working_directory": "/", } yield {"type": "tool_result", "result": result} history.append({"role": "tool", "tool_call_id": tool_call_id, "name": "shell_exec", "content": json.dumps(result)}) sessions.save_history(session_id, history) continue final_result: dict[str, Any] | None = None async for event in shell_exec_stream(session_id, command, timeout): yield event if event.get("type") == "tool_result": final_result = event["result"] if final_result is None: final_result = { "command": command, "stdout": "", "stderr": "Command ended without a result object.", "exit_code": -1, "duration_seconds": 0, "timed_out": False, "cancelled": False, "session_id": session_id, "session_name": sessions.get_metadata(session_id).get("name", session_id), "working_directory": "/", } history.append({"role": "tool", "tool_call_id": tool_call_id, "name": "shell_exec", "content": json.dumps(final_result)}) sessions.save_history(session_id, history) async for event in _yield_drain_queued(session_id, history): yield event if cancel_event.is_set(): yield {"type": "generation_cancelled", "message": "Stopped after tool execution."} return async def run_agent( session_id: str, user_message: str | None = None, overrides: dict[str, Any] | None = None, approved_pending: dict[str, Any] | None = None, ) -> AsyncIterator[dict[str, Any]]: overrides = overrides or {} lock = _lock_for(session_id) if lock.locked() and user_message is not None: queued = sessions.enqueue_message(session_id, user_message) yield {"type": "queued", "message": queued} return async with lock: cancel_event = asyncio.Event() ACTIVE_CANCEL_EVENTS[session_id] = cancel_event try: history = sessions.load_history(session_id) if approved_pending: tool_calls = approved_pending.get("tool_calls") or [] sessions.clear_pending(session_id) async for event in _execute_tool_calls(session_id, history, tool_calls, cancel_event): yield event async for event in _yield_drain_queued(session_id, history): yield event elif user_message is not None: history.append({"role": "user", "content": user_message}) sessions.save_history(session_id, history) sessions.log(session_id, "user_message", {"text": user_message}) else: raise ValueError("run_agent requires a user_message or approved_pending") for iteration in range(MAX_AGENT_ITERATIONS): async for event in _yield_drain_queued(session_id, history): yield event if cancel_event.is_set(): yield {"type": "generation_cancelled", "message": "Generation stopped by owner."} yield {"type": "refresh"} return messages = [{"role": "system", "content": SYSTEM_PROMPT}, *history] acc = CompletionAccumulator() provider = provider_from_overrides(overrides) yield {"type": "assistant_start", "iteration": iteration + 1, "provider": provider.name, "provider_label": provider.label} try: async for event in provider.stream_chat(messages, overrides, acc, cancel_event): yield event except ProviderError as exc: message = redact_text(str(exc)) sessions.log(session_id, "provider_error", {"provider": provider.name, "message": message}) yield {"type": "error", "message": message} yield {"type": "refresh"} return except Exception as exc: message = redact_text(str(exc)) sessions.log(session_id, "provider_error", {"provider": provider.name, "message": message}) yield {"type": "error", "message": message} yield {"type": "refresh"} return assistant_message = _assistant_message_from_acc(acc) if acc.content or acc.reasoning or acc.tool_calls or acc.cancelled: history.append(assistant_message) sessions.save_history(session_id, history) sessions.log( session_id, "assistant_message", {"provider": provider.name, "content": acc.content, "reasoning": acc.reasoning, "tool_calls": acc.tool_calls, "cancelled": acc.cancelled}, ) if acc.cancelled or cancel_event.is_set(): async for event in _yield_drain_queued(session_id, history): yield event yield {"type": "assistant_done", "content": acc.content, "reasoning": acc.reasoning, "cancelled": True} yield {"type": "refresh"} return tool_calls = acc.tool_calls if not tool_calls: queued_after_stream = await _drain_queued_messages(session_id, history) for item in queued_after_stream: yield {"type": "queued_consumed", "message": item} if queued_after_stream: yield {"type": "assistant_done", "content": acc.content, "reasoning": acc.reasoning} continue yield {"type": "assistant_done", "content": acc.content, "reasoning": acc.reasoning} yield {"type": "refresh"} return if settings.approval_mode: pending = {"created_at": time.time(), "tool_calls": tool_calls, "session_id": session_id} sessions.save_pending(session_id, pending) yield {"type": "approval_required", "tool_calls": tool_calls} yield {"type": "refresh"} return async for event in _execute_tool_calls(session_id, history, tool_calls, cancel_event): yield event yield {"type": "refresh"} message = f"Stopped after {MAX_AGENT_ITERATIONS} tool-calling iterations to avoid an infinite loop. Ask me to continue if more work is needed." history.append({"role": "assistant", "content": message}) sessions.save_history(session_id, history) yield {"type": "assistant_start", "iteration": MAX_AGENT_ITERATIONS + 1} yield {"type": "assistant_delta", "text": message} yield {"type": "assistant_done", "content": message} yield {"type": "refresh"} finally: current = ACTIVE_CANCEL_EVENTS.get(session_id) if current is cancel_event: ACTIVE_CANCEL_EVENTS.pop(session_id, None)