| | from __future__ import annotations |
| |
|
| | import inspect |
| | import sys |
| | from collections.abc import Callable, Iterable, Mapping |
| | from contextlib import AbstractContextManager |
| | from types import TracebackType |
| | from typing import TYPE_CHECKING, Any |
| |
|
| | if sys.version_info < (3, 11): |
| | from ._exceptions import BaseExceptionGroup |
| |
|
| | if TYPE_CHECKING: |
| | _Handler = Callable[[BaseExceptionGroup[Any]], Any] |
| |
|
| |
|
| | class _Catcher: |
| | def __init__(self, handler_map: Mapping[tuple[type[BaseException], ...], _Handler]): |
| | self._handler_map = handler_map |
| |
|
| | def __enter__(self) -> None: |
| | pass |
| |
|
| | def __exit__( |
| | self, |
| | etype: type[BaseException] | None, |
| | exc: BaseException | None, |
| | tb: TracebackType | None, |
| | ) -> bool: |
| | if exc is not None: |
| | unhandled = self.handle_exception(exc) |
| | if unhandled is exc: |
| | return False |
| | elif unhandled is None: |
| | return True |
| | else: |
| | if isinstance(exc, BaseExceptionGroup): |
| | try: |
| | raise unhandled from exc.__cause__ |
| | except BaseExceptionGroup: |
| | |
| | |
| | unhandled.__context__ = exc.__cause__ |
| | raise |
| |
|
| | raise unhandled from exc |
| |
|
| | return False |
| |
|
| | def handle_exception(self, exc: BaseException) -> BaseException | None: |
| | excgroup: BaseExceptionGroup | None |
| | if isinstance(exc, BaseExceptionGroup): |
| | excgroup = exc |
| | else: |
| | excgroup = BaseExceptionGroup("", [exc]) |
| |
|
| | new_exceptions: list[BaseException] = [] |
| | for exc_types, handler in self._handler_map.items(): |
| | matched, excgroup = excgroup.split(exc_types) |
| | if matched: |
| | try: |
| | try: |
| | raise matched |
| | except BaseExceptionGroup: |
| | result = handler(matched) |
| | except BaseExceptionGroup as new_exc: |
| | if new_exc is matched: |
| | new_exceptions.append(new_exc) |
| | else: |
| | new_exceptions.extend(new_exc.exceptions) |
| | except BaseException as new_exc: |
| | new_exceptions.append(new_exc) |
| | else: |
| | if inspect.iscoroutine(result): |
| | raise TypeError( |
| | f"Error trying to handle {matched!r} with {handler!r}. " |
| | "Exception handler must be a sync function." |
| | ) from exc |
| |
|
| | if not excgroup: |
| | break |
| |
|
| | if new_exceptions: |
| | if len(new_exceptions) == 1: |
| | return new_exceptions[0] |
| |
|
| | return BaseExceptionGroup("", new_exceptions) |
| | elif ( |
| | excgroup and len(excgroup.exceptions) == 1 and excgroup.exceptions[0] is exc |
| | ): |
| | return exc |
| | else: |
| | return excgroup |
| |
|
| |
|
| | def catch( |
| | __handlers: Mapping[type[BaseException] | Iterable[type[BaseException]], _Handler], |
| | ) -> AbstractContextManager[None]: |
| | if not isinstance(__handlers, Mapping): |
| | raise TypeError("the argument must be a mapping") |
| |
|
| | handler_map: dict[ |
| | tuple[type[BaseException], ...], Callable[[BaseExceptionGroup]] |
| | ] = {} |
| | for type_or_iterable, handler in __handlers.items(): |
| | iterable: tuple[type[BaseException]] |
| | if isinstance(type_or_iterable, type) and issubclass( |
| | type_or_iterable, BaseException |
| | ): |
| | iterable = (type_or_iterable,) |
| | elif isinstance(type_or_iterable, Iterable): |
| | iterable = tuple(type_or_iterable) |
| | else: |
| | raise TypeError( |
| | "each key must be either an exception classes or an iterable thereof" |
| | ) |
| |
|
| | if not callable(handler): |
| | raise TypeError("handlers must be callable") |
| |
|
| | for exc_type in iterable: |
| | if not isinstance(exc_type, type) or not issubclass( |
| | exc_type, BaseException |
| | ): |
| | raise TypeError( |
| | "each key must be either an exception classes or an iterable " |
| | "thereof" |
| | ) |
| |
|
| | if issubclass(exc_type, BaseExceptionGroup): |
| | raise TypeError( |
| | "catching ExceptionGroup with catch() is not allowed. " |
| | "Use except instead." |
| | ) |
| |
|
| | handler_map[iterable] = handler |
| |
|
| | return _Catcher(handler_map) |
| |
|