from typing import Dict, Optional, Tuple, Union import torch from kornia.geometry.bbox import bbox_generator3d from kornia.utils import _extract_device_dtype from ..utils import _adapted_uniform, _joint_range_check def random_rotation_generator3d( batch_size: int, degrees: torch.Tensor, same_on_batch: bool = False, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: r"""Get parameters for ``rotate`` for a random rotate transform. Args: batch_size (int): the tensor batch size. degrees (torch.Tensor): Ranges of degrees (3, 2) for yaw, pitch and roll. same_on_batch (bool): apply the same transformation across the batch. Default: False. device (torch.device): the device on which the random numbers will be generated. Default: cpu. dtype (torch.dtype): the data type of the generated random numbers. Default: float32. Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. - yaw (torch.Tensor): element-wise rotation yaws with a shape of (B,). - pitch (torch.Tensor): element-wise rotation pitches with a shape of (B,). - roll (torch.Tensor): element-wise rotation rolls with a shape of (B,). """ if degrees.shape != torch.Size([3, 2]): raise AssertionError(f"'degrees' must be the shape of (3, 2). Got {degrees.shape}.") _device, _dtype = _extract_device_dtype([degrees]) degrees = degrees.to(device=device, dtype=dtype) yaw = _adapted_uniform((batch_size,), degrees[0][0], degrees[0][1], same_on_batch) pitch = _adapted_uniform((batch_size,), degrees[1][0], degrees[1][1], same_on_batch) roll = _adapted_uniform((batch_size,), degrees[2][0], degrees[2][1], same_on_batch) return dict( yaw=yaw.to(device=_device, dtype=_dtype), pitch=pitch.to(device=_device, dtype=_dtype), roll=roll.to(device=_device, dtype=_dtype), ) def random_affine_generator3d( batch_size: int, depth: int, height: int, width: int, degrees: torch.Tensor, translate: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, shears: Optional[torch.Tensor] = None, same_on_batch: bool = False, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: r"""Get parameters for ```3d affine``` transformation random affine transform. Args: batch_size (int): the tensor batch size. depth (int) : depth of the image. height (int) : height of the image. width (int): width of the image. degrees (torch.Tensor): Ranges of degrees with shape (3, 2) for yaw, pitch and roll. translate (torch.Tensor, optional): maximum absolute fraction with shape (3,) for horizontal, vertical and depthical translations (dx,dy,dz). Will not translate by default. scale (torch.Tensor, optional): scaling factor interval, e.g (a, b), then scale is randomly sampled from the range a <= scale <= b. Will keep original scale by default. shear (sequence or float, optional): Range of degrees to select from. Shaped as (6, 2) for 6 facet (xy, xz, yx, yz, zx, zy). The shear to the i-th facet in the range (-shear[i, 0], shear[i, 1]) will be applied. same_on_batch (bool): apply the same transformation across the batch. Default: False Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. - translations (torch.Tensor): element-wise translations with a shape of (B, 3). - center (torch.Tensor): element-wise center with a shape of (B, 3). - scale (torch.Tensor): element-wise scales with a shape of (B, 3). - angle (torch.Tensor): element-wise rotation angles with a shape of (B, 3). - sxy (torch.Tensor): element-wise x-y-facet shears with a shape of (B,). - sxz (torch.Tensor): element-wise x-z-facet shears with a shape of (B,). - syx (torch.Tensor): element-wise y-x-facet shears with a shape of (B,). - syz (torch.Tensor): element-wise y-z-facet shears with a shape of (B,). - szx (torch.Tensor): element-wise z-x-facet shears with a shape of (B,). - szy (torch.Tensor): element-wise z-y-facet shears with a shape of (B,). Note: The generated random numbers are not reproducible across different devices and dtypes. """ if not ( type(depth) is int and depth > 0 and type(height) is int and height > 0 and type(width) is int and width > 0 ): raise AssertionError(f"'depth', 'height' and 'width' must be integers. Got {depth}, {height}, {width}.") _device, _dtype = _extract_device_dtype([degrees, translate, scale, shears]) if degrees.shape != torch.Size([3, 2]): raise AssertionError(f"'degrees' must be the shape of (3, 2). Got {degrees.shape}.") degrees = degrees.to(device=device, dtype=dtype) yaw = _adapted_uniform((batch_size,), degrees[0][0], degrees[0][1], same_on_batch) pitch = _adapted_uniform((batch_size,), degrees[1][0], degrees[1][1], same_on_batch) roll = _adapted_uniform((batch_size,), degrees[2][0], degrees[2][1], same_on_batch) angles = torch.stack([yaw, pitch, roll], dim=1) # compute tensor ranges if scale is not None: if scale.shape != torch.Size([3, 2]): raise AssertionError(f"'scale' must be the shape of (3, 2). Got {scale.shape}.") scale = scale.to(device=device, dtype=dtype) scale = torch.stack( [ _adapted_uniform((batch_size,), scale[0, 0], scale[0, 1], same_on_batch), _adapted_uniform((batch_size,), scale[1, 0], scale[1, 1], same_on_batch), _adapted_uniform((batch_size,), scale[2, 0], scale[2, 1], same_on_batch), ], dim=1, ) else: scale = torch.ones(batch_size, device=device, dtype=dtype).reshape(batch_size, 1).repeat(1, 3) if translate is not None: if translate.shape != torch.Size([3]): raise AssertionError(f"'translate' must be the shape of (2). Got {translate.shape}.") translate = translate.to(device=device, dtype=dtype) max_dx: torch.Tensor = translate[0] * width max_dy: torch.Tensor = translate[1] * height max_dz: torch.Tensor = translate[2] * depth # translations should be in x,y,z translations = torch.stack( [ _adapted_uniform((batch_size,), -max_dx, max_dx, same_on_batch), _adapted_uniform((batch_size,), -max_dy, max_dy, same_on_batch), _adapted_uniform((batch_size,), -max_dz, max_dz, same_on_batch), ], dim=1, ) else: translations = torch.zeros((batch_size, 3), device=device, dtype=dtype) # center should be in x,y,z center: torch.Tensor = torch.tensor([width, height, depth], device=device, dtype=dtype).view(1, 3) / 2.0 - 0.5 center = center.expand(batch_size, -1) if shears is not None: if shears.shape != torch.Size([6, 2]): raise AssertionError(f"'shears' must be the shape of (6, 2). Got {shears.shape}.") shears = shears.to(device=device, dtype=dtype) sxy = _adapted_uniform((batch_size,), shears[0, 0], shears[0, 1], same_on_batch) sxz = _adapted_uniform((batch_size,), shears[1, 0], shears[1, 1], same_on_batch) syx = _adapted_uniform((batch_size,), shears[2, 0], shears[2, 1], same_on_batch) syz = _adapted_uniform((batch_size,), shears[3, 0], shears[3, 1], same_on_batch) szx = _adapted_uniform((batch_size,), shears[4, 0], shears[4, 1], same_on_batch) szy = _adapted_uniform((batch_size,), shears[5, 0], shears[5, 1], same_on_batch) else: sxy = sxz = syx = syz = szx = szy = torch.tensor([0] * batch_size, device=device, dtype=dtype) return dict( translations=translations.to(device=_device, dtype=_dtype), center=center.to(device=_device, dtype=_dtype), scale=scale.to(device=_device, dtype=_dtype), angles=angles.to(device=_device, dtype=_dtype), sxy=sxy.to(device=_device, dtype=_dtype), sxz=sxz.to(device=_device, dtype=_dtype), syx=syx.to(device=_device, dtype=_dtype), syz=syz.to(device=_device, dtype=_dtype), szx=szx.to(device=_device, dtype=_dtype), szy=szy.to(device=_device, dtype=_dtype), ) def random_motion_blur_generator3d( batch_size: int, kernel_size: Union[int, Tuple[int, int]], angle: torch.Tensor, direction: torch.Tensor, same_on_batch: bool = False, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: r"""Get parameters for motion blur. Args: batch_size (int): the tensor batch size. kernel_size (int or (int, int)): motion kernel size (odd and positive) or range. angle (torch.Tensor): yaw, pitch and roll range of the motion blur in degrees :math:`(3, 2)`. direction (torch.Tensor): forward/backward direction of the motion blur. Lower values towards -1.0 will point the motion blur towards the back (with angle provided via angle), while higher values towards 1.0 will point the motion blur forward. A value of 0.0 leads to a uniformly (but still angled) motion blur. same_on_batch (bool): apply the same transformation across the batch. Default: False. device (torch.device): the device on which the random numbers will be generated. Default: cpu. dtype (torch.dtype): the data type of the generated random numbers. Default: float32. Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. - ksize_factor (torch.Tensor): element-wise kernel size factors with a shape of (B,). - angle_factor (torch.Tensor): element-wise center with a shape of (B,). - direction_factor (torch.Tensor): element-wise scales with a shape of (B,). Note: The generated random numbers are not reproducible across different devices and dtypes. """ _device, _dtype = _extract_device_dtype([angle, direction]) _joint_range_check(direction, 'direction', (-1, 1)) if isinstance(kernel_size, int): if not (kernel_size >= 3 and kernel_size % 2 == 1): raise AssertionError(f"`kernel_size` must be odd and greater than 3. Got {kernel_size}.") ksize_factor = torch.tensor([kernel_size] * batch_size, device=device, dtype=dtype).int() elif isinstance(kernel_size, tuple): if not (len(kernel_size) == 2 and kernel_size[0] >= 3 and kernel_size[0] <= kernel_size[1]): raise AssertionError(f"`kernel_size` must be greater than 3. Got range {kernel_size}.") # kernel_size is fixed across the batch ksize_factor = ( _adapted_uniform((batch_size,), kernel_size[0] // 2, kernel_size[1] // 2, same_on_batch=True).int() * 2 + 1 ) else: raise TypeError(f"Unsupported type: {type(kernel_size)}") if angle.shape != torch.Size([3, 2]): raise AssertionError(f"'angle' must be the shape of (3, 2). Got {angle.shape}.") angle = angle.to(device=device, dtype=dtype) yaw = _adapted_uniform((batch_size,), angle[0][0], angle[0][1], same_on_batch) pitch = _adapted_uniform((batch_size,), angle[1][0], angle[1][1], same_on_batch) roll = _adapted_uniform((batch_size,), angle[2][0], angle[2][1], same_on_batch) angle_factor = torch.stack([yaw, pitch, roll], dim=1) direction = direction.to(device=device, dtype=dtype) direction_factor = _adapted_uniform((batch_size,), direction[0], direction[1], same_on_batch) return dict( ksize_factor=ksize_factor.to(device=_device), angle_factor=angle_factor.to(device=_device, dtype=_dtype), direction_factor=direction_factor.to(device=_device, dtype=_dtype), ) def center_crop_generator3d( batch_size: int, depth: int, height: int, width: int, size: Tuple[int, int, int], device: torch.device = torch.device('cpu'), ) -> Dict[str, torch.Tensor]: r"""Get parameters for ```center_crop3d``` transformation for center crop transform. Args: batch_size (int): the tensor batch size. depth (int) : depth of the image. height (int) : height of the image. width (int): width of the image. size (tuple): Desired output size of the crop, like (d, h, w). device (torch.device): the device on which the random numbers will be generated. Default: cpu. Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. - src (torch.Tensor): cropping bounding boxes with a shape of (B, 8, 3). - dst (torch.Tensor): output bounding boxes with a shape (B, 8, 3). Note: No random number will be generated. """ if not isinstance(size, (tuple, list)) and len(size) == 3: raise ValueError(f"Input size must be a tuple/list of length 3. Got {size}") if not ( type(depth) is int and depth > 0 and type(height) is int and height > 0 and type(width) is int and width > 0 ): raise AssertionError(f"'depth', 'height' and 'width' must be integers. Got {depth}, {height}, {width}.") if not (depth >= size[0] and height >= size[1] and width >= size[2]): raise AssertionError(f"Crop size must be smaller than input size. Got ({depth}, {height}, {width}) and {size}.") if batch_size == 0: return dict(src=torch.zeros([0, 8, 3]), dst=torch.zeros([0, 8, 3])) # unpack input sizes dst_d, dst_h, dst_w = size src_d, src_h, src_w = (depth, height, width) # compute start/end offsets dst_d_half = dst_d / 2 dst_h_half = dst_h / 2 dst_w_half = dst_w / 2 src_d_half = src_d / 2 src_h_half = src_h / 2 src_w_half = src_w / 2 start_x = src_w_half - dst_w_half start_y = src_h_half - dst_h_half start_z = src_d_half - dst_d_half end_x = start_x + dst_w - 1 end_y = start_y + dst_h - 1 end_z = start_z + dst_d - 1 # [x, y, z] origin # top-left-front, top-right-front, bottom-right-front, bottom-left-front # top-left-back, top-right-back, bottom-right-back, bottom-left-back points_src: torch.Tensor = torch.tensor( [ [ [start_x, start_y, start_z], [end_x, start_y, start_z], [end_x, end_y, start_z], [start_x, end_y, start_z], [start_x, start_y, end_z], [end_x, start_y, end_z], [end_x, end_y, end_z], [start_x, end_y, end_z], ] ], device=device, dtype=torch.long, ).expand(batch_size, -1, -1) # [x, y, z] destination # top-left-front, top-right-front, bottom-right-front, bottom-left-front # top-left-back, top-right-back, bottom-right-back, bottom-left-back points_dst: torch.Tensor = torch.tensor( [ [ [0, 0, 0], [dst_w - 1, 0, 0], [dst_w - 1, dst_h - 1, 0], [0, dst_h - 1, 0], [0, 0, dst_d - 1], [dst_w - 1, 0, dst_d - 1], [dst_w - 1, dst_h - 1, dst_d - 1], [0, dst_h - 1, dst_d - 1], ] ], device=device, dtype=torch.long, ).expand(batch_size, -1, -1) return dict(src=points_src, dst=points_dst) def random_crop_generator3d( batch_size: int, input_size: Tuple[int, int, int], size: Union[Tuple[int, int, int], torch.Tensor], resize_to: Optional[Tuple[int, int, int]] = None, same_on_batch: bool = False, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: r"""Get parameters for ```crop``` transformation for crop transform. Args: batch_size (int): the tensor batch size. input_size (tuple): Input image shape, like (d, h, w). size (tuple): Desired size of the crop operation, like (d, h, w). If tensor, it must be (B, 3). resize_to (tuple): Desired output size of the crop, like (d, h, w). If None, no resize will be performed. same_on_batch (bool): apply the same transformation across the batch. Default: False. device (torch.device): the device on which the random numbers will be generated. Default: cpu. dtype (torch.dtype): the data type of the generated random numbers. Default: float32. Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. - src (torch.Tensor): cropping bounding boxes with a shape of (B, 8, 3). - dst (torch.Tensor): output bounding boxes with a shape (B, 8, 3). Note: The generated random numbers are not reproducible across different devices and dtypes. """ _device, _dtype = _extract_device_dtype([size if isinstance(size, torch.Tensor) else None]) if not isinstance(size, torch.Tensor): size = torch.tensor(size, device=device, dtype=dtype).repeat(batch_size, 1) else: size = size.to(device=device, dtype=dtype) if size.shape != torch.Size([batch_size, 3]): raise AssertionError( "If `size` is a tensor, it must be shaped as (B, 3). " f"Got {size.shape} while expecting {torch.Size([batch_size, 3])}." ) if not ( len(input_size) == 3 and isinstance(input_size[0], (int,)) and isinstance(input_size[1], (int,)) and isinstance(input_size[2], (int,)) and input_size[0] > 0 and input_size[1] > 0 and input_size[2] > 0 ): raise AssertionError(f"`input_size` must be a tuple of 3 positive integers. Got {input_size}.") x_diff = input_size[2] - size[:, 2] + 1 y_diff = input_size[1] - size[:, 1] + 1 z_diff = input_size[0] - size[:, 0] + 1 if (x_diff < 0).any() or (y_diff < 0).any() or (z_diff < 0).any(): raise ValueError(f"input_size {str(input_size)} cannot be smaller than crop size {str(size)} in any dimension.") if batch_size == 0: return dict( src=torch.zeros([0, 8, 3], device=_device, dtype=_dtype), dst=torch.zeros([0, 8, 3], device=_device, dtype=_dtype), ) if same_on_batch: # If same_on_batch, select the first then repeat. x_start = _adapted_uniform((batch_size,), 0, x_diff[0], same_on_batch).floor() y_start = _adapted_uniform((batch_size,), 0, y_diff[0], same_on_batch).floor() z_start = _adapted_uniform((batch_size,), 0, z_diff[0], same_on_batch).floor() else: x_start = _adapted_uniform((1,), 0, x_diff, same_on_batch).floor() y_start = _adapted_uniform((1,), 0, y_diff, same_on_batch).floor() z_start = _adapted_uniform((1,), 0, z_diff, same_on_batch).floor() crop_src = bbox_generator3d( x_start.to(device=_device, dtype=_dtype).view(-1), y_start.to(device=_device, dtype=_dtype).view(-1), z_start.to(device=_device, dtype=_dtype).view(-1), size[:, 2].to(device=_device, dtype=_dtype) - 1, size[:, 1].to(device=_device, dtype=_dtype) - 1, size[:, 0].to(device=_device, dtype=_dtype) - 1, ) if resize_to is None: crop_dst = bbox_generator3d( torch.tensor([0] * batch_size, device=_device, dtype=_dtype), torch.tensor([0] * batch_size, device=_device, dtype=_dtype), torch.tensor([0] * batch_size, device=_device, dtype=_dtype), size[:, 2].to(device=_device, dtype=_dtype) - 1, size[:, 1].to(device=_device, dtype=_dtype) - 1, size[:, 0].to(device=_device, dtype=_dtype) - 1, ) else: if not ( len(resize_to) == 3 and isinstance(resize_to[0], (int,)) and isinstance(resize_to[1], (int,)) and isinstance(resize_to[2], (int,)) and resize_to[0] > 0 and resize_to[1] > 0 and resize_to[2] > 0 ): raise AssertionError(f"`resize_to` must be a tuple of 3 positive integers. Got {resize_to}.") crop_dst = torch.tensor( [ [ [0, 0, 0], [resize_to[-1] - 1, 0, 0], [resize_to[-1] - 1, resize_to[-2] - 1, 0], [0, resize_to[-2] - 1, 0], [0, 0, resize_to[-3] - 1], [resize_to[-1] - 1, 0, resize_to[-3] - 1], [resize_to[-1] - 1, resize_to[-2] - 1, resize_to[-3] - 1], [0, resize_to[-2] - 1, resize_to[-3] - 1], ] ], device=_device, dtype=_dtype, ).repeat(batch_size, 1, 1) return dict(src=crop_src.to(device=_device), dst=crop_dst.to(device=_device)) def random_perspective_generator3d( batch_size: int, depth: int, height: int, width: int, distortion_scale: torch.Tensor, same_on_batch: bool = False, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: r"""Get parameters for ``perspective`` for a random perspective transform. Args: batch_size (int): the tensor batch size. depth (int) : depth of the image. height (int) : height of the image. width (int): width of the image. distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1. same_on_batch (bool): apply the same transformation across the batch. Default: False. device (torch.device): the device on which the random numbers will be generated. Default: cpu. dtype (torch.dtype): the data type of the generated random numbers. Default: float32. Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. - src (torch.Tensor): perspective source bounding boxes with a shape of (B, 8, 3). - dst (torch.Tensor): perspective target bounding boxes with a shape (B, 8, 3). Note: The generated random numbers are not reproducible across different devices and dtypes. """ if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1): raise AssertionError(f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}") _device, _dtype = _extract_device_dtype([distortion_scale]) distortion_scale = distortion_scale.to(device=device, dtype=dtype) start_points: torch.Tensor = torch.tensor( [ [ [0.0, 0, 0], [width - 1, 0, 0], [width - 1, height - 1, 0], [0, height - 1, 0], [0.0, 0, depth - 1], [width - 1, 0, depth - 1], [width - 1, height - 1, depth - 1], [0, height - 1, depth - 1], ] ], device=device, dtype=dtype, ).expand(batch_size, -1, -1) # generate random offset not larger than half of the image fx = distortion_scale * width / 2 fy = distortion_scale * height / 2 fz = distortion_scale * depth / 2 factor = torch.stack([fx, fy, fz], dim=0).view(-1, 1, 3) rand_val: torch.Tensor = _adapted_uniform( start_points.shape, torch.tensor(0, device=device, dtype=dtype), torch.tensor(1, device=device, dtype=dtype), same_on_batch, ) pts_norm = torch.tensor( [[[1, 1, 1], [-1, 1, 1], [-1, -1, 1], [1, -1, 1], [1, 1, -1], [-1, 1, -1], [-1, -1, -1], [1, -1, -1]]], device=device, dtype=dtype, ) end_points = start_points + factor * rand_val * pts_norm return dict( start_points=start_points.to(device=_device, dtype=_dtype), end_points=end_points.to(device=_device, dtype=_dtype), )