Spaces:
Runtime error
Runtime error
| from functools import partial | |
| from typing import Tuple, List, Optional | |
| import torch | |
| from torch import Tensor, nn | |
| from mmengine.model import BaseModule, normal_init | |
| from mmdet.registry import MODELS | |
| from mmdet.models.layers import PatchEmbed | |
| from ext.meta.sam_meta import checkpoint_dict | |
| from ext.sam.common import LayerNorm2d | |
| from ext.sam.image_encoder import Block | |
| from utils.load_checkpoint import load_checkpoint_with_prefix | |
| class MultiLayerTransformerNeck(BaseModule): | |
| STRIDE = 16 | |
| def __init__( | |
| self, | |
| input_size: Tuple[int, int], | |
| in_channels: List[int], | |
| embed_channels: int, | |
| out_channels: int, | |
| layer_ids: Tuple[int] = (0, 1, 2, 3), | |
| strides: Tuple[int] = (4, 8, 16, 32), | |
| embedding_path: Optional[str] = None, | |
| fix=False, | |
| init_cfg=None | |
| ) -> None: | |
| super().__init__(init_cfg=None) | |
| self.transformer_size = (input_size[0] // self.STRIDE, input_size[1] // self.STRIDE) | |
| self.layer_ids = layer_ids | |
| self.patch_embeds = nn.ModuleList() | |
| for idx, in_ch in enumerate(in_channels): | |
| if idx in layer_ids: | |
| if strides[idx] > self.STRIDE: | |
| patch_embed = PatchEmbed( | |
| conv_type=nn.ConvTranspose2d, | |
| in_channels=in_ch, | |
| embed_dims=embed_channels, | |
| kernel_size=strides[idx] // self.STRIDE, | |
| stride=strides[idx] // self.STRIDE, | |
| input_size=(input_size[0] // strides[idx], input_size[1] // strides[idx]) | |
| ) | |
| else: | |
| patch_embed = PatchEmbed( | |
| in_channels=in_ch, | |
| embed_dims=embed_channels, | |
| kernel_size=self.STRIDE // strides[idx], | |
| stride=self.STRIDE // strides[idx], | |
| input_size=(input_size[0] // strides[idx], input_size[1] // strides[idx]) | |
| ) | |
| self.patch_embeds.append(patch_embed) | |
| else: | |
| self.patch_embeds.append(nn.Identity()) | |
| if embedding_path is not None: | |
| assert embedding_path.startswith('sam_') | |
| embedding_ckpt = embedding_path.split('_', maxsplit=1)[1] | |
| path = checkpoint_dict[embedding_ckpt] | |
| state_dict = load_checkpoint_with_prefix(path, prefix='image_encoder') | |
| pos_embed = state_dict['pos_embed'] | |
| else: | |
| # For loading from checkpoint | |
| pos_embed = torch.zeros(1, input_size[0] // self.STRIDE, input_size[1] // self.STRIDE, embed_channels) | |
| self.register_buffer('pos_embed', pos_embed) | |
| self.level_encoding = nn.Embedding(len(layer_ids), embed_channels) | |
| depth = 5 | |
| global_attn_indexes = [4] | |
| window_size = 14 | |
| self.blocks = nn.ModuleList() | |
| for i in range(depth): | |
| block = Block( | |
| dim=embed_channels, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), | |
| act_layer=nn.GELU, | |
| use_rel_pos=True, | |
| rel_pos_zero_init=True, | |
| window_size=window_size if i not in global_attn_indexes else 0, | |
| input_size=self.transformer_size, | |
| ) | |
| self.blocks.append(block) | |
| self.neck = nn.Sequential( | |
| nn.Conv2d( | |
| embed_channels, | |
| out_channels, | |
| kernel_size=1, | |
| bias=False, | |
| ), | |
| LayerNorm2d(out_channels), | |
| nn.Conv2d( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| bias=False, | |
| ), | |
| LayerNorm2d(out_channels), | |
| ) | |
| self.fix = fix | |
| if self.fix: | |
| self.train(mode=False) | |
| for name, param in self.named_parameters(): | |
| param.requires_grad = False | |
| if init_cfg is not None: | |
| assert init_cfg['type'] == 'Pretrained' | |
| checkpoint_path = init_cfg['checkpoint'] | |
| state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=init_cfg['prefix']) | |
| self.load_state_dict(state_dict, strict=True) | |
| self._is_init = True | |
| def init_weights(self): | |
| normal_init(self.level_encoding, mean=0, std=1) | |
| def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: | |
| if not isinstance(mode, bool): | |
| raise ValueError("training mode is expected to be boolean") | |
| if self.fix: | |
| super().train(mode=False) | |
| else: | |
| super().train(mode=mode) | |
| return self | |
| def forward(self, inputs: Tuple[Tensor]) -> Tensor: | |
| input_embeddings = [] | |
| level_cnt = 0 | |
| for idx, feat in enumerate(inputs): | |
| if idx not in self.layer_ids: | |
| continue | |
| feat, size = self.patch_embeds[idx](feat) | |
| feat = feat.unflatten(1, size) | |
| feat = feat + self.level_encoding.weight[level_cnt] | |
| input_embeddings.append(feat) | |
| level_cnt += 1 | |
| feat = sum(input_embeddings) | |
| feat = feat + self.pos_embed | |
| for block in self.blocks: | |
| feat = block(feat) | |
| feat = feat.permute(0, 3, 1, 2).contiguous() | |
| feat = self.neck(feat) | |
| return feat | |