""" The matplotlib plotter implementation for all the primitive tasks (in our case: lines and dots) """ from typing import Any, Callable, Dict, List import matplotlib.pyplot as plt import mpl_toolkits.mplot3d.axes3d as p3 import numpy as np from .core import BasePlotter, BasePlotterTask class Matplotlib2DPlotter(BasePlotter): _fig: plt.figure # plt figure _ax: plt.axis # plt axis # stores artist objects for each task (task name as the key) _artist_cache: Dict[str, Any] # callables for each task primitives _create_impl_callables: Dict[str, Callable] _update_impl_callables: Dict[str, Callable] def __init__(self, task: "BasePlotterTask") -> None: fig, ax = plt.subplots() self._fig = fig self._ax = ax self._artist_cache = {} self._create_impl_callables = { "Draw2DLines": self._lines_create_impl, "Draw2DDots": self._dots_create_impl, "Draw2DTrail": self._trail_create_impl, } self._update_impl_callables = { "Draw2DLines": self._lines_update_impl, "Draw2DDots": self._dots_update_impl, "Draw2DTrail": self._trail_update_impl, } self._init_lim() super().__init__(task) @property def ax(self): return self._ax @property def fig(self): return self._fig def show(self): plt.show() def _min(self, x, y): if x is None: return y if y is None: return x return min(x, y) def _max(self, x, y): if x is None: return y if y is None: return x return max(x, y) def _init_lim(self): self._curr_x_min = None self._curr_y_min = None self._curr_x_max = None self._curr_y_max = None def _update_lim(self, xs, ys): self._curr_x_min = self._min(np.min(xs), self._curr_x_min) self._curr_y_min = self._min(np.min(ys), self._curr_y_min) self._curr_x_max = self._max(np.max(xs), self._curr_x_max) self._curr_y_max = self._max(np.max(ys), self._curr_y_max) def _set_lim(self): if not ( self._curr_x_min is None or self._curr_x_max is None or self._curr_y_min is None or self._curr_y_max is None ): self._ax.set_xlim(self._curr_x_min, self._curr_x_max) self._ax.set_ylim(self._curr_y_min, self._curr_y_max) self._init_lim() @staticmethod def _lines_extract_xy_impl(index, lines_task): return lines_task[index, :, 0], lines_task[index, :, 1] @staticmethod def _trail_extract_xy_impl(index, trail_task): return (trail_task[index : index + 2, 0], trail_task[index : index + 2, 1]) def _lines_create_impl(self, lines_task): color = lines_task.color self._artist_cache[lines_task.task_name] = [ self._ax.plot( *Matplotlib2DPlotter._lines_extract_xy_impl(i, lines_task), color=color, linewidth=lines_task.line_width, alpha=lines_task.alpha )[0] for i in range(len(lines_task)) ] def _lines_update_impl(self, lines_task): lines_artists = self._artist_cache[lines_task.task_name] for i in range(len(lines_task)): artist = lines_artists[i] xs, ys = Matplotlib2DPlotter._lines_extract_xy_impl(i, lines_task) artist.set_data(xs, ys) if lines_task.influence_lim: self._update_lim(xs, ys) def _dots_create_impl(self, dots_task): color = dots_task.color self._artist_cache[dots_task.task_name] = self._ax.plot( dots_task[:, 0], dots_task[:, 1], c=color, linestyle="", marker=".", markersize=dots_task.marker_size, alpha=dots_task.alpha, )[0] def _dots_update_impl(self, dots_task): dots_artist = self._artist_cache[dots_task.task_name] dots_artist.set_data(dots_task[:, 0], dots_task[:, 1]) if dots_task.influence_lim: self._update_lim(dots_task[:, 0], dots_task[:, 1]) def _trail_create_impl(self, trail_task): color = trail_task.color trail_length = len(trail_task) - 1 self._artist_cache[trail_task.task_name] = [ self._ax.plot( *Matplotlib2DPlotter._trail_extract_xy_impl(i, trail_task), color=trail_task.color, linewidth=trail_task.line_width, alpha=trail_task.alpha * (1.0 - i / (trail_length - 1)) )[0] for i in range(trail_length) ] def _trail_update_impl(self, trail_task): trails_artists = self._artist_cache[trail_task.task_name] for i in range(len(trail_task) - 1): artist = trails_artists[i] xs, ys = Matplotlib2DPlotter._trail_extract_xy_impl(i, trail_task) artist.set_data(xs, ys) if trail_task.influence_lim: self._update_lim(xs, ys) def _create_impl(self, task_list): for task in task_list: self._create_impl_callables[task.task_type](task) self._draw() def _update_impl(self, task_list): for task in task_list: self._update_impl_callables[task.task_type](task) self._draw() def _set_aspect_equal_2d(self, zero_centered=True): xlim = self._ax.get_xlim() ylim = self._ax.get_ylim() if not zero_centered: xmean = np.mean(xlim) ymean = np.mean(ylim) else: xmean = 0 ymean = 0 plot_radius = max( [ abs(lim - mean_) for lims, mean_ in ((xlim, xmean), (ylim, ymean)) for lim in lims ] ) self._ax.set_xlim([xmean - plot_radius, xmean + plot_radius]) self._ax.set_ylim([ymean - plot_radius, ymean + plot_radius]) def _draw(self): self._set_lim() self._set_aspect_equal_2d() self._fig.canvas.draw() self._fig.canvas.flush_events() plt.pause(0.00001) class Matplotlib3DPlotter(BasePlotter): _fig: plt.figure # plt figure _ax: p3.Axes3D # plt 3d axis # stores artist objects for each task (task name as the key) _artist_cache: Dict[str, Any] # callables for each task primitives _create_impl_callables: Dict[str, Callable] _update_impl_callables: Dict[str, Callable] def __init__(self, task: "BasePlotterTask") -> None: self._fig = plt.figure() self._ax = p3.Axes3D(self._fig) self._artist_cache = {} self._create_impl_callables = { "Draw3DLines": self._lines_create_impl, "Draw3DDots": self._dots_create_impl, "Draw3DTrail": self._trail_create_impl, } self._update_impl_callables = { "Draw3DLines": self._lines_update_impl, "Draw3DDots": self._dots_update_impl, "Draw3DTrail": self._trail_update_impl, } self._init_lim() super().__init__(task) @property def ax(self): return self._ax @property def fig(self): return self._fig def show(self): plt.show() def _min(self, x, y): if x is None: return y if y is None: return x return min(x, y) def _max(self, x, y): if x is None: return y if y is None: return x return max(x, y) def _init_lim(self): self._curr_x_min = None self._curr_y_min = None self._curr_z_min = None self._curr_x_max = None self._curr_y_max = None self._curr_z_max = None def _update_lim(self, xs, ys, zs): self._curr_x_min = self._min(np.min(xs), self._curr_x_min) self._curr_y_min = self._min(np.min(ys), self._curr_y_min) self._curr_z_min = self._min(np.min(zs), self._curr_z_min) self._curr_x_max = self._max(np.max(xs), self._curr_x_max) self._curr_y_max = self._max(np.max(ys), self._curr_y_max) self._curr_z_max = self._max(np.max(zs), self._curr_z_max) def _set_lim(self): if not ( self._curr_x_min is None or self._curr_x_max is None or self._curr_y_min is None or self._curr_y_max is None or self._curr_z_min is None or self._curr_z_max is None ): self._ax.set_xlim3d(self._curr_x_min, self._curr_x_max) self._ax.set_ylim3d(self._curr_y_min, self._curr_y_max) self._ax.set_zlim3d(self._curr_z_min, self._curr_z_max) self._init_lim() @staticmethod def _lines_extract_xyz_impl(index, lines_task): return lines_task[index, :, 0], lines_task[index, :, 1], lines_task[index, :, 2] @staticmethod def _trail_extract_xyz_impl(index, trail_task): return ( trail_task[index : index + 2, 0], trail_task[index : index + 2, 1], trail_task[index : index + 2, 2], ) def _lines_create_impl(self, lines_task): color = lines_task.color self._artist_cache[lines_task.task_name] = [ self._ax.plot( *Matplotlib3DPlotter._lines_extract_xyz_impl(i, lines_task), color=color, linewidth=lines_task.line_width, alpha=lines_task.alpha )[0] for i in range(len(lines_task)) ] def _lines_update_impl(self, lines_task): lines_artists = self._artist_cache[lines_task.task_name] for i in range(len(lines_task)): artist = lines_artists[i] xs, ys, zs = Matplotlib3DPlotter._lines_extract_xyz_impl(i, lines_task) artist.set_data(xs, ys) artist.set_3d_properties(zs) if lines_task.influence_lim: self._update_lim(xs, ys, zs) def _dots_create_impl(self, dots_task): color = dots_task.color self._artist_cache[dots_task.task_name] = self._ax.plot( dots_task[:, 0], dots_task[:, 1], dots_task[:, 2], c=color, linestyle="", marker=".", markersize=dots_task.marker_size, alpha=dots_task.alpha, )[0] def _dots_update_impl(self, dots_task): dots_artist = self._artist_cache[dots_task.task_name] dots_artist.set_data(dots_task[:, 0], dots_task[:, 1]) dots_artist.set_3d_properties(dots_task[:, 2]) if dots_task.influence_lim: self._update_lim(dots_task[:, 0], dots_task[:, 1], dots_task[:, 2]) def _trail_create_impl(self, trail_task): color = trail_task.color trail_length = len(trail_task) - 1 self._artist_cache[trail_task.task_name] = [ self._ax.plot( *Matplotlib3DPlotter._trail_extract_xyz_impl(i, trail_task), color=trail_task.color, linewidth=trail_task.line_width, alpha=trail_task.alpha * (1.0 - i / (trail_length - 1)) )[0] for i in range(trail_length) ] def _trail_update_impl(self, trail_task): trails_artists = self._artist_cache[trail_task.task_name] for i in range(len(trail_task) - 1): artist = trails_artists[i] xs, ys, zs = Matplotlib3DPlotter._trail_extract_xyz_impl(i, trail_task) artist.set_data(xs, ys) artist.set_3d_properties(zs) if trail_task.influence_lim: self._update_lim(xs, ys, zs) def _create_impl(self, task_list): for task in task_list: self._create_impl_callables[task.task_type](task) self._draw() def _update_impl(self, task_list): for task in task_list: self._update_impl_callables[task.task_type](task) self._draw() def _set_aspect_equal_3d(self): xlim = self._ax.get_xlim3d() ylim = self._ax.get_ylim3d() zlim = self._ax.get_zlim3d() xmean = np.mean(xlim) ymean = np.mean(ylim) zmean = np.mean(zlim) plot_radius = max( [ abs(lim - mean_) for lims, mean_ in ((xlim, xmean), (ylim, ymean), (zlim, zmean)) for lim in lims ] ) self._ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius]) self._ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius]) self._ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius]) def _draw(self): self._set_lim() self._set_aspect_equal_3d() self._fig.canvas.draw() self._fig.canvas.flush_events() plt.pause(0.00001)