|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from matplotlib.backends.backend_agg import FigureCanvasAgg |
|
|
|
|
|
|
|
|
def tensor2ndarray(value: Union[np.ndarray, torch.Tensor]) -> np.ndarray: |
|
|
"""If the type of value is torch.Tensor, convert the value to np.ndarray. |
|
|
|
|
|
Args: |
|
|
value (np.ndarray, torch.Tensor): value. |
|
|
|
|
|
Returns: |
|
|
Any: value. |
|
|
""" |
|
|
if isinstance(value, torch.Tensor): |
|
|
value = value.detach().cpu().numpy() |
|
|
return value |
|
|
|
|
|
|
|
|
def value2list(value: Any, valid_type: Union[Type, Tuple[Type, ...]], |
|
|
expand_dim: int) -> List[Any]: |
|
|
"""If the type of ``value`` is ``valid_type``, convert the value to list |
|
|
and expand to ``expand_dim``. |
|
|
|
|
|
Args: |
|
|
value (Any): value. |
|
|
valid_type (Union[Type, Tuple[Type, ...]): valid type. |
|
|
expand_dim (int): expand dim. |
|
|
|
|
|
Returns: |
|
|
List[Any]: value. |
|
|
""" |
|
|
if isinstance(value, valid_type): |
|
|
value = [value] * expand_dim |
|
|
return value |
|
|
|
|
|
|
|
|
def check_type(name: str, value: Any, |
|
|
valid_type: Union[Type, Tuple[Type, ...]]) -> None: |
|
|
"""Check whether the type of value is in ``valid_type``. |
|
|
|
|
|
Args: |
|
|
name (str): value name. |
|
|
value (Any): value. |
|
|
valid_type (Type, Tuple[Type, ...]): expected type. |
|
|
""" |
|
|
if not isinstance(value, valid_type): |
|
|
raise TypeError(f'`{name}` should be {valid_type} ' |
|
|
f' but got {type(value)}') |
|
|
|
|
|
|
|
|
def check_length(name: str, value: Any, valid_length: int) -> None: |
|
|
"""If type of the ``value`` is list, check whether its length is equal with |
|
|
or greater than ``valid_length``. |
|
|
|
|
|
Args: |
|
|
name (str): value name. |
|
|
value (Any): value. |
|
|
valid_length (int): expected length. |
|
|
""" |
|
|
if isinstance(value, list): |
|
|
if len(value) < valid_length: |
|
|
raise AssertionError( |
|
|
f'The length of {name} must equal with or ' |
|
|
f'greater than {valid_length}, but got {len(value)}') |
|
|
|
|
|
|
|
|
def check_type_and_length(name: str, value: Any, |
|
|
valid_type: Union[Type, Tuple[Type, ...]], |
|
|
valid_length: int) -> None: |
|
|
"""Check whether the type of value is in ``valid_type``. If type of the |
|
|
``value`` is list, check whether its length is equal with or greater than |
|
|
``valid_length``. |
|
|
|
|
|
Args: |
|
|
value (Any): value. |
|
|
legal_type (Type, Tuple[Type, ...]): legal type. |
|
|
valid_length (int): expected length. |
|
|
|
|
|
Returns: |
|
|
List[Any]: value. |
|
|
""" |
|
|
check_type(name, value, valid_type) |
|
|
check_length(name, value, valid_length) |
|
|
|
|
|
|
|
|
def color_val_matplotlib( |
|
|
colors: Union[str, tuple, List[Union[str, tuple]]] |
|
|
) -> Union[str, tuple, List[Union[str, tuple]]]: |
|
|
"""Convert various input in RGB order to normalized RGB matplotlib color |
|
|
tuples, |
|
|
Args: |
|
|
colors (Union[str, tuple, List[Union[str, tuple]]]): Color inputs |
|
|
Returns: |
|
|
Union[str, tuple, List[Union[str, tuple]]]: A tuple of 3 normalized |
|
|
floats indicating RGB channels. |
|
|
""" |
|
|
if isinstance(colors, str): |
|
|
return colors |
|
|
elif isinstance(colors, tuple): |
|
|
assert len(colors) == 3 |
|
|
for channel in colors: |
|
|
assert 0 <= channel <= 255 |
|
|
colors = [channel / 255 for channel in colors] |
|
|
return tuple(colors) |
|
|
elif isinstance(colors, list): |
|
|
colors = [ |
|
|
color_val_matplotlib(color) |
|
|
for color in colors |
|
|
] |
|
|
return colors |
|
|
else: |
|
|
raise TypeError(f'Invalid type for color: {type(colors)}') |
|
|
|
|
|
|
|
|
def color_str2rgb(color: str) -> tuple: |
|
|
"""Convert Matplotlib str color to an RGB color which range is 0 to 255, |
|
|
silently dropping the alpha channel. |
|
|
|
|
|
Args: |
|
|
color (str): Matplotlib color. |
|
|
|
|
|
Returns: |
|
|
tuple: RGB color. |
|
|
""" |
|
|
import matplotlib |
|
|
rgb_color: tuple = matplotlib.colors.to_rgb(color) |
|
|
rgb_color = tuple(int(c * 255) for c in rgb_color) |
|
|
return rgb_color |
|
|
|
|
|
|
|
|
def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor], |
|
|
img: Optional[np.ndarray] = None, |
|
|
alpha: float = 0.5) -> np.ndarray: |
|
|
"""Convert feat_map to heatmap and overlay on image, if image is not None. |
|
|
|
|
|
Args: |
|
|
feat_map (np.ndarray, torch.Tensor): The feat_map to convert |
|
|
with of shape (H, W), where H is the image height and W is |
|
|
the image width. |
|
|
img (np.ndarray, optional): The origin image. The format |
|
|
should be RGB. Defaults to None. |
|
|
alpha (float): The transparency of featmap. Defaults to 0.5. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: heatmap |
|
|
""" |
|
|
assert feat_map.ndim == 2 or (feat_map.ndim == 3 |
|
|
and feat_map.shape[0] in [1, 3]) |
|
|
if isinstance(feat_map, torch.Tensor): |
|
|
feat_map = feat_map.detach().cpu().numpy() |
|
|
|
|
|
if feat_map.ndim == 3: |
|
|
feat_map = feat_map.transpose(1, 2, 0) |
|
|
|
|
|
norm_img = np.zeros(feat_map.shape) |
|
|
norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX) |
|
|
norm_img = np.asarray(norm_img, dtype=np.uint8) |
|
|
heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) |
|
|
heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) |
|
|
if img is not None: |
|
|
heat_img = cv2.addWeighted(img, 1 - alpha, heat_img, alpha, 0) |
|
|
return heat_img |
|
|
|
|
|
|
|
|
def wait_continue(figure, timeout: float = 0, continue_key: str = ' ') -> int: |
|
|
"""Show the image and wait for the user's input. |
|
|
|
|
|
This implementation refers to |
|
|
https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py |
|
|
|
|
|
Args: |
|
|
timeout (float): If positive, continue after ``timeout`` seconds. |
|
|
Defaults to 0. |
|
|
continue_key (str): The key for users to continue. Defaults to |
|
|
the space key. |
|
|
|
|
|
Returns: |
|
|
int: If zero, means time out or the user pressed ``continue_key``, |
|
|
and if one, means the user closed the show figure. |
|
|
""" |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.backend_bases import CloseEvent |
|
|
is_inline = 'inline' in plt.get_backend() |
|
|
if is_inline: |
|
|
|
|
|
return 0 |
|
|
|
|
|
if figure.canvas.manager: |
|
|
|
|
|
figure.show() |
|
|
|
|
|
while True: |
|
|
|
|
|
|
|
|
event = None |
|
|
|
|
|
def handler(ev): |
|
|
|
|
|
nonlocal event |
|
|
|
|
|
|
|
|
event = ev if not isinstance(event, CloseEvent) else event |
|
|
figure.canvas.stop_event_loop() |
|
|
|
|
|
cids = [ |
|
|
figure.canvas.mpl_connect(name, handler) |
|
|
for name in ('key_press_event', 'close_event') |
|
|
] |
|
|
|
|
|
try: |
|
|
figure.canvas.start_event_loop(timeout) |
|
|
finally: |
|
|
|
|
|
for cid in cids: |
|
|
figure.canvas.mpl_disconnect(cid) |
|
|
|
|
|
if isinstance(event, CloseEvent): |
|
|
return 1 |
|
|
elif event is None or event.key == continue_key: |
|
|
return 0 |
|
|
|
|
|
|
|
|
def img_from_canvas(canvas: 'FigureCanvasAgg') -> np.ndarray: |
|
|
"""Get RGB image from ``FigureCanvasAgg``. |
|
|
|
|
|
Args: |
|
|
canvas (FigureCanvasAgg): The canvas to get image. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: the output of image in RGB. |
|
|
""" |
|
|
s, (width, height) = canvas.print_to_buffer() |
|
|
buffer = np.frombuffer(s, dtype='uint8') |
|
|
img_rgba = buffer.reshape(height, width, 4) |
|
|
rgb, alpha = np.split(img_rgba, [3], axis=2) |
|
|
return rgb.astype('uint8') |
|
|
|