| | import torch.nn as nn |
| | from .net_utils import single_conv, double_conv, double_conv_down, double_conv_up, PosEnSine |
| | from .transformer_basics import OurMultiheadAttention |
| |
|
| |
|
| | class TransformerDecoderUnit(nn.Module): |
| | def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None): |
| | super(TransformerDecoderUnit, self).__init__() |
| | self.feat_dim = feat_dim |
| | self.attn_type = attn_type |
| | self.pos_en_flag = pos_en_flag |
| | self.P = P |
| |
|
| | self.pos_en = PosEnSine(self.feat_dim // 2) |
| | self.attn = OurMultiheadAttention(feat_dim, n_head) |
| |
|
| | self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) |
| | self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) |
| | self.activation = nn.ReLU(inplace=True) |
| |
|
| | self.norm = nn.BatchNorm2d(self.feat_dim) |
| |
|
| | def forward(self, q, k, v): |
| | if self.pos_en_flag: |
| | q_pos_embed = self.pos_en(q) |
| | k_pos_embed = self.pos_en(k) |
| | else: |
| | q_pos_embed = 0 |
| | k_pos_embed = 0 |
| |
|
| | |
| | out = self.attn( |
| | q=q + q_pos_embed, k=k + k_pos_embed, v=v, attn_type=self.attn_type, P=self.P |
| | )[0] |
| |
|
| | |
| | out2 = self.linear2(self.activation(self.linear1(out))) |
| | out = out + out2 |
| | out = self.norm(out) |
| |
|
| | return out |
| |
|
| |
|
| | class Unet(nn.Module): |
| | def __init__(self, in_ch, feat_ch, out_ch): |
| | super().__init__() |
| | self.conv_in = single_conv(in_ch, feat_ch) |
| |
|
| | self.conv1 = double_conv_down(feat_ch, feat_ch) |
| | self.conv2 = double_conv_down(feat_ch, feat_ch) |
| | self.conv3 = double_conv(feat_ch, feat_ch) |
| | self.conv4 = double_conv_up(feat_ch, feat_ch) |
| | self.conv5 = double_conv_up(feat_ch, feat_ch) |
| | self.conv6 = double_conv(feat_ch, out_ch) |
| |
|
| | def forward(self, x): |
| | feat0 = self.conv_in(x) |
| | feat1 = self.conv1(feat0) |
| | feat2 = self.conv2(feat1) |
| | feat3 = self.conv3(feat2) |
| | feat3 = feat3 + feat2 |
| | feat4 = self.conv4(feat3) |
| | feat4 = feat4 + feat1 |
| | feat5 = self.conv5(feat4) |
| | feat5 = feat5 + feat0 |
| | feat6 = self.conv6(feat5) |
| |
|
| | return feat0, feat1, feat2, feat3, feat4, feat6 |
| |
|
| |
|
| | class Texformer(nn.Module): |
| | def __init__(self, opts): |
| | super().__init__() |
| | self.feat_dim = opts.feat_dim |
| | src_ch = opts.src_ch |
| | tgt_ch = opts.tgt_ch |
| | out_ch = opts.out_ch |
| | self.mask_fusion = opts.mask_fusion |
| |
|
| | if not self.mask_fusion: |
| | v_ch = out_ch |
| | else: |
| | v_ch = 2 + 3 |
| |
|
| | self.unet_q = Unet(tgt_ch, self.feat_dim, self.feat_dim) |
| | self.unet_k = Unet(src_ch, self.feat_dim, self.feat_dim) |
| | self.unet_v = Unet(v_ch, self.feat_dim, self.feat_dim) |
| |
|
| | self.trans_dec = nn.ModuleList( |
| | [ |
| | None, None, None, |
| | TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'softmax'), |
| | TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct'), |
| | TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct') |
| | ] |
| | ) |
| |
|
| | self.conv0 = double_conv(self.feat_dim, self.feat_dim) |
| | self.conv1 = double_conv_down(self.feat_dim, self.feat_dim) |
| | self.conv2 = double_conv_down(self.feat_dim, self.feat_dim) |
| | self.conv3 = double_conv(self.feat_dim, self.feat_dim) |
| | self.conv4 = double_conv_up(self.feat_dim, self.feat_dim) |
| | self.conv5 = double_conv_up(self.feat_dim, self.feat_dim) |
| |
|
| | if not self.mask_fusion: |
| | self.conv6 = nn.Sequential( |
| | single_conv(self.feat_dim, self.feat_dim), |
| | nn.Conv2d(self.feat_dim, out_ch, 3, 1, 1) |
| | ) |
| | else: |
| | self.conv6 = nn.Sequential( |
| | single_conv(self.feat_dim, self.feat_dim), |
| | nn.Conv2d(self.feat_dim, 2 + 3 + 1, 3, 1, 1) |
| | ) |
| | self.sigmoid = nn.Sigmoid() |
| |
|
| | self.tanh = nn.Tanh() |
| |
|
| | def forward(self, q, k, v): |
| | print('qkv', q.shape, k.shape, v.shape) |
| | q_feat = self.unet_q(q) |
| | k_feat = self.unet_k(k) |
| | v_feat = self.unet_v(v) |
| |
|
| | print('q_feat', len(q_feat)) |
| | outputs = [] |
| | for i in range(3, len(q_feat)): |
| | print(i, q_feat[i].shape, k_feat[i].shape, v_feat[i].shape) |
| | outputs.append(self.trans_dec[i](q_feat[i], k_feat[i], v_feat[i])) |
| | print('outputs', outputs[-1].shape) |
| |
|
| | f0 = self.conv0(outputs[2]) |
| | f1 = self.conv1(f0) |
| | f1 = f1 + outputs[1] |
| | f2 = self.conv2(f1) |
| | f2 = f2 + outputs[0] |
| | f3 = self.conv3(f2) |
| | f3 = f3 + outputs[0] + f2 |
| | f4 = self.conv4(f3) |
| | f4 = f4 + outputs[1] + f1 |
| | f5 = self.conv5(f4) |
| | f5 = f5 + outputs[2] + f0 |
| | if not self.mask_fusion: |
| | out = self.tanh(self.conv6(f5)) |
| | else: |
| | out_ = self.conv6(f5) |
| | out = [self.tanh(out_[:, :2]), self.tanh(out_[:, 2:5]), self.sigmoid(out_[:, 5:])] |
| | return out |
| |
|