File size: 541 Bytes
fc7d689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class PointGenerator:
    """
    A generator that samples random points on a target shape.
    """
    def __call__(self, key, wiggle=True):
        """
        Generate points.

        Parameters
        ----------
        key: `jax.random.PRNGKey`
            The random key.
        wiggle: `bool`
            If True, the points are wiggled.

        Returns
        -------
        points: `jax.Array`
            The points on the target shape.
        """
        raise NotImplementedError("Subclasses must implement this method.")