Update modeling_grok2.py
Browse files- modeling_grok2.py +7 -2
modeling_grok2.py
CHANGED
|
@@ -180,8 +180,13 @@ class Grok2Expert(nn.Module):
|
|
| 180 |
def forward(self, x):
|
| 181 |
device = self.w1.weight.device
|
| 182 |
x = x.to(device)
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
|
| 187 |
# ββ Sparse MoE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 180 |
def forward(self, x):
|
| 181 |
device = self.w1.weight.device
|
| 182 |
x = x.to(device)
|
| 183 |
+
d1 = self.w1.weight.device
|
| 184 |
+
d3 = self.w3.weight.device
|
| 185 |
+
d2 = self.w2.weight.device
|
| 186 |
+
gate = F.silu(self.w1(x.to(d1)))
|
| 187 |
+
up = self.w3(x.to(d3))
|
| 188 |
+
h = gate.to(d2) * up.to(d2)
|
| 189 |
+
return self.w2(h)
|
| 190 |
|
| 191 |
|
| 192 |
# ββ Sparse MoE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|