File size: 541 Bytes
fc7d689 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | class PointGenerator:
"""
A generator that samples random points on a target shape.
"""
def __call__(self, key, wiggle=True):
"""
Generate points.
Parameters
----------
key: `jax.random.PRNGKey`
The random key.
wiggle: `bool`
If True, the points are wiggled.
Returns
-------
points: `jax.Array`
The points on the target shape.
"""
raise NotImplementedError("Subclasses must implement this method.")
|