File size: 40,847 Bytes
285d1ee 7cd1022 285d1ee 7cd1022 285d1ee 7cd1022 285d1ee 7cd1022 285d1ee 9ba3056 285d1ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 | """Self-contained modeling file for trust_remote_code use.
This file merges mup_models.py and hf_wrapper.py into a single module with no
imports from looped_scaling.*. It is intended to be placed alongside a
config.json that sets ``auto_map`` / ``model_type = "loop-lm"`` so that
HuggingFace's ``from_pretrained(..., trust_remote_code=True)`` can load it
without requiring the looped_scaling package to be installed.
Supported model variants: "base" (MuTransformer), "looped" (LoopedTransformer),
"moe" (MoETransformer), "looped-moe" (LoopedMoETransformer).
"""
import torch
import math
import sys
import torch.nn as nn
import torch.nn.functional as F
from collections.abc import Callable, Iterable
from einops import rearrange, einsum, reduce, repeat
from typing import IO, Any, BinaryIO, Optional
from torch import Tensor
from collections import Counter, defaultdict
from torch.nn.functional import scaled_dot_product_attention as sdpa # for flash attention
from torch.nn.functional import grouped_mm, silu
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
BASE_D_MODEL = 128
BASE_D_FF = 384
""" Standard Transformer and Components implemented with muP """
# ---------------------------------------------------------------------------
# Numerically stable softmax (inlined from looped_scaling/utils.py)
# ---------------------------------------------------------------------------
def softmax(logits: Tensor, dim: int) -> Tensor:
logits = logits.float()
# get max values over specified dimension
max_values = torch.max(logits, dim=dim, keepdim=True).values
# subtract max_values from x so max element is 0
shifted = logits - max_values # broadcast should work
# get exp of shifted terms
shifted_exps = torch.exp(shifted)
# get sum of shifted terms
shifted_exp_sums = torch.sum(shifted_exps, dim=dim, keepdim=True)
# calculate product
product = shifted_exps / shifted_exp_sums
return product
# y = Wx (no bias terms!)
class Linear(nn.Module):
def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None):
super().__init__()
# Register parameter first so shape is always stored (required for HF meta-device loading)
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device=device))
# for muP, derive initial std deviation from given base model's std_deviation and width ratio
std_scaled = std_base / math.sqrt(width_ratio)
nn.init.trunc_normal_(self.weight, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
def forward(self, x: Tensor) -> Tensor:
# Pytorch standard: on input side of expression, d_in is last dim of x so "... d_in"
# on output side of einsum expression, so "... d_out" follows convention
# to put the output dim last
return einsum(self.weight, x, "d_out d_in, ... d_in -> ... d_out")
class Embedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
super().__init__()
# Register parameter first so shape is always stored (required for HF meta-device loading)
self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device))
# normalize the embeddings to spec
nn.init.trunc_normal_(self.weight, mean=0.0, std=1.0, a=-3, b=3)
def forward(self, token_ids: Tensor) -> Tensor:
# for every id, we need to pull the row vector associated
return self.weight[token_ids]
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
super().__init__()
# for muP no gain parameter on the rms
self.d_model = d_model
self.eps = eps
def forward(self, x: Tensor) -> Tensor:
# upcast input to torch.float32
in_dtype = x.dtype
x = x.to(torch.float32)
# calculate the RMS scalar
# scalar for every ex. in batch, for every emb in sequence
mean_squared_sum = (1/self.d_model)*einsum(x, x, "... seq d, ... seq d -> ... seq")
rms = torch.sqrt(mean_squared_sum + self.eps)
# for muP, no gain on rms norm as is normally applied.
rms_norm = einsum(x, 1/rms, "... seq d, ... seq -> ... seq d")
# return result to original dtype
return rms_norm.to(in_dtype)
class PositionwiseFeedforward(nn.Module):
# SwiGLU(x) = W2(SiLU(W1x)⊙W3x)
def __init__(self, d_model: int, d_ff: int, width_ratio: float, device=None, dtype=None):
super().__init__()
# for muP, calculate the base model's standard deviation
w_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_FF)) # same for all W because d_model+d_ff = d_ff+d_model
# initialize parameters of SWiGLU FFN
self.w1 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype)
self.w2 = Linear(d_ff, d_model, width_ratio, w_std_base, device=device, dtype=dtype)
self.w3 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
# FFN = W2*(SiLU(W1*X) dot W3X)
silu_in = self.w1(x)
silu_out = silu(silu_in) # silu_in * torch.sigmoid(silu_in)
gate = self.w3(x)
gated_prod = silu_out * gate
final_prod = self.w2(gated_prod)
return final_prod
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None, dtype=None):
"""
theta: float Θ value for the RoPE
d_k: int dimension of query and key vectors
max_seq_len: int Maximum sequence length that will be inputted
device: torch.device | None = None Device to store the buffer on
"""
super().__init__()
rotations = torch.empty(max_seq_len, d_k//2, 2, 2, device=device, dtype=dtype)
# initialize rotation matrix
for i in range(max_seq_len):
for k in range(d_k//2):
angle = i/(theta**(2*k/d_k))
rot = Tensor([[math.cos(angle), -math.sin(angle)],
[math.sin(angle), math.cos(angle)]])
rotations[i, k, :] = rot
self.register_buffer("rotations", rotations, persistent=True)
def forward(self, x: Tensor, token_positions: Tensor) -> Tensor:
"""
self.rotations shape: (seq_dim, feature_dim, 2, 2)
x: tensor of shape (..., seq_dim, feature_dim)
token_positions: tensor of shape (..., seq_dim)
"""
# get the correct rotation matrices
# by default, 0'th dim of array_indexed is index dim, last dim of indices is feature dim
rot = self.rotations[token_positions].to(dtype=x.dtype) # match activation dtype (buffer is float32, activations may be bfloat16)
# rearrange by every two elements along feature dim of input x
x_pairs = rearrange(x, "... seq_dim (feature_dim i) -> ... seq_dim feature_dim i", i=2)
# apply rotations to these. for each pairwise position is A@x->y : (ixj)@(j,)->(i,)
y_pairs = einsum(rot, x_pairs, "... seq_dim feature_dim i j, ... seq_dim feature_dim j -> ... seq_dim feature_dim i")
# reshape y_pairs back to original shape
y = rearrange(y_pairs, "... seq_dim feature_dim i -> ... seq_dim (feature_dim i)")
return y
def scaled_dot_product_attention(
Q: Tensor,
K: Tensor,
V: Tensor,
mask: Optional[Tensor] = None,
) -> Tensor:
"""
Given key (K), query (Q), and value (V) tensors, return
the output of your scaled dot product attention implementation.
Args:
let m be seq length of inputs, n be seq length of outputs
d_k is look-up dim, d_v is value dim
Q (Float[Tensor, "batch ... n d_k"]): Query tensor
K (Float[Tensor, "batch ... m d_k"]): Key tensor
V (Float[Tensor, "batch ... m d_v"]): Values tensor
mask (Float[Tensor, " ... n m"] | None): Mask tensor
Returns:
Float[Tensor, " ... n d_v"]: Output of SDPA
"""
# get the key feature dim (should be last dim of Q and K)
d_k = Q.shape[-1]
assert d_k == K.shape[-1]
# calculate the weighted scores (similarity product). for muP, scale by d_k not sqrt(d_k)
scores = einsum(Q, K, "... n d_k, ... m d_k -> ... n m") / d_k
# apply the mask if there is one
if mask is not None:
bool_mask = mask.bool() # compatible if somehow, input is mask bool or if float
attn_mask = torch.where(bool_mask, 0.0, float('-inf')).to(scores.dtype)
scores = scores + attn_mask
# calculate the weighted
weights = softmax(scores, dim=-1) # the softmax should be taken over the m inputs at an i'th output pos.
# return weights@V
return einsum(weights, V, "... n m, ... m d_v -> ... n d_v")
class MultiheadSelfAttention(nn.Module):
"""
Args:
d_model (int): Dimensionality of the feedforward input and output.
num_heads (int): Number of heads to use in multi-headed attention.
max_seq_len (int): Maximum sequence length to pre-cache if your implementation does that.
q_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the Q projection
k_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the K projection
v_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the V projection
o_proj_weight (Float[Tensor, "d_model d_v"]): Weights for the output projection
in_features (Float[Tensor, "... sequence_length d_in"]): Tensor to run your implementation on.
Returns:
Float[Tensor, " ... sequence_length d_out"]: Tensor with the output of running your optimized, batched multi-headed attention
implementation with the given QKV projection weights and input features.
"""
def __init__(self, d_model: int, num_heads: int, max_seq_len: int = None, theta: float = None, width_ratio: float = 1.0, device=None, dtype=None):
super().__init__()
# initialize the multi-head self attention weights as 1 large matrix (which will be sliced)
assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
self.d_model = d_model
self.num_heads = num_heads
# for muP, calculate standard deviation of base model
attn_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_MODEL))
# for muP, initialize the Wq,Wk,Wv,Wo linear weights with width_ratio and base model's stddev
self.q_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
self.k_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
self.v_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
self.output_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
# # Removed for torch sdpa, uncomment if using normal code
# if max_seq_len:
# causal_mask = torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=dtype, device=device))
# self.register_buffer("causal_mask", causal_mask, persistent=False)
# else:
# self.register_buffer("causal_mask", None, persistent=False)
assert theta is None or max_seq_len is not None, "max_seq_len must be provided when theta is given for multi-head self attention with RoPE."
if theta:
d_k = d_model//num_heads
self.rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len, device, dtype)
else:
self.rope = None
def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
# get Q, K, V matrices
Q = self.q_proj(x) # output shape is [batch seq d_model]
K = self.k_proj(x)
V = self.v_proj(x)
# #create causal mask intepreting the second to last dim as seq dim
# if self.causal_mask is None:
# seq_dim = x.shape[-2]
# cmask = torch.tril(torch.ones(seq_dim, seq_dim, dtype=x.dtype, device=x.device))
# else:
# # Slice the pre-computed mask to match actual sequence length (could be < than max_seq_len)
# seq_dim = x.shape[-2]
# cmask = self.causal_mask[:seq_dim, :seq_dim]
# get slice size for multi-head self attention
d_k = self.d_model // self.num_heads
d_v = self.d_model // self.num_heads
q_heads = rearrange(Q, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k)
k_heads = rearrange(K, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k)
# apply RoPE to q_heads and k_heads
if self.rope:
seq_dim = x.shape[-2] # x is (b,s,d)
if token_positions is None:
token_positions = torch.arange(seq_dim, device=x.device)
token_positions = rearrange(token_positions, "seq -> 1 seq") # 1 seq allows broadcast across batch dim
q_heads = self.rope(q_heads, token_positions)
k_heads = self.rope(k_heads, token_positions)
v_heads = rearrange(V, "... seq (heads d_v) -> ... heads seq d_v", d_v=d_v)
#mha_heads = scaled_dot_product_attention(q_heads, k_heads, v_heads, cmask)
mha_heads = sdpa(q_heads, k_heads, v_heads, is_causal=True, scale=1.0/d_k)
mha = rearrange(mha_heads, "... heads seq d_v -> ... seq (heads d_v)")
# apply o_proj_weight to the concatenated multi-head attention product
out = self.output_proj(mha)
return out
class PrenormBlock(nn.Module):
def __init__(self,
d_model: int,
num_heads: int,
d_ff: int,
max_seq_len: int,
theta: float,
width_ratio: float,
device=None,
dtype=None):
super().__init__()
# norm layer
self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
# mhsa with rope
self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
# add step
# norm layer
self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
# positionwise feed forward
self.ffn = PositionwiseFeedforward(d_model, d_ff, width_ratio, device, dtype)
# add to output
def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
# first Tx operation, Norm + MHSA w/ RoPE
norm1_out = self.ln1(x)
# we may have to define token_positions if it is not given
attn_out = self.attn(norm1_out, token_positions)
# ensure no broadcasting, elementwise addition on [batch seq d_model]
assert(x.shape == attn_out.shape)
resid1_out = attn_out + x
# second Tx operation, Norm + SwiGLU
norm2_out = self.ln2(resid1_out)
ffn_out = self.ffn(norm2_out)
# ensure no broadcasting, elementwise addition
assert(ffn_out.shape == resid1_out.shape)
final_out = resid1_out + ffn_out
return final_out
class MuTransformer(nn.Module):
def __init__(
self, vocab_size: int,
context_length: int,
d_model: int,
num_layers: int,
num_heads: int,
d_ff: int,
rope_theta: float,
width_ratio: float = 1.0,
weight_tying: bool = False,
device=None, dtype=None):
super().__init__()
self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)])
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
self.weight_tying = weight_tying
if weight_tying:
self.lm_head = self.token_embeddings.weight
else:
std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
self.width_ratio = width_ratio
def forward(self, x: Tensor) -> Tensor:
# 1. token embed step, no muP alpha_in
x = self.token_embeddings(x)
# 2. prenorm blocks step
for layer in self.layers:
x = layer(x)
# 3. Final norm
x = self.ln_final(x)
# 4. unembed layer, muP implemented as scaling on init variance and lr of lm_head, not output scaling
if self.weight_tying:
x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio
else:
x = self.lm_head(x)
# 5. return output, no muP alpha_out
return x
""" Looped Language Models implemented with MuP """
class LoopedStack(nn.Module):
def __init__(
self,
context_length: int,
d_model: int,
num_layers_in_stack: int,
num_heads: int,
d_ff: int,
rope_theta: float,
width_ratio: float = 1.0,
mixture_of_experts: bool = False,
num_experts: Optional[int] = None,
num_active: Optional[int] = None,
device=None, dtype=None):
super().__init__()
if mixture_of_experts:
# self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active,
# context_length,rope_theta,width_ratio,device,dtype)
# for _ in range(num_layers_in_stack)])
self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active,
context_length, rope_theta, width_ratio, device, dtype)
for _ in range(num_layers_in_stack)])
else:
self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta,
width_ratio, device, dtype) for _ in range(num_layers_in_stack)])
self.mixture_of_experts = mixture_of_experts
def forward(self, x: Tensor) -> Tensor:
# prenorm blocks step
if self.mixture_of_experts:
lb_total = 0
lz_total = 0
# sum up load balancing and z-losses across each layer
for layer in self.layers:
x, lb, lz = layer(x)
lb_total += lb
lz_total += lz
return x, lb_total, lz_total
else:
for layer in self.layers:
x = layer(x)
return x
class LoopedTransformer(nn.Module):
def __init__(
self,
vocab_size: int,
context_length: int,
d_model: int,
num_layers_in_stack: int,
num_stacks: int,
num_heads: int,
d_ff: int,
rope_theta: float,
width_ratio: float = 1.0,
weight_tying: bool = False,
device=None, dtype=None):
super().__init__()
self.num_stacks = num_stacks
self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads, d_ff, rope_theta, width_ratio, device=device, dtype=dtype)
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
self.weight_tying = weight_tying
self.width_ratio = width_ratio
if weight_tying:
self.lm_head = self.token_embeddings.weight
else:
std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
self.lm_head = Linear(d_model, vocab_size, width_ratio, std_base_lm_head, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
# token embed step
x = self.token_embeddings(x)
# repeated calls to stack
for i in range(self.num_stacks):
x = self.stack(x)
# final norm
x = self.ln_final(x)
# Vocab projection or lm_head
if self.weight_tying:
x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio
else:
x = self.lm_head(x)
return x
""" Mixture-of-Experts Implementation in muP """
# Router Class
class Router(nn.Module):
def __init__(self, d_model: int, num_experts: int, num_active=None, width_ratio: float = 1.0, device=None, dtype=None):
super().__init__()
# router is simply a linear layer. we initialize (d_in, d_out) according to my code
std_base = math.sqrt(2/(BASE_D_MODEL+num_experts))
self.gate = Linear(d_model, num_experts, width_ratio, std_base, device=device, dtype=dtype) # adjusted for muP
self.num_active = num_active
def forward(self, x: Tensor):
# returns scores, top_k_scores, top_k_indices
logits = self.gate(x) # should be shape (batch, seq, n_routers)
# probs
probs = softmax(logits, dim=-1)
# get top_k
top_scores, top_experts = torch.topk(probs, k=self.num_active, dim=-1)
# renormalize the top scores so weighted sum of expert products can be taken
score_sums = torch.sum(top_scores, dim=-1, keepdim=True) # (batch, seq)
top_scores = top_scores/score_sums
return logits, probs, top_scores, top_experts
class MoEPrenormBlock(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int,
max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None):
super().__init__()
# norm layer before mHSA+RoPE
self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
# mhsa with rope
self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
# norm layer before position-wise feedfoward
self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
# router
self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype)
# save MoE hyperparams
self.num_experts = num_experts
self.num_active = num_active
# initialize MoE FFNs as a module list
d_ff_expert = d_ff // num_active
self.experts = nn.ModuleList([PositionwiseFeedforward(d_model, d_ff_expert, width_ratio, device, dtype) for _ in range(num_experts)]) # adjusted for muP
def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
# input dims
batch, seq, dim = x.shape
# first Tx operation, Norm + MHSA w/ RoPE
norm1_out = self.ln1(x)
# we may have to define token_positions if it is not given
attn_out = self.attn(norm1_out, token_positions)
# ensure no broadcasting, elementwise addition on [batch seq d_model]
assert(x.shape == attn_out.shape)
resid1_out = attn_out + x
# prenorm before position-wise feedforward
norm2_out = self.ln2(resid1_out)
# get scores from Router. returns shape (batch,seq,k)
logits, probs, top_scores, top_experts = self.router(norm2_out) # logits and probs are (batch, seq, n_routers)
expert_mean_probs = torch.mean(probs, dim=(0, 1)) # take mean across batch and seq dims
# apply mixture of experts
experts_out = torch.zeros_like(norm2_out) # copies shape, device and dtype
total_tokens_assigned = batch*seq*self.num_active
lb_sum = 0
for expert_idx in range(self.num_experts):
# get masks for expert selection
expert_mask = (top_experts == expert_idx)
embed_mask = expert_mask.any(dim=-1) # if any of the k is expert, we want to transform embed
if not embed_mask.any(): continue
pi = expert_mean_probs[expert_idx].item()
fi = (expert_mask.sum().item())/total_tokens_assigned # num embeds assigned to expert in batch
lb_sum += fi*pi
# extract embeds and weights for activated experts
weights = top_scores[expert_mask] # (num_embeds)
expert_embeds = norm2_out[embed_mask] # (num_embeds, hidden_dim)
# forward for the correct experts
expert_out = self.experts[expert_idx](expert_embeds) # Vanilla Implementation
# map back to experts output
experts_out[embed_mask] += weights.unsqueeze(-1)*expert_out # broadcast elementwise multiply by hidden dim
# calculate batch's load balancing loss
lb = self.num_experts*lb_sum
# calculate batch's router z loss
logsumexp = torch.logsumexp(logits.float(), dim=-1)
lz = torch.mean(logsumexp ** 2)
# ensure no broadcasting, elementwise addition
assert(experts_out.shape == resid1_out.shape)
final_out = resid1_out + experts_out
return final_out, lb, lz
class GroupedMoEPrenormBlock(nn.Module):
@staticmethod
def _init_expert_weights(num_experts, in_features, out_features, width_ratio, std_base, device, dtype) -> nn.Parameter:
w = torch.empty(num_experts, in_features, out_features, device=device, dtype=dtype) # (batch, in, out)
std_scaled = std_base / math.sqrt(width_ratio)
nn.init.trunc_normal_(w, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
return nn.Parameter(w)
def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int,
max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None):
super().__init__()
# norm layer before mHSA+RoPE
self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
# mhsa with rope
self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
# norm layer before position-wise feedfoward
self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
# router
self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype)
# save MoE hyperparams
self.num_experts = num_experts
self.num_active = num_active
# initialize MoE FFNs as a module list
d_ff_expert = d_ff // num_active
# expose and stack the MoE SwiGLU weights for all experts. with experts in string, optimizer scales weights by width_ratio
w_std_base = math.sqrt(2 / (BASE_D_MODEL + BASE_D_FF))
self.experts_w1 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype)
self.experts_w2 = self._init_expert_weights(num_experts, d_ff_expert, d_model, width_ratio, w_std_base, device, dtype)
self.experts_w3 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype)
def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
batch, seq, dim = x.shape
total_tokens = batch * seq
# first Tx operation, Norm + MHSA w/ RoPE
norm1_out = self.ln1(x)
attn_out = self.attn(norm1_out, token_positions)
assert(x.shape == attn_out.shape)
resid1_out = attn_out + x
# prenorm before position-wise feedforward
norm2_out = self.ln2(resid1_out)
# get scores from Router. returns shape (batch, seq, k)
logits, probs, top_scores, top_experts = self.router(norm2_out)
# flatten to 2D for grouped_mm
x_flat = rearrange(norm2_out, 'b s d -> (b s) d') # (total_tokens, d_model)
flat_expert_ids = rearrange(top_experts, 'b s k -> (b s k)') # (total_tokens * k,)
flat_scores = rearrange(top_scores, 'b s k -> (b s k)') # (total_tokens * k,)
flat_positions = torch.arange(total_tokens, device=x.device) # (total_tokens)
flat_token_ids = repeat(flat_positions, 'n -> (n k)', k=self.num_active) # (total_tokens * k)
# sort by expert
sort_indices = flat_expert_ids.argsort(stable=True)
sorted_expert_ids = flat_expert_ids[sort_indices]
sorted_token_ids = flat_token_ids[sort_indices]
sorted_scores = flat_scores[sort_indices]
sorted_x = x_flat[sorted_token_ids] # (total_tokens * k, d_model)
# build offs (cumulative token counts per expert)
counts = torch.bincount(sorted_expert_ids, minlength=self.num_experts)
offs = counts.cumsum(0).to(torch.int32) # (num_experts,)
# grouped SwiGLU: W2(SiLU(W1 x) dot W3 x)
h1 = grouped_mm(sorted_x, self.experts_w1, offs=offs)
h3 = grouped_mm(sorted_x, self.experts_w3, offs=offs)
gated = silu(h1) * h3
expert_out = grouped_mm(gated, self.experts_w2, offs=offs) # (total_tokens * k, d_model)
# weight by router scores and scatter-add back
expert_out = einsum(expert_out, sorted_scores, 'n d, n -> n d')
output_flat = torch.zeros(total_tokens, dim, device=x.device, dtype=expert_out.dtype)
output_flat.index_add_(0, sorted_token_ids, expert_out)
# reshape back to (batch, seq, d_model)
experts_out = rearrange(output_flat, '(b s) d -> b s d', b=batch, s=seq)
# aux losses
fi = counts.float() / (total_tokens * self.num_active)
pi = reduce(probs, 'b s e -> e', 'mean')
lb = self.num_experts * einsum(fi, pi, 'e, e ->')
logsumexp = torch.logsumexp(logits.float(), dim=-1)
lz = reduce(logsumexp ** 2, '... -> ', 'mean')
# residual connection
assert(experts_out.shape == resid1_out.shape)
final_out = resid1_out + experts_out
return final_out, lb, lz
# MoE Implementation
class MoETransformer(nn.Module):
def __init__(
self, vocab_size: int,
context_length: int,
d_model: int,
num_layers: int,
num_heads: int,
d_ff: int,
num_experts: int,
num_active: int,
rope_theta: float,
width_ratio: float = 1.0,
device=None, dtype=None):
super().__init__()
self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
self.num_layers = num_layers
# self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active,
# context_length,rope_theta,width_ratio,device,dtype) for _ in range(num_layers)])
self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active,
context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)])
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
# only non-tied embeddings now
std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
# collect aux losses
lb_total = 0
lz_total = 0
# 1. token embed step
x = self.token_embeddings(x)
# 2. prenorm blocks step
for layer in self.layers:
x, lb, lz = layer(x)
lb_total += lb
lz_total += lz
# 3. Final norm
x = self.ln_final(x)
# 4. Vocab projection or lm_head
x = self.lm_head(x)
# calculate average layer aux loss
lb_avg = lb_total / self.num_layers
lz_avg = lz_total / self.num_layers
return x, lb_avg, lz_avg
class LoopedMoETransformer(nn.Module):
def __init__(
self, vocab_size: int,
context_length: int,
d_model: int,
num_layers_in_stack: int,
num_stacks: int,
num_heads: int,
d_ff: int,
num_experts: int,
num_active: int,
rope_theta: float,
width_ratio: float,
device=None, dtype=None):
super().__init__()
self.stack_depth = num_stacks
self.total_layers = num_stacks*num_layers_in_stack
self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads,
d_ff, rope_theta, width_ratio, mixture_of_experts=True,
num_experts=num_experts, num_active=num_active,
device=device, dtype=dtype) # parameters for loop with MoE
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
# scale lm head
std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
# collect aux losses
lb_total = 0
lz_total = 0
# token embed step
x = self.token_embeddings(x)
# repeated calls to stack
for i in range(self.stack_depth):
x, lb, lz = self.stack(x)
lb_total += lb
lz_total += lz
# final norm
x = self.ln_final(x)
# Vocab projection or lm_head
x = self.lm_head(x)
# calculate aux loss averages
lb_avg = lb_total / self.total_layers
lz_avg = lz_total / self.total_layers
return x, lb_avg, lz_avg
# ---------------------------------------------------------------------------
# HuggingFace wrapper (from hf_wrapper.py)
# ---------------------------------------------------------------------------
class LoopLMConfig(PretrainedConfig):
"""Config for all four loop-lm model variants."""
model_type = "loop-lm"
def __init__(
self,
# which of the four architectures to use
model_variant: str = "base", # "base" | "looped" | "moe" | "looped-moe"
# shared
vocab_size: int = 50257,
context_length: int = 1024,
d_model: int = 1024,
num_heads: int = 16,
d_ff: int = 2752,
rope_theta: float = 10000.0,
width_ratio: float = 8.0, # d_model / base_d_model (128); set at training time
# base + moe only
num_layers: int = 16,
# base + looped only
weight_tying: bool = False,
# looped + looped-moe only
num_layers_in_stack: int = 8,
num_stacks: int = 2,
# moe + looped-moe only
num_experts: int = 8,
num_active: int = 2,
# aux loss weights — used when forward() is called with labels
lb_loss_factor: float = 0.01,
lz_loss_factor: float = 0.001,
**kwargs,
):
super().__init__(**kwargs)
self.model_variant = model_variant
self.vocab_size = vocab_size
self.context_length = context_length
self.d_model = d_model
self.num_heads = num_heads
self.d_ff = d_ff
self.rope_theta = rope_theta
self.width_ratio = width_ratio
self.num_layers = num_layers
self.weight_tying = weight_tying
self.num_layers_in_stack = num_layers_in_stack
self.num_stacks = num_stacks
self.num_experts = num_experts
self.num_active = num_active
self.lb_loss_factor = lb_loss_factor
self.lz_loss_factor = lz_loss_factor
# lm-evaluation-harness looks for this attribute to cap sequence length
self.max_length = context_length
class LoopLMForCausalLM(PreTrainedModel, GenerationMixin):
"""Causal LM wrapper over all four looped-scaling variants.
Implements the HuggingFace PreTrainedModel interface so you can:
- Upload/download via push_to_hub / from_pretrained
- Run lm-evaluation-harness evals
- Fine-tune with TRL's SFTTrainer / DPOTrainer
"""
config_class = LoopLMConfig
# tell HF which parameter holds the output logits for generation
_keys_to_ignore_on_load_missing = []
def __init__(self, config: LoopLMConfig):
super().__init__(config)
self.model = self._build_inner_model(config)
self.post_init()
# ------------------------------------------------------------------
# Model construction
# ------------------------------------------------------------------
def _build_inner_model(self, config: LoopLMConfig):
kw = dict(
vocab_size=config.vocab_size,
context_length=config.context_length,
d_model=config.d_model,
num_heads=config.num_heads,
d_ff=config.d_ff,
rope_theta=config.rope_theta,
width_ratio=config.width_ratio,
# device=None so weights are placed on CPU; caller uses .to(device)
)
v = config.model_variant
if v == "base":
return MuTransformer(
**kw,
num_layers=config.num_layers,
weight_tying=config.weight_tying,
)
elif v == "looped":
return LoopedTransformer(
**kw,
num_layers_in_stack=config.num_layers_in_stack,
num_stacks=config.num_stacks,
weight_tying=config.weight_tying,
)
elif v == "moe":
return MoETransformer(
**kw,
num_layers=config.num_layers,
num_experts=config.num_experts,
num_active=config.num_active,
)
elif v == "looped-moe":
return LoopedMoETransformer(
**kw,
num_layers_in_stack=config.num_layers_in_stack,
num_stacks=config.num_stacks,
num_experts=config.num_experts,
num_active=config.num_active,
)
else:
raise ValueError(f"Unknown model_variant: {v!r}. Choose from: base, looped, moe, looped-moe")
# ------------------------------------------------------------------
# Embedding access (required by some HF utilities)
# ------------------------------------------------------------------
def get_input_embeddings(self):
return self.model.token_embeddings
def set_input_embeddings(self, value):
self.model.token_embeddings = value
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None, # causal mask is handled internally
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
"""
Args:
input_ids: (batch, seq)
attention_mask: ignored — models use a built-in causal mask
labels: (batch, seq) token ids; if provided, returns cross-entropy loss.
For MoE variants, aux losses (lb + lz) are added to the CE loss.
"""
is_moe = self.config.model_variant in ("moe", "looped-moe")
if is_moe:
logits, lb, lz = self.model(input_ids)
else:
logits = self.model(input_ids)
lb = lz = 0.0
loss = None
if labels is not None:
ce_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
)
aux = self.config.lb_loss_factor * lb + self.config.lz_loss_factor * lz
loss = ce_loss + aux if self.training else ce_loss
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
)
# ------------------------------------------------------------------
# Generation support (no KV cache — generation is correct but slow)
# ------------------------------------------------------------------
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
|