|
|
"`Image` provides support to convert, transform and show images" |
|
|
from ..torch_core import * |
|
|
from ..basic_data import * |
|
|
from ..layers import MSELossFlat |
|
|
from io import BytesIO |
|
|
import PIL |
|
|
|
|
|
__all__ = ['PIL', 'Image', 'ImageBBox', 'ImageSegment', 'ImagePoints', 'FlowField', 'RandTransform', 'TfmAffine', 'TfmCoord', |
|
|
'TfmCrop', 'TfmLighting', 'TfmPixel', 'Transform', 'bb2hw', 'image2np', 'open_image', 'open_mask', 'tis2hw', |
|
|
'pil2tensor', 'scale_flow', 'show_image', 'CoordFunc', 'TfmList', 'open_mask_rle', 'rle_encode', |
|
|
'rle_decode', 'ResizeMethod', 'plot_flat', 'plot_multi', 'show_multi', 'show_all'] |
|
|
|
|
|
ResizeMethod = IntEnum('ResizeMethod', 'CROP PAD SQUISH NO') |
|
|
def pil2tensor(image:Union[NPImage,NPArray],dtype:np.dtype)->TensorImage: |
|
|
"Convert PIL style `image` array to torch style image tensor." |
|
|
a = np.asarray(image) |
|
|
if a.ndim==2 : a = np.expand_dims(a,2) |
|
|
a = np.transpose(a, (1, 0, 2)) |
|
|
a = np.transpose(a, (2, 1, 0)) |
|
|
return torch.from_numpy(a.astype(dtype, copy=False) ) |
|
|
|
|
|
def image2np(image:Tensor)->np.ndarray: |
|
|
"Convert from torch style `image` to numpy/matplotlib style." |
|
|
res = image.cpu().permute(1,2,0).numpy() |
|
|
return res[...,0] if res.shape[2]==1 else res |
|
|
|
|
|
def bb2hw(a:Collection[int])->np.ndarray: |
|
|
"Convert bounding box points from (width,height,center) to (height,width,top,left)." |
|
|
return np.array([a[1],a[0],a[3]-a[1],a[2]-a[0]]) |
|
|
|
|
|
def tis2hw(size:Union[int,TensorImageSize]) -> Tuple[int,int]: |
|
|
"Convert `int` or `TensorImageSize` to (height,width) of an image." |
|
|
if type(size) is str: raise RuntimeError("Expected size to be an int or a tuple, got a string.") |
|
|
return listify(size, 2) if isinstance(size, int) else listify(size[-2:],2) |
|
|
|
|
|
def _draw_outline(o:Patch, lw:int): |
|
|
"Outline bounding box onto image `Patch`." |
|
|
o.set_path_effects([patheffects.Stroke( |
|
|
linewidth=lw, foreground='black'), patheffects.Normal()]) |
|
|
|
|
|
def _draw_rect(ax:plt.Axes, b:Collection[int], color:str='white', text=None, text_size=14): |
|
|
"Draw bounding box on `ax`." |
|
|
patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2)) |
|
|
_draw_outline(patch, 4) |
|
|
if text is not None: |
|
|
patch = ax.text(*b[:2], text, verticalalignment='top', color=color, fontsize=text_size, weight='bold') |
|
|
_draw_outline(patch,1) |
|
|
|
|
|
def _get_default_args(func:Callable): |
|
|
return {k: v.default |
|
|
for k, v in inspect.signature(func).parameters.items() |
|
|
if v.default is not inspect.Parameter.empty} |
|
|
|
|
|
@dataclass |
|
|
class FlowField(): |
|
|
"Wrap together some coords `flow` with a `size`." |
|
|
size:Tuple[int,int] |
|
|
flow:Tensor |
|
|
|
|
|
CoordFunc = Callable[[FlowField, ArgStar, KWArgs], LogitTensorImage] |
|
|
|
|
|
class Image(ItemBase): |
|
|
"Support applying transforms to image data in `px`." |
|
|
def __init__(self, px:Tensor): |
|
|
self._px = px |
|
|
self._logit_px=None |
|
|
self._flow=None |
|
|
self._affine_mat=None |
|
|
self.sample_kwargs = {} |
|
|
|
|
|
def set_sample(self, **kwargs)->'ImageBase': |
|
|
"Set parameters that control how we `grid_sample` the image after transforms are applied." |
|
|
self.sample_kwargs = kwargs |
|
|
return self |
|
|
|
|
|
def clone(self): |
|
|
"Mimic the behavior of torch.clone for `Image` objects." |
|
|
return self.__class__(self.px.clone()) |
|
|
|
|
|
@property |
|
|
def shape(self)->Tuple[int,int,int]: return self._px.shape |
|
|
@property |
|
|
def size(self)->Tuple[int,int]: return self.shape[-2:] |
|
|
@property |
|
|
def device(self)->torch.device: return self._px.device |
|
|
|
|
|
def __repr__(self): return f'{self.__class__.__name__} {tuple(self.shape)}' |
|
|
def _repr_png_(self): return self._repr_image_format('png') |
|
|
def _repr_jpeg_(self): return self._repr_image_format('jpeg') |
|
|
|
|
|
def _repr_image_format(self, format_str): |
|
|
with BytesIO() as str_buffer: |
|
|
plt.imsave(str_buffer, image2np(self.px), format=format_str) |
|
|
return str_buffer.getvalue() |
|
|
|
|
|
def apply_tfms(self, tfms:TfmList, do_resolve:bool=True, xtra:Optional[Dict[Callable,dict]]=None, |
|
|
size:Optional[Union[int,TensorImageSize]]=None, resize_method:ResizeMethod=None, |
|
|
mult:int=None, padding_mode:str='reflection', mode:str='bilinear', remove_out:bool=True, |
|
|
is_x:bool=True, x_frames:int=1, y_frames:int=1)->TensorImage: |
|
|
"Apply all `tfms` to the `Image`, if `do_resolve` picks value for random args." |
|
|
if not (tfms or xtra or size): return self |
|
|
|
|
|
if size is not None and isinstance(size, int): |
|
|
num_frames = x_frames if is_x else y_frames |
|
|
if num_frames > 1: |
|
|
size = (size, size*num_frames) |
|
|
|
|
|
tfms = listify(tfms) |
|
|
xtra = ifnone(xtra, {}) |
|
|
default_rsz = ResizeMethod.SQUISH if (size is not None and is_listy(size)) else ResizeMethod.CROP |
|
|
resize_method = ifnone(resize_method, default_rsz) |
|
|
if resize_method <= 2 and size is not None: tfms = self._maybe_add_crop_pad(tfms) |
|
|
tfms = sorted(tfms, key=lambda o: o.tfm.order) |
|
|
if do_resolve: _resolve_tfms(tfms) |
|
|
x = self.clone() |
|
|
x.set_sample(padding_mode=padding_mode, mode=mode, remove_out=remove_out) |
|
|
if size is not None: |
|
|
crop_target = _get_crop_target(size, mult=mult) |
|
|
if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD): |
|
|
target = _get_resize_target(x, crop_target, do_crop=(resize_method==ResizeMethod.CROP)) |
|
|
x.resize(target) |
|
|
elif resize_method==ResizeMethod.SQUISH: x.resize((x.shape[0],) + crop_target) |
|
|
else: size = x.size |
|
|
size_tfms = [o for o in tfms if isinstance(o.tfm,TfmCrop)] |
|
|
for tfm in tfms: |
|
|
if tfm.tfm in xtra: x = tfm(x, **xtra[tfm.tfm]) |
|
|
elif tfm in size_tfms: |
|
|
if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD): |
|
|
x = tfm(x, size=_get_crop_target(size,mult=mult), padding_mode=padding_mode) |
|
|
else: x = tfm(x) |
|
|
return x.refresh() |
|
|
|
|
|
def refresh(self)->None: |
|
|
"Apply any logit, flow, or affine transfers that have been sent to the `Image`." |
|
|
if self._logit_px is not None: |
|
|
self._px = self._logit_px.sigmoid_() |
|
|
self._logit_px = None |
|
|
if self._affine_mat is not None or self._flow is not None: |
|
|
self._px = _grid_sample(self._px, self.flow, **self.sample_kwargs) |
|
|
self.sample_kwargs = {} |
|
|
self._flow = None |
|
|
return self |
|
|
|
|
|
def save(self, fn:PathOrStr): |
|
|
"Save the image to `fn`." |
|
|
x = image2np(self.data*255).astype(np.uint8) |
|
|
PIL.Image.fromarray(x).save(fn) |
|
|
|
|
|
@property |
|
|
def px(self)->TensorImage: |
|
|
"Get the tensor pixel buffer." |
|
|
self.refresh() |
|
|
return self._px |
|
|
@px.setter |
|
|
def px(self,v:TensorImage)->None: |
|
|
"Set the pixel buffer to `v`." |
|
|
self._px=v |
|
|
|
|
|
@property |
|
|
def flow(self)->FlowField: |
|
|
"Access the flow-field grid after applying queued affine transforms." |
|
|
if self._flow is None: |
|
|
self._flow = _affine_grid(self.shape) |
|
|
if self._affine_mat is not None: |
|
|
self._flow = _affine_mult(self._flow,self._affine_mat) |
|
|
self._affine_mat = None |
|
|
return self._flow |
|
|
|
|
|
@flow.setter |
|
|
def flow(self,v:FlowField): self._flow=v |
|
|
|
|
|
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any): |
|
|
"Equivalent to `image = sigmoid(func(logit(image)))`." |
|
|
self.logit_px = func(self.logit_px, *args, **kwargs) |
|
|
return self |
|
|
|
|
|
def pixel(self, func:PixelFunc, *args, **kwargs)->'Image': |
|
|
"Equivalent to `image.px = func(image.px)`." |
|
|
self.px = func(self.px, *args, **kwargs) |
|
|
return self |
|
|
|
|
|
def coord(self, func:CoordFunc, *args, **kwargs)->'Image': |
|
|
"Equivalent to `image.flow = func(image.flow, image.size)`." |
|
|
self.flow = func(self.flow, *args, **kwargs) |
|
|
return self |
|
|
|
|
|
def affine(self, func:AffineFunc, *args, **kwargs)->'Image': |
|
|
"Equivalent to `image.affine_mat = image.affine_mat @ func()`." |
|
|
m = tensor(func(*args, **kwargs)).to(self.device) |
|
|
self.affine_mat = self.affine_mat @ m |
|
|
return self |
|
|
|
|
|
def resize(self, size:Union[int,TensorImageSize])->'Image': |
|
|
"Resize the image to `size`, size can be a single int." |
|
|
assert self._flow is None |
|
|
if isinstance(size, int): size=(self.shape[0], size, size) |
|
|
if tuple(size)==tuple(self.shape): return self |
|
|
self.flow = _affine_grid(size) |
|
|
return self |
|
|
|
|
|
@property |
|
|
def affine_mat(self)->AffineMatrix: |
|
|
"Get the affine matrix that will be applied by `refresh`." |
|
|
if self._affine_mat is None: |
|
|
self._affine_mat = torch.eye(3).to(self.device) |
|
|
return self._affine_mat |
|
|
@affine_mat.setter |
|
|
def affine_mat(self,v)->None: self._affine_mat=v |
|
|
|
|
|
@property |
|
|
def logit_px(self)->LogitTensorImage: |
|
|
"Get logit(image.px)." |
|
|
if self._logit_px is None: self._logit_px = logit_(self.px) |
|
|
return self._logit_px |
|
|
@logit_px.setter |
|
|
def logit_px(self,v:LogitTensorImage)->None: self._logit_px=v |
|
|
|
|
|
@property |
|
|
def data(self)->TensorImage: |
|
|
"Return this images pixels as a tensor." |
|
|
return self.px |
|
|
|
|
|
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True, |
|
|
cmap:str=None, y:Any=None, **kwargs): |
|
|
"Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`" |
|
|
cmap = ifnone(cmap, defaults.cmap) |
|
|
ax = show_image(self, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize) |
|
|
if y is not None: y.show(ax=ax, **kwargs) |
|
|
if title is not None: ax.set_title(title) |
|
|
|
|
|
class ImageSegment(Image): |
|
|
"Support applying transforms to segmentation masks data in `px`." |
|
|
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'Image': return self |
|
|
|
|
|
def refresh(self): |
|
|
self.sample_kwargs['mode'] = 'nearest' |
|
|
return super().refresh() |
|
|
|
|
|
@property |
|
|
def data(self)->TensorImage: |
|
|
"Return this image pixels as a `LongTensor`." |
|
|
return self.px.long() |
|
|
|
|
|
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True, |
|
|
cmap:str='tab20', alpha:float=0.5, **kwargs): |
|
|
"Show the `ImageSegment` on `ax`." |
|
|
ax = show_image(self, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize, |
|
|
interpolation='nearest', alpha=alpha, vmin=0, **kwargs) |
|
|
if title: ax.set_title(title) |
|
|
|
|
|
def reconstruct(self, t:Tensor): return ImageSegment(t) |
|
|
|
|
|
class ImagePoints(Image): |
|
|
"Support applying transforms to a `flow` of points." |
|
|
def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True): |
|
|
if scale: flow = scale_flow(flow) |
|
|
if y_first: flow.flow = flow.flow.flip(1) |
|
|
self._flow = flow |
|
|
self._affine_mat = None |
|
|
self.flow_func = [] |
|
|
self.sample_kwargs = {} |
|
|
self.transformed = False |
|
|
self.loss_func = MSELossFlat() |
|
|
|
|
|
def clone(self): |
|
|
"Mimic the behavior of torch.clone for `ImagePoints` objects." |
|
|
return self.__class__(FlowField(self.size, self.flow.flow.clone()), scale=False, y_first=False) |
|
|
|
|
|
@property |
|
|
def shape(self)->Tuple[int,int,int]: return (1, *self._flow.size) |
|
|
@property |
|
|
def size(self)->Tuple[int,int]: return self._flow.size |
|
|
@size.setter |
|
|
def size(self, sz:int): self._flow.size=sz |
|
|
@property |
|
|
def device(self)->torch.device: return self._flow.flow.device |
|
|
|
|
|
def __repr__(self): return f'{self.__class__.__name__} {tuple(self.size)}' |
|
|
def _repr_image_format(self, format_str): return None |
|
|
|
|
|
@property |
|
|
def flow(self)->FlowField: |
|
|
"Access the flow-field grid after applying queued affine and coord transforms." |
|
|
if self._affine_mat is not None: |
|
|
self._flow = _affine_inv_mult(self._flow, self._affine_mat) |
|
|
self._affine_mat = None |
|
|
self.transformed = True |
|
|
if len(self.flow_func) != 0: |
|
|
for f in self.flow_func[::-1]: self._flow = f(self._flow) |
|
|
self.transformed = True |
|
|
self.flow_func = [] |
|
|
return self._flow |
|
|
|
|
|
@flow.setter |
|
|
def flow(self,v:FlowField): self._flow=v |
|
|
|
|
|
def coord(self, func:CoordFunc, *args, **kwargs)->'ImagePoints': |
|
|
"Put `func` with `args` and `kwargs` in `self.flow_func` for later." |
|
|
if 'invert' in kwargs: kwargs['invert'] = True |
|
|
else: warn(f"{func.__name__} isn't implemented for {self.__class__}.") |
|
|
self.flow_func.append(partial(func, *args, **kwargs)) |
|
|
return self |
|
|
|
|
|
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'ImagePoints': return self |
|
|
|
|
|
def pixel(self, func:PixelFunc, *args, **kwargs)->'ImagePoints': |
|
|
"Equivalent to `self = func_flow(self)`." |
|
|
self = func(self, *args, **kwargs) |
|
|
self.transformed=True |
|
|
return self |
|
|
|
|
|
def refresh(self) -> 'ImagePoints': |
|
|
return self |
|
|
|
|
|
def resize(self, size:Union[int,TensorImageSize]) -> 'ImagePoints': |
|
|
"Resize the image to `size`, size can be a single int." |
|
|
if isinstance(size, int): size=(1, size, size) |
|
|
self._flow.size = size[1:] |
|
|
return self |
|
|
|
|
|
@property |
|
|
def data(self)->Tensor: |
|
|
"Return the points associated to this object." |
|
|
flow = self.flow |
|
|
if self.transformed: |
|
|
if 'remove_out' not in self.sample_kwargs or self.sample_kwargs['remove_out']: |
|
|
flow = _remove_points_out(flow) |
|
|
self.transformed=False |
|
|
return flow.flow.flip(1) |
|
|
|
|
|
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True, **kwargs): |
|
|
"Show the `ImagePoints` on `ax`." |
|
|
if ax is None: _,ax = plt.subplots(figsize=figsize) |
|
|
pnt = scale_flow(FlowField(self.size, self.data), to_unit=False).flow.flip(1) |
|
|
params = {'s': 10, 'marker': '.', 'c': 'r', **kwargs} |
|
|
ax.scatter(pnt[:, 0], pnt[:, 1], **params) |
|
|
if hide_axis: ax.axis('off') |
|
|
if title: ax.set_title(title) |
|
|
|
|
|
class ImageBBox(ImagePoints): |
|
|
"Support applying transforms to a `flow` of bounding boxes." |
|
|
def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True, labels:Collection=None, |
|
|
classes:dict=None, pad_idx:int=0): |
|
|
super().__init__(flow, scale, y_first) |
|
|
self.pad_idx = pad_idx |
|
|
if labels is not None and len(labels)>0 and not isinstance(labels[0],Category): |
|
|
labels = array([Category(l,classes[l]) for l in labels]) |
|
|
self.labels = labels |
|
|
|
|
|
def clone(self) -> 'ImageBBox': |
|
|
"Mimic the behavior of torch.clone for `Image` objects." |
|
|
flow = FlowField(self.size, self.flow.flow.clone()) |
|
|
return self.__class__(flow, scale=False, y_first=False, labels=self.labels, pad_idx=self.pad_idx) |
|
|
|
|
|
@classmethod |
|
|
def create(cls, h:int, w:int, bboxes:Collection[Collection[int]], labels:Collection=None, classes:dict=None, |
|
|
pad_idx:int=0, scale:bool=True)->'ImageBBox': |
|
|
"Create an ImageBBox object from `bboxes`." |
|
|
if isinstance(bboxes, np.ndarray) and bboxes.dtype == np.object: bboxes = np.array([bb for bb in bboxes]) |
|
|
bboxes = tensor(bboxes).float() |
|
|
tr_corners = torch.cat([bboxes[:,0][:,None], bboxes[:,3][:,None]], 1) |
|
|
bl_corners = bboxes[:,1:3].flip(1) |
|
|
bboxes = torch.cat([bboxes[:,:2], tr_corners, bl_corners, bboxes[:,2:]], 1) |
|
|
flow = FlowField((h,w), bboxes.view(-1,2)) |
|
|
return cls(flow, labels=labels, classes=classes, pad_idx=pad_idx, y_first=True, scale=scale) |
|
|
|
|
|
def _compute_boxes(self) -> Tuple[LongTensor, LongTensor]: |
|
|
bboxes = self.flow.flow.flip(1).view(-1, 4, 2).contiguous().clamp(min=-1, max=1) |
|
|
mins, maxes = bboxes.min(dim=1)[0], bboxes.max(dim=1)[0] |
|
|
bboxes = torch.cat([mins, maxes], 1) |
|
|
mask = (bboxes[:,2]-bboxes[:,0] > 0) * (bboxes[:,3]-bboxes[:,1] > 0) |
|
|
if len(mask) == 0: return tensor([self.pad_idx] * 4), tensor([self.pad_idx]) |
|
|
res = bboxes[mask] |
|
|
if self.labels is None: return res,None |
|
|
return res, self.labels[to_np(mask).astype(bool)] |
|
|
|
|
|
@property |
|
|
def data(self)->Union[FloatTensor, Tuple[FloatTensor,LongTensor]]: |
|
|
bboxes,lbls = self._compute_boxes() |
|
|
lbls = np.array([o.data for o in lbls]) if lbls is not None else None |
|
|
return bboxes if lbls is None else (bboxes, lbls) |
|
|
|
|
|
def show(self, y:Image=None, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True, |
|
|
color:str='white', **kwargs): |
|
|
"Show the `ImageBBox` on `ax`." |
|
|
if ax is None: _,ax = plt.subplots(figsize=figsize) |
|
|
bboxes, lbls = self._compute_boxes() |
|
|
h,w = self.flow.size |
|
|
bboxes.add_(1).mul_(torch.tensor([h/2, w/2, h/2, w/2])).long() |
|
|
for i, bbox in enumerate(bboxes): |
|
|
if lbls is not None: text = str(lbls[i]) |
|
|
else: text=None |
|
|
_draw_rect(ax, bb2hw(bbox), text=text, color=color) |
|
|
|
|
|
def open_image(fn:PathOrStr, div:bool=True, convert_mode:str='RGB', cls:type=Image, |
|
|
after_open:Callable=None)->Image: |
|
|
"Return `Image` object created from image in file `fn`." |
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore", UserWarning) |
|
|
x = PIL.Image.open(fn).convert(convert_mode) |
|
|
if after_open: x = after_open(x) |
|
|
x = pil2tensor(x,np.float32) |
|
|
if div: x.div_(255) |
|
|
return cls(x) |
|
|
|
|
|
def open_mask(fn:PathOrStr, div=False, convert_mode='L', after_open:Callable=None)->ImageSegment: |
|
|
"Return `ImageSegment` object create from mask in file `fn`. If `div`, divides pixel values by 255." |
|
|
return open_image(fn, div=div, convert_mode=convert_mode, cls=ImageSegment, after_open=after_open) |
|
|
|
|
|
def open_mask_rle(mask_rle:str, shape:Tuple[int, int])->ImageSegment: |
|
|
"Return `ImageSegment` object create from run-length encoded string in `mask_lre` with size in `shape`." |
|
|
x = FloatTensor(rle_decode(str(mask_rle), shape).astype(np.uint8)) |
|
|
x = x.view(shape[1], shape[0], -1) |
|
|
return ImageSegment(x.permute(2,0,1)) |
|
|
|
|
|
def rle_encode(img:NPArrayMask)->str: |
|
|
"Return run-length encoding string from `img`." |
|
|
pixels = np.concatenate([[0], img.flatten() , [0]]) |
|
|
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 |
|
|
runs[1::2] -= runs[::2] |
|
|
return ' '.join(str(x) for x in runs) |
|
|
|
|
|
def rle_decode(mask_rle:str, shape:Tuple[int,int])->NPArrayMask: |
|
|
"Return an image array from run-length encoded string `mask_rle` with `shape`." |
|
|
s = mask_rle.split() |
|
|
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] |
|
|
starts -= 1 |
|
|
ends = starts + lengths |
|
|
img = np.zeros(shape[0]*shape[1], dtype=np.uint) |
|
|
for low, up in zip(starts, ends): img[low:up] = 1 |
|
|
return img.reshape(shape) |
|
|
|
|
|
def show_image(img:Image, ax:plt.Axes=None, figsize:tuple=(3,3), hide_axis:bool=True, cmap:str='binary', |
|
|
alpha:float=None, **kwargs)->plt.Axes: |
|
|
"Display `Image` in notebook." |
|
|
if ax is None: fig,ax = plt.subplots(figsize=figsize) |
|
|
ax.imshow(image2np(img.data), cmap=cmap, alpha=alpha, **kwargs) |
|
|
if hide_axis: ax.axis('off') |
|
|
return ax |
|
|
|
|
|
def scale_flow(flow, to_unit=True): |
|
|
"Scale the coords in `flow` to -1/1 or the image size depending on `to_unit`." |
|
|
s = tensor([flow.size[0]/2,flow.size[1]/2])[None] |
|
|
if to_unit: flow.flow = flow.flow/s-1 |
|
|
else: flow.flow = (flow.flow+1)*s |
|
|
return flow |
|
|
|
|
|
def _remove_points_out(flow:FlowField): |
|
|
pad_mask = (flow.flow[:,0] >= -1) * (flow.flow[:,0] <= 1) * (flow.flow[:,1] >= -1) * (flow.flow[:,1] <= 1) |
|
|
flow.flow = flow.flow[pad_mask] |
|
|
return flow |
|
|
|
|
|
class Transform(): |
|
|
"Utility class for adding probability and wrapping support to transform `func`." |
|
|
_wrap=None |
|
|
order=0 |
|
|
def __init__(self, func:Callable, order:Optional[int]=None): |
|
|
"Create a transform for `func` and assign it an priority `order`, attach to `Image` class." |
|
|
if order is not None: self.order=order |
|
|
self.func=func |
|
|
self.func.__name__ = func.__name__[1:] |
|
|
functools.update_wrapper(self, self.func) |
|
|
self.func.__annotations__['return'] = Image |
|
|
self.params = copy(func.__annotations__) |
|
|
self.def_args = _get_default_args(func) |
|
|
setattr(Image, func.__name__, |
|
|
lambda x, *args, **kwargs: self.calc(x, *args, **kwargs)) |
|
|
|
|
|
def __call__(self, *args:Any, p:float=1., is_random:bool=True, use_on_y:bool=True, **kwargs:Any)->Image: |
|
|
"Calc now if `args` passed; else create a transform called prob `p` if `random`." |
|
|
if args: return self.calc(*args, **kwargs) |
|
|
else: return RandTransform(self, kwargs=kwargs, is_random=is_random, use_on_y=use_on_y, p=p) |
|
|
|
|
|
def calc(self, x:Image, *args:Any, **kwargs:Any)->Image: |
|
|
"Apply to image `x`, wrapping it if necessary." |
|
|
if self._wrap: return getattr(x, self._wrap)(self.func, *args, **kwargs) |
|
|
else: return self.func(x, *args, **kwargs) |
|
|
|
|
|
@property |
|
|
def name(self)->str: return self.__class__.__name__ |
|
|
|
|
|
def __repr__(self)->str: return f'{self.name} ({self.func.__name__})' |
|
|
|
|
|
@dataclass |
|
|
class RandTransform(): |
|
|
"Wrap `Transform` to add randomized execution." |
|
|
tfm:Transform |
|
|
kwargs:dict |
|
|
p:float=1.0 |
|
|
resolved:dict = field(default_factory=dict) |
|
|
do_run:bool = True |
|
|
is_random:bool = True |
|
|
use_on_y:bool = True |
|
|
def __post_init__(self): functools.update_wrapper(self, self.tfm) |
|
|
|
|
|
def resolve(self)->None: |
|
|
"Bind any random variables in the transform." |
|
|
if not self.is_random: |
|
|
self.resolved = {**self.tfm.def_args, **self.kwargs} |
|
|
return |
|
|
|
|
|
self.resolved = {} |
|
|
|
|
|
for k,v in self.kwargs.items(): |
|
|
|
|
|
if k in self.tfm.params: |
|
|
rand_func = self.tfm.params[k] |
|
|
self.resolved[k] = rand_func(*listify(v)) |
|
|
|
|
|
else: self.resolved[k] = v |
|
|
|
|
|
for k,v in self.tfm.def_args.items(): |
|
|
if k not in self.resolved: self.resolved[k]=v |
|
|
|
|
|
for k,v in self.tfm.params.items(): |
|
|
if k not in self.resolved and k!='return': self.resolved[k]=v() |
|
|
|
|
|
self.do_run = rand_bool(self.p) |
|
|
|
|
|
@property |
|
|
def order(self)->int: return self.tfm.order |
|
|
|
|
|
def __call__(self, x:Image, *args, **kwargs)->Image: |
|
|
"Randomly execute our tfm on `x`." |
|
|
return self.tfm(x, *args, **{**self.resolved, **kwargs}) if self.do_run else x |
|
|
|
|
|
def _resolve_tfms(tfms:TfmList): |
|
|
"Resolve every tfm in `tfms`." |
|
|
for f in listify(tfms): f.resolve() |
|
|
|
|
|
def _grid_sample(x:TensorImage, coords:FlowField, mode:str='bilinear', padding_mode:str='reflection', remove_out:bool=True)->TensorImage: |
|
|
"Resample pixels in `coords` from `x` by `mode`, with `padding_mode` in ('reflection','border','zeros')." |
|
|
coords = coords.flow.permute(0, 3, 1, 2).contiguous().permute(0, 2, 3, 1) |
|
|
if mode=='bilinear': |
|
|
mn,mx = coords.min(),coords.max() |
|
|
|
|
|
z = 1/(mx-mn).item()*2 |
|
|
|
|
|
d = min(x.shape[1]/coords.shape[1], x.shape[2]/coords.shape[2])/2 |
|
|
|
|
|
if d>1 and d>z: x = F.interpolate(x[None], scale_factor=1/d, mode='area')[0] |
|
|
return F.grid_sample(x[None], coords, mode=mode, padding_mode=padding_mode)[0] |
|
|
|
|
|
def _affine_grid(size:TensorImageSize)->FlowField: |
|
|
size = ((1,)+size) |
|
|
N, C, H, W = size |
|
|
grid = FloatTensor(N, H, W, 2) |
|
|
linear_points = torch.linspace(-1, 1, W) if W > 1 else tensor([-1]) |
|
|
grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, :, 0]) |
|
|
linear_points = torch.linspace(-1, 1, H) if H > 1 else tensor([-1]) |
|
|
grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, :, 1]) |
|
|
return FlowField(size[2:], grid) |
|
|
|
|
|
def _affine_mult(c:FlowField,m:AffineMatrix)->FlowField: |
|
|
"Multiply `c` by `m` - can adjust for rectangular shaped `c`." |
|
|
if m is None: return c |
|
|
size = c.flow.size() |
|
|
h,w = c.size |
|
|
m[0,1] *= h/w |
|
|
m[1,0] *= w/h |
|
|
c.flow = c.flow.view(-1,2) |
|
|
c.flow = torch.addmm(m[:2,2], c.flow, m[:2,:2].t()).view(size) |
|
|
return c |
|
|
|
|
|
def _affine_inv_mult(c, m): |
|
|
"Applies the inverse affine transform described in `m` to `c`." |
|
|
size = c.flow.size() |
|
|
h,w = c.size |
|
|
m[0,1] *= h/w |
|
|
m[1,0] *= w/h |
|
|
c.flow = c.flow.view(-1,2) |
|
|
a = torch.inverse(m[:2,:2].t()) |
|
|
c.flow = torch.mm(c.flow - m[:2,2], a).view(size) |
|
|
return c |
|
|
|
|
|
class TfmAffine(Transform): |
|
|
"Decorator for affine tfm funcs." |
|
|
order,_wrap = 5,'affine' |
|
|
class TfmPixel(Transform): |
|
|
"Decorator for pixel tfm funcs." |
|
|
order,_wrap = 10,'pixel' |
|
|
class TfmCoord(Transform): |
|
|
"Decorator for coord tfm funcs." |
|
|
order,_wrap = 4,'coord' |
|
|
class TfmCrop(TfmPixel): |
|
|
"Decorator for crop tfm funcs." |
|
|
order=99 |
|
|
class TfmLighting(Transform): |
|
|
"Decorator for lighting tfm funcs." |
|
|
order,_wrap = 8,'lighting' |
|
|
|
|
|
def _round_multiple(x:int, mult:int=None)->int: |
|
|
"Calc `x` to nearest multiple of `mult`." |
|
|
return (int(x/mult+0.5)*mult) if mult is not None else x |
|
|
|
|
|
def _get_crop_target(target_px:Union[int,TensorImageSize], mult:int=None)->Tuple[int,int]: |
|
|
"Calc crop shape of `target_px` to nearest multiple of `mult`." |
|
|
target_r,target_c = tis2hw(target_px) |
|
|
return _round_multiple(target_r,mult),_round_multiple(target_c,mult) |
|
|
|
|
|
def _get_resize_target(img, crop_target, do_crop=False)->TensorImageSize: |
|
|
"Calc size of `img` to fit in `crop_target` - adjust based on `do_crop`." |
|
|
if crop_target is None: return None |
|
|
ch,r,c = img.shape |
|
|
target_r,target_c = crop_target |
|
|
ratio = (min if do_crop else max)(r/target_r, c/target_c) |
|
|
return ch,int(round(r/ratio)),int(round(c/ratio)) |
|
|
|
|
|
def plot_flat(r, c, figsize): |
|
|
"Shortcut for `enumerate(subplots.flatten())`" |
|
|
return enumerate(plt.subplots(r, c, figsize=figsize)[1].flatten()) |
|
|
|
|
|
def plot_multi(func:Callable[[int,int,plt.Axes],None], r:int=1, c:int=1, figsize:Tuple=(12,6)): |
|
|
"Call `func` for every combination of `r,c` on a subplot" |
|
|
axes = plt.subplots(r, c, figsize=figsize)[1] |
|
|
for i in range(r): |
|
|
for j in range(c): func(i,j,axes[i,j]) |
|
|
|
|
|
def show_multi(func:Callable[[int,int],Image], r:int=1, c:int=1, figsize:Tuple=(9,9)): |
|
|
"Call `func(i,j).show(ax)` for every combination of `r,c`" |
|
|
plot_multi(lambda i,j,ax: func(i,j).show(ax), r, c, figsize=figsize) |
|
|
|
|
|
def show_all(imgs:Collection[Image], r:int=1, c:Optional[int]=None, figsize=(12,6)): |
|
|
"Show all `imgs` using `r` rows" |
|
|
imgs = listify(imgs) |
|
|
if c is None: c = len(imgs)//r |
|
|
for i,ax in plot_flat(r,c,figsize): imgs[i].show(ax) |
|
|
|