| class PointGenerator: | |
| """ | |
| A generator that samples random points on a target shape. | |
| """ | |
| def __call__(self, key, wiggle=True): | |
| """ | |
| Generate points. | |
| Parameters | |
| ---------- | |
| key: `jax.random.PRNGKey` | |
| The random key. | |
| wiggle: `bool` | |
| If True, the points are wiggled. | |
| Returns | |
| ------- | |
| points: `jax.Array` | |
| The points on the target shape. | |
| """ | |
| raise NotImplementedError("Subclasses must implement this method.") | |