|
|
import math |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from keras.src import tree |
|
|
from keras.src.api_export import keras_export |
|
|
from keras.src.backend.common.backend_utils import canonicalize_axis |
|
|
from keras.src.backend.common.backend_utils import to_tuple_or_list |
|
|
|
|
|
|
|
|
def broadcast_shapes(shape1, shape2): |
|
|
"""Broadcast input shapes to a unified shape. |
|
|
|
|
|
Convert to list for mutability. |
|
|
|
|
|
Args: |
|
|
shape1: A tuple or list of integers. |
|
|
shape2: A tuple or list of integers. |
|
|
|
|
|
Returns: |
|
|
output_shape (list of integers or `None`): The broadcasted shape. |
|
|
|
|
|
Example: |
|
|
>>> broadcast_shapes((5, 3), (1, 3)) |
|
|
[5, 3] |
|
|
""" |
|
|
shape1 = list(shape1) |
|
|
shape2 = list(shape2) |
|
|
origin_shape1 = shape1 |
|
|
origin_shape2 = shape2 |
|
|
|
|
|
if len(shape1) > len(shape2): |
|
|
shape2 = [1] * (len(shape1) - len(shape2)) + shape2 |
|
|
if len(shape1) < len(shape2): |
|
|
shape1 = [1] * (len(shape2) - len(shape1)) + shape1 |
|
|
output_shape = list(shape1) |
|
|
for i in range(len(shape1)): |
|
|
if shape1[i] == 1: |
|
|
output_shape[i] = shape2[i] |
|
|
elif shape1[i] is None: |
|
|
output_shape[i] = None if shape2[i] == 1 else shape2[i] |
|
|
else: |
|
|
if shape2[i] == 1 or shape2[i] is None or shape2[i] == shape1[i]: |
|
|
output_shape[i] = shape1[i] |
|
|
else: |
|
|
raise ValueError( |
|
|
"Cannot broadcast shape, the failure dim has value " |
|
|
f"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. " |
|
|
f"Input shapes are: {origin_shape1} and {origin_shape2}." |
|
|
) |
|
|
|
|
|
return output_shape |
|
|
|
|
|
|
|
|
def compute_expand_dims_output_shape(input_shape, axis): |
|
|
"""Compute the output shape for the `expand_dims` operation. |
|
|
|
|
|
Args: |
|
|
input_shape: Input shape. |
|
|
axis: int or sequence of ints for the axis to expand. |
|
|
|
|
|
Returns: |
|
|
Tuple of ints: The output shape after the `expand_dims` operation. |
|
|
""" |
|
|
input_shape = list(input_shape) |
|
|
if axis is None: |
|
|
axis = len(input_shape) |
|
|
axis = to_tuple_or_list(axis) |
|
|
out_ndim = len(axis) + len(input_shape) |
|
|
axis = [canonicalize_axis(a, out_ndim) for a in axis] |
|
|
shape_iter = iter(input_shape) |
|
|
new_shape = [ |
|
|
1 if ax in axis else next(shape_iter) for ax in range(out_ndim) |
|
|
] |
|
|
return tuple(new_shape) |
|
|
|
|
|
|
|
|
def compute_pooling_output_shape( |
|
|
input_shape, |
|
|
pool_size, |
|
|
strides, |
|
|
padding="valid", |
|
|
data_format="channels_last", |
|
|
): |
|
|
"""Computes the output shape of pooling operations. |
|
|
|
|
|
Args: |
|
|
input_shape: Input shape. Must be a tuple of integers. |
|
|
pool_size: Size of the pooling operation. Must be a tuple of integers. |
|
|
strides: Stride of the pooling operation. Must be a tuple of integers. |
|
|
Defaults to `pool_size`. |
|
|
padding: Padding method. Available methods are `"valid"` or `"same"`. |
|
|
Defaults to `"valid"`. |
|
|
data_format: String, either `"channels_last"` or `"channels_first"`. |
|
|
The ordering of the dimensions in the inputs. `"channels_last"` |
|
|
corresponds to inputs with shape `(batch, height, width, channels)` |
|
|
while `"channels_first"` corresponds to inputs with shape |
|
|
`(batch, channels, height, weight)`. Defaults to `"channels_last"`. |
|
|
|
|
|
Returns: |
|
|
Tuple of ints: The output shape of the pooling operation. |
|
|
|
|
|
Examples: |
|
|
|
|
|
# Basic usage with square pooling on a single image |
|
|
>>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2)) |
|
|
(1, 2, 2, 1) |
|
|
|
|
|
# Strided pooling on a single image with strides different from pool_size |
|
|
>>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2), strides=(1, 1)) |
|
|
(1, 3, 3, 1) |
|
|
|
|
|
# Pooling on a batch of images |
|
|
>>> compute_pooling_output_shape((32, 4, 4, 3), (2, 2)) |
|
|
(32, 2, 2, 3) |
|
|
""" |
|
|
strides = pool_size if strides is None else strides |
|
|
input_shape_origin = list(input_shape) |
|
|
input_shape = np.array(input_shape) |
|
|
if data_format == "channels_last": |
|
|
spatial_shape = input_shape[1:-1] |
|
|
else: |
|
|
spatial_shape = input_shape[2:] |
|
|
none_dims = [] |
|
|
for i in range(len(spatial_shape)): |
|
|
if spatial_shape[i] is None: |
|
|
|
|
|
|
|
|
spatial_shape[i] = -1 |
|
|
none_dims.append(i) |
|
|
pool_size = np.array(pool_size) |
|
|
if padding == "valid": |
|
|
output_spatial_shape = ( |
|
|
np.floor((spatial_shape - pool_size) / strides) + 1 |
|
|
) |
|
|
for i in range(len(output_spatial_shape)): |
|
|
if i not in none_dims and output_spatial_shape[i] < 0: |
|
|
raise ValueError( |
|
|
"Computed output size would be negative. Received: " |
|
|
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`." |
|
|
) |
|
|
elif padding == "same": |
|
|
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1 |
|
|
else: |
|
|
raise ValueError( |
|
|
"Argument `padding` must be either 'valid' or 'same'. Received: " |
|
|
f"padding={padding}" |
|
|
) |
|
|
output_spatial_shape = [int(i) for i in output_spatial_shape] |
|
|
for i in none_dims: |
|
|
output_spatial_shape[i] = None |
|
|
output_spatial_shape = tuple(output_spatial_shape) |
|
|
if data_format == "channels_last": |
|
|
output_shape = ( |
|
|
(input_shape_origin[0],) |
|
|
+ output_spatial_shape |
|
|
+ (input_shape_origin[-1],) |
|
|
) |
|
|
else: |
|
|
output_shape = ( |
|
|
input_shape_origin[0], |
|
|
input_shape_origin[1], |
|
|
) + output_spatial_shape |
|
|
return output_shape |
|
|
|
|
|
|
|
|
def compute_conv_output_shape( |
|
|
input_shape, |
|
|
filters, |
|
|
kernel_size, |
|
|
strides=1, |
|
|
padding="valid", |
|
|
data_format="channels_last", |
|
|
dilation_rate=1, |
|
|
): |
|
|
"""Compute the output shape of conv ops.""" |
|
|
if data_format == "channels_last": |
|
|
spatial_shape = input_shape[1:-1] |
|
|
kernel_shape = kernel_size + (input_shape[-1], filters) |
|
|
else: |
|
|
spatial_shape = input_shape[2:] |
|
|
kernel_shape = kernel_size + (input_shape[1], filters) |
|
|
if len(kernel_shape) != len(input_shape): |
|
|
raise ValueError( |
|
|
"Kernel shape must have the same length as input, but received " |
|
|
f"kernel of shape {kernel_shape} and " |
|
|
f"input of shape {input_shape}." |
|
|
) |
|
|
if isinstance(dilation_rate, int): |
|
|
dilation_rate = (dilation_rate,) * len(spatial_shape) |
|
|
if isinstance(strides, int): |
|
|
strides = (strides,) * len(spatial_shape) |
|
|
if len(dilation_rate) != len(spatial_shape): |
|
|
raise ValueError( |
|
|
"Dilation must be None, scalar or tuple/list of length of " |
|
|
"inputs' spatial shape, but received " |
|
|
f"`dilation_rate={dilation_rate}` and " |
|
|
f"input of shape {input_shape}." |
|
|
) |
|
|
none_dims = [] |
|
|
spatial_shape = np.array(spatial_shape) |
|
|
for i in range(len(spatial_shape)): |
|
|
if spatial_shape[i] is None: |
|
|
|
|
|
|
|
|
spatial_shape[i] = -1 |
|
|
none_dims.append(i) |
|
|
|
|
|
kernel_spatial_shape = np.array(kernel_shape[:-2]) |
|
|
dilation_rate = np.array(dilation_rate) |
|
|
if padding == "valid": |
|
|
output_spatial_shape = ( |
|
|
np.floor( |
|
|
(spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1) |
|
|
/ strides |
|
|
) |
|
|
+ 1 |
|
|
) |
|
|
for i in range(len(output_spatial_shape)): |
|
|
if i not in none_dims and output_spatial_shape[i] < 0: |
|
|
raise ValueError( |
|
|
"Computed output size would be negative. Received " |
|
|
f"`inputs shape={input_shape}`, " |
|
|
f"`kernel shape={kernel_shape}`, " |
|
|
f"`dilation_rate={dilation_rate}`." |
|
|
) |
|
|
elif padding == "same" or padding == "causal": |
|
|
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1 |
|
|
else: |
|
|
raise ValueError( |
|
|
"`padding` must be either `'valid'` or `'same'`. Received " |
|
|
f"{padding}." |
|
|
) |
|
|
output_spatial_shape = [int(i) for i in output_spatial_shape] |
|
|
for i in none_dims: |
|
|
output_spatial_shape[i] = None |
|
|
output_spatial_shape = tuple(output_spatial_shape) |
|
|
if data_format == "channels_last": |
|
|
output_shape = ( |
|
|
(input_shape[0],) + output_spatial_shape + (kernel_shape[-1],) |
|
|
) |
|
|
else: |
|
|
output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape |
|
|
return output_shape |
|
|
|
|
|
|
|
|
def compute_matmul_output_shape(shape1, shape2): |
|
|
"""Compute the output shape of a `matmul` operation. |
|
|
|
|
|
Args: |
|
|
shape1: Shape of the left operand. |
|
|
shape2: Shape of the right operand. |
|
|
|
|
|
Returns: |
|
|
Tuple of ints: The output shape for the `matmul` operation. |
|
|
""" |
|
|
if len(shape1) == 1: |
|
|
shape1 = (1, shape1[0]) |
|
|
if len(shape2) == 1: |
|
|
shape2 = (shape2[0], 1) |
|
|
if ( |
|
|
shape1[-1] is not None |
|
|
and shape2[-2] is not None |
|
|
and shape1[-1] != shape2[-2] |
|
|
): |
|
|
raise ValueError( |
|
|
"Inner dimensions (`x1.shape[-1]` and `x2.shape[-2]`) must be " |
|
|
f"equal, but received `x1.shape={shape1}` and " |
|
|
f"`x2.shape={shape2}`." |
|
|
) |
|
|
|
|
|
leading_shape = broadcast_shapes(shape1[:-2], shape2[:-2]) |
|
|
last_2_dims_shape = [shape1[-2], shape2[-1]] |
|
|
output_shape = leading_shape + last_2_dims_shape |
|
|
if len(shape1) == 1: |
|
|
del output_shape[-2] |
|
|
if len(shape2) == 1: |
|
|
del output_shape[-1] |
|
|
return tuple(output_shape) |
|
|
|
|
|
|
|
|
def compute_reshape_output_shape(input_shape, newshape, newshape_arg_name): |
|
|
"""Converts `-1` in `newshape` to either an actual dimension or `None`. |
|
|
|
|
|
This utility does not special case the 0th dimension (batch size). |
|
|
""" |
|
|
unknown_dim_count = newshape.count(-1) |
|
|
if unknown_dim_count > 1: |
|
|
raise ValueError( |
|
|
"There must be at most one unknown dimension (-1) in " |
|
|
f"{newshape_arg_name}. Received: {newshape_arg_name}={newshape}." |
|
|
) |
|
|
|
|
|
|
|
|
if None in input_shape: |
|
|
return tuple(dim if dim != -1 else None for dim in newshape) |
|
|
|
|
|
input_size = math.prod(input_shape) |
|
|
|
|
|
if unknown_dim_count == 0: |
|
|
if input_size != math.prod(newshape): |
|
|
raise ValueError( |
|
|
"The total size of the tensor must be unchanged. Received: " |
|
|
f"input_shape={input_shape}, {newshape_arg_name}={newshape}" |
|
|
) |
|
|
return newshape |
|
|
|
|
|
|
|
|
known_output_size = 1 |
|
|
unknown_dim_index = None |
|
|
for index, dim in enumerate(newshape): |
|
|
if dim == -1: |
|
|
unknown_dim_index = index |
|
|
else: |
|
|
known_output_size *= dim |
|
|
|
|
|
if known_output_size == 0 or input_size % known_output_size != 0: |
|
|
raise ValueError( |
|
|
"The total size of the tensor must be unchanged, however, the " |
|
|
"input size cannot by divided by the specified dimensions in " |
|
|
f"{newshape_arg_name}. Received: input_shape={input_shape}, " |
|
|
f"{newshape_arg_name}={newshape}" |
|
|
) |
|
|
|
|
|
output_shape = list(newshape) |
|
|
output_shape[unknown_dim_index] = input_size // known_output_size |
|
|
return tuple(output_shape) |
|
|
|
|
|
|
|
|
def compute_transpose_output_shape(input_shape, axes): |
|
|
"""Compute the output shape for the `transpose` operation. |
|
|
|
|
|
Args: |
|
|
input_shape: Input shape. |
|
|
axes: Permutation of the dimensions for the `transpose` operation. |
|
|
|
|
|
Returns: |
|
|
Tuple of ints: The output shape after the `transpose` operation. |
|
|
""" |
|
|
input_shape = list(input_shape) |
|
|
if axes is None: |
|
|
return tuple(input_shape[::-1]) |
|
|
|
|
|
if len(axes) != len(input_shape): |
|
|
raise ValueError( |
|
|
"axis must be a list of the same length as the input shape, " |
|
|
f"expected {len(input_shape)}, but received {len(axes)}." |
|
|
) |
|
|
return tuple(input_shape[ax] for ax in axes) |
|
|
|
|
|
|
|
|
def compute_take_along_axis_output_shape(input_shape, indices_shape, axis): |
|
|
input_shape = list(input_shape) |
|
|
indices_shape = list(indices_shape) |
|
|
if axis is None: |
|
|
input_shape = ( |
|
|
[None] if None in input_shape else [int(np.prod(input_shape))] |
|
|
) |
|
|
|
|
|
if len(input_shape) != len(indices_shape): |
|
|
raise ValueError( |
|
|
"`x` and `indices` must have the same number of dimensions, " |
|
|
f"but receive shape {input_shape} and {indices_shape}." |
|
|
) |
|
|
|
|
|
input_shape[axis] = indices_shape[axis] |
|
|
output_shape = broadcast_shapes(input_shape, indices_shape) |
|
|
return output_shape |
|
|
|
|
|
|
|
|
def reduce_shape(shape, axis=None, keepdims=False): |
|
|
shape = list(shape) |
|
|
if axis is None: |
|
|
if keepdims: |
|
|
return tuple([1 for _ in shape]) |
|
|
else: |
|
|
return tuple([]) |
|
|
|
|
|
if keepdims: |
|
|
for ax in axis: |
|
|
shape[ax] = 1 |
|
|
return tuple(shape) |
|
|
else: |
|
|
for ax in sorted(axis, reverse=True): |
|
|
del shape[ax] |
|
|
return tuple(shape) |
|
|
|
|
|
|
|
|
@keras_export("keras.utils.get_source_inputs") |
|
|
def get_source_inputs(tensor): |
|
|
"""Returns the list of input tensors necessary to compute `tensor`. |
|
|
|
|
|
Output will always be a list of tensors |
|
|
(potentially with 1 element). |
|
|
|
|
|
Args: |
|
|
tensor: The tensor to start from. |
|
|
|
|
|
Returns: |
|
|
List of input tensors. |
|
|
""" |
|
|
if not hasattr(tensor, "_keras_history"): |
|
|
return tensor |
|
|
|
|
|
operation, node_index, _ = tensor._keras_history |
|
|
if not operation or not operation._inbound_nodes: |
|
|
return [tensor] |
|
|
else: |
|
|
node = operation._inbound_nodes[node_index] |
|
|
if node.is_input: |
|
|
|
|
|
return tree.flatten(node.output_tensors) |
|
|
else: |
|
|
source_tensors = [] |
|
|
for tensor in node.input_tensors: |
|
|
previous_sources = get_source_inputs(tensor) |
|
|
|
|
|
for x in previous_sources: |
|
|
if all(x is not t for t in source_tensors): |
|
|
source_tensors.append(x) |
|
|
return source_tensors |
|
|
|