| import torch | |
| from torch import nn | |
| from transformers import PreTrainedModel | |
| from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder | |
| from .configuration_radzero import AlignTransformerConfig | |
| def build_align_transformer(config): | |
| if config.model_type == "align_transformer": | |
| model = AlignTransformer(config) | |
| else: | |
| raise NotImplementedError() | |
| return model | |
| class AlignTransformer(PreTrainedModel): | |
| def __init__(self, config: AlignTransformerConfig): | |
| super().__init__(config) | |
| self.projector = None | |
| if config.num_hidden_layers: | |
| self.transformer_layers = Dinov2Encoder(config) | |
| else: | |
| self.transformer_layers = None | |
| if config.use_layer_norm: | |
| self.layer_norm = nn.LayerNorm(config.hidden_size) | |
| else: | |
| self.layer_norm = None | |
| def forward(self, vision_tokens): | |
| if self.projector is not None: | |
| cls_token = vision_tokens[:, :1] | |
| patch_tokens = vision_tokens[:, 1:] | |
| patch_tokens = self.projector(patch_tokens)["last_hidden_state"] | |
| vision_tokens = torch.cat([cls_token, patch_tokens], dim=1) | |
| if self.transformer_layers is not None: | |
| vision_tokens = self.transformer_layers(vision_tokens)["last_hidden_state"] | |
| if self.layer_norm is not None: | |
| vision_tokens = self.layer_norm(vision_tokens) | |
| return vision_tokens | |