Update modeling_norbert.py
Browse files- modeling_norbert.py +10 -3
modeling_norbert.py
CHANGED
|
@@ -414,10 +414,11 @@ class Classifier(nn.Module):
|
|
| 414 |
nn.Dropout(drop_out),
|
| 415 |
nn.Linear(config.hidden_size, num_labels)
|
| 416 |
)
|
| 417 |
-
self.
|
|
|
|
| 418 |
|
| 419 |
-
def initialize(self
|
| 420 |
-
std = math.sqrt(2.0 / (5.0 * hidden_size))
|
| 421 |
nn.init.trunc_normal_(self.nonlinearity[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 422 |
nn.init.trunc_normal_(self.nonlinearity[-1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 423 |
self.nonlinearity[1].bias.data.zero_()
|
|
@@ -438,6 +439,9 @@ class NorbertForSequenceClassification(NorbertModel):
|
|
| 438 |
self.num_labels = config.num_labels
|
| 439 |
self.head = Classifier(config, self.num_labels)
|
| 440 |
|
|
|
|
|
|
|
|
|
|
| 441 |
def forward(
|
| 442 |
self,
|
| 443 |
input_ids: Optional[torch.Tensor] = None,
|
|
@@ -504,6 +508,9 @@ class NorbertForTokenClassification(NorbertModel):
|
|
| 504 |
self.num_labels = config.num_labels
|
| 505 |
self.head = Classifier(config, self.num_labels)
|
| 506 |
|
|
|
|
|
|
|
|
|
|
| 507 |
def forward(
|
| 508 |
self,
|
| 509 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 414 |
nn.Dropout(drop_out),
|
| 415 |
nn.Linear(config.hidden_size, num_labels)
|
| 416 |
)
|
| 417 |
+
self.hidden_size = config.hidden_size
|
| 418 |
+
self.initialize()
|
| 419 |
|
| 420 |
+
def initialize(self):
|
| 421 |
+
std = math.sqrt(2.0 / (5.0 * self.hidden_size))
|
| 422 |
nn.init.trunc_normal_(self.nonlinearity[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 423 |
nn.init.trunc_normal_(self.nonlinearity[-1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 424 |
self.nonlinearity[1].bias.data.zero_()
|
|
|
|
| 439 |
self.num_labels = config.num_labels
|
| 440 |
self.head = Classifier(config, self.num_labels)
|
| 441 |
|
| 442 |
+
def post_init(self):
|
| 443 |
+
self.head.initialize()
|
| 444 |
+
|
| 445 |
def forward(
|
| 446 |
self,
|
| 447 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 508 |
self.num_labels = config.num_labels
|
| 509 |
self.head = Classifier(config, self.num_labels)
|
| 510 |
|
| 511 |
+
def post_init(self):
|
| 512 |
+
self.head.initialize()
|
| 513 |
+
|
| 514 |
def forward(
|
| 515 |
self,
|
| 516 |
input_ids: Optional[torch.Tensor] = None,
|