File size: 4,242 Bytes
69e1a8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import contextlib
import multiprocessing
import queue  # for Empty
import sys
import traceback
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from io import StringIO
from typing import Any

PROCESS_KILLED = -9
PROCESS_NO_RESULT = -999


# Similar to https://docs.python.org/3/library/subprocess.html#subprocess.CompletedProcess
# (args, check_returncode() are intentionally not supported here.)
@dataclass
class CompletedProcess:
    returncode: int
    stdout: str
    stderr: str


class ChildProcessWrapper:
    def __init__(
        self,
        result_queue: Any,
        target: Callable[..., None],
        args: Sequence[Any] | None,
        kwargs: dict[str, Any] | None,
    ) -> None:
        self.target = target
        self.args = () if args is None else args
        self.kwargs = {} if kwargs is None else kwargs
        self.result_queue = result_queue

    def __call__(self) -> None:
        # Capture stdout/stderr
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        sys.stdout = StringIO()
        sys.stderr = StringIO()

        try:
            self.target(*self.args, **self.kwargs)
            returncode = 0
        except SystemExit as e:  # Handle sys.exit()
            returncode = e.code if isinstance(e.code, int) else 0
        except BaseException:
            traceback.print_exc()
            returncode = 1
        finally:
            # Collect outputs and restore streams
            stdout = sys.stdout.getvalue()
            stderr = sys.stderr.getvalue()
            sys.stdout = old_stdout
            sys.stderr = old_stderr
            with contextlib.suppress(Exception):
                self.result_queue.put((returncode, stdout, stderr))


def run_in_spawned_child_process(
    target: Callable[..., None],
    *,
    args: Sequence[Any] | None = None,
    kwargs: dict[str, Any] | None = None,
    timeout: float | None = None,
    rethrow: bool = False,
) -> CompletedProcess:
    """Run `target` in a spawned child process, capturing stdout/stderr.

    The provided `target` must be defined at the top level of a module, and must
    be importable in the spawned child process. Lambdas, closures, or interactively
    defined functions (e.g., in Jupyter notebooks) will not work.

    If `rethrow=True` and the child process exits with a nonzero code,
    raises ChildProcessError with the captured stderr.
    """
    ctx = multiprocessing.get_context("spawn")
    result_queue = ctx.Queue()
    process = ctx.Process(target=ChildProcessWrapper(result_queue, target, args, kwargs))
    process.start()

    try:
        process.join(timeout)
        if process.is_alive():
            process.terminate()
            process.join()
            result = CompletedProcess(
                returncode=PROCESS_KILLED,
                stdout="",
                stderr=f"Process timed out after {timeout} seconds and was terminated.",
            )
        else:
            try:
                returncode, stdout, stderr = result_queue.get(timeout=1.0)
            except (queue.Empty, EOFError):
                result = CompletedProcess(
                    returncode=PROCESS_NO_RESULT,
                    stdout="",
                    stderr="Process exited or crashed before returning results.",
                )
            else:
                result = CompletedProcess(
                    returncode=returncode,
                    stdout=stdout,
                    stderr=stderr,
                )

        if rethrow and result.returncode != 0:
            raise ChildProcessError(
                f"Child process exited with code {result.returncode}.\n"
                "--- stderr-from-child-process ---\n"
                f"{result.stderr}"
                "<end-of-stderr-from-child-process>\n"
            )

        return result

    finally:
        try:
            result_queue.close()
            result_queue.join_thread()
        except Exception:  # noqa: S110
            pass
        if process.is_alive():
            process.kill()
            process.join()