Update modeling_grok2.py
Browse files- modeling_grok2.py +32 -71
modeling_grok2.py
CHANGED
|
@@ -28,6 +28,7 @@ Architecture:
|
|
| 28 |
Sparse MoE: 8 experts, top-2, SwiGLU (w1=gate, w3=up, w2=down)
|
| 29 |
4x RMSNorm per layer (no bias)
|
| 30 |
RoPE with scaled theta
|
|
|
|
| 31 |
"""
|
| 32 |
|
| 33 |
import math
|
|
@@ -157,7 +158,7 @@ class Grok2Attention(nn.Module):
|
|
| 157 |
self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False)
|
| 158 |
self.rotary_emb = Grok2RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta)
|
| 159 |
|
| 160 |
-
def forward(self, hidden_states, attention_mask=None,
|
| 161 |
B, T, _ = hidden_states.shape
|
| 162 |
|
| 163 |
q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
|
|
@@ -169,21 +170,6 @@ class Grok2Attention(nn.Module):
|
|
| 169 |
sin = sin[:, :, :T, :self.head_dim]
|
| 170 |
q, k = apply_rotary_emb(q, k, cos, sin)
|
| 171 |
|
| 172 |
-
if past_key_value is not None:
|
| 173 |
-
if hasattr(past_key_value, 'update'):
|
| 174 |
-
# DynamicCache β use the official update method
|
| 175 |
-
k, v = past_key_value.update(k, v, self._layer_idx)
|
| 176 |
-
elif hasattr(past_key_value, 'key_cache'):
|
| 177 |
-
layer_idx = getattr(self, '_layer_idx', 0)
|
| 178 |
-
if layer_idx < len(past_key_value.key_cache):
|
| 179 |
-
k = torch.cat([past_key_value.key_cache[layer_idx], k], dim=2)
|
| 180 |
-
v = torch.cat([past_key_value.value_cache[layer_idx], v], dim=2)
|
| 181 |
-
else:
|
| 182 |
-
k = torch.cat([past_key_value[0], k], dim=2)
|
| 183 |
-
v = torch.cat([past_key_value[1], v], dim=2)
|
| 184 |
-
|
| 185 |
-
present = (k, v) if use_cache else None
|
| 186 |
-
|
| 187 |
# GQA expand
|
| 188 |
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
| 189 |
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
|
@@ -191,16 +177,14 @@ class Grok2Attention(nn.Module):
|
|
| 191 |
scale = math.sqrt(self.head_dim)
|
| 192 |
attn = torch.matmul(q, k.transpose(-2, -1)) / scale
|
| 193 |
|
| 194 |
-
# Attn logit softcapping
|
| 195 |
if self.attn_softcap > 0:
|
| 196 |
attn = attn / self.attn_softcap
|
| 197 |
attn = torch.tanh(attn)
|
| 198 |
attn = attn * self.attn_softcap
|
| 199 |
|
| 200 |
-
kv_len = k.shape[2]
|
| 201 |
causal = torch.triu(
|
| 202 |
-
torch.full((T,
|
| 203 |
-
diagonal=1
|
| 204 |
)
|
| 205 |
attn = attn + causal.unsqueeze(0).unsqueeze(0)
|
| 206 |
|
|
@@ -210,12 +194,11 @@ class Grok2Attention(nn.Module):
|
|
| 210 |
attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 211 |
out = torch.matmul(attn, v)
|
| 212 |
out = out.transpose(1, 2).contiguous().view(B, T, -1)
|
| 213 |
-
return self.o_proj(out)
|
| 214 |
|
| 215 |
|
| 216 |
# ββ MoE Expert ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 217 |
class Grok2Expert(nn.Module):
|
| 218 |
-
"""Single expert: SwiGLU with w1=gate, w3=up, w2=down."""
|
| 219 |
def __init__(self, hidden_size, moe_intermediate_size):
|
| 220 |
super().__init__()
|
| 221 |
self.w1 = nn.Linear(hidden_size, moe_intermediate_size, bias=False)
|
|
@@ -244,9 +227,8 @@ class Grok2SparseMoE(nn.Module):
|
|
| 244 |
B, T, H = x.shape
|
| 245 |
x_flat = x.view(-1, H)
|
| 246 |
|
| 247 |
-
router_logits = self.gate(x_flat)
|
| 248 |
|
| 249 |
-
# Router softcapping
|
| 250 |
if self.router_softcap > 0:
|
| 251 |
router_logits = router_logits / self.router_softcap
|
| 252 |
router_logits = torch.tanh(router_logits)
|
|
@@ -268,7 +250,7 @@ class Grok2SparseMoE(nn.Module):
|
|
| 268 |
return out.view(B, T, H)
|
| 269 |
|
| 270 |
|
| 271 |
-
# ββ Dense MLP
|
| 272 |
class Grok2MLP(nn.Module):
|
| 273 |
def __init__(self, config: Grok2Config):
|
| 274 |
super().__init__()
|
|
@@ -285,27 +267,23 @@ class Grok2DecoderLayer(nn.Module):
|
|
| 285 |
def __init__(self, config: Grok2Config, layer_idx: int):
|
| 286 |
super().__init__()
|
| 287 |
self.layer_idx = layer_idx
|
| 288 |
-
self.pre_attn_norm
|
| 289 |
-
self.self_attn
|
| 290 |
-
self.
|
| 291 |
-
self.
|
| 292 |
-
self.pre_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 293 |
self.block_sparse_moe = Grok2SparseMoE(config)
|
| 294 |
-
self.mlp
|
| 295 |
-
self.post_moe_norm
|
| 296 |
|
| 297 |
-
def forward(self, hidden_states, attention_mask=None,
|
| 298 |
-
# Attention
|
| 299 |
residual = hidden_states
|
| 300 |
hidden_states = self.pre_attn_norm(hidden_states)
|
| 301 |
-
hidden_states
|
| 302 |
-
hidden_states, attention_mask=attention_mask,
|
| 303 |
-
past_key_value=past_key_value, use_cache=use_cache
|
| 304 |
-
)
|
| 305 |
hidden_states = self.post_attn_norm(hidden_states)
|
| 306 |
hidden_states = residual + hidden_states
|
| 307 |
|
| 308 |
-
# MoE + dense residual
|
| 309 |
residual = hidden_states
|
| 310 |
hidden_states = self.pre_moe_norm(hidden_states)
|
| 311 |
moe_out = self.block_sparse_moe(hidden_states)
|
|
@@ -313,7 +291,7 @@ class Grok2DecoderLayer(nn.Module):
|
|
| 313 |
hidden_states = self.post_moe_norm(moe_out + mlp_out)
|
| 314 |
hidden_states = residual + hidden_states
|
| 315 |
|
| 316 |
-
return hidden_states
|
| 317 |
|
| 318 |
|
| 319 |
# ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -327,25 +305,11 @@ class Grok2Model(nn.Module):
|
|
| 327 |
])
|
| 328 |
self.norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 329 |
|
| 330 |
-
def forward(self, input_ids, attention_mask=None,
|
| 331 |
hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier_scale
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
pkv = None
|
| 336 |
-
if past_key_values is not None:
|
| 337 |
-
if hasattr(past_key_values, 'key_cache'):
|
| 338 |
-
pkv = past_key_values
|
| 339 |
-
else:
|
| 340 |
-
pkv = past_key_values[i] if i < len(past_key_values) else None
|
| 341 |
-
hidden_states, present = layer(
|
| 342 |
-
hidden_states, attention_mask=attention_mask,
|
| 343 |
-
past_key_value=pkv, use_cache=use_cache
|
| 344 |
-
)
|
| 345 |
-
if use_cache and present is not None:
|
| 346 |
-
presents.append(present)
|
| 347 |
-
|
| 348 |
-
return self.norm(hidden_states), presents
|
| 349 |
|
| 350 |
|
| 351 |
# ββ CausalLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -356,8 +320,8 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 356 |
|
| 357 |
def __init__(self, config: Grok2Config):
|
| 358 |
super().__init__(config)
|
| 359 |
-
self.model
|
| 360 |
-
self.lm_head
|
| 361 |
self.output_multiplier_scale = config.output_multiplier_scale
|
| 362 |
self.final_logit_softcapping = config.final_logit_softcapping
|
| 363 |
self.post_init()
|
|
@@ -372,13 +336,10 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 372 |
past_key_values=None,
|
| 373 |
inputs_embeds=None,
|
| 374 |
labels=None,
|
| 375 |
-
use_cache=
|
| 376 |
**kwargs,
|
| 377 |
):
|
| 378 |
-
hidden_states
|
| 379 |
-
input_ids, attention_mask=attention_mask,
|
| 380 |
-
past_key_values=past_key_values, use_cache=False
|
| 381 |
-
)
|
| 382 |
|
| 383 |
logits = self.lm_head(hidden_states) * self.output_multiplier_scale
|
| 384 |
|
|
@@ -398,15 +359,15 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 398 |
)
|
| 399 |
|
| 400 |
return CausalLMOutputWithPast(
|
| 401 |
-
loss=loss,
|
|
|
|
|
|
|
| 402 |
)
|
| 403 |
|
| 404 |
-
def prepare_inputs_for_generation(self, input_ids,
|
| 405 |
-
|
| 406 |
-
input_ids = input_ids[:, -1:]
|
| 407 |
-
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
|
| 408 |
|
| 409 |
|
| 410 |
-
# ββ Register
|
| 411 |
AutoConfig.register("grok2", Grok2Config)
|
| 412 |
AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)
|
|
|
|
| 28 |
Sparse MoE: 8 experts, top-2, SwiGLU (w1=gate, w3=up, w2=down)
|
| 29 |
4x RMSNorm per layer (no bias)
|
| 30 |
RoPE with scaled theta
|
| 31 |
+
KV cache disabled β forward pass only, no past_key_values
|
| 32 |
"""
|
| 33 |
|
| 34 |
import math
|
|
|
|
| 158 |
self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False)
|
| 159 |
self.rotary_emb = Grok2RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta)
|
| 160 |
|
| 161 |
+
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
| 162 |
B, T, _ = hidden_states.shape
|
| 163 |
|
| 164 |
q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
| 170 |
sin = sin[:, :, :T, :self.head_dim]
|
| 171 |
q, k = apply_rotary_emb(q, k, cos, sin)
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
# GQA expand
|
| 174 |
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
| 175 |
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
|
|
|
| 177 |
scale = math.sqrt(self.head_dim)
|
| 178 |
attn = torch.matmul(q, k.transpose(-2, -1)) / scale
|
| 179 |
|
|
|
|
| 180 |
if self.attn_softcap > 0:
|
| 181 |
attn = attn / self.attn_softcap
|
| 182 |
attn = torch.tanh(attn)
|
| 183 |
attn = attn * self.attn_softcap
|
| 184 |
|
|
|
|
| 185 |
causal = torch.triu(
|
| 186 |
+
torch.full((T, T), float("-inf"), device=q.device, dtype=q.dtype),
|
| 187 |
+
diagonal=1
|
| 188 |
)
|
| 189 |
attn = attn + causal.unsqueeze(0).unsqueeze(0)
|
| 190 |
|
|
|
|
| 194 |
attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 195 |
out = torch.matmul(attn, v)
|
| 196 |
out = out.transpose(1, 2).contiguous().view(B, T, -1)
|
| 197 |
+
return self.o_proj(out)
|
| 198 |
|
| 199 |
|
| 200 |
# ββ MoE Expert ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 201 |
class Grok2Expert(nn.Module):
|
|
|
|
| 202 |
def __init__(self, hidden_size, moe_intermediate_size):
|
| 203 |
super().__init__()
|
| 204 |
self.w1 = nn.Linear(hidden_size, moe_intermediate_size, bias=False)
|
|
|
|
| 227 |
B, T, H = x.shape
|
| 228 |
x_flat = x.view(-1, H)
|
| 229 |
|
| 230 |
+
router_logits = self.gate(x_flat)
|
| 231 |
|
|
|
|
| 232 |
if self.router_softcap > 0:
|
| 233 |
router_logits = router_logits / self.router_softcap
|
| 234 |
router_logits = torch.tanh(router_logits)
|
|
|
|
| 250 |
return out.view(B, T, H)
|
| 251 |
|
| 252 |
|
| 253 |
+
# ββ Dense MLP βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 254 |
class Grok2MLP(nn.Module):
|
| 255 |
def __init__(self, config: Grok2Config):
|
| 256 |
super().__init__()
|
|
|
|
| 267 |
def __init__(self, config: Grok2Config, layer_idx: int):
|
| 268 |
super().__init__()
|
| 269 |
self.layer_idx = layer_idx
|
| 270 |
+
self.pre_attn_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 271 |
+
self.self_attn = Grok2Attention(config)
|
| 272 |
+
self.post_attn_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 273 |
+
self.pre_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
|
|
|
| 274 |
self.block_sparse_moe = Grok2SparseMoE(config)
|
| 275 |
+
self.mlp = Grok2MLP(config)
|
| 276 |
+
self.post_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 277 |
|
| 278 |
+
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
| 279 |
+
# Attention
|
| 280 |
residual = hidden_states
|
| 281 |
hidden_states = self.pre_attn_norm(hidden_states)
|
| 282 |
+
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
|
|
|
|
|
|
|
|
|
|
| 283 |
hidden_states = self.post_attn_norm(hidden_states)
|
| 284 |
hidden_states = residual + hidden_states
|
| 285 |
|
| 286 |
+
# MoE + dense residual
|
| 287 |
residual = hidden_states
|
| 288 |
hidden_states = self.pre_moe_norm(hidden_states)
|
| 289 |
moe_out = self.block_sparse_moe(hidden_states)
|
|
|
|
| 291 |
hidden_states = self.post_moe_norm(moe_out + mlp_out)
|
| 292 |
hidden_states = residual + hidden_states
|
| 293 |
|
| 294 |
+
return hidden_states
|
| 295 |
|
| 296 |
|
| 297 |
# ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 305 |
])
|
| 306 |
self.norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 307 |
|
| 308 |
+
def forward(self, input_ids, attention_mask=None, **kwargs):
|
| 309 |
hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier_scale
|
| 310 |
+
for layer in self.layers:
|
| 311 |
+
hidden_states = layer(hidden_states, attention_mask=attention_mask)
|
| 312 |
+
return self.norm(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
|
| 315 |
# ββ CausalLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 320 |
|
| 321 |
def __init__(self, config: Grok2Config):
|
| 322 |
super().__init__(config)
|
| 323 |
+
self.model = Grok2Model(config)
|
| 324 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 325 |
self.output_multiplier_scale = config.output_multiplier_scale
|
| 326 |
self.final_logit_softcapping = config.final_logit_softcapping
|
| 327 |
self.post_init()
|
|
|
|
| 336 |
past_key_values=None,
|
| 337 |
inputs_embeds=None,
|
| 338 |
labels=None,
|
| 339 |
+
use_cache=None,
|
| 340 |
**kwargs,
|
| 341 |
):
|
| 342 |
+
hidden_states = self.model(input_ids, attention_mask=attention_mask)
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
logits = self.lm_head(hidden_states) * self.output_multiplier_scale
|
| 345 |
|
|
|
|
| 359 |
)
|
| 360 |
|
| 361 |
return CausalLMOutputWithPast(
|
| 362 |
+
loss=loss,
|
| 363 |
+
logits=logits,
|
| 364 |
+
past_key_values=None,
|
| 365 |
)
|
| 366 |
|
| 367 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 368 |
+
return {"input_ids": input_ids}
|
|
|
|
|
|
|
| 369 |
|
| 370 |
|
| 371 |
+
# ββ Register ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 372 |
AutoConfig.register("grok2", Grok2Config)
|
| 373 |
AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)
|