File size: 8,028 Bytes
29b9c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
DTCWT in Pytorch Wavelets
=========================

Pytorch wavelets is a port of `dtcwt_slim`__, which was my first attempt at
doing the DTCWT quickly on a GPU. It has since been cleaned up to run for
pytorch and do the quickest forward and inverse transforms I can make, as well
as being able to pass gradients through the inputs.

For those unfamiliar with the DTCWT, it is a shift invariant wavelet transform
that comes with limited redundancy. Compared to the undecimated wavelet
transform, which has :math:`2^J` redundancy, the DTCWT only has :math:`2^d`
redundancy (where d is the number of input dimensions - i.e. 4:1 redundancy for
image transforms). Instead of producing 3 output subbands like the DWT, it
produces 6, which roughly represent 15, 45, 75, 105, 135 and 165 degree
wavelets. On top of this, the 6 subbands have real and imaginary outputs which
are in quadrature with each other (similar to windowed sine and cosine
functions, or the gabor wavelet).

It is possible to calculate similar transforms (such as the morlet or gabor)
using fourier transforms, but the DTCWT is faster as it uses separable
convolutions.

.. image:: dtcwt_bands.png

__ https://github.com/fbcotter/dtcwt_slim

Notes
-----
Because of the above mentioned properties of the DTCWT, the output is slightly
different to the DWT. As mentioned, it is 4:1 redundant in 2D, so we expect
4 times as many coefficients as from the decimated wavelet transform. These
extra coefficients come from:

- 6 subband outputs instead of 3 with ial and imaginary coefficients instead 
  of just real. For an :math:`N \times N` image, the first level bandpass has
  :math:`3N^2` instead of :math:`3N^2/4` coefficients.
- The lowpass is always at double the resolution of what you'd expect it to be
  for the level in the wavelet tree. I.e. for a 1 level transform, the lowapss
  output is still :math:`N\times N`. For a two level transform, it is :math:`N/2
  \times N/2` and so on.  

Example
-------

.. code:: python

    import torch
    from pytorch_wavelets import DTCWTForward, DTCWTInverse
    xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b')
    X = torch.randn(10,5,64,64)
    Yl, Yh = xfm(X) 
    print(Yl.shape)
    >>> torch.Size([10, 5, 16, 16])
    print(Yh[0].shape) 
    >>> torch.Size([10, 5, 6, 32, 32, 2])
    print(Yh[1].shape)
    >>> torch.Size([10, 5, 6, 16, 16, 2])
    print(Yh[2].shape)
    >>> torch.Size([10, 5, 6, 8, 8, 2])
    ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b')
    Y = ifm((Yl, Yh))

Like with the DWT, Yh returned is a tuple. There are 2 extra dimensions - the
first comes between the channel dimension of the input and the row dimension.
This is the 6 orientations of the DTCWT. The second is the final dimension, which is the
real an imaginary parts (complex numbers are not native to pytorch). I.e. to
access the real part of the 45 degree wavelet for the first subband, you would
use :code:`Yh[0][:,:,1,:,:,0]`, and the imaginary part of the 165 degree wavelet
would be :code:`Yh[0][:,:,5,:,:,1]`. 

The above images were created by doing a forward transform with an input of
zeros (creates a pyramid with the correct size bands), and then setting the
centre spatial value to 1 for each of the orientations at the third scale. I.e.:

.. code:: python

    import numpy as np
    import torch
    from pytorch_wavelets import DTCWTForward, DTCWTInverse
    xfm = DTCWTForward(J=3)
    ifm = DTCWTInverse()
    x = torch.zeros(1,1,64,64)
    # Create 12 outputs, one for the real and imaginary point spread functions
    # for each of the 6 orientations
    out = np.zeros((12,64,64)
    yl, yh = xfm(x)
    for b in range(6):
      for ri in range(2):
        yh[2][0,0,b,4,4,ri] = 1
        out[b*2 + ri] = ifm((yl, yh))
        yh[2][0,0,b,4,4,ri] = 0
    # Can now plot the output

Advanced Options
----------------
There is a whole host of advanced options for calculating the DTCWT. The above
example shows the use case that will work most of the time. However, here are
some more ways the DTCWT can be done:

Custom Biorthogonal and Qshift Filters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Rather than specifying the type of filter for the layer 1 (biort), and layer 2+
(qshift) transforms, you can provide the filters directly. They should be given
as a tuple of array-like objects. For the biorthogonal filters, this is
a 2-tuple of low and highpass filters. For the qshift filters, this will be
a 4-tuple of low for tree a, low for tree b, high for tree a and high for tree
b filters.

E.g.:

.. code:: python
  
  from pytorch_wavelets import DTCWTForward
  from pytorch_wavelets.dtcwt.coeffs import biort
  # The standard style
  xfm1 = DTCWTForward(biort='near_sym_a', J=1)
  # Get our own filters, here we reverse the standard filters so they 
  # still have the right properties, only changing the phase
  h0o, _, h1o, _ = biort('near_sym_a')
  xfm2 = DTCWTForward(biort=(h0o[::-1], h1o[::-1]), J=1)

Note that you must be careful when doing this, as the filters are designed to
have the correct phase properties, so any changes will likely result in a loss
of the quarter shift and hence the shift invariant properties of the transform.

Skipping Highpasses
~~~~~~~~~~~~~~~~~~~
There is the option to not calculate the bandpass outputs at given scales. This
can speed up the transform if you know that there is very little useful content
in some areas of the frequency space. To do this, you can give a list of
booleans to the `skip_hps` parameter (if it is a single boolean, that is then
used for all the scales). The first value corresponds to the first
scale highpass outputs, and a value of true means do not calculate them.

E.g.:

.. code:: python

  from pytorch_wavelets import DTCWTForward
  xfm = DTCWTForward(J=3, skip_hps=[True, False, False])
  yl, yh = xfm(torch.randn(1, 1, 64, 64))
  print(yh[0].shape)
  >>> torch.Size([0])
  print(yh[1].shape)
  >>> torch.Size([1, 1, 6, 16, 16, 2])

Naturally, the inverse transform happily accepts tensors with 0 shape (or even
`None`'s) and sets that level to be all zeros.

Changing the output shape
~~~~~~~~~~~~~~~~~~~~~~~~~
By default the highpass outputs have an extra 2 dimensions, one at the end for
complex values, and one after the channel dimension, for the 6 orientations.
E.g. an input of shape of :math:`(N, C_{in}, H_{in}, W_{in})` will have bandpass 
coefficients with shapes :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`,
(we've put dashes next to the height and width as they will change with scale). 

You can choose where the orientations and real and imaginary dimensions go with
the options `o_dim` and `ri_dim`, which are by default 2 and -1.

Including all the lowpasses
~~~~~~~~~~~~~~~~~~~~~~~~~~~
In case you want to get all the intermediate lowpasses, you can with the
`include_scale` parameter. This works a bit like the `skip_hps` where you can
provide a single boolean to apply it to all the scales, or a list of booleans to
fine tune which lowpasses you want.

If any of the value in `include_scale` is true, then the transform output will
change, and the lowpass will be a tuple.

E.g.

.. code:: python

  from pytorch_wavelets import DTCWTForward
  xfm1 = DTCWTForward(J=3)
  xfm2 = DTCWTForward(J=3, include_scale=True)
  xfm3 = DTCWTForward(J=3, include_scale=[False, True, True])
  x = torch.randn(1, 1, 64, 64)
  yl, yh = xfm1(x)
  print(yl.shape)
  >>> torch.Size([1, 1, 16, 16])
  # Now do xfm2 which will give back all scales
  yl, yh = xfm2(x)
  for l in yl:
    print(yl.shape)
  >>> torch.Size([1, 1, 64, 64]) 
  >>> torch.Size([1, 1, 32, 32]) 
  >>> torch.Size([1, 1, 16, 16]) 
  # Now do xfm3 which will give back the last two scales
  yl, yh = xfm3(x)
  for l in yl:
    print(yl.shape)
  >>> torch.Size([0]) 
  >>> torch.Size([1, 1, 32, 32]) 
  >>> torch.Size([1, 1, 16, 16]) 

Note that to do the inverse transform, you have to give the final lowpass
output. You can provide None to indicate it's all zeros, but you cannot provide
all the intermediate lowpasses.