| from .. import config |
| import importlib |
| import torch |
| import torch.nn as nn |
| from .. import SparseTensor |
|
|
|
|
| _backends = {} |
|
|
|
|
| class SparseConv3d(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): |
| super(SparseConv3d, self).__init__() |
| if config.CONV not in _backends: |
| _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) |
| _backends[config.CONV].sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key) |
|
|
| def forward(self, x: SparseTensor) -> SparseTensor: |
| return _backends[config.CONV].sparse_conv3d_forward(self, x) |
|
|
|
|
| class SparseInverseConv3d(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): |
| super(SparseInverseConv3d, self).__init__() |
| if config.CONV not in _backends: |
| _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) |
| _backends[config.CONV].sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, bias, indice_key) |
|
|
| def forward(self, x: SparseTensor) -> SparseTensor: |
| return _backends[config.CONV].sparse_inverse_conv3d_forward(self, x) |
|
|