| """ Vision Transformer (ViT) in PyTorch |
| A PyTorch implement of Vision Transformers as described in: |
| 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' |
| - https://arxiv.org/abs/2010.11929 |
| `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` |
| - https://arxiv.org/abs/2106.10270 |
| The official jax code is released and available at https://github.com/google-research/vision_transformer |
| Acknowledgments: |
| * The paper authors for releasing code and weights, thanks! |
| * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out |
| for some einops/einsum fun |
| * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT |
| * Bert reference code checks against Huggingface Transformers and Tensorflow Bert |
| Hacked together by / Copyright 2020, Ross Wightman |
| """ |
|
|
| import math |
| import logging |
| from functools import partial |
| from collections import OrderedDict |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
|
|
| from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD |
| from timm.models.helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq |
| from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ |
| from timm.models.registry import register_model |
|
|
| _logger = logging.getLogger(__name__) |
|
|
|
|
| def _cfg(url='', **kwargs): |
| return { |
| 'url': url, |
| 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, |
| 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, |
| 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, |
| 'first_conv': 'patch_embed.proj', 'classifier': 'head', |
| **kwargs |
| } |
|
|
|
|
| default_cfgs = { |
| |
| 'vit_tiny_patch16_224': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), |
| 'vit_tiny_patch16_384': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', |
| input_size=(3, 384, 384), crop_pct=1.0), |
| 'vit_small_patch32_224': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), |
| 'vit_small_patch32_384': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', |
| input_size=(3, 384, 384), crop_pct=1.0), |
| 'vit_small_patch16_224': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), |
| 'vit_small_patch16_384': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', |
| input_size=(3, 384, 384), crop_pct=1.0), |
| 'vit_base_patch32_224': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), |
| 'vit_base_patch32_384': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', |
| input_size=(3, 384, 384), crop_pct=1.0), |
| 'vit_base_patch16_224': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), |
| 'vit_base_patch16_384': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', |
| input_size=(3, 384, 384), crop_pct=1.0), |
| 'vit_base_patch8_224': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), |
| 'vit_large_patch32_224': _cfg( |
| url='', |
| ), |
| 'vit_large_patch32_384': _cfg( |
| url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', |
| input_size=(3, 384, 384), crop_pct=1.0), |
| 'vit_large_patch16_224': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), |
| 'vit_large_patch16_384': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/' |
| 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', |
| input_size=(3, 384, 384), crop_pct=1.0), |
|
|
| 'vit_large_patch14_224': _cfg(url=''), |
| 'vit_huge_patch14_224': _cfg(url=''), |
| 'vit_giant_patch14_224': _cfg(url=''), |
| 'vit_gigantic_patch14_224': _cfg(url=''), |
|
|
| 'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), |
|
|
| |
| 'vit_tiny_patch16_224_in21k': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', |
| num_classes=21843), |
| 'vit_small_patch32_224_in21k': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', |
| num_classes=21843), |
| 'vit_small_patch16_224_in21k': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', |
| num_classes=21843), |
| 'vit_base_patch32_224_in21k': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', |
| num_classes=21843), |
| 'vit_base_patch16_224_in21k': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', |
| num_classes=21843), |
| 'vit_base_patch8_224_in21k': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', |
| num_classes=21843), |
| 'vit_large_patch32_224_in21k': _cfg( |
| url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', |
| num_classes=21843), |
| 'vit_large_patch16_224_in21k': _cfg( |
| url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', |
| num_classes=21843), |
| 'vit_huge_patch14_224_in21k': _cfg( |
| url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', |
| hf_hub_id='timm/vit_huge_patch14_224_in21k', |
| num_classes=21843), |
|
|
| |
| 'vit_base_patch32_224_sam': _cfg( |
| url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), |
| 'vit_base_patch16_224_sam': _cfg( |
| url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), |
|
|
| |
| 'vit_small_patch16_224_dino': _cfg( |
| url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', |
| mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), |
| 'vit_small_patch8_224_dino': _cfg( |
| url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', |
| mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), |
| 'vit_base_patch16_224_dino': _cfg( |
| url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', |
| mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), |
| 'vit_base_patch8_224_dino': _cfg( |
| url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', |
| mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), |
|
|
|
|
| |
| 'vit_base_patch16_224_miil_in21k': _cfg( |
| url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', |
| mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, |
| ), |
| 'vit_base_patch16_224_miil': _cfg( |
| url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' |
| '/vit_base_patch16_224_1k_miil_84_4.pth', |
| mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', |
| ), |
|
|
| |
| 'vit_small_patch16_36x1_224': _cfg(url=''), |
| 'vit_small_patch16_18x2_224': _cfg(url=''), |
| 'vit_base_patch16_18x2_224': _cfg(url=''), |
| } |
|
|
|
|
| class Attention_LoRA(nn.Module): |
| def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., r=64, n_tasks=10): |
| super().__init__() |
|
|
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| |
| self.scale = qk_scale or head_dim ** -0.5 |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
| self.attn_gradients = None |
| self.attention_map = None |
| self.rank = r |
|
|
| self.lora_A_k = nn.ModuleList([nn.Linear(dim, r, bias=False) for _ in range(n_tasks)]) |
| self.lora_B_k = nn.ModuleList([nn.Linear(r, dim, bias=False) for _ in range(n_tasks)]) |
| self.lora_A_v = nn.ModuleList([nn.Linear(dim, r, bias=False) for _ in range(n_tasks)]) |
| self.lora_B_v = nn.ModuleList([nn.Linear(r, dim, bias=False) for _ in range(n_tasks)]) |
| self.rank = r |
|
|
| self.matrix = torch.zeros(dim ,dim) |
| self.n_matrix = 0 |
| self.cur_matrix = torch.zeros(dim ,dim) |
| self.n_cur_matrix = 0 |
|
|
| def init_param(self): |
| for t in range(len(self.lora_A_k)): |
| nn.init.kaiming_uniform_(self.lora_A_k[t].weight, a=math.sqrt(5)) |
| nn.init.kaiming_uniform_(self.lora_A_v[t].weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B_k[t].weight) |
| nn.init.zeros_(self.lora_B_v[t].weight) |
|
|
| def save_attn_gradients(self, attn_gradients): |
| self.attn_gradients = attn_gradients |
| |
| def get_attn_gradients(self): |
| return self.attn_gradients |
| |
| def save_attention_map(self, attention_map): |
| self.attention_map = attention_map |
| |
| def get_attention_map(self): |
| return self.attention_map |
| |
| def forward(self, x, task, register_hook=False, get_feat=False,get_cur_feat=False): |
| if get_feat: |
| self.matrix = (self.matrix*self.n_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_matrix + x.shape[0]*x.shape[1]) |
| self.n_matrix += x.shape[0]*x.shape[1] |
| if get_cur_feat: |
| self.cur_matrix = (self.cur_matrix*self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0]*x.shape[1]) |
| self.n_cur_matrix += x.shape[0]*x.shape[1] |
|
|
| B, N, C = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| |
| if task > -0.5: |
| weight_k = torch.stack([torch.mm(self.lora_B_k[t].weight, self.lora_A_k[t].weight) for t in range(task+1)], dim=0).sum(dim=0) |
| weight_v = torch.stack([torch.mm(self.lora_B_v[t].weight, self.lora_A_v[t].weight) for t in range(task+1)], dim=0).sum(dim=0) |
| k = k + F.linear(x, weight_k).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
| v = v + F.linear(x, weight_v).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x |
|
|
| def get_matrix(self, task): |
| matrix_k = torch.mm(self.lora_B_k[task].weight, self.lora_A_k[task].weight) |
| matrix_v = torch.mm(self.lora_B_v[task].weight, self.lora_A_v[task].weight) |
| return matrix_k, matrix_v |
| |
| def get_pre_matrix(self, task): |
| with torch.no_grad(): |
| weight_k = torch.stack([torch.mm(self.lora_B_k[t].weight, self.lora_A_k[t].weight) for t in range(task)], dim=0).sum(dim=0) |
| weight_v = torch.stack([torch.mm(self.lora_B_v[t].weight, self.lora_A_v[t].weight) for t in range(task)], dim=0).sum(dim=0) |
| return weight_k, weight_v |
|
|
| class LayerScale(nn.Module): |
| def __init__(self, dim, init_values=1e-5, inplace=False): |
| super().__init__() |
| self.inplace = inplace |
| self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
| def forward(self, x): |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
| class Block(nn.Module): |
|
|
| def __init__( |
| self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, |
| drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, n_tasks=10, r=64): |
| super().__init__() |
| self.norm1 = norm_layer(dim) |
| self.attn = Attention_LoRA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, n_tasks=n_tasks, r=r) |
| self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
| |
| self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
| self.norm2 = norm_layer(dim) |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
| self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
| self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
| def forward(self, x, task, register_hook=False, get_feat=False, get_cur_feat=False): |
| x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), task, register_hook=register_hook, get_feat=get_feat, get_cur_feat=get_cur_feat))) |
| x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) |
| return x |
|
|
|
|
| class ParallelBlock(nn.Module): |
|
|
| def __init__( |
| self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, |
| drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
| super().__init__() |
| self.num_parallel = num_parallel |
| self.attns = nn.ModuleList() |
| self.ffns = nn.ModuleList() |
| for _ in range(num_parallel): |
| self.attns.append(nn.Sequential(OrderedDict([ |
| ('norm', norm_layer(dim)), |
| ('attn', Attention_LoRA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), |
| ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), |
| ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) |
| ]))) |
| self.ffns.append(nn.Sequential(OrderedDict([ |
| ('norm', norm_layer(dim)), |
| ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), |
| ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), |
| ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) |
| ]))) |
|
|
| def _forward_jit(self, x): |
| x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) |
| x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) |
| return x |
|
|
| @torch.jit.ignore |
| def _forward(self, x): |
| x = x + sum(attn(x) for attn in self.attns) |
| x = x + sum(ffn(x) for ffn in self.ffns) |
| return x |
|
|
| def forward(self, x): |
| if torch.jit.is_scripting() or torch.jit.is_tracing(): |
| return self._forward_jit(x) |
| else: |
| return self._forward(x) |
|
|
|
|
| class VisionTransformer(nn.Module): |
| """ Vision Transformer |
| A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` |
| - https://arxiv.org/abs/2010.11929 |
| """ |
|
|
| def __init__( |
| self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', |
| embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, |
| drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, |
| embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, n_tasks=10, rank=64): |
| """ |
| Args: |
| img_size (int, tuple): input image size |
| patch_size (int, tuple): patch size |
| in_chans (int): number of input channels |
| num_classes (int): number of classes for classification head |
| global_pool (str): type of global pooling for final sequence (default: 'token') |
| embed_dim (int): embedding dimension |
| depth (int): depth of transformer |
| num_heads (int): number of attention heads |
| mlp_ratio (int): ratio of mlp hidden dim to embedding dim |
| qkv_bias (bool): enable bias for qkv if True |
| representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set |
| drop_rate (float): dropout rate |
| attn_drop_rate (float): attention dropout rate |
| drop_path_rate (float): stochastic depth rate |
| weight_init: (str): weight init scheme |
| init_values: (float): layer-scale init values |
| embed_layer (nn.Module): patch embedding layer |
| norm_layer: (nn.Module): normalization layer |
| act_layer: (nn.Module): MLP activation layer |
| """ |
| super().__init__() |
| assert global_pool in ('', 'avg', 'token') |
| norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
| act_layer = act_layer or nn.GELU |
|
|
| self.num_classes = num_classes |
| self.global_pool = global_pool |
| self.num_features = self.embed_dim = embed_dim |
| self.num_tokens = 1 |
| self.grad_checkpointing = False |
|
|
| self.patch_embed = embed_layer( |
| img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) |
| num_patches = self.patch_embed.num_patches |
|
|
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| self.cls_token_grow = nn.Parameter(torch.zeros(1, 5000, embed_dim)) |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) |
| self.pos_embed_grow = nn.Parameter(torch.zeros(1, num_patches + 1000, embed_dim)) |
| self.pos_drop = nn.Dropout(p=drop_rate) |
|
|
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
| self.blocks = nn.Sequential(*[ |
| block_fn( |
| dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, |
| drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,n_tasks=n_tasks,r=rank) |
| for i in range(depth)]) |
| use_fc_norm = self.global_pool == 'avg' |
| self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() |
|
|
| |
| self.representation_size = representation_size |
| self.pre_logits = nn.Identity() |
| if representation_size: |
| self._reset_representation(representation_size) |
|
|
| |
| self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() |
| final_chs = self.representation_size if self.representation_size else self.embed_dim |
| self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() |
| self.out_dim = final_chs |
|
|
| if weight_init != 'skip': |
| self.init_weights(weight_init) |
|
|
| def _reset_representation(self, representation_size): |
| self.representation_size = representation_size |
| if self.representation_size: |
| self.pre_logits = nn.Sequential(OrderedDict([ |
| ('fc', nn.Linear(self.embed_dim, self.representation_size)), |
| ('act', nn.Tanh()) |
| ])) |
| else: |
| self.pre_logits = nn.Identity() |
|
|
| def init_weights(self, mode=''): |
| assert mode in ('jax', 'jax_nlhb', 'moco', '') |
| head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. |
| trunc_normal_(self.pos_embed, std=.02) |
| trunc_normal_(self.pos_embed_grow, std=.02) |
| nn.init.normal_(self.cls_token, std=1e-6) |
| nn.init.normal_(self.cls_token_grow, std=1e-6) |
| named_apply(get_init_weights_vit(mode, head_bias), self) |
|
|
| def _init_weights(self, m): |
| |
| init_weights_vit_timm(m) |
|
|
| @torch.jit.ignore() |
| def load_pretrained(self, checkpoint_path, prefix=''): |
| _load_weights(self, checkpoint_path, prefix) |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| return {'pos_embed', 'cls_token', 'dist_token'} |
|
|
| @torch.jit.ignore |
| def group_matcher(self, coarse=False): |
| return dict( |
| stem=r'^cls_token|pos_embed|patch_embed', |
| blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] |
| ) |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| self.grad_checkpointing = enable |
|
|
| @torch.jit.ignore |
| def get_classifier(self): |
| return self.head |
|
|
| def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None): |
| self.num_classes = num_classes |
| if global_pool is not None: |
| assert global_pool in ('', 'avg', 'token') |
| self.global_pool = global_pool |
| if representation_size is not None: |
| self._reset_representation(representation_size) |
| final_chs = self.representation_size if self.representation_size else self.embed_dim |
| self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() |
|
|
| def forward_features(self, x): |
| x = self.patch_embed(x) |
| x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) |
|
|
| x = self.pos_drop(x + self.pos_embed) |
| if self.grad_checkpointing and not torch.jit.is_scripting(): |
| x = checkpoint_seq(self.blocks, x) |
| else: |
| x = self.blocks(x) |
| x = self.norm(x) |
| return x |
|
|
| def forward_features_grow(self, x, class_num): |
| x = self.patch_embed(x) |
| |
| |
| x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) |
| x = self.pos_drop(x + self.pos_embed) |
| x = torch.cat((self.cls_token_grow[:, :class_num*2, :].expand(x.shape[0], -1, -1), x), dim=1) |
|
|
| |
| if self.grad_checkpointing and not torch.jit.is_scripting(): |
| x = checkpoint_seq(self.blocks, x) |
| else: |
| x = self.blocks(x) |
| x = self.norm(x) |
| return x |
|
|
| def forward_head(self, x, pre_logits: bool = False): |
| if self.global_pool: |
| x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] |
| x = self.fc_norm(x) |
| x = self.pre_logits(x) |
| return x if pre_logits else self.head(x) |
|
|
| def forward(self, x, grow_flag=False, numcls=0): |
| if not grow_flag: |
| x = self.forward_features(x) |
| else: |
| x = self.forward_features_grow(x, numcls) |
|
|
| if self.global_pool: |
| x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] |
| x = self.fc_norm(x) |
| return { |
| 'fmaps': [x], |
| 'features': x |
| } |
|
|
|
|
| def init_weights_vit_timm(module: nn.Module, name: str = ''): |
| """ ViT weight initialization, original timm impl (for reproducibility) """ |
| if isinstance(module, nn.Linear): |
| trunc_normal_(module.weight, std=.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
|
|
| def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): |
| """ ViT weight initialization, matching JAX (Flax) impl """ |
| if isinstance(module, nn.Linear): |
| if name.startswith('head'): |
| nn.init.zeros_(module.weight) |
| nn.init.constant_(module.bias, head_bias) |
| elif name.startswith('pre_logits'): |
| lecun_normal_(module.weight) |
| nn.init.zeros_(module.bias) |
| else: |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Conv2d): |
| lecun_normal_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
|
|
| def init_weights_vit_moco(module: nn.Module, name: str = ''): |
| """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ |
| if isinstance(module, nn.Linear): |
| if 'qkv' in name: |
| |
| val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) |
| nn.init.uniform_(module.weight, -val, val) |
| else: |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
|
|
| def get_init_weights_vit(mode='jax', head_bias: float = 0.): |
| if 'jax' in mode: |
| return partial(init_weights_vit_jax, head_bias=head_bias) |
| elif 'moco' in mode: |
| return init_weights_vit_moco |
| else: |
| return init_weights_vit_timm |
|
|
|
|
| @torch.no_grad() |
| def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): |
| """ Load weights from .npz checkpoints for official Google Brain Flax implementation |
| """ |
| import numpy as np |
|
|
| def _n2p(w, t=True): |
| if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: |
| w = w.flatten() |
| if t: |
| if w.ndim == 4: |
| w = w.transpose([3, 2, 0, 1]) |
| elif w.ndim == 3: |
| w = w.transpose([2, 0, 1]) |
| elif w.ndim == 2: |
| w = w.transpose([1, 0]) |
| return torch.from_numpy(w) |
|
|
| w = np.load(checkpoint_path) |
| if not prefix and 'opt/target/embedding/kernel' in w: |
| prefix = 'opt/target/' |
|
|
| if hasattr(model.patch_embed, 'backbone'): |
| |
| backbone = model.patch_embed.backbone |
| stem_only = not hasattr(backbone, 'stem') |
| stem = backbone if stem_only else backbone.stem |
| stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) |
| stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) |
| stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) |
| if not stem_only: |
| for i, stage in enumerate(backbone.stages): |
| for j, block in enumerate(stage.blocks): |
| bp = f'{prefix}block{i + 1}/unit{j + 1}/' |
| for r in range(3): |
| getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) |
| getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) |
| getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) |
| if block.downsample is not None: |
| block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) |
| block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) |
| block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) |
| embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) |
| else: |
| embed_conv_w = adapt_input_conv( |
| model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) |
| model.patch_embed.proj.weight.copy_(embed_conv_w) |
| model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) |
| model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) |
| pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) |
| if pos_embed_w.shape != model.pos_embed.shape: |
| pos_embed_w = resize_pos_embed( |
| pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) |
| model.pos_embed.copy_(pos_embed_w) |
| model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) |
| model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) |
| if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: |
| model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) |
| model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) |
| if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: |
| model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) |
| model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) |
| for i, block in enumerate(model.blocks.children()): |
| block_prefix = f'{prefix}Transformer/encoderblock_{i}/' |
| mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' |
| block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) |
| block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) |
| block.attn.qkv.weight.copy_(torch.cat([ |
| _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) |
| block.attn.qkv.bias.copy_(torch.cat([ |
| _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) |
| block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) |
| block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) |
| for r in range(2): |
| getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) |
| getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) |
| block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) |
| block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) |
|
|
|
|
| def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): |
| |
| |
| _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) |
| ntok_new = posemb_new.shape[1] |
| if num_tokens: |
| posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] |
| ntok_new -= num_tokens |
| else: |
| posemb_tok, posemb_grid = posemb[:, :0], posemb[0] |
| gs_old = int(math.sqrt(len(posemb_grid))) |
| if not len(gs_new): |
| gs_new = [int(math.sqrt(ntok_new))] * 2 |
| assert len(gs_new) >= 2 |
| _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) |
| posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) |
| posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) |
| posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) |
| posemb = torch.cat([posemb_tok, posemb_grid], dim=1) |
| return posemb |
|
|
|
|
| def checkpoint_filter_fn(state_dict, model): |
| """ convert patch embedding weight from manual patchify + linear proj to conv""" |
| out_dict = {} |
| if 'model' in state_dict: |
| |
| state_dict = state_dict['model'] |
| for k, v in state_dict.items(): |
| if 'patch_embed.proj.weight' in k and len(v.shape) < 4: |
| |
| O, I, H, W = model.patch_embed.proj.weight.shape |
| v = v.reshape(O, -1, H, W) |
| elif k == 'pos_embed' and v.shape != model.pos_embed.shape: |
| |
| v = resize_pos_embed( |
| v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) |
| out_dict[k] = v |
| return out_dict |
|
|
|
|
| def _create_vision_transformer(variant, pretrained=False, **kwargs): |
| if kwargs.get('features_only', None): |
| raise RuntimeError('features_only not implemented for Vision Transformer models.') |
|
|
| |
| |
| pretrained_cfg = resolve_pretrained_cfg(variant) |
| default_num_classes = pretrained_cfg['num_classes'] |
| num_classes = kwargs.get('num_classes', default_num_classes) |
| repr_size = kwargs.pop('representation_size', None) |
| if repr_size is not None and num_classes != default_num_classes: |
| |
| |
| _logger.warning("Removing representation layer for fine-tuning.") |
| repr_size = None |
|
|
| if pretrained_cfg: |
| del kwargs['pretrained_cfg'] |
|
|
| model = build_model_with_cfg( |
| VisionTransformer, variant, pretrained, |
| pretrained_cfg=pretrained_cfg, |
| representation_size=repr_size, |
| pretrained_filter_fn=checkpoint_filter_fn, |
| pretrained_custom_load='npz' in pretrained_cfg['url'], |
| **kwargs) |
| return model |