File size: 3,673 Bytes
1a0d68d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import numpy as np
from autograd.builtins import NamedTupleVSpace
from autograd.extend import VSpace
class ArrayVSpace(VSpace):
def __init__(self, value):
value = np.asarray(value)
self.shape = value.shape
self.dtype = value.dtype
@property
def size(self):
return np.prod(self.shape)
@property
def ndim(self):
return len(self.shape)
def zeros(self):
return np.zeros(self.shape, dtype=self.dtype)
def ones(self):
return np.ones(self.shape, dtype=self.dtype)
def standard_basis(self):
for idxs in np.ndindex(*self.shape):
vect = np.zeros(self.shape, dtype=self.dtype)
vect[idxs] = 1
yield vect
def randn(self):
return np.array(np.random.randn(*self.shape)).astype(self.dtype)
def _inner_prod(self, x, y):
return np.dot(np.ravel(x), np.ravel(y))
class ComplexArrayVSpace(ArrayVSpace):
iscomplex = True
@property
def size(self):
return np.prod(self.shape) * 2
def ones(self):
return np.ones(self.shape, dtype=self.dtype) + 1.0j * np.ones(self.shape, dtype=self.dtype)
def standard_basis(self):
for idxs in np.ndindex(*self.shape):
for v in [1.0, 1.0j]:
vect = np.zeros(self.shape, dtype=self.dtype)
vect[idxs] = v
yield vect
def randn(self):
return np.array(np.random.randn(*self.shape)).astype(self.dtype) + 1.0j * np.array(
np.random.randn(*self.shape)
).astype(self.dtype)
def _inner_prod(self, x, y):
return np.real(np.dot(np.conj(np.ravel(x)), np.ravel(y)))
def _covector(self, x):
return np.conj(x)
VSpace.register(np.ndarray, lambda x: ComplexArrayVSpace(x) if np.iscomplexobj(x) else ArrayVSpace(x))
for type_ in [float, np.longdouble, np.float64, np.float32, np.float16]:
ArrayVSpace.register(type_)
for type_ in [complex, np.clongdouble, np.complex64, np.complex128]:
ComplexArrayVSpace.register(type_)
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
class EigResultVSpace(NamedTupleVSpace):
seq_type = np.linalg._linalg.EigResult
class EighResultVSpace(NamedTupleVSpace):
seq_type = np.linalg._linalg.EighResult
class QRResultVSpace(NamedTupleVSpace):
seq_type = np.linalg._linalg.QRResult
class SlogdetResultVSpace(NamedTupleVSpace):
seq_type = np.linalg._linalg.SlogdetResult
class SVDResultVSpace(NamedTupleVSpace):
seq_type = np.linalg._linalg.SVDResult
EigResultVSpace.register(np.linalg._linalg.EigResult)
EighResultVSpace.register(np.linalg._linalg.EighResult)
QRResultVSpace.register(np.linalg._linalg.QRResult)
SlogdetResultVSpace.register(np.linalg._linalg.SlogdetResult)
SVDResultVSpace.register(np.linalg._linalg.SVDResult)
elif np.__version__ >= "1.25":
class EigResultVSpace(NamedTupleVSpace):
seq_type = np.linalg.linalg.EigResult
class EighResultVSpace(NamedTupleVSpace):
seq_type = np.linalg.linalg.EighResult
class QRResultVSpace(NamedTupleVSpace):
seq_type = np.linalg.linalg.QRResult
class SlogdetResultVSpace(NamedTupleVSpace):
seq_type = np.linalg.linalg.SlogdetResult
class SVDResultVSpace(NamedTupleVSpace):
seq_type = np.linalg.linalg.SVDResult
EigResultVSpace.register(np.linalg.linalg.EigResult)
EighResultVSpace.register(np.linalg.linalg.EighResult)
QRResultVSpace.register(np.linalg.linalg.QRResult)
SlogdetResultVSpace.register(np.linalg.linalg.SlogdetResult)
SVDResultVSpace.register(np.linalg.linalg.SVDResult)
|