|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import unittest |
|
|
|
|
|
from MinkowskiEngine import ( |
|
|
SparseTensor, |
|
|
MinkowskiGlobalSumPooling, |
|
|
MinkowskiBroadcastFunction, |
|
|
MinkowskiBroadcastAddition, |
|
|
MinkowskiBroadcastMultiplication, |
|
|
MinkowskiBroadcast, |
|
|
MinkowskiBroadcastConcatenation, |
|
|
BroadcastMode, |
|
|
) |
|
|
|
|
|
from utils.gradcheck import gradcheck |
|
|
from tests.python.common import data_loader |
|
|
|
|
|
|
|
|
class TestBroadcast(unittest.TestCase): |
|
|
def test_broadcast_gpu(self): |
|
|
in_channels, D = 2, 2 |
|
|
coords, feats, labels = data_loader(in_channels) |
|
|
coords, feats_glob, labels = data_loader(in_channels) |
|
|
feats = feats.double() |
|
|
feats_glob = feats_glob.double() |
|
|
feats.requires_grad_() |
|
|
feats_glob.requires_grad_() |
|
|
|
|
|
input = SparseTensor(feats, coords) |
|
|
pool = MinkowskiGlobalSumPooling() |
|
|
input_glob = pool(input).detach() |
|
|
input_glob.F.requires_grad_() |
|
|
broadcast_add = MinkowskiBroadcastAddition() |
|
|
broadcast_mul = MinkowskiBroadcastMultiplication() |
|
|
broadcast_cat = MinkowskiBroadcastConcatenation() |
|
|
cpu_add = broadcast_add(input, input_glob) |
|
|
cpu_mul = broadcast_mul(input, input_glob) |
|
|
cpu_cat = broadcast_cat(input, input_glob) |
|
|
|
|
|
|
|
|
fn = MinkowskiBroadcastFunction() |
|
|
|
|
|
device = torch.device("cuda") |
|
|
|
|
|
input = SparseTensor(feats, coords, device=device) |
|
|
input_glob = pool(input).detach() |
|
|
gpu_add = broadcast_add(input, input_glob) |
|
|
gpu_mul = broadcast_mul(input, input_glob) |
|
|
gpu_cat = broadcast_cat(input, input_glob) |
|
|
|
|
|
self.assertTrue(torch.prod(gpu_add.F.cpu() - cpu_add.F < 1e-5).item() == 1) |
|
|
self.assertTrue(torch.prod(gpu_mul.F.cpu() - cpu_mul.F < 1e-5).item() == 1) |
|
|
self.assertTrue(torch.prod(gpu_cat.F.cpu() - cpu_cat.F < 1e-5).item() == 1) |
|
|
|
|
|
self.assertTrue( |
|
|
gradcheck( |
|
|
fn, |
|
|
( |
|
|
input.F, |
|
|
input_glob.F, |
|
|
broadcast_add.operation_type, |
|
|
input.coordinate_map_key, |
|
|
input_glob.coordinate_map_key, |
|
|
input.coordinate_manager, |
|
|
), |
|
|
) |
|
|
) |
|
|
|
|
|
self.assertTrue( |
|
|
gradcheck( |
|
|
fn, |
|
|
( |
|
|
input.F, |
|
|
input_glob.F, |
|
|
broadcast_mul.operation_type, |
|
|
input.coordinate_map_key, |
|
|
input_glob.coordinate_map_key, |
|
|
input.coordinate_manager, |
|
|
), |
|
|
) |
|
|
) |
|
|
|
|
|
def test_broadcast(self): |
|
|
in_channels, D = 2, 2 |
|
|
coords, feats, labels = data_loader(in_channels) |
|
|
coords, feats_glob, labels = data_loader(in_channels) |
|
|
feats = feats.double() |
|
|
feats_glob = feats_glob.double() |
|
|
feats.requires_grad_() |
|
|
feats_glob.requires_grad_() |
|
|
input = SparseTensor(feats, coords) |
|
|
pool = MinkowskiGlobalSumPooling() |
|
|
input_glob = pool(input).detach() |
|
|
input_glob.requires_grad_() |
|
|
broadcast = MinkowskiBroadcast() |
|
|
broadcast_cat = MinkowskiBroadcastConcatenation() |
|
|
broadcast_add = MinkowskiBroadcastAddition() |
|
|
broadcast_mul = MinkowskiBroadcastMultiplication() |
|
|
output = broadcast(input, input_glob) |
|
|
print(output) |
|
|
output = broadcast_cat(input, input_glob) |
|
|
print(output) |
|
|
output = broadcast_add(input, input_glob) |
|
|
print(output) |
|
|
output = broadcast_mul(input, input_glob) |
|
|
print(output) |
|
|
|
|
|
|
|
|
fn = MinkowskiBroadcastFunction() |
|
|
self.assertTrue( |
|
|
gradcheck( |
|
|
fn, |
|
|
( |
|
|
input.F, |
|
|
input_glob.F, |
|
|
broadcast_add.operation_type, |
|
|
input.coordinate_map_key, |
|
|
input_glob.coordinate_map_key, |
|
|
input.coordinate_manager, |
|
|
), |
|
|
) |
|
|
) |
|
|
|
|
|
self.assertTrue( |
|
|
gradcheck( |
|
|
fn, |
|
|
( |
|
|
input.F, |
|
|
input_glob.F, |
|
|
broadcast_mul.operation_type, |
|
|
input.coordinate_map_key, |
|
|
input_glob.coordinate_map_key, |
|
|
input.coordinate_manager, |
|
|
), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
unittest.main() |
|
|
|