AI-agent / app /agent.py
riddhiman's picture
Replace all files: ai-shell-workstation-provider-upgrade
850576b verified
from __future__ import annotations
import asyncio
import json
import time
import uuid
from typing import Any, AsyncIterator
from .config import settings
from .providers import CompletionAccumulator, ProviderError, provider_from_overrides
from .security import redact_text
from .sessions import sessions
from .shell import shell_exec_stream
SYSTEM_PROMPT = """You are an autonomous shell-based AI agent. You are working inside the root directory of the current session. The visible `/` is your session root. Uploaded files are available in `/uploads`. Put files you want to return to the owner in `/outputs`. Use the `shell_exec` tool to inspect files, create files, modify files, install dependencies, execute commands, run tests, debug problems, and complete the owner’s requests. Choose the appropriate language, tools, commands, libraries, and workflow yourself. You have broad freedom in this environment. Do not invent command results. After using tools, give the owner a clear final summary and mention any output files you created."""
MAX_AGENT_ITERATIONS = 24
ACTIVE_CANCEL_EVENTS: dict[str, asyncio.Event] = {}
ACTIVE_LOCKS: dict[str, asyncio.Lock] = {}
def session_is_active(session_id: str) -> bool:
event = ACTIVE_CANCEL_EVENTS.get(session_id)
return bool(event and not event.is_set())
def request_generation_stop(session_id: str) -> bool:
event = ACTIVE_CANCEL_EVENTS.get(session_id)
if not event:
return False
event.set()
sessions.log(session_id, "generation_stop_requested", {})
return True
def _lock_for(session_id: str) -> asyncio.Lock:
lock = ACTIVE_LOCKS.get(session_id)
if lock is None:
lock = asyncio.Lock()
ACTIVE_LOCKS[session_id] = lock
return lock
def _assistant_message_from_acc(acc: CompletionAccumulator) -> dict[str, Any]:
message: dict[str, Any] = {"role": "assistant", "content": acc.content or ""}
if acc.reasoning:
message["reasoning"] = acc.reasoning
if acc.cancelled:
message["cancelled"] = True
calls = acc.tool_calls
if calls and not acc.cancelled:
message["tool_calls"] = calls
return message
def _parse_tool_arguments(call: dict[str, Any]) -> tuple[str | None, int | None, str | None]:
fn = call.get("function") or {}
if fn.get("name") != "shell_exec":
return None, None, f"Unsupported tool requested: {fn.get('name')}. The only available tool is shell_exec."
raw_args = fn.get("arguments") or "{}"
try:
args = json.loads(raw_args)
except json.JSONDecodeError as exc:
return None, None, f"Malformed tool arguments: {exc}. Raw arguments: {raw_args[:1000]}"
command = args.get("command")
if not isinstance(command, str) or not command.strip():
return None, None, "shell_exec requires a non-empty string command."
timeout = args.get("timeout_seconds")
if timeout is not None:
try:
timeout = int(timeout)
except Exception:
timeout = None
return command, timeout, None
async def _drain_queued_messages(session_id: str, history: list[dict[str, Any]]) -> list[dict[str, Any]]:
queued = sessions.drain_queued_messages(session_id)
for item in queued:
text = str(item.get("content") or "").strip()
if not text:
continue
history.append({"role": "user", "content": text, "queued_id": item.get("id")})
sessions.log(session_id, "queued_message_consumed", {"id": item.get("id"), "text": text})
if queued:
sessions.save_history(session_id, history)
return queued
async def _yield_drain_queued(session_id: str, history: list[dict[str, Any]]) -> AsyncIterator[dict[str, Any]]:
queued = await _drain_queued_messages(session_id, history)
for item in queued:
yield {"type": "queued_consumed", "message": item}
async def _execute_tool_calls(
session_id: str,
history: list[dict[str, Any]],
tool_calls: list[dict[str, Any]],
cancel_event: asyncio.Event,
) -> AsyncIterator[dict[str, Any]]:
for call in tool_calls:
if cancel_event.is_set():
yield {"type": "generation_cancelled", "message": "Stopped before the next tool call."}
return
command, timeout, error = _parse_tool_arguments(call)
tool_call_id = call.get("id") or f"call_{uuid.uuid4().hex[:12]}"
if error:
result = {
"command": command or "",
"stdout": "",
"stderr": error,
"exit_code": -1,
"duration_seconds": 0,
"timed_out": False,
"cancelled": False,
"session_id": session_id,
"session_name": sessions.get_metadata(session_id).get("name", session_id),
"working_directory": "/",
}
yield {"type": "tool_result", "result": result}
history.append({"role": "tool", "tool_call_id": tool_call_id, "name": "shell_exec", "content": json.dumps(result)})
sessions.save_history(session_id, history)
continue
final_result: dict[str, Any] | None = None
async for event in shell_exec_stream(session_id, command, timeout):
yield event
if event.get("type") == "tool_result":
final_result = event["result"]
if final_result is None:
final_result = {
"command": command,
"stdout": "",
"stderr": "Command ended without a result object.",
"exit_code": -1,
"duration_seconds": 0,
"timed_out": False,
"cancelled": False,
"session_id": session_id,
"session_name": sessions.get_metadata(session_id).get("name", session_id),
"working_directory": "/",
}
history.append({"role": "tool", "tool_call_id": tool_call_id, "name": "shell_exec", "content": json.dumps(final_result)})
sessions.save_history(session_id, history)
async for event in _yield_drain_queued(session_id, history):
yield event
if cancel_event.is_set():
yield {"type": "generation_cancelled", "message": "Stopped after tool execution."}
return
async def run_agent(
session_id: str,
user_message: str | None = None,
overrides: dict[str, Any] | None = None,
approved_pending: dict[str, Any] | None = None,
) -> AsyncIterator[dict[str, Any]]:
overrides = overrides or {}
lock = _lock_for(session_id)
if lock.locked() and user_message is not None:
queued = sessions.enqueue_message(session_id, user_message)
yield {"type": "queued", "message": queued}
return
async with lock:
cancel_event = asyncio.Event()
ACTIVE_CANCEL_EVENTS[session_id] = cancel_event
try:
history = sessions.load_history(session_id)
if approved_pending:
tool_calls = approved_pending.get("tool_calls") or []
sessions.clear_pending(session_id)
async for event in _execute_tool_calls(session_id, history, tool_calls, cancel_event):
yield event
async for event in _yield_drain_queued(session_id, history):
yield event
elif user_message is not None:
history.append({"role": "user", "content": user_message})
sessions.save_history(session_id, history)
sessions.log(session_id, "user_message", {"text": user_message})
else:
raise ValueError("run_agent requires a user_message or approved_pending")
for iteration in range(MAX_AGENT_ITERATIONS):
async for event in _yield_drain_queued(session_id, history):
yield event
if cancel_event.is_set():
yield {"type": "generation_cancelled", "message": "Generation stopped by owner."}
yield {"type": "refresh"}
return
messages = [{"role": "system", "content": SYSTEM_PROMPT}, *history]
acc = CompletionAccumulator()
provider = provider_from_overrides(overrides)
yield {"type": "assistant_start", "iteration": iteration + 1, "provider": provider.name, "provider_label": provider.label}
try:
async for event in provider.stream_chat(messages, overrides, acc, cancel_event):
yield event
except ProviderError as exc:
message = redact_text(str(exc))
sessions.log(session_id, "provider_error", {"provider": provider.name, "message": message})
yield {"type": "error", "message": message}
yield {"type": "refresh"}
return
except Exception as exc:
message = redact_text(str(exc))
sessions.log(session_id, "provider_error", {"provider": provider.name, "message": message})
yield {"type": "error", "message": message}
yield {"type": "refresh"}
return
assistant_message = _assistant_message_from_acc(acc)
if acc.content or acc.reasoning or acc.tool_calls or acc.cancelled:
history.append(assistant_message)
sessions.save_history(session_id, history)
sessions.log(
session_id,
"assistant_message",
{"provider": provider.name, "content": acc.content, "reasoning": acc.reasoning, "tool_calls": acc.tool_calls, "cancelled": acc.cancelled},
)
if acc.cancelled or cancel_event.is_set():
async for event in _yield_drain_queued(session_id, history):
yield event
yield {"type": "assistant_done", "content": acc.content, "reasoning": acc.reasoning, "cancelled": True}
yield {"type": "refresh"}
return
tool_calls = acc.tool_calls
if not tool_calls:
queued_after_stream = await _drain_queued_messages(session_id, history)
for item in queued_after_stream:
yield {"type": "queued_consumed", "message": item}
if queued_after_stream:
yield {"type": "assistant_done", "content": acc.content, "reasoning": acc.reasoning}
continue
yield {"type": "assistant_done", "content": acc.content, "reasoning": acc.reasoning}
yield {"type": "refresh"}
return
if settings.approval_mode:
pending = {"created_at": time.time(), "tool_calls": tool_calls, "session_id": session_id}
sessions.save_pending(session_id, pending)
yield {"type": "approval_required", "tool_calls": tool_calls}
yield {"type": "refresh"}
return
async for event in _execute_tool_calls(session_id, history, tool_calls, cancel_event):
yield event
yield {"type": "refresh"}
message = f"Stopped after {MAX_AGENT_ITERATIONS} tool-calling iterations to avoid an infinite loop. Ask me to continue if more work is needed."
history.append({"role": "assistant", "content": message})
sessions.save_history(session_id, history)
yield {"type": "assistant_start", "iteration": MAX_AGENT_ITERATIONS + 1}
yield {"type": "assistant_delta", "text": message}
yield {"type": "assistant_done", "content": message}
yield {"type": "refresh"}
finally:
current = ACTIVE_CANCEL_EVENTS.get(session_id)
if current is cancel_event:
ACTIVE_CANCEL_EVENTS.pop(session_id, None)