Johnblick187 commited on
Commit
e3c9bf2
Β·
verified Β·
1 Parent(s): 731441f

Update modeling_grok2.py

Browse files
Files changed (1) hide show
  1. modeling_grok2.py +3 -3
modeling_grok2.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  modeling_grok2.py β€” Custom Grok 2 modeling code for transformers.
3
- Allows Heretic and AutoModel to load Johnblick187/grok-2.
4
 
5
  Exact tensor key names:
6
  model.embed_tokens.weight [131072, 8192]
@@ -347,7 +347,7 @@ class Grok2Model(nn.Module):
347
 
348
 
349
  # ── CausalLM ──────────────────────────────────────────────────────────────────
350
- class Grok2ForCausalLM(PreTrainedModel, GenerationMixin):
351
  config_class = Grok2Config
352
  base_model_prefix = "model"
353
  supports_gradient_checkpointing = False
@@ -407,4 +407,4 @@ class Grok2ForCausalLM(PreTrainedModel, GenerationMixin):
407
 
408
  # ── Register with AutoModel ───────────────────────────────────────────────────
409
  AutoConfig.register("grok2", Grok2Config)
410
- AutoModelForCausalLM.register(Grok2Config, Grok2ForCausalLM)
 
1
  """
2
  modeling_grok2.py β€” Custom Grok 2 modeling code for transformers.
3
+ Allows AutoModel to load Johnblick187/grok-2.
4
 
5
  Exact tensor key names:
6
  model.embed_tokens.weight [131072, 8192]
 
347
 
348
 
349
  # ── CausalLM ──────────────────────────────────────────────────────────────────
350
+ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
351
  config_class = Grok2Config
352
  base_model_prefix = "model"
353
  supports_gradient_checkpointing = False
 
407
 
408
  # ── Register with AutoModel ───────────────────────────────────────────────────
409
  AutoConfig.register("grok2", Grok2Config)
410
+ AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)