| | import torch |
| |
|
| | def model_init(m): |
| | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear): |
| | torch.nn.init.xavier_uniform_(m.weight) |
| | torch.nn.init.zeros_(m.bias) |
| |
|
| | class NLB(torch.nn.Module): |
| | def __init__(self, in_ch, relu_a=0.01): |
| | self.inter_ch = torch.div(in_ch, 2, rounding_mode='floor').item() |
| | super().__init__() |
| | self.theta_layer = torch.nn.Conv2d(in_channels=in_ch, out_channels=self.inter_ch, \ |
| | kernel_size=1, padding=0) |
| | self.phi_layer = torch.nn.Conv2d(in_channels=in_ch, out_channels=self.inter_ch, \ |
| | kernel_size=1, padding=0) |
| | self.g_layer = torch.nn.Conv2d(in_channels=in_ch, out_channels=self.inter_ch, \ |
| | kernel_size=1, padding=0) |
| | self.atten_act = torch.nn.Softmax(dim=-1) |
| | self.out_cnn = torch.nn.Conv2d(in_channels=self.inter_ch, out_channels=in_ch, \ |
| | kernel_size=1, padding=0) |
| | |
| | def forward(self, x): |
| | mbsz, _, h, w = x.size() |
| | |
| | theta = self.theta_layer(x).view(mbsz, self.inter_ch, -1).permute(0, 2, 1) |
| | phi = self.phi_layer(x).view(mbsz, self.inter_ch, -1) |
| | g = self.g_layer(x).view(mbsz, self.inter_ch, -1).permute(0, 2, 1) |
| | |
| | theta_phi = self.atten_act(torch.matmul(theta, phi)) |
| | |
| | theta_phi_g = torch.matmul(theta_phi, g).permute(0, 2, 1).view(mbsz, self.inter_ch, h, w) |
| | |
| | _out_tmp = self.out_cnn(theta_phi_g) |
| | _out_tmp = torch.add(_out_tmp, x) |
| | |
| | return _out_tmp |
| |
|
| |
|
| | class BraggNN(torch.nn.Module): |
| | def __init__(self, imgsz, fcsz=(64, 32, 16, 8)): |
| | super().__init__() |
| | self.cnn_ops = [] |
| | cnn_out_chs = (64, 32, 8) |
| | cnn_in_chs = (1, ) + cnn_out_chs[:-1] |
| | fsz = imgsz |
| | for ic, oc, in zip(cnn_in_chs, cnn_out_chs): |
| | self.cnn_ops += [ |
| | torch.nn.Conv2d(in_channels=ic, out_channels=oc, kernel_size=3, \ |
| | stride=1, padding=0), |
| | torch.nn.LeakyReLU(negative_slope=0.01), |
| | ] |
| | fsz -= 2 |
| | self.nlb = NLB(in_ch=cnn_out_chs[0]) |
| | self.dense_ops = [] |
| | dense_in_chs = (fsz * fsz * cnn_out_chs[-1], ) + fcsz[:-1] |
| | for ic, oc in zip(dense_in_chs, fcsz): |
| | self.dense_ops += [ |
| | torch.nn.Linear(ic, oc), |
| | torch.nn.LeakyReLU(negative_slope=0.01), |
| | ] |
| | |
| | self.dense_ops += [torch.nn.Linear(fcsz[-1], 2), ] |
| | |
| | self.cnn_layers = torch.nn.Sequential(*self.cnn_ops) |
| | self.dense_layers = torch.nn.Sequential(*self.dense_ops) |
| | |
| | def forward(self, x): |
| | _out = x |
| | for layer in self.cnn_layers[:1]: |
| | _out = layer(_out) |
| |
|
| | _out = self.nlb(_out) |
| |
|
| | for layer in self.cnn_layers[1:]: |
| | _out = layer(_out) |
| | |
| | _out = _out.flatten(start_dim=1) |
| | for layer in self.dense_layers: |
| | _out = layer(_out) |
| | |
| | return _out |
| |
|