GenD-Sentinel / src /model /forada /layer.py
yermandy's picture
init
c29babb
import torch
from timm.layers import to_2tuple
from torch import nn
from torch.nn import functional as F
class LayerNorm(nn.Module):
"""
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
variance normalization over the channel dimension for inputs that have shape
(batch_size, channels, height, width).
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.normalized_shape = (normalized_shape,)
def forward(self, x: torch.Tensor):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, affine_func=nn.Linear):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(affine_func(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x: torch.Tensor):
for i, layer in enumerate(self.layers):
# L R L R L
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class Fusion(nn.Module):
def __init__(self, clip_dim, adapter_dim):
super().__init__()
self.clip_dim = clip_dim
self.adapter_dim = adapter_dim
self.proj = nn.Sequential(
LayerNorm(clip_dim),
nn.Conv2d(clip_dim, adapter_dim, kernel_size=1),
)
def forward(self, x, clip_x, spatial_shape):
h, w = spatial_shape
n, l, d = clip_x.shape
if l == h * w:
clip_x = clip_x.permute(0, 2, 1).view(n, d, h, w) # NLD->NDL->NDhw
else:
clip_x = clip_x.permute(0, 2, 1).view(n, d, 14, 14) # NLD->NDL->NDhw
clip_x = F.interpolate(
clip_x.contiguous(),
size=(16, 16),
mode="bilinear",
align_corners=False,
) # ND 14 14 => N D 16 16
clip_x = self.proj(clip_x).view(n, self.adapter_dim, h * w).permute(0, 2, 1)
x = x + clip_x # NLD
return x
class MaskPostXrayProcess(nn.Module):
def __init__(self, in_c):
super().__init__()
self.process = nn.Sequential(
nn.Conv2d(
in_channels=in_c, out_channels=in_c // 2, kernel_size=3, stride=1, padding=1
), # (N Q h,w)->(N 64 h,w))
nn.BatchNorm2d(in_c // 2),
nn.ReLU(),
nn.Conv2d(in_channels=in_c // 2, out_channels=in_c // 4, kernel_size=3, stride=1, padding=1), # (N 32 h,w)
nn.BatchNorm2d(in_c // 4),
nn.ReLU(),
nn.Conv2d(in_channels=in_c // 4, out_channels=1, kernel_size=1, stride=1, padding=0), # (N 16 h,w)
nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=16, stride=16), # (N 16 h,w)->(N 1 256 256)
# nn.Upsample(size=(256, 256), mode='bilinear', align_corners=True) # (N 1 256 256)
)
def forward(self, x, if_boundaries):
x = x.reshape(x.shape[0], x.shape[1], -1) # (N Q 256)
x = x.permute(0, 2, 1) # (N L Q)
if_boundaries = if_boundaries.unsqueeze(-1) # (NL1) 不是boundry的patch块置为0
x = x * if_boundaries # (N L Q) * (N L 1)
x = x.permute(0, 2, 1) # (N Q L)
x = x.reshape(x.shape[0], x.shape[1], 16, 16)
post_x = self.process(x) # (N 1 224 224)
return post_x
class PostClipProcess(nn.Module):
"""
NQD -> ND -> N2
"""
def __init__(self, num_quires, embed_dim):
super().__init__()
self.first_process = nn.Sequential(
nn.Conv1d(
in_channels=num_quires, out_channels=num_quires // 2, kernel_size=3, stride=1, padding=1
), # NQD->N1D
nn.BatchNorm1d(num_quires // 2),
nn.ReLU(),
nn.Conv1d(in_channels=num_quires // 2, out_channels=num_quires // 4, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(num_quires // 4),
nn.ReLU(),
nn.Conv1d(in_channels=num_quires // 4, out_channels=1, kernel_size=3, stride=1, padding=1),
)
# self.norm = VT_LN(embed_dim)
self.second_process = nn.Sequential( # ND->N2
nn.Linear(in_features=embed_dim, out_features=embed_dim // 2),
nn.ReLU(),
nn.Linear(in_features=embed_dim // 2, out_features=embed_dim // 4),
nn.ReLU(),
# nn.Linear(in_features=embed_dim // 4, out_features=embed_dim // 8),
# nn.ReLU(),
nn.Linear(in_features=embed_dim // 4, out_features=2),
)
def forward(self, x):
x = self.first_process(x) # NQD->N1D
x = x.squeeze(1) # NQD->ND
x = self.second_process(x)
return x
class VT_LN(nn.LayerNorm):
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class PatchEmbed(nn.Module):
def __init__(self, img_size=256, patch_size=16, in_chans=3, embed_dim=192, norm_layer=None, bias=False, **kwargs):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = VT_LN(embed_dim)
def forward(self, x):
x = self.proj(x)
x = x.reshape(x.shape[0], x.shape[1], -1) # NDL
x = x.permute(0, 2, 1) # NDL->NLD
# x = self.norm(x)
return x