""" modeling_grok2.py — Grok 2 for transformers, full multi-GPU support. Pure bf16 throughout. Device-aware at every operation. """ import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from transformers import AutoConfig, AutoModelForCausalLM # ── Config ──────────────────────────────────────────────────────────────────── class Grok2Config(PretrainedConfig): model_type = "grok2" def __init__( self, vocab_size=131072, hidden_size=8192, num_hidden_layers=64, num_attention_heads=64, num_key_value_heads=8, intermediate_size=32768, moe_intermediate_size=16384, num_local_experts=8, num_experts_per_tok=2, max_position_embeddings=131072, rope_theta=208533496.0, rms_norm_eps=1e-5, embedding_multiplier_scale=90.50966799187809, output_multiplier_scale=0.5, final_logit_softcapping=50.0, attn_logit_softcapping=30.0, router_logit_softcapping=30.0, pad_token_id=0, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.head_dim = hidden_size // num_attention_heads self.intermediate_size = intermediate_size self.moe_intermediate_size = moe_intermediate_size self.num_local_experts = num_local_experts self.num_experts_per_tok = num_experts_per_tok self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.rms_norm_eps = rms_norm_eps self.embedding_multiplier_scale = embedding_multiplier_scale self.output_multiplier_scale = output_multiplier_scale self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping self.router_logit_softcapping = router_logit_softcapping super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) # ── RMSNorm ─────────────────────────────────────────────────────────────────── class Grok2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x): # Stay in input dtype throughout variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return self.weight.to(x.device, x.dtype) * x # ── RoPE ────────────────────────────────────────────────────────────────────── def rotate_half(x): x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] return torch.cat([-x2, x1], dim=-1) def apply_rotary_emb(q, k, cos, sin): return (q * cos) + (rotate_half(q) * sin), \ (k * cos) + (rotate_half(k) * sin) class Grok2RotaryEmbedding(nn.Module): def __init__(self, dim, max_pos=131072, base=208533496.0, scaling_factor=16.0): super().__init__() base = base * scaling_factor inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._cached_len = 0 def _build_cache(self, seq_len, device, dtype): t = torch.arange(seq_len, device=device).float() freqs = torch.outer(t, self.inv_freq.to(device)) emb = torch.cat([freqs, freqs], dim=-1) self._cos = emb.cos().to(dtype)[None, None, :, :] self._sin = emb.sin().to(dtype)[None, None, :, :] self._cached_len = seq_len self._cached_device = device def forward(self, seq_len, device, dtype): if seq_len > self._cached_len or not hasattr(self, '_cached_device') or device != self._cached_device: self._build_cache(seq_len, device, dtype) return self._cos[:, :, :seq_len, :], self._sin[:, :, :seq_len, :] # ── Attention ───────────────────────────────────────────────────────────────── class Grok2Attention(nn.Module): def __init__(self, config: Grok2Config): super().__init__() self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.head_dim = config.head_dim self.num_kv_groups = self.num_heads // self.num_kv_heads self.attn_softcap = config.attn_logit_softcapping self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False) self.rotary_emb = Grok2RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta) def forward(self, hidden_states, attention_mask=None, **kwargs): B, T, _ = hidden_states.shape device = hidden_states.device dtype = hidden_states.dtype q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(T, device, dtype) cos = cos[:, :, :T, :self.head_dim] sin = sin[:, :, :T, :self.head_dim] q, k = apply_rotary_emb(q, k, cos, sin) # GQA expand k = k.repeat_interleave(self.num_kv_groups, dim=1) v = v.repeat_interleave(self.num_kv_groups, dim=1) scale = math.sqrt(self.head_dim) attn = torch.matmul(q, k.transpose(-2, -1)) / scale if self.attn_softcap > 0: attn = torch.tanh(attn / self.attn_softcap) * self.attn_softcap causal = torch.triu( torch.full((T, T), float("-inf"), device=device, dtype=dtype), diagonal=1 ) attn = attn + causal.unsqueeze(0).unsqueeze(0) if attention_mask is not None: attn = attn + attention_mask.to(device=device, dtype=dtype) attn = F.softmax(attn, dim=-1).to(dtype) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(B, T, -1) return self.o_proj(out) # ── MoE Expert ──────────────────────────────────────────────────────────────── class Grok2Expert(nn.Module): def __init__(self, hidden_size, moe_intermediate_size): super().__init__() self.w1 = nn.Linear(hidden_size, moe_intermediate_size, bias=False) self.w2 = nn.Linear(moe_intermediate_size, hidden_size, bias=False) self.w3 = nn.Linear(hidden_size, moe_intermediate_size, bias=False) def forward(self, x): device = self.w1.weight.device x = x.to(device) d1 = self.w1.weight.device d3 = self.w3.weight.device d2 = self.w2.weight.device gate = F.silu(self.w1(x.to(d1))) up = self.w3(x.to(d3)) h = gate.to(d2) * up.to(d2) return self.w2(h) # ── Sparse MoE ──────────────────────────────────────────────────────────────── class Grok2SparseMoE(nn.Module): def __init__(self, config: Grok2Config): super().__init__() self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok self.router_softcap = config.router_logit_softcapping self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) self.experts = nn.ModuleList([ Grok2Expert(config.hidden_size, config.moe_intermediate_size) for _ in range(config.num_local_experts) ]) def forward(self, x): B, T, H = x.shape device = x.device dtype = x.dtype x_flat = x.view(-1, H) router_logits = self.gate(x_flat) if self.router_softcap > 0: router_logits = torch.tanh(router_logits / self.router_softcap) * self.router_softcap router_weights = F.softmax(router_logits, dim=-1) top_weights, top_indices = router_weights.topk(self.top_k, dim=-1) top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True) out = torch.zeros_like(x_flat) for k in range(self.top_k): expert_ids = top_indices[:, k] weights = top_weights[:, k].unsqueeze(-1) for e in range(self.num_experts): mask = (expert_ids == e) if not mask.any(): continue # Move tokens to expert's device, compute, move result back expert_device = next(self.experts[e].parameters()).device x_masked = x_flat[mask].to(device=expert_device, dtype=dtype) expert_out = self.experts[e](x_masked).to(device=device, dtype=dtype) out[mask] += weights[mask] * expert_out return out.view(B, T, H) # ── Dense MLP ───────────────────────────────────────────────────────────────── class Grok2MLP(nn.Module): def __init__(self, config: Grok2Config): super().__init__() self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) # ── Decoder Layer ───────────────────────────────────────────────────────────── class Grok2DecoderLayer(nn.Module): def __init__(self, config: Grok2Config, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.pre_attn_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps) self.self_attn = Grok2Attention(config) self.post_attn_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps) self.pre_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps) self.block_sparse_moe = Grok2SparseMoE(config) self.mlp = Grok2MLP(config) self.post_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps) def forward(self, hidden_states, attention_mask=None, **kwargs): device = hidden_states.device dtype = hidden_states.dtype # Attention block residual = hidden_states hidden_states = self.pre_attn_norm(hidden_states) hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask) hidden_states = self.post_attn_norm(hidden_states.to(device=device, dtype=dtype)) hidden_states = residual + hidden_states.to(device=device, dtype=dtype) # MoE + dense residual block residual = hidden_states normed = self.pre_moe_norm(hidden_states) moe_out = self.block_sparse_moe(normed) mlp_out = self.mlp(normed) combined = moe_out.to(device=device, dtype=dtype) + mlp_out.to(device=device, dtype=dtype) hidden_states = self.post_moe_norm(combined) hidden_states = residual + hidden_states.to(device=device, dtype=dtype) return hidden_states # ── Model ───────────────────────────────────────────────────────────────────── class Grok2Model(nn.Module): def __init__(self, config: Grok2Config): super().__init__() self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embedding_multiplier_scale = config.embedding_multiplier_scale self.layers = nn.ModuleList([ Grok2DecoderLayer(config, i) for i in range(config.num_hidden_layers) ]) self.norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps) def forward(self, input_ids, attention_mask=None, **kwargs): hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier_scale for layer in self.layers: hidden_states = layer(hidden_states, attention_mask=attention_mask) return self.norm(hidden_states) # ── CausalLM ────────────────────────────────────────────────────────────────── class Grok1ForCausalLM(PreTrainedModel, GenerationMixin): config_class = Grok2Config base_model_prefix = "model" supports_gradient_checkpointing = False def __init__(self, config: Grok2Config): super().__init__(config) self.model = Grok2Model(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.output_multiplier_scale = config.output_multiplier_scale self.final_logit_softcapping = config.final_logit_softcapping self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def get_output_embeddings(self): return self.lm_head def forward( self, input_ids=None, attention_mask=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, **kwargs, ): hidden_states = self.model(input_ids, attention_mask=attention_mask) # Move to lm_head device hidden_states = hidden_states.to(self.lm_head.weight.device) logits = self.lm_head(hidden_states) * self.output_multiplier_scale if self.final_logit_softcapping > 0: logits = torch.tanh(logits / self.final_logit_softcapping) * self.final_logit_softcapping loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} # ── Register ────────────────────────────────────────────────────────────────── AutoConfig.register("grok2", Grok2Config) AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)