import torch.nn as nn import torch from torch.nn import functional as F from .vit.utils import trunc_normal_ from .vit.vision_transformer import VisionTransformer from ..feature_extractor.clova_impl import ResNet_FeatureExtractor from .addon_module import * from ..common.mae_posembed import get_2d_sincos_pos_embed __all__ = ['ViTEncoder', 'ViTEncoderV2', 'ViTEncoderV3', 'TRIGBaseEncoder', 'create_vit_modeling'] class ViTEncoder(VisionTransformer): ''' ''' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if kwargs['hybrid_backbone'] is None: self.patch_embed = PatchEmbed( img_size=kwargs['img_size'], in_chans=kwargs['in_chans'], patch_size=kwargs['patch_size'], embed_dim=kwargs['embed_dim'], ) else: self.patch_embed = HybridEmbed( backbone=kwargs['hybrid_backbone'], img_size=kwargs['img_size'], in_chans=kwargs['in_chans'], patch_size=kwargs['patch_size'], embed_dim=kwargs['embed_dim'], ) num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim'])) self.emb_height = self.patch_embed.grid_size[0] self.emb_width = self.patch_embed.grid_size[1] trunc_normal_(self.pos_embed, std=.02) self.apply(self._init_weights) def reset_classifier(self, num_classes): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def interpolating_pos_embedding(self, embedding, height, width): """ Source: https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 """ npatch = embedding.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and height == width: return self.pos_embed class_pos_embedding = self.pos_embed[:, 0] patch_pos_embedding = self.pos_embed[:, 1:] dim = self.pos_embed.shape[-1] h0 = height // self.patch_embed.patch_size[0] w0 = width // self.patch_embed.patch_size[1] #add a small number to avo_id floating point error # https://github.com/facebookresearch/dino/issues/8 h0 = h0 + 0.1 w0 = w0 + 0.1 patch_pos_embedding = nn.functional.interpolate( patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2), scale_factor=(h0 / self.emb_height, w0 / self.emb_width), mode='bicubic', align_corners=False ) assert int(h0) == patch_pos_embedding.shape[-2] and int(w0) == patch_pos_embedding.shape[-1] patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) class_pos_embedding = class_pos_embedding.unsqueeze(0) return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1) def forward_features(self, x): B, C, _, _ = x.shape x, pad_info, size, interpolating_pos = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if interpolating_pos: x = x + self.interpolating_pos_embedding(x, size['height'], size['width']) else: x = x + self.pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return x, pad_info, size class TRIGBaseEncoder(ViTEncoder): ''' https://arxiv.org/pdf/2111.08314.pdf ''' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.patch_embed = HybridEmbed1D( backbone=kwargs['hybrid_backbone'], img_size=kwargs['img_size'], in_chans=kwargs['in_chans'], patch_size=kwargs['patch_size'], embed_dim=kwargs['embed_dim'], ) num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim'])) self.emb_height = 1 self.emb_width = self.patch_embed.grid_size[1] trunc_normal_(self.pos_embed, std=.02) self.apply(self._init_weights) def interpolating_pos_embedding(self, embedding, height, width): """ Source: https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 """ npatch = embedding.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and height == width: return self.pos_embed class_pos_embedding = self.pos_embed[:, 0] patch_pos_embedding = self.pos_embed[:, 1:] dim = self.pos_embed.shape[-1] w0 = width // self.patch_embed.window_width #add a small number to avoid floating point error # https://github.com/facebookresearch/dino/issues/8 w0 = w0 + 0.1 patch_pos_embedding = nn.functional.interpolate( patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2), scale_factor=(1, w0 / self.emb_width), mode='bicubic', align_corners=False ) assert int(w0) == patch_pos_embedding.shape[-1] patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) class_pos_embedding = class_pos_embedding.unsqueeze(0) return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1) def forward_features(self, x): B, _, _, _ = x.shape x, padinfo, size, interpolating_pos = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) #cls_tokens is init_embedding in TRIG paper if interpolating_pos: x = x + self.interpolating_pos_embedding(x, size['height'], size['width']) else: x = x + self.pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return x, padinfo, size class ViTEncoderV2(ViTEncoder): def forward(self, x): B, _, _, _ = x.shape x, pad_info, size, _ = self.patch_embed(x) _, numpatches, *_ = x.shape cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed[:, :(numpatches + 1)] x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return x, pad_info, size class ViTEncoderV3(ViTEncoder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if hasattr(self, 'pos_embed'): del self.pos_embed num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']), requires_grad=False) self.initialize_posembed() def initialize_posembed(self): pos_embed = get_2d_sincos_pos_embed( self.pos_embed.shape[-1], self.patch_embed.grid_size[0], self.patch_embed.grid_size[1], cls_token=True ) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) def forward(self, x): B, _, _, _ = x.shape x, pad_info, size, _ = self.patch_embed(x) _, numpatches, *_ = x.shape cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed[:, :(numpatches + 1)] x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return x, pad_info, size def create_vit_modeling(opt): seq_modeling = opt['SequenceModeling']['params'] if seq_modeling['backbone'] is not None: if seq_modeling['backbone']['name'] == 'resnet': param_kwargs = dict() if seq_modeling['backbone'].get('pretrained', None) is not None: param_kwargs['pretrained'] = seq_modeling['backbone']['pretrained'] if seq_modeling['backbone'].get('weight_dir', None) is not None: param_kwargs['weight_dir'] = seq_modeling['backbone']['weight_dir'] print('kwargs', param_kwargs) backbone = ResNet_FeatureExtractor( seq_modeling['backbone']['input_channel'], seq_modeling['backbone']['output_channel'], seq_modeling['backbone']['gcb'], **param_kwargs ) elif seq_modeling['backbone']['name'] == 'cnn': backbone = None else: backbone = None max_dimension = (opt['imgH'], opt['max_dimension'][1]) if opt['imgH'] else opt['max_dimension'] if seq_modeling['patching_style'] == '2d': if seq_modeling.get('fix_embed', False): encoder = ViTEncoderV3 else: if not seq_modeling.get('interpolate_embed', True): encoder = ViTEncoderV2 else: encoder = ViTEncoder else: encoder = TRIGBaseEncoder encoder_seq_modeling = encoder( img_size=max_dimension, patch_size=seq_modeling['patch_size'], in_chans=seq_modeling['input_channel'], depth=seq_modeling['depth'], num_classes=0, embed_dim=seq_modeling['hidden_size'], num_heads=seq_modeling['num_heads'], hybrid_backbone=backbone ) return encoder_seq_modeling