File size: 4,242 Bytes
69e1a8d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import contextlib
import multiprocessing
import queue # for Empty
import sys
import traceback
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from io import StringIO
from typing import Any
PROCESS_KILLED = -9
PROCESS_NO_RESULT = -999
# Similar to https://docs.python.org/3/library/subprocess.html#subprocess.CompletedProcess
# (args, check_returncode() are intentionally not supported here.)
@dataclass
class CompletedProcess:
returncode: int
stdout: str
stderr: str
class ChildProcessWrapper:
def __init__(
self,
result_queue: Any,
target: Callable[..., None],
args: Sequence[Any] | None,
kwargs: dict[str, Any] | None,
) -> None:
self.target = target
self.args = () if args is None else args
self.kwargs = {} if kwargs is None else kwargs
self.result_queue = result_queue
def __call__(self) -> None:
# Capture stdout/stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = StringIO()
sys.stderr = StringIO()
try:
self.target(*self.args, **self.kwargs)
returncode = 0
except SystemExit as e: # Handle sys.exit()
returncode = e.code if isinstance(e.code, int) else 0
except BaseException:
traceback.print_exc()
returncode = 1
finally:
# Collect outputs and restore streams
stdout = sys.stdout.getvalue()
stderr = sys.stderr.getvalue()
sys.stdout = old_stdout
sys.stderr = old_stderr
with contextlib.suppress(Exception):
self.result_queue.put((returncode, stdout, stderr))
def run_in_spawned_child_process(
target: Callable[..., None],
*,
args: Sequence[Any] | None = None,
kwargs: dict[str, Any] | None = None,
timeout: float | None = None,
rethrow: bool = False,
) -> CompletedProcess:
"""Run `target` in a spawned child process, capturing stdout/stderr.
The provided `target` must be defined at the top level of a module, and must
be importable in the spawned child process. Lambdas, closures, or interactively
defined functions (e.g., in Jupyter notebooks) will not work.
If `rethrow=True` and the child process exits with a nonzero code,
raises ChildProcessError with the captured stderr.
"""
ctx = multiprocessing.get_context("spawn")
result_queue = ctx.Queue()
process = ctx.Process(target=ChildProcessWrapper(result_queue, target, args, kwargs))
process.start()
try:
process.join(timeout)
if process.is_alive():
process.terminate()
process.join()
result = CompletedProcess(
returncode=PROCESS_KILLED,
stdout="",
stderr=f"Process timed out after {timeout} seconds and was terminated.",
)
else:
try:
returncode, stdout, stderr = result_queue.get(timeout=1.0)
except (queue.Empty, EOFError):
result = CompletedProcess(
returncode=PROCESS_NO_RESULT,
stdout="",
stderr="Process exited or crashed before returning results.",
)
else:
result = CompletedProcess(
returncode=returncode,
stdout=stdout,
stderr=stderr,
)
if rethrow and result.returncode != 0:
raise ChildProcessError(
f"Child process exited with code {result.returncode}.\n"
"--- stderr-from-child-process ---\n"
f"{result.stderr}"
"<end-of-stderr-from-child-process>\n"
)
return result
finally:
try:
result_queue.close()
result_queue.join_thread()
except Exception: # noqa: S110
pass
if process.is_alive():
process.kill()
process.join()
|