Johnblick187 commited on
Commit
b0bca5a
Β·
verified Β·
1 Parent(s): 3401c5a

Update modeling_grok2.py

Browse files
Files changed (1) hide show
  1. modeling_grok2.py +2 -4
modeling_grok2.py CHANGED
@@ -104,10 +104,8 @@ class Grok2RMSNorm(nn.Module):
104
  self.eps = eps
105
 
106
  def forward(self, x):
107
- orig = x.dtype
108
- x = x.float()
109
- x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
110
- return (self.weight * x).to(orig)
111
 
112
 
113
  # ── RoPE ──────────────────────────────────────────────────────────────────────
 
104
  self.eps = eps
105
 
106
  def forward(self, x):
107
+ variance = x.pow(2).mean(-1, keepdim=True)
108
+ return self.weight * x * torch.rsqrt(variance + self.eps)
 
 
109
 
110
 
111
  # ── RoPE ──────────────────────────────────────────────────────────────────────