hbyecoding's picture
Upload 143 files
b2c5353 verified
from typing import Union
import warnings
import numpy as np
import torch
import cv2
import skimage.measure
# -----------------------------------------------------------------------------
# Random sampling
# -----------------------------------------------------------------------------
def _as_single_val(value, high: bool = True) -> Union[int,float]:
"""
Args:
high: if True, include the upper bound for the range (for integer ranges)
"""
if isinstance(value, (int, float)):
return value
if isinstance(value, (tuple, list)):
if len(value) == 1:
return value[0]
else:
assert len(value) == 2, f"Invalid 2-tuple {value}"
if any(isinstance(i, float) for i in value):
value = (float(value[0]), float(value[1]))
if isinstance(value[0], int):
if high:
return np.random.randint(value[0], value[1]+1)
else:
return np.random.randint(*value)
else:
return np.random.uniform(*value)
def chance(x: Union[float,int,bool]) -> bool:
"""
Args:
x: probability of returning True
"""
if x == 0:
return False
elif x == 1:
return True
else:
return np.random.rand() < x
# -----------------------------------------------------------------------------
# Connected Components
# -----------------------------------------------------------------------------
def get_components(seg: torch.Tensor, background: bool = False, show: bool = False, return_area: bool = False):
"""
Get a map of all the components in an image
Args:
seg: must be binary 1 x H x W or H x W
background: if False, only get the components where seg != 0, otherwise get components for both the label and background
show: if True, plot
return_area: if True, return list of areas of each components
"""
assert seg.dtype == torch.int8 or seg.dtype == torch.int32, "seg must be integer"
if len(seg.shape)==3:
seg = seg.squeeze()
assert len(seg.shape)==2, "seg must be 2D"
comps, n = skimage.measure.label(seg, return_num=True)
if background:
# Get connected components of the background
back_comps = skimage.measure.label(1-seg)
back_comps[ np.nonzero(back_comps) ] = back_comps[ np.nonzero(back_comps) ] + n
combined_map = comps + back_comps
else:
combined_map = comps
if show:
import os
os.environ['NEURITE_BACKEND'] = 'pytorch'
import neurite as ne
if background:
ne.plot.slices([seg, comps, back_comps, combined_map],
titles=['Seg', 'Seg Components', '1-Seg Components', 'Combined Components'],
cmaps=['viridis'], do_colorbars=True)
else:
ne.plot.slices([seg, comps], titles=['Seg', "Components"], cmaps=['viridis'], do_colorbars=True, width=10)
if return_area:
n_components = combined_map.max()
areas = [(combined_map==i).sum() for i in range(1, n_components+1)]
return combined_map, areas
else:
return combined_map
# -----------------------------------------------------------------------------
# Distance Transform
# -----------------------------------------------------------------------------
def get_combined_dt(error_region: torch.Tensor, background: bool = False):
"""
Get a combined distance transform of the false positives and false negatives
Args:
error_region (torch.Tensor): (1, height, width) on [-1,0,1] where +1 is false positive and -1 is false negative
"""
fp_mask = torch.abs(torch.clamp(error_region, min=0.0)).cpu()
fn_mask = torch.abs(torch.clamp(error_region, max=0.0)).cpu()
# Note: distanceTransform expects a binary image
fp_mask_dt = cv2.distanceTransform(fp_mask[0,...,None].numpy().astype(np.uint8), cv2.DIST_L2, 0)
fn_mask_dt = cv2.distanceTransform(fn_mask[0,...,None].numpy().astype(np.uint8), cv2.DIST_L2, 0)
mask_dt = fp_mask_dt + fn_mask_dt # shape: (height, width)
if background:
background_mask = (error_region==0).float().cpu()
background_mask_dt = cv2.distanceTransform(background_mask[0,...,None].numpy().astype(np.uint8), cv2.DIST_L2, 0)
mask_dt += background_mask_dt
return mask_dt
# -----------------------------------------------------------------------------
# Debugging
# -----------------------------------------------------------------------------
def warn_in_range(tensor, range_to_check=None, name='tensor'):
"""
Check if tensor contains NaN/Inf and (optional) is in range
"""
if tensor.isnan().any():
warnings.warn(f'{name} contains NaN')
if tensor.isinf().any():
warnings.warn(f'{name} contains inf')
if range_to_check is not None:
assert len(range_to_check) == 2, f'range should be in form [min, max] {range_to_check}'
if tensor.min() < range_to_check[0]:
warnings.warn(f'{name} should be in {range_to_check}, found: {tensor.min()}')
if tensor.max() > range_to_check[1]:
warnings.warn(f'{name} should be in {range_to_check}, found: {tensor.max()}')