Update modeling_grok2.py
Browse files- modeling_grok2.py +3 -3
modeling_grok2.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
modeling_grok2.py β Custom Grok 2 modeling code for transformers.
|
| 3 |
-
Allows
|
| 4 |
|
| 5 |
Exact tensor key names:
|
| 6 |
model.embed_tokens.weight [131072, 8192]
|
|
@@ -347,7 +347,7 @@ class Grok2Model(nn.Module):
|
|
| 347 |
|
| 348 |
|
| 349 |
# ββ CausalLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 350 |
-
class
|
| 351 |
config_class = Grok2Config
|
| 352 |
base_model_prefix = "model"
|
| 353 |
supports_gradient_checkpointing = False
|
|
@@ -407,4 +407,4 @@ class Grok2ForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 407 |
|
| 408 |
# ββ Register with AutoModel βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 409 |
AutoConfig.register("grok2", Grok2Config)
|
| 410 |
-
AutoModelForCausalLM.register(Grok2Config,
|
|
|
|
| 1 |
"""
|
| 2 |
modeling_grok2.py β Custom Grok 2 modeling code for transformers.
|
| 3 |
+
Allows AutoModel to load Johnblick187/grok-2.
|
| 4 |
|
| 5 |
Exact tensor key names:
|
| 6 |
model.embed_tokens.weight [131072, 8192]
|
|
|
|
| 347 |
|
| 348 |
|
| 349 |
# ββ CausalLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 350 |
+
class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
|
| 351 |
config_class = Grok2Config
|
| 352 |
base_model_prefix = "model"
|
| 353 |
supports_gradient_checkpointing = False
|
|
|
|
| 407 |
|
| 408 |
# ββ Register with AutoModel βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 409 |
AutoConfig.register("grok2", Grok2Config)
|
| 410 |
+
AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)
|