| from __future__ import annotations |
|
|
| from contextlib import ( |
| contextmanager, |
| nullcontext, |
| ) |
| import inspect |
| import re |
| import sys |
| from typing import ( |
| TYPE_CHECKING, |
| Literal, |
| cast, |
| ) |
| import warnings |
|
|
| from pandas.compat import PY311 |
|
|
| if TYPE_CHECKING: |
| from collections.abc import ( |
| Generator, |
| Sequence, |
| ) |
|
|
|
|
| @contextmanager |
| def assert_produces_warning( |
| expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None = Warning, |
| filter_level: Literal[ |
| "error", "ignore", "always", "default", "module", "once" |
| ] = "always", |
| check_stacklevel: bool = True, |
| raise_on_extra_warnings: bool = True, |
| match: str | None = None, |
| ) -> Generator[list[warnings.WarningMessage], None, None]: |
| """ |
| Context manager for running code expected to either raise a specific warning, |
| multiple specific warnings, or not raise any warnings. Verifies that the code |
| raises the expected warning(s), and that it does not raise any other unexpected |
| warnings. It is basically a wrapper around ``warnings.catch_warnings``. |
| |
| Parameters |
| ---------- |
| expected_warning : {Warning, False, tuple[Warning, ...], None}, default Warning |
| The type of Exception raised. ``exception.Warning`` is the base |
| class for all warnings. To raise multiple types of exceptions, |
| pass them as a tuple. To check that no warning is returned, |
| specify ``False`` or ``None``. |
| filter_level : str or None, default "always" |
| Specifies whether warnings are ignored, displayed, or turned |
| into errors. |
| Valid values are: |
| |
| * "error" - turns matching warnings into exceptions |
| * "ignore" - discard the warning |
| * "always" - always emit a warning |
| * "default" - print the warning the first time it is generated |
| from each location |
| * "module" - print the warning the first time it is generated |
| from each module |
| * "once" - print the warning the first time it is generated |
| |
| check_stacklevel : bool, default True |
| If True, displays the line that called the function containing |
| the warning to show were the function is called. Otherwise, the |
| line that implements the function is displayed. |
| raise_on_extra_warnings : bool, default True |
| Whether extra warnings not of the type `expected_warning` should |
| cause the test to fail. |
| match : str, optional |
| Match warning message. |
| |
| Examples |
| -------- |
| >>> import warnings |
| >>> with assert_produces_warning(): |
| ... warnings.warn(UserWarning()) |
| ... |
| >>> with assert_produces_warning(False): |
| ... warnings.warn(RuntimeWarning()) |
| ... |
| Traceback (most recent call last): |
| ... |
| AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. |
| >>> with assert_produces_warning(UserWarning): |
| ... warnings.warn(RuntimeWarning()) |
| Traceback (most recent call last): |
| ... |
| AssertionError: Did not see expected warning of class 'UserWarning'. |
| |
| ..warn:: This is *not* thread-safe. |
| """ |
| __tracebackhide__ = True |
|
|
| with warnings.catch_warnings(record=True) as w: |
| warnings.simplefilter(filter_level) |
| try: |
| yield w |
| finally: |
| if expected_warning: |
| expected_warning = cast(type[Warning], expected_warning) |
| _assert_caught_expected_warning( |
| caught_warnings=w, |
| expected_warning=expected_warning, |
| match=match, |
| check_stacklevel=check_stacklevel, |
| ) |
| if raise_on_extra_warnings: |
| _assert_caught_no_extra_warnings( |
| caught_warnings=w, |
| expected_warning=expected_warning, |
| ) |
|
|
|
|
| def maybe_produces_warning(warning: type[Warning], condition: bool, **kwargs): |
| """ |
| Return a context manager that possibly checks a warning based on the condition |
| """ |
| if condition: |
| return assert_produces_warning(warning, **kwargs) |
| else: |
| return nullcontext() |
|
|
|
|
| def _assert_caught_expected_warning( |
| *, |
| caught_warnings: Sequence[warnings.WarningMessage], |
| expected_warning: type[Warning], |
| match: str | None, |
| check_stacklevel: bool, |
| ) -> None: |
| """Assert that there was the expected warning among the caught warnings.""" |
| saw_warning = False |
| matched_message = False |
| unmatched_messages = [] |
|
|
| for actual_warning in caught_warnings: |
| if issubclass(actual_warning.category, expected_warning): |
| saw_warning = True |
|
|
| if check_stacklevel: |
| _assert_raised_with_correct_stacklevel(actual_warning) |
|
|
| if match is not None: |
| if re.search(match, str(actual_warning.message)): |
| matched_message = True |
| else: |
| unmatched_messages.append(actual_warning.message) |
|
|
| if not saw_warning: |
| raise AssertionError( |
| f"Did not see expected warning of class " |
| f"{repr(expected_warning.__name__)}" |
| ) |
|
|
| if match and not matched_message: |
| raise AssertionError( |
| f"Did not see warning {repr(expected_warning.__name__)} " |
| f"matching '{match}'. The emitted warning messages are " |
| f"{unmatched_messages}" |
| ) |
|
|
|
|
| def _assert_caught_no_extra_warnings( |
| *, |
| caught_warnings: Sequence[warnings.WarningMessage], |
| expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None, |
| ) -> None: |
| """Assert that no extra warnings apart from the expected ones are caught.""" |
| extra_warnings = [] |
|
|
| for actual_warning in caught_warnings: |
| if _is_unexpected_warning(actual_warning, expected_warning): |
| |
| if actual_warning.category == ResourceWarning: |
| |
| |
| if "unclosed <ssl.SSLSocket" in str(actual_warning.message): |
| continue |
| |
| |
| |
| if any("matplotlib" in mod for mod in sys.modules): |
| continue |
| if PY311 and actual_warning.category == EncodingWarning: |
| |
| |
| |
| continue |
| extra_warnings.append( |
| ( |
| actual_warning.category.__name__, |
| actual_warning.message, |
| actual_warning.filename, |
| actual_warning.lineno, |
| ) |
| ) |
|
|
| if extra_warnings: |
| raise AssertionError(f"Caused unexpected warning(s): {repr(extra_warnings)}") |
|
|
|
|
| def _is_unexpected_warning( |
| actual_warning: warnings.WarningMessage, |
| expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None, |
| ) -> bool: |
| """Check if the actual warning issued is unexpected.""" |
| if actual_warning and not expected_warning: |
| return True |
| expected_warning = cast(type[Warning], expected_warning) |
| return bool(not issubclass(actual_warning.category, expected_warning)) |
|
|
|
|
| def _assert_raised_with_correct_stacklevel( |
| actual_warning: warnings.WarningMessage, |
| ) -> None: |
| |
| frame = inspect.currentframe() |
| for _ in range(4): |
| frame = frame.f_back |
| try: |
| caller_filename = inspect.getfile(frame) |
| finally: |
| |
| |
| del frame |
| msg = ( |
| "Warning not set with correct stacklevel. " |
| f"File where warning is raised: {actual_warning.filename} != " |
| f"{caller_filename}. Warning message: {actual_warning.message}" |
| ) |
| assert actual_warning.filename == caller_filename, msg |
|
|