File size: 40,847 Bytes
4e14e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e57525e
 
4e14e99
 
 
e57525e
4e14e99
 
 
 
 
 
 
 
 
 
 
e57525e
 
4e14e99
 
e57525e
4e14e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65ae846
4e14e99
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
"""Self-contained modeling file for trust_remote_code use.

This file merges mup_models.py and hf_wrapper.py into a single module with no
imports from looped_scaling.*. It is intended to be placed alongside a
config.json that sets ``auto_map`` / ``model_type = "loop-lm"`` so that
HuggingFace's ``from_pretrained(..., trust_remote_code=True)`` can load it
without requiring the looped_scaling package to be installed.

Supported model variants: "base" (MuTransformer), "looped" (LoopedTransformer),
"moe" (MoETransformer), "looped-moe" (LoopedMoETransformer).
"""

import torch
import math
import sys
import torch.nn as nn
import torch.nn.functional as F
from collections.abc import Callable, Iterable
from einops import rearrange, einsum, reduce, repeat
from typing import IO, Any, BinaryIO, Optional
from torch import Tensor
from collections import Counter, defaultdict
from torch.nn.functional import scaled_dot_product_attention as sdpa  # for flash attention
from torch.nn.functional import grouped_mm, silu
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast

BASE_D_MODEL = 128
BASE_D_FF = 384

""" Standard Transformer and Components implemented with muP """


# ---------------------------------------------------------------------------
# Numerically stable softmax (inlined from looped_scaling/utils.py)
# ---------------------------------------------------------------------------

def softmax(logits: Tensor, dim: int) -> Tensor:
    logits = logits.float()
    # get max values over specified dimension
    max_values = torch.max(logits, dim=dim, keepdim=True).values

    # subtract max_values from x so max element is 0
    shifted = logits - max_values  # broadcast should work

    # get exp of shifted terms
    shifted_exps = torch.exp(shifted)

    # get sum of shifted terms
    shifted_exp_sums = torch.sum(shifted_exps, dim=dim, keepdim=True)

    # calculate product
    product = shifted_exps / shifted_exp_sums

    return product


# y = Wx (no bias terms!)
class Linear(nn.Module):
    def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None):
        super().__init__()

        # Register parameter first so shape is always stored (required for HF meta-device loading)
        self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device=device))

        # for muP, derive initial std deviation from given base model's std_deviation and width ratio
        std_scaled = std_base / math.sqrt(width_ratio)
        nn.init.trunc_normal_(self.weight, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)

    def forward(self, x: Tensor) -> Tensor:
        # Pytorch standard: on input side of expression, d_in is last dim of x so "... d_in"
        # on output side of einsum expression, so "... d_out" follows convention
        # to put the output dim last
        return einsum(self.weight, x, "d_out d_in, ... d_in -> ... d_out")

class Embedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
        super().__init__()

        # Register parameter first so shape is always stored (required for HF meta-device loading)
        self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device))

        # normalize the embeddings to spec
        nn.init.trunc_normal_(self.weight, mean=0.0, std=1.0, a=-3, b=3)

    def forward(self, token_ids: Tensor) -> Tensor:
        # for every id, we need to pull the row vector associated
        return self.weight[token_ids]

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        super().__init__()

        # for muP no gain parameter on the rms
        self.d_model = d_model
        self.eps = eps

    def forward(self, x: Tensor) -> Tensor:
        # upcast input to torch.float32
        in_dtype = x.dtype
        x = x.to(torch.float32)

        # calculate the RMS scalar
        # scalar for every ex. in batch, for every emb in sequence
        mean_squared_sum = (1/self.d_model)*einsum(x, x, "... seq d, ... seq d -> ... seq")
        rms = torch.sqrt(mean_squared_sum + self.eps)

        # for muP, no gain on rms norm as is normally applied.
        rms_norm = einsum(x, 1/rms, "... seq d, ... seq -> ... seq d")

        # return result to original dtype
        return rms_norm.to(in_dtype)

class PositionwiseFeedforward(nn.Module):
    # SwiGLU(x) = W2(SiLU(W1x)⊙W3x)
    def __init__(self, d_model: int, d_ff: int, width_ratio: float, device=None, dtype=None):
        super().__init__()

        # for muP, calculate the base model's standard deviation
        w_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_FF))  # same for all W because d_model+d_ff = d_ff+d_model

        # initialize parameters of SWiGLU FFN
        self.w1 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype)
        self.w2 = Linear(d_ff, d_model, width_ratio, w_std_base, device=device, dtype=dtype)
        self.w3 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype)

    def forward(self, x: Tensor) -> Tensor:
        # FFN = W2*(SiLU(W1*X) dot W3X)
        silu_in = self.w1(x)
        silu_out = silu(silu_in)  # silu_in * torch.sigmoid(silu_in)
        gate = self.w3(x)
        gated_prod = silu_out * gate
        final_prod = self.w2(gated_prod)
        return final_prod

class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None, dtype=None):
        """
        theta: float Θ value for the RoPE
        d_k: int dimension of query and key vectors
        max_seq_len: int Maximum sequence length that will be inputted
        device: torch.device | None = None Device to store the buffer on
        """
        super().__init__()
        rotations = torch.empty(max_seq_len, d_k//2, 2, 2, device=device, dtype=dtype)

        # initialize rotation matrix
        for i in range(max_seq_len):
            for k in range(d_k//2):
                angle = i/(theta**(2*k/d_k))
                rot = Tensor([[math.cos(angle), -math.sin(angle)],
                                    [math.sin(angle), math.cos(angle)]])
                rotations[i, k, :] = rot

        self.register_buffer("rotations", rotations, persistent=True)


    def forward(self, x: Tensor, token_positions: Tensor) -> Tensor:
        """
        self.rotations shape: (seq_dim, feature_dim, 2, 2)
        x: tensor of shape (..., seq_dim, feature_dim)
        token_positions: tensor of shape (..., seq_dim)
        """
        # get the correct rotation matrices
        # by default, 0'th dim of array_indexed is index dim, last dim of indices is feature dim
        rot = self.rotations[token_positions].to(dtype=x.dtype)  # match activation dtype (buffer is float32, activations may be bfloat16)

        # rearrange by every two elements along feature dim of input x
        x_pairs = rearrange(x, "... seq_dim (feature_dim i) -> ... seq_dim feature_dim i", i=2)

        # apply rotations to these. for each pairwise position is A@x->y : (ixj)@(j,)->(i,)
        y_pairs = einsum(rot, x_pairs, "... seq_dim feature_dim i j, ... seq_dim feature_dim j -> ... seq_dim feature_dim i")

        # reshape y_pairs back to original shape
        y = rearrange(y_pairs, "... seq_dim feature_dim i -> ... seq_dim (feature_dim i)")

        return y

def scaled_dot_product_attention(
        Q: Tensor,
        K: Tensor,
        V: Tensor,
        mask: Optional[Tensor] = None,
        ) -> Tensor:
    """
    Given key (K), query (Q), and value (V) tensors, return
    the output of your scaled dot product attention implementation.

    Args:
        let m be seq length of inputs, n be seq length of outputs
        d_k is look-up dim, d_v is value dim
        Q (Float[Tensor, "batch ... n d_k"]): Query tensor
        K (Float[Tensor, "batch ... m d_k"]): Key tensor
        V (Float[Tensor, "batch ... m d_v"]): Values tensor
        mask (Float[Tensor, " ... n m"] | None): Mask tensor
    Returns:
        Float[Tensor, " ... n d_v"]: Output of SDPA
    """

    # get the key feature dim (should be last dim of Q and K)
    d_k = Q.shape[-1]
    assert d_k == K.shape[-1]

    # calculate the weighted scores (similarity product). for muP, scale by d_k not sqrt(d_k)
    scores = einsum(Q, K, "... n d_k, ... m d_k -> ... n m") / d_k

    # apply the mask if there is one
    if mask is not None:
        bool_mask = mask.bool()  # compatible if somehow, input is mask bool or if float
        attn_mask = torch.where(bool_mask, 0.0, float('-inf')).to(scores.dtype)
        scores = scores + attn_mask

    # calculate the weighted
    weights = softmax(scores, dim=-1)  # the softmax should be taken over the m inputs at an i'th output pos.

    # return weights@V
    return einsum(weights, V, "... n m, ... m d_v -> ... n d_v")

class MultiheadSelfAttention(nn.Module):
    """
    Args:
        d_model (int): Dimensionality of the feedforward input and output.
        num_heads (int): Number of heads to use in multi-headed attention.
        max_seq_len (int): Maximum sequence length to pre-cache if your implementation does that.
        q_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the Q projection
        k_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the K projection
        v_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the V projection
        o_proj_weight (Float[Tensor, "d_model d_v"]): Weights for the output projection
        in_features (Float[Tensor, "... sequence_length d_in"]): Tensor to run your implementation on.

    Returns:
        Float[Tensor, " ... sequence_length d_out"]: Tensor with the output of running your optimized, batched multi-headed attention
        implementation with the given QKV projection weights and input features.
    """
    def __init__(self, d_model: int, num_heads: int, max_seq_len: int = None, theta: float = None, width_ratio: float = 1.0, device=None, dtype=None):
        super().__init__()

        # initialize the multi-head self attention weights as 1 large matrix (which will be sliced)
        assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"

        self.d_model = d_model
        self.num_heads = num_heads

        # for muP, calculate standard deviation of base model
        attn_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_MODEL))

        # for muP, initialize the Wq,Wk,Wv,Wo linear weights with width_ratio and base model's stddev
        self.q_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
        self.k_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
        self.v_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
        self.output_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)

        # # Removed for torch sdpa, uncomment if using normal code
        # if max_seq_len:
        #     causal_mask = torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=dtype, device=device))
        #     self.register_buffer("causal_mask", causal_mask, persistent=False)
        # else:
        #     self.register_buffer("causal_mask", None, persistent=False)

        assert theta is None or max_seq_len is not None, "max_seq_len must be provided when theta is given for multi-head self attention with RoPE."

        if theta:
            d_k = d_model//num_heads
            self.rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len, device, dtype)
        else:
            self.rope = None

    def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
        # get Q, K, V matrices
        Q = self.q_proj(x)  # output shape is [batch seq d_model]
        K = self.k_proj(x)
        V = self.v_proj(x)

        # #create causal mask intepreting the second to last dim as seq dim
        # if self.causal_mask is None:
        #     seq_dim = x.shape[-2]
        #     cmask = torch.tril(torch.ones(seq_dim, seq_dim, dtype=x.dtype, device=x.device))
        # else:
        #     # Slice the pre-computed mask to match actual sequence length (could be < than max_seq_len)
        #     seq_dim = x.shape[-2]
        #     cmask = self.causal_mask[:seq_dim, :seq_dim]

        # get slice size for multi-head self attention
        d_k = self.d_model // self.num_heads
        d_v = self.d_model // self.num_heads

        q_heads = rearrange(Q, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k)
        k_heads = rearrange(K, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k)

        # apply RoPE to q_heads and k_heads
        if self.rope:
            seq_dim = x.shape[-2]  # x is (b,s,d)
            if token_positions is None:
                token_positions = torch.arange(seq_dim, device=x.device)
                token_positions = rearrange(token_positions, "seq -> 1 seq")  # 1 seq allows broadcast across batch dim

            q_heads = self.rope(q_heads, token_positions)
            k_heads = self.rope(k_heads, token_positions)

        v_heads = rearrange(V, "... seq (heads d_v) -> ... heads seq d_v", d_v=d_v)

        #mha_heads = scaled_dot_product_attention(q_heads, k_heads, v_heads, cmask)
        mha_heads = sdpa(q_heads, k_heads, v_heads, is_causal=True, scale=1.0/d_k)
        mha = rearrange(mha_heads, "... heads seq d_v -> ... seq (heads d_v)")

        # apply o_proj_weight to the concatenated multi-head attention product
        out = self.output_proj(mha)

        return out

class PrenormBlock(nn.Module):
    def __init__(self,
                 d_model: int,
                 num_heads: int,
                 d_ff: int,
                 max_seq_len: int,
                 theta: float,
                 width_ratio: float,
                 device=None,
                 dtype=None):
        super().__init__()
        # norm layer
        self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
        # mhsa with rope
        self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
        # add step
        # norm layer
        self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
        # positionwise feed forward
        self.ffn = PositionwiseFeedforward(d_model, d_ff, width_ratio, device, dtype)
        # add to output

    def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:

        # first Tx operation, Norm + MHSA w/ RoPE
        norm1_out = self.ln1(x)
        # we may have to define token_positions if it is not given
        attn_out = self.attn(norm1_out, token_positions)

        # ensure no broadcasting, elementwise addition on [batch seq d_model]
        assert(x.shape == attn_out.shape)
        resid1_out = attn_out + x

        # second Tx operation, Norm + SwiGLU
        norm2_out = self.ln2(resid1_out)
        ffn_out = self.ffn(norm2_out)

        # ensure no broadcasting, elementwise addition
        assert(ffn_out.shape == resid1_out.shape)
        final_out = resid1_out + ffn_out
        return final_out

class MuTransformer(nn.Module):
    def __init__(
            self, vocab_size: int,
            context_length: int,
            d_model: int,
            num_layers: int,
            num_heads: int,
            d_ff: int,
            rope_theta: float,
            width_ratio: float = 1.0,
            weight_tying: bool = False,
            device=None, dtype=None):
       super().__init__()
       self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
       self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)])
       self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
       self.weight_tying = weight_tying
       if weight_tying:
           self.lm_head = self.token_embeddings.weight
       else:
           std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
           self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
       self.width_ratio = width_ratio

    def forward(self, x: Tensor) -> Tensor:
        # 1. token embed step, no muP alpha_in
        x = self.token_embeddings(x)

        # 2. prenorm blocks step
        for layer in self.layers:
            x = layer(x)

        # 3. Final norm
        x = self.ln_final(x)

        # 4. unembed layer, muP implemented as scaling on init variance and lr of lm_head, not output scaling
        if self.weight_tying:
            x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio
        else:
            x = self.lm_head(x)

        # 5. return output, no muP alpha_out
        return x

""" Looped Language Models implemented with MuP """

class LoopedStack(nn.Module):
    def __init__(
            self,
            context_length: int,
            d_model: int,
            num_layers_in_stack: int,
            num_heads: int,
            d_ff: int,
            rope_theta: float,
            width_ratio: float = 1.0,
            mixture_of_experts: bool = False,
            num_experts: Optional[int] = None,
            num_active: Optional[int] = None,
            device=None, dtype=None):
       super().__init__()
       if mixture_of_experts:
           # self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active,
           #                                              context_length,rope_theta,width_ratio,device,dtype)
           #                              for _ in range(num_layers_in_stack)])
           self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active,
                                                               context_length, rope_theta, width_ratio, device, dtype)
                                        for _ in range(num_layers_in_stack)])
       else:
            self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta,
                                                      width_ratio, device, dtype) for _ in range(num_layers_in_stack)])
       self.mixture_of_experts = mixture_of_experts

    def forward(self, x: Tensor) -> Tensor:
        # prenorm blocks step
        if self.mixture_of_experts:
            lb_total = 0
            lz_total = 0
            # sum up load balancing and z-losses across each layer
            for layer in self.layers:
                x, lb, lz = layer(x)
                lb_total += lb
                lz_total += lz
            return x, lb_total, lz_total
        else:
            for layer in self.layers:
                x = layer(x)
            return x

class LoopedTransformer(nn.Module):
    def __init__(
            self,
            vocab_size: int,
            context_length: int,
            d_model: int,
            num_layers_in_stack: int,
            num_stacks: int,
            num_heads: int,
            d_ff: int,
            rope_theta: float,
            width_ratio: float = 1.0,
            weight_tying: bool = False,
            device=None, dtype=None):
       super().__init__()
       self.num_stacks = num_stacks

       self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
       self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads, d_ff, rope_theta, width_ratio, device=device, dtype=dtype)
       self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
       self.weight_tying = weight_tying
       self.width_ratio = width_ratio

       if weight_tying:
           self.lm_head = self.token_embeddings.weight
       else:
           std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
           self.lm_head = Linear(d_model, vocab_size, width_ratio, std_base_lm_head, device=device, dtype=dtype)

    def forward(self, x: Tensor) -> Tensor:
        # token embed step
        x = self.token_embeddings(x)

        # repeated calls to stack
        for i in range(self.num_stacks):
            x = self.stack(x)

        # final norm
        x = self.ln_final(x)

        # Vocab projection or lm_head
        if self.weight_tying:
            x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio
        else:
            x = self.lm_head(x)

        return x

""" Mixture-of-Experts Implementation in muP """

# Router Class
class Router(nn.Module):
    def __init__(self, d_model: int, num_experts: int, num_active=None, width_ratio: float = 1.0, device=None, dtype=None):
        super().__init__()
        # router is simply a linear layer. we initialize (d_in, d_out) according to my code
        std_base = math.sqrt(2/(BASE_D_MODEL+num_experts))
        self.gate = Linear(d_model, num_experts, width_ratio, std_base, device=device, dtype=dtype)  # adjusted for muP
        self.num_active = num_active

    def forward(self, x: Tensor):
        # returns scores, top_k_scores, top_k_indices
        logits = self.gate(x)  # should be shape (batch, seq, n_routers)

        # probs
        probs = softmax(logits, dim=-1)

        # get top_k
        top_scores, top_experts = torch.topk(probs, k=self.num_active, dim=-1)

        # renormalize the top scores so weighted sum of expert products can be taken
        score_sums = torch.sum(top_scores, dim=-1, keepdim=True)  # (batch, seq)
        top_scores = top_scores/score_sums

        return logits, probs, top_scores, top_experts

class MoEPrenormBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int,
                  max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None):
        super().__init__()
        # norm layer before mHSA+RoPE
        self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)

        # mhsa with rope
        self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)

        # norm layer before position-wise feedfoward
        self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)

        # router
        self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype)

        # save MoE hyperparams
        self.num_experts = num_experts
        self.num_active = num_active

        # initialize MoE FFNs as a module list
        d_ff_expert = d_ff // num_active
        self.experts = nn.ModuleList([PositionwiseFeedforward(d_model, d_ff_expert, width_ratio, device, dtype) for _ in range(num_experts)])  # adjusted for muP

    def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
        # input dims
        batch, seq, dim = x.shape

        # first Tx operation, Norm + MHSA w/ RoPE
        norm1_out = self.ln1(x)
        # we may have to define token_positions if it is not given
        attn_out = self.attn(norm1_out, token_positions)

        # ensure no broadcasting, elementwise addition on [batch seq d_model]
        assert(x.shape == attn_out.shape)
        resid1_out = attn_out + x

        # prenorm before position-wise feedforward
        norm2_out = self.ln2(resid1_out)

        # get scores from Router. returns shape (batch,seq,k)
        logits, probs, top_scores, top_experts = self.router(norm2_out)  # logits and probs are (batch, seq, n_routers)
        expert_mean_probs = torch.mean(probs, dim=(0, 1))  # take mean across batch and seq dims

        # apply mixture of experts
        experts_out = torch.zeros_like(norm2_out)  # copies shape, device and dtype
        total_tokens_assigned = batch*seq*self.num_active
        lb_sum = 0

        for expert_idx in range(self.num_experts):
            # get masks for expert selection
            expert_mask = (top_experts == expert_idx)
            embed_mask = expert_mask.any(dim=-1)  # if any of the k is expert, we want to transform embed
            if not embed_mask.any(): continue
            pi = expert_mean_probs[expert_idx].item()
            fi = (expert_mask.sum().item())/total_tokens_assigned  # num embeds assigned to expert in batch
            lb_sum += fi*pi

            # extract embeds and weights for activated experts
            weights = top_scores[expert_mask]  # (num_embeds)
            expert_embeds = norm2_out[embed_mask]  # (num_embeds, hidden_dim)

            # forward for the correct experts
            expert_out = self.experts[expert_idx](expert_embeds)  # Vanilla Implementation

            # map back to experts output
            experts_out[embed_mask] += weights.unsqueeze(-1)*expert_out  # broadcast elementwise multiply by hidden dim

        # calculate batch's load balancing loss
        lb = self.num_experts*lb_sum

        # calculate batch's router z loss
        logsumexp = torch.logsumexp(logits.float(), dim=-1)
        lz = torch.mean(logsumexp ** 2)

        # ensure no broadcasting, elementwise addition
        assert(experts_out.shape == resid1_out.shape)
        final_out = resid1_out + experts_out
        return final_out, lb, lz


class GroupedMoEPrenormBlock(nn.Module):
    @staticmethod
    def _init_expert_weights(num_experts, in_features, out_features, width_ratio, std_base, device, dtype) -> nn.Parameter:
        w = torch.empty(num_experts, in_features, out_features, device=device, dtype=dtype)  # (batch, in, out)
        std_scaled = std_base / math.sqrt(width_ratio)
        nn.init.trunc_normal_(w, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
        return nn.Parameter(w)

    def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int,
                  max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None):
        super().__init__()
        # norm layer before mHSA+RoPE
        self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)

        # mhsa with rope
        self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)

        # norm layer before position-wise feedfoward
        self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)

        # router
        self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype)

        # save MoE hyperparams
        self.num_experts = num_experts
        self.num_active = num_active

        # initialize MoE FFNs as a module list
        d_ff_expert = d_ff // num_active

        # expose and stack the MoE SwiGLU weights for all experts. with experts in string, optimizer scales weights by width_ratio
        w_std_base = math.sqrt(2 / (BASE_D_MODEL + BASE_D_FF))
        self.experts_w1 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype)
        self.experts_w2 = self._init_expert_weights(num_experts, d_ff_expert, d_model, width_ratio, w_std_base, device, dtype)
        self.experts_w3 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype)

    def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
        batch, seq, dim = x.shape
        total_tokens = batch * seq

        # first Tx operation, Norm + MHSA w/ RoPE
        norm1_out = self.ln1(x)
        attn_out = self.attn(norm1_out, token_positions)

        assert(x.shape == attn_out.shape)
        resid1_out = attn_out + x

        # prenorm before position-wise feedforward
        norm2_out = self.ln2(resid1_out)

        # get scores from Router. returns shape (batch, seq, k)
        logits, probs, top_scores, top_experts = self.router(norm2_out)

        # flatten to 2D for grouped_mm
        x_flat = rearrange(norm2_out, 'b s d -> (b s) d')                         # (total_tokens, d_model)
        flat_expert_ids = rearrange(top_experts, 'b s k -> (b s k)')               # (total_tokens * k,)
        flat_scores = rearrange(top_scores, 'b s k -> (b s k)')                    # (total_tokens * k,)
        flat_positions = torch.arange(total_tokens, device=x.device)               # (total_tokens)
        flat_token_ids = repeat(flat_positions, 'n -> (n k)', k=self.num_active)   # (total_tokens * k)

        # sort by expert
        sort_indices = flat_expert_ids.argsort(stable=True)
        sorted_expert_ids = flat_expert_ids[sort_indices]
        sorted_token_ids = flat_token_ids[sort_indices]
        sorted_scores = flat_scores[sort_indices]
        sorted_x = x_flat[sorted_token_ids]                                        # (total_tokens * k, d_model)

        # build offs (cumulative token counts per expert)
        counts = torch.bincount(sorted_expert_ids, minlength=self.num_experts)
        offs = counts.cumsum(0).to(torch.int32)                                    # (num_experts,)

        # grouped SwiGLU: W2(SiLU(W1 x) dot W3 x)
        h1 = grouped_mm(sorted_x, self.experts_w1, offs=offs)
        h3 = grouped_mm(sorted_x, self.experts_w3, offs=offs)
        gated = silu(h1) * h3
        expert_out = grouped_mm(gated, self.experts_w2, offs=offs)                # (total_tokens * k, d_model)

        # weight by router scores and scatter-add back
        expert_out = einsum(expert_out, sorted_scores, 'n d, n -> n d')
        output_flat = torch.zeros(total_tokens, dim, device=x.device, dtype=expert_out.dtype)
        output_flat.index_add_(0, sorted_token_ids, expert_out)

        # reshape back to (batch, seq, d_model)
        experts_out = rearrange(output_flat, '(b s) d -> b s d', b=batch, s=seq)

        # aux losses
        fi = counts.float() / (total_tokens * self.num_active)
        pi = reduce(probs, 'b s e -> e', 'mean')
        lb = self.num_experts * einsum(fi, pi, 'e, e ->')

        logsumexp = torch.logsumexp(logits.float(), dim=-1)
        lz = reduce(logsumexp ** 2, '... -> ', 'mean')

        # residual connection
        assert(experts_out.shape == resid1_out.shape)
        final_out = resid1_out + experts_out
        return final_out, lb, lz


# MoE Implementation
class MoETransformer(nn.Module):
    def __init__(
            self, vocab_size: int,
            context_length: int,
            d_model: int,
            num_layers: int,
            num_heads: int,
            d_ff: int,
            num_experts: int,
            num_active: int,
            rope_theta: float,
            width_ratio: float = 1.0,
            device=None, dtype=None):
       super().__init__()
       self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
       self.num_layers = num_layers
       # self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active,
       #                                              context_length,rope_theta,width_ratio,device,dtype) for _ in range(num_layers)])
       self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active,
                                                           context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)])
       self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)

       # only non-tied embeddings now
       std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
       self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)

    def forward(self, x: Tensor) -> Tensor:
        # collect aux losses
        lb_total = 0
        lz_total = 0

        # 1. token embed step
        x = self.token_embeddings(x)

        # 2. prenorm blocks step
        for layer in self.layers:
            x, lb, lz = layer(x)
            lb_total += lb
            lz_total += lz

        # 3. Final norm
        x = self.ln_final(x)

        # 4. Vocab projection or lm_head
        x = self.lm_head(x)

        # calculate average layer aux loss
        lb_avg = lb_total / self.num_layers
        lz_avg = lz_total / self.num_layers

        return x, lb_avg, lz_avg

class LoopedMoETransformer(nn.Module):
    def __init__(
            self, vocab_size: int,
            context_length: int,
            d_model: int,
            num_layers_in_stack: int,
            num_stacks: int,
            num_heads: int,
            d_ff: int,
            num_experts: int,
            num_active: int,
            rope_theta: float,
            width_ratio: float,
            device=None, dtype=None):
       super().__init__()
       self.stack_depth = num_stacks
       self.total_layers = num_stacks*num_layers_in_stack
       self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
       self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads,
                                d_ff, rope_theta, width_ratio, mixture_of_experts=True,
                                num_experts=num_experts, num_active=num_active,
                                device=device, dtype=dtype)  # parameters for loop with MoE
       self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)

       # scale lm head
       std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
       self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)


    def forward(self, x: Tensor) -> Tensor:
        # collect aux losses
        lb_total = 0
        lz_total = 0

        # token embed step
        x = self.token_embeddings(x)

        # repeated calls to stack
        for i in range(self.stack_depth):
            x, lb, lz = self.stack(x)
            lb_total += lb
            lz_total += lz

        # final norm
        x = self.ln_final(x)

        # Vocab projection or lm_head
        x = self.lm_head(x)

        # calculate aux loss averages
        lb_avg = lb_total / self.total_layers
        lz_avg = lz_total / self.total_layers

        return x, lb_avg, lz_avg


# ---------------------------------------------------------------------------
# HuggingFace wrapper (from hf_wrapper.py)
# ---------------------------------------------------------------------------

class LoopLMConfig(PretrainedConfig):
    """Config for all four loop-lm model variants."""

    model_type = "loop-lm"

    def __init__(
        self,
        # which of the four architectures to use
        model_variant: str = "base",        # "base" | "looped" | "moe" | "looped-moe"
        # shared
        vocab_size: int = 50257,
        context_length: int = 1024,
        d_model: int = 1024,
        num_heads: int = 16,
        d_ff: int = 2752,
        rope_theta: float = 10000.0,
        width_ratio: float = 8.0,           # d_model / base_d_model (128); set at training time
        # base + moe only
        num_layers: int = 16,
        # base + looped only
        weight_tying: bool = False,
        # looped + looped-moe only
        num_layers_in_stack: int = 8,
        num_stacks: int = 2,
        # moe + looped-moe only
        num_experts: int = 8,
        num_active: int = 2,
        # aux loss weights — used when forward() is called with labels
        lb_loss_factor: float = 0.01,
        lz_loss_factor: float = 0.001,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.model_variant = model_variant
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.rope_theta = rope_theta
        self.width_ratio = width_ratio
        self.num_layers = num_layers
        self.weight_tying = weight_tying
        self.num_layers_in_stack = num_layers_in_stack
        self.num_stacks = num_stacks
        self.num_experts = num_experts
        self.num_active = num_active
        self.lb_loss_factor = lb_loss_factor
        self.lz_loss_factor = lz_loss_factor
        # lm-evaluation-harness looks for this attribute to cap sequence length
        self.max_length = context_length


class LoopLMForCausalLM(PreTrainedModel, GenerationMixin):
    """Causal LM wrapper over all four looped-scaling variants.

    Implements the HuggingFace PreTrainedModel interface so you can:
      - Upload/download via push_to_hub / from_pretrained
      - Run lm-evaluation-harness evals
      - Fine-tune with TRL's SFTTrainer / DPOTrainer
    """

    config_class = LoopLMConfig
    # tell HF which parameter holds the output logits for generation
    _keys_to_ignore_on_load_missing = []

    def __init__(self, config: LoopLMConfig):
        super().__init__(config)
        self.model = self._build_inner_model(config)
        self.post_init()

    # ------------------------------------------------------------------
    # Model construction
    # ------------------------------------------------------------------

    def _build_inner_model(self, config: LoopLMConfig):
        kw = dict(
            vocab_size=config.vocab_size,
            context_length=config.context_length,
            d_model=config.d_model,
            num_heads=config.num_heads,
            d_ff=config.d_ff,
            rope_theta=config.rope_theta,
            width_ratio=config.width_ratio,
            # device=None so weights are placed on CPU; caller uses .to(device)
        )
        v = config.model_variant
        if v == "base":
            return MuTransformer(
                **kw,
                num_layers=config.num_layers,
                weight_tying=config.weight_tying,
            )
        elif v == "looped":
            return LoopedTransformer(
                **kw,
                num_layers_in_stack=config.num_layers_in_stack,
                num_stacks=config.num_stacks,
                weight_tying=config.weight_tying,
            )
        elif v == "moe":
            return MoETransformer(
                **kw,
                num_layers=config.num_layers,
                num_experts=config.num_experts,
                num_active=config.num_active,
            )
        elif v == "looped-moe":
            return LoopedMoETransformer(
                **kw,
                num_layers_in_stack=config.num_layers_in_stack,
                num_stacks=config.num_stacks,
                num_experts=config.num_experts,
                num_active=config.num_active,
            )
        else:
            raise ValueError(f"Unknown model_variant: {v!r}. Choose from: base, looped, moe, looped-moe")

    # ------------------------------------------------------------------
    # Embedding access (required by some HF utilities)
    # ------------------------------------------------------------------

    def get_input_embeddings(self):
        return self.model.token_embeddings

    def set_input_embeddings(self, value):
        self.model.token_embeddings = value

    # ------------------------------------------------------------------
    # Forward
    # ------------------------------------------------------------------

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,   # causal mask is handled internally
        labels: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        """
        Args:
            input_ids: (batch, seq)
            attention_mask: ignored — models use a built-in causal mask
            labels: (batch, seq) token ids; if provided, returns cross-entropy loss.
                    For MoE variants, aux losses (lb + lz) are added to the CE loss.
        """
        is_moe = self.config.model_variant in ("moe", "looped-moe")

        if is_moe:
            logits, lb, lz = self.model(input_ids)
        else:
            logits = self.model(input_ids)
            lb = lz = 0.0

        loss = None
        if labels is not None:
            ce_loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
            )
            aux = self.config.lb_loss_factor * lb + self.config.lz_loss_factor * lz
            loss = ce_loss + aux if self.training else ce_loss

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
        )

    # ------------------------------------------------------------------
    # Generation support (no KV cache — generation is correct but slow)
    # ------------------------------------------------------------------

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}