import numpy as np from keras.src import tree from keras.src.backend import config from keras.src.backend import standardize_dtype from keras.src.backend.common import dtypes from keras.src.backend.common.backend_utils import standardize_axis_for_numpy from keras.src.backend.numpy.core import convert_to_tensor def rot90(array, k=1, axes=(0, 1)): """Rotate an array by 90 degrees in the specified plane.""" if array.ndim < 2: raise ValueError( "Input array must have at least 2 dimensions. " f"Received: array.ndim={array.ndim}" ) if len(axes) != 2 or axes[0] == axes[1]: raise ValueError( f"Invalid axes: {axes}. Axes must be a tuple " "of two different dimensions." ) return np.rot90(array, k=k, axes=axes) def add(x1, x2): if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) if not isinstance(x2, (int, float)): x2 = convert_to_tensor(x2) dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)), ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) return np.add(x1, x2) def einsum(subscripts, *operands, **kwargs): operands = tree.map_structure(convert_to_tensor, operands) dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands)) # When operands are of int8, we cast the result to int32 to align with # the behavior of jax. if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == "int8": compute_dtype = "int32" # prevent overflow result_dtype = "int32" else: result_dtype = dtypes.result_type(*dtypes_to_resolve) compute_dtype = result_dtype # TODO: np.einsum doesn't support bfloat16 if compute_dtype == "bfloat16": compute_dtype = "float32" operands = tree.map_structure(lambda x: x.astype(compute_dtype), operands) return np.einsum(subscripts, *operands, **kwargs).astype(result_dtype) def subtract(x1, x2): if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) if not isinstance(x2, (int, float)): x2 = convert_to_tensor(x2) dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)), ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) return np.subtract(x1, x2) def matmul(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) # When both x1 and x2 are of int8, we cast the outputs to int32 to align # with jax x1_dtype = standardize_dtype(x1.dtype) x2_dtype = standardize_dtype(x2.dtype) if x1_dtype == "int8" and x2_dtype == "int8": dtype = "int32" else: dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = x1.astype(dtype) x2 = x2.astype(dtype) return np.matmul(x1, x2).astype(dtype) def multiply(x1, x2): if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) if not isinstance(x2, (int, float)): x2 = convert_to_tensor(x2) dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)), ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) return np.multiply(x1, x2) def mean(x, axis=None, keepdims=False): axis = standardize_axis_for_numpy(axis) x = convert_to_tensor(x) ori_dtype = standardize_dtype(x.dtype) if "int" in ori_dtype or ori_dtype == "bool": result_dtype = dtypes.result_type(x.dtype, "float32") else: result_dtype = ori_dtype return np.mean(x, axis=axis, keepdims=keepdims).astype(result_dtype) def max(x, axis=None, keepdims=False, initial=None): axis = standardize_axis_for_numpy(axis) return np.max(x, axis=axis, keepdims=keepdims, initial=initial) def ones(shape, dtype=None): dtype = dtype or config.floatx() return np.ones(shape, dtype=dtype) def zeros(shape, dtype=None): dtype = dtype or config.floatx() return np.zeros(shape, dtype=dtype) def absolute(x): return np.absolute(x) def abs(x): return absolute(x) def all(x, axis=None, keepdims=False): axis = standardize_axis_for_numpy(axis) return np.all(x, axis=axis, keepdims=keepdims) def angle(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.angle(x) def any(x, axis=None, keepdims=False): axis = standardize_axis_for_numpy(axis) return np.any(x, axis=axis, keepdims=keepdims) def amax(x, axis=None, keepdims=False): axis = standardize_axis_for_numpy(axis) return np.amax(x, axis=axis, keepdims=keepdims) def amin(x, axis=None, keepdims=False): axis = standardize_axis_for_numpy(axis) return np.amin(x, axis=axis, keepdims=keepdims) def append(x1, x2, axis=None): axis = standardize_axis_for_numpy(axis) x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = x1.astype(dtype) x2 = x2.astype(dtype) return np.append(x1, x2, axis=axis) def arange(start, stop=None, step=None, dtype=None): if dtype is None: dtypes_to_resolve = [ getattr(start, "dtype", type(start)), getattr(step, "dtype", type(step)), ] if stop is not None: dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) dtype = dtypes.result_type(*dtypes_to_resolve) return np.arange(start, stop, step=step, dtype=dtype) def arccos(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.arccos(x) def arccosh(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.arccosh(x) def arcsin(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.arcsin(x) def arcsinh(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.arcsinh(x) def arctan(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.arctan(x) def arctan2(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) dtype = dtypes.result_type(x1.dtype, x2.dtype, float) x1 = x1.astype(dtype) x2 = x2.astype(dtype) return np.arctan2(x1, x2) def arctanh(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.arctanh(x) def argmax(x, axis=None, keepdims=False): x = convert_to_tensor(x) axis = standardize_axis_for_numpy(axis) dtype = standardize_dtype(x.dtype) if "float" not in dtype or x.ndim == 0: return np.argmax(x, axis=axis, keepdims=keepdims).astype("int32") dtype = dtypes.result_type(dtype, "float32") x = x.astype(dtype) is_negative_zero = (x == 0.0) & np.signbit(x) x = np.where(is_negative_zero, -np.finfo(x.dtype).tiny, x) return np.argmax(x, axis=axis, keepdims=keepdims).astype("int32") def argmin(x, axis=None, keepdims=False): x = convert_to_tensor(x) axis = standardize_axis_for_numpy(axis) dtype = standardize_dtype(x.dtype) if "float" not in dtype or x.ndim == 0: return np.argmin(x, axis=axis, keepdims=keepdims).astype("int32") dtype = dtypes.result_type(dtype, "float32") x = x.astype(dtype) is_negative_zero = (x == 0.0) & np.signbit(x) x = np.where(is_negative_zero, -np.finfo(x.dtype).tiny, x) return np.argmin(x, axis=axis, keepdims=keepdims).astype("int32") def argsort(x, axis=-1): axis = standardize_axis_for_numpy(axis) return np.argsort(x, axis=axis).astype("int32") def array(x, dtype=None): return convert_to_tensor(x, dtype=dtype) def average(x, axis=None, weights=None): axis = standardize_axis_for_numpy(axis) x = convert_to_tensor(x) dtypes_to_resolve = [x.dtype, float] if weights is not None: weights = convert_to_tensor(weights) dtypes_to_resolve.append(weights.dtype) dtype = dtypes.result_type(*dtypes_to_resolve) x = x.astype(dtype) if weights is not None: weights = weights.astype(dtype) return np.average(x, weights=weights, axis=axis) def bartlett(x): x = convert_to_tensor(x) return np.bartlett(x).astype(config.floatx()) def hamming(x): x = convert_to_tensor(x) return np.hamming(x).astype(config.floatx()) def bincount(x, weights=None, minlength=0, sparse=False): if sparse: raise ValueError("Unsupported value `sparse=True` with numpy backend") x = convert_to_tensor(x) dtypes_to_resolve = [x.dtype] if weights is not None: weights = convert_to_tensor(weights) dtypes_to_resolve.append(weights.dtype) dtype = dtypes.result_type(*dtypes_to_resolve) else: dtype = "int32" if len(x.shape) == 2: if weights is None: def bincount_fn(arr): return np.bincount(arr, minlength=minlength) bincounts = list(map(bincount_fn, x)) else: def bincount_fn(arr_w): return np.bincount( arr_w[0], weights=arr_w[1], minlength=minlength ) bincounts = list(map(bincount_fn, zip(x, weights))) return np.stack(bincounts).astype(dtype) return np.bincount(x, weights, minlength).astype(dtype) def bitwise_and(x, y): x = convert_to_tensor(x) y = convert_to_tensor(y) return np.bitwise_and(x, y) def bitwise_invert(x): x = convert_to_tensor(x) return np.bitwise_not(x) def bitwise_not(x): return bitwise_invert(x) def bitwise_or(x, y): x = convert_to_tensor(x) y = convert_to_tensor(y) return np.bitwise_or(x, y) def bitwise_xor(x, y): x = convert_to_tensor(x) y = convert_to_tensor(y) return np.bitwise_xor(x, y) def bitwise_left_shift(x, y): x = convert_to_tensor(x) if not isinstance(y, int): y = convert_to_tensor(y) return np.left_shift(x, y) def left_shift(x, y): return bitwise_left_shift(x, y) def bitwise_right_shift(x, y): x = convert_to_tensor(x) if not isinstance(y, int): y = convert_to_tensor(y) return np.right_shift(x, y) def right_shift(x, y): return bitwise_right_shift(x, y) def blackman(x): x = convert_to_tensor(x) return np.blackman(x).astype(config.floatx()) def broadcast_to(x, shape): return np.broadcast_to(x, shape) def ceil(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.ceil(x) def clip(x, x_min, x_max): x = convert_to_tensor(x) dtype = standardize_dtype(x.dtype) if dtype == "bool": dtype = "int32" return np.clip(x, x_min, x_max).astype(dtype) def concatenate(xs, axis=0): axis = standardize_axis_for_numpy(axis) dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) xs = tree.map_structure( lambda x: convert_to_tensor(x).astype(dtype), xs ) return np.concatenate(xs, axis=axis) def conjugate(x): return np.conjugate(x) def conj(x): return conjugate(x) def copy(x): return np.copy(x) def cos(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.cos(x) def cosh(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.cosh(x) def count_nonzero(x, axis=None): axis = standardize_axis_for_numpy(axis) # np.count_nonzero will return python int when axis=None, so we need # to convert_to_tensor return convert_to_tensor(np.count_nonzero(x, axis=axis)).astype("int32") def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): axis = standardize_axis_for_numpy(axis) x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = x1.astype(dtype) x2 = x2.astype(dtype) return np.cross( x1, x2, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis, ) def cumprod(x, axis=None, dtype=None): axis = standardize_axis_for_numpy(axis) dtype = dtypes.result_type(dtype or x.dtype) if dtype == "bool": dtype = "int32" return np.cumprod(x, axis=axis, dtype=dtype) def cumsum(x, axis=None, dtype=None): axis = standardize_axis_for_numpy(axis) dtype = dtypes.result_type(dtype or x.dtype) if dtype == "bool": dtype = "int32" return np.cumsum(x, axis=axis, dtype=dtype) def diag(x, k=0): return np.diag(x, k=k) def diagflat(x, k=0): return np.diagflat(x, k=k) def diagonal(x, offset=0, axis1=0, axis2=1): axis1 = standardize_axis_for_numpy(axis1) axis2 = standardize_axis_for_numpy(axis2) return np.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) def diff(a, n=1, axis=-1): return np.diff(a, n=n, axis=axis) def digitize(x, bins): return np.digitize(x, bins).astype(np.int32) def dot(x, y): x = convert_to_tensor(x) y = convert_to_tensor(y) dtype = dtypes.result_type(x.dtype, y.dtype) x = x.astype(dtype) y = y.astype(dtype) return np.dot(x, y) def empty(shape, dtype=None): dtype = dtype or config.floatx() return np.empty(shape, dtype=dtype) def equal(x1, x2): return np.equal(x1, x2) def exp(x): x = convert_to_tensor(x) ori_dtype = standardize_dtype(x.dtype) if "int" in ori_dtype or ori_dtype == "bool": x = x.astype(config.floatx()) return np.exp(x) def exp2(x): x = convert_to_tensor(x) ori_dtype = standardize_dtype(x.dtype) if "int" in ori_dtype or ori_dtype == "bool": x = x.astype(config.floatx()) return np.exp2(x) def expand_dims(x, axis): axis = standardize_axis_for_numpy(axis) return np.expand_dims(x, axis) def expm1(x): x = convert_to_tensor(x) ori_dtype = standardize_dtype(x.dtype) if "int" in ori_dtype or ori_dtype == "bool": x = x.astype(config.floatx()) return np.expm1(x) def flip(x, axis=None): axis = standardize_axis_for_numpy(axis) return np.flip(x, axis=axis) def floor(x): x = convert_to_tensor(x) dtype = ( config.floatx() if standardize_dtype(x.dtype) == "int64" else dtypes.result_type(x.dtype, float) ) x = x.astype(dtype) return np.floor(x) def full(shape, fill_value, dtype=None): dtype = dtype or config.floatx() return np.full(shape, fill_value, dtype=dtype) def full_like(x, fill_value, dtype=None): return np.full_like(x, fill_value, dtype=dtype) def greater(x1, x2): return np.greater(x1, x2) def greater_equal(x1, x2): return np.greater_equal(x1, x2) def hstack(xs): dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) xs = tree.map_structure( lambda x: convert_to_tensor(x).astype(dtype), xs ) return np.hstack(xs) def identity(n, dtype=None): dtype = dtype or config.floatx() return np.identity(n, dtype=dtype) def imag(x): return np.imag(x) def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False): return np.isclose(x1, x2, rtol, atol, equal_nan) def isfinite(x): return np.isfinite(x) def isinf(x): return np.isinf(x) def isnan(x): return np.isnan(x) def less(x1, x2): return np.less(x1, x2) def less_equal(x1, x2): return np.less_equal(x1, x2) def linspace( start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 ): axis = standardize_axis_for_numpy(axis) if dtype is None: dtypes_to_resolve = [ getattr(start, "dtype", type(start)), getattr(stop, "dtype", type(stop)), float, ] dtype = dtypes.result_type(*dtypes_to_resolve) return np.linspace( start, stop, num=num, endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis, ) def log(x): x = convert_to_tensor(x) dtype = ( config.floatx() if standardize_dtype(x.dtype) == "int64" else dtypes.result_type(x.dtype, float) ) return np.log(x, dtype=dtype) def log10(x): x = convert_to_tensor(x) dtype = ( config.floatx() if standardize_dtype(x.dtype) == "int64" else dtypes.result_type(x.dtype, float) ) return np.log10(x, dtype=dtype) def log1p(x): x = convert_to_tensor(x) dtype = ( config.floatx() if standardize_dtype(x.dtype) == "int64" else dtypes.result_type(x.dtype, float) ) return np.log1p(x, dtype=dtype) def log2(x): x = convert_to_tensor(x) dtype = ( config.floatx() if standardize_dtype(x.dtype) == "int64" else dtypes.result_type(x.dtype, float) ) return np.log2(x, dtype=dtype) def logaddexp(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) dtype = dtypes.result_type(x1.dtype, x2.dtype, float) x1 = x1.astype(dtype) x2 = x2.astype(dtype) return np.logaddexp(x1, x2) def logical_and(x1, x2): return np.logical_and(x1, x2) def logical_not(x): return np.logical_not(x) def logical_or(x1, x2): return np.logical_or(x1, x2) def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): if dtype is None: dtypes_to_resolve = [ getattr(start, "dtype", type(start)), getattr(stop, "dtype", type(stop)), float, ] dtype = dtypes.result_type(*dtypes_to_resolve) return np.logspace( start, stop, num=num, endpoint=endpoint, base=base, dtype=dtype, axis=axis, ) def maximum(x1, x2): if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) if not isinstance(x2, (int, float)): x2 = convert_to_tensor(x2) dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)), ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) return np.maximum(x1, x2) def median(x, axis=None, keepdims=False): dtype = dtypes.result_type(x.dtype, float) return np.median(x, axis=axis, keepdims=keepdims).astype(dtype) def meshgrid(*x, indexing="xy"): return np.meshgrid(*x, indexing=indexing) def min(x, axis=None, keepdims=False, initial=None): axis = standardize_axis_for_numpy(axis) return np.min(x, axis=axis, keepdims=keepdims, initial=initial) def minimum(x1, x2): if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) if not isinstance(x2, (int, float)): x2 = convert_to_tensor(x2) dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)), ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) return np.minimum(x1, x2) def mod(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) dtype = dtypes.result_type(x1.dtype, x2.dtype) if dtype == "bool": dtype = "int32" x1 = x1.astype(dtype) x2 = x2.astype(dtype) return np.mod(x1, x2) def moveaxis(x, source, destination): return np.moveaxis(x, source=source, destination=destination) def nan_to_num(x, nan=0.0, posinf=None, neginf=None): return np.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) def ndim(x): return np.ndim(x) def nonzero(x): return tuple(indices.astype("int32") for indices in np.nonzero(x)) def not_equal(x1, x2): return np.not_equal(x1, x2) def zeros_like(x, dtype=None): return np.zeros_like(x, dtype=dtype) def ones_like(x, dtype=None): return np.ones_like(x, dtype=dtype) def outer(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = x1.astype(dtype) x2 = x2.astype(dtype) return np.outer(x1, x2) def pad(x, pad_width, mode="constant", constant_values=None): kwargs = {} if constant_values is not None: if mode != "constant": raise ValueError( "Argument `constant_values` can only be " "provided when `mode == 'constant'`. " f"Received: mode={mode}" ) kwargs["constant_values"] = constant_values return np.pad(x, pad_width, mode=mode, **kwargs) def prod(x, axis=None, keepdims=False, dtype=None): axis = standardize_axis_for_numpy(axis) x = convert_to_tensor(x) if dtype is None: dtype = dtypes.result_type(x.dtype) if dtype in ("bool", "int8", "int16"): dtype = "int32" elif dtype in ("uint8", "uint16"): dtype = "uint32" return np.prod(x, axis=axis, keepdims=keepdims, dtype=dtype) def quantile(x, q, axis=None, method="linear", keepdims=False): axis = standardize_axis_for_numpy(axis) x = convert_to_tensor(x) ori_dtype = standardize_dtype(x.dtype) # np.quantile doesn't support bool if ori_dtype == "bool": x = x.astype(config.floatx()) if ori_dtype == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) return np.quantile( x, q, axis=axis, method=method, keepdims=keepdims ).astype(dtype) def ravel(x): return np.ravel(x) def unravel_index(x, shape): dtype = dtypes.result_type(x.dtype) return tuple( indices.astype(dtype) for indices in np.unravel_index(x, shape) ) def real(x): return np.real(x) def reciprocal(x): return np.reciprocal(x) def repeat(x, repeats, axis=None): return np.repeat(x, repeats, axis=axis) def reshape(x, newshape): return np.reshape(x, newshape) def roll(x, shift, axis=None): return np.roll(x, shift, axis=axis) def searchsorted(sorted_sequence, values, side="left"): if ndim(sorted_sequence) != 1: raise ValueError( "`searchsorted` only supports 1-D sorted sequences. " "You can use `keras.ops.vectorized_map` " "to extend it to N-D sequences. Received: " f"sorted_sequence.shape={sorted_sequence.shape}" ) out_type = ( "int32" if len(sorted_sequence) <= np.iinfo(np.int32).max else "int64" ) return np.searchsorted(sorted_sequence, values, side=side).astype(out_type) def sign(x): return np.sign(x) def signbit(x): return np.signbit(x) def sin(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.sin(x) def sinh(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.sinh(x) def size(x): return np.size(x) def sort(x, axis=-1): axis = standardize_axis_for_numpy(axis) return np.sort(x, axis=axis) def split(x, indices_or_sections, axis=0): axis = standardize_axis_for_numpy(axis) return np.split(x, indices_or_sections, axis=axis) def stack(x, axis=0): axis = standardize_axis_for_numpy(axis) dtype_set = set([getattr(a, "dtype", type(a)) for a in x]) if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) x = tree.map_structure(lambda a: convert_to_tensor(a).astype(dtype), x) return np.stack(x, axis=axis) def std(x, axis=None, keepdims=False): axis = standardize_axis_for_numpy(axis) x = convert_to_tensor(x) ori_dtype = standardize_dtype(x.dtype) if "int" in ori_dtype or ori_dtype == "bool": x = x.astype(config.floatx()) return np.std(x, axis=axis, keepdims=keepdims) def swapaxes(x, axis1, axis2): return np.swapaxes(x, axis1=axis1, axis2=axis2) def take(x, indices, axis=None): axis = standardize_axis_for_numpy(axis) return np.take(x, indices, axis=axis) def take_along_axis(x, indices, axis=None): axis = standardize_axis_for_numpy(axis) return np.take_along_axis(x, indices, axis=axis) def tan(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.tan(x) def tanh(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = x.astype(dtype) return np.tanh(x) def tensordot(x1, x2, axes=2): axes = tuple(axes) if isinstance(axes, list) else axes x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = x1.astype(dtype) x2 = x2.astype(dtype) return np.tensordot(x1, x2, axes=axes) def round(x, decimals=0): return np.round(x, decimals=decimals) def tile(x, repeats): return np.tile(x, repeats) def trace(x, offset=0, axis1=0, axis2=1): axis1 = standardize_axis_for_numpy(axis1) axis2 = standardize_axis_for_numpy(axis2) x = convert_to_tensor(x) dtype = standardize_dtype(x.dtype) if dtype not in ("int64", "uint32", "uint64"): dtype = dtypes.result_type(dtype, "int32") return np.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) def tri(N, M=None, k=0, dtype=None): dtype = dtype or config.floatx() return np.tri(N, M=M, k=k, dtype=dtype) def tril(x, k=0): return np.tril(x, k=k) def triu(x, k=0): return np.triu(x, k=k) def trunc(x): x = convert_to_tensor(x) dtype = standardize_dtype(x.dtype) if "int" in dtype or "bool" == dtype: return x return np.trunc(x) def vdot(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = x1.astype(dtype) x2 = x2.astype(dtype) return np.vdot(x1, x2) def inner(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = x1.astype(dtype) x2 = x2.astype(dtype) return np.inner(x1, x2) def vstack(xs): dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) xs = tree.map_structure( lambda x: convert_to_tensor(x).astype(dtype), xs ) return np.vstack(xs) def vectorize(pyfunc, *, excluded=None, signature=None): return np.vectorize(pyfunc, excluded=excluded, signature=signature) def where(condition, x1, x2): if x1 is not None and x2 is not None: if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) if not isinstance(x2, (int, float)): x2 = convert_to_tensor(x2) dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)), ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) return np.where(condition, x1, x2) else: return np.where(condition) def divide(x1, x2): if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) if not isinstance(x2, (int, float)): x2 = convert_to_tensor(x2) dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)), float, ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) return np.divide(x1, x2) def divide_no_nan(x1, x2): if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) if not isinstance(x2, (int, float)): x2 = convert_to_tensor(x2) dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)), float, ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) # No need for the double-where trick since we don't calculate gradients in # numpy backend. return np.where(x2 == 0, np.array(0, dtype=dtype), np.divide(x1, x2)) def true_divide(x1, x2): return divide(x1, x2) def power(x1, x2): if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) if not isinstance(x2, (int, float)): x2 = convert_to_tensor(x2) dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)), ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) return np.power(x1, x2) def negative(x): return np.negative(x) def square(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "bool": x = x.astype("int32") return np.square(x) def sqrt(x): x = convert_to_tensor(x) # upcast to float64 for int64 which matches JAX's behavior dtype = ( config.floatx() if standardize_dtype(x.dtype) == "int64" else dtypes.result_type(x.dtype, float) ) return np.sqrt(x, dtype=dtype) def squeeze(x, axis=None): axis = standardize_axis_for_numpy(axis) return np.squeeze(x, axis=axis) def transpose(x, axes=None): axes = tuple(axes) if isinstance(axes, list) else axes return np.transpose(x, axes=axes) def var(x, axis=None, keepdims=False): axis = standardize_axis_for_numpy(axis) x = convert_to_tensor(x) compute_dtype = dtypes.result_type(x.dtype, "float32") result_dtype = dtypes.result_type(x.dtype, float) return np.var(x, axis=axis, keepdims=keepdims, dtype=compute_dtype).astype( result_dtype ) def sum(x, axis=None, keepdims=False): axis = standardize_axis_for_numpy(axis) dtype = standardize_dtype(x.dtype) # follow jax's rule if dtype in ("bool", "int8", "int16"): dtype = "int32" elif dtype in ("uint8", "uint16"): dtype = "uint32" return np.sum(x, axis=axis, keepdims=keepdims).astype(dtype) def eye(N, M=None, k=0, dtype=None): dtype = dtype or config.floatx() return np.eye(N, M=M, k=k, dtype=dtype) def floor_divide(x1, x2): if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) if not isinstance(x2, (int, float)): x2 = convert_to_tensor(x2) dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)) ) x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) return np.floor_divide(x1, x2) def logical_xor(x1, x2): return np.logical_xor(x1, x2) def correlate(x1, x2, mode="valid"): dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)), ) if dtype == "int64": dtype = "float64" elif dtype not in ["bfloat16", "float16", "float64"]: dtype = "float32" x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) return np.correlate(x1, x2, mode) def select(condlist, choicelist, default=0): return np.select(condlist, choicelist, default=default) def slogdet(x): return tuple(np.linalg.slogdet(x)) def argpartition(x, kth, axis=-1): return np.argpartition(x, kth, axis).astype("int32") def histogram(x, bins, range): return np.histogram(x, bins=bins, range=range)