add rms eps to init_weights
Browse files- modeling_glm2.py +5 -1
modeling_glm2.py
CHANGED
|
@@ -373,6 +373,10 @@ class gLM2PreTrainedModel(PreTrainedModel):
|
|
| 373 |
# Force the buffer to update
|
| 374 |
with torch.no_grad():
|
| 375 |
module.inv_freq.copy_(inv_freq)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
|
| 378 |
class gLM2Model(gLM2PreTrainedModel):
|
|
@@ -476,4 +480,4 @@ class gLM2LMHead(nn.Module):
|
|
| 476 |
config.dim, config.vocab_size, bias=False)
|
| 477 |
|
| 478 |
def forward(self, features):
|
| 479 |
-
return self.proj_output(self.norm(features))
|
|
|
|
| 373 |
# Force the buffer to update
|
| 374 |
with torch.no_grad():
|
| 375 |
module.inv_freq.copy_(inv_freq)
|
| 376 |
+
elif isinstance(module, RMSNorm):
|
| 377 |
+
if hasattr(module, "variance_epsilon"):
|
| 378 |
+
with torch.no_grad():
|
| 379 |
+
module.variance_epsilon.fill_(self.config.norm_eps)
|
| 380 |
|
| 381 |
|
| 382 |
class gLM2Model(gLM2PreTrainedModel):
|
|
|
|
| 480 |
config.dim, config.vocab_size, bias=False)
|
| 481 |
|
| 482 |
def forward(self, features):
|
| 483 |
+
return self.proj_output(self.norm(features))
|