File size: 5,686 Bytes
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Encoder Class for Cosmos
"""

import torch

from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
from uniception.models.libs.cosmos_tokenizer.modules import ContinuousFormulation, EncoderType
from uniception.models.libs.cosmos_tokenizer.networks import TokenizerConfigs


class CosmosEncoder(UniCeptionViTEncoderBase):
    "Uniception Cosmos Encoder"

    def __init__(
        self,
        name: str,
        data_norm_type: str = "cosmos",
        patch_size: int = 8,
        pretrained_checkpoint_path: str = None,
        *args,
        **kwargs,
    ):
        """
        Cosmos Encoder for extracting spatial features from images.

        Args:
            name (str): Name of the encoder.
            data_norm_type (str): Image normalization type. Default: "cosmos"
            patch_size (int): Patch size for the encoder. Default: 8
            pretrained_checkpoint_path (str): Path to the pretrained checkpoint. Default: None
        """
        # Init the base class
        super().__init__(name=name, data_norm_type=data_norm_type, patch_size=patch_size, *args, **kwargs)

        # Init Cosmos Encoder sepecific attributes
        tokenizer_config = TokenizerConfigs["CI"].value.copy()
        tokenizer_config.update(dict(spatial_compression=self.patch_size))

        z_factor = tokenizer_config["z_factor"]
        z_channels = tokenizer_config["z_channels"]
        latent_channels = tokenizer_config["latent_channels"]
        encoder_name = kwargs.get("encoder", EncoderType.Default.name)
        print(tokenizer_config)
        del tokenizer_config["z_factor"]
        del tokenizer_config["z_channels"]
        del tokenizer_config["latent_channels"]
        self.encoder = EncoderType[encoder_name].value(z_channels=z_factor * z_channels, **tokenizer_config)
        self.quant_conv = torch.nn.Conv2d(z_factor * z_channels, z_factor * latent_channels, 1)
        formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
        self.distribution = ContinuousFormulation[formulation_name].value()

        # Load the pretrained checkpoint
        if pretrained_checkpoint_path is not None:
            print(f"Loading custom pretrained Cosmos checkpoint from {pretrained_checkpoint_path}")
            ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
            print(self.load_state_dict(ckpt["model"]))

    def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
        """Encodes an image into a latent embedding or code.

        Args:
            input_tensor: The input tensor Bx3xHxW layout, range [-1..1].
        Returns:
            For continuous image (CI) tokenizer, the tuple contains:
                - The latent embedding, Bx16x(h)x(w), where the compression
                rate is (H/h x W/w), and channel dimension of 16.
            For discrete image (DI) tokenizer, the tuple contains:
                - The indices, Bx(h)x(w), from a codebook of size 64K, which
                corresponds to FSQ levels of (8,8,8,5,5,5).
               - The discrete code, Bx6x(h)x(w), where the compression rate is
                again (H/h x W/w), and channel dimension of 6.
        """
        x = self.encoder(input_tensor)
        x = self.quant_conv(x)
        output_latent = self.distribution(x)

        if isinstance(output_latent, torch.Tensor):
            return output_latent
        return output_latent[:-1]

    def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
        """
        Cosmos Encoder Forward Pass

        Args:
            encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.

        Returns:
            ViTEncoderOutput: Output data from the encoder.
        """
        # Check image normalization type
        self._check_data_normalization_type(encoder_input.data_norm_type)

        # Check the dtype and shape of the input image
        assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
        assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
        batch_size, channels, height, width = encoder_input.image.shape
        assert channels == 3, "Input must have 3 channels"
        assert (
            height % self.patch_size == 0 and width % self.patch_size == 0
        ), f"Input shape must be divisible by patch size: {self.patch_size}"

        # Extract the features from the DINOv2 model
        features = self.encode(encoder_input.image)[0].contiguous()

        return ViTEncoderOutput(features=features)


if __name__ == "__main__":

    # initialize different variants of the Cosmos Encoder, untrained
    for is_continuous in [True]:
        for patch_size in [8, 16]:
            encoder = CosmosEncoder(name="cosmos", patch_size=patch_size)

    # # initialize from trained checkpoint, with/without jit inference capability
    PRETRAINED_JIT_CHECKPOINTS = {
        ("CI", 8): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI8x8/encoder.pth",
        ("CI", 16): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI16x16/encoder.pth",
    }

    for patch_size in [8, 16]:

        encoder = CosmosEncoder(
            name="cosmos",
            patch_size=patch_size,
            pretrained_checkpoint_path=PRETRAINED_JIT_CHECKPOINTS[("CI", patch_size)],
        )

    # example inference
    dummy_image = torch.randn(1, 3, 256, 256).cuda()

    encoder_input = ViTEncoderInput(data_norm_type="cosmos", image=dummy_image)

    encoder = encoder.cuda()
    encoder_output = encoder(encoder_input)