import torch from timm.layers import to_2tuple from torch import nn from torch.nn import functional as F class LayerNorm(nn.Module): """ A LayerNorm variant, popularized by Transformers, that performs point-wise mean and variance normalization over the channel dimension for inputs that have shape (batch_size, channels, height, width). https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 """ def __init__(self, normalized_shape, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.normalized_shape = (normalized_shape,) def forward(self, x: torch.Tensor): u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class MLP(nn.Module): """Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers, affine_func=nn.Linear): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(affine_func(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x: torch.Tensor): for i, layer in enumerate(self.layers): # L R L R L x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x class Fusion(nn.Module): def __init__(self, clip_dim, adapter_dim): super().__init__() self.clip_dim = clip_dim self.adapter_dim = adapter_dim self.proj = nn.Sequential( LayerNorm(clip_dim), nn.Conv2d(clip_dim, adapter_dim, kernel_size=1), ) def forward(self, x, clip_x, spatial_shape): h, w = spatial_shape n, l, d = clip_x.shape if l == h * w: clip_x = clip_x.permute(0, 2, 1).view(n, d, h, w) # NLD->NDL->NDhw else: clip_x = clip_x.permute(0, 2, 1).view(n, d, 14, 14) # NLD->NDL->NDhw clip_x = F.interpolate( clip_x.contiguous(), size=(16, 16), mode="bilinear", align_corners=False, ) # ND 14 14 => N D 16 16 clip_x = self.proj(clip_x).view(n, self.adapter_dim, h * w).permute(0, 2, 1) x = x + clip_x # NLD return x class MaskPostXrayProcess(nn.Module): def __init__(self, in_c): super().__init__() self.process = nn.Sequential( nn.Conv2d( in_channels=in_c, out_channels=in_c // 2, kernel_size=3, stride=1, padding=1 ), # (N Q h,w)->(N 64 h,w)) nn.BatchNorm2d(in_c // 2), nn.ReLU(), nn.Conv2d(in_channels=in_c // 2, out_channels=in_c // 4, kernel_size=3, stride=1, padding=1), # (N 32 h,w) nn.BatchNorm2d(in_c // 4), nn.ReLU(), nn.Conv2d(in_channels=in_c // 4, out_channels=1, kernel_size=1, stride=1, padding=0), # (N 16 h,w) nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=16, stride=16), # (N 16 h,w)->(N 1 256 256) # nn.Upsample(size=(256, 256), mode='bilinear', align_corners=True) # (N 1 256 256) ) def forward(self, x, if_boundaries): x = x.reshape(x.shape[0], x.shape[1], -1) # (N Q 256) x = x.permute(0, 2, 1) # (N L Q) if_boundaries = if_boundaries.unsqueeze(-1) # (NL1) 不是boundry的patch块置为0 x = x * if_boundaries # (N L Q) * (N L 1) x = x.permute(0, 2, 1) # (N Q L) x = x.reshape(x.shape[0], x.shape[1], 16, 16) post_x = self.process(x) # (N 1 224 224) return post_x class PostClipProcess(nn.Module): """ NQD -> ND -> N2 """ def __init__(self, num_quires, embed_dim): super().__init__() self.first_process = nn.Sequential( nn.Conv1d( in_channels=num_quires, out_channels=num_quires // 2, kernel_size=3, stride=1, padding=1 ), # NQD->N1D nn.BatchNorm1d(num_quires // 2), nn.ReLU(), nn.Conv1d(in_channels=num_quires // 2, out_channels=num_quires // 4, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(num_quires // 4), nn.ReLU(), nn.Conv1d(in_channels=num_quires // 4, out_channels=1, kernel_size=3, stride=1, padding=1), ) # self.norm = VT_LN(embed_dim) self.second_process = nn.Sequential( # ND->N2 nn.Linear(in_features=embed_dim, out_features=embed_dim // 2), nn.ReLU(), nn.Linear(in_features=embed_dim // 2, out_features=embed_dim // 4), nn.ReLU(), # nn.Linear(in_features=embed_dim // 4, out_features=embed_dim // 8), # nn.ReLU(), nn.Linear(in_features=embed_dim // 4, out_features=2), ) def forward(self, x): x = self.first_process(x) # NQD->N1D x = x.squeeze(1) # NQD->ND x = self.second_process(x) return x class VT_LN(nn.LayerNorm): def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class PatchEmbed(nn.Module): def __init__(self, img_size=256, patch_size=16, in_chans=3, embed_dim=192, norm_layer=None, bias=False, **kwargs): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = VT_LN(embed_dim) def forward(self, x): x = self.proj(x) x = x.reshape(x.shape[0], x.shape[1], -1) # NDL x = x.permute(0, 2, 1) # NDL->NLD # x = self.norm(x) return x