| from functools import partial |
| from itertools import repeat |
| |
| import collections.abc as container_abcs |
|
|
| import logging |
| import os |
| from collections import OrderedDict |
|
|
| import numpy as np |
| import scipy |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from einops.layers.torch import Rearrange |
|
|
| from timm.models.layers import DropPath, trunc_normal_ |
|
|
| |
| |
|
|
| from torchinfo import summary |
| import json |
|
|
| _model_entrypoints = {} |
|
|
|
|
| def register_model(fn): |
| module_name_split = fn.__module__.split('.') |
| model_name = module_name_split[-1] |
|
|
| _model_entrypoints[model_name] = fn |
|
|
| return fn |
|
|
|
|
| def model_entrypoints(model_name): |
| return _model_entrypoints[model_name] |
|
|
|
|
| def is_model(model_name): |
| return model_name in _model_entrypoints |
|
|
|
|
| |
| def _ntuple(n): |
| def parse(x): |
| if isinstance(x, container_abcs.Iterable): |
| return x |
| return tuple(repeat(x, n)) |
|
|
| return parse |
|
|
|
|
| to_1tuple = _ntuple(1) |
| to_2tuple = _ntuple(2) |
| to_3tuple = _ntuple(3) |
| to_4tuple = _ntuple(4) |
| to_ntuple = _ntuple |
|
|
|
|
| class LayerNorm(nn.LayerNorm): |
| """Subclass torch's LayerNorm to handle fp16.""" |
|
|
| def forward(self, x: torch.Tensor): |
| orig_type = x.dtype |
| ret = super().forward(x.type(torch.float32)) |
| return ret.type(orig_type) |
|
|
|
|
| class QuickGELU(nn.Module): |
| def forward(self, x: torch.Tensor): |
| return x * torch.sigmoid(1.702 * x) |
|
|
|
|
| class Mlp(nn.Module): |
| def __init__(self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.GELU, |
| drop=0.): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, |
| dim_in, |
| dim_out, |
| num_heads, |
| qkv_bias=False, |
| attn_drop=0., |
| proj_drop=0., |
| method='dw_bn', |
| kernel_size=3, |
| stride_kv=1, |
| stride_q=1, |
| padding_kv=1, |
| padding_q=1, |
| with_cls_token=True, |
| **kwargs |
| ): |
| super().__init__() |
| self.stride_kv = stride_kv |
| self.stride_q = stride_q |
| self.dim = dim_out |
| self.num_heads = num_heads |
| |
| self.scale = dim_out ** -0.5 |
| self.with_cls_token = with_cls_token |
|
|
| self.conv_proj_q = self._build_projection( |
| dim_in, dim_out, kernel_size, padding_q, |
| stride_q, 'linear' if method == 'avg' else method |
| ) |
| self.conv_proj_k = self._build_projection( |
| dim_in, dim_out, kernel_size, padding_kv, |
| stride_kv, method |
| ) |
| self.conv_proj_v = self._build_projection( |
| dim_in, dim_out, kernel_size, padding_kv, |
| stride_kv, method |
| ) |
|
|
| self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias) |
| self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias) |
| self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias) |
|
|
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim_out, dim_out) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def _build_projection(self, |
| dim_in, |
| dim_out, |
| kernel_size, |
| padding, |
| stride, |
| method): |
| if method == 'dw_bn': |
| proj = nn.Sequential(OrderedDict([ |
| ('conv', nn.Conv2d( |
| dim_in, |
| dim_in, |
| kernel_size=kernel_size, |
| padding=padding, |
| stride=stride, |
| bias=False, |
| groups=dim_in |
| )), |
| ('bn', nn.BatchNorm2d(dim_in)), |
| ('rearrage', Rearrange('b c h w -> b (h w) c')), |
| ])) |
| elif method == 'avg': |
| proj = nn.Sequential(OrderedDict([ |
| ('avg', nn.AvgPool2d( |
| kernel_size=kernel_size, |
| padding=padding, |
| stride=stride, |
| ceil_mode=True |
| )), |
| ('rearrage', Rearrange('b c h w -> b (h w) c')), |
| ])) |
| elif method == 'linear': |
| proj = None |
| else: |
| raise ValueError('Unknown method ({})'.format(method)) |
|
|
| return proj |
|
|
| def forward_conv(self, x, h, w): |
| if self.with_cls_token: |
| cls_token, x = torch.split(x, [1, h*w], 1) |
|
|
| x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) |
|
|
| if self.conv_proj_q is not None: |
| q = self.conv_proj_q(x) |
| else: |
| q = rearrange(x, 'b c h w -> b (h w) c') |
|
|
| if self.conv_proj_k is not None: |
| k = self.conv_proj_k(x) |
| else: |
| k = rearrange(x, 'b c h w -> b (h w) c') |
|
|
| if self.conv_proj_v is not None: |
| v = self.conv_proj_v(x) |
| else: |
| v = rearrange(x, 'b c h w -> b (h w) c') |
|
|
| if self.with_cls_token: |
| q = torch.cat((cls_token, q), dim=1) |
| k = torch.cat((cls_token, k), dim=1) |
| v = torch.cat((cls_token, v), dim=1) |
|
|
| return q, k, v |
|
|
| def forward(self, x, h, w): |
| if ( |
| self.conv_proj_q is not None |
| or self.conv_proj_k is not None |
| or self.conv_proj_v is not None |
| ): |
| q, k, v = self.forward_conv(x, h, w) |
|
|
| q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads) |
| k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads) |
| v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads) |
|
|
| attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale |
| attn = F.softmax(attn_score, dim=-1) |
| attn = self.attn_drop(attn) |
|
|
| x = torch.einsum('bhlt,bhtv->bhlv', [attn, v]) |
| x = rearrange(x, 'b h t d -> b t (h d)') |
|
|
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| return x |
|
|
| @staticmethod |
| def compute_macs(module, input, output): |
| |
| |
| input = input[0] |
| flops = 0 |
|
|
| _, T, C = input.shape |
| H = W = int(np.sqrt(T-1)) if module.with_cls_token else int(np.sqrt(T)) |
|
|
| H_Q = H / module.stride_q |
| W_Q = H / module.stride_q |
| T_Q = H_Q * W_Q + 1 if module.with_cls_token else H_Q * W_Q |
|
|
| H_KV = H / module.stride_kv |
| W_KV = W / module.stride_kv |
| T_KV = H_KV * W_KV + 1 if module.with_cls_token else H_KV * W_KV |
|
|
| |
| |
| |
| |
| |
| flops += T_Q * T_KV * module.dim |
| |
| flops += T_Q * module.dim * T_KV |
|
|
| if ( |
| hasattr(module, 'conv_proj_q') |
| and hasattr(module.conv_proj_q, 'conv') |
| ): |
| params = sum( |
| [ |
| p.numel() |
| for p in module.conv_proj_q.conv.parameters() |
| ] |
| ) |
| flops += params * H_Q * W_Q |
|
|
| if ( |
| hasattr(module, 'conv_proj_k') |
| and hasattr(module.conv_proj_k, 'conv') |
| ): |
| params = sum( |
| [ |
| p.numel() |
| for p in module.conv_proj_k.conv.parameters() |
| ] |
| ) |
| flops += params * H_KV * W_KV |
|
|
| if ( |
| hasattr(module, 'conv_proj_v') |
| and hasattr(module.conv_proj_v, 'conv') |
| ): |
| params = sum( |
| [ |
| p.numel() |
| for p in module.conv_proj_v.conv.parameters() |
| ] |
| ) |
| flops += params * H_KV * W_KV |
|
|
| params = sum([p.numel() for p in module.proj_q.parameters()]) |
| flops += params * T_Q |
| params = sum([p.numel() for p in module.proj_k.parameters()]) |
| flops += params * T_KV |
| params = sum([p.numel() for p in module.proj_v.parameters()]) |
| flops += params * T_KV |
| params = sum([p.numel() for p in module.proj.parameters()]) |
| flops += params * T |
|
|
| module.__flops__ += flops |
|
|
|
|
| class Block(nn.Module): |
|
|
| def __init__(self, |
| dim_in, |
| dim_out, |
| num_heads, |
| mlp_ratio=4., |
| qkv_bias=False, |
| drop=0., |
| attn_drop=0., |
| drop_path=0., |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, |
| **kwargs): |
| super().__init__() |
|
|
| self.with_cls_token = kwargs['with_cls_token'] |
|
|
| self.norm1 = norm_layer(dim_in) |
| self.attn = Attention( |
| dim_in, dim_out, num_heads, qkv_bias, attn_drop, drop, |
| **kwargs |
| ) |
|
|
| self.drop_path = DropPath(drop_path) \ |
| if drop_path > 0. else nn.Identity() |
| self.norm2 = norm_layer(dim_out) |
|
|
| dim_mlp_hidden = int(dim_out * mlp_ratio) |
| self.mlp = Mlp( |
| in_features=dim_out, |
| hidden_features=dim_mlp_hidden, |
| act_layer=act_layer, |
| drop=drop |
| ) |
|
|
| def forward(self, x, h, w): |
| res = x |
|
|
| x = self.norm1(x) |
| attn = self.attn(x, h, w) |
| x = res + self.drop_path(attn) |
| x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
| return x |
|
|
|
|
| class ConvEmbed(nn.Module): |
| """ Image to Conv Embedding |
| |
| """ |
|
|
| def __init__(self, |
| patch_size=7, |
| in_chans=1, |
| embed_dim=64, |
| stride=4, |
| padding=2, |
| norm_layer=None): |
| super().__init__() |
| patch_size = to_2tuple(patch_size) |
| self.patch_size = patch_size |
|
|
| self.proj = nn.Conv2d( |
| in_chans, embed_dim, |
| kernel_size=patch_size, |
| stride=stride, |
| padding=padding |
| ) |
| self.norm = norm_layer(embed_dim) if norm_layer else None |
|
|
| def forward(self, x): |
| x = self.proj(x) |
|
|
| B, C, H, W = x.shape |
| x = rearrange(x, 'b c h w -> b (h w) c') |
| if self.norm: |
| x = self.norm(x) |
| x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) |
|
|
| return x |
|
|
|
|
| class VisionTransformer(nn.Module): |
| """ Vision Transformer with support for patch or hybrid CNN input stage |
| """ |
| def __init__(self, |
| patch_size=16, |
| patch_stride=16, |
| patch_padding=0, |
| in_chans=1, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| mlp_ratio=4., |
| qkv_bias=False, |
| drop_rate=0., |
| attn_drop_rate=0., |
| drop_path_rate=0., |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, |
| init='trunc_norm', |
| **kwargs): |
| super().__init__() |
| self.num_features = self.embed_dim = embed_dim |
|
|
| self.rearrage = None |
|
|
| self.patch_embed = ConvEmbed( |
| |
| patch_size=patch_size, |
| in_chans=in_chans, |
| stride=patch_stride, |
| padding=patch_padding, |
| embed_dim=embed_dim, |
| norm_layer=norm_layer |
| ) |
|
|
| with_cls_token = kwargs['with_cls_token'] |
| if with_cls_token: |
| self.cls_token = nn.Parameter( |
| torch.zeros(1, 1, embed_dim) |
| ) |
| else: |
| self.cls_token = None |
|
|
| self.pos_drop = nn.Dropout(p=drop_rate) |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
|
|
| blocks = [] |
| for j in range(depth): |
| blocks.append( |
| Block( |
| dim_in=embed_dim, |
| dim_out=embed_dim, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| drop=drop_rate, |
| attn_drop=attn_drop_rate, |
| drop_path=dpr[j], |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| **kwargs |
| ) |
| ) |
| self.blocks = nn.ModuleList(blocks) |
|
|
| if self.cls_token is not None: |
| trunc_normal_(self.cls_token, std=.02) |
|
|
| if init == 'xavier': |
| self.apply(self._init_weights_xavier) |
| else: |
| self.apply(self._init_weights_trunc_normal) |
|
|
| def _init_weights_trunc_normal(self, m): |
| if isinstance(m, nn.Linear): |
| logging.info('=> init weight of Linear from trunc norm') |
| trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| logging.info('=> init bias of Linear to zeros') |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
|
|
| def _init_weights_xavier(self, m): |
| if isinstance(m, nn.Linear): |
| logging.info('=> init weight of Linear from xavier uniform') |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| logging.info('=> init bias of Linear to zeros') |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
|
|
| def forward(self, x): |
| x = self.patch_embed(x) |
| B, C, H, W = x.size() |
|
|
| x = rearrange(x, 'b c h w -> b (h w) c') |
|
|
| cls_tokens = None |
| if self.cls_token is not None: |
| |
| cls_tokens = self.cls_token.expand(B, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
|
|
| x = self.pos_drop(x) |
|
|
| for i, blk in enumerate(self.blocks): |
| x = blk(x, H, W) |
|
|
| if self.cls_token is not None: |
| cls_tokens, x = torch.split(x, [1, H*W], 1) |
| x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) |
|
|
| return x, cls_tokens |
|
|
|
|
| class ConvolutionalVisionTransformer(nn.Module): |
| def __init__(self, |
| in_chans=1, |
| num_classes=1000, |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, |
| init='trunc_norm', |
| spec=None): |
| super().__init__() |
| self.num_classes = num_classes |
|
|
| self.num_stages = spec['NUM_STAGES'] |
| for i in range(self.num_stages): |
| kwargs = { |
| 'patch_size': spec['PATCH_SIZE'][i], |
| 'patch_stride': spec['PATCH_STRIDE'][i], |
| 'patch_padding': spec['PATCH_PADDING'][i], |
| 'embed_dim': spec['DIM_EMBED'][i], |
| 'depth': spec['DEPTH'][i], |
| 'num_heads': spec['NUM_HEADS'][i], |
| 'mlp_ratio': spec['MLP_RATIO'][i], |
| 'qkv_bias': spec['QKV_BIAS'][i], |
| 'drop_rate': spec['DROP_RATE'][i], |
| 'attn_drop_rate': spec['ATTN_DROP_RATE'][i], |
| 'drop_path_rate': spec['DROP_PATH_RATE'][i], |
| 'with_cls_token': spec['CLS_TOKEN'][i], |
| 'method': spec['QKV_PROJ_METHOD'][i], |
| 'kernel_size': spec['KERNEL_QKV'][i], |
| 'padding_q': spec['PADDING_Q'][i], |
| 'padding_kv': spec['PADDING_KV'][i], |
| 'stride_kv': spec['STRIDE_KV'][i], |
| 'stride_q': spec['STRIDE_Q'][i], |
| } |
|
|
| stage = VisionTransformer( |
| in_chans=in_chans, |
| init=init, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| **kwargs |
| ) |
| setattr(self, f'stage{i}', stage) |
|
|
| in_chans = spec['DIM_EMBED'][i] |
|
|
| dim_embed = spec['DIM_EMBED'][-1] |
| self.norm = norm_layer(dim_embed) |
| self.cls_token = spec['CLS_TOKEN'][-1] |
|
|
| |
| |
| |
| self.head = nn.Identity() |
|
|
| |
|
|
| def init_weights(self, pretrained='', pretrained_layers=[], verbose=True): |
| if os.path.isfile(pretrained): |
| pretrained_dict = torch.load(pretrained, map_location='cpu') |
| logging.info(f'=> loading pretrained model {pretrained}') |
| model_dict = self.state_dict() |
| pretrained_dict = { |
| k: v for k, v in pretrained_dict.items() |
| if k in model_dict.keys() |
| } |
| need_init_state_dict = {} |
| for k, v in pretrained_dict.items(): |
| need_init = ( |
| k.split('.')[0] in pretrained_layers |
| |
| or pretrained_layers[0] == '*' |
| ) |
| if need_init: |
| if verbose: |
| logging.info(f'=> init {k} from {pretrained}') |
| if 'pos_embed' in k and v.size() != model_dict[k].size(): |
| size_pretrained = v.size() |
| size_new = model_dict[k].size() |
| logging.info( |
| '=> load_pretrained: resized variant: {} to {}' |
| .format(size_pretrained, size_new) |
| ) |
|
|
| ntok_new = size_new[1] |
| ntok_new -= 1 |
|
|
| posemb_tok, posemb_grid = v[:, :1], v[0, 1:] |
|
|
| gs_old = int(np.sqrt(len(posemb_grid))) |
| gs_new = int(np.sqrt(ntok_new)) |
|
|
| logging.info( |
| '=> load_pretrained: grid-size from {} to {}' |
| .format(gs_old, gs_new) |
| ) |
|
|
| posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) |
| zoom = (gs_new / gs_old, gs_new / gs_old, 1) |
| posemb_grid = scipy.ndimage.zoom( |
| posemb_grid, zoom, order=1 |
| ) |
| posemb_grid = posemb_grid.reshape(1, gs_new ** 2, -1) |
| v = torch.tensor( |
| np.concatenate([posemb_tok, posemb_grid], axis=1) |
| ) |
|
|
| need_init_state_dict[k] = v |
| self.load_state_dict(need_init_state_dict, strict=False) |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| layers = set() |
| for i in range(self.num_stages): |
| layers.add(f'stage{i}.pos_embed') |
| layers.add(f'stage{i}.cls_token') |
|
|
| return layers |
|
|
| def forward_features(self, x): |
| for i in range(self.num_stages): |
| x, cls_tokens = getattr(self, f'stage{i}')(x) |
|
|
| if self.cls_token: |
| x = self.norm(cls_tokens) |
| |
| x = torch.squeeze(x) |
| else: |
| x = rearrange(x, 'b c h w -> b (h w) c') |
| x = self.norm(x) |
| x = torch.mean(x, dim=1) |
|
|
| return x |
|
|
| def forward(self, x): |
| x = self.forward_features(x) |
| x = self.head(x) |
|
|
| return x |
|
|
|
|
| @register_model |
| def get_cls_model(**kwargs): |
| msvit_spec = config.MODEL.SPEC |
| msvit = ConvolutionalVisionTransformer( |
| in_chans=1, |
| num_classes=config.MODEL.NUM_CLASSES, |
| act_layer=QuickGELU, |
| norm_layer=partial(LayerNorm, eps=1e-5), |
| init=getattr(msvit_spec, 'INIT', 'trunc_norm'), |
| spec=msvit_spec |
| ) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| return msvit |
|
|
| def build_model(config, **kwargs): |
| model_name = config.MODEL.NAME |
| if not is_model(model_name): |
| raise ValueError(f'Unkown model: {model_name}') |
|
|
| return model_entrypoints(model_name)(config, **kwargs) |
|
|
| def cvt13(**kwargs): |
| f = open('config.json', 'r') |
| config = json.load(f) |
| return ConvolutionalVisionTransformer(spec=config['MODEL']['SPEC']) |
|
|
| if __name__ == '__main__': |
| f = open('config.json', 'r') |
| config = json.load(f) |
| model = ConvolutionalVisionTransformer(spec=config['MODEL']['SPEC']) |
| print(summary(model)) |
| quit() |
| print(summary(model, input_size=(4, 1, 128, 301))) |