File size: 16,018 Bytes
25986db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""

Backbone modules.

"""
import os.path
import shutil
import tarfile
import tempfile

import torch
from pytorch_pretrained_bert import BertConfig, cached_path, CONFIG_NAME, WEIGHTS_NAME, load_tf_weights_in_bert
from pytorch_pretrained_bert.modeling import BertLayerNorm, PRETRAINED_MODEL_ARCHIVE_MAP, logger, BERT_CONFIG_NAME, \
    BertEmbeddings, BertEncoder
from pytorch_pretrained_bert.modeling_transfo_xl import TF_WEIGHTS_NAME
from torch import nn

from lib.utils.misc import NestedTensor


class BertPreTrainedModel(nn.Module):
    """ An abstract class to handle weights initialization and

        a simple interface for dowloading and loading pretrained checkpoints.

    """
    def __init__(self, config, *inputs, **kwargs):
        super(BertPreTrainedModel, self).__init__()
        if not isinstance(config, BertConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
                "To create a model from a Google pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
                ))
        self.config = config

    def init_bert_weights(self, module):
        """ Initialize the weights.

        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
        """

        Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.

        Download and cache the pre-trained model file if needed.



        Params:

            pretrained_model_name_or_path: either:

                - a str with the name of a pre-trained model to load selected in the list of:

                    . `bert-base-uncased`

                    . `bert-large-uncased`

                    . `bert-base-cased`

                    . `bert-large-cased`

                    . `bert-base-multilingual-uncased`

                    . `bert-base-multilingual-cased`

                    . `bert-base-chinese`

                - a path or url to a pretrained model archive containing:

                    . `bert_config.json` a configuration file for the model

                    . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance

                - a path or url to a pretrained model archive containing:

                    . `bert_config.json` a configuration file for the model

                    . `model.chkpt` a TensorFlow checkpoint

            from_tf: should we load the weights from a locally saved TensorFlow checkpoint

            cache_dir: an optional path to a folder in which the pre-trained checkpoints will be cached.

            state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained checkpoints

            *inputs, **kwargs: additional input for the specific Bert class

                (ex: num_labels for BertForSequenceClassification)

        """
        state_dict = kwargs.get('state_dict', None)
        kwargs.pop('state_dict', None)
        cache_dir = kwargs.get('cache_dir', None)
        kwargs.pop('cache_dir', None)
        from_tf = kwargs.get('from_tf', False)
        kwargs.pop('from_tf', None)

        if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
            archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
        else:
            archive_file = pretrained_model_name_or_path
        # redirect to the cache, if necessary
        try:
            resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
        except EnvironmentError:
            logger.error(
                "Model name '{}' was not found in model name list ({}). "
                "We assumed '{}' was a path or url but couldn't find any file "
                "associated to this path or url.".format(
                    pretrained_model_name_or_path,
                    ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
                    archive_file))
            return None
        if resolved_archive_file == archive_file:
            logger.info("loading archive file {}".format(archive_file))
        else:
            logger.info("loading archive file {} from cache at {}".format(
                archive_file, resolved_archive_file))
        tempdir = None
        if os.path.isdir(resolved_archive_file) or from_tf:
            serialization_dir = resolved_archive_file
        else:
            # Extract archive to temp dir
            tempdir = tempfile.mkdtemp()
            logger.info("extracting archive file {} to temp dir {}".format(
                resolved_archive_file, tempdir))
            with tarfile.open(resolved_archive_file, 'r:gz') as archive:
                archive.extractall(tempdir)
            serialization_dir = tempdir
        # Load config
        config_file = os.path.join(serialization_dir, CONFIG_NAME)
        if not os.path.exists(config_file):
            # Backward compatibility with old naming format
            config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
        config = BertConfig.from_json_file(config_file)
        logger.info("Model config {}".format(config))
        # Instantiate model.
        model = cls(config, *inputs, **kwargs)
        if state_dict is None and not from_tf:
            weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
            state_dict = torch.load(weights_path, map_location='cpu')
        if tempdir:
            # Clean up temp dir
            shutil.rmtree(tempdir)
        if from_tf:
            # Directly load from a TensorFlow checkpoint
            weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
            return load_tf_weights_in_bert(model, weights_path)
        # Load from a PyTorch state_dict
        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if 'gamma' in key:
                new_key = key.replace('gamma', 'weight')
            if 'beta' in key:
                new_key = key.replace('beta', 'bias')
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)

        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, '_metadata', None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=''):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')
        start_prefix = ''
        if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
            start_prefix = 'bert.'
        load(model, prefix=start_prefix)
        if len(missing_keys) > 0:
            logger.info("Weights of {} not initialized from pretrained model: {}".format(
                model.__class__.__name__, missing_keys))
        if len(unexpected_keys) > 0:
            logger.info("Weights from pretrained model not used in {}: {}".format(
                model.__class__.__name__, unexpected_keys))
        if len(error_msgs) > 0:
            raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
                               model.__class__.__name__, "\n\t".join(error_msgs)))
        return model


class BertModel(BertPreTrainedModel):
    """BERT model ("Bidirectional Embedding Representations from a Transformer").



    Params:

        config: a BertConfig class instance with the configuration to build a new model



    Inputs:

        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]

            with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts

            `extract_features.py`, `run_classifier.py` and `run_squad.py`)

        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token

            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to

            a `sentence B` token (see BERT paper for more details).

        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices

            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max

            input sequence length in the current batch. It's the mask that we typically use for attention when

            a batch has varying length sentences.

        `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.



    Outputs: Tuple of (encoded_layers, pooled_output)

        `encoded_layers`: controled by `output_all_encoded_layers` argument:

            - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end

                of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each

                encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],

            - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding

                to the last attention block of shape [batch_size, sequence_length, hidden_size],

        `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a

            classifier pretrained on top of the hidden state associated to the first character of the

            input (`CLS`) to train on the Next-Sentence task (see BERT's paper).



    Example usage:

    ```python

    # Already been converted into WordPiece token ids

    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])

    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])

    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])



    config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,

        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)



    model = modeling.BertModel(config=config)

    all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)

    ```

    """
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        # self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        # sequence_output = encoded_layers[-1]
        # pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers


class BERT(nn.Module):
    def __init__(self, name: str, path: str, train_bert: bool, hidden_dim: int, max_len: int, enc_num):
        super().__init__()
        if name == 'bert-base-uncased':
            self.num_channels = 768
        else:
            self.num_channels = 1024
        self.enc_num = enc_num
        if path is not None and os.path.exists(path):
            self.bert = BertModel.from_pretrained(path)
        else:
            self.bert = BertModel.from_pretrained(name)

        if not train_bert:
            print('Language Model Bert has been frozen!')
            for parameter in self.bert.parameters():
                parameter.requires_grad_(False)
        # print(self.bert)

    def forward(self, tensor_list: NestedTensor):

        if self.enc_num > 0:
            all_encoder_layers = self.bert(tensor_list.tensors, token_type_ids=None, attention_mask=tensor_list.mask)
            # use the output of the X-th transformer encoder layers
            xs = all_encoder_layers[self.enc_num - 1]
        else:
            xs = self.bert.embeddings.word_embeddings(tensor_list.tensors)

        mask = tensor_list.mask.to(torch.bool)
        mask = ~mask
        out = NestedTensor(xs, mask)

        return out


# def build_bert(cfg):
#     # position_embedding = build_position_encoding(cfg)
#     train_bert = cfg.MODEL.LANGUAGE.BERT.LR > 0
#     bert = BERT(cfg.MODEL.LANGUAGE.TYPE, cfg.MODEL.LANGUAGE.PATH, train_bert, cfg.MODEL.LANGUAGE.BERT.HIDDEN_DIM,
#                 cfg.MODEL.LANGUAGE.BERT.MAX_QUERY_LEN, cfg.MODEL.LANGUAGE.BERT.ENC_NUM)
#     # model = Joiner(bert, position_embedding)
#     # model.num_channels = bert.num_channels
#     return bert

def build_bert(cfg):
    # position_embedding = build_position_encoding(cfg)
    # train_bert = cfg.MODEL.LANGUAGE.BERT.LR > 0
    train_bert = cfg.TRAIN.LR > 0
    bert_type = cfg.MODEL.LANGUAGE.IMPLEMENT
    if bert_type == "pytorch":
        bert_model = BERT(cfg.MODEL.LANGUAGE.TYPE, cfg.MODEL.LANGUAGE.PATH, train_bert,
                          cfg.MODEL.LANGUAGE.BERT.HIDDEN_DIM,
                          cfg.MODEL.LANGUAGE.BERT.MAX_QUERY_LEN, cfg.MODEL.LANGUAGE.BERT.ENC_NUM)
    else:
        raise ValueError("Undefined BERT TYPE '%s'" % bert_type)
    return bert_model