Spaces:
Sleeping
Sleeping
| import torch | |
| import argparse | |
| import py3nvml | |
| import timeit | |
| parser = argparse.ArgumentParser('Profile the dwt') | |
| parser.add_argument('method', choices=['torch', 'numpy'], | |
| help='Method to use to calculate dwt') | |
| parser.add_argument('xfm', choices=['dwt', 'dtcwt'], | |
| help='which transform to use') | |
| parser.add_argument('-f', '--forward', action='store_true', | |
| help='Only do forward transform (default is fwd and inv)') | |
| parser.add_argument('-j', type=int, default=2, | |
| help='number of scales of transform to do') | |
| parser.add_argument('-s', '--size', default=0, type=int, | |
| help='spatial size of input') | |
| parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'], | |
| help='which device to test') | |
| parser.add_argument('--wave', default='db4', | |
| help='which wavelet to use') | |
| parser.add_argument('--batch', default=16, type=int, | |
| help='Number of images in parallel') | |
| if __name__ == "__main__": | |
| args = parser.parse_args() | |
| py3nvml.grab_gpus(1) | |
| if args.size > 0: | |
| size = (args.batch, 1, args.size, args.size) | |
| else: | |
| size = (args.batch, 1, 128, 128) | |
| if args.method == 'torch': | |
| if args.xfm == 'dwt': | |
| t = timeit.Timer('ifm(xfm(x))', | |
| setup=""" | |
| import torch | |
| from pytorch_wavelets import DWT, IDWT | |
| x = torch.randn(*{sz}).to('{dev}') | |
| xfm = DWT(J={J}, wave='{wave}').to('{dev}') | |
| ifm = IDWT(wave='{wave}').to('{dev}')""".format(sz=size, dev=args.device, J=args.j, | |
| wave=args.wave)) | |
| print('5 run average is {:.3f}s'.format(t.timeit(number=5)/5)) | |
| else: | |
| t = timeit.Timer('ifm(xfm(x))', | |
| setup=""" | |
| import torch | |
| from pytorch_wavelets import DTCWTForward, DTCWTInverse | |
| x = torch.randn(*{sz}).to('{dev}') | |
| xfm = DTCWTForward(J={J}).to('{dev}') | |
| ifm = DTCWTInverse(J={J}).to('{dev}')""".format(sz=size, dev=args.device, J=args.j)) | |
| print('5 run average is {:.3f}s'.format(t.timeit(number=5)/5)) | |
| else: | |
| if args.xfm == 'dwt': | |
| t = timeit.Timer('ifm(xfm(x))', | |
| setup=""" | |
| import numpy as np | |
| import pywt | |
| x = np.random.randn(*{sz}) | |
| xfm = lambda a: pywt.wavedec2(a, '{wave}', level={J}, mode='reflect') | |
| ifm = lambda a: pywt.waverec2(a, '{wave}', mode='reflect') | |
| """.format(sz=size, wave=args.wave, J=args.j)) | |
| print('5 run average is {:.3f}s'.format(t.timeit(number=5)/5)) | |
| else: | |
| t = timeit.Timer(""" | |
| for b in x: | |
| for c in b: | |
| xfm.inverse(xfm.forward(c, nlevels={J})) | |
| """.format(J=args.j), setup=""" | |
| import numpy as np | |
| import dtcwt | |
| x = np.random.randn(*{sz}) | |
| xfm = dtcwt.Transform2d(biort='near_sym_a', qshift='qshift_a') | |
| """.format(sz=size)) | |
| print('5 run average is {:.3f}s'.format(t.timeit(number=5)/5)) | |
| # end_to_end(args.method, args.wave, size, args.j, args.device) | |