AI-agent / app /shell.py
riddhiman's picture
Replace all files: ai-shell-workstation-provider-upgrade
850576b verified
from __future__ import annotations
import asyncio
import os
import shutil
import signal
import time
from pathlib import Path
from typing import Any, AsyncIterator
from .config import settings
from .security import redact_text
from .sessions import SessionPaths, sessions
RUNNING_PROCESSES: dict[str, asyncio.subprocess.Process] = {}
COMMAND_CANCEL_REQUESTS: set[str] = set()
def truncate_preserving_tail(text: str, limit: int) -> tuple[str, bool]:
if len(text) <= limit:
return text, False
head = max(0, limit // 3)
tail = max(0, limit - head - 120)
return text[:head] + f"\n\n[... output truncated to {limit} characters ...]\n\n" + text[-tail:], True
def _existing_bind_args() -> list[str]:
binds: list[str] = []
candidates = [
"/bin",
"/usr",
"/lib",
"/lib64",
"/opt",
"/etc/resolv.conf",
"/etc/hosts",
"/etc/passwd",
"/etc/group",
"/etc/ssl",
"/dev/null",
"/dev/urandom",
"/dev/random",
"/proc",
"/tmp",
]
for candidate in candidates:
if Path(candidate).exists():
binds.extend(["-b", candidate])
return binds
def build_shell_args(paths: SessionPaths, command: str) -> tuple[list[str], str]:
proot = shutil.which("proot")
if proot:
args = [
proot,
"-r",
str(paths.root),
"-w",
"/",
*_existing_bind_args(),
"/bin/bash",
"-lc",
command,
]
return args, "/"
return ["/bin/bash", "-lc", command], str(paths.root)
async def stop_running_process(session_id: str) -> bool:
proc = RUNNING_PROCESSES.get(session_id)
if not proc or proc.returncode is not None:
return False
COMMAND_CANCEL_REQUESTS.add(session_id)
try:
if proc.pid:
os.killpg(proc.pid, signal.SIGTERM)
except ProcessLookupError:
return False
except Exception:
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=5)
except asyncio.TimeoutError:
try:
if proc.pid:
os.killpg(proc.pid, signal.SIGKILL)
except Exception:
proc.kill()
return True
async def shell_exec_stream(
session_id: str,
command: str,
timeout_seconds: int | None = None,
) -> AsyncIterator[dict[str, Any]]:
paths = sessions.ensure(session_id)
metadata = sessions.get_metadata(session_id)
timeout = timeout_seconds or settings.max_tool_runtime_seconds
timeout = max(1, min(int(timeout), settings.max_tool_runtime_seconds))
args, cwd = build_shell_args(paths, command)
env = os.environ.copy()
env.update(
{
"SESSION_ID": session_id,
"SESSION_NAME": str(metadata.get("name", session_id)),
"SESSION_ROOT": "/",
"UPLOADS_DIR": "/uploads",
"OUTPUTS_DIR": "/outputs",
"WORKDIR": "/work",
"HOME": "/work",
"PYTHONUNBUFFERED": "1",
}
)
start = time.monotonic()
sessions.log(session_id, "shell_start", {"command": command, "timeout_seconds": timeout, "working_directory": "/"})
yield {
"type": "tool_start",
"command": command,
"timeout_seconds": timeout,
"session_id": session_id,
"session_name": metadata.get("name", session_id),
"working_directory": "/",
}
stdout_parts: list[str] = []
stderr_parts: list[str] = []
stdout_chars = 0
stderr_chars = 0
max_chars = settings.max_output_chars
timed_out = False
cancelled = False
exit_code: int | None = None
try:
proc = await asyncio.create_subprocess_exec(
*args,
cwd=cwd,
env=env,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
start_new_session=True,
)
RUNNING_PROCESSES[session_id] = proc
except Exception as exc:
duration = time.monotonic() - start
error = redact_text(str(exc))
result = {
"command": command,
"stdout": "",
"stderr": error,
"exit_code": -1,
"duration_seconds": round(duration, 3),
"timed_out": False,
"session_id": session_id,
"session_name": metadata.get("name", session_id),
"working_directory": "/",
}
sessions.log(session_id, "shell_error", result)
yield {"type": "tool_result", "result": result}
return
async def read_stream(name: str, stream: asyncio.StreamReader | None) -> None:
nonlocal stdout_chars, stderr_chars
if stream is None:
return
while True:
chunk = await stream.read(4096)
if not chunk:
break
text = redact_text(chunk.decode("utf-8", errors="replace"))
if name == "stdout":
if stdout_chars < max_chars:
remaining = max_chars - stdout_chars
kept = text[:remaining]
stdout_parts.append(kept)
stdout_chars += len(kept)
else:
if stderr_chars < max_chars:
remaining = max_chars - stderr_chars
kept = text[:remaining]
stderr_parts.append(kept)
stderr_chars += len(kept)
# The reader tasks accumulate output for the model-visible result. A lightweight
# poller below emits new output to the browser while the command runs.
stdout_task = asyncio.create_task(read_stream("stdout", proc.stdout))
stderr_task = asyncio.create_task(read_stream("stderr", proc.stderr))
last_stdout_len = 0
last_stderr_len = 0
try:
wait_task = asyncio.create_task(proc.wait())
while True:
done, _ = await asyncio.wait({wait_task}, timeout=0.25)
stdout_text = "".join(stdout_parts)
stderr_text = "".join(stderr_parts)
if len(stdout_text) > last_stdout_len:
delta = stdout_text[last_stdout_len:]
last_stdout_len = len(stdout_text)
yield {"type": "tool_output_delta", "stream": "stdout", "text": delta}
if len(stderr_text) > last_stderr_len:
delta = stderr_text[last_stderr_len:]
last_stderr_len = len(stderr_text)
yield {"type": "tool_output_delta", "stream": "stderr", "text": delta}
if wait_task in done:
break
if time.monotonic() - start > timeout:
timed_out = True
try:
if proc.pid:
os.killpg(proc.pid, signal.SIGTERM)
except Exception:
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=5)
except asyncio.TimeoutError:
try:
if proc.pid:
os.killpg(proc.pid, signal.SIGKILL)
except Exception:
proc.kill()
break
await asyncio.gather(stdout_task, stderr_task, return_exceptions=True)
stdout_text = "".join(stdout_parts)
stderr_text = "".join(stderr_parts)
if len(stdout_text) > last_stdout_len:
yield {"type": "tool_output_delta", "stream": "stdout", "text": stdout_text[last_stdout_len:]}
if len(stderr_text) > last_stderr_len:
yield {"type": "tool_output_delta", "stream": "stderr", "text": stderr_text[last_stderr_len:]}
exit_code = proc.returncode
cancelled = session_id in COMMAND_CANCEL_REQUESTS
if cancelled and (stderr_chars < max_chars):
cancel_msg = "\n[command stopped by owner]"
stderr_parts.append(cancel_msg)
if timed_out and (stderr_chars < max_chars):
timeout_msg = f"\n[command timed out after {timeout} seconds]"
stderr_parts.append(timeout_msg)
finally:
RUNNING_PROCESSES.pop(session_id, None)
COMMAND_CANCEL_REQUESTS.discard(session_id)
for task in (stdout_task, stderr_task):
if not task.done():
task.cancel()
duration = time.monotonic() - start
stdout_final, stdout_truncated = truncate_preserving_tail("".join(stdout_parts), max_chars)
stderr_final, stderr_truncated = truncate_preserving_tail("".join(stderr_parts), max_chars)
result = {
"command": command,
"stdout": stdout_final,
"stderr": stderr_final,
"exit_code": exit_code if exit_code is not None else -1,
"duration_seconds": round(duration, 3),
"timed_out": timed_out,
"cancelled": cancelled,
"stdout_truncated": stdout_truncated or stdout_chars >= max_chars,
"stderr_truncated": stderr_truncated or stderr_chars >= max_chars,
"session_id": session_id,
"session_name": metadata.get("name", session_id),
"working_directory": "/",
}
sessions.log(session_id, "shell_result", result)
yield {"type": "tool_result", "result": result}