Update modeling_plm.py
Browse files- modeling_plm.py +1 -1
modeling_plm.py
CHANGED
|
@@ -92,7 +92,7 @@ class PLMRMSNorm(nn.Module):
|
|
| 92 |
super().__init__()
|
| 93 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 94 |
self.variance_epsilon = eps
|
| 95 |
-
|
| 96 |
def forward(self, hidden_states):
|
| 97 |
input_dtype = hidden_states.dtype
|
| 98 |
hidden_states = hidden_states.to(torch.float32)
|
|
|
|
| 92 |
super().__init__()
|
| 93 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 94 |
self.variance_epsilon = eps
|
| 95 |
+
# We modify RMSNorm to align with TENorm, https://github.com/NVIDIA/TransformerEngine/issues/1132
|
| 96 |
def forward(self, hidden_states):
|
| 97 |
input_dtype = hidden_states.dtype
|
| 98 |
hidden_states = hidden_states.to(torch.float32)
|