File size: 9,052 Bytes
29b9c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
2D Wavelet Transforms in Pytorch
================================

|build-status| |docs| |doi|

.. |build-status| image:: https://travis-ci.org/fbcotter/pytorch_wavelets.png?branch=master
    :alt: build status
    :scale: 100%
    :target: https://travis-ci.org/fbcotter/pytorch_wavelets

.. |docs| image:: https://readthedocs.org/projects/pytorch-wavelets/badge/?version=latest
    :target: https://pytorch-wavelets.readthedocs.io/en/latest/?badge=latest
    :alt: Documentation Status

.. |doi| image:: https://zenodo.org/badge/146817005.svg
   :target: https://zenodo.org/badge/latestdoi/146817005
   
The full documentation is also available `here`__.

__ http://pytorch-wavelets.readthedocs.io/

This package provides support for computing the 2D discrete wavelet and 
the 2d dual-tree complex wavelet transforms, their inverses, and passing 
gradients through both using pytorch.

The implementation is designed to be used with batches of multichannel images.
We use the standard pytorch implementation of having 'NCHW' data format.

We also have added layers to do the 2-D DTCWT based scatternet. This is similar
to the Morlet based scatternet in `KymatIO`__, but is roughly 10 times faster.

If you use this repo, please cite my PhD thesis, chapter 3: https://doi.org/10.17863/CAM.53748.

__ https://github.com/kymatio/kymatio

New in version 1.3.0
~~~~~~~~~~~~~~~~~~~~

- Added 1D DWT support

.. code:: python

    import torch
    from pytorch_wavelets import DWT1DForward, DWT1DInverse  # or simply DWT1D, IDWT1D
    dwt = DWT1DForward(wave='db6', J=3)
    X = torch.randn(10, 5, 100)
    yl, yh = dwt(X)
    print(yl.shape)
    >>> torch.Size([10, 5, 22])
    print(yh[0].shape)
    >>> torch.Size([10, 5, 55])
    print(yh[1].shape)
    >>> torch.Size([10, 5, 33])
    print(yh[2].shape)
    >>> torch.Size([10, 5, 22])
    idwt = DWT1DInverse(wave='db6')
    x = idwt((yl, yh))

New in version 1.2.0
~~~~~~~~~~~~~~~~~~~~

- Added a DTCWT based ScatterNet

.. code:: python

    import torch
    from pytorch_wavelets import ScatLayer
    scat = ScatLayer()
    X = torch.randn(10,5,64,64)
    # A first order scatternet with 6 orientations and one lowpass channels
    # gives 7 times the input channel dimension
    Z = scat(X)
    print(Z.shape)
    >>> torch.Size([10, 35, 32, 32])
    # A second order scatternet with 6 orientations and one lowpass channels
    # gives 7^2 times the input channel dimension
    scat2 = torch.nn.Sequential(ScatLayer(), ScatLayer())
    Z = scat2(X)
    print(Z.shape)
    >>> torch.Size([10, 245, 16, 16])
    # We also have a slightly more specialized, but slower, second order scatternet
    from pytorch_wavelets import ScatLayerj2
    scat2a = ScatLayerj2()
    Z = scat2a(X)
    print(Z.shape)
    >>> torch.Size([10, 245, 16, 16])
    # These all of course work with cuda
    scat2a.cuda()
    Z = scat2a(X.cuda())

New in version 1.1.0
~~~~~~~~~~~~~~~~~~~~

- Fixed memory problem with dwt 
- Fixed the backend code for the dtcwt calculation - much cleaner now but similar performance
- Both dtcwt and dwt should be more memory efficient/aware now. 
- Removed need to specify number of scales for DTCWTInverse

New in version 1.0.0
~~~~~~~~~~~~~~~~~~~~
Version 1.0.0 has now added support for separable DWT calculation, and more
padding schemes, such as symmetric, zero and periodization.

Also, no longer need to specify the number of channels when creating the wavelet
transform classes.

Speed Tests
~~~~~~~~~~~
We compare doing the dtcwt with the python package and doing the dwt with
PyWavelets to doing both in pytorch_wavelets, using a GTX1080. The numpy methods
were run on a 14 core Xeon Phi machine using intel's parallel python. For the
dtwcwt we use the `near_sym_a` filters for the first scale and the `qshift_a`
filters for subsequent scales. For the dwt we use the `db4` filters.

For a fixed input size, but varying the number of scales (from 1 to 4) we have
the following speeds (averaged over 5 runs):

.. raw:: html

    <img src="docs/scale.png" width="700px">

For an input size with height and width 512 by 512, we also vary the batch size
for a 3 scale transform. The resulting speeds were:

.. raw:: html

    <img src="docs/batchsize.png" width="700px">

Installation
````````````
The easiest way to install ``pytorch_wavelets`` is to clone the repo and pip install
it. Later versions will be released on PyPi but the docs need to updated first::

    $ git clone https://github.com/fbcotter/pytorch_wavelets
    $ cd pytorch_wavelets
    $ pip install .

(Although the `develop` command may be more useful if you intend to perform any
significant modification to the library.) A test suite is provided so that you
may verify the code works on your system::

    $ pip install -r tests/requirements.txt
    $ pytest tests/

Example Use
```````````
For the DWT - note that the highpass output has an extra dimension, in which we
stack the (lh, hl, hh) coefficients.  Also note that the Yh output has the
finest detail coefficients first, and the coarsest last (the opposite to
PyWavelets).

.. code:: python

    import torch
    from pytorch_wavelets import DWTForward, DWTInverse
    xfm = DWTForward(J=3, wave='db3', mode='zero')
    X = torch.randn(10,5,64,64)
    Yl, Yh = xfm(X) 
    print(Yl.shape)
    >>> torch.Size([10, 5, 12, 12])
    print(Yh[0].shape) 
    >>> torch.Size([10, 5, 3, 34, 34])
    print(Yh[1].shape)
    >>> torch.Size([10, 5, 3, 19, 19])
    print(Yh[2].shape)
    >>> torch.Size([10, 5, 3, 12, 12])
    ifm = DWTInverse(wave='db3', mode='zero')
    Y = ifm((Yl, Yh))

For the DTCWT:

.. code:: python

    import torch
    from pytorch_wavelets import DTCWTForward, DTCWTInverse
    xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b')
    X = torch.randn(10,5,64,64)
    Yl, Yh = xfm(X) 
    print(Yl.shape)
    >>> torch.Size([10, 5, 16, 16])
    print(Yh[0].shape) 
    >>> torch.Size([10, 5, 6, 32, 32, 2])
    print(Yh[1].shape)
    >>> torch.Size([10, 5, 6, 16, 16, 2])
    print(Yh[2].shape)
    >>> torch.Size([10, 5, 6, 8, 8, 2])
    ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b')
    Y = ifm((Yl, Yh))

Some initial notes:

- Yh returned is a tuple. There are 2 extra dimensions - the first comes between
  the channel dimension of the input and the row dimension. This is the
  6 orientations of the DTCWT. The second is the final dimension, which is the
  real an imaginary parts (complex numbers are not native to pytorch)

Running on the GPU
~~~~~~~~~~~~~~~~~~
This should come as no surprise to pytorch users. The DWT and DTCWT transforms support
cuda calling:

.. code:: python

    import torch
    from pytorch_wavelets import DTCWTForward, DTCWTInverse
    xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
    X = torch.randn(10,5,64,64).cuda()
    Yl, Yh = xfm(X) 
    ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
    Y = ifm((Yl, Yh))

The automated tests cannot test the gpu functionality, but do check cpu running.
To test whether the repo is working on your gpu, you can download the repo,
ensure you have pytorch with cuda enabled (the tests will check to see if
:code:`torch.cuda.is_available()` returns true), and run:

.. code:: 

    pip install -r tests/requirements.txt
    pytest tests/

From the base of the repo.

Backpropagation
~~~~~~~~~~~~~~~
It is possible to pass gradients through the forward and backward transforms.
All you need to do is ensure that the input to each has the required_grad
attribute set to true.



Provenance
~~~~~~~~~~
Based on the Dual-Tree Complex Wavelet Transform Pack for MATLAB by Nick
Kingsbury, Cambridge University. The original README can be found in
ORIGINAL_README.txt.  This file outlines the conditions of use of the original
MATLAB toolbox.

Further information on the DT CWT can be obtained from papers
downloadable from my website (given below). The best tutorial is in
the 1999 Royal Society Paper. In particular this explains the conversion
between 'real' quad-number subimages and pairs of complex subimages. 
The Q-shift filters are explained in the ICIP 2000 paper and in more detail
in the May 2001 paper for the Journal on Applied and Computational 
Harmonic Analysis.

This code is copyright and is supplied free of charge for research
purposes only. In return for supplying the code, all I ask is that, if
you use the algorithms, you give due reference to this work in any
papers that you write and that you let me know if you find any good
applications for the DT CWT. If the applications are good, I would be
very interested in collaboration. I accept no liability arising from use
of these algorithms.

Nick Kingsbury, 
Cambridge University, June 2003.

Dr N G Kingsbury,
Dept. of Engineering, University of Cambridge,
Trumpington St., Cambridge CB2 1PZ, UK., or
Trinity College, Cambridge CB2 1TQ, UK.
Phone: (0 or +44) 1223 338514 / 332647;  Home: 1954 211152;
Fax: 1223 338564 / 332662;  E-mail: ngk@eng.cam.ac.uk
Web home page: http://www.eng.cam.ac.uk/~ngk/

.. vim:sw=4:sts=4:et