hbyecoding's picture
Upload 143 files
b2c5353 verified
from typing import Union, Tuple, List, Optional
import numpy as np
import torch
import kornia
import cv2
import voxynth
from .utils import _as_single_val
import os
# Prevent neurite from trying to load tensorflow
os.environ['NEURITE_BACKEND'] = 'pytorch'
# -----------------------------------------------------------------------------
# Parent class
# -----------------------------------------------------------------------------
class WarpScribble:
"""
Parent scribble class with shared functions for generating noise masks (useful for breaking up scribbles) and applying deformation fields (to warp scribbles)
"""
def __init__(self,
warp: bool = True,
warp_smoothing: Union[int,Tuple[int],List[int]] = (4, 16),
warp_magnitude: Union[int,Tuple[int],List[int]] = (1, 6),
mask_smoothing: Union[int,Tuple[int],List[int]] = (4, 16),
):
if isinstance(warp_smoothing, int):
warp_smoothing = [warp_smoothing, warp_smoothing]
if isinstance(warp_magnitude, int):
warp_magnitude = [warp_magnitude, warp_magnitude]
# Warp settings
self.warp = warp
self.warp_smoothing = list(warp_smoothing)
self.warp_magnitude = list(warp_magnitude)
# Noise mask settings
self.mask_smoothing = mask_smoothing
def noise_mask(self, shape: Union[Tuple[int],List[int]] = (8,128,128), device = None):
"""
Get a random binary mask by thresholding smoothed noise. The mask is used to break up the scribbles
"""
if isinstance(self.mask_smoothing, tuple):
get_smoothing = lambda: np.random.uniform(*self.mask_smoothing)
else:
get_smoothing = lambda: self.mask_smoothing
noise = torch.stack([
voxynth.noise.perlin(shape=shape[-2:], smoothing=get_smoothing(), magnitude=1, device=device) for _ in range(shape[0])
]) # shape: b x H x W
noise_mask = (noise > 0.0).int().unsqueeze(1)
return noise_mask # shaoe: b x 1 x H x W
def apply_warp(self, x: torch.Tensor):
"""
Warp a given mask x using a random deformation field
"""
if x.sum() > 0:
# warp scribbles using a deformation field
deformation_field = voxynth.transform.random_transform(
shape = x.shape[-2:],
affine_probability = 0.0,
warp_probability = 1.0,
warp_integrations = 0,
warp_smoothing_range = self.warp_smoothing,
warp_magnitude_range = self.warp_magnitude,
voxsize = 1,
device = x.device,
isdisp = False
)
warped = voxynth.transform.spatial_transform(x, trf = deformation_field, isdisp=False)
if warped.sum() == 0:
return x
else:
return (warped - warped.min()) / (warped.max() - warped.min())
else:
# Don't need to warp if mask is empty
return x
def batch_scribble(self, mask: torch.Tensor, n_scribbles: int = 1):
"""
Simulate scribbles for a batch of examples (mask).
"""
raise NotImplementedError
def __call__(self, mask: torch.Tensor, n_scribbles: int = 1) -> torch.Tensor:
"""
Args:
mask: (b,1,H,W) or (1,H,W) mask in [0,1] to sample scribbles from
Returns:
scribble_mask: (b,1,H,W) or (1,H,W) mask(s) of scribbles on [0,1]
"""
assert len(mask.shape) in [3,4], f"mask must be b x 1 x h x w or 1 x h x w. currently {mask.shape}"
if len(mask.shape)==3:
# shape: 1 x h x w
return self.batch_scribble(mask[None,...], n_scribbles=n_scribbles)[0,...]
else:
# shape: b x 1 x h x w
return self.batch_scribble(mask, n_scribbles=n_scribbles)
# -----------------------------------------------------------------------------
# Line Scribbles
# -----------------------------------------------------------------------------
class LineScribble(WarpScribble):
"""
Generates scribbles by
1) drawing lines connecting random points on the mask
2) warping with a random deformation field
3) then correcting any scribbles outside the mask
5) optionally, limiting the max area of scribbles to k pixels
"""
def __init__(self,
# Warp settings
warp: bool = True,
warp_smoothing: Union[int,Tuple[int],List[int]] = (4, 16),
warp_magnitude: Union[int,Tuple[int],List[int]] = (1, 6),
mask_smoothing: Union[int,Tuple[int],List[int]] = (4, 16),
# Line scribble settings
thickness: int = 1,
preserve_scribble: bool = True, # if True, prevents empty scribble masks from being returned
max_pixels: Optional[int] = None, # per "scribble"
max_pixels_smooth: Optional[int] = 42,
# Viz
show: bool = False
):
super().__init__(
warp=warp,
warp_smoothing=warp_smoothing,
warp_magnitude=warp_magnitude,
mask_smoothing=mask_smoothing,
)
self.thickness = thickness
self.preserve_scribble = preserve_scribble
self.max_pixels = max_pixels
self.max_pixels_smooth = max_pixels_smooth
self.show = show
def batch_scribble(self, mask: torch.Tensor, n_scribbles: int = 1) -> torch.Tensor:
"""
Args:
mask: (b,1,H,W) mask in [0,1] to sample scribbles from
n_scribbles: number of line scribbles to sample initially
Returns:
scribble_mask: (b,1,H,W) mask(s) of scribbles in [0,1]
"""
bs = mask.shape[0]
# Points to sample line endpoints from
points = torch.nonzero(mask[:,0,...])
def sample_lines(indices):
image = np.zeros(mask.shape[-2:]+(1,))
if len(indices) > 0:
# Sample points for each example in the batch
idx = np.random.randint(low=0, high=len(indices), size=2*n_scribbles)
endpoints = points[indices,1:][idx,0,...]
# Flip order of coordinates to be xy
endpoints = torch.flip(endpoints, dims=(1,)).cpu().numpy()
# Draw lines between the sample points
for i in range(n_scribbles):
thickness = _as_single_val(self.thickness)
image = cv2.line(image, tuple(endpoints[i*2]), tuple(endpoints[i*2+1]), color=1, thickness=thickness)
return torch.from_numpy(image) # shape: H x W x 1
scribbles = torch.stack([
sample_lines(torch.argwhere(points[:,0]==i)) for i in range(bs)
]).to(mask.device).moveaxis(-1,1).float() # shape: b x 1 x H x W
if self.warp:
warped_scribbles = torch.stack([self.apply_warp(scribbles[b,...]) for b in range(bs)]) # shape: b x 1 x H x W
else:
warped_scribbles = scribbles
# Remove lines outside the mask
corrected_warped_scribbles = mask * warped_scribbles
if self.preserve_scribble:
# If none of the scribble falls in the mask after warping, undo warping
idx = torch.where(torch.sum(corrected_warped_scribbles, dim=(1,2,3)) == 0)
corrected_warped_scribbles[idx] = mask[idx] * scribbles[idx]
if self.max_pixels is not None:
noise = torch.stack([
voxynth.noise.perlin(shape=mask.shape[-2:], smoothing=self.max_pixels_smooth, magnitude=1, device=mask.device) for _ in range(bs)
]).unsqueeze(1) # shape: b x 1 x H x W
# Shift all noise to be positive
if noise.min() < 0:
noise = noise - noise.min()
# Get the top k pixels
flat_mask = (noise * corrected_warped_scribbles).view(bs, -1)
vals, idx = flat_mask.topk(k=(self.max_pixels*n_scribbles), dim=1)
binary_mask = torch.zeros_like(flat_mask)
binary_mask.scatter_(dim=1, index=idx, src=torch.ones_like(flat_mask))
corrected_warped_scribbles = binary_mask.view(*mask.shape) * corrected_warped_scribbles
if self.show:
import neurite as ne
import matplotlib.pyplot as plt
from .plot import show_scribbles
if self.max_pixels is not None:
binary_mask = binary_mask.reshape(*mask.shape)
tensors = [mask, scribbles, warped_scribbles, noise, binary_mask, corrected_warped_scribbles, mask]
titles = ["Mask", "Lines", "Warped Lines", 'Smooth Noise', 'Top k Pixels', 'Corrected Scribbles', 'Corrected Scribbles']
else:
tensors = [mask, scribbles, warped_scribbles, corrected_warped_scribbles, mask]
titles = ["Mask", "Lines", "Warped Lines", 'Corrected Scribbles', 'Corrected Scribbles']
fig,axes = ne.plot.slices(
sum([[x[i,0,...].cpu() for x in tensors] for i in range(bs)], []),
sum([titles for _ in range(bs)], []),
show=False, grid=(bs,len(titles)), width=3*len(titles), do_colorbars=False
)
if bs > 1:
for i in range(bs):
show_scribbles(corrected_warped_scribbles[i,0,...].cpu(), axes[i,-1])
else:
show_scribbles(corrected_warped_scribbles[0,0,...].cpu(), axes[-1])
plt.show()
return corrected_warped_scribbles # b x 1 x H x W
# -----------------------------------------------------------------------------
# Median Axis Scribble
# -----------------------------------------------------------------------------
class CenterlineScribble(WarpScribble):
"""
Generates scribbles by
1) skeletonizing the mask
2) chopping up with a random noise mask
3) warping with a random deformation field
4) then correcting any scribbles that fall outside the mask
5) optionally, limiting the max area of scribbles to k pixels
"""
def __init__(self,
# Warp settings
warp: bool = True,
warp_smoothing: Union[int,Tuple[int],List[int]] = (4, 16),
warp_magnitude: Union[int,Tuple[int],List[int]] = (1, 6),
mask_smoothing: Union[int,Tuple[int],List[int]] = (4, 16),
# Thickness of skeleton
dilate_kernel_size: Optional[int] = None,
preserve_scribble: bool = True, # if True, prevents empty scribble masks from being returned
max_pixels: Optional[int] = None, # per "scribble"
max_pixels_smooth: int = 42,
# Viz
show : bool = False
):
super().__init__(
warp=warp,
warp_smoothing=warp_smoothing,
warp_magnitude=warp_magnitude,
mask_smoothing=mask_smoothing,
)
self.dilate_kernel_size = dilate_kernel_size
self.preserve_scribble = preserve_scribble
self.max_pixels = max_pixels
self.max_pixels_smooth = max_pixels_smooth
self.show = show
def batch_scribble(self, mask: torch.Tensor, n_scribbles: Optional[int] = 1):
"""
Simulate scribbles for a batch of examples.
Args:
mask: (b,1,H,W) mask in [0,1] to sample scribbles from. torch.int32
n_scribbles: (int) only used when max_pixels is set as a multiplier for total area of the scribbles
currently, this argument does not control the number of components in the scribble mask
Returns:
scribble_mask: (b,1,H,W) mask(s) of scribbles in [0,1]
"""
assert len(mask.shape)==4, f"mask must be b x 1 x h x w. currently {mask.shape}"
bs = mask.shape[0]
mask_w_border = 255*mask.clone().moveaxis(1,-1)
mask_w_border[:,:,0,:] = 0
mask_w_border[:,:,-1,:] = 0
mask_w_border[:,0,:,:] = 0
mask_w_border[:,-1,:,:] = 0
# Skeletonize the mask
skeleton = torch.from_numpy(
np.stack([
cv2.ximgproc.thinning(mask_w_border[i,...].cpu().numpy().astype(np.uint8))/255 for i in range(bs)
])
).squeeze(-1).unsqueeze(1).to(mask.device).float() # shape: b x 1 x H x W
if self.dilate_kernel_size is not None:
# Dilate the boundary to make it thicker
k = _as_single_val(self.dilate_kernel_size)
if k > 0:
kernel = torch.ones((k,k), device=mask.device)
dilated_skeleton = kornia.morphology.dilation(skeleton, kernel=kernel, engine='convolution')
else:
dilated_skeleton = skeleton
noise_mask = self.noise_mask(shape=mask.shape, device=mask.device)
# Break up the boundary contours
scribbles = (dilated_skeleton * noise_mask) # shape: b x 1 x H x W
if self.preserve_scribble:
# If none of the scribbles fall in the random mask, keep the whole scribble
idx = torch.where(torch.sum(scribbles, dim=(1,2,3)) == 0)
scribbles[idx] = skeleton[idx]
if self.warp:
warped_scribbles = torch.stack([self.apply_warp(scribbles[b,...]) for b in range(bs)])
else:
warped_scribbles = scribbles
corrected_warped_scribbles = mask * warped_scribbles # shape: b x 1 x H x W
if self.preserve_scribble:
# If none of the scribble falls in the mask after warping, remove the warping
idx = torch.where(torch.sum(corrected_warped_scribbles, dim=(1,2,3)) == 0)
corrected_warped_scribbles[idx] = mask[idx] * scribbles[idx]
if self.max_pixels is not None:
noise = torch.stack([
voxynth.noise.perlin(shape=mask.shape[-2:], smoothing=self.max_pixels_smooth, magnitude=1, device=mask.device) for _ in range(bs)
]).unsqueeze(1) # shape: b x 1 x H x W
# Shift all noise mask to be positive
if noise.min() < 0:
noise = noise - noise.min()
flat_mask = (noise * corrected_warped_scribbles).view(bs, -1)
vals, idx = flat_mask.topk(k=(self.max_pixels*n_scribbles), dim=1)
binary_mask = torch.zeros_like(flat_mask)
binary_mask.scatter_(dim=1, index=idx, src=torch.ones_like(flat_mask))
corrected_warped_scribbles = binary_mask.view(*mask.shape) * corrected_warped_scribbles
if self.show:
import neurite as ne
from .plot import show_scribbles
import matplotlib.pyplot as plt
tensors = [mask, skeleton]
titles = ["Input Mask", "Skeleton"]
if self.dilate_kernel_size is not None:
tensors.append(dilated_skeleton)
titles.append('Dilated Skeleton')
if self.max_pixels is not None:
tensors += [noise_mask, scribbles, warped_scribbles, noise, binary_mask.reshape(*mask.shape), corrected_warped_scribbles, mask]
titles += ["Noise Mask", 'Broken Skeleton', 'Warped Scribbles', 'Smooth Noise', 'Top k Pixels', 'Corrected Scribbles', 'Corrected Scribbles']
else:
tensors += [noise_mask, scribbles, warped_scribbles, corrected_warped_scribbles, mask]
titles += ["Noise Mask", 'Broken Skeleton', 'Warped Scribbles', 'Corrected Scribbles', 'Corrected Scribbles']
fig,axes = ne.plot.slices(
sum([[x[i,...].squeeze().cpu() for x in tensors] for i in range(bs)], []),
sum([ titles for _ in range(bs)], []),
show=False, grid=(bs,len(titles)), width=3*len(titles)
)
if bs > 1:
for i in range(bs):
show_scribbles(corrected_warped_scribbles[i,0,...].cpu(), axes[i,-1])
else:
show_scribbles(corrected_warped_scribbles[0,0,...].cpu(), axes[-1])
plt.show()
return corrected_warped_scribbles
# -----------------------------------------------------------------------------
# Contour Scribbles
# -----------------------------------------------------------------------------
class ContourScribble(WarpScribble):
"""
Generates scribbles by
1) blurring and thresholding the mask, then getting the contours
2) chopping up the contour scribbles with a random noise mask
3) warping with a random deformation field
4) then correcting any scribbles that fall outside the mask
5) optionally, limiting the max area of scribbles to k pixels
"""
def __init__(self,
# Warp settings
warp: bool = True,
warp_smoothing: Union[int,Tuple[int],List[int]] = (4, 16),
warp_magnitude: Union[int,Tuple[int],List[int]] = (1, 6),
mask_smoothing: Union[int,Tuple[int],List[int]] = (4, 16),
# Blur settings
blur_kernel_size: int = 33,
blur_sigma: Union[float,Tuple[float],List[float]] = (5.0, 20.0),
# Other settings
dilate_kernel_size: Optional[Union[int, Tuple[int]]] = None,
preserve_scribble: bool = True, # if True, prevents empty scribble masks from being returned
max_pixels: Optional[int] = None, # per "scribble"
max_pixels_smooth: Optional[int] = 42,
# Viz
show : bool = False
):
super().__init__(
warp=warp,
warp_smoothing=warp_smoothing,
warp_magnitude=warp_magnitude,
mask_smoothing=mask_smoothing,
)
# Blur settings
if isinstance(blur_sigma, float) or isinstance(blur_sigma, int):
blur_sigma = (blur_sigma, blur_sigma+1e-7)
self.blur_fn = kornia.augmentation.RandomGaussianBlur(
kernel_size=(blur_kernel_size, blur_kernel_size), sigma=blur_sigma, p=1.
)
# Line thickness
self.dilate_kernel_size = dilate_kernel_size
# Corrections
self.preserve_scribble = preserve_scribble
self.max_pixels = max_pixels
self.max_pixels_smooth = max_pixels_smooth
# Viz
self.show = show
def batch_scribble(self, mask: torch.Tensor, n_scribbles: Optional[int] = 1):
"""
Args:
mask: (b,1,H,W) mask in [0,1] to sample scribbles from
n_scribbles: (int) only used when max_pixels is set as a multiplier for total area of the scribbles
currently, this argument does not control the number of components in the scribble mask
Returns:
scribble_mask: (b,1,H,W) mask(s) of scribbles in [0,1]
"""
assert len(mask.shape)==4, f"mask must be b x 1 x h x w. currently {mask.shape}"
bs = mask.shape[0]
rev_mask = (1 - mask)
blur_mask = self.blur_fn(rev_mask)
corrected_blur_mask = torch.reshape(torch.maximum(blur_mask, rev_mask), (bs,-1))
# Randomly sample a threshold for each example
min_bs = corrected_blur_mask.min(1)[0].cpu().numpy()
binary_mask = (torch.reshape(mask, (bs,-1)) > 0)*corrected_blur_mask
max_bs = torch.reshape(binary_mask, (bs,-1)).max(1)[0].cpu().numpy()
thresh = torch.from_numpy(np.random.uniform(min_bs, max_bs, size=bs)).to(mask.device)
# Apply threshold
thresh = thresh[...,None].repeat(1,mask.shape[-2]*mask.shape[-1])
binary_blur_mask = (corrected_blur_mask <= thresh).view(mask.shape).float()
# Use filter to get contours
_,boundary = kornia.filters.canny(binary_blur_mask, hysteresis=False)
if self.dilate_kernel_size is not None:
# Dilate the boundary to make it thicker
k = _as_single_val(self.dilate_kernel_size)
if k > 0:
kernel = torch.ones((k,k), device=boundary.device)
dilated_boundary = kornia.morphology.dilation(boundary, kernel=kernel, engine='convolution')
else:
dilated_boundary = boundary
else:
dilated_boundary = boundary
# Get noise mask to break up the contours
noise_mask = self.noise_mask(shape=mask.shape, device=mask.device)
# Break up the boundary contours
scribbles = dilated_boundary * noise_mask # shape: b x 1 x H x W
if self.preserve_scribble:
# If none of the scribbles fall in the noise mask, keep the whole scribble
idx = torch.where(torch.sum(scribbles, dim=(1,2,3)) == 0)[0]
scribbles[idx,...] = dilated_boundary[idx,...]
if self.warp:
warped_scribbles = torch.stack([self.apply_warp(scribbles[b,...]) for b in range(bs)])
else:
warped_scribbles = scribbles
# Remove scribbles that are outside the mask
corrected_warped_scribbles = mask * warped_scribbles
if self.preserve_scribble:
# If none of the scribble falls in the mask after warping, remove the warping
idx = torch.where(torch.sum(corrected_warped_scribbles, dim=(1,2,3)) == 0)[0]
corrected_warped_scribbles[idx,...] = mask[idx,...] * scribbles[idx,...]
if self.max_pixels is not None:
noise = torch.stack([
voxynth.noise.perlin(shape=mask.shape[-2:], smoothing=self.max_pixels_smooth, magnitude=1, device=mask.device) for _ in range(bs)
]).unsqueeze(1) # shape: b x 1 x H x W
# Shift noise mask to be positive
if noise.min() < 0:
noise = noise - noise.min()
flat_mask = (noise * corrected_warped_scribbles).view(bs, -1)
vals, idx = flat_mask.topk(k=(self.max_pixels*n_scribbles), dim=1)
binary_mask = torch.zeros_like(flat_mask)
binary_mask.scatter_(dim=1, index=idx, src=torch.ones_like(flat_mask))
corrected_warped_scribbles = binary_mask.view(*mask.shape) * corrected_warped_scribbles
if self.show:
import neurite as ne
from .plot import show_scribbles
import matplotlib.pyplot as plt
tensors = [mask, blur_mask.view(mask.shape), corrected_blur_mask.view(mask.shape), binary_blur_mask, boundary]
titles = ["Input Mask", "Blurred (Rev.) Mask", 'Corrected Blurred Mask', 'Thresholded Blur. Mask', 'Contours']
if self.dilate_kernel_size is not None:
tensors.append(dilated_boundary)
titles.append('Dilated Contours')
tensors += [noise_mask, scribbles, warped_scribbles]
titles += ['Noise Mask', 'Broken Contours', 'Warped Contours']
if self.max_pixels is not None:
tensors += [noise, binary_mask.reshape(*mask.shape), corrected_warped_scribbles, mask]
titles += ['Smooth Noise', 'Top k Pixels', "Corrected Scribbles", 'Corrected Scribbles']
else:
tensors += [corrected_warped_scribbles, mask]
titles += ["Corrected Scribbles", 'Corrected Scribbles']
fig,axes = ne.plot.slices(
sum([[x[i,0,...].cpu() for x in tensors] for i in range(bs)], []),
sum([titles for _ in range(bs)], []),
show=False, grid=(bs,len(titles)), width=3*len(titles)
)
if bs > 1:
for i in range(bs):
show_scribbles(corrected_warped_scribbles[i,0,...].cpu(), axes[i,-1])
else:
show_scribbles(corrected_warped_scribbles[0,0,...].cpu(), axes[-1])
plt.show()
return corrected_warped_scribbles