| import functools |
| import traceback |
| from copy import deepcopy |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| from matplotlib.widgets import Button |
| from omegaconf import OmegaConf |
|
|
| from ..datasets.base_dataset import collate |
| from ..models.cache_loader import CacheLoader |
| from .tools import RadioHideTool |
|
|
| |
| |
|
|
|
|
| class GlobalFrame: |
| default_conf = { |
| "x": "???", |
| "y": "???", |
| "diff": False, |
| "child": {}, |
| "remove_outliers": False, |
| } |
|
|
| child_frame = None |
|
|
| childs = [] |
|
|
| lines = [] |
|
|
| scatters = {} |
|
|
| def __init__(self, conf, results, loader, predictions, title=None, child_frame=None): |
| self.child_frame = child_frame |
| if self.child_frame is not None: |
| |
| self.default_conf["child"] = self.child_frame.default_conf |
|
|
| self.conf = OmegaConf.merge(self.default_conf, conf) |
| self.results = results |
| self.loader = loader |
| self.predictions = predictions |
| self.metrics = set() |
| for k, v in results.items(): |
| self.metrics.update(v.keys()) |
| self.metrics = sorted(list(self.metrics)) |
|
|
| self.conf.x = conf["x"] or self.metrics[0] |
| self.conf.y = conf["y"] or self.metrics[1] |
|
|
| assert self.conf.x in self.metrics |
| assert self.conf.y in self.metrics |
|
|
| self.names = list(results) |
| self.fig, self.axes = self.init_frame() |
| if title is not None: |
| self.fig.canvas.manager.set_window_title(title) |
|
|
| self.xradios = self.fig.canvas.manager.toolmanager.add_tool( |
| "x", |
| RadioHideTool, |
| options=self.metrics, |
| callback_fn=self.update_x, |
| active=self.conf.x, |
| keymap="x", |
| ) |
|
|
| self.yradios = self.fig.canvas.manager.toolmanager.add_tool( |
| "y", |
| RadioHideTool, |
| options=self.metrics, |
| callback_fn=self.update_y, |
| active=self.conf.y, |
| keymap="y", |
| ) |
| if self.fig.canvas.manager.toolbar is not None: |
| self.fig.canvas.manager.toolbar.add_tool("x", "navigation") |
| self.fig.canvas.manager.toolbar.add_tool("y", "navigation") |
|
|
| def init_frame(self): |
| """initialize frame""" |
| fig, ax = plt.subplots() |
| ax.set_title("click on points") |
| diffb_ax = fig.add_axes([0.01, 0.02, 0.12, 0.06]) |
| self.diffb = Button(diffb_ax, label="diff_only") |
| self.diffb.on_clicked(self.diff_clicked) |
| fig.canvas.mpl_connect("pick_event", self.on_scatter_pick) |
| fig.canvas.mpl_connect("motion_notify_event", self.hover) |
| return fig, ax |
|
|
| def draw(self): |
| """redraw content in frame""" |
| self.scatters = {} |
| self.axes.clear() |
| self.axes.set_xlabel(self.conf.x) |
| self.axes.set_ylabel(self.conf.y) |
|
|
| refx = 0.0 |
| refy = 0.0 |
| x_cat = isinstance(self.results[self.names[0]][self.conf.x][0], (bytes, str)) |
| y_cat = isinstance(self.results[self.names[0]][self.conf.y][0], (bytes, str)) |
|
|
| if self.conf.diff: |
| if not x_cat: |
| refx = np.array(self.results[self.names[0]][self.conf.x]) |
| if not y_cat: |
| refy = np.array(self.results[self.names[0]][self.conf.y]) |
| for name in list(self.results.keys()): |
| x = np.array(self.results[name][self.conf.x]) |
| y = np.array(self.results[name][self.conf.y]) |
|
|
| if x_cat and np.char.isdigit(x.astype(str)).all(): |
| x = x.astype(int) |
| if y_cat and np.char.isdigit(y.astype(str)).all(): |
| y = y.astype(int) |
|
|
| x = x if x_cat else x - refx |
| y = y if y_cat else y - refy |
|
|
| (s,) = self.axes.plot(x, y, "o", markersize=3, label=name, picker=True, pickradius=5) |
| self.scatters[name] = s |
|
|
| if x_cat and not y_cat: |
| xunique, ind, xinv, xbin = np.unique( |
| x, return_inverse=True, return_counts=True, return_index=True |
| ) |
| ybin = np.bincount(xinv, weights=y) |
| sort_ax = np.argsort(ind) |
| self.axes.step( |
| xunique[sort_ax], |
| (ybin / xbin)[sort_ax], |
| where="mid", |
| color=s.get_color(), |
| ) |
|
|
| if not x_cat: |
| xavg = np.nan_to_num(x).mean() |
| self.axes.axvline(xavg, c=s.get_color(), zorder=1, alpha=1.0) |
| xmed = np.median(x - refx) |
| self.axes.axvline( |
| xmed, |
| c=s.get_color(), |
| zorder=0, |
| alpha=0.5, |
| linestyle="dashed", |
| visible=False, |
| ) |
|
|
| if not y_cat: |
| yavg = np.nan_to_num(y).mean() |
| self.axes.axhline(yavg, c=s.get_color(), zorder=1, alpha=0.5) |
| ymed = np.median(y - refy) |
| self.axes.axhline( |
| ymed, |
| c=s.get_color(), |
| zorder=0, |
| alpha=0.5, |
| linestyle="dashed", |
| visible=False, |
| ) |
| if x_cat and x.dtype == object and xunique.shape[0] > 5: |
| self.axes.set_xticklabels(xunique[sort_ax], rotation=90) |
| self.axes.legend() |
|
|
| def on_scatter_pick(self, handle): |
| try: |
| art = handle.artist |
| try: |
| event = handle.mouseevent.button.value |
| except AttributeError: |
| return |
| name = art.get_label() |
| ind = handle.ind[0] |
| |
| self.spawn_child(name, ind, event=event) |
| except Exception: |
| traceback.print_exc() |
| exit(0) |
|
|
| def spawn_child(self, model_name, ind, event=None): |
| [line.remove() for line in self.lines] |
| self.lines = [] |
|
|
| x_source = self.scatters[model_name].get_xdata()[ind] |
| y_source = self.scatters[model_name].get_ydata()[ind] |
| for oname in self.names: |
| xn = self.scatters[oname].get_xdata()[ind] |
| yn = self.scatters[oname].get_ydata()[ind] |
|
|
| (ln,) = self.axes.plot([x_source, xn], [y_source, yn], "r") |
| self.lines.append(ln) |
|
|
| self.fig.canvas.draw_idle() |
|
|
| if self.child_frame is None: |
| return |
|
|
| data = collate([self.loader.dataset[ind]]) |
|
|
| preds = { |
| name: CacheLoader({"path": str(pfile), "add_data_path": False})(data) |
| for name, pfile in self.predictions.items() |
| } |
| summaries_i = { |
| name: {k: v[ind] for k, v in res.items() if k != "names"} |
| for name, res in self.results.items() |
| } |
| frame = self.child_frame( |
| self.conf.child, |
| deepcopy(data), |
| preds, |
| title=str(data["name"][0]), |
| event=event, |
| summaries=summaries_i, |
| ) |
|
|
| frame.fig.canvas.mpl_connect( |
| "key_press_event", |
| functools.partial(self.on_childframe_key_event, frame=frame, ind=ind, event=event), |
| ) |
| self.childs.append(frame) |
| self.childs[-1].fig.show() |
|
|
| def hover(self, event): |
| if event.inaxes != self.axes: |
| return |
|
|
| for _, s in self.scatters.items(): |
| cont, ind = s.contains(event) |
| if cont: |
| ind = ind["ind"][0] |
| xdata, ydata = s.get_data() |
| [line.remove() for line in self.lines] |
| self.lines = [] |
|
|
| for oname in self.names: |
| xn = self.scatters[oname].get_xdata()[ind] |
| yn = self.scatters[oname].get_ydata()[ind] |
|
|
| (ln,) = self.axes.plot( |
| [xdata[ind], xn], |
| [ydata[ind], yn], |
| "black", |
| zorder=0, |
| alpha=0.5, |
| ) |
| self.lines.append(ln) |
| self.fig.canvas.draw_idle() |
| break |
|
|
| def diff_clicked(self, args): |
| self.conf.diff = not self.conf.diff |
| self.draw() |
| self.fig.canvas.draw_idle() |
|
|
| def update_x(self, x): |
| self.conf.x = x |
| self.draw() |
|
|
| def update_y(self, y): |
| self.conf.y = y |
| self.draw() |
|
|
| def on_childframe_key_event(self, key_event, frame, ind, event): |
| if key_event.key == "delete": |
| plt.close(frame.fig) |
| self.childs.remove(frame) |
| elif key_event.key in ["left", "right", "shift+left", "shift+right"]: |
| key = key_event.key |
| if key.startswith("shift+"): |
| key = key.replace("shift+", "") |
| else: |
| plt.close(frame.fig) |
| self.childs.remove(frame) |
| new_ind = ind + 1 if key_event.key == "right" else ind - 1 |
| self.spawn_child( |
| self.names[0], |
| new_ind % len(self.loader), |
| event=event, |
| ) |
|
|