| |
|
|
| import math |
| from itertools import chain |
| from typing import Any, Optional |
| from omegaconf import OmegaConf |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.functional import interpolate |
| from einops.layers.torch import Rearrange |
|
|
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers import AutoConfig, AutoModel, AutoProcessor, AutoImageProcessor |
| from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTModel |
|
|
| def handle_feature_output( |
| x: torch.Tensor, feature_reduce_method: Optional[str] = None, num_discard_tokens: int = 0 |
| ) -> torch.Tensor: |
| """Handle feature output from transformer. |
| |
| Args: |
| x (torch.Tensor): input feature to be handled. shape is |
| [B, 1+H*W+N, C] if including both CLS and register tokens. |
| [B, 1+H*W, C] for standard model (N=0). |
| [B, H*W, C] for model without CLS. |
| feature_reduce_method (Optional[str]): method to select token. Options: |
| - `mean_pooling`: average over spatial tokens (non CLS tokens), output shape = [B, C]. |
| - `max_pooling`: max over spatial tokens, output shape = [B, C]. |
| - `cls`: return CLS token only, output shape = [B, C]. |
| - `identity`: return the feature without touching it, output shape = input shape. |
| - `None`: return spatial tokens, output shape = [B, H*W, C] (assuming input is [B, 1+H*W, C]). |
| suppose raw feature is in shape [B, 1+H*W, C], `1` corresponds to CLS token. |
| num_discard_tokens (int): |
| number of tokens to be discarded. Assuming they are at the end of the sequence. |
| Returns: |
| torch.Tensor: selected feature tokens. |
| """ |
|
|
| match feature_reduce_method: |
| case "mean_pooling": |
| return torch.mean(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) |
| case "max_pooling": |
| return torch.amax(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) |
| case "cls": |
| return x[:, 0] |
| case "identity": |
| return x |
| case None: |
| return x[:, 1 : x.size(1) - num_discard_tokens] |
| case _: |
| raise NotImplementedError(f"feature_reduce_method {feature_reduce_method} it not implemented.") |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| class ViTEmbeddingsNoCLS(ViTEmbeddings): |
| """ViT Embedding Module without CLS token.""" |
|
|
| def __init__(self, config: AutoConfig, use_mask_token: bool = False): |
| """Initialization. |
| |
| Args: |
| config (AutoConfig): config for ViT. |
| use_mask_token (bool, optional): whether to use mask token. Defaults to False. |
| """ |
| super(ViTEmbeddingsNoCLS, self).__init__(config, use_mask_token=use_mask_token) |
| self.cls_token = None |
|
|
| def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: |
| """ |
| This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher |
| resolution images. |
| |
| Source: |
| https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 |
| """ |
|
|
| num_patches = embeddings.shape[1] |
| num_positions = self.position_embeddings.shape[1] - 1 |
| if num_patches == num_positions and height == width: |
| return self.position_embeddings |
| patch_pos_embed = self.position_embeddings[:, 1:] |
| dim = embeddings.shape[-1] |
| h0 = height // self.config.patch_size |
| w0 = width // self.config.patch_size |
| |
| |
| h0, w0 = h0 + 0.1, w0 + 0.1 |
| patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) |
| patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) |
| patch_pos_embed = nn.functional.interpolate( |
| patch_pos_embed, |
| scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), |
| mode="bicubic", |
| align_corners=False, |
| ) |
| assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
| return patch_pos_embed |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| bool_masked_pos: Optional[torch.BoolTensor] = None, |
| interpolate_pos_encoding: bool = False, |
| ) -> torch.Tensor: |
| batch_size, num_channels, height, width = pixel_values.shape |
| embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) |
|
|
| if bool_masked_pos is not None: |
| seq_length = embeddings.shape[1] |
| mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) |
| |
| mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) |
| embeddings = embeddings * (1.0 - mask) + mask_tokens * mask |
|
|
| |
| if interpolate_pos_encoding: |
| embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) |
| else: |
| embeddings = embeddings + self.position_embeddings[:, 1:] |
|
|
| embeddings = self.dropout(embeddings) |
|
|
| return embeddings |
|
|
|
|
| |
| class ViTModelNoCLS(ViTModel): |
| """ViT Model without CLS token.""" |
|
|
| def __init__(self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None: |
| super(ViTModelNoCLS, self).__init__(config, add_pooling_layer, use_mask_token) |
| self.embeddings = ViTEmbeddingsNoCLS(config, use_mask_token=use_mask_token) |
| self.no_cls = True |
|
|
| def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None: |
| """Initialize the weights""" |
| if isinstance(module, (nn.Linear, nn.Conv2d)): |
| |
| |
| module.weight.data = nn.init.trunc_normal_( |
| module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range |
| ).to(module.weight.dtype) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
| elif isinstance(module, ViTEmbeddings): |
| module.position_embeddings.data = nn.init.trunc_normal_( |
| module.position_embeddings.data.to(torch.float32), |
| mean=0.0, |
| std=self.config.initializer_range, |
| ).to(module.position_embeddings.dtype) |
|
|
|
|
| |
| class ViTEmbeddingsReg(ViTEmbeddings): |
| """ |
| ViT Embedding Module with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1 |
| """ |
|
|
| def __init__(self, config: AutoConfig, use_mask_token: bool = False, num_reg_tokens: int = 7): |
| super(ViTEmbeddingsReg, self).__init__(config, use_mask_token=use_mask_token) |
| self.reg_token = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size)) |
| self.num_reg_tokens = num_reg_tokens |
| self.reg_pos_embed = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size)) |
|
|
| self.reg_pos_embed.data = nn.init.trunc_normal_( |
| self.reg_pos_embed.data.to(torch.float32), |
| mean=0.0, |
| std=self.config.initializer_range, |
| ).to(self.reg_pos_embed.dtype) |
|
|
| self.reg_token.data = nn.init.trunc_normal_( |
| self.reg_token.data.to(torch.float32), |
| mean=0.0, |
| std=self.config.initializer_range, |
| ).to(self.reg_token.dtype) |
|
|
| def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: |
| """ |
| This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher |
| resolution images. |
| |
| Source: |
| https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 |
| """ |
|
|
| num_patches = embeddings.shape[1] - 1 - self.num_reg_tokens |
| num_positions = self.position_embeddings.shape[1] - 1 |
| if num_patches == num_positions and height == width: |
| return self.position_embeddings |
| class_pos_embed = self.position_embeddings[:, 0] |
| patch_pos_embed = self.position_embeddings[:, 1:] |
| reg_pos_embed = self.reg_pos_embed |
| dim = embeddings.shape[-1] |
| h0 = height // self.config.patch_size |
| w0 = width // self.config.patch_size |
| |
| |
| h0, w0 = h0 + 0.1, w0 + 0.1 |
| patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) |
| patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) |
| patch_pos_embed = nn.functional.interpolate( |
| patch_pos_embed, |
| scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), |
| mode="bicubic", |
| align_corners=False, |
| ) |
| assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
| return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, reg_pos_embed), dim=1) |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| bool_masked_pos: Optional[torch.BoolTensor] = None, |
| interpolate_pos_encoding: bool = False, |
| ) -> torch.Tensor: |
| batch_size, num_channels, height, width = pixel_values.shape |
| embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) |
|
|
| if bool_masked_pos is not None: |
| seq_length = embeddings.shape[1] |
| mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) |
| |
| mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) |
| embeddings = embeddings * (1.0 - mask) + mask_tokens * mask |
|
|
| |
| cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
| reg_tokens = self.reg_token.expand(batch_size, -1, -1) |
| embeddings = torch.cat((cls_tokens, embeddings, reg_tokens), dim=1) |
|
|
| |
| if interpolate_pos_encoding: |
| embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) |
| else: |
| embeddings = embeddings + torch.cat([self.position_embeddings, self.reg_pos_embed], dim=1) |
|
|
| embeddings = self.dropout(embeddings) |
|
|
| return embeddings |
|
|
|
|
| |
| class ViTModelReg(ViTModel): |
| """ViT Model with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1""" |
|
|
| def __init__( |
| self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, num_reg_tokens: int = 7 |
| ): |
| super(ViTModelReg, self).__init__(config, add_pooling_layer, use_mask_token) |
| self.embeddings = ViTEmbeddingsReg(config, use_mask_token=use_mask_token, num_reg_tokens=num_reg_tokens) |
| self.num_reg_tokens = num_reg_tokens |
|
|
| def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None: |
| """Initialize the weights""" |
| if isinstance(module, (nn.Linear, nn.Conv2d)): |
| |
| |
| module.weight.data = nn.init.trunc_normal_( |
| module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range |
| ).to(module.weight.dtype) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
| elif isinstance(module, ViTEmbeddings): |
| module.position_embeddings.data = nn.init.trunc_normal_( |
| module.position_embeddings.data.to(torch.float32), |
| mean=0.0, |
| std=self.config.initializer_range, |
| ).to(module.position_embeddings.dtype) |
| module.cls_token.data = nn.init.trunc_normal_( |
| module.cls_token.data.to(torch.float32), |
| mean=0.0, |
| std=self.config.initializer_range, |
| ).to(module.cls_token.dtype) |
|
|
|
|
| class TorchImageProcessor: |
| def __init__(self, processor): |
| |
| self.mean = torch.tensor(processor.image_mean, dtype=torch.float32).reshape((1, 3, 1, 1)) |
| self.std = torch.tensor(processor.image_std, dtype=torch.float32).reshape((1, 3, 1, 1)) |
| self.width = processor.size['width'] |
| self.height = processor.size['height'] |
|
|
| def __call__(self, x, |
| do_resize: bool = True, |
| do_rescale: bool = True, |
| do_normalize: bool = True, |
| device='cuda'): |
| |
| if do_resize: |
| |
| |
| x = F.interpolate( |
| x, |
| size=(self.height, self.width), |
| mode='bilinear', |
| align_corners=False |
| ) |
| |
| |
| if do_rescale: |
| x = x / 255. |
| if do_normalize: |
| x = x - self.mean.to(device) |
| x = x / self.std.to(device) |
| return {'pixel_values': x} |
|
|
|
|
| class DeiT(nn.Module): |
| """DeiT model. |
| |
| Paper: Training data-efficient image transformers & distillation through attention |
| https://arxiv.org/abs/2012.12877 |
| Huggingface Reference: https://huggingface.co/docs/transformers/en/model_doc/deit |
| |
| Attributes: |
| model_name (str): name of the model. |
| pretrained (bool): whether to use pretrained weights. |
| """ |
|
|
| def __init__( |
| self, |
| model_name: str = "facebook/deit-small-patch16-224", |
| pretrained: bool = False, |
| image_size: int = 224, |
| ): |
| super().__init__() |
| self.image_size = image_size |
| model = AutoModel.from_pretrained(model_name) |
| if pretrained: |
| self.model = model |
| else: |
| deit_config = model.config |
| self.model = AutoModel.from_config(deit_config) |
| del model |
|
|
| self.model.pooler = nn.Identity() |
|
|
| |
| self.processor = AutoProcessor.from_pretrained(model_name, use_fast=True) |
| self.gpu_processor = TorchImageProcessor(self.processor) |
|
|
| def get_feature_size( |
| self, |
| keep_spatial: bool = False, |
| return_torch_size: bool = False, |
| ) -> torch.Size | tuple[int, ...]: |
| """Get the size of the feature. |
| |
| Args: |
| keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. |
| return_torch_size (bool): if true, return torch.Size type. Defaults to False. |
| |
| Returns: |
| torch.Size | tuple[int, ...]: returned feature shape. |
| """ |
| with torch.inference_mode(): |
| image_size = (224, 224) |
| x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) |
| y = self.forward(x)[:, 1:] |
| size = y.size()[1:][::-1] |
| if keep_spatial: |
| assert math.isqrt(size[-1]) |
| h = w = int(math.sqrt(size[-1])) |
| size = (size[0], h, w) |
| if return_torch_size: |
| size = torch.Size(size) |
| return size |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| do_resize: bool = True, |
| interpolate_pos_encoding: Optional[bool] = None, |
| do_rescale: bool = True, |
| do_normalize: bool = True, |
| ) -> torch.Tensor: |
| """Forward pass of the model |
| |
| Args: |
| x (torch.Tensor): model input. |
| |
| - arguments for self.processor. Details can be find at |
| https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor |
| do_resize (bool): if do resizing in processor. Defaults to True. |
| interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. |
| do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. |
| do_normalize (bool): if do normalize in processor. Defaults to True. |
| |
| Returns: |
| torch.Tensor: model output. |
| """ |
| |
| |
| |
| if x.shape[-1] == 3: |
| x = x.permute(0, 3, 1, 2) |
| input = self.gpu_processor(x, device=self.model.device) |
|
|
| y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) |
| return y.last_hidden_state |
|
|
|
|
| class DeiTNoCLS(nn.Module): |
| """Modified DeiT model without CLS token.""" |
|
|
| def __init__( |
| self, model_name: str = "nocls-facebook/deit-small-patch16-224", pretrained: bool = False, image_size: int = 224 |
| ): |
| super().__init__() |
| self.image_size = image_size |
| pretrained_model_name = model_name.replace("nocls-", "") |
| deit_config = AutoConfig.from_pretrained(pretrained_model_name) |
| self.model = ViTModelNoCLS(deit_config) |
| if pretrained: |
| pretrained_model = AutoModel.from_pretrained(pretrained_model_name) |
| pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()} |
| self.load_state_dict(pretrained_dict, strict=False) |
| del pretrained_model, pretrained_dict |
|
|
| self.model.pooler = nn.Identity() |
| self.processor = AutoProcessor.from_pretrained(pretrained_model_name) |
| self.no_cls = True |
|
|
| def get_feature_size( |
| self, |
| keep_spatial: bool = False, |
| return_torch_size: bool = False, |
| ) -> torch.Size | tuple[int, ...]: |
| """Get the size of the feature. |
| |
| Args: |
| keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. |
| return_torch_size (bool): if true, return torch.Size type. Defaults to False. |
| |
| Returns: |
| torch.Size | tuple[int, ...]: returned feature shape. |
| """ |
| with torch.inference_mode(): |
| image_size = (self.image_size, self.image_size) |
| x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) |
| y = self.forward(x) |
| size = y.size()[1:][::-1] |
| if keep_spatial: |
| assert math.isqrt(size[-1]) |
| h = w = int(math.sqrt(size[-1])) |
| size = (size[0], h, w) |
| if return_torch_size: |
| size = torch.Size(size) |
| return size |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| do_resize: bool = True, |
| interpolate_pos_encoding: Optional[bool] = None, |
| do_rescale: bool = True, |
| do_normalize: bool = True, |
| ) -> torch.Tensor: |
| """Forward pass of the model |
| |
| Args: |
| x (torch.Tensor): model input. |
| |
| - arguments for self.processor. Details can be find at |
| https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor |
| do_resize (bool): if do resizing in processor. Defaults to True. |
| do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. |
| do_normalize (bool): if do normalize in processor. Defaults to True. |
| |
| - argument for forward |
| interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. |
| |
| Returns: |
| torch.Tensor: model output. |
| """ |
| input = self.processor( |
| x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize |
| ).to(self.model.device) |
| y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) |
| return y.last_hidden_state |
|
|
|
|
| class DeiTReg(nn.Module): |
| """Modified DeiT model with register tokens.""" |
|
|
| def __init__( |
| self, |
| model_name: str = "reg-facebook/deit-small-patch16-224", |
| pretrained: bool = False, |
| image_size: int = 224, |
| num_reg_tokens: int = 7, |
| ): |
| super().__init__() |
| self.image_size = image_size |
| pretrained_model_name = model_name.replace("reg-", "") |
| deit_config = AutoConfig.from_pretrained(pretrained_model_name) |
| self.model = ViTModelReg(deit_config, num_reg_tokens=num_reg_tokens) |
| if pretrained: |
| pretrained_model = AutoModel.from_pretrained(pretrained_model_name) |
| pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()} |
| self.load_state_dict(pretrained_dict, strict=False) |
| del pretrained_model, pretrained_dict |
|
|
| self.model.pooler = nn.Identity() |
| self.processor = AutoProcessor.from_pretrained(pretrained_model_name) |
| self.num_reg_tokens = num_reg_tokens |
|
|
| def get_feature_size( |
| self, |
| keep_spatial: bool = False, |
| return_torch_size: bool = False, |
| ) -> torch.Size | tuple[int, ...]: |
| """Get the size of the feature. |
| |
| Args: |
| keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. |
| return_torch_size (bool): if true, return torch.Size type. Defaults to False. |
| |
| Returns: |
| torch.Size | tuple[int, ...]: returned feature shape. |
| """ |
| with torch.inference_mode(): |
| image_size = (self.image_size, self.image_size) |
| x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) |
| y = self.forward(x)[:, 1 : -self.num_reg_tokens] |
| size = y.size()[1:][::-1] |
| if keep_spatial: |
| assert math.isqrt(size[-1]) |
| h = w = int(math.sqrt(size[-1])) |
| size = (size[0], h, w) |
| if return_torch_size: |
| size = torch.Size(size) |
| return size |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| do_resize: bool = True, |
| interpolate_pos_encoding: Optional[bool] = None, |
| do_rescale: bool = True, |
| do_normalize: bool = True, |
| ) -> torch.Tensor: |
| """Forward pass of the model |
| |
| Args: |
| x (torch.Tensor): model input. |
| |
| - arguments for self.processor. Details can be find at |
| https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor |
| do_resize (bool): if do resizing in processor. Defaults to True. |
| interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. |
| do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. |
| do_normalize (bool): if do normalize in processor. Defaults to True. |
| |
| Returns: |
| torch.Tensor: model output. |
| """ |
| input = self.processor( |
| x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize |
| ).to(self.model.device) |
| y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) |
| return y.last_hidden_state |
|
|
|
|
| def build_backbone(model_name: str, pretrained: bool = False, image_size: int = 224, **kwargs: Any) -> nn.Module: |
| """Build the backbone visual encoder of robot vision foundation model. |
| |
| Args: |
| model_name (str): name of the model. |
| pretrained (bool): whether to use pretrained weights. Defaults to False. |
| image_size (int): size of the image. Assume a square image. Defaults to 224 |
| kwargs (Any): any kwargs specific to some models. For example, |
| `num_reg_tokens` for `DeiTReg` when `"reg"` in `model_name` |
| |
| Returns: |
| nn.Module: backbone network. |
| """ |
| if "reg" in model_name: |
| return DeiTReg(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs) |
| elif "nocls" in model_name: |
| return DeiTNoCLS(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs) |
| elif "deit" in model_name: |
| return DeiT(model_name=model_name, pretrained=pretrained, image_size=image_size) |
| else: |
| raise NotImplementedError(f"Requested {model_name} is not implemented.") |
|
|
| class Interpolation(nn.Module): |
| """Interpolation nn.Module wrap for nn.functional.interpolate. |
| |
| Attributes: |
| target_size (tuple[int, int] | torch.Size): target spatial size of this interpolation. |
| """ |
|
|
| def __init__(self, target_size: tuple[int, int] | torch.Size) -> None: |
| super().__init__() |
| self.target_size = target_size |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Very simple forward pass to call interpolate().""" |
| return interpolate(x, self.target_size) |
|
|
|
|
| class LinearAdapterHead(nn.Module): |
| """Adapter head contains a single linear layer.""" |
| def __init__( |
| self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size |
| ): |
| """Initialization function for LinearAdapterHead. |
| Args: |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. |
| num_layer (int): number of MLP layers (One linear layer if num_layer = 1). |
| """ |
| super().__init__() |
|
|
| self.source_size = source_size |
| self.target_size = target_size |
|
|
| source_channel_size = self.source_size[0] |
| target_channel_size = self.target_size[0] |
|
|
| self.adapter = nn.Sequential( |
| nn.Linear(source_channel_size, target_channel_size), |
| ) |
|
|
| def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: |
| """Forward pass for the adapter. """ |
| assert backbone_no_cls == False |
| |
| |
| x = x[:, 0] |
| x = self.adapter(x) |
| return x |
|
|
|
|
| class MLPAdapterHead(nn.Module): |
| """MLP Adapter module. |
| |
| Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. |
| Will first do interpolation to match the spatial size [H_t, W_t], |
| followed by MLP to project to the target channel dimension [C_t]. |
| |
| Attributes: |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. [C, H, W] |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. [C, H, W] |
| adapter (nn.Module): the adapter module. |
| interpolation (nn.Module): interpolation to adjust sizes before MLP. |
| """ |
|
|
| def __init__( |
| self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size, num_layer: int |
| ): |
| """Initialization function for MLPAdapter. |
| |
| Args: |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. |
| num_layer (int): number of MLP layers (One linear layer if num_layer = 1). |
| """ |
| super().__init__() |
| assert num_layer >= 1, f"`num_layer` in {self._get_name()} should >= 1. Got {num_layer}" |
|
|
| self.source_size = source_size |
| self.target_size = target_size |
|
|
| source_channel_size = self.source_size[0] |
| target_channel_size = self.target_size[0] |
|
|
| self.interpolation = nn.Sequential( |
| nn.Identity(), |
| ) |
| if self.source_size[1] != self.target_size[1]: |
| self.interpolation = nn.Sequential( |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
| Interpolation(self.target_size[1:]), |
| Rearrange("b c h w-> b (h w) c"), |
| ) |
|
|
| if num_layer == 1: |
| self.adapter = nn.Sequential( |
| nn.Linear(source_channel_size, target_channel_size), |
| ) |
| elif num_layer >= 2: |
| hidden_dim = source_channel_size * 2 |
| self.adapter = nn.Sequential( |
| nn.Linear(source_channel_size, hidden_dim), |
| *list( |
| chain.from_iterable([[nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)] for _ in range(num_layer - 2)]) |
| ), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, target_channel_size), |
| ) |
|
|
| def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: |
| """Forward pass for the adapter. First interpolation then MLP.""" |
| |
| if not backbone_no_cls: |
| x = x[:, 1:] |
| |
| x = self.interpolation(x) |
| x = self.adapter(x) |
| return x |
|
|
|
|
| class ConvAdapterHead(nn.Module): |
| """Convolutional Adapter module. |
| |
| Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. |
| Uses CNN to map channel and spatial sizes jointly. |
| Note: only work for (16, 16), (any, any), any <= 14, and (64, 64) spatial sizes for now. |
| |
| Attributes: |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. |
| adapter (nn.Module): the adapter module. |
| interpolation (nn.Module): interpolation to adjust sizes before MLP. |
| """ |
|
|
| def __init__( |
| self, |
| source_size: tuple[int, ...] | torch.Size, |
| target_size: tuple[int, ...] | torch.Size, |
| ): |
| """Initialization function for ConvAdapter. |
| |
| Args: |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. |
| """ |
| super().__init__() |
| self.source_size = source_size |
| self.target_size = target_size |
|
|
| hidden_dim = self.source_size[0] * 2 |
| source_channel_size = self.source_size[0] |
| target_channel_size = self.target_size[0] |
|
|
| if self.source_size[1] < 12: |
| raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.") |
| elif self.source_size[1] < 16: |
| self.pad = nn.Sequential( |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
| nn.ConvTranspose2d( |
| source_channel_size, |
| source_channel_size, |
| kernel_size=3, |
| stride=1, |
| output_padding=14 - self.source_size[1], |
| ), |
| ) |
| self.source_size = (self.source_size[0], 16, 16) |
| elif self.source_size[1] == 16 or self.source_size[1] == 64: |
| self.pad = nn.Sequential( |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
| ) |
| else: |
| raise NotImplementedError("feature spatial size (>=16x16) other than 16x16 and 64x64 is not supported.") |
|
|
| if self.source_size[1] < self.target_size[1]: |
| self.adapter = nn.Sequential( |
| nn.LayerNorm(self.source_size), |
| nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, 31, 31]), |
| nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, 64, 64]), |
| nn.ConvTranspose2d(hidden_dim, target_channel_size, kernel_size=3, stride=1, padding=1), |
| Rearrange("b c h w-> b (h w) c"), |
| ) |
| elif self.source_size[1] == self.target_size[1]: |
| self.adapter = nn.Sequential( |
| nn.LayerNorm(self.source_size), |
| nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, *self.source_size[1:]]), |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, *self.source_size[1:]]), |
| nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), |
| Rearrange("b c h w-> b (h w) c"), |
| ) |
| else: |
| self.adapter = nn.Sequential( |
| nn.LayerNorm(self.source_size), |
| nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, 32, 32]), |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, 16, 16]), |
| nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), |
| Rearrange("b c h w-> b (h w) c"), |
| ) |
|
|
| def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: |
| """Forward pass for ConvAdapter""" |
| |
| if not backbone_no_cls: |
| x = x[:, 1:] |
| |
| x = self.pad(x) |
| x = self.adapter(x) |
| return x |
|
|
|
|
| class LightConvAdapterHead(nn.Module): |
| """Light Convolutional Adapter module. |
| |
| Transforms features from source size in [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. |
| Uses CNN to map channel and spatial sizes jointly. |
| Note: only work for source sizes (H_s, W_s): (16, 16), (any, any), 12 <= any <= 14, |
| and target sizes (H_t, W_t): (16, 16) and (64, 64) for now. |
| |
| Attributes: |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature, |
| channel first (C, H, W). |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature, |
| channel first (C, H, W). |
| adapter (nn.Module): the adapter module. |
| interpolation (nn.Module): interpolation to adjust sizes before MLP. |
| """ |
|
|
| def __init__( |
| self, |
| source_size: tuple[int, ...] | torch.Size, |
| target_size: tuple[int, ...] | torch.Size, |
| hidden_size_factor: int | float = 1.0, |
| ): |
| """Initialization function for ConvAdapter. |
| |
| Args: |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. |
| hidden_size_factor (int | float): the size of hidden dim of feature translator |
| as a factor of input feature hidden dim. |
| """ |
| super().__init__() |
| if source_size[1] != source_size[2] or target_size[1] != target_size[2]: |
| raise NotImplementedError( |
| "Currently does not support non-square feature maps like source size" |
| "{source_size} and target size {target_size}." |
| ) |
| self.source_size = source_size |
| self.target_size = target_size |
| self.hidden_size_factor = hidden_size_factor |
|
|
| hidden_dim = int(self.source_size[0] * hidden_size_factor) |
| source_channel_size = self.source_size[0] |
| target_channel_size = self.target_size[0] |
|
|
| if self.source_size[1] < 12: |
| raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.") |
| elif self.source_size[1] < 16 and self.target_size[1] >= 16: |
| self.pad = nn.Sequential( |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
| nn.ConvTranspose2d( |
| source_channel_size, |
| source_channel_size, |
| kernel_size=3, |
| stride=1, |
| output_padding=14 - self.source_size[1], |
| ), |
| ) |
| self.source_size = (self.source_size[0], 16, 16) |
| elif (self.source_size[1] == 16 or self.source_size[1] == 64) or \ |
| (self.source_size[1] == 14 and self.target_size[1] == 14): |
| |
| self.pad = nn.Sequential( |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
| ) |
| elif self.target_size[1] < 14: |
| self.pad = nn.Sequential( |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
| ) |
| else: |
| raise NotImplementedError("feature spatial size larger than 16x16 (other than 64x64) is not supported.") |
|
|
| if self.source_size[1] == 16 and self.target_size[1] == 64: |
| self.adapter = nn.Sequential( |
| nn.LayerNorm(self.source_size), |
| nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, 31, 31]), |
| nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, 64, 64]), |
| Rearrange("b c h w-> b (h w) c"), |
| nn.Linear(hidden_dim, target_channel_size), |
| ) |
| elif self.source_size[1] == self.target_size[1]: |
| self.adapter = nn.Sequential( |
| nn.LayerNorm(self.source_size), |
| nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, *self.source_size[1:]]), |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, *self.source_size[1:]]), |
| Rearrange("b c h w-> b (h w) c"), |
| nn.Linear(hidden_dim, target_channel_size), |
| ) |
| elif self.source_size[1] == 64 and self.target_size[1] == 16: |
| self.adapter = nn.Sequential( |
| nn.LayerNorm(self.source_size), |
| nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, 32, 32]), |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, 16, 16]), |
| Rearrange("b c h w-> b (h w) c"), |
| nn.Linear(hidden_dim, target_channel_size), |
| ) |
| elif self.target_size[1] == 7: |
| self.adapter = nn.Sequential( |
| nn.LayerNorm(self.source_size), |
| nn.Conv2d(source_channel_size, hidden_dim, kernel_size=4, stride=2, padding=1), |
| nn.ReLU(), |
| nn.LayerNorm([hidden_dim, 7, 7]), |
| Rearrange("b c h w-> b (h w) c"), |
| nn.Linear(hidden_dim, target_channel_size) |
| ) |
| else: |
| NotImplementedError(f"{self.source_size} to {self.target_size} is not supported.") |
|
|
| def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: |
| """Forward pass for ConvAdapter""" |
| |
| if not backbone_no_cls: |
| x = x[:, 1:] |
| x = self.pad(x) |
| x = self.adapter(x) |
| return x |
|
|
|
|
| class FeatureTranslator(nn.Module): |
| """Base class for the feature translator. |
| |
| The flow is backbone_adapter -> translator_stem -> translator_heads. |
| |
| Attributes: |
| backbone_feature_size (torch.Size): the size of features of the backbone. |
| target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. |
| translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048. |
| target_model_names (list[str]): convenient attribute to hold all the names of the target models. |
| |
| backbone_adapter (nn.Module): the adapter to map channel dim of backbone to the translator hidden dim. |
| translator_stem (nn.Module): the shared stem for all target models. |
| translator_heads (nn.ModuleDict): specific heads for different target models. |
| """ |
|
|
| def __init__( |
| self, |
| backbone_feature_size: torch.Size, |
| target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], |
| translator_hidden_size: int = 1024, |
| ) -> None: |
| """Initalization function for FeatureTranslator. |
| |
| Args: |
| backbone_feature_size (torch.Size): the size of features of the backbone. |
| target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. |
| translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048. |
| """ |
| super().__init__() |
| self.backbone_feature_size = backbone_feature_size |
| self.target_feature_sizes = target_feature_sizes |
| self.translator_hidden_size = translator_hidden_size |
| self.target_model_names = list(target_feature_sizes.keys()) |
| self.legit_target_model_name_map: dict[str, str] = {t: t.replace(".", "_") for t in self.target_model_names} |
| self.translator_heads: nn.ModuleDict = None |
|
|
| self.backbone_adapter = nn.Sequential( |
| nn.LayerNorm(self.backbone_feature_size[0]), |
| nn.Linear( |
| self.backbone_feature_size[0], |
| self.translator_hidden_size, |
| ), |
| ) |
| self.translator_stem: nn.Module = nn.Identity() |
| self.build_translator_heads() |
|
|
| def build_translator_heads(self) -> None: |
| """Build translator heads to match the dimension of each target feature set. |
| |
| Example: |
| translator_heads: dict[str, nn.Module] = ... |
| self.translator_heads = nn.ModuleDict(translator_heads) |
| """ |
| raise NotImplementedError("build_translator_heads() should be overridden") |
|
|
| def forward( |
| self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False |
| ) -> torch.Tensor: |
| """Forward pass for a base feature translator. |
| |
| Args: |
| x (torch.Tensor): input features from the backbone. [B, (1)+H*W, C]. |
| (1) means optional CLS token. If `backbone_no_cls==True`, then [B, H*W, C]. |
| target_model_names (Optional[list[str]]): names of the target models. |
| backbone_no_cls (bool): indicate backbone has cls token or not. |
| Can use it to customize whether to drop cls. |
| |
| Returns: |
| dict[str, torch.Tensor]: predicted features for target models. |
| """ |
| |
| x = self.backbone_adapter(x) |
| x = self.translator_stem(x) |
| target_model_names = target_model_names if target_model_names is not None else self.target_model_names |
| features = {t: self.translator_heads[self.legit_target_model_name_map[t]](x, backbone_no_cls=backbone_no_cls) for t in target_model_names} |
| return features |
|
|
|
|
| class MLPFeatureTranslator(FeatureTranslator): |
| def __init__( |
| self, |
| backbone_feature_size: torch.Size, |
| target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], |
| translator_hidden_size: int = 1024, |
| translator_n_layer: int = 3, |
| ) -> None: |
| """Initalization function for MLPFeatureTranslator. |
| |
| Args: |
| backbone_feature_size (torch.Size): the size of features of the backbone. |
| target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. |
| translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048. |
| translator_n_layer (int): number of MLP layers. Defaults to 3. |
| """ |
| self.translator_n_layer = translator_n_layer |
|
|
| super().__init__( |
| backbone_feature_size=backbone_feature_size, |
| target_feature_sizes=target_feature_sizes, |
| translator_hidden_size=translator_hidden_size, |
| ) |
|
|
| def build_translator_heads(self) -> nn.ModuleDict: |
| """Build MLP translator heads to match the dimension of each target feature set.""" |
| translator_heads = {} |
| source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:]) |
| for target_model, target_size in self.target_feature_sizes.items(): |
| head = MLPAdapterHead(source_size=source_size, target_size=target_size, num_layer=self.translator_n_layer) |
| translator_heads[self.legit_target_model_name_map[target_model]] = head |
| self.translator_heads = nn.ModuleDict(translator_heads) |
|
|
|
|
| class ConvFeatureTranslator(FeatureTranslator): |
| def __init__( |
| self, |
| backbone_feature_size: torch.Size, |
| target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], |
| translator_hidden_size: int = 1024, |
| ) -> None: |
| """Initalization function for ConvFeatureTranslator. |
| |
| Args: |
| backbone_feature_size (torch.Size): the size of features of the backbone. |
| target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. |
| translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048. |
| """ |
| super().__init__( |
| backbone_feature_size=backbone_feature_size, |
| target_feature_sizes=target_feature_sizes, |
| translator_hidden_size=translator_hidden_size, |
| ) |
|
|
| def build_translator_heads(self) -> nn.ModuleDict: |
| """Build translator heads to match the dimension of each target feature set. |
| |
| Returns: |
| nn.ModuleDict: the translator heads. |
| """ |
| translator_heads = {} |
| source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:]) |
| for target_model, target_size in self.target_feature_sizes.items(): |
| head = ConvAdapterHead(source_size=source_size, target_size=target_size) |
| translator_heads[self.legit_target_model_name_map[target_model]] = head |
| self.translator_heads = nn.ModuleDict(translator_heads) |
|
|
|
|
| class LightConvFeatureTranslator(FeatureTranslator): |
| def __init__( |
| self, |
| backbone_feature_size: torch.Size, |
| target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], |
| translator_hidden_size: int = 1024, |
| hidden_size_factor: int | float = 1.0, |
| ) -> None: |
| """Initalization function for LightConvFeatureTranslator. |
| It's for a smaller translator compared to ConvFeatureTranslator. |
| |
| Args: |
| backbone_feature_size (torch.Size): the size of features of the backbone. |
| target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. |
| translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 1024. |
| hidden_size_factor: the size of hidden dim of feature translator |
| as a factor of input feature hidden dim. Defaults to 1.0 |
| """ |
| self.hidden_size_factor = hidden_size_factor |
| super().__init__( |
| backbone_feature_size=backbone_feature_size, |
| target_feature_sizes=target_feature_sizes, |
| translator_hidden_size=translator_hidden_size, |
| ) |
| self.backbone_adapter = nn.Identity() |
|
|
| def build_translator_heads(self) -> nn.ModuleDict: |
| """Build translator heads to match the dimension of each target feature set. |
| |
| Returns: |
| nn.ModuleDict: the translator heads. |
| """ |
| translator_heads = {} |
| for target_model, target_size in self.target_feature_sizes.items(): |
| if "_cls" in target_model: |
| head = LinearAdapterHead( |
| source_size=self.backbone_feature_size, |
| target_size=target_size |
| ) |
| else: |
| head = LightConvAdapterHead( |
| source_size=self.backbone_feature_size, |
| target_size=target_size, |
| hidden_size_factor=self.hidden_size_factor |
| ) |
| translator_heads[self.legit_target_model_name_map[target_model]] = head |
| self.translator_heads = nn.ModuleDict(translator_heads) |
|
|
|
|
| class TransformerFreatureTranslator(FeatureTranslator): |
| def __init__( |
| self, |
| backbone_feature_size: torch.Size, |
| target_feature_sizes: dict[str, torch.Size | tuple[int, int]], |
| translator_hidden_size: int = 1024, |
| translator_n_layers: int = 2, |
| translator_n_heads: int = 8, |
| translator_activation: str = "gelu", |
| ) -> None: |
| super().__init__( |
| backbone_feature_size=backbone_feature_size, |
| target_feature_sizes=target_feature_sizes, |
| translator_hidden_size=translator_hidden_size, |
| ) |
|
|
| self.translator_stem = nn.TransformerDecoder( |
| nn.TransformerDecoderLayer( |
| d_model=translator_hidden_size, |
| nhead=translator_n_heads, |
| dim_feedforward=translator_hidden_size * 2, |
| activation=translator_activation, |
| batch_first=True, |
| norm_first=True, |
| ), |
| num_layers=translator_n_layers, |
| ) |
|
|
| self.decode_tokens = nn.Parameter( |
| torch.randn((1, math.prod(self.backbone_feature_size[1:]), translator_hidden_size)) |
| ) |
|
|
| self.target_model_emb = nn.ParameterDict( |
| { |
| self.legit_target_model_name_map[t]: torch.randn(1, 1, translator_hidden_size) |
| for t in self.target_model_names |
| } |
| ) |
|
|
| def build_translator_heads(self) -> None: |
| """Build Transformer translator heads to match the dimension of each target feature set.""" |
| translator_heads = {} |
| for target_model, target_size in self.target_feature_sizes.items(): |
| head = MLPAdapterHead( |
| source_size=(self.translator_hidden_size, *self.backbone_feature_size[1:]), |
| target_size=target_size, |
| num_layer=2, |
| ) |
| translator_heads[self.legit_target_model_name_map[target_model]] = head |
| self.translator_heads = nn.ModuleDict(translator_heads) |
|
|
| def forward( |
| self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False |
| ) -> torch.Tensor: |
| """Forward pass for a simple linear translator. |
| |
| Args: |
| x (torch.Tensor): input features from the backbone. |
| target_model_names (Optional[str]): names of the target models. |
| backbone_no_cls (bool): indicate backbone has cls token or not. |
| Can use it to customize whether to drop cls. |
| |
| Returns: |
| dict[str, torch.Tensor]: predicted features for target models. |
| """ |
| if not backbone_no_cls: |
| x = x[:, 1:] |
| x = self.backbone_adapter(x) |
| features = {} |
| target_model_names = target_model_names if target_model_names is not None else self.target_model_names |
| for t in target_model_names: |
| feature = self.translator_stem( |
| torch.cat( |
| [ |
| self.decode_tokens.repeat(x.size(0), 1, 1), |
| self.target_model_emb[self.legit_target_model_name_map[t]].repeat(x.size(0), 1, 1), |
| ], |
| dim=1, |
| ), |
| memory=x, |
| )[:, 1:, ...] |
| features[t] = self.translator_heads[self.legit_target_model_name_map[t]](feature) |
| return features |
|
|
|
|
| def build_feature_translator(translator_type: str, **kwargs: Any) -> FeatureTranslator: |
| """Handy function to build feature translators given the type |
| |
| Args: |
| translator_type (str): the type of the translator, |
| one in `"mlp"`, `"conv"`, `"lconv"`, `"transformer"` (or `"trans"`). |
| At the moment we are actively using `"lconv"`. |
| |
| Returns: |
| FeatureTranslator: the corresponding FeatureTranslator |
| """ |
| if translator_type == "mlp": |
| return MLPFeatureTranslator(**kwargs) |
| elif translator_type == "conv": |
| return ConvFeatureTranslator(**kwargs) |
| elif translator_type == "lconv": |
| return LightConvFeatureTranslator(**kwargs) |
| elif translator_type == "transformer" or translator_type == "trans": |
| return TransformerFreatureTranslator(**kwargs) |
| else: |
| raise NotImplementedError(f"Requested {translator_type} is not implemented yet.") |
|
|
|
|
| class TheiaConfig(PretrainedConfig): |
| def __init__( |
| self, |
| backbone: str | nn.Module = "facebook/deit-tiny-patch16-224", |
| pretrained: bool = False, |
| target_feature_sizes: Optional[dict[str, torch.Size | tuple[int, ...]]] = None, |
| translator_type: str = "lconv", |
| translator_hidden_size_factor: float | int = 1.0, |
| target_loss_weights: Optional[dict[str, float]] = None, |
| feature_reduce_method: Optional[str] = None, |
| feature_neck: bool = False, |
| feature_neck_hidden_dim: int = 256, |
| forward_neck: bool = False, |
| feature_neck_nonlinearity: str = "relu", |
| iamge_size: int = 224, |
| num_reg_tokens: int = 0, |
| **kwargs: Any |
| ): |
| self.backbone = backbone |
| self.pretrained = pretrained |
| self.target_feature_sizes = target_feature_sizes |
| self.translator_type = translator_type |
| self.translator_hidden_size_factor = translator_hidden_size_factor |
| self.target_loss_weights = target_loss_weights |
| self.feature_reduce_method = feature_reduce_method |
| self.feature_neck = feature_neck |
| self.feature_neck_hidden_dim = feature_neck_hidden_dim |
| self.forward_neck = forward_neck |
| self.feature_neck_nonlinearity = feature_neck_nonlinearity |
| self.image_size = 224 |
| self.num_reg_tokens = num_reg_tokens |
| super().__init__(**kwargs) |
|
|
| class TheiaModel(PreTrainedModel): |
| config_class = TheiaConfig |
|
|
| def __init__(self, config: TheiaConfig): |
| super().__init__(config) |
|
|
| self.target_feature_sizes = config.target_feature_sizes |
| self.preprocessor = None |
| self.pretrained = config.pretrained |
|
|
| |
| self.image_size = config.image_size |
| if "reg" in config.backbone: |
| self.backbone: nn.Module = build_backbone(config.backbone, config.pretrained, image_size=config.image_size, num_reg_tokens = config.num_reg_tokens) |
| else: |
| self.backbone: nn.Module = build_backbone(config.backbone, config.pretrained, image_size=config.image_size) |
|
|
| |
| self.feature_reduce_method = config.feature_reduce_method |
| self.no_cls = hasattr(self.backbone, "no_cls") |
| self.num_reg_tokens = self.backbone.num_reg_tokens if hasattr(self.backbone, "num_reg_tokens") else 0 |
|
|
| |
| backbone_feature_size = self.backbone.get_feature_size(keep_spatial=True) |
| if self.target_feature_sizes: |
| translator_kwargs = { |
| "hidden_size_factor": config.translator_hidden_size_factor |
| } |
| translator_kwargs["backbone_feature_size"] = backbone_feature_size |
| translator_kwargs["target_feature_sizes"] = config.target_feature_sizes |
| self.translator = build_feature_translator( |
| config.translator_type, **translator_kwargs |
| ) |
| else: |
| self.translator = None |
|
|
| self.feature_neck = config.feature_neck |
| self.feature_neck_hidden_dim = config.feature_neck_hidden_dim |
| self.forward_neck = config.forward_neck |
| if self.feature_neck: |
| num_tokens_edge = self.backbone.model.config.image_size // self.backbone.model.config.patch_size |
| self.neck = nn.Sequential( |
| Rearrange("b (h w) c -> b c h w", h=num_tokens_edge, w=num_tokens_edge), |
| nn.Conv2d(self.backbone.model.config.hidden_size, self.feature_neck_hidden_dim, kernel_size=4, stride=2, padding=1), |
| nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(), |
| nn.Conv2d(self.feature_neck_hidden_dim, self.feature_neck_hidden_dim, kernel_size=3, stride=2), |
| nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(), |
| nn.Conv2d(self.feature_neck_hidden_dim, self.feature_neck_hidden_dim, kernel_size=3, stride=1), |
| nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(), |
| nn.Flatten() |
| ) |
| else: |
| self.neck = None |
|
|
| |
| self.mse_loss = nn.MSELoss() |
| self.l1_loss = nn.SmoothL1Loss() |
| self.cos_loss = nn.CosineEmbeddingLoss() |
| self.cos_target = torch.ones((1), dtype=torch.int, requires_grad=False) |
| self.target_loss_weights = config.target_loss_weights |
|
|
| def load_pretrained_weights(self, checkpoint_path: str) -> None: |
| """ |
| Load weights from `checkpoint_path` manually. |
| |
| Args: |
| checkpoint_path (str): path to the weights. |
| """ |
| |
| if checkpoint_path: |
| weights_dict = torch.load(checkpoint_path, map_location="cpu") |
| |
| pretrained_dict = {k: v for k, v in weights_dict.items() if k in self.state_dict()} |
| self.load_state_dict(pretrained_dict, strict=False) |
|
|
| def freeze_translator(self) -> None: |
| """Freeze feature translators `self.translator`.""" |
| if self.translator is not None: |
| for param in self.translator.parameters(): |
| param.requires_grad = False |
|
|
| def freeze_backbone(self) -> None: |
| """Freeze backbone (encoder) `self.backbone`. """ |
| self.freeze_encoder() |
|
|
| def freeze_encoder(self) -> None: |
| """Freeze backbone (encoder) `self.backbone`. """ |
| for param in self.backbone.parameters(): |
| param.requires_grad = False |
|
|
| def freeze_neck(self) -> None: |
| """Freeze feature neck `self.neck`.""" |
| if self.neck is not None: |
| for param in self.neck.parameters(): |
| param.requires_grad = False |
| |
| def freeze_everything(self) -> None: |
| """Freeze all parameters in the model.""" |
| self.freeze_translator() |
| self.freeze_neck() |
| self.freeze_encoder() |
|
|
| def unfreeze_translator(self) -> None: |
| if self.translator is not None: |
| for param in self.translator.parameters(): |
| param.requires_grad = True |
|
|
| def unfreeze_backbone(self) -> None: |
| "Set parameters in backbone (encoder) `self.backbone` trainable." |
| self.unfreeze_encoder() |
|
|
| def unfreeze_encoder(self) -> None: |
| "Set parameters in backbone (encoder) `self.backbone` trainable." |
| for param in self.backbone.parameters(): |
| param.requires_grad = True |
|
|
| def unfreeze_neck(self) -> None: |
| "Set parameters in feature neck `self.neck` trainable." |
| if self.neck is not None: |
| for param in self.neck.parameters(): |
| param.requires_grad = True |
| |
| def unfreeze_everything(self) -> None: |
| """Set all parameters trainable.""" |
| self.unfreeze_translator() |
| self.unfreeze_neck() |
| self.unfreeze_encoder() |
|
|
| def set_forward_neck(self, forward_neck: bool = True) -> None: |
| """ |
| Set `self.forward_neck` to `forward_neck` value. |
| |
| Args: |
| forward_neck (bool): whether forward the feature through the random initialized neck. |
| If set to True, the output from `self.forward()` will be in shape [batch_size, self.config.feature_neck_hidden_dim] |
| """ |
| self.forward_neck = forward_neck |
|
|
| def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: |
| """Forward RVFM feature only (before translators). |
| |
| Args: |
| x (torch.Tensor): input image. By default it accepts images |
| in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8. |
| kwargs (Any): kwargs including mainly those for huggingface preprocessor: |
| `do_resize` (bool) defaults to True. |
| `interpolate_pos_encoding` (Optional[bool]) defaults to None. |
| `do_rescale` (bool) defaults to True. |
| `do_normalize` (bool) defaults to True. |
| |
| Returns: |
| torch.Tensor: RVFM feature. |
| """ |
| feature = self.backbone(x, **kwargs) |
| |
| |
| |
| return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens) |
|
|
| def forward(self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, **kwargs: Any) -> dict[str, torch.Tensor] | torch.Tensor: |
| """Forward pass of Robot Vision Foundation Model. |
| |
| Args: |
| x (torch.Tensor): input image. By default it accepts images |
| in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8. |
| target_model_names (Optional[list[str]]): names of the target foundation models. |
| kwargs (Any): kwargs including mainly those for huggingface preprocessor: |
| `do_resize` (bool) defaults to True. |
| `interpolate_pos_encoding` (Optional[bool]) defaults to None. |
| `do_rescale` (bool) defaults to True. |
| `do_normalize` (bool) defaults to True. |
| |
| Returns: |
| if `self.forward_neck`: |
| torch.Tensor: compact vector feature passed through the neck. [B, C_neck] |
| else: |
| dict[str, torch.Tensor]: features that match to each foundation model. |
| Each feature is in [B, (H*W), C] or [B, C]. |
| """ |
| if self.forward_neck: |
| x = self.forward_feature(x) |
| return self.neck(x) |
| else: |
| x = self.backbone(x, **kwargs) |
| if self.num_reg_tokens > 0: |
| x = x[:, :-self.num_reg_tokens] |
| features = self.translator(x, target_model_names, backbone_no_cls=self.no_cls) |
| return features |
|
|
| def get_loss(self, pred_features: dict[str, torch.Tensor], y: dict[str, torch.Tensor]) -> dict[str, Any]: |
| """Get loss terms given predictions and targets. |
| |
| Args: |
| pred_features (dict[str, torch.Tensor]): predictions. |
| y (dict[str, torch.Tensor]): targets. |
| |
| Returns: |
| tuple[Any, ...]: loss terms |
| """ |
| mse_loss_avg, cos_loss_avg, l1_loss_avg = 0, 0, 0 |
| mse_losses_per_model = {} |
| cos_losses_per_model = {} |
| l1_losses_per_model = {} |
|
|
| for t in pred_features: |
| pred = pred_features[t] |
| target = y[t] |
|
|
| |
| mse_loss = self.mse_loss(pred, target) |
| weight = self.target_loss_weights if self.target_loss_weights else 1.0 / len(pred_features) |
|
|
| |
| l1_loss = self.l1_loss(pred, target) |
|
|
| |
| pred_norm = F.normalize(pred.flatten(start_dim=1), dim=1, p=2) |
| target_norm = F.normalize(target.flatten(start_dim=1), dim=1, p=2) |
| target = self.cos_target.repeat(pred.size(0)).to(pred.device) |
| cos_loss = self.cos_loss(pred_norm, target_norm, target) |
|
|
| mse_loss_avg += mse_loss * weight |
| cos_loss_avg += cos_loss / len(pred_features) |
| l1_loss_avg += l1_loss * weight |
|
|
| mse_losses_per_model[t] = mse_loss.item() |
| cos_losses_per_model[t] = cos_loss.item() |
| l1_losses_per_model[t] = l1_loss.item() |
|
|
| return { |
| "mse_loss": mse_loss_avg, |
| "cos_loss": cos_loss_avg, |
| "l1_loss": l1_loss_avg, |
| "mse_losses_per_model": mse_losses_per_model, |
| "cos_losses_per_model": cos_losses_per_model, |
| "l1_losses_per_model": l1_losses_per_model, |
| } |
|
|