Johnblick187 commited on
Commit
3401c5a
Β·
verified Β·
1 Parent(s): 08c719a

Update modeling_grok2.py

Browse files
Files changed (1) hide show
  1. modeling_grok2.py +32 -71
modeling_grok2.py CHANGED
@@ -28,6 +28,7 @@ Architecture:
28
  Sparse MoE: 8 experts, top-2, SwiGLU (w1=gate, w3=up, w2=down)
29
  4x RMSNorm per layer (no bias)
30
  RoPE with scaled theta
 
31
  """
32
 
33
  import math
@@ -157,7 +158,7 @@ class Grok2Attention(nn.Module):
157
  self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False)
158
  self.rotary_emb = Grok2RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta)
159
 
160
- def forward(self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False):
161
  B, T, _ = hidden_states.shape
162
 
163
  q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
@@ -169,21 +170,6 @@ class Grok2Attention(nn.Module):
169
  sin = sin[:, :, :T, :self.head_dim]
170
  q, k = apply_rotary_emb(q, k, cos, sin)
171
 
172
- if past_key_value is not None:
173
- if hasattr(past_key_value, 'update'):
174
- # DynamicCache β€” use the official update method
175
- k, v = past_key_value.update(k, v, self._layer_idx)
176
- elif hasattr(past_key_value, 'key_cache'):
177
- layer_idx = getattr(self, '_layer_idx', 0)
178
- if layer_idx < len(past_key_value.key_cache):
179
- k = torch.cat([past_key_value.key_cache[layer_idx], k], dim=2)
180
- v = torch.cat([past_key_value.value_cache[layer_idx], v], dim=2)
181
- else:
182
- k = torch.cat([past_key_value[0], k], dim=2)
183
- v = torch.cat([past_key_value[1], v], dim=2)
184
-
185
- present = (k, v) if use_cache else None
186
-
187
  # GQA expand
188
  k = k.repeat_interleave(self.num_kv_groups, dim=1)
189
  v = v.repeat_interleave(self.num_kv_groups, dim=1)
@@ -191,16 +177,14 @@ class Grok2Attention(nn.Module):
191
  scale = math.sqrt(self.head_dim)
192
  attn = torch.matmul(q, k.transpose(-2, -1)) / scale
193
 
194
- # Attn logit softcapping
195
  if self.attn_softcap > 0:
196
  attn = attn / self.attn_softcap
197
  attn = torch.tanh(attn)
198
  attn = attn * self.attn_softcap
199
 
200
- kv_len = k.shape[2]
201
  causal = torch.triu(
202
- torch.full((T, kv_len), float("-inf"), device=q.device, dtype=q.dtype),
203
- diagonal=1 + kv_len - T
204
  )
205
  attn = attn + causal.unsqueeze(0).unsqueeze(0)
206
 
@@ -210,12 +194,11 @@ class Grok2Attention(nn.Module):
210
  attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
211
  out = torch.matmul(attn, v)
212
  out = out.transpose(1, 2).contiguous().view(B, T, -1)
213
- return self.o_proj(out), present
214
 
215
 
216
  # ── MoE Expert ────────────────────────────────────────────────────────────────
217
  class Grok2Expert(nn.Module):
218
- """Single expert: SwiGLU with w1=gate, w3=up, w2=down."""
219
  def __init__(self, hidden_size, moe_intermediate_size):
220
  super().__init__()
221
  self.w1 = nn.Linear(hidden_size, moe_intermediate_size, bias=False)
@@ -244,9 +227,8 @@ class Grok2SparseMoE(nn.Module):
244
  B, T, H = x.shape
245
  x_flat = x.view(-1, H)
246
 
247
- router_logits = self.gate(x_flat) # [B*T, n_experts]
248
 
249
- # Router softcapping
250
  if self.router_softcap > 0:
251
  router_logits = router_logits / self.router_softcap
252
  router_logits = torch.tanh(router_logits)
@@ -268,7 +250,7 @@ class Grok2SparseMoE(nn.Module):
268
  return out.view(B, T, H)
269
 
270
 
271
- # ── Dense MLP (residual path) ─────────────────────────────────────────────────
272
  class Grok2MLP(nn.Module):
273
  def __init__(self, config: Grok2Config):
274
  super().__init__()
@@ -285,27 +267,23 @@ class Grok2DecoderLayer(nn.Module):
285
  def __init__(self, config: Grok2Config, layer_idx: int):
286
  super().__init__()
287
  self.layer_idx = layer_idx
288
- self.pre_attn_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
289
- self.self_attn = Grok2Attention(config)
290
- self.self_attn._layer_idx = layer_idx
291
- self.post_attn_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
292
- self.pre_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
293
  self.block_sparse_moe = Grok2SparseMoE(config)
294
- self.mlp = Grok2MLP(config)
295
- self.post_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
296
 
297
- def forward(self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False):
298
- # Attention block
299
  residual = hidden_states
300
  hidden_states = self.pre_attn_norm(hidden_states)
301
- hidden_states, present = self.self_attn(
302
- hidden_states, attention_mask=attention_mask,
303
- past_key_value=past_key_value, use_cache=use_cache
304
- )
305
  hidden_states = self.post_attn_norm(hidden_states)
306
  hidden_states = residual + hidden_states
307
 
308
- # MoE + dense residual block
309
  residual = hidden_states
310
  hidden_states = self.pre_moe_norm(hidden_states)
311
  moe_out = self.block_sparse_moe(hidden_states)
@@ -313,7 +291,7 @@ class Grok2DecoderLayer(nn.Module):
313
  hidden_states = self.post_moe_norm(moe_out + mlp_out)
314
  hidden_states = residual + hidden_states
315
 
316
- return hidden_states, present
317
 
318
 
319
  # ── Model ─────────────────────────────────────────────────────────────────────
@@ -327,25 +305,11 @@ class Grok2Model(nn.Module):
327
  ])
328
  self.norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
329
 
330
- def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False):
331
  hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier_scale
332
- presents = [] if use_cache else None
333
-
334
- for i, layer in enumerate(self.layers):
335
- pkv = None
336
- if past_key_values is not None:
337
- if hasattr(past_key_values, 'key_cache'):
338
- pkv = past_key_values
339
- else:
340
- pkv = past_key_values[i] if i < len(past_key_values) else None
341
- hidden_states, present = layer(
342
- hidden_states, attention_mask=attention_mask,
343
- past_key_value=pkv, use_cache=use_cache
344
- )
345
- if use_cache and present is not None:
346
- presents.append(present)
347
-
348
- return self.norm(hidden_states), presents
349
 
350
 
351
  # ── CausalLM ──────────────────────────────────────────────────────────────────
@@ -356,8 +320,8 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
356
 
357
  def __init__(self, config: Grok2Config):
358
  super().__init__(config)
359
- self.model = Grok2Model(config)
360
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
361
  self.output_multiplier_scale = config.output_multiplier_scale
362
  self.final_logit_softcapping = config.final_logit_softcapping
363
  self.post_init()
@@ -372,13 +336,10 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
372
  past_key_values=None,
373
  inputs_embeds=None,
374
  labels=None,
375
- use_cache=True,
376
  **kwargs,
377
  ):
378
- hidden_states, presents = self.model(
379
- input_ids, attention_mask=attention_mask,
380
- past_key_values=past_key_values, use_cache=False
381
- )
382
 
383
  logits = self.lm_head(hidden_states) * self.output_multiplier_scale
384
 
@@ -398,15 +359,15 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
398
  )
399
 
400
  return CausalLMOutputWithPast(
401
- loss=loss, logits=logits, past_key_values=presents
 
 
402
  )
403
 
404
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
405
- if past_key_values is not None:
406
- input_ids = input_ids[:, -1:]
407
- return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
408
 
409
 
410
- # ── Register with AutoModel ───────────────────────────────────────────────────
411
  AutoConfig.register("grok2", Grok2Config)
412
  AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)
 
28
  Sparse MoE: 8 experts, top-2, SwiGLU (w1=gate, w3=up, w2=down)
29
  4x RMSNorm per layer (no bias)
30
  RoPE with scaled theta
31
+ KV cache disabled β€” forward pass only, no past_key_values
32
  """
33
 
34
  import math
 
158
  self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False)
159
  self.rotary_emb = Grok2RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta)
160
 
161
+ def forward(self, hidden_states, attention_mask=None, **kwargs):
162
  B, T, _ = hidden_states.shape
163
 
164
  q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
 
170
  sin = sin[:, :, :T, :self.head_dim]
171
  q, k = apply_rotary_emb(q, k, cos, sin)
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # GQA expand
174
  k = k.repeat_interleave(self.num_kv_groups, dim=1)
175
  v = v.repeat_interleave(self.num_kv_groups, dim=1)
 
177
  scale = math.sqrt(self.head_dim)
178
  attn = torch.matmul(q, k.transpose(-2, -1)) / scale
179
 
 
180
  if self.attn_softcap > 0:
181
  attn = attn / self.attn_softcap
182
  attn = torch.tanh(attn)
183
  attn = attn * self.attn_softcap
184
 
 
185
  causal = torch.triu(
186
+ torch.full((T, T), float("-inf"), device=q.device, dtype=q.dtype),
187
+ diagonal=1
188
  )
189
  attn = attn + causal.unsqueeze(0).unsqueeze(0)
190
 
 
194
  attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
195
  out = torch.matmul(attn, v)
196
  out = out.transpose(1, 2).contiguous().view(B, T, -1)
197
+ return self.o_proj(out)
198
 
199
 
200
  # ── MoE Expert ────────────────────────────────────────────────────────────────
201
  class Grok2Expert(nn.Module):
 
202
  def __init__(self, hidden_size, moe_intermediate_size):
203
  super().__init__()
204
  self.w1 = nn.Linear(hidden_size, moe_intermediate_size, bias=False)
 
227
  B, T, H = x.shape
228
  x_flat = x.view(-1, H)
229
 
230
+ router_logits = self.gate(x_flat)
231
 
 
232
  if self.router_softcap > 0:
233
  router_logits = router_logits / self.router_softcap
234
  router_logits = torch.tanh(router_logits)
 
250
  return out.view(B, T, H)
251
 
252
 
253
+ # ── Dense MLP ─────────────────────────────────────────────────────────────────
254
  class Grok2MLP(nn.Module):
255
  def __init__(self, config: Grok2Config):
256
  super().__init__()
 
267
  def __init__(self, config: Grok2Config, layer_idx: int):
268
  super().__init__()
269
  self.layer_idx = layer_idx
270
+ self.pre_attn_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
271
+ self.self_attn = Grok2Attention(config)
272
+ self.post_attn_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
273
+ self.pre_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
 
274
  self.block_sparse_moe = Grok2SparseMoE(config)
275
+ self.mlp = Grok2MLP(config)
276
+ self.post_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
277
 
278
+ def forward(self, hidden_states, attention_mask=None, **kwargs):
279
+ # Attention
280
  residual = hidden_states
281
  hidden_states = self.pre_attn_norm(hidden_states)
282
+ hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
 
 
 
283
  hidden_states = self.post_attn_norm(hidden_states)
284
  hidden_states = residual + hidden_states
285
 
286
+ # MoE + dense residual
287
  residual = hidden_states
288
  hidden_states = self.pre_moe_norm(hidden_states)
289
  moe_out = self.block_sparse_moe(hidden_states)
 
291
  hidden_states = self.post_moe_norm(moe_out + mlp_out)
292
  hidden_states = residual + hidden_states
293
 
294
+ return hidden_states
295
 
296
 
297
  # ── Model ─────────────────────────────────────────────────────────────────────
 
305
  ])
306
  self.norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
307
 
308
+ def forward(self, input_ids, attention_mask=None, **kwargs):
309
  hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier_scale
310
+ for layer in self.layers:
311
+ hidden_states = layer(hidden_states, attention_mask=attention_mask)
312
+ return self.norm(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
 
315
  # ── CausalLM ──────────────────────────────────────────────────────────────────
 
320
 
321
  def __init__(self, config: Grok2Config):
322
  super().__init__(config)
323
+ self.model = Grok2Model(config)
324
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
325
  self.output_multiplier_scale = config.output_multiplier_scale
326
  self.final_logit_softcapping = config.final_logit_softcapping
327
  self.post_init()
 
336
  past_key_values=None,
337
  inputs_embeds=None,
338
  labels=None,
339
+ use_cache=None,
340
  **kwargs,
341
  ):
342
+ hidden_states = self.model(input_ids, attention_mask=attention_mask)
 
 
 
343
 
344
  logits = self.lm_head(hidden_states) * self.output_multiplier_scale
345
 
 
359
  )
360
 
361
  return CausalLMOutputWithPast(
362
+ loss=loss,
363
+ logits=logits,
364
+ past_key_values=None,
365
  )
366
 
367
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
368
+ return {"input_ids": input_ids}
 
 
369
 
370
 
371
+ # ── Register ──────────────────────────────────────────────────────────────────
372
  AutoConfig.register("grok2", Grok2Config)
373
  AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)