Text Generation
PyTorch
English
Chinese
plm
conversational
custom_code
jjw0126 commited on
Commit
b08ef4d
·
verified ·
1 Parent(s): 06c8d07

Update modeling_plm.py

Browse files
Files changed (1) hide show
  1. modeling_plm.py +1 -1
modeling_plm.py CHANGED
@@ -92,7 +92,7 @@ class PLMRMSNorm(nn.Module):
92
  super().__init__()
93
  self.weight = nn.Parameter(torch.ones(hidden_size))
94
  self.variance_epsilon = eps
95
-
96
  def forward(self, hidden_states):
97
  input_dtype = hidden_states.dtype
98
  hidden_states = hidden_states.to(torch.float32)
 
92
  super().__init__()
93
  self.weight = nn.Parameter(torch.ones(hidden_size))
94
  self.variance_epsilon = eps
95
+ # We modify RMSNorm to align with TENorm, https://github.com/NVIDIA/TransformerEngine/issues/1132
96
  def forward(self, hidden_states):
97
  input_dtype = hidden_states.dtype
98
  hidden_states = hidden_states.to(torch.float32)