Grok-2 / modeling_grok2.py
Johnblick187's picture
Update modeling_grok2.py
6e76888 verified
"""
modeling_grok2.py β€” Grok 2 for transformers, full multi-GPU support.
Pure bf16 throughout. Device-aware at every operation.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import AutoConfig, AutoModelForCausalLM
# ── Config ────────────────────────────────────────────────────────────────────
class Grok2Config(PretrainedConfig):
model_type = "grok2"
def __init__(
self,
vocab_size=131072,
hidden_size=8192,
num_hidden_layers=64,
num_attention_heads=64,
num_key_value_heads=8,
intermediate_size=32768,
moe_intermediate_size=16384,
num_local_experts=8,
num_experts_per_tok=2,
max_position_embeddings=131072,
rope_theta=208533496.0,
rms_norm_eps=1e-5,
embedding_multiplier_scale=90.50966799187809,
output_multiplier_scale=0.5,
final_logit_softcapping=50.0,
attn_logit_softcapping=30.0,
router_logit_softcapping=30.0,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = hidden_size // num_attention_heads
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_local_experts = num_local_experts
self.num_experts_per_tok = num_experts_per_tok
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rms_norm_eps = rms_norm_eps
self.embedding_multiplier_scale = embedding_multiplier_scale
self.output_multiplier_scale = output_multiplier_scale
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.router_logit_softcapping = router_logit_softcapping
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# ── RMSNorm ───────────────────────────────────────────────────────────────────
class Grok2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
# Stay in input dtype throughout
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight.to(x.device, x.dtype) * x
# ── RoPE ──────────────────────────────────────────────────────────────────────
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat([-x2, x1], dim=-1)
def apply_rotary_emb(q, k, cos, sin):
return (q * cos) + (rotate_half(q) * sin), \
(k * cos) + (rotate_half(k) * sin)
class Grok2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_pos=131072, base=208533496.0, scaling_factor=16.0):
super().__init__()
base = base * scaling_factor
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._cached_len = 0
def _build_cache(self, seq_len, device, dtype):
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, self.inv_freq.to(device))
emb = torch.cat([freqs, freqs], dim=-1)
self._cos = emb.cos().to(dtype)[None, None, :, :]
self._sin = emb.sin().to(dtype)[None, None, :, :]
self._cached_len = seq_len
self._cached_device = device
def forward(self, seq_len, device, dtype):
if seq_len > self._cached_len or not hasattr(self, '_cached_device') or device != self._cached_device:
self._build_cache(seq_len, device, dtype)
return self._cos[:, :, :seq_len, :], self._sin[:, :, :seq_len, :]
# ── Attention ─────────────────────────────────────────────────────────────────
class Grok2Attention(nn.Module):
def __init__(self, config: Grok2Config):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.attn_softcap = config.attn_logit_softcapping
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False)
self.rotary_emb = Grok2RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta)
def forward(self, hidden_states, attention_mask=None, **kwargs):
B, T, _ = hidden_states.shape
device = hidden_states.device
dtype = hidden_states.dtype
q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(T, device, dtype)
cos = cos[:, :, :T, :self.head_dim]
sin = sin[:, :, :T, :self.head_dim]
q, k = apply_rotary_emb(q, k, cos, sin)
# GQA expand
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
scale = math.sqrt(self.head_dim)
attn = torch.matmul(q, k.transpose(-2, -1)) / scale
if self.attn_softcap > 0:
attn = torch.tanh(attn / self.attn_softcap) * self.attn_softcap
causal = torch.triu(
torch.full((T, T), float("-inf"), device=device, dtype=dtype),
diagonal=1
)
attn = attn + causal.unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
attn = attn + attention_mask.to(device=device, dtype=dtype)
attn = F.softmax(attn, dim=-1).to(dtype)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(out)
# ── MoE Expert ────────────────────────────────────────────────────────────────
class Grok2Expert(nn.Module):
def __init__(self, hidden_size, moe_intermediate_size):
super().__init__()
self.w1 = nn.Linear(hidden_size, moe_intermediate_size, bias=False)
self.w2 = nn.Linear(moe_intermediate_size, hidden_size, bias=False)
self.w3 = nn.Linear(hidden_size, moe_intermediate_size, bias=False)
def forward(self, x):
device = self.w1.weight.device
x = x.to(device)
d1 = self.w1.weight.device
d3 = self.w3.weight.device
d2 = self.w2.weight.device
gate = F.silu(self.w1(x.to(d1)))
up = self.w3(x.to(d3))
h = gate.to(d2) * up.to(d2)
return self.w2(h)
# ── Sparse MoE ────────────────────────────────────────────────────────────────
class Grok2SparseMoE(nn.Module):
def __init__(self, config: Grok2Config):
super().__init__()
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
self.router_softcap = config.router_logit_softcapping
self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
self.experts = nn.ModuleList([
Grok2Expert(config.hidden_size, config.moe_intermediate_size)
for _ in range(config.num_local_experts)
])
def forward(self, x):
B, T, H = x.shape
device = x.device
dtype = x.dtype
x_flat = x.view(-1, H)
router_logits = self.gate(x_flat)
if self.router_softcap > 0:
router_logits = torch.tanh(router_logits / self.router_softcap) * self.router_softcap
router_weights = F.softmax(router_logits, dim=-1)
top_weights, top_indices = router_weights.topk(self.top_k, dim=-1)
top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
out = torch.zeros_like(x_flat)
for k in range(self.top_k):
expert_ids = top_indices[:, k]
weights = top_weights[:, k].unsqueeze(-1)
for e in range(self.num_experts):
mask = (expert_ids == e)
if not mask.any():
continue
# Move tokens to expert's device, compute, move result back
expert_device = next(self.experts[e].parameters()).device
x_masked = x_flat[mask].to(device=expert_device, dtype=dtype)
expert_out = self.experts[e](x_masked).to(device=device, dtype=dtype)
out[mask] += weights[mask] * expert_out
return out.view(B, T, H)
# ── Dense MLP ─────────────────────────────────────────────────────────────────
class Grok2MLP(nn.Module):
def __init__(self, config: Grok2Config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
# ── Decoder Layer ─────────────────────────────────────────────────────────────
class Grok2DecoderLayer(nn.Module):
def __init__(self, config: Grok2Config, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.pre_attn_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.self_attn = Grok2Attention(config)
self.post_attn_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.pre_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.block_sparse_moe = Grok2SparseMoE(config)
self.mlp = Grok2MLP(config)
self.post_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
def forward(self, hidden_states, attention_mask=None, **kwargs):
device = hidden_states.device
dtype = hidden_states.dtype
# Attention block
residual = hidden_states
hidden_states = self.pre_attn_norm(hidden_states)
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
hidden_states = self.post_attn_norm(hidden_states.to(device=device, dtype=dtype))
hidden_states = residual + hidden_states.to(device=device, dtype=dtype)
# MoE + dense residual block
residual = hidden_states
normed = self.pre_moe_norm(hidden_states)
moe_out = self.block_sparse_moe(normed)
mlp_out = self.mlp(normed)
combined = moe_out.to(device=device, dtype=dtype) + mlp_out.to(device=device, dtype=dtype)
hidden_states = self.post_moe_norm(combined)
hidden_states = residual + hidden_states.to(device=device, dtype=dtype)
return hidden_states
# ── Model ─────────────────────────────────────────────────────────────────────
class Grok2Model(nn.Module):
def __init__(self, config: Grok2Config):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.embedding_multiplier_scale = config.embedding_multiplier_scale
self.layers = nn.ModuleList([
Grok2DecoderLayer(config, i) for i in range(config.num_hidden_layers)
])
self.norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
def forward(self, input_ids, attention_mask=None, **kwargs):
hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier_scale
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask=attention_mask)
return self.norm(hidden_states)
# ── CausalLM ──────────────────────────────────────────────────────────────────
class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
config_class = Grok2Config
base_model_prefix = "model"
supports_gradient_checkpointing = False
def __init__(self, config: Grok2Config):
super().__init__(config)
self.model = Grok2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.output_multiplier_scale = config.output_multiplier_scale
self.final_logit_softcapping = config.final_logit_softcapping
self.post_init()
def get_input_embeddings(self): return self.model.embed_tokens
def get_output_embeddings(self): return self.lm_head
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
**kwargs,
):
hidden_states = self.model(input_ids, attention_mask=attention_mask)
# Move to lm_head device
hidden_states = hidden_states.to(self.lm_head.weight.device)
logits = self.lm_head(hidden_states) * self.output_multiplier_scale
if self.final_logit_softcapping > 0:
logits = torch.tanh(logits / self.final_logit_softcapping) * self.final_logit_softcapping
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
# ── Register ──────────────────────────────────────────────────────────────────
AutoConfig.register("grok2", Grok2Config)
AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)