| from __future__ import annotations |
|
|
| import asyncio |
| import json |
| import time |
| import uuid |
| from typing import Any, AsyncIterator |
|
|
| from .config import settings |
| from .providers import CompletionAccumulator, ProviderError, provider_from_overrides |
| from .security import redact_text |
| from .sessions import sessions |
| from .shell import shell_exec_stream |
|
|
| SYSTEM_PROMPT = """You are an autonomous shell-based AI agent. You are working inside the root directory of the current session. The visible `/` is your session root. Uploaded files are available in `/uploads`. Put files you want to return to the owner in `/outputs`. Use the `shell_exec` tool to inspect files, create files, modify files, install dependencies, execute commands, run tests, debug problems, and complete the owner’s requests. Choose the appropriate language, tools, commands, libraries, and workflow yourself. You have broad freedom in this environment. Do not invent command results. After using tools, give the owner a clear final summary and mention any output files you created.""" |
|
|
| MAX_AGENT_ITERATIONS = 24 |
| ACTIVE_CANCEL_EVENTS: dict[str, asyncio.Event] = {} |
| ACTIVE_LOCKS: dict[str, asyncio.Lock] = {} |
|
|
|
|
| def session_is_active(session_id: str) -> bool: |
| event = ACTIVE_CANCEL_EVENTS.get(session_id) |
| return bool(event and not event.is_set()) |
|
|
|
|
| def request_generation_stop(session_id: str) -> bool: |
| event = ACTIVE_CANCEL_EVENTS.get(session_id) |
| if not event: |
| return False |
| event.set() |
| sessions.log(session_id, "generation_stop_requested", {}) |
| return True |
|
|
|
|
| def _lock_for(session_id: str) -> asyncio.Lock: |
| lock = ACTIVE_LOCKS.get(session_id) |
| if lock is None: |
| lock = asyncio.Lock() |
| ACTIVE_LOCKS[session_id] = lock |
| return lock |
|
|
|
|
| def _assistant_message_from_acc(acc: CompletionAccumulator) -> dict[str, Any]: |
| message: dict[str, Any] = {"role": "assistant", "content": acc.content or ""} |
| if acc.reasoning: |
| message["reasoning"] = acc.reasoning |
| if acc.cancelled: |
| message["cancelled"] = True |
| calls = acc.tool_calls |
| if calls and not acc.cancelled: |
| message["tool_calls"] = calls |
| return message |
|
|
|
|
| def _parse_tool_arguments(call: dict[str, Any]) -> tuple[str | None, int | None, str | None]: |
| fn = call.get("function") or {} |
| if fn.get("name") != "shell_exec": |
| return None, None, f"Unsupported tool requested: {fn.get('name')}. The only available tool is shell_exec." |
| raw_args = fn.get("arguments") or "{}" |
| try: |
| args = json.loads(raw_args) |
| except json.JSONDecodeError as exc: |
| return None, None, f"Malformed tool arguments: {exc}. Raw arguments: {raw_args[:1000]}" |
| command = args.get("command") |
| if not isinstance(command, str) or not command.strip(): |
| return None, None, "shell_exec requires a non-empty string command." |
| timeout = args.get("timeout_seconds") |
| if timeout is not None: |
| try: |
| timeout = int(timeout) |
| except Exception: |
| timeout = None |
| return command, timeout, None |
|
|
|
|
| async def _drain_queued_messages(session_id: str, history: list[dict[str, Any]]) -> list[dict[str, Any]]: |
| queued = sessions.drain_queued_messages(session_id) |
| for item in queued: |
| text = str(item.get("content") or "").strip() |
| if not text: |
| continue |
| history.append({"role": "user", "content": text, "queued_id": item.get("id")}) |
| sessions.log(session_id, "queued_message_consumed", {"id": item.get("id"), "text": text}) |
| if queued: |
| sessions.save_history(session_id, history) |
| return queued |
|
|
|
|
| async def _yield_drain_queued(session_id: str, history: list[dict[str, Any]]) -> AsyncIterator[dict[str, Any]]: |
| queued = await _drain_queued_messages(session_id, history) |
| for item in queued: |
| yield {"type": "queued_consumed", "message": item} |
|
|
|
|
| async def _execute_tool_calls( |
| session_id: str, |
| history: list[dict[str, Any]], |
| tool_calls: list[dict[str, Any]], |
| cancel_event: asyncio.Event, |
| ) -> AsyncIterator[dict[str, Any]]: |
| for call in tool_calls: |
| if cancel_event.is_set(): |
| yield {"type": "generation_cancelled", "message": "Stopped before the next tool call."} |
| return |
| command, timeout, error = _parse_tool_arguments(call) |
| tool_call_id = call.get("id") or f"call_{uuid.uuid4().hex[:12]}" |
| if error: |
| result = { |
| "command": command or "", |
| "stdout": "", |
| "stderr": error, |
| "exit_code": -1, |
| "duration_seconds": 0, |
| "timed_out": False, |
| "cancelled": False, |
| "session_id": session_id, |
| "session_name": sessions.get_metadata(session_id).get("name", session_id), |
| "working_directory": "/", |
| } |
| yield {"type": "tool_result", "result": result} |
| history.append({"role": "tool", "tool_call_id": tool_call_id, "name": "shell_exec", "content": json.dumps(result)}) |
| sessions.save_history(session_id, history) |
| continue |
| final_result: dict[str, Any] | None = None |
| async for event in shell_exec_stream(session_id, command, timeout): |
| yield event |
| if event.get("type") == "tool_result": |
| final_result = event["result"] |
| if final_result is None: |
| final_result = { |
| "command": command, |
| "stdout": "", |
| "stderr": "Command ended without a result object.", |
| "exit_code": -1, |
| "duration_seconds": 0, |
| "timed_out": False, |
| "cancelled": False, |
| "session_id": session_id, |
| "session_name": sessions.get_metadata(session_id).get("name", session_id), |
| "working_directory": "/", |
| } |
| history.append({"role": "tool", "tool_call_id": tool_call_id, "name": "shell_exec", "content": json.dumps(final_result)}) |
| sessions.save_history(session_id, history) |
| async for event in _yield_drain_queued(session_id, history): |
| yield event |
| if cancel_event.is_set(): |
| yield {"type": "generation_cancelled", "message": "Stopped after tool execution."} |
| return |
|
|
|
|
| async def run_agent( |
| session_id: str, |
| user_message: str | None = None, |
| overrides: dict[str, Any] | None = None, |
| approved_pending: dict[str, Any] | None = None, |
| ) -> AsyncIterator[dict[str, Any]]: |
| overrides = overrides or {} |
| lock = _lock_for(session_id) |
| if lock.locked() and user_message is not None: |
| queued = sessions.enqueue_message(session_id, user_message) |
| yield {"type": "queued", "message": queued} |
| return |
|
|
| async with lock: |
| cancel_event = asyncio.Event() |
| ACTIVE_CANCEL_EVENTS[session_id] = cancel_event |
| try: |
| history = sessions.load_history(session_id) |
|
|
| if approved_pending: |
| tool_calls = approved_pending.get("tool_calls") or [] |
| sessions.clear_pending(session_id) |
| async for event in _execute_tool_calls(session_id, history, tool_calls, cancel_event): |
| yield event |
| async for event in _yield_drain_queued(session_id, history): |
| yield event |
| elif user_message is not None: |
| history.append({"role": "user", "content": user_message}) |
| sessions.save_history(session_id, history) |
| sessions.log(session_id, "user_message", {"text": user_message}) |
| else: |
| raise ValueError("run_agent requires a user_message or approved_pending") |
|
|
| for iteration in range(MAX_AGENT_ITERATIONS): |
| async for event in _yield_drain_queued(session_id, history): |
| yield event |
| if cancel_event.is_set(): |
| yield {"type": "generation_cancelled", "message": "Generation stopped by owner."} |
| yield {"type": "refresh"} |
| return |
|
|
| messages = [{"role": "system", "content": SYSTEM_PROMPT}, *history] |
| acc = CompletionAccumulator() |
| provider = provider_from_overrides(overrides) |
| yield {"type": "assistant_start", "iteration": iteration + 1, "provider": provider.name, "provider_label": provider.label} |
| try: |
| async for event in provider.stream_chat(messages, overrides, acc, cancel_event): |
| yield event |
| except ProviderError as exc: |
| message = redact_text(str(exc)) |
| sessions.log(session_id, "provider_error", {"provider": provider.name, "message": message}) |
| yield {"type": "error", "message": message} |
| yield {"type": "refresh"} |
| return |
| except Exception as exc: |
| message = redact_text(str(exc)) |
| sessions.log(session_id, "provider_error", {"provider": provider.name, "message": message}) |
| yield {"type": "error", "message": message} |
| yield {"type": "refresh"} |
| return |
|
|
| assistant_message = _assistant_message_from_acc(acc) |
| if acc.content or acc.reasoning or acc.tool_calls or acc.cancelled: |
| history.append(assistant_message) |
| sessions.save_history(session_id, history) |
| sessions.log( |
| session_id, |
| "assistant_message", |
| {"provider": provider.name, "content": acc.content, "reasoning": acc.reasoning, "tool_calls": acc.tool_calls, "cancelled": acc.cancelled}, |
| ) |
|
|
| if acc.cancelled or cancel_event.is_set(): |
| async for event in _yield_drain_queued(session_id, history): |
| yield event |
| yield {"type": "assistant_done", "content": acc.content, "reasoning": acc.reasoning, "cancelled": True} |
| yield {"type": "refresh"} |
| return |
|
|
| tool_calls = acc.tool_calls |
| if not tool_calls: |
| queued_after_stream = await _drain_queued_messages(session_id, history) |
| for item in queued_after_stream: |
| yield {"type": "queued_consumed", "message": item} |
| if queued_after_stream: |
| yield {"type": "assistant_done", "content": acc.content, "reasoning": acc.reasoning} |
| continue |
| yield {"type": "assistant_done", "content": acc.content, "reasoning": acc.reasoning} |
| yield {"type": "refresh"} |
| return |
|
|
| if settings.approval_mode: |
| pending = {"created_at": time.time(), "tool_calls": tool_calls, "session_id": session_id} |
| sessions.save_pending(session_id, pending) |
| yield {"type": "approval_required", "tool_calls": tool_calls} |
| yield {"type": "refresh"} |
| return |
|
|
| async for event in _execute_tool_calls(session_id, history, tool_calls, cancel_event): |
| yield event |
| yield {"type": "refresh"} |
|
|
| message = f"Stopped after {MAX_AGENT_ITERATIONS} tool-calling iterations to avoid an infinite loop. Ask me to continue if more work is needed." |
| history.append({"role": "assistant", "content": message}) |
| sessions.save_history(session_id, history) |
| yield {"type": "assistant_start", "iteration": MAX_AGENT_ITERATIONS + 1} |
| yield {"type": "assistant_delta", "text": message} |
| yield {"type": "assistant_done", "content": message} |
| yield {"type": "refresh"} |
| finally: |
| current = ACTIVE_CANCEL_EVENTS.get(session_id) |
| if current is cancel_event: |
| ACTIVE_CANCEL_EVENTS.pop(session_id, None) |
|
|