joebruce1313's picture
Upload 38004 files
1f5470c verified
import numpy as np
from keras.src.backend import standardize_dtype
from keras.src.backend.common import dtypes
from keras.src.backend.jax.math import fft as jax_fft
from keras.src.backend.jax.math import fft2 as jax_fft2
from keras.src.backend.numpy.core import convert_to_tensor
from keras.src.utils.module_utils import scipy
def _segment_reduction_fn(
data, segment_ids, reduction_method, num_segments, sorted
):
if num_segments is None:
num_segments = np.amax(segment_ids) + 1
valid_indices = segment_ids >= 0 # Ignore segment_ids that are -1
valid_data = data[valid_indices]
valid_segment_ids = segment_ids[valid_indices]
data_shape = list(valid_data.shape)
data_shape[0] = (
num_segments # Replace first dimension (which corresponds to segments)
)
if reduction_method == np.maximum:
result = np.ones(data_shape, dtype=valid_data.dtype) * -np.inf
else:
result = np.zeros(data_shape, dtype=valid_data.dtype)
if sorted:
reduction_method.at(result, valid_segment_ids, valid_data)
else:
sort_indices = np.argsort(valid_segment_ids)
sorted_segment_ids = valid_segment_ids[sort_indices]
sorted_data = valid_data[sort_indices]
reduction_method.at(result, sorted_segment_ids, sorted_data)
return result
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
return _segment_reduction_fn(
data, segment_ids, np.add, num_segments, sorted
)
def segment_max(data, segment_ids, num_segments=None, sorted=False):
return _segment_reduction_fn(
data, segment_ids, np.maximum, num_segments, sorted
)
def top_k(x, k, sorted=False):
if sorted:
# Take the k largest values.
sorted_indices = np.argsort(x, axis=-1)[..., ::-1]
sorted_values = np.take_along_axis(x, sorted_indices, axis=-1)
top_k_values = sorted_values[..., :k]
top_k_indices = sorted_indices[..., :k]
else:
# Partition the array such that all values larger than the k-th
# largest value are to the right of it.
top_k_indices = np.argpartition(x, -k, axis=-1)[..., -k:]
top_k_values = np.take_along_axis(x, top_k_indices, axis=-1)
return top_k_values, top_k_indices
def in_top_k(targets, predictions, k):
targets = targets[:, None]
topk_values = top_k(predictions, k)[0]
targets_values = np.take_along_axis(predictions, targets, axis=-1)
mask = targets_values >= topk_values
return np.any(mask, axis=-1)
def logsumexp(x, axis=None, keepdims=False):
return scipy.special.logsumexp(x, axis=axis, keepdims=keepdims)
def qr(x, mode="reduced"):
if mode not in {"reduced", "complete"}:
raise ValueError(
"`mode` argument value not supported. "
"Expected one of {'reduced', 'complete'}. "
f"Received: mode={mode}"
)
return np.linalg.qr(x, mode=mode)
def extract_sequences(x, sequence_length, sequence_stride):
*batch_shape, _ = x.shape
batch_shape = list(batch_shape)
shape = x.shape[:-1] + (
(x.shape[-1] - (sequence_length - sequence_stride)) // sequence_stride,
sequence_length,
)
strides = x.strides[:-1] + (
sequence_stride * x.strides[-1],
x.strides[-1],
)
x = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
return np.reshape(x, (*batch_shape, *x.shape[-2:]))
def _get_complex_tensor_from_tuple(x):
if not isinstance(x, (tuple, list)) or len(x) != 2:
raise ValueError(
"Input `x` should be a tuple of two tensors - real and imaginary."
f"Received: x={x}"
)
# `convert_to_tensor` does not support passing complex tensors. We separate
# the input out into real and imaginary and convert them separately.
real, imag = x
# Check shapes.
if real.shape != imag.shape:
raise ValueError(
"Input `x` should be a tuple of two tensors - real and imaginary."
"Both the real and imaginary parts should have the same shape. "
f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}"
)
# Ensure dtype is float.
if not np.issubdtype(real.dtype, np.floating) or not np.issubdtype(
imag.dtype, np.floating
):
raise ValueError(
"At least one tensor in input `x` is not of type float."
f"Received: x={x}."
)
complex_input = real + 1j * imag
return complex_input
def fft(x):
real, imag = jax_fft(x)
return np.array(real), np.array(imag)
def fft2(x):
real, imag = jax_fft2(x)
return np.array(real), np.array(imag)
def ifft2(x):
complex_input = _get_complex_tensor_from_tuple(x)
complex_output = np.fft.ifft2(complex_input)
return np.real(complex_output), np.imag(complex_output)
def rfft(x, fft_length=None):
complex_output = np.fft.rfft(x, n=fft_length, axis=-1, norm="backward")
# numpy always outputs complex128, so we need to recast the dtype
return (
np.real(complex_output).astype(x.dtype),
np.imag(complex_output).astype(x.dtype),
)
def irfft(x, fft_length=None):
complex_input = _get_complex_tensor_from_tuple(x)
# numpy always outputs float64, so we need to recast the dtype
return np.fft.irfft(
complex_input, n=fft_length, axis=-1, norm="backward"
).astype(x[0].dtype)
def stft(
x, sequence_length, sequence_stride, fft_length, window="hann", center=True
):
if standardize_dtype(x.dtype) not in {"float32", "float64"}:
raise TypeError(
"Invalid input type. Expected `float32` or `float64`. "
f"Received: input type={x.dtype}"
)
if fft_length < sequence_length:
raise ValueError(
"`fft_length` must equal or larger than `sequence_length`. "
f"Received: sequence_length={sequence_length}, "
f"fft_length={fft_length}"
)
if isinstance(window, str):
if window not in {"hann", "hamming"}:
raise ValueError(
"If a string is passed to `window`, it must be one of "
f'`"hann"`, `"hamming"`. Received: window={window}'
)
x = convert_to_tensor(x)
ori_dtype = x.dtype
if center:
pad_width = [(0, 0) for _ in range(len(x.shape))]
pad_width[-1] = (fft_length // 2, fft_length // 2)
x = np.pad(x, pad_width, mode="reflect")
l_pad = (fft_length - sequence_length) // 2
r_pad = fft_length - sequence_length - l_pad
if window is not None:
if isinstance(window, str):
win = convert_to_tensor(
scipy.signal.get_window(window, sequence_length), dtype=x.dtype
)
else:
win = convert_to_tensor(window, dtype=x.dtype)
if len(win.shape) != 1 or win.shape[-1] != sequence_length:
raise ValueError(
"The shape of `window` must be equal to [sequence_length]."
f"Received: window shape={win.shape}"
)
win = np.pad(win, [[l_pad, r_pad]])
else:
win = np.ones((sequence_length + l_pad + r_pad), dtype=x.dtype)
x = scipy.signal.stft(
x,
fs=1.0,
window=win,
nperseg=(sequence_length + l_pad + r_pad),
noverlap=(sequence_length + l_pad + r_pad - sequence_stride),
nfft=fft_length,
boundary=None,
padded=False,
)[-1]
# scale and swap to (..., num_sequences, fft_bins)
x = x / np.sqrt(1.0 / win.sum() ** 2)
x = np.swapaxes(x, -2, -1)
return np.real(x).astype(ori_dtype), np.imag(x).astype(ori_dtype)
def istft(
x,
sequence_length,
sequence_stride,
fft_length,
length=None,
window="hann",
center=True,
):
x = _get_complex_tensor_from_tuple(x)
dtype = np.real(x).dtype
expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1)
l_pad = (fft_length - sequence_length) // 2
r_pad = fft_length - sequence_length - l_pad
if window is not None:
if isinstance(window, str):
win = convert_to_tensor(
scipy.signal.get_window(window, sequence_length), dtype=dtype
)
else:
win = convert_to_tensor(window, dtype=dtype)
if len(win.shape) != 1 or win.shape[-1] != sequence_length:
raise ValueError(
"The shape of `window` must be equal to [sequence_length]."
f"Received: window shape={win.shape}"
)
win = np.pad(win, [[l_pad, r_pad]])
else:
win = np.ones((sequence_length + l_pad + r_pad), dtype=dtype)
x = scipy.signal.istft(
x,
fs=1.0,
window=win,
nperseg=(sequence_length + l_pad + r_pad),
noverlap=(sequence_length + l_pad + r_pad - sequence_stride),
nfft=fft_length,
boundary=False,
time_axis=-2,
freq_axis=-1,
)[-1]
# scale
x = x / win.sum() if window is not None else x / sequence_stride
start = 0 if center is False else fft_length // 2
if length is not None:
end = start + length
elif center is True:
end = -(fft_length // 2)
else:
end = expected_output_len
return x[..., start:end]
def rsqrt(x):
return 1.0 / np.sqrt(x)
def erf(x):
return np.array(scipy.special.erf(x))
def erfinv(x):
return np.array(scipy.special.erfinv(x))
def solve(a, b):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
return np.linalg.solve(a, b)
def norm(x, ord=None, axis=None, keepdims=False):
x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if "int" in dtype or dtype == "bool":
dtype = dtypes.result_type(x.dtype, "float32")
return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims).astype(
dtype
)
def logdet(x):
from keras.src.backend.numpy.numpy import slogdet
# In NumPy slogdet is more stable than `np.log(np.linalg.det(x))`. See
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html
return slogdet(x)[1]