Spaces:
Sleeping
Sleeping
| import torch | |
| import pytest | |
| from torch.autograd import gradcheck | |
| from pytorch_wavelets.dwt.lowlevel import AFB2D, SFB2D | |
| from pytorch_wavelets import DWTForward, DWTInverse | |
| import py3nvml | |
| from contextlib import contextmanager | |
| ATOL = 1e-4 | |
| EPS = 1e-4 | |
| HAVE_GPU = torch.cuda.is_available() | |
| if HAVE_GPU: | |
| dev = torch.device('cuda') | |
| else: | |
| dev = torch.device('cpu') | |
| def set_double_precision(): | |
| old_prec = torch.get_default_dtype() | |
| try: | |
| torch.set_default_dtype(torch.float64) | |
| yield | |
| finally: | |
| torch.set_default_dtype(old_prec) | |
| def setup(): | |
| py3nvml.grab_gpus(1, gpu_fraction=0.5, env_set_ok=True) | |
| def test_fwd(mode): | |
| with set_double_precision(): | |
| x = torch.randn(1,3,16,16, device=dev, requires_grad=True) | |
| xfm = DWTForward(J=2).to(dev) | |
| input = (x, xfm.h0_row, xfm.h1_row, xfm.h0_col, xfm.h1_col, mode) | |
| gradcheck(AFB2D.apply, input, eps=EPS, atol=ATOL) | |
| def test_inv_j2(mode): | |
| with set_double_precision(): | |
| low = torch.randn(1,3,16,16, device=dev, requires_grad=True) | |
| high = torch.randn(1,3,3,16,16, device=dev, requires_grad=True) | |
| ifm = DWTInverse().to(dev) | |
| input = (low, high, ifm.g0_row, ifm.g1_row, ifm.g0_col, ifm.g1_col, mode) | |
| gradcheck(SFB2D.apply, input, eps=EPS, atol=ATOL) | |