davda54 commited on
Commit
8b27574
·
verified ·
1 Parent(s): 042ceea

Update modeling_gptbert.py

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +21 -3
modeling_gptbert.py CHANGED
@@ -138,6 +138,24 @@ class Embedding(nn.Module):
138
  return self.dropout(word_embedding)
139
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  class Classifier(nn.Module):
142
  def __init__(self, config: GptBertConfig, n_labels: int):
143
  super().__init__()
@@ -146,7 +164,7 @@ class Classifier(nn.Module):
146
  self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
147
  self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
148
  self.dropout = nn.Dropout(config.classifier_dropout)
149
- self.emb2vocab = CastedLinearIn(config.hidden_size, n_labels, bias=True)
150
 
151
  def forward(self, x: torch.Tensor):
152
  x = self.pre_norm(x.float()).type_as(x)
@@ -154,7 +172,7 @@ class Classifier(nn.Module):
154
  x = gelu_new(x)
155
  x = self.post_norm(x.float()).type_as(x)
156
  x = self.dropout(x)
157
- x = self.emb2vocab(x)
158
  return x
159
 
160
 
@@ -571,7 +589,7 @@ class GptBertModel(GptBertPreTrainedModel):
571
 
572
  self.embedding = Embedding(config)
573
  self.encoder = Encoder(config)
574
- self.classifier = Classifier(config, config.vocab_size) if add_mlm_layer else None
575
  self.set_window_length(config)
576
  self.gradient_checkpointing = False
577
  self.post_init()
 
138
  return self.dropout(word_embedding)
139
 
140
 
141
+ class LMClassifier(nn.Module):
142
+ def __init__(self, config: GptBertConfig, n_labels: int):
143
+ super().__init__()
144
+
145
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
146
+ self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
147
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
148
+ self.emb2vocab = CastedLinearIn(config.hidden_size, n_labels, bias=True)
149
+
150
+ def forward(self, x: torch.Tensor):
151
+ x = self.pre_norm(x.float()).type_as(x)
152
+ x = self.projection(x)
153
+ x = gelu_new(x)
154
+ x = self.post_norm(x.float()).type_as(x)
155
+ x = self.emb2vocab(x)
156
+ return x
157
+
158
+
159
  class Classifier(nn.Module):
160
  def __init__(self, config: GptBertConfig, n_labels: int):
161
  super().__init__()
 
164
  self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
165
  self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
166
  self.dropout = nn.Dropout(config.classifier_dropout)
167
+ self.output_projection = CastedLinearIn(config.hidden_size, n_labels, bias=True)
168
 
169
  def forward(self, x: torch.Tensor):
170
  x = self.pre_norm(x.float()).type_as(x)
 
172
  x = gelu_new(x)
173
  x = self.post_norm(x.float()).type_as(x)
174
  x = self.dropout(x)
175
+ x = self.output_projection(x)
176
  return x
177
 
178
 
 
589
 
590
  self.embedding = Embedding(config)
591
  self.encoder = Encoder(config)
592
+ self.classifier = LMClassifier(config, config.vocab_size) if add_mlm_layer else None
593
  self.set_window_length(config)
594
  self.gradient_checkpointing = False
595
  self.post_init()