Hajorda's picture
Upload folder using huggingface_hub
1a0d68d verified
from typing import Union
import numpy as np
from autograd.builtins import SequenceBox
from autograd.extend import Box, primitive
from . import numpy_wrapper as anp
Box.__array_priority__ = 90.0
class ArrayBox(Box):
__slots__ = []
__array_priority__ = 100.0
@primitive
def __getitem__(A, idx):
return A[idx]
# Constants w.r.t float data just pass though
shape = property(lambda self: self._value.shape)
ndim = property(lambda self: self._value.ndim)
size = property(lambda self: self._value.size)
dtype = property(lambda self: self._value.dtype)
T = property(lambda self: anp.transpose(self))
def __array_namespace__(self, *, api_version: Union[str, None] = None):
return anp
def __len__(self):
return len(self._value)
def astype(self, *args, **kwargs):
return anp._astype(self, *args, **kwargs)
def __neg__(self):
return anp.negative(self)
def __add__(self, other):
return anp.add(self, other)
def __sub__(self, other):
return anp.subtract(self, other)
def __mul__(self, other):
return anp.multiply(self, other)
def __pow__(self, other):
return anp.power(self, other)
def __div__(self, other):
return anp.divide(self, other)
def __mod__(self, other):
return anp.mod(self, other)
def __truediv__(self, other):
return anp.true_divide(self, other)
def __matmul__(self, other):
return anp.matmul(self, other)
def __radd__(self, other):
return anp.add(other, self)
def __rsub__(self, other):
return anp.subtract(other, self)
def __rmul__(self, other):
return anp.multiply(other, self)
def __rpow__(self, other):
return anp.power(other, self)
def __rdiv__(self, other):
return anp.divide(other, self)
def __rmod__(self, other):
return anp.mod(other, self)
def __rtruediv__(self, other):
return anp.true_divide(other, self)
def __rmatmul__(self, other):
return anp.matmul(other, self)
def __eq__(self, other):
return anp.equal(self, other)
def __ne__(self, other):
return anp.not_equal(self, other)
def __gt__(self, other):
return anp.greater(self, other)
def __ge__(self, other):
return anp.greater_equal(self, other)
def __lt__(self, other):
return anp.less(self, other)
def __le__(self, other):
return anp.less_equal(self, other)
def __abs__(self):
return anp.abs(self)
def __hash__(self):
return id(self)
ArrayBox.register(np.ndarray)
for type_ in [
float,
np.longdouble,
np.float64,
np.float32,
np.float16,
complex,
np.clongdouble,
np.complex64,
np.complex128,
]:
ArrayBox.register(type_)
# These numpy.ndarray methods are just refs to an equivalent numpy function
nondiff_methods = [
"all",
"any",
"argmax",
"argmin",
"argpartition",
"argsort",
"nonzero",
"searchsorted",
"round",
]
diff_methods = [
"clip",
"compress",
"cumprod",
"cumsum",
"diagonal",
"max",
"mean",
"min",
"prod",
"ptp",
"ravel",
"repeat",
"reshape",
"squeeze",
"std",
"sum",
"swapaxes",
"take",
"trace",
"transpose",
"var",
]
for method_name in nondiff_methods + diff_methods:
setattr(ArrayBox, method_name, anp.__dict__[method_name])
# Flatten has no function, only a method.
setattr(ArrayBox, "flatten", anp.__dict__["ravel"])
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
SequenceBox.register(np.linalg._linalg.EigResult)
SequenceBox.register(np.linalg._linalg.EighResult)
SequenceBox.register(np.linalg._linalg.QRResult)
SequenceBox.register(np.linalg._linalg.SlogdetResult)
SequenceBox.register(np.linalg._linalg.SVDResult)
elif np.__version__ >= "1.25":
SequenceBox.register(np.linalg.linalg.EigResult)
SequenceBox.register(np.linalg.linalg.EighResult)
SequenceBox.register(np.linalg.linalg.QRResult)
SequenceBox.register(np.linalg.linalg.SlogdetResult)
SequenceBox.register(np.linalg.linalg.SVDResult)