Spaces:
Sleeping
Sleeping
| from typing import Union | |
| import warnings | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| import skimage.measure | |
| # ----------------------------------------------------------------------------- | |
| # Random sampling | |
| # ----------------------------------------------------------------------------- | |
| def _as_single_val(value, high: bool = True) -> Union[int,float]: | |
| """ | |
| Args: | |
| high: if True, include the upper bound for the range (for integer ranges) | |
| """ | |
| if isinstance(value, (int, float)): | |
| return value | |
| if isinstance(value, (tuple, list)): | |
| if len(value) == 1: | |
| return value[0] | |
| else: | |
| assert len(value) == 2, f"Invalid 2-tuple {value}" | |
| if any(isinstance(i, float) for i in value): | |
| value = (float(value[0]), float(value[1])) | |
| if isinstance(value[0], int): | |
| if high: | |
| return np.random.randint(value[0], value[1]+1) | |
| else: | |
| return np.random.randint(*value) | |
| else: | |
| return np.random.uniform(*value) | |
| def chance(x: Union[float,int,bool]) -> bool: | |
| """ | |
| Args: | |
| x: probability of returning True | |
| """ | |
| if x == 0: | |
| return False | |
| elif x == 1: | |
| return True | |
| else: | |
| return np.random.rand() < x | |
| # ----------------------------------------------------------------------------- | |
| # Connected Components | |
| # ----------------------------------------------------------------------------- | |
| def get_components(seg: torch.Tensor, background: bool = False, show: bool = False, return_area: bool = False): | |
| """ | |
| Get a map of all the components in an image | |
| Args: | |
| seg: must be binary 1 x H x W or H x W | |
| background: if False, only get the components where seg != 0, otherwise get components for both the label and background | |
| show: if True, plot | |
| return_area: if True, return list of areas of each components | |
| """ | |
| assert seg.dtype == torch.int8 or seg.dtype == torch.int32, "seg must be integer" | |
| if len(seg.shape)==3: | |
| seg = seg.squeeze() | |
| assert len(seg.shape)==2, "seg must be 2D" | |
| comps, n = skimage.measure.label(seg, return_num=True) | |
| if background: | |
| # Get connected components of the background | |
| back_comps = skimage.measure.label(1-seg) | |
| back_comps[ np.nonzero(back_comps) ] = back_comps[ np.nonzero(back_comps) ] + n | |
| combined_map = comps + back_comps | |
| else: | |
| combined_map = comps | |
| if show: | |
| import os | |
| os.environ['NEURITE_BACKEND'] = 'pytorch' | |
| import neurite as ne | |
| if background: | |
| ne.plot.slices([seg, comps, back_comps, combined_map], | |
| titles=['Seg', 'Seg Components', '1-Seg Components', 'Combined Components'], | |
| cmaps=['viridis'], do_colorbars=True) | |
| else: | |
| ne.plot.slices([seg, comps], titles=['Seg', "Components"], cmaps=['viridis'], do_colorbars=True, width=10) | |
| if return_area: | |
| n_components = combined_map.max() | |
| areas = [(combined_map==i).sum() for i in range(1, n_components+1)] | |
| return combined_map, areas | |
| else: | |
| return combined_map | |
| # ----------------------------------------------------------------------------- | |
| # Distance Transform | |
| # ----------------------------------------------------------------------------- | |
| def get_combined_dt(error_region: torch.Tensor, background: bool = False): | |
| """ | |
| Get a combined distance transform of the false positives and false negatives | |
| Args: | |
| error_region (torch.Tensor): (1, height, width) on [-1,0,1] where +1 is false positive and -1 is false negative | |
| """ | |
| fp_mask = torch.abs(torch.clamp(error_region, min=0.0)).cpu() | |
| fn_mask = torch.abs(torch.clamp(error_region, max=0.0)).cpu() | |
| # Note: distanceTransform expects a binary image | |
| fp_mask_dt = cv2.distanceTransform(fp_mask[0,...,None].numpy().astype(np.uint8), cv2.DIST_L2, 0) | |
| fn_mask_dt = cv2.distanceTransform(fn_mask[0,...,None].numpy().astype(np.uint8), cv2.DIST_L2, 0) | |
| mask_dt = fp_mask_dt + fn_mask_dt # shape: (height, width) | |
| if background: | |
| background_mask = (error_region==0).float().cpu() | |
| background_mask_dt = cv2.distanceTransform(background_mask[0,...,None].numpy().astype(np.uint8), cv2.DIST_L2, 0) | |
| mask_dt += background_mask_dt | |
| return mask_dt | |
| # ----------------------------------------------------------------------------- | |
| # Debugging | |
| # ----------------------------------------------------------------------------- | |
| def warn_in_range(tensor, range_to_check=None, name='tensor'): | |
| """ | |
| Check if tensor contains NaN/Inf and (optional) is in range | |
| """ | |
| if tensor.isnan().any(): | |
| warnings.warn(f'{name} contains NaN') | |
| if tensor.isinf().any(): | |
| warnings.warn(f'{name} contains inf') | |
| if range_to_check is not None: | |
| assert len(range_to_check) == 2, f'range should be in form [min, max] {range_to_check}' | |
| if tensor.min() < range_to_check[0]: | |
| warnings.warn(f'{name} should be in {range_to_check}, found: {tensor.min()}') | |
| if tensor.max() > range_to_check[1]: | |
| warnings.warn(f'{name} should be in {range_to_check}, found: {tensor.max()}') |