File size: 8,032 Bytes
29b9c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from __future__ import absolute_import

import numpy as np

from dtcwt.numpy import Transform2d as Transform2d_np
from dtcwt.numpy import Pyramid


def appropriate_complex_type_for(X):
    """Return an appropriate complex data type depending on the type of X. If X
    is already complex, return that, if it is floating point return a complex
    type of the appropriate size and if it is integer, choose an complex
    floating point type depending on the result of :py:func:`numpy.asfarray`.

    """
    X = asfarray(X)
    if np.issubsctype(X.dtype, np.complex64) or \
            np.issubsctype(X.dtype, np.complex128):
        return X.dtype
    elif np.issubsctype(X.dtype, np.float32):
        return np.complex64
    elif np.issubsctype(X.dtype, np.float64):
        return np.complex128

    # God knows, err on the side of caution
    return np.complex128


def asfarray(X):
    """Similar to :py:func:`numpy.asfarray` except that this function tries to
    preserve the original datatype of X if it is already a floating point type
    and will pass floating point arrays through directly without copying.

    """
    X = np.asanyarray(X)
    return np.asfarray(X, dtype=X.dtype)


class Transform2d(object):
    """
    An implementation of the 2D DT-CWT via numpy.

    Parameters
    ----------
    biort: str or np.array
        The biorthogonal wavelet family to use. If a string, will use this to
        call pytorch_wavelets.dtcwt.coeffs.biort. If an array, will use these as the values.
    qshift: str or np.array
        The quarter shift wavelet family to use. If a string, will use this to
        call pytorch_wavelets.dtcwt.coeffs.biort. If an array, will use these as the values.

    .. note::

        *biort* and *qshift* are the wavelets which parameterise the transform.
        If *biort* or *qshift* are strings, they are used as an argument to the
        :py:func:`dtcwt.coeffs.biort` or :py:func:`dtcwt.coeffs.qshift`
        functions.  Otherwise, they are interpreted as tuples of vectors giving
        filter coefficients. In the *biort* case, this should be (h0o, g0o, h1o,
        g1o). In the *qshift* case, this should be (h0a, h0b, g0a, g0b, h1a,
        h1b, g1a, g1b).

    .. note::

    .. codeauthor:: Fergal Cotter <fbc23@cam.ac.uk>, Feb 2018
    """
    def __init__(self, biort='near_sym_a', qshift='qshift_a'):
        self.xfm = Transform2d_np(biort, qshift)

    def forward(self, X, nlevels=3, include_scale=False):
        """ Perform a forward transform on an image with multiple channels.

        Will perform the DTCWT independently on each channel. Data format for
        the input must have the height and width as the last 2 dimensions.

        Parameters
        ----------
        X: np.array
            Input image which you wish to transform. Can be 2, 3, or 4
            dimensions, but height and width must be the last 2.
        nlevels: int
            Number of levels of the dtcwt transform to calculate.
        include_scale: bool
            Whether or not to return the lowpass results at each scale of the
            transform, or only at the highest scale (as is custom for
            multiresolution analysis)

        Returns
        -------
            Yl: ndarray
                Lowpass output
            Yh: list(ndarray)
                Highpass outputs. Will be complex and have one more dimension
                than the input representing the 6 orientations of the wavelets.
                This extra dimension will be the third last dimension. The first
                entry in the list is the first scale.
            Yscale: list(ndarray)
                Only returns if include_scale was true. A list of lowpass
                outputs at each scale.

        .. codeauthor:: Fergal Cotter <fbc23@cam.ac.uk>, Feb 2018
        """
        # Reshape the inputs to all be 3d inputs of shape (batch, h, w)
        X = asfarray(X)
        s = X.shape
        if len(s) == 2:
            X = np.reshape(X, (1, *s))
        elif len(s) == 4:
            X = np.reshape(X, (s[0]*s[1], s[2], s[3]))

        # Do the dtcwt now with a 3 dimensional input
        p = self.xfm.forward(X[0], nlevels, include_scale)
        Yl = np.zeros((X.shape[0], *p.lowpass.shape), dtype=X.dtype)
        Yh = [np.zeros((X.shape[0], 6, *p.highpasses[i].shape[0:2]),
                       dtype=appropriate_complex_type_for(X)) for i in
              range(nlevels)]
        if include_scale:
            Yscale = [np.zeros((X.shape[0], *p.scales[i].shape), dtype=X.dtype)
                      for i in range(nlevels)]
        Yl[0] = p.lowpass
        for i in range(nlevels):
            Yh[i][0] = p.highpasses[i].transpose((2,0,1))
            if include_scale:
                Yscale[i][0] = p.scales[i]

        for n in range(1, X.shape[0]):
            p = self.xfm.forward(X[n], nlevels, include_scale)
            Yl[n] = p.lowpass
            for i in range(nlevels):
                Yh[i][n] = p.highpasses[i].transpose((2,0,1))
                if include_scale:
                    Yscale[i][n] = p.scales[i]

        # Reshape output to match input
        if len(s) == 2:
            Yl = Yl[0]
            Yh = [Yh[i][0] for i in range(nlevels)]
            if include_scale:
                Yscale = [Yscale[i][0] for i in range(nlevels)]
        elif len(s) == 4:
            Yl = np.reshape(Yl, (s[0], s[1], *Yl.shape[-2:]))
            Yh = [np.reshape(Yh[i], (s[0], s[1], *Yh[i].shape[-3:])) for i in
                  range(nlevels)]
            if include_scale:
                Yscale = [np.reshape(Yscale[i],
                                     (s[0], s[1], *Yscale[i].shape[-2:]))
                          for i in range(nlevels)]

        if include_scale:
            return Yl, Yh, Yscale
        else:
            return Yl, Yh

    def inverse(self, Yl, Yh, gain_mask=None):
        """
        Perform an inverse transform on an image with multiple channels.

        Parameters
        ----------
        Yl: ndarray
            The lowpass coefficients. Can be 2, 3, or 4 dimensions
        Yh: list(ndarray)
            The complex high pass coefficients. Must be compatible with the
            lowpass coefficients. Should have one more dimension. E.g if Yl
            was of shape [batch, ch, h, w], then the Yh's should be each of
            shape [batch, ch, 6, h', w'] (with h' and w' being dependent on the
            scale).
        gain_mask: None or ndarray
            Can use this to set subbands to have non-unit gain. Should be
            anarray of size [nlevels, 6] and can be complex or real. Useful for
            masking out subbands.

        Returns
        -------
        X: ndarray
            An array , X, compatible with the reconstruction.

        .. codeauthor:: Fergal Cotter <fbc23@cam.ac.uk>, Feb 2018
        """
        J = len(Yh)
        s = Yl.shape

        # Reshape the inputs to all be 3d inputs of shape (batch, h, w)
        if len(s) == 2:
            Yl = np.reshape(Yl, (1, *s))
            Yh = [np.reshape(Yh[i], (1, *Yh[i].shape)) for i in range(J)]
        elif len(s) == 4:
            Yl = np.reshape(Yl, (s[0]*s[1], *s[-2:]))
            Yh = [np.reshape(Yh[i], (s[0]*s[1], *Yh[i].shape[-3:]))
                  for i in range(J)]

        # Do the inverse dtcwt now with a 3 dimensional input
        X = self.xfm.inverse(
            Pyramid(Yl[0], [np.transpose(Yh[i][0], (1,2,0)) for i in range(J)]),
            gain_mask=gain_mask)

        x = np.zeros((Yl.shape[0], *X.shape), dtype=X.dtype)
        x[0] = X
        for n in range(1, Yl.shape[0]):
            X = self.xfm.inverse(
                Pyramid(Yl[n], [np.transpose(Yh[i][n], (1,2,0))
                                for i in range(J)]))
            x[n] = X

        # Reshape output to match input
        if len(s) == 2:
            X = x[0]
        elif len(s) == 4:
            X = np.reshape(x, (s[0], s[1], *x.shape[-2:]))
        else:
            X = x

        return X