Spaces:
Build error
Build error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from functools import partial | |
| from utils.layers import get_embedder | |
| from utils.layers import LayerScale | |
| from timm.models.vision_transformer import Mlp, DropPath | |
| from utils.pos_embed import get_2d_sincos_pos_embed | |
| class ImplFuncAttention(nn.Module): | |
| def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., last_layer=False): | |
| super().__init__() | |
| assert dim % num_heads == 0, 'dim should be divisible by num_heads' | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = head_dim ** -0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| self.last_layer = last_layer | |
| def forward(self, x, N_points): | |
| B, N, C = x.shape | |
| N_latent = N - N_points | |
| # [3, B, num_heads, N, C/num_heads] | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |
| # [B, num_heads, N, C/num_heads] | |
| q, k, v = qkv.unbind(0) | |
| # [B, num_heads, N_latent, C/num_heads] | |
| q_latent, k_latent, v_latent = q[:, :, :-N_points], k[:, :, :-N_points], v[:, :, :-N_points] | |
| # [B, num_heads, N_points, C/num_heads] | |
| q_points, k_points, v_points = q[:, :, -N_points:], k[:, :, -N_points:], v[:, :, -N_points:] | |
| # attention weight for each point, it's only connected to the latent and itself | |
| # [B, num_heads, N_points, N_latent+1] | |
| # get the cross attention, [B, num_heads, N_points, N_latent] | |
| attn_cross = (q_points @ k_latent.transpose(-2, -1)) * self.scale | |
| # get the attention to self feature, [B, num_heads, N_points, 1] | |
| attn_self = torch.sum(q_points * k_points, dim=-1, keepdim=True) * self.scale | |
| # get the normalized attention, [B, num_heads, N_points, N_latent+1] | |
| attn_joint = torch.cat([attn_cross, attn_self], dim=-1) | |
| attn_joint = attn_joint.softmax(dim=-1) | |
| attn_joint = self.attn_drop(attn_joint) | |
| # break it down to weigh and sum the values | |
| # [B, num_heads, N_points, N_latent] @ [B, num_heads, N_latent, C/num_heads] | |
| # -> [B, num_heads, N_points, C/num_heads] -> [B, N_points, C] | |
| sum_cross = (attn_joint[:, :, :, :N_latent] @ v_latent).transpose(1, 2).reshape(B, N_points, C) | |
| # [B, num_heads, N_points, 1] * [B, num_heads, N_points, C/num_heads] | |
| # -> [B, num_heads, N_points, C/num_heads] -> [B, N_points, C] | |
| sum_self = (attn_joint[:, :, :, N_latent:] * v_points).transpose(1, 2).reshape(B, N_points, C) | |
| # [B, N_points, C] | |
| output_points = sum_cross + sum_self | |
| if self.last_layer: | |
| output = self.proj(output_points) | |
| output = self.proj_drop(output) | |
| # [B, N_points, C], [B, N_points, N_latent] | |
| return output, attn_joint[..., :-1].mean(dim=1) | |
| # attention weight for the latent vec, it's not connected to the points | |
| # [B, num_heads, N_latent, N_latent] | |
| attn_latent = (q_latent @ k_latent.transpose(-2, -1)) * self.scale | |
| attn_latent = attn_latent.softmax(dim=-1) | |
| attn_latent = self.attn_drop(attn_latent) | |
| # get the output latent, [B, N_latent, C] | |
| output_latent = (attn_latent @ v_latent).transpose(1, 2).reshape(B, N_latent, C) | |
| # concatenate the output and return | |
| output = torch.cat([output_latent, output_points], dim=1) | |
| output = self.proj(output) | |
| output = self.proj_drop(output) | |
| # [B, N, C], [B, N_points, N_latent+1] | |
| return output, attn_joint[..., :-1].mean(dim=1) | |
| class ImplFuncBlock(nn.Module): | |
| def __init__( | |
| self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, | |
| drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, last_layer=False): | |
| super().__init__() | |
| self.last_layer = last_layer | |
| self.norm1 = norm_layer(dim) | |
| self.attn = ImplFuncAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, last_layer=last_layer) | |
| self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() | |
| self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
| self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() | |
| self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| def forward(self, x, unseen_size): | |
| if self.last_layer: | |
| attn_out, attn_vis = self.attn(self.norm1(x), unseen_size) | |
| output = x[:, -unseen_size:] + self.drop_path1(self.ls1(attn_out)) | |
| output = output + self.drop_path2(self.ls2(self.mlp(self.norm2(output)))) | |
| return output, attn_vis | |
| else: | |
| attn_out, attn_vis = self.attn(self.norm1(x), unseen_size) | |
| x = x + self.drop_path1(self.ls1(attn_out)) | |
| x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) | |
| return x, attn_vis | |
| class LinearProj3D(nn.Module): | |
| """ | |
| Linear projection of 3D point into embedding space | |
| """ | |
| def __init__(self, embed_dim, posenc_res=0): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| # define positional embedder | |
| self.embed_fn = None | |
| input_ch = 3 | |
| if posenc_res > 0: | |
| self.embed_fn, input_ch = get_embedder(posenc_res, input_dims=3) | |
| # linear proj layer | |
| self.proj = nn.Linear(input_ch, embed_dim) | |
| def forward(self, points_3D): | |
| if self.embed_fn is not None: | |
| points_3D = self.embed_fn(points_3D) | |
| return self.proj(points_3D) | |
| class MLPBlocks(nn.Module): | |
| def __init__(self, num_hidden_layers, n_channels, latent_dim, | |
| skip_in=[], posenc_res=0): | |
| super().__init__() | |
| # projection to the same number of channels | |
| self.dims = [3 + latent_dim] + [n_channels] * num_hidden_layers + [1] | |
| self.num_layers = len(self.dims) | |
| self.skip_in = skip_in | |
| # define positional embedder | |
| self.embed_fn = None | |
| if posenc_res > 0: | |
| embed_fn, input_ch = get_embedder(posenc_res, input_dims=3) | |
| self.embed_fn = embed_fn | |
| self.dims[0] += (input_ch - 3) | |
| self.layers = nn.ModuleList([]) | |
| for l in range(0, self.num_layers - 1): | |
| out_dim = self.dims[l + 1] | |
| if l in self.skip_in: | |
| in_dim = self.dims[l] + self.dims[0] | |
| else: | |
| in_dim = self.dims[l] | |
| lin = nn.Linear(in_dim, out_dim) | |
| self.layers.append(lin) | |
| # register for param init | |
| self.posenc_res = posenc_res | |
| # activation | |
| self.softplus = nn.Softplus(beta=100) | |
| def forward(self, points, proj_latent): | |
| # positional encoding | |
| if self.embed_fn is not None: | |
| points = self.embed_fn(points) | |
| # forward by layer | |
| # [B, N, posenc+C] | |
| inputs = torch.cat([points, proj_latent], dim=-1) | |
| x = inputs | |
| for l in range(0, self.num_layers - 1): | |
| if l in self.skip_in: | |
| x = torch.cat([x, inputs], -1) / np.sqrt(2) | |
| x = self.layers[l](x) | |
| if l < self.num_layers - 2: | |
| x = self.softplus(x) | |
| return x | |
| class Implicit(nn.Module): | |
| """ | |
| Implicit function conditioned on depth encodings | |
| """ | |
| def __init__(self, | |
| num_patches, latent_dim=768, semantic=False, n_channels=512, | |
| n_blocks_attn=2, n_layers_mlp=6, num_heads=16, posenc_3D=0, | |
| mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1, | |
| skip_in=[], pos_perlayer=True): | |
| super().__init__() | |
| self.num_patches = num_patches | |
| self.pos_perlayer = pos_perlayer | |
| self.semantic = semantic | |
| # projection to the same number of channels, no posenc | |
| self.point_proj = LinearProj3D(n_channels) | |
| self.latent_proj = nn.Linear(latent_dim, n_channels, bias=True) | |
| # positional embedding for the depth latent codes | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, n_channels), requires_grad=False) # fixed sin-cos embedding | |
| # multi-head attention blocks | |
| self.blocks_attn = nn.ModuleList([ | |
| ImplFuncBlock( | |
| n_channels, num_heads, mlp_ratio, | |
| qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path | |
| ) for _ in range(n_blocks_attn-1)]) | |
| self.blocks_attn.append( | |
| ImplFuncBlock( | |
| n_channels, num_heads, mlp_ratio, | |
| qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path, last_layer=True | |
| ) | |
| ) | |
| self.norm = norm_layer(n_channels) | |
| self.impl_mlp = None | |
| # define the impl MLP | |
| if n_layers_mlp > 0: | |
| self.impl_mlp = MLPBlocks(n_layers_mlp, n_channels, n_channels, | |
| skip_in=skip_in, posenc_res=posenc_3D) | |
| else: | |
| # occ and color prediction | |
| self.pred_head = nn.Linear(n_channels, 1, bias=True) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # initialize the positional embedding for the depth latent codes | |
| pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True) | |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
| # initialize nn.Linear and nn.LayerNorm | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| # we use xavier_uniform following official JAX ViT: | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward(self, latent_depth, latent_semantic, points_3D): | |
| # concatenate latent codes if semantic is used | |
| latent = torch.cat([latent_depth, latent_semantic], dim=-1) if self.semantic else latent_depth | |
| # project latent code and add posenc | |
| # [B, 1+n_patches, C] | |
| latent = self.latent_proj(latent) | |
| N_latent = latent.shape[1] | |
| # project query points | |
| # [B, n_points, C_dec] | |
| points_feat = self.point_proj(points_3D) | |
| # concat point feat with latent | |
| # [B, 1+n_patches+n_points, C_dec] | |
| output = torch.cat([latent, points_feat], dim=1) | |
| # apply multi-head attention blocks | |
| attn_vis = [] | |
| for l, blk in enumerate(self.blocks_attn): | |
| if self.pos_perlayer or l == 0: | |
| output[:, :N_latent] = output[:, :N_latent] + self.pos_embed | |
| output, attn = blk(output, points_feat.shape[1]) | |
| attn_vis.append(attn) | |
| output = self.norm(output) | |
| # average of attention weights across layers, [B, N_points, N_latent+1] | |
| attn_vis = torch.stack(attn_vis, dim=-1).mean(dim=-1) | |
| if self.impl_mlp: | |
| # apply mlp blocks | |
| output = self.impl_mlp(points_3D, output) | |
| else: | |
| # predictor projection | |
| # [B, n_points, 1] | |
| output = self.pred_head(output) | |
| # return the occ logit of shape [B, n_points] and the attention weights if needed | |
| return output.squeeze(-1), attn_vis | |