TUHs's picture
Upload 207 files
29b9c56
import torch
import pytest
from torch.autograd import gradcheck
from pytorch_wavelets.dwt.lowlevel import AFB2D, SFB2D
from pytorch_wavelets import DWTForward, DWTInverse
import py3nvml
from contextlib import contextmanager
ATOL = 1e-4
EPS = 1e-4
HAVE_GPU = torch.cuda.is_available()
if HAVE_GPU:
dev = torch.device('cuda')
else:
dev = torch.device('cpu')
@contextmanager
def set_double_precision():
old_prec = torch.get_default_dtype()
try:
torch.set_default_dtype(torch.float64)
yield
finally:
torch.set_default_dtype(old_prec)
def setup():
py3nvml.grab_gpus(1, gpu_fraction=0.5, env_set_ok=True)
@pytest.mark.skip("These tests take a very long time to compute")
@pytest.mark.parametrize("mode", [0, 1, 6])
def test_fwd(mode):
with set_double_precision():
x = torch.randn(1,3,16,16, device=dev, requires_grad=True)
xfm = DWTForward(J=2).to(dev)
input = (x, xfm.h0_row, xfm.h1_row, xfm.h0_col, xfm.h1_col, mode)
gradcheck(AFB2D.apply, input, eps=EPS, atol=ATOL)
@pytest.mark.skip("These tests take a very long time to compute")
@pytest.mark.parametrize("mode", [0, 1, 6])
def test_inv_j2(mode):
with set_double_precision():
low = torch.randn(1,3,16,16, device=dev, requires_grad=True)
high = torch.randn(1,3,3,16,16, device=dev, requires_grad=True)
ifm = DWTInverse().to(dev)
input = (low, high, ifm.g0_row, ifm.g1_row, ifm.g0_col, ifm.g1_col, mode)
gradcheck(SFB2D.apply, input, eps=EPS, atol=ATOL)