Spaces:
Running
Running
| """ | |
| MLSTRUCT-FP - DB - IMAGE - RECT BINARY | |
| Image of the surroundings of a rect. | |
| """ | |
| __all__ = ['RectBinaryImage', 'restore_plot_backend'] | |
| from MLStructFP.db.image._base import BaseImage, TYPE_IMAGE | |
| from MLStructFP.utils import make_dirs | |
| from MLStructFP._types import TYPE_CHECKING, Tuple, Dict, NumberType | |
| from PIL import Image | |
| import io | |
| import math | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| import time | |
| if TYPE_CHECKING: | |
| from MLStructFP.db._c_rect import Rect | |
| from MLStructFP.db._floor import Floor | |
| from MLStructFP.utils import GeomPoint2D | |
| from matplotlib.figure import Figure | |
| # Internal plot settings | |
| INITIAL_BACKEND: str = matplotlib.get_backend() | |
| MAX_STORED_FLOORS: int = 2 | |
| PLOT_BACKEND: str = 'agg' | |
| def restore_plot_backend() -> None: | |
| """ | |
| Restores plot backend. | |
| """ | |
| if plt.get_backend() == PLOT_BACKEND: | |
| plt.switch_backend(INITIAL_BACKEND) | |
| class RectBinaryImage(BaseImage): | |
| """ | |
| Rect binary image. | |
| This class creates a segmentation mask iterating all rects from a given floor; | |
| obviously, this can be further extended to create other maps, for example, for | |
| wall joints (points), or can be concatenated to create multiple mask tensor for | |
| a given crop (like wall segments + joints). It is up to your imagination and | |
| technical abilities =). I'd say that this is the most important class of all. | |
| """ | |
| _crop_px: int | |
| _initial_backend: str | |
| _initialized: bool | |
| _plot: Dict[str, Tuple['Figure', 'plt.Axes']] | |
| def __init__( | |
| self, | |
| path: str = '', | |
| save_images: bool = False, | |
| image_size_px: int = 64 | |
| ) -> None: | |
| """ | |
| Constructor. | |
| :param path: Image path | |
| :param save_images: Save images on path | |
| :param image_size_px: Image size (width/height), bigger images are expensive; double the width, quad the size | |
| """ | |
| BaseImage.__init__(self, path, save_images, image_size_px) | |
| self._crop_px = int(math.ceil(self._image_size / 32)) # Must be greater or equal than zero | |
| self._initialized = False | |
| self._plot = {} # Store matplotlib figures | |
| def init(self) -> 'RectBinaryImage': | |
| """ | |
| Initialize exporter. | |
| :return: Self | |
| """ | |
| if plt.get_backend() != PLOT_BACKEND: | |
| plt.switch_backend(PLOT_BACKEND) | |
| self._initialized = True | |
| self.close(restore_plot=False) | |
| self._initialized = True | |
| return self | |
| def _get_floor_plot(self, floor: 'Floor', store: bool) -> Tuple['Figure', 'plt.Axes']: | |
| """ | |
| Get basic figure of wall rects. | |
| :param floor: Source floor | |
| :param store: Store cache of the floor | |
| :return: Figure of the floor | |
| """ | |
| floor_id = str(floor.id) | |
| if floor_id in self._plot.keys(): | |
| return self._plot[floor_id] | |
| fig: 'Figure' = plt.figure(frameon=False) # Don't configure dpi | |
| plt.style.use('default') # Don't modify this either | |
| ax: 'plt.Axes' = fig.add_axes([0, 0, 1, 1]) | |
| ax.axis('off') | |
| ax.set_aspect(aspect='equal') | |
| ax.grid(False) # Don't enable as this may destroy the figures | |
| for r in floor.rect: | |
| r.plot_matplotlib( | |
| ax=ax, | |
| color='#000000' | |
| ) | |
| # Save | |
| if store: | |
| if len(self._plot) >= MAX_STORED_FLOORS: | |
| key_iter = iter(self._plot.keys()) | |
| k1: str = next(key_iter) | |
| f1, _ = self._plot[k1] | |
| plt.close(f1) # Close the figure | |
| del self._plot[k1] # Remove | |
| self._plot[floor_id] = (fig, ax) | |
| return fig, ax | |
| def make_rect(self, rect: 'Rect', crop_length: NumberType = 5) -> Tuple[int, 'np.ndarray']: | |
| """ | |
| Generate image for the perimeter of a given rectangle. | |
| :param rect: Rectangle | |
| :param crop_length: Size of crop from center of the rect to any edge in meters | |
| :return: Returns the image index and matrix | |
| """ | |
| cr: 'GeomPoint2D' = rect.get_mass_center() | |
| return self.make_region( | |
| xmin=cr.x - crop_length, | |
| xmax=cr.x + crop_length, | |
| ymin=cr.y - crop_length, | |
| ymax=cr.y + crop_length, | |
| floor=rect.floor | |
| ) | |
| # noinspection PyMethodMayBeStatic | |
| def _convert_image_color(self, im: 'Image.Image') -> 'Image.Image': | |
| """ | |
| Convert image. | |
| :param im: Image | |
| :return: Converted image | |
| """ | |
| return im.convert('P', palette=Image.Palette.ADAPTIVE) | |
| # noinspection PyMethodMayBeStatic | |
| def _post_process(self, im: 'Image.Image') -> 'Image.Image': | |
| """ | |
| Post process image. | |
| :param im: Image | |
| :return: Post-processed image | |
| """ | |
| return im | |
| def _make(self, xmin: NumberType, xmax: NumberType, ymin: NumberType, ymax: NumberType, | |
| floor: 'Floor') -> 'Image.Image': | |
| """ | |
| Generate image from a given coordinate (x, y). | |
| :param xmin: Minimum x-axis (m) | |
| :param xmax: Maximum x-axis (m) | |
| :param ymin: Minimum y-axis (m) | |
| :param ymax: Maximum y-axis (m) | |
| :param floor: Floor object | |
| :return: Returns the image | |
| """ | |
| fig, ax = self._get_floor_plot(floor, store=True) | |
| # Set the current figure | |
| # noinspection PyUnresolvedReferences | |
| plt.figure(fig.number) | |
| # Extent axes | |
| ax.set_xlim(min(xmin, xmax), max(xmin, xmax)) | |
| ax.set_ylim(min(ymin, ymax), max(ymin, ymax)) | |
| # Convert from matplotlib | |
| ram = io.BytesIO() | |
| plt.savefig(ram, format='png', dpi=100, bbox_inches='tight', transparent=False) | |
| ram.seek(0) | |
| im = Image.open(ram).convert('RGB') | |
| ram.close() | |
| return im | |
| def make_region(self, xmin: NumberType, xmax: NumberType, ymin: NumberType, ymax: NumberType, | |
| floor: 'Floor') -> Tuple[int, 'np.ndarray']: | |
| """ | |
| Generate image for a given region. | |
| :param xmin: Minimum x-axis (m) | |
| :param xmax: Maximum x-axis (m) | |
| :param ymin: Minimum y-axis (m) | |
| :param ymax: Maximum y-axis (m) | |
| :param floor: Floor object | |
| :return: Returns the image index and matrix | |
| """ | |
| if not self._initialized: | |
| raise RuntimeError('Exporter not initialized, use .init()') | |
| t0 = time.time() | |
| # Make crop | |
| im: 'Image.Image' = self._make(xmin, xmax, ymin, ymax, floor) | |
| # Convert color | |
| im2: 'Image.Image' = self._convert_image_color(im) | |
| # Resize | |
| s_resize = self._image_size + 2 * self._crop_px | |
| im3: 'Image.Image' = im2.resize((s_resize, s_resize)) | |
| # Crop | |
| s_crop = self._image_size + self._crop_px | |
| im4: 'Image.Image' = self._post_process(im3.crop((self._crop_px, self._crop_px, s_crop, s_crop))) | |
| # Save to file | |
| figname: str = f'{floor.id}-x-{xmin:.2f}-{xmax:.2f}-y-{ymin:.2f}-{ymax:.2f}' | |
| if self._save_images: | |
| assert self._path != '', 'Path cannot be empty' | |
| filesave = os.path.join(self._path, figname + '.png') | |
| im4.save(make_dirs(filesave), format='PNG') | |
| # noinspection PyTypeChecker | |
| array = np.array(im4, dtype=TYPE_IMAGE) | |
| # Save to array | |
| if self.save: | |
| self._images.append(array) | |
| self._names.append(figname) | |
| # Close data | |
| im.close() | |
| im2.close() | |
| im3.close() | |
| im4.close() | |
| # Remove data | |
| del im, im2, im3, im4 | |
| # Returns the image index on the library array | |
| self._last_make_region_time = time.time() - t0 | |
| return len(self._images) - 1, array | |
| def close(self, restore_plot: bool = True) -> None: | |
| """ | |
| Close and delete all generated figures. | |
| This function also restores plot engine. | |
| :param restore_plot: Restores plotting engine | |
| """ | |
| if not self._initialized: | |
| raise RuntimeError('Exporter not initialized, it cannot be closed') | |
| # Close figures | |
| for f in self._plot.keys(): | |
| fig, _ = self._plot[f] | |
| plt.close(fig) | |
| # Remove | |
| self._plot.clear() | |
| self._images.clear() | |
| self._names.clear() | |
| # Restore plot | |
| if restore_plot: | |
| restore_plot_backend() | |
| self._initialized = False | |