| import inspect |
| import sys |
| import warnings |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| from matplotlib.backend_tools import ToolToggleBase |
| from matplotlib.widgets import Button, RadioButtons |
|
|
| from siclib.geometry.camera import SimpleRadial as Camera |
| from siclib.geometry.gravity import Gravity |
| from siclib.geometry.perspective_fields import ( |
| get_latitude_field, |
| get_perspective_field, |
| get_up_field, |
| ) |
| from siclib.models.utils.metrics import latitude_error, up_error |
| from siclib.utils.conversions import rad2deg |
| from siclib.visualization.viz2d import ( |
| add_text, |
| plot_confidences, |
| plot_heatmaps, |
| plot_horizon_lines, |
| plot_latitudes, |
| plot_vector_fields, |
| ) |
|
|
| |
| |
|
|
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| plt.rcParams["toolbar"] = "toolmanager" |
|
|
|
|
| class RadioHideTool(ToolToggleBase): |
| """Show lines with a given gid.""" |
|
|
| default_keymap = "R" |
| description = "Show by gid" |
| default_toggled = False |
| radio_group = "default" |
|
|
| def __init__(self, *args, options=[], active=None, callback_fn=None, keymap="R", **kwargs): |
| super().__init__(*args, **kwargs) |
| self.f = 1.0 |
| self.options = options |
| self.callback_fn = callback_fn |
| self.active = self.options.index(active) if active else 0 |
| self.default_keymap = keymap |
|
|
| self.enabled = self.default_toggled |
|
|
| def build_radios(self): |
| w = 0.2 |
| self.radios_ax = self.figure.add_axes([1.0 - w, 0.4, w, 0.5], zorder=1) |
| |
| self.radios = RadioButtons(self.radios_ax, self.options, active=self.active) |
| self.radios.on_clicked(self.on_radio_clicked) |
|
|
| def enable(self, *args): |
| size = self.figure.get_size_inches() |
| size[0] *= self.f |
| self.build_radios() |
| self.figure.canvas.draw_idle() |
| self.enabled = True |
|
|
| def disable(self, *args): |
| size = self.figure.get_size_inches() |
| size[0] /= self.f |
| self.radios_ax.remove() |
| self.radios = None |
| self.figure.canvas.draw_idle() |
| self.enabled = False |
|
|
| def on_radio_clicked(self, value): |
| self.active = self.options.index(value) |
| enabled = self.enabled |
| if enabled: |
| self.disable() |
| if self.callback_fn is not None: |
| self.callback_fn(value) |
| if enabled: |
| self.enable() |
|
|
|
|
| class ToggleTool(ToolToggleBase): |
| """Show lines with a given gid.""" |
|
|
| default_keymap = "t" |
| description = "Show by gid" |
|
|
| def __init__(self, *args, callback_fn=None, keymap="t", **kwargs): |
| super().__init__(*args, **kwargs) |
| self.f = 1.0 |
| self.callback_fn = callback_fn |
| self.default_keymap = keymap |
| self.enabled = self.default_toggled |
|
|
| def enable(self, *args): |
| self.callback_fn(True) |
|
|
| def disable(self, *args): |
| self.callback_fn(False) |
|
|
|
|
| def add_whitespace_left(fig, factor): |
| w, h = fig.get_size_inches() |
| left = fig.subplotpars.left |
| fig.set_size_inches([w * (1 + factor), h]) |
| fig.subplots_adjust(left=(factor + left) / (1 + factor)) |
|
|
|
|
| def add_whitespace_bottom(fig, factor): |
| w, h = fig.get_size_inches() |
| b = fig.subplotpars.bottom |
| fig.set_size_inches([w, h * (1 + factor)]) |
| fig.subplots_adjust(bottom=(factor + b) / (1 + factor)) |
| fig.canvas.draw_idle() |
|
|
|
|
| class ImagePlot: |
| plot_name = "image" |
| required_keys = ["image"] |
|
|
| def __init__(self, fig, axes, data, preds): |
| pass |
|
|
|
|
| class HorizonLinePlot: |
| plot_name = "horizon_line" |
| required_keys = ["camera", "gravity"] |
|
|
| def __init__(self, fig, axes, data, preds): |
| for idx, name in enumerate(preds): |
| pred = preds[name] |
| gt_cam = data["camera"][0].detach().cpu() |
| gt_gravity = data["gravity"][0].detach().cpu() |
| plot_horizon_lines([gt_cam], [gt_gravity], line_colors="r", ax=[axes[0][idx]]) |
|
|
| if "camera" in pred and "gravity" in pred: |
| pred_cam = Camera(pred["camera"][0].detach().cpu()) |
| gravity = Gravity(pred["gravity"][0].detach().cpu()) |
| plot_horizon_lines([pred_cam], [gravity], line_colors="yellow", ax=[axes[0][idx]]) |
|
|
|
|
| class LatitudePlot: |
| plot_name = "latitude" |
| required_keys = ["latitude_field"] |
|
|
| def __init__(self, fig, axes, data, preds): |
| self.artists = [] |
| self.gt_mode = False |
| self.text_objects = [] |
|
|
| self.fig = fig |
| self.axes = axes |
| self.data = data |
| self.preds = preds |
|
|
| |
| self.ax_button = self.fig.add_axes([0.01, 0.02, 0.2, 0.06]) |
| self.button = Button(self.ax_button, "Toggle GT") |
| self.button.on_clicked(self.toggle_display) |
|
|
| self.update_plot() |
|
|
| def toggle_display(self, event): |
| |
| self.gt_mode = not self.gt_mode |
| self.update_plot() |
|
|
| def update_plot(self): |
| for x in self.artists: |
| x.remove() |
| for text in self.text_objects: |
| text.remove() |
|
|
| self.artists = [] |
| self.text_objects = [] |
|
|
| for idx, name in enumerate(self.preds): |
| pred = self.preds[name] |
|
|
| if self.gt_mode: |
| latitude = self.data["latitude_field"][0][0] |
| text = "\nGT" |
| else: |
| if "latitude_field" not in pred: |
| continue |
| latitude = pred["latitude_field"][0][0] |
| text = "\nPrediction" |
|
|
| self.artists += plot_latitudes([latitude], axes=[self.axes[0][idx]]) |
|
|
| self.text_objects.append(add_text(idx, text)) |
|
|
| |
| self.fig.canvas.draw() |
|
|
| def clear(self): |
| |
| self.button.disconnect_events() |
| self.ax_button.remove() |
|
|
| for x in self.artists: |
| x.remove() |
| for text in self.text_objects: |
| text.remove() |
|
|
| self.artists = [] |
| self.text_objects = [] |
|
|
|
|
| class LatitudeErrorPlot: |
| plot_name = "latitude_error" |
| required_keys = ["latitude_field"] |
|
|
| def __init__(self, fig, axes, data, preds): |
| self.artists = [] |
| for idx, name in enumerate(preds): |
| pred = preds[name] |
| gt = data["latitude_field"].detach().cpu() |
|
|
| if "latitude_field" in pred: |
| lat = pred["latitude_field"].detach().cpu() |
| error = latitude_error(lat, gt)[0].numpy() |
|
|
| if "latitude_confidence" in pred: |
| confidence = pred["latitude_confidence"].detach().cpu().numpy() |
| confidence = np.log10(confidence).clip(-5) |
| confidence = (confidence + 5) / (confidence.max() + 5) |
| arts = plot_heatmaps( |
| [error], cmap="turbo", axes=[axes[0][idx]], colorbar=True, a=confidence |
| ) |
| else: |
| arts = plot_heatmaps([error], cmap="turbo", axes=[axes[0][idx]], colorbar=True) |
| self.artists += arts |
|
|
| def clear(self): |
| for x in self.artists: |
| x.remove() |
| x.colorbar.remove() |
|
|
| self.artists = [] |
|
|
|
|
| class LatitudeConfidencePlot: |
| plot_name = "latitude_confidence" |
| required_keys = [] |
| |
|
|
| def __init__(self, fig, axes, data, preds): |
| self.artists = [] |
| for idx, name in enumerate(preds): |
| pred = preds[name] |
|
|
| if "latitude_confidence" in pred: |
| arts = plot_confidences([pred["latitude_confidence"][0]], axes=[axes[0][idx]]) |
| self.artists += arts |
|
|
| def clear(self): |
| for x in self.artists: |
| x.remove() |
| x.colorbar.remove() |
|
|
| self.artists = [] |
|
|
|
|
| class UpPlot: |
| plot_name = "up" |
| required_keys = ["up_field"] |
|
|
| def __init__(self, fig, axes, data, preds): |
| self.artists = [] |
| self.gt_mode = False |
| self.text_objects = [] |
|
|
| self.fig = fig |
| self.axes = axes |
| self.data = data |
| self.preds = preds |
|
|
| |
| self.ax_button = self.fig.add_axes([0.01, 0.02, 0.2, 0.06]) |
| self.button = Button(self.ax_button, "Toggle GT") |
| self.button.on_clicked(self.toggle_display) |
|
|
| self.update_plot() |
|
|
| def toggle_display(self, event): |
| |
| self.gt_mode = not self.gt_mode |
| self.update_plot() |
|
|
| def update_plot(self): |
| for x in self.artists: |
| x.remove() |
| for text in self.text_objects: |
| text.remove() |
|
|
| self.artists = [] |
| self.text_objects = [] |
|
|
| for idx, name in enumerate(self.preds): |
| pred = self.preds[name] |
|
|
| if self.gt_mode: |
| up = self.data["up_field"][0] |
| text = "\nGT" |
| else: |
| if "up_field" not in pred: |
| continue |
| up = pred["up_field"][0] |
| text = "\nPrediction" |
|
|
| |
| self.artists += plot_vector_fields([up], axes=[self.axes[0][idx]]) |
|
|
| self.text_objects.append(add_text(idx, text)) |
|
|
| |
| self.fig.canvas.draw() |
|
|
| def clear(self): |
| |
| self.button.disconnect_events() |
| self.ax_button.remove() |
|
|
| for x in self.artists: |
| x.remove() |
| for text in self.text_objects: |
| text.remove() |
|
|
| self.artists = [] |
| self.text_objects = [] |
|
|
|
|
| class UpErrorPlot: |
| plot_name = "up_error" |
| required_keys = ["up_field"] |
|
|
| def __init__(self, fig, axes, data, preds): |
| self.artists = [] |
| for idx, name in enumerate(preds): |
| pred = preds[name] |
| gt = data["up_field"].detach().cpu() |
|
|
| if "up_field" in pred: |
| up = pred["up_field"].detach().cpu() |
| error = up_error(up, gt)[0].numpy() |
|
|
| if "up_confidence" in pred: |
| confidence = pred["up_confidence"].detach().cpu().numpy() |
| confidence = np.log10(confidence).clip(-5) |
| confidence = (confidence + 5) / (confidence.max() + 5) |
| arts = plot_heatmaps( |
| [error], cmap="turbo", axes=[axes[0][idx]], colorbar=True, a=confidence |
| ) |
| else: |
| arts = plot_heatmaps([error], cmap="turbo", axes=[axes[0][idx]], colorbar=True) |
| self.artists += arts |
|
|
| def clear(self): |
| for x in self.artists: |
| x.remove() |
| x.colorbar.remove() |
|
|
| self.artists = [] |
|
|
|
|
| class UpConfidencePlot: |
| plot_name = "up_confidence" |
| required_keys = [] |
| |
|
|
| def __init__(self, fig, axes, data, preds): |
| self.artists = [] |
| for idx, name in enumerate(preds): |
| pred = preds[name] |
|
|
| if "up_confidence" in pred: |
| arts = plot_confidences([pred["up_confidence"][0]], axes=[axes[0][idx]]) |
| self.artists += arts |
|
|
| def clear(self): |
| for x in self.artists: |
| x.remove() |
| x.colorbar.remove() |
|
|
| self.artists = [] |
|
|
|
|
| class PerspectiveField: |
| plot_name = "perspective_field" |
| required_keys = ["camera", "gravity"] |
|
|
| def __init__(self, fig, axes, data, preds): |
| self.artists = [] |
| self.gt_mode = False |
| self.text_objects = [] |
|
|
| self.fig = fig |
| self.axes = axes |
| self.data = data |
| self.preds = preds |
|
|
| |
| self.ax_button = self.fig.add_axes([0.01, 0.02, 0.2, 0.06]) |
| self.button = Button(self.ax_button, "Toggle GT") |
| self.button.on_clicked(self.toggle_display) |
|
|
| self.update_plot() |
|
|
| def toggle_display(self, event): |
| |
| self.gt_mode = not self.gt_mode |
| self.update_plot() |
|
|
| def update_plot(self): |
| for x in self.artists: |
| x.remove() |
| for text in self.text_objects: |
| text.remove() |
|
|
| self.artists = [] |
| self.text_objects = [] |
|
|
| for idx, name in enumerate(self.preds): |
| pred = self.preds[name] |
|
|
| if self.gt_mode: |
| camera = self.data["camera"] |
| gravity = self.data["gravity"] |
| text = "\nGT" |
| else: |
| camera = pred["camera"] |
| gravity = pred["gravity"] |
| text = "\nPrediction" |
| camera = Camera(camera) |
| gravity = Gravity(gravity) |
|
|
| up, latitude = get_perspective_field(camera, gravity) |
|
|
| self.artists += plot_latitudes([latitude[0, 0]], axes=[self.axes[0][idx]]) |
| self.artists += plot_vector_fields([up[0]], axes=[self.axes[0][idx]]) |
|
|
| self.text_objects.append(add_text(idx, text)) |
|
|
| |
| self.fig.canvas.draw() |
|
|
| def clear(self): |
| |
| self.button.disconnect_events() |
| self.ax_button.remove() |
|
|
| for x in self.artists: |
| x.remove() |
| for text in self.text_objects: |
| text.remove() |
|
|
| self.artists = [] |
| self.text_objects = [] |
|
|
|
|
| __plot_dict__ = { |
| obj.plot_name: obj |
| for _, obj in inspect.getmembers(sys.modules[__name__], predicate=inspect.isclass) |
| if hasattr(obj, "plot_name") |
| } |
|
|