joebruce1313's picture
Upload 38004 files
1f5470c verified
"""Legacy Keras 1/2 backend functions."""
import itertools
import numpy as np
from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.utils.module_utils import tensorflow as tf
py_any = any
py_all = all
@keras_export("keras._legacy.backend.abs")
def abs(x):
"""DEPRECATED."""
return tf.abs(x)
@keras_export("keras._legacy.backend.all")
def all(x, axis=None, keepdims=False):
"""DEPRECATED."""
x = tf.cast(x, tf.bool)
return tf.reduce_all(x, axis, keepdims)
@keras_export("keras._legacy.backend.any")
def any(x, axis=None, keepdims=False):
"""DEPRECATED."""
x = tf.cast(x, tf.bool)
return tf.reduce_any(x, axis, keepdims)
@keras_export("keras._legacy.backend.argmax")
def argmax(x, axis=-1):
"""DEPRECATED."""
return tf.argmax(x, axis)
@keras_export("keras._legacy.backend.argmin")
def argmin(x, axis=-1):
"""DEPRECATED."""
return tf.argmin(x, axis)
@keras_export("keras._legacy.backend.arange")
def arange(start, stop=None, step=1, dtype="int32"):
"""DEPRECATED."""
if stop is None and start < 0:
start = 0
result = tf.range(start, limit=stop, delta=step, name="arange")
if dtype != "int32":
result = tf.cast(result, dtype)
return result
@keras_export("keras._legacy.backend.batch_dot")
def batch_dot(x, y, axes=None):
"""DEPRECATED."""
x_shape = x.shape
y_shape = y.shape
x_ndim = len(x_shape)
y_ndim = len(y_shape)
if x_ndim < 2 or y_ndim < 2:
raise ValueError(
"Cannot do batch_dot on inputs "
"with rank < 2. "
"Received inputs with tf.shapes "
+ str(x_shape)
+ " and "
+ str(y_shape)
+ "."
)
x_batch_size = x_shape[0]
y_batch_size = y_shape[0]
if x_batch_size is not None and y_batch_size is not None:
if x_batch_size != y_batch_size:
raise ValueError(
"Cannot do batch_dot on inputs "
"with different batch sizes. "
"Received inputs with tf.shapes "
+ str(x_shape)
+ " and "
+ str(y_shape)
+ "."
)
if isinstance(axes, int):
axes = [axes, axes]
if axes is None:
if y_ndim == 2:
axes = [x_ndim - 1, y_ndim - 1]
else:
axes = [x_ndim - 1, y_ndim - 2]
if py_any(isinstance(a, (list, tuple)) for a in axes):
raise ValueError(
"Multiple target dimensions are not supported. "
+ "Expected: None, int, (int, int), "
+ "Provided: "
+ str(axes)
)
# if tuple, convert to list.
axes = list(axes)
# convert negative indices.
if axes[0] < 0:
axes[0] += x_ndim
if axes[1] < 0:
axes[1] += y_ndim
# sanity checks
if 0 in axes:
raise ValueError(
"Cannot perform batch_dot over axis 0. "
"If your inputs are not batched, "
"add a dummy batch dimension to your "
"inputs using K.expand_dims(x, 0)"
)
a0, a1 = axes
d1 = x_shape[a0]
d2 = y_shape[a1]
if d1 is not None and d2 is not None and d1 != d2:
raise ValueError(
"Cannot do batch_dot on inputs with tf.shapes "
+ str(x_shape)
+ " and "
+ str(y_shape)
+ " with axes="
+ str(axes)
+ ". x.shape[%d] != y.shape[%d] (%d != %d)."
% (axes[0], axes[1], d1, d2)
)
# backup ndims. Need them later.
orig_x_ndim = x_ndim
orig_y_ndim = y_ndim
# if rank is 2, expand to 3.
if x_ndim == 2:
x = tf.expand_dims(x, 1)
a0 += 1
x_ndim += 1
if y_ndim == 2:
y = tf.expand_dims(y, 2)
y_ndim += 1
# bring x's dimension to be reduced to last axis.
if a0 != x_ndim - 1:
pattern = list(range(x_ndim))
for i in range(a0, x_ndim - 1):
pattern[i] = pattern[i + 1]
pattern[-1] = a0
x = tf.transpose(x, pattern)
# bring y's dimension to be reduced to axis 1.
if a1 != 1:
pattern = list(range(y_ndim))
for i in range(a1, 1, -1):
pattern[i] = pattern[i - 1]
pattern[1] = a1
y = tf.transpose(y, pattern)
# normalize both inputs to rank 3.
if x_ndim > 3:
# squash middle dimensions of x.
x_shape = tf.shape(x)
x_mid_dims = x_shape[1:-1]
x_squashed_shape = tf.stack([x_shape[0], -1, x_shape[-1]])
x = tf.reshape(x, x_squashed_shape)
x_squashed = True
else:
x_squashed = False
if y_ndim > 3:
# squash trailing dimensions of y.
y_shape = tf.shape(y)
y_trail_dims = y_shape[2:]
y_squashed_shape = tf.stack([y_shape[0], y_shape[1], -1])
y = tf.reshape(y, y_squashed_shape)
y_squashed = True
else:
y_squashed = False
result = tf.matmul(x, y)
# if inputs were squashed, we have to reshape the matmul output.
output_shape = tf.shape(result)
do_reshape = False
if x_squashed:
output_shape = tf.concat(
[output_shape[:1], x_mid_dims, output_shape[-1:]], 0
)
do_reshape = True
if y_squashed:
output_shape = tf.concat([output_shape[:-1], y_trail_dims], 0)
do_reshape = True
if do_reshape:
result = tf.reshape(result, output_shape)
# if the inputs were originally rank 2, we remove the added 1 dim.
if orig_x_ndim == 2:
result = tf.squeeze(result, 1)
elif orig_y_ndim == 2:
result = tf.squeeze(result, -1)
return result
@keras_export("keras._legacy.backend.batch_flatten")
def batch_flatten(x):
"""DEPRECATED."""
x = tf.reshape(x, tf.stack([-1, prod(tf.shape(x)[1:])]))
return x
@keras_export("keras._legacy.backend.batch_get_value")
def batch_get_value(tensors):
"""DEPRECATED."""
return [x.numpy() for x in tensors]
@keras_export("keras._legacy.backend.batch_set_value")
def batch_set_value(tuples):
"""DEPRECATED."""
if tf.executing_eagerly() or tf.inside_function():
for x, value in tuples:
value = np.asarray(value, dtype=x.dtype.name)
x.assign(value)
@keras_export("keras._legacy.backend.batch_normalization")
def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
"""DEPRECATED."""
return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
@keras_export("keras._legacy.backend.bias_add")
def bias_add(x, bias, data_format=None):
"""DEPRECATED."""
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
bias_shape = bias.shape
if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
raise ValueError(
f"Unexpected bias dimensions {len(bias_shape)}. "
f"Expected it to be 1 or {ndim(x) - 1} dimensions"
)
if len(bias_shape) == 1:
if data_format == "channels_first":
return tf.nn.bias_add(x, bias, data_format="NCHW")
return tf.nn.bias_add(x, bias, data_format="NHWC")
if ndim(x) in (3, 4, 5):
if data_format == "channels_first":
bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]
return x + reshape(bias, bias_reshape_axis)
return x + reshape(bias, (1,) + bias_shape)
return tf.nn.bias_add(x, bias)
@keras_export("keras._legacy.backend.binary_crossentropy")
def binary_crossentropy(target, output, from_logits=False):
"""DEPRECATED."""
target = tf.convert_to_tensor(target)
output = tf.convert_to_tensor(output)
if from_logits:
return tf.nn.sigmoid_cross_entropy_with_logits(
labels=target, logits=output
)
epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
# Compute cross entropy from probabilities.
bce = target * tf.math.log(output + backend.epsilon())
bce += (1 - target) * tf.math.log(1 - output + backend.epsilon())
return -bce
@keras_export("keras._legacy.backend.binary_focal_crossentropy")
def binary_focal_crossentropy(
target,
output,
apply_class_balancing=False,
alpha=0.25,
gamma=2.0,
from_logits=False,
):
"""DEPRECATED."""
sigmoidal = tf.sigmoid(output) if from_logits else output
p_t = target * sigmoidal + (1 - target) * (1 - sigmoidal)
# Calculate focal factor
focal_factor = tf.pow(1.0 - p_t, gamma)
# Binary crossentropy
bce = binary_crossentropy(
target=target,
output=output,
from_logits=from_logits,
)
focal_bce = focal_factor * bce
if apply_class_balancing:
weight = target * alpha + (1 - target) * (1 - alpha)
focal_bce = weight * focal_bce
return focal_bce
@keras_export("keras._legacy.backend.cast")
def cast(x, dtype):
"""DEPRECATED."""
return tf.cast(x, dtype)
@keras_export("keras._legacy.backend.cast_to_floatx")
def cast_to_floatx(x):
"""DEPRECATED."""
if isinstance(x, (tf.Tensor, tf.Variable, tf.SparseTensor)):
return tf.cast(x, dtype=backend.floatx())
return np.asarray(x, dtype=backend.floatx())
@keras_export("keras._legacy.backend.categorical_crossentropy")
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""DEPRECATED."""
target = tf.convert_to_tensor(target)
output = tf.convert_to_tensor(output)
target.shape.assert_is_compatible_with(output.shape)
if from_logits:
return tf.nn.softmax_cross_entropy_with_logits(
labels=target, logits=output, axis=axis
)
# Adjust the predictions so that the probability of
# each class for every sample adds up to 1
# This is needed to ensure that the cross entropy is
# computed correctly.
output = output / tf.reduce_sum(output, axis, True)
# Compute cross entropy from probabilities.
epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
return -tf.reduce_sum(target * tf.math.log(output), axis)
@keras_export("keras._legacy.backend.categorical_focal_crossentropy")
def categorical_focal_crossentropy(
target,
output,
alpha=0.25,
gamma=2.0,
from_logits=False,
axis=-1,
):
"""DEPRECATED."""
target = tf.convert_to_tensor(target)
output = tf.convert_to_tensor(output)
target.shape.assert_is_compatible_with(output.shape)
if from_logits:
output = tf.nn.softmax(output, axis=axis)
# Adjust the predictions so that the probability of
# each class for every sample adds up to 1
# This is needed to ensure that the cross entropy is
# computed correctly.
output = output / tf.reduce_sum(output, axis=axis, keepdims=True)
epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
# Calculate cross entropy
cce = -target * tf.math.log(output)
# Calculate factors
modulating_factor = tf.pow(1.0 - output, gamma)
weighting_factor = tf.multiply(modulating_factor, alpha)
# Apply weighting factor
focal_cce = tf.multiply(weighting_factor, cce)
focal_cce = tf.reduce_sum(focal_cce, axis=axis)
return focal_cce
@keras_export("keras._legacy.backend.clip")
def clip(x, min_value, max_value):
"""DEPRECATED."""
if isinstance(min_value, (int, float)) and isinstance(
max_value, (int, float)
):
if max_value < min_value:
max_value = min_value
if min_value is None:
min_value = -np.inf
if max_value is None:
max_value = np.inf
return tf.clip_by_value(x, min_value, max_value)
@keras_export("keras._legacy.backend.concatenate")
def concatenate(tensors, axis=-1):
"""DEPRECATED."""
if axis < 0:
rank = ndim(tensors[0])
if rank:
axis %= rank
else:
axis = 0
if py_all(is_sparse(x) for x in tensors):
return tf.compat.v1.sparse_concat(axis, tensors)
elif py_all(isinstance(x, tf.RaggedTensor) for x in tensors):
return tf.concat(tensors, axis)
else:
return tf.concat([to_dense(x) for x in tensors], axis)
@keras_export("keras._legacy.backend.constant")
def constant(value, dtype=None, shape=None, name=None):
"""DEPRECATED."""
if dtype is None:
dtype = backend.floatx()
return tf.constant(value, dtype=dtype, shape=shape, name=name)
def _preprocess_conv1d_input(x, data_format):
tf_data_format = "NWC" # to pass TF Conv2dNative operations
if data_format == "channels_first":
tf_data_format = "NCW"
return x, tf_data_format
def _preprocess_conv2d_input(x, data_format, force_transpose=False):
tf_data_format = "NHWC"
if data_format == "channels_first":
if force_transpose:
x = tf.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC
else:
tf_data_format = "NCHW"
return x, tf_data_format
def _preprocess_conv3d_input(x, data_format):
tf_data_format = "NDHWC"
if data_format == "channels_first":
tf_data_format = "NCDHW"
return x, tf_data_format
def _preprocess_padding(padding):
if padding == "same":
padding = "SAME"
elif padding == "valid":
padding = "VALID"
else:
raise ValueError(f"Invalid padding: {padding}")
return padding
@keras_export("keras._legacy.backend.conv1d")
def conv1d(
x, kernel, strides=1, padding="valid", data_format=None, dilation_rate=1
):
"""DEPRECATED."""
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
kernel_shape = kernel.shape.as_list()
if padding == "causal":
# causal (dilated) convolution:
left_pad = dilation_rate * (kernel_shape[0] - 1)
x = temporal_padding(x, (left_pad, 0))
padding = "valid"
padding = _preprocess_padding(padding)
x, tf_data_format = _preprocess_conv1d_input(x, data_format)
x = tf.compat.v1.nn.convolution(
input=x,
filter=kernel,
dilation_rate=dilation_rate,
strides=strides,
padding=padding,
data_format=tf_data_format,
)
if data_format == "channels_first" and tf_data_format == "NWC":
x = tf.transpose(x, (0, 2, 1)) # NWC -> NCW
return x
@keras_export("keras._legacy.backend.conv2d")
def conv2d(
x,
kernel,
strides=(1, 1),
padding="valid",
data_format=None,
dilation_rate=(1, 1),
):
"""DEPRECATED."""
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
padding = _preprocess_padding(padding)
x = tf.compat.v1.nn.convolution(
input=x,
filter=kernel,
dilation_rate=dilation_rate,
strides=strides,
padding=padding,
data_format=tf_data_format,
)
if data_format == "channels_first" and tf_data_format == "NHWC":
x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
return x
@keras_export("keras._legacy.backend.conv2d_transpose")
def conv2d_transpose(
x,
kernel,
output_shape,
strides=(1, 1),
padding="valid",
data_format=None,
dilation_rate=(1, 1),
):
"""DEPRECATED."""
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
# `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
if data_format == "channels_first" and dilation_rate != (1, 1):
force_transpose = True
else:
force_transpose = False
x, tf_data_format = _preprocess_conv2d_input(
x, data_format, force_transpose
)
if data_format == "channels_first" and tf_data_format == "NHWC":
output_shape = (
output_shape[0],
output_shape[2],
output_shape[3],
output_shape[1],
)
if output_shape[0] is None:
output_shape = (tf.shape(x)[0],) + tuple(output_shape[1:])
if isinstance(output_shape, (tuple, list)):
output_shape = tf.stack(list(output_shape))
padding = _preprocess_padding(padding)
if tf_data_format == "NHWC":
strides = (1,) + strides + (1,)
else:
strides = (1, 1) + strides
if dilation_rate == (1, 1):
x = tf.compat.v1.nn.conv2d_transpose(
x,
kernel,
output_shape,
strides,
padding=padding,
data_format=tf_data_format,
)
else:
if dilation_rate[0] != dilation_rate[1]:
raise ValueError(
"Expected the 2 dimensions of the `dilation_rate` argument "
"to be equal to each other. "
f"Received: dilation_rate={dilation_rate}"
)
x = tf.nn.atrous_conv2d_transpose(
x, kernel, output_shape, rate=dilation_rate[0], padding=padding
)
if data_format == "channels_first" and tf_data_format == "NHWC":
x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
return x
@keras_export("keras._legacy.backend.conv3d")
def conv3d(
x,
kernel,
strides=(1, 1, 1),
padding="valid",
data_format=None,
dilation_rate=(1, 1, 1),
):
"""DEPRECATED."""
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
x, tf_data_format = _preprocess_conv3d_input(x, data_format)
padding = _preprocess_padding(padding)
x = tf.compat.v1.nn.convolution(
input=x,
filter=kernel,
dilation_rate=dilation_rate,
strides=strides,
padding=padding,
data_format=tf_data_format,
)
if data_format == "channels_first" and tf_data_format == "NDHWC":
x = tf.transpose(x, (0, 4, 1, 2, 3))
return x
@keras_export("keras._legacy.backend.cos")
def cos(x):
"""DEPRECATED."""
return tf.cos(x)
@keras_export("keras._legacy.backend.count_params")
def count_params(x):
"""DEPRECATED."""
return np.prod(x.shape.as_list())
@keras_export("keras._legacy.backend.ctc_batch_cost")
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
"""DEPRECATED."""
label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32)
input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32)
sparse_labels = tf.cast(
ctc_label_dense_to_sparse(y_true, label_length), tf.int32
)
y_pred = tf.math.log(
tf.transpose(y_pred, perm=[1, 0, 2]) + backend.epsilon()
)
return tf.expand_dims(
tf.compat.v1.nn.ctc_loss(
inputs=y_pred, labels=sparse_labels, sequence_length=input_length
),
1,
)
@keras_export("keras._legacy.backend.ctc_label_dense_to_sparse")
def ctc_label_dense_to_sparse(labels, label_lengths):
"""DEPRECATED."""
label_shape = tf.shape(labels)
num_batches_tns = tf.stack([label_shape[0]])
max_num_labels_tns = tf.stack([label_shape[1]])
def range_less_than(old_input, current_input):
return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill(
max_num_labels_tns, current_input
)
init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool)
dense_mask = tf.compat.v1.scan(
range_less_than, label_lengths, initializer=init, parallel_iterations=1
)
dense_mask = dense_mask[:, 0, :]
label_array = tf.reshape(
tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape
)
label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)
batch_array = tf.transpose(
tf.reshape(
tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns),
reverse(label_shape, 0),
)
)
batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)
indices = tf.transpose(
tf.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1])
)
vals_sparse = tf.compat.v1.gather_nd(labels, indices)
return tf.SparseTensor(
tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64)
)
@keras_export("keras._legacy.backend.ctc_decode")
def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
"""DEPRECATED."""
input_shape = tf.shape(y_pred)
num_samples, num_steps = input_shape[0], input_shape[1]
y_pred = tf.math.log(
tf.transpose(y_pred, perm=[1, 0, 2]) + backend.epsilon()
)
input_length = tf.cast(input_length, tf.int32)
if greedy:
(decoded, log_prob) = tf.nn.ctc_greedy_decoder(
inputs=y_pred, sequence_length=input_length
)
else:
(decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
inputs=y_pred,
sequence_length=input_length,
beam_width=beam_width,
top_paths=top_paths,
)
decoded_dense = []
for st in decoded:
st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
return (decoded_dense, log_prob)
@keras_export("keras._legacy.backend.cumsum")
def cumsum(x, axis=0):
"""DEPRECATED."""
return tf.cumsum(x, axis=axis)
@keras_export("keras._legacy.backend.cumprod")
def cumprod(x, axis=0):
"""DEPRECATED."""
return tf.math.cumprod(x, axis=axis)
@keras_export("keras._legacy.backend.depthwise_conv2d")
def depthwise_conv2d(
x,
depthwise_kernel,
strides=(1, 1),
padding="valid",
data_format=None,
dilation_rate=(1, 1),
):
"""DEPRECATED."""
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
padding = _preprocess_padding(padding)
if tf_data_format == "NHWC":
strides = (1,) + strides + (1,)
else:
strides = (1, 1) + strides
x = tf.nn.depthwise_conv2d(
x,
depthwise_kernel,
strides=strides,
padding=padding,
dilations=dilation_rate,
data_format=tf_data_format,
)
if data_format == "channels_first" and tf_data_format == "NHWC":
x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
return x
@keras_export("keras._legacy.backend.dot")
def dot(x, y):
"""DEPRECATED."""
if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
x_shape = []
for i, s in zip(x.shape, tf.unstack(tf.shape(x))):
if i is not None:
x_shape.append(i)
else:
x_shape.append(s)
x_shape = tuple(x_shape)
y_shape = []
for i, s in zip(y.shape, tf.unstack(tf.shape(y))):
if i is not None:
y_shape.append(i)
else:
y_shape.append(s)
y_shape = tuple(y_shape)
y_permute_dim = list(range(ndim(y)))
y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
xt = tf.reshape(x, [-1, x_shape[-1]])
yt = tf.reshape(tf.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
return tf.reshape(
tf.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:]
)
if is_sparse(x):
out = tf.sparse.sparse_dense_matmul(x, y)
else:
out = tf.matmul(x, y)
return out
@keras_export("keras._legacy.backend.dropout")
def dropout(x, level, noise_shape=None, seed=None):
"""DEPRECATED."""
if seed is None:
seed = np.random.randint(10e6)
return tf.nn.dropout(x, rate=level, noise_shape=noise_shape, seed=seed)
@keras_export("keras._legacy.backend.dtype")
def dtype(x):
"""DEPRECATED."""
return x.dtype.base_dtype.name
@keras_export("keras._legacy.backend.elu")
def elu(x, alpha=1.0):
"""DEPRECATED."""
res = tf.nn.elu(x)
if alpha == 1:
return res
else:
return tf.where(x > 0, res, alpha * res)
@keras_export("keras._legacy.backend.equal")
def equal(x, y):
"""DEPRECATED."""
return tf.equal(x, y)
@keras_export("keras._legacy.backend.eval")
def eval(x):
"""DEPRECATED."""
return get_value(to_dense(x))
@keras_export("keras._legacy.backend.exp")
def exp(x):
"""DEPRECATED."""
return tf.exp(x)
@keras_export("keras._legacy.backend.expand_dims")
def expand_dims(x, axis=-1):
"""DEPRECATED."""
return tf.expand_dims(x, axis)
@keras_export("keras._legacy.backend.eye")
def eye(size, dtype=None, name=None):
"""DEPRECATED."""
if dtype is None:
dtype = backend.floatx()
tf_dtype = tf.as_dtype(dtype)
return variable(tf.eye(size, dtype=tf_dtype), dtype, name)
@keras_export("keras._legacy.backend.flatten")
def flatten(x):
"""DEPRECATED."""
return tf.reshape(x, [-1])
@keras_export("keras._legacy.backend.foldl")
def foldl(fn, elems, initializer=None, name=None):
"""DEPRECATED."""
return tf.compat.v1.foldl(fn, elems, initializer=initializer, name=name)
@keras_export("keras._legacy.backend.foldr")
def foldr(fn, elems, initializer=None, name=None):
"""DEPRECATED."""
return tf.compat.v1.foldr(fn, elems, initializer=initializer, name=name)
@keras_export("keras._legacy.backend.gather")
def gather(reference, indices):
"""DEPRECATED."""
return tf.compat.v1.gather(reference, indices)
@keras_export("keras._legacy.backend.get_value")
def get_value(x):
"""DEPRECATED."""
if not tf.is_tensor(x):
return x
if tf.executing_eagerly() or isinstance(x, tf.__internal__.EagerTensor):
return x.numpy()
if not getattr(x, "_in_graph_mode", True):
# This is a variable which was created in an eager context, but is being
# evaluated from a Graph.
with tf.__internal__.eager_context.eager_mode():
return x.numpy()
with tf.init_scope():
return x.numpy()
@keras_export("keras._legacy.backend.gradients")
def gradients(loss, variables):
"""DEPRECATED."""
return tf.compat.v1.gradients(
loss, variables, colocate_gradients_with_ops=True
)
@keras_export("keras._legacy.backend.greater")
def greater(x, y):
"""DEPRECATED."""
return tf.greater(x, y)
@keras_export("keras._legacy.backend.greater_equal")
def greater_equal(x, y):
"""DEPRECATED."""
return tf.greater_equal(x, y)
@keras_export("keras._legacy.backend.hard_sigmoid")
def hard_sigmoid(x):
"""DEPRECATED."""
point_two = tf.convert_to_tensor(0.2, dtype=x.dtype)
point_five = tf.convert_to_tensor(0.5, dtype=x.dtype)
x = tf.multiply(x, point_two)
x = tf.add(x, point_five)
x = tf.clip_by_value(x, 0.0, 1.0)
return x
@keras_export("keras._legacy.backend.in_top_k")
def in_top_k(predictions, targets, k):
"""DEPRECATED."""
return tf.compat.v1.math.in_top_k(predictions, targets, k)
@keras_export("keras._legacy.backend.int_shape")
def int_shape(x):
"""DEPRECATED."""
try:
shape = x.shape
if not isinstance(shape, tuple):
shape = tuple(shape.as_list())
return shape
except ValueError:
return None
@keras_export("keras._legacy.backend.is_sparse")
def is_sparse(tensor):
"""DEPRECATED."""
spec = getattr(tensor, "_type_spec", None)
if spec is not None:
return isinstance(spec, tf.SparseTensorSpec)
return isinstance(tensor, tf.SparseTensor)
@keras_export("keras._legacy.backend.l2_normalize")
def l2_normalize(x, axis=None):
"""DEPRECATED."""
return tf.linalg.l2_normalize(x, axis=axis)
@keras_export("keras._legacy.backend.less")
def less(x, y):
"""DEPRECATED."""
return tf.less(x, y)
@keras_export("keras._legacy.backend.less_equal")
def less_equal(x, y):
"""DEPRECATED."""
return tf.less_equal(x, y)
@keras_export("keras._legacy.backend.log")
def log(x):
"""DEPRECATED."""
return tf.math.log(x)
@keras_export("keras._legacy.backend.map_fn")
def map_fn(fn, elems, name=None, dtype=None):
"""DEPRECATED."""
return tf.compat.v1.map_fn(fn, elems, name=name, dtype=dtype)
@keras_export("keras._legacy.backend.max")
def max(x, axis=None, keepdims=False):
"""DEPRECATED."""
return tf.reduce_max(x, axis, keepdims)
@keras_export("keras._legacy.backend.maximum")
def maximum(x, y):
"""DEPRECATED."""
return tf.maximum(x, y)
@keras_export("keras._legacy.backend.mean")
def mean(x, axis=None, keepdims=False):
"""DEPRECATED."""
if x.dtype.base_dtype == tf.bool:
x = tf.cast(x, backend.floatx())
return tf.reduce_mean(x, axis, keepdims)
@keras_export("keras._legacy.backend.min")
def min(x, axis=None, keepdims=False):
"""DEPRECATED."""
return tf.reduce_min(x, axis, keepdims)
@keras_export("keras._legacy.backend.minimum")
def minimum(x, y):
"""DEPRECATED."""
return tf.minimum(x, y)
@keras_export("keras._legacy.backend.moving_average_update")
def moving_average_update(x, value, momentum):
"""DEPRECATED."""
momentum = tf.cast(momentum, x.dtype)
value = tf.cast(value, x.dtype)
return x.assign_sub((x - value) * (1 - momentum))
@keras_export("keras._legacy.backend.name_scope")
def name_scope(name):
"""DEPRECATED."""
return tf.name_scope(name)
@keras_export("keras._legacy.backend.ndim")
def ndim(x):
"""DEPRECATED."""
return x.shape.rank
@keras_export("keras._legacy.backend.not_equal")
def not_equal(x, y):
"""DEPRECATED."""
return tf.not_equal(x, y)
@keras_export("keras._legacy.backend.one_hot")
def one_hot(indices, num_classes):
"""DEPRECATED."""
return tf.one_hot(indices, depth=num_classes, axis=-1)
@keras_export("keras._legacy.backend.ones")
def ones(shape, dtype=None, name=None):
"""DEPRECATED."""
with tf.init_scope():
if dtype is None:
dtype = backend.floatx()
tf_dtype = tf.as_dtype(dtype)
v = tf.ones(shape=shape, dtype=tf_dtype, name=name)
if py_all(v.shape.as_list()):
return variable(v, dtype=dtype, name=name)
return v
@keras_export("keras._legacy.backend.ones_like")
def ones_like(x, dtype=None, name=None):
"""DEPRECATED."""
return tf.ones_like(x, dtype=dtype, name=name)
@keras_export("keras._legacy.backend.permute_dimensions")
def permute_dimensions(x, pattern):
"""DEPRECATED."""
return tf.transpose(x, perm=pattern)
@keras_export("keras._legacy.backend.pool2d")
def pool2d(
x,
pool_size,
strides=(1, 1),
padding="valid",
data_format=None,
pool_mode="max",
):
"""DEPRECATED."""
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
if len(pool_size) != 2:
raise ValueError("`pool_size` must be a tuple of 2 integers.")
if len(strides) != 2:
raise ValueError("`strides` must be a tuple of 2 integers.")
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
padding = _preprocess_padding(padding)
if tf_data_format == "NHWC":
strides = (1,) + strides + (1,)
pool_size = (1,) + pool_size + (1,)
else:
strides = (1, 1) + strides
pool_size = (1, 1) + pool_size
if pool_mode == "max":
x = tf.compat.v1.nn.max_pool(
x, pool_size, strides, padding=padding, data_format=tf_data_format
)
elif pool_mode == "avg":
x = tf.compat.v1.nn.avg_pool(
x, pool_size, strides, padding=padding, data_format=tf_data_format
)
else:
raise ValueError("Invalid pooling mode: " + str(pool_mode))
if data_format == "channels_first" and tf_data_format == "NHWC":
x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
return x
@keras_export("keras._legacy.backend.pool3d")
def pool3d(
x,
pool_size,
strides=(1, 1, 1),
padding="valid",
data_format=None,
pool_mode="max",
):
"""DEPRECATED."""
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
x, tf_data_format = _preprocess_conv3d_input(x, data_format)
padding = _preprocess_padding(padding)
if tf_data_format == "NDHWC":
strides = (1,) + strides + (1,)
pool_size = (1,) + pool_size + (1,)
else:
strides = (1, 1) + strides
pool_size = (1, 1) + pool_size
if pool_mode == "max":
x = tf.nn.max_pool3d(
x, pool_size, strides, padding=padding, data_format=tf_data_format
)
elif pool_mode == "avg":
x = tf.nn.avg_pool3d(
x, pool_size, strides, padding=padding, data_format=tf_data_format
)
else:
raise ValueError("Invalid pooling mode: " + str(pool_mode))
if data_format == "channels_first" and tf_data_format == "NDHWC":
x = tf.transpose(x, (0, 4, 1, 2, 3))
return x
@keras_export("keras._legacy.backend.pow")
def pow(x, a):
"""DEPRECATED."""
return tf.pow(x, a)
@keras_export("keras._legacy.backend.prod")
def prod(x, axis=None, keepdims=False):
"""DEPRECATED."""
return tf.reduce_prod(x, axis, keepdims)
@keras_export("keras._legacy.backend.random_bernoulli")
def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
"""DEPRECATED."""
if dtype is None:
dtype = backend.floatx()
if seed is None:
seed = np.random.randint(10e6)
return tf.where(
tf.random.uniform(shape, dtype=dtype, seed=seed) <= p,
tf.ones(shape, dtype=dtype),
tf.zeros(shape, dtype=dtype),
)
@keras_export("keras._legacy.backend.random_normal")
def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""DEPRECATED."""
if dtype is None:
dtype = backend.floatx()
if seed is None:
seed = np.random.randint(10e6)
return tf.random.normal(
shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
)
@keras_export("keras._legacy.backend.random_normal_variable")
def random_normal_variable(
shape, mean, scale, dtype=None, name=None, seed=None
):
"""DEPRECATED."""
if dtype is None:
dtype = backend.floatx()
tf_dtype = tf.as_dtype(dtype)
if seed is None:
# ensure that randomness is conditioned by the Numpy RNG
seed = np.random.randint(10e8)
value = tf.compat.v1.random_normal_initializer(
mean, scale, dtype=tf_dtype, seed=seed
)(shape)
return variable(value, dtype=dtype, name=name)
@keras_export("keras._legacy.backend.random_uniform")
def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
"""DEPRECATED."""
if dtype is None:
dtype = backend.floatx()
if seed is None:
seed = np.random.randint(10e6)
return tf.random.uniform(
shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed
)
@keras_export("keras._legacy.backend.random_uniform_variable")
def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
"""DEPRECATED."""
if dtype is None:
dtype = backend.floatx()
tf_dtype = tf.as_dtype(dtype)
if seed is None:
# ensure that randomness is conditioned by the Numpy RNG
seed = np.random.randint(10e8)
value = tf.compat.v1.random_uniform_initializer(
low, high, dtype=tf_dtype, seed=seed
)(shape)
return variable(value, dtype=dtype, name=name)
@keras_export("keras._legacy.backend.reshape")
def reshape(x, shape):
"""DEPRECATED."""
return tf.reshape(x, shape)
@keras_export("keras._legacy.backend.relu")
def relu(x, alpha=0.0, max_value=None, threshold=0.0):
"""DEPRECATED."""
# While x can be a tensor or variable, we also see cases where
# numpy arrays, lists, tuples are passed as well.
# lists, tuples do not have 'dtype' attribute.
dtype = getattr(x, "dtype", backend.floatx())
if alpha != 0.0:
if max_value is None and threshold == 0:
return tf.nn.leaky_relu(x, alpha=alpha)
if threshold != 0:
negative_part = tf.nn.relu(-x + threshold)
else:
negative_part = tf.nn.relu(-x)
else:
negative_part = 1
clip_max = max_value is not None
if threshold != 0:
# computes x for x > threshold else 0
x = x * tf.cast(tf.greater(x, threshold), dtype=dtype)
elif max_value == 6:
# if no threshold, then can use nn.relu6 native TF op for performance
x = tf.nn.relu6(x)
clip_max = False
else:
x = tf.nn.relu(x)
if clip_max:
max_value = tf.convert_to_tensor(max_value, dtype=x.dtype)
zero = tf.convert_to_tensor(0, dtype=x.dtype)
x = tf.clip_by_value(x, zero, max_value)
if alpha != 0.0:
alpha = tf.convert_to_tensor(alpha, dtype=x.dtype)
x -= alpha * negative_part
return x
@keras_export("keras._legacy.backend.repeat")
def repeat(x, n):
"""DEPRECATED."""
assert ndim(x) == 2
x = tf.expand_dims(x, 1)
pattern = tf.stack([1, n, 1])
return tf.tile(x, pattern)
@keras_export("keras._legacy.backend.repeat_elements")
def repeat_elements(x, rep, axis):
"""DEPRECATED."""
x_shape = x.shape.as_list()
# For static axis
if x_shape[axis] is not None:
# slices along the repeat axis
splits = tf.split(value=x, num_or_size_splits=x_shape[axis], axis=axis)
# repeat each slice the given number of reps
x_rep = [s for s in splits for _ in range(rep)]
return concatenate(x_rep, axis)
# Here we use tf.tile to mimic behavior of np.repeat so that
# we can handle dynamic shapes (that include None).
# To do that, we need an auxiliary axis to repeat elements along
# it and then merge them along the desired axis.
# Repeating
auxiliary_axis = axis + 1
x_shape = tf.shape(x)
x_rep = tf.expand_dims(x, axis=auxiliary_axis)
reps = np.ones(len(x.shape) + 1)
reps[auxiliary_axis] = rep
x_rep = tf.tile(x_rep, reps)
# Merging
reps = np.delete(reps, auxiliary_axis)
reps[axis] = rep
reps = tf.constant(reps, dtype="int32")
x_shape *= reps
x_rep = tf.reshape(x_rep, x_shape)
# Fix shape representation
x_shape = x.shape.as_list()
x_rep.set_shape(x_shape)
return x_rep
@keras_export("keras._legacy.backend.resize_images")
def resize_images(
x, height_factor, width_factor, data_format, interpolation="nearest"
):
"""DEPRECATED."""
if data_format == "channels_first":
rows, cols = 2, 3
elif data_format == "channels_last":
rows, cols = 1, 2
else:
raise ValueError(f"Invalid `data_format` argument: {data_format}")
new_shape = x.shape[rows : cols + 1]
if new_shape.is_fully_defined():
new_shape = tf.constant(new_shape.as_list(), dtype="int32")
else:
new_shape = tf.shape(x)[rows : cols + 1]
new_shape *= tf.constant(
np.array([height_factor, width_factor], dtype="int32")
)
if data_format == "channels_first":
x = permute_dimensions(x, [0, 2, 3, 1])
interpolations = {
"area": tf.image.ResizeMethod.AREA,
"bicubic": tf.image.ResizeMethod.BICUBIC,
"bilinear": tf.image.ResizeMethod.BILINEAR,
"gaussian": tf.image.ResizeMethod.GAUSSIAN,
"lanczos3": tf.image.ResizeMethod.LANCZOS3,
"lanczos5": tf.image.ResizeMethod.LANCZOS5,
"mitchellcubic": tf.image.ResizeMethod.MITCHELLCUBIC,
"nearest": tf.image.ResizeMethod.NEAREST_NEIGHBOR,
}
interploations_list = '"' + '", "'.join(interpolations.keys()) + '"'
if interpolation in interpolations:
x = tf.image.resize(x, new_shape, method=interpolations[interpolation])
else:
raise ValueError(
"`interpolation` argument should be one of: "
f'{interploations_list}. Received: "{interpolation}".'
)
if data_format == "channels_first":
x = permute_dimensions(x, [0, 3, 1, 2])
return x
@keras_export("keras._legacy.backend.resize_volumes")
def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
"""DEPRECATED."""
if data_format == "channels_first":
output = repeat_elements(x, depth_factor, axis=2)
output = repeat_elements(output, height_factor, axis=3)
output = repeat_elements(output, width_factor, axis=4)
return output
elif data_format == "channels_last":
output = repeat_elements(x, depth_factor, axis=1)
output = repeat_elements(output, height_factor, axis=2)
output = repeat_elements(output, width_factor, axis=3)
return output
else:
raise ValueError(f"Invalid data_format: {data_format}")
@keras_export("keras._legacy.backend.reverse")
def reverse(x, axes):
"""DEPRECATED."""
if isinstance(axes, int):
axes = [axes]
return tf.reverse(x, axes)
@keras_export("keras._legacy.backend.rnn")
def rnn(
step_function,
inputs,
initial_states,
go_backwards=False,
mask=None,
constants=None,
unroll=False,
input_length=None,
time_major=False,
zero_output_for_mask=False,
return_all_outputs=True,
):
"""DEPRECATED."""
if not tf.__internal__.tf2.enabled():
return_all_outputs = True # Not supported in TF1.
def swap_batch_timestep(input_t):
# Swap the batch and timestep dim for the incoming tensor.
axes = list(range(len(input_t.shape)))
axes[0], axes[1] = 1, 0
return tf.transpose(input_t, axes)
if not time_major:
inputs = tf.nest.map_structure(swap_batch_timestep, inputs)
flatted_inputs = tf.nest.flatten(inputs)
time_steps = flatted_inputs[0].shape[0]
batch = flatted_inputs[0].shape[1]
time_steps_t = tf.shape(flatted_inputs[0])[0]
for input_ in flatted_inputs:
input_.shape.with_rank_at_least(3)
if mask is not None:
if mask.dtype != tf.bool:
mask = tf.cast(mask, tf.bool)
if len(mask.shape) == 2:
mask = expand_dims(mask)
if not time_major:
mask = swap_batch_timestep(mask)
if constants is None:
constants = []
# tf.where needs its condition tensor to be the same shape as its two
# result tensors, but in our case the condition (mask) tensor is
# (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
# So we need to broadcast the mask to match the shape of inputs.
# That's what the tile call does, it just repeats the mask along its
# second dimension n times.
def _expand_mask(mask_t, input_t, fixed_dim=1):
if tf.nest.is_nested(mask_t):
raise ValueError(
f"mask_t is expected to be tensor, but got {mask_t}"
)
if tf.nest.is_nested(input_t):
raise ValueError(
f"input_t is expected to be tensor, but got {input_t}"
)
rank_diff = len(input_t.shape) - len(mask_t.shape)
for _ in range(rank_diff):
mask_t = tf.expand_dims(mask_t, -1)
multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
return tf.tile(mask_t, multiples)
if unroll:
if not time_steps:
raise ValueError("Unrolling requires a fixed number of timesteps.")
states = tuple(initial_states)
successive_states = []
successive_outputs = []
# Process the input tensors. The input tensor need to be split on the
# time_step dim, and reverse if go_backwards is True. In the case of
# nested input, the input is flattened and then transformed
# individually. The result of this will be a tuple of lists, each of
# the item in tuple is list of the tensor with shape (batch, feature)
def _process_single_input_t(input_t):
input_t = tf.unstack(input_t) # unstack for time_step dim
if go_backwards:
input_t.reverse()
return input_t
if tf.nest.is_nested(inputs):
processed_input = tf.nest.map_structure(
_process_single_input_t, inputs
)
else:
processed_input = (_process_single_input_t(inputs),)
def _get_input_tensor(time):
inp = [t_[time] for t_ in processed_input]
return tf.nest.pack_sequence_as(inputs, inp)
if mask is not None:
mask_list = tf.unstack(mask)
if go_backwards:
mask_list.reverse()
for i in range(time_steps):
inp = _get_input_tensor(i)
mask_t = mask_list[i]
output, new_states = step_function(
inp, tuple(states) + tuple(constants)
)
tiled_mask_t = _expand_mask(mask_t, output)
if not successive_outputs:
prev_output = zeros_like(output)
else:
prev_output = successive_outputs[-1]
output = tf.where(tiled_mask_t, output, prev_output)
flat_states = tf.nest.flatten(states)
flat_new_states = tf.nest.flatten(new_states)
tiled_mask_t = tuple(
_expand_mask(mask_t, s) for s in flat_states
)
flat_final_states = tuple(
tf.where(m, s, ps)
for m, s, ps in zip(
tiled_mask_t, flat_new_states, flat_states
)
)
states = tf.nest.pack_sequence_as(states, flat_final_states)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)
if zero_output_for_mask:
last_output = tf.where(
_expand_mask(mask_list[-1], last_output),
last_output,
zeros_like(last_output),
)
outputs = tf.where(
_expand_mask(mask, outputs, fixed_dim=2),
outputs,
zeros_like(outputs),
)
else: # mask is None
for i in range(time_steps):
inp = _get_input_tensor(i)
output, states = step_function(
inp, tuple(states) + tuple(constants)
)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)
else: # Unroll == False
states = tuple(initial_states)
# Create input tensor array, if the inputs is nested tensors, then it
# will be flattened first, and tensor array will be created one per
# flattened tensor.
input_ta = tuple(
tf.TensorArray(
dtype=inp.dtype,
size=time_steps_t,
tensor_array_name=f"input_ta_{i}",
)
for i, inp in enumerate(flatted_inputs)
)
input_ta = tuple(
(
ta.unstack(input_)
if not go_backwards
else ta.unstack(reverse(input_, 0))
)
for ta, input_ in zip(input_ta, flatted_inputs)
)
# Get the time(0) input and compute the output for that, the output will
# be used to determine the dtype of output tensor array. Don't read from
# input_ta due to TensorArray clear_after_read default to True.
input_time_zero = tf.nest.pack_sequence_as(
inputs, [inp[0] for inp in flatted_inputs]
)
# output_time_zero is used to determine the cell output shape and its
# dtype. the value is discarded.
output_time_zero, _ = step_function(
input_time_zero, tuple(initial_states) + tuple(constants)
)
output_ta_size = time_steps_t if return_all_outputs else 1
output_ta = tuple(
tf.TensorArray(
dtype=out.dtype,
size=output_ta_size,
element_shape=out.shape,
tensor_array_name=f"output_ta_{i}",
)
for i, out in enumerate(tf.nest.flatten(output_time_zero))
)
time = tf.constant(0, dtype="int32", name="time")
if input_length is None:
max_iterations = time_steps_t
else:
max_iterations = tf.reduce_max(input_length)
while_loop_kwargs = {
"cond": lambda time, *_: time < time_steps_t,
"maximum_iterations": max_iterations,
"parallel_iterations": 32,
"swap_memory": True,
}
if mask is not None:
if go_backwards:
mask = reverse(mask, 0)
mask_ta = tf.TensorArray(
dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
)
mask_ta = mask_ta.unstack(mask)
def masking_fn(time):
return mask_ta.read(time)
def compute_masked_output(mask_t, flat_out, flat_mask):
tiled_mask_t = tuple(
_expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
for o in flat_out
)
return tuple(
tf.where(m, o, fm)
for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
)
elif isinstance(input_length, tf.Tensor):
if go_backwards:
max_len = tf.reduce_max(input_length, axis=0)
rev_input_length = tf.subtract(max_len - 1, input_length)
def masking_fn(time):
return tf.less(rev_input_length, time)
else:
def masking_fn(time):
return tf.greater(input_length, time)
def compute_masked_output(mask_t, flat_out, flat_mask):
return tuple(
tf.compat.v1.where(mask_t, o, zo)
for (o, zo) in zip(flat_out, flat_mask)
)
else:
masking_fn = None
if masking_fn is not None:
# Mask for the T output will be base on the output of T - 1. In the
# case T = 0, a zero filled tensor will be used.
flat_zero_output = tuple(
tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero)
)
def _step(time, output_ta_t, prev_output, *states):
"""RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
prev_output: tuple of outputs from time - 1.
*states: List of states.
Returns:
Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
"""
current_input = tuple(ta.read(time) for ta in input_ta)
# maybe set shape.
current_input = tf.nest.pack_sequence_as(inputs, current_input)
mask_t = masking_fn(time)
output, new_states = step_function(
current_input, tuple(states) + tuple(constants)
)
# mask output
flat_output = tf.nest.flatten(output)
flat_mask_output = (
flat_zero_output
if zero_output_for_mask
else tf.nest.flatten(prev_output)
)
flat_new_output = compute_masked_output(
mask_t, flat_output, flat_mask_output
)
# mask states
flat_state = tf.nest.flatten(states)
flat_new_state = tf.nest.flatten(new_states)
for state, new_state in zip(flat_state, flat_new_state):
if isinstance(new_state, tf.Tensor):
new_state.set_shape(state.shape)
flat_final_state = compute_masked_output(
mask_t, flat_new_state, flat_state
)
new_states = tf.nest.pack_sequence_as(
new_states, flat_final_state
)
ta_index_to_write = time if return_all_outputs else 0
output_ta_t = tuple(
ta.write(ta_index_to_write, out)
for ta, out in zip(output_ta_t, flat_new_output)
)
return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
new_states
)
final_outputs = tf.compat.v1.while_loop(
body=_step,
loop_vars=(time, output_ta, flat_zero_output) + states,
**while_loop_kwargs,
)
# Skip final_outputs[2] which is the output for final timestep.
new_states = final_outputs[3:]
else:
def _step(time, output_ta_t, *states):
"""RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
*states: List of states.
Returns:
Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
"""
current_input = tuple(ta.read(time) for ta in input_ta)
current_input = tf.nest.pack_sequence_as(inputs, current_input)
output, new_states = step_function(
current_input, tuple(states) + tuple(constants)
)
flat_state = tf.nest.flatten(states)
flat_new_state = tf.nest.flatten(new_states)
for state, new_state in zip(flat_state, flat_new_state):
if isinstance(new_state, tf.Tensor):
new_state.set_shape(state.shape)
flat_output = tf.nest.flatten(output)
ta_index_to_write = time if return_all_outputs else 0
output_ta_t = tuple(
ta.write(ta_index_to_write, out)
for ta, out in zip(output_ta_t, flat_output)
)
new_states = tf.nest.pack_sequence_as(
initial_states, flat_new_state
)
return (time + 1, output_ta_t) + tuple(new_states)
final_outputs = tf.compat.v1.while_loop(
body=_step,
loop_vars=(time, output_ta) + states,
**while_loop_kwargs,
)
new_states = final_outputs[2:]
output_ta = final_outputs[1]
outputs = tuple(o.stack() for o in output_ta)
last_output = tuple(o[-1] for o in outputs)
outputs = tf.nest.pack_sequence_as(output_time_zero, outputs)
last_output = tf.nest.pack_sequence_as(output_time_zero, last_output)
# static shape inference
def set_shape(output_):
if isinstance(output_, tf.Tensor):
shape = output_.shape.as_list()
if return_all_outputs:
shape[0] = time_steps
else:
shape[0] = 1
shape[1] = batch
output_.set_shape(shape)
return output_
outputs = tf.nest.map_structure(set_shape, outputs)
if not time_major:
outputs = tf.nest.map_structure(swap_batch_timestep, outputs)
return last_output, outputs, new_states
@keras_export("keras._legacy.backend.round")
def round(x):
"""DEPRECATED."""
return tf.round(x)
@keras_export("keras._legacy.backend.separable_conv2d")
def separable_conv2d(
x,
depthwise_kernel,
pointwise_kernel,
strides=(1, 1),
padding="valid",
data_format=None,
dilation_rate=(1, 1),
):
"""DEPRECATED."""
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
if len(strides) != 2:
raise ValueError("`strides` must be a tuple of 2 integers.")
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
padding = _preprocess_padding(padding)
if not isinstance(strides, tuple):
strides = tuple(strides)
if tf_data_format == "NHWC":
strides = (1,) + strides + (1,)
else:
strides = (1, 1) + strides
x = tf.nn.separable_conv2d(
x,
depthwise_kernel,
pointwise_kernel,
strides=strides,
padding=padding,
dilations=dilation_rate,
data_format=tf_data_format,
)
if data_format == "channels_first" and tf_data_format == "NHWC":
x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
return x
@keras_export("keras._legacy.backend.set_value")
def set_value(x, value):
"""DEPRECATED."""
value = np.asarray(value, dtype=x.dtype.name)
x.assign(value)
@keras_export("keras._legacy.backend.shape")
def shape(x):
"""DEPRECATED."""
return tf.shape(x)
@keras_export("keras._legacy.backend.sigmoid")
def sigmoid(x):
"""DEPRECATED."""
output = tf.sigmoid(x)
return output
@keras_export("keras._legacy.backend.sign")
def sign(x):
"""DEPRECATED."""
return tf.sign(x)
@keras_export("keras._legacy.backend.sin")
def sin(x):
"""DEPRECATED."""
return tf.sin(x)
@keras_export("keras._legacy.backend.softmax")
def softmax(x, axis=-1):
"""DEPRECATED."""
if x.shape.rank <= 1:
raise ValueError(
f"Cannot apply softmax to a tensor that is 1D. Received input: {x}"
)
if isinstance(axis, int):
output = tf.nn.softmax(x, axis=axis)
else:
# nn.softmax does not support tuple axis.
numerator = tf.exp(x - tf.reduce_max(x, axis=axis, keepdims=True))
denominator = tf.reduce_sum(numerator, axis=axis, keepdims=True)
output = numerator / denominator
# Cache the logits to use for crossentropy loss.
output._keras_logits = x
return output
@keras_export("keras._legacy.backend.softplus")
def softplus(x):
"""DEPRECATED."""
return tf.math.softplus(x)
@keras_export("keras._legacy.backend.softsign")
def softsign(x):
"""DEPRECATED."""
return tf.math.softsign(x)
@keras_export("keras._legacy.backend.sparse_categorical_crossentropy")
def sparse_categorical_crossentropy(
target, output, from_logits=False, axis=-1, ignore_class=None
):
"""DEPRECATED."""
target = tf.convert_to_tensor(target)
output = tf.convert_to_tensor(output)
target = cast(target, "int64")
if not from_logits:
epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
output = tf.clip_by_value(output, epsilon_, 1 - epsilon_)
output = tf.math.log(output)
# Permute output so that the last axis contains the logits/probabilities.
if isinstance(output.shape, (tuple, list)):
output_rank = len(output.shape)
else:
output_rank = output.shape.ndims
if output_rank is not None:
axis %= output_rank
if axis != output_rank - 1:
permutation = list(
itertools.chain(
range(axis), range(axis + 1, output_rank), [axis]
)
)
output = tf.transpose(output, perm=permutation)
elif axis != -1:
raise ValueError(
"Cannot compute sparse categorical crossentropy with `axis={}` "
"on an output tensor with unknown rank".format(axis)
)
# Try to adjust the shape so that rank of labels = rank of logits - 1.
output_shape = tf.shape(output)
target_rank = target.shape.ndims
update_shape = (
target_rank is not None
and output_rank is not None
and target_rank != output_rank - 1
)
if update_shape:
target = flatten(target)
output = tf.reshape(output, [-1, output_shape[-1]])
if ignore_class is not None:
valid_mask = tf.not_equal(target, cast(ignore_class, target.dtype))
target = target[valid_mask]
output = output[valid_mask]
res = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=target, logits=output
)
if ignore_class is not None:
res_shape = cast(output_shape[:-1], "int64")
valid_mask = tf.reshape(valid_mask, res_shape)
res = tf.scatter_nd(tf.where(valid_mask), res, res_shape)
res._keras_mask = valid_mask
return res
if update_shape and output_rank >= 3:
# If our output includes timesteps or
# spatial dimensions we need to reshape
res = tf.reshape(res, output_shape[:-1])
return res
@keras_export("keras._legacy.backend.spatial_2d_padding")
def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
"""DEPRECATED."""
assert len(padding) == 2
assert len(padding[0]) == 2
assert len(padding[1]) == 2
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
if data_format == "channels_first":
pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
else:
pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
return tf.compat.v1.pad(x, pattern)
@keras_export("keras._legacy.backend.spatial_3d_padding")
def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
"""DEPRECATED."""
assert len(padding) == 3
assert len(padding[0]) == 2
assert len(padding[1]) == 2
assert len(padding[2]) == 2
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
if data_format == "channels_first":
pattern = [
[0, 0],
[0, 0],
[padding[0][0], padding[0][1]],
[padding[1][0], padding[1][1]],
[padding[2][0], padding[2][1]],
]
else:
pattern = [
[0, 0],
[padding[0][0], padding[0][1]],
[padding[1][0], padding[1][1]],
[padding[2][0], padding[2][1]],
[0, 0],
]
return tf.compat.v1.pad(x, pattern)
@keras_export("keras._legacy.backend.sqrt")
def sqrt(x):
"""DEPRECATED."""
zero = tf.convert_to_tensor(0.0, x.dtype)
x = tf.maximum(x, zero)
return tf.sqrt(x)
@keras_export("keras._legacy.backend.square")
def square(x):
"""DEPRECATED."""
return tf.square(x)
@keras_export("keras._legacy.backend.squeeze")
def squeeze(x, axis):
"""DEPRECATED."""
return tf.squeeze(x, [axis])
@keras_export("keras._legacy.backend.stack")
def stack(x, axis=0):
"""DEPRECATED."""
return tf.stack(x, axis=axis)
@keras_export("keras._legacy.backend.std")
def std(x, axis=None, keepdims=False):
"""DEPRECATED."""
if x.dtype.base_dtype == tf.bool:
x = tf.cast(x, backend.floatx())
return tf.math.reduce_std(x, axis=axis, keepdims=keepdims)
@keras_export("keras._legacy.backend.stop_gradient")
def stop_gradient(variables):
"""DEPRECATED."""
if isinstance(variables, (list, tuple)):
return map(tf.stop_gradient, variables)
return tf.stop_gradient(variables)
@keras_export("keras._legacy.backend.sum")
def sum(x, axis=None, keepdims=False):
"""DEPRECATED."""
return tf.reduce_sum(x, axis, keepdims)
@keras_export("keras._legacy.backend.switch")
def switch(condition, then_expression, else_expression):
"""DEPRECATED."""
if condition.dtype != tf.bool:
condition = tf.cast(condition, "bool")
cond_ndim = ndim(condition)
if not cond_ndim:
if not callable(then_expression):
def then_expression_fn():
return then_expression
else:
then_expression_fn = then_expression
if not callable(else_expression):
def else_expression_fn():
return else_expression
else:
else_expression_fn = else_expression
x = tf.compat.v1.cond(condition, then_expression_fn, else_expression_fn)
else:
# tf.where needs its condition tensor
# to be the same shape as its two
# result tensors
if callable(then_expression):
then_expression = then_expression()
if callable(else_expression):
else_expression = else_expression()
expr_ndim = ndim(then_expression)
if cond_ndim > expr_ndim:
raise ValueError(
"Rank of `condition` should be less than or"
" equal to rank of `then_expression` and "
"`else_expression`. ndim(condition)="
+ str(cond_ndim)
+ ", ndim(then_expression)="
+ str(expr_ndim)
)
if cond_ndim > 1:
ndim_diff = expr_ndim - cond_ndim
cond_shape = tf.concat(
[tf.shape(condition), [1] * ndim_diff], axis=0
)
condition = tf.reshape(condition, cond_shape)
expr_shape = tf.shape(then_expression)
shape_diff = expr_shape - cond_shape
tile_shape = tf.where(
shape_diff > 0, expr_shape, tf.ones_like(expr_shape)
)
condition = tf.tile(condition, tile_shape)
x = tf.where(condition, then_expression, else_expression)
return x
@keras_export("keras._legacy.backend.tanh")
def tanh(x):
"""DEPRECATED."""
return tf.tanh(x)
@keras_export("keras._legacy.backend.temporal_padding")
def temporal_padding(x, padding=(1, 1)):
"""DEPRECATED."""
assert len(padding) == 2
pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
return tf.compat.v1.pad(x, pattern)
@keras_export("keras._legacy.backend.tile")
def tile(x, n):
"""DEPRECATED."""
if isinstance(n, int):
n = [n]
return tf.tile(x, n)
@keras_export("keras._legacy.backend.to_dense")
def to_dense(tensor):
"""DEPRECATED."""
if is_sparse(tensor):
return tf.sparse.to_dense(tensor)
else:
return tensor
@keras_export("keras._legacy.backend.transpose")
def transpose(x):
"""DEPRECATED."""
return tf.transpose(x)
@keras_export("keras._legacy.backend.truncated_normal")
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""DEPRECATED."""
if dtype is None:
dtype = backend.floatx()
if seed is None:
seed = np.random.randint(10e6)
return tf.random.truncated_normal(
shape, mean, stddev, dtype=dtype, seed=seed
)
@keras_export("keras._legacy.backend.update")
def update(x, new_x):
"""DEPRECATED."""
return tf.compat.v1.assign(x, new_x)
@keras_export("keras._legacy.backend.update_add")
def update_add(x, increment):
"""DEPRECATED."""
return tf.compat.v1.assign_add(x, increment)
@keras_export("keras._legacy.backend.update_sub")
def update_sub(x, decrement):
"""DEPRECATED."""
return tf.compat.v1.assign_sub(x, decrement)
@keras_export("keras._legacy.backend.var")
def var(x, axis=None, keepdims=False):
"""DEPRECATED."""
if x.dtype.base_dtype == tf.bool:
x = tf.cast(x, backend.floatx())
return tf.math.reduce_variance(x, axis=axis, keepdims=keepdims)
@keras_export("keras._legacy.backend.variable")
def variable(value, dtype=None, name=None, constraint=None):
"""DEPRECATED."""
if dtype is None:
dtype = backend.floatx()
if hasattr(value, "tocoo"):
sparse_coo = value.tocoo()
indices = np.concatenate(
(
np.expand_dims(sparse_coo.row, 1),
np.expand_dims(sparse_coo.col, 1),
),
1,
)
v = tf.SparseTensor(
indices=indices,
values=sparse_coo.data,
dense_shape=sparse_coo.shape,
)
v._keras_shape = sparse_coo.shape
return v
v = tf.Variable(
value, dtype=tf.as_dtype(dtype), name=name, constraint=constraint
)
return v
@keras_export("keras._legacy.backend.zeros")
def zeros(shape, dtype=None, name=None):
"""DEPRECATED."""
with tf.init_scope():
if dtype is None:
dtype = backend.floatx()
tf_dtype = tf.as_dtype(dtype)
v = tf.zeros(shape=shape, dtype=tf_dtype, name=name)
if py_all(v.shape.as_list()):
return variable(v, dtype=dtype, name=name)
return v
@keras_export("keras._legacy.backend.zeros_like")
def zeros_like(x, dtype=None, name=None):
"""DEPRECATED."""
return tf.zeros_like(x, dtype=dtype, name=name)