| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import torch |
|
|
| import MinkowskiEngine as ME |
| import MinkowskiEngine.MinkowskiFunctional as MF |
|
|
| from tests.python.common import data_loader |
|
|
|
|
| class UNet(ME.MinkowskiNetwork): |
|
|
| def __init__(self, in_nchannel, out_nchannel, D): |
| super(UNet, self).__init__(D) |
| self.block1 = torch.nn.Sequential( |
| ME.MinkowskiConvolution( |
| in_channels=in_nchannel, |
| out_channels=8, |
| kernel_size=3, |
| stride=1, |
| dimension=D), |
| ME.MinkowskiBatchNorm(8)) |
|
|
| self.block2 = torch.nn.Sequential( |
| ME.MinkowskiConvolution( |
| in_channels=8, |
| out_channels=16, |
| kernel_size=3, |
| stride=2, |
| dimension=D), |
| ME.MinkowskiBatchNorm(16), |
| ) |
|
|
| self.block3 = torch.nn.Sequential( |
| ME.MinkowskiConvolution( |
| in_channels=16, |
| out_channels=32, |
| kernel_size=3, |
| stride=2, |
| dimension=D), |
| ME.MinkowskiBatchNorm(32)) |
|
|
| self.block3_tr = torch.nn.Sequential( |
| ME.MinkowskiConvolutionTranspose( |
| in_channels=32, |
| out_channels=16, |
| kernel_size=3, |
| stride=2, |
| dimension=D), |
| ME.MinkowskiBatchNorm(16)) |
|
|
| self.block2_tr = torch.nn.Sequential( |
| ME.MinkowskiConvolutionTranspose( |
| in_channels=32, |
| out_channels=16, |
| kernel_size=3, |
| stride=2, |
| dimension=D), |
| ME.MinkowskiBatchNorm(16)) |
|
|
| self.conv1_tr = ME.MinkowskiConvolution( |
| in_channels=24, |
| out_channels=out_nchannel, |
| kernel_size=1, |
| stride=1, |
| dimension=D) |
|
|
| def forward(self, x): |
| out_s1 = self.block1(x) |
| out = MF.relu(out_s1) |
|
|
| out_s2 = self.block2(out) |
| out = MF.relu(out_s2) |
|
|
| out_s4 = self.block3(out) |
| out = MF.relu(out_s4) |
|
|
| out = MF.relu(self.block3_tr(out)) |
| out = ME.cat(out, out_s2) |
|
|
| out = MF.relu(self.block2_tr(out)) |
| out = ME.cat(out, out_s1) |
|
|
| return self.conv1_tr(out) |
|
|
|
|
| if __name__ == '__main__': |
| |
| net = UNet(3, 5, D=2) |
| print(net) |
|
|
| |
| coords, feat, label = data_loader() |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| net = net.to(device) |
| input = ME.SparseTensor(feat, coords, device=device) |
|
|
| |
| output = net(input) |
|
|