File size: 1,451 Bytes
4333430 2ba7893 4333430 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder
from .configuration_radzero import AlignTransformerConfig
def build_align_transformer(config):
if config.model_type == "align_transformer":
model = AlignTransformer(config)
else:
raise NotImplementedError()
return model
class AlignTransformer(PreTrainedModel):
def __init__(self, config: AlignTransformerConfig):
super().__init__(config)
self.projector = None
if config.num_hidden_layers:
self.transformer_layers = Dinov2Encoder(config)
else:
self.transformer_layers = None
if config.use_layer_norm:
self.layer_norm = nn.LayerNorm(config.hidden_size)
else:
self.layer_norm = None
def forward(self, vision_tokens):
if self.projector is not None:
cls_token = vision_tokens[:, :1]
patch_tokens = vision_tokens[:, 1:]
patch_tokens = self.projector(patch_tokens)["last_hidden_state"]
vision_tokens = torch.cat([cls_token, patch_tokens], dim=1)
if self.transformer_layers is not None:
vision_tokens = self.transformer_layers(vision_tokens)["last_hidden_state"]
if self.layer_norm is not None:
vision_tokens = self.layer_norm(vision_tokens)
return vision_tokens
|