Update modeling_gptbert.py
Browse files- modeling_gptbert.py +11 -8
modeling_gptbert.py
CHANGED
|
@@ -688,7 +688,17 @@ class GptBertPreTrainedModel(PreTrainedModel):
|
|
| 688 |
_supports_flex_attn = False
|
| 689 |
|
| 690 |
def _init_weights(self, module):
|
| 691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 692 |
|
| 693 |
|
| 694 |
class GptBertModel(GptBertPreTrainedModel):
|
|
@@ -879,17 +889,11 @@ class Classifier(nn.Module):
|
|
| 879 |
self.emb2vocab.bias.zero_()
|
| 880 |
|
| 881 |
def forward(self, x: torch.Tensor):
|
| 882 |
-
print(x)
|
| 883 |
x = self.pre_norm(x)
|
| 884 |
-
print(x)
|
| 885 |
x = self.dropout(x)
|
| 886 |
-
print(x)
|
| 887 |
x = self.projection(x)
|
| 888 |
-
print(x)
|
| 889 |
x = gelu_new(x)
|
| 890 |
-
print(x)
|
| 891 |
x = self.post_norm(x)
|
| 892 |
-
print(x)
|
| 893 |
return self.emb2vocab(x)
|
| 894 |
|
| 895 |
|
|
@@ -1043,7 +1047,6 @@ class GptBertForSequenceClassification(GptBertModel):
|
|
| 1043 |
|
| 1044 |
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1045 |
logits = self.head(sequence_output[:, 0, :])
|
| 1046 |
-
print(logits)
|
| 1047 |
|
| 1048 |
loss = None
|
| 1049 |
if labels is not None:
|
|
|
|
| 688 |
_supports_flex_attn = False
|
| 689 |
|
| 690 |
def _init_weights(self, module):
|
| 691 |
+
std = math.sqrt(2.0 / (5.0 * self.hidden_size))
|
| 692 |
+
|
| 693 |
+
if isinstance(module, nn.Linear):
|
| 694 |
+
nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 695 |
+
if module.bias is not None:
|
| 696 |
+
module.bias.data.zero_()
|
| 697 |
+
elif isinstance(module, nn.Embedding):
|
| 698 |
+
nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 699 |
+
elif isinstance(module, nn.LayerNorm):
|
| 700 |
+
module.bias.data.zero_()
|
| 701 |
+
module.weight.data.fill_(1.0)
|
| 702 |
|
| 703 |
|
| 704 |
class GptBertModel(GptBertPreTrainedModel):
|
|
|
|
| 889 |
self.emb2vocab.bias.zero_()
|
| 890 |
|
| 891 |
def forward(self, x: torch.Tensor):
|
|
|
|
| 892 |
x = self.pre_norm(x)
|
|
|
|
| 893 |
x = self.dropout(x)
|
|
|
|
| 894 |
x = self.projection(x)
|
|
|
|
| 895 |
x = gelu_new(x)
|
|
|
|
| 896 |
x = self.post_norm(x)
|
|
|
|
| 897 |
return self.emb2vocab(x)
|
| 898 |
|
| 899 |
|
|
|
|
| 1047 |
|
| 1048 |
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1049 |
logits = self.head(sequence_output[:, 0, :])
|
|
|
|
| 1050 |
|
| 1051 |
loss = None
|
| 1052 |
if labels is not None:
|