| import torch
|
| import torch.nn.functional as F
|
| from torch import nn
|
| from torch import Tensor
|
| from einops import rearrange, reduce, repeat
|
| from einops.layers.torch import Rearrange, Reduce
|
|
|
|
|
| class DoubleConv(nn.Module):
|
| def __init__(self, in_channels, out_channels, mid_channels=None):
|
| super().__init__()
|
| if not mid_channels:
|
| mid_channels = out_channels
|
| self.double_conv = nn.Sequential(
|
| nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(mid_channels),
|
| nn.ReLU(inplace=True),
|
| nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(out_channels),
|
| nn.ReLU(inplace=True),
|
| nn.MaxPool2d(2)
|
| )
|
|
|
| def forward(self, x):
|
| return self.double_conv(x)
|
|
|
|
|
| class Encoder_patch(nn.Module):
|
| def __init__(self, n_channels, emb_size=512, bilinear=True):
|
| super(Encoder_patch, self).__init__()
|
| self.n_channels = n_channels
|
| self.emb_size = emb_size
|
| self.bilinear = bilinear
|
|
|
| self.conv1 = DoubleConv(n_channels, 128)
|
| self.conv2 = DoubleConv(128, 256)
|
| self.conv3 = DoubleConv(256, emb_size)
|
|
|
| def forward(self, x):
|
| x = self.conv1(x)
|
| x = self.conv2(x)
|
| x = self.conv3(x)
|
| x = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(x, 1), start_dim=1)
|
| return x
|
|
|
|
|
| class PatchEmbedding(nn.Module):
|
| def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
|
| self.patch_size = patch_size
|
| super().__init__()
|
|
|
| self.projection = nn.Sequential(
|
| Rearrange('b c (ph h) (pw w) -> b c (ph pw) h w', c=in_channels, h=patch_size, ph=img_size // patch_size,
|
| w=patch_size, pw=img_size // patch_size),
|
| Rearrange('b c p h w -> (b p) c h w'),
|
| Encoder_patch(n_channels=in_channels, emb_size=emb_size),
|
| Rearrange('(b p) d-> b p d', p=(img_size // patch_size) ** 2),
|
| )
|
| self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
|
| self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))
|
|
|
| def forward(self, x: Tensor) -> Tensor:
|
| b, _, _, _ = x.shape
|
| x = self.projection(x)
|
| cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
|
|
|
| x = torch.cat([cls_tokens, x], dim=1)
|
|
|
| x += self.positions
|
| return x
|
|
|
|
|
| class MultiHeadAttention(nn.Module):
|
| def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
|
| super().__init__()
|
| self.emb_size = emb_size
|
| self.num_heads = num_heads
|
|
|
| self.qkv = nn.Linear(emb_size, emb_size * 3)
|
| self.att_drop = nn.Dropout(dropout)
|
| self.projection = nn.Linear(emb_size, emb_size)
|
|
|
| def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
|
|
|
| qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
|
| queries, keys, values = qkv[0], qkv[1], qkv[2]
|
|
|
| energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
| if mask is not None:
|
| fill_value = torch.finfo(torch.float32).min
|
| energy.mask_fill(~mask, fill_value)
|
|
|
| scaling = self.emb_size ** (1 / 2)
|
| att = F.softmax(energy, dim=-1) / scaling
|
| att = self.att_drop(att)
|
|
|
| out = torch.einsum('bhal, bhlv -> bhav ', att, values)
|
| out = rearrange(out, "b h n d -> b n (h d)")
|
| out = self.projection(out)
|
| return out
|
|
|
|
|
| class ResidualAdd(nn.Module):
|
| def __init__(self, fn):
|
| super().__init__()
|
| self.fn = fn
|
|
|
| def forward(self, x, **kwargs):
|
| res = x
|
| x = self.fn(x, **kwargs)
|
| x += res
|
| return x
|
|
|
|
|
| class FeedForwardBlock(nn.Sequential):
|
| def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
|
| super().__init__(
|
| nn.Linear(emb_size, expansion * emb_size),
|
| nn.GELU(),
|
| nn.Dropout(drop_p),
|
| nn.Linear(expansion * emb_size, emb_size),
|
| )
|
|
|
|
|
| class TransformerEncoderBlock(nn.Sequential):
|
| def __init__(self,
|
| emb_size: int = 768,
|
| drop_p: float = 0.,
|
| forward_expansion: int = 4,
|
| forward_drop_p: float = 0.,
|
| **kwargs):
|
| super().__init__(
|
| ResidualAdd(nn.Sequential(
|
| nn.LayerNorm(emb_size),
|
| MultiHeadAttention(emb_size, **kwargs),
|
| nn.Dropout(drop_p)
|
| )),
|
| ResidualAdd(nn.Sequential(
|
| nn.LayerNorm(emb_size),
|
| FeedForwardBlock(
|
| emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
|
| nn.Dropout(drop_p)
|
| )
|
| ))
|
|
|
|
|
| class TransformerEncoder(nn.Sequential):
|
| def __init__(self, depth: int = 12, **kwargs):
|
| super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
|
|
|
|
|
| class dependencymap(nn.Sequential):
|
| def __init__(self, emb_size: int = 768, n_regions: int = 256, patch_size: int = 16, img_size: int = 256,
|
| output_ch: int = 64, cuda=True):
|
| super().__init__()
|
| self.patch_size = patch_size
|
| self.img_size = img_size
|
| self.emb_size = emb_size
|
| self.output_ch = output_ch
|
| self.cuda = cuda
|
| self.outconv = nn.Sequential(
|
| nn.Conv2d(emb_size, output_ch, kernel_size=1, padding=0),
|
| nn.BatchNorm2d(output_ch),
|
| nn.Sigmoid()
|
| )
|
| self.out2 = nn.Sigmoid()
|
|
|
| self.gpool = nn.AdaptiveAvgPool1d(1)
|
|
|
| def forward(self, x):
|
| x_gpool = self.gpool(x)
|
| coeff = torch.zeros((x.size()[0], self.emb_size, self.img_size, self.img_size))
|
| coeff2 = torch.zeros((x.size()[0], 1, self.img_size, self.img_size))
|
| if self.cuda:
|
| coeff = coeff.cuda()
|
| coeff2 = coeff2.cuda()
|
| for i in range(0, self.img_size // self.patch_size):
|
| for j in range(0, self.img_size // self.patch_size):
|
| value = x[:, (i * self.patch_size) + j]
|
| value = value.view(value.size()[0], value.size()[1], 1, 1)
|
| coeff[:, :, self.patch_size * i:self.patch_size * (i + 1),
|
| self.patch_size * j:self.patch_size * (j + 1)] = value.repeat(1, 1, self.patch_size, self.patch_size)
|
|
|
| value = x_gpool[:, (i * self.patch_size) + j]
|
| value = value.view(value.size()[0], value.size()[1], 1, 1)
|
| coeff2[:, :, self.patch_size * i:self.patch_size * (i + 1),
|
| self.patch_size * j:self.patch_size * (j + 1)] = value.repeat(1, 1, self.patch_size, self.patch_size)
|
|
|
| global_contexual = self.outconv(coeff)
|
| regional_distribution = self.out2(coeff2)
|
| return [global_contexual, regional_distribution, self.out2(x_gpool)]
|
|
|
|
|
| class transmunet(nn.Sequential):
|
| def __init__(self,
|
| in_channels: int = 3,
|
| patch_size: int = 16,
|
| emb_size: int = 1024,
|
| img_size: int = 256,
|
| depth: int = 2,
|
| n_regions: int = (256 // 16) ** 2,
|
| output_ch: int = 64,
|
| cuda=True,
|
| **kwargs):
|
| super().__init__(
|
| PatchEmbedding(in_channels, patch_size, emb_size, img_size),
|
| TransformerEncoder(depth, emb_size=emb_size, **kwargs),
|
| dependencymap(emb_size, n_regions, patch_size, img_size, output_ch, cuda)
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|