import torch from torch import nn from transformers import PreTrainedModel from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder from .configuration_radzero import AlignTransformerConfig def build_align_transformer(config): if config.model_type == "align_transformer": model = AlignTransformer(config) else: raise NotImplementedError() return model class AlignTransformer(PreTrainedModel): def __init__(self, config: AlignTransformerConfig): super().__init__(config) self.projector = None if config.num_hidden_layers: self.transformer_layers = Dinov2Encoder(config) else: self.transformer_layers = None if config.use_layer_norm: self.layer_norm = nn.LayerNorm(config.hidden_size) else: self.layer_norm = None def forward(self, vision_tokens): if self.projector is not None: cls_token = vision_tokens[:, :1] patch_tokens = vision_tokens[:, 1:] patch_tokens = self.projector(patch_tokens)["last_hidden_state"] vision_tokens = torch.cat([cls_token, patch_tokens], dim=1) if self.transformer_layers is not None: vision_tokens = self.transformer_layers(vision_tokens)["last_hidden_state"] if self.layer_norm is not None: vision_tokens = self.layer_norm(vision_tokens) return vision_tokens