|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Multi-plane occupancy head.""" |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
class _UpProjection(nn.Module): |
|
|
"""Up-projection block.""" |
|
|
|
|
|
def __init__(self, num_input_features, num_output_features): |
|
|
"""Initialize the up-projection block.""" |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d( |
|
|
num_input_features, num_output_features, kernel_size=5, stride=1, padding=2, bias=False |
|
|
) |
|
|
self.bn1 = nn.BatchNorm2d(num_output_features) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.conv1_2 = nn.Conv2d( |
|
|
num_output_features, num_output_features, kernel_size=3, stride=1, padding=1, bias=False |
|
|
) |
|
|
self.bn1_2 = nn.BatchNorm2d(num_output_features) |
|
|
self.conv2 = nn.Conv2d( |
|
|
num_input_features, num_output_features, kernel_size=5, stride=1, padding=2, bias=False |
|
|
) |
|
|
self.bn2 = nn.BatchNorm2d(num_output_features) |
|
|
|
|
|
def forward(self, x, size): |
|
|
"""Forward pass.""" |
|
|
x = F.interpolate(x, size=size, mode="bilinear", align_corners=True) |
|
|
x_conv1 = self.relu(self.bn1(self.conv1(x))) |
|
|
bran1 = self.bn1_2(self.conv1_2(x_conv1)) |
|
|
bran2 = self.bn2(self.conv2(x)) |
|
|
out = self.relu(bran1 + bran2) |
|
|
return out |
|
|
|
|
|
|
|
|
class D(nn.Module): |
|
|
"""Decoder module.""" |
|
|
|
|
|
def __init__(self, block_channel): |
|
|
"""Initialize the decoder module.""" |
|
|
super().__init__() |
|
|
self.conv = nn.Conv2d( |
|
|
block_channel[0], block_channel[1], kernel_size=1, stride=1, bias=False |
|
|
) |
|
|
self.bn = nn.BatchNorm2d(block_channel[1]) |
|
|
|
|
|
self.up1 = _UpProjection(num_input_features=block_channel[1], |
|
|
num_output_features=block_channel[2]) |
|
|
|
|
|
self.up2 = _UpProjection(num_input_features=block_channel[2], |
|
|
num_output_features=block_channel[3]) |
|
|
|
|
|
add_feat_channel = block_channel[3] |
|
|
self.up3 = _UpProjection(num_input_features=add_feat_channel, |
|
|
num_output_features=add_feat_channel // 2) |
|
|
|
|
|
add_feat_channel = add_feat_channel // 2 |
|
|
self.up4 = _UpProjection(num_input_features=add_feat_channel, |
|
|
num_output_features=add_feat_channel // 2) |
|
|
|
|
|
def forward(self, x_block1, x_block2, x_block3, x_block4): |
|
|
"""Forward pass.""" |
|
|
x_d0 = F.relu(self.bn(self.conv(x_block4))) |
|
|
x_d1 = self.up1(x_d0, [x_block3.size(2), x_block3.size(3)]) |
|
|
x_d2 = self.up2(x_d1, [x_block2.size(2), x_block2.size(3)]) |
|
|
x_d3 = self.up3(x_d2, [x_block1.size(2), x_block1.size(3)]) |
|
|
x_d4 = self.up4(x_d3, [x_block1.size(2) * 2, x_block1.size(3) * 2]) |
|
|
return x_d4 |
|
|
|
|
|
|
|
|
class MFF(nn.Module): |
|
|
"""Multi-feature fusion module.""" |
|
|
|
|
|
def __init__(self, block_channel, num_features=64): |
|
|
"""Initialize the multi-feature fusion module.""" |
|
|
super().__init__() |
|
|
self.up1 = _UpProjection(num_input_features=block_channel[3], num_output_features=16) |
|
|
self.up2 = _UpProjection(num_input_features=block_channel[2], num_output_features=16) |
|
|
self.up3 = _UpProjection(num_input_features=block_channel[1], num_output_features=16) |
|
|
self.up4 = _UpProjection(num_input_features=block_channel[0], num_output_features=16) |
|
|
|
|
|
self.conv = nn.Conv2d( |
|
|
num_features, num_features, kernel_size=5, stride=1, padding=2, bias=False |
|
|
) |
|
|
self.bn = nn.BatchNorm2d(num_features) |
|
|
|
|
|
def forward(self, x_block1, x_block2, x_block3, x_block4, size): |
|
|
"""Forward pass.""" |
|
|
x_m1 = self.up1(x_block1, size) |
|
|
x_m2 = self.up2(x_block2, size) |
|
|
x_m3 = self.up3(x_block3, size) |
|
|
x_m4 = self.up4(x_block4, size) |
|
|
|
|
|
x = self.bn(self.conv(torch.cat((x_m1, x_m2, x_m3, x_m4), 1))) |
|
|
x = F.relu(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class R(nn.Module): |
|
|
"""Occupancy head module.""" |
|
|
|
|
|
def __init__(self, channel, num_class=1): |
|
|
"""Initialize the occupancy head module.""" |
|
|
super().__init__() |
|
|
|
|
|
self.target_size = (120, 160) |
|
|
self.resize = _UpProjection(num_input_features=channel, num_output_features=channel) |
|
|
|
|
|
self.conv0 = nn.Conv2d(channel, channel, kernel_size=5, stride=1, padding=2, bias=False) |
|
|
self.bn0 = nn.BatchNorm2d(channel) |
|
|
|
|
|
self.conv1 = nn.Conv2d(channel, channel, kernel_size=5, stride=1, padding=2, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(channel) |
|
|
|
|
|
self.conv2 = nn.Conv2d(channel, num_class, kernel_size=5, stride=1, padding=2, bias=True) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward pass.""" |
|
|
x0 = self.resize(x, self.target_size) |
|
|
x0 = self.conv0(x0) |
|
|
x0 = self.bn0(x0) |
|
|
x0 = F.relu(x0) |
|
|
|
|
|
x1 = self.conv1(x0) |
|
|
x1 = self.bn1(x1) |
|
|
x1 = F.relu(x1) |
|
|
|
|
|
x2 = self.conv2(x1) |
|
|
return x2 |
|
|
|
|
|
|
|
|
class MultiPlaneOccupancyHead(nn.Module): |
|
|
"""Multi-plane occupancy head.""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize the multi-plane occupancy head.""" |
|
|
super().__init__() |
|
|
block_channel = [2048, 1024, 512, 256] |
|
|
self.feature_key = ['res2', 'res3', 'res4', 'res5'] |
|
|
feature_channels = 64 |
|
|
|
|
|
self.D = D(block_channel) |
|
|
self.MFF = MFF(block_channel, feature_channels) |
|
|
head_channels = block_channel[-1] // 4 + feature_channels |
|
|
self.num_classes = 100 |
|
|
self.prediction = R(head_channels, self.num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward pass.""" |
|
|
x_block1, x_block2, x_block3, x_block4 = x[self.feature_key[0]], x[self.feature_key[1]], \ |
|
|
x[self.feature_key[2]], x[self.feature_key[3]] |
|
|
x_decoder = self.D(x_block1, x_block2, x_block3, x_block4) |
|
|
x_mff = self.MFF( |
|
|
x_block1, x_block2, x_block3, x_block4, [x_decoder.size(2), x_decoder.size(3)] |
|
|
) |
|
|
|
|
|
x_feat = torch.cat((x_decoder, x_mff), 1) |
|
|
occ_pred = self.prediction(x_feat) |
|
|
return occ_pred |
|
|
|