File size: 2,205 Bytes
2659b26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | import torch
import torch.nn as nn
import timm
from timm.models.vision_transformer import PatchEmbed
from functools import partial
class vit(timm.models.vision_transformer.VisionTransformer):
def __init__(self, global_pool=False, **kwargs):
super(vit, self).__init__()
self.global_pool = global_pool
embed_dim = kwargs['embed_dim']
num_classes = kwargs['num_classes']
self.head = nn.Linear(embed_dim, num_classes, bias=True)
if self.global_pool:
norm_layer = kwargs['norm_layer']
embed_dim = kwargs['embed_dim']
self.fc_norm = norm_layer(embed_dim)
del self.norm
for param in self.parameters():
param.requires_grad = False
for param in self.head.parameters():
param.requires_grad = True
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
if self.global_pool:
x = x[:, 1:, :].mean(dim=1)
outcome = self.fc_norm(x)
else:
x = self.norm(x)
outcome = x[:, 0]
return outcome
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def vit_base_patch16(**kwargs):
model = vit(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_large_patch16(**kwargs):
model = vit(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_huge_patch14(**kwargs):
model = vit(patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
|