| """Tweaked version of corresponding AllenNLP file""" |
| import logging |
| from copy import deepcopy |
| from typing import Dict |
|
|
| import torch |
| import torch.nn.functional as F |
| from allennlp.modules.token_embedders.token_embedder import TokenEmbedder |
| from allennlp.nn import util |
| from transformers import AutoModel, PreTrainedModel |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class PretrainedBertModel: |
| """ |
| In some instances you may want to load the same BERT model twice |
| (e.g. to use as a token embedder and also as a pooling layer). |
| This factory provides a cache so that you don't actually have to load the model twice. |
| """ |
|
|
| _cache: Dict[str, PreTrainedModel] = {} |
|
|
| @classmethod |
| def load(cls, model_name: str, cache_model: bool = True) -> PreTrainedModel: |
| if model_name in cls._cache: |
| return PretrainedBertModel._cache[model_name] |
|
|
| model = AutoModel.from_pretrained(model_name) |
| if cache_model: |
| cls._cache[model_name] = model |
|
|
| return model |
|
|
|
|
| class BertEmbedder(TokenEmbedder): |
| """ |
| A ``TokenEmbedder`` that produces BERT embeddings for your tokens. |
| Should be paired with a ``BertIndexer``, which produces wordpiece ids. |
| Most likely you probably want to use ``PretrainedBertEmbedder`` |
| for one of the named pretrained models, not this base class. |
| Parameters |
| ---------- |
| bert_model: ``BertModel`` |
| The BERT model being wrapped. |
| top_layer_only: ``bool``, optional (default = ``False``) |
| If ``True``, then only return the top layer instead of apply the scalar mix. |
| max_pieces : int, optional (default: 512) |
| The BERT embedder uses positional embeddings and so has a corresponding |
| maximum length for its input ids. Assuming the inputs are windowed |
| and padded appropriately by this length, the embedder will split them into a |
| large batch, feed them into BERT, and recombine the output as if it was a |
| longer sequence. |
| num_start_tokens : int, optional (default: 1) |
| The number of starting special tokens input to BERT (usually 1, i.e., [CLS]) |
| num_end_tokens : int, optional (default: 1) |
| The number of ending tokens input to BERT (usually 1, i.e., [SEP]) |
| scalar_mix_parameters: ``List[float]``, optional, (default = None) |
| If not ``None``, use these scalar mix parameters to weight the representations |
| produced by different layers. These mixing weights are not updated during |
| training. |
| """ |
|
|
| def __init__( |
| self, |
| bert_model: PreTrainedModel, |
| top_layer_only: bool = False, |
| max_pieces: int = 512, |
| num_start_tokens: int = 1, |
| num_end_tokens: int = 1 |
| ) -> None: |
| super().__init__() |
| self.bert_model = deepcopy(bert_model) |
| self.output_dim = bert_model.config.hidden_size |
| self.max_pieces = max_pieces |
| self.num_start_tokens = num_start_tokens |
| self.num_end_tokens = num_end_tokens |
| self._scalar_mix = None |
|
|
| def set_weights(self, freeze): |
| for param in self.bert_model.parameters(): |
| param.requires_grad = not freeze |
| return |
|
|
| def get_output_dim(self) -> int: |
| return self.output_dim |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| offsets: torch.LongTensor = None |
| ) -> torch.Tensor: |
| """ |
| Parameters |
| ---------- |
| input_ids : ``torch.LongTensor`` |
| The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. |
| offsets : ``torch.LongTensor``, optional |
| The BERT embeddings are one per wordpiece. However it's possible/likely |
| you might want one per original token. In that case, ``offsets`` |
| represents the indices of the desired wordpiece for each original token. |
| Depending on how your token indexer is configured, this could be the |
| position of the last wordpiece for each token, or it could be the position |
| of the first wordpiece for each token. |
| For example, if you had the sentence "Definitely not", and if the corresponding |
| wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids |
| would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. |
| If offsets are provided, the returned tensor will contain only the wordpiece |
| embeddings at those positions, and (in particular) will contain one embedding |
| per token. If offsets are not provided, the entire tensor of wordpiece embeddings |
| will be returned. |
| """ |
|
|
| batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1) |
| initial_dims = list(input_ids.shape[:-1]) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| needs_split = full_seq_len > self.max_pieces |
| last_window_size = 0 |
| if needs_split: |
| |
| split_input_ids = list(input_ids.split(self.max_pieces, dim=-1)) |
|
|
| |
| last_window_size = split_input_ids[-1].size(-1) |
| padding_amount = self.max_pieces - last_window_size |
| split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0) |
|
|
| |
| input_ids = torch.cat(split_input_ids, dim=0) |
|
|
| input_mask = (input_ids != 0).long() |
| |
| |
| all_encoder_layers = self.bert_model( |
| input_ids=util.combine_initial_dims(input_ids), |
| attention_mask=util.combine_initial_dims(input_mask), |
| )[0] |
| if len(all_encoder_layers[0].shape) == 3: |
| all_encoder_layers = torch.stack(all_encoder_layers) |
| elif len(all_encoder_layers[0].shape) == 2: |
| all_encoder_layers = torch.unsqueeze(all_encoder_layers, dim=0) |
|
|
| if needs_split: |
| |
| unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1) |
| unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| stride = (self.max_pieces - self.num_start_tokens - self.num_end_tokens) // 2 |
| stride_offset = stride // 2 + self.num_start_tokens |
|
|
| first_window = list(range(stride_offset)) |
|
|
| max_context_windows = [ |
| i |
| for i in range(full_seq_len) |
| if stride_offset - 1 < i % self.max_pieces < stride_offset + stride |
| ] |
|
|
| |
| if full_seq_len % self.max_pieces == 0: |
| lookback = self.max_pieces |
| else: |
| lookback = full_seq_len % self.max_pieces |
|
|
| final_window_start = full_seq_len - lookback + stride_offset + stride |
| final_window = list(range(final_window_start, full_seq_len)) |
|
|
| select_indices = first_window + max_context_windows + final_window |
|
|
| initial_dims.append(len(select_indices)) |
|
|
| recombined_embeddings = unpacked_embeddings[:, :, select_indices] |
| else: |
| recombined_embeddings = all_encoder_layers |
|
|
| |
| |
| |
| input_mask = (recombined_embeddings != 0).long() |
|
|
| if self._scalar_mix is not None: |
| mix = self._scalar_mix(recombined_embeddings, input_mask) |
| else: |
| mix = recombined_embeddings[-1] |
|
|
| |
|
|
| if offsets is None: |
| |
| dims = initial_dims if needs_split else input_ids.size() |
| return util.uncombine_initial_dims(mix, dims) |
| else: |
| |
| offsets2d = util.combine_initial_dims(offsets) |
| |
| range_vector = util.get_range_vector( |
| offsets2d.size(0), device=util.get_device_of(mix) |
| ).unsqueeze(1) |
| |
| selected_embeddings = mix[range_vector, offsets2d] |
|
|
| return util.uncombine_initial_dims(selected_embeddings, offsets.size()) |
|
|
|
|
| |
| class PretrainedBertEmbedder(BertEmbedder): |
|
|
| """ |
| Parameters |
| ---------- |
| pretrained_model: ``str`` |
| Either the name of the pretrained model to use (e.g. 'bert-base-uncased'), |
| or the path to the .tar.gz file with the model weights. |
| If the name is a key in the list of pretrained models at |
| https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L41 |
| the corresponding path will be used; otherwise it will be interpreted as a path or URL. |
| requires_grad : ``bool``, optional (default = False) |
| If True, compute gradient of BERT parameters for fine tuning. |
| top_layer_only: ``bool``, optional (default = ``False``) |
| If ``True``, then only return the top layer instead of apply the scalar mix. |
| scalar_mix_parameters: ``List[float]``, optional, (default = None) |
| If not ``None``, use these scalar mix parameters to weight the representations |
| produced by different layers. These mixing weights are not updated during |
| training. |
| """ |
|
|
| def __init__( |
| self, |
| pretrained_model: str, |
| requires_grad: bool = False, |
| top_layer_only: bool = False, |
| special_tokens_fix: int = 0, |
| ) -> None: |
| model = PretrainedBertModel.load(pretrained_model) |
|
|
| for param in model.parameters(): |
| param.requires_grad = requires_grad |
|
|
| super().__init__( |
| bert_model=model, |
| top_layer_only=top_layer_only |
| ) |
|
|
| if special_tokens_fix: |
| try: |
| vocab_size = self.bert_model.embeddings.word_embeddings.num_embeddings |
| except AttributeError: |
| |
| vocab_size = self.bert_model.word_embedding.num_embeddings + 5 |
| self.bert_model.resize_token_embeddings(vocab_size + 1) |
|
|