File size: 3,245 Bytes
2ed9a94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoModel, AutoConfig
from .configuration_multimodal import MultimodalConfig

class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_mult=2, p_drop=0.4):
        super().__init__()
        h = int(hidden_mult * out_dim)
        self.net = nn.Sequential(
            nn.Linear(in_dim, h),
            nn.GELU(),
            nn.Dropout(p_drop),
            nn.Linear(h, out_dim),
        )
        self.ln = nn.LayerNorm(out_dim)
        self.use_residual = (in_dim == out_dim)

    def forward(self, x):
        y = self.net(x)
        if self.use_residual:
            y = y + x
        return self.ln(y)

def masked_mean_pool(last_hidden_state, attention_mask):
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
    summed = (last_hidden_state * mask).sum(dim=1)
    lengths = mask.sum(dim=1).clamp(min=1e-6)
    return summed / lengths

class MultiEmbedTR(PreTrainedModel):
    config_class = MultimodalConfig

    def __init__(self, config: MultimodalConfig):
        super().__init__(config)

        text_cfg = AutoConfig.from_pretrained(config.text_model_name, trust_remote_code=True)
        vis_cfg  = AutoConfig.from_pretrained(config.vision_model_name)

        self.text_encoder = AutoModel.from_config(text_cfg, trust_remote_code=True)
        self.vision_encoder = AutoModel.from_config(vis_cfg)

        self.text_proj = ProjectionHead(config.text_dim, config.embed_dim)
        self.image_proj = ProjectionHead(config.image_dim, config.embed_dim)

        self.logit_scale = nn.Parameter(
            torch.tensor(math.log(config.temperature_init), dtype=torch.float)
        )

        self.post_init()

    def encode_text(self, input_ids, attention_mask):
        out = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        if self.config.use_mean_pooling_for_text:
            pooled = masked_mean_pool(out.last_hidden_state, attention_mask)
        else:
            pooled = out.last_hidden_state[:, 0, :]
        return F.normalize(self.text_proj(pooled), dim=-1)

    def encode_image(self, pixel_values):
        out = self.vision_encoder(
            pixel_values=pixel_values,
            return_dict=True
        )
        cls = out.last_hidden_state[:, 0, :]
        return F.normalize(self.image_proj(cls), dim=-1)

    def forward(

        self,

        input_ids=None,

        attention_mask=None,

        pixel_values=None,

        return_dict=True,

        **kwargs

    ):
        text_embeds = None
        image_embeds = None

        if input_ids is not None:
            text_embeds = self.encode_text(input_ids, attention_mask)

        if pixel_values is not None:
            image_embeds = self.encode_image(pixel_values)

        if not return_dict:
            return text_embeds, image_embeds

        return {
            "text_embeds": text_embeds,
            "image_embeds": image_embeds,
            "logit_scale": self.logit_scale.exp(),
        }