Efradeca's picture
Upload folder using huggingface_hub
fc7d689 verified
import jax.numpy as jnp
import jax.random as jrn
from jax import vmap
from neural_fdm.generators.generator import PointGenerator
# ===============================================================================
# Generators
# ===============================================================================
class TubePointGenerator(PointGenerator):
"""
A generator that outputs point evaluated on a wiggled tube.
"""
pass
class EllipticalTubePointGenerator(TubePointGenerator):
"""
A generator that outputs point evaluated on a wiggled elliptical tube.
Parameters
----------
height: `float`
The height of the tube.
radius: `float`
The reference radius of the tube.
num_sides: `int`
The number of sides per ellipse.
num_levels: `int`
The number of levels along the height of the tube.
num_rings: `int`
The number of levels that will work as compression rings. The first and last levels are fully supported.
minval: `jax.Array`
The minimum values of the space of random transformations.
maxval: `jax.Array`
The maximum values of the space of random transformations.
"""
def __init__(
self,
height,
radius,
num_sides,
num_levels,
num_rings,
minval,
maxval):
# sanity checks
assert num_rings >= 3, "Must include at least 1 ring in the middle!"
self._check_array_shapes(num_rings, minval, maxval)
self.height = height
self.radius = radius
self.num_sides = num_sides
self.num_levels = num_levels
self.num_rings = num_rings
self.minval = minval
self.maxval = maxval
self.levels_rings_comp = self._levels_rings_compression()
self.indices_rings_comp_ravel = self._indices_rings_compression_ravel()
self.indices_rings_comp_interior_ravel = self._indices_rings_compression_interior_ravel()
self.levels_rings_tension = self._levels_rings_tension()
self.shape_tube = (num_levels, num_sides, 3)
self.shape_rings = (num_rings, num_sides, 3)
def __call__(self, key, wiggle=True):
"""
Generate points.
Parameters
----------
key: `jax.random.PRNGKey`
The random key.
wiggle: `bool`, optional
Whether to wiggle the points at random.
Returns
-------
points: `jax.Array`
The points on the tube.
"""
points = self.points_on_tube(key, wiggle)
return jnp.ravel(points)
def _levels_rings_tension(self):
"""
Compute the integer indices of the levels that work as tension rings.
Returns
-------
indices: `jax.Array`
The indices.
"""
indices = [i for i in range(self.num_levels) if i not in self.levels_rings_comp]
indices = jnp.array(indices, dtype=jnp.int64)
assert indices.size == self.num_levels - self.num_rings
return indices
def _levels_rings_compression(self):
"""
Compute the integer indices of the levels that work as compression rings.
Returns
-------
indices: `jax.Array`
The indices.
"""
step = int(self.num_levels / (self.num_rings - 1))
indices = [0] + list(range(step, self.num_levels - 1, step)) + [self.num_levels - 1]
indices = jnp.array(indices, dtype=jnp.int64)
assert indices.size == self.num_rings
return indices
def _indices_rings_compression_ravel(self):
"""
Compute the integer indices of the vertices in the compression rings.
Returns
-------
indices: `jax.Array`
The indices.
"""
indices = []
for index in self.levels_rings_comp:
start = index * self.num_sides
end = start + self.num_sides
indices.extend(range(start, end))
indices = jnp.array(indices, dtype=jnp.int64)
return indices
def _indices_rings_compression_interior_ravel(self):
"""
Compute the integer indices of the vertices in the unsupported compression rings.
Returns
-------
indices: `jax.Array`
The indices.
"""
indices = []
for index in self.levels_rings_comp[1:-1]:
start = index * self.num_sides
end = start + self.num_sides
indices.extend(range(start, end))
indices = jnp.array(indices, dtype=jnp.int64)
return indices
def wiggle(self, key):
"""
Sample random radii and angles from a uniform distribution.
Parameters
----------
key: `jax.random.PRNGKey`
The random key.
Returns
-------
transform: tuple of `jax.Array`
The transformation factors for the radii and angles.
"""
return self.wiggle_radii(key), self.wiggle_angle(key)
def wiggle_radii(self, key):
"""
Sample random radii from a uniform distribution.
Parameters
----------
key: `jax.random.PRNGKey`
The random key.
Returns
-------
radii: `jax.Array`
The random radii.
"""
shape = (self.num_rings, 2)
minval = self.minval[:2]
maxval = self.maxval[:2]
return jrn.uniform(key, shape=shape, minval=minval, maxval=maxval)
def wiggle_angle(self, key):
"""
Sample random angles from a uniform distribution.
Parameters
----------
key: `jax.random.PRNGKey`
The random key.
Returns
-------
angles: `jax.Array`
The random angles.
"""
shape = (self.num_rings,)
minval = self.minval[2]
maxval = self.maxval[2]
return jrn.uniform(key, shape=shape, minval=minval, maxval=maxval)
def evaluate_points(self, transform):
"""
Generate wiggled points.
Parameters
----------
transform: tuple of `jax.Array`
The random radii and angles.
Returns
-------
points: `jax.Array`
The points.
"""
heights = jnp.linspace(0.0, self.height, self.num_levels)
radii = jnp.ones(shape=(self.num_levels, 2)) * self.radius
angles = jnp.ones(shape=(self.num_levels,))
wiggle_radii, wiggle_angle = transform
wiggle_radii = wiggle_radii * self.radius
radii = radii.at[self.levels_rings_comp, :].set(wiggle_radii)
angles = angles.at[self.levels_rings_comp].set(wiggle_angle)
points = points_on_ellipses(
radii[:, 0],
radii[:, 1],
heights,
self.num_sides,
angles,
)
return jnp.ravel(points)
def points_on_tube(self, key=None, wiggle=False):
"""
Evaluate wiggled points on the tube.
Parameters
----------
key: `jax.random.PRNGKey`
The random key.
wiggle: `bool`, optional
Whether to wiggle the points at random.
Returns
-------
points: `jax.Array`
The points on the tube.
"""
heights = jnp.linspace(0.0, self.height, self.num_levels)
radii = jnp.ones(shape=(self.num_levels, 2)) * self.radius
angles = jnp.ones(shape=(self.num_levels,))
if wiggle:
wiggle_radii, wiggle_angle = self.wiggle(key)
wiggle_radii = wiggle_radii * self.radius
radii = radii.at[self.levels_rings_comp, :].set(wiggle_radii)
angles = angles.at[self.levels_rings_comp].set(wiggle_angle)
points = points_on_ellipses(
radii[:, 0],
radii[:, 1],
heights,
self.num_sides,
angles,
)
return points
def _check_array_shapes(self, num_rings, minval, maxval):
"""
Verify that input shapes are consistent.
Parameters
----------
num_rings: `int`
The number of rings.
minval: `jax.Array`
The minimum values of the space of random transformations.
maxval: `jax.Array`
The maximum values of the space of random transformations.
"""
shape = (3, )
minval_shape = minval.shape
maxval_shape = maxval.shape
assert minval_shape == shape, f"{minval_shape} vs. {shape}"
assert maxval_shape == shape, f"{maxval_shape} vs. {shape}"
class CircularTubePointGenerator(EllipticalTubePointGenerator):
"""
A generator that outputs point evaluated on a wiggled circular tube.
"""
def wiggle_radii(self, key):
"""
Sample random radii from a uniform distribution.
Parameters
----------
key: `jax.random.PRNGKey`
The random key.
Returns
-------
radii: `jax.Array`
The random radii.
"""
shape = (self.num_rings,)
minval = self.minval[0]
maxval = self.maxval[0]
return jrn.uniform(key, shape=shape, minval=minval, maxval=maxval)
def points_on_tube(self, key=None, wiggle=False):
"""
Evaluate wiggled points on the tube.
Parameters
----------
key: `jax.random.PRNGKey`
The random key.
wiggle: `bool`, optional
Whether to wiggle the points at random.
Returns
-------
points: `jax.Array`
The points on the tube.
"""
heights = jnp.linspace(0.0, self.height, self.num_levels)
radii = jnp.ones(shape=(self.num_levels,)) * self.radius
angles = jnp.ones(shape=(self.num_levels,))
if wiggle:
wiggle_radii, wiggle_angle = self.wiggle(key)
wiggle_radii = wiggle_radii * self.radius
radii = radii.at[self.levels_rings_comp].set(wiggle_radii)
angles = angles.at[self.levels_rings_comp].set(wiggle_angle)
points = points_on_ellipses(
radii,
radii,
heights,
self.num_sides,
angles,
)
return points
# ===============================================================================
# Helper functions
# ===============================================================================
def points_on_ellipse_xy(radius_1, radius_2, num_sides, angle=0.0):
"""
Sample points on an ellipse on the XY plane.
Parameters
----------
radius_1: `float`
The radius of the ellipse along the X axis.
radius_2: `float`
The radius of the ellipse along the Y axis.
num_sides: `int`
The number of sides of the ellipse.
angle: `float`, optional
The angle of the ellipse in degrees relative to the X axis.
Returns
-------
points: `jax.Array`
The points.
Notes
-----
The first and last points are not equal.
"""
angles = 2 * jnp.pi * jnp.linspace(0.0, 1.0, num_sides + 1)
angles = jnp.reshape(angles, (-1, 1))
xs = radius_1 * jnp.cos(angles)
ys = radius_2 * jnp.sin(angles)
points = jnp.hstack((xs, ys))[:-1]
# Calculate rotation matrix
theta = jnp.radians(angle)
rotation_matrix = jnp.array([
[jnp.cos(theta), -jnp.sin(theta)],
[jnp.sin(theta), jnp.cos(theta)]
])
# Rotate points
points = points @ rotation_matrix.T
return points
def points_on_ellipse(radius_1, radius_2, height, num_sides, angle=0.0):
"""
Sample points on a planar ellipse at a given height.
Parameters
----------
radius_1: `float`
The radius of the ellipse along the X axis.
radius_2: `float`
The radius of the ellipse along the Y axis.
height: `float`
The height of the ellipse.
num_sides: `int`
The number of sides of the ellipse.
angle: `float`, optional
The angle of the ellipse in degrees relative to the X axis.
Returns
-------
points: `jax.Array`
The points.
Notes
-----
The first and last points are not equal.
"""
xy = points_on_ellipse_xy(radius_1, radius_2, num_sides, angle)
z = jnp.ones((num_sides, 1)) * height
return jnp.hstack((xy, z))
def points_on_ellipses(radius_1, radius_2, heights, num_sides, angles):
"""
Sample points on an sequence of ellipses distributed over an array of heights.
Parameters
----------
radius_1: `jax.Array`
The radii of the ellipses along the X axis.
radius_2: `jax.Array`
The radii of the ellipses along the Y axis.
heights: `jax.Array`
The heights of the ellipses.
num_sides: `int`
The number of sides of the ellipses.
angles: `jax.Array`
The angles of the ellipses in degrees relative to the X axis.
Returns
-------
points: `jax.Array`
The points on the ellipses.
Notes
-----
The first and last points per ellipse are not equal.
"""
polygon_fn = vmap(points_on_ellipse, in_axes=(0, 0, 0, None, 0))
return polygon_fn(radius_1, radius_2, heights, num_sides, angles)
# ===============================================================================
# Main
# ===============================================================================