Johnblick187 commited on
Commit
6e76888
Β·
verified Β·
1 Parent(s): 71f9cbe

Update modeling_grok2.py

Browse files
Files changed (1) hide show
  1. modeling_grok2.py +7 -2
modeling_grok2.py CHANGED
@@ -180,8 +180,13 @@ class Grok2Expert(nn.Module):
180
  def forward(self, x):
181
  device = self.w1.weight.device
182
  x = x.to(device)
183
- h = F.silu(self.w1(x)) * self.w3(x.to(self.w3.weight.device))
184
- return self.w2(h.to(self.w2.weight.device))
 
 
 
 
 
185
 
186
 
187
  # ── Sparse MoE ────────────────────────────────────────────────────────────────
 
180
  def forward(self, x):
181
  device = self.w1.weight.device
182
  x = x.to(device)
183
+ d1 = self.w1.weight.device
184
+ d3 = self.w3.weight.device
185
+ d2 = self.w2.weight.device
186
+ gate = F.silu(self.w1(x.to(d1)))
187
+ up = self.w3(x.to(d3))
188
+ h = gate.to(d2) * up.to(d2)
189
+ return self.w2(h)
190
 
191
 
192
  # ── Sparse MoE ────────────────────────────────────────────────────────────────