Update modeling_grok2.py
Browse files- modeling_grok2.py +3 -1
modeling_grok2.py
CHANGED
|
@@ -178,6 +178,8 @@ class Grok2Expert(nn.Module):
|
|
| 178 |
self.w3 = nn.Linear(hidden_size, moe_intermediate_size, bias=False)
|
| 179 |
|
| 180 |
def forward(self, x):
|
|
|
|
|
|
|
| 181 |
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 182 |
|
| 183 |
|
|
@@ -345,4 +347,4 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 345 |
|
| 346 |
# ββ Register ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 347 |
AutoConfig.register("grok2", Grok2Config)
|
| 348 |
-
AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)
|
|
|
|
| 178 |
self.w3 = nn.Linear(hidden_size, moe_intermediate_size, bias=False)
|
| 179 |
|
| 180 |
def forward(self, x):
|
| 181 |
+
device = self.w1.weight.device
|
| 182 |
+
x = x.to(device)
|
| 183 |
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 184 |
|
| 185 |
|
|
|
|
| 347 |
|
| 348 |
# ββ Register ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 349 |
AutoConfig.register("grok2", Grok2Config)
|
| 350 |
+
AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)
|