Johnblick187 commited on
Commit
22bdd44
Β·
verified Β·
1 Parent(s): e068954

Update modeling_grok2.py

Browse files
Files changed (1) hide show
  1. modeling_grok2.py +3 -1
modeling_grok2.py CHANGED
@@ -178,6 +178,8 @@ class Grok2Expert(nn.Module):
178
  self.w3 = nn.Linear(hidden_size, moe_intermediate_size, bias=False)
179
 
180
  def forward(self, x):
 
 
181
  return self.w2(F.silu(self.w1(x)) * self.w3(x))
182
 
183
 
@@ -345,4 +347,4 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
345
 
346
  # ── Register ──────────────────────────────────────────────────────────────────
347
  AutoConfig.register("grok2", Grok2Config)
348
- AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)
 
178
  self.w3 = nn.Linear(hidden_size, moe_intermediate_size, bias=False)
179
 
180
  def forward(self, x):
181
+ device = self.w1.weight.device
182
+ x = x.to(device)
183
  return self.w2(F.silu(self.w1(x)) * self.w3(x))
184
 
185
 
 
347
 
348
  # ── Register ──────────────────────────────────────────────────────────────────
349
  AutoConfig.register("grok2", Grok2Config)
350
+ AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)