davda54 commited on
Commit
f84759d
·
verified ·
1 Parent(s): e70d142

Update modeling_norbert.py

Browse files
Files changed (1) hide show
  1. modeling_norbert.py +31 -32
modeling_norbert.py CHANGED
@@ -22,11 +22,6 @@ class Encoder(nn.Module):
22
  def __init__(self, config, activation_checkpointing=False):
23
  super().__init__()
24
  self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])
25
-
26
- for i, layer in enumerate(self.layers):
27
- layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
28
- layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
29
-
30
  self.activation_checkpointing = activation_checkpointing
31
 
32
  def forward(self, hidden_states, attention_mask, relative_embedding):
@@ -119,11 +114,7 @@ class Attention(nn.Module):
119
  self.pre_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
120
  self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
121
 
122
- position_indices = torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(1) \
123
- - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
124
- position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
125
- position_indices = config.position_bucket_size - 1 + position_indices
126
- self.register_buffer("position_indices", position_indices.contiguous(), persistent=False)
127
 
128
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
129
  self.scale = 1.0 / math.sqrt(3 * self.head_size)
@@ -140,13 +131,14 @@ class Attention(nn.Module):
140
  batch_size, key_len, _ = hidden_states.size()
141
  query_len = key_len
142
 
143
- # Recompute position_indices if sequence length exceeds the precomputed size
144
- if self.position_indices.size(0) < query_len:
145
- position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
146
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
147
- position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
148
- position_indices = self.config.position_bucket_size - 1 + position_indices
149
- self.position_indices = position_indices.to(hidden_states.device)
 
150
 
151
  # Pre-LN and project query/key/value.
152
  hidden_states = self.pre_layer_norm(hidden_states) # shape: [B, T, D]
@@ -222,6 +214,8 @@ class NorbertPreTrainedModel(PreTrainedModel):
222
  config_class = NorbertConfig
223
  base_model_prefix = "norbert3"
224
  supports_gradient_checkpointing = True
 
 
225
 
226
  def _set_gradient_checkpointing(self, module, value=False):
227
  if isinstance(module, Encoder):
@@ -230,15 +224,12 @@ class NorbertPreTrainedModel(PreTrainedModel):
230
  def _init_weights(self, module):
231
  std = math.sqrt(2.0 / (5.0 * self.hidden_size))
232
 
233
- if isinstance(module, nn.Linear):
234
- nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
235
- if module.bias is not None:
236
- module.bias.data.zero_()
237
- elif isinstance(module, nn.Embedding):
238
- nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
239
- elif isinstance(module, nn.LayerNorm):
240
- module.bias.data.zero_()
241
- module.weight.data.fill_(1.0)
242
 
243
 
244
  class NorbertModel(NorbertPreTrainedModel):
@@ -251,6 +242,8 @@ class NorbertModel(NorbertPreTrainedModel):
251
  self.transformer = Encoder(config, activation_checkpointing=gradient_checkpointing)
252
  self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
253
 
 
 
254
  def get_input_embeddings(self):
255
  return self.embedding.word_embedding
256
 
@@ -315,16 +308,18 @@ class NorbertModel(NorbertPreTrainedModel):
315
 
316
 
317
  class NorbertForMaskedLM(NorbertModel):
318
- _keys_to_ignore_on_load_unexpected = ["head"]
 
319
 
320
  def __init__(self, config, **kwargs):
321
  super().__init__(config, add_mlm_layer=True, **kwargs)
 
322
 
323
  def get_output_embeddings(self):
324
- return self.classifier.nonlinearity[-1].weight
325
 
326
  def set_output_embeddings(self, new_embeddings):
327
- self.classifier.nonlinearity[-1].weight = new_embeddings
328
 
329
  def forward(
330
  self,
@@ -386,13 +381,14 @@ class Classifier(nn.Module):
386
 
387
 
388
  class NorbertForSequenceClassification(NorbertModel):
389
- _keys_to_ignore_on_load_unexpected = ["classifier"]
390
 
391
  def __init__(self, config, **kwargs):
392
  super().__init__(config, add_mlm_layer=False, **kwargs)
393
 
394
  self.num_labels = config.num_labels
395
  self.head = Classifier(config, self.num_labels)
 
396
 
397
  def forward(
398
  self,
@@ -451,13 +447,14 @@ class NorbertForSequenceClassification(NorbertModel):
451
 
452
 
453
  class NorbertForTokenClassification(NorbertModel):
454
- _keys_to_ignore_on_load_unexpected = ["classifier"]
455
 
456
  def __init__(self, config, **kwargs):
457
  super().__init__(config, add_mlm_layer=False, **kwargs)
458
 
459
  self.num_labels = config.num_labels
460
  self.head = Classifier(config, self.num_labels)
 
461
 
462
  def forward(
463
  self,
@@ -498,13 +495,14 @@ class NorbertForTokenClassification(NorbertModel):
498
 
499
 
500
  class NorbertForQuestionAnswering(NorbertModel):
501
- _keys_to_ignore_on_load_unexpected = ["classifier"]
502
 
503
  def __init__(self, config, **kwargs):
504
  super().__init__(config, add_mlm_layer=False, **kwargs)
505
 
506
  self.num_labels = config.num_labels
507
  self.head = Classifier(config, self.num_labels)
 
508
 
509
  def forward(
510
  self,
@@ -565,13 +563,14 @@ class NorbertForQuestionAnswering(NorbertModel):
565
 
566
 
567
  class NorbertForMultipleChoice(NorbertModel):
568
- _keys_to_ignore_on_load_unexpected = ["classifier"]
569
 
570
  def __init__(self, config, **kwargs):
571
  super().__init__(config, add_mlm_layer=False, **kwargs)
572
 
573
  self.num_labels = getattr(config, "num_labels", 2)
574
  self.head = Classifier(config, self.num_labels)
 
575
 
576
  def forward(
577
  self,
 
22
  def __init__(self, config, activation_checkpointing=False):
23
  super().__init__()
24
  self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
 
 
 
25
  self.activation_checkpointing = activation_checkpointing
26
 
27
  def forward(self, hidden_states, attention_mask, relative_embedding):
 
114
  self.pre_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
115
  self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
116
 
117
+ self.position_indices = None
 
 
 
 
118
 
119
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
  self.scale = 1.0 / math.sqrt(3 * self.head_size)
 
131
  batch_size, key_len, _ = hidden_states.size()
132
  query_len = key_len
133
 
134
+ # Recompute position_indices at the beginning or if sequence length exceeds the precomputed size
135
+ if self.position_indices is None or self.position_indices.size(0) < query_len:
136
+ self.position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
137
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
138
+ self.position_indices = self.make_log_bucket_position(self.position_indices, self.config.position_bucket_size, 512)
139
+ self.position_indices = self.config.position_bucket_size - 1 + self.position_indices
140
+ if self.position_indices.device != hidden_states.device:
141
+ self.position_indices = self.position_indices.to(hidden_states.device)
142
 
143
  # Pre-LN and project query/key/value.
144
  hidden_states = self.pre_layer_norm(hidden_states) # shape: [B, T, D]
 
214
  config_class = NorbertConfig
215
  base_model_prefix = "norbert3"
216
  supports_gradient_checkpointing = True
217
+ _tied_weights_keys = {}
218
+ _keys_to_ignore_on_load_unexpected = [r".*position_indices.*"]
219
 
220
  def _set_gradient_checkpointing(self, module, value=False):
221
  if isinstance(module, Encoder):
 
224
  def _init_weights(self, module):
225
  std = math.sqrt(2.0 / (5.0 * self.hidden_size))
226
 
227
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Embedding):
228
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-2*std, b=2*std)
229
+ elif isinstance(module, nn.LayerNorm) and module.weight is not None:
230
+ nn.init.ones_(module.weight)
231
+ if hasattr(module, "bias") and module.bias is not None:
232
+ nn.init.zeros_(module.bias)
 
 
 
233
 
234
 
235
  class NorbertModel(NorbertPreTrainedModel):
 
242
  self.transformer = Encoder(config, activation_checkpointing=gradient_checkpointing)
243
  self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
244
 
245
+ self.post_init()
246
+
247
  def get_input_embeddings(self):
248
  return self.embedding.word_embedding
249
 
 
308
 
309
 
310
  class NorbertForMaskedLM(NorbertModel):
311
+ _keys_to_ignore_on_load_unexpected = ["head", r".*position_indices.*"]
312
+ _tied_weights_keys = {"classifier.nonlinearity.5.weight": "embedding.word_embedding.weight"}
313
 
314
  def __init__(self, config, **kwargs):
315
  super().__init__(config, add_mlm_layer=True, **kwargs)
316
+ self.post_init()
317
 
318
  def get_output_embeddings(self):
319
+ return self.classifier.nonlinearity[-1]
320
 
321
  def set_output_embeddings(self, new_embeddings):
322
+ self.classifier.nonlinearity[-1] = new_embeddings
323
 
324
  def forward(
325
  self,
 
381
 
382
 
383
  class NorbertForSequenceClassification(NorbertModel):
384
+ _keys_to_ignore_on_load_unexpected = ["classifier", r".*position_indices.*"]
385
 
386
  def __init__(self, config, **kwargs):
387
  super().__init__(config, add_mlm_layer=False, **kwargs)
388
 
389
  self.num_labels = config.num_labels
390
  self.head = Classifier(config, self.num_labels)
391
+ self.post_init()
392
 
393
  def forward(
394
  self,
 
447
 
448
 
449
  class NorbertForTokenClassification(NorbertModel):
450
+ _keys_to_ignore_on_load_unexpected = ["classifier", r".*position_indices.*"]
451
 
452
  def __init__(self, config, **kwargs):
453
  super().__init__(config, add_mlm_layer=False, **kwargs)
454
 
455
  self.num_labels = config.num_labels
456
  self.head = Classifier(config, self.num_labels)
457
+ self.post_init()
458
 
459
  def forward(
460
  self,
 
495
 
496
 
497
  class NorbertForQuestionAnswering(NorbertModel):
498
+ _keys_to_ignore_on_load_unexpected = ["classifier", r".*position_indices.*"]
499
 
500
  def __init__(self, config, **kwargs):
501
  super().__init__(config, add_mlm_layer=False, **kwargs)
502
 
503
  self.num_labels = config.num_labels
504
  self.head = Classifier(config, self.num_labels)
505
+ self.post_init()
506
 
507
  def forward(
508
  self,
 
563
 
564
 
565
  class NorbertForMultipleChoice(NorbertModel):
566
+ _keys_to_ignore_on_load_unexpected = ["classifier", r".*position_indices.*"]
567
 
568
  def __init__(self, config, **kwargs):
569
  super().__init__(config, add_mlm_layer=False, **kwargs)
570
 
571
  self.num_labels = getattr(config, "num_labels", 2)
572
  self.head = Classifier(config, self.num_labels)
573
+ self.post_init()
574
 
575
  def forward(
576
  self,