TCMVince commited on
Commit
f5e20fc
·
verified ·
1 Parent(s): 1e904cb

Update mlm.py

Browse files
Files changed (1) hide show
  1. mlm.py +3 -2
mlm.py CHANGED
@@ -505,8 +505,9 @@ class BertEnergyModelForSequenceClassification(BertPreTrainedModel):
505
 
506
  #self.dropout = nn.Dropout(dropout)
507
  output_dim = config.hidden_size
508
- self.norm = nn.LayerNorm(output_dim, eps=config.layer_norm_eps)
509
- self.classifier = nn.Linear(output_dim, num_labels)
 
510
 
511
  self.post_init()
512
 
 
505
 
506
  #self.dropout = nn.Dropout(dropout)
507
  output_dim = config.hidden_size
508
+ embed_dim = config.embedding_dim
509
+ self.norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps)
510
+ self.classifier = nn.Linear(config.embedding_dim, num_labels)
511
 
512
  self.post_init()
513