File size: 4,564 Bytes
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Encoder class for Global Representation Encoder
"""

from functools import partial
from typing import Callable, List, Optional, Type, Union

import torch
import torch.nn as nn

from uniception.models.encoders.base import EncoderGlobalRepInput, EncoderGlobalRepOutput


class GlobalRepresentationEncoder(nn.Module):
    "UniCeption Global Representation Encoder"

    def __init__(
        self,
        name: str,
        in_chans: int = 3,
        enc_embed_dim: int = 1024,
        intermediate_dims: List[int] = [128, 256, 512],
        act_layer: Type[nn.Module] = nn.GELU,
        norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
        pretrained_checkpoint_path: Optional[str] = None,
        *args,
        **kwargs,
    ):
        """
        Global Representation Encoder for projecting a global representation to a desired latent dimension.

        Args:
            name (str): Name of the Encoder.
            in_chans (int): Number of input channels.
            enc_embed_dim (int): Embedding dimension of the encoder.
            intermediate_dims (List[int]): List of intermediate dimensions of the encoder.
            act_layer (Type[nn.Module]): Activation layer to use in the encoder.
            norm_layer (Union[Type[nn.Module], Callable[..., nn.Module]]): Final normalization layer to use in the encoder.
            pretrained_checkpoint_path (Optional[str]): Path to pretrained checkpoint. (default: None)
        """
        super().__init__(*args, **kwargs)

        # Initialize the attributes
        self.name = name
        self.in_chans = in_chans
        self.enc_embed_dim = enc_embed_dim
        self.intermediate_dims = intermediate_dims
        self.pretrained_checkpoint_path = pretrained_checkpoint_path

        # Init the activation layer
        self.act_layer = act_layer()

        # Initialize the encoder
        self.encoder = nn.Sequential(
            nn.Linear(self.in_chans, self.intermediate_dims[0]),
            self.act_layer,
        )
        for intermediate_idx in range(1, len(self.intermediate_dims)):
            self.encoder = nn.Sequential(
                self.encoder,
                nn.Linear(self.intermediate_dims[intermediate_idx - 1], self.intermediate_dims[intermediate_idx]),
                self.act_layer,
            )
        self.encoder = nn.Sequential(
            self.encoder,
            nn.Linear(self.intermediate_dims[-1], self.enc_embed_dim),
        )

        # Init weights of the final norm layer
        self.norm_layer = norm_layer(enc_embed_dim) if norm_layer else nn.Identity()
        if isinstance(self.norm_layer, nn.LayerNorm):
            nn.init.constant_(self.norm_layer.bias, 0)
            nn.init.constant_(self.norm_layer.weight, 1.0)

        # Load pretrained weights if provided
        if self.pretrained_checkpoint_path is not None:
            print(
                f"Loading pretrained Global Representation Encoder checkpoint from {self.pretrained_checkpoint_path} ..."
            )
            ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
            print(self.load_state_dict(ckpt["model"]))

    def forward(self, encoder_input: EncoderGlobalRepInput) -> EncoderGlobalRepOutput:
        """
        Global Representation Encoder Forward Pass

        Args:
            encoder_input (EncoderGlobalRepInput): Input data for the encoder.
                The provided data must contain a tensor of size (B, C).

        Returns:
            EncoderGlobalRepOutput: Output features from the encoder.
        """
        # Get the input data and verify the shape of the input
        input_data = encoder_input.data
        assert input_data.ndim == 2, "Input data must have shape (B, C)"
        assert input_data.shape[1] == self.in_chans, f"Input data must have {self.in_chans} channels"

        # Encode the global representation
        features = self.encoder(input_data)

        # Normalize the output
        features = self.norm_layer(features)

        return EncoderGlobalRepOutput(features=features)


if __name__ == "__main__":
    dummy_model = GlobalRepresentationEncoder(
        name="dummy", in_chans=3, enc_embed_dim=1024, intermediate_dims=[128, 256, 512]
    )
    dummy_input = EncoderGlobalRepInput(data=torch.randn(4, 3))
    dummy_output = dummy_model(dummy_input)
    assert dummy_output.features.shape == (4, 1024), "Output features must have shape (B, 1024)"
    print("Global Representation Encoder has been initialized successfully!")