import torch.nn as nn import torch from transformers import BertModel, AutoTokenizer class BaseBERT(nn.Module): def __init__(self, basebert_checkpoint='bert-base-uncased'): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(basebert_checkpoint) self.model = BertModel.from_pretrained(basebert_checkpoint) self.modality_embed = nn.Embedding(4, 768) def forward(self, text, modality): encoded = self.tokenizer( text, truncation=True, padding=True, return_tensors='pt', max_length=64, ).to(device=torch.cuda.current_device()) text_feature = self.model(**encoded).last_hidden_state[:, 0, :] modality_feature = self.modality_embed(modality) text_feature += modality_feature return text_feature