File size: 3,064 Bytes
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Linear head implementation
Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width);
The linear head implementation is based on DUSt3R and CroCoV2
References: https://github.com/naver/dust3r
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadInput


class LinearFeature(nn.Module):
    """
    This class implements a linear mapping from the low resolution patch features
    to pixel-wise features.
    """

    def __init__(
        self,
        input_feature_dim: int,
        output_dim: int,
        patch_size: int,
        pretrained_checkpoint_path: str = None,
        *args,
        **kwargs,
    ):
        """
        Initialize the linear feature mapping.

        Args:
            input_feature_dim : int, the input feature dimension
            output_dim : int, the output feature dimension
            patch_size : int, the patch size
            pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None)
        """

        super().__init__(*args, **kwargs)

        self.input_feature_dim = input_feature_dim
        self.output_dim = output_dim
        self.patch_size = patch_size
        self.pretrained_checkpoint_path = pretrained_checkpoint_path

        self.linear = nn.Conv2d(
            in_channels=self.input_feature_dim,
            out_channels=self.output_dim * (self.patch_size**2),
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )

        if self.pretrained_checkpoint_path is not None:
            print(f"Loading pretrained linear dense feature head from {self.pretrained_checkpoint_path}")
            ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
            print(self.load_state_dict(ckpt["model"]))

    def forward(self, feature_input: PredictionHeadInput):
        """
        Forward interface for the linear feature mapping.

        Args:
            feature_input : PredictionHeadInput, the input features
            - last_feature : torch.Tensor, the last feature tensor

        Returns:
            PixelTaskOutput, the output of the linear feature mapping
            - decoded_channels : torch.Tensor, the decoded channels

        """

        x = feature_input.last_feature

        assert (
            x.shape[1] == self.input_feature_dim
        ), f"Input feature dimension mismatch: {x.shape[1]} != {self.input_feature_dim}"

        x = self.linear(x)
        x = F.pixel_shuffle(x, self.patch_size)

        return PixelTaskOutput(decoded_channels=x)


if __name__ == "__main__":
    # Init an example linear feature head
    linear_prediction_head = LinearFeature(input_feature_dim=768, output_dim=4, patch_size=16)

    # Create a dummy input tensor with shape (B, C, H, W)
    dummy_input = torch.randn(1, 768, 14, 14)  # Example input

    # Run dummy forward pass
    output = linear_prediction_head(PredictionHeadInput(last_feature=dummy_input))