| | |
| | |
| | |
| | |
| |
|
| | from functools import partial |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.utils.checkpoint as cp |
| |
|
| | from ...ops.modules import MSDeformAttn |
| | from .drop_path import DropPath |
| |
|
| |
|
| | def get_reference_points(spatial_shapes, device): |
| | reference_points_list = [] |
| | for lvl, (H_, W_) in enumerate(spatial_shapes): |
| | ref_y, ref_x = torch.meshgrid( |
| | torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), |
| | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), |
| | ) |
| | ref_y = ref_y.reshape(-1)[None] / H_ |
| | ref_x = ref_x.reshape(-1)[None] / W_ |
| | ref = torch.stack((ref_x, ref_y), -1) |
| | reference_points_list.append(ref) |
| | reference_points = torch.cat(reference_points_list, 1) |
| | reference_points = reference_points[:, :, None] |
| | return reference_points |
| |
|
| |
|
| | def deform_inputs(x, patch_size): |
| | bs, c, h, w = x.shape |
| | spatial_shapes = torch.as_tensor( |
| | [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device |
| | ) |
| | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) |
| | reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device) |
| | deform_inputs1 = [reference_points, spatial_shapes, level_start_index] |
| |
|
| | spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device) |
| | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) |
| | reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device) |
| | deform_inputs2 = [reference_points, spatial_shapes, level_start_index] |
| |
|
| | return deform_inputs1, deform_inputs2 |
| |
|
| |
|
| | class ConvFFN(nn.Module): |
| | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.fc1 = nn.Linear(in_features, hidden_features) |
| | self.dwconv = DWConv(hidden_features) |
| | self.act = act_layer() |
| | self.fc2 = nn.Linear(hidden_features, out_features) |
| | self.drop = nn.Dropout(drop) |
| |
|
| | def forward(self, x, H, W): |
| | x = self.fc1(x) |
| | x = self.dwconv(x, H, W) |
| | x = self.act(x) |
| | x = self.drop(x) |
| | x = self.fc2(x) |
| | x = self.drop(x) |
| | return x |
| |
|
| |
|
| | class DWConv(nn.Module): |
| | def __init__(self, dim=768): |
| | super().__init__() |
| | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) |
| |
|
| | def forward(self, x, H, W): |
| | B, N, C = x.shape |
| | n = N // 21 |
| | x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous() |
| | x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous() |
| | x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous() |
| | x1 = self.dwconv(x1).flatten(2).transpose(1, 2) |
| | x2 = self.dwconv(x2).flatten(2).transpose(1, 2) |
| | x3 = self.dwconv(x3).flatten(2).transpose(1, 2) |
| | x = torch.cat([x1, x2, x3], dim=1) |
| | return x |
| |
|
| |
|
| | class Extractor(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | num_heads=6, |
| | n_points=4, |
| | n_levels=1, |
| | deform_ratio=1.0, |
| | with_cffn=True, |
| | cffn_ratio=0.25, |
| | drop=0.0, |
| | drop_path=0.0, |
| | norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| | with_cp=False, |
| | ): |
| | super().__init__() |
| | self.query_norm = norm_layer(dim) |
| | self.feat_norm = norm_layer(dim) |
| | self.attn = MSDeformAttn( |
| | d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio |
| | ) |
| | self.with_cffn = with_cffn |
| | self.with_cp = with_cp |
| | if with_cffn: |
| | self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop) |
| | self.ffn_norm = norm_layer(dim) |
| | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| |
|
| | def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W): |
| | def _inner_forward(query, feat): |
| |
|
| | attn = self.attn( |
| | self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None |
| | ) |
| | query = query + attn |
| |
|
| | if self.with_cffn: |
| | query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) |
| | return query |
| |
|
| | if self.with_cp and query.requires_grad: |
| | query = cp.checkpoint(_inner_forward, query, feat) |
| | else: |
| | query = _inner_forward(query, feat) |
| |
|
| | return query |
| |
|
| |
|
| | class Injector(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | num_heads=6, |
| | n_points=4, |
| | n_levels=1, |
| | deform_ratio=1.0, |
| | norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| | init_values=0.0, |
| | with_cp=False, |
| | ): |
| | super().__init__() |
| | self.with_cp = with_cp |
| | self.query_norm = norm_layer(dim) |
| | self.feat_norm = norm_layer(dim) |
| | self.attn = MSDeformAttn( |
| | d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio |
| | ) |
| | self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) |
| |
|
| | def forward(self, query, reference_points, feat, spatial_shapes, level_start_index): |
| | def _inner_forward(query, feat): |
| |
|
| | attn = self.attn( |
| | self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None |
| | ) |
| | return query + self.gamma * attn |
| |
|
| | if self.with_cp and query.requires_grad: |
| | query = cp.checkpoint(_inner_forward, query, feat) |
| | else: |
| | query = _inner_forward(query, feat) |
| |
|
| | return query |
| |
|
| |
|
| | class InteractionBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | num_heads=6, |
| | n_points=4, |
| | norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| | drop=0.0, |
| | drop_path=0.0, |
| | with_cffn=True, |
| | cffn_ratio=0.25, |
| | init_values=0.0, |
| | deform_ratio=1.0, |
| | extra_extractor=False, |
| | with_cp=False, |
| | ): |
| | super().__init__() |
| |
|
| | self.injector = Injector( |
| | dim=dim, |
| | n_levels=3, |
| | num_heads=num_heads, |
| | init_values=init_values, |
| | n_points=n_points, |
| | norm_layer=norm_layer, |
| | deform_ratio=deform_ratio, |
| | with_cp=with_cp, |
| | ) |
| | self.extractor = Extractor( |
| | dim=dim, |
| | n_levels=1, |
| | num_heads=num_heads, |
| | n_points=n_points, |
| | norm_layer=norm_layer, |
| | deform_ratio=deform_ratio, |
| | with_cffn=with_cffn, |
| | cffn_ratio=cffn_ratio, |
| | drop=drop, |
| | drop_path=drop_path, |
| | with_cp=with_cp, |
| | ) |
| | if extra_extractor: |
| | self.extra_extractors = nn.Sequential( |
| | *[ |
| | Extractor( |
| | dim=dim, |
| | num_heads=num_heads, |
| | n_points=n_points, |
| | norm_layer=norm_layer, |
| | with_cffn=with_cffn, |
| | cffn_ratio=cffn_ratio, |
| | deform_ratio=deform_ratio, |
| | drop=drop, |
| | drop_path=drop_path, |
| | with_cp=with_cp, |
| | ) |
| | for _ in range(2) |
| | ] |
| | ) |
| | else: |
| | self.extra_extractors = None |
| |
|
| | def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): |
| | x = self.injector( |
| | query=x, |
| | reference_points=deform_inputs1[0], |
| | feat=c, |
| | spatial_shapes=deform_inputs1[1], |
| | level_start_index=deform_inputs1[2], |
| | ) |
| | for idx, blk in enumerate(blocks): |
| | x = blk(x, H_toks, W_toks) |
| | c = self.extractor( |
| | query=c, |
| | reference_points=deform_inputs2[0], |
| | feat=x, |
| | spatial_shapes=deform_inputs2[1], |
| | level_start_index=deform_inputs2[2], |
| | H=H_c, |
| | W=W_c, |
| | ) |
| | if self.extra_extractors is not None: |
| | for extractor in self.extra_extractors: |
| | c = extractor( |
| | query=c, |
| | reference_points=deform_inputs2[0], |
| | feat=x, |
| | spatial_shapes=deform_inputs2[1], |
| | level_start_index=deform_inputs2[2], |
| | H=H_c, |
| | W=W_c, |
| | ) |
| | return x, c |
| |
|
| |
|
| | class InteractionBlockWithCls(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | num_heads=6, |
| | n_points=4, |
| | norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| | drop=0.0, |
| | drop_path=0.0, |
| | with_cffn=True, |
| | cffn_ratio=0.25, |
| | init_values=0.0, |
| | deform_ratio=1.0, |
| | extra_extractor=False, |
| | with_cp=False, |
| | ): |
| | super().__init__() |
| |
|
| | self.injector = Injector( |
| | dim=dim, |
| | n_levels=3, |
| | num_heads=num_heads, |
| | init_values=init_values, |
| | n_points=n_points, |
| | norm_layer=norm_layer, |
| | deform_ratio=deform_ratio, |
| | with_cp=with_cp, |
| | ) |
| | self.extractor = Extractor( |
| | dim=dim, |
| | n_levels=1, |
| | num_heads=num_heads, |
| | n_points=n_points, |
| | norm_layer=norm_layer, |
| | deform_ratio=deform_ratio, |
| | with_cffn=with_cffn, |
| | cffn_ratio=cffn_ratio, |
| | drop=drop, |
| | drop_path=drop_path, |
| | with_cp=with_cp, |
| | ) |
| | if extra_extractor: |
| | self.extra_extractors = nn.Sequential( |
| | *[ |
| | Extractor( |
| | dim=dim, |
| | num_heads=num_heads, |
| | n_points=n_points, |
| | norm_layer=norm_layer, |
| | with_cffn=with_cffn, |
| | cffn_ratio=cffn_ratio, |
| | deform_ratio=deform_ratio, |
| | drop=drop, |
| | drop_path=drop_path, |
| | with_cp=with_cp, |
| | ) |
| | for _ in range(2) |
| | ] |
| | ) |
| | else: |
| | self.extra_extractors = None |
| |
|
| | def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): |
| | x = self.injector( |
| | query=x, |
| | reference_points=deform_inputs1[0], |
| | feat=c, |
| | spatial_shapes=deform_inputs1[1], |
| | level_start_index=deform_inputs1[2], |
| | ) |
| | x = torch.cat((cls, x), dim=1) |
| | for idx, blk in enumerate(blocks): |
| | x = blk(x, H_toks, W_toks) |
| | cls, x = ( |
| | x[ |
| | :, |
| | :1, |
| | ], |
| | x[ |
| | :, |
| | 1:, |
| | ], |
| | ) |
| | c = self.extractor( |
| | query=c, |
| | reference_points=deform_inputs2[0], |
| | feat=x, |
| | spatial_shapes=deform_inputs2[1], |
| | level_start_index=deform_inputs2[2], |
| | H=H_c, |
| | W=W_c, |
| | ) |
| | if self.extra_extractors is not None: |
| | for extractor in self.extra_extractors: |
| | c = extractor( |
| | query=c, |
| | reference_points=deform_inputs2[0], |
| | feat=x, |
| | spatial_shapes=deform_inputs2[1], |
| | level_start_index=deform_inputs2[2], |
| | H=H_c, |
| | W=W_c, |
| | ) |
| | return x, c, cls |
| |
|
| |
|
| | class SpatialPriorModule(nn.Module): |
| | def __init__(self, inplanes=64, embed_dim=384, with_cp=False): |
| | super().__init__() |
| | self.with_cp = with_cp |
| |
|
| | self.stem = nn.Sequential( |
| | *[ |
| | nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), |
| | nn.SyncBatchNorm(inplanes), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), |
| | nn.SyncBatchNorm(inplanes), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), |
| | nn.SyncBatchNorm(inplanes), |
| | nn.ReLU(inplace=True), |
| | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), |
| | ] |
| | ) |
| | self.conv2 = nn.Sequential( |
| | *[ |
| | nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), |
| | nn.SyncBatchNorm(2 * inplanes), |
| | nn.ReLU(inplace=True), |
| | ] |
| | ) |
| | self.conv3 = nn.Sequential( |
| | *[ |
| | nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), |
| | nn.SyncBatchNorm(4 * inplanes), |
| | nn.ReLU(inplace=True), |
| | ] |
| | ) |
| | self.conv4 = nn.Sequential( |
| | *[ |
| | nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), |
| | nn.SyncBatchNorm(4 * inplanes), |
| | nn.ReLU(inplace=True), |
| | ] |
| | ) |
| | self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) |
| | self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) |
| | self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) |
| | self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) |
| |
|
| | def forward(self, x): |
| | def _inner_forward(x): |
| | c1 = self.stem(x) |
| | c2 = self.conv2(c1) |
| | c3 = self.conv3(c2) |
| | c4 = self.conv4(c3) |
| | c1 = self.fc1(c1) |
| | c2 = self.fc2(c2) |
| | c3 = self.fc3(c3) |
| | c4 = self.fc4(c4) |
| |
|
| | bs, dim, _, _ = c1.shape |
| | |
| | c2 = c2.view(bs, dim, -1).transpose(1, 2) |
| | c3 = c3.view(bs, dim, -1).transpose(1, 2) |
| | c4 = c4.view(bs, dim, -1).transpose(1, 2) |
| |
|
| | return c1, c2, c3, c4 |
| |
|
| | if self.with_cp and x.requires_grad: |
| | outs = cp.checkpoint(_inner_forward, x) |
| | else: |
| | outs = _inner_forward(x) |
| | return outs |
| |
|