from __future__ import annotations import asyncio import os import shutil import signal import time from pathlib import Path from typing import Any, AsyncIterator from .config import settings from .security import redact_text from .sessions import SessionPaths, sessions RUNNING_PROCESSES: dict[str, asyncio.subprocess.Process] = {} COMMAND_CANCEL_REQUESTS: set[str] = set() def truncate_preserving_tail(text: str, limit: int) -> tuple[str, bool]: if len(text) <= limit: return text, False head = max(0, limit // 3) tail = max(0, limit - head - 120) return text[:head] + f"\n\n[... output truncated to {limit} characters ...]\n\n" + text[-tail:], True def _existing_bind_args() -> list[str]: binds: list[str] = [] candidates = [ "/bin", "/usr", "/lib", "/lib64", "/opt", "/etc/resolv.conf", "/etc/hosts", "/etc/passwd", "/etc/group", "/etc/ssl", "/dev/null", "/dev/urandom", "/dev/random", "/proc", "/tmp", ] for candidate in candidates: if Path(candidate).exists(): binds.extend(["-b", candidate]) return binds def build_shell_args(paths: SessionPaths, command: str) -> tuple[list[str], str]: proot = shutil.which("proot") if proot: args = [ proot, "-r", str(paths.root), "-w", "/", *_existing_bind_args(), "/bin/bash", "-lc", command, ] return args, "/" return ["/bin/bash", "-lc", command], str(paths.root) async def stop_running_process(session_id: str) -> bool: proc = RUNNING_PROCESSES.get(session_id) if not proc or proc.returncode is not None: return False COMMAND_CANCEL_REQUESTS.add(session_id) try: if proc.pid: os.killpg(proc.pid, signal.SIGTERM) except ProcessLookupError: return False except Exception: proc.terminate() try: await asyncio.wait_for(proc.wait(), timeout=5) except asyncio.TimeoutError: try: if proc.pid: os.killpg(proc.pid, signal.SIGKILL) except Exception: proc.kill() return True async def shell_exec_stream( session_id: str, command: str, timeout_seconds: int | None = None, ) -> AsyncIterator[dict[str, Any]]: paths = sessions.ensure(session_id) metadata = sessions.get_metadata(session_id) timeout = timeout_seconds or settings.max_tool_runtime_seconds timeout = max(1, min(int(timeout), settings.max_tool_runtime_seconds)) args, cwd = build_shell_args(paths, command) env = os.environ.copy() env.update( { "SESSION_ID": session_id, "SESSION_NAME": str(metadata.get("name", session_id)), "SESSION_ROOT": "/", "UPLOADS_DIR": "/uploads", "OUTPUTS_DIR": "/outputs", "WORKDIR": "/work", "HOME": "/work", "PYTHONUNBUFFERED": "1", } ) start = time.monotonic() sessions.log(session_id, "shell_start", {"command": command, "timeout_seconds": timeout, "working_directory": "/"}) yield { "type": "tool_start", "command": command, "timeout_seconds": timeout, "session_id": session_id, "session_name": metadata.get("name", session_id), "working_directory": "/", } stdout_parts: list[str] = [] stderr_parts: list[str] = [] stdout_chars = 0 stderr_chars = 0 max_chars = settings.max_output_chars timed_out = False cancelled = False exit_code: int | None = None try: proc = await asyncio.create_subprocess_exec( *args, cwd=cwd, env=env, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, start_new_session=True, ) RUNNING_PROCESSES[session_id] = proc except Exception as exc: duration = time.monotonic() - start error = redact_text(str(exc)) result = { "command": command, "stdout": "", "stderr": error, "exit_code": -1, "duration_seconds": round(duration, 3), "timed_out": False, "session_id": session_id, "session_name": metadata.get("name", session_id), "working_directory": "/", } sessions.log(session_id, "shell_error", result) yield {"type": "tool_result", "result": result} return async def read_stream(name: str, stream: asyncio.StreamReader | None) -> None: nonlocal stdout_chars, stderr_chars if stream is None: return while True: chunk = await stream.read(4096) if not chunk: break text = redact_text(chunk.decode("utf-8", errors="replace")) if name == "stdout": if stdout_chars < max_chars: remaining = max_chars - stdout_chars kept = text[:remaining] stdout_parts.append(kept) stdout_chars += len(kept) else: if stderr_chars < max_chars: remaining = max_chars - stderr_chars kept = text[:remaining] stderr_parts.append(kept) stderr_chars += len(kept) # The reader tasks accumulate output for the model-visible result. A lightweight # poller below emits new output to the browser while the command runs. stdout_task = asyncio.create_task(read_stream("stdout", proc.stdout)) stderr_task = asyncio.create_task(read_stream("stderr", proc.stderr)) last_stdout_len = 0 last_stderr_len = 0 try: wait_task = asyncio.create_task(proc.wait()) while True: done, _ = await asyncio.wait({wait_task}, timeout=0.25) stdout_text = "".join(stdout_parts) stderr_text = "".join(stderr_parts) if len(stdout_text) > last_stdout_len: delta = stdout_text[last_stdout_len:] last_stdout_len = len(stdout_text) yield {"type": "tool_output_delta", "stream": "stdout", "text": delta} if len(stderr_text) > last_stderr_len: delta = stderr_text[last_stderr_len:] last_stderr_len = len(stderr_text) yield {"type": "tool_output_delta", "stream": "stderr", "text": delta} if wait_task in done: break if time.monotonic() - start > timeout: timed_out = True try: if proc.pid: os.killpg(proc.pid, signal.SIGTERM) except Exception: proc.terminate() try: await asyncio.wait_for(proc.wait(), timeout=5) except asyncio.TimeoutError: try: if proc.pid: os.killpg(proc.pid, signal.SIGKILL) except Exception: proc.kill() break await asyncio.gather(stdout_task, stderr_task, return_exceptions=True) stdout_text = "".join(stdout_parts) stderr_text = "".join(stderr_parts) if len(stdout_text) > last_stdout_len: yield {"type": "tool_output_delta", "stream": "stdout", "text": stdout_text[last_stdout_len:]} if len(stderr_text) > last_stderr_len: yield {"type": "tool_output_delta", "stream": "stderr", "text": stderr_text[last_stderr_len:]} exit_code = proc.returncode cancelled = session_id in COMMAND_CANCEL_REQUESTS if cancelled and (stderr_chars < max_chars): cancel_msg = "\n[command stopped by owner]" stderr_parts.append(cancel_msg) if timed_out and (stderr_chars < max_chars): timeout_msg = f"\n[command timed out after {timeout} seconds]" stderr_parts.append(timeout_msg) finally: RUNNING_PROCESSES.pop(session_id, None) COMMAND_CANCEL_REQUESTS.discard(session_id) for task in (stdout_task, stderr_task): if not task.done(): task.cancel() duration = time.monotonic() - start stdout_final, stdout_truncated = truncate_preserving_tail("".join(stdout_parts), max_chars) stderr_final, stderr_truncated = truncate_preserving_tail("".join(stderr_parts), max_chars) result = { "command": command, "stdout": stdout_final, "stderr": stderr_final, "exit_code": exit_code if exit_code is not None else -1, "duration_seconds": round(duration, 3), "timed_out": timed_out, "cancelled": cancelled, "stdout_truncated": stdout_truncated or stdout_chars >= max_chars, "stderr_truncated": stderr_truncated or stderr_chars >= max_chars, "session_id": session_id, "session_name": metadata.get("name", session_id), "working_directory": "/", } sessions.log(session_id, "shell_result", result) yield {"type": "tool_result", "result": result}