| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__(self, shape, num_img_feat, num_pla_feat): |
| super(Decoder, self).__init__() |
| self.shape = shape |
| self.img_model = self._make_layer(num_img_feat) |
| self.pla_model = self._make_layer(num_pla_feat) |
|
|
| self.combined = self._make_output(num_img_feat + num_pla_feat) |
|
|
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| m.weight.data.normal_(0, math.sqrt(2.0 / n)) |
| elif isinstance(m, nn.BatchNorm2d): |
| m.weight.data.fill_(1) |
| m.bias.data.zero_() |
|
|
| def _make_layer(self, num_feat): |
| ans = nn.ModuleList() |
| for _ in range(num_feat): |
| m = nn.Sequential( |
| nn.Conv2d(1, 1, 3, padding=1), nn.BatchNorm2d(1), nn.ReLU(inplace=True) |
| ) |
| ans.append(m) |
| return ans |
|
|
| def _make_output(self, planes, readout=1): |
| return nn.Sequential( |
| nn.Conv2d(planes, readout, 3, stride=1, padding=1), |
| nn.BatchNorm2d(readout), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward(self, x): |
| img_feat, pla_feat = x |
| feat = [] |
|
|
| for a, b in zip(img_feat, self.img_model): |
| f = F.interpolate(b(a), self.shape) |
| feat.append(f) |
|
|
| for a, b in zip(pla_feat, self.pla_model): |
| f = F.interpolate(b(a), self.shape) |
| feat.append(f) |
|
|
| feat = torch.cat(feat, dim=1) |
| feat = self.combined(feat) |
| return feat |
|
|
|
|
| def build_decoder(model_path, *args): |
| decoder = Decoder(*args) |
| loaded = torch.load(model_path, weights_only=True)["state_dict"] |
| decoder.load_state_dict(loaded) |
| return decoder |
|
|