Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Encoder Class for Cosmos | |
| """ | |
| import torch | |
| from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput | |
| from uniception.models.libs.cosmos_tokenizer.modules import ContinuousFormulation, EncoderType | |
| from uniception.models.libs.cosmos_tokenizer.networks import TokenizerConfigs | |
| class CosmosEncoder(UniCeptionViTEncoderBase): | |
| "Uniception Cosmos Encoder" | |
| def __init__( | |
| self, | |
| name: str, | |
| data_norm_type: str = "cosmos", | |
| patch_size: int = 8, | |
| pretrained_checkpoint_path: str = None, | |
| *args, | |
| **kwargs, | |
| ): | |
| """ | |
| Cosmos Encoder for extracting spatial features from images. | |
| Args: | |
| name (str): Name of the encoder. | |
| data_norm_type (str): Image normalization type. Default: "cosmos" | |
| patch_size (int): Patch size for the encoder. Default: 8 | |
| pretrained_checkpoint_path (str): Path to the pretrained checkpoint. Default: None | |
| """ | |
| # Init the base class | |
| super().__init__(name=name, data_norm_type=data_norm_type, patch_size=patch_size, *args, **kwargs) | |
| # Init Cosmos Encoder sepecific attributes | |
| tokenizer_config = TokenizerConfigs["CI"].value.copy() | |
| tokenizer_config.update(dict(spatial_compression=self.patch_size)) | |
| z_factor = tokenizer_config["z_factor"] | |
| z_channels = tokenizer_config["z_channels"] | |
| latent_channels = tokenizer_config["latent_channels"] | |
| encoder_name = kwargs.get("encoder", EncoderType.Default.name) | |
| print(tokenizer_config) | |
| del tokenizer_config["z_factor"] | |
| del tokenizer_config["z_channels"] | |
| del tokenizer_config["latent_channels"] | |
| self.encoder = EncoderType[encoder_name].value(z_channels=z_factor * z_channels, **tokenizer_config) | |
| self.quant_conv = torch.nn.Conv2d(z_factor * z_channels, z_factor * latent_channels, 1) | |
| formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) | |
| self.distribution = ContinuousFormulation[formulation_name].value() | |
| # Load the pretrained checkpoint | |
| if pretrained_checkpoint_path is not None: | |
| print(f"Loading custom pretrained Cosmos checkpoint from {pretrained_checkpoint_path}") | |
| ckpt = torch.load(pretrained_checkpoint_path, weights_only=False) | |
| print(self.load_state_dict(ckpt["model"])) | |
| def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: | |
| """Encodes an image into a latent embedding or code. | |
| Args: | |
| input_tensor: The input tensor Bx3xHxW layout, range [-1..1]. | |
| Returns: | |
| For continuous image (CI) tokenizer, the tuple contains: | |
| - The latent embedding, Bx16x(h)x(w), where the compression | |
| rate is (H/h x W/w), and channel dimension of 16. | |
| For discrete image (DI) tokenizer, the tuple contains: | |
| - The indices, Bx(h)x(w), from a codebook of size 64K, which | |
| corresponds to FSQ levels of (8,8,8,5,5,5). | |
| - The discrete code, Bx6x(h)x(w), where the compression rate is | |
| again (H/h x W/w), and channel dimension of 6. | |
| """ | |
| x = self.encoder(input_tensor) | |
| x = self.quant_conv(x) | |
| output_latent = self.distribution(x) | |
| if isinstance(output_latent, torch.Tensor): | |
| return output_latent | |
| return output_latent[:-1] | |
| def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput: | |
| """ | |
| Cosmos Encoder Forward Pass | |
| Args: | |
| encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. | |
| Returns: | |
| ViTEncoderOutput: Output data from the encoder. | |
| """ | |
| # Check image normalization type | |
| self._check_data_normalization_type(encoder_input.data_norm_type) | |
| # Check the dtype and shape of the input image | |
| assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor" | |
| assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)" | |
| batch_size, channels, height, width = encoder_input.image.shape | |
| assert channels == 3, "Input must have 3 channels" | |
| assert ( | |
| height % self.patch_size == 0 and width % self.patch_size == 0 | |
| ), f"Input shape must be divisible by patch size: {self.patch_size}" | |
| # Extract the features from the DINOv2 model | |
| features = self.encode(encoder_input.image)[0].contiguous() | |
| return ViTEncoderOutput(features=features) | |
| if __name__ == "__main__": | |
| # initialize different variants of the Cosmos Encoder, untrained | |
| for is_continuous in [True]: | |
| for patch_size in [8, 16]: | |
| encoder = CosmosEncoder(name="cosmos", patch_size=patch_size) | |
| # # initialize from trained checkpoint, with/without jit inference capability | |
| PRETRAINED_JIT_CHECKPOINTS = { | |
| ("CI", 8): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI8x8/encoder.pth", | |
| ("CI", 16): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI16x16/encoder.pth", | |
| } | |
| for patch_size in [8, 16]: | |
| encoder = CosmosEncoder( | |
| name="cosmos", | |
| patch_size=patch_size, | |
| pretrained_checkpoint_path=PRETRAINED_JIT_CHECKPOINTS[("CI", patch_size)], | |
| ) | |
| # example inference | |
| dummy_image = torch.randn(1, 3, 256, 256).cuda() | |
| encoder_input = ViTEncoderInput(data_norm_type="cosmos", image=dummy_image) | |
| encoder = encoder.cuda() | |
| encoder_output = encoder(encoder_input) | |