Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # @Time : 2024/7/26 上午11:19 | |
| # @Author : xiaoshun | |
| # @Email : 3038523973@qq.com | |
| # @File : dbnet.py | |
| # @Software: PyCharm | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| # from models.Transformer.ViT import truncated_normal_ | |
| # Decoder细化卷积模块 | |
| class SBR(nn.Module): | |
| def __init__(self, in_ch): | |
| super(SBR, self).__init__() | |
| self.conv1x3 = nn.Sequential( | |
| nn.Conv2d(in_ch, in_ch, kernel_size=(1, 3), stride=1, padding=(0, 1)), | |
| nn.BatchNorm2d(in_ch), | |
| nn.ReLU(True) | |
| ) | |
| self.conv3x1 = nn.Sequential( | |
| nn.Conv2d(in_ch, in_ch, kernel_size=(3, 1), stride=1, padding=(1, 0)), | |
| nn.BatchNorm2d(in_ch), | |
| nn.ReLU(True) | |
| ) | |
| def forward(self, x): | |
| out = self.conv3x1(self.conv1x3(x)) # 先进行1x3的卷积,得到结果并将结果再进行3x1的卷积 | |
| return out + x | |
| # 下采样卷积模块 stage 1,2,3 | |
| class c_stage123(nn.Module): | |
| def __init__(self, in_chans, out_chans): | |
| super().__init__() | |
| self.stage123 = nn.Sequential( | |
| nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(out_chans), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(out_chans), | |
| nn.ReLU(), | |
| ) | |
| self.conv1x1_123 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1) | |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| def forward(self, x): | |
| stage123 = self.stage123(x) # 3*3卷积,两倍下采样 3*224*224-->64*112*112 | |
| max = self.maxpool(x) # 最大值池化,两倍下采样 3*224*224-->3*112*112 | |
| max = self.conv1x1_123(max) # 1*1卷积 3*112*112-->64*112*112 | |
| stage123 = stage123 + max # 残差结构,广播机制 | |
| return stage123 | |
| # 下采样卷积模块 stage4,5 | |
| class c_stage45(nn.Module): | |
| def __init__(self, in_chans, out_chans): | |
| super().__init__() | |
| self.stage45 = nn.Sequential( | |
| nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(out_chans), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(out_chans), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(out_chans), | |
| nn.ReLU(), | |
| ) | |
| self.conv1x1_45 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1) | |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| def forward(self, x): | |
| stage45 = self.stage45(x) # 3*3卷积模块 2倍下采样 | |
| max = self.maxpool(x) # 最大值池化,两倍下采样 | |
| max = self.conv1x1_45(max) # 1*1卷积模块 调整通道数 | |
| stage45 = stage45 + max # 残差结构 | |
| return stage45 | |
| class Identity(nn.Module): # 恒等映射 | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x): | |
| return x | |
| # 轻量卷积模块 | |
| class DepthwiseConv2d(nn.Module): # 用于自注意力机制 | |
| def __init__(self, in_chans, out_chans, kernel_size=1, stride=1, padding=0, dilation=1): | |
| super().__init__() | |
| # depthwise conv | |
| self.depthwise = nn.Conv2d( | |
| in_channels=in_chans, | |
| out_channels=in_chans, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, # 深层卷积的膨胀率 | |
| groups=in_chans # 指定分组卷积的组数 | |
| ) | |
| # batch norm | |
| self.bn = nn.BatchNorm2d(num_features=in_chans) | |
| # pointwise conv 逐点卷积 | |
| self.pointwise = nn.Conv2d( | |
| in_channels=in_chans, | |
| out_channels=out_chans, | |
| kernel_size=1 | |
| ) | |
| def forward(self, x): | |
| x = self.depthwise(x) | |
| x = self.bn(x) | |
| x = self.pointwise(x) | |
| return x | |
| # residual skip connection 残差跳跃连接 | |
| class Residual(nn.Module): | |
| def __init__(self, fn): | |
| super().__init__() | |
| self.fn = fn | |
| def forward(self, input, **kwargs): | |
| x = self.fn(input, **kwargs) | |
| return (x + input) | |
| # layer norm plus 层归一化 | |
| class PreNorm(nn.Module): # 代表神经网络层 | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.fn = fn | |
| def forward(self, input, **kwargs): | |
| return self.fn(self.norm(input), **kwargs) | |
| # FeedForward层使得representation的表达能力更强 | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout=0.): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(in_features=dim, out_features=hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(in_features=hidden_dim, out_features=dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, input): | |
| return self.net(input) | |
| class ConvAttnetion(nn.Module): | |
| ''' | |
| using the Depth_Separable_Wise Conv2d to produce the q, k, v instead of using Linear Project in ViT | |
| ''' | |
| def __init__(self, dim, img_size, heads=8, dim_head=64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, | |
| dropout=0., last_stage=False): | |
| super().__init__() | |
| self.last_stage = last_stage | |
| self.img_size = img_size | |
| inner_dim = dim_head * heads # 512 | |
| project_out = not (heads == 1 and dim_head == dim) | |
| self.heads = heads | |
| self.scale = dim_head ** (-0.5) | |
| pad = (kernel_size - q_stride) // 2 | |
| self.to_q = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=q_stride, | |
| padding=pad) # 自注意力机制 | |
| self.to_k = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=k_stride, | |
| padding=pad) | |
| self.to_v = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=v_stride, | |
| padding=pad) | |
| self.to_out = nn.Sequential( | |
| nn.Linear( | |
| in_features=inner_dim, | |
| out_features=dim | |
| ), | |
| nn.Dropout(dropout) | |
| ) if project_out else Identity() | |
| def forward(self, x): | |
| b, n, c, h = *x.shape, self.heads # * 星号的作用大概是去掉 tuple 属性吧 | |
| # print(x.shape) | |
| # print('+++++++++++++++++++++++++++++++++') | |
| # if语句内容没有使用 | |
| if self.last_stage: | |
| cls_token = x[:, 0] | |
| # print(cls_token.shape) | |
| # print('+++++++++++++++++++++++++++++++++') | |
| x = x[:, 1:] # 去掉每个数组的第一个元素 | |
| cls_token = rearrange(torch.unsqueeze(cls_token, dim=1), 'b n (h d) -> b h n d', h=h) | |
| # rearrange:用于对张量的维度进行重新变换排序,可用于替换pytorch中的reshape,view,transpose和permute等操作 | |
| x = rearrange(x, 'b (l w) n -> b n l w', l=self.img_size, w=self.img_size) # [1, 3136, 64]-->1*64*56*56 | |
| # batch_size,N(通道数),h,w | |
| q = self.to_q(x) # 1*64*56*56-->1*64*56*56 | |
| # print(q.shape) | |
| # print('++++++++++++++') | |
| q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h) # 1*64*56*56-->1*1*3136*64 | |
| # print(q.shape) | |
| # print('=====================') | |
| # batch_size,head,h*w,dim_head | |
| k = self.to_k(x) # 操作和q一样 | |
| k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h) | |
| # batch_size,head,h*w,dim_head | |
| v = self.to_v(x) ##操作和q一样 | |
| # print(v.shape) | |
| # print('[[[[[[[[[[[[[[[[[[[[[[[[[[[[') | |
| v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h) | |
| # print(v.shape) | |
| # print(']]]]]]]]]]]]]]]]]]]]]]]]]]]') | |
| # batch_size,head,h*w,dim_head | |
| if self.last_stage: | |
| # print(q.shape) | |
| # print('================') | |
| q = torch.cat([cls_token, q], dim=2) | |
| # print(q.shape) | |
| # print('++++++++++++++++++') | |
| v = torch.cat([cls_token, v], dim=2) | |
| k = torch.cat([cls_token, k], dim=2) | |
| # calculate attention by matmul + scale | |
| # permute:(batch_size,head,dim_head,h*w | |
| # print(k.shape) | |
| # print('++++++++++++++++++++') | |
| k = k.permute(0, 1, 3, 2) # 1*1*3136*64-->1*1*64*3136 | |
| # print(k.shape) | |
| # print('====================') | |
| attention = (q.matmul(k)) # 1*1*3136*3136 | |
| # print(attention.shape) | |
| # print('--------------------') | |
| attention = attention * self.scale # 可以得到一个logit的向量,避免出现梯度下降和梯度爆炸 | |
| # print(attention.shape) | |
| # print('####################') | |
| # pass a softmax | |
| attention = F.softmax(attention, dim=-1) | |
| # print(attention.shape) | |
| # print('********************') | |
| # matmul v | |
| # attention.matmul(v):(batch_size,head,h*w,dim_head) | |
| # permute:(batch_size,h*w,head,dim_head) | |
| out = (attention.matmul(v)).permute(0, 2, 1, 3).reshape(b, n, | |
| c) # 1*3136*64 这些操作的目的是将注意力权重和值向量相乘后得到的结果进行重塑,得到一个形状为 (batch size, 序列长度, 值向量或矩阵的维度) 的张量 | |
| # linear project | |
| out = self.to_out(out) | |
| return out | |
| # Reshape Layers | |
| class Rearrange(nn.Module): | |
| def __init__(self, string, h, w): | |
| super().__init__() | |
| self.string = string | |
| self.h = h | |
| self.w = w | |
| def forward(self, input): | |
| if self.string == 'b c h w -> b (h w) c': | |
| N, C, H, W = input.shape | |
| # print(input.shape) | |
| x = torch.reshape(input, shape=(N, -1, self.h * self.w)).permute(0, 2, 1) | |
| # print(x.shape) | |
| # print('+++++++++++++++++++') | |
| if self.string == 'b (h w) c -> b c h w': | |
| N, _, C = input.shape | |
| # print(input.shape) | |
| x = torch.reshape(input, shape=(N, self.h, self.w, -1)).permute(0, 3, 1, 2) | |
| # print(x.shape) | |
| # print('=====================') | |
| return x | |
| # Transformer layers | |
| class Transformer(nn.Module): | |
| def __init__(self, dim, img_size, depth, heads, dim_head, mlp_dim, dropout=0., last_stage=False): | |
| super().__init__() | |
| self.layers = nn.ModuleList([ # 管理子模块,参数注册 | |
| nn.ModuleList([ | |
| PreNorm(dim=dim, fn=ConvAttnetion(dim, img_size, heads=heads, dim_head=dim_head, dropout=dropout, | |
| last_stage=last_stage)), # 归一化,重参数化 | |
| PreNorm(dim=dim, fn=FeedForward(dim=dim, hidden_dim=mlp_dim, dropout=dropout)) | |
| ]) for _ in range(depth) | |
| ]) | |
| def forward(self, x): | |
| for attn, ff in self.layers: | |
| x = x + attn(x) | |
| x = x + ff(x) | |
| return x | |
| class DBNet(nn.Module): # 最主要的大函数 | |
| def __init__(self, img_size, in_channels, num_classes, dim=64, kernels=[7, 3, 3, 3], strides=[4, 2, 2, 2], | |
| heads=[1, 3, 6, 6], | |
| depth=[1, 2, 10, 10], pool='cls', dropout=0., emb_dropout=0., scale_dim=4, ): | |
| super().__init__() | |
| assert pool in ['cls', 'mean'], f'pool type must be either cls or mean pooling' | |
| self.pool = pool | |
| self.dim = dim | |
| # stage1 | |
| # k:7 s:4 in: 1, 64, 56, 56 out: 1, 3136, 64 | |
| self.stage1_conv_embed = nn.Sequential( | |
| nn.Conv2d( # 1*3*224*224-->[1, 64, 56, 56] | |
| in_channels=in_channels, | |
| out_channels=dim, | |
| kernel_size=kernels[0], | |
| stride=strides[0], | |
| padding=2 | |
| ), | |
| Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4), # [1, 64, 56, 56]-->[1, 3136, 64] | |
| nn.LayerNorm(dim) # 对每个batch归一化 | |
| ) | |
| self.stage1_transformer = nn.Sequential( | |
| Transformer( # | |
| dim=dim, | |
| img_size=img_size // 4, | |
| depth=depth[0], # Transformer层中的编码器和解码器层数。 | |
| heads=heads[0], | |
| dim_head=self.dim, # 它是每个注意力头的维度大小,通常是嵌入维度除以头数。 | |
| mlp_dim=dim * scale_dim, # mlp_dim:它是Transformer中前馈神经网络的隐藏层维度大小,通常是嵌入维度乘以一个缩放因子。 | |
| dropout=dropout, | |
| # last_stage=last_stage #它是一个标志位,用于表示该Transformer层是否是最后一层。 | |
| ), | |
| Rearrange('b (h w) c -> b c h w', h=img_size // 4, w=img_size // 4) | |
| ) | |
| # stage2 | |
| # k:3 s:2 in: 1, 192, 28, 28 out: 1, 784, 192 | |
| in_channels = dim | |
| scale = heads[1] // heads[0] | |
| dim = scale * dim | |
| self.stage2_conv_embed = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=dim, | |
| kernel_size=kernels[1], | |
| stride=strides[1], | |
| padding=1 | |
| ), | |
| Rearrange('b c h w -> b (h w) c', h=img_size // 8, w=img_size // 8), | |
| nn.LayerNorm(dim) | |
| ) | |
| self.stage2_transformer = nn.Sequential( | |
| Transformer( | |
| dim=dim, | |
| img_size=img_size // 8, | |
| depth=depth[1], | |
| heads=heads[1], | |
| dim_head=self.dim, | |
| mlp_dim=dim * scale_dim, | |
| dropout=dropout | |
| ), | |
| Rearrange('b (h w) c -> b c h w', h=img_size // 8, w=img_size // 8) | |
| ) | |
| # stage3 | |
| in_channels = dim | |
| scale = heads[2] // heads[1] | |
| dim = scale * dim | |
| self.stage3_conv_embed = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=dim, | |
| kernel_size=kernels[2], | |
| stride=strides[2], | |
| padding=1 | |
| ), | |
| Rearrange('b c h w -> b (h w) c', h=img_size // 16, w=img_size // 16), | |
| nn.LayerNorm(dim) | |
| ) | |
| self.stage3_transformer = nn.Sequential( | |
| Transformer( | |
| dim=dim, | |
| img_size=img_size // 16, | |
| depth=depth[2], | |
| heads=heads[2], | |
| dim_head=self.dim, | |
| mlp_dim=dim * scale_dim, | |
| dropout=dropout | |
| ), | |
| Rearrange('b (h w) c -> b c h w', h=img_size // 16, w=img_size // 16) | |
| ) | |
| # stage4 | |
| in_channels = dim | |
| scale = heads[3] // heads[2] | |
| dim = scale * dim | |
| self.stage4_conv_embed = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=dim, | |
| kernel_size=kernels[3], | |
| stride=strides[3], | |
| padding=1 | |
| ), | |
| Rearrange('b c h w -> b (h w) c', h=img_size // 32, w=img_size // 32), | |
| nn.LayerNorm(dim) | |
| ) | |
| self.stage4_transformer = nn.Sequential( | |
| Transformer( | |
| dim=dim, img_size=img_size // 32, | |
| depth=depth[3], | |
| heads=heads[3], | |
| dim_head=self.dim, | |
| mlp_dim=dim * scale_dim, | |
| dropout=dropout, | |
| ), | |
| Rearrange('b (h w) c -> b c h w', h=img_size // 32, w=img_size // 32) | |
| ) | |
| ### CNN Branch ### | |
| self.c_stage1 = c_stage123(in_chans=3, out_chans=64) | |
| self.c_stage2 = c_stage123(in_chans=64, out_chans=128) | |
| self.c_stage3 = c_stage123(in_chans=128, out_chans=384) | |
| self.c_stage4 = c_stage45(in_chans=384, out_chans=512) | |
| self.c_stage5 = c_stage45(in_chans=512, out_chans=1024) | |
| self.c_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| self.up_conv1 = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=1) | |
| self.up_conv2 = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1) | |
| ### CTmerge ### | |
| self.CTmerge1 = nn.Sequential( | |
| nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| ) | |
| self.CTmerge2 = nn.Sequential( | |
| nn.Conv2d(in_channels=320, out_channels=128, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| ) | |
| self.CTmerge3 = nn.Sequential( | |
| nn.Conv2d(in_channels=768, out_channels=512, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels=512, out_channels=384, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(384), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(384), | |
| nn.ReLU(), | |
| ) | |
| self.CTmerge4 = nn.Sequential( | |
| nn.Conv2d(in_channels=896, out_channels=640, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(640), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels=640, out_channels=512, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(), | |
| ) | |
| # decoder | |
| self.decoder4 = nn.Sequential( | |
| DepthwiseConv2d( | |
| in_chans=1408, | |
| out_chans=1024, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ), | |
| DepthwiseConv2d( | |
| in_chans=1024, | |
| out_chans=512, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ), | |
| nn.GELU() | |
| ) | |
| self.decoder3 = nn.Sequential( | |
| DepthwiseConv2d( | |
| in_chans=896, | |
| out_chans=512, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ), | |
| DepthwiseConv2d( | |
| in_chans=512, | |
| out_chans=384, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ), | |
| nn.GELU() | |
| ) | |
| self.decoder2 = nn.Sequential( | |
| DepthwiseConv2d( | |
| in_chans=576, | |
| out_chans=256, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ), | |
| DepthwiseConv2d( | |
| in_chans=256, | |
| out_chans=192, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ), | |
| nn.GELU() | |
| ) | |
| self.decoder1 = nn.Sequential( | |
| DepthwiseConv2d( | |
| in_chans=256, | |
| out_chans=64, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ), | |
| DepthwiseConv2d( | |
| in_chans=64, | |
| out_chans=16, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ), | |
| nn.GELU() | |
| ) | |
| self.sbr4 = SBR(512) | |
| self.sbr3 = SBR(384) | |
| self.sbr2 = SBR(192) | |
| self.sbr1 = SBR(16) | |
| self.head = nn.Conv2d(in_channels=16, out_channels=num_classes, kernel_size=1) | |
| def forward(self, input): | |
| ### encoder ### | |
| # stage1 = ts1 cat cs1 | |
| # t_s1 = self.t_stage1(input) | |
| # print(input.shape) | |
| # print('++++++++++++++++++++++') | |
| t_s1 = self.stage1_conv_embed(input) # 1*3*224*224-->1*3136*64 | |
| # print(t_s1.shape) | |
| # print('======================') | |
| t_s1 = self.stage1_transformer(t_s1) # 1*3136*64-->1*64*56*56 | |
| # print(t_s1.shape) | |
| # print('----------------------') | |
| c_s1 = self.c_stage1(input) # 1*3*224*224-->1*64*112*112 | |
| # print(c_s1.shape) | |
| # print('!!!!!!!!!!!!!!!!!!!!!!!') | |
| stage1 = self.CTmerge1(torch.cat([t_s1, self.c_max(c_s1)], dim=1)) # 1*64*56*56 # 拼接两条分支 | |
| # print(stage1.shape) | |
| # print('[[[[[[[[[[[[[[[[[[[[[[[') | |
| # stage2 = ts2 up cs2 | |
| # t_s2 = self.t_stage2(stage1) | |
| t_s2 = self.stage2_conv_embed(stage1) # 1*64*56*56-->1*784*192 # stage2_conv_embed是转化为序列操作 | |
| # print(t_s2.shape) | |
| # print('[[[[[[[[[[[[[[[[[[[[[[[') | |
| t_s2 = self.stage2_transformer(t_s2) # 1*784*192-->1*192*28*28 | |
| # print(t_s2.shape) | |
| # print('+++++++++++++++++++++++++') | |
| c_s2 = self.c_stage2(c_s1) # 1*64*112*112-->1*128*56*56 | |
| stage2 = self.CTmerge2( | |
| torch.cat([c_s2, F.interpolate(t_s2, size=c_s2.size()[2:], mode='bilinear', align_corners=True)], | |
| dim=1)) # mode='bilinear'表示使用双线性插值 1*128*56*56 | |
| # stage3 = ts3 cat cs3 | |
| # t_s3 = self.t_stage3(t_s2) | |
| t_s3 = self.stage3_conv_embed(t_s2) # 1*192*28*28-->1*196*384 | |
| # print(t_s3.shape) | |
| # print('///////////////////////') | |
| t_s3 = self.stage3_transformer(t_s3) # 1*196*384-->1*384*14*14 | |
| # print(t_s3.shape) | |
| # print('....................') | |
| c_s3 = self.c_stage3(stage2) # 1*128*56*56-->1*384*28*28 | |
| stage3 = self.CTmerge3(torch.cat([t_s3, self.c_max(c_s3)], dim=1)) # 1*384*14*14 | |
| # stage4 = ts4 up cs4 | |
| # t_s4 = self.t_stage4(stage3) | |
| t_s4 = self.stage4_conv_embed(stage3) # 1*384*14*14-->1*49*384 | |
| # print(t_s4.shape) | |
| # print(';;;;;;;;;;;;;;;;;;;;;;;') | |
| t_s4 = self.stage4_transformer(t_s4) # 1*49*384-->1*384*7*7 | |
| # print(t_s4.shape) | |
| # print('::::::::::::::::::::') | |
| c_s4 = self.c_stage4(c_s3) # 1*384*28*28-->1*512*14*14 | |
| stage4 = self.CTmerge4( | |
| torch.cat([c_s4, F.interpolate(t_s4, size=c_s4.size()[2:], mode='bilinear', align_corners=True)], | |
| dim=1)) # 1*512*14*14 | |
| # cs5 | |
| c_s5 = self.c_stage5(stage4) # 1*512*14*14-->1*1024*7*7 | |
| ### decoder ### | |
| decoder4 = torch.cat([c_s5, t_s4], dim=1) # 1*1408*7*7 | |
| decoder4 = self.decoder4(decoder4) # 1*1408*7*7-->1*512*7*7 | |
| decoder4 = F.interpolate(decoder4, size=c_s3.size()[2:], mode='bilinear', | |
| align_corners=True) # 1*512*7*7-->1*512*28*28 | |
| decoder4 = self.sbr4(decoder4) # 1*512*28*28 | |
| # print(decoder4.shape) | |
| decoder3 = torch.cat([decoder4, c_s3], dim=1) # 1*896*28*28 | |
| decoder3 = self.decoder3(decoder3) # 1*384*28*28 | |
| decoder3 = F.interpolate(decoder3, size=t_s2.size()[2:], mode='bilinear', align_corners=True) # 1*384*28*28 | |
| decoder3 = self.sbr3(decoder3) # 1*384*28*28 | |
| # print(decoder3.shape) | |
| decoder2 = torch.cat([decoder3, t_s2], dim=1) # 1*576*28*28 | |
| decoder2 = self.decoder2(decoder2) # 1*192*28*28 | |
| decoder2 = F.interpolate(decoder2, size=c_s1.size()[2:], mode='bilinear', align_corners=True) # 1*192*112*112 | |
| decoder2 = self.sbr2(decoder2) # 1*192*112*112 | |
| # print(decoder2.shape) | |
| decoder1 = torch.cat([decoder2, c_s1], dim=1) # 1*256*112*112 | |
| decoder1 = self.decoder1(decoder1) # 1*16*112*112 | |
| # print(decoder1.shape) | |
| final = F.interpolate(decoder1, size=input.size()[2:], mode='bilinear', align_corners=True) # 1*16*224*224 | |
| # print(final.shape) | |
| # final = self.sbr1(decoder1) | |
| # print(final.shape) | |
| final = self.head(final) # 1*3*224*224 | |
| return final | |
| if __name__ == '__main__': | |
| x = torch.rand(1, 3, 224, 224).cuda() | |
| model = DBNet(img_size=224, in_channels=3, num_classes=7).cuda() | |
| y = model(x) | |
| print(y.shape) | |
| # torch.Size([1, 7, 224, 224]) | |