import jax import numpy as np from jax import lax from keras.src import backend from keras.src.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_jax, ) from keras.src.backend.numpy.core import cast from keras.src.backend.numpy.core import convert_to_tensor from keras.src.backend.numpy.core import is_tensor from keras.src.utils.module_utils import scipy def relu(x): x = convert_to_tensor(x) return np.maximum(x, np.array(0.0, x.dtype)) def relu6(x): x = convert_to_tensor(x) # np.clip incorrectly promote bfloat16 to float32, so we replace it with # np.minimum and np.maximum here return np.minimum( np.maximum(x, np.array(0.0, x.dtype)), np.array(6.0, x.dtype) ) def sigmoid(x): x = convert_to_tensor(x) return np.array(1.0, x.dtype) / (np.array(1.0, x.dtype) + np.exp(-x)) def sparse_sigmoid(x): x = convert_to_tensor(x) return np.where( x <= -1, np.array(0.0, x.dtype), np.where( x >= 1, np.array(1.0, x.dtype), np.array(0.5 * (x + 1), x.dtype) ), ) def tanh(x): return np.tanh(x) def tanh_shrink(x): x = convert_to_tensor(x) return x - np.tanh(x) def softplus(x): x = convert_to_tensor(x) return np.logaddexp(x, np.array(0.0, x.dtype)) def softsign(x): x = convert_to_tensor(x) return x / (np.array(1.0, x.dtype) + np.abs(x)) def soft_shrink(x, threshold=0.5): return np.where( x > threshold, np.array(x - threshold, dtype=x.dtype), np.where( x < -threshold, np.array(x + threshold, dtype=x.dtype), np.array(0.0, dtype=x.dtype), ), ) def sparse_plus(x): return np.where( x <= -1, np.zeros_like(x, dtype=x.dtype), np.where(x < 1, np.array((1 / 4) * (x + 1) ** 2, dtype=x.dtype), x), ) def silu(x): x = convert_to_tensor(x) return x * sigmoid(x) def squareplus(x, b=4): x = convert_to_tensor(x) b = convert_to_tensor(b, dtype=x.dtype) y = x + np.sqrt(x**2 + b) return y / 2 def log_sigmoid(x): x = convert_to_tensor(x) return -softplus(-x) def leaky_relu(x, negative_slope=0.2): x = convert_to_tensor(x) return np.maximum(x, np.array(negative_slope, x.dtype) * x) def hard_sigmoid(x): # python numbers will be promoted to float64 by np, so it's necessary to # first convert the python numbers to np scalars x = x / np.array(6.0, x.dtype) + np.array(0.5, x.dtype) return np.where( x <= 0.0, np.array(0.0, x.dtype), np.where(x >= 1.0, np.array(1.0, x.dtype), x), ) def hard_silu(x): return x * hard_sigmoid(x) def elu(x, alpha=1.0): x = convert_to_tensor(x) return np.where( x >= np.array(0.0, x.dtype), x, np.array(alpha, x.dtype) * np.expm1(x) ) def selu( x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946, ): x = convert_to_tensor(x) return np.array(scale, x.dtype) * elu(x, alpha) def gelu(x, approximate=True): x = convert_to_tensor(x) # followed by JAX's implementation if approximate: sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) cdf = np.array(0.5, x.dtype) * ( np.array(1.0, x.dtype) + np.tanh( sqrt_2_over_pi * (x + np.array(0.044715, x.dtype) * (x**3).astype(x.dtype)) ) ) return x * cdf else: sqrt_2 = np.sqrt(2).astype(x.dtype) return ( x * (scipy.special.erf(x / sqrt_2) + 1).astype(x.dtype) / np.array(2, x.dtype) ) def celu(x, alpha=1.0): x = convert_to_tensor(x) alpha = np.array(alpha, x.dtype) return np.maximum(x, np.array(0.0, dtype=x.dtype)) + alpha * np.expm1( np.minimum(x, np.array(0.0, dtype=x.dtype)) / alpha ) def glu(x, axis=-1): x = convert_to_tensor(x) if x.shape[axis] % 2 != 0: raise ValueError( "axis size must be divisible by 2. " f"Received: x.shape={x.shape} with axis={axis}" ) x1, x2 = np.split(x, 2, axis) return x1 * (1 / (1 + np.exp(-x2))) def hard_tanh(x): x = convert_to_tensor(x) min_val = np.asarray(-1.0, x.dtype) max_val = np.asarray(1.0, x.dtype) return np.array(np.clip(x, min_val, max_val), dtype=x.dtype) def hard_shrink(x, threshold=0.5): x = convert_to_tensor(x) threshold = np.asarray(threshold, x.dtype) return np.array( np.where(np.abs(x) > threshold, x, np.array(0.0, dtype=x.dtype)), dtype=x.dtype, ) def threshold(x, threshold, default_value): x = convert_to_tensor(x) return np.where(x > threshold, x, np.array(default_value, dtype=x.dtype)) def softmax(x, axis=None): exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) return exp_x / np.sum(exp_x, axis=axis, keepdims=True) def log_softmax(x, axis=None): max_x = np.max(x, axis=axis, keepdims=True) logsumexp = np.log(np.exp(x - max_x).sum(axis=axis, keepdims=True)) return x - max_x - logsumexp def sparsemax(logits, axis=-1): # Sort logits along the specified axis in descending order logits = convert_to_tensor(logits) logits_sorted = -1.0 * np.sort(-1.0 * logits, axis=axis) logits_cumsum = np.cumsum(logits_sorted, axis=axis) r = np.arange(1, logits.shape[axis] + 1) r_shape = [1] * logits.ndim r_shape[axis] = -1 # Broadcast to match the target axis r = r.reshape(r_shape) support = logits_sorted - (logits_cumsum - 1) / r > 0 # Find the threshold k = np.sum(support, axis=axis, keepdims=True) logits_cumsum_safe = np.where(support, logits_cumsum, 0.0) tau = (np.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k output = np.maximum(logits - tau, 0.0) return output def _convert_to_spatial_operand( x, num_spatial_dims, data_format="channels_last", include_batch_and_channels=True, ): # Helper function that converts an operand to a spatial operand. x = (x,) * num_spatial_dims if isinstance(x, int) else x if not include_batch_and_channels: return x if data_format == "channels_last": x = (1,) + x + (1,) else: x = (1,) + (1,) + x return x def _pool( inputs, initial_value, reduce_fn, pool_size, strides=None, padding="valid", ): """Helper function to define pooling functions. Args: inputs: input data of shape `N+2`. initial_value: the initial value for the reduction. reduce_fn: a reduce function of the form `(T, T) -> T`. pool_size: a sequence of `N` integers, representing the window size to reduce over. strides: a sequence of `N` integers, representing the inter-window strides (default: `(1, ..., 1)`). padding: either the string `same` or `valid`. Returns: The output of the reduction for each window slice. """ if padding not in ("same", "valid"): raise ValueError( f"Invalid padding '{padding}', must be 'same' or 'valid'." ) padding = padding.upper() return np.array( lax.reduce_window( inputs, initial_value, reduce_fn, pool_size, strides, padding, ) ) def max_pool( inputs, pool_size, strides=None, padding="valid", data_format=None, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 pool_size = _convert_to_spatial_operand( pool_size, num_spatial_dims, data_format ) strides = pool_size if strides is None else strides strides = _convert_to_spatial_operand( strides, num_spatial_dims, data_format ) return _pool(inputs, -np.inf, lax.max, pool_size, strides, padding) def average_pool( inputs, pool_size, strides, padding, data_format=None, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 pool_size = _convert_to_spatial_operand( pool_size, num_spatial_dims, data_format ) strides = pool_size if strides is None else strides strides = _convert_to_spatial_operand( strides, num_spatial_dims, data_format ) pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding) if padding == "valid": # Avoid the extra reduce_window. return pooled / np.prod(pool_size) else: # Count the number of valid entries at each input point, then use that # for computing average. Assumes that any two arrays of same shape will # be padded the same. Avoid broadcasting on axis where pooling is # skipped. shape = [ (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size) ] window_counts = _pool( np.ones(shape, inputs.dtype), 0.0, lax.add, pool_size, strides, padding, ) return pooled / window_counts def _convert_to_lax_conv_dimension_numbers( num_spatial_dims, data_format="channels_last", transpose=False, ): """Create a `lax.ConvDimensionNumbers` for the given inputs.""" num_dims = num_spatial_dims + 2 if data_format == "channels_last": spatial_dims = tuple(range(1, num_dims - 1)) inputs_dn = (0, num_dims - 1) + spatial_dims else: spatial_dims = tuple(range(2, num_dims)) inputs_dn = (0, 1) + spatial_dims if transpose: kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) else: kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) return lax.ConvDimensionNumbers( lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn ) def conv( inputs, kernel, strides=1, padding="valid", data_format=None, dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 dimension_numbers = _convert_to_lax_conv_dimension_numbers( num_spatial_dims, data_format, transpose=False, ) strides = _convert_to_spatial_operand( strides, num_spatial_dims, data_format, include_batch_and_channels=False, ) dilation_rate = _convert_to_spatial_operand( dilation_rate, num_spatial_dims, data_format, include_batch_and_channels=False, ) if data_format == "channels_last": channels = inputs.shape[-1] else: channels = inputs.shape[1] kernel_in_channels = kernel.shape[-2] if channels % kernel_in_channels > 0: raise ValueError( "The number of input channels must be evenly divisible by " f"kernel's in_channels. Received input channels {channels} and " f"kernel in_channels {kernel_in_channels}. " ) feature_group_count = channels // kernel_in_channels return np.array( jax.lax.conv_general_dilated( inputs, kernel if is_tensor(kernel) else kernel.numpy(), strides, padding, rhs_dilation=dilation_rate, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, ) ) def depthwise_conv( inputs, kernel, strides=1, padding="valid", data_format=None, dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 dimension_numbers = _convert_to_lax_conv_dimension_numbers( num_spatial_dims, data_format, transpose=False, ) strides = _convert_to_spatial_operand( strides, num_spatial_dims, data_format, include_batch_and_channels=False, ) dilation_rate = _convert_to_spatial_operand( dilation_rate, num_spatial_dims, data_format, include_batch_and_channels=False, ) feature_group_count = ( inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1] ) kernel = np.reshape( kernel if is_tensor(kernel) else kernel.numpy(), kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]), ) return np.array( jax.lax.conv_general_dilated( inputs, kernel, strides, padding, rhs_dilation=dilation_rate, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, ) ) def separable_conv( inputs, depthwise_kernel, pointwise_kernel, strides=1, padding="valid", data_format=None, dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) depthwise_conv_output = depthwise_conv( inputs, depthwise_kernel, strides, padding, data_format, dilation_rate, ) return conv( depthwise_conv_output, pointwise_kernel, strides=1, padding="valid", data_format=data_format, dilation_rate=dilation_rate, ) def conv_transpose( inputs, kernel, strides=1, padding="valid", output_padding=None, data_format=None, dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 padding_values = compute_conv_transpose_padding_args_for_jax( input_shape=inputs.shape, kernel_shape=kernel.shape, strides=strides, padding=padding, output_padding=output_padding, dilation_rate=dilation_rate, ) dimension_numbers = _convert_to_lax_conv_dimension_numbers( num_spatial_dims, data_format, transpose=False, ) strides = _convert_to_spatial_operand( strides, num_spatial_dims, data_format, include_batch_and_channels=False, ) dilation_rate = _convert_to_spatial_operand( dilation_rate, num_spatial_dims, data_format, include_batch_and_channels=False, ) return np.array( jax.lax.conv_transpose( inputs, kernel if is_tensor(kernel) else kernel.numpy(), strides, padding=padding_values, rhs_dilation=dilation_rate, dimension_numbers=dimension_numbers, transpose_kernel=True, ) ) def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False): if sparse: raise ValueError("Unsupported value `sparse=True` with numpy backend") x = convert_to_tensor(x) input_shape = x.shape x = x.reshape(-1) if not num_classes: num_classes = np.max(x) + 1 batch_size = x.shape[0] categorical = np.zeros((batch_size, num_classes), dtype=dtype) valid_indices = x >= 0 categorical[np.arange(batch_size)[valid_indices], x[valid_indices]] = 1 # First, reshape the array with the extra dimension at the end output_shape = input_shape + (num_classes,) categorical = np.reshape(categorical, output_shape) # Then, move this new dimension to the right place (according to axis) if axis != -1: categorical = np.moveaxis(categorical, -1, axis) return categorical def multi_hot(x, num_classes, axis=-1, dtype="float32", sparse=False): if sparse: raise ValueError("Unsupported value `sparse=True` with numpy backend") x = convert_to_tensor(x) reduction_axis = 1 if len(x.shape) > 1 else 0 outputs = np.max( one_hot(cast(x, "int32"), num_classes, axis=axis, dtype=dtype), axis=reduction_axis, ) return outputs def categorical_crossentropy(target, output, from_logits=False, axis=-1): target = np.array(target) output = np.array(output) if target.shape != output.shape: raise ValueError( "Arguments `target` and `output` must have the same shape. " "Received: " f"target.shape={target.shape}, output.shape={output.shape}" ) if len(target.shape) < 1: raise ValueError( "Arguments `target` and `output` must be at least rank 1. " "Received: " f"target.shape={target.shape}, output.shape={output.shape}" ) if from_logits: log_prob = log_softmax(output, axis=axis) else: output = output / np.sum(output, axis, keepdims=True) output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) log_prob = np.log(output) return -np.sum(target * log_prob, axis=axis) def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): target = np.array(target, dtype="int32") output = np.array(output) if len(target.shape) == len(output.shape) and target.shape[-1] == 1: target = np.squeeze(target, axis=-1) if len(output.shape) < 1: raise ValueError( "Argument `output` must be at least rank 1. " "Received: " f"output.shape={output.shape}" ) if target.shape != output.shape[:-1]: raise ValueError( "Arguments `target` and `output` must have the same shape " "up until the last dimension: " f"target.shape={target.shape}, output.shape={output.shape}" ) if from_logits: log_prob = log_softmax(output, axis=axis) else: output = output / np.sum(output, axis, keepdims=True) output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) log_prob = np.log(output) target = one_hot(target, output.shape[axis], axis=axis) return -np.sum(target * log_prob, axis=axis) def binary_crossentropy(target, output, from_logits=False): target = np.array(target) output = np.array(output) if target.shape != output.shape: raise ValueError( "Arguments `target` and `output` must have the same shape. " "Received: " f"target.shape={target.shape}, output.shape={output.shape}" ) if from_logits: output = sigmoid(output) output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) bce = target * np.log(output) bce += (1.0 - target) * np.log(1.0 - output) return -bce def moments(x, axes, keepdims=False, synchronized=False): if synchronized: raise NotImplementedError( "Argument synchronized=True is not supported with NumPy." ) axes = tuple(axes) if isinstance(axes, list) else axes # The dynamic range of float16 is too limited for statistics. As a # workaround, we simply perform the operations on float32 and convert back # to float16 need_cast = False ori_dtype = backend.standardize_dtype(x.dtype) if ori_dtype == "float16": need_cast = True x = cast(x, "float32") mean = np.mean(x, axes, keepdims=True) # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster # but less numerically stable. variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean) if not keepdims: mean = np.squeeze(mean, axes) variance = np.squeeze(variance, axes) if need_cast: # avoid overflow and underflow when casting from float16 to float32 mean = np.clip(mean, np.finfo(np.float16).min, np.finfo(np.float16).max) variance = np.clip( variance, np.finfo(np.float16).min, np.finfo(np.float16).max ) mean = cast(mean, ori_dtype) variance = cast(variance, ori_dtype) return mean, variance def batch_normalization( x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 ): shape = [1] * len(x.shape) shape[axis] = mean.shape[0] mean = np.reshape(mean, shape) variance = np.reshape(variance, shape) inv = 1.0 / np.sqrt(variance + epsilon) if scale is not None: scale = np.reshape(scale, shape) inv = inv * scale res = -mean * inv if offset is not None: offset = np.reshape(offset, shape) res = res + offset return x * inv + res def ctc_loss(target, output, target_length, output_length, mask_index=0): # Ref: https://github.com/google-deepmind/optax # optax.ctc_loss_with_forward_probs target = convert_to_tensor(target, dtype="int32") output = convert_to_tensor(output) target_length = convert_to_tensor(target_length, "int32") output_length = convert_to_tensor(output_length, "int32") batch_size, max_input_length, num_classes = output.shape batch_size, max_label_length = target.shape log_epsilon = -1e5 # Ensure that the dtype promotion behavior matches that of `tf.nn.ctc_loss` dtype = backend.result_type(output.dtype, "float32") output = output.astype(dtype) def _lengths_to_paddings(lengths, max_length): indices = np.arange(max_length).reshape( (1,) * lengths.ndim + (max_length,) ) lengths = np.expand_dims(lengths, axis=-1) elem_valid = indices < lengths return np.logical_not(elem_valid) target_paddings = _lengths_to_paddings(target_length, max_label_length) output_paddings = _lengths_to_paddings(output_length, max_input_length) target_paddings = target_paddings.astype(output.dtype) output_paddings = output_paddings.astype(output.dtype) logprobs = log_softmax(output, axis=-1) label_lengths = max_label_length - np.sum(target_paddings, axis=1).astype( np.int32 ) # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. repeat = (target[:, :-1] == target[:, 1:]).astype(np.float32) repeat = np.pad(repeat, ((0, 0), (0, 1))) logprobs_phi = logprobs[:, :, mask_index : mask_index + 1] # [B, T, 1] logprobs_phi = np.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] _one_hot = one_hot(target, num_classes=num_classes) # [B, N, K] logprobs_emit = np.einsum("btk,bnk->btn", logprobs, _one_hot) logprobs_emit = np.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] # [B, N] logalpha_phi_init = ( np.ones((batch_size, max_label_length + 1), dtype=output.dtype) * log_epsilon ) logalpha_phi_init[:, 0] = 0.0 logalpha_emit_init = ( np.ones((batch_size, max_label_length), dtype=output.dtype) * log_epsilon ) def update_phi_score(phi, added_score): # Update `phi[:, 1:]`` with adding `added_score` in log space. return np.concatenate( [phi[:, :1], np.logaddexp(phi[:, 1:], added_score)], axis=-1 ) def loop_body(prev, x): prev_phi, prev_emit = prev # emit-to-phi epsilon transition, except if the next label is repetition prev_phi_orig = prev_phi prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) logprob_emit, logprob_phi, pad = x # phi-to-emit transition next_emit = np.logaddexp( prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit ) # self-loop transition next_phi = prev_phi + logprob_phi # emit-to-phi blank transition only when the next label is repetition next_phi = update_phi_score( next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat) ) pad = pad.reshape((batch_size, 1)) next_emit = pad * prev_emit + (1.0 - pad) * next_emit next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi return (next_phi, next_emit), (next_phi, next_emit) def np_scan(f, init, xs): carry = init ys = [] for x in zip(*xs): carry, y = f(carry, x) ys.append(y) result = [] for i in range(len(ys[0])): result.append(np.stack([y[i] for y in ys])) return carry, result xs = (logprobs_emit, logprobs_phi, output_paddings.transpose((1, 0))) _, (logalpha_phi, logalpha_emit) = np_scan( loop_body, (logalpha_phi_init, logalpha_emit_init), xs ) # last row needs to be updated with the last epsilon transition logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) logalpha_phi[-1] = logalpha_phi_last # extract per_seq_loss # [B, N+1] _one_hot = one_hot(label_lengths, num_classes=max_label_length + 1) per_seq_loss = -np.einsum("bn,bn->b", logalpha_phi_last, _one_hot) return per_seq_loss def _ctc_greedy_decode( inputs, sequence_lengths, merge_repeated=True, mask_index=None, ): inputs = convert_to_tensor(inputs) sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") batch_size, max_length, num_classes = inputs.shape if mask_index is None: mask_index = num_classes - 1 indices = np.argmax(inputs, axis=-1).astype("int32") scores = np.max(inputs, axis=-1) seqlen_mask = np.arange(max_length)[None, :] seqlen_mask = seqlen_mask >= sequence_lengths[:, None] indices = np.where(seqlen_mask, mask_index, indices) scores = np.where(seqlen_mask, 0.0, scores) if merge_repeated: repeat_mask = indices[:, 1:] == indices[:, :-1] repeat_mask = np.pad(repeat_mask, ((0, 0), (1, 0))) indices = np.where(repeat_mask, mask_index, indices) # We set to -1 for blank labels invalid_mask = indices == mask_index indices = np.where(invalid_mask, -1, indices) # We rearrange the indices by moving `mask_index` to the end of the array order = np.expand_dims(np.arange(max_length), axis=0) # [1, N] order = np.tile(order, (batch_size, 1)) # [B, N] order = np.where(invalid_mask, max_length, order) order = np.argsort(order, axis=-1) indices = np.take_along_axis(indices, order, axis=-1) scores = -np.sum(scores, axis=1)[:, None] indices = np.expand_dims(indices, axis=0) return indices, scores def _ctc_beam_search_decode( inputs, sequence_lengths, beam_width=100, top_paths=1, mask_index=None, ): inputs = convert_to_tensor(inputs) sequence_lengths = convert_to_tensor(sequence_lengths) batch_size, max_seq_len, num_classes = inputs.shape inputs = log_softmax(inputs, axis=-1) seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_lengths[:, None] if mask_index is None: mask_index = num_classes - 1 # This is a workaround for the fact that np.argsort does not support # the order parameter which is used to break ties when scores are equal. # For compatibility with the tensorflow implementation, we flip the inputs # and the mask_index, and then flip the classes back to the correct indices inputs = np.flip(inputs, axis=2) mask_index = num_classes - mask_index - 1 _pad = -1 init_paths = np.full( (batch_size, 2 * beam_width, max_seq_len), _pad, dtype=np.int32 ) num_init_paths = np.min(np.array([num_classes, beam_width])) max_classes = np.argsort(inputs[:, 0], axis=1)[:, -num_init_paths:] init_classes = np.where(max_classes == mask_index, _pad, max_classes) init_paths[:, :num_init_paths, 0] = init_classes init_scores = np.full( (batch_size, 2 * beam_width), -np.inf, dtype=inputs.dtype ) init_scores[:, :num_init_paths] = np.take_along_axis( inputs[:, 0], max_classes, axis=1 ) init_masked = init_paths[:, :, 0] == _pad def _extend_paths(paths, scores, masked, x): paths = np.repeat(paths, num_classes, axis=0) scores = np.repeat(scores, num_classes) masked = np.repeat(masked, num_classes) path_tail_index = np.argmax(paths == _pad, axis=1) paths_arange = np.arange(2 * beam_width * num_classes) path_tails = paths[paths_arange, path_tail_index - 1] path_tails = np.where(path_tail_index == 0, _pad, path_tails) classes = np.arange(num_classes) classes[mask_index] = _pad classes = np.tile(classes, 2 * beam_width) prev_masked = masked masked = classes == _pad masked_repeat = ~prev_masked & (path_tails == classes) classes = np.where(masked_repeat, _pad, classes) paths[paths_arange, path_tail_index] = classes x = np.tile(x, 2 * beam_width) scores = scores + x return paths, scores, masked def _merge_scores(unique_inverse, scores): scores_max = np.max(scores) scores_exp = np.exp(scores - scores_max) scores = np.zeros_like(scores) for i, u in enumerate(unique_inverse): scores[u] += scores_exp[i] scores = np.log(scores) + scores_max return scores def _prune_paths(paths, scores, masked): paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) pad_size = (2 * num_classes * beam_width) - len(paths) if pad_size > 0: paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) paths = paths[: 2 * num_classes * beam_width] if len(unique_inverse.shape) >= 2: unique_inverse = np.squeeze(unique_inverse, axis=1) emit_scores = np.where(masked, -np.inf, scores) mask_scores = np.where(masked, scores, -np.inf) emit_scores = _merge_scores(unique_inverse, emit_scores) mask_scores = _merge_scores(unique_inverse, mask_scores) total_scores = np.logaddexp(emit_scores, mask_scores) top_indices = np.argsort(total_scores, kind="stable")[-beam_width:] paths = paths[top_indices] emit_scores = emit_scores[top_indices] mask_scores = mask_scores[top_indices] paths = np.tile(paths, (2, 1)) scores = np.concatenate([emit_scores, mask_scores]) masked = np.concatenate( [np.zeros(beam_width, bool), np.ones(beam_width, bool)] ) return paths, scores, masked def _decode_step(paths, scores, masked, x): paths, scores, masked = _extend_paths(paths, scores, masked, x) paths, scores, masked = _prune_paths(paths, scores, masked) return paths, scores, masked def _step(prev, x): paths, scores, masked = prev x, seqlen_mask = x if not seqlen_mask: paths, scores, masked = _decode_step(paths, scores, masked, x) return (paths, scores, masked), None def _decode_batch( init_paths, init_scores, init_masked, inputs, seqlen_mask ): def np_scan_only_carry(f, init, xs): carry = init for x in zip(*xs): carry, y = f(carry, x) return carry, None (paths, scores, masked), _ = np_scan_only_carry( _step, (init_paths, init_scores, init_masked), (inputs[1:], seqlen_mask[1:]), ) paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) pad_size = (2 * num_classes * beam_width) - len(paths) if pad_size > 0: paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) paths = paths[: 2 * num_classes * beam_width] if len(unique_inverse.shape) >= 2: unique_inverse = np.squeeze(unique_inverse, axis=1) scores = _merge_scores(unique_inverse, scores) top_indices = np.argsort(scores)[-top_paths:][::-1] paths = paths[top_indices] scores = scores[top_indices] return paths, scores results = [ _decode_batch(p, s, m, i, sm) for p, s, m, i, sm in zip( init_paths, init_scores, init_masked, inputs, seqlen_mask ) ] paths = np.stack([r[0] for r in results]) scores = np.stack([r[1] for r in results]) # convert classes back to the correct indices paths = np.where(paths == _pad, _pad, num_classes - paths - 1) paths = np.transpose(paths, [1, 0, 2]) return paths, scores def ctc_decode( inputs, sequence_lengths, strategy="greedy", beam_width=100, top_paths=1, merge_repeated=True, mask_index=0, ): inputs = convert_to_tensor(inputs) dtype = backend.result_type(inputs.dtype, "float32") inputs = cast(inputs, dtype) if strategy == "greedy": return _ctc_greedy_decode( inputs, sequence_lengths, merge_repeated=merge_repeated, mask_index=mask_index, ) elif strategy == "beam_search": return _ctc_beam_search_decode( inputs, sequence_lengths, beam_width=beam_width, top_paths=top_paths, mask_index=mask_index, ) else: raise ValueError( f"Invalid strategy {strategy}. Supported values are " "'greedy' and 'beam_search'." ) def psnr(x1, x2, max_val): if x1.shape != x2.shape: raise ValueError( f"Input shapes {x1.shape} and {x2.shape} must " "match for PSNR calculation. " ) max_val = convert_to_tensor(max_val, dtype=x2.dtype) mse = np.mean(np.square(x1 - x2)) psnr = 20 * np.log10(max_val) - 10 * np.log10(mse) return psnr def _get_large_negative(dtype): dtype = backend.standardize_dtype(dtype) val = 65500.0 if dtype == "float16" else 3.38953e38 return np.asarray(val * -0.7, dtype=dtype) def _apply_masks(logits, mask, is_causal): if mask is None and not is_causal: return logits combined_mask = np.ones_like(logits, dtype=np.bool_) if mask is not None: combined_mask = np.logical_and(combined_mask, mask) if is_causal: T, S = logits.shape[2], logits.shape[3] mask = np.tril(np.ones((T, S), dtype=np.bool_)) mask = mask[None, None, :, :] combined_mask = np.logical_and(combined_mask, mask) padded_logits = np.where( combined_mask, logits, _get_large_negative(logits.dtype) ) return padded_logits def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): original_dtype = key.dtype logits_dtype = np.promote_types(query.dtype, np.float32) if backend.standardize_dtype(key.dtype) == "bfloat16": # `np.einsum` doesn't support bfloat16 key = key.astype("float32") value = value.astype("float32") logits = np.einsum("BTNH,BSNH->BNTS", query, key) logits = logits.astype(logits_dtype) logits *= np.array(scale, dtype=logits.dtype) if bias is not None: logits = (logits + bias).astype(logits.dtype) padded_logits = _apply_masks(logits, mask, is_causal) # Softmax and it is always carried out in fp32. padded_logits = padded_logits.astype(np.float32) probs = softmax(padded_logits, axis=-1).astype(original_dtype) encoded_dtype = probs.dtype if backend.standardize_dtype(probs.dtype) == "bfloat16": # `np.einsum` doesn't support bfloat16 probs = probs.astype("float32") value = value.astype("float32") encoded = np.einsum("BNTS,BSNH->BTNH", probs, value) encoded = encoded.astype(encoded_dtype) return encoded def dot_product_attention( query, key, value, bias=None, mask=None, scale=None, is_causal=False, flash_attention=None, attn_logits_soft_cap=None, ): if flash_attention is None: flash_attention = False if flash_attention: raise ValueError("Flash attention is not supported in numpy backend.") # Ref: jax.nn.dot_product_attention # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828 # Not support `query_seq_lengths` and `key_value_seq_lengths` args query = convert_to_tensor(query) key = convert_to_tensor(key) value = convert_to_tensor(value) if len(query.shape) != 4: raise ValueError( "`dot_product_attention` only supports 4D inputs. " f"Received: query.shape={query.shape}, key.shape={key.shape}, " f"value.shape={value.shape}." ) _, _, _, H = key.shape scale = (1.0 / np.sqrt(H)) if scale is None else scale return _dot_product_attention_xla( query, key, value, bias, mask, is_causal, scale )