Spaces:
Build error
Build error
| # This source code is written based on https://github.com/facebookresearch/MCC | |
| # The original code base is licensed under the license found in the LICENSE file in the root directory. | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| from functools import partial | |
| from timm.models.vision_transformer import Block | |
| from utils.pos_embed import get_2d_sincos_pos_embed | |
| from utils.layers import Bottleneck_Conv | |
| class CoordEmb(nn.Module): | |
| """ | |
| Encode the seen coordinate map to a lower resolution feature map | |
| Achieved with window-wise attention block by deviding coord map into windows | |
| Each window is seperately encoded into a single CLS token with self-attention and posenc | |
| """ | |
| def __init__(self, embed_dim, win_size=8, num_heads=8): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.win_size = win_size | |
| self.two_d_pos_embed = nn.Parameter( | |
| torch.zeros(1, self.win_size*self.win_size + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.pos_embed = nn.Linear(3, embed_dim) | |
| self.blocks = nn.ModuleList([ | |
| # each block is a residual block with layernorm -> attention -> layernorm -> mlp | |
| Block(embed_dim, num_heads=num_heads, mlp_ratio=2.0, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) | |
| for _ in range(1) | |
| ]) | |
| self.invalid_coord_token = nn.Parameter(torch.zeros(embed_dim,)) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| torch.nn.init.normal_(self.cls_token, std=.02) | |
| two_d_pos_embed = get_2d_sincos_pos_embed(self.two_d_pos_embed.shape[-1], self.win_size, cls_token=True) | |
| self.two_d_pos_embed.data.copy_(torch.from_numpy(two_d_pos_embed).float().unsqueeze(0)) | |
| torch.nn.init.normal_(self.invalid_coord_token, std=.02) | |
| def forward(self, coord_obj, mask_obj): | |
| # [B, H, W, C] | |
| emb = self.pos_embed(coord_obj) | |
| emb[~mask_obj] = 0.0 | |
| emb[~mask_obj] += self.invalid_coord_token | |
| B, H, W, C = emb.shape | |
| # [B, H/ws, 8, W/ws, W, C] | |
| emb = emb.view(B, H // self.win_size, self.win_size, W // self.win_size, self.win_size, C) | |
| # [B * H/ws * W/ws, 64, C] | |
| emb = emb.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.win_size * self.win_size, C) | |
| # [B * H/ws * W/ws, 64, C], add posenc that is local to each patch | |
| emb = emb + self.two_d_pos_embed[:, 1:, :] | |
| # [1, 1, C] | |
| cls_token = self.cls_token + self.two_d_pos_embed[:, :1, :] | |
| # [B * H/ws * W/ws, 1, C] | |
| cls_tokens = cls_token.expand(emb.shape[0], -1, -1) | |
| # [B * H/ws * W/ws, 65, C] | |
| emb = torch.cat((cls_tokens, emb), dim=1) | |
| # transformer (single block) that handle each of the patch seperately | |
| # reasoning is done within each batch | |
| for _, blk in enumerate(self.blocks): | |
| emb = blk(emb) | |
| # return the cls token of each window, [B, H/ws*W/ws, C] | |
| return emb[:, 0].view(B, (H // self.win_size) * (W // self.win_size), -1) | |
| class CoordEncAtt(nn.Module): | |
| """ | |
| Seen surface encoder based on transformer. | |
| """ | |
| def __init__(self, | |
| embed_dim=768, n_blocks=12, num_heads=12, win_size=8, | |
| mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1): | |
| super().__init__() | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.coord_embed = CoordEmb(embed_dim, win_size, num_heads) | |
| self.blocks = nn.ModuleList([ | |
| Block( | |
| embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, | |
| drop_path=drop_path | |
| ) for _ in range(n_blocks)]) | |
| self.norm = norm_layer(embed_dim) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
| torch.nn.init.normal_(self.cls_token, std=.02) | |
| # initialize nn.Linear and nn.LayerNorm | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| # we use xavier_uniform following official JAX ViT: | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward(self, coord_obj, mask_obj): | |
| # [B, H/ws*W/ws, C] | |
| coord_embedding = self.coord_embed(coord_obj, mask_obj) | |
| # append cls token | |
| # [1, 1, C] | |
| cls_token = self.cls_token | |
| # [B, 1, C] | |
| cls_tokens = cls_token.expand(coord_embedding.shape[0], -1, -1) | |
| # [B, H/ws*W/ws+1, C] | |
| coord_embedding = torch.cat((cls_tokens, coord_embedding), dim=1) | |
| # apply Transformer blocks | |
| for blk in self.blocks: | |
| coord_embedding = blk(coord_embedding) | |
| coord_embedding = self.norm(coord_embedding) | |
| # [B, H/ws*W/ws+1, C] | |
| return coord_embedding | |
| class CoordEncRes(nn.Module): | |
| """ | |
| Seen surface encoder based on resnet. | |
| """ | |
| def __init__(self, opt): | |
| super().__init__() | |
| self.encoder = torchvision.models.resnet50(pretrained=True) | |
| self.encoder.fc = nn.Sequential( | |
| Bottleneck_Conv(2048), | |
| Bottleneck_Conv(2048), | |
| nn.Linear(2048, opt.arch.latent_dim) | |
| ) | |
| # define hooks | |
| self.seen_feature = None | |
| def feature_hook(model, input, output): | |
| self.seen_feature = output | |
| # attach hooks | |
| assert opt.arch.depth.dsp == 1 | |
| if (opt.arch.win_size) == 16: | |
| self.encoder.layer3.register_forward_hook(feature_hook) | |
| self.depth_feat_proj = nn.Sequential( | |
| Bottleneck_Conv(1024), | |
| Bottleneck_Conv(1024), | |
| nn.Conv2d(1024, opt.arch.latent_dim, 1) | |
| ) | |
| elif (opt.arch.win_size) == 32: | |
| self.encoder.layer4.register_forward_hook(feature_hook) | |
| self.depth_feat_proj = nn.Sequential( | |
| Bottleneck_Conv(2048), | |
| Bottleneck_Conv(2048), | |
| nn.Conv2d(2048, opt.arch.latent_dim, 1) | |
| ) | |
| else: | |
| print('Make sure win_size is 16 or 32 when using resnet backbone!') | |
| raise NotImplementedError | |
| def forward(self, coord_obj, mask_obj): | |
| batch_size = coord_obj.shape[0] | |
| assert len(coord_obj.shape) == len(mask_obj.shape) == 4 | |
| mask_obj = mask_obj.float() | |
| coord_obj = coord_obj * mask_obj | |
| # [B, 1, C] | |
| global_feat = self.encoder(coord_obj).unsqueeze(1) | |
| # [B, C, H/ws*W/ws] | |
| local_feat = self.depth_feat_proj(self.seen_feature).view(batch_size, global_feat.shape[-1], -1) | |
| # [B, H/ws*W/ws, C] | |
| local_feat = local_feat.permute(0, 2, 1).contiguous() | |
| # [B, 1+H/ws*W/ws, C] | |
| seen_embedding = torch.cat([global_feat, local_feat], dim=1) | |
| return seen_embedding |