| from __future__ import annotations |
|
|
| import asyncio |
| import os |
| import shutil |
| import signal |
| import time |
| from pathlib import Path |
| from typing import Any, AsyncIterator |
|
|
| from .config import settings |
| from .security import redact_text |
| from .sessions import SessionPaths, sessions |
|
|
| RUNNING_PROCESSES: dict[str, asyncio.subprocess.Process] = {} |
| COMMAND_CANCEL_REQUESTS: set[str] = set() |
|
|
|
|
| def truncate_preserving_tail(text: str, limit: int) -> tuple[str, bool]: |
| if len(text) <= limit: |
| return text, False |
| head = max(0, limit // 3) |
| tail = max(0, limit - head - 120) |
| return text[:head] + f"\n\n[... output truncated to {limit} characters ...]\n\n" + text[-tail:], True |
|
|
|
|
| def _existing_bind_args() -> list[str]: |
| binds: list[str] = [] |
| candidates = [ |
| "/bin", |
| "/usr", |
| "/lib", |
| "/lib64", |
| "/opt", |
| "/etc/resolv.conf", |
| "/etc/hosts", |
| "/etc/passwd", |
| "/etc/group", |
| "/etc/ssl", |
| "/dev/null", |
| "/dev/urandom", |
| "/dev/random", |
| "/proc", |
| "/tmp", |
| ] |
| for candidate in candidates: |
| if Path(candidate).exists(): |
| binds.extend(["-b", candidate]) |
| return binds |
|
|
|
|
| def build_shell_args(paths: SessionPaths, command: str) -> tuple[list[str], str]: |
| proot = shutil.which("proot") |
| if proot: |
| args = [ |
| proot, |
| "-r", |
| str(paths.root), |
| "-w", |
| "/", |
| *_existing_bind_args(), |
| "/bin/bash", |
| "-lc", |
| command, |
| ] |
| return args, "/" |
| return ["/bin/bash", "-lc", command], str(paths.root) |
|
|
|
|
| async def stop_running_process(session_id: str) -> bool: |
| proc = RUNNING_PROCESSES.get(session_id) |
| if not proc or proc.returncode is not None: |
| return False |
| COMMAND_CANCEL_REQUESTS.add(session_id) |
| try: |
| if proc.pid: |
| os.killpg(proc.pid, signal.SIGTERM) |
| except ProcessLookupError: |
| return False |
| except Exception: |
| proc.terminate() |
| try: |
| await asyncio.wait_for(proc.wait(), timeout=5) |
| except asyncio.TimeoutError: |
| try: |
| if proc.pid: |
| os.killpg(proc.pid, signal.SIGKILL) |
| except Exception: |
| proc.kill() |
| return True |
|
|
|
|
| async def shell_exec_stream( |
| session_id: str, |
| command: str, |
| timeout_seconds: int | None = None, |
| ) -> AsyncIterator[dict[str, Any]]: |
| paths = sessions.ensure(session_id) |
| metadata = sessions.get_metadata(session_id) |
| timeout = timeout_seconds or settings.max_tool_runtime_seconds |
| timeout = max(1, min(int(timeout), settings.max_tool_runtime_seconds)) |
| args, cwd = build_shell_args(paths, command) |
| env = os.environ.copy() |
| env.update( |
| { |
| "SESSION_ID": session_id, |
| "SESSION_NAME": str(metadata.get("name", session_id)), |
| "SESSION_ROOT": "/", |
| "UPLOADS_DIR": "/uploads", |
| "OUTPUTS_DIR": "/outputs", |
| "WORKDIR": "/work", |
| "HOME": "/work", |
| "PYTHONUNBUFFERED": "1", |
| } |
| ) |
|
|
| start = time.monotonic() |
| sessions.log(session_id, "shell_start", {"command": command, "timeout_seconds": timeout, "working_directory": "/"}) |
| yield { |
| "type": "tool_start", |
| "command": command, |
| "timeout_seconds": timeout, |
| "session_id": session_id, |
| "session_name": metadata.get("name", session_id), |
| "working_directory": "/", |
| } |
|
|
| stdout_parts: list[str] = [] |
| stderr_parts: list[str] = [] |
| stdout_chars = 0 |
| stderr_chars = 0 |
| max_chars = settings.max_output_chars |
| timed_out = False |
| cancelled = False |
| exit_code: int | None = None |
|
|
| try: |
| proc = await asyncio.create_subprocess_exec( |
| *args, |
| cwd=cwd, |
| env=env, |
| stdout=asyncio.subprocess.PIPE, |
| stderr=asyncio.subprocess.PIPE, |
| start_new_session=True, |
| ) |
| RUNNING_PROCESSES[session_id] = proc |
| except Exception as exc: |
| duration = time.monotonic() - start |
| error = redact_text(str(exc)) |
| result = { |
| "command": command, |
| "stdout": "", |
| "stderr": error, |
| "exit_code": -1, |
| "duration_seconds": round(duration, 3), |
| "timed_out": False, |
| "session_id": session_id, |
| "session_name": metadata.get("name", session_id), |
| "working_directory": "/", |
| } |
| sessions.log(session_id, "shell_error", result) |
| yield {"type": "tool_result", "result": result} |
| return |
|
|
| async def read_stream(name: str, stream: asyncio.StreamReader | None) -> None: |
| nonlocal stdout_chars, stderr_chars |
| if stream is None: |
| return |
| while True: |
| chunk = await stream.read(4096) |
| if not chunk: |
| break |
| text = redact_text(chunk.decode("utf-8", errors="replace")) |
| if name == "stdout": |
| if stdout_chars < max_chars: |
| remaining = max_chars - stdout_chars |
| kept = text[:remaining] |
| stdout_parts.append(kept) |
| stdout_chars += len(kept) |
| else: |
| if stderr_chars < max_chars: |
| remaining = max_chars - stderr_chars |
| kept = text[:remaining] |
| stderr_parts.append(kept) |
| stderr_chars += len(kept) |
|
|
| |
| |
| stdout_task = asyncio.create_task(read_stream("stdout", proc.stdout)) |
| stderr_task = asyncio.create_task(read_stream("stderr", proc.stderr)) |
| last_stdout_len = 0 |
| last_stderr_len = 0 |
|
|
| try: |
| wait_task = asyncio.create_task(proc.wait()) |
| while True: |
| done, _ = await asyncio.wait({wait_task}, timeout=0.25) |
| stdout_text = "".join(stdout_parts) |
| stderr_text = "".join(stderr_parts) |
| if len(stdout_text) > last_stdout_len: |
| delta = stdout_text[last_stdout_len:] |
| last_stdout_len = len(stdout_text) |
| yield {"type": "tool_output_delta", "stream": "stdout", "text": delta} |
| if len(stderr_text) > last_stderr_len: |
| delta = stderr_text[last_stderr_len:] |
| last_stderr_len = len(stderr_text) |
| yield {"type": "tool_output_delta", "stream": "stderr", "text": delta} |
| if wait_task in done: |
| break |
| if time.monotonic() - start > timeout: |
| timed_out = True |
| try: |
| if proc.pid: |
| os.killpg(proc.pid, signal.SIGTERM) |
| except Exception: |
| proc.terminate() |
| try: |
| await asyncio.wait_for(proc.wait(), timeout=5) |
| except asyncio.TimeoutError: |
| try: |
| if proc.pid: |
| os.killpg(proc.pid, signal.SIGKILL) |
| except Exception: |
| proc.kill() |
| break |
| await asyncio.gather(stdout_task, stderr_task, return_exceptions=True) |
| stdout_text = "".join(stdout_parts) |
| stderr_text = "".join(stderr_parts) |
| if len(stdout_text) > last_stdout_len: |
| yield {"type": "tool_output_delta", "stream": "stdout", "text": stdout_text[last_stdout_len:]} |
| if len(stderr_text) > last_stderr_len: |
| yield {"type": "tool_output_delta", "stream": "stderr", "text": stderr_text[last_stderr_len:]} |
| exit_code = proc.returncode |
| cancelled = session_id in COMMAND_CANCEL_REQUESTS |
| if cancelled and (stderr_chars < max_chars): |
| cancel_msg = "\n[command stopped by owner]" |
| stderr_parts.append(cancel_msg) |
| if timed_out and (stderr_chars < max_chars): |
| timeout_msg = f"\n[command timed out after {timeout} seconds]" |
| stderr_parts.append(timeout_msg) |
| finally: |
| RUNNING_PROCESSES.pop(session_id, None) |
| COMMAND_CANCEL_REQUESTS.discard(session_id) |
| for task in (stdout_task, stderr_task): |
| if not task.done(): |
| task.cancel() |
|
|
| duration = time.monotonic() - start |
| stdout_final, stdout_truncated = truncate_preserving_tail("".join(stdout_parts), max_chars) |
| stderr_final, stderr_truncated = truncate_preserving_tail("".join(stderr_parts), max_chars) |
| result = { |
| "command": command, |
| "stdout": stdout_final, |
| "stderr": stderr_final, |
| "exit_code": exit_code if exit_code is not None else -1, |
| "duration_seconds": round(duration, 3), |
| "timed_out": timed_out, |
| "cancelled": cancelled, |
| "stdout_truncated": stdout_truncated or stdout_chars >= max_chars, |
| "stderr_truncated": stderr_truncated or stderr_chars >= max_chars, |
| "session_id": session_id, |
| "session_name": metadata.get("name", session_id), |
| "working_directory": "/", |
| } |
| sessions.log(session_id, "shell_result", result) |
| yield {"type": "tool_result", "result": result} |
|
|