Spaces:
Build error
Build error
| # This source code is written based on https://github.com/facebookresearch/MCC | |
| # The original code base is licensed under the license found in the LICENSE file in the root directory. | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| from functools import partial | |
| from timm.models.vision_transformer import Block, PatchEmbed | |
| from utils.pos_embed import get_2d_sincos_pos_embed | |
| from utils.layers import Bottleneck_Conv | |
| class RGBEncAtt(nn.Module): | |
| """ | |
| Seen surface encoder based on transformer. | |
| """ | |
| def __init__(self, | |
| img_size=224, embed_dim=768, n_blocks=12, num_heads=12, win_size=16, | |
| mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1): | |
| super().__init__() | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.rgb_embed = PatchEmbed(img_size, win_size, 3, embed_dim) | |
| num_patches = self.rgb_embed.num_patches | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) | |
| self.blocks = nn.ModuleList([ | |
| Block( | |
| embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, | |
| drop_path=drop_path | |
| ) for _ in range(n_blocks)]) | |
| self.norm = norm_layer(embed_dim) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # initialize the pos enc with fixed cos-sin pattern | |
| pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.rgb_embed.num_patches**.5), cls_token=True) | |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
| # initialize rgb patch_embed like nn.Linear (instead of nn.Conv2d) | |
| w = self.rgb_embed.proj.weight.data | |
| torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
| torch.nn.init.normal_(self.cls_token, std=.02) | |
| # initialize nn.Linear and nn.LayerNorm | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| # we use xavier_uniform following official JAX ViT: | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward(self, rgb_obj): | |
| # [B, H/ws*W/ws, C] | |
| rgb_embedding = self.rgb_embed(rgb_obj) | |
| rgb_embedding = rgb_embedding + self.pos_embed[:, 1:, :] | |
| # append cls token | |
| # [1, 1, C] | |
| cls_token = self.cls_token + self.pos_embed[:, :1, :] | |
| # [B, 1, C] | |
| cls_tokens = cls_token.expand(rgb_embedding.shape[0], -1, -1) | |
| # [B, H/ws*W/ws+1, C] | |
| rgb_embedding = torch.cat((cls_tokens, rgb_embedding), dim=1) | |
| # apply Transformer blocks | |
| for blk in self.blocks: | |
| rgb_embedding = blk(rgb_embedding) | |
| rgb_embedding = self.norm(rgb_embedding) | |
| # [B, H/ws*W/ws+1, C] | |
| return rgb_embedding | |
| class RGBEncRes(nn.Module): | |
| """ | |
| RGB encoder based on resnet. | |
| """ | |
| def __init__(self, opt): | |
| super().__init__() | |
| self.encoder = torchvision.models.resnet50(pretrained=True) | |
| self.encoder.fc = nn.Sequential( | |
| Bottleneck_Conv(2048), | |
| Bottleneck_Conv(2048), | |
| nn.Linear(2048, opt.arch.latent_dim) | |
| ) | |
| # define hooks | |
| self.rgb_feature = None | |
| def feature_hook(model, input, output): | |
| self.rgb_feature = output | |
| # attach hooks | |
| if (opt.arch.win_size) == 16: | |
| self.encoder.layer3.register_forward_hook(feature_hook) | |
| self.rgb_feat_proj = nn.Sequential( | |
| Bottleneck_Conv(1024), | |
| Bottleneck_Conv(1024), | |
| nn.Conv2d(1024, opt.arch.latent_dim, 1) | |
| ) | |
| elif (opt.arch.win_size) == 32: | |
| self.encoder.layer4.register_forward_hook(feature_hook) | |
| self.rgb_feat_proj = nn.Sequential( | |
| Bottleneck_Conv(2048), | |
| Bottleneck_Conv(2048), | |
| nn.Conv2d(2048, opt.arch.latent_dim, 1) | |
| ) | |
| else: | |
| print('Make sure win_size is 16 or 32 when using resnet backbone!') | |
| raise NotImplementedError | |
| def forward(self, rgb_obj): | |
| batch_size = rgb_obj.shape[0] | |
| assert len(rgb_obj.shape) == 4 | |
| # [B, 1, C] | |
| global_feat = self.encoder(rgb_obj).unsqueeze(1) | |
| # [B, C, H/ws*W/ws] | |
| local_feat = self.rgb_feat_proj(self.rgb_feature).view(batch_size, global_feat.shape[-1], -1) | |
| # [B, H/ws*W/ws, C] | |
| local_feat = local_feat.permute(0, 2, 1).contiguous() | |
| # [B, 1+H/ws*W/ws, C] | |
| rgb_embedding = torch.cat([global_feat, local_feat], dim=1) | |
| return rgb_embedding | |