| |
|
|
| from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union |
|
|
| import cv2 |
| import numpy as np |
| import torch |
|
|
| if TYPE_CHECKING: |
| from matplotlib.backends.backend_agg import FigureCanvasAgg |
|
|
|
|
| def tensor2ndarray(value: Union[np.ndarray, torch.Tensor]) -> np.ndarray: |
| """If the type of value is torch.Tensor, convert the value to np.ndarray. |
| |
| Args: |
| value (np.ndarray, torch.Tensor): value. |
| |
| Returns: |
| Any: value. |
| """ |
| if isinstance(value, torch.Tensor): |
| value = value.detach().cpu().numpy() |
| return value |
|
|
|
|
| def value2list(value: Any, valid_type: Union[Type, Tuple[Type, ...]], |
| expand_dim: int) -> List[Any]: |
| """If the type of ``value`` is ``valid_type``, convert the value to list |
| and expand to ``expand_dim``. |
| |
| Args: |
| value (Any): value. |
| valid_type (Union[Type, Tuple[Type, ...]): valid type. |
| expand_dim (int): expand dim. |
| |
| Returns: |
| List[Any]: value. |
| """ |
| if isinstance(value, valid_type): |
| value = [value] * expand_dim |
| return value |
|
|
|
|
| def check_type(name: str, value: Any, |
| valid_type: Union[Type, Tuple[Type, ...]]) -> None: |
| """Check whether the type of value is in ``valid_type``. |
| |
| Args: |
| name (str): value name. |
| value (Any): value. |
| valid_type (Type, Tuple[Type, ...]): expected type. |
| """ |
| if not isinstance(value, valid_type): |
| raise TypeError(f'`{name}` should be {valid_type} ' |
| f' but got {type(value)}') |
|
|
|
|
| def check_length(name: str, value: Any, valid_length: int) -> None: |
| """If type of the ``value`` is list, check whether its length is equal with |
| or greater than ``valid_length``. |
| |
| Args: |
| name (str): value name. |
| value (Any): value. |
| valid_length (int): expected length. |
| """ |
| if isinstance(value, list): |
| if len(value) < valid_length: |
| raise AssertionError( |
| f'The length of {name} must equal with or ' |
| f'greater than {valid_length}, but got {len(value)}') |
|
|
|
|
| def check_type_and_length(name: str, value: Any, |
| valid_type: Union[Type, Tuple[Type, ...]], |
| valid_length: int) -> None: |
| """Check whether the type of value is in ``valid_type``. If type of the |
| ``value`` is list, check whether its length is equal with or greater than |
| ``valid_length``. |
| |
| Args: |
| value (Any): value. |
| legal_type (Type, Tuple[Type, ...]): legal type. |
| valid_length (int): expected length. |
| |
| Returns: |
| List[Any]: value. |
| """ |
| check_type(name, value, valid_type) |
| check_length(name, value, valid_length) |
|
|
|
|
| def color_val_matplotlib( |
| colors: Union[str, tuple, List[Union[str, tuple]]] |
| ) -> Union[str, tuple, List[Union[str, tuple]]]: |
| """Convert various input in RGB order to normalized RGB matplotlib color |
| tuples, |
| Args: |
| colors (Union[str, tuple, List[Union[str, tuple]]]): Color inputs |
| Returns: |
| Union[str, tuple, List[Union[str, tuple]]]: A tuple of 3 normalized |
| floats indicating RGB channels. |
| """ |
| if isinstance(colors, str): |
| return colors |
| elif isinstance(colors, tuple): |
| assert len(colors) == 3 |
| for channel in colors: |
| assert 0 <= channel <= 255 |
| colors = [channel / 255 for channel in colors] |
| return tuple(colors) |
| elif isinstance(colors, list): |
| colors = [ |
| color_val_matplotlib(color) |
| for color in colors |
| ] |
| return colors |
| else: |
| raise TypeError(f'Invalid type for color: {type(colors)}') |
|
|
|
|
| def color_str2rgb(color: str) -> tuple: |
| """Convert Matplotlib str color to an RGB color which range is 0 to 255, |
| silently dropping the alpha channel. |
| |
| Args: |
| color (str): Matplotlib color. |
| |
| Returns: |
| tuple: RGB color. |
| """ |
| import matplotlib |
| rgb_color: tuple = matplotlib.colors.to_rgb(color) |
| rgb_color = tuple(int(c * 255) for c in rgb_color) |
| return rgb_color |
|
|
|
|
| def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor], |
| img: Optional[np.ndarray] = None, |
| alpha: float = 0.5) -> np.ndarray: |
| """Convert feat_map to heatmap and overlay on image, if image is not None. |
| |
| Args: |
| feat_map (np.ndarray, torch.Tensor): The feat_map to convert |
| with of shape (H, W), where H is the image height and W is |
| the image width. |
| img (np.ndarray, optional): The origin image. The format |
| should be RGB. Defaults to None. |
| alpha (float): The transparency of featmap. Defaults to 0.5. |
| |
| Returns: |
| np.ndarray: heatmap |
| """ |
| assert feat_map.ndim == 2 or (feat_map.ndim == 3 |
| and feat_map.shape[0] in [1, 3]) |
| if isinstance(feat_map, torch.Tensor): |
| feat_map = feat_map.detach().cpu().numpy() |
|
|
| if feat_map.ndim == 3: |
| feat_map = feat_map.transpose(1, 2, 0) |
|
|
| norm_img = np.zeros(feat_map.shape) |
| norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX) |
| norm_img = np.asarray(norm_img, dtype=np.uint8) |
| heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) |
| heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) |
| if img is not None: |
| heat_img = cv2.addWeighted(img, 1 - alpha, heat_img, alpha, 0) |
| return heat_img |
|
|
|
|
| def wait_continue(figure, timeout: float = 0, continue_key: str = ' ') -> int: |
| """Show the image and wait for the user's input. |
| |
| This implementation refers to |
| https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py |
| |
| Args: |
| timeout (float): If positive, continue after ``timeout`` seconds. |
| Defaults to 0. |
| continue_key (str): The key for users to continue. Defaults to |
| the space key. |
| |
| Returns: |
| int: If zero, means time out or the user pressed ``continue_key``, |
| and if one, means the user closed the show figure. |
| """ |
| import matplotlib.pyplot as plt |
| from matplotlib.backend_bases import CloseEvent |
| is_inline = 'inline' in plt.get_backend() |
| if is_inline: |
| |
| return 0 |
|
|
| if figure.canvas.manager: |
| |
| figure.show() |
|
|
| while True: |
|
|
| |
| event = None |
|
|
| def handler(ev): |
| |
| nonlocal event |
| |
| |
| event = ev if not isinstance(event, CloseEvent) else event |
| figure.canvas.stop_event_loop() |
|
|
| cids = [ |
| figure.canvas.mpl_connect(name, handler) |
| for name in ('key_press_event', 'close_event') |
| ] |
|
|
| try: |
| figure.canvas.start_event_loop(timeout) |
| finally: |
| |
| for cid in cids: |
| figure.canvas.mpl_disconnect(cid) |
|
|
| if isinstance(event, CloseEvent): |
| return 1 |
| elif event is None or event.key == continue_key: |
| return 0 |
|
|
|
|
| def img_from_canvas(canvas: 'FigureCanvasAgg') -> np.ndarray: |
| """Get RGB image from ``FigureCanvasAgg``. |
| |
| Args: |
| canvas (FigureCanvasAgg): The canvas to get image. |
| |
| Returns: |
| np.ndarray: the output of image in RGB. |
| """ |
| s, (width, height) = canvas.print_to_buffer() |
| buffer = np.frombuffer(s, dtype='uint8') |
| img_rgba = buffer.reshape(height, width, 4) |
| rgb, alpha = np.split(img_rgba, [3], axis=2) |
| return rgb.astype('uint8') |
|
|