File size: 1,862 Bytes
d62394f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class Decoder(nn.Module):
    def __init__(self, shape, num_img_feat, num_pla_feat):
        super(Decoder, self).__init__()
        self.shape = shape
        self.img_model = self._make_layer(num_img_feat)
        self.pla_model = self._make_layer(num_pla_feat)

        self.combined = self._make_output(num_img_feat + num_pla_feat)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, num_feat):
        ans = nn.ModuleList()
        for _ in range(num_feat):
            m = nn.Sequential(
                nn.Conv2d(1, 1, 3, padding=1), nn.BatchNorm2d(1), nn.ReLU(inplace=True)
            )
            ans.append(m)
        return ans

    def _make_output(self, planes, readout=1):
        return nn.Sequential(
            nn.Conv2d(planes, readout, 3, stride=1, padding=1),
            nn.BatchNorm2d(readout),
            nn.Sigmoid(),
        )

    def forward(self, x):
        img_feat, pla_feat = x
        feat = []

        for a, b in zip(img_feat, self.img_model):
            f = F.interpolate(b(a), self.shape)
            feat.append(f)

        for a, b in zip(pla_feat, self.pla_model):
            f = F.interpolate(b(a), self.shape)
            feat.append(f)

        feat = torch.cat(feat, dim=1)
        feat = self.combined(feat)
        return feat


def build_decoder(model_path, *args):
    decoder = Decoder(*args)
    loaded = torch.load(model_path, weights_only=True)["state_dict"]
    decoder.load_state_dict(loaded)
    return decoder