| | import re |
| |
|
| | from keras.src.api_export import keras_export |
| | from keras.src.backend import KerasTensor |
| | from keras.src.backend import any_symbolic_tensors |
| | from keras.src.ops.core import shape |
| | from keras.src.ops.numpy import prod |
| | from keras.src.ops.numpy import reshape |
| | from keras.src.ops.numpy import transpose |
| | from keras.src.ops.operation import Operation |
| |
|
| |
|
| | def _create_axes_map(axes, input_shape, axes_lengths): |
| | axes_map = {} |
| |
|
| | for axis, dim in zip(axes, input_shape): |
| | |
| | grouped_axes = re.match(r"\(([\w\s]+)\)", axis) |
| |
|
| | if grouped_axes: |
| | inner_axes = grouped_axes.group(1).split() |
| | known_axes = [a for a in inner_axes if a in axes_lengths] |
| | inferred_axes = [a for a in inner_axes if a not in axes_lengths] |
| |
|
| | if inferred_axes: |
| | inferred_axis = inferred_axes[0] |
| | known_product = prod([axes_lengths[a] for a in known_axes]) |
| | axes_lengths[inferred_axis] = dim // known_product |
| |
|
| | axes_map.update({a: axes_lengths[a] for a in inner_axes}) |
| | else: |
| | axes_map[axis] = dim |
| |
|
| | return axes_map |
| |
|
| |
|
| | def _create_grouped_axes(axes): |
| | grouped_output_axes = [] |
| | for axis in axes: |
| | grouped_axes = re.match(r"\(([\w\s]+)\)", axis) |
| |
|
| | if grouped_axes: |
| | inner_axes = grouped_axes.group(1).split() |
| | grouped_output_axes.append(inner_axes) |
| | else: |
| | grouped_output_axes.append([axis]) |
| |
|
| | return grouped_output_axes |
| |
|
| |
|
| | def _flatten_group(axes): |
| | return [x for xs in axes for x in xs] |
| |
|
| |
|
| | def _get_transpose_order(from_shape, to_shape): |
| | flattened_from_shape = _flatten_group(_create_grouped_axes(from_shape)) |
| |
|
| | return [flattened_from_shape.index(dim) for dim in to_shape] |
| |
|
| |
|
| | def _compute_output_shape(axes_map, grouped_axes): |
| | output_shape = [] |
| | for group in grouped_axes: |
| | size = 1 |
| | for axis in group: |
| | size *= axes_map[axis] |
| | output_shape.append(size) |
| |
|
| | return tuple(output_shape) |
| |
|
| |
|
| | def _compute_decomposed_shape(input_axes, axes_lengths, axes_map): |
| | reshaped_input_axes = [] |
| | reshaped_sizes = [] |
| |
|
| | for axis in input_axes: |
| | if "(" in axis: |
| | inner_axes = re.findall(r"\w+", axis) |
| | sizes = [axes_lengths[a] for a in inner_axes] |
| | reshaped_input_axes.extend(inner_axes) |
| | reshaped_sizes.extend(sizes) |
| | else: |
| | reshaped_input_axes.append(axis) |
| | reshaped_sizes.append(axes_map[axis]) |
| |
|
| | return reshaped_sizes |
| |
|
| |
|
| | class Rearrange(Operation): |
| | def call(self, tensor, pattern, **axes_lengths): |
| | return rearrange(tensor, pattern, **axes_lengths) |
| |
|
| | def compute_output_spec(self, tensor, pattern, **axes_lengths): |
| | input_pattern, output_pattern = re.split(r"\s*->\s*", pattern) |
| | input_axes = re.findall(r"\w+|\(.*?\)", input_pattern) |
| | output_axes = re.findall(r"\w+|\(.*?\)", output_pattern) |
| | input_shape = shape(tensor) |
| |
|
| | axes_map = _create_axes_map(input_axes, input_shape, axes_lengths) |
| | grouped_output_axes = _create_grouped_axes(output_axes) |
| | output_shape = _compute_output_shape(axes_map, grouped_output_axes) |
| |
|
| | return KerasTensor(shape=output_shape, dtype=tensor.dtype) |
| |
|
| |
|
| | @keras_export("keras.ops.rearrange") |
| | def rearrange(tensor, pattern, **axes_lengths): |
| | """Rearranges the axes of a Keras tensor according to a specified pattern, |
| | einops-style. |
| | |
| | Args: |
| | tensor: Input Keras tensor. |
| | pattern: String describing the rearrangement in einops notation. |
| | **axes_lengths: Keyword arguments specifying lengths of axes |
| | when axes decomposition is used. |
| | |
| | Returns: |
| | Tensor: A Keras tensor with rearranged axes. |
| | |
| | Follows the logic of: |
| | |
| | 1. If decomposition is needed, reshape to match decomposed dimensions. |
| | 2. Permute known and inferred axes to match the form of the output. |
| | 3. Reshape to match the desired output shape. |
| | |
| | |
| | Example Usage: |
| | |
| | ``` |
| | >>> import numpy as np |
| | >>> from keras.ops import rearrange |
| | >>> images = np.random.rand(32, 30, 40, 3) # BHWC format |
| | |
| | # Reordering to BCHW |
| | >>> rearrange(images, 'b h w c -> b c h w').shape |
| | TensorShape([32, 3, 30, 40]) |
| | |
| | # "Merge" along first axis - concat images from a batch |
| | >>> rearrange(images, 'b h w c -> (b h) w c').shape |
| | TensorShape([960, 40, 3]) |
| | |
| | # "Merge" along second axis - concat images horizontally |
| | >>> rearrange(images, 'b h w c -> h (b w) c').shape |
| | TensorShape([30, 1280, 3]) |
| | |
| | # Flatten images into a CHW vector |
| | >>> rearrange(images, 'b h w c -> b (c h w)').shape |
| | TensorShape([32, 3600]) |
| | |
| | # Decompose H and W axes into 4 smaller patches |
| | >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape |
| | TensorShape([128, 15, 20, 3]) |
| | |
| | # Space-to-depth decomposition of input axes |
| | >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape |
| | TensorShape([32, 15, 20, 12]) |
| | ``` |
| | """ |
| |
|
| | if any_symbolic_tensors((tensor,)): |
| | return Rearrange().symbolic_call(tensor, pattern, **axes_lengths) |
| |
|
| | |
| | input_pattern, output_pattern = re.split(r"\s*->\s*", pattern) |
| | input_axes = re.findall(r"\w+|\(.*?\)", input_pattern) |
| | output_axes = re.findall(r"\w+|\(.*?\)", output_pattern) |
| | input_shape = shape(tensor) |
| |
|
| | |
| | axes_map = _create_axes_map(input_axes, input_shape, axes_lengths) |
| | grouped_output_axes = _create_grouped_axes(output_axes) |
| | flattened_output_axes = _flatten_group(grouped_output_axes) |
| |
|
| | |
| | decomposed_shapes = _compute_decomposed_shape( |
| | input_axes, axes_lengths, axes_map |
| | ) |
| | if decomposed_shapes != tensor.shape: |
| | tensor = reshape(tensor, decomposed_shapes) |
| |
|
| | |
| | permute_order = _get_transpose_order(input_axes, flattened_output_axes) |
| | tensor = transpose(tensor, permute_order) |
| |
|
| | |
| | output_shape = _compute_output_shape(axes_map, grouped_output_axes) |
| | tensor = reshape(tensor, output_shape) |
| |
|
| | return tensor |
| |
|