Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Encoder Class for CroCo & DUSt3R | |
| """ | |
| from functools import partial | |
| from typing import Callable, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput | |
| from uniception.models.libs.croco.blocks import Block | |
| from uniception.models.libs.croco.patch_embed import get_patch_embed | |
| from uniception.models.libs.croco.pos_embed import RoPE2D | |
| from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices | |
| class CroCoEncoder(UniCeptionViTEncoderBase): | |
| "UniCeption CroCov2 Encoder" | |
| def __init__( | |
| self, | |
| name: str, | |
| data_norm_type: str, | |
| patch_embed_cls: str = "PatchEmbedDust3R", | |
| img_size: Union[int, Tuple[int, int]] = (224, 224), | |
| patch_size: int = 16, | |
| enc_embed_dim: int = 1024, | |
| enc_depth: int = 24, | |
| enc_num_heads: int = 16, | |
| mlp_ratio: int = 4, | |
| norm_layer: Callable = partial(nn.LayerNorm, eps=1e-6), | |
| pos_embed: str = "RoPE100", | |
| pretrained_checkpoint_path: str = None, | |
| override_checkpoint_attributes: bool = False, | |
| *args, | |
| **kwargs, | |
| ): | |
| """ | |
| References: https://github.com/naver/dust3r, https://github.com/naver/croco | |
| Args: | |
| name (str): Name of the encoder. | |
| data_norm_type (str): Input data normalization type. | |
| patch_embed_cls (str, optional): The class to use for patch embedding. | |
| Defaults to 'PatchEmbedDust3R'. Options: ['PatchEmbedCroCo', 'PatchEmbedDust3R', 'ManyAR_PatchEmbed']. | |
| img_size (int, optional): The size of the input image. Defaults to 224. | |
| patch_size (int, optional): The size of the patches to divide the image into. Defaults to 16. | |
| enc_embed_dim (int, optional): The dimension of the encoder's embedding. Defaults to 768. | |
| enc_depth (int, optional): The number of encoder layers/transformer blocks. Defaults to 12. | |
| enc_num_heads (int, optional): The number of encoder heads. Defaults to 12. | |
| mlp_ratio (int, optional): The MLP ratio used for the CroCo encoder transformer. Defaults to 4. | |
| norm_layer (nn.Module, optional): The normalization layer to use in the transformer. Defaults to nn.LayerNorm with eps=1e-6. | |
| pos_embed (str, optional): Positional Embedding. Defaults to 'RoPE100'. Options: ['RoPEfreq']. | |
| pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. Defaults to None. | |
| """ | |
| # Init the base class | |
| super().__init__( | |
| name=name, | |
| data_norm_type=data_norm_type, | |
| patch_size=patch_size, | |
| *args, | |
| **kwargs, | |
| ) | |
| # Init the CroCo Encoder specific attributes | |
| self.patch_embed_cls = patch_embed_cls | |
| self.img_size = img_size | |
| self.enc_embed_dim = enc_embed_dim | |
| self.enc_depth = enc_depth | |
| self.enc_num_heads = enc_num_heads | |
| self.mlp_ratio = mlp_ratio | |
| self.norm_layer = norm_layer | |
| self.pretrained_checkpoint_path = pretrained_checkpoint_path | |
| self.override_checkpoint_attributes = override_checkpoint_attributes | |
| # Init the positional embedding | |
| self.pos_embed = pos_embed | |
| if pos_embed.startswith("RoPE"): # eg RoPE100 | |
| self.enc_pos_embed = None # nothing to add in the encoder with RoPE | |
| self.dec_pos_embed = None # nothing to add in the decoder with RoPE | |
| if RoPE2D is None: | |
| raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") | |
| freq = float(pos_embed[len("RoPE") :]) | |
| self.rope = RoPE2D(freq=freq) | |
| else: | |
| raise NotImplementedError("Unknown pos_embed " + pos_embed) | |
| # Init the patch embedding | |
| self._set_patch_embed(img_size, patch_size, enc_embed_dim) | |
| # Init the encoder | |
| self._set_encoder(enc_depth, enc_embed_dim, enc_num_heads, mlp_ratio, norm_layer, self.rope) | |
| # Initialize random weights | |
| self.initialize_weights() | |
| # Load the pretrained CroCo checkpoint if provided | |
| if pretrained_checkpoint_path: | |
| print(f"Loading pretrained CroCo checkpoint from {pretrained_checkpoint_path}") | |
| ckpt = torch.load(pretrained_checkpoint_path, weights_only=False) | |
| print(self.load_state_dict(ckpt["model"])) | |
| if not override_checkpoint_attributes: | |
| ckpt_data_norm_type = ckpt["data_norm_type"] | |
| ckpt_patch_embed_cls = ckpt["patch_embed_cls"] | |
| assert ( | |
| data_norm_type == ckpt_data_norm_type | |
| ), f"Data normalization type {data_norm_type} does not match the checkpoint {ckpt_data_norm_type}." | |
| assert ( | |
| patch_embed_cls == ckpt_patch_embed_cls | |
| ), f"Patch embedding class {patch_embed_cls} does not match the checkpoint {ckpt_patch_embed_cls}." | |
| else: | |
| print("No pretrained checkpoint provided. Randomly initializing the CroCo encoder.") | |
| def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): | |
| "Set the patch embedding scheme" | |
| self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) | |
| def _set_encoder(self, enc_depth, enc_embed_dim, enc_num_heads, mlp_ratio, norm_layer, rope): | |
| "Set the encoder" | |
| self.enc_blocks = nn.ModuleList( | |
| [ | |
| Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=rope) | |
| for _ in range(enc_depth) | |
| ] | |
| ) | |
| self.enc_norm = norm_layer(enc_embed_dim) | |
| def initialize_weights(self): | |
| "Initialize the weights of the patch embedding and the transformer encoder" | |
| # Patch embedding | |
| self.patch_embed._init_weights() | |
| # Linears and layer norms | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| "Initialize the transformer encoder weights" | |
| if isinstance(m, nn.Linear): | |
| # We use xavier_uniform following official JAX ViT: | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput: | |
| """ | |
| CroCov2 Encoder Forward Pass | |
| Args: | |
| encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. | |
| Returns: | |
| ViTEncoderOutput: Output data from the encoder. | |
| """ | |
| # Check image normalization type | |
| self._check_data_normalization_type(encoder_input.data_norm_type) | |
| # Get the true shape of the image for landscape/portrait mode check in patch embedding | |
| batch_size, _, height, width = encoder_input.image.shape | |
| if hasattr(encoder_input, "true_shape"): | |
| true_shape = encoder_input.true_shape | |
| else: | |
| true_shape = torch.tensor([height, width])[None].repeat(batch_size, 1) | |
| # Embed the image into patches | |
| features, pos = self.patch_embed(encoder_input.image, true_shape=true_shape) | |
| # Now apply the transformer encoder and normalization | |
| for blk in self.enc_blocks: | |
| features = blk(features, pos) | |
| features = self.enc_norm(features) | |
| # Resize the features to the expected shape | |
| # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size) | |
| features = features.permute(0, 2, 1) | |
| features = features.reshape( | |
| -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size | |
| ).contiguous() | |
| return ViTEncoderOutput(features=features) | |
| class CroCoIntermediateFeatureReturner(CroCoEncoder, IntermediateFeatureReturner): | |
| "Intermediate Feature Returner for UniCeption CroCo Encoder" | |
| def __init__( | |
| self, | |
| name: str, | |
| data_norm_type: str, | |
| patch_embed_cls: str = "PatchEmbedDust3R", | |
| img_size: Union[int, Tuple[int, int]] = (224, 224), | |
| patch_size: int = 16, | |
| enc_embed_dim: int = 1024, | |
| enc_depth: int = 24, | |
| enc_num_heads: int = 16, | |
| mlp_ratio: int = 4, | |
| norm_layer: Callable = partial(nn.LayerNorm, eps=1e-6), | |
| pos_embed: str = "RoPE100", | |
| pretrained_checkpoint_path: str = None, | |
| indices: Optional[Union[int, List[int]]] = None, | |
| norm_intermediate: bool = True, | |
| stop_early: bool = False, | |
| intermediates_only: bool = True, | |
| *args, | |
| **kwargs, | |
| ): | |
| """ | |
| Intermediate Feature Returner for the CroCo Encoder. | |
| Args: | |
| name (str): Name of the encoder. | |
| data_norm_type (str): Input data normalization type. | |
| patch_embed_cls (str, optional): The class to use for patch embedding. | |
| Defaults to 'PatchEmbedDust3R'. Options: ['PatchEmbedCroCo', 'PatchEmbedDust3R', 'ManyAR_PatchEmbed']. | |
| img_size (int, optional): The size of the input image. Defaults to 224. | |
| patch_size (int, optional): The size of the patches to divide the image into. Defaults to 16. | |
| enc_embed_dim (int, optional): The dimension of the encoder's embedding. Defaults to 768. | |
| enc_depth (int, optional): The number of encoder layers/transformer blocks. Defaults to 12. | |
| enc_num_heads (int, optional): The number of encoder heads. Defaults to 12. | |
| mlp_ratio (int, optional): The MLP ratio used for the CroCo encoder transformer. Defaults to 4. | |
| norm_layer (nn.Module, optional): The normalization layer to use in the transformer. Defaults to nn.LayerNorm with eps=1e-6. | |
| pos_embed (str, optional): Positional Embedding. Defaults to 'RoPE100'. Options: ['cosine', 'RoPE100']. | |
| pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. Defaults to None. | |
| indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to None. Options: | |
| - None: Return all intermediate layers. | |
| - int: Return the last n layers. | |
| - List[int]: Return the intermediate layers at the specified indices. | |
| norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True. | |
| stop_early (bool, optional): Whether to stop early. Defaults to False. | |
| intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True. | |
| """ | |
| # Init the base classes | |
| CroCoEncoder.__init__( | |
| self, | |
| name=name, | |
| data_norm_type=data_norm_type, | |
| patch_embed_cls=patch_embed_cls, | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| enc_embed_dim=enc_embed_dim, | |
| enc_depth=enc_depth, | |
| enc_num_heads=enc_num_heads, | |
| mlp_ratio=mlp_ratio, | |
| norm_layer=norm_layer, | |
| pos_embed=pos_embed, | |
| pretrained_checkpoint_path=pretrained_checkpoint_path, | |
| *args, | |
| **kwargs, | |
| ) | |
| IntermediateFeatureReturner.__init__( | |
| self, | |
| indices=indices, | |
| norm_intermediate=norm_intermediate, | |
| stop_early=stop_early, | |
| intermediates_only=intermediates_only, | |
| ) | |
| def forward( | |
| self, encoder_input: ViTEncoderInput | |
| ) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: | |
| """ | |
| CroCov2 Encoder Forward Pass with Intermediate Feature Return | |
| Args: | |
| encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. | |
| Returns: | |
| Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder. | |
| If `intermediates_only` is True, returns a list of intermediate features. | |
| Otherwise, returns a tuple with the final features and a list of intermediate features. | |
| """ | |
| # Check image normalization type | |
| self._check_data_normalization_type(encoder_input.data_norm_type) | |
| # Get the true shape of the image for landscape/portrait mode check in patch embedding | |
| batch_size, _, height, width = encoder_input.image.shape | |
| if hasattr(encoder_input, "true_shape"): | |
| true_shape = encoder_input.true_shape | |
| else: | |
| true_shape = torch.tensor([height, width])[None].repeat(batch_size, 1) | |
| # Embed the image into patches | |
| features, pos = self.patch_embed(encoder_input.image, true_shape=true_shape) | |
| # Get indices of the intermediate features to return | |
| intermediate_features = [] | |
| take_indices, max_index = feature_take_indices(len(self.enc_blocks), self.indices) | |
| # Get the blocks based on early stopping | |
| if torch.jit.is_scripting() or not self.stop_early: # can't slice blocks in torchscript | |
| blocks = self.enc_blocks | |
| else: | |
| blocks = self.enc_blocks[: max_index + 1] | |
| # Now apply the transformer encoder and normalization | |
| for blk_idx, blk in enumerate(blocks): | |
| features = blk(features, pos) | |
| if blk_idx in take_indices: | |
| # Normalize intermediates with final norm layer if enabled | |
| intermediate_features.append(self.enc_norm(features) if self.norm_intermediate else features) | |
| # Reshape the intermediate features and convert to ViTEncoderOutput class | |
| intermediate_features = [ | |
| intermediate.permute(0, 2, 1) | |
| .reshape(-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size) | |
| .contiguous() | |
| for intermediate in intermediate_features | |
| ] | |
| intermediate_features = [ViTEncoderOutput(features=intermediate) for intermediate in intermediate_features] | |
| # Return only the intermediate features if enabled | |
| if self.intermediates_only: | |
| return intermediate_features | |
| # Normalize and reshape the final features | |
| features = self.enc_norm(features) | |
| # Resize the features to the expected shape | |
| # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size) | |
| features = features.permute(0, 2, 1) | |
| features = features.reshape( | |
| -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size | |
| ).contiguous() | |
| final_features = ViTEncoderOutput(features=features) | |
| return final_features, intermediate_features | |
| if __name__ == "__main__": | |
| # Init the pre-trained CroCo Encoder | |
| pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_224.pth" | |
| croco_encoder = CroCoEncoder( | |
| name="croco", | |
| data_norm_type="croco", | |
| pretrained_checkpoint_path=pretrained_checkpoint_path, | |
| patch_embed_cls="PatchEmbedCroCo", | |
| ) | |
| # Init the pre-trained DUSt3R CroCo Encoder | |
| pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_224_DUSt3R_linear.pth" | |
| dust3r_encoder = CroCoEncoder( | |
| name="dust3r_224", | |
| data_norm_type="dust3r", | |
| pretrained_checkpoint_path=pretrained_checkpoint_path, | |
| patch_embed_cls="PatchEmbedDust3R", | |
| ) | |
| # Init the pre-trained DUSt3R 512 linear CroCo Encoder | |
| pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_linear.pth" | |
| dust3r_encoder_512 = CroCoEncoder( | |
| name="dust3r_512", | |
| data_norm_type="dust3r", | |
| pretrained_checkpoint_path=pretrained_checkpoint_path, | |
| patch_embed_cls="ManyAR_PatchEmbed", | |
| img_size=(512, 512), | |
| ) | |
| # Init the pre-trained DUSt3R 512 DPT CroCo Encoder | |
| pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth" | |
| dust3r_encoder_512_dpt = CroCoEncoder( | |
| name="dust3r_512_dpt", | |
| data_norm_type="dust3r", | |
| pretrained_checkpoint_path=pretrained_checkpoint_path, | |
| patch_embed_cls="ManyAR_PatchEmbed", | |
| img_size=(512, 512), | |
| ) | |
| # Init the MASt3R 512 CroCo Encoder | |
| pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_MASt3R.pth" | |
| mast3r_encoder_512 = CroCoEncoder( | |
| name="mast3r_512", | |
| data_norm_type="dust3r", | |
| pretrained_checkpoint_path=pretrained_checkpoint_path, | |
| patch_embed_cls="ManyAR_PatchEmbed", | |
| img_size=(512, 512), | |
| ) | |
| print("All CroCo & DUSt3R Encoders have been initialized successfully!") | |
| # Intermediate Feature Returner Tests | |
| print("Running Intermediate Feature Returner Tests...") | |
| pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth" | |
| # Run the intermediate feature returner with last-n index | |
| dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner( | |
| name="dust3r_512_dpt", | |
| data_norm_type="dust3r", | |
| pretrained_checkpoint_path=pretrained_checkpoint_path, | |
| patch_embed_cls="ManyAR_PatchEmbed", | |
| img_size=(512, 512), | |
| indices=6, # Last 6 layers | |
| ) | |
| dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r") | |
| output = dust3r_intermediate_feature_returner(dummy_input) | |
| assert isinstance(output, list), "Output must be a list of intermediate features" | |
| assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" | |
| assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices" | |
| # Run the intermediate feature returner with specific indices | |
| dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner( | |
| name="dust3r_512_dpt", | |
| data_norm_type="dust3r", | |
| pretrained_checkpoint_path=pretrained_checkpoint_path, | |
| patch_embed_cls="ManyAR_PatchEmbed", | |
| img_size=(512, 512), | |
| indices=[0, 2, 4, 6], # Specific layers | |
| ) | |
| dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r") | |
| output = dust3r_intermediate_feature_returner(dummy_input) | |
| assert isinstance(output, list), "Output must be a list of intermediate features" | |
| assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" | |
| assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices" | |
| # Test the normalizing of intermediate features | |
| dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner( | |
| name="dust3r_512_dpt", | |
| data_norm_type="dust3r", | |
| pretrained_checkpoint_path=pretrained_checkpoint_path, | |
| patch_embed_cls="ManyAR_PatchEmbed", | |
| img_size=(512, 512), | |
| indices=[-1], | |
| norm_intermediate=False, | |
| intermediates_only=False, | |
| ) | |
| dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r") | |
| output = dust3r_intermediate_feature_returner(dummy_input) | |
| assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features" | |
| assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features" | |
| assert isinstance(output[1], list), "Second element of output must be a list of intermediate features" | |
| assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" | |
| if not isinstance(dust3r_intermediate_feature_returner.enc_norm, torch.nn.Identity): | |
| assert not torch.equal( | |
| output[0].features, output[1][0].features | |
| ), "Final features and intermediate features must be different" | |
| dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner( | |
| name="dust3r_512_dpt", | |
| data_norm_type="dust3r", | |
| pretrained_checkpoint_path=pretrained_checkpoint_path, | |
| patch_embed_cls="ManyAR_PatchEmbed", | |
| img_size=(512, 512), | |
| indices=[-1], | |
| norm_intermediate=True, | |
| intermediates_only=False, | |
| ) | |
| dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r") | |
| output = dust3r_intermediate_feature_returner(dummy_input) | |
| assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features" | |
| assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features" | |
| assert isinstance(output[1], list), "Second element of output must be a list of intermediate features" | |
| assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" | |
| assert torch.equal( | |
| output[0].features, output[1][0].features | |
| ), "Final features and intermediate features must be same" | |
| print("All Intermediate Feature Returner Tests have passed successfully!") | |