| |
| |
|
|
| import contextlib |
| import multiprocessing |
| import queue |
| import sys |
| import traceback |
| from collections.abc import Callable, Sequence |
| from dataclasses import dataclass |
| from io import StringIO |
| from typing import Any |
|
|
| PROCESS_KILLED = -9 |
| PROCESS_NO_RESULT = -999 |
|
|
|
|
| |
| |
| @dataclass |
| class CompletedProcess: |
| returncode: int |
| stdout: str |
| stderr: str |
|
|
|
|
| class ChildProcessWrapper: |
| def __init__( |
| self, |
| result_queue: Any, |
| target: Callable[..., None], |
| args: Sequence[Any] | None, |
| kwargs: dict[str, Any] | None, |
| ) -> None: |
| self.target = target |
| self.args = () if args is None else args |
| self.kwargs = {} if kwargs is None else kwargs |
| self.result_queue = result_queue |
|
|
| def __call__(self) -> None: |
| |
| old_stdout = sys.stdout |
| old_stderr = sys.stderr |
| sys.stdout = StringIO() |
| sys.stderr = StringIO() |
|
|
| try: |
| self.target(*self.args, **self.kwargs) |
| returncode = 0 |
| except SystemExit as e: |
| returncode = e.code if isinstance(e.code, int) else 0 |
| except BaseException: |
| traceback.print_exc() |
| returncode = 1 |
| finally: |
| |
| stdout = sys.stdout.getvalue() |
| stderr = sys.stderr.getvalue() |
| sys.stdout = old_stdout |
| sys.stderr = old_stderr |
| with contextlib.suppress(Exception): |
| self.result_queue.put((returncode, stdout, stderr)) |
|
|
|
|
| def run_in_spawned_child_process( |
| target: Callable[..., None], |
| *, |
| args: Sequence[Any] | None = None, |
| kwargs: dict[str, Any] | None = None, |
| timeout: float | None = None, |
| rethrow: bool = False, |
| ) -> CompletedProcess: |
| """Run `target` in a spawned child process, capturing stdout/stderr. |
| |
| The provided `target` must be defined at the top level of a module, and must |
| be importable in the spawned child process. Lambdas, closures, or interactively |
| defined functions (e.g., in Jupyter notebooks) will not work. |
| |
| If `rethrow=True` and the child process exits with a nonzero code, |
| raises ChildProcessError with the captured stderr. |
| """ |
| ctx = multiprocessing.get_context("spawn") |
| result_queue = ctx.Queue() |
| process = ctx.Process(target=ChildProcessWrapper(result_queue, target, args, kwargs)) |
| process.start() |
|
|
| try: |
| process.join(timeout) |
| if process.is_alive(): |
| process.terminate() |
| process.join() |
| result = CompletedProcess( |
| returncode=PROCESS_KILLED, |
| stdout="", |
| stderr=f"Process timed out after {timeout} seconds and was terminated.", |
| ) |
| else: |
| try: |
| returncode, stdout, stderr = result_queue.get(timeout=1.0) |
| except (queue.Empty, EOFError): |
| result = CompletedProcess( |
| returncode=PROCESS_NO_RESULT, |
| stdout="", |
| stderr="Process exited or crashed before returning results.", |
| ) |
| else: |
| result = CompletedProcess( |
| returncode=returncode, |
| stdout=stdout, |
| stderr=stderr, |
| ) |
|
|
| if rethrow and result.returncode != 0: |
| raise ChildProcessError( |
| f"Child process exited with code {result.returncode}.\n" |
| "--- stderr-from-child-process ---\n" |
| f"{result.stderr}" |
| "<end-of-stderr-from-child-process>\n" |
| ) |
|
|
| return result |
|
|
| finally: |
| try: |
| result_queue.close() |
| result_queue.join_thread() |
| except Exception: |
| pass |
| if process.is_alive(): |
| process.kill() |
| process.join() |
|
|