File size: 3,662 Bytes
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Linear head with MLP implementation
Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width)
"""

from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadInput
from uniception.models.utils.transformer_blocks import Mlp


class MLPFeature(nn.Module):
    """
    This class implements a linear mapping from the low resolution patch features
    to pixel-wise features with an additional intermediate MLP layer.
    """

    def __init__(
        self,
        input_feature_dim: Union[int, str],
        patch_size: int,
        output_dim: int,
        mlp_ratio: int = 4,
        act_layer=nn.GELU,
        bias=True,
        drop=0.0,
        pretrained_checkpoint_path: str = None,
        *args,
        **kwargs,
    ):
        """
        Initialize the linear feature mapping.

        Args:
            input_feature_dim : int, the input feature dimension
            output_dim : int, the output feature dimension
            patch_size : int, the patch size
            pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None)
        """

        super().__init__(*args, **kwargs)

        if isinstance(input_feature_dim, str):
            input_feature_dim = eval(input_feature_dim)

        self.input_feature_dim = input_feature_dim
        self.output_dim = output_dim
        self.patch_size = patch_size
        self.pretrained_checkpoint_path = pretrained_checkpoint_path

        self.mlp = Mlp(
            in_features=self.input_feature_dim,
            hidden_features=int(mlp_ratio * self.input_feature_dim),
            act_layer=act_layer,
            drop=drop,
            bias=bias,
        )

        self.linear = nn.Conv2d(
            in_channels=self.input_feature_dim,
            out_channels=self.output_dim * (self.patch_size**2),
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )

        if self.pretrained_checkpoint_path is not None:
            print(f"Loading pretrained linear dense feature head from {self.pretrained_checkpoint_path}")
            ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
            print(self.load_state_dict(ckpt["model"]))

    def forward(self, feature_input: PredictionHeadInput):
        """
        Forward interface for the linear feature mapping.

        Args:
            feature_input : PredictionHeadInput, the input features
            - last_feature : torch.Tensor, the last feature tensor

        Returns:
            PixelTaskOutput, the output of the linear feature mapping
            - decoded_channels : torch.Tensor, the decoded channels

        """

        x = feature_input.last_feature

        assert (
            x.shape[1] == self.input_feature_dim
        ), f"Input feature dimension mismatch: {x.shape[1]} != {self.input_feature_dim}"

        x = self.mlp(x.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
        x = self.linear(x)
        x = F.pixel_shuffle(x, self.patch_size)

        return PixelTaskOutput(decoded_channels=x)


if __name__ == "__main__":
    # Init an example linear feature head
    linear_prediction_head = MLPFeature(
        input_feature_dim=768, mlp_ratio=4, act_layer=nn.GELU, output_dim=4, patch_size=16
    )

    # Create a dummy input tensor with shape (B, C, H, W)
    dummy_input = torch.randn(1, 768, 14, 14)  # Example input

    # Run dummy forward pass
    output = linear_prediction_head(PredictionHeadInput(last_feature=dummy_input))