Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Encoder class for Patch Embedder | |
| """ | |
| import math | |
| from functools import partial | |
| from typing import Callable, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.init import trunc_normal_ | |
| from uniception.models.encoders.base import ( | |
| UniCeptionViTEncoderBase, | |
| ViTEncoderInput, | |
| ViTEncoderNonImageInput, | |
| ViTEncoderOutput, | |
| ) | |
| def make_2tuple(x): | |
| if isinstance(x, tuple): | |
| assert len(x) == 2 | |
| return x | |
| assert isinstance(x, int) | |
| return (x, x) | |
| class PatchEmbedder(UniCeptionViTEncoderBase): | |
| "UniCeption Patch Embedder" | |
| def __init__( | |
| self, | |
| name: str, | |
| data_norm_type: str = "patch_embedder", | |
| input_size: Union[int, Tuple[int, int]] = 518, | |
| patch_size: int = 14, | |
| in_chans: int = 3, | |
| enc_embed_dim: int = 1024, | |
| norm_layer: Optional[Callable] = None, | |
| post_pe_norm_layer: Optional[Callable] = partial(nn.LayerNorm, eps=1e-6), | |
| interpolate_antialias: bool = False, | |
| interpolate_offset: float = 0.1, | |
| pretrained_checkpoint_path: str = None, | |
| *args, | |
| **kwargs, | |
| ): | |
| """ | |
| Patch Encoder for extracting patch-wise features from a spatial input of size (B, C, H, W). | |
| Learnable positional encoding is also applied to the patch-wise features. | |
| """ | |
| # Init the base class | |
| super().__init__( | |
| name=name, | |
| data_norm_type=data_norm_type, | |
| patch_size=patch_size, | |
| *args, | |
| **kwargs, | |
| ) | |
| # Init the Patch Embedder specific attributes | |
| patch_HW = make_2tuple(patch_size) | |
| self.input_size = make_2tuple(input_size) | |
| self.patches_resolution = (self.input_size[0] // patch_HW[0], self.input_size[1] // patch_HW[1]) | |
| self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] | |
| self.in_chans = in_chans | |
| self.enc_embed_dim = enc_embed_dim | |
| # Init the Patch Embedder layers | |
| self.proj = nn.Conv2d(in_chans, enc_embed_dim, kernel_size=patch_HW, stride=patch_HW) | |
| self.norm = norm_layer(enc_embed_dim) if norm_layer else nn.Identity() | |
| # Init the learnable positional encodings | |
| self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.enc_embed_dim)) | |
| trunc_normal_(self.pos_embed, std=0.02) | |
| self.interpolate_antialias = interpolate_antialias | |
| self.interpolate_offset = interpolate_offset | |
| # Init the norm layer after positional encoding | |
| self.post_pe_norm = post_pe_norm_layer(enc_embed_dim) if post_pe_norm_layer else nn.Identity() | |
| # Load the pretrained checkpoint if provided | |
| self.pretrained_checkpoint_path = pretrained_checkpoint_path | |
| if self.pretrained_checkpoint_path: | |
| print(f"Loading custom pretrained Patch Embedder checkpoint from {self.pretrained_checkpoint_path} ...") | |
| ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) | |
| print(self.load_state_dict(ckpt["model"])) | |
| def interpolate_pos_encoding(self, features, height, width): | |
| """ | |
| Interpolate the positional encoding to the expected size. | |
| Args: | |
| features (torch.Tensor): Input tensor of shape (B, N, C). | |
| height (int, float): Height of the input tensor. | |
| width (int, float): Width of the input tensor. | |
| Returns: | |
| torch.Tensor: Interpolated positional encoding tensor of shape (1, N, C). | |
| """ | |
| previous_dtype = features.dtype | |
| npatch = features.shape[1] | |
| N = self.pos_embed.shape[1] | |
| if npatch == N and height == width: | |
| return self.pos_embed | |
| patch_pos_embed = self.pos_embed.float() | |
| dim = features.shape[-1] | |
| height0 = height // self.patch_size | |
| width0 = width // self.patch_size | |
| M = int(math.sqrt(N)) # Recover the number of patches in each dimension | |
| assert N == M * M | |
| kwargs = {} | |
| if self.interpolate_offset: | |
| # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 | |
| # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors | |
| sh = float(height0 + self.interpolate_offset) / M | |
| sw = float(width0 + self.interpolate_offset) / M | |
| kwargs["scale_factor"] = (sh, sw) | |
| else: | |
| # Simply specify an output size instead of a scale factor | |
| kwargs["size"] = (height0, width0) | |
| patch_pos_embed = nn.functional.interpolate( | |
| patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), | |
| mode="bicubic", | |
| antialias=self.interpolate_antialias, | |
| **kwargs, | |
| ) | |
| assert (height0, width0) == patch_pos_embed.shape[-2:] | |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
| return patch_pos_embed.to(previous_dtype) | |
| def forward(self, encoder_input: Union[ViTEncoderInput, ViTEncoderNonImageInput]) -> ViTEncoderOutput: | |
| """ | |
| Patch Embedder Forward Pass | |
| Args: | |
| encoder_input (Union[ViTEncoderInput, ViTEncoderNonImageInput]): Input data for the encoder. | |
| If input type is ViTEncoderInput, input data must contain image normalization type and normalized image tensor. | |
| If input type is ViTEncoderNonImageInput, input data must contain a tensor of size (B, C, H, W). | |
| Returns: | |
| ViTEncoderOutput: Output data from the encoder. | |
| """ | |
| # Get the input data and verify normalization if the input type is ViTEncoderInput | |
| if isinstance(encoder_input, ViTEncoderInput): | |
| self._check_data_normalization_type(encoder_input.data_norm_type) | |
| input_data = encoder_input.image | |
| elif isinstance(encoder_input, ViTEncoderNonImageInput): | |
| input_data = encoder_input.data | |
| else: | |
| raise ValueError("Unsupported input type for Patch Embedder.") | |
| # Check the dtype and shape of the input | |
| assert isinstance(input_data, torch.Tensor), "Input must be a torch.Tensor" | |
| assert input_data.ndim == 4, "Input must be of shape (B, C, H, W)" | |
| batch_size, channels, height, width = input_data.shape | |
| assert ( | |
| height % self.patch_size == 0 and width % self.patch_size == 0 | |
| ), f"Input shape must be divisible by patch size: {self.patch_size}" | |
| # Patchify the input data and project into expected latent space | |
| features = self.proj(input_data) # (B, C, H, W) -> (B, E, H / Patch_Size, W / Patch_Size) | |
| features = features.flatten(2).transpose( | |
| 1, 2 | |
| ) # (B, E, H / Patch_Size, W / Patch_Size) -> (B, H / Patch_Size * W / Patch_Size, E) | |
| features = self.norm(features) # Normalize the features after patch embedding | |
| features = features + self.interpolate_pos_encoding( | |
| features, height, width | |
| ) # (B, H / Patch_Size * W / Patch_Size, E) | |
| features = self.post_pe_norm(features) # Normalize the features after positional encoding | |
| # Resize the features to the expected shape | |
| # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size) | |
| features = features.permute(0, 2, 1) | |
| features = features.reshape( | |
| -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size | |
| ).contiguous() | |
| return ViTEncoderOutput(features=features) | |
| if __name__ == "__main__": | |
| # Init Patch Embedder for images as input | |
| patch_embedder = PatchEmbedder( | |
| name="patch_embedder", | |
| data_norm_type="patch_embedder", | |
| input_size=518, | |
| patch_size=14, | |
| in_chans=3, | |
| enc_embed_dim=1024, | |
| ) | |
| # Test dummy image input | |
| dummy_image = torch.randn(1, 3, 518, 518) | |
| patch_embedder_output = patch_embedder(ViTEncoderInput(data_norm_type="patch_embedder", image=dummy_image)) | |
| assert patch_embedder_output.features.shape == ( | |
| 1, | |
| 1024, | |
| 37, | |
| 37, | |
| ), "Output features must have shape (1, 1024, 37, 37)" | |
| # Init Patch Embedder for non-image data as input | |
| patch_embedder = PatchEmbedder( | |
| name="patch_embedder", | |
| data_norm_type="patch_embedder", | |
| input_size=518, | |
| patch_size=14, | |
| in_chans=6, | |
| enc_embed_dim=1024, | |
| ) | |
| # Init Patch Embedder for single channel input | |
| patch_embedder = PatchEmbedder( | |
| name="patch_embedder", | |
| data_norm_type="patch_embedder", | |
| input_size=518, | |
| patch_size=14, | |
| in_chans=1, | |
| enc_embed_dim=1024, | |
| ) | |
| # Test dummy non-image input | |
| dummy_image = torch.randn(1, 1, 518, 518) | |
| patch_embedder_output = patch_embedder(ViTEncoderNonImageInput(data=dummy_image)) | |
| assert patch_embedder_output.features.shape == ( | |
| 1, | |
| 1024, | |
| 37, | |
| 37, | |
| ), "Output features must have shape (1, 1024, 37, 37)" | |
| print("All variants of Patch Embedder have been initialized successfully!") | |