davda54 commited on
Commit
9b2ac85
·
verified ·
1 Parent(s): fe327a7

Update modeling_norbert.py

Browse files
Files changed (1) hide show
  1. modeling_norbert.py +10 -3
modeling_norbert.py CHANGED
@@ -414,10 +414,11 @@ class Classifier(nn.Module):
414
  nn.Dropout(drop_out),
415
  nn.Linear(config.hidden_size, num_labels)
416
  )
417
- self.initialize(config.hidden_size)
 
418
 
419
- def initialize(self, hidden_size):
420
- std = math.sqrt(2.0 / (5.0 * hidden_size))
421
  nn.init.trunc_normal_(self.nonlinearity[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
422
  nn.init.trunc_normal_(self.nonlinearity[-1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
423
  self.nonlinearity[1].bias.data.zero_()
@@ -438,6 +439,9 @@ class NorbertForSequenceClassification(NorbertModel):
438
  self.num_labels = config.num_labels
439
  self.head = Classifier(config, self.num_labels)
440
 
 
 
 
441
  def forward(
442
  self,
443
  input_ids: Optional[torch.Tensor] = None,
@@ -504,6 +508,9 @@ class NorbertForTokenClassification(NorbertModel):
504
  self.num_labels = config.num_labels
505
  self.head = Classifier(config, self.num_labels)
506
 
 
 
 
507
  def forward(
508
  self,
509
  input_ids: Optional[torch.Tensor] = None,
 
414
  nn.Dropout(drop_out),
415
  nn.Linear(config.hidden_size, num_labels)
416
  )
417
+ self.hidden_size = config.hidden_size
418
+ self.initialize()
419
 
420
+ def initialize(self):
421
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
422
  nn.init.trunc_normal_(self.nonlinearity[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
423
  nn.init.trunc_normal_(self.nonlinearity[-1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
424
  self.nonlinearity[1].bias.data.zero_()
 
439
  self.num_labels = config.num_labels
440
  self.head = Classifier(config, self.num_labels)
441
 
442
+ def post_init(self):
443
+ self.head.initialize()
444
+
445
  def forward(
446
  self,
447
  input_ids: Optional[torch.Tensor] = None,
 
508
  self.num_labels = config.num_labels
509
  self.head = Classifier(config, self.num_labels)
510
 
511
+ def post_init(self):
512
+ self.head.initialize()
513
+
514
  def forward(
515
  self,
516
  input_ids: Optional[torch.Tensor] = None,