| """ |
| The matplotlib plotter implementation for all the primitive tasks (in our case: lines and |
| dots) |
| """ |
| from typing import Any, Callable, Dict, List |
|
|
| import matplotlib.pyplot as plt |
| import mpl_toolkits.mplot3d.axes3d as p3 |
|
|
| import numpy as np |
|
|
| from .core import BasePlotter, BasePlotterTask |
|
|
|
|
| class Matplotlib2DPlotter(BasePlotter): |
| _fig: plt.figure |
| _ax: plt.axis |
| |
| _artist_cache: Dict[str, Any] |
| |
| _create_impl_callables: Dict[str, Callable] |
| _update_impl_callables: Dict[str, Callable] |
|
|
| def __init__(self, task: "BasePlotterTask") -> None: |
| fig, ax = plt.subplots() |
| self._fig = fig |
| self._ax = ax |
| self._artist_cache = {} |
|
|
| self._create_impl_callables = { |
| "Draw2DLines": self._lines_create_impl, |
| "Draw2DDots": self._dots_create_impl, |
| "Draw2DTrail": self._trail_create_impl, |
| } |
| self._update_impl_callables = { |
| "Draw2DLines": self._lines_update_impl, |
| "Draw2DDots": self._dots_update_impl, |
| "Draw2DTrail": self._trail_update_impl, |
| } |
| self._init_lim() |
| super().__init__(task) |
|
|
| @property |
| def ax(self): |
| return self._ax |
|
|
| @property |
| def fig(self): |
| return self._fig |
|
|
| def show(self): |
| plt.show() |
|
|
| def _min(self, x, y): |
| if x is None: |
| return y |
| if y is None: |
| return x |
| return min(x, y) |
|
|
| def _max(self, x, y): |
| if x is None: |
| return y |
| if y is None: |
| return x |
| return max(x, y) |
|
|
| def _init_lim(self): |
| self._curr_x_min = None |
| self._curr_y_min = None |
| self._curr_x_max = None |
| self._curr_y_max = None |
|
|
| def _update_lim(self, xs, ys): |
| self._curr_x_min = self._min(np.min(xs), self._curr_x_min) |
| self._curr_y_min = self._min(np.min(ys), self._curr_y_min) |
| self._curr_x_max = self._max(np.max(xs), self._curr_x_max) |
| self._curr_y_max = self._max(np.max(ys), self._curr_y_max) |
|
|
| def _set_lim(self): |
| if not ( |
| self._curr_x_min is None |
| or self._curr_x_max is None |
| or self._curr_y_min is None |
| or self._curr_y_max is None |
| ): |
| self._ax.set_xlim(self._curr_x_min, self._curr_x_max) |
| self._ax.set_ylim(self._curr_y_min, self._curr_y_max) |
| self._init_lim() |
|
|
| @staticmethod |
| def _lines_extract_xy_impl(index, lines_task): |
| return lines_task[index, :, 0], lines_task[index, :, 1] |
|
|
| @staticmethod |
| def _trail_extract_xy_impl(index, trail_task): |
| return (trail_task[index : index + 2, 0], trail_task[index : index + 2, 1]) |
|
|
| def _lines_create_impl(self, lines_task): |
| color = lines_task.color |
| self._artist_cache[lines_task.task_name] = [ |
| self._ax.plot( |
| *Matplotlib2DPlotter._lines_extract_xy_impl(i, lines_task), |
| color=color, |
| linewidth=lines_task.line_width, |
| alpha=lines_task.alpha |
| )[0] |
| for i in range(len(lines_task)) |
| ] |
|
|
| def _lines_update_impl(self, lines_task): |
| lines_artists = self._artist_cache[lines_task.task_name] |
| for i in range(len(lines_task)): |
| artist = lines_artists[i] |
| xs, ys = Matplotlib2DPlotter._lines_extract_xy_impl(i, lines_task) |
| artist.set_data(xs, ys) |
| if lines_task.influence_lim: |
| self._update_lim(xs, ys) |
|
|
| def _dots_create_impl(self, dots_task): |
| color = dots_task.color |
| self._artist_cache[dots_task.task_name] = self._ax.plot( |
| dots_task[:, 0], |
| dots_task[:, 1], |
| c=color, |
| linestyle="", |
| marker=".", |
| markersize=dots_task.marker_size, |
| alpha=dots_task.alpha, |
| )[0] |
|
|
| def _dots_update_impl(self, dots_task): |
| dots_artist = self._artist_cache[dots_task.task_name] |
| dots_artist.set_data(dots_task[:, 0], dots_task[:, 1]) |
| if dots_task.influence_lim: |
| self._update_lim(dots_task[:, 0], dots_task[:, 1]) |
|
|
| def _trail_create_impl(self, trail_task): |
| color = trail_task.color |
| trail_length = len(trail_task) - 1 |
| self._artist_cache[trail_task.task_name] = [ |
| self._ax.plot( |
| *Matplotlib2DPlotter._trail_extract_xy_impl(i, trail_task), |
| color=trail_task.color, |
| linewidth=trail_task.line_width, |
| alpha=trail_task.alpha * (1.0 - i / (trail_length - 1)) |
| )[0] |
| for i in range(trail_length) |
| ] |
|
|
| def _trail_update_impl(self, trail_task): |
| trails_artists = self._artist_cache[trail_task.task_name] |
| for i in range(len(trail_task) - 1): |
| artist = trails_artists[i] |
| xs, ys = Matplotlib2DPlotter._trail_extract_xy_impl(i, trail_task) |
| artist.set_data(xs, ys) |
| if trail_task.influence_lim: |
| self._update_lim(xs, ys) |
|
|
| def _create_impl(self, task_list): |
| for task in task_list: |
| self._create_impl_callables[task.task_type](task) |
| self._draw() |
|
|
| def _update_impl(self, task_list): |
| for task in task_list: |
| self._update_impl_callables[task.task_type](task) |
| self._draw() |
|
|
| def _set_aspect_equal_2d(self, zero_centered=True): |
| xlim = self._ax.get_xlim() |
| ylim = self._ax.get_ylim() |
|
|
| if not zero_centered: |
| xmean = np.mean(xlim) |
| ymean = np.mean(ylim) |
| else: |
| xmean = 0 |
| ymean = 0 |
|
|
| plot_radius = max( |
| [ |
| abs(lim - mean_) |
| for lims, mean_ in ((xlim, xmean), (ylim, ymean)) |
| for lim in lims |
| ] |
| ) |
|
|
| self._ax.set_xlim([xmean - plot_radius, xmean + plot_radius]) |
| self._ax.set_ylim([ymean - plot_radius, ymean + plot_radius]) |
|
|
| def _draw(self): |
| self._set_lim() |
| self._set_aspect_equal_2d() |
| self._fig.canvas.draw() |
| self._fig.canvas.flush_events() |
| plt.pause(0.00001) |
|
|
|
|
| class Matplotlib3DPlotter(BasePlotter): |
| _fig: plt.figure |
| _ax: p3.Axes3D |
| |
| _artist_cache: Dict[str, Any] |
| |
| _create_impl_callables: Dict[str, Callable] |
| _update_impl_callables: Dict[str, Callable] |
|
|
| def __init__(self, task: "BasePlotterTask") -> None: |
| self._fig = plt.figure() |
| self._ax = p3.Axes3D(self._fig) |
| self._artist_cache = {} |
|
|
| self._create_impl_callables = { |
| "Draw3DLines": self._lines_create_impl, |
| "Draw3DDots": self._dots_create_impl, |
| "Draw3DTrail": self._trail_create_impl, |
| } |
| self._update_impl_callables = { |
| "Draw3DLines": self._lines_update_impl, |
| "Draw3DDots": self._dots_update_impl, |
| "Draw3DTrail": self._trail_update_impl, |
| } |
| self._init_lim() |
| super().__init__(task) |
|
|
| @property |
| def ax(self): |
| return self._ax |
|
|
| @property |
| def fig(self): |
| return self._fig |
|
|
| def show(self): |
| plt.show() |
|
|
| def _min(self, x, y): |
| if x is None: |
| return y |
| if y is None: |
| return x |
| return min(x, y) |
|
|
| def _max(self, x, y): |
| if x is None: |
| return y |
| if y is None: |
| return x |
| return max(x, y) |
|
|
| def _init_lim(self): |
| self._curr_x_min = None |
| self._curr_y_min = None |
| self._curr_z_min = None |
| self._curr_x_max = None |
| self._curr_y_max = None |
| self._curr_z_max = None |
|
|
| def _update_lim(self, xs, ys, zs): |
| self._curr_x_min = self._min(np.min(xs), self._curr_x_min) |
| self._curr_y_min = self._min(np.min(ys), self._curr_y_min) |
| self._curr_z_min = self._min(np.min(zs), self._curr_z_min) |
| self._curr_x_max = self._max(np.max(xs), self._curr_x_max) |
| self._curr_y_max = self._max(np.max(ys), self._curr_y_max) |
| self._curr_z_max = self._max(np.max(zs), self._curr_z_max) |
|
|
| def _set_lim(self): |
| if not ( |
| self._curr_x_min is None |
| or self._curr_x_max is None |
| or self._curr_y_min is None |
| or self._curr_y_max is None |
| or self._curr_z_min is None |
| or self._curr_z_max is None |
| ): |
| self._ax.set_xlim3d(self._curr_x_min, self._curr_x_max) |
| self._ax.set_ylim3d(self._curr_y_min, self._curr_y_max) |
| self._ax.set_zlim3d(self._curr_z_min, self._curr_z_max) |
| self._init_lim() |
|
|
| @staticmethod |
| def _lines_extract_xyz_impl(index, lines_task): |
| return lines_task[index, :, 0], lines_task[index, :, 1], lines_task[index, :, 2] |
|
|
| @staticmethod |
| def _trail_extract_xyz_impl(index, trail_task): |
| return ( |
| trail_task[index : index + 2, 0], |
| trail_task[index : index + 2, 1], |
| trail_task[index : index + 2, 2], |
| ) |
|
|
| def _lines_create_impl(self, lines_task): |
| color = lines_task.color |
| self._artist_cache[lines_task.task_name] = [ |
| self._ax.plot( |
| *Matplotlib3DPlotter._lines_extract_xyz_impl(i, lines_task), |
| color=color, |
| linewidth=lines_task.line_width, |
| alpha=lines_task.alpha |
| )[0] |
| for i in range(len(lines_task)) |
| ] |
|
|
| def _lines_update_impl(self, lines_task): |
| lines_artists = self._artist_cache[lines_task.task_name] |
| for i in range(len(lines_task)): |
| artist = lines_artists[i] |
| xs, ys, zs = Matplotlib3DPlotter._lines_extract_xyz_impl(i, lines_task) |
| artist.set_data(xs, ys) |
| artist.set_3d_properties(zs) |
| if lines_task.influence_lim: |
| self._update_lim(xs, ys, zs) |
|
|
| def _dots_create_impl(self, dots_task): |
| color = dots_task.color |
| self._artist_cache[dots_task.task_name] = self._ax.plot( |
| dots_task[:, 0], |
| dots_task[:, 1], |
| dots_task[:, 2], |
| c=color, |
| linestyle="", |
| marker=".", |
| markersize=dots_task.marker_size, |
| alpha=dots_task.alpha, |
| )[0] |
|
|
| def _dots_update_impl(self, dots_task): |
| dots_artist = self._artist_cache[dots_task.task_name] |
| dots_artist.set_data(dots_task[:, 0], dots_task[:, 1]) |
| dots_artist.set_3d_properties(dots_task[:, 2]) |
| if dots_task.influence_lim: |
| self._update_lim(dots_task[:, 0], dots_task[:, 1], dots_task[:, 2]) |
|
|
| def _trail_create_impl(self, trail_task): |
| color = trail_task.color |
| trail_length = len(trail_task) - 1 |
| self._artist_cache[trail_task.task_name] = [ |
| self._ax.plot( |
| *Matplotlib3DPlotter._trail_extract_xyz_impl(i, trail_task), |
| color=trail_task.color, |
| linewidth=trail_task.line_width, |
| alpha=trail_task.alpha * (1.0 - i / (trail_length - 1)) |
| )[0] |
| for i in range(trail_length) |
| ] |
|
|
| def _trail_update_impl(self, trail_task): |
| trails_artists = self._artist_cache[trail_task.task_name] |
| for i in range(len(trail_task) - 1): |
| artist = trails_artists[i] |
| xs, ys, zs = Matplotlib3DPlotter._trail_extract_xyz_impl(i, trail_task) |
| artist.set_data(xs, ys) |
| artist.set_3d_properties(zs) |
| if trail_task.influence_lim: |
| self._update_lim(xs, ys, zs) |
|
|
| def _create_impl(self, task_list): |
| for task in task_list: |
| self._create_impl_callables[task.task_type](task) |
| self._draw() |
|
|
| def _update_impl(self, task_list): |
| for task in task_list: |
| self._update_impl_callables[task.task_type](task) |
| self._draw() |
|
|
| def _set_aspect_equal_3d(self): |
| xlim = self._ax.get_xlim3d() |
| ylim = self._ax.get_ylim3d() |
| zlim = self._ax.get_zlim3d() |
|
|
| xmean = np.mean(xlim) |
| ymean = np.mean(ylim) |
| zmean = np.mean(zlim) |
|
|
| plot_radius = max( |
| [ |
| abs(lim - mean_) |
| for lims, mean_ in ((xlim, xmean), (ylim, ymean), (zlim, zmean)) |
| for lim in lims |
| ] |
| ) |
|
|
| self._ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius]) |
| self._ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius]) |
| self._ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius]) |
|
|
| def _draw(self): |
| self._set_lim() |
| self._set_aspect_equal_3d() |
| self._fig.canvas.draw() |
| self._fig.canvas.flush_events() |
| plt.pause(0.00001) |
|
|