Update modeling_grok2.py
Browse files- modeling_grok2.py +2 -4
modeling_grok2.py
CHANGED
|
@@ -104,10 +104,8 @@ class Grok2RMSNorm(nn.Module):
|
|
| 104 |
self.eps = eps
|
| 105 |
|
| 106 |
def forward(self, x):
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 110 |
-
return (self.weight * x).to(orig)
|
| 111 |
|
| 112 |
|
| 113 |
# ββ RoPE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 104 |
self.eps = eps
|
| 105 |
|
| 106 |
def forward(self, x):
|
| 107 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 108 |
+
return self.weight * x * torch.rsqrt(variance + self.eps)
|
|
|
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
# ββ RoPE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|