infinity1096
initial commit
c8b42eb
"""
Encoder Class for Cosmos
"""
import torch
from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
from uniception.models.libs.cosmos_tokenizer.modules import ContinuousFormulation, EncoderType
from uniception.models.libs.cosmos_tokenizer.networks import TokenizerConfigs
class CosmosEncoder(UniCeptionViTEncoderBase):
"Uniception Cosmos Encoder"
def __init__(
self,
name: str,
data_norm_type: str = "cosmos",
patch_size: int = 8,
pretrained_checkpoint_path: str = None,
*args,
**kwargs,
):
"""
Cosmos Encoder for extracting spatial features from images.
Args:
name (str): Name of the encoder.
data_norm_type (str): Image normalization type. Default: "cosmos"
patch_size (int): Patch size for the encoder. Default: 8
pretrained_checkpoint_path (str): Path to the pretrained checkpoint. Default: None
"""
# Init the base class
super().__init__(name=name, data_norm_type=data_norm_type, patch_size=patch_size, *args, **kwargs)
# Init Cosmos Encoder sepecific attributes
tokenizer_config = TokenizerConfigs["CI"].value.copy()
tokenizer_config.update(dict(spatial_compression=self.patch_size))
z_factor = tokenizer_config["z_factor"]
z_channels = tokenizer_config["z_channels"]
latent_channels = tokenizer_config["latent_channels"]
encoder_name = kwargs.get("encoder", EncoderType.Default.name)
print(tokenizer_config)
del tokenizer_config["z_factor"]
del tokenizer_config["z_channels"]
del tokenizer_config["latent_channels"]
self.encoder = EncoderType[encoder_name].value(z_channels=z_factor * z_channels, **tokenizer_config)
self.quant_conv = torch.nn.Conv2d(z_factor * z_channels, z_factor * latent_channels, 1)
formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
self.distribution = ContinuousFormulation[formulation_name].value()
# Load the pretrained checkpoint
if pretrained_checkpoint_path is not None:
print(f"Loading custom pretrained Cosmos checkpoint from {pretrained_checkpoint_path}")
ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
print(self.load_state_dict(ckpt["model"]))
def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
"""Encodes an image into a latent embedding or code.
Args:
input_tensor: The input tensor Bx3xHxW layout, range [-1..1].
Returns:
For continuous image (CI) tokenizer, the tuple contains:
- The latent embedding, Bx16x(h)x(w), where the compression
rate is (H/h x W/w), and channel dimension of 16.
For discrete image (DI) tokenizer, the tuple contains:
- The indices, Bx(h)x(w), from a codebook of size 64K, which
corresponds to FSQ levels of (8,8,8,5,5,5).
- The discrete code, Bx6x(h)x(w), where the compression rate is
again (H/h x W/w), and channel dimension of 6.
"""
x = self.encoder(input_tensor)
x = self.quant_conv(x)
output_latent = self.distribution(x)
if isinstance(output_latent, torch.Tensor):
return output_latent
return output_latent[:-1]
def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
"""
Cosmos Encoder Forward Pass
Args:
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
Returns:
ViTEncoderOutput: Output data from the encoder.
"""
# Check image normalization type
self._check_data_normalization_type(encoder_input.data_norm_type)
# Check the dtype and shape of the input image
assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
batch_size, channels, height, width = encoder_input.image.shape
assert channels == 3, "Input must have 3 channels"
assert (
height % self.patch_size == 0 and width % self.patch_size == 0
), f"Input shape must be divisible by patch size: {self.patch_size}"
# Extract the features from the DINOv2 model
features = self.encode(encoder_input.image)[0].contiguous()
return ViTEncoderOutput(features=features)
if __name__ == "__main__":
# initialize different variants of the Cosmos Encoder, untrained
for is_continuous in [True]:
for patch_size in [8, 16]:
encoder = CosmosEncoder(name="cosmos", patch_size=patch_size)
# # initialize from trained checkpoint, with/without jit inference capability
PRETRAINED_JIT_CHECKPOINTS = {
("CI", 8): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI8x8/encoder.pth",
("CI", 16): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI16x16/encoder.pth",
}
for patch_size in [8, 16]:
encoder = CosmosEncoder(
name="cosmos",
patch_size=patch_size,
pretrained_checkpoint_path=PRETRAINED_JIT_CHECKPOINTS[("CI", patch_size)],
)
# example inference
dummy_image = torch.randn(1, 3, 256, 256).cuda()
encoder_input = ViTEncoderInput(data_norm_type="cosmos", image=dummy_image)
encoder = encoder.cuda()
encoder_output = encoder(encoder_input)