| import torch.nn as nn
|
| import torch
|
| import numpy as np
|
| from model.encoder import ImageEncoder, RobertaEncoder
|
| import torch.nn.functional as F
|
| class LVL(nn.Module):
|
| def __init__(self):
|
| super(LVL, self).__init__()
|
| self.image_encoder = ImageEncoder()
|
| self.text_encoder = RobertaEncoder()
|
| self.t_prime = nn.Parameter(torch.ones([]) * np.log(0.07))
|
| self.b = nn.Parameter(torch.ones([]) * 0)
|
|
|
| def get_images_features(self,images):
|
| image_embeddings = self.image_encoder(images)
|
| image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
|
| return image_embeddings
|
|
|
| def get_texts_feature(self,input_ids,attention_mask):
|
| text_embeddings = self.text_encoder(input_ids, attention_mask)
|
| text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
|
| return text_embeddings
|
|
|
| def forward(self, images, input_ids, attention_mask):
|
| """
|
| Args:
|
| images: Tensor of shape (batch_size, 3, 224, 224)
|
| input_ids: Tensor of shape (batch_size, seq_length)
|
| attention_mask: Tensor of shape (batch_size, seq_length)
|
|
|
| Returns:
|
| Image and text embeddings normalized for similarity calculation
|
| """
|
|
|
| image_embeddings = self.get_images_features(images)
|
| text_embeddings = self.get_texts_feature(input_ids, attention_mask)
|
| return image_embeddings, text_embeddings
|
|
|