| | import torch.nn as nn |
| |
|
| | from modules.token_embedders.bert_encoder import BertEncoder |
| | from utils.nn_utils import batched_index_select, gelu |
| |
|
| |
|
| | class BertEmbedModel(nn.Module): |
| | """This class acts as an embeddding layer with bert model |
| | """ |
| | def __init__(self, cfg, vocab, rel_mlp=False): |
| | """This function constructs `BertEmbedModel` components and |
| | sets `BertEmbedModel` parameters |
| | |
| | Arguments: |
| | cfg {dict} -- config parameters for constructing multiple models |
| | vocab {Vocabulary} -- vocabulary |
| | """ |
| |
|
| | super().__init__() |
| | self.rel_mlp = rel_mlp |
| | self.activation = gelu |
| | self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name, |
| | trainable=cfg.fine_tune, |
| | output_size=cfg.bert_output_size, |
| | activation=self.activation, |
| | dropout=cfg.bert_dropout) |
| | self.encoder_output_size = self.bert_encoder.get_output_dims() |
| |
|
| | def forward(self, batch_inputs): |
| | """This function propagetes forwardly |
| | |
| | Arguments: |
| | batch_inputs {dict} -- batch input data |
| | """ |
| |
|
| | if 'wordpiece_segment_ids' in batch_inputs: |
| | batch_seq_bert_encoder_repr, batch_cls_repr = self.bert_encoder( |
| | batch_inputs['wordpiece_tokens'], batch_inputs['wordpiece_segment_ids']) |
| | else: |
| | batch_seq_bert_encoder_repr, batch_cls_repr = self.bert_encoder( |
| | batch_inputs['wordpiece_tokens']) |
| | |
| | if not self.rel_mlp: |
| | batch_seq_tokens_encoder_repr = batched_index_select(batch_seq_bert_encoder_repr, |
| | batch_inputs['wordpiece_tokens_index']) |
| | batch_inputs['seq_encoder_reprs'] = batch_seq_tokens_encoder_repr |
| | else: |
| | batch_inputs['seq_encoder_reprs'] = batch_seq_bert_encoder_repr |
| | |
| | |
| | |
| | batch_inputs['seq_cls_repr'] = batch_cls_repr |
| |
|
| | def get_hidden_size(self): |
| | """This function returns embedding dimensions |
| | |
| | Returns: |
| | int -- embedding dimensitons |
| | """ |
| |
|
| | return self.encoder_output_size |
| |
|