GenD-Sentinel / src /model /fsfm /models_vit.py
yermandy's picture
init
c29babb
# -*- coding: utf-8 -*-
# Author: Gaojian Wang@ZJUICSR
# --------------------------------------------------------
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
# You can find the license in the LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from functools import partial
import timm.models.vision_transformer
import torch
import torch.nn as nn
from timm.models.helpers import load_pretrained
from timm.models.vision_transformer import default_cfgs
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
"""Vision Transformer with support for global average pooling"""
def __init__(self, global_pool=False, **kwargs):
super(VisionTransformer, self).__init__(**kwargs)
self.global_pool = global_pool
if self.global_pool:
norm_layer = kwargs["norm_layer"]
embed_dim = kwargs["embed_dim"]
self.fc_norm = norm_layer(embed_dim)
del self.norm # remove the original norm
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
# if self.global_pool:
# x = x[:, 1:, :].mean(dim=1) # global pool without cls token
# outcome = self.fc_norm(x)
if self.global_pool:
x_gp = x[:, 1:, :].mean(dim=1) # global pool without cls token
outcome = self.fc_norm(x_gp)
# x_new = torch.zeros_like(x)
# x_new[:, 0, :] = x_gp
# x_new[:, 1:, :] = x[:, 1:, :]
# outcome = x_new
else:
x = self.norm(x)
# outcome = x[:, 0]
outcome = x # for fas code
return outcome
# def reset_classifier(self, num_classes, global_pool=''):
# self.num_classes = num_classes
# self.head = nn.ModuleList([
# nn.Linear(self.embed_dim, 512),
# nn.Linear(512, num_classes) if num_classes > 0 else nn.Identity()
# ])
#
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _conv_filter(state_dict, patch_size=16):
"""convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
for k, v in state_dict.items():
if "patch_embed.proj.weight" in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
def vit_small_patch16(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16,
embed_dim=384,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
) # ViT-small config in MOCO_V3
# model = VisionTransformer(
# patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, qkv_bias=True,
# norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # ViT-small config in timm
model.default_cfg = default_cfgs["vit_small_patch16_224"]
# if pretrained:
# load_pretrained(
# model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
# for timm version 0.6.12:
# pretrained_cfg = resolve_pretrained_cfg('vit_base_patch16_224',
# pretrained_cfg=kwargs.pop('pretrained_cfg', None))
# load_pretrained(
# model, pretrained_cfg=pretrained_cfg,
# num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
return model
def vit_base_patch16(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
model.default_cfg = default_cfgs["vit_base_patch16_224"]
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get("in_chans", 3), filter_fn=_conv_filter
)
# for timm version 0.6.12:
# pretrained_cfg = resolve_pretrained_cfg('vit_base_patch16_224',
# pretrained_cfg=kwargs.pop('pretrained_cfg', None))
# load_pretrained(
# model, pretrained_cfg=pretrained_cfg,
# num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
return model
def vit_large_patch16(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
model.default_cfg = default_cfgs["vit_large_patch16_224"]
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get("in_chans", 3))
# for timm version 0.6.12:
# pretrained_cfg = resolve_pretrained_cfg('vit_large_patch16_224',
# pretrained_cfg=kwargs.pop('pretrained_cfg', None))
# load_pretrained(
# model, pretrained_cfg=pretrained_cfg,
# num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
def vit_huge_patch14(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model