import copy from typing import List, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from einops.layers.torch import Rearrange LRELU_SLOPE = 0.02 def act_layer(act): if act == "relu": return nn.ReLU() elif act == "lrelu": return nn.LeakyReLU(LRELU_SLOPE) elif act == "elu": return nn.ELU() elif act == "tanh": return nn.Tanh() elif act == "prelu": return nn.PReLU() else: raise ValueError("%s not recognized." % act) def norm_layer2d(norm, channels): if norm == "batch": return nn.BatchNorm2d(channels) elif norm == "instance": return nn.InstanceNorm2d(channels, affine=True) elif norm == "layer": return nn.GroupNorm(1, channels, affine=True) elif norm == "group": return nn.GroupNorm(4, channels, affine=True) else: raise ValueError("%s not recognized." % norm) def norm_layer1d(norm, num_channels): if norm == "batch": return nn.BatchNorm1d(num_channels) elif norm == "instance": return nn.InstanceNorm1d(num_channels, affine=True) elif norm == "layer": return nn.LayerNorm(num_channels) else: raise ValueError("%s not recognized." % norm) class FiLMBlock(nn.Module): def __init__(self): super(FiLMBlock, self).__init__() def forward(self, x, gamma, beta): beta = beta.view(x.size(0), x.size(1), 1, 1) gamma = gamma.view(x.size(0), x.size(1), 1, 1) x = gamma * x + beta return x class Conv2DBlock(nn.Module): def __init__( self, in_channels, out_channels, kernel_sizes, strides, norm=None, activation=None, padding_mode="replicate", ): super(Conv2DBlock, self).__init__() padding = ( kernel_sizes // 2 if isinstance(kernel_sizes, int) else (kernel_sizes[0] // 2, kernel_sizes[1] // 2) ) self.conv2d = nn.Conv2d( in_channels, out_channels, kernel_sizes, strides, padding=padding, padding_mode=padding_mode, ) if activation is None: nn.init.xavier_uniform_( self.conv2d.weight, gain=nn.init.calculate_gain("linear") ) nn.init.zeros_(self.conv2d.bias) elif activation == "tanh": nn.init.xavier_uniform_( self.conv2d.weight, gain=nn.init.calculate_gain("tanh") ) nn.init.zeros_(self.conv2d.bias) elif activation == "lrelu": nn.init.kaiming_uniform_( self.conv2d.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" ) nn.init.zeros_(self.conv2d.bias) elif activation == "relu": nn.init.kaiming_uniform_(self.conv2d.weight, nonlinearity="relu") nn.init.zeros_(self.conv2d.bias) else: raise ValueError() self.activation = None self.norm = None if norm is not None: self.norm = norm_layer2d(norm, out_channels) if activation is not None: self.activation = act_layer(activation) def forward(self, x): x = self.conv2d(x) x = self.norm(x) if self.norm is not None else x x = self.activation(x) if self.activation is not None else x return x class Conv2DFiLMBlock(Conv2DBlock): def __init__( self, in_channels, out_channels, kernel_sizes, strides, norm=None, activation=None, padding_mode="replicate", ): super(Conv2DFiLMBlock, self).__init__( in_channels, out_channels, kernel_sizes, strides, norm, activation, padding_mode, ) self.film = FiLMBlock() def forward(self, x, gamma, beta): x = self.conv2d(x) x = self.norm(x) if self.norm is not None else x x = self.film(x, gamma, beta) x = self.activation(x) if self.activation is not None else x return x class Conv3DBlock(nn.Module): def __init__( self, in_channels, out_channels, kernel_sizes: Union[int, list] = 3, strides=1, norm=None, activation=None, padding_mode="replicate", padding=None, ): super(Conv3DBlock, self).__init__() padding = kernel_sizes // 2 if padding is None else padding self.conv3d = nn.Conv3d( in_channels, out_channels, kernel_sizes, strides, padding=padding, padding_mode=padding_mode, ) if activation is None: nn.init.xavier_uniform_( self.conv3d.weight, gain=nn.init.calculate_gain("linear") ) nn.init.zeros_(self.conv3d.bias) elif activation == "tanh": nn.init.xavier_uniform_( self.conv3d.weight, gain=nn.init.calculate_gain("tanh") ) nn.init.zeros_(self.conv3d.bias) elif activation == "lrelu": nn.init.kaiming_uniform_( self.conv3d.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" ) nn.init.zeros_(self.conv3d.bias) elif activation == "relu": nn.init.kaiming_uniform_(self.conv3d.weight, nonlinearity="relu") nn.init.zeros_(self.conv3d.bias) else: raise ValueError() self.activation = None self.norm = None if norm is not None: raise NotImplementedError("Norm not implemented.") if activation is not None: self.activation = act_layer(activation) self.out_channels = out_channels def forward(self, x): x = self.conv3d(x) x = self.norm(x) if self.norm is not None else x x = self.activation(x) if self.activation is not None else x return x class ConvTranspose3DBlock(nn.Module): def __init__( self, in_channels, out_channels, kernel_sizes: Union[int, list], strides, norm=None, activation=None, padding_mode="zeros", padding=None, ): super(ConvTranspose3DBlock, self).__init__() padding = kernel_sizes // 2 if padding is None else padding self.conv3d = nn.ConvTranspose3d( in_channels, out_channels, kernel_sizes, strides, padding=padding, padding_mode=padding_mode, ) if activation is None: nn.init.xavier_uniform_( self.conv3d.weight, gain=nn.init.calculate_gain("linear") ) nn.init.zeros_(self.conv3d.bias) elif activation == "tanh": nn.init.xavier_uniform_( self.conv3d.weight, gain=nn.init.calculate_gain("tanh") ) nn.init.zeros_(self.conv3d.bias) elif activation == "lrelu": nn.init.kaiming_uniform_( self.conv3d.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" ) nn.init.zeros_(self.conv3d.bias) elif activation == "relu": nn.init.kaiming_uniform_(self.conv3d.weight, nonlinearity="relu") nn.init.zeros_(self.conv3d.bias) else: raise ValueError() self.activation = None self.norm = None if norm is not None: self.norm = norm_layer3d(norm, out_channels) if activation is not None: self.activation = act_layer(activation) def forward(self, x): x = self.conv3d(x) x = self.norm(x) if self.norm is not None else x x = self.activation(x) if self.activation is not None else x return x class Conv2DUpsampleBlock(nn.Module): def __init__( self, in_channels, out_channels, kernel_sizes, strides, norm=None, activation=None, ): super(Conv2DUpsampleBlock, self).__init__() layer = [ Conv2DBlock(in_channels, out_channels, kernel_sizes, 1, norm, activation) ] if strides > 1: layer.append( nn.Upsample(scale_factor=strides, mode="bilinear", align_corners=False) ) convt_block = Conv2DBlock( out_channels, out_channels, kernel_sizes, 1, norm, activation ) layer.append(convt_block) self.conv_up = nn.Sequential(*layer) def forward(self, x): return self.conv_up(x) class Conv3DUpsampleBlock(nn.Module): def __init__( self, in_channels, out_channels, strides, kernel_sizes=3, norm=None, activation=None, ): super(Conv3DUpsampleBlock, self).__init__() layer = [ Conv3DBlock(in_channels, out_channels, kernel_sizes, 1, norm, activation) ] if strides > 1: layer.append( nn.Upsample(scale_factor=strides, mode="trilinear", align_corners=False) ) convt_block = Conv3DBlock( out_channels, out_channels, kernel_sizes, 1, norm, activation ) layer.append(convt_block) self.conv_up = nn.Sequential(*layer) def forward(self, x): return self.conv_up(x) class DenseBlock(nn.Module): def __init__(self, in_features, out_features, norm=None, activation=None): super(DenseBlock, self).__init__() self.linear = nn.Linear(in_features, out_features) if activation is None: nn.init.xavier_uniform_( self.linear.weight, gain=nn.init.calculate_gain("linear") ) nn.init.zeros_(self.linear.bias) elif activation == "tanh": nn.init.xavier_uniform_( self.linear.weight, gain=nn.init.calculate_gain("tanh") ) nn.init.zeros_(self.linear.bias) elif activation == "lrelu": nn.init.kaiming_uniform_( self.linear.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" ) nn.init.zeros_(self.linear.bias) elif activation == "relu": nn.init.kaiming_uniform_(self.linear.weight, nonlinearity="relu") nn.init.zeros_(self.linear.bias) else: raise ValueError() self.activation = None self.norm = None if norm is not None: self.norm = norm_layer1d(norm, out_features) if activation is not None: self.activation = act_layer(activation) def forward(self, x): x = self.linear(x) x = self.norm(x) if self.norm is not None else x x = self.activation(x) if self.activation is not None else x return x class SiameseNet(nn.Module): def __init__( self, input_channels: List[int], filters: List[int], kernel_sizes: List[int], strides: List[int], norm: str = None, activation: str = "relu", ): super(SiameseNet, self).__init__() self._input_channels = input_channels self._filters = filters self._kernel_sizes = kernel_sizes self._strides = strides self._norm = norm self._activation = activation self.output_channels = filters[-1] # * len(input_channels) def build(self): self._siamese_blocks = nn.ModuleList() for i, ch in enumerate(self._input_channels): blocks = [] for i, (filt, ksize, stride) in enumerate( zip(self._filters, self._kernel_sizes, self._strides) ): conv_block = Conv2DBlock( ch, filt, ksize, stride, self._norm, self._activation ) blocks.append(conv_block) self._siamese_blocks.append(nn.Sequential(*blocks)) self._fuse = Conv2DBlock( self._filters[-1] * len(self._siamese_blocks), self._filters[-1], 1, 1, self._norm, self._activation, ) def forward(self, x): if len(x) != len(self._siamese_blocks): raise ValueError( "Expected a list of tensors of size %d." % len(self._siamese_blocks) ) self.streams = [stream(y) for y, stream in zip(x, self._siamese_blocks)] y = self._fuse(torch.cat(self.streams, 1)) return y class CNNAndFcsNet(nn.Module): def __init__( self, siamese_net: SiameseNet, low_dim_state_len: int, input_resolution: List[int], filters: List[int], kernel_sizes: List[int], strides: List[int], norm: str = None, fc_layers: List[int] = None, activation: str = "relu", ): super(CNNAndFcsNet, self).__init__() self._siamese_net = copy.deepcopy(siamese_net) self._input_channels = self._siamese_net.output_channels + low_dim_state_len self._filters = filters self._kernel_sizes = kernel_sizes self._strides = strides self._norm = norm self._activation = activation self._fc_layers = [] if fc_layers is None else fc_layers self._input_resolution = input_resolution def build(self): self._siamese_net.build() layers = [] channels = self._input_channels for i, (filt, ksize, stride) in enumerate( list(zip(self._filters, self._kernel_sizes, self._strides))[:-1] ): layers.append( Conv2DBlock(channels, filt, ksize, stride, self._norm, self._activation) ) channels = filt layers.append( Conv2DBlock( channels, self._filters[-1], self._kernel_sizes[-1], self._strides[-1] ) ) self._cnn = nn.Sequential(*layers) self._maxp = nn.AdaptiveMaxPool2d(1) channels = self._filters[-1] dense_layers = [] for n in self._fc_layers[:-1]: dense_layers.append(DenseBlock(channels, n, activation=self._activation)) channels = n dense_layers.append(DenseBlock(channels, self._fc_layers[-1])) self._fcs = nn.Sequential(*dense_layers) def forward(self, observations, low_dim_ins): x = self._siamese_net(observations) _, _, h, w = x.shape low_dim_latents = low_dim_ins.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) combined = torch.cat([x, low_dim_latents], dim=1) x = self._cnn(combined) x = self._maxp(x).squeeze(-1).squeeze(-1) return self._fcs(x) class CNNLangAndFcsNet(nn.Module): def __init__( self, siamese_net: SiameseNet, low_dim_state_len: int, input_resolution: List[int], filters: List[int], kernel_sizes: List[int], strides: List[int], norm: str = None, fc_layers: List[int] = None, activation: str = "relu", ): super(CNNLangAndFcsNet, self).__init__() self._siamese_net = copy.deepcopy(siamese_net) self._input_channels = self._siamese_net.output_channels + low_dim_state_len self._filters = filters self._kernel_sizes = kernel_sizes self._strides = strides self._norm = norm self._activation = activation self._fc_layers = [] if fc_layers is None else fc_layers self._input_resolution = input_resolution self._lang_feat_dim = 1024 def build(self): self._siamese_net.build() layers = [] channels = self._input_channels self.conv1 = Conv2DFiLMBlock( channels, self._filters[0], self._kernel_sizes[0], self._strides[0] ) self.gamma1 = nn.Linear(self._lang_feat_dim, self._filters[0]) self.beta1 = nn.Linear(self._lang_feat_dim, self._filters[0]) self.conv2 = Conv2DFiLMBlock( self._filters[0], self._filters[1], self._kernel_sizes[1], self._strides[1] ) self.gamma2 = nn.Linear(self._lang_feat_dim, self._filters[1]) self.beta2 = nn.Linear(self._lang_feat_dim, self._filters[1]) self.conv3 = Conv2DFiLMBlock( self._filters[1], self._filters[2], self._kernel_sizes[2], self._strides[2] ) self.gamma3 = nn.Linear(self._lang_feat_dim, self._filters[2]) self.beta3 = nn.Linear(self._lang_feat_dim, self._filters[2]) self._maxp = nn.AdaptiveMaxPool2d(1) channels = self._filters[-1] dense_layers = [] for n in self._fc_layers[:-1]: dense_layers.append(DenseBlock(channels, n, activation=self._activation)) channels = n dense_layers.append(DenseBlock(channels, self._fc_layers[-1])) self._fcs = nn.Sequential(*dense_layers) def forward(self, observations, low_dim_ins, lang_goal_emb): x = self._siamese_net(observations) _, _, h, w = x.shape low_dim_latents = low_dim_ins.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) combined = torch.cat([x, low_dim_latents], dim=1) g1 = self.gamma1(lang_goal_emb) b1 = self.beta1(lang_goal_emb) x = self.conv1(combined, g1, b1) g2 = self.gamma2(lang_goal_emb) b2 = self.beta2(lang_goal_emb) x = self.conv2(x, g2, b2) g3 = self.gamma3(lang_goal_emb) b3 = self.beta3(lang_goal_emb) x = self.conv3(x, g3, b3) x = self._maxp(x).squeeze(-1).squeeze(-1) return self._fcs(x) # helpers def pair(t): return t if isinstance(t, tuple) else (t, t) # classes class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.0): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head**-0.5 self.attend = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = ( nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity() ) def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ PreNorm( dim, Attention( dim, heads=heads, dim_head=dim_head, dropout=dropout ), ), PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)), ] ) ) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x # ViT IO implementation adpated for baseline # Source: https://github.com/lucidrains/vit-pytorch # License: https://github.com/lucidrains/vit-pytorch/blob/main/LICENSE class ViT(nn.Module): def __init__( self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool="cls", channels=3, dim_head=64, dropout=0.0, emb_dropout=0.0 ): super().__init__() image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) assert ( image_height % patch_height == 0 and image_width % patch_width == 0 ), "Image dimensions must be divisible by the patch size." self.num_patches_x = image_height // patch_height self.num_patches_y = image_width // patch_width self.num_patches = self.num_patches_x * self.num_patches_y patch_dim = channels * patch_height * patch_width assert pool in { "cls", "mean", }, "pool type must be either cls (cls token) or mean (mean pooling)" self.to_patch_embedding = nn.Sequential( Rearrange( "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_height, p2=patch_width, ), nn.Linear(patch_dim, dim), ) self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) def forward(self, img): x = self.to_patch_embedding(img) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, : (n + 1)] x = self.dropout(x) x = self.transformer(x) x = x[:, 1:].reshape(b, -1, self.num_patches_x, self.num_patches_y) return x class ViTLangAndFcsNet(nn.Module): def __init__( self, vit: ViT, low_dim_state_len: int, input_resolution: List[int], filters: List[int], kernel_sizes: List[int], strides: List[int], norm: str = None, fc_layers: List[int] = None, activation: str = "relu", ): super(ViTLangAndFcsNet, self).__init__() self._vit = copy.deepcopy(vit) self._input_channels = 64 + low_dim_state_len self._filters = filters self._kernel_sizes = kernel_sizes self._strides = strides self._norm = norm self._activation = activation self._fc_layers = [] if fc_layers is None else fc_layers self._input_resolution = input_resolution self._lang_feat_dim = 1024 def build(self): layers = [] channels = self._input_channels self.conv1 = Conv2DFiLMBlock( channels, self._filters[0], self._kernel_sizes[0], self._strides[0] ) self.gamma1 = nn.Linear(self._lang_feat_dim, self._filters[0]) self.beta1 = nn.Linear(self._lang_feat_dim, self._filters[0]) self.conv2 = Conv2DFiLMBlock( self._filters[0], self._filters[1], self._kernel_sizes[1], self._strides[1] ) self.gamma2 = nn.Linear(self._lang_feat_dim, self._filters[1]) self.beta2 = nn.Linear(self._lang_feat_dim, self._filters[1]) self.conv3 = Conv2DFiLMBlock( self._filters[1], self._filters[2], self._kernel_sizes[2], self._strides[2] ) self.gamma3 = nn.Linear(self._lang_feat_dim, self._filters[2]) self.beta3 = nn.Linear(self._lang_feat_dim, self._filters[2]) self._maxp = nn.AdaptiveMaxPool2d(1) channels = self._filters[-1] dense_layers = [] for n in self._fc_layers[:-1]: dense_layers.append(DenseBlock(channels, n, activation=self._activation)) channels = n dense_layers.append(DenseBlock(channels, self._fc_layers[-1])) self._fcs = nn.Sequential(*dense_layers) def forward(self, observations, low_dim_ins, lang_goal_emb): rgb_depth = torch.cat([*observations], dim=1) x = self._vit(rgb_depth) _, _, h, w = x.shape low_dim_latents = low_dim_ins.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) combined = torch.cat([x, low_dim_latents], dim=1) g1 = self.gamma1(lang_goal_emb) b1 = self.beta1(lang_goal_emb) x = self.conv1(combined, g1, b1) g2 = self.gamma2(lang_goal_emb) b2 = self.beta2(lang_goal_emb) x = self.conv2(x, g2, b2) g3 = self.gamma3(lang_goal_emb) b3 = self.beta3(lang_goal_emb) x = self.conv3(x, g3, b3) x = self._maxp(x).squeeze(-1).squeeze(-1) return self._fcs(x) class Conv3DInceptionBlockUpsampleBlock(nn.Module): def __init__( self, in_channels, out_channels, scale_factor, norm=None, activation=None, residual=False, ): super(Conv3DInceptionBlockUpsampleBlock, self).__init__() layer = [] convt_block = Conv3DInceptionBlock(in_channels, out_channels, norm, activation) layer.append(convt_block) if scale_factor > 1: layer.append( nn.Upsample( scale_factor=scale_factor, mode="trilinear", align_corners=False ) ) convt_block = Conv3DInceptionBlock(out_channels, out_channels, norm, activation) layer.append(convt_block) self.conv_up = nn.Sequential(*layer) def forward(self, x): return self.conv_up(x) class Conv3DInceptionBlock(nn.Module): def __init__( self, in_channels, out_channels, norm=None, activation=None, residual=False ): super(Conv3DInceptionBlock, self).__init__() self._residual = residual cs = out_channels // 4 assert out_channels % 4 == 0 latent = 32 self._1x1conv = Conv3DBlock( in_channels, cs * 2, kernel_sizes=1, strides=1, norm=norm, activation=activation, ) self._1x1conv_a = Conv3DBlock( in_channels, latent, kernel_sizes=1, strides=1, norm=norm, activation=activation, ) self._3x3conv = Conv3DBlock( latent, cs, kernel_sizes=3, strides=1, norm=norm, activation=activation ) self._1x1conv_b = Conv3DBlock( in_channels, latent, kernel_sizes=1, strides=1, norm=norm, activation=activation, ) self._5x5_via_3x3conv_a = Conv3DBlock( latent, latent, kernel_sizes=3, strides=1, norm=norm, activation=activation ) self._5x5_via_3x3conv_b = Conv3DBlock( latent, cs, kernel_sizes=3, strides=1, norm=norm, activation=activation ) self.out_channels = out_channels + (in_channels if residual else 0) def forward(self, x): yy = [] if self._residual: yy = [x] return torch.cat( yy + [ self._1x1conv(x), self._3x3conv(self._1x1conv_a(x)), self._5x5_via_3x3conv_b(self._5x5_via_3x3conv_a(self._1x1conv_b(x))), ], 1, ) class ConvTransposeUp3DBlock(nn.Module): def __init__( self, in_channels, out_channels, strides=2, padding=0, norm=None, activation=None, residual=False, ): super(ConvTransposeUp3DBlock, self).__init__() self._residual = residual self._1x1conv = Conv3DBlock( in_channels, out_channels, kernel_sizes=1, strides=1, norm=norm, activation=activation, ) self._3x3conv = ConvTranspose3DBlock( out_channels, out_channels, kernel_sizes=2, strides=strides, norm=norm, activation=activation, padding=padding, ) self._1x1conv_a = Conv3DBlock( out_channels, out_channels, kernel_sizes=1, strides=1, norm=norm, ) self.out_channels = out_channels def forward(self, x): x = self._1x1conv(x) x = self._3x3conv(x) x = self._1x1conv_a(x) return x class SpatialSoftmax3D(torch.nn.Module): def __init__(self, depth, height, width, channel): super(SpatialSoftmax3D, self).__init__() self.depth = depth self.height = height self.width = width self.channel = channel self.temperature = 0.01 pos_x, pos_y, pos_z = np.meshgrid( np.linspace(-1.0, 1.0, self.depth), np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width), ) pos_x = torch.from_numpy( pos_x.reshape(self.depth * self.height * self.width) ).float() pos_y = torch.from_numpy( pos_y.reshape(self.depth * self.height * self.width) ).float() pos_z = torch.from_numpy( pos_z.reshape(self.depth * self.height * self.width) ).float() self.register_buffer("pos_x", pos_x) self.register_buffer("pos_y", pos_y) self.register_buffer("pos_z", pos_z) def forward(self, feature): feature = feature.view( -1, self.height * self.width * self.depth ) # (B, c*d*h*w) softmax_attention = F.softmax(feature / self.temperature, dim=-1) expected_x = torch.sum(self.pos_x * softmax_attention, dim=1, keepdim=True) expected_y = torch.sum(self.pos_y * softmax_attention, dim=1, keepdim=True) expected_z = torch.sum(self.pos_z * softmax_attention, dim=1, keepdim=True) expected_xy = torch.cat([expected_x, expected_y, expected_z], 1) feature_keypoints = expected_xy.view(-1, self.channel * 3) return feature_keypoints