|
|
from keras.src import backend |
|
|
from keras.src.api_export import keras_export |
|
|
|
|
|
|
|
|
@keras_export("keras.random.normal") |
|
|
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): |
|
|
"""Draw random samples from a normal (Gaussian) distribution. |
|
|
|
|
|
Args: |
|
|
shape: The shape of the random values to generate. |
|
|
mean: Float, defaults to 0. Mean of the random values to generate. |
|
|
stddev: Float, defaults to 1. Standard deviation of the random values |
|
|
to generate. |
|
|
dtype: Optional dtype of the tensor. Only floating point types are |
|
|
supported. If not specified, `keras.config.floatx()` is used, |
|
|
which defaults to `float32` unless you configured it otherwise (via |
|
|
`keras.config.set_floatx(float_dtype)`). |
|
|
seed: Optional Python integer or instance of |
|
|
`keras.random.SeedGenerator`. |
|
|
By default, the `seed` argument is `None`, and an internal global |
|
|
`keras.random.SeedGenerator` is used. The `seed` argument can be |
|
|
used to ensure deterministic (repeatable) random number generation. |
|
|
Note that passing an integer as the `seed` value will produce the |
|
|
same random values for each call. To generate different random |
|
|
values for repeated calls, an instance of |
|
|
`keras.random.SeedGenerator` must be provided as the `seed` value. |
|
|
Remark concerning the JAX backend: When tracing functions with the |
|
|
JAX backend the global `keras.random.SeedGenerator` is not |
|
|
supported. Therefore, during tracing the default value `seed=None` |
|
|
will produce an error, and a `seed` argument must be provided. |
|
|
""" |
|
|
return backend.random.normal( |
|
|
shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed |
|
|
) |
|
|
|
|
|
|
|
|
@keras_export("keras.random.categorical") |
|
|
def categorical(logits, num_samples, dtype="int32", seed=None): |
|
|
"""Draws samples from a categorical distribution. |
|
|
|
|
|
This function takes as input `logits`, a 2-D input tensor with shape |
|
|
(batch_size, num_classes). Each row of the input represents a categorical |
|
|
distribution, with each column index containing the log-probability for a |
|
|
given class. |
|
|
|
|
|
The function will output a 2-D tensor with shape (batch_size, num_samples), |
|
|
where each row contains samples from the corresponding row in `logits`. |
|
|
Each column index contains an independent samples drawn from the input |
|
|
distribution. |
|
|
|
|
|
Args: |
|
|
logits: 2-D Tensor with shape (batch_size, num_classes). Each row |
|
|
should define a categorical distribution with the unnormalized |
|
|
log-probabilities for all classes. |
|
|
num_samples: Int, the number of independent samples to draw for each |
|
|
row of the input. This will be the second dimension of the output |
|
|
tensor's shape. |
|
|
dtype: Optional dtype of the output tensor. |
|
|
seed: Optional Python integer or instance of |
|
|
`keras.random.SeedGenerator`. |
|
|
By default, the `seed` argument is `None`, and an internal global |
|
|
`keras.random.SeedGenerator` is used. The `seed` argument can be |
|
|
used to ensure deterministic (repeatable) random number generation. |
|
|
Note that passing an integer as the `seed` value will produce the |
|
|
same random values for each call. To generate different random |
|
|
values for repeated calls, an instance of |
|
|
`keras.random.SeedGenerator` must be provided as the `seed` value. |
|
|
Remark concerning the JAX backend: When tracing functions with the |
|
|
JAX backend the global `keras.random.SeedGenerator` is not |
|
|
supported. Therefore, during tracing the default value seed=None |
|
|
will produce an error, and a `seed` argument must be provided. |
|
|
|
|
|
Returns: |
|
|
A 2-D tensor with (batch_size, num_samples). |
|
|
""" |
|
|
logits_shape = list(backend.convert_to_tensor(logits).shape) |
|
|
if len(logits_shape) != 2: |
|
|
raise ValueError( |
|
|
"`logits` should be a 2-D tensor with shape " |
|
|
f"[batch_size, num_classes]. Received: logits={logits}" |
|
|
) |
|
|
return backend.random.categorical( |
|
|
logits, num_samples, dtype=dtype, seed=seed |
|
|
) |
|
|
|
|
|
|
|
|
@keras_export("keras.random.uniform") |
|
|
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): |
|
|
"""Draw samples from a uniform distribution. |
|
|
|
|
|
The generated values follow a uniform distribution in the range |
|
|
`[minval, maxval)`. The lower bound `minval` is included in the range, |
|
|
while the upper bound `maxval` is excluded. |
|
|
|
|
|
`dtype` must be a floating point type, the default range is `[0, 1)`. |
|
|
|
|
|
Args: |
|
|
shape: The shape of the random values to generate. |
|
|
minval: Float, defaults to 0. Lower bound of the range of |
|
|
random values to generate (inclusive). |
|
|
maxval: Float, defaults to 1. Upper bound of the range of |
|
|
random values to generate (exclusive). |
|
|
dtype: Optional dtype of the tensor. Only floating point types are |
|
|
supported. If not specified, `keras.config.floatx()` is used, |
|
|
which defaults to `float32` unless you configured it otherwise (via |
|
|
`keras.config.set_floatx(float_dtype)`) |
|
|
seed: Optional Python integer or instance of |
|
|
`keras.random.SeedGenerator`. |
|
|
By default, the `seed` argument is `None`, and an internal global |
|
|
`keras.random.SeedGenerator` is used. The `seed` argument can be |
|
|
used to ensure deterministic (repeatable) random number generation. |
|
|
Note that passing an integer as the `seed` value will produce the |
|
|
same random values for each call. To generate different random |
|
|
values for repeated calls, an instance of |
|
|
`keras.random.SeedGenerator` must be provided as the `seed` value. |
|
|
Remark concerning the JAX backend: When tracing functions with the |
|
|
JAX backend the global `keras.random.SeedGenerator` is not |
|
|
supported. Therefore, during tracing the default value seed=None |
|
|
will produce an error, and a `seed` argument must be provided. |
|
|
""" |
|
|
if dtype and not backend.is_float_dtype(dtype): |
|
|
raise ValueError( |
|
|
"`keras.random.uniform` requires a floating point `dtype`. " |
|
|
f"Received: dtype={dtype} " |
|
|
) |
|
|
return backend.random.uniform( |
|
|
shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed |
|
|
) |
|
|
|
|
|
|
|
|
@keras_export("keras.random.randint") |
|
|
def randint(shape, minval, maxval, dtype="int32", seed=None): |
|
|
"""Draw random integers from a uniform distribution. |
|
|
|
|
|
The generated values follow a uniform distribution in the range |
|
|
`[minval, maxval)`. The lower bound `minval` is included in the range, |
|
|
while the upper bound `maxval` is excluded. |
|
|
|
|
|
`dtype` must be an integer type. |
|
|
|
|
|
Args: |
|
|
shape: The shape of the random values to generate. |
|
|
minval: Float, defaults to 0. Lower bound of the range of |
|
|
random values to generate (inclusive). |
|
|
maxval: Float, defaults to 1. Upper bound of the range of |
|
|
random values to generate (exclusive). |
|
|
dtype: Optional dtype of the tensor. Only integer types are |
|
|
supported. If not specified, `keras.config.floatx()` is used, |
|
|
which defaults to `float32` unless you configured it otherwise (via |
|
|
`keras.config.set_floatx(float_dtype)`) |
|
|
seed: Optional Python integer or instance of |
|
|
`keras.random.SeedGenerator`. |
|
|
By default, the `seed` argument is `None`, and an internal global |
|
|
`keras.random.SeedGenerator` is used. The `seed` argument can be |
|
|
used to ensure deterministic (repeatable) random number generation. |
|
|
Note that passing an integer as the `seed` value will produce the |
|
|
same random values for each call. To generate different random |
|
|
values for repeated calls, an instance of |
|
|
`keras.random.SeedGenerator` must be provided as the `seed` value. |
|
|
Remark concerning the JAX backend: When tracing functions with the |
|
|
JAX backend the global `keras.random.SeedGenerator` is not |
|
|
supported. Therefore, during tracing the default value seed=None |
|
|
will produce an error, and a `seed` argument must be provided. |
|
|
""" |
|
|
if dtype and not backend.is_int_dtype(dtype): |
|
|
raise ValueError( |
|
|
"`keras.random.randint` requires an integer `dtype`. " |
|
|
f"Received: dtype={dtype} " |
|
|
) |
|
|
return backend.random.randint( |
|
|
shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed |
|
|
) |
|
|
|
|
|
|
|
|
@keras_export("keras.random.truncated_normal") |
|
|
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): |
|
|
"""Draw samples from a truncated normal distribution. |
|
|
|
|
|
The values are drawn from a normal distribution with specified mean and |
|
|
standard deviation, discarding and re-drawing any samples that are more |
|
|
than two standard deviations from the mean. |
|
|
|
|
|
Args: |
|
|
shape: The shape of the random values to generate. |
|
|
mean: Float, defaults to 0. Mean of the random values to generate. |
|
|
stddev: Float, defaults to 1. Standard deviation of the random values |
|
|
to generate. |
|
|
dtype: Optional dtype of the tensor. Only floating point types are |
|
|
supported. If not specified, `keras.config.floatx()` is used, |
|
|
which defaults to `float32` unless you configured it otherwise (via |
|
|
`keras.config.set_floatx(float_dtype)`) |
|
|
seed: Optional Python integer or instance of |
|
|
`keras.random.SeedGenerator`. |
|
|
By default, the `seed` argument is `None`, and an internal global |
|
|
`keras.random.SeedGenerator` is used. The `seed` argument can be |
|
|
used to ensure deterministic (repeatable) random number generation. |
|
|
Note that passing an integer as the `seed` value will produce the |
|
|
same random values for each call. To generate different random |
|
|
values for repeated calls, an instance of |
|
|
`keras.random.SeedGenerator` must be provided as the `seed` value. |
|
|
Remark concerning the JAX backend: When tracing functions with the |
|
|
JAX backend the global `keras.random.SeedGenerator` is not |
|
|
supported. Therefore, during tracing the default value seed=None |
|
|
will produce an error, and a `seed` argument must be provided. |
|
|
""" |
|
|
return backend.random.truncated_normal( |
|
|
shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed |
|
|
) |
|
|
|
|
|
|
|
|
@keras_export("keras.random.dropout") |
|
|
def dropout(inputs, rate, noise_shape=None, seed=None): |
|
|
return backend.random.dropout( |
|
|
inputs, rate, noise_shape=noise_shape, seed=seed |
|
|
) |
|
|
|
|
|
|
|
|
@keras_export("keras.random.shuffle") |
|
|
def shuffle(x, axis=0, seed=None): |
|
|
"""Shuffle the elements of a tensor uniformly at random along an axis. |
|
|
|
|
|
Args: |
|
|
x: The tensor to be shuffled. |
|
|
axis: An integer specifying the axis along which to shuffle. Defaults to |
|
|
`0`. |
|
|
seed: Optional Python integer or instance of |
|
|
`keras.random.SeedGenerator`. |
|
|
By default, the `seed` argument is `None`, and an internal global |
|
|
`keras.random.SeedGenerator` is used. The `seed` argument can be |
|
|
used to ensure deterministic (repeatable) random number generation. |
|
|
Note that passing an integer as the `seed` value will produce the |
|
|
same random values for each call. To generate different random |
|
|
values for repeated calls, an instance of |
|
|
`keras.random.SeedGenerator` must be provided as the `seed` value. |
|
|
Remark concerning the JAX backend: When tracing functions with the |
|
|
JAX backend the global `keras.random.SeedGenerator` is not |
|
|
supported. Therefore, during tracing the default value seed=None |
|
|
will produce an error, and a `seed` argument must be provided. |
|
|
""" |
|
|
return backend.random.shuffle(x, axis=axis, seed=seed) |
|
|
|
|
|
|
|
|
@keras_export("keras.random.gamma") |
|
|
def gamma(shape, alpha, dtype=None, seed=None): |
|
|
"""Draw random samples from the Gamma distribution. |
|
|
|
|
|
Args: |
|
|
shape: The shape of the random values to generate. |
|
|
alpha: Float, the parameter of the distribution. |
|
|
dtype: Optional dtype of the tensor. Only floating point types are |
|
|
supported. If not specified, `keras.config.floatx()` is used, |
|
|
which defaults to `float32` unless you configured it otherwise (via |
|
|
`keras.config.set_floatx(float_dtype)`). |
|
|
seed: Optional Python integer or instance of |
|
|
`keras.random.SeedGenerator`. |
|
|
By default, the `seed` argument is `None`, and an internal global |
|
|
`keras.random.SeedGenerator` is used. The `seed` argument can be |
|
|
used to ensure deterministic (repeatable) random number generation. |
|
|
Note that passing an integer as the `seed` value will produce the |
|
|
same random values for each call. To generate different random |
|
|
values for repeated calls, an instance of |
|
|
`keras.random.SeedGenerator` must be provided as the `seed` value. |
|
|
Remark concerning the JAX backend: When tracing functions with the |
|
|
JAX backend the global `keras.random.SeedGenerator` is not |
|
|
supported. Therefore, during tracing the default value seed=None |
|
|
will produce an error, and a `seed` argument must be provided. |
|
|
""" |
|
|
return backend.random.gamma(shape, alpha=alpha, dtype=dtype, seed=seed) |
|
|
|
|
|
|
|
|
@keras_export("keras.random.binomial") |
|
|
def binomial(shape, counts, probabilities, dtype=None, seed=None): |
|
|
"""Draw samples from a Binomial distribution. |
|
|
|
|
|
The values are drawn from a Binomial distribution with |
|
|
specified trial count and probability of success. |
|
|
|
|
|
Args: |
|
|
shape: The shape of the random values to generate. |
|
|
counts: A number or array of numbers representing the |
|
|
number of trials. It must be broadcastable with `probabilities`. |
|
|
probabilities: A float or array of floats representing the |
|
|
probability of success of an individual event. |
|
|
It must be broadcastable with `counts`. |
|
|
dtype: Optional dtype of the tensor. Only floating point types are |
|
|
supported. If not specified, `keras.config.floatx()` is used, |
|
|
which defaults to `float32` unless you configured it otherwise (via |
|
|
`keras.config.set_floatx(float_dtype)`). |
|
|
seed: Optional Python integer or instance of |
|
|
`keras.random.SeedGenerator`. |
|
|
By default, the `seed` argument is `None`, and an internal global |
|
|
`keras.random.SeedGenerator` is used. The `seed` argument can be |
|
|
used to ensure deterministic (repeatable) random number generation. |
|
|
Note that passing an integer as the `seed` value will produce the |
|
|
same random values for each call. To generate different random |
|
|
values for repeated calls, an instance of |
|
|
`keras.random.SeedGenerator` must be provided as the `seed` value. |
|
|
Remark concerning the JAX backend: When tracing functions with the |
|
|
JAX backend the global `keras.random.SeedGenerator` is not |
|
|
supported. Therefore, during tracing the default value seed=None |
|
|
will produce an error, and a `seed` argument must be provided. |
|
|
""" |
|
|
return backend.random.binomial( |
|
|
shape, |
|
|
counts=counts, |
|
|
probabilities=probabilities, |
|
|
dtype=dtype, |
|
|
seed=seed, |
|
|
) |
|
|
|
|
|
|
|
|
@keras_export("keras.random.beta") |
|
|
def beta(shape, alpha, beta, dtype=None, seed=None): |
|
|
"""Draw samples from a Beta distribution. |
|
|
|
|
|
The values are drawn from a Beta distribution parametrized |
|
|
by alpha and beta. |
|
|
|
|
|
Args: |
|
|
shape: The shape of the random values to generate. |
|
|
alpha: Float or an array of floats representing the first |
|
|
parameter alpha. Must be broadcastable with `beta` and `shape`. |
|
|
beta: Float or an array of floats representing the second |
|
|
parameter beta. Must be broadcastable with `alpha` and `shape`. |
|
|
dtype: Optional dtype of the tensor. Only floating point types are |
|
|
supported. If not specified, `keras.config.floatx()` is used, |
|
|
which defaults to `float32` unless you configured it otherwise (via |
|
|
`keras.config.set_floatx(float_dtype)`). |
|
|
seed: Optional Python integer or instance of |
|
|
`keras.random.SeedGenerator`. |
|
|
By default, the `seed` argument is `None`, and an internal global |
|
|
`keras.random.SeedGenerator` is used. The `seed` argument can be |
|
|
used to ensure deterministic (repeatable) random number generation. |
|
|
Note that passing an integer as the `seed` value will produce the |
|
|
same random values for each call. To generate different random |
|
|
values for repeated calls, an instance of |
|
|
`keras.random.SeedGenerator` must be provided as the `seed` value. |
|
|
Remark concerning the JAX backend: When tracing functions with the |
|
|
JAX backend the global `keras.random.SeedGenerator` is not |
|
|
supported. Therefore, during tracing the default value seed=None |
|
|
will produce an error, and a `seed` argument must be provided. |
|
|
""" |
|
|
return backend.random.beta( |
|
|
shape=shape, alpha=alpha, beta=beta, dtype=dtype, seed=seed |
|
|
) |
|
|
|