| """ |
| Module consolidating common testing functions for checking plotting. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import TYPE_CHECKING |
|
|
| import numpy as np |
|
|
| from pandas.core.dtypes.api import is_list_like |
|
|
| import pandas as pd |
| from pandas import Series |
| import pandas._testing as tm |
|
|
| if TYPE_CHECKING: |
| from collections.abc import Sequence |
|
|
| from matplotlib.axes import Axes |
|
|
|
|
| def _check_legend_labels(axes, labels=None, visible=True): |
| """ |
| Check each axes has expected legend labels |
| |
| Parameters |
| ---------- |
| axes : matplotlib Axes object, or its list-like |
| labels : list-like |
| expected legend labels |
| visible : bool |
| expected legend visibility. labels are checked only when visible is |
| True |
| """ |
| if visible and (labels is None): |
| raise ValueError("labels must be specified when visible is True") |
| axes = _flatten_visible(axes) |
| for ax in axes: |
| if visible: |
| assert ax.get_legend() is not None |
| _check_text_labels(ax.get_legend().get_texts(), labels) |
| else: |
| assert ax.get_legend() is None |
|
|
|
|
| def _check_legend_marker(ax, expected_markers=None, visible=True): |
| """ |
| Check ax has expected legend markers |
| |
| Parameters |
| ---------- |
| ax : matplotlib Axes object |
| expected_markers : list-like |
| expected legend markers |
| visible : bool |
| expected legend visibility. labels are checked only when visible is |
| True |
| """ |
| if visible and (expected_markers is None): |
| raise ValueError("Markers must be specified when visible is True") |
| if visible: |
| handles, _ = ax.get_legend_handles_labels() |
| markers = [handle.get_marker() for handle in handles] |
| assert markers == expected_markers |
| else: |
| assert ax.get_legend() is None |
|
|
|
|
| def _check_data(xp, rs): |
| """ |
| Check each axes has identical lines |
| |
| Parameters |
| ---------- |
| xp : matplotlib Axes object |
| rs : matplotlib Axes object |
| """ |
| import matplotlib.pyplot as plt |
|
|
| xp_lines = xp.get_lines() |
| rs_lines = rs.get_lines() |
|
|
| assert len(xp_lines) == len(rs_lines) |
| for xpl, rsl in zip(xp_lines, rs_lines): |
| xpdata = xpl.get_xydata() |
| rsdata = rsl.get_xydata() |
| tm.assert_almost_equal(xpdata, rsdata) |
|
|
| plt.close("all") |
|
|
|
|
| def _check_visible(collections, visible=True): |
| """ |
| Check each artist is visible or not |
| |
| Parameters |
| ---------- |
| collections : matplotlib Artist or its list-like |
| target Artist or its list or collection |
| visible : bool |
| expected visibility |
| """ |
| from matplotlib.collections import Collection |
|
|
| if not isinstance(collections, Collection) and not is_list_like(collections): |
| collections = [collections] |
|
|
| for patch in collections: |
| assert patch.get_visible() == visible |
|
|
|
|
| def _check_patches_all_filled(axes: Axes | Sequence[Axes], filled: bool = True) -> None: |
| """ |
| Check for each artist whether it is filled or not |
| |
| Parameters |
| ---------- |
| axes : matplotlib Axes object, or its list-like |
| filled : bool |
| expected filling |
| """ |
|
|
| axes = _flatten_visible(axes) |
| for ax in axes: |
| for patch in ax.patches: |
| assert patch.fill == filled |
|
|
|
|
| def _get_colors_mapped(series, colors): |
| unique = series.unique() |
| |
| |
| mapped = dict(zip(unique, colors)) |
| return [mapped[v] for v in series.values] |
|
|
|
|
| def _check_colors(collections, linecolors=None, facecolors=None, mapping=None): |
| """ |
| Check each artist has expected line colors and face colors |
| |
| Parameters |
| ---------- |
| collections : list-like |
| list or collection of target artist |
| linecolors : list-like which has the same length as collections |
| list of expected line colors |
| facecolors : list-like which has the same length as collections |
| list of expected face colors |
| mapping : Series |
| Series used for color grouping key |
| used for andrew_curves, parallel_coordinates, radviz test |
| """ |
| from matplotlib import colors |
| from matplotlib.collections import ( |
| Collection, |
| LineCollection, |
| PolyCollection, |
| ) |
| from matplotlib.lines import Line2D |
|
|
| conv = colors.ColorConverter |
| if linecolors is not None: |
| if mapping is not None: |
| linecolors = _get_colors_mapped(mapping, linecolors) |
| linecolors = linecolors[: len(collections)] |
|
|
| assert len(collections) == len(linecolors) |
| for patch, color in zip(collections, linecolors): |
| if isinstance(patch, Line2D): |
| result = patch.get_color() |
| |
| result = conv.to_rgba(result) |
| elif isinstance(patch, (PolyCollection, LineCollection)): |
| result = tuple(patch.get_edgecolor()[0]) |
| else: |
| result = patch.get_edgecolor() |
|
|
| expected = conv.to_rgba(color) |
| assert result == expected |
|
|
| if facecolors is not None: |
| if mapping is not None: |
| facecolors = _get_colors_mapped(mapping, facecolors) |
| facecolors = facecolors[: len(collections)] |
|
|
| assert len(collections) == len(facecolors) |
| for patch, color in zip(collections, facecolors): |
| if isinstance(patch, Collection): |
| |
| result = patch.get_facecolor()[0] |
| else: |
| result = patch.get_facecolor() |
|
|
| if isinstance(result, np.ndarray): |
| result = tuple(result) |
|
|
| expected = conv.to_rgba(color) |
| assert result == expected |
|
|
|
|
| def _check_text_labels(texts, expected): |
| """ |
| Check each text has expected labels |
| |
| Parameters |
| ---------- |
| texts : matplotlib Text object, or its list-like |
| target text, or its list |
| expected : str or list-like which has the same length as texts |
| expected text label, or its list |
| """ |
| if not is_list_like(texts): |
| assert texts.get_text() == expected |
| else: |
| labels = [t.get_text() for t in texts] |
| assert len(labels) == len(expected) |
| for label, e in zip(labels, expected): |
| assert label == e |
|
|
|
|
| def _check_ticks_props(axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None): |
| """ |
| Check each axes has expected tick properties |
| |
| Parameters |
| ---------- |
| axes : matplotlib Axes object, or its list-like |
| xlabelsize : number |
| expected xticks font size |
| xrot : number |
| expected xticks rotation |
| ylabelsize : number |
| expected yticks font size |
| yrot : number |
| expected yticks rotation |
| """ |
| from matplotlib.ticker import NullFormatter |
|
|
| axes = _flatten_visible(axes) |
| for ax in axes: |
| if xlabelsize is not None or xrot is not None: |
| if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter): |
| |
| |
| labels = ax.get_xticklabels() |
| else: |
| labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True) |
|
|
| for label in labels: |
| if xlabelsize is not None: |
| tm.assert_almost_equal(label.get_fontsize(), xlabelsize) |
| if xrot is not None: |
| tm.assert_almost_equal(label.get_rotation(), xrot) |
|
|
| if ylabelsize is not None or yrot is not None: |
| if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter): |
| labels = ax.get_yticklabels() |
| else: |
| labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True) |
|
|
| for label in labels: |
| if ylabelsize is not None: |
| tm.assert_almost_equal(label.get_fontsize(), ylabelsize) |
| if yrot is not None: |
| tm.assert_almost_equal(label.get_rotation(), yrot) |
|
|
|
|
| def _check_ax_scales(axes, xaxis="linear", yaxis="linear"): |
| """ |
| Check each axes has expected scales |
| |
| Parameters |
| ---------- |
| axes : matplotlib Axes object, or its list-like |
| xaxis : {'linear', 'log'} |
| expected xaxis scale |
| yaxis : {'linear', 'log'} |
| expected yaxis scale |
| """ |
| axes = _flatten_visible(axes) |
| for ax in axes: |
| assert ax.xaxis.get_scale() == xaxis |
| assert ax.yaxis.get_scale() == yaxis |
|
|
|
|
| def _check_axes_shape(axes, axes_num=None, layout=None, figsize=None): |
| """ |
| Check expected number of axes is drawn in expected layout |
| |
| Parameters |
| ---------- |
| axes : matplotlib Axes object, or its list-like |
| axes_num : number |
| expected number of axes. Unnecessary axes should be set to |
| invisible. |
| layout : tuple |
| expected layout, (expected number of rows , columns) |
| figsize : tuple |
| expected figsize. default is matplotlib default |
| """ |
| from pandas.plotting._matplotlib.tools import flatten_axes |
|
|
| if figsize is None: |
| figsize = (6.4, 4.8) |
| visible_axes = _flatten_visible(axes) |
|
|
| if axes_num is not None: |
| assert len(visible_axes) == axes_num |
| for ax in visible_axes: |
| |
| assert len(ax.get_children()) > 0 |
|
|
| if layout is not None: |
| x_set = set() |
| y_set = set() |
| for ax in flatten_axes(axes): |
| |
| points = ax.get_position().get_points() |
| x_set.add(points[0][0]) |
| y_set.add(points[0][1]) |
| result = (len(y_set), len(x_set)) |
| assert result == layout |
|
|
| tm.assert_numpy_array_equal( |
| visible_axes[0].figure.get_size_inches(), |
| np.array(figsize, dtype=np.float64), |
| ) |
|
|
|
|
| def _flatten_visible(axes: Axes | Sequence[Axes]) -> Sequence[Axes]: |
| """ |
| Flatten axes, and filter only visible |
| |
| Parameters |
| ---------- |
| axes : matplotlib Axes object, or its list-like |
| |
| """ |
| from pandas.plotting._matplotlib.tools import flatten_axes |
|
|
| axes_ndarray = flatten_axes(axes) |
| axes = [ax for ax in axes_ndarray if ax.get_visible()] |
| return axes |
|
|
|
|
| def _check_has_errorbars(axes, xerr=0, yerr=0): |
| """ |
| Check axes has expected number of errorbars |
| |
| Parameters |
| ---------- |
| axes : matplotlib Axes object, or its list-like |
| xerr : number |
| expected number of x errorbar |
| yerr : number |
| expected number of y errorbar |
| """ |
| axes = _flatten_visible(axes) |
| for ax in axes: |
| containers = ax.containers |
| xerr_count = 0 |
| yerr_count = 0 |
| for c in containers: |
| has_xerr = getattr(c, "has_xerr", False) |
| has_yerr = getattr(c, "has_yerr", False) |
| if has_xerr: |
| xerr_count += 1 |
| if has_yerr: |
| yerr_count += 1 |
| assert xerr == xerr_count |
| assert yerr == yerr_count |
|
|
|
|
| def _check_box_return_type( |
| returned, return_type, expected_keys=None, check_ax_title=True |
| ): |
| """ |
| Check box returned type is correct |
| |
| Parameters |
| ---------- |
| returned : object to be tested, returned from boxplot |
| return_type : str |
| return_type passed to boxplot |
| expected_keys : list-like, optional |
| group labels in subplot case. If not passed, |
| the function checks assuming boxplot uses single ax |
| check_ax_title : bool |
| Whether to check the ax.title is the same as expected_key |
| Intended to be checked by calling from ``boxplot``. |
| Normal ``plot`` doesn't attach ``ax.title``, it must be disabled. |
| """ |
| from matplotlib.axes import Axes |
|
|
| types = {"dict": dict, "axes": Axes, "both": tuple} |
| if expected_keys is None: |
| |
| if return_type is None: |
| return_type = "dict" |
|
|
| assert isinstance(returned, types[return_type]) |
| if return_type == "both": |
| assert isinstance(returned.ax, Axes) |
| assert isinstance(returned.lines, dict) |
| else: |
| |
| if return_type is None: |
| for r in _flatten_visible(returned): |
| assert isinstance(r, Axes) |
| return |
|
|
| assert isinstance(returned, Series) |
|
|
| assert sorted(returned.keys()) == sorted(expected_keys) |
| for key, value in returned.items(): |
| assert isinstance(value, types[return_type]) |
| |
| if return_type == "axes": |
| if check_ax_title: |
| assert value.get_title() == key |
| elif return_type == "both": |
| if check_ax_title: |
| assert value.ax.get_title() == key |
| assert isinstance(value.ax, Axes) |
| assert isinstance(value.lines, dict) |
| elif return_type == "dict": |
| line = value["medians"][0] |
| axes = line.axes |
| if check_ax_title: |
| assert axes.get_title() == key |
| else: |
| raise AssertionError |
|
|
|
|
| def _check_grid_settings(obj, kinds, kws={}): |
| |
|
|
| import matplotlib as mpl |
|
|
| def is_grid_on(): |
| xticks = mpl.pyplot.gca().xaxis.get_major_ticks() |
| yticks = mpl.pyplot.gca().yaxis.get_major_ticks() |
| xoff = all(not g.gridline.get_visible() for g in xticks) |
| yoff = all(not g.gridline.get_visible() for g in yticks) |
|
|
| return not (xoff and yoff) |
|
|
| spndx = 1 |
| for kind in kinds: |
| mpl.pyplot.subplot(1, 4 * len(kinds), spndx) |
| spndx += 1 |
| mpl.rc("axes", grid=False) |
| obj.plot(kind=kind, **kws) |
| assert not is_grid_on() |
| mpl.pyplot.clf() |
|
|
| mpl.pyplot.subplot(1, 4 * len(kinds), spndx) |
| spndx += 1 |
| mpl.rc("axes", grid=True) |
| obj.plot(kind=kind, grid=False, **kws) |
| assert not is_grid_on() |
| mpl.pyplot.clf() |
|
|
| if kind not in ["pie", "hexbin", "scatter"]: |
| mpl.pyplot.subplot(1, 4 * len(kinds), spndx) |
| spndx += 1 |
| mpl.rc("axes", grid=True) |
| obj.plot(kind=kind, **kws) |
| assert is_grid_on() |
| mpl.pyplot.clf() |
|
|
| mpl.pyplot.subplot(1, 4 * len(kinds), spndx) |
| spndx += 1 |
| mpl.rc("axes", grid=False) |
| obj.plot(kind=kind, grid=True, **kws) |
| assert is_grid_on() |
| mpl.pyplot.clf() |
|
|
|
|
| def _unpack_cycler(rcParams, field="color"): |
| """ |
| Auxiliary function for correctly unpacking cycler after MPL >= 1.5 |
| """ |
| return [v[field] for v in rcParams["axes.prop_cycle"]] |
|
|
|
|
| def get_x_axis(ax): |
| return ax._shared_axes["x"] |
|
|
|
|
| def get_y_axis(ax): |
| return ax._shared_axes["y"] |
|
|
|
|
| def _check_plot_works(f, default_axes=False, **kwargs): |
| """ |
| Create plot and ensure that plot return object is valid. |
| |
| Parameters |
| ---------- |
| f : func |
| Plotting function. |
| default_axes : bool, optional |
| If False (default): |
| - If `ax` not in `kwargs`, then create subplot(211) and plot there |
| - Create new subplot(212) and plot there as well |
| - Mind special corner case for bootstrap_plot (see `_gen_two_subplots`) |
| If True: |
| - Simply run plotting function with kwargs provided |
| - All required axes instances will be created automatically |
| - It is recommended to use it when the plotting function |
| creates multiple axes itself. It helps avoid warnings like |
| 'UserWarning: To output multiple subplots, |
| the figure containing the passed axes is being cleared' |
| **kwargs |
| Keyword arguments passed to the plotting function. |
| |
| Returns |
| ------- |
| Plot object returned by the last plotting. |
| """ |
| import matplotlib.pyplot as plt |
|
|
| if default_axes: |
| gen_plots = _gen_default_plot |
| else: |
| gen_plots = _gen_two_subplots |
|
|
| ret = None |
| try: |
| fig = kwargs.get("figure", plt.gcf()) |
| plt.clf() |
|
|
| for ret in gen_plots(f, fig, **kwargs): |
| tm.assert_is_valid_plot_return_object(ret) |
|
|
| finally: |
| plt.close(fig) |
|
|
| return ret |
|
|
|
|
| def _gen_default_plot(f, fig, **kwargs): |
| """ |
| Create plot in a default way. |
| """ |
| yield f(**kwargs) |
|
|
|
|
| def _gen_two_subplots(f, fig, **kwargs): |
| """ |
| Create plot on two subplots forcefully created. |
| """ |
| if "ax" not in kwargs: |
| fig.add_subplot(211) |
| yield f(**kwargs) |
|
|
| if f is pd.plotting.bootstrap_plot: |
| assert "ax" not in kwargs |
| else: |
| kwargs["ax"] = fig.add_subplot(212) |
| yield f(**kwargs) |
|
|