Chenghao-Qiu commited on
Commit
918d868
·
verified ·
1 Parent(s): 993d0e0

Upload bert_layers.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. bert_layers.py +915 -0
bert_layers.py ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
5
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
6
+ # Copyright (c) 2022, Tri Dao.
7
+
8
+ import copy
9
+ import logging
10
+ import math
11
+ import warnings
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import rearrange
17
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
18
+ from transformers.activations import ACT2FN
19
+ from transformers.modeling_outputs import (MaskedLMOutput,
20
+ SequenceClassifierOutput)
21
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
22
+ from transformers.modeling_utils import PreTrainedModel
23
+
24
+ from .bert_padding import (index_first_axis,
25
+ index_put_first_axis, pad_input,
26
+ unpad_input, unpad_input_only)
27
+ from .configuration_bert import BertConfig
28
+ try:
29
+ from .flash_attn_triton import flash_attn_qkvpacked_func
30
+ except ImportError as e:
31
+ flash_attn_qkvpacked_func = None
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class BertEmbeddings(nn.Module):
37
+
38
+ def __init__(self, config):
39
+ super().__init__()
40
+ self.word_embeddings = nn.Embedding(config.vocab_size,
41
+ config.hidden_size,
42
+ padding_idx=config.pad_token_id)
43
+ # ALiBi doesn't use position embeddings
44
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
45
+ config.hidden_size)
46
+
47
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model
48
+ # variable name and be able to load any TensorFlow checkpoint file
49
+ self.LayerNorm = nn.LayerNorm(config.hidden_size,
50
+ eps=config.layer_norm_eps)
51
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
52
+ self.register_buffer('token_type_ids',
53
+ torch.zeros(config.max_position_embeddings,
54
+ dtype=torch.long),
55
+ persistent=False)
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: Optional[torch.LongTensor] = None,
60
+ token_type_ids: Optional[torch.LongTensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ past_key_values_length: int = 0,
64
+ ) -> torch.Tensor:
65
+ if (input_ids is not None) == (inputs_embeds is not None):
66
+ raise ValueError('Must specify either input_ids or input_embeds!')
67
+ if input_ids is not None:
68
+ input_shape = input_ids.size()
69
+ else:
70
+ assert inputs_embeds is not None # just for type checking
71
+ input_shape = inputs_embeds.size()[:-1]
72
+
73
+ seq_length = input_shape[1]
74
+
75
+ if position_ids is None:
76
+ # great! ALiBi
77
+ pass
78
+
79
+ # Setting the token_type_ids to the registered buffer in constructor
80
+ # where it is all zeros, which usually occurs when it's auto-generated;
81
+ # registered buffer helps users when tracing the model without passing
82
+ # token_type_ids, solves issue #5664
83
+ if token_type_ids is None:
84
+ if hasattr(self, 'token_type_ids'):
85
+ assert isinstance(self.token_type_ids, torch.LongTensor)
86
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
87
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
88
+ input_shape[0], seq_length)
89
+ token_type_ids = buffered_token_type_ids_expanded # type: ignore
90
+ else:
91
+ token_type_ids = torch.zeros(input_shape, # type: ignore
92
+ dtype=torch.long,
93
+ device=self.word_embeddings.device) # type: ignore # yapf: disable
94
+
95
+ if inputs_embeds is None:
96
+ inputs_embeds = self.word_embeddings(input_ids)
97
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
98
+
99
+ embeddings = inputs_embeds + token_type_embeddings
100
+ # no position embeddings! ALiBi
101
+ embeddings = self.LayerNorm(embeddings)
102
+ embeddings = self.dropout(embeddings)
103
+ return embeddings
104
+
105
+
106
+ class BertUnpadSelfAttention(nn.Module):
107
+
108
+ def __init__(self, config):
109
+ super().__init__()
110
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
111
+ config, 'embedding_size'):
112
+ raise ValueError(
113
+ f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention '
114
+ f'heads ({config.num_attention_heads})')
115
+
116
+ self.num_attention_heads = config.num_attention_heads
117
+ self.attention_head_size = int(config.hidden_size /
118
+ config.num_attention_heads)
119
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
120
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
121
+ self.p_dropout = config.attention_probs_dropout_prob
122
+ self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
123
+
124
+ # Warn if defaulting to pytorch because of import issues
125
+ if flash_attn_qkvpacked_func is None:
126
+ warnings.warn(
127
+ 'Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).'
128
+ )
129
+
130
+ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
131
+ max_seqlen_in_batch: int, indices: torch.Tensor,
132
+ attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
133
+ """Perform self-attention.
134
+
135
+ If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
136
+ implementation of self-attention.
137
+
138
+ The arguments are unpadded, and our implementations of attention require padded arguments,
139
+ so we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers.
140
+ The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute.
141
+ It is possible to write an unpadded implementation of attention (in Triton and PyTorch), which we will eventually do.
142
+
143
+ Args:
144
+ hidden_states: (total_nnz, dim)
145
+ cu_seqlens: (batch + 1,)
146
+ max_seqlen_in_batch: int
147
+ indices: (total_nnz,)
148
+ attn_mask: (batch, max_seqlen_in_batch)
149
+ bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
150
+
151
+ Returns:
152
+ attention: (total_nnz, dim)
153
+ """
154
+ qkv = self.Wqkv(hidden_states)
155
+ qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1,
156
+ max_seqlen_in_batch) # batch, max_seqlen_in_batch, thd
157
+ qkv = rearrange(qkv,
158
+ 'b s (t h d) -> b s t h d',
159
+ t=3,
160
+ h=self.num_attention_heads)
161
+ if self.p_dropout or flash_attn_qkvpacked_func is None:
162
+ # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
163
+ q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
164
+ k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
165
+ v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
166
+ attention_scores = torch.matmul(q, k) / math.sqrt(
167
+ self.attention_head_size)
168
+ attention_scores = attention_scores + bias
169
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
170
+ attention_probs = self.dropout(attention_probs)
171
+ print('-------------------')
172
+
173
+ attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
174
+ 3) # b s h d
175
+ else:
176
+ # Triton implementation only supports 0 attention dropout
177
+ convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
178
+ if convert_dtype:
179
+ # Triton implementation only supports fp16 and bf16
180
+ orig_dtype = qkv.dtype
181
+ qkv = qkv.to(torch.float16)
182
+ bias_dtype = bias.dtype
183
+ bias = bias.to(torch.float16)
184
+ attention = flash_attn_qkvpacked_func(qkv, bias)
185
+ attention = attention.to(orig_dtype)
186
+ bias = bias.to(bias_dtype)
187
+ else:
188
+ attention = flash_attn_qkvpacked_func(qkv, bias)
189
+
190
+ # attn_mask is 1 for attend and 0 for don't
191
+ attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
192
+ return rearrange(attention, 'nnz h d -> nnz (h d)')
193
+
194
+
195
+ # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
196
+ class BertSelfOutput(nn.Module):
197
+
198
+ def __init__(self, config):
199
+ super().__init__()
200
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
201
+ self.LayerNorm = nn.LayerNorm(config.hidden_size,
202
+ eps=config.layer_norm_eps)
203
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
204
+
205
+ def forward(self, hidden_states: torch.Tensor,
206
+ input_tensor: torch.Tensor) -> torch.Tensor:
207
+ hidden_states = self.dense(hidden_states)
208
+ hidden_states = self.dropout(hidden_states)
209
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
210
+ return hidden_states
211
+
212
+
213
+ class BertUnpadAttention(nn.Module):
214
+ """Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
215
+
216
+ def __init__(self, config):
217
+ super().__init__()
218
+ self.self = BertUnpadSelfAttention(config)
219
+ self.output = BertSelfOutput(config)
220
+
221
+ def forward(
222
+ self,
223
+ input_tensor: torch.Tensor,
224
+ cu_seqlens: torch.Tensor,
225
+ max_s: int,
226
+ subset_idx: Optional[torch.Tensor] = None,
227
+ indices: Optional[torch.Tensor] = None,
228
+ attn_mask: Optional[torch.Tensor] = None,
229
+ bias: Optional[torch.Tensor] = None,
230
+ ) -> torch.Tensor:
231
+ """Forward pass for scaled self-attention without padding.
232
+
233
+ Arguments:
234
+ input_tensor: (total_nnz, dim)
235
+ cu_seqlens: (batch + 1,)
236
+ max_s: int
237
+ subset_idx: () set of indices whose values we care about at the end of the layer
238
+ (e.g., the masked tokens, if this is the final layer).
239
+ indices: None or (total_nnz,)
240
+ attn_mask: None or (batch, max_seqlen_in_batch)
241
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
242
+ """
243
+ self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
244
+ attn_mask, bias)
245
+ if subset_idx is not None:
246
+ return self.output(index_first_axis(self_output, subset_idx),
247
+ index_first_axis(input_tensor, subset_idx))
248
+ else:
249
+ return self.output(self_output, input_tensor)
250
+
251
+
252
+ class BertGatedLinearUnitMLP(nn.Module):
253
+ """Applies the FFN at the end of each Mosaic BERT layer.
254
+
255
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
256
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but
257
+ introduces Gated Linear Units.
258
+
259
+ Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a
260
+ standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with
261
+ `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed
262
+ with the `config.intermediate_size=3072`.
263
+ However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased
264
+ parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
265
+ """
266
+
267
+ def __init__(self, config):
268
+ super().__init__()
269
+ self.config = config
270
+ self.gated_layers = nn.Linear(config.hidden_size,
271
+ config.intermediate_size * 2,
272
+ bias=False)
273
+ self.act = nn.GELU(approximate='none')
274
+ self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
275
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
276
+ self.layernorm = nn.LayerNorm(config.hidden_size,
277
+ eps=config.layer_norm_eps)
278
+
279
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
280
+ """Compute new hidden states from current hidden states.
281
+
282
+ Args:
283
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
284
+ the attention layer [nnz, dim].
285
+ """
286
+ residual_connection = hidden_states
287
+ # compute the activation
288
+ hidden_states = self.gated_layers(hidden_states)
289
+ gated = hidden_states[:, :self.config.intermediate_size]
290
+ non_gated = hidden_states[:, self.config.intermediate_size:]
291
+ hidden_states = self.act(gated) * non_gated
292
+ hidden_states = self.dropout(hidden_states)
293
+ # multiply by the second matrix
294
+ hidden_states = self.wo(hidden_states)
295
+ # add the residual connection and post-LN
296
+ hidden_states = self.layernorm(hidden_states + residual_connection)
297
+ return hidden_states
298
+
299
+
300
+ class BertLayer(nn.Module):
301
+ """Composes the Mosaic BERT attention and FFN blocks into a single layer."""
302
+
303
+ def __init__(self, config):
304
+ super(BertLayer, self).__init__()
305
+ self.attention = BertUnpadAttention(config)
306
+ self.mlp = BertGatedLinearUnitMLP(config)
307
+
308
+ def forward(
309
+ self,
310
+ hidden_states: torch.Tensor,
311
+ cu_seqlens: torch.Tensor,
312
+ seqlen: int,
313
+ subset_idx: Optional[torch.Tensor] = None,
314
+ indices: Optional[torch.Tensor] = None,
315
+ attn_mask: Optional[torch.Tensor] = None,
316
+ bias: Optional[torch.Tensor] = None,
317
+ ) -> torch.Tensor:
318
+ """Forward pass for a BERT layer, including both attention and MLP.
319
+
320
+ Args:
321
+ hidden_states: (total_nnz, dim)
322
+ cu_seqlens: (batch + 1,)
323
+ seqlen: int
324
+ subset_idx: () set of indices whose values we care about at the end of the layer
325
+ (e.g., the masked tokens, if this is the final layer).
326
+ indices: None or (total_nnz,)
327
+ attn_mask: None or (batch, max_seqlen_in_batch)
328
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
329
+ """
330
+ attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
331
+ subset_idx, indices, attn_mask, bias)
332
+ layer_output = self.mlp(attention_output)
333
+ return layer_output
334
+
335
+
336
+ class BertEncoder(nn.Module):
337
+ """A stack of BERT layers providing the backbone of Mosaic BERT.
338
+
339
+ This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertEncoder`,
340
+ but with substantial modifications to implement unpadding and ALiBi.
341
+
342
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
343
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
344
+ """
345
+
346
+ def __init__(self, config):
347
+ super().__init__()
348
+ layer = BertLayer(config)
349
+ self.layer = nn.ModuleList(
350
+ [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
351
+
352
+ self.num_attention_heads = config.num_attention_heads
353
+
354
+ # The alibi mask will be dynamically expanded if it is too small for
355
+ # the input the model receives. But it generally helps to initialize it
356
+ # to a reasonably large size to help pre-allocate CUDA memory.
357
+ # The default `alibi_starting_size` is 512.
358
+ self._current_alibi_size = int(config.alibi_starting_size)
359
+ self.alibi = torch.zeros(
360
+ (1, self.num_attention_heads, self._current_alibi_size,
361
+ self._current_alibi_size))
362
+ self.rebuild_alibi_tensor(size=config.alibi_starting_size)
363
+
364
+ def rebuild_alibi_tensor(self,
365
+ size: int,
366
+ device: Optional[Union[torch.device, str]] = None):
367
+ # Alibi
368
+ # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
369
+ # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
370
+ # of the logits, which makes the math work out *after* applying causal masking. If no causal masking
371
+ # will be applied, it is necessary to construct the diagonal mask.
372
+ n_heads = self.num_attention_heads
373
+
374
+ def _get_alibi_head_slopes(n_heads: int) -> List[float]:
375
+
376
+ def get_slopes_power_of_2(n_heads: int) -> List[float]:
377
+ start = (2**(-2**-(math.log2(n_heads) - 3)))
378
+ ratio = start
379
+ return [start * ratio**i for i in range(n_heads)]
380
+
381
+ # In the paper, they only train models that have 2^a heads for some a. This function
382
+ # has some good properties that only occur when the input is a power of 2. To
383
+ # maintain that even when the number of heads is not a power of 2, we use a
384
+ # workaround.
385
+ if math.log2(n_heads).is_integer():
386
+ return get_slopes_power_of_2(n_heads)
387
+
388
+ closest_power_of_2 = 2**math.floor(math.log2(n_heads))
389
+ slopes_a = get_slopes_power_of_2(closest_power_of_2)
390
+ slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
391
+ slopes_b = slopes_b[0::2][:n_heads - closest_power_of_2]
392
+ return slopes_a + slopes_b
393
+
394
+ context_position = torch.arange(size, device=device)[:, None]
395
+ memory_position = torch.arange(size, device=device)[None, :]
396
+ relative_position = torch.abs(memory_position - context_position)
397
+ # [n_heads, max_token_length, max_token_length]
398
+ relative_position = relative_position.unsqueeze(0).expand(
399
+ n_heads, -1, -1)
400
+ slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
401
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
402
+ # [1, n_heads, max_token_length, max_token_length]
403
+ alibi = alibi.unsqueeze(0)
404
+ assert alibi.shape == torch.Size([1, n_heads, size, size])
405
+
406
+ self._current_alibi_size = size
407
+ self.alibi = alibi
408
+
409
+ def forward(
410
+ self,
411
+ hidden_states: torch.Tensor,
412
+ attention_mask: torch.Tensor,
413
+ output_all_encoded_layers: Optional[bool] = True,
414
+ subset_mask: Optional[torch.Tensor] = None,
415
+ ) -> List[torch.Tensor]:
416
+
417
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
418
+ extended_attention_mask = extended_attention_mask.to(
419
+ dtype=torch.float32) # fp16 compatibility
420
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
421
+
422
+ attention_mask_bool = attention_mask.bool()
423
+ batch, seqlen = hidden_states.shape[:2]
424
+ # Unpad inputs and mask. It will remove tokens that are padded.
425
+ # Assume ntokens is total number of tokens (padded and non-padded)
426
+ # and ntokens_unpad is total number of non-padded tokens.
427
+ # Then unpadding performs the following compression of the inputs:
428
+ # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
429
+ hidden_states, indices, cu_seqlens, _ = unpad_input(
430
+ hidden_states, attention_mask_bool)
431
+
432
+ # Add alibi matrix to extended_attention_mask
433
+ if self._current_alibi_size < seqlen:
434
+ # Rebuild the alibi tensor when needed
435
+ warnings.warn(
436
+ f'Increasing alibi size from {self._current_alibi_size} to {seqlen}'
437
+ )
438
+ self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
439
+ elif self.alibi.device != hidden_states.device:
440
+ # Device catch-up
441
+ self.alibi = self.alibi.to(hidden_states.device)
442
+ alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
443
+ attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
444
+ alibi_attn_mask = attn_bias + alibi_bias
445
+
446
+ all_encoder_layers = []
447
+ if subset_mask is None:
448
+ for layer_module in self.layer:
449
+ hidden_states = layer_module(hidden_states,
450
+ cu_seqlens,
451
+ seqlen,
452
+ None,
453
+ indices,
454
+ attn_mask=attention_mask,
455
+ bias=alibi_attn_mask)
456
+ if output_all_encoded_layers:
457
+ all_encoder_layers.append(hidden_states)
458
+ # Pad inputs and mask. It will insert back zero-padded tokens.
459
+ # Assume ntokens is total number of tokens (padded and non-padded)
460
+ # and ntokens_unpad is total number of non-padded tokens.
461
+ # Then padding performs the following de-compression:
462
+ # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
463
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
464
+ else:
465
+ for i in range(len(self.layer) - 1):
466
+ layer_module = self.layer[i]
467
+ hidden_states = layer_module(hidden_states,
468
+ cu_seqlens,
469
+ seqlen,
470
+ None,
471
+ indices,
472
+ attn_mask=attention_mask,
473
+ bias=alibi_attn_mask)
474
+ if output_all_encoded_layers:
475
+ all_encoder_layers.append(hidden_states)
476
+ subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
477
+ as_tuple=False).flatten()
478
+ hidden_states = self.layer[-1](hidden_states,
479
+ cu_seqlens,
480
+ seqlen,
481
+ subset_idx=subset_idx,
482
+ indices=indices,
483
+ attn_mask=attention_mask,
484
+ bias=alibi_attn_mask)
485
+
486
+ if not output_all_encoded_layers:
487
+ all_encoder_layers.append(hidden_states)
488
+ return all_encoder_layers
489
+
490
+
491
+ class BertPooler(nn.Module):
492
+
493
+ def __init__(self, config):
494
+ super(BertPooler, self).__init__()
495
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
496
+ self.activation = nn.Tanh()
497
+
498
+ def forward(self,
499
+ hidden_states: torch.Tensor,
500
+ pool: Optional[bool] = True) -> torch.Tensor:
501
+ # We "pool" the model by simply taking the hidden state corresponding
502
+ # to the first token.
503
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
504
+ pooled_output = self.dense(first_token_tensor)
505
+ pooled_output = self.activation(pooled_output)
506
+ return pooled_output
507
+
508
+
509
+ class BertPredictionHeadTransform(nn.Module):
510
+
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
514
+ if isinstance(config.hidden_act, str):
515
+ self.transform_act_fn = ACT2FN[config.hidden_act]
516
+ else:
517
+ self.transform_act_fn = config.hidden_act
518
+ self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
519
+
520
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
521
+ hidden_states = self.dense(hidden_states)
522
+ hidden_states = self.transform_act_fn(hidden_states)
523
+ hidden_states = self.LayerNorm(hidden_states)
524
+ return hidden_states
525
+
526
+
527
+ class BertModel(BertPreTrainedModel):
528
+ """Overall BERT model.
529
+
530
+ Args:
531
+ config: a BertConfig class instance with the configuration to build a new model
532
+
533
+ Inputs:
534
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
535
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
536
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
537
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
538
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
539
+ a `sentence B` token (see BERT paper for more details).
540
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
541
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
542
+ input sequence length in the current batch. It's the mask that we typically use for attention when
543
+ a batch has varying length sentences.
544
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
545
+
546
+ Outputs: Tuple of (encoded_layers, pooled_output)
547
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
548
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
549
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
550
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
551
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
552
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
553
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
554
+ classifier pretrained on top of the hidden state associated to the first character of the
555
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
556
+
557
+ Example usage:
558
+ ```python
559
+ # Already been converted into WordPiece token ids
560
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
561
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
562
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
563
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
564
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
565
+ model = BertModel(config=config)
566
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
567
+ ```
568
+ """
569
+ config_class = BertConfig
570
+
571
+ def __init__(self, config, add_pooling_layer=True):
572
+ super(BertModel, self).__init__(config)
573
+ self.embeddings = BertEmbeddings(config)
574
+ self.encoder = BertEncoder(config)
575
+ self.pooler = BertPooler(config) if add_pooling_layer else None
576
+ self.post_init()
577
+
578
+ def get_input_embeddings(self):
579
+ return self.embeddings.word_embeddings
580
+
581
+ def set_input_embeddings(self, value):
582
+ self.embeddings.word_embeddings = value
583
+
584
+ def forward(
585
+ self,
586
+ input_ids: torch.Tensor,
587
+ token_type_ids: Optional[torch.Tensor] = None,
588
+ attention_mask: Optional[torch.Tensor] = None,
589
+ position_ids: Optional[torch.Tensor] = None,
590
+ output_all_encoded_layers: Optional[bool] = False,
591
+ masked_tokens_mask: Optional[torch.Tensor] = None,
592
+ **kwargs
593
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
594
+ if attention_mask is None:
595
+ attention_mask = torch.ones_like(input_ids)
596
+ if token_type_ids is None:
597
+ token_type_ids = torch.zeros_like(input_ids)
598
+
599
+ embedding_output = self.embeddings(input_ids, token_type_ids,
600
+ position_ids)
601
+
602
+ subset_mask = []
603
+ first_col_mask = []
604
+
605
+ if masked_tokens_mask is None:
606
+ subset_mask = None
607
+ else:
608
+ first_col_mask = torch.zeros_like(masked_tokens_mask)
609
+ first_col_mask[:, 0] = True
610
+ subset_mask = masked_tokens_mask | first_col_mask
611
+
612
+ encoder_outputs = self.encoder(
613
+ embedding_output,
614
+ attention_mask,
615
+ output_all_encoded_layers=output_all_encoded_layers,
616
+ subset_mask=subset_mask)
617
+
618
+ if masked_tokens_mask is None:
619
+ sequence_output = encoder_outputs[-1]
620
+ pooled_output = self.pooler(
621
+ sequence_output) if self.pooler is not None else None
622
+ else:
623
+ # TD [2022-03-01]: the indexing here is very tricky.
624
+ attention_mask_bool = attention_mask.bool()
625
+ subset_idx = subset_mask[attention_mask_bool] # type: ignore
626
+ sequence_output = encoder_outputs[-1][
627
+ masked_tokens_mask[attention_mask_bool][subset_idx]]
628
+ if self.pooler is not None:
629
+ pool_input = encoder_outputs[-1][
630
+ first_col_mask[attention_mask_bool][subset_idx]]
631
+ pooled_output = self.pooler(pool_input, pool=False)
632
+ else:
633
+ pooled_output = None
634
+
635
+ if not output_all_encoded_layers:
636
+ encoder_outputs = sequence_output
637
+
638
+ if self.pooler is not None:
639
+ return encoder_outputs, pooled_output
640
+
641
+ return encoder_outputs, None
642
+
643
+
644
+ ###################
645
+ # Bert Heads
646
+ ###################
647
+ class BertLMPredictionHead(nn.Module):
648
+
649
+ def __init__(self, config, bert_model_embedding_weights):
650
+ super().__init__()
651
+ self.transform = BertPredictionHeadTransform(config)
652
+ # The output weights are the same as the input embeddings, but there is
653
+ # an output-only bias for each token.
654
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
655
+ bert_model_embedding_weights.size(0))
656
+ self.decoder.weight = bert_model_embedding_weights
657
+
658
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
659
+ hidden_states = self.transform(hidden_states)
660
+ hidden_states = self.decoder(hidden_states)
661
+ return hidden_states
662
+
663
+
664
+ class BertOnlyMLMHead(nn.Module):
665
+
666
+ def __init__(self, config, bert_model_embedding_weights):
667
+ super().__init__()
668
+ self.predictions = BertLMPredictionHead(config,
669
+ bert_model_embedding_weights)
670
+
671
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
672
+ prediction_scores = self.predictions(sequence_output)
673
+ return prediction_scores
674
+
675
+
676
+ class BertOnlyNSPHead(nn.Module):
677
+
678
+ def __init__(self, config):
679
+ super().__init__()
680
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
681
+
682
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
683
+ seq_relationship_score = self.seq_relationship(pooled_output)
684
+ return seq_relationship_score
685
+
686
+
687
+
688
+ class BertForMaskedLM(BertPreTrainedModel):
689
+
690
+ def __init__(self, config):
691
+ super().__init__(config)
692
+
693
+ if config.is_decoder:
694
+ warnings.warn(
695
+ 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
696
+ 'bi-directional self-attention.')
697
+
698
+ self.bert = BertModel(config, add_pooling_layer=False)
699
+ self.cls = BertOnlyMLMHead(config,
700
+ self.bert.embeddings.word_embeddings.weight)
701
+
702
+ # Initialize weights and apply final processing
703
+ self.post_init()
704
+
705
+ def get_output_embeddings(self):
706
+ return self.cls.predictions.decoder
707
+
708
+ def set_output_embeddings(self, new_embeddings):
709
+ self.cls.predictions.decoder = new_embeddings
710
+
711
+ def forward(
712
+ self,
713
+ input_ids: Optional[torch.Tensor] = None,
714
+ attention_mask: Optional[torch.Tensor] = None,
715
+ token_type_ids: Optional[torch.Tensor] = None,
716
+ position_ids: Optional[torch.Tensor] = None,
717
+ head_mask: Optional[torch.Tensor] = None,
718
+ inputs_embeds: Optional[torch.Tensor] = None,
719
+ encoder_hidden_states: Optional[torch.Tensor] = None,
720
+ encoder_attention_mask: Optional[torch.Tensor] = None,
721
+ labels: Optional[torch.Tensor] = None,
722
+ output_attentions: Optional[bool] = None,
723
+ output_hidden_states: Optional[bool] = None,
724
+ return_dict: Optional[bool] = None,
725
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
726
+ # labels should be a `torch.LongTensor` of shape
727
+ # `(batch_size, sequence_length)`. These are used for computing the
728
+ # masked language modeling loss.
729
+ #
730
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
731
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
732
+ # (masked), the loss is only computed for the tokens with labels in `[0,
733
+ # ..., config.vocab_size]`
734
+ #
735
+ # Prediction scores are only computed for masked tokens and the (bs,
736
+ # seqlen) dimensions are flattened
737
+ if (input_ids is not None) == (inputs_embeds is not None):
738
+ raise ValueError('Must specify either input_ids or input_embeds!')
739
+
740
+ if labels is None:
741
+ masked_tokens_mask = None
742
+ else:
743
+ masked_tokens_mask = labels > 0
744
+
745
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
746
+
747
+ outputs = self.bert(
748
+ input_ids,
749
+ attention_mask=attention_mask,
750
+ token_type_ids=token_type_ids,
751
+ position_ids=position_ids,
752
+ head_mask=head_mask,
753
+ inputs_embeds=inputs_embeds,
754
+ encoder_hidden_states=encoder_hidden_states,
755
+ encoder_attention_mask=encoder_attention_mask,
756
+ output_attentions=output_attentions,
757
+ output_hidden_states=output_hidden_states,
758
+ return_dict=return_dict,
759
+ masked_tokens_mask=masked_tokens_mask,
760
+ )
761
+
762
+ sequence_output = outputs[0]
763
+ prediction_scores = self.cls(sequence_output)
764
+
765
+ loss = None
766
+ if labels is not None:
767
+ # Compute loss
768
+ loss_fct = nn.CrossEntropyLoss()
769
+ masked_token_idx = torch.nonzero(labels.flatten() > 0,
770
+ as_tuple=False).flatten()
771
+ loss = loss_fct(prediction_scores,
772
+ labels.flatten()[masked_token_idx])
773
+
774
+ assert input_ids is not None, 'Coding error; please open an issue'
775
+ batch, seqlen = input_ids.shape[:2]
776
+ prediction_scores = rearrange(index_put_first_axis(
777
+ prediction_scores, masked_token_idx, batch * seqlen),
778
+ '(b s) d -> b s d',
779
+ b=batch)
780
+
781
+ if not return_dict:
782
+ output = (prediction_scores,) + outputs[2:]
783
+ return ((loss,) + output) if loss is not None else output
784
+
785
+ return MaskedLMOutput(
786
+ loss=loss,
787
+ logits=prediction_scores,
788
+ hidden_states=outputs[0],
789
+ attentions=None,
790
+ )
791
+
792
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
793
+ attention_mask: torch.Tensor,
794
+ **model_kwargs):
795
+ input_shape = input_ids.shape
796
+ effective_batch_size = input_shape[0]
797
+
798
+ # add a dummy token
799
+ if self.config.pad_token_id is None:
800
+ raise ValueError('The PAD token should be defined for generation')
801
+
802
+ attention_mask = torch.cat([
803
+ attention_mask,
804
+ attention_mask.new_zeros((attention_mask.shape[0], 1))
805
+ ],
806
+ dim=-1)
807
+ dummy_token = torch.full((effective_batch_size, 1),
808
+ self.config.pad_token_id,
809
+ dtype=torch.long,
810
+ device=input_ids.device)
811
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
812
+
813
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
814
+
815
+
816
+
817
+ class BertForSequenceClassification(BertPreTrainedModel):
818
+ """Bert Model transformer with a sequence classification/regression head.
819
+
820
+ This head is just a linear layer on top of the pooled output. Used for,
821
+ e.g., GLUE tasks.
822
+ """
823
+
824
+ def __init__(self, config):
825
+ super().__init__(config)
826
+ self.num_labels = config.num_labels
827
+ self.config = config
828
+
829
+ self.bert = BertModel(config)
830
+ classifier_dropout = (config.classifier_dropout
831
+ if config.classifier_dropout is not None else
832
+ config.hidden_dropout_prob)
833
+ self.dropout = nn.Dropout(classifier_dropout)
834
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
835
+
836
+ # Initialize weights and apply final processing
837
+ self.post_init()
838
+
839
+
840
+ def forward(
841
+ self,
842
+ input_ids: Optional[torch.Tensor] = None,
843
+ attention_mask: Optional[torch.Tensor] = None,
844
+ token_type_ids: Optional[torch.Tensor] = None,
845
+ position_ids: Optional[torch.Tensor] = None,
846
+ head_mask: Optional[torch.Tensor] = None,
847
+ inputs_embeds: Optional[torch.Tensor] = None,
848
+ labels: Optional[torch.Tensor] = None,
849
+ output_attentions: Optional[bool] = None,
850
+ output_hidden_states: Optional[bool] = None,
851
+ return_dict: Optional[bool] = None,
852
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
853
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
854
+ # Labels for computing the sequence classification/regression loss.
855
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
856
+ # If `config.num_labels == 1` a regression loss is computed
857
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
858
+ # is computed (cross-entropy).
859
+
860
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
861
+
862
+ outputs = self.bert(
863
+ input_ids,
864
+ attention_mask=attention_mask,
865
+ token_type_ids=token_type_ids,
866
+ position_ids=position_ids,
867
+ head_mask=head_mask,
868
+ inputs_embeds=inputs_embeds,
869
+ output_attentions=output_attentions,
870
+ output_hidden_states=output_hidden_states,
871
+ return_dict=return_dict,
872
+ )
873
+
874
+ pooled_output = outputs[1]
875
+
876
+ pooled_output = self.dropout(pooled_output)
877
+ logits = self.classifier(pooled_output)
878
+
879
+ loss = None
880
+ if labels is not None:
881
+ # Compute loss
882
+ if self.config.problem_type is None:
883
+ if self.num_labels == 1:
884
+ self.config.problem_type = 'regression'
885
+ elif self.num_labels > 1 and (labels.dtype == torch.long or
886
+ labels.dtype == torch.int):
887
+ self.config.problem_type = 'single_label_classification'
888
+ else:
889
+ self.config.problem_type = 'multi_label_classification'
890
+
891
+ if self.config.problem_type == 'regression':
892
+ loss_fct = nn.MSELoss()
893
+ if self.num_labels == 1:
894
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
895
+ else:
896
+ loss = loss_fct(logits, labels)
897
+ elif self.config.problem_type == 'single_label_classification':
898
+ loss_fct = nn.CrossEntropyLoss()
899
+ loss = loss_fct(logits.view(-1, self.num_labels),
900
+ labels.view(-1))
901
+ elif self.config.problem_type == 'multi_label_classification':
902
+ loss_fct = nn.BCEWithLogitsLoss()
903
+ loss = loss_fct(logits, labels)
904
+
905
+ if not return_dict:
906
+ output = (logits,) + outputs[2:]
907
+ return ((loss,) + output) if loss is not None else output
908
+
909
+ return SequenceClassifierOutput(
910
+ loss=loss,
911
+ logits=logits,
912
+ hidden_states=outputs[0],
913
+ attentions=None,
914
+ )
915
+