joebruce1313's picture
Upload 38004 files
1f5470c verified
import numpy as np
from keras.src import tree
from keras.src.backend import config
from keras.src.backend import standardize_dtype
from keras.src.backend.common import dtypes
from keras.src.backend.common.backend_utils import standardize_axis_for_numpy
from keras.src.backend.numpy.core import convert_to_tensor
def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the specified plane."""
if array.ndim < 2:
raise ValueError(
"Input array must have at least 2 dimensions. "
f"Received: array.ndim={array.ndim}"
)
if len(axes) != 2 or axes[0] == axes[1]:
raise ValueError(
f"Invalid axes: {axes}. Axes must be a tuple "
"of two different dimensions."
)
return np.rot90(array, k=k, axes=axes)
def add(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
if not isinstance(x2, (int, float)):
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.add(x1, x2)
def einsum(subscripts, *operands, **kwargs):
operands = tree.map_structure(convert_to_tensor, operands)
dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands))
# When operands are of int8, we cast the result to int32 to align with
# the behavior of jax.
if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == "int8":
compute_dtype = "int32" # prevent overflow
result_dtype = "int32"
else:
result_dtype = dtypes.result_type(*dtypes_to_resolve)
compute_dtype = result_dtype
# TODO: np.einsum doesn't support bfloat16
if compute_dtype == "bfloat16":
compute_dtype = "float32"
operands = tree.map_structure(lambda x: x.astype(compute_dtype), operands)
return np.einsum(subscripts, *operands, **kwargs).astype(result_dtype)
def subtract(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
if not isinstance(x2, (int, float)):
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.subtract(x1, x2)
def matmul(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
# When both x1 and x2 are of int8, we cast the outputs to int32 to align
# with jax
x1_dtype = standardize_dtype(x1.dtype)
x2_dtype = standardize_dtype(x2.dtype)
if x1_dtype == "int8" and x2_dtype == "int8":
dtype = "int32"
else:
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.matmul(x1, x2).astype(dtype)
def multiply(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
if not isinstance(x2, (int, float)):
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.multiply(x1, x2)
def mean(x, axis=None, keepdims=False):
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
if "int" in ori_dtype or ori_dtype == "bool":
result_dtype = dtypes.result_type(x.dtype, "float32")
else:
result_dtype = ori_dtype
return np.mean(x, axis=axis, keepdims=keepdims).astype(result_dtype)
def max(x, axis=None, keepdims=False, initial=None):
axis = standardize_axis_for_numpy(axis)
return np.max(x, axis=axis, keepdims=keepdims, initial=initial)
def ones(shape, dtype=None):
dtype = dtype or config.floatx()
return np.ones(shape, dtype=dtype)
def zeros(shape, dtype=None):
dtype = dtype or config.floatx()
return np.zeros(shape, dtype=dtype)
def absolute(x):
return np.absolute(x)
def abs(x):
return absolute(x)
def all(x, axis=None, keepdims=False):
axis = standardize_axis_for_numpy(axis)
return np.all(x, axis=axis, keepdims=keepdims)
def angle(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.angle(x)
def any(x, axis=None, keepdims=False):
axis = standardize_axis_for_numpy(axis)
return np.any(x, axis=axis, keepdims=keepdims)
def amax(x, axis=None, keepdims=False):
axis = standardize_axis_for_numpy(axis)
return np.amax(x, axis=axis, keepdims=keepdims)
def amin(x, axis=None, keepdims=False):
axis = standardize_axis_for_numpy(axis)
return np.amin(x, axis=axis, keepdims=keepdims)
def append(x1, x2, axis=None):
axis = standardize_axis_for_numpy(axis)
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.append(x1, x2, axis=axis)
def arange(start, stop=None, step=None, dtype=None):
if dtype is None:
dtypes_to_resolve = [
getattr(start, "dtype", type(start)),
getattr(step, "dtype", type(step)),
]
if stop is not None:
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
dtype = dtypes.result_type(*dtypes_to_resolve)
return np.arange(start, stop, step=step, dtype=dtype)
def arccos(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.arccos(x)
def arccosh(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.arccosh(x)
def arcsin(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.arcsin(x)
def arcsinh(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.arcsinh(x)
def arctan(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.arctan(x)
def arctan2(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.arctan2(x1, x2)
def arctanh(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.arctanh(x)
def argmax(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
axis = standardize_axis_for_numpy(axis)
dtype = standardize_dtype(x.dtype)
if "float" not in dtype or x.ndim == 0:
return np.argmax(x, axis=axis, keepdims=keepdims).astype("int32")
dtype = dtypes.result_type(dtype, "float32")
x = x.astype(dtype)
is_negative_zero = (x == 0.0) & np.signbit(x)
x = np.where(is_negative_zero, -np.finfo(x.dtype).tiny, x)
return np.argmax(x, axis=axis, keepdims=keepdims).astype("int32")
def argmin(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
axis = standardize_axis_for_numpy(axis)
dtype = standardize_dtype(x.dtype)
if "float" not in dtype or x.ndim == 0:
return np.argmin(x, axis=axis, keepdims=keepdims).astype("int32")
dtype = dtypes.result_type(dtype, "float32")
x = x.astype(dtype)
is_negative_zero = (x == 0.0) & np.signbit(x)
x = np.where(is_negative_zero, -np.finfo(x.dtype).tiny, x)
return np.argmin(x, axis=axis, keepdims=keepdims).astype("int32")
def argsort(x, axis=-1):
axis = standardize_axis_for_numpy(axis)
return np.argsort(x, axis=axis).astype("int32")
def array(x, dtype=None):
return convert_to_tensor(x, dtype=dtype)
def average(x, axis=None, weights=None):
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
dtypes_to_resolve = [x.dtype, float]
if weights is not None:
weights = convert_to_tensor(weights)
dtypes_to_resolve.append(weights.dtype)
dtype = dtypes.result_type(*dtypes_to_resolve)
x = x.astype(dtype)
if weights is not None:
weights = weights.astype(dtype)
return np.average(x, weights=weights, axis=axis)
def bartlett(x):
x = convert_to_tensor(x)
return np.bartlett(x).astype(config.floatx())
def hamming(x):
x = convert_to_tensor(x)
return np.hamming(x).astype(config.floatx())
def bincount(x, weights=None, minlength=0, sparse=False):
if sparse:
raise ValueError("Unsupported value `sparse=True` with numpy backend")
x = convert_to_tensor(x)
dtypes_to_resolve = [x.dtype]
if weights is not None:
weights = convert_to_tensor(weights)
dtypes_to_resolve.append(weights.dtype)
dtype = dtypes.result_type(*dtypes_to_resolve)
else:
dtype = "int32"
if len(x.shape) == 2:
if weights is None:
def bincount_fn(arr):
return np.bincount(arr, minlength=minlength)
bincounts = list(map(bincount_fn, x))
else:
def bincount_fn(arr_w):
return np.bincount(
arr_w[0], weights=arr_w[1], minlength=minlength
)
bincounts = list(map(bincount_fn, zip(x, weights)))
return np.stack(bincounts).astype(dtype)
return np.bincount(x, weights, minlength).astype(dtype)
def bitwise_and(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return np.bitwise_and(x, y)
def bitwise_invert(x):
x = convert_to_tensor(x)
return np.bitwise_not(x)
def bitwise_not(x):
return bitwise_invert(x)
def bitwise_or(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return np.bitwise_or(x, y)
def bitwise_xor(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return np.bitwise_xor(x, y)
def bitwise_left_shift(x, y):
x = convert_to_tensor(x)
if not isinstance(y, int):
y = convert_to_tensor(y)
return np.left_shift(x, y)
def left_shift(x, y):
return bitwise_left_shift(x, y)
def bitwise_right_shift(x, y):
x = convert_to_tensor(x)
if not isinstance(y, int):
y = convert_to_tensor(y)
return np.right_shift(x, y)
def right_shift(x, y):
return bitwise_right_shift(x, y)
def blackman(x):
x = convert_to_tensor(x)
return np.blackman(x).astype(config.floatx())
def broadcast_to(x, shape):
return np.broadcast_to(x, shape)
def ceil(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.ceil(x)
def clip(x, x_min, x_max):
x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if dtype == "bool":
dtype = "int32"
return np.clip(x, x_min, x_max).astype(dtype)
def concatenate(xs, axis=0):
axis = standardize_axis_for_numpy(axis)
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tree.map_structure(
lambda x: convert_to_tensor(x).astype(dtype), xs
)
return np.concatenate(xs, axis=axis)
def conjugate(x):
return np.conjugate(x)
def conj(x):
return conjugate(x)
def copy(x):
return np.copy(x)
def cos(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.cos(x)
def cosh(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.cosh(x)
def count_nonzero(x, axis=None):
axis = standardize_axis_for_numpy(axis)
# np.count_nonzero will return python int when axis=None, so we need
# to convert_to_tensor
return convert_to_tensor(np.count_nonzero(x, axis=axis)).astype("int32")
def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
axis = standardize_axis_for_numpy(axis)
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.cross(
x1,
x2,
axisa=axisa,
axisb=axisb,
axisc=axisc,
axis=axis,
)
def cumprod(x, axis=None, dtype=None):
axis = standardize_axis_for_numpy(axis)
dtype = dtypes.result_type(dtype or x.dtype)
if dtype == "bool":
dtype = "int32"
return np.cumprod(x, axis=axis, dtype=dtype)
def cumsum(x, axis=None, dtype=None):
axis = standardize_axis_for_numpy(axis)
dtype = dtypes.result_type(dtype or x.dtype)
if dtype == "bool":
dtype = "int32"
return np.cumsum(x, axis=axis, dtype=dtype)
def diag(x, k=0):
return np.diag(x, k=k)
def diagflat(x, k=0):
return np.diagflat(x, k=k)
def diagonal(x, offset=0, axis1=0, axis2=1):
axis1 = standardize_axis_for_numpy(axis1)
axis2 = standardize_axis_for_numpy(axis2)
return np.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)
def diff(a, n=1, axis=-1):
return np.diff(a, n=n, axis=axis)
def digitize(x, bins):
return np.digitize(x, bins).astype(np.int32)
def dot(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
dtype = dtypes.result_type(x.dtype, y.dtype)
x = x.astype(dtype)
y = y.astype(dtype)
return np.dot(x, y)
def empty(shape, dtype=None):
dtype = dtype or config.floatx()
return np.empty(shape, dtype=dtype)
def equal(x1, x2):
return np.equal(x1, x2)
def exp(x):
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
if "int" in ori_dtype or ori_dtype == "bool":
x = x.astype(config.floatx())
return np.exp(x)
def exp2(x):
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
if "int" in ori_dtype or ori_dtype == "bool":
x = x.astype(config.floatx())
return np.exp2(x)
def expand_dims(x, axis):
axis = standardize_axis_for_numpy(axis)
return np.expand_dims(x, axis)
def expm1(x):
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
if "int" in ori_dtype or ori_dtype == "bool":
x = x.astype(config.floatx())
return np.expm1(x)
def flip(x, axis=None):
axis = standardize_axis_for_numpy(axis)
return np.flip(x, axis=axis)
def floor(x):
x = convert_to_tensor(x)
dtype = (
config.floatx()
if standardize_dtype(x.dtype) == "int64"
else dtypes.result_type(x.dtype, float)
)
x = x.astype(dtype)
return np.floor(x)
def full(shape, fill_value, dtype=None):
dtype = dtype or config.floatx()
return np.full(shape, fill_value, dtype=dtype)
def full_like(x, fill_value, dtype=None):
return np.full_like(x, fill_value, dtype=dtype)
def greater(x1, x2):
return np.greater(x1, x2)
def greater_equal(x1, x2):
return np.greater_equal(x1, x2)
def hstack(xs):
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tree.map_structure(
lambda x: convert_to_tensor(x).astype(dtype), xs
)
return np.hstack(xs)
def identity(n, dtype=None):
dtype = dtype or config.floatx()
return np.identity(n, dtype=dtype)
def imag(x):
return np.imag(x)
def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):
return np.isclose(x1, x2, rtol, atol, equal_nan)
def isfinite(x):
return np.isfinite(x)
def isinf(x):
return np.isinf(x)
def isnan(x):
return np.isnan(x)
def less(x1, x2):
return np.less(x1, x2)
def less_equal(x1, x2):
return np.less_equal(x1, x2)
def linspace(
start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0
):
axis = standardize_axis_for_numpy(axis)
if dtype is None:
dtypes_to_resolve = [
getattr(start, "dtype", type(start)),
getattr(stop, "dtype", type(stop)),
float,
]
dtype = dtypes.result_type(*dtypes_to_resolve)
return np.linspace(
start,
stop,
num=num,
endpoint=endpoint,
retstep=retstep,
dtype=dtype,
axis=axis,
)
def log(x):
x = convert_to_tensor(x)
dtype = (
config.floatx()
if standardize_dtype(x.dtype) == "int64"
else dtypes.result_type(x.dtype, float)
)
return np.log(x, dtype=dtype)
def log10(x):
x = convert_to_tensor(x)
dtype = (
config.floatx()
if standardize_dtype(x.dtype) == "int64"
else dtypes.result_type(x.dtype, float)
)
return np.log10(x, dtype=dtype)
def log1p(x):
x = convert_to_tensor(x)
dtype = (
config.floatx()
if standardize_dtype(x.dtype) == "int64"
else dtypes.result_type(x.dtype, float)
)
return np.log1p(x, dtype=dtype)
def log2(x):
x = convert_to_tensor(x)
dtype = (
config.floatx()
if standardize_dtype(x.dtype) == "int64"
else dtypes.result_type(x.dtype, float)
)
return np.log2(x, dtype=dtype)
def logaddexp(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.logaddexp(x1, x2)
def logical_and(x1, x2):
return np.logical_and(x1, x2)
def logical_not(x):
return np.logical_not(x)
def logical_or(x1, x2):
return np.logical_or(x1, x2)
def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
if dtype is None:
dtypes_to_resolve = [
getattr(start, "dtype", type(start)),
getattr(stop, "dtype", type(stop)),
float,
]
dtype = dtypes.result_type(*dtypes_to_resolve)
return np.logspace(
start,
stop,
num=num,
endpoint=endpoint,
base=base,
dtype=dtype,
axis=axis,
)
def maximum(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
if not isinstance(x2, (int, float)):
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.maximum(x1, x2)
def median(x, axis=None, keepdims=False):
dtype = dtypes.result_type(x.dtype, float)
return np.median(x, axis=axis, keepdims=keepdims).astype(dtype)
def meshgrid(*x, indexing="xy"):
return np.meshgrid(*x, indexing=indexing)
def min(x, axis=None, keepdims=False, initial=None):
axis = standardize_axis_for_numpy(axis)
return np.min(x, axis=axis, keepdims=keepdims, initial=initial)
def minimum(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
if not isinstance(x2, (int, float)):
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.minimum(x1, x2)
def mod(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
if dtype == "bool":
dtype = "int32"
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.mod(x1, x2)
def moveaxis(x, source, destination):
return np.moveaxis(x, source=source, destination=destination)
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
return np.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
def ndim(x):
return np.ndim(x)
def nonzero(x):
return tuple(indices.astype("int32") for indices in np.nonzero(x))
def not_equal(x1, x2):
return np.not_equal(x1, x2)
def zeros_like(x, dtype=None):
return np.zeros_like(x, dtype=dtype)
def ones_like(x, dtype=None):
return np.ones_like(x, dtype=dtype)
def outer(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.outer(x1, x2)
def pad(x, pad_width, mode="constant", constant_values=None):
kwargs = {}
if constant_values is not None:
if mode != "constant":
raise ValueError(
"Argument `constant_values` can only be "
"provided when `mode == 'constant'`. "
f"Received: mode={mode}"
)
kwargs["constant_values"] = constant_values
return np.pad(x, pad_width, mode=mode, **kwargs)
def prod(x, axis=None, keepdims=False, dtype=None):
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
if dtype is None:
dtype = dtypes.result_type(x.dtype)
if dtype in ("bool", "int8", "int16"):
dtype = "int32"
elif dtype in ("uint8", "uint16"):
dtype = "uint32"
return np.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
def quantile(x, q, axis=None, method="linear", keepdims=False):
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
# np.quantile doesn't support bool
if ori_dtype == "bool":
x = x.astype(config.floatx())
if ori_dtype == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
return np.quantile(
x, q, axis=axis, method=method, keepdims=keepdims
).astype(dtype)
def ravel(x):
return np.ravel(x)
def unravel_index(x, shape):
dtype = dtypes.result_type(x.dtype)
return tuple(
indices.astype(dtype) for indices in np.unravel_index(x, shape)
)
def real(x):
return np.real(x)
def reciprocal(x):
return np.reciprocal(x)
def repeat(x, repeats, axis=None):
return np.repeat(x, repeats, axis=axis)
def reshape(x, newshape):
return np.reshape(x, newshape)
def roll(x, shift, axis=None):
return np.roll(x, shift, axis=axis)
def searchsorted(sorted_sequence, values, side="left"):
if ndim(sorted_sequence) != 1:
raise ValueError(
"`searchsorted` only supports 1-D sorted sequences. "
"You can use `keras.ops.vectorized_map` "
"to extend it to N-D sequences. Received: "
f"sorted_sequence.shape={sorted_sequence.shape}"
)
out_type = (
"int32" if len(sorted_sequence) <= np.iinfo(np.int32).max else "int64"
)
return np.searchsorted(sorted_sequence, values, side=side).astype(out_type)
def sign(x):
return np.sign(x)
def signbit(x):
return np.signbit(x)
def sin(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.sin(x)
def sinh(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.sinh(x)
def size(x):
return np.size(x)
def sort(x, axis=-1):
axis = standardize_axis_for_numpy(axis)
return np.sort(x, axis=axis)
def split(x, indices_or_sections, axis=0):
axis = standardize_axis_for_numpy(axis)
return np.split(x, indices_or_sections, axis=axis)
def stack(x, axis=0):
axis = standardize_axis_for_numpy(axis)
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
x = tree.map_structure(lambda a: convert_to_tensor(a).astype(dtype), x)
return np.stack(x, axis=axis)
def std(x, axis=None, keepdims=False):
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
if "int" in ori_dtype or ori_dtype == "bool":
x = x.astype(config.floatx())
return np.std(x, axis=axis, keepdims=keepdims)
def swapaxes(x, axis1, axis2):
return np.swapaxes(x, axis1=axis1, axis2=axis2)
def take(x, indices, axis=None):
axis = standardize_axis_for_numpy(axis)
return np.take(x, indices, axis=axis)
def take_along_axis(x, indices, axis=None):
axis = standardize_axis_for_numpy(axis)
return np.take_along_axis(x, indices, axis=axis)
def tan(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.tan(x)
def tanh(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = x.astype(dtype)
return np.tanh(x)
def tensordot(x1, x2, axes=2):
axes = tuple(axes) if isinstance(axes, list) else axes
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.tensordot(x1, x2, axes=axes)
def round(x, decimals=0):
return np.round(x, decimals=decimals)
def tile(x, repeats):
return np.tile(x, repeats)
def trace(x, offset=0, axis1=0, axis2=1):
axis1 = standardize_axis_for_numpy(axis1)
axis2 = standardize_axis_for_numpy(axis2)
x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if dtype not in ("int64", "uint32", "uint64"):
dtype = dtypes.result_type(dtype, "int32")
return np.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
def tri(N, M=None, k=0, dtype=None):
dtype = dtype or config.floatx()
return np.tri(N, M=M, k=k, dtype=dtype)
def tril(x, k=0):
return np.tril(x, k=k)
def triu(x, k=0):
return np.triu(x, k=k)
def trunc(x):
x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if "int" in dtype or "bool" == dtype:
return x
return np.trunc(x)
def vdot(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.vdot(x1, x2)
def inner(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.inner(x1, x2)
def vstack(xs):
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tree.map_structure(
lambda x: convert_to_tensor(x).astype(dtype), xs
)
return np.vstack(xs)
def vectorize(pyfunc, *, excluded=None, signature=None):
return np.vectorize(pyfunc, excluded=excluded, signature=signature)
def where(condition, x1, x2):
if x1 is not None and x2 is not None:
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
if not isinstance(x2, (int, float)):
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.where(condition, x1, x2)
else:
return np.where(condition)
def divide(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
if not isinstance(x2, (int, float)):
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
float,
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.divide(x1, x2)
def divide_no_nan(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
if not isinstance(x2, (int, float)):
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
float,
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
# No need for the double-where trick since we don't calculate gradients in
# numpy backend.
return np.where(x2 == 0, np.array(0, dtype=dtype), np.divide(x1, x2))
def true_divide(x1, x2):
return divide(x1, x2)
def power(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
if not isinstance(x2, (int, float)):
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.power(x1, x2)
def negative(x):
return np.negative(x)
def square(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "bool":
x = x.astype("int32")
return np.square(x)
def sqrt(x):
x = convert_to_tensor(x)
# upcast to float64 for int64 which matches JAX's behavior
dtype = (
config.floatx()
if standardize_dtype(x.dtype) == "int64"
else dtypes.result_type(x.dtype, float)
)
return np.sqrt(x, dtype=dtype)
def squeeze(x, axis=None):
axis = standardize_axis_for_numpy(axis)
return np.squeeze(x, axis=axis)
def transpose(x, axes=None):
axes = tuple(axes) if isinstance(axes, list) else axes
return np.transpose(x, axes=axes)
def var(x, axis=None, keepdims=False):
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
compute_dtype = dtypes.result_type(x.dtype, "float32")
result_dtype = dtypes.result_type(x.dtype, float)
return np.var(x, axis=axis, keepdims=keepdims, dtype=compute_dtype).astype(
result_dtype
)
def sum(x, axis=None, keepdims=False):
axis = standardize_axis_for_numpy(axis)
dtype = standardize_dtype(x.dtype)
# follow jax's rule
if dtype in ("bool", "int8", "int16"):
dtype = "int32"
elif dtype in ("uint8", "uint16"):
dtype = "uint32"
return np.sum(x, axis=axis, keepdims=keepdims).astype(dtype)
def eye(N, M=None, k=0, dtype=None):
dtype = dtype or config.floatx()
return np.eye(N, M=M, k=k, dtype=dtype)
def floor_divide(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
if not isinstance(x2, (int, float)):
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2))
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.floor_divide(x1, x2)
def logical_xor(x1, x2):
return np.logical_xor(x1, x2)
def correlate(x1, x2, mode="valid"):
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
)
if dtype == "int64":
dtype = "float64"
elif dtype not in ["bfloat16", "float16", "float64"]:
dtype = "float32"
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.correlate(x1, x2, mode)
def select(condlist, choicelist, default=0):
return np.select(condlist, choicelist, default=default)
def slogdet(x):
return tuple(np.linalg.slogdet(x))
def argpartition(x, kth, axis=-1):
return np.argpartition(x, kth, axis).astype("int32")
def histogram(x, bins, range):
return np.histogram(x, bins=bins, range=range)