| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Multi-plane occupancy head.""" |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
|
|
| class _UpProjection(nn.Module): |
| """Up-projection block.""" |
|
|
| def __init__(self, num_input_features, num_output_features): |
| """Initialize the up-projection block.""" |
| super().__init__() |
| self.conv1 = nn.Conv2d( |
| num_input_features, num_output_features, kernel_size=5, stride=1, padding=2, bias=False |
| ) |
| self.bn1 = nn.BatchNorm2d(num_output_features) |
| self.relu = nn.ReLU(inplace=True) |
| self.conv1_2 = nn.Conv2d( |
| num_output_features, num_output_features, kernel_size=3, stride=1, padding=1, bias=False |
| ) |
| self.bn1_2 = nn.BatchNorm2d(num_output_features) |
| self.conv2 = nn.Conv2d( |
| num_input_features, num_output_features, kernel_size=5, stride=1, padding=2, bias=False |
| ) |
| self.bn2 = nn.BatchNorm2d(num_output_features) |
|
|
| def forward(self, x, size): |
| """Forward pass.""" |
| x = F.interpolate(x, size=size, mode="bilinear", align_corners=True) |
| x_conv1 = self.relu(self.bn1(self.conv1(x))) |
| bran1 = self.bn1_2(self.conv1_2(x_conv1)) |
| bran2 = self.bn2(self.conv2(x)) |
| out = self.relu(bran1 + bran2) |
| return out |
|
|
|
|
| class D(nn.Module): |
| """Decoder module.""" |
|
|
| def __init__(self, block_channel): |
| """Initialize the decoder module.""" |
| super().__init__() |
| self.conv = nn.Conv2d( |
| block_channel[0], block_channel[1], kernel_size=1, stride=1, bias=False |
| ) |
| self.bn = nn.BatchNorm2d(block_channel[1]) |
|
|
| self.up1 = _UpProjection(num_input_features=block_channel[1], |
| num_output_features=block_channel[2]) |
|
|
| self.up2 = _UpProjection(num_input_features=block_channel[2], |
| num_output_features=block_channel[3]) |
|
|
| add_feat_channel = block_channel[3] |
| self.up3 = _UpProjection(num_input_features=add_feat_channel, |
| num_output_features=add_feat_channel // 2) |
|
|
| add_feat_channel = add_feat_channel // 2 |
| self.up4 = _UpProjection(num_input_features=add_feat_channel, |
| num_output_features=add_feat_channel // 2) |
|
|
| def forward(self, x_block1, x_block2, x_block3, x_block4): |
| """Forward pass.""" |
| x_d0 = F.relu(self.bn(self.conv(x_block4))) |
| x_d1 = self.up1(x_d0, [x_block3.size(2), x_block3.size(3)]) |
| x_d2 = self.up2(x_d1, [x_block2.size(2), x_block2.size(3)]) |
| x_d3 = self.up3(x_d2, [x_block1.size(2), x_block1.size(3)]) |
| x_d4 = self.up4(x_d3, [x_block1.size(2) * 2, x_block1.size(3) * 2]) |
| return x_d4 |
|
|
|
|
| class MFF(nn.Module): |
| """Multi-feature fusion module.""" |
|
|
| def __init__(self, block_channel, num_features=64): |
| """Initialize the multi-feature fusion module.""" |
| super().__init__() |
| self.up1 = _UpProjection(num_input_features=block_channel[3], num_output_features=16) |
| self.up2 = _UpProjection(num_input_features=block_channel[2], num_output_features=16) |
| self.up3 = _UpProjection(num_input_features=block_channel[1], num_output_features=16) |
| self.up4 = _UpProjection(num_input_features=block_channel[0], num_output_features=16) |
|
|
| self.conv = nn.Conv2d( |
| num_features, num_features, kernel_size=5, stride=1, padding=2, bias=False |
| ) |
| self.bn = nn.BatchNorm2d(num_features) |
|
|
| def forward(self, x_block1, x_block2, x_block3, x_block4, size): |
| """Forward pass.""" |
| x_m1 = self.up1(x_block1, size) |
| x_m2 = self.up2(x_block2, size) |
| x_m3 = self.up3(x_block3, size) |
| x_m4 = self.up4(x_block4, size) |
|
|
| x = self.bn(self.conv(torch.cat((x_m1, x_m2, x_m3, x_m4), 1))) |
| x = F.relu(x) |
| return x |
|
|
|
|
| class R(nn.Module): |
| """Occupancy head module.""" |
|
|
| def __init__(self, channel, num_class=1): |
| """Initialize the occupancy head module.""" |
| super().__init__() |
|
|
| self.target_size = (120, 160) |
| self.resize = _UpProjection(num_input_features=channel, num_output_features=channel) |
|
|
| self.conv0 = nn.Conv2d(channel, channel, kernel_size=5, stride=1, padding=2, bias=False) |
| self.bn0 = nn.BatchNorm2d(channel) |
|
|
| self.conv1 = nn.Conv2d(channel, channel, kernel_size=5, stride=1, padding=2, bias=False) |
| self.bn1 = nn.BatchNorm2d(channel) |
|
|
| self.conv2 = nn.Conv2d(channel, num_class, kernel_size=5, stride=1, padding=2, bias=True) |
|
|
| def forward(self, x): |
| """Forward pass.""" |
| x0 = self.resize(x, self.target_size) |
| x0 = self.conv0(x0) |
| x0 = self.bn0(x0) |
| x0 = F.relu(x0) |
|
|
| x1 = self.conv1(x0) |
| x1 = self.bn1(x1) |
| x1 = F.relu(x1) |
|
|
| x2 = self.conv2(x1) |
| return x2 |
|
|
|
|
| class MultiPlaneOccupancyHead(nn.Module): |
| """Multi-plane occupancy head.""" |
|
|
| def __init__(self): |
| """Initialize the multi-plane occupancy head.""" |
| super().__init__() |
| block_channel = [2048, 1024, 512, 256] |
| self.feature_key = ['res2', 'res3', 'res4', 'res5'] |
| feature_channels = 64 |
|
|
| self.D = D(block_channel) |
| self.MFF = MFF(block_channel, feature_channels) |
| head_channels = block_channel[-1] // 4 + feature_channels |
| self.num_classes = 100 |
| self.prediction = R(head_channels, self.num_classes) |
|
|
| def forward(self, x): |
| """Forward pass.""" |
| x_block1, x_block2, x_block3, x_block4 = x[self.feature_key[0]], x[self.feature_key[1]], \ |
| x[self.feature_key[2]], x[self.feature_key[3]] |
| x_decoder = self.D(x_block1, x_block2, x_block3, x_block4) |
| x_mff = self.MFF( |
| x_block1, x_block2, x_block3, x_block4, [x_decoder.size(2), x_decoder.size(3)] |
| ) |
|
|
| x_feat = torch.cat((x_decoder, x_mff), 1) |
| occ_pred = self.prediction(x_feat) |
| return occ_pred |
|
|