| | import torch |
| | import torch.nn as nn |
| | from transformers import * |
| | import warnings |
| | warnings.filterwarnings('ignore') |
| | |
| | MODELS = { |
| | 'prajjwal1/bert-mini': (BertModel, BertTokenizer), |
| | } |
| |
|
| | class Text_Encoder(nn.Module): |
| | def __init__(self, device): |
| | super(Text_Encoder, self).__init__() |
| | self.base_model = 'prajjwal1/bert-mini' |
| | self.dropout = 0.1 |
| |
|
| | self.tokenizer = MODELS[self.base_model][1].from_pretrained(self.base_model) |
| |
|
| | self.bert_layer = MODELS[self.base_model][0].from_pretrained(self.base_model, |
| | add_pooling_layer=False, |
| | hidden_dropout_prob=self.dropout, |
| | attention_probs_dropout_prob=self.dropout, |
| | output_hidden_states=True) |
| | |
| | self.linear_layer = nn.Sequential(nn.Linear(256, 256), nn.ReLU(inplace=True)) |
| | |
| | self.device = device |
| |
|
| | def tokenize(self, caption): |
| | |
| | tokenized = self.tokenizer(caption, add_special_tokens=False, padding=True, return_tensors='pt') |
| | input_ids = tokenized['input_ids'] |
| | attns_mask = tokenized['attention_mask'] |
| |
|
| | input_ids = input_ids.to(self.device) |
| | attns_mask = attns_mask.to(self.device) |
| | return input_ids, attns_mask |
| |
|
| | def forward(self, input_ids, attns_mask): |
| | |
| | output = self.bert_layer(input_ids=input_ids, attention_mask=attns_mask)[0] |
| | cls_embed = output[:, 0, :] |
| | text_embed = self.linear_layer(cls_embed) |
| |
|
| | return text_embed, output |