| from __future__ import annotations |
|
|
| import collections |
| from collections.abc import Callable |
| import functools |
| import sys |
| import threading |
| import traceback |
| from typing import NamedTuple |
| from typing import TYPE_CHECKING |
| import warnings |
|
|
| from _pytest.config import Config |
| from _pytest.nodes import Item |
| from _pytest.stash import StashKey |
| from _pytest.tracemalloc import tracemalloc_message |
| import pytest |
|
|
|
|
| if TYPE_CHECKING: |
| pass |
|
|
| if sys.version_info < (3, 11): |
| from exceptiongroup import ExceptionGroup |
|
|
|
|
| class ThreadExceptionMeta(NamedTuple): |
| msg: str |
| cause_msg: str |
| exc_value: BaseException | None |
|
|
|
|
| thread_exceptions: StashKey[collections.deque[ThreadExceptionMeta | BaseException]] = ( |
| StashKey() |
| ) |
|
|
|
|
| def collect_thread_exception(config: Config) -> None: |
| pop_thread_exception = config.stash[thread_exceptions].pop |
| errors: list[pytest.PytestUnhandledThreadExceptionWarning | RuntimeError] = [] |
| meta = None |
| hook_error = None |
| try: |
| while True: |
| try: |
| meta = pop_thread_exception() |
| except IndexError: |
| break |
|
|
| if isinstance(meta, BaseException): |
| hook_error = RuntimeError("Failed to process thread exception") |
| hook_error.__cause__ = meta |
| errors.append(hook_error) |
| continue |
|
|
| msg = meta.msg |
| try: |
| warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg)) |
| except pytest.PytestUnhandledThreadExceptionWarning as e: |
| |
| if meta.exc_value is not None: |
| |
| |
| |
| e.args = (meta.cause_msg,) |
| e.__cause__ = meta.exc_value |
| errors.append(e) |
|
|
| if len(errors) == 1: |
| raise errors[0] |
| if errors: |
| raise ExceptionGroup("multiple thread exception warnings", errors) |
| finally: |
| del errors, meta, hook_error |
|
|
|
|
| def cleanup( |
| *, config: Config, prev_hook: Callable[[threading.ExceptHookArgs], object] |
| ) -> None: |
| try: |
| try: |
| |
| |
| |
| collect_thread_exception(config) |
| finally: |
| threading.excepthook = prev_hook |
| finally: |
| del config.stash[thread_exceptions] |
|
|
|
|
| def thread_exception_hook( |
| args: threading.ExceptHookArgs, |
| /, |
| *, |
| append: Callable[[ThreadExceptionMeta | BaseException], object], |
| ) -> None: |
| try: |
| |
| |
| |
| thread_name = "<unknown>" if args.thread is None else args.thread.name |
| summary = f"Exception in thread {thread_name}" |
| traceback_message = "\n\n" + "".join( |
| traceback.format_exception( |
| args.exc_type, |
| args.exc_value, |
| args.exc_traceback, |
| ) |
| ) |
| tracemalloc_tb = "\n" + tracemalloc_message(args.thread) |
| msg = summary + traceback_message + tracemalloc_tb |
| cause_msg = summary + tracemalloc_tb |
|
|
| append( |
| ThreadExceptionMeta( |
| |
| msg=msg, |
| cause_msg=cause_msg, |
| exc_value=args.exc_value, |
| ) |
| ) |
| except BaseException as e: |
| append(e) |
| |
| |
| |
| |
| raise |
|
|
|
|
| def pytest_configure(config: Config) -> None: |
| prev_hook = threading.excepthook |
| deque: collections.deque[ThreadExceptionMeta | BaseException] = collections.deque() |
| config.stash[thread_exceptions] = deque |
| config.add_cleanup(functools.partial(cleanup, config=config, prev_hook=prev_hook)) |
| threading.excepthook = functools.partial(thread_exception_hook, append=deque.append) |
|
|
|
|
| @pytest.hookimpl(trylast=True) |
| def pytest_runtest_setup(item: Item) -> None: |
| collect_thread_exception(item.config) |
|
|
|
|
| @pytest.hookimpl(trylast=True) |
| def pytest_runtest_call(item: Item) -> None: |
| collect_thread_exception(item.config) |
|
|
|
|
| @pytest.hookimpl(trylast=True) |
| def pytest_runtest_teardown(item: Item) -> None: |
| collect_thread_exception(item.config) |
|
|