duycse1603's picture
[Add] source
6163604
import torch.nn as nn
import torch
from torch.nn import functional as F
from .vit.utils import trunc_normal_
from .vit.vision_transformer import VisionTransformer
from ..feature_extractor.clova_impl import ResNet_FeatureExtractor
from .addon_module import *
from ..common.mae_posembed import get_2d_sincos_pos_embed
__all__ = ['ViTEncoder', 'ViTEncoderV2', 'ViTEncoderV3', 'TRIGBaseEncoder', 'create_vit_modeling']
class ViTEncoder(VisionTransformer):
'''
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if kwargs['hybrid_backbone'] is None:
self.patch_embed = PatchEmbed(
img_size=kwargs['img_size'],
in_chans=kwargs['in_chans'],
patch_size=kwargs['patch_size'],
embed_dim=kwargs['embed_dim'],
)
else:
self.patch_embed = HybridEmbed(
backbone=kwargs['hybrid_backbone'],
img_size=kwargs['img_size'],
in_chans=kwargs['in_chans'],
patch_size=kwargs['patch_size'],
embed_dim=kwargs['embed_dim'],
)
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']))
self.emb_height = self.patch_embed.grid_size[0]
self.emb_width = self.patch_embed.grid_size[1]
trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def reset_classifier(self, num_classes):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def interpolating_pos_embedding(self, embedding, height, width):
"""
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
npatch = embedding.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and height == width:
return self.pos_embed
class_pos_embedding = self.pos_embed[:, 0]
patch_pos_embedding = self.pos_embed[:, 1:]
dim = self.pos_embed.shape[-1]
h0 = height // self.patch_embed.patch_size[0]
w0 = width // self.patch_embed.patch_size[1]
#add a small number to avo_id floating point error
# https://github.com/facebookresearch/dino/issues/8
h0 = h0 + 0.1
w0 = w0 + 0.1
patch_pos_embedding = nn.functional.interpolate(
patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2),
scale_factor=(h0 / self.emb_height, w0 / self.emb_width),
mode='bicubic',
align_corners=False
)
assert int(h0) == patch_pos_embedding.shape[-2] and int(w0) == patch_pos_embedding.shape[-1]
patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
class_pos_embedding = class_pos_embedding.unsqueeze(0)
return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1)
def forward_features(self, x):
B, C, _, _ = x.shape
x, pad_info, size, interpolating_pos = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if interpolating_pos:
x = x + self.interpolating_pos_embedding(x, size['height'], size['width'])
else:
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, pad_info, size
class TRIGBaseEncoder(ViTEncoder):
'''
https://arxiv.org/pdf/2111.08314.pdf
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.patch_embed = HybridEmbed1D(
backbone=kwargs['hybrid_backbone'],
img_size=kwargs['img_size'],
in_chans=kwargs['in_chans'],
patch_size=kwargs['patch_size'],
embed_dim=kwargs['embed_dim'],
)
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']))
self.emb_height = 1
self.emb_width = self.patch_embed.grid_size[1]
trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def interpolating_pos_embedding(self, embedding, height, width):
"""
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
npatch = embedding.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and height == width:
return self.pos_embed
class_pos_embedding = self.pos_embed[:, 0]
patch_pos_embedding = self.pos_embed[:, 1:]
dim = self.pos_embed.shape[-1]
w0 = width // self.patch_embed.window_width
#add a small number to avoid floating point error
# https://github.com/facebookresearch/dino/issues/8
w0 = w0 + 0.1
patch_pos_embedding = nn.functional.interpolate(
patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2),
scale_factor=(1, w0 / self.emb_width),
mode='bicubic',
align_corners=False
)
assert int(w0) == patch_pos_embedding.shape[-1]
patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
class_pos_embedding = class_pos_embedding.unsqueeze(0)
return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1)
def forward_features(self, x):
B, _, _, _ = x.shape
x, padinfo, size, interpolating_pos = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1) #cls_tokens is init_embedding in TRIG paper
if interpolating_pos:
x = x + self.interpolating_pos_embedding(x, size['height'], size['width'])
else:
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, padinfo, size
class ViTEncoderV2(ViTEncoder):
def forward(self, x):
B, _, _, _ = x.shape
x, pad_info, size, _ = self.patch_embed(x)
_, numpatches, *_ = x.shape
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[:, :(numpatches + 1)]
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, pad_info, size
class ViTEncoderV3(ViTEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if hasattr(self, 'pos_embed'):
del self.pos_embed
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']), requires_grad=False)
self.initialize_posembed()
def initialize_posembed(self):
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
self.patch_embed.grid_size[0],
self.patch_embed.grid_size[1],
cls_token=True
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
def forward(self, x):
B, _, _, _ = x.shape
x, pad_info, size, _ = self.patch_embed(x)
_, numpatches, *_ = x.shape
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[:, :(numpatches + 1)]
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, pad_info, size
def create_vit_modeling(opt):
seq_modeling = opt['SequenceModeling']['params']
if seq_modeling['backbone'] is not None:
if seq_modeling['backbone']['name'] == 'resnet':
param_kwargs = dict()
if seq_modeling['backbone'].get('pretrained', None) is not None:
param_kwargs['pretrained'] = seq_modeling['backbone']['pretrained']
if seq_modeling['backbone'].get('weight_dir', None) is not None:
param_kwargs['weight_dir'] = seq_modeling['backbone']['weight_dir']
print('kwargs', param_kwargs)
backbone = ResNet_FeatureExtractor(
seq_modeling['backbone']['input_channel'],
seq_modeling['backbone']['output_channel'],
seq_modeling['backbone']['gcb'],
**param_kwargs
)
elif seq_modeling['backbone']['name'] == 'cnn':
backbone = None
else: backbone = None
max_dimension = (opt['imgH'], opt['max_dimension'][1]) if opt['imgH'] else opt['max_dimension']
if seq_modeling['patching_style'] == '2d':
if seq_modeling.get('fix_embed', False):
encoder = ViTEncoderV3
else:
if not seq_modeling.get('interpolate_embed', True):
encoder = ViTEncoderV2
else:
encoder = ViTEncoder
else:
encoder = TRIGBaseEncoder
encoder_seq_modeling = encoder(
img_size=max_dimension,
patch_size=seq_modeling['patch_size'],
in_chans=seq_modeling['input_channel'],
depth=seq_modeling['depth'],
num_classes=0,
embed_dim=seq_modeling['hidden_size'],
num_heads=seq_modeling['num_heads'],
hybrid_backbone=backbone
)
return encoder_seq_modeling