vae-fdm / src /neural_fdm /generators /generator.py
Efradeca's picture
Upload folder using huggingface_hub
fc7d689 verified
raw
history blame contribute delete
541 Bytes
class PointGenerator:
"""
A generator that samples random points on a target shape.
"""
def __call__(self, key, wiggle=True):
"""
Generate points.
Parameters
----------
key: `jax.random.PRNGKey`
The random key.
wiggle: `bool`
If True, the points are wiggled.
Returns
-------
points: `jax.Array`
The points on the target shape.
"""
raise NotImplementedError("Subclasses must implement this method.")