| import jax.numpy as jnp |
| import jax.random as jrn |
| from jax import vmap |
|
|
| from neural_fdm.generators.generator import PointGenerator |
|
|
| |
| |
| |
|
|
| class TubePointGenerator(PointGenerator): |
| """ |
| A generator that outputs point evaluated on a wiggled tube. |
| """ |
| pass |
|
|
|
|
| class EllipticalTubePointGenerator(TubePointGenerator): |
| """ |
| A generator that outputs point evaluated on a wiggled elliptical tube. |
| |
| Parameters |
| ---------- |
| height: `float` |
| The height of the tube. |
| radius: `float` |
| The reference radius of the tube. |
| num_sides: `int` |
| The number of sides per ellipse. |
| num_levels: `int` |
| The number of levels along the height of the tube. |
| num_rings: `int` |
| The number of levels that will work as compression rings. The first and last levels are fully supported. |
| minval: `jax.Array` |
| The minimum values of the space of random transformations. |
| maxval: `jax.Array` |
| The maximum values of the space of random transformations. |
| """ |
| def __init__( |
| self, |
| height, |
| radius, |
| num_sides, |
| num_levels, |
| num_rings, |
| minval, |
| maxval): |
|
|
| |
| assert num_rings >= 3, "Must include at least 1 ring in the middle!" |
| self._check_array_shapes(num_rings, minval, maxval) |
|
|
| self.height = height |
| self.radius = radius |
|
|
| self.num_sides = num_sides |
| self.num_levels = num_levels |
| self.num_rings = num_rings |
|
|
| self.minval = minval |
| self.maxval = maxval |
|
|
| self.levels_rings_comp = self._levels_rings_compression() |
| self.indices_rings_comp_ravel = self._indices_rings_compression_ravel() |
| self.indices_rings_comp_interior_ravel = self._indices_rings_compression_interior_ravel() |
|
|
| self.levels_rings_tension = self._levels_rings_tension() |
|
|
| self.shape_tube = (num_levels, num_sides, 3) |
| self.shape_rings = (num_rings, num_sides, 3) |
|
|
| def __call__(self, key, wiggle=True): |
| """ |
| Generate points. |
| |
| Parameters |
| ---------- |
| key: `jax.random.PRNGKey` |
| The random key. |
| wiggle: `bool`, optional |
| Whether to wiggle the points at random. |
| |
| Returns |
| ------- |
| points: `jax.Array` |
| The points on the tube. |
| """ |
| points = self.points_on_tube(key, wiggle) |
|
|
| return jnp.ravel(points) |
|
|
| def _levels_rings_tension(self): |
| """ |
| Compute the integer indices of the levels that work as tension rings. |
| |
| Returns |
| ------- |
| indices: `jax.Array` |
| The indices. |
| """ |
| indices = [i for i in range(self.num_levels) if i not in self.levels_rings_comp] |
| indices = jnp.array(indices, dtype=jnp.int64) |
|
|
| assert indices.size == self.num_levels - self.num_rings |
|
|
| return indices |
|
|
| def _levels_rings_compression(self): |
| """ |
| Compute the integer indices of the levels that work as compression rings. |
| |
| Returns |
| ------- |
| indices: `jax.Array` |
| The indices. |
| """ |
| step = int(self.num_levels / (self.num_rings - 1)) |
|
|
| indices = [0] + list(range(step, self.num_levels - 1, step)) + [self.num_levels - 1] |
| indices = jnp.array(indices, dtype=jnp.int64) |
|
|
| assert indices.size == self.num_rings |
|
|
| return indices |
|
|
| def _indices_rings_compression_ravel(self): |
| """ |
| Compute the integer indices of the vertices in the compression rings. |
| |
| Returns |
| ------- |
| indices: `jax.Array` |
| The indices. |
| """ |
| indices = [] |
| for index in self.levels_rings_comp: |
| start = index * self.num_sides |
| end = start + self.num_sides |
| indices.extend(range(start, end)) |
|
|
| indices = jnp.array(indices, dtype=jnp.int64) |
|
|
| return indices |
|
|
| def _indices_rings_compression_interior_ravel(self): |
| """ |
| Compute the integer indices of the vertices in the unsupported compression rings. |
| |
| Returns |
| ------- |
| indices: `jax.Array` |
| The indices. |
| """ |
| indices = [] |
| for index in self.levels_rings_comp[1:-1]: |
| start = index * self.num_sides |
| end = start + self.num_sides |
| indices.extend(range(start, end)) |
|
|
| indices = jnp.array(indices, dtype=jnp.int64) |
|
|
| return indices |
|
|
| def wiggle(self, key): |
| """ |
| Sample random radii and angles from a uniform distribution. |
| |
| Parameters |
| ---------- |
| key: `jax.random.PRNGKey` |
| The random key. |
| |
| Returns |
| ------- |
| transform: tuple of `jax.Array` |
| The transformation factors for the radii and angles. |
| """ |
| return self.wiggle_radii(key), self.wiggle_angle(key) |
|
|
| def wiggle_radii(self, key): |
| """ |
| Sample random radii from a uniform distribution. |
| |
| Parameters |
| ---------- |
| key: `jax.random.PRNGKey` |
| The random key. |
| |
| Returns |
| ------- |
| radii: `jax.Array` |
| The random radii. |
| """ |
| shape = (self.num_rings, 2) |
| minval = self.minval[:2] |
| maxval = self.maxval[:2] |
|
|
| return jrn.uniform(key, shape=shape, minval=minval, maxval=maxval) |
|
|
| def wiggle_angle(self, key): |
| """ |
| Sample random angles from a uniform distribution. |
| |
| Parameters |
| ---------- |
| key: `jax.random.PRNGKey` |
| The random key. |
| |
| Returns |
| ------- |
| angles: `jax.Array` |
| The random angles. |
| """ |
| shape = (self.num_rings,) |
| minval = self.minval[2] |
| maxval = self.maxval[2] |
|
|
| return jrn.uniform(key, shape=shape, minval=minval, maxval=maxval) |
|
|
| def evaluate_points(self, transform): |
| """ |
| Generate wiggled points. |
| |
| Parameters |
| ---------- |
| transform: tuple of `jax.Array` |
| The random radii and angles. |
| |
| Returns |
| ------- |
| points: `jax.Array` |
| The points. |
| """ |
| heights = jnp.linspace(0.0, self.height, self.num_levels) |
| radii = jnp.ones(shape=(self.num_levels, 2)) * self.radius |
| angles = jnp.ones(shape=(self.num_levels,)) |
|
|
| wiggle_radii, wiggle_angle = transform |
| wiggle_radii = wiggle_radii * self.radius |
| radii = radii.at[self.levels_rings_comp, :].set(wiggle_radii) |
| angles = angles.at[self.levels_rings_comp].set(wiggle_angle) |
|
|
| points = points_on_ellipses( |
| radii[:, 0], |
| radii[:, 1], |
| heights, |
| self.num_sides, |
| angles, |
| ) |
|
|
| return jnp.ravel(points) |
|
|
| def points_on_tube(self, key=None, wiggle=False): |
| """ |
| Evaluate wiggled points on the tube. |
| |
| Parameters |
| ---------- |
| key: `jax.random.PRNGKey` |
| The random key. |
| wiggle: `bool`, optional |
| Whether to wiggle the points at random. |
| |
| Returns |
| ------- |
| points: `jax.Array` |
| The points on the tube. |
| """ |
| heights = jnp.linspace(0.0, self.height, self.num_levels) |
| radii = jnp.ones(shape=(self.num_levels, 2)) * self.radius |
| angles = jnp.ones(shape=(self.num_levels,)) |
|
|
| if wiggle: |
| wiggle_radii, wiggle_angle = self.wiggle(key) |
| wiggle_radii = wiggle_radii * self.radius |
| radii = radii.at[self.levels_rings_comp, :].set(wiggle_radii) |
| angles = angles.at[self.levels_rings_comp].set(wiggle_angle) |
|
|
| points = points_on_ellipses( |
| radii[:, 0], |
| radii[:, 1], |
| heights, |
| self.num_sides, |
| angles, |
| ) |
|
|
| return points |
|
|
| def _check_array_shapes(self, num_rings, minval, maxval): |
| """ |
| Verify that input shapes are consistent. |
| |
| Parameters |
| ---------- |
| num_rings: `int` |
| The number of rings. |
| minval: `jax.Array` |
| The minimum values of the space of random transformations. |
| maxval: `jax.Array` |
| The maximum values of the space of random transformations. |
| """ |
| shape = (3, ) |
| minval_shape = minval.shape |
| maxval_shape = maxval.shape |
|
|
| assert minval_shape == shape, f"{minval_shape} vs. {shape}" |
| assert maxval_shape == shape, f"{maxval_shape} vs. {shape}" |
|
|
|
|
| class CircularTubePointGenerator(EllipticalTubePointGenerator): |
| """ |
| A generator that outputs point evaluated on a wiggled circular tube. |
| """ |
| def wiggle_radii(self, key): |
| """ |
| Sample random radii from a uniform distribution. |
| |
| Parameters |
| ---------- |
| key: `jax.random.PRNGKey` |
| The random key. |
| |
| Returns |
| ------- |
| radii: `jax.Array` |
| The random radii. |
| """ |
| shape = (self.num_rings,) |
| minval = self.minval[0] |
| maxval = self.maxval[0] |
|
|
| return jrn.uniform(key, shape=shape, minval=minval, maxval=maxval) |
|
|
| def points_on_tube(self, key=None, wiggle=False): |
| """ |
| Evaluate wiggled points on the tube. |
| |
| Parameters |
| ---------- |
| key: `jax.random.PRNGKey` |
| The random key. |
| wiggle: `bool`, optional |
| Whether to wiggle the points at random. |
| |
| Returns |
| ------- |
| points: `jax.Array` |
| The points on the tube. |
| """ |
| heights = jnp.linspace(0.0, self.height, self.num_levels) |
| radii = jnp.ones(shape=(self.num_levels,)) * self.radius |
| angles = jnp.ones(shape=(self.num_levels,)) |
|
|
| if wiggle: |
| wiggle_radii, wiggle_angle = self.wiggle(key) |
| wiggle_radii = wiggle_radii * self.radius |
| radii = radii.at[self.levels_rings_comp].set(wiggle_radii) |
| angles = angles.at[self.levels_rings_comp].set(wiggle_angle) |
|
|
| points = points_on_ellipses( |
| radii, |
| radii, |
| heights, |
| self.num_sides, |
| angles, |
| ) |
|
|
| return points |
|
|
|
|
| |
| |
| |
|
|
| def points_on_ellipse_xy(radius_1, radius_2, num_sides, angle=0.0): |
| """ |
| Sample points on an ellipse on the XY plane. |
| |
| Parameters |
| ---------- |
| radius_1: `float` |
| The radius of the ellipse along the X axis. |
| radius_2: `float` |
| The radius of the ellipse along the Y axis. |
| num_sides: `int` |
| The number of sides of the ellipse. |
| angle: `float`, optional |
| The angle of the ellipse in degrees relative to the X axis. |
| |
| Returns |
| ------- |
| points: `jax.Array` |
| The points. |
| |
| Notes |
| ----- |
| The first and last points are not equal. |
| """ |
| angles = 2 * jnp.pi * jnp.linspace(0.0, 1.0, num_sides + 1) |
| angles = jnp.reshape(angles, (-1, 1)) |
| xs = radius_1 * jnp.cos(angles) |
| ys = radius_2 * jnp.sin(angles) |
|
|
| points = jnp.hstack((xs, ys))[:-1] |
|
|
| |
| theta = jnp.radians(angle) |
| rotation_matrix = jnp.array([ |
| [jnp.cos(theta), -jnp.sin(theta)], |
| [jnp.sin(theta), jnp.cos(theta)] |
| ]) |
|
|
| |
| points = points @ rotation_matrix.T |
|
|
| return points |
|
|
|
|
| def points_on_ellipse(radius_1, radius_2, height, num_sides, angle=0.0): |
| """ |
| Sample points on a planar ellipse at a given height. |
| |
| Parameters |
| ---------- |
| radius_1: `float` |
| The radius of the ellipse along the X axis. |
| radius_2: `float` |
| The radius of the ellipse along the Y axis. |
| height: `float` |
| The height of the ellipse. |
| num_sides: `int` |
| The number of sides of the ellipse. |
| angle: `float`, optional |
| The angle of the ellipse in degrees relative to the X axis. |
| |
| Returns |
| ------- |
| points: `jax.Array` |
| The points. |
| |
| Notes |
| ----- |
| The first and last points are not equal. |
| """ |
| xy = points_on_ellipse_xy(radius_1, radius_2, num_sides, angle) |
| z = jnp.ones((num_sides, 1)) * height |
|
|
| return jnp.hstack((xy, z)) |
|
|
|
|
| def points_on_ellipses(radius_1, radius_2, heights, num_sides, angles): |
| """ |
| Sample points on an sequence of ellipses distributed over an array of heights. |
| |
| Parameters |
| ---------- |
| radius_1: `jax.Array` |
| The radii of the ellipses along the X axis. |
| radius_2: `jax.Array` |
| The radii of the ellipses along the Y axis. |
| heights: `jax.Array` |
| The heights of the ellipses. |
| num_sides: `int` |
| The number of sides of the ellipses. |
| angles: `jax.Array` |
| The angles of the ellipses in degrees relative to the X axis. |
| |
| Returns |
| ------- |
| points: `jax.Array` |
| The points on the ellipses. |
| |
| Notes |
| ----- |
| The first and last points per ellipse are not equal. |
| """ |
| polygon_fn = vmap(points_on_ellipse, in_axes=(0, 0, 0, None, 0)) |
|
|
| return polygon_fn(radius_1, radius_2, heights, num_sides, angles) |
|
|
|
|
| |
| |
| |
|
|
|
|