| """Self-contained modeling file for trust_remote_code use. |
| |
| This file merges mup_models.py and hf_wrapper.py into a single module with no |
| imports from looped_scaling.*. It is intended to be placed alongside a |
| config.json that sets ``auto_map`` / ``model_type = "loop-lm"`` so that |
| HuggingFace's ``from_pretrained(..., trust_remote_code=True)`` can load it |
| without requiring the looped_scaling package to be installed. |
| |
| Supported model variants: "base" (MuTransformer), "looped" (LoopedTransformer), |
| "moe" (MoETransformer), "looped-moe" (LoopedMoETransformer). |
| """ |
|
|
| import torch |
| import math |
| import sys |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from collections.abc import Callable, Iterable |
| from einops import rearrange, einsum, reduce, repeat |
| from typing import IO, Any, BinaryIO, Optional |
| from torch import Tensor |
| from collections import Counter, defaultdict |
| from torch.nn.functional import scaled_dot_product_attention as sdpa |
| from torch.nn.functional import grouped_mm, silu |
| from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| BASE_D_MODEL = 128 |
| BASE_D_FF = 384 |
|
|
| """ Standard Transformer and Components implemented with muP """ |
|
|
|
|
| |
| |
| |
|
|
| def softmax(logits: Tensor, dim: int) -> Tensor: |
| logits = logits.float() |
| |
| max_values = torch.max(logits, dim=dim, keepdim=True).values |
|
|
| |
| shifted = logits - max_values |
|
|
| |
| shifted_exps = torch.exp(shifted) |
|
|
| |
| shifted_exp_sums = torch.sum(shifted_exps, dim=dim, keepdim=True) |
|
|
| |
| product = shifted_exps / shifted_exp_sums |
|
|
| return product |
|
|
|
|
| |
| class Linear(nn.Module): |
| def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None): |
| super().__init__() |
|
|
| |
| self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device=device)) |
|
|
| |
| std_scaled = std_base / math.sqrt(width_ratio) |
| nn.init.trunc_normal_(self.weight, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| |
| |
| return einsum(self.weight, x, "d_out d_in, ... d_in -> ... d_out") |
|
|
| class Embedding(nn.Module): |
| def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None): |
| super().__init__() |
|
|
| |
| self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device)) |
|
|
| |
| nn.init.trunc_normal_(self.weight, mean=0.0, std=1.0, a=-3, b=3) |
|
|
| def forward(self, token_ids: Tensor) -> Tensor: |
| |
| return self.weight[token_ids] |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None): |
| super().__init__() |
|
|
| |
| self.d_model = d_model |
| self.eps = eps |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| in_dtype = x.dtype |
| x = x.to(torch.float32) |
|
|
| |
| |
| mean_squared_sum = (1/self.d_model)*einsum(x, x, "... seq d, ... seq d -> ... seq") |
| rms = torch.sqrt(mean_squared_sum + self.eps) |
|
|
| |
| rms_norm = einsum(x, 1/rms, "... seq d, ... seq -> ... seq d") |
|
|
| |
| return rms_norm.to(in_dtype) |
|
|
| class PositionwiseFeedforward(nn.Module): |
| |
| def __init__(self, d_model: int, d_ff: int, width_ratio: float, device=None, dtype=None): |
| super().__init__() |
|
|
| |
| w_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_FF)) |
|
|
| |
| self.w1 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype) |
| self.w2 = Linear(d_ff, d_model, width_ratio, w_std_base, device=device, dtype=dtype) |
| self.w3 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| silu_in = self.w1(x) |
| silu_out = silu(silu_in) |
| gate = self.w3(x) |
| gated_prod = silu_out * gate |
| final_prod = self.w2(gated_prod) |
| return final_prod |
|
|
| class RotaryPositionalEmbedding(nn.Module): |
| def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None, dtype=None): |
| """ |
| theta: float Θ value for the RoPE |
| d_k: int dimension of query and key vectors |
| max_seq_len: int Maximum sequence length that will be inputted |
| device: torch.device | None = None Device to store the buffer on |
| """ |
| super().__init__() |
| rotations = torch.empty(max_seq_len, d_k//2, 2, 2, device=device, dtype=dtype) |
|
|
| |
| for i in range(max_seq_len): |
| for k in range(d_k//2): |
| angle = i/(theta**(2*k/d_k)) |
| rot = Tensor([[math.cos(angle), -math.sin(angle)], |
| [math.sin(angle), math.cos(angle)]]) |
| rotations[i, k, :] = rot |
|
|
| self.register_buffer("rotations", rotations, persistent=True) |
|
|
|
|
| def forward(self, x: Tensor, token_positions: Tensor) -> Tensor: |
| """ |
| self.rotations shape: (seq_dim, feature_dim, 2, 2) |
| x: tensor of shape (..., seq_dim, feature_dim) |
| token_positions: tensor of shape (..., seq_dim) |
| """ |
| |
| |
| rot = self.rotations[token_positions].to(dtype=x.dtype) |
|
|
| |
| x_pairs = rearrange(x, "... seq_dim (feature_dim i) -> ... seq_dim feature_dim i", i=2) |
|
|
| |
| y_pairs = einsum(rot, x_pairs, "... seq_dim feature_dim i j, ... seq_dim feature_dim j -> ... seq_dim feature_dim i") |
|
|
| |
| y = rearrange(y_pairs, "... seq_dim feature_dim i -> ... seq_dim (feature_dim i)") |
|
|
| return y |
|
|
| def scaled_dot_product_attention( |
| Q: Tensor, |
| K: Tensor, |
| V: Tensor, |
| mask: Optional[Tensor] = None, |
| ) -> Tensor: |
| """ |
| Given key (K), query (Q), and value (V) tensors, return |
| the output of your scaled dot product attention implementation. |
| |
| Args: |
| let m be seq length of inputs, n be seq length of outputs |
| d_k is look-up dim, d_v is value dim |
| Q (Float[Tensor, "batch ... n d_k"]): Query tensor |
| K (Float[Tensor, "batch ... m d_k"]): Key tensor |
| V (Float[Tensor, "batch ... m d_v"]): Values tensor |
| mask (Float[Tensor, " ... n m"] | None): Mask tensor |
| Returns: |
| Float[Tensor, " ... n d_v"]: Output of SDPA |
| """ |
|
|
| |
| d_k = Q.shape[-1] |
| assert d_k == K.shape[-1] |
|
|
| |
| scores = einsum(Q, K, "... n d_k, ... m d_k -> ... n m") / d_k |
|
|
| |
| if mask is not None: |
| bool_mask = mask.bool() |
| attn_mask = torch.where(bool_mask, 0.0, float('-inf')).to(scores.dtype) |
| scores = scores + attn_mask |
|
|
| |
| weights = softmax(scores, dim=-1) |
|
|
| |
| return einsum(weights, V, "... n m, ... m d_v -> ... n d_v") |
|
|
| class MultiheadSelfAttention(nn.Module): |
| """ |
| Args: |
| d_model (int): Dimensionality of the feedforward input and output. |
| num_heads (int): Number of heads to use in multi-headed attention. |
| max_seq_len (int): Maximum sequence length to pre-cache if your implementation does that. |
| q_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the Q projection |
| k_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the K projection |
| v_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the V projection |
| o_proj_weight (Float[Tensor, "d_model d_v"]): Weights for the output projection |
| in_features (Float[Tensor, "... sequence_length d_in"]): Tensor to run your implementation on. |
| |
| Returns: |
| Float[Tensor, " ... sequence_length d_out"]: Tensor with the output of running your optimized, batched multi-headed attention |
| implementation with the given QKV projection weights and input features. |
| """ |
| def __init__(self, d_model: int, num_heads: int, max_seq_len: int = None, theta: float = None, width_ratio: float = 1.0, device=None, dtype=None): |
| super().__init__() |
|
|
| |
| assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads})" |
|
|
| self.d_model = d_model |
| self.num_heads = num_heads |
|
|
| |
| attn_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_MODEL)) |
|
|
| |
| self.q_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) |
| self.k_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) |
| self.v_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) |
| self.output_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| assert theta is None or max_seq_len is not None, "max_seq_len must be provided when theta is given for multi-head self attention with RoPE." |
|
|
| if theta: |
| d_k = d_model//num_heads |
| self.rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len, device, dtype) |
| else: |
| self.rope = None |
|
|
| def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: |
| |
| Q = self.q_proj(x) |
| K = self.k_proj(x) |
| V = self.v_proj(x) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| d_k = self.d_model // self.num_heads |
| d_v = self.d_model // self.num_heads |
|
|
| q_heads = rearrange(Q, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k) |
| k_heads = rearrange(K, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k) |
|
|
| |
| if self.rope: |
| seq_dim = x.shape[-2] |
| if token_positions is None: |
| token_positions = torch.arange(seq_dim, device=x.device) |
| token_positions = rearrange(token_positions, "seq -> 1 seq") |
|
|
| q_heads = self.rope(q_heads, token_positions) |
| k_heads = self.rope(k_heads, token_positions) |
|
|
| v_heads = rearrange(V, "... seq (heads d_v) -> ... heads seq d_v", d_v=d_v) |
|
|
| |
| mha_heads = sdpa(q_heads, k_heads, v_heads, is_causal=True, scale=1.0/d_k) |
| mha = rearrange(mha_heads, "... heads seq d_v -> ... seq (heads d_v)") |
|
|
| |
| out = self.output_proj(mha) |
|
|
| return out |
|
|
| class PrenormBlock(nn.Module): |
| def __init__(self, |
| d_model: int, |
| num_heads: int, |
| d_ff: int, |
| max_seq_len: int, |
| theta: float, |
| width_ratio: float, |
| device=None, |
| dtype=None): |
| super().__init__() |
| |
| self.ln1 = RMSNorm(d_model, device=device, dtype=dtype) |
| |
| self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype) |
| |
| |
| self.ln2 = RMSNorm(d_model, device=device, dtype=dtype) |
| |
| self.ffn = PositionwiseFeedforward(d_model, d_ff, width_ratio, device, dtype) |
| |
|
|
| def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: |
|
|
| |
| norm1_out = self.ln1(x) |
| |
| attn_out = self.attn(norm1_out, token_positions) |
|
|
| |
| assert(x.shape == attn_out.shape) |
| resid1_out = attn_out + x |
|
|
| |
| norm2_out = self.ln2(resid1_out) |
| ffn_out = self.ffn(norm2_out) |
|
|
| |
| assert(ffn_out.shape == resid1_out.shape) |
| final_out = resid1_out + ffn_out |
| return final_out |
|
|
| class MuTransformer(nn.Module): |
| def __init__( |
| self, vocab_size: int, |
| context_length: int, |
| d_model: int, |
| num_layers: int, |
| num_heads: int, |
| d_ff: int, |
| rope_theta: float, |
| width_ratio: float = 1.0, |
| weight_tying: bool = False, |
| device=None, dtype=None): |
| super().__init__() |
| self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) |
| self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)]) |
| self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) |
| self.weight_tying = weight_tying |
| if weight_tying: |
| self.lm_head = self.token_embeddings.weight |
| else: |
| std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) |
| self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype) |
| self.width_ratio = width_ratio |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| x = self.token_embeddings(x) |
|
|
| |
| for layer in self.layers: |
| x = layer(x) |
|
|
| |
| x = self.ln_final(x) |
|
|
| |
| if self.weight_tying: |
| x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio |
| else: |
| x = self.lm_head(x) |
|
|
| |
| return x |
|
|
| """ Looped Language Models implemented with MuP """ |
|
|
| class LoopedStack(nn.Module): |
| def __init__( |
| self, |
| context_length: int, |
| d_model: int, |
| num_layers_in_stack: int, |
| num_heads: int, |
| d_ff: int, |
| rope_theta: float, |
| width_ratio: float = 1.0, |
| mixture_of_experts: bool = False, |
| num_experts: Optional[int] = None, |
| num_active: Optional[int] = None, |
| device=None, dtype=None): |
| super().__init__() |
| if mixture_of_experts: |
| |
| |
| |
| self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active, |
| context_length, rope_theta, width_ratio, device, dtype) |
| for _ in range(num_layers_in_stack)]) |
| else: |
| self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta, |
| width_ratio, device, dtype) for _ in range(num_layers_in_stack)]) |
| self.mixture_of_experts = mixture_of_experts |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| if self.mixture_of_experts: |
| lb_total = 0 |
| lz_total = 0 |
| |
| for layer in self.layers: |
| x, lb, lz = layer(x) |
| lb_total += lb |
| lz_total += lz |
| return x, lb_total, lz_total |
| else: |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
| class LoopedTransformer(nn.Module): |
| def __init__( |
| self, |
| vocab_size: int, |
| context_length: int, |
| d_model: int, |
| num_layers_in_stack: int, |
| num_stacks: int, |
| num_heads: int, |
| d_ff: int, |
| rope_theta: float, |
| width_ratio: float = 1.0, |
| weight_tying: bool = False, |
| device=None, dtype=None): |
| super().__init__() |
| self.num_stacks = num_stacks |
|
|
| self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) |
| self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads, d_ff, rope_theta, width_ratio, device=device, dtype=dtype) |
| self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) |
| self.weight_tying = weight_tying |
| self.width_ratio = width_ratio |
|
|
| if weight_tying: |
| self.lm_head = self.token_embeddings.weight |
| else: |
| std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) |
| self.lm_head = Linear(d_model, vocab_size, width_ratio, std_base_lm_head, device=device, dtype=dtype) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| x = self.token_embeddings(x) |
|
|
| |
| for i in range(self.num_stacks): |
| x = self.stack(x) |
|
|
| |
| x = self.ln_final(x) |
|
|
| |
| if self.weight_tying: |
| x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio |
| else: |
| x = self.lm_head(x) |
|
|
| return x |
|
|
| """ Mixture-of-Experts Implementation in muP """ |
|
|
| |
| class Router(nn.Module): |
| def __init__(self, d_model: int, num_experts: int, num_active=None, width_ratio: float = 1.0, device=None, dtype=None): |
| super().__init__() |
| |
| std_base = math.sqrt(2/(BASE_D_MODEL+num_experts)) |
| self.gate = Linear(d_model, num_experts, width_ratio, std_base, device=device, dtype=dtype) |
| self.num_active = num_active |
|
|
| def forward(self, x: Tensor): |
| |
| logits = self.gate(x) |
|
|
| |
| probs = softmax(logits, dim=-1) |
|
|
| |
| top_scores, top_experts = torch.topk(probs, k=self.num_active, dim=-1) |
|
|
| |
| score_sums = torch.sum(top_scores, dim=-1, keepdim=True) |
| top_scores = top_scores/score_sums |
|
|
| return logits, probs, top_scores, top_experts |
|
|
| class MoEPrenormBlock(nn.Module): |
| def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int, |
| max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None): |
| super().__init__() |
| |
| self.ln1 = RMSNorm(d_model, device=device, dtype=dtype) |
|
|
| |
| self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype) |
|
|
| |
| self.ln2 = RMSNorm(d_model, device=device, dtype=dtype) |
|
|
| |
| self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype) |
|
|
| |
| self.num_experts = num_experts |
| self.num_active = num_active |
|
|
| |
| d_ff_expert = d_ff // num_active |
| self.experts = nn.ModuleList([PositionwiseFeedforward(d_model, d_ff_expert, width_ratio, device, dtype) for _ in range(num_experts)]) |
|
|
| def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: |
| |
| batch, seq, dim = x.shape |
|
|
| |
| norm1_out = self.ln1(x) |
| |
| attn_out = self.attn(norm1_out, token_positions) |
|
|
| |
| assert(x.shape == attn_out.shape) |
| resid1_out = attn_out + x |
|
|
| |
| norm2_out = self.ln2(resid1_out) |
|
|
| |
| logits, probs, top_scores, top_experts = self.router(norm2_out) |
| expert_mean_probs = torch.mean(probs, dim=(0, 1)) |
|
|
| |
| experts_out = torch.zeros_like(norm2_out) |
| total_tokens_assigned = batch*seq*self.num_active |
| lb_sum = 0 |
|
|
| for expert_idx in range(self.num_experts): |
| |
| expert_mask = (top_experts == expert_idx) |
| embed_mask = expert_mask.any(dim=-1) |
| if not embed_mask.any(): continue |
| pi = expert_mean_probs[expert_idx].item() |
| fi = (expert_mask.sum().item())/total_tokens_assigned |
| lb_sum += fi*pi |
|
|
| |
| weights = top_scores[expert_mask] |
| expert_embeds = norm2_out[embed_mask] |
|
|
| |
| expert_out = self.experts[expert_idx](expert_embeds) |
|
|
| |
| experts_out[embed_mask] += weights.unsqueeze(-1)*expert_out |
|
|
| |
| lb = self.num_experts*lb_sum |
|
|
| |
| logsumexp = torch.logsumexp(logits.float(), dim=-1) |
| lz = torch.mean(logsumexp ** 2) |
|
|
| |
| assert(experts_out.shape == resid1_out.shape) |
| final_out = resid1_out + experts_out |
| return final_out, lb, lz |
|
|
|
|
| class GroupedMoEPrenormBlock(nn.Module): |
| @staticmethod |
| def _init_expert_weights(num_experts, in_features, out_features, width_ratio, std_base, device, dtype) -> nn.Parameter: |
| w = torch.empty(num_experts, in_features, out_features, device=device, dtype=dtype) |
| std_scaled = std_base / math.sqrt(width_ratio) |
| nn.init.trunc_normal_(w, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled) |
| return nn.Parameter(w) |
|
|
| def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int, |
| max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None): |
| super().__init__() |
| |
| self.ln1 = RMSNorm(d_model, device=device, dtype=dtype) |
|
|
| |
| self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype) |
|
|
| |
| self.ln2 = RMSNorm(d_model, device=device, dtype=dtype) |
|
|
| |
| self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype) |
|
|
| |
| self.num_experts = num_experts |
| self.num_active = num_active |
|
|
| |
| d_ff_expert = d_ff // num_active |
|
|
| |
| w_std_base = math.sqrt(2 / (BASE_D_MODEL + BASE_D_FF)) |
| self.experts_w1 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype) |
| self.experts_w2 = self._init_expert_weights(num_experts, d_ff_expert, d_model, width_ratio, w_std_base, device, dtype) |
| self.experts_w3 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype) |
|
|
| def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: |
| batch, seq, dim = x.shape |
| total_tokens = batch * seq |
|
|
| |
| norm1_out = self.ln1(x) |
| attn_out = self.attn(norm1_out, token_positions) |
|
|
| assert(x.shape == attn_out.shape) |
| resid1_out = attn_out + x |
|
|
| |
| norm2_out = self.ln2(resid1_out) |
|
|
| |
| logits, probs, top_scores, top_experts = self.router(norm2_out) |
|
|
| |
| x_flat = rearrange(norm2_out, 'b s d -> (b s) d') |
| flat_expert_ids = rearrange(top_experts, 'b s k -> (b s k)') |
| flat_scores = rearrange(top_scores, 'b s k -> (b s k)') |
| flat_positions = torch.arange(total_tokens, device=x.device) |
| flat_token_ids = repeat(flat_positions, 'n -> (n k)', k=self.num_active) |
|
|
| |
| sort_indices = flat_expert_ids.argsort(stable=True) |
| sorted_expert_ids = flat_expert_ids[sort_indices] |
| sorted_token_ids = flat_token_ids[sort_indices] |
| sorted_scores = flat_scores[sort_indices] |
| sorted_x = x_flat[sorted_token_ids] |
|
|
| |
| counts = torch.bincount(sorted_expert_ids, minlength=self.num_experts) |
| offs = counts.cumsum(0).to(torch.int32) |
|
|
| |
| h1 = grouped_mm(sorted_x, self.experts_w1, offs=offs) |
| h3 = grouped_mm(sorted_x, self.experts_w3, offs=offs) |
| gated = silu(h1) * h3 |
| expert_out = grouped_mm(gated, self.experts_w2, offs=offs) |
|
|
| |
| expert_out = einsum(expert_out, sorted_scores, 'n d, n -> n d') |
| output_flat = torch.zeros(total_tokens, dim, device=x.device, dtype=expert_out.dtype) |
| output_flat.index_add_(0, sorted_token_ids, expert_out) |
|
|
| |
| experts_out = rearrange(output_flat, '(b s) d -> b s d', b=batch, s=seq) |
|
|
| |
| fi = counts.float() / (total_tokens * self.num_active) |
| pi = reduce(probs, 'b s e -> e', 'mean') |
| lb = self.num_experts * einsum(fi, pi, 'e, e ->') |
|
|
| logsumexp = torch.logsumexp(logits.float(), dim=-1) |
| lz = reduce(logsumexp ** 2, '... -> ', 'mean') |
|
|
| |
| assert(experts_out.shape == resid1_out.shape) |
| final_out = resid1_out + experts_out |
| return final_out, lb, lz |
|
|
|
|
| |
| class MoETransformer(nn.Module): |
| def __init__( |
| self, vocab_size: int, |
| context_length: int, |
| d_model: int, |
| num_layers: int, |
| num_heads: int, |
| d_ff: int, |
| num_experts: int, |
| num_active: int, |
| rope_theta: float, |
| width_ratio: float = 1.0, |
| device=None, dtype=None): |
| super().__init__() |
| self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) |
| self.num_layers = num_layers |
| |
| |
| self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active, |
| context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)]) |
| self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) |
|
|
| |
| std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) |
| self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| lb_total = 0 |
| lz_total = 0 |
|
|
| |
| x = self.token_embeddings(x) |
|
|
| |
| for layer in self.layers: |
| x, lb, lz = layer(x) |
| lb_total += lb |
| lz_total += lz |
|
|
| |
| x = self.ln_final(x) |
|
|
| |
| x = self.lm_head(x) |
|
|
| |
| lb_avg = lb_total / self.num_layers |
| lz_avg = lz_total / self.num_layers |
|
|
| return x, lb_avg, lz_avg |
|
|
| class LoopedMoETransformer(nn.Module): |
| def __init__( |
| self, vocab_size: int, |
| context_length: int, |
| d_model: int, |
| num_layers_in_stack: int, |
| num_stacks: int, |
| num_heads: int, |
| d_ff: int, |
| num_experts: int, |
| num_active: int, |
| rope_theta: float, |
| width_ratio: float, |
| device=None, dtype=None): |
| super().__init__() |
| self.stack_depth = num_stacks |
| self.total_layers = num_stacks*num_layers_in_stack |
| self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) |
| self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads, |
| d_ff, rope_theta, width_ratio, mixture_of_experts=True, |
| num_experts=num_experts, num_active=num_active, |
| device=device, dtype=dtype) |
| self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) |
|
|
| |
| std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) |
| self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype) |
|
|
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| lb_total = 0 |
| lz_total = 0 |
|
|
| |
| x = self.token_embeddings(x) |
|
|
| |
| for i in range(self.stack_depth): |
| x, lb, lz = self.stack(x) |
| lb_total += lb |
| lz_total += lz |
|
|
| |
| x = self.ln_final(x) |
|
|
| |
| x = self.lm_head(x) |
|
|
| |
| lb_avg = lb_total / self.total_layers |
| lz_avg = lz_total / self.total_layers |
|
|
| return x, lb_avg, lz_avg |
|
|
|
|
| |
| |
| |
|
|
| class LoopLMConfig(PretrainedConfig): |
| """Config for all four loop-lm model variants.""" |
|
|
| model_type = "loop-lm" |
|
|
| def __init__( |
| self, |
| |
| model_variant: str = "base", |
| |
| vocab_size: int = 50257, |
| context_length: int = 1024, |
| d_model: int = 1024, |
| num_heads: int = 16, |
| d_ff: int = 2752, |
| rope_theta: float = 10000.0, |
| width_ratio: float = 8.0, |
| |
| num_layers: int = 16, |
| |
| weight_tying: bool = False, |
| |
| num_layers_in_stack: int = 8, |
| num_stacks: int = 2, |
| |
| num_experts: int = 8, |
| num_active: int = 2, |
| |
| lb_loss_factor: float = 0.01, |
| lz_loss_factor: float = 0.001, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.model_variant = model_variant |
| self.vocab_size = vocab_size |
| self.context_length = context_length |
| self.d_model = d_model |
| self.num_heads = num_heads |
| self.d_ff = d_ff |
| self.rope_theta = rope_theta |
| self.width_ratio = width_ratio |
| self.num_layers = num_layers |
| self.weight_tying = weight_tying |
| self.num_layers_in_stack = num_layers_in_stack |
| self.num_stacks = num_stacks |
| self.num_experts = num_experts |
| self.num_active = num_active |
| self.lb_loss_factor = lb_loss_factor |
| self.lz_loss_factor = lz_loss_factor |
| |
| self.max_length = context_length |
|
|
|
|
| class LoopLMForCausalLM(PreTrainedModel, GenerationMixin): |
| """Causal LM wrapper over all four looped-scaling variants. |
| |
| Implements the HuggingFace PreTrainedModel interface so you can: |
| - Upload/download via push_to_hub / from_pretrained |
| - Run lm-evaluation-harness evals |
| - Fine-tune with TRL's SFTTrainer / DPOTrainer |
| """ |
|
|
| config_class = LoopLMConfig |
| |
| _keys_to_ignore_on_load_missing = [] |
|
|
| def __init__(self, config: LoopLMConfig): |
| super().__init__(config) |
| self.model = self._build_inner_model(config) |
| self.post_init() |
|
|
| |
| |
| |
|
|
| def _build_inner_model(self, config: LoopLMConfig): |
| kw = dict( |
| vocab_size=config.vocab_size, |
| context_length=config.context_length, |
| d_model=config.d_model, |
| num_heads=config.num_heads, |
| d_ff=config.d_ff, |
| rope_theta=config.rope_theta, |
| width_ratio=config.width_ratio, |
| |
| ) |
| v = config.model_variant |
| if v == "base": |
| return MuTransformer( |
| **kw, |
| num_layers=config.num_layers, |
| weight_tying=config.weight_tying, |
| ) |
| elif v == "looped": |
| return LoopedTransformer( |
| **kw, |
| num_layers_in_stack=config.num_layers_in_stack, |
| num_stacks=config.num_stacks, |
| weight_tying=config.weight_tying, |
| ) |
| elif v == "moe": |
| return MoETransformer( |
| **kw, |
| num_layers=config.num_layers, |
| num_experts=config.num_experts, |
| num_active=config.num_active, |
| ) |
| elif v == "looped-moe": |
| return LoopedMoETransformer( |
| **kw, |
| num_layers_in_stack=config.num_layers_in_stack, |
| num_stacks=config.num_stacks, |
| num_experts=config.num_experts, |
| num_active=config.num_active, |
| ) |
| else: |
| raise ValueError(f"Unknown model_variant: {v!r}. Choose from: base, looped, moe, looped-moe") |
|
|
| |
| |
| |
|
|
| def get_input_embeddings(self): |
| return self.model.token_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.model.token_embeddings = value |
|
|
| |
| |
| |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| """ |
| Args: |
| input_ids: (batch, seq) |
| attention_mask: ignored — models use a built-in causal mask |
| labels: (batch, seq) token ids; if provided, returns cross-entropy loss. |
| For MoE variants, aux losses (lb + lz) are added to the CE loss. |
| """ |
| is_moe = self.config.model_variant in ("moe", "looped-moe") |
|
|
| if is_moe: |
| logits, lb, lz = self.model(input_ids) |
| else: |
| logits = self.model(input_ids) |
| lb = lz = 0.0 |
|
|
| loss = None |
| if labels is not None: |
| ce_loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| labels.view(-1), |
| ) |
| aux = self.config.lb_loss_factor * lb + self.config.lz_loss_factor * lz |
| loss = ce_loss + aux if self.training else ce_loss |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| ) |
|
|
| |
| |
| |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|