import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from .conv_layers import DepthwiseSeparableConv, BasicBlock, Bottleneck, MBConv, FusedMBConv, ConvNormAct from .trans_layers import TransformerBlock from einops import rearrange import pdb class BidirectionAttention(nn.Module): def __init__(self, feat_dim, map_dim, out_dim, heads=4, dim_head=64, attn_drop=0., proj_drop=0., map_size=16, proj_type='depthwise'): super().__init__() self.inner_dim = dim_head * heads self.feat_dim = feat_dim self.map_dim = map_dim self.heads = heads self.scale = dim_head ** (-0.5) self.dim_head = dim_head self.map_size = map_size assert proj_type in ['linear', 'depthwise'] if proj_type == 'linear': self.feat_qv = nn.Conv2d(feat_dim, self.inner_dim*2, kernel_size=1, bias=False) self.feat_out = nn.Conv2d(self.inner_dim, out_dim, kernel_size=1, bias=False) else: self.feat_qv = DepthwiseSeparableConv(feat_dim, self.inner_dim * 2) self.feat_out = DepthwiseSeparableConv(self.inner_dim, out_dim) self.map_qv = nn.Conv2d(map_dim, self.inner_dim*2, kernel_size=1, bias=False) self.map_out = nn.Conv2d(self.inner_dim, map_dim, kernel_size=1, bias=False) self.attn_drop = nn.Dropout(attn_drop) self.proj_drop = nn.Dropout(proj_drop) def forward(self, feat, semantic_map): B, C, H, W = feat.shape feat_q, feat_v = self.feat_qv(feat).chunk(2, dim=1) # B, inner_dim, H, W map_q, map_v = self.map_qv(semantic_map).chunk(2, dim=1) # B, inner_dim, rs, rs feat_q, feat_v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head = self.dim_head, heads=self.heads, h=H, w=W), [feat_q, feat_v]) map_q, map_v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=self.map_size, w=self.map_size), [map_q, map_v]) attn = torch.einsum('bhid,bhjd->bhij', feat_q, map_q) attn *= self.scale feat_map_attn = F.softmax(attn, dim=-1) # semantic map is very concise that don't need dropout # add dropout migth cause unstable during training map_feat_attn = self.attn_drop(F.softmax(attn, dim=-2)) feat_out = torch.einsum('bhij,bhjd->bhid', feat_map_attn, map_v) feat_out = rearrange(feat_out, 'b heads (h w) dim_head -> b (dim_head heads) h w', h=H, w=W, dim_head=self.dim_head, heads=self.heads) map_out = torch.einsum('bhji,bhjd->bhid', map_feat_attn, feat_v) map_out = rearrange(map_out, 'b heads (h w) dim_head -> b (dim_head heads) h w', b=B, dim_head=self.dim_head, heads=self.heads, h=self.map_size, w=self.map_size) feat_out = self.proj_drop(self.feat_out(feat_out)) map_out = self.proj_drop(self.map_out(map_out)) return feat_out, map_out class BidirectionAttentionBlock(nn.Module): def __init__(self, feat_dim, map_dim, out_dim, heads, dim_head, norm=nn.BatchNorm2d, act=nn.GELU, expansion=4, attn_drop=0., proj_drop=0., map_size=8, proj_type='depthwise'): super().__init__() assert norm in [nn.BatchNorm2d, nn.InstanceNorm2d, True, False] assert act in [nn.ReLU, nn.ReLU6, nn.GELU, nn.SiLU, True, False] assert proj_type in ['linear', 'depthwise'] self.norm1 = norm(feat_dim) if norm else nn.Identity() # norm layer for feature map self.norm2 = norm(map_dim) if norm else nn.Identity() # norm layer for semantic map self.attn = BidirectionAttention(feat_dim, map_dim, out_dim, heads=heads, dim_head=dim_head, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type) self.shortcut = nn.Sequential() if feat_dim != out_dim: self.shortcut = ConvNormAct(feat_dim, out_dim, kernel_size=1, padding=0, norm=norm, act=act, preact=True) if proj_type == 'linear': self.feedforward = FusedMBConv(out_dim, out_dim, expansion=expansion, kernel_size=1, act=act, norm=norm) # 2 conv1x1 else: self.feedforward = MBConv(out_dim, out_dim, expansion=expansion, kernel_size=3, act=act, norm=norm, p=proj_drop) # depthwise conv def forward(self, x, semantic_map): feat = self.norm1(x) mapp = self.norm2(semantic_map) out, mapp = self.attn(feat, mapp) out += self.shortcut(x) out = self.feedforward(out) mapp += semantic_map return out, mapp class PatchMerging(nn.Module): """ Modified patch merging layer that works as down-sampling """ def __init__(self, dim, out_dim, norm=nn.BatchNorm2d, proj_type='depthwise', map_proj=True): super().__init__() self.dim = dim # 32 if proj_type == 'linear': self.reduction = nn.Conv2d(4*dim, out_dim, kernel_size=1, bias=False) else: self.reduction = DepthwiseSeparableConv(4*dim, out_dim) # (32*4, 64) self.norm = norm(4*dim) if map_proj: self.map_projection = nn.Conv2d(dim, out_dim, kernel_size=1, bias=False) # (32, 64, kernel_size, bias) def forward(self, x, semantic_map=None): """ x: B, C, H, W """ x0 = x[:, :, 0::2, 0::2] x1 = x[:, :, 1::2, 0::2] x2 = x[:, :, 0::2, 1::2] x3 = x[:, :, 1::2, 1::2] x = torch.cat([x0, x1, x2, x3], 1) # B, 4C, H, W x = self.norm(x) x = self.reduction(x) # depthwise + pointwise 4C -> outdim if semantic_map is not None: semantic_map = self.map_projection(semantic_map) # dim -> outdim return x, semantic_map class BasicLayer(nn.Module): """ A basic transformer layer for one stage No downsample of upsample operation in this layer, they are wraped in the down_block or up_block of UTNet """ def __init__(self, feat_dim, map_dim, out_dim, num_blocks, heads=4, dim_head=64, expansion=1, attn_drop=0., proj_drop=0., map_size=8, proj_type='depthwise', norm=nn.BatchNorm2d, act=nn.GELU): super().__init__() dim1 = feat_dim dim2 = out_dim self.blocks = nn.ModuleList([]) for i in range(num_blocks): self.blocks.append(BidirectionAttentionBlock(dim1, map_dim, dim2, heads, dim_head, expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act)) dim1 = out_dim def forward(self, x, semantic_map): for block in self.blocks: x, semantic_map = block(x, semantic_map) return x, semantic_map class SemanticMapGeneration(nn.Module): def __init__(self, feat_dim, map_dim, map_size): # (64, 64, 8) super().__init__() self.map_size = map_size # 8 self.map_dim = map_dim # 64 self.map_code_num = map_size * map_size # 8*8=64 self.base_proj = nn.Conv2d(feat_dim, map_dim, kernel_size=3, padding=1, bias=False) # (64, 64, 3, 1, false) self.semantic_proj = nn.Conv2d(feat_dim, self.map_code_num, kernel_size=3, padding=1, bias=False) # (64, 64, 3, 1 false) def forward(self, x): B, C, H, W = x.shape # B, C, H, W feat = self.base_proj(x) # B, map_dim, h, w weight_map = self.semantic_proj(x) # B, map_code_num, h, w weight_map = weight_map.view(B, self.map_code_num, -1) weight_map = F.softmax(weight_map, dim=2) # B, map_code_num, hw feat = feat.view(B, self.map_dim, -1) # B, map_dim, hw semantic_map = torch.einsum('bij,bkj->bik', feat, weight_map) return semantic_map.view(B, self.map_dim, self.map_size, self.map_size) class SemanticMapFusion(nn.Module): def __init__(self, in_dim_list, dim, heads, depth=1, norm=nn.BatchNorm2d): super().__init__() self.dim = dim # project all maps to the same channel num self.in_proj = nn.ModuleList([]) for i in range(len(in_dim_list)): self.in_proj.append(nn.Conv2d(in_dim_list[i], dim, kernel_size=1, bias=False)) self.fusion = TransformerBlock(dim, depth, heads, dim//heads, dim, attn_drop=0., proj_drop=0.) # project all maps back to their origin channel num self.out_proj = nn.ModuleList([]) for i in range(len(in_dim_list)): self.out_proj.append(nn.Conv2d(dim, in_dim_list[i], kernel_size=1, bias=False)) def forward(self, map_list): B, _, H, W = map_list[0].shape proj_maps = [self.in_proj[i](map_list[i]).view(B, self.dim, -1).permute(0, 2, 1) for i in range(len(map_list))] # B, L, C where L=HW proj_maps = torch.cat(proj_maps, dim=1) attned_maps = self.fusion(proj_maps) attned_maps = attned_maps.chunk(len(map_list), dim=1) maps_out = [self.out_proj[i](attned_maps[i].permute(0, 2, 1).view(B, self.dim, H, W)) for i in range(len(map_list))] return maps_out ####################################################################### # UTNet block that for one stage, which contains conv block and trans block class inconv(nn.Module): def __init__(self, in_ch, out_ch, block=BasicBlock, norm=nn.BatchNorm2d, act=nn.GELU): super().__init__() self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False) self.conv2 = block(out_ch, out_ch, norm=norm, act=act) def forward(self, x): if x.shape == 5: x = x.squeeze(1) out = self.conv1(x) # (3, 480, 480) -> (32, 480, 480] out = self.conv2(out) # block (32, 32, norm, act) conv norm relu 残差 return out class Down_block(nn.Module): def __init__(self, in_ch, out_ch, conv_num, trans_num, conv_block=BasicBlock, heads=4, dim_head=64, expansion=4, attn_drop=0., proj_drop=0., map_size=8, proj_type='depthwise', norm=nn.BatchNorm2d, act=nn.GELU, map_generate=False, map_proj=True, map_dim=None): # (32, 64, 2, 0, basicblock, batchnorm, gelu, False, False) super().__init__() map_dim = out_ch if map_dim is None else map_dim # 64 self.map_generate = map_generate # False if map_generate: self.map_gen = SemanticMapGeneration(out_ch, map_dim, map_size) # return semantic_map.view(B, self.map_dim, self.map_size, self.map_size) self.patch_merging = PatchMerging(in_ch, out_ch, proj_type=proj_type, norm=norm, map_proj=map_proj) # in_ch->out_ch block_list = [] for i in range(conv_num): # 2 block_list.append(conv_block(out_ch, out_ch, norm=norm, act=act)) dim1 = out_ch self.conv_blocks = nn.Sequential(*block_list) self.trans_blocks = BasicLayer(out_ch, map_dim, out_ch, num_blocks=trans_num, \ heads=heads, dim_head=dim_head, norm=norm, act=act, expansion=expansion,\ attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type) def forward(self, x, semantic_map=None): x, semantic_map = self.patch_merging(x, semantic_map) # in_ch->out_chan out = self.conv_blocks(x) # out->out if self.map_generate: semantic_map = self.map_gen(out) # (B, self.map_dim, self.map_size, self.map_size)) out, semantic_map = self.trans_blocks(out, semantic_map) return out, semantic_map class Up_block(nn.Module): def __init__(self, in_ch, out_ch, conv_num, trans_num, conv_block=BasicBlock, heads=4, dim_head=64, expansion=1, attn_drop=0., proj_drop=0., map_size=8, proj_type='linear', norm=nn.BatchNorm2d, act=nn.GELU, map_dim=None, map_shortcut=False): super().__init__() self.reduction = nn.Conv2d(in_ch+out_ch, out_ch, kernel_size=1, padding=0, bias=False) self.norm = norm(in_ch+out_ch) self.map_shortcut = map_shortcut map_dim = out_ch if map_dim is None else map_dim if map_shortcut: self.map_reduction = nn.Conv2d(in_ch+out_ch, map_dim, kernel_size=1, bias=False) else: self.map_reduction = nn.Conv2d(in_ch, map_dim, kernel_size=1, bias=False) self.trans_blocks = BasicLayer(out_ch, map_dim, out_ch, num_blocks=trans_num, \ heads=heads, dim_head=dim_head, norm=norm, act=act, expansion=expansion,\ attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type) conv_list = [] for i in range(conv_num): conv_list.append(conv_block(out_ch, out_ch, norm=norm, act=act)) self.conv_blocks = nn.Sequential(*conv_list) def forward(self, x1, x2, map1, map2=None): # x1: low-res feature, x2: high-res feature # map1: semantic map from previous low-res layer # map2: semantic map from encoder shortcut path, might be none if we don't have the map from encoder x1 = F.interpolate(x1, size=x2.shape[-2:], mode='bilinear', align_corners=True) feat = torch.cat([x1, x2], dim=1) out = self.reduction(self.norm(feat)) if self.map_shortcut and map2 is not None: semantic_map = torch.cat([map1, map2], dim=1) else: semantic_map = map1 semantic_map = self.map_reduction(semantic_map) out, semantic_map = self.trans_blocks(out, semantic_map) out = self.conv_blocks(out) return out, semantic_map