Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,686 Bytes
c8b42eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
"""
Encoder Class for Cosmos
"""
import torch
from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
from uniception.models.libs.cosmos_tokenizer.modules import ContinuousFormulation, EncoderType
from uniception.models.libs.cosmos_tokenizer.networks import TokenizerConfigs
class CosmosEncoder(UniCeptionViTEncoderBase):
"Uniception Cosmos Encoder"
def __init__(
self,
name: str,
data_norm_type: str = "cosmos",
patch_size: int = 8,
pretrained_checkpoint_path: str = None,
*args,
**kwargs,
):
"""
Cosmos Encoder for extracting spatial features from images.
Args:
name (str): Name of the encoder.
data_norm_type (str): Image normalization type. Default: "cosmos"
patch_size (int): Patch size for the encoder. Default: 8
pretrained_checkpoint_path (str): Path to the pretrained checkpoint. Default: None
"""
# Init the base class
super().__init__(name=name, data_norm_type=data_norm_type, patch_size=patch_size, *args, **kwargs)
# Init Cosmos Encoder sepecific attributes
tokenizer_config = TokenizerConfigs["CI"].value.copy()
tokenizer_config.update(dict(spatial_compression=self.patch_size))
z_factor = tokenizer_config["z_factor"]
z_channels = tokenizer_config["z_channels"]
latent_channels = tokenizer_config["latent_channels"]
encoder_name = kwargs.get("encoder", EncoderType.Default.name)
print(tokenizer_config)
del tokenizer_config["z_factor"]
del tokenizer_config["z_channels"]
del tokenizer_config["latent_channels"]
self.encoder = EncoderType[encoder_name].value(z_channels=z_factor * z_channels, **tokenizer_config)
self.quant_conv = torch.nn.Conv2d(z_factor * z_channels, z_factor * latent_channels, 1)
formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
self.distribution = ContinuousFormulation[formulation_name].value()
# Load the pretrained checkpoint
if pretrained_checkpoint_path is not None:
print(f"Loading custom pretrained Cosmos checkpoint from {pretrained_checkpoint_path}")
ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
print(self.load_state_dict(ckpt["model"]))
def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
"""Encodes an image into a latent embedding or code.
Args:
input_tensor: The input tensor Bx3xHxW layout, range [-1..1].
Returns:
For continuous image (CI) tokenizer, the tuple contains:
- The latent embedding, Bx16x(h)x(w), where the compression
rate is (H/h x W/w), and channel dimension of 16.
For discrete image (DI) tokenizer, the tuple contains:
- The indices, Bx(h)x(w), from a codebook of size 64K, which
corresponds to FSQ levels of (8,8,8,5,5,5).
- The discrete code, Bx6x(h)x(w), where the compression rate is
again (H/h x W/w), and channel dimension of 6.
"""
x = self.encoder(input_tensor)
x = self.quant_conv(x)
output_latent = self.distribution(x)
if isinstance(output_latent, torch.Tensor):
return output_latent
return output_latent[:-1]
def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
"""
Cosmos Encoder Forward Pass
Args:
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
Returns:
ViTEncoderOutput: Output data from the encoder.
"""
# Check image normalization type
self._check_data_normalization_type(encoder_input.data_norm_type)
# Check the dtype and shape of the input image
assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
batch_size, channels, height, width = encoder_input.image.shape
assert channels == 3, "Input must have 3 channels"
assert (
height % self.patch_size == 0 and width % self.patch_size == 0
), f"Input shape must be divisible by patch size: {self.patch_size}"
# Extract the features from the DINOv2 model
features = self.encode(encoder_input.image)[0].contiguous()
return ViTEncoderOutput(features=features)
if __name__ == "__main__":
# initialize different variants of the Cosmos Encoder, untrained
for is_continuous in [True]:
for patch_size in [8, 16]:
encoder = CosmosEncoder(name="cosmos", patch_size=patch_size)
# # initialize from trained checkpoint, with/without jit inference capability
PRETRAINED_JIT_CHECKPOINTS = {
("CI", 8): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI8x8/encoder.pth",
("CI", 16): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI16x16/encoder.pth",
}
for patch_size in [8, 16]:
encoder = CosmosEncoder(
name="cosmos",
patch_size=patch_size,
pretrained_checkpoint_path=PRETRAINED_JIT_CHECKPOINTS[("CI", patch_size)],
)
# example inference
dummy_image = torch.randn(1, 3, 256, 256).cuda()
encoder_input = ViTEncoderInput(data_norm_type="cosmos", image=dummy_image)
encoder = encoder.cuda()
encoder_output = encoder(encoder_input)
|