""" Pose head implementation Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width); The Pose head implementation is based on Reloc3r and MaRePo References: https://github.com/ffrivera0/reloc3r/blob/main/reloc3r/pose_head.py """ import copy import torch import torch.nn as nn import torch.nn.functional as F from uniception.models.prediction_heads.base import PredictionHeadInput, SummaryTaskOutput class ResConvBlock(nn.Module): """ 1x1 convolution residual block implementation based on Reloc3r & MaRePo """ def __init__( self, in_channels: int, out_channels: int, *args, **kwargs, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.head_skip = ( nn.Identity() if self.in_channels == self.out_channels else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) ) self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) def forward(self, res): x = F.relu(self.res_conv1(res)) x = F.relu(self.res_conv2(x)) x = F.relu(self.res_conv3(x)) res = self.head_skip(res) + x return res class PoseHead(nn.Module): """ Pose regression head implementation based on Reloc3r & MaRePo """ def __init__( self, patch_size: int, input_feature_dim: int, num_resconv_block: int = 2, rot_representation_dim: int = 4, pretrained_checkpoint_path: str = None, *args, **kwargs, ): """ Initialize the pose head. Args: patch_size : int, the patch size of the transformer used to generate the input features input_feature_dim : int, the input feature dimension num_resconv_block : int, the number of residual convolution blocks rot_representation_dim : int, the dimension of the rotation representation pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None) """ super().__init__() self.patch_size = patch_size self.input_feature_dim = input_feature_dim self.num_resconv_block = num_resconv_block self.rot_representation_dim = rot_representation_dim self.pretrained_checkpoint_path = pretrained_checkpoint_path # Initialize the hidden dimension of the pose head based on the patch size self.output_dim = 4 * (self.patch_size**2) # Initialize the projection layer for the hidden dimension of the pose head self.proj = nn.Conv2d( in_channels=self.input_feature_dim, out_channels=self.output_dim, kernel_size=1, stride=1, padding=0, bias=True, ) # Initialize sequential layers of the pose head self.res_conv = nn.ModuleList( [copy.deepcopy(ResConvBlock(self.output_dim, self.output_dim)) for _ in range(self.num_resconv_block)] ) self.avgpool = nn.AdaptiveAvgPool2d(1) self.more_mlps = nn.Sequential( nn.Linear(self.output_dim, self.output_dim), nn.ReLU(), nn.Linear(self.output_dim, self.output_dim), nn.ReLU(), ) self.fc_t = nn.Linear(self.output_dim, 3) self.fc_rot = nn.Linear(self.output_dim, self.rot_representation_dim) # Load the pretrained checkpoint if provided if self.pretrained_checkpoint_path is not None: print(f"Loading pretrained pose head from {self.pretrained_checkpoint_path}") ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) print(self.load_state_dict(ckpt["model"])) def forward(self, feature_input: PredictionHeadInput): """ Forward interface for the pose head. The pose head requires an adapter on the final output to get the pose. Args: feature_input : PredictionHeadInput, the input features - last_feature : torch.Tensor, the last feature tensor Returns: SummaryTaskOutput, the output of the pose head - decoded_channels : torch.Tensor, the decoded channels """ # Get the patch-level features from the input feat = feature_input.last_feature # (B, C, H, W) # Check the input dimensions assert ( feat.shape[1] == self.input_feature_dim ), f"Input feature dimension {feat.shape[1]} does not match expected dimension {self.input_feature_dim}" # Apply the projection layer to the patch-level features feat = self.proj(feat) # (B, PC, H, W) # Apply the residual convolution blocks to the projected features for i in range(self.num_resconv_block): feat = self.res_conv[i](feat) # Apply the average pooling layer to the residual convolution output feat = self.avgpool(feat) # (B, PC, 1, 1) # Flatten the average pooled features feat = feat.view(feat.size(0), -1) # (B, PC) # Apply the more MLPs to the flattened features feat = self.more_mlps(feat) # (B, PC) # Apply the final linear layers to the more MLPs output feat_t = self.fc_t(feat) # (B, 3) feat_rot = self.fc_rot(feat) # (B, self.rot_representation_dim) # Concatenate the translation and rotation features output_feat = torch.cat([feat_t, feat_rot], dim=1) # (B, 3 + self.rot_representation_dim return SummaryTaskOutput(decoded_channels=output_feat) if __name__ == "__main__": # Init an example pose head pose_head = PoseHead( patch_size=16, input_feature_dim=768, num_resconv_block=2, rot_representation_dim=4, pretrained_checkpoint_path=None, ) # Create a dummy input tensor with shape (B, C, H, W) dummy_input = torch.randn(1, 768, 14, 14) # Example input # Run dummy forward pass output = pose_head(PredictionHeadInput(last_feature=dummy_input)) # Check the output shape assert output.decoded_channels.shape == (1, 7), "Output shape mismatch" print("Pose head test passed!")