Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import matplotlib as mpl | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import PIL.Image | |
| import plotly.graph_objects as go | |
| from ..utils.viz_2d import add_text | |
| from .parser import Groups | |
| class GeoPlotter: | |
| def __init__(self, zoom=12, **kwargs): | |
| self.fig = go.Figure() | |
| self.fig.update_layout( | |
| mapbox_style="open-street-map", | |
| autosize=True, | |
| mapbox_zoom=zoom, | |
| margin={"r": 0, "t": 0, "l": 0, "b": 0}, | |
| showlegend=True, | |
| **kwargs, | |
| ) | |
| def points(self, latlons, color, text=None, name=None, size=5, **kwargs): | |
| latlons = np.asarray(latlons) | |
| self.fig.add_trace( | |
| go.Scattermapbox( | |
| lat=latlons[..., 0], | |
| lon=latlons[..., 1], | |
| mode="markers", | |
| text=text, | |
| marker_color=color, | |
| marker_size=size, | |
| name=name, | |
| **kwargs, | |
| ) | |
| ) | |
| center = latlons.reshape(-1, 2).mean(0) | |
| self.fig.update_layout( | |
| mapbox_center=dict(zip(("lat", "lon"), center)), | |
| ) | |
| def bbox(self, bbox, color, name=None, **kwargs): | |
| corners = np.stack( | |
| [bbox.min_, bbox.left_top, bbox.max_, bbox.right_bottom, bbox.min_] | |
| ) | |
| self.fig.add_trace( | |
| go.Scattermapbox( | |
| lat=corners[:, 0], | |
| lon=corners[:, 1], | |
| mode="lines", | |
| marker_color=color, | |
| name=name, | |
| **kwargs, | |
| ) | |
| ) | |
| self.fig.update_layout( | |
| mapbox_center=dict(zip(("lat", "lon"), bbox.center)), | |
| ) | |
| def raster(self, raster, bbox, below="traces", **kwargs): | |
| if not np.issubdtype(raster.dtype, np.integer): | |
| raster = (raster * 255).astype(np.uint8) | |
| raster = PIL.Image.fromarray(raster) | |
| corners = np.stack( | |
| [ | |
| bbox.min_, | |
| bbox.left_top, | |
| bbox.max_, | |
| bbox.right_bottom, | |
| ] | |
| )[::-1, ::-1] | |
| layers = [*self.fig.layout.mapbox.layers] | |
| layers.append( | |
| dict( | |
| sourcetype="image", | |
| source=raster, | |
| coordinates=corners, | |
| below=below, | |
| **kwargs, | |
| ) | |
| ) | |
| self.fig.layout.mapbox.layers = layers | |
| map_colors = { | |
| "building": (84, 155, 255), | |
| "parking": (255, 229, 145), | |
| "playground": (150, 133, 125), | |
| "grass": (188, 255, 143), | |
| "park": (0, 158, 16), | |
| "forest": (0, 92, 9), | |
| "water": (184, 213, 255), | |
| "fence": (238, 0, 255), | |
| "wall": (0, 0, 0), | |
| "hedge": (107, 68, 48), | |
| "kerb": (255, 234, 0), | |
| "building_outline": (0, 0, 255), | |
| "cycleway": (0, 251, 255), | |
| "path": (8, 237, 0), | |
| "road": (255, 0, 0), | |
| "tree_row": (0, 92, 9), | |
| "busway": (255, 128, 0), | |
| "void": [int(255 * 0.9)] * 3, | |
| } | |
| class Colormap: | |
| colors_areas = np.stack([map_colors[k] for k in ["void"] + Groups.areas]) | |
| colors_ways = np.stack([map_colors[k] for k in ["void"] + Groups.ways]) | |
| def apply(cls, rasters): | |
| return ( | |
| np.where( | |
| rasters[1, ..., None] > 0, | |
| cls.colors_ways[rasters[1]], | |
| cls.colors_areas[rasters[0]], | |
| ) | |
| / 255.0 | |
| ) | |
| def add_colorbar(cls): | |
| ax2 = plt.gcf().add_axes([1, 0.1, 0.02, 0.8]) | |
| color_list = np.r_[cls.colors_areas[1:], cls.colors_ways[1:]] / 255.0 | |
| cmap = mpl.colors.ListedColormap(color_list[::-1]) | |
| ticks = np.linspace(0, 1, len(color_list), endpoint=False) | |
| ticks += 1 / len(color_list) / 2 | |
| cb = mpl.colorbar.ColorbarBase( | |
| ax2, | |
| cmap=cmap, | |
| orientation="vertical", | |
| ticks=ticks, | |
| ) | |
| cb.set_ticklabels((Groups.areas + Groups.ways)[::-1]) | |
| ax2.tick_params(labelsize=15) | |
| def plot_nodes(idx, raster, fontsize=8, size=15): | |
| ax = plt.gcf().axes[idx] | |
| ax.autoscale(enable=False) | |
| nodes_xy = np.stack(np.where(raster > 0)[::-1], -1) | |
| nodes_val = raster[tuple(nodes_xy.T[::-1])] - 1 | |
| ax.scatter(*nodes_xy.T, c="k", s=size) | |
| for xy, val in zip(nodes_xy, nodes_val): | |
| group = Groups.nodes[val] | |
| add_text( | |
| idx, | |
| group, | |
| xy + 2, | |
| lcolor=None, | |
| fs=fontsize, | |
| color="k", | |
| normalized=False, | |
| ha="center", | |
| ) | |