MultiEmbedTR / modeling_multimodal.py
utkubascakir's picture
Upload folder using huggingface_hub
2ed9a94 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoModel, AutoConfig
from .configuration_multimodal import MultimodalConfig
class ProjectionHead(nn.Module):
def __init__(self, in_dim, out_dim, hidden_mult=2, p_drop=0.4):
super().__init__()
h = int(hidden_mult * out_dim)
self.net = nn.Sequential(
nn.Linear(in_dim, h),
nn.GELU(),
nn.Dropout(p_drop),
nn.Linear(h, out_dim),
)
self.ln = nn.LayerNorm(out_dim)
self.use_residual = (in_dim == out_dim)
def forward(self, x):
y = self.net(x)
if self.use_residual:
y = y + x
return self.ln(y)
def masked_mean_pool(last_hidden_state, attention_mask):
mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
summed = (last_hidden_state * mask).sum(dim=1)
lengths = mask.sum(dim=1).clamp(min=1e-6)
return summed / lengths
class MultiEmbedTR(PreTrainedModel):
config_class = MultimodalConfig
def __init__(self, config: MultimodalConfig):
super().__init__(config)
text_cfg = AutoConfig.from_pretrained(config.text_model_name, trust_remote_code=True)
vis_cfg = AutoConfig.from_pretrained(config.vision_model_name)
self.text_encoder = AutoModel.from_config(text_cfg, trust_remote_code=True)
self.vision_encoder = AutoModel.from_config(vis_cfg)
self.text_proj = ProjectionHead(config.text_dim, config.embed_dim)
self.image_proj = ProjectionHead(config.image_dim, config.embed_dim)
self.logit_scale = nn.Parameter(
torch.tensor(math.log(config.temperature_init), dtype=torch.float)
)
self.post_init()
def encode_text(self, input_ids, attention_mask):
out = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
if self.config.use_mean_pooling_for_text:
pooled = masked_mean_pool(out.last_hidden_state, attention_mask)
else:
pooled = out.last_hidden_state[:, 0, :]
return F.normalize(self.text_proj(pooled), dim=-1)
def encode_image(self, pixel_values):
out = self.vision_encoder(
pixel_values=pixel_values,
return_dict=True
)
cls = out.last_hidden_state[:, 0, :]
return F.normalize(self.image_proj(cls), dim=-1)
def forward(
self,
input_ids=None,
attention_mask=None,
pixel_values=None,
return_dict=True,
**kwargs
):
text_embeds = None
image_embeds = None
if input_ids is not None:
text_embeds = self.encode_text(input_ids, attention_mask)
if pixel_values is not None:
image_embeds = self.encode_image(pixel_values)
if not return_dict:
return text_embeds, image_embeds
return {
"text_embeds": text_embeds,
"image_embeds": image_embeds,
"logit_scale": self.logit_scale.exp(),
}