Spaces:
Sleeping
Sleeping
File size: 7,155 Bytes
29b9c56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
# from pytorch_wavelets.dtcwt.transform2d import DTCWTForward2, DTCWTInverse
import argparse
import py3nvml
import torch.nn.functional as F
import torch.nn as nn
import pytorch_wavelets.dwt.transform2d as dwt
import pytorch_wavelets.dwt.lowlevel as lowlevel
import pytorch_wavelets.dtcwt.lowlevel2 as lowlevel2
from pytorch_wavelets.dtcwt.coeffs import level1, qshift
parser = argparse.ArgumentParser(
'Profile the forward and inverse dtcwt in pytorch')
parser.add_argument('--no_grad', action='store_true',
help='Dont calculate the gradients')
parser.add_argument('--ref', action='store_true',
help='Compare to doing the DTCWT with ffts')
parser.add_argument('-c', '--convolution', action='store_true',
help='Profile an 11x11 convolution')
parser.add_argument('--dwt', action='store_true',
help='Profile dwt instead of dtcwt')
parser.add_argument('--fb', action='store_true',
help='Do the 4 fb implementation of the dtcwt')
parser.add_argument('-f', '--forward', action='store_true',
help='Only do forward transform (default is fwd and inv)')
parser.add_argument('-i', '--inverse', action='store_true',
help='Only do inverse transform (default is fwd and inv)')
parser.add_argument('-j', type=int, default=2,
help='number of scales of transform to do')
parser.add_argument('--no_hp', action='store_true')
parser.add_argument('-s', '--size', default=0, type=int,
help='spatial size of input')
parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'],
help='which device to test')
parser.add_argument('--batch', default=16, type=int,
help='Number of images in parallel')
ICIP = False
def forward(size, no_grad, J, no_hp=False, dev='cuda'):
x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
xfm = DTCWTForward(J=J, skip_hps=no_hp, o_dim=1, mode='symmetric').to(dev)
for _ in range(5):
Yl, Yh = xfm(x)
if not no_grad:
Yl.backward(torch.ones_like(Yl))
return Yl, Yh
def inverse(size, no_grad, J, no_hp=False, dev='cuda'):
yl = torch.randn(size[0], size[1], size[2] >> (J-1), size[3] >> (J-1),
requires_grad=(not no_grad)).to(dev)
yh = [torch.randn(size[0], size[1], 6, size[2] >> j, size[3] >> j, 2,
requires_grad=(not no_grad)).to(dev)
for j in range(1,J+1)]
ifm = DTCWTInverse().to(dev)
for _ in range(5):
Y = ifm((yl, yh))
if not no_grad:
Y.backward(torch.ones_like(Y))
return Y
def end_to_end(size, no_grad, J, no_hp=False, dev='cuda'):
x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
xfm = DTCWTForward(J=J, skip_hps=no_hp).to(dev)
ifm = DTCWTInverse().to(dev)
Yl, Yh = xfm(x)
for _ in range(5):
Y = ifm((Yl, Yh))
if not no_grad:
Y.backward(torch.ones_like(Y))
return Y
def reference_conv(size, no_grad=False, dev='cuda'):
x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
w = torch.randn(10,1,11,11).to(dev)
y = F.conv2d(x, w, padding=5, groups=1)
if not no_grad:
y.backward(torch.ones_like(y))
return y
def reference_fftconv(size, J, no_grad=False, dev='cuda'):
x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
# Make the rough assumption that the wavelet sizes are 9*J spatial size
sz = 9*J
xp = F.pad(x, (0, sz-1, 0, sz-1))
FX = torch.rfft(xp, 2)
FX = torch.unsqueeze(FX, dim=2)
FW = torch.randn(1, 1, 12*J+1, FX.shape[-3], FX.shape[-2], 2, device=dev)
FYr = FX[..., 0] * FW[..., 0] - FX[..., 1] * FW[..., 1]
FYi = FX[..., 0] * FW[..., 1] + FX[..., 1] * FW[..., 0]
FY = torch.stack((FYr, FYi), dim=-1)
FY = FY.view(FY.shape[0], -1, *FY.shape[-3:])
Y = torch.irfft(FY, 2, signal_sizes=xp.shape[-2:])
if not no_grad:
Y.backward(torch.ones_like(Y))
def separable_dwt(size, J, no_grad=False, dev='cuda'):
x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
xfm = dwt.DWTForward(J, wave='db5', mode='zero').to(dev)
for _ in range(5):
yl, yh = xfm(x)
if not no_grad:
yh[0].backward(torch.ones_like(yh[0]))
return yl.mean(), [y.mean() for y in yh]
def selesnick_dtcwt(size, J, no_grad=False, dev='cuda'):
x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
xfm = DTCWTForward2(J=J, mode='symmetric').to(dev)
for _ in range(5):
Yl, Yh = xfm(x)
if not no_grad:
Yl.backward(torch.ones_like(Yl))
return Yl, Yh
def test_dtcwt(size, J, no_grad=False, dev='cuda'):
x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
h0a, h0b, _, _, h1a, h1b, _, _ = level1('farras')
filts = lowlevel2.prep_filt_quad_afb2d_nonsep(
h0a, h1a, h0a, h1a,
h0a, h1a, h0b, h1b,
h0b, h1b, h0a, h1a,
h0b, h1b, h0b, h1b, device=dev)
for j in range(3):
yl, yh = lowlevel.afb2d_nonsep(x, filts, mode='zero')
x = yl.reshape(yl.shape[0], -1, yl.shape[-2], yl.shape[-1])
def test_dtcwt2(size, J, no_grad=False, dev='cuda'):
x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
h0a, h0b, _, _, h1a, h1b, _, _ = level1('farras')
cols, rows = lowlevel2.prep_filt_quad_afb2d(h0a, h1a, h0b, h1b, device=dev)
yh = []
for j in range(3):
x, y = lowlevel2.quad_afb2d(x, cols, rows, mode='zero')
yh.append(y)
return x, yh
if __name__ == "__main__":
args = parser.parse_args()
py3nvml.grab_gpus(1)
if args.size > 0:
size = (args.batch, 5, args.size, args.size)
else:
size = (args.batch, 5, 128, 128)
if args.ref:
print('Running dtcwt with FFTs')
reference_fftconv(size, args.j, args.no_grad, args.device)
elif args.convolution:
print('Running 11x11 convolution')
reference_conv(size, args.no_grad, args.device)
elif args.dwt:
print('Running separable dwt')
separable_dwt(size, args.j, args.no_grad, args.device)
elif args.fb:
print('Running 4 dwts')
yl, yh = selesnick_dtcwt(size, args.j, args.no_grad, args.device)
# yl, yh = test_dtcwt2(size, args.j, no_grad=args.no_grad, dev=args.device)
else:
if args.forward:
print('Running forward transform')
yl, yh = forward(size, args.no_grad, args.j, args.no_hp, args.device)
elif args.inverse:
print('Running inverse transform')
inverse(size, args.no_grad, args.j, args.no_hp, args.device)
else:
print('Running end to end')
end_to_end(size, args.no_grad, args.j, args.no_hp, args.device)
if ICIP:
n, _, c, h, w, _ = yh[0].shape
mag = torch.sqrt(yh[0][...,0] **2 + yh[0][...,1]**2 +0.01) - 0.1
mag = mag.view(n, 6*c, h, w)
gain1 = nn.Conv2d(6*c, c, 3, padding=1).cuda()
y = gain1(mag)
torch.cuda.synchronize()
|