| | |
| | import warnings |
| | from collections import OrderedDict |
| | from copy import deepcopy |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint as cp |
| | from mmcv.cnn import build_norm_layer |
| | from mmcv.cnn.bricks.transformer import FFN, build_dropout |
| | from mmengine.logging import MMLogger |
| | from mmengine.model import BaseModule, ModuleList |
| | from mmengine.model.weight_init import (constant_init, trunc_normal_, |
| | trunc_normal_init) |
| | from mmengine.runner.checkpoint import CheckpointLoader |
| | from mmengine.utils import to_2tuple |
| | from typing import Optional, Sequence, Tuple, Union |
| | from mmdet.registry import MODELS |
| | from mmdet.utils import OptConfigType, OptMultiConfig |
| | from torch import Tensor, nn |
| | from ..layers import PatchEmbed, PatchMerging,AdaptivePadding |
| |
|
| |
|
| | def expand_tensor_along_second_dim(x, num): |
| | assert x.size(1)<=num |
| | |
| | repeat_times = num // x.size(1) |
| | |
| | x = x.repeat(1, repeat_times, 1, 1) |
| | |
| | if num % x.size(1) != 0: |
| | x = torch.cat([x, x[:, :num % x.size(1)]], dim=1) |
| | return x |
| |
|
| | def extract_tensor_along_second_dim(x, m): |
| | |
| | idx = torch.linspace(0, x.size(1) - 1, m).long().to(x.device) |
| | |
| | x = torch.index_select(x, 1, idx) |
| |
|
| | return x |
| |
|
| |
|
| | @MODELS.register_module() |
| | class No_backbone_ST(BaseModule): |
| | def __init__(self, |
| | in_channels=3, |
| | embed_dims=96, |
| | strides=(1, 2, 2, 4), |
| | patch_size=(1, 2, 2, 4), |
| | patch_norm=True, |
| | act_cfg=dict(type='GELU'), |
| | norm_cfg=dict(type='LN'), |
| | pretrained=None, |
| | num_levels =2, |
| | init_cfg=None): |
| | assert not (init_cfg and pretrained), \ |
| | 'init_cfg and pretrained cannot be specified at the same time' |
| | if isinstance(pretrained, str): |
| | warnings.warn('DeprecationWarning: pretrained is deprecated, ' |
| | 'please use "init_cfg" instead') |
| | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) |
| | elif pretrained is None: |
| | self.init_cfg = init_cfg |
| | else: |
| | raise TypeError('pretrained must be a str or None') |
| |
|
| | super(No_backbone_ST, self).__init__(init_cfg=init_cfg) |
| | assert strides[0] == patch_size[0], 'Use non-overlapping patch embed.' |
| | self.embed_dims =embed_dims |
| | self.in_channels = in_channels |
| |
|
| | self.patch_embed = PatchEmbed( |
| | in_channels=in_channels, |
| | embed_dims=embed_dims, |
| | conv_type='Conv2d', |
| | kernel_size=patch_size[0], |
| | stride=strides[0], |
| | norm_cfg=norm_cfg if patch_norm else None, |
| | init_cfg=None) |
| | self.num_levels = num_levels |
| | self.conv = nn.Conv2d(in_channels, embed_dims, kernel_size=1) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(in_channels, embed_dims), |
| | nn.LeakyReLU(negative_slope=0.2), |
| | nn.Linear(embed_dims, embed_dims), |
| | nn.LeakyReLU(negative_slope=0.2) |
| | ) |
| | if norm_cfg is not None: |
| | self.norm = build_norm_layer(norm_cfg, embed_dims)[1] |
| | |
| |
|
| | def train(self, mode=True): |
| | """Convert the model into training mode while keep layers freezed.""" |
| | super(No_backbone_ST, self).train(mode) |
| |
|
| | def forward(self, x): |
| |
|
| | |
| | |
| | |
| |
|
| | if self.in_channels < x.size(1): |
| | x = extract_tensor_along_second_dim(x, self.in_channels) |
| | outs = [] |
| | |
| | out = self.conv(x) |
| | out = self.norm(out.flatten(2).transpose(1, 2)) |
| | |
| | |
| | |
| | |
| | out = out.permute(0, 2, 1).reshape(x.size(0), self.embed_dims,x.size(2),x.size(3)).contiguous() |
| | outs.append(out) |
| | if self.num_levels > 1: |
| | mean = outs[0].mean(dim=(2, 3), keepdim=True).detach() |
| | outs.append(mean) |
| | return outs |
| |
|