| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Decoder(nn.Module): | |
| def __init__(self, shape, num_img_feat, num_pla_feat): | |
| super(Decoder, self).__init__() | |
| self.shape = shape | |
| self.img_model = self._make_layer(num_img_feat) | |
| self.pla_model = self._make_layer(num_pla_feat) | |
| self.combined = self._make_output(num_img_feat + num_pla_feat) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
| m.weight.data.normal_(0, math.sqrt(2.0 / n)) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| def _make_layer(self, num_feat): | |
| ans = nn.ModuleList() | |
| for _ in range(num_feat): | |
| m = nn.Sequential( | |
| nn.Conv2d(1, 1, 3, padding=1), nn.BatchNorm2d(1), nn.ReLU(inplace=True) | |
| ) | |
| ans.append(m) | |
| return ans | |
| def _make_output(self, planes, readout=1): | |
| return nn.Sequential( | |
| nn.Conv2d(planes, readout, 3, stride=1, padding=1), | |
| nn.BatchNorm2d(readout), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| img_feat, pla_feat = x | |
| feat = [] | |
| for a, b in zip(img_feat, self.img_model): | |
| f = F.interpolate(b(a), self.shape) | |
| feat.append(f) | |
| for a, b in zip(pla_feat, self.pla_model): | |
| f = F.interpolate(b(a), self.shape) | |
| feat.append(f) | |
| feat = torch.cat(feat, dim=1) | |
| feat = self.combined(feat) | |
| return feat | |
| def build_decoder(model_path, *args): | |
| decoder = Decoder(*args) | |
| loaded = torch.load(model_path, weights_only=True)["state_dict"] | |
| decoder.load_state_dict(loaded) | |
| return decoder | |