File size: 3,673 Bytes
1a0d68d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np

from autograd.builtins import NamedTupleVSpace
from autograd.extend import VSpace


class ArrayVSpace(VSpace):
    def __init__(self, value):
        value = np.asarray(value)
        self.shape = value.shape
        self.dtype = value.dtype

    @property
    def size(self):
        return np.prod(self.shape)

    @property
    def ndim(self):
        return len(self.shape)

    def zeros(self):
        return np.zeros(self.shape, dtype=self.dtype)

    def ones(self):
        return np.ones(self.shape, dtype=self.dtype)

    def standard_basis(self):
        for idxs in np.ndindex(*self.shape):
            vect = np.zeros(self.shape, dtype=self.dtype)
            vect[idxs] = 1
            yield vect

    def randn(self):
        return np.array(np.random.randn(*self.shape)).astype(self.dtype)

    def _inner_prod(self, x, y):
        return np.dot(np.ravel(x), np.ravel(y))


class ComplexArrayVSpace(ArrayVSpace):
    iscomplex = True

    @property
    def size(self):
        return np.prod(self.shape) * 2

    def ones(self):
        return np.ones(self.shape, dtype=self.dtype) + 1.0j * np.ones(self.shape, dtype=self.dtype)

    def standard_basis(self):
        for idxs in np.ndindex(*self.shape):
            for v in [1.0, 1.0j]:
                vect = np.zeros(self.shape, dtype=self.dtype)
                vect[idxs] = v
                yield vect

    def randn(self):
        return np.array(np.random.randn(*self.shape)).astype(self.dtype) + 1.0j * np.array(
            np.random.randn(*self.shape)
        ).astype(self.dtype)

    def _inner_prod(self, x, y):
        return np.real(np.dot(np.conj(np.ravel(x)), np.ravel(y)))

    def _covector(self, x):
        return np.conj(x)


VSpace.register(np.ndarray, lambda x: ComplexArrayVSpace(x) if np.iscomplexobj(x) else ArrayVSpace(x))

for type_ in [float, np.longdouble, np.float64, np.float32, np.float16]:
    ArrayVSpace.register(type_)

for type_ in [complex, np.clongdouble, np.complex64, np.complex128]:
    ComplexArrayVSpace.register(type_)


if np.lib.NumpyVersion(np.__version__) >= "2.0.0":

    class EigResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.EigResult

    class EighResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.EighResult

    class QRResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.QRResult

    class SlogdetResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.SlogdetResult

    class SVDResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.SVDResult

    EigResultVSpace.register(np.linalg._linalg.EigResult)
    EighResultVSpace.register(np.linalg._linalg.EighResult)
    QRResultVSpace.register(np.linalg._linalg.QRResult)
    SlogdetResultVSpace.register(np.linalg._linalg.SlogdetResult)
    SVDResultVSpace.register(np.linalg._linalg.SVDResult)
elif np.__version__ >= "1.25":

    class EigResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.EigResult

    class EighResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.EighResult

    class QRResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.QRResult

    class SlogdetResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.SlogdetResult

    class SVDResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.SVDResult

    EigResultVSpace.register(np.linalg.linalg.EigResult)
    EighResultVSpace.register(np.linalg.linalg.EighResult)
    QRResultVSpace.register(np.linalg.linalg.QRResult)
    SlogdetResultVSpace.register(np.linalg.linalg.SlogdetResult)
    SVDResultVSpace.register(np.linalg.linalg.SVDResult)