import torch from pytorch_wavelets import DTCWTForward, DTCWTInverse # from pytorch_wavelets.dtcwt.transform2d import DTCWTForward2, DTCWTInverse import argparse import py3nvml import torch.nn.functional as F import torch.nn as nn import pytorch_wavelets.dwt.transform2d as dwt import pytorch_wavelets.dwt.lowlevel as lowlevel import pytorch_wavelets.dtcwt.lowlevel2 as lowlevel2 from pytorch_wavelets.dtcwt.coeffs import level1, qshift parser = argparse.ArgumentParser( 'Profile the forward and inverse dtcwt in pytorch') parser.add_argument('--no_grad', action='store_true', help='Dont calculate the gradients') parser.add_argument('--ref', action='store_true', help='Compare to doing the DTCWT with ffts') parser.add_argument('-c', '--convolution', action='store_true', help='Profile an 11x11 convolution') parser.add_argument('--dwt', action='store_true', help='Profile dwt instead of dtcwt') parser.add_argument('--fb', action='store_true', help='Do the 4 fb implementation of the dtcwt') parser.add_argument('-f', '--forward', action='store_true', help='Only do forward transform (default is fwd and inv)') parser.add_argument('-i', '--inverse', action='store_true', help='Only do inverse transform (default is fwd and inv)') parser.add_argument('-j', type=int, default=2, help='number of scales of transform to do') parser.add_argument('--no_hp', action='store_true') parser.add_argument('-s', '--size', default=0, type=int, help='spatial size of input') parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'], help='which device to test') parser.add_argument('--batch', default=16, type=int, help='Number of images in parallel') ICIP = False def forward(size, no_grad, J, no_hp=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) xfm = DTCWTForward(J=J, skip_hps=no_hp, o_dim=1, mode='symmetric').to(dev) for _ in range(5): Yl, Yh = xfm(x) if not no_grad: Yl.backward(torch.ones_like(Yl)) return Yl, Yh def inverse(size, no_grad, J, no_hp=False, dev='cuda'): yl = torch.randn(size[0], size[1], size[2] >> (J-1), size[3] >> (J-1), requires_grad=(not no_grad)).to(dev) yh = [torch.randn(size[0], size[1], 6, size[2] >> j, size[3] >> j, 2, requires_grad=(not no_grad)).to(dev) for j in range(1,J+1)] ifm = DTCWTInverse().to(dev) for _ in range(5): Y = ifm((yl, yh)) if not no_grad: Y.backward(torch.ones_like(Y)) return Y def end_to_end(size, no_grad, J, no_hp=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) xfm = DTCWTForward(J=J, skip_hps=no_hp).to(dev) ifm = DTCWTInverse().to(dev) Yl, Yh = xfm(x) for _ in range(5): Y = ifm((Yl, Yh)) if not no_grad: Y.backward(torch.ones_like(Y)) return Y def reference_conv(size, no_grad=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) w = torch.randn(10,1,11,11).to(dev) y = F.conv2d(x, w, padding=5, groups=1) if not no_grad: y.backward(torch.ones_like(y)) return y def reference_fftconv(size, J, no_grad=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) # Make the rough assumption that the wavelet sizes are 9*J spatial size sz = 9*J xp = F.pad(x, (0, sz-1, 0, sz-1)) FX = torch.rfft(xp, 2) FX = torch.unsqueeze(FX, dim=2) FW = torch.randn(1, 1, 12*J+1, FX.shape[-3], FX.shape[-2], 2, device=dev) FYr = FX[..., 0] * FW[..., 0] - FX[..., 1] * FW[..., 1] FYi = FX[..., 0] * FW[..., 1] + FX[..., 1] * FW[..., 0] FY = torch.stack((FYr, FYi), dim=-1) FY = FY.view(FY.shape[0], -1, *FY.shape[-3:]) Y = torch.irfft(FY, 2, signal_sizes=xp.shape[-2:]) if not no_grad: Y.backward(torch.ones_like(Y)) def separable_dwt(size, J, no_grad=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) xfm = dwt.DWTForward(J, wave='db5', mode='zero').to(dev) for _ in range(5): yl, yh = xfm(x) if not no_grad: yh[0].backward(torch.ones_like(yh[0])) return yl.mean(), [y.mean() for y in yh] def selesnick_dtcwt(size, J, no_grad=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) xfm = DTCWTForward2(J=J, mode='symmetric').to(dev) for _ in range(5): Yl, Yh = xfm(x) if not no_grad: Yl.backward(torch.ones_like(Yl)) return Yl, Yh def test_dtcwt(size, J, no_grad=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) h0a, h0b, _, _, h1a, h1b, _, _ = level1('farras') filts = lowlevel2.prep_filt_quad_afb2d_nonsep( h0a, h1a, h0a, h1a, h0a, h1a, h0b, h1b, h0b, h1b, h0a, h1a, h0b, h1b, h0b, h1b, device=dev) for j in range(3): yl, yh = lowlevel.afb2d_nonsep(x, filts, mode='zero') x = yl.reshape(yl.shape[0], -1, yl.shape[-2], yl.shape[-1]) def test_dtcwt2(size, J, no_grad=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) h0a, h0b, _, _, h1a, h1b, _, _ = level1('farras') cols, rows = lowlevel2.prep_filt_quad_afb2d(h0a, h1a, h0b, h1b, device=dev) yh = [] for j in range(3): x, y = lowlevel2.quad_afb2d(x, cols, rows, mode='zero') yh.append(y) return x, yh if __name__ == "__main__": args = parser.parse_args() py3nvml.grab_gpus(1) if args.size > 0: size = (args.batch, 5, args.size, args.size) else: size = (args.batch, 5, 128, 128) if args.ref: print('Running dtcwt with FFTs') reference_fftconv(size, args.j, args.no_grad, args.device) elif args.convolution: print('Running 11x11 convolution') reference_conv(size, args.no_grad, args.device) elif args.dwt: print('Running separable dwt') separable_dwt(size, args.j, args.no_grad, args.device) elif args.fb: print('Running 4 dwts') yl, yh = selesnick_dtcwt(size, args.j, args.no_grad, args.device) # yl, yh = test_dtcwt2(size, args.j, no_grad=args.no_grad, dev=args.device) else: if args.forward: print('Running forward transform') yl, yh = forward(size, args.no_grad, args.j, args.no_hp, args.device) elif args.inverse: print('Running inverse transform') inverse(size, args.no_grad, args.j, args.no_hp, args.device) else: print('Running end to end') end_to_end(size, args.no_grad, args.j, args.no_hp, args.device) if ICIP: n, _, c, h, w, _ = yh[0].shape mag = torch.sqrt(yh[0][...,0] **2 + yh[0][...,1]**2 +0.01) - 0.1 mag = mag.view(n, 6*c, h, w) gain1 = nn.Conv2d(6*c, c, 3, padding=1).cuda() y = gain1(mag) torch.cuda.synchronize()