Spaces:
Sleeping
Sleeping
File size: 3,061 Bytes
29b9c56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import torch
import argparse
import py3nvml
import timeit
parser = argparse.ArgumentParser('Profile the dwt')
parser.add_argument('method', choices=['torch', 'numpy'],
help='Method to use to calculate dwt')
parser.add_argument('xfm', choices=['dwt', 'dtcwt'],
help='which transform to use')
parser.add_argument('-f', '--forward', action='store_true',
help='Only do forward transform (default is fwd and inv)')
parser.add_argument('-j', type=int, default=2,
help='number of scales of transform to do')
parser.add_argument('-s', '--size', default=0, type=int,
help='spatial size of input')
parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'],
help='which device to test')
parser.add_argument('--wave', default='db4',
help='which wavelet to use')
parser.add_argument('--batch', default=16, type=int,
help='Number of images in parallel')
if __name__ == "__main__":
args = parser.parse_args()
py3nvml.grab_gpus(1)
if args.size > 0:
size = (args.batch, 1, args.size, args.size)
else:
size = (args.batch, 1, 128, 128)
if args.method == 'torch':
if args.xfm == 'dwt':
t = timeit.Timer('ifm(xfm(x))',
setup="""
import torch
from pytorch_wavelets import DWT, IDWT
x = torch.randn(*{sz}).to('{dev}')
xfm = DWT(J={J}, wave='{wave}').to('{dev}')
ifm = IDWT(wave='{wave}').to('{dev}')""".format(sz=size, dev=args.device, J=args.j,
wave=args.wave))
print('5 run average is {:.3f}s'.format(t.timeit(number=5)/5))
else:
t = timeit.Timer('ifm(xfm(x))',
setup="""
import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
x = torch.randn(*{sz}).to('{dev}')
xfm = DTCWTForward(J={J}).to('{dev}')
ifm = DTCWTInverse(J={J}).to('{dev}')""".format(sz=size, dev=args.device, J=args.j))
print('5 run average is {:.3f}s'.format(t.timeit(number=5)/5))
else:
if args.xfm == 'dwt':
t = timeit.Timer('ifm(xfm(x))',
setup="""
import numpy as np
import pywt
x = np.random.randn(*{sz})
xfm = lambda a: pywt.wavedec2(a, '{wave}', level={J}, mode='reflect')
ifm = lambda a: pywt.waverec2(a, '{wave}', mode='reflect')
""".format(sz=size, wave=args.wave, J=args.j))
print('5 run average is {:.3f}s'.format(t.timeit(number=5)/5))
else:
t = timeit.Timer("""
for b in x:
for c in b:
xfm.inverse(xfm.forward(c, nlevels={J}))
""".format(J=args.j), setup="""
import numpy as np
import dtcwt
x = np.random.randn(*{sz})
xfm = dtcwt.Transform2d(biort='near_sym_a', qshift='qshift_a')
""".format(sz=size))
print('5 run average is {:.3f}s'.format(t.timeit(number=5)/5))
# end_to_end(args.method, args.wave, size, args.j, args.device)
|