Safetensors
custom_code
File size: 1,451 Bytes
4333430
 
 
 
 
2ba7893
4333430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder

from .configuration_radzero import AlignTransformerConfig


def build_align_transformer(config):
    if config.model_type == "align_transformer":
        model = AlignTransformer(config)
    else:
        raise NotImplementedError()

    return model


class AlignTransformer(PreTrainedModel):
    def __init__(self, config: AlignTransformerConfig):
        super().__init__(config)

        self.projector = None

        if config.num_hidden_layers:
            self.transformer_layers = Dinov2Encoder(config)
        else:
            self.transformer_layers = None

        if config.use_layer_norm:
            self.layer_norm = nn.LayerNorm(config.hidden_size)
        else:
            self.layer_norm = None

    def forward(self, vision_tokens):

        if self.projector is not None:

            cls_token = vision_tokens[:, :1]
            patch_tokens = vision_tokens[:, 1:]

            patch_tokens = self.projector(patch_tokens)["last_hidden_state"]
            vision_tokens = torch.cat([cls_token, patch_tokens], dim=1)

        if self.transformer_layers is not None:
            vision_tokens = self.transformer_layers(vision_tokens)["last_hidden_state"]

        if self.layer_norm is not None:
            vision_tokens = self.layer_norm(vision_tokens)

        return vision_tokens