| |
| from __future__ import annotations |
|
|
| from math import ceil |
| from typing import TYPE_CHECKING |
| import warnings |
|
|
| from matplotlib import ticker |
| import matplotlib.table |
| import numpy as np |
|
|
| from pandas.util._exceptions import find_stack_level |
|
|
| from pandas.core.dtypes.common import is_list_like |
| from pandas.core.dtypes.generic import ( |
| ABCDataFrame, |
| ABCIndex, |
| ABCSeries, |
| ) |
|
|
| if TYPE_CHECKING: |
| from collections.abc import ( |
| Iterable, |
| Sequence, |
| ) |
|
|
| from matplotlib.axes import Axes |
| from matplotlib.axis import Axis |
| from matplotlib.figure import Figure |
| from matplotlib.lines import Line2D |
| from matplotlib.table import Table |
|
|
| from pandas import ( |
| DataFrame, |
| Series, |
| ) |
|
|
|
|
| def do_adjust_figure(fig: Figure) -> bool: |
| """Whether fig has constrained_layout enabled.""" |
| if not hasattr(fig, "get_constrained_layout"): |
| return False |
| return not fig.get_constrained_layout() |
|
|
|
|
| def maybe_adjust_figure(fig: Figure, *args, **kwargs) -> None: |
| """Call fig.subplots_adjust unless fig has constrained_layout enabled.""" |
| if do_adjust_figure(fig): |
| fig.subplots_adjust(*args, **kwargs) |
|
|
|
|
| def format_date_labels(ax: Axes, rot) -> None: |
| |
| for label in ax.get_xticklabels(): |
| label.set_horizontalalignment("right") |
| label.set_rotation(rot) |
| fig = ax.get_figure() |
| if fig is not None: |
| |
| maybe_adjust_figure(fig, bottom=0.2) |
|
|
|
|
| def table( |
| ax, data: DataFrame | Series, rowLabels=None, colLabels=None, **kwargs |
| ) -> Table: |
| if isinstance(data, ABCSeries): |
| data = data.to_frame() |
| elif isinstance(data, ABCDataFrame): |
| pass |
| else: |
| raise ValueError("Input data must be DataFrame or Series") |
|
|
| if rowLabels is None: |
| rowLabels = data.index |
|
|
| if colLabels is None: |
| colLabels = data.columns |
|
|
| cellText = data.values |
|
|
| |
| |
| return matplotlib.table.table( |
| ax, |
| cellText=cellText, |
| rowLabels=rowLabels, |
| colLabels=colLabels, |
| **kwargs, |
| ) |
|
|
|
|
| def _get_layout( |
| nplots: int, |
| layout: tuple[int, int] | None = None, |
| layout_type: str = "box", |
| ) -> tuple[int, int]: |
| if layout is not None: |
| if not isinstance(layout, (tuple, list)) or len(layout) != 2: |
| raise ValueError("Layout must be a tuple of (rows, columns)") |
|
|
| nrows, ncols = layout |
|
|
| if nrows == -1 and ncols > 0: |
| layout = nrows, ncols = (ceil(nplots / ncols), ncols) |
| elif ncols == -1 and nrows > 0: |
| layout = nrows, ncols = (nrows, ceil(nplots / nrows)) |
| elif ncols <= 0 and nrows <= 0: |
| msg = "At least one dimension of layout must be positive" |
| raise ValueError(msg) |
|
|
| if nrows * ncols < nplots: |
| raise ValueError( |
| f"Layout of {nrows}x{ncols} must be larger than required size {nplots}" |
| ) |
|
|
| return layout |
|
|
| if layout_type == "single": |
| return (1, 1) |
| elif layout_type == "horizontal": |
| return (1, nplots) |
| elif layout_type == "vertical": |
| return (nplots, 1) |
|
|
| layouts = {1: (1, 1), 2: (1, 2), 3: (2, 2), 4: (2, 2)} |
| try: |
| return layouts[nplots] |
| except KeyError: |
| k = 1 |
| while k**2 < nplots: |
| k += 1 |
|
|
| if (k - 1) * k >= nplots: |
| return k, (k - 1) |
| else: |
| return k, k |
|
|
|
|
| |
|
|
|
|
| def create_subplots( |
| naxes: int, |
| sharex: bool = False, |
| sharey: bool = False, |
| squeeze: bool = True, |
| subplot_kw=None, |
| ax=None, |
| layout=None, |
| layout_type: str = "box", |
| **fig_kw, |
| ): |
| """ |
| Create a figure with a set of subplots already made. |
| |
| This utility wrapper makes it convenient to create common layouts of |
| subplots, including the enclosing figure object, in a single call. |
| |
| Parameters |
| ---------- |
| naxes : int |
| Number of required axes. Exceeded axes are set invisible. Default is |
| nrows * ncols. |
| |
| sharex : bool |
| If True, the X axis will be shared amongst all subplots. |
| |
| sharey : bool |
| If True, the Y axis will be shared amongst all subplots. |
| |
| squeeze : bool |
| |
| If True, extra dimensions are squeezed out from the returned axis object: |
| - if only one subplot is constructed (nrows=ncols=1), the resulting |
| single Axis object is returned as a scalar. |
| - for Nx1 or 1xN subplots, the returned object is a 1-d numpy object |
| array of Axis objects are returned as numpy 1-d arrays. |
| - for NxM subplots with N>1 and M>1 are returned as a 2d array. |
| |
| If False, no squeezing is done: the returned axis object is always |
| a 2-d array containing Axis instances, even if it ends up being 1x1. |
| |
| subplot_kw : dict |
| Dict with keywords passed to the add_subplot() call used to create each |
| subplots. |
| |
| ax : Matplotlib axis object, optional |
| |
| layout : tuple |
| Number of rows and columns of the subplot grid. |
| If not specified, calculated from naxes and layout_type |
| |
| layout_type : {'box', 'horizontal', 'vertical'}, default 'box' |
| Specify how to layout the subplot grid. |
| |
| fig_kw : Other keyword arguments to be passed to the figure() call. |
| Note that all keywords not recognized above will be |
| automatically included here. |
| |
| Returns |
| ------- |
| fig, ax : tuple |
| - fig is the Matplotlib Figure object |
| - ax can be either a single axis object or an array of axis objects if |
| more than one subplot was created. The dimensions of the resulting array |
| can be controlled with the squeeze keyword, see above. |
| |
| Examples |
| -------- |
| x = np.linspace(0, 2*np.pi, 400) |
| y = np.sin(x**2) |
| |
| # Just a figure and one subplot |
| f, ax = plt.subplots() |
| ax.plot(x, y) |
| ax.set_title('Simple plot') |
| |
| # Two subplots, unpack the output array immediately |
| f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) |
| ax1.plot(x, y) |
| ax1.set_title('Sharing Y axis') |
| ax2.scatter(x, y) |
| |
| # Four polar axes |
| plt.subplots(2, 2, subplot_kw=dict(polar=True)) |
| """ |
| import matplotlib.pyplot as plt |
|
|
| if subplot_kw is None: |
| subplot_kw = {} |
|
|
| if ax is None: |
| fig = plt.figure(**fig_kw) |
| else: |
| if is_list_like(ax): |
| if squeeze: |
| ax = flatten_axes(ax) |
| if layout is not None: |
| warnings.warn( |
| "When passing multiple axes, layout keyword is ignored.", |
| UserWarning, |
| stacklevel=find_stack_level(), |
| ) |
| if sharex or sharey: |
| warnings.warn( |
| "When passing multiple axes, sharex and sharey " |
| "are ignored. These settings must be specified when creating axes.", |
| UserWarning, |
| stacklevel=find_stack_level(), |
| ) |
| if ax.size == naxes: |
| fig = ax.flat[0].get_figure() |
| return fig, ax |
| else: |
| raise ValueError( |
| f"The number of passed axes must be {naxes}, the " |
| "same as the output plot" |
| ) |
|
|
| fig = ax.get_figure() |
| |
| if naxes == 1: |
| if squeeze: |
| return fig, ax |
| else: |
| return fig, flatten_axes(ax) |
| else: |
| warnings.warn( |
| "To output multiple subplots, the figure containing " |
| "the passed axes is being cleared.", |
| UserWarning, |
| stacklevel=find_stack_level(), |
| ) |
| fig.clear() |
|
|
| nrows, ncols = _get_layout(naxes, layout=layout, layout_type=layout_type) |
| nplots = nrows * ncols |
|
|
| |
| |
| axarr = np.empty(nplots, dtype=object) |
|
|
| |
| ax0 = fig.add_subplot(nrows, ncols, 1, **subplot_kw) |
|
|
| if sharex: |
| subplot_kw["sharex"] = ax0 |
| if sharey: |
| subplot_kw["sharey"] = ax0 |
| axarr[0] = ax0 |
|
|
| |
| |
| for i in range(1, nplots): |
| kwds = subplot_kw.copy() |
| |
| |
| |
| if i >= naxes: |
| kwds["sharex"] = None |
| kwds["sharey"] = None |
| ax = fig.add_subplot(nrows, ncols, i + 1, **kwds) |
| axarr[i] = ax |
|
|
| if naxes != nplots: |
| for ax in axarr[naxes:]: |
| ax.set_visible(False) |
|
|
| handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey) |
|
|
| if squeeze: |
| |
| |
| |
| if nplots == 1: |
| axes = axarr[0] |
| else: |
| axes = axarr.reshape(nrows, ncols).squeeze() |
| else: |
| |
| axes = axarr.reshape(nrows, ncols) |
|
|
| return fig, axes |
|
|
|
|
| def _remove_labels_from_axis(axis: Axis) -> None: |
| for t in axis.get_majorticklabels(): |
| t.set_visible(False) |
|
|
| |
| |
| if isinstance(axis.get_minor_locator(), ticker.NullLocator): |
| axis.set_minor_locator(ticker.AutoLocator()) |
| if isinstance(axis.get_minor_formatter(), ticker.NullFormatter): |
| axis.set_minor_formatter(ticker.FormatStrFormatter("")) |
| for t in axis.get_minorticklabels(): |
| t.set_visible(False) |
|
|
| axis.get_label().set_visible(False) |
|
|
|
|
| def _has_externally_shared_axis(ax1: Axes, compare_axis: str) -> bool: |
| """ |
| Return whether an axis is externally shared. |
| |
| Parameters |
| ---------- |
| ax1 : matplotlib.axes.Axes |
| Axis to query. |
| compare_axis : str |
| `"x"` or `"y"` according to whether the X-axis or Y-axis is being |
| compared. |
| |
| Returns |
| ------- |
| bool |
| `True` if the axis is externally shared. Otherwise `False`. |
| |
| Notes |
| ----- |
| If two axes with different positions are sharing an axis, they can be |
| referred to as *externally* sharing the common axis. |
| |
| If two axes sharing an axis also have the same position, they can be |
| referred to as *internally* sharing the common axis (a.k.a twinning). |
| |
| _handle_shared_axes() is only interested in axes externally sharing an |
| axis, regardless of whether either of the axes is also internally sharing |
| with a third axis. |
| """ |
| if compare_axis == "x": |
| axes = ax1.get_shared_x_axes() |
| elif compare_axis == "y": |
| axes = ax1.get_shared_y_axes() |
| else: |
| raise ValueError( |
| "_has_externally_shared_axis() needs 'x' or 'y' as a second parameter" |
| ) |
|
|
| axes_siblings = axes.get_siblings(ax1) |
|
|
| |
| ax1_points = ax1.get_position().get_points() |
|
|
| for ax2 in axes_siblings: |
| if not np.array_equal(ax1_points, ax2.get_position().get_points()): |
| return True |
|
|
| return False |
|
|
|
|
| def handle_shared_axes( |
| axarr: Iterable[Axes], |
| nplots: int, |
| naxes: int, |
| nrows: int, |
| ncols: int, |
| sharex: bool, |
| sharey: bool, |
| ) -> None: |
| if nplots > 1: |
| row_num = lambda x: x.get_subplotspec().rowspan.start |
| col_num = lambda x: x.get_subplotspec().colspan.start |
|
|
| is_first_col = lambda x: x.get_subplotspec().is_first_col() |
|
|
| if nrows > 1: |
| try: |
| |
| |
| layout = np.zeros((nrows + 1, ncols + 1), dtype=np.bool_) |
| for ax in axarr: |
| layout[row_num(ax), col_num(ax)] = ax.get_visible() |
|
|
| for ax in axarr: |
| |
| |
| |
| if not layout[row_num(ax) + 1, col_num(ax)]: |
| continue |
| if sharex or _has_externally_shared_axis(ax, "x"): |
| _remove_labels_from_axis(ax.xaxis) |
|
|
| except IndexError: |
| |
| |
| is_last_row = lambda x: x.get_subplotspec().is_last_row() |
| for ax in axarr: |
| if is_last_row(ax): |
| continue |
| if sharex or _has_externally_shared_axis(ax, "x"): |
| _remove_labels_from_axis(ax.xaxis) |
|
|
| if ncols > 1: |
| for ax in axarr: |
| |
| |
| |
| if is_first_col(ax): |
| continue |
| if sharey or _has_externally_shared_axis(ax, "y"): |
| _remove_labels_from_axis(ax.yaxis) |
|
|
|
|
| def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray: |
| if not is_list_like(axes): |
| return np.array([axes]) |
| elif isinstance(axes, (np.ndarray, ABCIndex)): |
| return np.asarray(axes).ravel() |
| return np.array(axes) |
|
|
|
|
| def set_ticks_props( |
| axes: Axes | Sequence[Axes], |
| xlabelsize: int | None = None, |
| xrot=None, |
| ylabelsize: int | None = None, |
| yrot=None, |
| ): |
| import matplotlib.pyplot as plt |
|
|
| for ax in flatten_axes(axes): |
| if xlabelsize is not None: |
| plt.setp(ax.get_xticklabels(), fontsize=xlabelsize) |
| if xrot is not None: |
| plt.setp(ax.get_xticklabels(), rotation=xrot) |
| if ylabelsize is not None: |
| plt.setp(ax.get_yticklabels(), fontsize=ylabelsize) |
| if yrot is not None: |
| plt.setp(ax.get_yticklabels(), rotation=yrot) |
| return axes |
|
|
|
|
| def get_all_lines(ax: Axes) -> list[Line2D]: |
| lines = ax.get_lines() |
|
|
| if hasattr(ax, "right_ax"): |
| lines += ax.right_ax.get_lines() |
|
|
| if hasattr(ax, "left_ax"): |
| lines += ax.left_ax.get_lines() |
|
|
| return lines |
|
|
|
|
| def get_xlim(lines: Iterable[Line2D]) -> tuple[float, float]: |
| left, right = np.inf, -np.inf |
| for line in lines: |
| x = line.get_xdata(orig=False) |
| left = min(np.nanmin(x), left) |
| right = max(np.nanmax(x), right) |
| return left, right |
|
|