lgcharpe commited on
Commit
9e04fd3
·
verified ·
1 Parent(s): 45e963e

Update modeling_gptbert.py

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +11 -8
modeling_gptbert.py CHANGED
@@ -688,7 +688,17 @@ class GptBertPreTrainedModel(PreTrainedModel):
688
  _supports_flex_attn = False
689
 
690
  def _init_weights(self, module):
691
- pass
 
 
 
 
 
 
 
 
 
 
692
 
693
 
694
  class GptBertModel(GptBertPreTrainedModel):
@@ -879,17 +889,11 @@ class Classifier(nn.Module):
879
  self.emb2vocab.bias.zero_()
880
 
881
  def forward(self, x: torch.Tensor):
882
- print(x)
883
  x = self.pre_norm(x)
884
- print(x)
885
  x = self.dropout(x)
886
- print(x)
887
  x = self.projection(x)
888
- print(x)
889
  x = gelu_new(x)
890
- print(x)
891
  x = self.post_norm(x)
892
- print(x)
893
  return self.emb2vocab(x)
894
 
895
 
@@ -1043,7 +1047,6 @@ class GptBertForSequenceClassification(GptBertModel):
1043
 
1044
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1045
  logits = self.head(sequence_output[:, 0, :])
1046
- print(logits)
1047
 
1048
  loss = None
1049
  if labels is not None:
 
688
  _supports_flex_attn = False
689
 
690
  def _init_weights(self, module):
691
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
692
+
693
+ if isinstance(module, nn.Linear):
694
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
695
+ if module.bias is not None:
696
+ module.bias.data.zero_()
697
+ elif isinstance(module, nn.Embedding):
698
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
699
+ elif isinstance(module, nn.LayerNorm):
700
+ module.bias.data.zero_()
701
+ module.weight.data.fill_(1.0)
702
 
703
 
704
  class GptBertModel(GptBertPreTrainedModel):
 
889
  self.emb2vocab.bias.zero_()
890
 
891
  def forward(self, x: torch.Tensor):
 
892
  x = self.pre_norm(x)
 
893
  x = self.dropout(x)
 
894
  x = self.projection(x)
 
895
  x = gelu_new(x)
 
896
  x = self.post_norm(x)
 
897
  return self.emb2vocab(x)
898
 
899
 
 
1047
 
1048
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1049
  logits = self.head(sequence_output[:, 0, :])
 
1050
 
1051
  loss = None
1052
  if labels is not None: