| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel, AutoModel, AutoConfig | |
| from .configuration_multimodal import MultimodalConfig | |
| class ProjectionHead(nn.Module): | |
| def __init__(self, in_dim, out_dim, hidden_mult=2, p_drop=0.4): | |
| super().__init__() | |
| h = int(hidden_mult * out_dim) | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, h), | |
| nn.GELU(), | |
| nn.Dropout(p_drop), | |
| nn.Linear(h, out_dim), | |
| ) | |
| self.ln = nn.LayerNorm(out_dim) | |
| self.use_residual = (in_dim == out_dim) | |
| def forward(self, x): | |
| y = self.net(x) | |
| if self.use_residual: | |
| y = y + x | |
| return self.ln(y) | |
| def masked_mean_pool(last_hidden_state, attention_mask): | |
| mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) | |
| summed = (last_hidden_state * mask).sum(dim=1) | |
| lengths = mask.sum(dim=1).clamp(min=1e-6) | |
| return summed / lengths | |
| class MultiEmbedTR(PreTrainedModel): | |
| config_class = MultimodalConfig | |
| def __init__(self, config: MultimodalConfig): | |
| super().__init__(config) | |
| text_cfg = AutoConfig.from_pretrained(config.text_model_name, trust_remote_code=True) | |
| vis_cfg = AutoConfig.from_pretrained(config.vision_model_name) | |
| self.text_encoder = AutoModel.from_config(text_cfg, trust_remote_code=True) | |
| self.vision_encoder = AutoModel.from_config(vis_cfg) | |
| self.text_proj = ProjectionHead(config.text_dim, config.embed_dim) | |
| self.image_proj = ProjectionHead(config.image_dim, config.embed_dim) | |
| self.logit_scale = nn.Parameter( | |
| torch.tensor(math.log(config.temperature_init), dtype=torch.float) | |
| ) | |
| self.post_init() | |
| def encode_text(self, input_ids, attention_mask): | |
| out = self.text_encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True | |
| ) | |
| if self.config.use_mean_pooling_for_text: | |
| pooled = masked_mean_pool(out.last_hidden_state, attention_mask) | |
| else: | |
| pooled = out.last_hidden_state[:, 0, :] | |
| return F.normalize(self.text_proj(pooled), dim=-1) | |
| def encode_image(self, pixel_values): | |
| out = self.vision_encoder( | |
| pixel_values=pixel_values, | |
| return_dict=True | |
| ) | |
| cls = out.last_hidden_state[:, 0, :] | |
| return F.normalize(self.image_proj(cls), dim=-1) | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| pixel_values=None, | |
| return_dict=True, | |
| **kwargs | |
| ): | |
| text_embeds = None | |
| image_embeds = None | |
| if input_ids is not None: | |
| text_embeds = self.encode_text(input_ids, attention_mask) | |
| if pixel_values is not None: | |
| image_embeds = self.encode_image(pixel_values) | |
| if not return_dict: | |
| return text_embeds, image_embeds | |
| return { | |
| "text_embeds": text_embeds, | |
| "image_embeds": image_embeds, | |
| "logit_scale": self.logit_scale.exp(), | |
| } |