File size: 6,601 Bytes
f4a0919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multi-plane occupancy head."""

import torch
from torch import nn
from torch.nn import functional as F


class _UpProjection(nn.Module):
    """Up-projection block."""

    def __init__(self, num_input_features, num_output_features):
        """Initialize the up-projection block."""
        super().__init__()
        self.conv1 = nn.Conv2d(
            num_input_features, num_output_features, kernel_size=5, stride=1, padding=2, bias=False
        )
        self.bn1 = nn.BatchNorm2d(num_output_features)
        self.relu = nn.ReLU(inplace=True)
        self.conv1_2 = nn.Conv2d(
            num_output_features, num_output_features, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn1_2 = nn.BatchNorm2d(num_output_features)
        self.conv2 = nn.Conv2d(
            num_input_features, num_output_features, kernel_size=5, stride=1, padding=2, bias=False
        )
        self.bn2 = nn.BatchNorm2d(num_output_features)

    def forward(self, x, size):
        """Forward pass."""
        x = F.interpolate(x, size=size, mode="bilinear", align_corners=True)
        x_conv1 = self.relu(self.bn1(self.conv1(x)))
        bran1 = self.bn1_2(self.conv1_2(x_conv1))
        bran2 = self.bn2(self.conv2(x))
        out = self.relu(bran1 + bran2)
        return out


class D(nn.Module):
    """Decoder module."""

    def __init__(self, block_channel):
        """Initialize the decoder module."""
        super().__init__()
        self.conv = nn.Conv2d(
            block_channel[0], block_channel[1], kernel_size=1, stride=1, bias=False
        )
        self.bn = nn.BatchNorm2d(block_channel[1])

        self.up1 = _UpProjection(num_input_features=block_channel[1],
                                 num_output_features=block_channel[2])

        self.up2 = _UpProjection(num_input_features=block_channel[2],
                                 num_output_features=block_channel[3])

        add_feat_channel = block_channel[3]
        self.up3 = _UpProjection(num_input_features=add_feat_channel,
                                 num_output_features=add_feat_channel // 2)

        add_feat_channel = add_feat_channel // 2
        self.up4 = _UpProjection(num_input_features=add_feat_channel,
                                 num_output_features=add_feat_channel // 2)

    def forward(self, x_block1, x_block2, x_block3, x_block4):
        """Forward pass."""
        x_d0 = F.relu(self.bn(self.conv(x_block4)))
        x_d1 = self.up1(x_d0, [x_block3.size(2), x_block3.size(3)])
        x_d2 = self.up2(x_d1, [x_block2.size(2), x_block2.size(3)])
        x_d3 = self.up3(x_d2, [x_block1.size(2), x_block1.size(3)])
        x_d4 = self.up4(x_d3, [x_block1.size(2) * 2, x_block1.size(3) * 2])
        return x_d4


class MFF(nn.Module):
    """Multi-feature fusion module."""

    def __init__(self, block_channel, num_features=64):
        """Initialize the multi-feature fusion module."""
        super().__init__()
        self.up1 = _UpProjection(num_input_features=block_channel[3], num_output_features=16)
        self.up2 = _UpProjection(num_input_features=block_channel[2], num_output_features=16)
        self.up3 = _UpProjection(num_input_features=block_channel[1], num_output_features=16)
        self.up4 = _UpProjection(num_input_features=block_channel[0], num_output_features=16)

        self.conv = nn.Conv2d(
            num_features, num_features, kernel_size=5, stride=1, padding=2, bias=False
        )
        self.bn = nn.BatchNorm2d(num_features)

    def forward(self, x_block1, x_block2, x_block3, x_block4, size):
        """Forward pass."""
        x_m1 = self.up1(x_block1, size)
        x_m2 = self.up2(x_block2, size)
        x_m3 = self.up3(x_block3, size)
        x_m4 = self.up4(x_block4, size)

        x = self.bn(self.conv(torch.cat((x_m1, x_m2, x_m3, x_m4), 1)))
        x = F.relu(x)
        return x


class R(nn.Module):
    """Occupancy head module."""

    def __init__(self, channel, num_class=1):
        """Initialize the occupancy head module."""
        super().__init__()

        self.target_size = (120, 160)
        self.resize = _UpProjection(num_input_features=channel, num_output_features=channel)

        self.conv0 = nn.Conv2d(channel, channel, kernel_size=5, stride=1, padding=2, bias=False)
        self.bn0 = nn.BatchNorm2d(channel)

        self.conv1 = nn.Conv2d(channel, channel, kernel_size=5, stride=1, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(channel)

        self.conv2 = nn.Conv2d(channel, num_class, kernel_size=5, stride=1, padding=2, bias=True)

    def forward(self, x):
        """Forward pass."""
        x0 = self.resize(x, self.target_size)  # resize to 120*160
        x0 = self.conv0(x0)
        x0 = self.bn0(x0)
        x0 = F.relu(x0)

        x1 = self.conv1(x0)
        x1 = self.bn1(x1)
        x1 = F.relu(x1)

        x2 = self.conv2(x1)
        return x2


class MultiPlaneOccupancyHead(nn.Module):
    """Multi-plane occupancy head."""

    def __init__(self):
        """Initialize the multi-plane occupancy head."""
        super().__init__()
        block_channel = [2048, 1024, 512, 256]
        self.feature_key = ['res2', 'res3', 'res4', 'res5']
        feature_channels = 64

        self.D = D(block_channel)
        self.MFF = MFF(block_channel, feature_channels)
        head_channels = block_channel[-1] // 4 + feature_channels
        self.num_classes = 100
        self.prediction = R(head_channels, self.num_classes)

    def forward(self, x):
        """Forward pass."""
        x_block1, x_block2, x_block3, x_block4 = x[self.feature_key[0]], x[self.feature_key[1]], \
            x[self.feature_key[2]], x[self.feature_key[3]]
        x_decoder = self.D(x_block1, x_block2, x_block3, x_block4)
        x_mff = self.MFF(
            x_block1, x_block2, x_block3, x_block4, [x_decoder.size(2), x_decoder.size(3)]
        )

        x_feat = torch.cat((x_decoder, x_mff), 1)
        occ_pred = self.prediction(x_feat)
        return occ_pred