andersonbcdefg commited on
Commit
b4bad42
·
1 Parent(s): 5a63953

Upload modeling_bert.py

Browse files
Files changed (1) hide show
  1. modeling_bert.py +763 -0
modeling_bert.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
+
6
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
7
+
8
+ import logging
9
+ import re
10
+ from collections import OrderedDict
11
+ from collections.abc import Sequence
12
+ from functools import partial
13
+ from typing import Any, Mapping
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import rearrange
19
+ from transformers import BertConfig, PretrainedConfig
20
+ from transformers.models.bert.modeling_bert import (
21
+ BaseModelOutputWithPoolingAndCrossAttentions,
22
+ BertForPreTrainingOutput,
23
+ )
24
+
25
+ from flash_attn.bert_padding import (
26
+ index_first_axis,
27
+ index_first_axis_residual,
28
+ pad_input,
29
+ unpad_input,
30
+ )
31
+ from flash_attn.modules.block import Block
32
+ from flash_attn.modules.embedding import BertEmbeddings
33
+ from flash_attn.modules.mha import MHA
34
+ from flash_attn.modules.mlp import FusedMLP, Mlp
35
+ from flash_attn.utils.pretrained import state_dict_from_pretrained
36
+
37
+ try:
38
+ from flash_attn.ops.fused_dense import FusedDense
39
+ except ImportError:
40
+ FusedDense = None
41
+
42
+ try:
43
+ from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
44
+ except ImportError:
45
+ dropout_add_layer_norm, layer_norm = None, None
46
+
47
+ try:
48
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
49
+ except ImportError:
50
+ CrossEntropyLoss = None
51
+
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+
56
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
57
+ use_flash_attn = getattr(config, "use_flash_attn", False)
58
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
59
+ rotary_kwargs = {}
60
+ if config.position_embedding_type == "rotary":
61
+ rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
62
+ rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
63
+ rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
64
+ rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
65
+ mixer_cls = partial(
66
+ MHA,
67
+ num_heads=config.num_attention_heads,
68
+ cross_attn=cross_attn,
69
+ dropout=config.attention_probs_dropout_prob,
70
+ causal=False,
71
+ fused_bias_fc=fused_bias_fc,
72
+ use_flash_attn=use_flash_attn,
73
+ return_residual=return_residual,
74
+ **rotary_kwargs,
75
+ )
76
+ return mixer_cls
77
+
78
+
79
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
80
+ inner_dim = config.intermediate_size
81
+ fused_mlp = getattr(config, "fused_mlp", False)
82
+ if fused_mlp:
83
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
84
+ "fused_mlp only " "supports approximate gelu"
85
+ )
86
+ if not fused_mlp:
87
+ approximate = (
88
+ "tanh"
89
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
90
+ else "none"
91
+ )
92
+ mlp_cls = partial(
93
+ Mlp,
94
+ hidden_features=inner_dim,
95
+ activation=partial(F.gelu, approximate=approximate),
96
+ return_residual=return_residual,
97
+ )
98
+ else:
99
+ if FusedMLP is None:
100
+ raise ImportError("fused_dense is not installed")
101
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
102
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
103
+ if isinstance(mlp_checkpoint_lvl, Sequence):
104
+ assert layer_idx is not None
105
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
106
+ mlp_cls = partial(
107
+ FusedMLP,
108
+ hidden_features=inner_dim,
109
+ checkpoint_lvl=mlp_checkpoint_lvl,
110
+ return_residual=return_residual,
111
+ )
112
+ return mlp_cls
113
+
114
+
115
+ def create_block(config, layer_idx=None):
116
+ last_layer_subset = getattr(config, "last_layer_subset", False)
117
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
118
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
119
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
120
+ # one layer) so we just choose not to return residual in this case.
121
+ return_residual = not cross_attn
122
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
123
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
124
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
125
+ block = Block(
126
+ config.hidden_size,
127
+ mixer_cls,
128
+ mlp_cls,
129
+ norm_cls=norm_cls,
130
+ prenorm=False,
131
+ resid_dropout1=config.hidden_dropout_prob,
132
+ resid_dropout2=config.hidden_dropout_prob,
133
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
134
+ return_residual=return_residual,
135
+ )
136
+ return block
137
+
138
+
139
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
140
+ def _init_weights(module, initializer_range=0.02):
141
+ if isinstance(module, nn.Linear):
142
+ nn.init.normal_(module.weight, std=initializer_range)
143
+ if module.bias is not None:
144
+ nn.init.zeros_(module.bias)
145
+ elif isinstance(module, nn.Embedding):
146
+ nn.init.normal_(module.weight, std=initializer_range)
147
+ if module.padding_idx is not None:
148
+ nn.init.zeros_(module.weight[module.padding_idx])
149
+
150
+
151
+ class BertEncoder(nn.Module):
152
+ def __init__(self, config: BertConfig):
153
+ super().__init__()
154
+ self.use_flash_attn = getattr(config, "use_flash_attn", False)
155
+ self.layers = nn.ModuleList(
156
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
157
+ )
158
+
159
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
160
+ """If subset_mask is not None, we only want output for the subset of the sequence.
161
+ This means that we only compute the last layer output for these tokens.
162
+ subset_mask: (batch, seqlen), dtype=torch.bool
163
+ """
164
+ if key_padding_mask is None or not self.use_flash_attn:
165
+ mixer_kwargs = (
166
+ {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
167
+ )
168
+ for layer in self.layers:
169
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
170
+ if subset_mask is not None:
171
+ hidden_states = hidden_states[subset_mask]
172
+ else:
173
+ batch, seqlen = hidden_states.shape[:2]
174
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
175
+ hidden_states, key_padding_mask
176
+ )
177
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
178
+ if subset_mask is None:
179
+ for layer in self.layers:
180
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
181
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
182
+ else:
183
+ for layer in self.layers[:-1]:
184
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
185
+ if key_padding_mask is not None:
186
+ subset_idx = torch.nonzero(
187
+ subset_mask[key_padding_mask], as_tuple=False
188
+ ).flatten()
189
+ subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
190
+ subset_cu_seqlens = F.pad(
191
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
192
+ )
193
+ else:
194
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
195
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
196
+ subset_cu_seqlens = F.pad(
197
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
198
+ )
199
+ hidden_states_subset, hidden_states = index_first_axis_residual(
200
+ hidden_states, subset_idx
201
+ )
202
+ # It's ok to set max_seqlen_q to be much larger
203
+ mixer_kwargs = {
204
+ "x_kv": hidden_states,
205
+ "cu_seqlens": subset_cu_seqlens,
206
+ "max_seqlen": max_seqlen_in_batch,
207
+ "cu_seqlens_k": cu_seqlens,
208
+ "max_seqlen_k": max_seqlen_in_batch,
209
+ }
210
+ hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
211
+ return hidden_states
212
+
213
+
214
+ class BertPooler(nn.Module):
215
+ def __init__(self, config):
216
+ super().__init__()
217
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
218
+ if fused_bias_fc and FusedDense is None:
219
+ raise ImportError("fused_dense is not installed")
220
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
221
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
222
+ self.activation = nn.Tanh()
223
+
224
+ def forward(self, hidden_states, pool=True):
225
+ # We "pool" the model by simply taking the hidden state corresponding
226
+ # to the first token.
227
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
228
+ pooled_output = self.dense(first_token_tensor)
229
+ pooled_output = self.activation(pooled_output)
230
+ return pooled_output
231
+
232
+
233
+ class BertPredictionHeadTransform(nn.Module):
234
+ def __init__(self, config):
235
+ super().__init__()
236
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
237
+ if fused_bias_fc and FusedDense is None:
238
+ raise ImportError("fused_dense is not installed")
239
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
240
+ if self.fused_dropout_add_ln and layer_norm is None:
241
+ raise ImportError("dropout_add_layer_norm is not installed")
242
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
243
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
244
+ approximate = (
245
+ "tanh"
246
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
247
+ else "none"
248
+ )
249
+ self.transform_act_fn = nn.GELU(approximate=approximate)
250
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
251
+
252
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
253
+ hidden_states = self.dense(hidden_states)
254
+ hidden_states = self.transform_act_fn(hidden_states)
255
+ if not self.fused_dropout_add_ln:
256
+ hidden_states = self.layer_norm(hidden_states)
257
+ else:
258
+ hidden_states = layer_norm(
259
+ hidden_states, self.layer_norm.weight, self.layer_norm.bias, self.layer_norm.eps
260
+ )
261
+ return hidden_states
262
+
263
+
264
+ class BertLMPredictionHead(nn.Module):
265
+ def __init__(self, config):
266
+ super().__init__()
267
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
268
+ if fused_bias_fc and FusedDense is None:
269
+ raise ImportError("fused_dense is not installed")
270
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
271
+
272
+ self.transform = BertPredictionHeadTransform(config)
273
+
274
+ # The output weights are the same as the input embeddings, but there is
275
+ # an output-only bias for each token.
276
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
277
+
278
+ def forward(self, hidden_states):
279
+ hidden_states = self.transform(hidden_states)
280
+ hidden_states = self.decoder(hidden_states)
281
+ return hidden_states
282
+
283
+
284
+ class BertPreTrainingHeads(nn.Module):
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ self.predictions = BertLMPredictionHead(config)
288
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
289
+
290
+ def forward(self, sequence_output, pooled_output):
291
+ prediction_scores = self.predictions(sequence_output)
292
+ seq_relationship_score = self.seq_relationship(pooled_output)
293
+ return prediction_scores, seq_relationship_score
294
+
295
+
296
+ class BertPreTrainedModel(nn.Module):
297
+ """An abstract class to handle weights initialization and
298
+ a simple interface for dowloading and loading pretrained models.
299
+ """
300
+
301
+ def __init__(self, config, *inputs, **kwargs):
302
+ super().__init__()
303
+ if not isinstance(config, BertConfig):
304
+ raise ValueError(
305
+ "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
306
+ "To create a model from a Google pretrained model use "
307
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
308
+ self.__class__.__name__, self.__class__.__name__
309
+ )
310
+ )
311
+ self.config = config
312
+
313
+ @classmethod
314
+ def from_pretrained(cls, model_name, config, *inputs, **kwargs):
315
+ """
316
+ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
317
+ Download and cache the pre-trained model file if needed.
318
+
319
+ Params:
320
+ pretrained_model_name_or_path: either:
321
+ - a path or url to a pretrained model archive containing:
322
+ . `bert_config.json` a configuration file for the model
323
+ . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
324
+ - a path or url to a pretrained model archive containing:
325
+ . `bert_config.json` a configuration file for the model
326
+ . `model.chkpt` a TensorFlow checkpoint
327
+ *inputs, **kwargs: additional input for the specific Bert class
328
+ (ex: num_labels for BertForSequenceClassification)
329
+ """
330
+ # Instantiate model.
331
+ model = cls(config, *inputs, **kwargs)
332
+ load_return = model.load_state_dict(
333
+ remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
334
+ )
335
+ logger.info(load_return)
336
+ return model
337
+
338
+
339
+ class BertModel(BertPreTrainedModel):
340
+ def __init__(self, config: BertConfig, add_pooling_layer=True):
341
+ super().__init__(config)
342
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
343
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
344
+ config.vocab_size += self.pad_vocab_size_multiple - (
345
+ config.vocab_size % self.pad_vocab_size_multiple
346
+ )
347
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
348
+ if self.fused_dropout_add_ln and layer_norm is None:
349
+ raise ImportError("dropout_add_layer_norm is not installed")
350
+ assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
351
+
352
+ self.embeddings = BertEmbeddings(
353
+ config.hidden_size,
354
+ config.vocab_size,
355
+ config.max_position_embeddings,
356
+ config.type_vocab_size,
357
+ padding_idx=config.pad_token_id,
358
+ )
359
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
360
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
361
+ self.encoder = BertEncoder(config)
362
+ self.pooler = BertPooler(config) if add_pooling_layer else None
363
+
364
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
365
+
366
+ def forward(
367
+ self,
368
+ input_ids,
369
+ position_ids=None,
370
+ token_type_ids=None,
371
+ attention_mask=None,
372
+ masked_tokens_mask=None,
373
+ ):
374
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
375
+ we only want the output for the masked tokens. This means that we only compute the last
376
+ layer output for these tokens.
377
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
378
+ """
379
+ hidden_states = self.embeddings(
380
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
381
+ )
382
+ # TD [2022-12:18]: Don't need to force residual in fp32
383
+ # BERT puts embedding LayerNorm before embedding dropout.
384
+ if not self.fused_dropout_add_ln:
385
+ hidden_states = self.emb_ln(hidden_states)
386
+ else:
387
+ hidden_states = layer_norm(
388
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, self.emb_ln.eps
389
+ )
390
+ hidden_states = self.emb_drop(hidden_states)
391
+
392
+ if masked_tokens_mask is not None:
393
+ batch_size, seqlen = input_ids.shape[:2]
394
+ # We also need the first column for the CLS token
395
+ first_col_mask = torch.zeros(
396
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
397
+ )
398
+ first_col_mask[:, 0] = True
399
+ subset_mask = masked_tokens_mask | first_col_mask
400
+ else:
401
+ subset_mask = None
402
+
403
+ sequence_output = self.encoder(
404
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
405
+ )
406
+
407
+ if masked_tokens_mask is None:
408
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
409
+ else:
410
+ # TD [2022-03-01]: the indexing here is very tricky.
411
+ if attention_mask is not None:
412
+ subset_idx = subset_mask[attention_mask]
413
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
414
+ sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
415
+ else:
416
+ pool_input = sequence_output[first_col_mask[subset_mask]]
417
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
418
+ pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
419
+
420
+ return BaseModelOutputWithPoolingAndCrossAttentions(
421
+ last_hidden_state=sequence_output,
422
+ pooler_output=pooled_output,
423
+ )
424
+
425
+
426
+ class BertForPreTraining(BertPreTrainedModel):
427
+ def __init__(self, config: BertConfig):
428
+ super().__init__(config)
429
+ # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
430
+ # (around 15%) to the classifier heads.
431
+ self.dense_seq_output = getattr(config, "dense_seq_output", False)
432
+ # If last_layer_subset, we only need the compute the last layer for a subset of tokens
433
+ # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
434
+ self.last_layer_subset = getattr(config, "last_layer_subset", False)
435
+ if self.last_layer_subset:
436
+ assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
437
+ use_xentropy = getattr(config, "use_xentropy", False)
438
+ if use_xentropy and CrossEntropyLoss is None:
439
+ raise ImportError("xentropy_cuda is not installed")
440
+ loss_cls = (
441
+ nn.CrossEntropyLoss
442
+ if not use_xentropy
443
+ else partial(CrossEntropyLoss, inplace_backward=True)
444
+ )
445
+
446
+ self.bert = BertModel(config)
447
+ self.cls = BertPreTrainingHeads(config)
448
+ self.mlm_loss = loss_cls(ignore_index=0)
449
+ self.nsp_loss = loss_cls(ignore_index=-1)
450
+
451
+ # Initialize weights and apply final processing
452
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
453
+ self.tie_weights()
454
+
455
+ def tie_weights(self):
456
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
457
+
458
+ def forward(
459
+ self,
460
+ input_ids,
461
+ position_ids=None,
462
+ token_type_ids=None,
463
+ attention_mask=None,
464
+ labels=None,
465
+ next_sentence_label=None,
466
+ ):
467
+ """
468
+ If labels are provided, they must be 0 for masked out tokens (as specified in the attention
469
+ mask).
470
+ Outputs:
471
+ if `labels` and `next_sentence_label` are not `None`:
472
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
473
+ sentence classification loss.
474
+ if `labels` or `next_sentence_label` is `None`:
475
+ Outputs a tuple comprising
476
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
477
+ - the next sentence classification logits of shape [batch_size, 2].
478
+
479
+ """
480
+ masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
481
+ outputs = self.bert(
482
+ input_ids,
483
+ position_ids=position_ids,
484
+ token_type_ids=token_type_ids,
485
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
486
+ masked_tokens_mask=masked_tokens_mask,
487
+ )
488
+ sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
489
+ if self.dense_seq_output and labels is not None:
490
+ masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
491
+ if not self.last_layer_subset:
492
+ sequence_output = index_first_axis(
493
+ rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
494
+ )
495
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
496
+
497
+ total_loss = None
498
+ if labels is not None and next_sentence_label is not None:
499
+ if (
500
+ self.dense_seq_output and labels is not None
501
+ ): # prediction_scores are already flattened
502
+ masked_lm_loss = self.mlm_loss(
503
+ prediction_scores, labels.flatten()[masked_token_idx]
504
+ )
505
+ else:
506
+ masked_lm_loss = self.mlm_loss(
507
+ rearrange(prediction_scores, "... v -> (...) v"),
508
+ rearrange(labels, "... -> (...)"),
509
+ )
510
+ next_sentence_loss = self.nsp_loss(
511
+ rearrange(seq_relationship_score, "... t -> (...) t"),
512
+ rearrange(next_sentence_label, "... -> (...)"),
513
+ )
514
+ total_loss = masked_lm_loss.float() + next_sentence_loss.float()
515
+
516
+ return BertForPreTrainingOutput(
517
+ loss=total_loss,
518
+ prediction_logits=prediction_scores,
519
+ seq_relationship_logits=seq_relationship_score,
520
+ )
521
+
522
+
523
+ def remap_state_dict(state_dict, config: PretrainedConfig):
524
+ """
525
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
526
+ """
527
+
528
+ # LayerNorm
529
+ def key_mapping_ln_gamma_beta(key):
530
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
531
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
532
+ return key
533
+
534
+ state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
535
+
536
+ # Layers
537
+ def key_mapping_layers(key):
538
+ return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
539
+
540
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
541
+
542
+ # LayerNorm
543
+ def key_mapping_ln(key):
544
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
545
+ key = re.sub(
546
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
547
+ r"bert.encoder.layers.\1.norm1.\2",
548
+ key,
549
+ )
550
+ key = re.sub(
551
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
552
+ r"bert.encoder.layers.\1.norm2.\2",
553
+ key,
554
+ )
555
+ key = re.sub(
556
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
557
+ r"cls.predictions.transform.layer_norm.\1",
558
+ key,
559
+ )
560
+ return key
561
+
562
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
563
+
564
+ # MLP
565
+ def key_mapping_mlp(key):
566
+ key = re.sub(
567
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
568
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
569
+ key,
570
+ )
571
+ key = re.sub(
572
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
573
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
574
+ key,
575
+ )
576
+ return key
577
+
578
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
579
+
580
+ # Attention
581
+ last_layer_subset = getattr(config, "last_layer_subset", False)
582
+ for d in range(config.num_hidden_layers):
583
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
584
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
585
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
586
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
587
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
588
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
589
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
590
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
591
+ [Wq, Wk, Wv], dim=0
592
+ )
593
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
594
+ else:
595
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
596
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
597
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
598
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
599
+
600
+ def key_mapping_attn(key):
601
+ return re.sub(
602
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
603
+ r"bert.encoder.layers.\1.mixer.out_proj.\2",
604
+ key,
605
+ )
606
+
607
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
608
+
609
+ def key_mapping_decoder_bias(key):
610
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
611
+
612
+ state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
613
+
614
+ # Word embedding
615
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
616
+ if pad_vocab_size_multiple > 1:
617
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
618
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
619
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
620
+ )
621
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
622
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
623
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
624
+ )
625
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
626
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
627
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
628
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
629
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
630
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
631
+ )
632
+
633
+ return state_dict
634
+
635
+
636
+ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
637
+ """
638
+ Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
639
+
640
+ This function is meant to be the inverse of remap_state_dict.
641
+ """
642
+ # Word embedding
643
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
644
+ if pad_vocab_size_multiple > 1:
645
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
646
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
647
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
648
+ # unpad embeddings
649
+ state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
650
+ : config.orig_vocab_size, :
651
+ ]
652
+ state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
653
+ state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
654
+
655
+ for d in range(config.num_hidden_layers):
656
+ last_layer_subset = getattr(config, "last_layer_subset", False)
657
+ if not last_layer_subset or d != (config.num_hidden_layers - 1):
658
+ Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
659
+ Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
660
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
661
+ : Wqkv_weights.shape[0] // 3, :
662
+ ]
663
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
664
+ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
665
+ ]
666
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
667
+ 2 * Wqkv_weights.shape[0] // 3 :, :
668
+ ]
669
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
670
+ : Wqkv_biases.shape[0] // 3
671
+ ]
672
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
673
+ Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
674
+ ]
675
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
676
+ 2 * Wqkv_biases.shape[0] // 3 :
677
+ ]
678
+ else:
679
+ Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
680
+ Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
681
+ Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
682
+ Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
683
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
684
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
685
+ : Wkv_weights.shape[0] // 2, :
686
+ ]
687
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
688
+ Wkv_weights.shape[0] // 2 :, :
689
+ ]
690
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
691
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
692
+ : Wkv_biases.shape[0] // 2
693
+ ]
694
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
695
+ Wkv_biases.shape[0] // 2 :
696
+ ]
697
+
698
+ def inv_key_mapping_ln(key):
699
+ key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
700
+ key = re.sub(
701
+ r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
702
+ r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
703
+ key,
704
+ )
705
+ key = re.sub(
706
+ r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
707
+ r"bert.encoder.layers.\1.output.LayerNorm.\2",
708
+ key,
709
+ )
710
+ key = re.sub(
711
+ r"cls.predictions.transform.layer_norm.(weight|bias)",
712
+ r"cls.predictions.transform.LayerNorm.\1",
713
+ key,
714
+ )
715
+ return key
716
+
717
+ def inv_key_mapping_ln_gamma_beta(key):
718
+ key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
719
+ key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
720
+ return key
721
+
722
+ def inv_key_mapping_layers(key):
723
+ return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
724
+
725
+ def inv_key_mapping_mlp(key):
726
+ key = re.sub(
727
+ r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
728
+ r"bert.encoder.layer.\1.intermediate.dense.\2",
729
+ key,
730
+ )
731
+ key = re.sub(
732
+ r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
733
+ r"bert.encoder.layer.\1.output.dense.\2",
734
+ key,
735
+ )
736
+ return key
737
+
738
+ def inv_key_mapping_attn(key):
739
+ return re.sub(
740
+ r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
741
+ r"bert.encoder.layer.\1.attention.output.dense.\2",
742
+ key,
743
+ )
744
+
745
+ def inv_key_mapping_decoder_bias(key):
746
+ return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
747
+
748
+ state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
749
+ state_dict = OrderedDict(
750
+ (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
751
+ )
752
+ state_dict = OrderedDict(
753
+ (inv_key_mapping_layers(key), value) for key, value in state_dict.items()
754
+ )
755
+ state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
756
+ state_dict = OrderedDict(
757
+ (inv_key_mapping_attn(key), value) for key, value in state_dict.items()
758
+ )
759
+ state_dict = OrderedDict(
760
+ (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
761
+ )
762
+
763
+ return state_dict