import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple def rand(size, val=0.01): out = torch.zeros(size) nn.init.uniform_(out, -val, val) return out # from medsam def window_partition(x: torch.Tensor, window_size: int): B, C, H, W = x.size() pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size) windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size) return windows, (Hp, Wp), (Hp // window_size, Wp // window_size) def prompt_partition(prompt: torch.Tensor, h_windows: int, w_windows: int): # prompt: B, C, H, W B, C, H, W = prompt.size() prompt = prompt.view(B, 1, 1, C, H, W) prompt = prompt.repeat((1, h_windows, w_windows, 1, 1, 1)).contiguous().view(-1, C, H, W) return prompt def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]): # windows: B * Hp // window_size * Wp // window_size, C, window_size, window_size Hp, Wp = pad_hw H, W = hw B = (windows.shape[0] * window_size * window_size) // (Hp * Wp) # 0 1 2 3 4 5 x = windows.view(B, Hp // window_size, Wp // window_size, -1, window_size, window_size) x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, -1, Hp, Wp) if Hp > H or Wp > W: x = x[:, :, :H, :W].contiguous() return x class GELU(nn.Module): def __init__(self): super().__init__() def forward(self, x): cdf = 0.5 * (1 + torch.erf(x / 2**0.5)) return x * cdf class OneLayerRes(nn.Module): def __init__(self, in_features, out_features, kernel_size, padding) -> None: super().__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding) self.weight = nn.Parameter(torch.zeros(1), requires_grad=True) def forward(self, x): x = x + self.weight * self.conv(x) return x class MLP(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.2): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) return x class MultiHeadSelfAttention(nn.Module): def __init__(self, dim, num_heads=8, drop_rate=0.2): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.norm = nn.LayerNorm(dim) self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=False) self.drop = nn.Dropout(drop_rate) self.proj = nn.Linear(dim, dim) def forward(self, x, heat=False): B, N, C = x.shape out = self.norm(x) qkv = self.qkv(out).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.drop(attn) out = (attn @ v).transpose(1, 2).reshape(B, N, C) out = self.proj(out) out = self.drop(out) out = x + out if heat: return out, attn return out class MultiHeadAttention2D_POS(nn.Module): def __init__(self, dim_q, dim_k, dim_v, embed_dim, num_heads=8, drop_rate=0.2, embed_dim_ratio=4, stride=1, slide=0): super().__init__() self.stride = stride self.num_heads = num_heads self.slide = slide self.embed_dim_qk = embed_dim // embed_dim_ratio if self.embed_dim_qk % num_heads != 0: self.embed_dim_qk = (self.embed_dim_qk // num_heads + 1) * num_heads self.embed_dim_v = embed_dim if self.embed_dim_v % num_heads != 0: self.embed_dim_v = (self.embed_dim_v // num_heads + 1) * num_heads head_dim = self.embed_dim_qk // num_heads self.scale = head_dim ** -0.5 self.conv_q = nn.Conv2d(in_channels=dim_q, out_channels=self.embed_dim_qk, kernel_size=stride, padding=0, stride=stride) self.conv_k = nn.Conv2d(in_channels=dim_k, out_channels=self.embed_dim_qk, kernel_size=stride, padding=0, stride=stride) self.conv_v = nn.Conv2d(in_channels=dim_v, out_channels=self.embed_dim_v, kernel_size=stride, padding=0, stride=stride) self.drop = nn.Dropout(drop_rate) self.proj_out = nn.Conv2d(in_channels=self.embed_dim_v, out_channels=dim_q, kernel_size=3, padding=1) if self.stride > 1: self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear') else: self.upsample = nn.Identity() self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True) def forward(self, q, k, v, heat=False): B, _, H_q, W_q = q.size() _, _, H_kv, W_kv = k.size() H_q = H_q // self.stride W_q = W_q // self.stride H_kv = H_kv // self.stride W_kv = W_kv // self.stride proj_q = self.conv_q(q).reshape(B, self.num_heads, self.embed_dim_qk // self.num_heads, H_q * W_q).permute(0, 1, 3, 2).contiguous() proj_k = self.conv_k(k).reshape(B, self.num_heads, self.embed_dim_qk // self.num_heads, H_kv * W_kv).permute(0, 1, 3, 2).contiguous() proj_v = self.conv_v(v).reshape(B, self.num_heads, self.embed_dim_v // self.num_heads, H_kv * W_kv).permute(0, 1, 3, 2).contiguous() attn = (proj_q @ proj_k.transpose(-2, -1)).contiguous() * self.scale # B, self.num_heads, H_q * W_q, H_kv * W_kv attn = attn.softmax(dim=-1) attn = self.drop(attn) out = (attn @ proj_v) # B, self.num_heads, H_q * W_q, self.embed_dim // self.num_heads out = out.transpose(2, 3).contiguous().reshape(B, self.embed_dim_v, H_q, W_q) if self.slide > 0: out = out[:, :, self.slide // self.stride:] q = q[:, :, self.slide:] out = self.proj_out(out) out = self.upsample(out) out = self.drop(out) out = q + out * self.gamma return out class MultiHeadAttention2D_CHA(nn.Module): def __init__(self, dim_q, dim_kv, stride, num_heads=8, drop_rate=0.2, slide=0): super().__init__() self.num_heads = num_heads self.stride = stride self.slide = slide self.dim_q_out = dim_q - slide self.conv_q = nn.Conv2d(in_channels=dim_q, out_channels=dim_q * num_heads, kernel_size=stride, stride=stride, groups=dim_q) self.conv_k = nn.Conv2d(in_channels=dim_kv, out_channels=dim_kv * num_heads, kernel_size=stride, stride=stride, groups=dim_kv) self.conv_v = nn.Conv2d(in_channels=dim_kv, out_channels=dim_kv * num_heads, kernel_size=stride, stride=stride, groups=dim_kv) self.drop = nn.Dropout(drop_rate) self.proj_out = nn.ConvTranspose2d(in_channels=self.dim_q_out * num_heads, out_channels=self.dim_q_out, kernel_size=stride, stride=stride, groups=self.dim_q_out) self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True) def forward(self, q, k, v, heat=False): B, C_q, H_q, W_q = q.size() _, C_kv, H_kv, W_kv = k.size() proj_q = self.conv_q(q).reshape(B, self.num_heads, C_q, -1) # batch_size * num_heads * dim_q * (H * W) proj_k = self.conv_k(k).reshape(B, self.num_heads, C_kv, -1) proj_v = self.conv_v(v).reshape(B, self.num_heads, C_kv, -1) # batch_size * num_heads * dim_kv * (H * W) scale = proj_q.size(3) ** -0.5 attn = (proj_q @ proj_k.transpose(-2, -1)).contiguous() * scale # batch_size, num_heads, dim_q, dim_kv attn = attn.softmax(dim=-1) attn = self.drop(attn) out = (attn @ proj_v) # batch_size, num_heads, dim_q, (H * W) if self.slide > 0: out = out[:, :, :-self.slide] out = out.reshape(B, self.num_heads * self.dim_q_out, H_q // self.stride, W_q // self.stride) out = self.proj_out(out) out = self.drop(out) out = q + out * self.gamma return out class MultiHeadAttention2D_Dual2_2(nn.Module): def __init__(self, dim_pos, dim_cha, embed_dim, att_fusion, num_heads=8, drop_rate=0.2, embed_dim_ratio=4, stride=1, cha_slide=0, pos_slide=0, use_conv=True): super().__init__() self.pos_att = MultiHeadAttention2D_POS(dim_q=dim_pos, dim_k=dim_pos, dim_v=dim_pos, embed_dim=embed_dim, num_heads=num_heads, drop_rate=drop_rate, embed_dim_ratio=embed_dim_ratio, stride=stride, slide=pos_slide) self.cha_att = MultiHeadAttention2D_CHA(dim_q=dim_cha, dim_kv=dim_cha, num_heads=num_heads, drop_rate=drop_rate, slide=cha_slide, stride=stride) self.att_fusion = att_fusion # concat, add if att_fusion == 'concat': channel_in = 2 * (dim_pos - cha_slide) if att_fusion == 'add': channel_in = (dim_pos - cha_slide) channel_out = dim_pos - cha_slide self.use_conv = use_conv if use_conv: self.conv_out = nn.Sequential(nn.Dropout2d(drop_rate, True), nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1)) else: self.conv_out = nn.Identity() def forward(self, qkv_pos, qkv_cha, heat=False): if qkv_cha is None: qkv_cha = qkv_pos out_pos = self.pos_att(qkv_pos, qkv_pos, qkv_pos, heat) out_cha = self.cha_att(qkv_cha, qkv_cha, qkv_cha, heat) C = out_pos.size(1) H = out_cha.size(2) if self.att_fusion == 'concat': out = torch.cat([out_pos[:, :, -H:], out_cha[:, :C, :]], dim=1) if self.att_fusion == 'add': out = (out_pos[:, :, -H:] + out_cha[:, :C, :]) / 2 out = self.conv_out(out) return out class ResMLP(MLP): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.2): super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, act_layer=act_layer, drop=drop) self.norm = nn.LayerNorm(in_features) def forward(self, x): out = self.norm(x) out = self.fc1(out) out = self.act(out) out = self.drop(out) out = self.fc2(out) out = out + x return out class MHSABlock(nn.Module): def __init__(self, dim, num_heads=8, drop_rate=0.2) -> None: super().__init__() self.mhsa = MultiHeadSelfAttention(dim=dim, num_heads=num_heads, drop_rate=drop_rate) self.mlp = ResMLP(in_features=dim, hidden_features=dim*4, out_features=dim) def forward(self, x, heat=False): if heat: x, attn = self.mhsa(x, heat=True) else: x = self.mhsa(x) x = self.mlp(x) if heat: return x, attn return x class SelfAttentionBlocks(nn.Module): def __init__(self, dim, block_num, num_heads=8, drop_rate=0.2): super().__init__() self.block_num = block_num assert self.block_num >= 1 self.blocks = nn.ModuleList([MHSABlock(dim=dim, num_heads=num_heads, drop_rate=drop_rate) for i in range(self.block_num)]) def forward(self, x, heat=False): attns = [] for blk in self.blocks: if heat: x, attn = blk(x, heat=True) attns.append(attn) else: x = blk(x) if heat: return x, attns return x class conv_block(nn.Module): def __init__(self,ch_in,ch_out): super(conv_block,self).__init__() self.conv = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self,x): x = self.conv(x) return x class up_conv(nn.Module): def __init__(self,ch_in,ch_out): super(up_conv,self).__init__() self.up = nn.Sequential( nn.Upsample(scale_factor=2), nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self,x): x = self.up(x) return x class Recurrent_block(nn.Module): def __init__(self,ch_out,t=2): super(Recurrent_block,self).__init__() self.t = t self.ch_out = ch_out self.conv = nn.Sequential( nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self,x): for i in range(self.t): if i==0: x1 = self.conv(x) x1 = self.conv(x+x1) return x1 class RRCNN_block(nn.Module): def __init__(self,ch_in,ch_out,t=2): super(RRCNN_block,self).__init__() self.RCNN = nn.Sequential( Recurrent_block(ch_out,t=t), Recurrent_block(ch_out,t=t) ) self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0) def forward(self,x): x = self.Conv_1x1(x) x1 = self.RCNN(x) return x+x1 class single_conv(nn.Module): def __init__(self,ch_in,ch_out): super(single_conv,self).__init__() self.conv = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self,x): x = self.conv(x) return x class Attention_block(nn.Module): def __init__(self,F_g, F_l, F_int): super(Attention_block,self).__init__() self.W_g = nn.Sequential( nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), nn.BatchNorm2d(F_int) ) self.W_x = nn.Sequential( nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), nn.BatchNorm2d(F_int) ) self.psi = nn.Sequential( nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), nn.BatchNorm2d(1), nn.Sigmoid() ) self.relu = nn.ReLU(inplace=True) def forward(self,g,x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.relu(g1+x1) psi = self.psi(psi) return x*psi class R2AttUNetDecoder(nn.Module): def __init__(self, channels, t=2): super(R2AttUNetDecoder,self).__init__() self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear') self.Up5 = up_conv(ch_in=channels[4], ch_out=channels[3]) self.Att5 = Attention_block(F_g=channels[3], F_l=channels[3], F_int=channels[3]//2) self.Up_RRCNN5 = RRCNN_block(ch_in=2 * channels[3], ch_out=channels[3], t=t) self.Up4 = up_conv(ch_in=channels[3], ch_out=channels[2]) self.Att4 = Attention_block(F_g=channels[2], F_l=channels[2], F_int=channels[2]//2) self.Up_RRCNN4 = RRCNN_block(ch_in=2 * channels[2], ch_out=channels[2], t=t) self.Up3 = up_conv(ch_in=channels[2], ch_out=channels[1]) self.Att3 = Attention_block(F_g=channels[1], F_l=channels[1], F_int=channels[1]//2) self.Up_RRCNN3 = RRCNN_block(ch_in=2 * channels[1], ch_out=channels[1], t=t) self.Up2 = up_conv(ch_in=channels[1], ch_out=channels[0]) self.Att2 = Attention_block(F_g=channels[0], F_l=channels[0], F_int=channels[0]//2) self.Up_RRCNN2 = RRCNN_block(ch_in=2 * channels[0], ch_out=channels[0], t=t) def forward(self, x1, x2, x3, x4, x5): out = self.Up5(x5) x4_att = self.Att5(g=out, x=x4) out = torch.cat((x4_att, out),dim=1) out = self.Up_RRCNN5(out) out = self.Up4(out) x3_att = self.Att4(g=out, x=x3) out = torch.cat((x3_att, out),dim=1) out = self.Up_RRCNN4(out) out = self.Up3(out) x2_att = self.Att3(g=out, x=x2) out = torch.cat((x2_att, out),dim=1) out = self.Up_RRCNN3(out) out = self.Up2(out) x1_att = self.Att2(g=out, x=x1) out = torch.cat((x1_att, out),dim=1) out = self.Up_RRCNN2(out) out = self.Upsample(out) return out class ConvBlock(nn.Module): def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=0, bias=True): super(ConvBlock, self).__init__() self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) self.bn1 = nn.BatchNorm2d(ch_out) self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) self.bn2 = nn.BatchNorm2d(ch_out) self.activate = nn.LeakyReLU(negative_slope=0.01) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = self.activate(out) out = self.conv2(out) out = self.bn2(out) out = self.activate(out) return out class UNetDecoder(nn.Module): def __init__(self, channels): super(UNetDecoder,self).__init__() self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear') self.Up5 = up_conv(ch_in=channels[4], ch_out=channels[3]) self.conv5 = ConvBlock(ch_in=2 * channels[3], ch_out=channels[3], kernel_size=3, stride=1, padding=1) self.Up4 = up_conv(ch_in=channels[3], ch_out=channels[2]) self.conv4 = ConvBlock(ch_in=2 * channels[2], ch_out=channels[2], kernel_size=3, stride=1, padding=1) self.Up3 = up_conv(ch_in=channels[2], ch_out=channels[1]) self.conv3 = ConvBlock(ch_in=2 * channels[1], ch_out=channels[1], kernel_size=3, stride=1, padding=1) self.Up2 = up_conv(ch_in=channels[1], ch_out=channels[0]) self.conv2 = ConvBlock(ch_in=2 * channels[0], ch_out=channels[0], kernel_size=3, stride=1, padding=1) def forward(self, x1, x2, x3, x4, x5): out = self.Up5(x5) out = torch.cat((x4, out),dim=1) out = self.conv5(out) out = self.Up4(out) out = torch.cat((x3, out),dim=1) out = self.conv4(out) out = self.Up3(out) out = torch.cat((x2, out),dim=1) out = self.conv3(out) out = self.Up2(out) out = torch.cat((x1, out),dim=1) out = self.conv2(out) out = self.Upsample(out) return out class U_Net_P(nn.Module): def __init__(self, encoder, decoder, output_ch, num_classes): super(U_Net_P, self).__init__() self.encoder = encoder self.decoder = decoder self.Last_Conv = nn.Conv2d(output_ch, num_classes, kernel_size=3, stride=1, padding=1) def forward(self, x): # encoding path x1, x2, x3, x4, x5 = self.encoder(x) x = self.decoder(x1, x2, x3, x4, x5) x = self.Last_Conv(x) return x class Prompt_U_Net_P_DCP(nn.Module): def __init__(self, encoder, decoder, output_ch, num_classes, dataset_idx, encoder_channels, prompt_init, pos_promot_channels, cha_promot_channels, embed_ratio, strides, local_window_sizes, att_fusion, use_conv): super(Prompt_U_Net_P_DCP, self).__init__() self.dataset_idx = dataset_idx self.local_window_sizes = local_window_sizes self.encoder = encoder self.decoder = decoder self.Last_Conv = nn.Conv2d(output_ch, num_classes, kernel_size=3, stride=1, padding=1) if prompt_init == 'zero': p_init = torch.zeros elif prompt_init == 'one': p_init = torch.ones elif prompt_init == 'rand': p_init = rand else: raise Exception(prompt_init) self.pos_promot_channels = pos_promot_channels pos_p1 = p_init((1, encoder_channels[0], pos_promot_channels[0], local_window_sizes[0])) pos_p2 = p_init((1, encoder_channels[1], pos_promot_channels[1], local_window_sizes[1])) pos_p3 = p_init((1, encoder_channels[2], pos_promot_channels[2], local_window_sizes[2])) pos_p4 = p_init((1, encoder_channels[3], pos_promot_channels[3], local_window_sizes[3])) pos_p5 = p_init((1, encoder_channels[4], pos_promot_channels[4], local_window_sizes[4])) self.pos_promot1 = nn.ParameterDict({str(k): nn.Parameter(pos_p1.detach().clone(), requires_grad=True) for k in dataset_idx}) self.pos_promot2 = nn.ParameterDict({str(k): nn.Parameter(pos_p2.detach().clone(), requires_grad=True) for k in dataset_idx}) self.pos_promot3 = nn.ParameterDict({str(k): nn.Parameter(pos_p3.detach().clone(), requires_grad=True) for k in dataset_idx}) self.pos_promot4 = nn.ParameterDict({str(k): nn.Parameter(pos_p4.detach().clone(), requires_grad=True) for k in dataset_idx}) self.pos_promot5 = nn.ParameterDict({str(k): nn.Parameter(pos_p5.detach().clone(), requires_grad=True) for k in dataset_idx}) self.cha_promot_channels = cha_promot_channels cha_p1 = p_init((1, cha_promot_channels[0], local_window_sizes[0], local_window_sizes[0])) cha_p2 = p_init((1, cha_promot_channels[1], local_window_sizes[1], local_window_sizes[1])) cha_p3 = p_init((1, cha_promot_channels[2], local_window_sizes[2], local_window_sizes[2])) cha_p4 = p_init((1, cha_promot_channels[3], local_window_sizes[3], local_window_sizes[3])) cha_p5 = p_init((1, cha_promot_channels[4], local_window_sizes[4], local_window_sizes[4])) self.cha_promot1 = nn.ParameterDict({str(k): nn.Parameter(cha_p1.detach().clone(), requires_grad=True) for k in dataset_idx}) self.cha_promot2 = nn.ParameterDict({str(k): nn.Parameter(cha_p2.detach().clone(), requires_grad=True) for k in dataset_idx}) self.cha_promot3 = nn.ParameterDict({str(k): nn.Parameter(cha_p3.detach().clone(), requires_grad=True) for k in dataset_idx}) self.cha_promot4 = nn.ParameterDict({str(k): nn.Parameter(cha_p4.detach().clone(), requires_grad=True) for k in dataset_idx}) self.cha_promot5 = nn.ParameterDict({str(k): nn.Parameter(cha_p5.detach().clone(), requires_grad=True) for k in dataset_idx}) self.strides = strides self.att1 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[0], dim_cha=encoder_channels[0] + cha_promot_channels[0], embed_dim=encoder_channels[0], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[0], pos_slide=0, cha_slide=0, use_conv=use_conv) self.att2 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[1], dim_cha=encoder_channels[1] + cha_promot_channels[1], embed_dim=encoder_channels[1], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[1], pos_slide=0, cha_slide=0, use_conv=use_conv) self.att3 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[2], dim_cha=encoder_channels[2] + cha_promot_channels[2], embed_dim=encoder_channels[2], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[2], pos_slide=0, cha_slide=0, use_conv=use_conv) self.att4 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[3], dim_cha=encoder_channels[3] + cha_promot_channels[3], embed_dim=encoder_channels[3], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[3], pos_slide=0, cha_slide=0, use_conv=use_conv) self.att5 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[4], dim_cha=encoder_channels[4] + cha_promot_channels[4], embed_dim=encoder_channels[4], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[4], pos_slide=0, cha_slide=0, use_conv=use_conv) def get_cha_prompts(self, dataset_idx, batch_size): if len(dataset_idx) != batch_size: raise Exception(dataset_idx, self.dataset_idx, batch_size) promots1 = torch.concatenate([self.cha_promot1[str(i)] for i in dataset_idx], dim=0) promots2 = torch.concatenate([self.cha_promot2[str(i)] for i in dataset_idx], dim=0) promots3 = torch.concatenate([self.cha_promot3[str(i)] for i in dataset_idx], dim=0) promots4 = torch.concatenate([self.cha_promot4[str(i)] for i in dataset_idx], dim=0) promots5 = torch.concatenate([self.cha_promot5[str(i)] for i in dataset_idx], dim=0) return promots1, promots2, promots3, promots4, promots5 def get_pos_prompts(self, dataset_idx, batch_size): if len(dataset_idx) != batch_size: raise Exception(dataset_idx, self.dataset_idx) promots1 = torch.concatenate([self.pos_promot1[str(i)] for i in dataset_idx], dim=0) promots2 = torch.concatenate([self.pos_promot2[str(i)] for i in dataset_idx], dim=0) promots3 = torch.concatenate([self.pos_promot3[str(i)] for i in dataset_idx], dim=0) promots4 = torch.concatenate([self.pos_promot4[str(i)] for i in dataset_idx], dim=0) promots5 = torch.concatenate([self.pos_promot5[str(i)] for i in dataset_idx], dim=0) return promots1, promots2, promots3, promots4, promots5 def forward(self, x, dataset_idx, return_features=False): if isinstance(dataset_idx, torch.Tensor): dataset_idx = list(dataset_idx.cpu().numpy()) cha_promots1, cha_promots2, cha_promots3, cha_promots4, cha_promots5 = self.get_cha_prompts(dataset_idx=dataset_idx, batch_size=x.size(0)) pos_promots1, pos_promots2, pos_promots3, pos_promots4, pos_promots5 = self.get_pos_prompts(dataset_idx=dataset_idx, batch_size=x.size(0)) x1, x2, x3, x4, x5 = self.encoder(x) if return_features: pre_x1, pre_x2, pre_x3, pre_x4, pre_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone() h1, w1 = x1.size()[2:] h2, w2 = x2.size()[2:] h3, w3 = x3.size()[2:] h4, w4 = x4.size()[2:] h5, w5 = x5.size()[2:] x1, (Hp1, Wp1), (h_win1, w_win1) = window_partition(x1, self.local_window_sizes[0]) x2, (Hp2, Wp2), (h_win2, w_win2) = window_partition(x2, self.local_window_sizes[1]) x3, (Hp3, Wp3), (h_win3, w_win3) = window_partition(x3, self.local_window_sizes[2]) x4, (Hp4, Wp4), (h_win4, w_win4) = window_partition(x4, self.local_window_sizes[3]) x5, (Hp5, Wp5), (h_win5, w_win5) = window_partition(x5, self.local_window_sizes[4]) cha_promots1 = prompt_partition(cha_promots1, h_win1, w_win1) cha_promots2 = prompt_partition(cha_promots2, h_win2, w_win2) cha_promots3 = prompt_partition(cha_promots3, h_win3, w_win3) cha_promots4 = prompt_partition(cha_promots4, h_win4, w_win4) cha_promots5 = prompt_partition(cha_promots5, h_win5, w_win5) pos_promots1 = prompt_partition(pos_promots1, h_win1, w_win1) pos_promots2 = prompt_partition(pos_promots2, h_win2, w_win2) pos_promots3 = prompt_partition(pos_promots3, h_win3, w_win3) pos_promots4 = prompt_partition(pos_promots4, h_win4, w_win4) pos_promots5 = prompt_partition(pos_promots5, h_win5, w_win5) cha_x1, cha_x2, cha_x3, cha_x4, cha_x5 = torch.cat([x1, cha_promots1], dim=1), torch.cat([x2, cha_promots2], dim=1), torch.cat([x3, cha_promots3], dim=1), torch.cat([x4, cha_promots4], dim=1), torch.cat([x5, cha_promots5], dim=1) pos_x1, pos_x2, pos_x3, pos_x4, pos_x5 = torch.cat([pos_promots1, x1], dim=2), torch.cat([pos_promots2, x2], dim=2), torch.cat([pos_promots3, x3], dim=2), torch.cat([pos_promots4, x4], dim=2), torch.cat([pos_promots5, x5], dim=2) x1, x2, x3, x4, x5 = self.att1(pos_x1, cha_x1), self.att2(pos_x2, cha_x2), self.att3(pos_x3, cha_x3), self.att4(pos_x4, cha_x4), self.att5(pos_x5, cha_x5) x1 = window_unpartition(x1, self.local_window_sizes[0], (Hp1, Wp1), (h1, w1)) x2 = window_unpartition(x2, self.local_window_sizes[1], (Hp2, Wp2), (h2, w2)) x3 = window_unpartition(x3, self.local_window_sizes[2], (Hp3, Wp3), (h3, w3)) x4 = window_unpartition(x4, self.local_window_sizes[3], (Hp4, Wp4), (h4, w4)) x5 = window_unpartition(x5, self.local_window_sizes[4], (Hp5, Wp5), (h5, w5)) if return_features: pro_x1, pro_x2, pro_x3, pro_x4, pro_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone() return (pre_x1, pre_x2, pre_x3, pre_x4, pre_x5), (pro_x1, pro_x2, pro_x3, pro_x4, pro_x5) x = self.decoder(x1, x2, x3, x4, x5) x = self.Last_Conv(x) return x