Spaces:
Running
Running
| import torch | |
| from rstor.architecture.nafnet import NAFNet | |
| def test_nafnet(): | |
| enc_blks = [1, 1] | |
| middle_blk_num = 1 | |
| dec_blks = [1, 2] | |
| model = NAFNet( | |
| img_channel=3, | |
| width=2, | |
| middle_blk_num=middle_blk_num, | |
| enc_blk_nums=enc_blks, | |
| dec_blk_nums=dec_blks, | |
| ) | |
| x = torch.rand(2, 3, 128, 128) | |
| y = model(x) | |
| assert y.shape == (2, 3, 128, 128) | |