File size: 6,446 Bytes
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
Pose head implementation
Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width);
The Pose head implementation is based on Reloc3r and MaRePo
References:
https://github.com/ffrivera0/reloc3r/blob/main/reloc3r/pose_head.py
"""

import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

from uniception.models.prediction_heads.base import PredictionHeadInput, SummaryTaskOutput


class ResConvBlock(nn.Module):
    """
    1x1 convolution residual block implementation based on Reloc3r & MaRePo
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.head_skip = (
            nn.Identity()
            if self.in_channels == self.out_channels
            else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
        )
        self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
        self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
        self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)

    def forward(self, res):
        x = F.relu(self.res_conv1(res))
        x = F.relu(self.res_conv2(x))
        x = F.relu(self.res_conv3(x))
        res = self.head_skip(res) + x
        return res


class PoseHead(nn.Module):
    """
    Pose regression head implementation based on Reloc3r & MaRePo
    """

    def __init__(
        self,
        patch_size: int,
        input_feature_dim: int,
        num_resconv_block: int = 2,
        rot_representation_dim: int = 4,
        pretrained_checkpoint_path: str = None,
        *args,
        **kwargs,
    ):
        """
        Initialize the pose head.

        Args:
            patch_size : int, the patch size of the transformer used to generate the input features
            input_feature_dim : int, the input feature dimension
            num_resconv_block : int, the number of residual convolution blocks
            rot_representation_dim : int, the dimension of the rotation representation
            pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None)
        """
        super().__init__()
        self.patch_size = patch_size
        self.input_feature_dim = input_feature_dim
        self.num_resconv_block = num_resconv_block
        self.rot_representation_dim = rot_representation_dim
        self.pretrained_checkpoint_path = pretrained_checkpoint_path

        # Initialize the hidden dimension of the pose head based on the patch size
        self.output_dim = 4 * (self.patch_size**2)

        # Initialize the projection layer for the hidden dimension of the pose head
        self.proj = nn.Conv2d(
            in_channels=self.input_feature_dim,
            out_channels=self.output_dim,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )

        # Initialize sequential layers of the pose head
        self.res_conv = nn.ModuleList(
            [copy.deepcopy(ResConvBlock(self.output_dim, self.output_dim)) for _ in range(self.num_resconv_block)]
        )
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.more_mlps = nn.Sequential(
            nn.Linear(self.output_dim, self.output_dim),
            nn.ReLU(),
            nn.Linear(self.output_dim, self.output_dim),
            nn.ReLU(),
        )
        self.fc_t = nn.Linear(self.output_dim, 3)
        self.fc_rot = nn.Linear(self.output_dim, self.rot_representation_dim)

        # Load the pretrained checkpoint if provided
        if self.pretrained_checkpoint_path is not None:
            print(f"Loading pretrained pose head from {self.pretrained_checkpoint_path}")
            ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
            print(self.load_state_dict(ckpt["model"]))

    def forward(self, feature_input: PredictionHeadInput):
        """
        Forward interface for the pose head.
        The pose head requires an adapter on the final output to get the pose.

        Args:
            feature_input : PredictionHeadInput, the input features
            - last_feature : torch.Tensor, the last feature tensor

        Returns:
            SummaryTaskOutput, the output of the pose head
            - decoded_channels : torch.Tensor, the decoded channels
        """
        # Get the patch-level features from the input
        feat = feature_input.last_feature  # (B, C, H, W)

        # Check the input dimensions
        assert (
            feat.shape[1] == self.input_feature_dim
        ), f"Input feature dimension {feat.shape[1]} does not match expected dimension {self.input_feature_dim}"

        # Apply the projection layer to the patch-level features
        feat = self.proj(feat)  # (B, PC, H, W)

        # Apply the residual convolution blocks to the projected features
        for i in range(self.num_resconv_block):
            feat = self.res_conv[i](feat)

        # Apply the average pooling layer to the residual convolution output
        feat = self.avgpool(feat)  # (B, PC, 1, 1)

        # Flatten the average pooled features
        feat = feat.view(feat.size(0), -1)  # (B, PC)

        # Apply the more MLPs to the flattened features
        feat = self.more_mlps(feat)  # (B, PC)

        # Apply the final linear layers to the more MLPs output
        feat_t = self.fc_t(feat)  # (B, 3)
        feat_rot = self.fc_rot(feat)  # (B, self.rot_representation_dim)

        # Concatenate the translation and rotation features
        output_feat = torch.cat([feat_t, feat_rot], dim=1)  # (B, 3 + self.rot_representation_dim

        return SummaryTaskOutput(decoded_channels=output_feat)


if __name__ == "__main__":
    # Init an example pose head
    pose_head = PoseHead(
        patch_size=16,
        input_feature_dim=768,
        num_resconv_block=2,
        rot_representation_dim=4,
        pretrained_checkpoint_path=None,
    )

    # Create a dummy input tensor with shape (B, C, H, W)
    dummy_input = torch.randn(1, 768, 14, 14)  # Example input

    # Run dummy forward pass
    output = pose_head(PredictionHeadInput(last_feature=dummy_input))

    # Check the output shape
    assert output.decoded_channels.shape == (1, 7), "Output shape mismatch"

    print("Pose head test passed!")