File size: 30,042 Bytes
be761d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
import  math

import torch
import torch.nn as nn
import torch.nn.functional as F

from enum import Enum
from dataclasses import dataclass, field
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined

from .causal_conv1d_compilable import causal_conv1d_fn, causal_conv1d_update
from .ssm_compilable import mamba_chunk_scan_combined
from .norms import build_norm


class InitStdFactor(Enum):
    DISABLED = "disabled"            # Init std is divided by 1.0
    GLOBAL_DEPTH = "global_depth"    # Init std is divided by sqrt(2*num_layers)
    CURRENT_DEPTH = "current_depth"  # Init std is divided by sqrt(2*depth)
    DIM_RATIO = "dim_ratio"          # Init std is divided by model_dim/4096


@dataclass
class InitConfig:
    dt_max: float = 0.1
    dt_min: float = 0.001

    dt_init_floor: float = 1e-4

    A_init_min: float = 1
    A_init_max: float = 16


DEFAULT_INIT_CONFIG = InitConfig()


@dataclass
class BaseMambaConfig:
    """
    Configuration for the Mamba family of models.
    """
    dim: int = 512
    num_layers: int = 8
    num_heads: int = 8

    state_dim: int = 128
    num_groups: int = 1
    conv_size: int | None = 4

    bias: bool = False      # Linear bias
    conv_bias: bool = True  # Convolutional bias
    dt_bias: bool = False
    D_has_head_dim: bool = False
    learnable_init_states: bool = False

    ffn_dim_multiplier: float = 2.0
    multiple_of: int = 256  # Enforce that MLP hidden layer size is multiple of a large power of 2

    norm_eps: float = 1e-6
    norm_type: str = "rmsnorm"
    
    # CUDA-related items
    ssm_chunk_size: int = 256
    use_mem_eff_path: bool = False

    # Initialization-related items
    init_use_depth: bool = False
    init_base_std: float | None = None
    init_std_factor: str = "disabled"  # e.g. "global_depth"
    init_config: InitConfig = field(default_factory=InitConfig)


class SSM(nn.Module):
    """
    State Space Model (SSM) implementation with selective state updates and convolution.

    Implements the core SSM computation with support for both training and inference modes.
    During inference, uses cached states for efficient token-by-token generation.
    """
    def __init__(self, config: BaseMambaConfig) -> None:
        """Initialize SSM parameters and layers.
        Args:
            config: Configuration containing model hyperparameters
        """
        super().__init__()
        self.config = config
        vars(self).update(vars(config))

        assert self.dim > 0,        "Model dimension (config.dim) must be positive"
        assert self.num_heads > 0,  "Number of heads (config.num_heads) must be positive"
        assert self.state_dim > 0,  "State dimension (config.state_dim) must be positive"

        if self.ffn_dim_multiplier is None:
            raise ValueError(
                "ffn_dim_multiplier must be set to a valid float (e.g. 2.0) "
                "to determine hidden_dim in SSM."
            )
        assert self.ffn_dim_multiplier > 0, "ffn_dim_multiplier must be > 0"

        self.hidden_dim = int(self.ffn_dim_multiplier * self.dim)
        self.hidden_dim = config.multiple_of * ( # Round up to multiple_of
            (self.hidden_dim + self.multiple_of - 1) // self.multiple_of
        )
        
        assert self.hidden_dim % self.num_heads == 0, (
            f"Hidden dim {self.hidden_dim} not divisible by num_heads={self.num_heads}."
        )

        self.head_dim = self.hidden_dim // self.num_heads

        self.dt_limit_kwargs = {}
        dt_limit = (self.init_config.dt_min, self.init_config.dt_max)
        if dt_limit != (0.0, float("inf")):
            self.dt_limit_kwargs = dict(dt_limit=dt_limit)

        # Order: [z, x, B, C, dt]
        d_input = (
            2 * self.hidden_dim
            + 2 * self.num_groups * self.state_dim
            + self.num_heads
        )

        self.input = nn.Linear(self.dim, d_input, bias=self.bias)

        # Only create Conv1d if self.conv_size is specified
        if self.conv_size is not None:
            conv_dim = self.hidden_dim + 2 * self.num_groups * self.state_dim

            # Depthwise-ish conv (groups = out_channels)
            # TODO: Check that this is used if causal_conv1d_fn and causal_conv1d_update cannot be imported
            self.conv1d = nn.Conv1d(
                in_channels=conv_dim,
                out_channels=conv_dim,
                kernel_size=self.conv_size,
                groups=conv_dim,
                bias=self.conv_bias,  # <- This is a boolean in your config, so pass that or True/False
                padding=self.conv_size - 1  # for "causal" style
            )

        if config.dt_bias:
            self.dt_bias = nn.Parameter(torch.empty(self.num_heads))
        else:
            self.dt_bias = nn.Parameter(torch.zeros(self.num_heads), requires_grad=False)

        self.A_log = nn.Parameter(torch.empty(self.num_heads))

        if config.D_has_head_dim:
            self.D = nn.Parameter(torch.ones(self.num_heads, self.head_dim))
        else:
            self.D = nn.Parameter(torch.ones(self.num_heads))
        
        if self.learnable_init_states:
            self.init_states = nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.state_dim))

        # Can also just use nn.RMSNorm
        self.norm = build_norm(config.norm_type, dim=self.hidden_dim, eps=self.norm_eps)
        
        self.output = nn.Linear(self.hidden_dim, self.dim, bias=self.bias)

    def _causal_conv(
        self, 
        zxbcdt: torch.Tensor,
        tok_idx: torch.Tensor | None = None,
        cu_seqlens: torch.Tensor | None = None,
        ssm_impl: str = "ssm"
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # TODO: Make slightly less verbose
        """Processes input through causal convolution path, handling both full sequence and incremental cases.

        This function implements two processing modes:
        1. Full sequence ("ssm"): Used during training and initial prompt processing.
        2. Incremental ("ssm_update"): Used during token-by-token generation.

        Args:
            zxbcdt: Input tensor containing concatenated [z, x, B, C, dt] components
            tok_idx: Token indices for sequence processing. Required for "ssm" mode.
                Defaults to None.
            cu_seqlens: Cumulative sequence lengths for variable length processing.
                Used only in "ssm" mode with caching. Defaults to None.
            ssm_impl: Implementation mode, either "ssm" for full sequence processing
                or "ssm_update" for incremental generation. Defaults to "ssm".

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
                Tuple containing separated components (z, x, B, C, dt), where:
                - z: Gating branch
                - x: Main branch
                - B, C: SSM state matrices (analogous to K, Q in attention)
                - dt: Time delta values

        Notes:
            - When using "ssm" mode during inference, a cache should be pre-initialized
            externally. This design allows for flexible caching strategies without
            modifying model code.
            - The "ssm_update" mode requires a cache to exist and will use it for
            incremental state updates during generation.
            - B, C components correspond to Key, Query in the SSM/attention duality.
        """
        # Split input into components
        z, xBC, dt = torch.split(
            zxbcdt,
            [
                self.hidden_dim,
                self.hidden_dim + 2 * self.num_groups * self.state_dim,
                self.num_heads,
            ],
            dim=-1,
        )

        if ssm_impl == "ssm":
            if hasattr(self, "cache"):
                conv_varlen_states = causal_conv1d_varlen_states(
                    xBC.squeeze(0),
                    cu_seqlens,
                    state_len=self.cache.conv_cache.shape[-1],
                )
                self.cache.conv_cache.copy_(conv_varlen_states)

            xBC = causal_conv1d_fn(
                x=xBC.transpose(1, 2),
                weight=self.conv1d.weight.squeeze(1),
                bias=self.conv1d.bias,
                activation="silu",
                seq_idx=tok_idx,
            ).transpose(1, 2)
        elif ssm_impl == "ssm_update":
            xBC = causal_conv1d_update(
                x=xBC.squeeze(0),
                conv_state=self.cache.conv_cache,
                weight=self.conv1d.weight.squeeze(1),
                bias=self.conv1d.bias,
                activation="silu",
            ).unsqueeze(0)
        else:
            raise NotImplementedError(f"SSM implementation {ssm_impl} not supported")

        # Split processed tensor into components
        x, B, C = torch.split(
            xBC,
            [
                self.hidden_dim,
                self.num_groups * self.state_dim,
                self.num_groups * self.state_dim,
            ],
            dim=-1,
        )

        return z, x, B, C, dt

    def _non_causal_conv(self, zxbcdt: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        z, x, B, C, dt = torch.split(
            zxbcdt,
            [
                self.hidden_dim,
                self.hidden_dim,
                self.num_groups * self.state_dim,
                self.num_groups * self.state_dim,
                self.num_heads,
            ],
            dim=-1,
        )
        return z, x, B, C, dt

    def _fwd(self, x, dt, A, B, C, tok_idx, cu_seqlens, initial_states):
        """
        For training
        
        Returns:
            (bsz, seq_len, num_heads, head_dim)
        """
        y = mamba_chunk_scan_combined(
            x,
            dt,
            A,
            B,
            C,
            dt_bias=self.dt_bias,
            dt_softplus=True,
            chunk_size=self.ssm_chunk_size,
            D=self.D,
            z=None,
            seq_idx=tok_idx,
            cu_seqlens=cu_seqlens,
            initial_states=initial_states,
            **self.dt_limit_kwargs,
        )

        if hasattr(self, "cache"):
            y, varlen_states = y
            self.cache.state_cache.copy_(varlen_states)

        return y
    
    def _step(self, x, seq_len, dt, A, B, C):
        """
        For inference / generation.
        """
        x = x.squeeze(0)
        A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_dim)
        dt = dt.permute(1, 2, 0).expand(seq_len, self.num_heads, self.head_dim)
        D = self.D
        if D is not None and D.dim() == 1:
            D = D.unsqueeze(1).expand(self.num_heads, self.head_dim)
        B, C = B.squeeze(0), C.squeeze(0)
        y = selective_state_update(
            self.cache.state_cache,
            x,
            dt,
            A,
            B,
            C,
            D,
            z=None,
            dt_bias=(
                torch.zeros(self.num_heads, self.head_dim).to(x)
                if self.dt_bias is None
                else self.dt_bias.unsqueeze(1).expand(self.num_heads, self.head_dim)
            ),
            dt_softplus=True,
        ).unsqueeze(0)
        
        return y

    def forward(
        self,
        x: torch.Tensor,
        tok_idx: torch.Tensor | None = None,
        cu_seqlens: torch.Tensor | None = None,
        ssm_impl: str = "ssm",
    ) -> torch.Tensor:
        bsz, seq_len, _ = x.shape

        zxbcdt = self.input(x)

        A = -torch.exp(self.A_log.float())
        initial_states = (
            self.init_states.expand(bsz, -1, -1, -1)
            if self.learnable_init_states else None
        )

        # Causal conv path
        if self.conv_size is not None:
            
            # Memory-efficient Triton kernel path
            if self.use_mem_eff_path:
                out = mamba_split_conv1d_scan_combined(
                    zxbcdt,
                    self.conv1d.weight.squeeze(1),
                    self.conv1d.bias,
                    self.dt_bias,
                    A,
                    D=self.D,
                    chunk_size=self.ssm_chunk_size,
                    seq_idx=tok_idx,
                    activation="silu",
                    rmsnorm_weight=self.norm.weight,
                    rmsnorm_eps=self.norm.eps,
                    outproj_weight=self.output.weight,
                    outproj_bias=self.output.bias,
                    headdim=self.head_dim,
                    ngroups=self.num_groups,
                    norm_before_gate=False, # Post-norm, y = self.norm(y * F.silu(z))
                    initial_states=initial_states,
                    **self.dt_limit_kwargs,
                )
                return out
            else:
                # CUDA kernel path
                z, x, B, C, dt = self._causal_conv(zxbcdt)
        else:
            # Non-causal conv path
            z, x, B, C, dt = self._non_causal_conv(zxbcdt)

        x = x.view(bsz, seq_len, self.num_heads, self.head_dim)
        B = B.view(bsz, seq_len, self.num_groups, self.state_dim)
        C = C.view(bsz, seq_len, self.num_groups, self.state_dim)

        # Chunked SSM scan
        if ssm_impl == "ssm":
            # (bsz, seq_len, num_heads, head_dim)
            y = self._fwd(x, dt, A, B, C, tok_idx, cu_seqlens, initial_states)
        elif ssm_impl == "ssm_update":
            y = self._step(x, seq_len, dt, A, B, C)
        else:
            raise NotImplementedError(f"SSM implementation {ssm_impl} not supported")

        y = y.view(bsz, seq_len, self.hidden_dim)

        # Could be different activation function, including None.
        # Mamba people post_norm here also (sometimes norm(z)*y or norm(z*y))
        # y = self.norm(y) * F.silu(z)
        y = self.norm(y * F.silu(z))
        out = self.output(y)

        return out

    @torch.inference_mode()
    def reset_parameters(self, init_std, factor) -> None:
        config = self.config
        init_config = config.init_config
        if init_config is None:
            init_config = DEFAULT_INIT_CONFIG

        # Linear layers
        in_init_std = init_std or (self.dim ** (-0.5))
        out_init_std = init_std or (self.hidden_dim ** (-0.5))
        out_init_std = out_init_std / factor

        nn.init.trunc_normal_(
            self.input.weight,
            mean=0.0,
            std=in_init_std,
            a=-3 * in_init_std,
            b=3 * in_init_std,
        )

        nn.init.trunc_normal_(
            self.output.weight,
            mean=0.0,
            std=out_init_std,
            a=-3 * out_init_std,
            b=3 * out_init_std,
        )

        # SSM
        if self.dt_bias is not None and self.dt_bias.requires_grad:
            log_dt_min = math.log(init_config.dt_min)
            log_dt_max = math.log(init_config.dt_max)
            
            # Sample log_dt ~ Uniform[log_dt_min, log_dt_max]
            log_dt = torch.rand(self.num_heads, device=self.dt_bias.device) * (log_dt_max - log_dt_min) + log_dt_min
            dt = torch.exp(log_dt)
            dt = torch.clamp(dt, min=init_config.dt_init_floor)

            # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
            inv_dt = dt + torch.log(-torch.expm1(-dt))
            self.dt_bias.copy_(inv_dt)

        elif self.dt_bias is not None:
            # If dt_bias is not trainable, we can just keep it zero or set to any constant
            self.dt_bias.fill_(0.0)

        # Convolution
        if self.conv_size is not None:
            conv_std = init_std or (self.conv_size ** (-0.5))
            nn.init.trunc_normal_(
                self.conv1d.weight,
                mean=0.0,
                std=conv_std,
                a=-3 * conv_std,
                b=3 * conv_std,
            )
            if self.conv1d.bias is not None:
                nn.init.zeros_(self.conv1d.bias)

        # Learnable init states
        if self.learnable_init_states:
            self.init_states.zero_()

        # Initialize A_log ~ log( Uniform(A_init_min, A_init_max) )
        self.A_log.uniform_(init_config.A_init_min, init_config.A_init_max)
        self.A_log.log_()

        if self.D is not None:
            self.D.data.fill_(1.0)

        # Reset norm parameters
        self.norm.reset_parameters()


class MambaBlock(nn.Module):
    def __init__(self, config: BaseMambaConfig):
        super().__init__()
        self.norm = build_norm(config.norm_type, dim=config.dim, eps=config.norm_eps)
        self.ssm = SSM(config)

    def forward(
        self,
        x: torch.Tensor,
        tok_idx: torch.Tensor | None,
        cu_seqlens: torch.Tensor | None,
        ssm_impl: str = "ssm",
    ) -> torch.Tensor:
        x = x + self.ssm(self.norm(x), tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
        return x

    @torch.inference_mode()
    def init_weights(self, init_std=None, factor=1.0):
        self.norm.reset_parameters()
        self.ssm.reset_parameters(init_std, factor)


class BaseMamba(nn.Module):
    def __init__(self, config: BaseMambaConfig):
        super().__init__()
        self.model_dim = config.dim
        self.init_base_std = config.init_base_std

        self.init_config = config.init_config
        self.init_std_factor = InitStdFactor(config.init_std_factor)

        self.layers = nn.ModuleList()
        for _ in range(config.num_layers):
            self.layers.append(MambaBlock(config))

    def forward(
        self,
        h: torch.Tensor,
        tok_idx: torch.Tensor | None,
        cu_seqlens: torch.Tensor | None,
        ssm_impl: str = "ssm",
    ) -> torch.Tensor:
        for layer in self.layers:
            h = layer(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
        return h

    @torch.inference_mode()
    def reset_parameters(self):
        pass

    @torch.inference_mode()
    def init_weights(self):
        self.reset_parameters()
        for depth, layer in enumerate(self.layers):
            factor = {
                InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
                InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
                InitStdFactor.DIM_RATIO: self.model_dim / 4096,
                InitStdFactor.DISABLED: 1.0,
            }[self.init_std_factor]

            layer.init_weights(self.init_base_std, factor)


@dataclass
class Mamba2Config(BaseMambaConfig):
    seed: int = 1337

    vocab_size: int = -1 # Will error if unchanged, makes you double check!
    weight_tying: bool = False
    torch_dtype: torch.dtype = torch.bfloat16

    loss_reduction: str = "mean"

    use_attn: bool = False
    softcap: float = 50.0


class Mamba2(BaseMamba):
    def __init__(self, config: Mamba2Config) -> None:
        super().__init__(config)
        self.weight_tying = config.weight_tying
        self.loss_reduction = config.loss_reduction

        assert config.vocab_size > 0, "vocab_size must be set and > 0"

        self.tok_emb = torch.nn.Embedding(config.vocab_size, config.dim)

        self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)

        self.output = nn.Linear(
            config.dim,
            config.vocab_size,
            bias=False,
        )

        if config.weight_tying:
            self.output.weight = self.tok_emb.weight

        print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))

    def _get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        if hasattr(self, "pos_emb") and self.pos_emb is not None:
            n_params -= self.pos_emb.weight.numel()
        if self.tok_emb.weight is not self.output.weight:
            n_params -= self.tok_emb.weight.numel()
        return n_params

    def forward(
        self,
        x: torch.Tensor,
        target: torch.Tensor | None = None,
        tok_idx: torch.Tensor | None = None,
        cu_seqlens: torch.Tensor | None = None,
        ssm_impl: str = "ssm",
    ) -> torch.Tensor:
        h = self.tok_emb(x)
        h = super().forward(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
        logits = self.output(self.norm(h))
        return logits

    @torch.inference_mode()
    def reset_parameters(self, init_std=None):
        # Either use fixed base std or sqrt model dim
        super().reset_parameters()
        init_std = init_std or (self.model_dim ** (-0.5))
        self.norm.reset_parameters()
        nn.init.trunc_normal_(
            self.tok_emb.weight,
            mean=0.0,
            std=init_std,
            a=-3 * init_std,
            b=3 * init_std,
        )
        if not self.weight_tying:
            nn.init.trunc_normal_(
                self.output.weight,
                mean=0.0,
                std=init_std,
                a=-3 * init_std,
                b=3 * init_std,
            )

    @torch.inference_mode()
    def init_weights(self, buffer_device: torch.device = None):
        """
        Initialize model parameters and optionally compute buffers on a specific device.
        
        Args:
            buffer_device (torch.device, optional): If provided, any large or precomputed
                buffers (like RoPE frequency tensors) will be allocated or re-created on
                this device during initialization. This can avoid overhead from transferring
                buffers between CPU and GPU after creation. If None, buffers default to the
                device of the first parameter or CPU.

        Usage:
            - Pass a GPU device (e.g., ``torch.device('cuda')``) when you want to ensure
            buffers are created directly on GPU, preventing extra transfers.
            - Pass a CPU device (e.g., ``torch.device('cpu')``) if you want to keep
            large buffers in CPU memory (common in CPU-offload or pipeline-parallel setups).
            - Leave it as ``None`` to rely on the model’s existing parameter device or
            the default PyTorch device context.

        When / Why:
            - Useful in distributed or pipeline-parallel training where parameters may
            initially live on CPU, but you still need certain buffers on GPU to avoid
            overhead during forward passes.
            - Prevents large re-allocations or re-copies when big buffers (like RoPE
            frequency tables) are needed per rank.
        """
        super().init_weights()

    @classmethod
    def from_model_args(cls, config: Mamba2Config) -> "Mamba2":
        """
        Initialize a Mamba model from a MambaConfig object.

        Args:
            config (MambaConfig): Mamba configuration arguments.

        Returns:
            Mamba: Mamba-2 model.
        """
        return cls(config)


def get_mamba2_flops(
    seq_len: int,
    dim: int,
    num_layers: int,
    vocab_size: int,
    ffn_multiplier: float = 2.0,
    state_dim: int = 128,
    conv_size: int = 4,
    num_heads: int = 8,
    num_groups: int = 1,
    multiple_of: int = 256,
    include_input_embedding: bool = True,
    include_output_logits: bool = True,
    forward_backward_multiplier: float = 1.0,
) -> int:
    """
    Estimate the FLOPs for a Mamba-2 style model using a "Chinchilla-like" shape-based approach.

    By default, this returns the forward-pass cost. If you want a rough
    forward+backward estimate, set `forward_backward_multiplier=3.0` (common
    rule-of-thumb for these models).

    What gets counted:
    • Hidden dimension is rounded up to 'multiple_of' = 256 (as in Mamba).
    • Per-layer:
        1) Input Linear: [dim → 2*hidden_dim + 2*(groups*state_dim) + num_heads]
        2) Depthwise Conv1D: 2*(conv_dim * conv_size), where conv_dim=hidden_dim + 2*groups*state_dim
        3) SSM selective scan: ~9*(dim*state_dim) (from Mamba dev discussion)
        4) Output Linear: [hidden_dim → dim]
    • Each layer’s cost is multiplied by (seq_len * num_layers).
    • Optionally adds:
        - The cost of the input embedding (treating it as a matmul: seq_len×vocab_size × vocab_size×dim).
        - The cost of the final projection [dim → vocab_size].
    • Finally scaled by `forward_backward_multiplier` if desired.

    Args:
        seq_len (int): Sequence length (number of tokens).
        dim (int): Model (embedding) dimension.
        num_layers (int): Number of Mamba layers.
        vocab_size (int): Vocabulary size for final logits projection.
        ffn_multiplier (float): FFN expansion ratio, e.g. 2.0 => hidden_dim=2×dim (rounded up).
        state_dim (int): SSM state dimension (commonly 128).
        conv_size (int): Kernel size for the depthwise conv1d (default=4).
        num_heads (int): Number of heads (slightly affects input-lin out_dim).
        num_groups (int): For "grouped" states in some Mamba variants (usually 1).
        multiple_of (int): Round hidden_dim up to this multiple (commonly 256).
        include_input_embedding (bool): If True, count the cost of an “embedding matmul”
                                        for the input tokens => shape-based approach.
        include_output_logits (bool): If True, count the cost of final [dim → vocab_size].
        forward_backward_multiplier (float): E.g. 1.0 for forward only, 2.0 or 3.0 for forward+backward.

    Returns:
        int: Approximate total FLOPs (multiply-adds) for the selected pass(es),
            as an integer.
    """
    # 0) Input embedding (optional)
    flops_embedding = 0
    if include_input_embedding:
        flops_embedding = 2 * (seq_len * vocab_size * dim)

    # 1) Round up hidden_dim
    raw_hidden_dim = int(ffn_multiplier * dim)
    hidden_dim = multiple_of * ((raw_hidden_dim + multiple_of - 1) // multiple_of)

    # 2) Per-layer forward cost
    out_dim_input = 2*hidden_dim + 2*(num_groups*state_dim) + num_heads
    flops_input_linear = 2 * (dim * out_dim_input)
    conv_dim = hidden_dim + 2*(num_groups*state_dim)
    flops_conv = 2 * (conv_dim * conv_size)
    flops_ssm = 9 * state_dim * dim
    flops_output_linear = 2 * (hidden_dim * dim)
    flops_layer = (flops_input_linear + flops_conv + flops_ssm + flops_output_linear)

    # Multiply by #layers and sequence length
    flops_layers = flops_layer * num_layers * seq_len

    # 3) Final projection [dim → vocab_size] (optional)
    flops_vocab = 0
    if include_output_logits:
        flops_vocab = 2 * (seq_len * dim * vocab_size)

    # 4) Total forward FLOPs
    flops_forward = flops_embedding + flops_layers + flops_vocab

    # 5) Scale for forward+backward if desired
    return int(flops_forward * forward_backward_multiplier)

def get_mamba2_flops_per_token(
    **kwargs
) -> float:
    """
    Estimate FLOPs per token for a Mamba-2 style model.

    This function extracts necessary parameters from kwargs and calculates the FLOPs per token.
    
    Args:
        **kwargs: Dictionary containing model configuration parameters.
    
    Returns:
        float: Approximate FLOPs per token.
    """
    defaults = {
        'ffn_dim_multiplier': 2.0,
        'state_dim': 128,
        'conv_size': 4,
        'num_heads': 8,
        'num_groups': 1,
        'multiple_of': 256,
        'include_input_embedding': True,
        'include_output_logits': True,
        'forward_backward_multiplier': 1.0,
    }
    # Merge defaults
    for k, v in defaults.items():
        kwargs.setdefault(k, v)
    # Mandatory keys
    for required in ['seq_len', 'dim', 'num_layers', 'vocab_size']:
        if required not in kwargs:
            raise ValueError(f"Missing required parameter: {required}")

    total_flops = get_mamba2_flops(
        seq_len=kwargs['seq_len'],
        dim=kwargs['dim'],
        num_layers=kwargs['num_layers'],
        vocab_size=kwargs['vocab_size'],
        ffn_multiplier=kwargs['ffn_dim_multiplier'],
        state_dim=kwargs['state_dim'],
        conv_size=kwargs['conv_size'],
        num_heads=kwargs['num_heads'],
        num_groups=kwargs['num_groups'],
        multiple_of=kwargs['multiple_of'],
        include_input_embedding=kwargs['include_input_embedding'],
        include_output_logits=kwargs['include_output_logits'],
        forward_backward_multiplier=kwargs['forward_backward_multiplier'],
    )
    flops_per_token = total_flops / kwargs['seq_len']

    return flops_per_token


# Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops)
def get_no_recompute_ops():
    return {
        torch.ops.aten.mm.default,
        torch.ops.aten._scaled_mm.default,
        torch.ops.c10d_functional.reduce_scatter_tensor.default,
        torch.ops.mamba_ssm.ssm_chunk_scan_combined_fwd.default,

        # For low-precision training, it's useful to always save the result of max(abs(tensor))
        torch.ops.aten.abs.default,
        torch.ops.aten.max.default,
    }


def main():
    from mamba_ssm import Mamba2 as MambaRef

    x = torch.randn(2, 64, 192).cuda()

    # Create and run the first model
    model = MambaRef(
        d_model=192,
        expand=2,
        d_conv=4,
        d_state=64,
        headdim=48,
    ).cuda()
    y = model(x)
    print("Mamba reference output: ", y)
    print("Mean of MambaRef output: ", y.mean().item())
    print("Stddev of MambaRef output: ", y.std().item())

    # Create and run the second model
    config = Mamba2Config(vocab_size=200064, use_mem_eff_path=True)
    model2 = Mamba2(
        config=config,
    ).cuda()

    # Fix: Convert x to torch.LongTensor
    x_indices = torch.randint(0, config.vocab_size, (2, 64), dtype=torch.long).cuda()

    y2 = model2(x_indices)
    print("Mamba output: ", y2)
    print("Mean of Mamba output: ", y2.mean().item())
    print("Stddev of Mamba output: ", y2.std().item())

if __name__ == "__main__":
    main()