AlexWortega's picture
Upload model.py with huggingface_hub
8314313 verified
"""200M-active / ~1.2B-total MoE model — fresh-pretrain target.
Architecture: same DeepSeekMoE-style 1-shared + 32-routed top-2 MoE,
GQA attention with QK-Norm and partial RoPE, tied embed/lm_head.
Two material differences from the 100M-active sibling at
`~/ml-intern-runs/moe-100m-volta-week/model.py`:
1. Bigger config defaults (vocab=151 936, d_model=640, n_layers=16,
n_q_heads=10, n_kv_heads=2, head_dim=64, d_ff=1024,
n_routed_experts=32, top_k=2, moe_first_layer=1).
2. **Tiled cross-entropy loss.** Vocab=151 936 × micro_bs=8 ×
seq_len=2048 in fp16 ≈ 4.7 GB just for the logits, again the same
for softmax intermediates. Instead we never materialize the full
logit tensor: we tile the post-final-norm hidden state into
`seq_chunk_size` slices along (B·S), do the per-slice
`F.linear(h, embed.T) → logits` and `F.cross_entropy(...)` inside a
`torch.utils.checkpoint`, and sum the partial CE values. Peak
resident logit buffer is `seq_chunk_size · vocab · 4` (fp32 for
numerical safety) ≈ 0.3 GB at chunk=512.
The chunked-CE path is mathematically equivalent to a single
`F.cross_entropy` on the full `(N, V)` logits with `reduction='mean'`,
to fp32 precision — verified by `tests/test_chunked_ce.py`.
The single-shot reference path is kept gated behind
`MoEModelConfig.use_chunked_ce=False` for the verification test only.
"""
from __future__ import annotations
import math
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as ckpt
# ============================ Config ============================
@dataclass
class MoEModelConfig:
vocab_size: int = 151936
d_model: int = 640
n_layers: int = 16
n_q_heads: int = 10
n_kv_heads: int = 2
head_dim: int = 64
rope_partial: int = 32
rope_theta: float = 10000.0
d_ff: int = 1024
# Variant A (router-stability rescue): dropped 32 -> 16 routed experts
# after two NaN cascades on 32+1. d_ff stays at 1024; active stays at
# ~200M, total drops 1.07B -> ~620M. Half the router load = 2x easier
# to balance under fp16 on V100 (no bf16 tensor cores), which lets us
# use moderate aux/bias rather than the aggressive ones that NaN'd.
n_routed_experts: int = 16
n_shared_experts: int = 1
top_k: int = 2
moe_first_layer: int = 1
router_z_coef: float = 1e-3
# Additive Gaussian noise applied to ``sel_logits`` (logit + bias) during
# training, before the top-k pick. Breaks routing lock-in so dead experts
# can occasionally win top-2 and the bias controller has something to
# work with. Set non-zero during a router-recovery resume. 0.0 = noise
# off (standard inference + post-recovery training). Eval is always
# noise-free regardless of this value.
router_noise_std: float = 0.0
# Variant A on 2 GPUs (CUDA_VISIBLE_DEVICES=2,3): moderate coeffs.
# 1e-3 aux + 1e-3 bias is 10x lower than the aux=1e-2 / bias=5e-3 that
# NaN'd on 32+1 experts, and we have 16 experts now (2x easier to
# balance). Half DDP all-reduce noise from 2 vs 4 GPUs further helps
# router stability. Magnitude-based bias formula kept (err=(mean-c)/mean).
router_aux_coef: float = 1e-3
bias_update_rate: float = 1e-3
max_seq_len: int = 2048
tie_embeddings: bool = True
rms_eps: float = 1e-6
init_std: float = 0.02
mup_base_d: int = 512
attn_backend: str = "sdpa" # "sdpa" or "fa_volta" (Triton FA fwd+bwd on V100)
moe_backend: str = "grouped" # "bmm" = per-expert for-loop (legacy);
# "grouped" = stacked-weight bmm (fast)
moe_capacity_factor: float = 1.25 # only used when moe_backend="grouped".
# 1.0 = no padding (drops overflow);
# 1.25 = ~6% drops at CV=0.5 (acceptable);
# 2.0 = no drops up to CV≈0.5 but 2x bmm work.
smear_gate: bool = True
use_chunked_ce: bool = True
ce_chunk_tokens: int = 512 # per-chunk token count for tiled CE
ce_checkpoint_chunks: bool = True
# Pass-2 optimization: route CE through Liger's fused linear+CE kernel
# which never materializes the (N, V) logit tensor. Falls back to the
# chunked CE path above if liger_kernel is not importable. Mathematically
# equivalent to `cross_entropy(F.linear(h, embed.T) * mup, labels)` to
# fp32 precision (verified by tests/test_liger_ce.py).
use_liger_ce: bool = True
def as_dict(self):
return asdict(self)
def small_config(**overrides) -> MoEModelConfig:
cfg = MoEModelConfig()
for k, v in overrides.items():
if not hasattr(cfg, k):
raise ValueError(f"unknown config key: {k}")
setattr(cfg, k, v)
return cfg
# ============================ Norms ============================
class RMSNorm(nn.Module):
def __init__(self, d: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d))
self.eps = eps
def forward(self, x):
dtype = x.dtype
x32 = x.float()
rms = x32.pow(2).mean(dim=-1, keepdim=True).add_(self.eps).rsqrt_()
return (x32 * rms).to(dtype) * self.weight
class QKNorm(nn.Module):
def __init__(self, n_heads, head_dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(n_heads, head_dim))
self.gain = nn.Parameter(torch.ones(n_heads, 1))
self.eps = eps
def forward(self, x):
dtype = x.dtype
x32 = x.float()
rms = x32.pow(2).mean(dim=-1, keepdim=True).add_(self.eps).rsqrt_()
out = (x32 * rms).to(dtype)
return out * self.weight.view(1, 1, *self.weight.shape) * \
self.gain.view(1, 1, *self.gain.shape)
# ============================ RoPE ============================
def _build_cos_sin(seq_len, dim, theta, device, dtype):
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos = freqs.cos().repeat_interleave(2, dim=-1)
sin = freqs.sin().repeat_interleave(2, dim=-1)
return cos.to(dtype), sin.to(dtype)
def _rotate_half_pairs(x):
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
class PartialRoPE(nn.Module):
def __init__(self, head_dim, rope_dim, max_seq_len, theta=10000.0):
super().__init__()
assert rope_dim <= head_dim and rope_dim % 2 == 0
self.head_dim = head_dim
self.rope_dim = rope_dim
self.max_seq_len = max_seq_len
self.theta = theta
cos, sin = _build_cos_sin(max_seq_len, rope_dim, theta, "cpu", torch.float32)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(self, q, k, position_ids=None):
S = q.size(1)
if position_ids is None:
cos = self.cos_cached[:S].to(q.dtype)
sin = self.sin_cached[:S].to(q.dtype)
else:
cos = self.cos_cached[position_ids].to(q.dtype)
sin = self.sin_cached[position_ids].to(q.dtype)
if cos.dim() == 2:
cos = cos.view(1, S, 1, self.rope_dim)
sin = sin.view(1, S, 1, self.rope_dim)
else:
cos = cos.view(cos.size(0), S, 1, self.rope_dim)
sin = sin.view(sin.size(0), S, 1, self.rope_dim)
def _apply(x):
x_rot = x[..., :self.rope_dim]
x_pass = x[..., self.rope_dim:]
x_rot = x_rot * cos + _rotate_half_pairs(x_rot) * sin
return torch.cat([x_rot, x_pass], dim=-1)
return _apply(q), _apply(k)
# ============================ Attention ============================
def _repeat_kv(x, n_rep):
if n_rep == 1: return x
B, S, H, D = x.shape
return x[:, :, :, None, :].expand(B, S, H, n_rep, D).reshape(B, S, H * n_rep, D)
class GQAAttention(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.n_q, self.n_kv, self.d_h = cfg.n_q_heads, cfg.n_kv_heads, cfg.head_dim
assert self.n_q % self.n_kv == 0
self.n_rep = self.n_q // self.n_kv
d = cfg.d_model
self.q_proj = nn.Linear(d, self.n_q * self.d_h, bias=False)
self.k_proj = nn.Linear(d, self.n_kv * self.d_h, bias=False)
self.v_proj = nn.Linear(d, self.n_kv * self.d_h, bias=False)
self.o_proj = nn.Linear(self.n_q * self.d_h, d, bias=False)
self.q_norm = QKNorm(self.n_q, self.d_h, eps=cfg.rms_eps)
self.k_norm = QKNorm(self.n_kv, self.d_h, eps=cfg.rms_eps)
self.rope = PartialRoPE(self.d_h, cfg.rope_partial, cfg.max_seq_len, cfg.rope_theta)
if cfg.smear_gate:
self.smear = nn.Parameter(torch.ones(self.n_kv))
else:
self.smear = None
def forward(self, x, attn_mask=None):
B, S, _ = x.shape
q = self.q_proj(x).view(B, S, self.n_q, self.d_h)
k = self.k_proj(x).view(B, S, self.n_kv, self.d_h)
v = self.v_proj(x).view(B, S, self.n_kv, self.d_h)
q = self.q_norm(q); k = self.k_norm(k)
q, k = self.rope(q, k)
if self.smear is not None:
v = v * self.smear.view(1, 1, self.n_kv, 1)
backend = getattr(self.cfg, "attn_backend", "sdpa")
# FA-Volta path is only correct in fp16/bf16 (Triton kernel is
# half-precision only) and only with the kv-repeat layout it
# expects: (B, S, H, D). SDPA path keeps the legacy (B, H, S, D).
use_fa = (backend == "fa_volta") and q.dtype in (torch.float16, torch.bfloat16)
if use_fa:
from flash_attn_volta.autograd import flash_attn
k_rep = _repeat_kv(k, self.n_rep)
v_rep = _repeat_kv(v, self.n_rep)
out = flash_attn(q.contiguous(), k_rep.contiguous(), v_rep.contiguous(),
causal=(attn_mask is None))
out = out.contiguous().view(B, S, self.n_q * self.d_h)
else:
qh = q.transpose(1, 2)
kh = _repeat_kv(k, self.n_rep).transpose(1, 2)
vh = _repeat_kv(v, self.n_rep).transpose(1, 2)
out = F.scaled_dot_product_attention(qh, kh, vh, is_causal=(attn_mask is None))
out = out.transpose(1, 2).contiguous().view(B, S, self.n_q * self.d_h)
return self.o_proj(out)
# ============================ Experts / Router / MoE ============================
class SwiGLUExpert(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.gate = nn.Linear(d_model, d_ff, bias=False)
self.up = nn.Linear(d_model, d_ff, bias=False)
self.down = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
return self.down(F.silu(self.gate(x)) * self.up(x))
class SigmoidRouter(nn.Module):
def __init__(self, d_model, n_experts, top_k,
z_coef=1e-3, aux_coef=1e-3, bias_update_rate=1e-3,
noise_std=0.0):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k
self.w = nn.Parameter(torch.zeros(n_experts, d_model))
nn.init.normal_(self.w, std=0.02)
self.register_buffer("bias", torch.zeros(n_experts))
self.z_coef = z_coef
self.aux_coef = aux_coef
self.bias_update_rate = bias_update_rate
self.noise_std = noise_std
def forward(self, x_flat):
with torch.cuda.amp.autocast(enabled=False):
x32 = x_flat.float()
logits = F.linear(x32, self.w.float())
scores = torch.sigmoid(logits)
sel_logits = logits + self.bias.float().unsqueeze(0)
if self.training and self.noise_std > 0:
# Additive Gaussian noise breaks load-imbalance lock-in.
sel_logits = sel_logits + torch.randn_like(sel_logits) * self.noise_std
topk_sel, topk_idx = torch.topk(sel_logits, k=self.top_k, dim=-1)
topk_weight = scores.gather(-1, topk_idx)
topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-9)
lse = torch.logsumexp(logits, dim=-1)
z_loss = (lse ** 2).mean()
with torch.no_grad():
one_hot = F.one_hot(topk_idx, num_classes=self.n_experts).sum(dim=1)
p_i = scores.mean(dim=0)
f_i_grad = one_hot.float().mean(dim=0)
aux_loss = self.n_experts * (f_i_grad * p_i).sum()
with torch.no_grad():
counts = one_hot.sum(dim=0).float()
cv = counts.std() / counts.mean().clamp_min(1.0)
# Entropy metric — whole router is in autocast(enabled=False)
# so plain fp32 ops are safe.
scores_fp32 = scores.float()
p_avg = scores_fp32.mean(dim=0).clamp_min(1e-9)
p_avg = p_avg / p_avg.sum()
entropy = -(p_avg * p_avg.log()).sum() / math.log(2.0)
return topk_idx, topk_weight, {"z_loss": z_loss, "aux_loss": aux_loss,
"counts": counts, "router_cv": cv,
"router_entropy_bits": entropy}
@torch.no_grad()
def step_bias_update(self, counts):
"""Symmetric load-balance bias update with starved-expert boost.
Old formulation used ``err = (mean - c) / max(mean, 1)`` which was
asymmetric — overloaded experts got pushed down with unbounded
magnitude (err can be large negative when c >> mean) while starved
experts could push up by at most +1 (err ≤ 1). After 7953 steps
on the 100B run, this drove biases to range [-23, +7], with 102 /
240 expert slots completely dead. See ``DEAD_EXPERTS.md``.
New formulation uses fractional load (`p_i = c_i / total`) vs
uniform target, gives a 10× rate boost for starved experts
(`p_i < 0.1 · target_p`), and hard-clamps the bias to [-5, +5]
to prevent runaway in either direction.
Caller MUST all-reduce ``counts`` across ranks before calling this
in a DDP setting — otherwise different ranks compute different
updates from local-view counts, and DDP's default
``broadcast_buffers=True`` then overwrites all ranks' bias with
rank 0's (biased) view.
"""
counts_f = counts.float()
total = counts_f.sum().clamp_min(1.0)
p_i = counts_f / total
target_p = 1.0 / self.n_experts
err = target_p - p_i # positive = underloaded
update = err * self.bias_update_rate
# Constant additive boost for starved experts (load < 10 % of
# fair share). Rate-multiplier alone is too weak; the original
# 100B run drove unclamped bias to -23/+7, so the natural control
# range is wide — clamp at ±10 (not ±5) so the controller can
# actually compete with router_w logits in the ±10 range we see
# at this ckpt. 0.05/step boost reaches the +10 clamp in 200 steps.
starved = (p_i < 0.1 * target_p).float()
update = update + starved * 0.05
self.bias.add_(update)
self.bias.clamp_(min=-10.0, max=10.0)
def _moe_dispatch_bmm(x, topk_idx, topk_weight, experts):
"""Per-expert dispatch — kept for parity / unit tests.
Issues a single GPU->CPU sync (via ``offsets.cpu().tolist()``) at the
start of each MoE forward so the python for-loop can slice the sorted
token list with integer offsets. Also runs a 1-token "dust" pass on
every expert each step so DDP's ``find_unused_parameters=True`` path
sees all expert grads. Slow but correct; superseded by
``_moe_dispatch_grouped`` on the fast path.
"""
N, K = topk_idx.shape
flat_expert = topk_idx.reshape(-1)
flat_weight = topk_weight.reshape(-1).to(x.dtype)
flat_token = torch.arange(N, device=x.device).repeat_interleave(K)
order = torch.argsort(flat_expert, stable=False)
flat_expert_s = flat_expert[order]
flat_token_s = flat_token[order]
flat_weight_s = flat_weight[order]
n_experts = len(experts)
counts = torch.bincount(flat_expert_s, minlength=n_experts)
offsets = torch.cumsum(counts, dim=0)
out = torch.zeros_like(x)
offsets_cpu = offsets.cpu().tolist()
counts_cpu = counts.cpu().tolist()
start = 0
x_dust = x[:1]
for e in range(n_experts):
y_dust = experts[e](x_dust)
out.index_add_(0, flat_token[:1], (y_dust * 0.0).to(out.dtype))
end = offsets_cpu[e]
if counts_cpu[e] == 0:
start = end; continue
tok_idx = flat_token_s[start:end]
w = flat_weight_s[start:end].unsqueeze(-1)
x_e = x.index_select(0, tok_idx)
y_e = experts[e](x_e)
out.index_add_(0, tok_idx, (y_e * w).to(out.dtype))
start = end
return out
def _moe_dispatch_grouped(x, topk_idx, topk_weight,
gate_w, up_w, down_w,
capacity_factor: float = 1.5):
"""Token-permuted, capacity-padded grouped-bmm MoE dispatch.
Inputs:
x: [N, d]
topk_idx: [N, K] long
topk_weight: [N, K]
gate_w, up_w: [E, d_ff, d]
down_w: [E, d, d_ff]
capacity_factor: pad each expert's slot count to
``ceil(N * K / E * capacity_factor)``. Tokens beyond capacity
are dropped (contribute 0); their topk weight is wasted but the
router still receives gradient through the still-routed top-k
partner. A factor of 1.5 leaves slack for CV up to ~1.5.
Returns:
out: [N, d] - sum over the top-k expert outputs, each multiplied
by the matching topk_weight, with dropped tokens
contributing zero.
Stays GPU-resident throughout — no .cpu() / .item() sync. Issues
3 batched-bmm kernels (gate, up, down) plus 1 sort + 1 bincount +
index_select/scatter, regardless of expert count.
"""
N, K = topk_idx.shape
E, d_ff, d = gate_w.shape
NK = N * K
capacity = max(1, int(math.ceil(NK / E * capacity_factor)))
flat_e = topk_idx.reshape(-1) # [NK]
flat_t = (torch.arange(N, device=x.device, dtype=torch.long)
.repeat_interleave(K)) # [NK]
flat_w = topk_weight.reshape(-1).to(x.dtype) # [NK]
order = torch.argsort(flat_e, stable=False)
sorted_e = flat_e[order] # [NK]
sorted_t = flat_t[order]
sorted_w = flat_w[order]
# Per-expert slot index (0..count[e]-1). Tokens with slot >= capacity
# are dropped.
counts = torch.bincount(sorted_e, minlength=E) # [E]
expert_start = counts.cumsum(0) - counts # [E]
global_pos = torch.arange(NK, device=x.device, dtype=torch.long)
slot_in_expert = global_pos - expert_start.index_select(0, sorted_e)
keep = slot_in_expert < capacity
slot_idx = sorted_e * capacity + slot_in_expert # flat [E*capacity]
kept_slot = slot_idx[keep]
kept_tok = sorted_t[keep]
kept_w = sorted_w[keep]
# Gather tokens into [E*capacity, d] dense buffer (zero where unused).
x_grouped = x.new_zeros((E * capacity, d))
x_kept = x.index_select(0, kept_tok) # [n_kept, d]
x_grouped.index_copy_(0, kept_slot, x_kept)
x_grouped = x_grouped.view(E, capacity, d)
# All-expert SwiGLU forward via 3 batched matmuls.
# gate_w: [E, d_ff, d]; x_grouped: [E, capacity, d]
g = torch.einsum("etd,efd->etf", x_grouped, gate_w)
u = torch.einsum("etd,efd->etf", x_grouped, up_w)
h = F.silu(g) * u # [E, capacity, d_ff]
y = torch.einsum("etf,edf->etd", h, down_w) # [E, capacity, d]
# Scatter back with topk weights.
y_flat = y.view(E * capacity, d)
y_kept = y_flat.index_select(0, kept_slot) * kept_w.unsqueeze(-1)
out = x.new_zeros((N, d))
out.index_add_(0, kept_tok, y_kept)
return out
class MoEFFN(nn.Module):
"""MoE FFN with two dispatch backends.
``cfg.moe_backend``:
* ``"bmm"`` — legacy per-expert python for-loop. Kept for
the chunked-CE unit test and as a fallback.
* ``"grouped"`` — token-permuted, capacity-padded grouped-bmm
dispatch (`_moe_dispatch_grouped`). Stacked
expert weights as ``self.gate`` / ``self.up``
/ ``self.down`` (shape ``[E, d_ff, d]`` and
``[E, d, d_ff]`` for down). The per-expert
``routed_experts`` ModuleList is **not**
built in this mode — state-dict keys are
flat ``gate``/``up``/``down``. Conversion
from a legacy ckpt is handled by
:func:`MoEModel.load_state_dict` (auto-stacks
``routed_experts.i.{gate,up,down}.weight``
into the new tensors).
"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.n_routed = cfg.n_routed_experts
self.n_shared = cfg.n_shared_experts
self.backend = getattr(cfg, "moe_backend", "bmm")
self.capacity_factor = getattr(cfg, "moe_capacity_factor", 1.5)
if self.backend == "grouped":
d = cfg.d_model; d_ff = cfg.d_ff; E = self.n_routed
self.gate = nn.Parameter(torch.empty(E, d_ff, d))
self.up = nn.Parameter(torch.empty(E, d_ff, d))
self.down = nn.Parameter(torch.empty(E, d, d_ff))
self.routed_experts = None
else:
self.routed_experts = nn.ModuleList(
[SwiGLUExpert(cfg.d_model, cfg.d_ff) for _ in range(self.n_routed)]
)
self.gate = self.up = self.down = None
self.shared_expert = SwiGLUExpert(cfg.d_model, cfg.d_ff) if self.n_shared > 0 else None
self.router = SigmoidRouter(d_model=cfg.d_model, n_experts=self.n_routed,
top_k=cfg.top_k, z_coef=cfg.router_z_coef,
aux_coef=cfg.router_aux_coef,
bias_update_rate=cfg.bias_update_rate,
noise_std=getattr(cfg, "router_noise_std", 0.0))
def forward(self, x):
B, S, d = x.shape
x_flat = x.reshape(B * S, d)
topk_idx, topk_weight, aux = self.router(x_flat)
if self.backend == "grouped":
y_routed = _moe_dispatch_grouped(
x_flat, topk_idx, topk_weight,
self.gate, self.up, self.down,
capacity_factor=self.capacity_factor,
)
else:
y_routed = _moe_dispatch_bmm(x_flat, topk_idx, topk_weight, self.routed_experts)
if self.shared_expert is not None:
y = y_routed + self.shared_expert(x_flat)
else:
y = y_routed
return y.view(B, S, d), aux
# ============================ Block ============================
class Block(nn.Module):
def __init__(self, cfg, layer_idx):
super().__init__()
self.cfg = cfg
self.layer_idx = layer_idx
self.attn_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
self.attn = GQAAttention(cfg)
self.ffn_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
self.is_moe = layer_idx >= cfg.moe_first_layer
if self.is_moe:
self.ffn = MoEFFN(cfg)
else:
self.ffn = SwiGLUExpert(cfg.d_model, cfg.d_ff)
def forward(self, x, attn_mask=None):
x = x + self.attn(self.attn_norm(x), attn_mask=attn_mask)
if self.is_moe:
y, aux = self.ffn(self.ffn_norm(x))
return x + y, aux
else:
return x + self.ffn(self.ffn_norm(x)), None
# ============================ Tiled CE loss ============================
def _ce_chunk_forward(h_chunk: torch.Tensor,
embed_weight: torch.Tensor,
labels_chunk: torch.Tensor,
mup_scale: float,
reduction: str = "sum") -> torch.Tensor:
"""Compute CE on a single (N_chunk, D) slice. Returns a scalar tensor
representing the *sum* of CE over the chunk's non-ignored positions
(or 'mean' if reduction='mean')."""
logits = F.linear(h_chunk, embed_weight)
logits = (logits * mup_scale).float()
return F.cross_entropy(logits, labels_chunk,
ignore_index=-100, reduction=reduction)
def tiled_cross_entropy(h_flat: torch.Tensor,
embed_weight: torch.Tensor,
labels_flat: torch.Tensor,
mup_scale: float,
chunk_size: int = 512,
use_checkpoint: bool = True) -> torch.Tensor:
"""Mean cross-entropy over `labels_flat`, computed in chunks along the
token dimension. Gradient flows back into `h_flat` and `embed_weight`.
Equivalent (to fp32 precision) to:
logits = F.linear(h_flat, embed_weight) * mup_scale
F.cross_entropy(logits.float(), labels_flat, ignore_index=-100,
reduction='mean')
Memory: peak resident logit buffer is `chunk_size · vocab · 4` bytes
(one chunk at a time, no full (N, V) materialization).
With `use_checkpoint=True`, each chunk's forward (linear + CE) is
wrapped in `torch.utils.checkpoint`, so backward recomputes the chunk
instead of holding logits + softmax intermediates resident across the
full backward pass. Cost: one extra forward per chunk during backward.
"""
N = h_flat.size(0)
total_sum = h_flat.new_zeros((), dtype=torch.float32)
valid_mask = labels_flat != -100
n_valid = valid_mask.sum().clamp_min(1).to(torch.float32)
for i in range(0, N, chunk_size):
h_i = h_flat[i:i + chunk_size]
lbl_i = labels_flat[i:i + chunk_size]
if use_checkpoint and h_i.requires_grad:
# Float values cannot be passed into checkpoint as a Tensor arg
# would be — wrap as a 0-d tensor so autograd treats it cleanly.
mup = torch.tensor(mup_scale, device=h_i.device, dtype=torch.float32)
def _fn(h_c, w_c, lbl_c, mup_c):
logits = F.linear(h_c, w_c)
logits = (logits * mup_c).float()
return F.cross_entropy(logits, lbl_c, ignore_index=-100,
reduction="sum")
ce_sum_i = ckpt.checkpoint(_fn, h_i, embed_weight, lbl_i, mup,
use_reentrant=True)
else:
ce_sum_i = _ce_chunk_forward(h_i, embed_weight, lbl_i, mup_scale,
reduction="sum")
total_sum = total_sum + ce_sum_i
return total_sum / n_valid
# ============================ Liger fused linear+CE ============================
_LIGER_AVAILABLE = None
_LIGER_LOSS_FN = None
def _try_import_liger():
"""Resolve the Liger fused linear+CE loss class lazily and cache it.
Returns the class object on success, ``None`` on import failure. The
module-level cache means the import + class lookup happens once per
process even though we may instantiate the loss many times.
"""
global _LIGER_AVAILABLE, _LIGER_LOSS_FN
if _LIGER_AVAILABLE is False:
return None
if _LIGER_LOSS_FN is not None:
return _LIGER_LOSS_FN
try:
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
_LIGER_LOSS_FN = LigerFusedLinearCrossEntropyLoss
_LIGER_AVAILABLE = True
return _LIGER_LOSS_FN
except Exception:
_LIGER_AVAILABLE = False
return None
def _maybe_disable_dynamo(fn):
"""Mark ``fn`` opaque to ``torch._dynamo`` if dynamo is importable.
Liger's Triton kernel is incompatible with Inductor's launcher rewrite
(it gets called with ``num_warps`` as a kwarg that the rewritten
launcher does not accept). We don't *want* Inductor to inline this
call anyway — the whole point of Liger is that its hand-tuned kernel
is already faster than anything dynamo would synthesize.
"""
try:
import torch._dynamo as _dynamo
return _dynamo.disable(fn)
except Exception:
return fn
@_maybe_disable_dynamo
def liger_fused_cross_entropy(h_flat: torch.Tensor,
embed_weight: torch.Tensor,
labels_flat: torch.Tensor,
mup_scale: float) -> torch.Tensor:
"""Liger fused linear+CE — single Triton kernel that computes
``cross_entropy(F.linear(h, embed_weight) * mup_scale, labels)`` without
ever materializing the (N, V) logit tensor.
Equivalence to ``tiled_cross_entropy``:
F.linear(h * mup_scale, embed_weight)
= F.linear(h, embed_weight) * mup_scale (linearity)
so pre-scaling ``h_flat`` by ``mup_scale`` and feeding it as ``_input``
matches the original ``logits * mup_scale`` semantics exactly.
Args:
h_flat: [N, D] hidden states (typically fp16 in training).
embed_weight: [V, D] tied embed / lm_head weight (fp16).
labels_flat: [N] long tensor; -100 entries are ignored.
mup_scale: scalar; multiplies hidden state pre-linear.
Returns:
scalar fp32 loss (mean over non-ignored positions).
"""
cls = _try_import_liger()
if cls is None:
raise RuntimeError(
"liger_kernel not importable — install with "
"`python3.10 -m pip install liger-kernel==0.3.0 --no-deps`."
)
loss_fn = cls(ignore_index=-100, reduction="mean")
h_scaled = h_flat * mup_scale
# Inside autocast, Liger reads ``torch.get_autocast_gpu_dtype()`` to decide
# the internal logits dtype but accumulates ``grad_weight`` in
# ``weight.dtype`` (fp32) and ``_input_chunk`` in ``_input.dtype`` (fp32 if
# the upstream RMSNorm promoted), so addmm sees mat1=Half and mat2=Float
# and rejects. Fix by casting both _input and lin_weight to autocast dtype
# before the call; autograd's .to() handles the gradient cast back to fp32
# at parameter accumulation time.
if torch.is_autocast_enabled():
dt = torch.get_autocast_gpu_dtype()
h_scaled = h_scaled.to(dt)
embed_in = embed_weight.to(dt)
else:
embed_in = embed_weight
return loss_fn(lin_weight=embed_in, _input=h_scaled,
target=labels_flat, bias=None)
# ============================ Top-level model ============================
_ROUTED_PARAM_SUFFIXES = (".gate", ".up", ".down")
def _is_routed_expert_param(name: str) -> bool:
"""True if the parameter name belongs to the routed-expert FFN stack.
Covers both layouts:
* legacy: ``blocks.{i}.ffn.routed_experts.{e}.{gate,up,down}.weight``
* grouped: ``blocks.{i}.ffn.{gate,up,down}`` (stacked [E, *, *])
The grouped-tensor names collide with the shared-expert's
``blocks.{i}.ffn.shared_expert.{gate,up,down}.weight`` — that case is
filtered out by the ``shared_expert`` clause earlier in the caller's
classification chain, so this function only needs to ID the routed
stack vs everything else.
"""
if "routed_experts" in name:
return True
# Grouped layout: blocks.{i}.ffn.{gate,up,down} (no further suffix)
parts = name.split(".")
if len(parts) >= 4 and parts[0] == "blocks" and parts[2] == "ffn":
if parts[3] in ("gate", "up", "down") and len(parts) == 4:
return True
return False
class MoEModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.blocks = nn.ModuleList([Block(cfg, i) for i in range(cfg.n_layers)])
self.final_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
if cfg.tie_embeddings:
self.lm_head = None
else:
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.mup_scale = math.sqrt(cfg.mup_base_d / cfg.d_model)
self.apply(self._init_weights)
# Initialize stacked MoE weights (the apply() walk above only sees
# nn.Linear / nn.Embedding modules; raw nn.Parameter tensors on the
# grouped backend need explicit init.).
if getattr(cfg, "moe_backend", "bmm") == "grouped":
self._init_grouped_moe()
def _init_weights(self, m):
cfg = self.cfg
if isinstance(m, nn.Linear):
with torch.no_grad():
nn.init.orthogonal_(m.weight)
fan_in = m.weight.size(1)
m.weight.mul_(1.0 / math.sqrt(fan_in) * math.sqrt(m.weight.size(0)))
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=cfg.init_std)
def _init_grouped_moe(self):
"""Match the per-expert SwiGLUExpert init: per-expert orthogonal,
then ``1/sqrt(fan_in) * sqrt(fan_out)`` rescale. Applied independently
per expert so the stacked tensor is statistically equivalent to a
ModuleList of independently-initialized experts."""
for blk in self.blocks:
if not blk.is_moe: continue
moe = blk.ffn
if moe.backend != "grouped": continue
for w in (moe.gate, moe.up, moe.down):
with torch.no_grad():
# w: [E, out, in]
fan_in = w.size(-1)
fan_out = w.size(-2)
for e in range(w.size(0)):
nn.init.orthogonal_(w[e])
w[e].mul_(1.0 / math.sqrt(fan_in) * math.sqrt(fan_out))
def _convert_legacy_moe_keys(self, state_dict):
"""If state_dict carries per-expert ``routed_experts.{i}.{gate,up,down}.weight``
and the model is in grouped backend, stack them into the new
``gate``/``up``/``down`` tensors and drop the per-expert keys.
No-op if either the model is in bmm backend or the state_dict already
uses the stacked keys.
"""
if getattr(self.cfg, "moe_backend", "bmm") != "grouped":
return state_dict
new_sd = dict(state_dict)
for li, blk in enumerate(self.blocks):
if not blk.is_moe: continue
prefix = f"blocks.{li}.ffn"
legacy_key = f"{prefix}.routed_experts.0.gate.weight"
if legacy_key not in new_sd: continue
E = self.cfg.n_routed_experts
for which, attr in [("gate", "gate"), ("up", "up"), ("down", "down")]:
stack = []
for e in range(E):
k = f"{prefix}.routed_experts.{e}.{which}.weight"
stack.append(new_sd.pop(k))
new_sd[f"{prefix}.{attr}"] = torch.stack(stack, dim=0)
return new_sd
def load_state_dict(self, state_dict, strict=True, assign=False):
state_dict = self._convert_legacy_moe_keys(state_dict)
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def _lm_head_weight(self):
return self.embed.weight if self.lm_head is None else self.lm_head.weight
def forward(self, input_ids, labels=None, return_aux=True,
return_logits: bool = False):
"""Forward pass.
If `labels` is provided, returns `(logits_or_None, loss, aux_total)`
— and by default `logits` is `None` (we never materialize the full
(B,S,V) tensor in training to keep peak memory low). Pass
`return_logits=True` only for eval / generation paths that fit.
If `labels` is None, returns `(logits, None, aux_total)` with the
full logit tensor — only safe at small B·S.
"""
x = self.embed(input_ids)
aux_total = {"z_loss": 0.0, "aux_loss": 0.0,
"router_cv_sum": 0.0, "router_entropy_sum": 0.0,
"n_moe": 0, "counts_per_layer": []}
for blk in self.blocks:
x, aux = blk(x)
if aux is not None:
aux_total["z_loss"] = aux_total["z_loss"] + aux["z_loss"]
aux_total["aux_loss"] = aux_total["aux_loss"] + aux["aux_loss"]
aux_total["router_cv_sum"] = aux_total["router_cv_sum"] + aux["router_cv"].detach()
aux_total["router_entropy_sum"] = aux_total["router_entropy_sum"] + aux["router_entropy_bits"].detach()
aux_total["n_moe"] += 1
aux_total["counts_per_layer"].append(aux["counts"].detach())
x = self.final_norm(x)
head_w = self._lm_head_weight()
loss = None
logits = None
if labels is not None:
B, S, D = x.shape
h_flat = x.reshape(B * S, D)
lbl_flat = labels.reshape(-1).long()
use_liger = getattr(self.cfg, "use_liger_ce", False) and \
_try_import_liger() is not None and \
not return_logits
if use_liger:
loss = liger_fused_cross_entropy(
h_flat, head_w, lbl_flat, self.mup_scale,
)
elif self.cfg.use_chunked_ce:
loss = tiled_cross_entropy(
h_flat, head_w, lbl_flat, self.mup_scale,
chunk_size=self.cfg.ce_chunk_tokens,
use_checkpoint=self.cfg.ce_checkpoint_chunks,
)
else:
logits_full = F.linear(h_flat, head_w) * self.mup_scale
loss = F.cross_entropy(logits_full.float(), lbl_flat,
ignore_index=-100, reduction="mean")
if return_logits:
logits = logits_full.view(B, S, -1)
if return_logits and logits is None:
# Materialize once for eval/gen if caller insists. Tile to
# avoid the single-shot allocation in the chunked path.
logits_full = F.linear(h_flat, head_w) * self.mup_scale
logits = logits_full.view(B, S, -1)
else:
logits = F.linear(x, head_w) * self.mup_scale
if return_aux:
n_moe = max(1, aux_total["n_moe"])
aux_total["router_cv"] = aux_total["router_cv_sum"] / n_moe
aux_total["router_entropy_bits"] = aux_total["router_entropy_sum"] / n_moe
return logits, loss, aux_total
return logits, loss
@torch.no_grad()
def step_router_biases(self, counts_per_layer):
i = 0
for blk in self.blocks:
if blk.is_moe:
blk.ffn.router.step_bias_update(counts_per_layer[i])
i += 1
def num_parameters(self, only_active=False):
if not only_active:
return sum(p.numel() for p in self.parameters())
total = 0
for n, p in self.named_parameters():
if _is_routed_expert_param(n):
total += int(p.numel() * self.cfg.top_k / self.cfg.n_routed_experts)
else:
total += p.numel()
return total
def param_breakdown(self):
"""Return a dict with named param-count buckets — useful for
comparing against the design target."""
b = {"embed": 0, "attn": 0, "router": 0,
"shared_expert": 0, "routed_experts": 0,
"dense_ffn": 0, "norms": 0, "lm_head": 0, "other": 0}
for n, p in self.named_parameters():
num = p.numel()
if "embed" in n:
b["embed"] += num
elif "lm_head" in n:
b["lm_head"] += num
elif "attn" in n or "q_proj" in n or "k_proj" in n or "v_proj" in n or "o_proj" in n or "q_norm" in n or "k_norm" in n or "rope" in n or "smear" in n:
b["attn"] += num
elif "router" in n:
b["router"] += num
elif "shared_expert" in n:
b["shared_expert"] += num
elif _is_routed_expert_param(n):
b["routed_experts"] += num
elif "ffn" in n and ("gate" in n or "up" in n or "down" in n):
b["dense_ffn"] += num
elif "norm" in n:
b["norms"] += num
else:
b["other"] += num
return b
if __name__ == "__main__":
# quick standalone smoke
cfg = small_config()
m = MoEModel(cfg)
print(f"params total {m.num_parameters()/1e6:.2f} M "
f"active {m.num_parameters(only_active=True)/1e6:.2f} M")
bd = m.param_breakdown()
for k, v in bd.items():
print(f" {k}: {v/1e6:.2f} M")
ids = torch.randint(0, cfg.vocab_size, (2, 64))
logits, loss, aux = m(ids, labels=ids)
print(f"logits {logits} loss {loss.item():.3f} router_cv {aux['router_cv'].item():.3f}")