Update modeling_grok2.py
Browse files- modeling_grok2.py +2 -1
modeling_grok2.py
CHANGED
|
@@ -180,7 +180,8 @@ class Grok2Expert(nn.Module):
|
|
| 180 |
def forward(self, x):
|
| 181 |
device = self.w1.weight.device
|
| 182 |
x = x.to(device)
|
| 183 |
-
|
|
|
|
| 184 |
|
| 185 |
|
| 186 |
# ββ Sparse MoE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 180 |
def forward(self, x):
|
| 181 |
device = self.w1.weight.device
|
| 182 |
x = x.to(device)
|
| 183 |
+
h = F.silu(self.w1(x)) * self.w3(x.to(self.w3.weight.device))
|
| 184 |
+
return self.w2(h.to(self.w2.weight.device))
|
| 185 |
|
| 186 |
|
| 187 |
# ββ Sparse MoE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|