| | import torch |
| | from src.data import NAG |
| | from src.transforms import Transform |
| | from src.utils.geometry import rodrigues_rotation_matrix |
| |
|
| |
|
| | __all__ = [ |
| | 'CenterPosition', 'RandomTiltAndRotate', 'RandomAnisotropicScale', |
| | 'RandomAxisFlip'] |
| |
|
| |
|
| | class CenterPosition(Transform): |
| | """Center the position of all nodes of all levels of a NAG around |
| | their level-0 centroid. |
| | """ |
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def _process(self, nag): |
| | offset = nag[0].pos.mean(dim=0) |
| | for i_level in range(nag.num_levels): |
| | nag[i_level].pos -= offset |
| | return nag |
| |
|
| |
|
| | class RandomTiltAndRotate(Transform): |
| | """Rotate the NAG around a random axis, with a random angle. The |
| | axis is picked following a gaussian jitter around the z axis. The |
| | angle is picked following a uniform distribution within a specified |
| | range. |
| | |
| | If the nodes have a `normal` or 'mean_normal' attribute, we also |
| | rotate those accordingly. |
| | |
| | Warning: any other absolute orientation-related attributes beside |
| | `pos`, `normal` and 'mean_normal' may be broken by this transform. |
| | |
| | :param phi: float (degrees) |
| | The random axis will have random angle wrt the z axis. This |
| | random angle corresponds to adding some random xy offset to z. |
| | This offset is sampled from a 2D gaussian distribution of |
| | standard deviation `sigma` computed so that a `3 * sigma` xy |
| | offset corresponds to a `phi` angle. |
| | :param theta: float (degrees) |
| | The random rotation angle will be uniformly picked within |
| | [-abs(theta), abs(theta)] |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__(self, phi=5, theta=180): |
| | assert isinstance(phi, (int, float)) |
| | assert isinstance(theta, (int, float)) |
| | self.phi = float(abs(phi)) |
| | self.theta = float(abs(theta)) |
| |
|
| | def _process(self, nag): |
| | device = nag.device |
| |
|
| | |
| | sigma = self.phi / 180. * torch.pi / 3 |
| | if sigma > 0: |
| | means = torch.zeros(2, device=device) |
| | stds = torch.eye(2, device=device) * sigma |
| | distribution = torch.distributions.MultivariateNormal(means, stds) |
| | axis_xy = distribution.sample() |
| | axis_z = torch.ones(1, device=device) |
| | axis = torch.cat((axis_xy, axis_z)) |
| | axis /= axis.norm() |
| | else: |
| | axis = torch.zeros(3, device=device, dtype=torch.float) |
| | axis[2] = 1 |
| |
|
| | |
| | theta = torch.rand(1, device=device) * 2 * self.theta - self.theta |
| |
|
| | |
| | R = rodrigues_rotation_matrix(axis, theta) |
| |
|
| | |
| | |
| | for i_level in range(nag.num_levels): |
| | if sigma <= 0: |
| | continue |
| | nag[i_level].pos = nag[i_level].pos @ R.T |
| |
|
| | |
| | |
| | for k in ['normal', 'mean_normal']: |
| | if getattr(nag[i_level], k, None) is not None: |
| | nag[i_level][k] = self._rotate_normal(nag[i_level][k], R) |
| |
|
| | |
| | |
| | |
| | if nag[i_level].edge_attr is not None: |
| | edge_attr = nag[i_level].edge_attr |
| | assert edge_attr.shape[1] == 7, \ |
| | "Expected exactly 7 features in `edge_attr`, generated " \ |
| | "with `_minimalistic_horizontal_edge_features`" |
| | dtype = edge_attr.dtype |
| | edge_attr[:, :3] = (edge_attr[:, :3].float() @ R.T).to(dtype) |
| | nag[i_level].edge_attr = edge_attr |
| |
|
| | return nag |
| |
|
| | @staticmethod |
| | def _rotate_normal(normal, R): |
| | dtype = normal.dtype |
| | normal = (normal.float() @ R.T).to(dtype) |
| | normal[normal[:, 2] < 0] *= -1 |
| | return normal |
| |
|
| |
|
| | class RandomAnisotropicScale(Transform): |
| | """Scales node positions by a randomly sampled factor ``s1, s2, s3`` |
| | within a given interval, *e.g.*, resulting in the following |
| | transformation matrix |
| | |
| | .. math:: |
| | \left[ |
| | \begin{array}{ccc} |
| | s1 & 0 & 0 \\ |
| | 0 & s2 & 0 \\ |
| | 0 & 0 & s3 \\ |
| | \end{array} |
| | \right] |
| | |
| | for three-dimensional positions. |
| | |
| | If the nodes have a `normal` attribute, we also reorient those |
| | accordingly, while preserving their unit-norm. |
| | |
| | Warning: any other absolute orientation-related attributes beside |
| | `pos` and `normal` may be broken by this transform. |
| | |
| | Credit: https://github.com/torch-points3d/torch-points3d |
| | |
| | :param delta: float or List(float) |
| | Scaling will be uniformly sampled in [-delta, delta]. If a |
| | 3-element list may be passed to scale X, Y and Z differently. |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__(self, delta=0.2): |
| | assert isinstance(delta, (float, int)) or isinstance(delta, (tuple, list)) |
| | if isinstance(delta, (float, int)): |
| | delta = [float(delta)] * 3 |
| | assert len(delta) == 3 |
| | self.delta = torch.tensor(delta).abs().view(1, -1) |
| |
|
| | def _process(self, nag): |
| | |
| | scale = 1 + (torch.rand(1) * 2 * self.delta - self.delta).to(nag.device) |
| |
|
| | for i_level in range(nag.num_levels): |
| | nag[i_level].pos = nag[i_level].pos * scale |
| |
|
| | |
| | |
| | for k in ['normal', 'mean_normal']: |
| | if getattr(nag[i_level], k, None) is not None: |
| | nag[i_level][k] = self._scale_normal(nag[i_level][k], scale) |
| |
|
| | |
| | |
| | |
| | if getattr(nag[i_level], 'edge_attr', None) is not None: |
| | edge_attr = nag[i_level].edge_attr |
| | assert edge_attr.shape[1] == 7, \ |
| | "Expected exactly 7 features in `edge_attr`, generated " \ |
| | "with `_minimalistic_horizontal_edge_features`" |
| | edge_attr[:, :3] *= scale |
| | edge_attr[:, 3:] *= scale.norm() |
| | nag[i_level].edge_attr = edge_attr |
| |
|
| | return nag |
| |
|
| | @staticmethod |
| | def _scale_normal(normal, scale): |
| | return torch.nn.functional.normalize(normal * scale, dim=1) |
| |
|
| |
|
| | class RandomAxisFlip(Transform): |
| | """Flip the node positions wrt one of the XYZ axes, with a specified |
| | probability. This transform is not very modular because it is |
| | intended to be composed with `RandomTiltAndRotate` for richer |
| | geometric augmentations. |
| | |
| | If the nodes have a `normal` or 'mean_normal' attribute, we also |
| | flip those accordingly. |
| | |
| | Warning: any other absolute orientation-related attributes beside |
| | `pos`, `normal` 'mean_normal' may be broken by this transform. |
| | |
| | :param p: float |
| | Probability of flip |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__(self, axis=0, p=0.5): |
| | assert isinstance(axis, int) |
| | assert isinstance(p, float) |
| | self.axis = axis |
| | self.p = p |
| |
|
| | def _process(self, nag): |
| | if torch.rand(1, device=nag.device) > self.p: |
| | return nag |
| |
|
| | axis = self.axis |
| | for i_level in range(nag.num_levels): |
| | nag[i_level].pos[:, axis] *= -1 |
| |
|
| | |
| | |
| | for k in ['normal', 'mean_normal']: |
| | if getattr(nag[i_level], k, None) is not None: |
| | nag[i_level][k] = self._flip_normal(nag[i_level][k], axis) |
| |
|
| | |
| | |
| | |
| | if nag[i_level].edge_attr is not None: |
| | edge_attr = nag[i_level].edge_attr |
| | assert edge_attr.shape[1] == 7, \ |
| | "Expected exactly 7 features in `edge_attr`, generated " \ |
| | "with `_minimalistic_horizontal_edge_features`" |
| | edge_attr[:, :3][:, axis] *= -1 |
| | nag[i_level].edge_attr = edge_attr |
| |
|
| | return nag |
| |
|
| | @staticmethod |
| | def _flip_normal(normal, axis): |
| | normal[:, axis] *= -1 |
| | normal[normal[:, 2] < 0] *= -1 |
| | return normal |
| |
|