UNet_DCP_1024 / models /UNet_p.py
qijie.wei
first commit
c5f4ee2
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
def rand(size, val=0.01):
out = torch.zeros(size)
nn.init.uniform_(out, -val, val)
return out
# from medsam
def window_partition(x: torch.Tensor, window_size: int):
B, C, H, W = x.size()
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size)
return windows, (Hp, Wp), (Hp // window_size, Wp // window_size)
def prompt_partition(prompt: torch.Tensor, h_windows: int, w_windows: int):
# prompt: B, C, H, W
B, C, H, W = prompt.size()
prompt = prompt.view(B, 1, 1, C, H, W)
prompt = prompt.repeat((1, h_windows, w_windows, 1, 1, 1)).contiguous().view(-1, C, H, W)
return prompt
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]):
# windows: B * Hp // window_size * Wp // window_size, C, window_size, window_size
Hp, Wp = pad_hw
H, W = hw
B = (windows.shape[0] * window_size * window_size) // (Hp * Wp)
# 0 1 2 3 4 5
x = windows.view(B, Hp // window_size, Wp // window_size, -1, window_size, window_size)
x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, -1, Hp, Wp)
if Hp > H or Wp > W:
x = x[:, :, :H, :W].contiguous()
return x
class GELU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
cdf = 0.5 * (1 + torch.erf(x / 2**0.5))
return x * cdf
class OneLayerRes(nn.Module):
def __init__(self, in_features, out_features, kernel_size, padding) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding)
self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
def forward(self, x):
x = x + self.weight * self.conv(x)
return x
class MLP(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.2):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return x
class MultiHeadSelfAttention(nn.Module):
def __init__(self, dim, num_heads=8, drop_rate=0.2):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.norm = nn.LayerNorm(dim)
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.drop = nn.Dropout(drop_rate)
self.proj = nn.Linear(dim, dim)
def forward(self, x, heat=False):
B, N, C = x.shape
out = self.norm(x)
qkv = self.qkv(out).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.drop(attn)
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
out = self.proj(out)
out = self.drop(out)
out = x + out
if heat:
return out, attn
return out
class MultiHeadAttention2D_POS(nn.Module):
def __init__(self, dim_q, dim_k, dim_v, embed_dim, num_heads=8, drop_rate=0.2, embed_dim_ratio=4, stride=1, slide=0):
super().__init__()
self.stride = stride
self.num_heads = num_heads
self.slide = slide
self.embed_dim_qk = embed_dim // embed_dim_ratio
if self.embed_dim_qk % num_heads != 0:
self.embed_dim_qk = (self.embed_dim_qk // num_heads + 1) * num_heads
self.embed_dim_v = embed_dim
if self.embed_dim_v % num_heads != 0:
self.embed_dim_v = (self.embed_dim_v // num_heads + 1) * num_heads
head_dim = self.embed_dim_qk // num_heads
self.scale = head_dim ** -0.5
self.conv_q = nn.Conv2d(in_channels=dim_q, out_channels=self.embed_dim_qk, kernel_size=stride, padding=0, stride=stride)
self.conv_k = nn.Conv2d(in_channels=dim_k, out_channels=self.embed_dim_qk, kernel_size=stride, padding=0, stride=stride)
self.conv_v = nn.Conv2d(in_channels=dim_v, out_channels=self.embed_dim_v, kernel_size=stride, padding=0, stride=stride)
self.drop = nn.Dropout(drop_rate)
self.proj_out = nn.Conv2d(in_channels=self.embed_dim_v, out_channels=dim_q, kernel_size=3, padding=1)
if self.stride > 1:
self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear')
else:
self.upsample = nn.Identity()
self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True)
def forward(self, q, k, v, heat=False):
B, _, H_q, W_q = q.size()
_, _, H_kv, W_kv = k.size()
H_q = H_q // self.stride
W_q = W_q // self.stride
H_kv = H_kv // self.stride
W_kv = W_kv // self.stride
proj_q = self.conv_q(q).reshape(B, self.num_heads, self.embed_dim_qk // self.num_heads, H_q * W_q).permute(0, 1, 3, 2).contiguous()
proj_k = self.conv_k(k).reshape(B, self.num_heads, self.embed_dim_qk // self.num_heads, H_kv * W_kv).permute(0, 1, 3, 2).contiguous()
proj_v = self.conv_v(v).reshape(B, self.num_heads, self.embed_dim_v // self.num_heads, H_kv * W_kv).permute(0, 1, 3, 2).contiguous()
attn = (proj_q @ proj_k.transpose(-2, -1)).contiguous() * self.scale # B, self.num_heads, H_q * W_q, H_kv * W_kv
attn = attn.softmax(dim=-1)
attn = self.drop(attn)
out = (attn @ proj_v) # B, self.num_heads, H_q * W_q, self.embed_dim // self.num_heads
out = out.transpose(2, 3).contiguous().reshape(B, self.embed_dim_v, H_q, W_q)
if self.slide > 0:
out = out[:, :, self.slide // self.stride:]
q = q[:, :, self.slide:]
out = self.proj_out(out)
out = self.upsample(out)
out = self.drop(out)
out = q + out * self.gamma
return out
class MultiHeadAttention2D_CHA(nn.Module):
def __init__(self, dim_q, dim_kv, stride, num_heads=8, drop_rate=0.2, slide=0):
super().__init__()
self.num_heads = num_heads
self.stride = stride
self.slide = slide
self.dim_q_out = dim_q - slide
self.conv_q = nn.Conv2d(in_channels=dim_q, out_channels=dim_q * num_heads, kernel_size=stride, stride=stride, groups=dim_q)
self.conv_k = nn.Conv2d(in_channels=dim_kv, out_channels=dim_kv * num_heads, kernel_size=stride, stride=stride, groups=dim_kv)
self.conv_v = nn.Conv2d(in_channels=dim_kv, out_channels=dim_kv * num_heads, kernel_size=stride, stride=stride, groups=dim_kv)
self.drop = nn.Dropout(drop_rate)
self.proj_out = nn.ConvTranspose2d(in_channels=self.dim_q_out * num_heads, out_channels=self.dim_q_out, kernel_size=stride, stride=stride, groups=self.dim_q_out)
self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True)
def forward(self, q, k, v, heat=False):
B, C_q, H_q, W_q = q.size()
_, C_kv, H_kv, W_kv = k.size()
proj_q = self.conv_q(q).reshape(B, self.num_heads, C_q, -1) # batch_size * num_heads * dim_q * (H * W)
proj_k = self.conv_k(k).reshape(B, self.num_heads, C_kv, -1)
proj_v = self.conv_v(v).reshape(B, self.num_heads, C_kv, -1) # batch_size * num_heads * dim_kv * (H * W)
scale = proj_q.size(3) ** -0.5
attn = (proj_q @ proj_k.transpose(-2, -1)).contiguous() * scale # batch_size, num_heads, dim_q, dim_kv
attn = attn.softmax(dim=-1)
attn = self.drop(attn)
out = (attn @ proj_v) # batch_size, num_heads, dim_q, (H * W)
if self.slide > 0:
out = out[:, :, :-self.slide]
out = out.reshape(B, self.num_heads * self.dim_q_out, H_q // self.stride, W_q // self.stride)
out = self.proj_out(out)
out = self.drop(out)
out = q + out * self.gamma
return out
class MultiHeadAttention2D_Dual2_2(nn.Module):
def __init__(self, dim_pos, dim_cha, embed_dim, att_fusion, num_heads=8, drop_rate=0.2, embed_dim_ratio=4, stride=1, cha_slide=0, pos_slide=0, use_conv=True):
super().__init__()
self.pos_att = MultiHeadAttention2D_POS(dim_q=dim_pos, dim_k=dim_pos, dim_v=dim_pos, embed_dim=embed_dim, num_heads=num_heads, drop_rate=drop_rate, embed_dim_ratio=embed_dim_ratio, stride=stride, slide=pos_slide)
self.cha_att = MultiHeadAttention2D_CHA(dim_q=dim_cha, dim_kv=dim_cha, num_heads=num_heads, drop_rate=drop_rate, slide=cha_slide, stride=stride)
self.att_fusion = att_fusion # concat, add
if att_fusion == 'concat':
channel_in = 2 * (dim_pos - cha_slide)
if att_fusion == 'add':
channel_in = (dim_pos - cha_slide)
channel_out = dim_pos - cha_slide
self.use_conv = use_conv
if use_conv:
self.conv_out = nn.Sequential(nn.Dropout2d(drop_rate, True), nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1))
else:
self.conv_out = nn.Identity()
def forward(self, qkv_pos, qkv_cha, heat=False):
if qkv_cha is None:
qkv_cha = qkv_pos
out_pos = self.pos_att(qkv_pos, qkv_pos, qkv_pos, heat)
out_cha = self.cha_att(qkv_cha, qkv_cha, qkv_cha, heat)
C = out_pos.size(1)
H = out_cha.size(2)
if self.att_fusion == 'concat':
out = torch.cat([out_pos[:, :, -H:], out_cha[:, :C, :]], dim=1)
if self.att_fusion == 'add':
out = (out_pos[:, :, -H:] + out_cha[:, :C, :]) / 2
out = self.conv_out(out)
return out
class ResMLP(MLP):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.2):
super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, act_layer=act_layer, drop=drop)
self.norm = nn.LayerNorm(in_features)
def forward(self, x):
out = self.norm(x)
out = self.fc1(out)
out = self.act(out)
out = self.drop(out)
out = self.fc2(out)
out = out + x
return out
class MHSABlock(nn.Module):
def __init__(self, dim, num_heads=8, drop_rate=0.2) -> None:
super().__init__()
self.mhsa = MultiHeadSelfAttention(dim=dim, num_heads=num_heads, drop_rate=drop_rate)
self.mlp = ResMLP(in_features=dim, hidden_features=dim*4, out_features=dim)
def forward(self, x, heat=False):
if heat:
x, attn = self.mhsa(x, heat=True)
else:
x = self.mhsa(x)
x = self.mlp(x)
if heat:
return x, attn
return x
class SelfAttentionBlocks(nn.Module):
def __init__(self, dim, block_num, num_heads=8, drop_rate=0.2):
super().__init__()
self.block_num = block_num
assert self.block_num >= 1
self.blocks = nn.ModuleList([MHSABlock(dim=dim, num_heads=num_heads, drop_rate=drop_rate)
for i in range(self.block_num)])
def forward(self, x, heat=False):
attns = []
for blk in self.blocks:
if heat:
x, attn = blk(x, heat=True)
attns.append(attn)
else:
x = blk(x)
if heat:
return x, attns
return x
class conv_block(nn.Module):
def __init__(self,ch_in,ch_out):
super(conv_block,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.conv(x)
return x
class up_conv(nn.Module):
def __init__(self,ch_in,ch_out):
super(up_conv,self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.up(x)
return x
class Recurrent_block(nn.Module):
def __init__(self,ch_out,t=2):
super(Recurrent_block,self).__init__()
self.t = t
self.ch_out = ch_out
self.conv = nn.Sequential(
nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
for i in range(self.t):
if i==0:
x1 = self.conv(x)
x1 = self.conv(x+x1)
return x1
class RRCNN_block(nn.Module):
def __init__(self,ch_in,ch_out,t=2):
super(RRCNN_block,self).__init__()
self.RCNN = nn.Sequential(
Recurrent_block(ch_out,t=t),
Recurrent_block(ch_out,t=t)
)
self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
def forward(self,x):
x = self.Conv_1x1(x)
x1 = self.RCNN(x)
return x+x1
class single_conv(nn.Module):
def __init__(self,ch_in,ch_out):
super(single_conv,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.conv(x)
return x
class Attention_block(nn.Module):
def __init__(self,F_g, F_l, F_int):
super(Attention_block,self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self,g,x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1+x1)
psi = self.psi(psi)
return x*psi
class R2AttUNetDecoder(nn.Module):
def __init__(self, channels, t=2):
super(R2AttUNetDecoder,self).__init__()
self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear')
self.Up5 = up_conv(ch_in=channels[4], ch_out=channels[3])
self.Att5 = Attention_block(F_g=channels[3], F_l=channels[3], F_int=channels[3]//2)
self.Up_RRCNN5 = RRCNN_block(ch_in=2 * channels[3], ch_out=channels[3], t=t)
self.Up4 = up_conv(ch_in=channels[3], ch_out=channels[2])
self.Att4 = Attention_block(F_g=channels[2], F_l=channels[2], F_int=channels[2]//2)
self.Up_RRCNN4 = RRCNN_block(ch_in=2 * channels[2], ch_out=channels[2], t=t)
self.Up3 = up_conv(ch_in=channels[2], ch_out=channels[1])
self.Att3 = Attention_block(F_g=channels[1], F_l=channels[1], F_int=channels[1]//2)
self.Up_RRCNN3 = RRCNN_block(ch_in=2 * channels[1], ch_out=channels[1], t=t)
self.Up2 = up_conv(ch_in=channels[1], ch_out=channels[0])
self.Att2 = Attention_block(F_g=channels[0], F_l=channels[0], F_int=channels[0]//2)
self.Up_RRCNN2 = RRCNN_block(ch_in=2 * channels[0], ch_out=channels[0], t=t)
def forward(self, x1, x2, x3, x4, x5):
out = self.Up5(x5)
x4_att = self.Att5(g=out, x=x4)
out = torch.cat((x4_att, out),dim=1)
out = self.Up_RRCNN5(out)
out = self.Up4(out)
x3_att = self.Att4(g=out, x=x3)
out = torch.cat((x3_att, out),dim=1)
out = self.Up_RRCNN4(out)
out = self.Up3(out)
x2_att = self.Att3(g=out, x=x2)
out = torch.cat((x2_att, out),dim=1)
out = self.Up_RRCNN3(out)
out = self.Up2(out)
x1_att = self.Att2(g=out, x=x1)
out = torch.cat((x1_att, out),dim=1)
out = self.Up_RRCNN2(out)
out = self.Upsample(out)
return out
class ConvBlock(nn.Module):
def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=0, bias=True):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
self.bn2 = nn.BatchNorm2d(ch_out)
self.activate = nn.LeakyReLU(negative_slope=0.01)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.activate(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.activate(out)
return out
class UNetDecoder(nn.Module):
def __init__(self, channels):
super(UNetDecoder,self).__init__()
self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear')
self.Up5 = up_conv(ch_in=channels[4], ch_out=channels[3])
self.conv5 = ConvBlock(ch_in=2 * channels[3], ch_out=channels[3], kernel_size=3, stride=1, padding=1)
self.Up4 = up_conv(ch_in=channels[3], ch_out=channels[2])
self.conv4 = ConvBlock(ch_in=2 * channels[2], ch_out=channels[2], kernel_size=3, stride=1, padding=1)
self.Up3 = up_conv(ch_in=channels[2], ch_out=channels[1])
self.conv3 = ConvBlock(ch_in=2 * channels[1], ch_out=channels[1], kernel_size=3, stride=1, padding=1)
self.Up2 = up_conv(ch_in=channels[1], ch_out=channels[0])
self.conv2 = ConvBlock(ch_in=2 * channels[0], ch_out=channels[0], kernel_size=3, stride=1, padding=1)
def forward(self, x1, x2, x3, x4, x5):
out = self.Up5(x5)
out = torch.cat((x4, out),dim=1)
out = self.conv5(out)
out = self.Up4(out)
out = torch.cat((x3, out),dim=1)
out = self.conv4(out)
out = self.Up3(out)
out = torch.cat((x2, out),dim=1)
out = self.conv3(out)
out = self.Up2(out)
out = torch.cat((x1, out),dim=1)
out = self.conv2(out)
out = self.Upsample(out)
return out
class U_Net_P(nn.Module):
def __init__(self, encoder, decoder, output_ch, num_classes):
super(U_Net_P, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.Last_Conv = nn.Conv2d(output_ch, num_classes, kernel_size=3, stride=1, padding=1)
def forward(self, x):
# encoding path
x1, x2, x3, x4, x5 = self.encoder(x)
x = self.decoder(x1, x2, x3, x4, x5)
x = self.Last_Conv(x)
return x
class Prompt_U_Net_P_DCP(nn.Module):
def __init__(self, encoder, decoder, output_ch, num_classes, dataset_idx, encoder_channels, prompt_init, pos_promot_channels, cha_promot_channels, embed_ratio, strides, local_window_sizes, att_fusion, use_conv):
super(Prompt_U_Net_P_DCP, self).__init__()
self.dataset_idx = dataset_idx
self.local_window_sizes = local_window_sizes
self.encoder = encoder
self.decoder = decoder
self.Last_Conv = nn.Conv2d(output_ch, num_classes, kernel_size=3, stride=1, padding=1)
if prompt_init == 'zero':
p_init = torch.zeros
elif prompt_init == 'one':
p_init = torch.ones
elif prompt_init == 'rand':
p_init = rand
else:
raise Exception(prompt_init)
self.pos_promot_channels = pos_promot_channels
pos_p1 = p_init((1, encoder_channels[0], pos_promot_channels[0], local_window_sizes[0]))
pos_p2 = p_init((1, encoder_channels[1], pos_promot_channels[1], local_window_sizes[1]))
pos_p3 = p_init((1, encoder_channels[2], pos_promot_channels[2], local_window_sizes[2]))
pos_p4 = p_init((1, encoder_channels[3], pos_promot_channels[3], local_window_sizes[3]))
pos_p5 = p_init((1, encoder_channels[4], pos_promot_channels[4], local_window_sizes[4]))
self.pos_promot1 = nn.ParameterDict({str(k): nn.Parameter(pos_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
self.pos_promot2 = nn.ParameterDict({str(k): nn.Parameter(pos_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
self.pos_promot3 = nn.ParameterDict({str(k): nn.Parameter(pos_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
self.pos_promot4 = nn.ParameterDict({str(k): nn.Parameter(pos_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
self.pos_promot5 = nn.ParameterDict({str(k): nn.Parameter(pos_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
self.cha_promot_channels = cha_promot_channels
cha_p1 = p_init((1, cha_promot_channels[0], local_window_sizes[0], local_window_sizes[0]))
cha_p2 = p_init((1, cha_promot_channels[1], local_window_sizes[1], local_window_sizes[1]))
cha_p3 = p_init((1, cha_promot_channels[2], local_window_sizes[2], local_window_sizes[2]))
cha_p4 = p_init((1, cha_promot_channels[3], local_window_sizes[3], local_window_sizes[3]))
cha_p5 = p_init((1, cha_promot_channels[4], local_window_sizes[4], local_window_sizes[4]))
self.cha_promot1 = nn.ParameterDict({str(k): nn.Parameter(cha_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
self.cha_promot2 = nn.ParameterDict({str(k): nn.Parameter(cha_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
self.cha_promot3 = nn.ParameterDict({str(k): nn.Parameter(cha_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
self.cha_promot4 = nn.ParameterDict({str(k): nn.Parameter(cha_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
self.cha_promot5 = nn.ParameterDict({str(k): nn.Parameter(cha_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
self.strides = strides
self.att1 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[0], dim_cha=encoder_channels[0] + cha_promot_channels[0], embed_dim=encoder_channels[0], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[0], pos_slide=0, cha_slide=0, use_conv=use_conv)
self.att2 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[1], dim_cha=encoder_channels[1] + cha_promot_channels[1], embed_dim=encoder_channels[1], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[1], pos_slide=0, cha_slide=0, use_conv=use_conv)
self.att3 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[2], dim_cha=encoder_channels[2] + cha_promot_channels[2], embed_dim=encoder_channels[2], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[2], pos_slide=0, cha_slide=0, use_conv=use_conv)
self.att4 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[3], dim_cha=encoder_channels[3] + cha_promot_channels[3], embed_dim=encoder_channels[3], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[3], pos_slide=0, cha_slide=0, use_conv=use_conv)
self.att5 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[4], dim_cha=encoder_channels[4] + cha_promot_channels[4], embed_dim=encoder_channels[4], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[4], pos_slide=0, cha_slide=0, use_conv=use_conv)
def get_cha_prompts(self, dataset_idx, batch_size):
if len(dataset_idx) != batch_size:
raise Exception(dataset_idx, self.dataset_idx, batch_size)
promots1 = torch.concatenate([self.cha_promot1[str(i)] for i in dataset_idx], dim=0)
promots2 = torch.concatenate([self.cha_promot2[str(i)] for i in dataset_idx], dim=0)
promots3 = torch.concatenate([self.cha_promot3[str(i)] for i in dataset_idx], dim=0)
promots4 = torch.concatenate([self.cha_promot4[str(i)] for i in dataset_idx], dim=0)
promots5 = torch.concatenate([self.cha_promot5[str(i)] for i in dataset_idx], dim=0)
return promots1, promots2, promots3, promots4, promots5
def get_pos_prompts(self, dataset_idx, batch_size):
if len(dataset_idx) != batch_size:
raise Exception(dataset_idx, self.dataset_idx)
promots1 = torch.concatenate([self.pos_promot1[str(i)] for i in dataset_idx], dim=0)
promots2 = torch.concatenate([self.pos_promot2[str(i)] for i in dataset_idx], dim=0)
promots3 = torch.concatenate([self.pos_promot3[str(i)] for i in dataset_idx], dim=0)
promots4 = torch.concatenate([self.pos_promot4[str(i)] for i in dataset_idx], dim=0)
promots5 = torch.concatenate([self.pos_promot5[str(i)] for i in dataset_idx], dim=0)
return promots1, promots2, promots3, promots4, promots5
def forward(self, x, dataset_idx, return_features=False):
if isinstance(dataset_idx, torch.Tensor):
dataset_idx = list(dataset_idx.cpu().numpy())
cha_promots1, cha_promots2, cha_promots3, cha_promots4, cha_promots5 = self.get_cha_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
pos_promots1, pos_promots2, pos_promots3, pos_promots4, pos_promots5 = self.get_pos_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
x1, x2, x3, x4, x5 = self.encoder(x)
if return_features:
pre_x1, pre_x2, pre_x3, pre_x4, pre_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
h1, w1 = x1.size()[2:]
h2, w2 = x2.size()[2:]
h3, w3 = x3.size()[2:]
h4, w4 = x4.size()[2:]
h5, w5 = x5.size()[2:]
x1, (Hp1, Wp1), (h_win1, w_win1) = window_partition(x1, self.local_window_sizes[0])
x2, (Hp2, Wp2), (h_win2, w_win2) = window_partition(x2, self.local_window_sizes[1])
x3, (Hp3, Wp3), (h_win3, w_win3) = window_partition(x3, self.local_window_sizes[2])
x4, (Hp4, Wp4), (h_win4, w_win4) = window_partition(x4, self.local_window_sizes[3])
x5, (Hp5, Wp5), (h_win5, w_win5) = window_partition(x5, self.local_window_sizes[4])
cha_promots1 = prompt_partition(cha_promots1, h_win1, w_win1)
cha_promots2 = prompt_partition(cha_promots2, h_win2, w_win2)
cha_promots3 = prompt_partition(cha_promots3, h_win3, w_win3)
cha_promots4 = prompt_partition(cha_promots4, h_win4, w_win4)
cha_promots5 = prompt_partition(cha_promots5, h_win5, w_win5)
pos_promots1 = prompt_partition(pos_promots1, h_win1, w_win1)
pos_promots2 = prompt_partition(pos_promots2, h_win2, w_win2)
pos_promots3 = prompt_partition(pos_promots3, h_win3, w_win3)
pos_promots4 = prompt_partition(pos_promots4, h_win4, w_win4)
pos_promots5 = prompt_partition(pos_promots5, h_win5, w_win5)
cha_x1, cha_x2, cha_x3, cha_x4, cha_x5 = torch.cat([x1, cha_promots1], dim=1), torch.cat([x2, cha_promots2], dim=1), torch.cat([x3, cha_promots3], dim=1), torch.cat([x4, cha_promots4], dim=1), torch.cat([x5, cha_promots5], dim=1)
pos_x1, pos_x2, pos_x3, pos_x4, pos_x5 = torch.cat([pos_promots1, x1], dim=2), torch.cat([pos_promots2, x2], dim=2), torch.cat([pos_promots3, x3], dim=2), torch.cat([pos_promots4, x4], dim=2), torch.cat([pos_promots5, x5], dim=2)
x1, x2, x3, x4, x5 = self.att1(pos_x1, cha_x1), self.att2(pos_x2, cha_x2), self.att3(pos_x3, cha_x3), self.att4(pos_x4, cha_x4), self.att5(pos_x5, cha_x5)
x1 = window_unpartition(x1, self.local_window_sizes[0], (Hp1, Wp1), (h1, w1))
x2 = window_unpartition(x2, self.local_window_sizes[1], (Hp2, Wp2), (h2, w2))
x3 = window_unpartition(x3, self.local_window_sizes[2], (Hp3, Wp3), (h3, w3))
x4 = window_unpartition(x4, self.local_window_sizes[3], (Hp4, Wp4), (h4, w4))
x5 = window_unpartition(x5, self.local_window_sizes[4], (Hp5, Wp5), (h5, w5))
if return_features:
pro_x1, pro_x2, pro_x3, pro_x4, pro_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
return (pre_x1, pre_x2, pre_x3, pre_x4, pre_x5), (pro_x1, pro_x2, pro_x3, pro_x4, pro_x5)
x = self.decoder(x1, x2, x3, x4, x5)
x = self.Last_Conv(x)
return x