import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, AutoModel, AutoConfig from .configuration_multimodal import MultimodalConfig class ProjectionHead(nn.Module): def __init__(self, in_dim, out_dim, hidden_mult=2, p_drop=0.4): super().__init__() h = int(hidden_mult * out_dim) self.net = nn.Sequential( nn.Linear(in_dim, h), nn.GELU(), nn.Dropout(p_drop), nn.Linear(h, out_dim), ) self.ln = nn.LayerNorm(out_dim) self.use_residual = (in_dim == out_dim) def forward(self, x): y = self.net(x) if self.use_residual: y = y + x return self.ln(y) def masked_mean_pool(last_hidden_state, attention_mask): mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) summed = (last_hidden_state * mask).sum(dim=1) lengths = mask.sum(dim=1).clamp(min=1e-6) return summed / lengths class MultiEmbedTR(PreTrainedModel): config_class = MultimodalConfig def __init__(self, config: MultimodalConfig): super().__init__(config) text_cfg = AutoConfig.from_pretrained(config.text_model_name, trust_remote_code=True) vis_cfg = AutoConfig.from_pretrained(config.vision_model_name) self.text_encoder = AutoModel.from_config(text_cfg, trust_remote_code=True) self.vision_encoder = AutoModel.from_config(vis_cfg) self.text_proj = ProjectionHead(config.text_dim, config.embed_dim) self.image_proj = ProjectionHead(config.image_dim, config.embed_dim) self.logit_scale = nn.Parameter( torch.tensor(math.log(config.temperature_init), dtype=torch.float) ) self.post_init() def encode_text(self, input_ids, attention_mask): out = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) if self.config.use_mean_pooling_for_text: pooled = masked_mean_pool(out.last_hidden_state, attention_mask) else: pooled = out.last_hidden_state[:, 0, :] return F.normalize(self.text_proj(pooled), dim=-1) def encode_image(self, pixel_values): out = self.vision_encoder( pixel_values=pixel_values, return_dict=True ) cls = out.last_hidden_state[:, 0, :] return F.normalize(self.image_proj(cls), dim=-1) def forward( self, input_ids=None, attention_mask=None, pixel_values=None, return_dict=True, **kwargs ): text_embeds = None image_embeds = None if input_ids is not None: text_embeds = self.encode_text(input_ids, attention_mask) if pixel_values is not None: image_embeds = self.encode_image(pixel_values) if not return_dict: return text_embeds, image_embeds return { "text_embeds": text_embeds, "image_embeds": image_embeds, "logit_scale": self.logit_scale.exp(), }