|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import unittest |
|
|
|
|
|
from MinkowskiEngine import MinkowskiDirectMaxPoolingFunction |
|
|
|
|
|
from utils.gradcheck import gradcheck |
|
|
|
|
|
|
|
|
class TestCase(unittest.TestCase): |
|
|
def test(self): |
|
|
if not torch.cuda.is_available(): |
|
|
return |
|
|
pool = MinkowskiDirectMaxPoolingFunction() |
|
|
in_map = torch.randint(0, 5, (10,)).int() |
|
|
out_map = torch.randint(0, 3, (10,)).int() |
|
|
in_feat = torch.rand(5, 16).double() |
|
|
in_feat.requires_grad_() |
|
|
out_nrows = 3 |
|
|
out_feat = pool.apply(in_map, out_map, in_feat, out_nrows) |
|
|
print(out_feat) |
|
|
out_feat.sum().backward() |
|
|
|
|
|
self.assertTrue( |
|
|
gradcheck( |
|
|
pool, |
|
|
(in_map, out_map, in_feat, out_nrows), |
|
|
) |
|
|
) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
return |
|
|
|
|
|
in_map = in_map.cuda() |
|
|
out_map = out_map.cuda() |
|
|
in_feat = in_feat.cuda() |
|
|
|
|
|
out_feat = pool.apply(in_map, out_map, in_feat, out_nrows) |
|
|
print(out_feat) |
|
|
|
|
|
self.assertTrue( |
|
|
gradcheck( |
|
|
pool, |
|
|
(in_map, out_map, in_feat, out_nrows), |
|
|
) |
|
|
) |
|
|
|
|
|
def test_long(self): |
|
|
if not torch.cuda.is_available(): |
|
|
return |
|
|
pool = MinkowskiDirectMaxPoolingFunction() |
|
|
in_map = torch.randint(0, 5, (10,)) |
|
|
out_map = torch.randint(0, 3, (10,)) |
|
|
in_feat = torch.rand(5, 16).double() |
|
|
in_feat.requires_grad_() |
|
|
out_nrows = 3 |
|
|
out_feat = pool.apply(in_map, out_map, in_feat, out_nrows) |
|
|
print(out_feat) |
|
|
out_feat.sum().backward() |
|
|
|
|
|
self.assertTrue( |
|
|
gradcheck( |
|
|
pool, |
|
|
(in_map, out_map, in_feat, out_nrows), |
|
|
) |
|
|
) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
return |
|
|
|
|
|
in_map = in_map.cuda() |
|
|
out_map = out_map.cuda() |
|
|
in_feat = in_feat.cuda() |
|
|
|
|
|
out_feat = pool.apply(in_map, out_map, in_feat, out_nrows) |
|
|
print(out_feat) |
|
|
|
|
|
self.assertTrue( |
|
|
gradcheck( |
|
|
pool, |
|
|
(in_map, out_map, in_feat, out_nrows), |
|
|
) |
|
|
) |
|
|
|