Medal-S-V1.0 / model /base_bert.py
spc819's picture
Upload 69 files
7f3dfd7 verified
raw
history blame contribute delete
908 Bytes
import torch.nn as nn
import torch
from transformers import BertModel, AutoTokenizer
class BaseBERT(nn.Module):
def __init__(self, basebert_checkpoint='bert-base-uncased'):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(basebert_checkpoint)
self.model = BertModel.from_pretrained(basebert_checkpoint)
self.modality_embed = nn.Embedding(4, 768)
def forward(self, text, modality):
encoded = self.tokenizer(
text,
truncation=True,
padding=True,
return_tensors='pt',
max_length=64,
).to(device=torch.cuda.current_device())
text_feature = self.model(**encoded).last_hidden_state[:, 0, :]
modality_feature = self.modality_embed(modality)
text_feature += modality_feature
return text_feature