Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from positional_encodings.torch_encodings import PositionalEncoding2D | |
| class LayerNorm2D(nn.Module): | |
| def __init__(self, embed_dim): | |
| super().__init__() | |
| self.layer_norm = nn.LayerNorm(embed_dim) | |
| def forward(self, x): | |
| x = x.permute(0, 2, 3, 1) | |
| x = self.layer_norm(x) | |
| x = x.permute(0, 3, 1, 2) | |
| return x | |
| class Image_Adaptor(nn.Module): | |
| def __init__(self, in_channels, adp_channels, dropout=0.1): | |
| super().__init__() | |
| self.adaptor = nn.Sequential( | |
| nn.Conv2d(in_channels, adp_channels // 4, kernel_size=4, padding='same'), | |
| LayerNorm2D(adp_channels // 4), | |
| nn.GELU(), | |
| nn.Conv2d(adp_channels // 4, adp_channels // 4, kernel_size=2, padding='same'), | |
| LayerNorm2D(adp_channels // 4), | |
| nn.GELU(), | |
| nn.Conv2d(adp_channels // 4, adp_channels, kernel_size=2, padding='same') | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, images): | |
| """ | |
| input: [N, in_channels, H, W] | |
| output: [N, apd_channels, H, W] | |
| """ | |
| adapt_imgs = self.adaptor(images) | |
| return self.dropout(adapt_imgs) | |
| class Positional_Encoding(nn.Module): | |
| def __init__(self, adp_channels): | |
| super().__init__() | |
| self.pe = PositionalEncoding2D(adp_channels) | |
| def forward(self, adapt_imgs): | |
| """ | |
| input: [N, apd_channels, H, W] | |
| output: [N, apd_channels, H, W] | |
| """ | |
| x = adapt_imgs.permute(0, -2, -1, -3) | |
| encode = self.pe(x) | |
| encode = encode.permute(0, -1, -3, -2) | |
| return encode | |
| class GeGLU(nn.Module): | |
| def __init__(self, emb_channels, ffn_size): | |
| super().__init__() | |
| self.wi_0 = nn.Linear(emb_channels, ffn_size, bias=False) | |
| self.wi_1 = nn.Linear(emb_channels, ffn_size, bias=False) | |
| self.act = nn.GELU() | |
| def forward(self, x): | |
| x_gelu = self.act(self.wi_0(x)) | |
| x_linear = self.wi_1(x) | |
| x = x_gelu * x_linear | |
| return x | |
| class Feed_Forward(nn.Module): | |
| def __init__(self, in_channels, ffw_channels, dropout=0.1): | |
| super().__init__() | |
| self.ln1 = GeGLU(in_channels, ffw_channels) | |
| self.dropout = nn.Dropout(dropout) | |
| self.ln2 = GeGLU(ffw_channels, in_channels) | |
| def forward(self, x): | |
| ''' | |
| input: [N, H, W, channels] | |
| output: [N, H, W, channels] | |
| ''' | |
| x = self.ln1(x) | |
| x = self.dropout(x) | |
| x = self.ln2(x) | |
| return x | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, channels, num_attn_heads, dropout=0.1): | |
| super().__init__() | |
| self.head_size = num_attn_heads | |
| self.channels = channels | |
| self.attn_size = channels // num_attn_heads | |
| self.scale = self.attn_size ** -0.5 | |
| assert num_attn_heads * self.attn_size == channels, "Input channels of attention must divisible by number of attention head!" | |
| self.lq = nn.Linear(channels, self.head_size*self.attn_size, bias=False) | |
| self.lk = nn.Linear(channels, self.head_size*self.attn_size, bias=False) | |
| self.lv = nn.Linear(channels, self.head_size*self.attn_size, bias=False) | |
| self.lout = nn.Linear(self.head_size*self.attn_size, channels, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, q, k, v): | |
| ''' | |
| input: [N, H, W, channels] cho cả 3 cái q, k, v | |
| output: [N, H, W, channels] | |
| ''' | |
| bz, H, W, C = q.shape | |
| # Duỗi ảnh ra trước | |
| q = q.view(bz, -1, C) # [N, H*W, C] | |
| k = k.view(bz, -1, C) # [N, H*W, C] | |
| v = v.view(bz, -1, C) # [N, H*W, C] | |
| q = self.lq(q).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az] | |
| k = self.lk(k).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az] | |
| v = self.lv(v).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az] | |
| q = q.transpose(1, 2) # [N, hz, H*W, az] | |
| k = k.transpose(1, 2).transpose(-1, -2) # [N, hz, az, H*W] | |
| v = v.transpose(1, 2) # [N, hz, H*W, az] | |
| q *= self.scale | |
| x = torch.matmul(q, k) # [N, hz, H*W, H*W] | |
| x = torch.softmax(x, dim=-1) | |
| x = self.dropout(x) | |
| x = x.matmul(v) # [N, hz, H*W, az] | |
| x = x.transpose(1, 2).contiguous() # [N, H*W, hz, az] | |
| x = x.view(bz, -1, C) # [N, H*W, C] | |
| x = x.view(bz, H, W, C) # [N, H, W, C] | |
| x = self.lout(x) # [N, H, W, C] | |
| return x | |
| class Transformer_Encoder_Layer(nn.Module): | |
| def __init__(self, channels, num_attn_heads, ffw_channels, dropout=0.1): | |
| super().__init__() | |
| self.attn_norm = nn.LayerNorm(channels) | |
| self.attn_layer = MultiHeadAttention(channels, num_attn_heads, dropout) | |
| self.attn_dropout = nn.Dropout(dropout) | |
| self.ffw_norm = nn.LayerNorm(channels) | |
| self.ffw_layer = Feed_Forward(channels, ffw_channels, dropout) | |
| self.ffw_dropout = nn.Dropout(dropout) | |
| def forward(self, adp_pos_imgs): | |
| """ | |
| input: [N, H, W, channels] | |
| output: [N, H, W, channels] | |
| """ | |
| _x = adp_pos_imgs | |
| x = self.attn_norm(adp_pos_imgs) | |
| x = self.attn_layer(x, x, x) | |
| x = self.attn_dropout(x) | |
| x = x + _x | |
| _x = x | |
| x = self.ffw_norm(x) | |
| x = self.ffw_layer(x) | |
| x = self.ffw_dropout(x) | |
| x = x + _x | |
| return x | |
| class Transformer_Encoder(nn.Module): | |
| def __init__(self, in_channels, out_channels, num_layers, num_attn_heads, ffw_channels, dropout=0.1): | |
| super().__init__() | |
| self.encoder_layers = nn.ModuleList([ | |
| Transformer_Encoder_Layer(in_channels, num_attn_heads, ffw_channels, dropout) for _ in range(num_layers) | |
| ]) | |
| self.linear = nn.Linear(in_channels, out_channels) | |
| self.last_norm = LayerNorm2D(out_channels) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, adp_pos_imgs): | |
| """ | |
| input: [N, in_channels, H, W] | |
| output: [N, out_channels, H, W] | |
| """ | |
| x = adp_pos_imgs.permute(0, -2, -1, -3) # [N, H, W, in_channels] | |
| for layer in self.encoder_layers: | |
| x = layer(x) | |
| x = self.linear(x) # [N, H, W, out_channels] | |
| x = x.permute(0, -1, -3, -2) | |
| x = self.last_norm(x) | |
| out = self.dropout(x) | |
| return out | |
| class Double_Conv(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.double_conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, X): | |
| """ | |
| input: [N, in_channels, H, W] | |
| output: [N, out_channels, H//2, W//2] | |
| """ | |
| return self.double_conv(X) | |
| class Down(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.down = nn.Sequential( | |
| nn.MaxPool2d(2), | |
| Double_Conv(in_channels, out_channels) | |
| ) | |
| def forward(self, X): | |
| """ | |
| input: [N, in_channels, H, W] | |
| output: [N, out_channels, H//2, W//2] | |
| """ | |
| return self.down(X) | |
| class Up(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2) | |
| self.conv = Double_Conv(in_channels, out_channels) | |
| def forward(self, X1, X2): | |
| """ | |
| input: X1 : [N, in_channels, H // 2, W // 2] | |
| X2 : [N, in_channels // 2, H, W] | |
| output: X : [N, out_channels, H, W] | |
| """ | |
| X1 = self.up(X1) | |
| diffY = X2.shape[-2] - X1.shape[-2] | |
| diffX = X2.shape[-1] - X1.shape[-1] | |
| pad_top = diffY // 2 | |
| pad_bottom = diffY - pad_top | |
| pad_left = diffX // 2 | |
| pad_right = diffX - pad_left | |
| X1 = F.pad(X1, (pad_left, pad_right, pad_top, pad_bottom)) | |
| X = torch.cat((X2, X1), dim=-3) | |
| return self.conv(X) | |
| class Out_Conv(nn.Module): | |
| def __init__(self, adp_channels, out_channels): | |
| super().__init__() | |
| self.out_conv = nn.Conv2d(adp_channels, out_channels, kernel_size=1) | |
| def forward(self, X): | |
| return self.out_conv(X) | |
| class Trans_UNet(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| adp_channels, | |
| out_channels, | |
| trans_num_layers=5, | |
| trans_num_attn_heads=8, | |
| trans_ffw_channels=1024, | |
| dropout=0.1): | |
| super().__init__() | |
| self.img_adaptor = Image_Adaptor(in_channels, adp_channels, dropout) | |
| self.pos_encoding = Positional_Encoding(adp_channels) | |
| self.down1 = Down(adp_channels * 1, adp_channels * 2) | |
| self.down2 = Down(adp_channels * 2, adp_channels * 4) | |
| self.down3 = Down(adp_channels * 4, adp_channels * 8) | |
| self.down4 = Down(adp_channels * 8, adp_channels * 16) | |
| self.down5 = Down(adp_channels * 16, adp_channels * 32) | |
| self.trans_encoder = Transformer_Encoder(adp_channels * 32, adp_channels * 32, trans_num_layers, trans_num_attn_heads, trans_ffw_channels, dropout) | |
| self.up5 = Up(adp_channels * 32, adp_channels * 16) | |
| self.up4 = Up(adp_channels * 16, adp_channels * 8) | |
| self.up3 = Up(adp_channels * 8, adp_channels * 4) | |
| self.up2 = Up(adp_channels * 4, adp_channels * 2) | |
| self.up1 = Up(adp_channels * 2, adp_channels * 1) | |
| self.out_conv = Out_Conv(adp_channels, out_channels) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, images): | |
| adp_imgs = self.img_adaptor(images) | |
| pos_enc = self.pos_encoding(adp_imgs) | |
| adp_imgs += pos_enc | |
| d1 = self.down1(adp_imgs) | |
| d2 = self.down2(d1) | |
| d3 = self.down3(d2) | |
| d4 = self.down4(d3) | |
| d5 = self.down5(d4) | |
| x = self.trans_encoder(d5) | |
| u5 = self.up5(x, d4) | |
| u4 = self.up4(u5, d3) | |
| u3 = self.up3(u4, d2) | |
| u2 = self.up2(u3, d1) | |
| u1 = self.up1(u2, adp_imgs) | |
| x = self.out_conv(u1) | |
| out = self.sigmoid(x) | |
| return out |