|
|
import jax |
|
|
import numpy as np |
|
|
from jax import lax |
|
|
|
|
|
from keras.src import backend |
|
|
from keras.src.backend.common.backend_utils import ( |
|
|
compute_conv_transpose_padding_args_for_jax, |
|
|
) |
|
|
from keras.src.backend.numpy.core import cast |
|
|
from keras.src.backend.numpy.core import convert_to_tensor |
|
|
from keras.src.backend.numpy.core import is_tensor |
|
|
from keras.src.utils.module_utils import scipy |
|
|
|
|
|
|
|
|
def relu(x): |
|
|
x = convert_to_tensor(x) |
|
|
return np.maximum(x, np.array(0.0, x.dtype)) |
|
|
|
|
|
|
|
|
def relu6(x): |
|
|
x = convert_to_tensor(x) |
|
|
|
|
|
|
|
|
return np.minimum( |
|
|
np.maximum(x, np.array(0.0, x.dtype)), np.array(6.0, x.dtype) |
|
|
) |
|
|
|
|
|
|
|
|
def sigmoid(x): |
|
|
x = convert_to_tensor(x) |
|
|
return np.array(1.0, x.dtype) / (np.array(1.0, x.dtype) + np.exp(-x)) |
|
|
|
|
|
|
|
|
def sparse_sigmoid(x): |
|
|
x = convert_to_tensor(x) |
|
|
return np.where( |
|
|
x <= -1, |
|
|
np.array(0.0, x.dtype), |
|
|
np.where( |
|
|
x >= 1, np.array(1.0, x.dtype), np.array(0.5 * (x + 1), x.dtype) |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def tanh(x): |
|
|
return np.tanh(x) |
|
|
|
|
|
|
|
|
def tanh_shrink(x): |
|
|
x = convert_to_tensor(x) |
|
|
return x - np.tanh(x) |
|
|
|
|
|
|
|
|
def softplus(x): |
|
|
x = convert_to_tensor(x) |
|
|
return np.logaddexp(x, np.array(0.0, x.dtype)) |
|
|
|
|
|
|
|
|
def softsign(x): |
|
|
x = convert_to_tensor(x) |
|
|
return x / (np.array(1.0, x.dtype) + np.abs(x)) |
|
|
|
|
|
|
|
|
def soft_shrink(x, threshold=0.5): |
|
|
return np.where( |
|
|
x > threshold, |
|
|
np.array(x - threshold, dtype=x.dtype), |
|
|
np.where( |
|
|
x < -threshold, |
|
|
np.array(x + threshold, dtype=x.dtype), |
|
|
np.array(0.0, dtype=x.dtype), |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def sparse_plus(x): |
|
|
return np.where( |
|
|
x <= -1, |
|
|
np.zeros_like(x, dtype=x.dtype), |
|
|
np.where(x < 1, np.array((1 / 4) * (x + 1) ** 2, dtype=x.dtype), x), |
|
|
) |
|
|
|
|
|
|
|
|
def silu(x): |
|
|
x = convert_to_tensor(x) |
|
|
return x * sigmoid(x) |
|
|
|
|
|
|
|
|
def squareplus(x, b=4): |
|
|
x = convert_to_tensor(x) |
|
|
b = convert_to_tensor(b, dtype=x.dtype) |
|
|
y = x + np.sqrt(x**2 + b) |
|
|
return y / 2 |
|
|
|
|
|
|
|
|
def log_sigmoid(x): |
|
|
x = convert_to_tensor(x) |
|
|
return -softplus(-x) |
|
|
|
|
|
|
|
|
def leaky_relu(x, negative_slope=0.2): |
|
|
x = convert_to_tensor(x) |
|
|
return np.maximum(x, np.array(negative_slope, x.dtype) * x) |
|
|
|
|
|
|
|
|
def hard_sigmoid(x): |
|
|
|
|
|
|
|
|
x = x / np.array(6.0, x.dtype) + np.array(0.5, x.dtype) |
|
|
return np.where( |
|
|
x <= 0.0, |
|
|
np.array(0.0, x.dtype), |
|
|
np.where(x >= 1.0, np.array(1.0, x.dtype), x), |
|
|
) |
|
|
|
|
|
|
|
|
def hard_silu(x): |
|
|
return x * hard_sigmoid(x) |
|
|
|
|
|
|
|
|
def elu(x, alpha=1.0): |
|
|
x = convert_to_tensor(x) |
|
|
return np.where( |
|
|
x >= np.array(0.0, x.dtype), x, np.array(alpha, x.dtype) * np.expm1(x) |
|
|
) |
|
|
|
|
|
|
|
|
def selu( |
|
|
x, |
|
|
alpha=1.6732632423543772848170429916717, |
|
|
scale=1.0507009873554804934193349852946, |
|
|
): |
|
|
x = convert_to_tensor(x) |
|
|
return np.array(scale, x.dtype) * elu(x, alpha) |
|
|
|
|
|
|
|
|
def gelu(x, approximate=True): |
|
|
x = convert_to_tensor(x) |
|
|
|
|
|
if approximate: |
|
|
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) |
|
|
cdf = np.array(0.5, x.dtype) * ( |
|
|
np.array(1.0, x.dtype) |
|
|
+ np.tanh( |
|
|
sqrt_2_over_pi |
|
|
* (x + np.array(0.044715, x.dtype) * (x**3).astype(x.dtype)) |
|
|
) |
|
|
) |
|
|
return x * cdf |
|
|
else: |
|
|
sqrt_2 = np.sqrt(2).astype(x.dtype) |
|
|
return ( |
|
|
x |
|
|
* (scipy.special.erf(x / sqrt_2) + 1).astype(x.dtype) |
|
|
/ np.array(2, x.dtype) |
|
|
) |
|
|
|
|
|
|
|
|
def celu(x, alpha=1.0): |
|
|
x = convert_to_tensor(x) |
|
|
alpha = np.array(alpha, x.dtype) |
|
|
return np.maximum(x, np.array(0.0, dtype=x.dtype)) + alpha * np.expm1( |
|
|
np.minimum(x, np.array(0.0, dtype=x.dtype)) / alpha |
|
|
) |
|
|
|
|
|
|
|
|
def glu(x, axis=-1): |
|
|
x = convert_to_tensor(x) |
|
|
if x.shape[axis] % 2 != 0: |
|
|
raise ValueError( |
|
|
"axis size must be divisible by 2. " |
|
|
f"Received: x.shape={x.shape} with axis={axis}" |
|
|
) |
|
|
x1, x2 = np.split(x, 2, axis) |
|
|
return x1 * (1 / (1 + np.exp(-x2))) |
|
|
|
|
|
|
|
|
def hard_tanh(x): |
|
|
x = convert_to_tensor(x) |
|
|
min_val = np.asarray(-1.0, x.dtype) |
|
|
max_val = np.asarray(1.0, x.dtype) |
|
|
return np.array(np.clip(x, min_val, max_val), dtype=x.dtype) |
|
|
|
|
|
|
|
|
def hard_shrink(x, threshold=0.5): |
|
|
x = convert_to_tensor(x) |
|
|
threshold = np.asarray(threshold, x.dtype) |
|
|
return np.array( |
|
|
np.where(np.abs(x) > threshold, x, np.array(0.0, dtype=x.dtype)), |
|
|
dtype=x.dtype, |
|
|
) |
|
|
|
|
|
|
|
|
def threshold(x, threshold, default_value): |
|
|
x = convert_to_tensor(x) |
|
|
return np.where(x > threshold, x, np.array(default_value, dtype=x.dtype)) |
|
|
|
|
|
|
|
|
def softmax(x, axis=None): |
|
|
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) |
|
|
return exp_x / np.sum(exp_x, axis=axis, keepdims=True) |
|
|
|
|
|
|
|
|
def log_softmax(x, axis=None): |
|
|
max_x = np.max(x, axis=axis, keepdims=True) |
|
|
logsumexp = np.log(np.exp(x - max_x).sum(axis=axis, keepdims=True)) |
|
|
return x - max_x - logsumexp |
|
|
|
|
|
|
|
|
def sparsemax(logits, axis=-1): |
|
|
|
|
|
logits = convert_to_tensor(logits) |
|
|
logits_sorted = -1.0 * np.sort(-1.0 * logits, axis=axis) |
|
|
logits_cumsum = np.cumsum(logits_sorted, axis=axis) |
|
|
r = np.arange(1, logits.shape[axis] + 1) |
|
|
r_shape = [1] * logits.ndim |
|
|
r_shape[axis] = -1 |
|
|
r = r.reshape(r_shape) |
|
|
support = logits_sorted - (logits_cumsum - 1) / r > 0 |
|
|
|
|
|
k = np.sum(support, axis=axis, keepdims=True) |
|
|
logits_cumsum_safe = np.where(support, logits_cumsum, 0.0) |
|
|
tau = (np.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k |
|
|
output = np.maximum(logits - tau, 0.0) |
|
|
return output |
|
|
|
|
|
|
|
|
def _convert_to_spatial_operand( |
|
|
x, |
|
|
num_spatial_dims, |
|
|
data_format="channels_last", |
|
|
include_batch_and_channels=True, |
|
|
): |
|
|
|
|
|
x = (x,) * num_spatial_dims if isinstance(x, int) else x |
|
|
if not include_batch_and_channels: |
|
|
return x |
|
|
if data_format == "channels_last": |
|
|
x = (1,) + x + (1,) |
|
|
else: |
|
|
x = (1,) + (1,) + x |
|
|
return x |
|
|
|
|
|
|
|
|
def _pool( |
|
|
inputs, |
|
|
initial_value, |
|
|
reduce_fn, |
|
|
pool_size, |
|
|
strides=None, |
|
|
padding="valid", |
|
|
): |
|
|
"""Helper function to define pooling functions. |
|
|
|
|
|
Args: |
|
|
inputs: input data of shape `N+2`. |
|
|
initial_value: the initial value for the reduction. |
|
|
reduce_fn: a reduce function of the form `(T, T) -> T`. |
|
|
pool_size: a sequence of `N` integers, representing the window size to |
|
|
reduce over. |
|
|
strides: a sequence of `N` integers, representing the inter-window |
|
|
strides (default: `(1, ..., 1)`). |
|
|
padding: either the string `same` or `valid`. |
|
|
|
|
|
Returns: |
|
|
The output of the reduction for each window slice. |
|
|
""" |
|
|
if padding not in ("same", "valid"): |
|
|
raise ValueError( |
|
|
f"Invalid padding '{padding}', must be 'same' or 'valid'." |
|
|
) |
|
|
padding = padding.upper() |
|
|
return np.array( |
|
|
lax.reduce_window( |
|
|
inputs, |
|
|
initial_value, |
|
|
reduce_fn, |
|
|
pool_size, |
|
|
strides, |
|
|
padding, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def max_pool( |
|
|
inputs, |
|
|
pool_size, |
|
|
strides=None, |
|
|
padding="valid", |
|
|
data_format=None, |
|
|
): |
|
|
data_format = backend.standardize_data_format(data_format) |
|
|
num_spatial_dims = inputs.ndim - 2 |
|
|
pool_size = _convert_to_spatial_operand( |
|
|
pool_size, num_spatial_dims, data_format |
|
|
) |
|
|
strides = pool_size if strides is None else strides |
|
|
strides = _convert_to_spatial_operand( |
|
|
strides, num_spatial_dims, data_format |
|
|
) |
|
|
return _pool(inputs, -np.inf, lax.max, pool_size, strides, padding) |
|
|
|
|
|
|
|
|
def average_pool( |
|
|
inputs, |
|
|
pool_size, |
|
|
strides, |
|
|
padding, |
|
|
data_format=None, |
|
|
): |
|
|
data_format = backend.standardize_data_format(data_format) |
|
|
num_spatial_dims = inputs.ndim - 2 |
|
|
pool_size = _convert_to_spatial_operand( |
|
|
pool_size, num_spatial_dims, data_format |
|
|
) |
|
|
strides = pool_size if strides is None else strides |
|
|
strides = _convert_to_spatial_operand( |
|
|
strides, num_spatial_dims, data_format |
|
|
) |
|
|
|
|
|
pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding) |
|
|
if padding == "valid": |
|
|
|
|
|
return pooled / np.prod(pool_size) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shape = [ |
|
|
(a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size) |
|
|
] |
|
|
window_counts = _pool( |
|
|
np.ones(shape, inputs.dtype), |
|
|
0.0, |
|
|
lax.add, |
|
|
pool_size, |
|
|
strides, |
|
|
padding, |
|
|
) |
|
|
return pooled / window_counts |
|
|
|
|
|
|
|
|
def _convert_to_lax_conv_dimension_numbers( |
|
|
num_spatial_dims, |
|
|
data_format="channels_last", |
|
|
transpose=False, |
|
|
): |
|
|
"""Create a `lax.ConvDimensionNumbers` for the given inputs.""" |
|
|
num_dims = num_spatial_dims + 2 |
|
|
|
|
|
if data_format == "channels_last": |
|
|
spatial_dims = tuple(range(1, num_dims - 1)) |
|
|
inputs_dn = (0, num_dims - 1) + spatial_dims |
|
|
else: |
|
|
spatial_dims = tuple(range(2, num_dims)) |
|
|
inputs_dn = (0, 1) + spatial_dims |
|
|
|
|
|
if transpose: |
|
|
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) |
|
|
else: |
|
|
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) |
|
|
|
|
|
return lax.ConvDimensionNumbers( |
|
|
lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn |
|
|
) |
|
|
|
|
|
|
|
|
def conv( |
|
|
inputs, |
|
|
kernel, |
|
|
strides=1, |
|
|
padding="valid", |
|
|
data_format=None, |
|
|
dilation_rate=1, |
|
|
): |
|
|
data_format = backend.standardize_data_format(data_format) |
|
|
num_spatial_dims = inputs.ndim - 2 |
|
|
dimension_numbers = _convert_to_lax_conv_dimension_numbers( |
|
|
num_spatial_dims, |
|
|
data_format, |
|
|
transpose=False, |
|
|
) |
|
|
strides = _convert_to_spatial_operand( |
|
|
strides, |
|
|
num_spatial_dims, |
|
|
data_format, |
|
|
include_batch_and_channels=False, |
|
|
) |
|
|
dilation_rate = _convert_to_spatial_operand( |
|
|
dilation_rate, |
|
|
num_spatial_dims, |
|
|
data_format, |
|
|
include_batch_and_channels=False, |
|
|
) |
|
|
if data_format == "channels_last": |
|
|
channels = inputs.shape[-1] |
|
|
else: |
|
|
channels = inputs.shape[1] |
|
|
kernel_in_channels = kernel.shape[-2] |
|
|
if channels % kernel_in_channels > 0: |
|
|
raise ValueError( |
|
|
"The number of input channels must be evenly divisible by " |
|
|
f"kernel's in_channels. Received input channels {channels} and " |
|
|
f"kernel in_channels {kernel_in_channels}. " |
|
|
) |
|
|
feature_group_count = channels // kernel_in_channels |
|
|
return np.array( |
|
|
jax.lax.conv_general_dilated( |
|
|
inputs, |
|
|
kernel if is_tensor(kernel) else kernel.numpy(), |
|
|
strides, |
|
|
padding, |
|
|
rhs_dilation=dilation_rate, |
|
|
dimension_numbers=dimension_numbers, |
|
|
feature_group_count=feature_group_count, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def depthwise_conv( |
|
|
inputs, |
|
|
kernel, |
|
|
strides=1, |
|
|
padding="valid", |
|
|
data_format=None, |
|
|
dilation_rate=1, |
|
|
): |
|
|
data_format = backend.standardize_data_format(data_format) |
|
|
num_spatial_dims = inputs.ndim - 2 |
|
|
dimension_numbers = _convert_to_lax_conv_dimension_numbers( |
|
|
num_spatial_dims, |
|
|
data_format, |
|
|
transpose=False, |
|
|
) |
|
|
strides = _convert_to_spatial_operand( |
|
|
strides, |
|
|
num_spatial_dims, |
|
|
data_format, |
|
|
include_batch_and_channels=False, |
|
|
) |
|
|
dilation_rate = _convert_to_spatial_operand( |
|
|
dilation_rate, |
|
|
num_spatial_dims, |
|
|
data_format, |
|
|
include_batch_and_channels=False, |
|
|
) |
|
|
feature_group_count = ( |
|
|
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1] |
|
|
) |
|
|
kernel = np.reshape( |
|
|
kernel if is_tensor(kernel) else kernel.numpy(), |
|
|
kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]), |
|
|
) |
|
|
return np.array( |
|
|
jax.lax.conv_general_dilated( |
|
|
inputs, |
|
|
kernel, |
|
|
strides, |
|
|
padding, |
|
|
rhs_dilation=dilation_rate, |
|
|
dimension_numbers=dimension_numbers, |
|
|
feature_group_count=feature_group_count, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def separable_conv( |
|
|
inputs, |
|
|
depthwise_kernel, |
|
|
pointwise_kernel, |
|
|
strides=1, |
|
|
padding="valid", |
|
|
data_format=None, |
|
|
dilation_rate=1, |
|
|
): |
|
|
data_format = backend.standardize_data_format(data_format) |
|
|
depthwise_conv_output = depthwise_conv( |
|
|
inputs, |
|
|
depthwise_kernel, |
|
|
strides, |
|
|
padding, |
|
|
data_format, |
|
|
dilation_rate, |
|
|
) |
|
|
return conv( |
|
|
depthwise_conv_output, |
|
|
pointwise_kernel, |
|
|
strides=1, |
|
|
padding="valid", |
|
|
data_format=data_format, |
|
|
dilation_rate=dilation_rate, |
|
|
) |
|
|
|
|
|
|
|
|
def conv_transpose( |
|
|
inputs, |
|
|
kernel, |
|
|
strides=1, |
|
|
padding="valid", |
|
|
output_padding=None, |
|
|
data_format=None, |
|
|
dilation_rate=1, |
|
|
): |
|
|
data_format = backend.standardize_data_format(data_format) |
|
|
num_spatial_dims = inputs.ndim - 2 |
|
|
padding_values = compute_conv_transpose_padding_args_for_jax( |
|
|
input_shape=inputs.shape, |
|
|
kernel_shape=kernel.shape, |
|
|
strides=strides, |
|
|
padding=padding, |
|
|
output_padding=output_padding, |
|
|
dilation_rate=dilation_rate, |
|
|
) |
|
|
dimension_numbers = _convert_to_lax_conv_dimension_numbers( |
|
|
num_spatial_dims, |
|
|
data_format, |
|
|
transpose=False, |
|
|
) |
|
|
strides = _convert_to_spatial_operand( |
|
|
strides, |
|
|
num_spatial_dims, |
|
|
data_format, |
|
|
include_batch_and_channels=False, |
|
|
) |
|
|
dilation_rate = _convert_to_spatial_operand( |
|
|
dilation_rate, |
|
|
num_spatial_dims, |
|
|
data_format, |
|
|
include_batch_and_channels=False, |
|
|
) |
|
|
|
|
|
return np.array( |
|
|
jax.lax.conv_transpose( |
|
|
inputs, |
|
|
kernel if is_tensor(kernel) else kernel.numpy(), |
|
|
strides, |
|
|
padding=padding_values, |
|
|
rhs_dilation=dilation_rate, |
|
|
dimension_numbers=dimension_numbers, |
|
|
transpose_kernel=True, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False): |
|
|
if sparse: |
|
|
raise ValueError("Unsupported value `sparse=True` with numpy backend") |
|
|
x = convert_to_tensor(x) |
|
|
input_shape = x.shape |
|
|
|
|
|
x = x.reshape(-1) |
|
|
if not num_classes: |
|
|
num_classes = np.max(x) + 1 |
|
|
|
|
|
batch_size = x.shape[0] |
|
|
categorical = np.zeros((batch_size, num_classes), dtype=dtype) |
|
|
valid_indices = x >= 0 |
|
|
categorical[np.arange(batch_size)[valid_indices], x[valid_indices]] = 1 |
|
|
|
|
|
|
|
|
output_shape = input_shape + (num_classes,) |
|
|
categorical = np.reshape(categorical, output_shape) |
|
|
|
|
|
|
|
|
if axis != -1: |
|
|
categorical = np.moveaxis(categorical, -1, axis) |
|
|
|
|
|
return categorical |
|
|
|
|
|
|
|
|
def multi_hot(x, num_classes, axis=-1, dtype="float32", sparse=False): |
|
|
if sparse: |
|
|
raise ValueError("Unsupported value `sparse=True` with numpy backend") |
|
|
x = convert_to_tensor(x) |
|
|
reduction_axis = 1 if len(x.shape) > 1 else 0 |
|
|
outputs = np.max( |
|
|
one_hot(cast(x, "int32"), num_classes, axis=axis, dtype=dtype), |
|
|
axis=reduction_axis, |
|
|
) |
|
|
return outputs |
|
|
|
|
|
|
|
|
def categorical_crossentropy(target, output, from_logits=False, axis=-1): |
|
|
target = np.array(target) |
|
|
output = np.array(output) |
|
|
|
|
|
if target.shape != output.shape: |
|
|
raise ValueError( |
|
|
"Arguments `target` and `output` must have the same shape. " |
|
|
"Received: " |
|
|
f"target.shape={target.shape}, output.shape={output.shape}" |
|
|
) |
|
|
if len(target.shape) < 1: |
|
|
raise ValueError( |
|
|
"Arguments `target` and `output` must be at least rank 1. " |
|
|
"Received: " |
|
|
f"target.shape={target.shape}, output.shape={output.shape}" |
|
|
) |
|
|
|
|
|
if from_logits: |
|
|
log_prob = log_softmax(output, axis=axis) |
|
|
else: |
|
|
output = output / np.sum(output, axis, keepdims=True) |
|
|
output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) |
|
|
log_prob = np.log(output) |
|
|
return -np.sum(target * log_prob, axis=axis) |
|
|
|
|
|
|
|
|
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): |
|
|
target = np.array(target, dtype="int32") |
|
|
output = np.array(output) |
|
|
if len(target.shape) == len(output.shape) and target.shape[-1] == 1: |
|
|
target = np.squeeze(target, axis=-1) |
|
|
|
|
|
if len(output.shape) < 1: |
|
|
raise ValueError( |
|
|
"Argument `output` must be at least rank 1. " |
|
|
"Received: " |
|
|
f"output.shape={output.shape}" |
|
|
) |
|
|
if target.shape != output.shape[:-1]: |
|
|
raise ValueError( |
|
|
"Arguments `target` and `output` must have the same shape " |
|
|
"up until the last dimension: " |
|
|
f"target.shape={target.shape}, output.shape={output.shape}" |
|
|
) |
|
|
if from_logits: |
|
|
log_prob = log_softmax(output, axis=axis) |
|
|
else: |
|
|
output = output / np.sum(output, axis, keepdims=True) |
|
|
output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) |
|
|
log_prob = np.log(output) |
|
|
target = one_hot(target, output.shape[axis], axis=axis) |
|
|
return -np.sum(target * log_prob, axis=axis) |
|
|
|
|
|
|
|
|
def binary_crossentropy(target, output, from_logits=False): |
|
|
target = np.array(target) |
|
|
output = np.array(output) |
|
|
|
|
|
if target.shape != output.shape: |
|
|
raise ValueError( |
|
|
"Arguments `target` and `output` must have the same shape. " |
|
|
"Received: " |
|
|
f"target.shape={target.shape}, output.shape={output.shape}" |
|
|
) |
|
|
|
|
|
if from_logits: |
|
|
output = sigmoid(output) |
|
|
|
|
|
output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) |
|
|
bce = target * np.log(output) |
|
|
bce += (1.0 - target) * np.log(1.0 - output) |
|
|
return -bce |
|
|
|
|
|
|
|
|
def moments(x, axes, keepdims=False, synchronized=False): |
|
|
if synchronized: |
|
|
raise NotImplementedError( |
|
|
"Argument synchronized=True is not supported with NumPy." |
|
|
) |
|
|
axes = tuple(axes) if isinstance(axes, list) else axes |
|
|
|
|
|
|
|
|
|
|
|
need_cast = False |
|
|
ori_dtype = backend.standardize_dtype(x.dtype) |
|
|
if ori_dtype == "float16": |
|
|
need_cast = True |
|
|
x = cast(x, "float32") |
|
|
|
|
|
mean = np.mean(x, axes, keepdims=True) |
|
|
|
|
|
|
|
|
|
|
|
variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean) |
|
|
|
|
|
if not keepdims: |
|
|
mean = np.squeeze(mean, axes) |
|
|
variance = np.squeeze(variance, axes) |
|
|
if need_cast: |
|
|
|
|
|
mean = np.clip(mean, np.finfo(np.float16).min, np.finfo(np.float16).max) |
|
|
variance = np.clip( |
|
|
variance, np.finfo(np.float16).min, np.finfo(np.float16).max |
|
|
) |
|
|
mean = cast(mean, ori_dtype) |
|
|
variance = cast(variance, ori_dtype) |
|
|
return mean, variance |
|
|
|
|
|
|
|
|
def batch_normalization( |
|
|
x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 |
|
|
): |
|
|
shape = [1] * len(x.shape) |
|
|
shape[axis] = mean.shape[0] |
|
|
mean = np.reshape(mean, shape) |
|
|
variance = np.reshape(variance, shape) |
|
|
|
|
|
inv = 1.0 / np.sqrt(variance + epsilon) |
|
|
if scale is not None: |
|
|
scale = np.reshape(scale, shape) |
|
|
inv = inv * scale |
|
|
|
|
|
res = -mean * inv |
|
|
if offset is not None: |
|
|
offset = np.reshape(offset, shape) |
|
|
res = res + offset |
|
|
|
|
|
return x * inv + res |
|
|
|
|
|
|
|
|
def ctc_loss(target, output, target_length, output_length, mask_index=0): |
|
|
|
|
|
|
|
|
target = convert_to_tensor(target, dtype="int32") |
|
|
output = convert_to_tensor(output) |
|
|
target_length = convert_to_tensor(target_length, "int32") |
|
|
output_length = convert_to_tensor(output_length, "int32") |
|
|
batch_size, max_input_length, num_classes = output.shape |
|
|
batch_size, max_label_length = target.shape |
|
|
log_epsilon = -1e5 |
|
|
|
|
|
|
|
|
dtype = backend.result_type(output.dtype, "float32") |
|
|
output = output.astype(dtype) |
|
|
|
|
|
def _lengths_to_paddings(lengths, max_length): |
|
|
indices = np.arange(max_length).reshape( |
|
|
(1,) * lengths.ndim + (max_length,) |
|
|
) |
|
|
lengths = np.expand_dims(lengths, axis=-1) |
|
|
elem_valid = indices < lengths |
|
|
return np.logical_not(elem_valid) |
|
|
|
|
|
target_paddings = _lengths_to_paddings(target_length, max_label_length) |
|
|
output_paddings = _lengths_to_paddings(output_length, max_input_length) |
|
|
target_paddings = target_paddings.astype(output.dtype) |
|
|
output_paddings = output_paddings.astype(output.dtype) |
|
|
|
|
|
logprobs = log_softmax(output, axis=-1) |
|
|
label_lengths = max_label_length - np.sum(target_paddings, axis=1).astype( |
|
|
np.int32 |
|
|
) |
|
|
|
|
|
|
|
|
repeat = (target[:, :-1] == target[:, 1:]).astype(np.float32) |
|
|
repeat = np.pad(repeat, ((0, 0), (0, 1))) |
|
|
|
|
|
logprobs_phi = logprobs[:, :, mask_index : mask_index + 1] |
|
|
logprobs_phi = np.transpose(logprobs_phi, (1, 0, 2)) |
|
|
|
|
|
_one_hot = one_hot(target, num_classes=num_classes) |
|
|
logprobs_emit = np.einsum("btk,bnk->btn", logprobs, _one_hot) |
|
|
logprobs_emit = np.transpose(logprobs_emit, (1, 0, 2)) |
|
|
|
|
|
|
|
|
logalpha_phi_init = ( |
|
|
np.ones((batch_size, max_label_length + 1), dtype=output.dtype) |
|
|
* log_epsilon |
|
|
) |
|
|
logalpha_phi_init[:, 0] = 0.0 |
|
|
logalpha_emit_init = ( |
|
|
np.ones((batch_size, max_label_length), dtype=output.dtype) |
|
|
* log_epsilon |
|
|
) |
|
|
|
|
|
def update_phi_score(phi, added_score): |
|
|
|
|
|
return np.concatenate( |
|
|
[phi[:, :1], np.logaddexp(phi[:, 1:], added_score)], axis=-1 |
|
|
) |
|
|
|
|
|
def loop_body(prev, x): |
|
|
prev_phi, prev_emit = prev |
|
|
|
|
|
prev_phi_orig = prev_phi |
|
|
prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) |
|
|
|
|
|
logprob_emit, logprob_phi, pad = x |
|
|
|
|
|
|
|
|
next_emit = np.logaddexp( |
|
|
prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit |
|
|
) |
|
|
|
|
|
next_phi = prev_phi + logprob_phi |
|
|
|
|
|
next_phi = update_phi_score( |
|
|
next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat) |
|
|
) |
|
|
|
|
|
pad = pad.reshape((batch_size, 1)) |
|
|
next_emit = pad * prev_emit + (1.0 - pad) * next_emit |
|
|
next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi |
|
|
|
|
|
return (next_phi, next_emit), (next_phi, next_emit) |
|
|
|
|
|
def np_scan(f, init, xs): |
|
|
carry = init |
|
|
ys = [] |
|
|
for x in zip(*xs): |
|
|
carry, y = f(carry, x) |
|
|
ys.append(y) |
|
|
result = [] |
|
|
for i in range(len(ys[0])): |
|
|
result.append(np.stack([y[i] for y in ys])) |
|
|
return carry, result |
|
|
|
|
|
xs = (logprobs_emit, logprobs_phi, output_paddings.transpose((1, 0))) |
|
|
_, (logalpha_phi, logalpha_emit) = np_scan( |
|
|
loop_body, (logalpha_phi_init, logalpha_emit_init), xs |
|
|
) |
|
|
|
|
|
|
|
|
logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) |
|
|
logalpha_phi[-1] = logalpha_phi_last |
|
|
|
|
|
|
|
|
|
|
|
_one_hot = one_hot(label_lengths, num_classes=max_label_length + 1) |
|
|
per_seq_loss = -np.einsum("bn,bn->b", logalpha_phi_last, _one_hot) |
|
|
return per_seq_loss |
|
|
|
|
|
|
|
|
def _ctc_greedy_decode( |
|
|
inputs, |
|
|
sequence_lengths, |
|
|
merge_repeated=True, |
|
|
mask_index=None, |
|
|
): |
|
|
inputs = convert_to_tensor(inputs) |
|
|
sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") |
|
|
batch_size, max_length, num_classes = inputs.shape |
|
|
|
|
|
if mask_index is None: |
|
|
mask_index = num_classes - 1 |
|
|
|
|
|
indices = np.argmax(inputs, axis=-1).astype("int32") |
|
|
scores = np.max(inputs, axis=-1) |
|
|
|
|
|
seqlen_mask = np.arange(max_length)[None, :] |
|
|
seqlen_mask = seqlen_mask >= sequence_lengths[:, None] |
|
|
|
|
|
indices = np.where(seqlen_mask, mask_index, indices) |
|
|
scores = np.where(seqlen_mask, 0.0, scores) |
|
|
|
|
|
if merge_repeated: |
|
|
repeat_mask = indices[:, 1:] == indices[:, :-1] |
|
|
repeat_mask = np.pad(repeat_mask, ((0, 0), (1, 0))) |
|
|
indices = np.where(repeat_mask, mask_index, indices) |
|
|
|
|
|
|
|
|
invalid_mask = indices == mask_index |
|
|
indices = np.where(invalid_mask, -1, indices) |
|
|
|
|
|
|
|
|
order = np.expand_dims(np.arange(max_length), axis=0) |
|
|
order = np.tile(order, (batch_size, 1)) |
|
|
order = np.where(invalid_mask, max_length, order) |
|
|
order = np.argsort(order, axis=-1) |
|
|
indices = np.take_along_axis(indices, order, axis=-1) |
|
|
|
|
|
scores = -np.sum(scores, axis=1)[:, None] |
|
|
indices = np.expand_dims(indices, axis=0) |
|
|
return indices, scores |
|
|
|
|
|
|
|
|
def _ctc_beam_search_decode( |
|
|
inputs, |
|
|
sequence_lengths, |
|
|
beam_width=100, |
|
|
top_paths=1, |
|
|
mask_index=None, |
|
|
): |
|
|
inputs = convert_to_tensor(inputs) |
|
|
sequence_lengths = convert_to_tensor(sequence_lengths) |
|
|
|
|
|
batch_size, max_seq_len, num_classes = inputs.shape |
|
|
inputs = log_softmax(inputs, axis=-1) |
|
|
seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_lengths[:, None] |
|
|
|
|
|
if mask_index is None: |
|
|
mask_index = num_classes - 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = np.flip(inputs, axis=2) |
|
|
mask_index = num_classes - mask_index - 1 |
|
|
|
|
|
_pad = -1 |
|
|
|
|
|
init_paths = np.full( |
|
|
(batch_size, 2 * beam_width, max_seq_len), _pad, dtype=np.int32 |
|
|
) |
|
|
|
|
|
num_init_paths = np.min(np.array([num_classes, beam_width])) |
|
|
max_classes = np.argsort(inputs[:, 0], axis=1)[:, -num_init_paths:] |
|
|
init_classes = np.where(max_classes == mask_index, _pad, max_classes) |
|
|
init_paths[:, :num_init_paths, 0] = init_classes |
|
|
|
|
|
init_scores = np.full( |
|
|
(batch_size, 2 * beam_width), -np.inf, dtype=inputs.dtype |
|
|
) |
|
|
init_scores[:, :num_init_paths] = np.take_along_axis( |
|
|
inputs[:, 0], max_classes, axis=1 |
|
|
) |
|
|
init_masked = init_paths[:, :, 0] == _pad |
|
|
|
|
|
def _extend_paths(paths, scores, masked, x): |
|
|
paths = np.repeat(paths, num_classes, axis=0) |
|
|
scores = np.repeat(scores, num_classes) |
|
|
masked = np.repeat(masked, num_classes) |
|
|
|
|
|
path_tail_index = np.argmax(paths == _pad, axis=1) |
|
|
paths_arange = np.arange(2 * beam_width * num_classes) |
|
|
path_tails = paths[paths_arange, path_tail_index - 1] |
|
|
path_tails = np.where(path_tail_index == 0, _pad, path_tails) |
|
|
|
|
|
classes = np.arange(num_classes) |
|
|
classes[mask_index] = _pad |
|
|
classes = np.tile(classes, 2 * beam_width) |
|
|
|
|
|
prev_masked = masked |
|
|
masked = classes == _pad |
|
|
|
|
|
masked_repeat = ~prev_masked & (path_tails == classes) |
|
|
classes = np.where(masked_repeat, _pad, classes) |
|
|
paths[paths_arange, path_tail_index] = classes |
|
|
|
|
|
x = np.tile(x, 2 * beam_width) |
|
|
scores = scores + x |
|
|
|
|
|
return paths, scores, masked |
|
|
|
|
|
def _merge_scores(unique_inverse, scores): |
|
|
scores_max = np.max(scores) |
|
|
scores_exp = np.exp(scores - scores_max) |
|
|
scores = np.zeros_like(scores) |
|
|
for i, u in enumerate(unique_inverse): |
|
|
scores[u] += scores_exp[i] |
|
|
scores = np.log(scores) + scores_max |
|
|
return scores |
|
|
|
|
|
def _prune_paths(paths, scores, masked): |
|
|
paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) |
|
|
pad_size = (2 * num_classes * beam_width) - len(paths) |
|
|
if pad_size > 0: |
|
|
paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) |
|
|
paths = paths[: 2 * num_classes * beam_width] |
|
|
if len(unique_inverse.shape) >= 2: |
|
|
unique_inverse = np.squeeze(unique_inverse, axis=1) |
|
|
|
|
|
emit_scores = np.where(masked, -np.inf, scores) |
|
|
mask_scores = np.where(masked, scores, -np.inf) |
|
|
|
|
|
emit_scores = _merge_scores(unique_inverse, emit_scores) |
|
|
mask_scores = _merge_scores(unique_inverse, mask_scores) |
|
|
|
|
|
total_scores = np.logaddexp(emit_scores, mask_scores) |
|
|
top_indices = np.argsort(total_scores, kind="stable")[-beam_width:] |
|
|
|
|
|
paths = paths[top_indices] |
|
|
emit_scores = emit_scores[top_indices] |
|
|
mask_scores = mask_scores[top_indices] |
|
|
|
|
|
paths = np.tile(paths, (2, 1)) |
|
|
scores = np.concatenate([emit_scores, mask_scores]) |
|
|
masked = np.concatenate( |
|
|
[np.zeros(beam_width, bool), np.ones(beam_width, bool)] |
|
|
) |
|
|
|
|
|
return paths, scores, masked |
|
|
|
|
|
def _decode_step(paths, scores, masked, x): |
|
|
paths, scores, masked = _extend_paths(paths, scores, masked, x) |
|
|
paths, scores, masked = _prune_paths(paths, scores, masked) |
|
|
return paths, scores, masked |
|
|
|
|
|
def _step(prev, x): |
|
|
paths, scores, masked = prev |
|
|
x, seqlen_mask = x |
|
|
if not seqlen_mask: |
|
|
paths, scores, masked = _decode_step(paths, scores, masked, x) |
|
|
return (paths, scores, masked), None |
|
|
|
|
|
def _decode_batch( |
|
|
init_paths, init_scores, init_masked, inputs, seqlen_mask |
|
|
): |
|
|
def np_scan_only_carry(f, init, xs): |
|
|
carry = init |
|
|
for x in zip(*xs): |
|
|
carry, y = f(carry, x) |
|
|
return carry, None |
|
|
|
|
|
(paths, scores, masked), _ = np_scan_only_carry( |
|
|
_step, |
|
|
(init_paths, init_scores, init_masked), |
|
|
(inputs[1:], seqlen_mask[1:]), |
|
|
) |
|
|
|
|
|
paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) |
|
|
pad_size = (2 * num_classes * beam_width) - len(paths) |
|
|
if pad_size > 0: |
|
|
paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) |
|
|
paths = paths[: 2 * num_classes * beam_width] |
|
|
if len(unique_inverse.shape) >= 2: |
|
|
unique_inverse = np.squeeze(unique_inverse, axis=1) |
|
|
scores = _merge_scores(unique_inverse, scores) |
|
|
|
|
|
top_indices = np.argsort(scores)[-top_paths:][::-1] |
|
|
paths = paths[top_indices] |
|
|
scores = scores[top_indices] |
|
|
|
|
|
return paths, scores |
|
|
|
|
|
results = [ |
|
|
_decode_batch(p, s, m, i, sm) |
|
|
for p, s, m, i, sm in zip( |
|
|
init_paths, init_scores, init_masked, inputs, seqlen_mask |
|
|
) |
|
|
] |
|
|
paths = np.stack([r[0] for r in results]) |
|
|
scores = np.stack([r[1] for r in results]) |
|
|
|
|
|
|
|
|
paths = np.where(paths == _pad, _pad, num_classes - paths - 1) |
|
|
paths = np.transpose(paths, [1, 0, 2]) |
|
|
return paths, scores |
|
|
|
|
|
|
|
|
def ctc_decode( |
|
|
inputs, |
|
|
sequence_lengths, |
|
|
strategy="greedy", |
|
|
beam_width=100, |
|
|
top_paths=1, |
|
|
merge_repeated=True, |
|
|
mask_index=0, |
|
|
): |
|
|
inputs = convert_to_tensor(inputs) |
|
|
dtype = backend.result_type(inputs.dtype, "float32") |
|
|
inputs = cast(inputs, dtype) |
|
|
|
|
|
if strategy == "greedy": |
|
|
return _ctc_greedy_decode( |
|
|
inputs, |
|
|
sequence_lengths, |
|
|
merge_repeated=merge_repeated, |
|
|
mask_index=mask_index, |
|
|
) |
|
|
elif strategy == "beam_search": |
|
|
return _ctc_beam_search_decode( |
|
|
inputs, |
|
|
sequence_lengths, |
|
|
beam_width=beam_width, |
|
|
top_paths=top_paths, |
|
|
mask_index=mask_index, |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Invalid strategy {strategy}. Supported values are " |
|
|
"'greedy' and 'beam_search'." |
|
|
) |
|
|
|
|
|
|
|
|
def psnr(x1, x2, max_val): |
|
|
if x1.shape != x2.shape: |
|
|
raise ValueError( |
|
|
f"Input shapes {x1.shape} and {x2.shape} must " |
|
|
"match for PSNR calculation. " |
|
|
) |
|
|
|
|
|
max_val = convert_to_tensor(max_val, dtype=x2.dtype) |
|
|
mse = np.mean(np.square(x1 - x2)) |
|
|
psnr = 20 * np.log10(max_val) - 10 * np.log10(mse) |
|
|
return psnr |
|
|
|
|
|
|
|
|
def _get_large_negative(dtype): |
|
|
dtype = backend.standardize_dtype(dtype) |
|
|
val = 65500.0 if dtype == "float16" else 3.38953e38 |
|
|
return np.asarray(val * -0.7, dtype=dtype) |
|
|
|
|
|
|
|
|
def _apply_masks(logits, mask, is_causal): |
|
|
if mask is None and not is_causal: |
|
|
return logits |
|
|
|
|
|
combined_mask = np.ones_like(logits, dtype=np.bool_) |
|
|
if mask is not None: |
|
|
combined_mask = np.logical_and(combined_mask, mask) |
|
|
|
|
|
if is_causal: |
|
|
T, S = logits.shape[2], logits.shape[3] |
|
|
mask = np.tril(np.ones((T, S), dtype=np.bool_)) |
|
|
mask = mask[None, None, :, :] |
|
|
combined_mask = np.logical_and(combined_mask, mask) |
|
|
|
|
|
padded_logits = np.where( |
|
|
combined_mask, logits, _get_large_negative(logits.dtype) |
|
|
) |
|
|
return padded_logits |
|
|
|
|
|
|
|
|
def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): |
|
|
original_dtype = key.dtype |
|
|
logits_dtype = np.promote_types(query.dtype, np.float32) |
|
|
if backend.standardize_dtype(key.dtype) == "bfloat16": |
|
|
|
|
|
key = key.astype("float32") |
|
|
value = value.astype("float32") |
|
|
logits = np.einsum("BTNH,BSNH->BNTS", query, key) |
|
|
logits = logits.astype(logits_dtype) |
|
|
logits *= np.array(scale, dtype=logits.dtype) |
|
|
|
|
|
if bias is not None: |
|
|
logits = (logits + bias).astype(logits.dtype) |
|
|
|
|
|
padded_logits = _apply_masks(logits, mask, is_causal) |
|
|
|
|
|
|
|
|
padded_logits = padded_logits.astype(np.float32) |
|
|
probs = softmax(padded_logits, axis=-1).astype(original_dtype) |
|
|
encoded_dtype = probs.dtype |
|
|
if backend.standardize_dtype(probs.dtype) == "bfloat16": |
|
|
|
|
|
probs = probs.astype("float32") |
|
|
value = value.astype("float32") |
|
|
encoded = np.einsum("BNTS,BSNH->BTNH", probs, value) |
|
|
encoded = encoded.astype(encoded_dtype) |
|
|
return encoded |
|
|
|
|
|
|
|
|
def dot_product_attention( |
|
|
query, |
|
|
key, |
|
|
value, |
|
|
bias=None, |
|
|
mask=None, |
|
|
scale=None, |
|
|
is_causal=False, |
|
|
flash_attention=None, |
|
|
attn_logits_soft_cap=None, |
|
|
): |
|
|
if flash_attention is None: |
|
|
flash_attention = False |
|
|
if flash_attention: |
|
|
raise ValueError("Flash attention is not supported in numpy backend.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query = convert_to_tensor(query) |
|
|
key = convert_to_tensor(key) |
|
|
value = convert_to_tensor(value) |
|
|
if len(query.shape) != 4: |
|
|
raise ValueError( |
|
|
"`dot_product_attention` only supports 4D inputs. " |
|
|
f"Received: query.shape={query.shape}, key.shape={key.shape}, " |
|
|
f"value.shape={value.shape}." |
|
|
) |
|
|
_, _, _, H = key.shape |
|
|
scale = (1.0 / np.sqrt(H)) if scale is None else scale |
|
|
return _dot_product_attention_xla( |
|
|
query, key, value, bias, mask, is_causal, scale |
|
|
) |
|
|
|