File size: 1,064 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
.. _random:

Random
======

Random sampling functions in MLX use an implicit global PRNG state by default.
However, all function take an optional ``key`` keyword argument for when more
fine-grained control or explicit state management is needed.

For example, you can generate random numbers with:

.. code-block:: python

  for _ in range(3):
    print(mx.random.uniform())

which will print a sequence of unique pseudo random numbers. Alternatively you
can explicitly set the key:

.. code-block:: python

  key = mx.random.key(0)
  for _ in range(3):
    print(mx.random.uniform(key=key))

which will yield the same pseudo random number at each iteration.

Following `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_
we use a splittable version of Threefry, which is a counter-based PRNG.

.. currentmodule:: mlx.core.random

.. autosummary:: 
  :toctree: _autosummary

   bernoulli
   categorical
   gumbel
   key
   normal
   multivariate_normal
   randint
   seed
   split
   truncated_normal
   uniform
   laplace
   permutation