|
|
| """
|
| Backbone modules.
|
| """
|
| import os.path
|
| import shutil
|
| import tarfile
|
| import tempfile
|
|
|
| import torch
|
| from pytorch_pretrained_bert import BertConfig, cached_path, CONFIG_NAME, WEIGHTS_NAME, load_tf_weights_in_bert
|
| from pytorch_pretrained_bert.modeling import BertLayerNorm, PRETRAINED_MODEL_ARCHIVE_MAP, logger, BERT_CONFIG_NAME, \
|
| BertEmbeddings, BertEncoder
|
| from pytorch_pretrained_bert.modeling_transfo_xl import TF_WEIGHTS_NAME
|
| from torch import nn
|
|
|
| from lib.utils.misc import NestedTensor
|
|
|
|
|
| class BertPreTrainedModel(nn.Module):
|
| """ An abstract class to handle weights initialization and
|
| a simple interface for dowloading and loading pretrained checkpoints.
|
| """
|
| def __init__(self, config, *inputs, **kwargs):
|
| super(BertPreTrainedModel, self).__init__()
|
| if not isinstance(config, BertConfig):
|
| raise ValueError(
|
| "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
| "To create a model from a Google pretrained model use "
|
| "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
| self.__class__.__name__, self.__class__.__name__
|
| ))
|
| self.config = config
|
|
|
| def init_bert_weights(self, module):
|
| """ Initialize the weights.
|
| """
|
| if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
|
|
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| elif isinstance(module, BertLayerNorm):
|
| module.bias.data.zero_()
|
| module.weight.data.fill_(1.0)
|
| if isinstance(module, nn.Linear) and module.bias is not None:
|
| module.bias.data.zero_()
|
|
|
| @classmethod
|
| def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
| """
|
| Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
| Download and cache the pre-trained model file if needed.
|
|
|
| Params:
|
| pretrained_model_name_or_path: either:
|
| - a str with the name of a pre-trained model to load selected in the list of:
|
| . `bert-base-uncased`
|
| . `bert-large-uncased`
|
| . `bert-base-cased`
|
| . `bert-large-cased`
|
| . `bert-base-multilingual-uncased`
|
| . `bert-base-multilingual-cased`
|
| . `bert-base-chinese`
|
| - a path or url to a pretrained model archive containing:
|
| . `bert_config.json` a configuration file for the model
|
| . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
| - a path or url to a pretrained model archive containing:
|
| . `bert_config.json` a configuration file for the model
|
| . `model.chkpt` a TensorFlow checkpoint
|
| from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
| cache_dir: an optional path to a folder in which the pre-trained checkpoints will be cached.
|
| state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained checkpoints
|
| *inputs, **kwargs: additional input for the specific Bert class
|
| (ex: num_labels for BertForSequenceClassification)
|
| """
|
| state_dict = kwargs.get('state_dict', None)
|
| kwargs.pop('state_dict', None)
|
| cache_dir = kwargs.get('cache_dir', None)
|
| kwargs.pop('cache_dir', None)
|
| from_tf = kwargs.get('from_tf', False)
|
| kwargs.pop('from_tf', None)
|
|
|
| if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
| archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
| else:
|
| archive_file = pretrained_model_name_or_path
|
|
|
| try:
|
| resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
| except EnvironmentError:
|
| logger.error(
|
| "Model name '{}' was not found in model name list ({}). "
|
| "We assumed '{}' was a path or url but couldn't find any file "
|
| "associated to this path or url.".format(
|
| pretrained_model_name_or_path,
|
| ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
| archive_file))
|
| return None
|
| if resolved_archive_file == archive_file:
|
| logger.info("loading archive file {}".format(archive_file))
|
| else:
|
| logger.info("loading archive file {} from cache at {}".format(
|
| archive_file, resolved_archive_file))
|
| tempdir = None
|
| if os.path.isdir(resolved_archive_file) or from_tf:
|
| serialization_dir = resolved_archive_file
|
| else:
|
|
|
| tempdir = tempfile.mkdtemp()
|
| logger.info("extracting archive file {} to temp dir {}".format(
|
| resolved_archive_file, tempdir))
|
| with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
| archive.extractall(tempdir)
|
| serialization_dir = tempdir
|
|
|
| config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
| if not os.path.exists(config_file):
|
|
|
| config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
|
| config = BertConfig.from_json_file(config_file)
|
| logger.info("Model config {}".format(config))
|
|
|
| model = cls(config, *inputs, **kwargs)
|
| if state_dict is None and not from_tf:
|
| weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
| state_dict = torch.load(weights_path, map_location='cpu')
|
| if tempdir:
|
|
|
| shutil.rmtree(tempdir)
|
| if from_tf:
|
|
|
| weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
| return load_tf_weights_in_bert(model, weights_path)
|
|
|
| old_keys = []
|
| new_keys = []
|
| for key in state_dict.keys():
|
| new_key = None
|
| if 'gamma' in key:
|
| new_key = key.replace('gamma', 'weight')
|
| if 'beta' in key:
|
| new_key = key.replace('beta', 'bias')
|
| if new_key:
|
| old_keys.append(key)
|
| new_keys.append(new_key)
|
| for old_key, new_key in zip(old_keys, new_keys):
|
| state_dict[new_key] = state_dict.pop(old_key)
|
|
|
| missing_keys = []
|
| unexpected_keys = []
|
| error_msgs = []
|
|
|
| metadata = getattr(state_dict, '_metadata', None)
|
| state_dict = state_dict.copy()
|
| if metadata is not None:
|
| state_dict._metadata = metadata
|
|
|
| def load(module, prefix=''):
|
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
| module._load_from_state_dict(
|
| state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
| for name, child in module._modules.items():
|
| if child is not None:
|
| load(child, prefix + name + '.')
|
| start_prefix = ''
|
| if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
|
| start_prefix = 'bert.'
|
| load(model, prefix=start_prefix)
|
| if len(missing_keys) > 0:
|
| logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
| model.__class__.__name__, missing_keys))
|
| if len(unexpected_keys) > 0:
|
| logger.info("Weights from pretrained model not used in {}: {}".format(
|
| model.__class__.__name__, unexpected_keys))
|
| if len(error_msgs) > 0:
|
| raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
| model.__class__.__name__, "\n\t".join(error_msgs)))
|
| return model
|
|
|
|
|
| class BertModel(BertPreTrainedModel):
|
| """BERT model ("Bidirectional Embedding Representations from a Transformer").
|
|
|
| Params:
|
| config: a BertConfig class instance with the configuration to build a new model
|
|
|
| Inputs:
|
| `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
| with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
| `extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
| `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
| types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
| a `sentence B` token (see BERT paper for more details).
|
| `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
| selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
| input sequence length in the current batch. It's the mask that we typically use for attention when
|
| a batch has varying length sentences.
|
| `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
|
|
| Outputs: Tuple of (encoded_layers, pooled_output)
|
| `encoded_layers`: controled by `output_all_encoded_layers` argument:
|
| - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
|
| of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
|
| encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
| - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
| to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
| `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
| classifier pretrained on top of the hidden state associated to the first character of the
|
| input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
|
|
|
| Example usage:
|
| ```python
|
| # Already been converted into WordPiece token ids
|
| input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
| input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
| token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
|
|
| config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
| num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
|
| model = modeling.BertModel(config=config)
|
| all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
| ```
|
| """
|
| def __init__(self, config):
|
| super(BertModel, self).__init__(config)
|
| self.embeddings = BertEmbeddings(config)
|
| self.encoder = BertEncoder(config)
|
|
|
| self.apply(self.init_bert_weights)
|
|
|
| def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
|
| if attention_mask is None:
|
| attention_mask = torch.ones_like(input_ids)
|
| if token_type_ids is None:
|
| token_type_ids = torch.zeros_like(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
| extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
|
|
|
|
|
|
|
|
|
|
|
| extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)
|
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
|
| embedding_output = self.embeddings(input_ids, token_type_ids)
|
| encoded_layers = self.encoder(embedding_output,
|
| extended_attention_mask,
|
| output_all_encoded_layers=output_all_encoded_layers)
|
|
|
|
|
| if not output_all_encoded_layers:
|
| encoded_layers = encoded_layers[-1]
|
| return encoded_layers
|
|
|
|
|
| class BERT(nn.Module):
|
| def __init__(self, name: str, path: str, train_bert: bool, hidden_dim: int, max_len: int, enc_num):
|
| super().__init__()
|
| if name == 'bert-base-uncased':
|
| self.num_channels = 768
|
| else:
|
| self.num_channels = 1024
|
| self.enc_num = enc_num
|
| if path is not None and os.path.exists(path):
|
| self.bert = BertModel.from_pretrained(path)
|
| else:
|
| self.bert = BertModel.from_pretrained(name)
|
|
|
| if not train_bert:
|
| print('Language Model Bert has been frozen!')
|
| for parameter in self.bert.parameters():
|
| parameter.requires_grad_(False)
|
|
|
|
|
| def forward(self, tensor_list: NestedTensor):
|
|
|
| if self.enc_num > 0:
|
| all_encoder_layers = self.bert(tensor_list.tensors, token_type_ids=None, attention_mask=tensor_list.mask)
|
|
|
| xs = all_encoder_layers[self.enc_num - 1]
|
| else:
|
| xs = self.bert.embeddings.word_embeddings(tensor_list.tensors)
|
|
|
| mask = tensor_list.mask.to(torch.bool)
|
| mask = ~mask
|
| out = NestedTensor(xs, mask)
|
|
|
| return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def build_bert(cfg):
|
|
|
|
|
| train_bert = cfg.TRAIN.LR > 0
|
| bert_type = cfg.MODEL.LANGUAGE.IMPLEMENT
|
| if bert_type == "pytorch":
|
| bert_model = BERT(cfg.MODEL.LANGUAGE.TYPE, cfg.MODEL.LANGUAGE.PATH, train_bert,
|
| cfg.MODEL.LANGUAGE.BERT.HIDDEN_DIM,
|
| cfg.MODEL.LANGUAGE.BERT.MAX_QUERY_LEN, cfg.MODEL.LANGUAGE.BERT.ENC_NUM)
|
| else:
|
| raise ValueError("Undefined BERT TYPE '%s'" % bert_type)
|
| return bert_model
|
|
|
|
|