| """ |
| Backends in `einops` are organized to meet the following requirements |
| - backends are not imported unless those are actually needed, because |
| - backends may not be installed |
| - importing all available backends will drive to significant memory footprint |
| - backends may be present but installed with errors (but never used), |
| importing may drive to crashes |
| - backend should be either symbolic or imperative |
| - this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined |
| - if backend can't provide symbols for shape dimensions, UnknownSize objects are used |
| """ |
|
|
| import sys |
|
|
| __author__ = "Alex Rogozhnikov" |
|
|
| _loaded_backends: dict = {} |
| _type2backend: dict = {} |
| _debug_importing = False |
|
|
|
|
| def get_backend(tensor) -> "AbstractBackend": |
| """ |
| Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor. |
| If needed, imports package and creates backend |
| """ |
| _type = type(tensor) |
| _result = _type2backend.get(_type, None) |
| if _result is not None: |
| return _result |
|
|
| previously_loaded_backends = list(_loaded_backends.items()) |
| for _framework_name, backend in previously_loaded_backends: |
| if backend.is_appropriate_type(tensor): |
| _type2backend[_type] = backend |
| return backend |
|
|
| |
| backend_subclasses = [] |
| backends = AbstractBackend.__subclasses__() |
| while backends: |
| backend = backends.pop() |
| backends += backend.__subclasses__() |
| backend_subclasses.append(backend) |
|
|
| |
| prev_backend_names = [x for x, _ in previously_loaded_backends] |
| for BackendSubclass in backend_subclasses: |
| if _debug_importing: |
| print("Testing for subclass of ", BackendSubclass) |
| if BackendSubclass.framework_name not in prev_backend_names: |
| |
| if BackendSubclass.framework_name in sys.modules: |
| if _debug_importing: |
| print("Imported backend for ", BackendSubclass.framework_name) |
| backend = BackendSubclass() |
| _loaded_backends[backend.framework_name] = backend |
| if backend.is_appropriate_type(tensor): |
| _type2backend[_type] = backend |
| return backend |
|
|
| raise RuntimeError(f"Tensor type unknown to einops {type(tensor)}") |
|
|
|
|
| class AbstractBackend: |
| """Base backend class, major part of methods are only for debugging purposes.""" |
|
|
| framework_name: str |
|
|
| def is_appropriate_type(self, tensor): |
| """helper method should recognize tensors it can handle""" |
| raise NotImplementedError() |
|
|
| def from_numpy(self, x): |
| raise NotImplementedError("framework doesn't support imperative execution") |
|
|
| def to_numpy(self, x): |
| raise NotImplementedError("framework doesn't support imperative execution") |
|
|
| def create_symbol(self, shape): |
| raise NotImplementedError("framework doesn't support symbolic computations") |
|
|
| def eval_symbol(self, symbol, symbol_value_pairs): |
| |
| raise NotImplementedError("framework doesn't support symbolic computations") |
|
|
| def arange(self, start, stop): |
| |
| raise NotImplementedError("framework doesn't implement arange") |
|
|
| def shape(self, x): |
| """shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)""" |
| return x.shape |
|
|
| def reshape(self, x, shape): |
| return x.reshape(shape) |
|
|
| def transpose(self, x, axes): |
| return x.transpose(axes) |
|
|
| def reduce(self, x, operation, axes): |
| return getattr(x, operation)(axis=axes) |
|
|
| def stack_on_zeroth_dimension(self, tensors: list): |
| raise NotImplementedError() |
|
|
| def add_axis(self, x, new_position): |
| raise NotImplementedError() |
|
|
| def add_axes(self, x, n_axes, pos2len): |
| repeats = [1] * n_axes |
| for axis_position, axis_length in pos2len.items(): |
| x = self.add_axis(x, axis_position) |
| repeats[axis_position] = axis_length |
| return self.tile(x, tuple(repeats)) |
|
|
| def tile(self, x, repeats): |
| """repeats - same lengths as x.shape""" |
| raise NotImplementedError() |
|
|
| def concat(self, tensors, axis: int): |
| """concatenates tensors along axis. |
| Assume identical across tensors: devices, dtypes and shapes except selected axis.""" |
| raise NotImplementedError() |
|
|
| def is_float_type(self, x): |
| |
| |
| raise NotImplementedError() |
|
|
| def layers(self): |
| raise NotImplementedError("backend does not provide layers") |
|
|
| def __repr__(self): |
| return f"<einops backend for {self.framework_name}>" |
|
|
| def einsum(self, pattern, *x): |
| raise NotImplementedError("backend does not support einsum") |
|
|
|
|
| class UnknownSize: |
| """pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements""" |
|
|
| def __floordiv__(self, other): |
| return self |
|
|
| def __eq__(self, other): |
| return True |
|
|
| def __mul__(self, other): |
| return self |
|
|
| def __rmul__(self, other): |
| return self |
|
|
| def __hash__(self): |
| return hash(None) |
|
|
|
|
| class NumpyBackend(AbstractBackend): |
| framework_name = "numpy" |
|
|
| def __init__(self): |
| import numpy |
|
|
| self.np = numpy |
|
|
| def is_appropriate_type(self, tensor): |
| return isinstance(tensor, self.np.ndarray) |
|
|
| def from_numpy(self, x): |
| return x |
|
|
| def to_numpy(self, x): |
| return x |
|
|
| def arange(self, start, stop): |
| return self.np.arange(start, stop) |
|
|
| def stack_on_zeroth_dimension(self, tensors: list): |
| return self.np.stack(tensors) |
|
|
| def tile(self, x, repeats): |
| return self.np.tile(x, repeats) |
|
|
| def concat(self, tensors, axis: int): |
| return self.np.concatenate(tensors, axis=axis) |
|
|
| def is_float_type(self, x): |
| return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") |
|
|
| def add_axis(self, x, new_position): |
| return self.np.expand_dims(x, new_position) |
|
|
| def einsum(self, pattern, *x): |
| return self.np.einsum(pattern, *x) |
|
|
|
|
| class JaxBackend(NumpyBackend): |
| framework_name = "jax" |
|
|
| def __init__(self): |
| super().__init__() |
| self.onp = self.np |
|
|
| import jax.numpy |
|
|
| self.np = jax.numpy |
|
|
| def from_numpy(self, x): |
| return self.np.asarray(x) |
|
|
| def to_numpy(self, x): |
| return self.onp.asarray(x) |
|
|
|
|
| class TorchBackend(AbstractBackend): |
| framework_name = "torch" |
|
|
| def __init__(self): |
| import torch |
|
|
| self.torch = torch |
| |
| from . import _torch_specific |
|
|
| def is_appropriate_type(self, tensor): |
| return isinstance(tensor, self.torch.Tensor) |
|
|
| def from_numpy(self, x): |
| variable = self.torch.from_numpy(x) |
| if self.is_float_type(variable): |
| |
| variable.requires_grad = True |
| return variable |
|
|
| def to_numpy(self, x): |
| return x.detach().cpu().numpy() |
|
|
| def arange(self, start, stop): |
| return self.torch.arange(start, stop, dtype=self.torch.int64) |
|
|
| def reduce(self, x, operation, reduced_axes): |
| if operation == "min": |
| return x.amin(dim=reduced_axes) |
| elif operation == "max": |
| return x.amax(dim=reduced_axes) |
| elif operation == "sum": |
| return x.sum(dim=reduced_axes) |
| elif operation == "mean": |
| return x.mean(dim=reduced_axes) |
| elif operation in ("any", "all", "prod"): |
| |
| for i in sorted(reduced_axes)[::-1]: |
| x = getattr(x, operation)(dim=i) |
| return x |
| else: |
| raise NotImplementedError("Unknown reduction ", operation) |
|
|
| def transpose(self, x, axes): |
| return x.permute(axes) |
|
|
| def stack_on_zeroth_dimension(self, tensors: list): |
| return self.torch.stack(tensors) |
|
|
| def add_axes(self, x, n_axes, pos2len): |
| repeats = [-1] * n_axes |
| for axis_position, axis_length in pos2len.items(): |
| x = self.add_axis(x, axis_position) |
| repeats[axis_position] = axis_length |
| return x.expand(repeats) |
|
|
| def tile(self, x, repeats): |
| return x.repeat(repeats) |
|
|
| def concat(self, tensors, axis: int): |
| return self.torch.cat(tensors, dim=axis) |
|
|
| def add_axis(self, x, new_position): |
| return self.torch.unsqueeze(x, new_position) |
|
|
| def is_float_type(self, x): |
| return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16] |
|
|
| def layers(self): |
| from .layers import torch |
|
|
| return torch |
|
|
| def einsum(self, pattern, *x): |
| return self.torch.einsum(pattern, *x) |
|
|
|
|
| class CupyBackend(AbstractBackend): |
| framework_name = "cupy" |
|
|
| def __init__(self): |
| import cupy |
|
|
| self.cupy = cupy |
|
|
| def is_appropriate_type(self, tensor): |
| return isinstance(tensor, self.cupy.ndarray) |
|
|
| def from_numpy(self, x): |
| return self.cupy.asarray(x) |
|
|
| def to_numpy(self, x): |
| return self.cupy.asnumpy(x) |
|
|
| def arange(self, start, stop): |
| return self.cupy.arange(start, stop) |
|
|
| def stack_on_zeroth_dimension(self, tensors: list): |
| return self.cupy.stack(tensors) |
|
|
| def tile(self, x, repeats): |
| return self.cupy.tile(x, repeats) |
|
|
| def concat(self, tensors, axis: int): |
| return self.cupy.concatenate(tensors, axis=axis) |
|
|
| def add_axis(self, x, new_position): |
| return self.cupy.expand_dims(x, new_position) |
|
|
| def is_float_type(self, x): |
| return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") |
|
|
| def einsum(self, pattern, *x): |
| return self.cupy.einsum(pattern, *x) |
|
|
|
|
| class HashableTuple: |
| """Overcomes non-hashability of symbolic elements""" |
|
|
| def __init__(self, elements: tuple): |
| self.elements = elements |
|
|
| def __iter__(self): |
| yield from self.elements |
|
|
| def __len__(self): |
| return len(self.elements) |
|
|
| def __getitem__(self, item): |
| return self.elements[item] |
|
|
| |
|
|
|
|
| class TensorflowBackend(AbstractBackend): |
| framework_name = "tensorflow" |
|
|
| def __init__(self): |
| import tensorflow |
|
|
| self.tf = tensorflow |
|
|
| def is_appropriate_type(self, tensor): |
| return isinstance(tensor, (self.tf.Tensor, self.tf.Variable)) |
|
|
| def from_numpy(self, x): |
| assert self.tf.executing_eagerly() |
| return self.tf.convert_to_tensor(x) |
|
|
| def to_numpy(self, x): |
| assert self.tf.executing_eagerly() |
| return x.numpy() |
|
|
| def arange(self, start, stop): |
| return self.tf.range(start, stop) |
|
|
| def shape(self, x): |
| if self.tf.executing_eagerly(): |
| return tuple(UnknownSize() if d is None else int(d) for d in x.shape) |
| else: |
| static_shape = x.shape.as_list() |
| tf_shape = self.tf.shape(x) |
| |
| shape = tuple([s or tf_shape[dim] for dim, s in enumerate(static_shape)]) |
| try: |
| hash(shape) |
| return shape |
| except BaseException: |
| |
| return HashableTuple(shape) |
|
|
| def reduce(self, x, operation, axes): |
| return getattr(self.tf, "reduce_" + operation)(x, axis=axes) |
|
|
| def reshape(self, x, shape): |
| return self.tf.reshape(x, shape) |
|
|
| def transpose(self, x, axes): |
| return self.tf.transpose(x, axes) |
|
|
| def stack_on_zeroth_dimension(self, tensors: list): |
| return self.tf.stack(tensors) |
|
|
| def tile(self, x, repeats): |
| return self.tf.tile(x, repeats) |
|
|
| def concat(self, tensors, axis: int): |
| return self.tf.concat(tensors, axis=axis) |
|
|
| def add_axis(self, x, new_position): |
| return self.tf.expand_dims(x, new_position) |
|
|
| def is_float_type(self, x): |
| return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") |
|
|
| def layers(self): |
| from .layers import tensorflow |
|
|
| return tensorflow |
|
|
| def einsum(self, pattern, *x): |
| return self.tf.einsum(pattern, *x) |
|
|
|
|
| class TFKerasBackend(AbstractBackend): |
| framework_name = "tensorflow.keras" |
|
|
| def __init__(self): |
| import tensorflow as tf |
|
|
| self.tf = tf |
| self.keras = tf.keras |
| self.K = tf.keras.backend |
|
|
| def is_appropriate_type(self, tensor): |
| return self.tf.is_tensor(tensor) and self.K.is_keras_tensor(tensor) |
|
|
| def create_symbol(self, shape): |
| return self.keras.Input(batch_shape=shape) |
|
|
| def eval_symbol(self, symbol, symbol_value_pairs): |
| model = self.keras.models.Model([var for (var, _) in symbol_value_pairs], symbol) |
| return model.predict_on_batch([val for (_, val) in symbol_value_pairs]) |
|
|
| def arange(self, start, stop): |
| return self.K.arange(start, stop) |
|
|
| def shape(self, x): |
| shape = self.K.shape(x) |
| return HashableTuple(tuple(shape)) |
|
|
| def reduce(self, x, operation, axes): |
| return getattr(self.K, operation)(x, axis=axes) |
|
|
| def reshape(self, x, shape): |
| return self.K.reshape(x, shape) |
|
|
| def transpose(self, x, axes): |
| return self.K.permute_dimensions(x, axes) |
|
|
| def stack_on_zeroth_dimension(self, tensors: list): |
| return self.K.stack(tensors) |
|
|
| def tile(self, x, repeats): |
| return self.K.tile(x, repeats) |
|
|
| def concat(self, tensors, axis: int): |
| return self.K.concatenate(tensors, axis=axis) |
|
|
| def add_axis(self, x, new_position): |
| return self.K.expand_dims(x, new_position) |
|
|
| def is_float_type(self, x): |
| return "float" in self.K.dtype(x) |
|
|
| def layers(self): |
| from .layers import keras |
|
|
| return keras |
|
|
|
|
| class OneFlowBackend(AbstractBackend): |
| framework_name = "oneflow" |
|
|
| def __init__(self): |
| import oneflow as flow |
|
|
| self.flow = flow |
|
|
| def is_appropriate_type(self, tensor): |
| return isinstance(tensor, self.flow.Tensor) |
|
|
| def from_numpy(self, x): |
| variable = self.flow.from_numpy(x) |
| if self.is_float_type(variable): |
| |
| variable.requires_grad = True |
| return variable |
|
|
| def to_numpy(self, x): |
| return x.detach().cpu().numpy() |
|
|
| def arange(self, start, stop): |
| return self.flow.arange(start, stop, dtype=self.flow.int64) |
|
|
| def reduce(self, x, operation, reduced_axes): |
| for axis in sorted(reduced_axes, reverse=True): |
| if operation == "min": |
| x, _ = x.min(dim=axis) |
| elif operation == "max": |
| x, _ = x.max(dim=axis) |
| elif operation in ["sum", "mean", "prod", "any", "all"]: |
| x = getattr(x, operation)(dim=axis) |
| else: |
| raise NotImplementedError("Unknown reduction ", operation) |
| return x |
|
|
| def transpose(self, x, axes): |
| return x.permute(axes) |
|
|
| def stack_on_zeroth_dimension(self, tensors: list): |
| return self.flow.stack(tensors) |
|
|
| def add_axes(self, x, n_axes, pos2len): |
| repeats = [-1] * n_axes |
| for axis_position, axis_length in pos2len.items(): |
| x = self.add_axis(x, axis_position) |
| repeats[axis_position] = axis_length |
| return x.expand(*repeats) |
|
|
| def tile(self, x, repeats): |
| return x.repeat(repeats) |
|
|
| def concat(self, tensors, axis: int): |
| return self.flow.concat(tensors, dim=axis) |
|
|
| def add_axis(self, x, new_position): |
| return self.flow.unsqueeze(x, new_position) |
|
|
| def is_float_type(self, x): |
| return x.dtype in [self.flow.float16, self.flow.float32, self.flow.float64] |
|
|
| def layers(self): |
| from .layers import oneflow |
|
|
| return oneflow |
|
|
| def einsum(self, pattern, *x): |
| return self.flow.einsum(pattern, *x) |
|
|
|
|
| class PaddleBackend(AbstractBackend): |
| framework_name = "paddle" |
|
|
| def __init__(self): |
| import paddle |
|
|
| self.paddle = paddle |
|
|
| def is_appropriate_type(self, tensor): |
| return self.paddle.is_tensor(tensor) |
|
|
| def from_numpy(self, x): |
| tensor = self.paddle.to_tensor(x) |
| tensor.stop_gradient = False |
| return tensor |
|
|
| def to_numpy(self, x): |
| return x.detach().numpy() |
|
|
| def arange(self, start, stop): |
| return self.paddle.arange(start, stop, dtype=self.paddle.int64) |
|
|
| def reduce(self, x, operation, axes): |
| if len(axes) == x.ndim: |
| |
| return super().reduce(x, operation, axes).squeeze(0) |
| else: |
| return super().reduce(x, operation, axes) |
|
|
| def transpose(self, x, axes): |
| return x.transpose(axes) |
|
|
| def add_axes(self, x, n_axes, pos2len): |
| repeats = [-1] * n_axes |
| for axis_position, axis_length in pos2len.items(): |
| x = self.add_axis(x, axis_position) |
| repeats[axis_position] = axis_length |
| return x.expand(repeats) |
|
|
| def stack_on_zeroth_dimension(self, tensors: list): |
| return self.paddle.stack(tensors) |
|
|
| def reshape(self, x, shape): |
| return x.reshape(shape) |
|
|
| def tile(self, x, repeats): |
| return x.tile(repeats) |
|
|
| def concat(self, tensors, axis: int): |
| return self.paddle.concat(tensors, axis=axis) |
|
|
| def add_axis(self, x, new_position): |
| return x.unsqueeze(new_position) |
|
|
| def is_float_type(self, x): |
| return x.dtype in [self.paddle.float16, self.paddle.float32, self.paddle.float64] |
|
|
| def layers(self): |
| from .layers import paddle |
|
|
| return paddle |
|
|
| def einsum(self, pattern, *x): |
| return self.paddle.einsum(pattern, *x) |
|
|
| def shape(self, x): |
| return tuple(x.shape) |
|
|
|
|
| class TinygradBackend(AbstractBackend): |
| framework_name = "tinygrad" |
|
|
| def __init__(self): |
| import tinygrad |
|
|
| self.tinygrad = tinygrad |
|
|
| def is_appropriate_type(self, tensor): |
| return isinstance(tensor, self.tinygrad.Tensor) |
|
|
| def from_numpy(self, x): |
| return self.tinygrad.Tensor(x) |
|
|
| def to_numpy(self, x): |
| return x.numpy() |
|
|
| def arange(self, start, stop): |
| return self.tinygrad.Tensor.arange(start, stop) |
|
|
| def shape(self, x): |
| return x.shape |
|
|
| def reshape(self, x, shape): |
| return x.reshape(shape) |
|
|
| def transpose(self, x, axes): |
| return x.permute(axes) |
|
|
| def reduce(self, x, operation, axes): |
| for axis in sorted(axes, reverse=True): |
| x = getattr(x, operation)(axis=axis) |
| return x |
|
|
| def stack_on_zeroth_dimension(self, tensors: list): |
| return self.tinygrad.Tensor.stack(tensors) |
|
|
| def add_axis(self, x, new_position): |
| return x.unsqueeze(new_position) |
|
|
| def tile(self, x, repeats): |
| return x.repeat(repeats) |
|
|
| def concat(self, tensors, axis: int): |
| return tensors[0].cat(*tensors[1:], dim=axis) if len(tensors) > 1 else tensors[0] |
|
|
| def is_float_type(self, x): |
| return self.tinygrad.dtypes.is_float(x.dtype) |
|
|
| def einsum(self, pattern, *x): |
| return self.tinygrad.Tensor.einsum(pattern, *x) |
|
|
|
|
| class PyTensorBackend(AbstractBackend): |
| framework_name = "pytensor" |
|
|
| def __init__(self): |
| from pytensor import tensor |
|
|
| self.pt = tensor |
|
|
| def is_appropriate_type(self, tensor): |
| return isinstance(tensor, self.pt.TensorVariable) |
|
|
| def is_float_type(self, x): |
| return x.dtype in self.pt.type.float_dtypes |
|
|
| def from_numpy(self, x): |
| return self.pt.as_tensor(x) |
|
|
| def to_numpy(self, x): |
| return x.eval() |
|
|
| def create_symbol(self, shape): |
| if not isinstance(shape, tuple | list): |
| shape = (shape,) |
| return self.pt.tensor(shape=shape) |
|
|
| def eval_symbol(self, symbol, symbol_value_pairs): |
| return symbol.eval(dict(symbol_value_pairs)) |
|
|
| def arange(self, start, stop): |
| return self.pt.arange(start, stop) |
|
|
| def shape(self, x): |
| |
| return tuple( |
| static_dim if static_dim is not None else symbolic_dim |
| for static_dim, symbolic_dim in zip(x.type.shape, x.shape) |
| ) |
|
|
| def stack_on_zeroth_dimension(self, tensors: list): |
| return self.pt.stack(tensors) |
|
|
| def tile(self, x, repeats): |
| return self.pt.tile(x, repeats) |
|
|
| def concat(self, tensors, axis: int): |
| return self.pt.concatenate(tensors, axis=axis) |
|
|
| def add_axis(self, x, new_position): |
| return self.pt.expand_dims(x, new_position) |
|
|
| def einsum(self, pattern, *x): |
| return self.pt.einsum(pattern, *x) |
|
|
|
|
| class MLXBackend(AbstractBackend): |
| framework_name = "mlx" |
|
|
| def __init__(self): |
| import mlx.core as mx |
| import numpy as np |
|
|
| self.mx = mx |
| self.np = np |
|
|
| def is_appropriate_type(self, tensor): |
| return isinstance(tensor, self.mx.array) |
|
|
| def from_numpy(self, x): |
| return self.mx.array(x) |
|
|
| def to_numpy(self, x): |
| if x.dtype == self.mx.bfloat16: |
| x = x.astype(self.mx.float32) |
| return self.np.array(x) |
|
|
| def arange(self, start, stop): |
| return self.mx.arange(start, stop) |
|
|
| def stack_on_zeroth_dimension(self, tensors: list): |
| return self.mx.stack(tensors) |
|
|
| def add_axes(self, x, new_position): |
| return self.mx.expand_dims(x, new_position) |
|
|
| def tile(self, x, repeats): |
| return self.mx.tile(x, repeats) |
|
|
| def concat(self, tensors, axis: int): |
| return self.mx.concatenate(tensors, axis=axis) |
|
|
| def is_float_type(self, x): |
| return self.mx.issubdtype(x.dtype, self.mx.floating) |
|
|
| def einsum(self, pattern, *x): |
| return self.mx.einsum(pattern, *x) |
|
|