yagizdevre commited on
Commit
b9b3e2d
·
1 Parent(s): b31e1c7

configs added

Browse files
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_minimamba import MiniMambaConfig
2
+ from .modeling_minimamba import MiniMamba
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<|endofprompt|>": 200018
3
+ }
attn.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ try:
9
+ from flash_attn import flash_attn_func
10
+ except ImportError as e:
11
+ print(
12
+ f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
13
+ )
14
+
15
+
16
+ def nearest_power_of_two(x: int, round_up: bool = False) -> int:
17
+ return (
18
+ 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
19
+ )
20
+
21
+ def _generate_slopes(self, n: int):
22
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
23
+ return [start * (start**i) for i in range(n)]
24
+
25
+ def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
26
+ # If n_heads is a power of 2, generate slopes directly
27
+ if math.log2(n_heads).is_integer():
28
+ slopes = self._generate_slopes(n_heads)
29
+ else:
30
+ # Get slopes for the nearest power of two
31
+ n = nearest_power_of_two(n_heads, round_up=False)
32
+ slopes_power_of_two = self._generate_slopes(n)
33
+
34
+ # Generate extra slopes
35
+ extra_slopes = self._generate_slopes(2 * n)
36
+ extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
37
+ slopes = slopes_power_of_two + extra_slopes_trunc
38
+ slopes = torch.tensor(slopes, device=self.device)
39
+ slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
40
+ return slopes
41
+
42
+
43
+ def precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float = 10000.0):
44
+ # For half the dimensions, build the scale factor:
45
+ freq_seq = torch.arange(0, head_dim, 2).float() / head_dim
46
+ freqs = 1.0 / (theta ** freq_seq)
47
+
48
+ # Outer product with positions
49
+ t = torch.arange(max_seq_len, dtype=torch.float32)
50
+ angles = torch.outer(t, freqs)
51
+
52
+ # Build a complex exponential e^{i * theta}
53
+ freqs_cis = torch.polar(
54
+ torch.ones_like(angles),
55
+ angles
56
+ )
57
+ return freqs_cis
58
+
59
+
60
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
61
+ """
62
+ x is [B, n_heads, seq_len, head_dim_as_complex],
63
+ so we want to broadcast freqs_cis from [max_seq_len, half_dim]
64
+ to [1, 1, seq_len, half_dim].
65
+ """
66
+ seq_len = x.shape[2]
67
+ freqs_cis = freqs_cis[:seq_len] # slice down to current seq_len
68
+ return freqs_cis.view(1, 1, seq_len, -1)
69
+
70
+
71
+ def apply_rotary_emb(
72
+ xq: torch.Tensor,
73
+ xk: torch.Tensor,
74
+ freqs_cis: torch.Tensor,
75
+ ) -> tuple[torch.Tensor, torch.Tensor]:
76
+ # Convert real -> complex by grouping last dim in pairs
77
+ # shape => [B, n_heads, seq_len, head_dim//2, 2] => complex => [B, n_heads, seq_len, head_dim//2]
78
+ xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
79
+ xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
80
+
81
+ # Broadcast the frequencies to match [B, n_heads, seq_len, head_dim//2]
82
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex)
83
+
84
+ # Multiply => apply rotation
85
+ xq_complex = xq_complex * freqs_cis
86
+ xk_complex = xk_complex * freqs_cis
87
+
88
+ # Convert back to real => shape [B, n_heads, seq_len, head_dim]
89
+ xq_out = torch.view_as_real(xq_complex).reshape(*xq.shape)
90
+ xk_out = torch.view_as_real(xk_complex).reshape(*xk.shape)
91
+ return xq_out.type_as(xq), xk_out.type_as(xk)
92
+
93
+
94
+ class Attention(nn.Module):
95
+ def __init__(self, config):
96
+ super(Attention, self).__init__()
97
+ self.dim, self.num_heads = config.dim, config.num_heads
98
+ assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
99
+ self.head_dim = config.dim // config.num_heads
100
+
101
+ self.c_attn = nn.Linear(self.dim, 3*self.dim, bias=config.bias)
102
+ self.c_proj = nn.Linear(config.dim, config.dim, bias=config.bias)
103
+ self.c_proj.SCALE_INIT = 1
104
+
105
+ self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None
106
+ self.window_size = config.window_size
107
+ self.softcap = config.softcap
108
+
109
+ self.dropout = config.dropout
110
+ self.resid_dropout = nn.Dropout(self.dropout)
111
+
112
+ def _generate_slopes(self, n: int):
113
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
114
+ return [start * (start**i) for i in range(n)]
115
+
116
+ def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25):
117
+ # If n_heads is a power of 2, generate slopes directly
118
+ if math.log2(num_heads).is_integer():
119
+ slopes = self._generate_slopes(num_heads)
120
+ else:
121
+ # Get slopes for the nearest power of two
122
+ n = nearest_power_of_two(num_heads, round_up=False)
123
+ slopes_power_of_two = self._generate_slopes(n)
124
+
125
+ # Generate extra slopes
126
+ extra_slopes = self._generate_slopes(2 * n)
127
+ extra_slopes_trunc = extra_slopes[0::2][: num_heads - n]
128
+ slopes = slopes_power_of_two + extra_slopes_trunc
129
+ slopes = torch.tensor(slopes, device=torch.device("cuda"))
130
+ slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
131
+ return slopes
132
+
133
+ def forward(
134
+ self,
135
+ x: torch.Tensor = None,
136
+ q: torch.Tensor = None,
137
+ k: torch.Tensor = None,
138
+ v: torch.Tensor = None,
139
+ freqs_cis: torch.Tensor = None,
140
+ ) -> torch.Tensor:
141
+ if x is not None:
142
+ q = k = v = x
143
+ if any(t is None for t in [q, k, v]):
144
+ raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.")
145
+
146
+ bsz, q_len, dim = q.shape
147
+ _, k_len, _ = k.shape
148
+ _, v_len, _ = v.shape
149
+
150
+ qkv = self.c_attn(x)
151
+ q, k, v = torch.chunk(qkv, 3, dim=2)
152
+
153
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim)
154
+ k = k.view(bsz, k_len, self.num_heads, self.head_dim)
155
+ v = v.view(bsz, v_len, self.num_heads, self.head_dim)
156
+
157
+ if self.alibi_slopes is None: # Use either ALiBi or RoPE
158
+ q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
159
+
160
+ y = flash_attn_func( # https://arxiv.org/pdf/2307.08691
161
+ q=q, k=k, v=v,
162
+ dropout_p=self.dropout if self.training else 0.0,
163
+ causal=True,
164
+ window_size=(self.window_size, 0), # Set to config.seq_len if full attention
165
+ alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409
166
+ softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
167
+ )
168
+
169
+ y = y.contiguous().view(bsz, q_len, -1)
170
+ y = self.resid_dropout(self.c_proj(y))
171
+ return y
172
+
173
+
174
+ class MLP(nn.Module):
175
+ def __init__(self, config):
176
+ # https://arxiv.org/pdf/2002.05202
177
+ super().__init__()
178
+ self.hidden_size = config.dim
179
+ self.intermediate_size = config.dim * config.mlp_scale
180
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
181
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
182
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
183
+ self.dropout = nn.Dropout(config.dropout)
184
+
185
+ def forward(self, x):
186
+ gate = self.gate_proj(x)
187
+ gate = F.gelu(gate, approximate="tanh")
188
+ up = self.up_proj(x)
189
+ fuse = gate * up
190
+ outputs = self.down_proj(fuse)
191
+ outputs = self.dropout(outputs)
192
+ return outputs
193
+
194
+
195
+ class AttentionLayer(nn.Module):
196
+ def __init__(self, config) -> None:
197
+ super(AttentionLayer, self).__init__()
198
+ self.attn_norm = nn.RMSNorm(config.dim)
199
+ self.attn = Attention(config=config)
200
+ self.mlp_norm = nn.RMSNorm(config.dim)
201
+ self.mlp = MLP(config)
202
+
203
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor:
204
+ x = x + self.attn(x=self.attn_norm(x), freqs_cis=freqs_cis)
205
+ x = x + self.mlp(self.mlp_norm(x))
206
+ return x
casual_conv1d_compilable.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+ import causal_conv1d_cuda
4
+
5
+ # Causal Conv1D Forward Function
6
+ @torch.library.custom_op(
7
+ "mamba_causal_conv1d::causal_conv1d_fwd",
8
+ mutates_args=(),
9
+ device_types="cuda",
10
+ )
11
+ def causal_conv1d_fwd(
12
+ x: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ bias: Optional[torch.Tensor] = None,
15
+ seq_idx: Optional[torch.Tensor] = None,
16
+ activation: Optional[str] = None,
17
+ ) -> torch.Tensor:
18
+ # Ensure activation is valid
19
+ if activation not in [None, "silu", "swish"]:
20
+ raise NotImplementedError("activation must be None, silu, or swish")
21
+
22
+ # Ensure x is contiguous
23
+ if x.stride(2) != 1 and x.stride(1) != 1:
24
+ x = x.contiguous()
25
+
26
+ # Make bias and seq_idx contiguous if they exist
27
+ bias = bias.contiguous() if bias is not None else None
28
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
29
+
30
+ # Translate activation to bool for custom CUDA kernel
31
+ use_activation = activation in ["silu", "swish"]
32
+
33
+ # Call custom CUDA kernel for forward pass
34
+ out = causal_conv1d_cuda.causal_conv1d_fwd(
35
+ x, weight, bias, seq_idx, None, None, use_activation
36
+ )
37
+ return out
38
+
39
+ # Register a fake forward pass for tracing
40
+ @causal_conv1d_fwd.register_fake
41
+ def _causal_conv1d_fwd_fake(
42
+ x: torch.Tensor,
43
+ weight: torch.Tensor,
44
+ bias: Optional[torch.Tensor] = None,
45
+ seq_idx: Optional[torch.Tensor] = None,
46
+ activation: Optional[str] = None,
47
+ ) -> torch.Tensor:
48
+ torch._check(x.shape[-2] == weight.shape[0])
49
+ return torch.empty_like(x)
50
+
51
+ # Causal Conv1D Backward Function
52
+ @torch.library.custom_op(
53
+ "mamba_causal_conv1d::causal_conv1d_bwd",
54
+ mutates_args=(),
55
+ device_types="cuda",
56
+ )
57
+ def causal_conv1d_bwd(
58
+ x: torch.Tensor,
59
+ weight: torch.Tensor,
60
+ bias: Optional[torch.Tensor],
61
+ dout: torch.Tensor,
62
+ seq_idx: Optional[torch.Tensor],
63
+ activation: bool,
64
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
65
+ # Ensure dout is contiguous
66
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
67
+ dout = dout.contiguous()
68
+
69
+ # Call custom CUDA kernel for backward pass
70
+ dx, dweight, dbias, _ = causal_conv1d_cuda.causal_conv1d_bwd(
71
+ x, weight, bias, dout, seq_idx, None, None, None, False, activation
72
+ )
73
+
74
+ # Handle optional bias gradient
75
+ dbias = dbias if bias is not None else torch.empty((0,), device=dout.device)
76
+
77
+ return dx, dweight, dbias
78
+
79
+ # Register a fake backward pass for tracing
80
+ @causal_conv1d_bwd.register_fake
81
+ def _causal_conv1d_bwd_fake(
82
+ x: torch.Tensor,
83
+ weight: torch.Tensor,
84
+ bias: Optional[torch.Tensor],
85
+ dout: torch.Tensor,
86
+ seq_idx: Optional[torch.Tensor],
87
+ activation: bool,
88
+ ):
89
+ return (
90
+ torch.empty_like(x),
91
+ torch.empty_like(weight),
92
+ torch.empty_like(bias) if bias is not None else None,
93
+ )
94
+
95
+ # Setup context for autograd
96
+ def causal_conv1d_setup_context(ctx, inputs, output):
97
+ x, weight, bias, seq_idx, activation = inputs
98
+ ctx.activation = activation in ["silu", "swish"]
99
+ ctx.save_for_backward(x, weight, bias, seq_idx)
100
+
101
+ # Bridge for backward pass in autograd
102
+ def causal_conv1d_bwd_bridge(ctx, dout):
103
+ x, weight, bias, seq_idx = ctx.saved_tensors
104
+ dx, dweight, dbias = causal_conv1d_bwd(x, weight, bias, dout, seq_idx, ctx.activation)
105
+
106
+ # Handle None return values
107
+ dbias = dbias if bias is not None else None
108
+ return dx, dweight, dbias, None, None
109
+
110
+ # Register custom autograd function
111
+ torch.library.register_autograd(
112
+ "mamba_causal_conv1d::causal_conv1d_fwd",
113
+ causal_conv1d_bwd_bridge,
114
+ setup_context=causal_conv1d_setup_context,
115
+ )
116
+
117
+ # Define a higher-level function to invoke the custom op
118
+ def causal_conv1d_fn(x, weight, bias=None, seq_idx=None, activation=None):
119
+ return causal_conv1d_fwd(x, weight, bias, seq_idx, activation)
120
+
121
+
122
+ @torch.library.custom_op(
123
+ "mamba_causal_conv1d::causal_conv1d_update",
124
+ mutates_args=(),
125
+ device_types="cuda",
126
+ )
127
+ def causal_conv1d_update_fwd(
128
+ x: torch.Tensor,
129
+ conv_state: torch.Tensor,
130
+ weight: torch.Tensor,
131
+ bias: Optional[torch.Tensor] = None,
132
+ activation: Optional[str] = None,
133
+ cache_seqlens: Optional[torch.Tensor] = None,
134
+ ) -> torch.Tensor:
135
+ """
136
+ x: (batch, dim) or (batch, dim, seqlen)
137
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
138
+ weight: (dim, width)
139
+ bias: (dim,)
140
+ cache_seqlens: (batch,), dtype int32.
141
+ If not None, the conv_state is treated as a circular buffer.
142
+ The conv_state will be updated by copying x to the conv_state starting at the index
143
+ @cache_seqlens % state_len.
144
+
145
+ out: (batch, dim) or (batch, dim, seqlen)
146
+ """
147
+ if activation not in [None, "silu", "swish"]:
148
+ raise NotImplementedError("activation must be None, silu, or swish")
149
+ activation = activation in ["silu", "swish"]
150
+ unsqueeze = x.dim() == 2
151
+ if unsqueeze:
152
+ x = x.unsqueeze(-1)
153
+ out = causal_conv1d_cuda.causal_conv1d_update(
154
+ x, conv_state, weight, bias, activation, cache_seqlens
155
+ )
156
+ if unsqueeze:
157
+ out = out.squeeze(-1)
158
+ return out
159
+
160
+ @causal_conv1d_update_fwd.register_fake
161
+ def _causal_conv1d_update_fwd(
162
+ x: torch.Tensor,
163
+ conv_state: torch.Tensor,
164
+ weight: torch.Tensor,
165
+ bias: Optional[torch.Tensor] = None,
166
+ activation: Optional[str] = None,
167
+ cache_seqlens: Optional[torch.Tensor] = None,
168
+ ) -> torch.Tensor:
169
+ return torch.empty_like(x)
170
+
171
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
172
+ return causal_conv1d_update_fwd(x, conv_state, weight, bias, activation, cache_seqlens)
173
+
174
+ # Test the implementation
175
+ if __name__ == "__main__":
176
+ from causal_conv1d import causal_conv1d_fn as causal_conv1d_fn_ref
177
+
178
+ torch.manual_seed(0)
179
+
180
+ x = torch.randn(8, 32, 16, device="cuda", requires_grad=True)
181
+ weight = torch.randn(32, 3, device="cuda", requires_grad=True)
182
+ bias = None#torch.randn(32, device="cuda", requires_grad=True)
183
+
184
+ # Test the forward and backward pass
185
+ print("Custom Implementation")
186
+ out = causal_conv1d_fn(x, weight, bias, activation="silu")
187
+ out.sum().backward()
188
+
189
+ print(out.min(), out.max(), out.mean(), out.std())
190
+ print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
191
+ print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
192
+
193
+ # Try compiling the function using torch.compile
194
+ x.grad.zero_(), weight.grad.zero_()
195
+ compiled_conv1d = torch.compile(causal_conv1d_fn)
196
+ print(compiled_conv1d)
197
+
198
+ # Run the compiled function
199
+ print("Compiled Implementation")
200
+ out = compiled_conv1d(x, weight, bias, activation="silu")
201
+ out.sum().backward()
202
+
203
+ print(out.min(), out.max(), out.mean(), out.std())
204
+ print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
205
+ print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
206
+
207
+ print("Reference Implementation")
208
+ x.grad.zero_(), weight.grad.zero_()
209
+ out = causal_conv1d_fn_ref(x, weight, bias, activation="silu")
210
+ out.sum().backward()
211
+
212
+ print(out.min(), out.max(), out.mean(), out.std())
213
+ print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
214
+ print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
config.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "minimamba",
3
+ "_name_or_path": "Mamba_546M",
4
+ "architectures": ["MiniMamba"],
5
+ "dim": 896,
6
+ "num_layers": 56,
7
+ "num_heads": 32,
8
+ "state_dim": 128,
9
+ "num_groups": 1,
10
+ "conv_size": 4,
11
+ "use_mem_eff_path": true,
12
+ "dt_bias": true,
13
+ "D_has_head_dim": true,
14
+ "learnable_init_states": false,
15
+ "ssm_chunk_size": 256,
16
+ "vocab_size": 200064,
17
+ "mlp_scale": 2,
18
+ "multiple_of": 256,
19
+ "norm_eps": 1e-5,
20
+ "init_use_depth": false,
21
+ "init_base_std": null,
22
+ "init_std_factor": "disabled",
23
+ "hidden_act": "silu",
24
+ "bias": false,
25
+ "torch_dtype": "bfloat16",
26
+ "seed": 1337,
27
+ "init_args": {
28
+ "dt_max": 0.1,
29
+ "dt_min": 0.001,
30
+ "dt_init_floor": 1e-4,
31
+ "A_init_min": 0.01,
32
+ "A_init_max": 16
33
+ },
34
+ "seq_len": 8192,
35
+ "window_size": 1024,
36
+ "weight_tying": true,
37
+ "dropout": 0.0,
38
+ "num_epochs": 1,
39
+ "global_bsz": 524288,
40
+ "bsz": 1,
41
+ "warmup_steps": 1907,
42
+ "eval_period": 50,
43
+ "save_period": 500,
44
+ "max_lr": 3.0e-4,
45
+ "min_lr": 3.0e-5,
46
+ "max_norm": 1.0,
47
+ "dilation": 1,
48
+ "fsdp": false,
49
+ "ddp": true,
50
+ "mixed_precision": true,
51
+ "cpu_offload": false,
52
+ "sharding_strategy": "full_shard",
53
+ "state_dict_type": "full",
54
+ "auto_wrap_policy": "partial",
55
+ "backward_prefetch": "backward_pre",
56
+ "forward_prefetch": false,
57
+ "sync_module_states": true,
58
+ "use_orig_params": true,
59
+ "device_id": null,
60
+ "precision": {
61
+ "param": "bfloat16",
62
+ "reduce": "bfloat16",
63
+ "buffer": "bfloat16"
64
+ },
65
+ "fsdp_modules": [
66
+ "MambaBlock",
67
+ "AttentionLayer"
68
+ ],
69
+ "use_activation_checkpointing": true,
70
+ "use_attn": true,
71
+ "use_alibi": true,
72
+ "softcap": 50.0,
73
+ "theta": 10000.0,
74
+ "torch_compile": false
75
+ }
configuration_minimamba.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MiniMambaConfig(PretrainedConfig):
4
+ """
5
+ Minimal or extended config class for MiniMamba.
6
+ Inherits from HF's PretrainedConfig so we can do:
7
+ model = MiniMamba.from_pretrained(...)
8
+ and it will load this config automatically.
9
+
10
+ This config includes all fields from the provided config.json.
11
+ """
12
+ model_type = "minimamba"
13
+
14
+ def __init__(
15
+ self,
16
+ # Standard HF fields:
17
+ model_type="minimamba",
18
+ _name_or_path="Mamba_5460M",
19
+ architectures=["MiniMamba"],
20
+
21
+ # Key Mamba architecture hyperparameters:
22
+ dim=896,
23
+ num_layers=56,
24
+ num_heads=32,
25
+ state_dim=128,
26
+ num_groups=1,
27
+ conv_size=4,
28
+ use_mem_eff_path=True,
29
+ dt_bias=True,
30
+ D_has_head_dim=True,
31
+ learnable_init_states=False,
32
+ ssm_chunk_size=256,
33
+ vocab_size=200064,
34
+ mlp_scale=2,
35
+ ffn_dim_multiplier=2.0,
36
+ multiple_of=256,
37
+ norm_eps=1e-5,
38
+ init_use_depth=False,
39
+ init_base_std=None,
40
+ init_std_factor="disabled",
41
+ hidden_act="silu",
42
+ bias=False,
43
+
44
+ # Torch / training:
45
+ torch_dtype="bfloat16",
46
+ seed=1337,
47
+
48
+ # The init_config block nested in JSON:
49
+ init_args=None, # e.g. dict with dt_max, dt_min, dt_init_floor, ...
50
+
51
+ # Additional Mamba or training fields:
52
+ seq_len=8192,
53
+ weight_tying=True,
54
+ dropout=0.0,
55
+ num_epochs=1,
56
+ global_bsz=524288,
57
+ bsz=1,
58
+ warmup_steps=1907,
59
+ eval_period=50,
60
+ save_period=500,
61
+ max_lr=0.0003,
62
+ min_lr=3e-5,
63
+ max_norm=1.0,
64
+ dilation=1,
65
+ fsdp=False,
66
+ ddp=True,
67
+ mixed_precision=True,
68
+ cpu_offload=False,
69
+ sharding_strategy="full_shard",
70
+ state_dict_type="full",
71
+ auto_wrap_policy="partial",
72
+ backward_prefetch="backward_pre",
73
+ forward_prefetch=False,
74
+ sync_module_states=True,
75
+ use_orig_params=True,
76
+ device_id=None,
77
+ precision=None, # e.g. dict with param="bfloat16", reduce="bfloat16", buffer="bfloat16"
78
+ fsdp_modules=None,# e.g. ["MambaBlock"]
79
+ use_activation_checkpointing=True,
80
+ use_attn=True,
81
+ softcap=50.0,
82
+ torch_compile=True,
83
+
84
+ # Now accept arbitrary additional kwargs, to remain flexible:
85
+ **kwargs
86
+ ):
87
+ super().__init__(
88
+ # In HF, these common keys are typically passed to the parent:
89
+ model_type=model_type,
90
+ _name_or_path=_name_or_path,
91
+ architectures=architectures,
92
+ **kwargs
93
+ )
94
+
95
+ self.dim = dim
96
+ self.num_layers = num_layers
97
+ self.num_heads = num_heads
98
+ self.state_dim = state_dim
99
+ self.num_groups = num_groups
100
+ self.conv_size = conv_size
101
+ self.use_mem_eff_path = use_mem_eff_path
102
+ self.dt_bias = dt_bias
103
+ self.D_has_head_dim = D_has_head_dim
104
+ self.learnable_init_states = learnable_init_states
105
+ self.ssm_chunk_size = ssm_chunk_size
106
+ self.vocab_size = vocab_size
107
+ self.ffn_dim_multiplier = ffn_dim_multiplier
108
+ self.multiple_of = multiple_of
109
+ self.norm_eps = norm_eps
110
+ self.init_use_depth = init_use_depth
111
+ self.init_base_std = init_base_std
112
+ self.init_std_factor = init_std_factor
113
+ self.hidden_act = hidden_act
114
+ self.bias = bias
115
+
116
+ self.torch_dtype = torch_dtype
117
+ self.seed = seed
118
+
119
+ # Nested init_args (dt_max, dt_min, etc.).
120
+ # Could store it as a dict, or parse out the fields individually:
121
+ self.init_args = init_args or {}
122
+
123
+ self.seq_len = seq_len
124
+ self.weight_tying = weight_tying
125
+ self.dropout = dropout
126
+ self.num_epochs = num_epochs
127
+ self.global_bsz = global_bsz
128
+ self.bsz = bsz
129
+ self.warmup_steps = warmup_steps
130
+ self.eval_period = eval_period
131
+ self.save_period = save_period
132
+ self.max_lr = max_lr
133
+ self.min_lr = min_lr
134
+ self.max_norm = max_norm
135
+ self.dilation = dilation
136
+ self.fsdp = fsdp
137
+ self.ddp = ddp
138
+ self.mixed_precision = mixed_precision
139
+ self.cpu_offload = cpu_offload
140
+ self.sharding_strategy = sharding_strategy
141
+ self.state_dict_type = state_dict_type
142
+ self.auto_wrap_policy = auto_wrap_policy
143
+ self.backward_prefetch = backward_prefetch
144
+ self.forward_prefetch = forward_prefetch
145
+ self.sync_module_states = sync_module_states
146
+ self.use_orig_params = use_orig_params
147
+ self.device_id = device_id
148
+ self.precision = precision
149
+ self.fsdp_modules = fsdp_modules
150
+ self.use_activation_checkpointing = use_activation_checkpointing
151
+ self.use_attn = use_attn
152
+ self.softcap = softcap
153
+ self.torch_compile = torch_compile
154
+
155
+ # If you want to store any leftover kwargs:
156
+ self.extra_args = kwargs
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from enum import Enum
8
+ from dataclasses import dataclass, field
9
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
10
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
11
+
12
+ # --- TODO: These two are always compiled even when kernel.compile is disabled. We should fix this. ---
13
+ from causal_conv1d_compilable import causal_conv1d_fn, causal_conv1d_update
14
+ from ssm_compilable import mamba_chunk_scan_combined
15
+ # -----------------------------------------------------------------------------------------------------
16
+
17
+ from .norms import build_norm
18
+ from .attn import AttentionLayer
19
+ from .attn import precompute_freqs_cis
20
+ from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
21
+
22
+
23
+ class InitStdFactor(Enum):
24
+ DISABLED = "disabled" # Init std is divided by 1.0
25
+ GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*num_layers)
26
+ CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
27
+ DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
28
+
29
+
30
+ @dataclass
31
+ class InitConfig:
32
+ dt_max: float = 0.1
33
+ dt_min: float = 0.001
34
+
35
+ dt_init_floor: float = 1e-4
36
+
37
+ A_init_min: float = 1
38
+ A_init_max: float = 16
39
+
40
+
41
+ DEFAULT_INIT_CONFIG = InitConfig()
42
+
43
+
44
+ @dataclass
45
+ class BaseMambaConfig:
46
+ """
47
+ Configuration for the Mamba family of models.
48
+ """
49
+ dim: int = 512
50
+ num_layers: int = 8
51
+ num_heads: int = 8
52
+
53
+ state_dim: int = 128
54
+ num_groups: int = 1
55
+ conv_size: int | None = 4
56
+
57
+ bias: bool = False # Linear bias
58
+ conv_bias: bool = True # Convolutional bias
59
+ dt_bias: bool = False
60
+ D_has_head_dim: bool = False
61
+ learnable_init_states: bool = False
62
+
63
+ mlp_scale: int = 2
64
+ multiple_of: int = 256 # Enforce that MLP hidden layer size is multiple of a large power of 2
65
+
66
+ norm_eps: float = 1e-6
67
+ norm_type: str = "rmsnorm"
68
+
69
+ # CUDA-related items
70
+ ssm_chunk_size: int = 256
71
+ use_mem_eff_path: bool = False
72
+
73
+ # Initialization-related items
74
+ init_use_depth: bool = False
75
+ init_base_std: float | None = None
76
+ init_std_factor: str = "disabled" # e.g. "global_depth"
77
+ init_config: InitConfig = field(default_factory=InitConfig)
78
+
79
+
80
+ class SSM(nn.Module):
81
+ """
82
+ State Space Model (SSM) implementation with selective state updates and convolution.
83
+
84
+ Implements the core SSM computation with support for both training and inference modes.
85
+ During inference, uses cached states for efficient token-by-token generation.
86
+ """
87
+ def __init__(self, config: BaseMambaConfig) -> None:
88
+ """Initialize SSM parameters and layers.
89
+ Args:
90
+ config: Configuration containing model hyperparameters
91
+ """
92
+ super().__init__()
93
+ self.config = config
94
+ vars(self).update(vars(config))
95
+
96
+ assert self.dim > 0, "Model dimension (config.dim) must be positive"
97
+ assert self.num_heads > 0, "Number of heads (config.num_heads) must be positive"
98
+ assert self.state_dim > 0, "State dimension (config.state_dim) must be positive"
99
+
100
+ if self.mlp_scale is None:
101
+ raise ValueError(
102
+ "mlp_scale must be set to a valid float (e.g. 2.0) "
103
+ "to determine hidden_dim in SSM."
104
+ )
105
+ assert self.mlp_scale > 0, "mlp_scale must be > 0"
106
+
107
+ self.hidden_dim = int(self.mlp_scale * self.dim)
108
+ self.hidden_dim = config.multiple_of * ( # Round up to multiple_of
109
+ (self.hidden_dim + self.multiple_of - 1) // self.multiple_of
110
+ )
111
+
112
+ assert self.hidden_dim % self.num_heads == 0, (
113
+ f"Hidden dim {self.hidden_dim} not divisible by num_heads={self.num_heads}."
114
+ )
115
+
116
+ self.head_dim = self.hidden_dim // self.num_heads
117
+
118
+ self.dt_limit_kwargs = {}
119
+ dt_limit = (self.init_config.dt_min, self.init_config.dt_max)
120
+ if dt_limit != (0.0, float("inf")):
121
+ self.dt_limit_kwargs = dict(dt_limit=dt_limit)
122
+
123
+ # Order: [z, x, B, C, dt]
124
+ d_input = (
125
+ 2 * self.hidden_dim
126
+ + 2 * self.num_groups * self.state_dim
127
+ + self.num_heads
128
+ )
129
+
130
+ self.input = nn.Linear(self.dim, d_input, bias=self.bias)
131
+
132
+ # Only create Conv1d if self.conv_size is specified
133
+ if self.conv_size is not None:
134
+ conv_dim = self.hidden_dim + 2 * self.num_groups * self.state_dim
135
+
136
+ # Depthwise-ish conv (groups = out_channels)
137
+ # TODO: Check that this is used if causal_conv1d_fn and causal_conv1d_update cannot be imported
138
+ self.conv1d = nn.Conv1d(
139
+ in_channels=conv_dim,
140
+ out_channels=conv_dim,
141
+ kernel_size=self.conv_size,
142
+ groups=conv_dim,
143
+ bias=self.conv_bias, # <- This is a boolean in your config, so pass that or True/False
144
+ padding=self.conv_size - 1 # for "causal" style
145
+ )
146
+
147
+ if config.dt_bias:
148
+ self.dt_bias = nn.Parameter(torch.empty(self.num_heads))
149
+ else:
150
+ self.dt_bias = nn.Parameter(torch.zeros(self.num_heads), requires_grad=False)
151
+
152
+ self.A_log = nn.Parameter(torch.empty(self.num_heads))
153
+
154
+ if config.D_has_head_dim:
155
+ self.D = nn.Parameter(torch.ones(self.num_heads, self.head_dim))
156
+ else:
157
+ self.D = nn.Parameter(torch.ones(self.num_heads))
158
+
159
+ if self.learnable_init_states:
160
+ self.init_states = nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.state_dim))
161
+
162
+ self.norm = build_norm(config.norm_type, dim=self.hidden_dim, eps=self.norm_eps)
163
+ self.output = nn.Linear(self.hidden_dim, self.dim, bias=self.bias)
164
+
165
+ def _causal_conv(
166
+ self,
167
+ zxbcdt: torch.Tensor,
168
+ tok_idx: torch.Tensor | None = None,
169
+ cu_seqlens: torch.Tensor | None = None,
170
+ ssm_impl: str = "ssm"
171
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
172
+ # TODO: Make slightly less verbose
173
+ """Processes input through causal convolution path, handling both full sequence and incremental cases.
174
+
175
+ This function implements two processing modes:
176
+ 1. Full sequence ("ssm"): Used during training and initial prompt processing.
177
+ 2. Incremental ("ssm_update"): Used during token-by-token generation.
178
+
179
+ Args:
180
+ zxbcdt: Input tensor containing concatenated [z, x, B, C, dt] components
181
+ tok_idx: Token indices for sequence processing. Required for "ssm" mode.
182
+ Defaults to None.
183
+ cu_seqlens: Cumulative sequence lengths for variable length processing.
184
+ Used only in "ssm" mode with caching. Defaults to None.
185
+ ssm_impl: Implementation mode, either "ssm" for full sequence processing
186
+ or "ssm_update" for incremental generation. Defaults to "ssm".
187
+
188
+ Returns:
189
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
190
+ Tuple containing separated components (z, x, B, C, dt), where:
191
+ - z: Gating branch
192
+ - x: Main branch
193
+ - B, C: SSM state matrices (analogous to K, Q in attention)
194
+ - dt: Time delta values
195
+
196
+ Notes:
197
+ - When using "ssm" mode during inference, a cache should be pre-initialized
198
+ externally. This design allows for flexible caching strategies without
199
+ modifying model code.
200
+ - The "ssm_update" mode requires a cache to exist and will use it for
201
+ incremental state updates during generation.
202
+ - B, C components correspond to Key, Query in the SSM/attention duality.
203
+ """
204
+ # Split input into components
205
+ z, xBC, dt = torch.split(
206
+ zxbcdt,
207
+ [
208
+ self.hidden_dim,
209
+ self.hidden_dim + 2 * self.num_groups * self.state_dim,
210
+ self.num_heads,
211
+ ],
212
+ dim=-1,
213
+ )
214
+
215
+ if ssm_impl == "ssm":
216
+ if hasattr(self, "cache"):
217
+ conv_varlen_states = causal_conv1d_varlen_states(
218
+ xBC.squeeze(0),
219
+ cu_seqlens,
220
+ state_len=self.cache.conv_cache.shape[-1],
221
+ )
222
+ self.cache.conv_cache.copy_(conv_varlen_states)
223
+
224
+ xBC = causal_conv1d_fn(
225
+ x=xBC.transpose(1, 2),
226
+ weight=self.conv1d.weight.squeeze(1),
227
+ bias=self.conv1d.bias,
228
+ activation="silu",
229
+ seq_idx=tok_idx,
230
+ ).transpose(1, 2)
231
+ elif ssm_impl == "ssm_update":
232
+ xBC = causal_conv1d_update(
233
+ x=xBC.squeeze(0),
234
+ conv_state=self.cache.conv_cache,
235
+ weight=self.conv1d.weight.squeeze(1),
236
+ bias=self.conv1d.bias,
237
+ activation="silu",
238
+ ).unsqueeze(0)
239
+ else:
240
+ raise NotImplementedError(f"SSM implementation {ssm_impl} not supported")
241
+
242
+ # Split processed tensor into components
243
+ x, B, C = torch.split(
244
+ xBC,
245
+ [
246
+ self.hidden_dim,
247
+ self.num_groups * self.state_dim,
248
+ self.num_groups * self.state_dim,
249
+ ],
250
+ dim=-1,
251
+ )
252
+
253
+ return z, x, B, C, dt
254
+
255
+ def _non_causal_conv(self, zxbcdt: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
256
+ z, x, B, C, dt = torch.split(
257
+ zxbcdt,
258
+ [
259
+ self.hidden_dim,
260
+ self.hidden_dim,
261
+ self.num_groups * self.state_dim,
262
+ self.num_groups * self.state_dim,
263
+ self.num_heads,
264
+ ],
265
+ dim=-1,
266
+ )
267
+ return z, x, B, C, dt
268
+
269
+ def _fwd(self, x, dt, A, B, C, tok_idx, cu_seqlens, initial_states):
270
+ """
271
+ For training
272
+
273
+ Returns:
274
+ (bsz, seq_len, num_heads, head_dim)
275
+ """
276
+ y = mamba_chunk_scan_combined(
277
+ x,
278
+ dt,
279
+ A,
280
+ B,
281
+ C,
282
+ dt_bias=self.dt_bias,
283
+ dt_softplus=True,
284
+ chunk_size=self.ssm_chunk_size,
285
+ D=self.D,
286
+ z=None,
287
+ seq_idx=tok_idx,
288
+ cu_seqlens=cu_seqlens,
289
+ initial_states=initial_states,
290
+ **self.dt_limit_kwargs,
291
+ )
292
+
293
+ if hasattr(self, "cache"):
294
+ y, varlen_states = y
295
+ self.cache.state_cache.copy_(varlen_states)
296
+
297
+ return y
298
+
299
+ def _step(self, x, seq_len, dt, A, B, C):
300
+ """
301
+ For inference / generation.
302
+ """
303
+ x = x.squeeze(0)
304
+ A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_dim)
305
+ dt = dt.permute(1, 2, 0).expand(seq_len, self.num_heads, self.head_dim)
306
+ D = self.D
307
+ if D is not None and D.dim() == 1:
308
+ D = D.unsqueeze(1).expand(self.num_heads, self.head_dim)
309
+ B, C = B.squeeze(0), C.squeeze(0)
310
+ y = selective_state_update(
311
+ self.cache.state_cache,
312
+ x,
313
+ dt,
314
+ A,
315
+ B,
316
+ C,
317
+ D,
318
+ z=None,
319
+ dt_bias=(
320
+ torch.zeros(self.num_heads, self.head_dim).to(x)
321
+ if self.dt_bias is None
322
+ else self.dt_bias.unsqueeze(1).expand(self.num_heads, self.head_dim)
323
+ ),
324
+ dt_softplus=True,
325
+ ).unsqueeze(0)
326
+
327
+ return y
328
+
329
+ def forward(
330
+ self,
331
+ x: torch.Tensor,
332
+ tok_idx: torch.Tensor | None = None,
333
+ cu_seqlens: torch.Tensor | None = None,
334
+ ssm_impl: str = "ssm",
335
+ ) -> torch.Tensor:
336
+ bsz, seq_len, _ = x.shape
337
+
338
+ zxbcdt = self.input(x)
339
+
340
+ A = -torch.exp(self.A_log.float())
341
+ initial_states = (
342
+ self.init_states.expand(bsz, -1, -1, -1)
343
+ if self.learnable_init_states else None
344
+ )
345
+
346
+ # Causal conv path
347
+ if self.conv_size is not None:
348
+
349
+ # Memory-efficient Triton kernel path
350
+ if self.use_mem_eff_path:
351
+ out = mamba_split_conv1d_scan_combined(
352
+ zxbcdt,
353
+ self.conv1d.weight.squeeze(1),
354
+ self.conv1d.bias,
355
+ self.dt_bias,
356
+ A,
357
+ D=self.D,
358
+ chunk_size=self.ssm_chunk_size,
359
+ seq_idx=tok_idx,
360
+ activation="silu",
361
+ rmsnorm_weight=self.norm.weight,
362
+ rmsnorm_eps=self.norm.eps,
363
+ outproj_weight=self.output.weight,
364
+ outproj_bias=self.output.bias,
365
+ headdim=self.head_dim,
366
+ ngroups=self.num_groups,
367
+ norm_before_gate=False, # Post-norm, y = self.norm(y * F.silu(z))
368
+ initial_states=initial_states,
369
+ **self.dt_limit_kwargs,
370
+ )
371
+ return out
372
+ else:
373
+ # CUDA kernel path
374
+ z, x, B, C, dt = self._causal_conv(zxbcdt)
375
+ else:
376
+ # Non-causal conv path
377
+ z, x, B, C, dt = self._non_causal_conv(zxbcdt)
378
+
379
+ x = x.view(bsz, seq_len, self.num_heads, self.head_dim)
380
+ B = B.view(bsz, seq_len, self.num_groups, self.state_dim)
381
+ C = C.view(bsz, seq_len, self.num_groups, self.state_dim)
382
+
383
+ # Chunked SSM scan
384
+ if ssm_impl == "ssm":
385
+ # (bsz, seq_len, num_heads, head_dim)
386
+ y = self._fwd(x, dt, A, B, C, tok_idx, cu_seqlens, initial_states)
387
+ elif ssm_impl == "ssm_update":
388
+ y = self._step(x, seq_len, dt, A, B, C)
389
+ else:
390
+ raise NotImplementedError(f"SSM implementation {ssm_impl} not supported")
391
+
392
+ y = y.view(bsz, seq_len, self.hidden_dim)
393
+
394
+ # Could be different activation function, including None.
395
+ # Mamba people post_norm here also (sometimes norm(z)*y or norm(z*y))
396
+ # y = self.norm(y) * F.silu(z)
397
+ y = self.norm(y * F.silu(z))
398
+ out = self.output(y)
399
+
400
+ return out
401
+
402
+ @torch.inference_mode()
403
+ def reset_parameters(self, init_std, factor) -> None:
404
+ config = self.config
405
+ init_config = config.init_config
406
+ if init_config is None:
407
+ init_config = DEFAULT_INIT_CONFIG
408
+
409
+ # Linear layers
410
+ in_init_std = init_std or (self.dim ** (-0.5))
411
+ out_init_std = init_std or (self.hidden_dim ** (-0.5))
412
+ out_init_std = out_init_std / factor
413
+
414
+ nn.init.trunc_normal_(
415
+ self.input.weight,
416
+ mean=0.0,
417
+ std=in_init_std,
418
+ a=-3 * in_init_std,
419
+ b=3 * in_init_std,
420
+ )
421
+
422
+ nn.init.trunc_normal_(
423
+ self.output.weight,
424
+ mean=0.0,
425
+ std=out_init_std,
426
+ a=-3 * out_init_std,
427
+ b=3 * out_init_std,
428
+ )
429
+
430
+ # SSM
431
+ if self.dt_bias is not None and self.dt_bias.requires_grad:
432
+ log_dt_min = math.log(init_config.dt_min)
433
+ log_dt_max = math.log(init_config.dt_max)
434
+
435
+ # Sample log_dt ~ Uniform[log_dt_min, log_dt_max]
436
+ log_dt = torch.rand(self.num_heads, device=self.dt_bias.device) * (log_dt_max - log_dt_min) + log_dt_min
437
+ dt = torch.exp(log_dt)
438
+ dt = torch.clamp(dt, min=init_config.dt_init_floor)
439
+
440
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
441
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
442
+ self.dt_bias.copy_(inv_dt)
443
+
444
+ elif self.dt_bias is not None:
445
+ # If dt_bias is not trainable, we can just keep it zero or set to any constant
446
+ self.dt_bias.fill_(0.0)
447
+
448
+ # Convolution
449
+ if self.conv_size is not None:
450
+ conv_std = init_std or (self.conv_size ** (-0.5))
451
+ nn.init.trunc_normal_(
452
+ self.conv1d.weight,
453
+ mean=0.0,
454
+ std=conv_std,
455
+ a=-3 * conv_std,
456
+ b=3 * conv_std,
457
+ )
458
+ if self.conv1d.bias is not None:
459
+ nn.init.zeros_(self.conv1d.bias)
460
+
461
+ # Learnable init states
462
+ if self.learnable_init_states:
463
+ self.init_states.zero_()
464
+
465
+ # Initialize A_log ~ log( Uniform(A_init_min, A_init_max) )
466
+ self.A_log.uniform_(init_config.A_init_min, init_config.A_init_max)
467
+ self.A_log.log_()
468
+
469
+ if self.D is not None:
470
+ self.D.data.fill_(1.0)
471
+
472
+ # Reset norm parameters
473
+ self.norm.reset_parameters()
474
+
475
+
476
+ class MambaBlock(nn.Module):
477
+ def __init__(self, config: BaseMambaConfig):
478
+ super().__init__()
479
+ self.norm = build_norm(config.norm_type, dim=config.dim, eps=config.norm_eps)
480
+ self.ssm = SSM(config)
481
+
482
+ def forward(
483
+ self,
484
+ x: torch.Tensor,
485
+ tok_idx: torch.Tensor | None,
486
+ cu_seqlens: torch.Tensor | None,
487
+ ssm_impl: str = "ssm",
488
+ ) -> torch.Tensor:
489
+ x = x + self.ssm(self.norm(x), tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
490
+ return x
491
+
492
+ @torch.inference_mode()
493
+ def init_weights(self, init_std=None, factor=1.0):
494
+ self.norm.reset_parameters()
495
+ self.ssm.reset_parameters(init_std, factor)
496
+
497
+
498
+ class BaseMamba(nn.Module):
499
+ def __init__(self, config: BaseMambaConfig):
500
+ super().__init__()
501
+ assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
502
+ self.head_dim = config.dim // config.num_heads
503
+
504
+ self.model_dim = config.dim
505
+ self.init_base_std = config.init_base_std
506
+
507
+ self.init_config = config.init_config
508
+ self.init_std_factor = InitStdFactor(config.init_std_factor)
509
+
510
+ # From pytorch/pytorch#123411, we set persistent=True for torch.compile and PP compatibility
511
+ self.register_buffer("freqs_cis", precompute_freqs_cis(
512
+ head_dim=self.head_dim,
513
+ max_seq_len=config.seq_len,
514
+ theta=config.theta,
515
+ ), persistent=True)
516
+
517
+ self.layers = nn.ModuleList()
518
+ for layer_idx in range(config.num_layers):
519
+ # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
520
+ if layer_idx % 2 == 0:
521
+ self.layers.append(MambaBlock(config))
522
+ else:
523
+ self.layers.append(
524
+ AttentionLayer(config)
525
+ if config.use_attn
526
+ else (MambaBlock(config))
527
+ )
528
+
529
+ def _unwrap(self, layer: nn.Module) -> nn.Module:
530
+ """Helper function to find the underlying layer name (if wrapped in DDP or FSDP)"""
531
+ while hasattr(layer, "module"):
532
+ layer = layer.module
533
+ return layer
534
+
535
+ def forward(
536
+ self,
537
+ h: torch.Tensor,
538
+ tok_idx: torch.Tensor | None,
539
+ cu_seqlens: torch.Tensor | None,
540
+ ssm_impl: str = "ssm",
541
+ ) -> torch.Tensor:
542
+ for layer in self.layers:
543
+ unwrapped_layer = self._unwrap(layer)
544
+ if isinstance(unwrapped_layer, MambaBlock):
545
+ h = unwrapped_layer(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
546
+ elif isinstance(unwrapped_layer, AttentionLayer):
547
+ h = unwrapped_layer(h, self.freqs_cis)
548
+ else:
549
+ raise ValueError(f"ERROR: Unexpected layer type: {type(unwrapped_layer).__name__}")
550
+ return h
551
+
552
+ @torch.inference_mode()
553
+ def reset_parameters(self):
554
+ pass
555
+
556
+ @torch.inference_mode()
557
+ def init_weights(self):
558
+ self.reset_parameters()
559
+ for depth, layer in enumerate(self.layers):
560
+ factor = {
561
+ InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
562
+ InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
563
+ InitStdFactor.DIM_RATIO: self.model_dim / 4096,
564
+ InitStdFactor.DISABLED: 1.0,
565
+ }[self.init_std_factor]
566
+
567
+ if not hasattr(layer, "attn"): # Only initialize Mamba layers
568
+ layer.init_weights(self.init_base_std, factor)
569
+
570
+
571
+ @dataclass
572
+ class Mamba2Config(BaseMambaConfig):
573
+ seed: int = 1337
574
+
575
+ vocab_size: int = -1 # Will error if unchanged, makes you double check!
576
+ seq_len: int = 8192
577
+ window_size: int = 1024
578
+ weight_tying: bool = False
579
+ torch_dtype: torch.dtype = torch.bfloat16
580
+
581
+ loss_reduction: str = "mean"
582
+
583
+ use_attn: bool = True
584
+ use_alibi: bool = True
585
+ dropout: float = 0.0
586
+ softcap: float = 50.0
587
+ theta: float = 10000.0
588
+
589
+ device: torch.device = None
590
+ dtype: torch.dtype = torch.bfloat16
591
+
592
+
593
+ class Mamba2(BaseMamba):
594
+ def __init__(self, config: Mamba2Config) -> None:
595
+ super().__init__(config)
596
+ self.weight_tying = config.weight_tying
597
+ self.loss_reduction = config.loss_reduction
598
+
599
+ assert config.vocab_size > 0, "vocab_size must be set and > 0"
600
+
601
+ self.tok_emb = torch.nn.Embedding(config.vocab_size, config.dim)
602
+
603
+ self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
604
+
605
+ self.output = nn.Linear(
606
+ config.dim,
607
+ config.vocab_size,
608
+ bias=False,
609
+ )
610
+
611
+ if config.weight_tying:
612
+ self.output.weight = self.tok_emb.weight
613
+
614
+ print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
615
+
616
+ def _get_num_params(self):
617
+ n_params = sum(p.numel() for p in self.parameters())
618
+
619
+ if hasattr(self, "pos_emb") and self.pos_emb is not None:
620
+ n_params -= self.pos_emb.weight.numel()
621
+
622
+ return n_params
623
+
624
+ def forward(
625
+ self,
626
+ x: torch.Tensor,
627
+ target: torch.Tensor | None = None,
628
+ tok_idx: torch.Tensor | None = None,
629
+ cu_seqlens: torch.Tensor | None = None,
630
+ ssm_impl: str = "ssm",
631
+ ) -> torch.Tensor:
632
+ h = self.tok_emb(x)
633
+ h = super().forward(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
634
+ logits = self.output(self.norm(h))
635
+ return logits
636
+
637
+ @torch.inference_mode()
638
+ def reset_parameters(self, init_std=None):
639
+ # Either use fixed base std or sqrt model dim
640
+ super().reset_parameters()
641
+ init_std = init_std or (self.model_dim ** (-0.5))
642
+ self.norm.reset_parameters()
643
+ nn.init.trunc_normal_(
644
+ self.tok_emb.weight,
645
+ mean=0.0,
646
+ std=init_std,
647
+ a=-3 * init_std,
648
+ b=3 * init_std,
649
+ )
650
+ if not self.weight_tying:
651
+ nn.init.trunc_normal_(
652
+ self.output.weight,
653
+ mean=0.0,
654
+ std=init_std,
655
+ a=-3 * init_std,
656
+ b=3 * init_std,
657
+ )
658
+
659
+ @torch.inference_mode()
660
+ def init_weights(self, buffer_device: torch.device = None):
661
+ """
662
+ Initialize model parameters and optionally compute buffers on a specific device.
663
+
664
+ Args:
665
+ buffer_device (torch.device, optional): If provided, any large or precomputed
666
+ buffers (like RoPE frequency tensors) will be allocated or re-created on
667
+ this device during initialization. This can avoid overhead from transferring
668
+ buffers between CPU and GPU after creation. If None, buffers default to the
669
+ device of the first parameter or CPU.
670
+
671
+ Usage:
672
+ - Pass a GPU device (e.g., ``torch.device('cuda')``) when you want to ensure
673
+ buffers are created directly on GPU, preventing extra transfers.
674
+ - Pass a CPU device (e.g., ``torch.device('cpu')``) if you want to keep
675
+ large buffers in CPU memory (common in CPU-offload or pipeline-parallel setups).
676
+ - Leave it as ``None`` to rely on the model’s existing parameter device or
677
+ the default PyTorch device context.
678
+
679
+ When / Why:
680
+ - Useful in distributed or pipeline-parallel training where parameters may
681
+ initially live on CPU, but you still need certain buffers on GPU to avoid
682
+ overhead during forward passes.
683
+ - Prevents large re-allocations or re-copies when big buffers (like RoPE
684
+ frequency tables) are needed per rank.
685
+ """
686
+ super().init_weights()
687
+
688
+ @classmethod
689
+ def from_model_args(cls, config: Mamba2Config) -> "Mamba2":
690
+ """
691
+ Initialize a Mamba model from a MambaConfig object.
692
+
693
+ Args:
694
+ config (MambaConfig): Mamba configuration arguments.
695
+
696
+ Returns:
697
+ Mamba: Mamba-2 model.
698
+ """
699
+ return cls(config)
700
+
701
+
702
+ if __name__ == '__main__':
703
+ import json
704
+
705
+ config_path = "config.json"
706
+
707
+ with open(config_path, "r") as f:
708
+ config_data = json.load(f)
709
+
710
+ if torch.cuda.is_available():
711
+ device = torch.device("cuda")
712
+ elif torch.backends.mps.is_available():
713
+ device = torch.device("mps")
714
+ else:
715
+ device = torch.device("cpu")
716
+ print("Device:", device)
717
+
718
+ torch_dtype = getattr(torch, config_data["torch_dtype"])
719
+ print("Torch dtype:", torch_dtype)
720
+
721
+ dim = config_data["dim"]
722
+ num_heads = config_data["num_heads"]
723
+ num_layers = config_data["num_layers"]
724
+ vocab_size = config_data["vocab_size"]
725
+ bias = config_data["bias"]
726
+ state_dim = config_data["state_dim"]
727
+ num_groups = config_data["num_groups"]
728
+ conv_size = config_data.get("conv_size")
729
+ use_mem_eff_path = config_data.get("use_mem_eff_path")
730
+ dt_bias = config_data["dt_bias"]
731
+ D_has_head_dim = config_data["D_has_head_dim"]
732
+ learnable_init_states = config_data["learnable_init_states"]
733
+ ssm_chunk_size = config_data["ssm_chunk_size"]
734
+ weight_tying = config_data["weight_tying"]
735
+ mlp_scale = config_data.get("mlp_scale")
736
+ multiple_of = config_data["multiple_of"]
737
+ norm_eps = config_data["norm_eps"]
738
+ init_use_depth = config_data["init_use_depth"]
739
+ init_base_std = config_data.get("init_base_std")
740
+ init_std_factor = config_data["init_std_factor"]
741
+ use_attn = config_data["use_attn"]
742
+ softcap = config_data["softcap"]
743
+ torch_compile = config_data["torch_compile"]
744
+
745
+ configs = Mamba2Config(
746
+ dim=dim,
747
+ num_layers=num_layers,
748
+ num_heads=num_heads,
749
+ vocab_size=vocab_size,
750
+ bias=bias,
751
+ torch_dtype=torch_dtype,
752
+ state_dim=state_dim,
753
+ num_groups=num_groups,
754
+ conv_size=conv_size,
755
+ use_mem_eff_path=use_mem_eff_path,
756
+ dt_bias=dt_bias,
757
+ D_has_head_dim=D_has_head_dim,
758
+ learnable_init_states=learnable_init_states,
759
+ ssm_chunk_size=ssm_chunk_size,
760
+ weight_tying=weight_tying,
761
+ mlp_scale=mlp_scale,
762
+ multiple_of=multiple_of,
763
+ norm_eps=norm_eps,
764
+ init_use_depth=init_use_depth,
765
+ init_base_std=init_base_std,
766
+ init_std_factor=init_std_factor,
767
+ use_attn=use_attn,
768
+ softcap=softcap,
769
+ )
770
+
771
+ print("Configs:")
772
+ for key, value in vars(configs).items():
773
+ print(f" {key}: {value}")
774
+
775
+ model = Mamba2(configs).to(device=device, dtype=torch_dtype)
776
+
777
+ x = torch.randint(
778
+ 0, configs.vocab_size,
779
+ (config_data["bsz"], config_data["seq_len"]),
780
+ dtype=torch.long
781
+ ).to(device)
782
+
783
+ outputs = model(x)
784
+
785
+ print("Output shape:", outputs.shape)
786
+ print("Sample output:", outputs[0, 0, :10])
787
+ print("Mean of Mamba output: ", outputs.mean().item())
788
+ print("Stddev of Mamba output: ", outputs.std().item())
modeling_minimamba.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import CausalLMOutput
8
+
9
+ from .configuration_minimamba import MiniMambaConfig
10
+ from .model import Mamba2, Mamba2Config
11
+
12
+
13
+
14
+ class MiniMamba(PreTrainedModel):
15
+ """
16
+ A Hugging Face–style wrapper around a Mamba2 model, providing:
17
+ • forward(...) returning a CausalLMOutput
18
+ • support for HF training loops
19
+ • a naive generate(...) method with top-k/top-p sampling
20
+ """
21
+ config_class = MiniMambaConfig # Tells HF which config class to use
22
+
23
+ def __init__(self, config: MiniMambaConfig) -> None:
24
+ """
25
+ Initialize the MiniMamba model, bridging Mamba2 with HF's PreTrainedModel.
26
+ """
27
+ super().__init__(config)
28
+
29
+ # If your config includes Mamba2-like parameters, you can build a Mamba2Config from it:
30
+ mamba2_args = Mamba2Config(
31
+ dim=config.dim,
32
+ num_layers=config.num_layers,
33
+ num_heads=config.num_heads,
34
+ state_dim=config.state_dim,
35
+ num_groups=config.num_groups,
36
+ conv_size=config.conv_size,
37
+ use_mem_eff_path=config.use_mem_eff_path,
38
+ dt_bias=config.dt_bias,
39
+ D_has_head_dim=config.D_has_head_dim,
40
+ learnable_init_states=config.learnable_init_states,
41
+ ssm_chunk_size=config.ssm_chunk_size,
42
+ vocab_size=config.vocab_size,
43
+ ffn_dim_multiplier=config.ffn_dim_multiplier,
44
+ multiple_of=config.multiple_of,
45
+ norm_eps=config.norm_eps,
46
+ init_use_depth=config.init_use_depth,
47
+ init_base_std=config.init_base_std,
48
+ init_std_factor=config.init_std_factor,
49
+ bias=config.bias,
50
+ softcap=config.softcap,
51
+ use_attn=config.use_attn,
52
+ # Torch / training:
53
+ seed=config.seed,
54
+
55
+ # The init_config block nested in JSON:
56
+
57
+
58
+ # Additional Mamba or training fields:
59
+ mlp_scale=config.mlp_scale,
60
+
61
+
62
+
63
+ weight_tying=config.weight_tying if hasattr(config, "weight_tying") else False,
64
+ torch_dtype=getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype,
65
+ )
66
+
67
+ # Internally hold a Mamba2 model
68
+ self.mamba = Mamba2(config=mamba2_args)
69
+
70
+ # Because HF wants the final linear to be part of this top-level model,
71
+ # you *can* rely on Mamba2’s built-in embedding + output if you prefer.
72
+ # Mamba2 already has self.tok_emb and self.output.
73
+ # So we typically do NOT need a separate embedding or lm_head here.
74
+ #
75
+ # We only do so if we want the “HF standard” tie-weights approach:
76
+ # self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
77
+ # self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
78
+ # self.lm_head.weight = self.tok_emb.weight
79
+ #
80
+ # But Mamba2 does that internally if config.weight_tying == True.
81
+
82
+ # This is optional: store any device or dtype you might want
83
+ self.device_ = 'cuda' if torch.cuda.is_available() else 'cpu'
84
+ if isinstance(config.torch_dtype, str):
85
+ self.dtype_ = getattr(torch, config.torch_dtype)
86
+ else:
87
+ self.dtype_ = config.torch_dtype
88
+
89
+ # Parameter initialization (HF calls them with self._init_weights in some flows).
90
+ self.apply(self._init_weights)
91
+
92
+ print("MiniMamba Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
93
+
94
+ def forward(
95
+ self,
96
+ input_ids: torch.LongTensor,
97
+ labels: torch.LongTensor = None,
98
+ **kwargs
99
+ ) -> CausalLMOutput:
100
+ """
101
+ Forward pass for causal language modeling.
102
+ Returns a CausalLMOutput that includes loss (if labels is provided) and logits.
103
+ """
104
+ # Mamba2's forward expects (x: torch.Tensor, target: torch.Tensor|None, ...)
105
+ # but we only need the logits from the simple call:
106
+ logits = self.mamba(input_ids) # shape: [batch, seq_len, vocab_size]
107
+
108
+ loss = None
109
+ if labels is not None:
110
+ # By default, huggingface GPT-like models shift the logits by one
111
+ shift_logits = logits[..., :-1, :].contiguous()
112
+ shift_labels = labels[..., 1:].contiguous()
113
+ loss_fct = nn.CrossEntropyLoss()
114
+ loss = loss_fct(
115
+ shift_logits.view(-1, shift_logits.size(-1)),
116
+ shift_labels.view(-1)
117
+ )
118
+
119
+ return CausalLMOutput(
120
+ loss=loss,
121
+ logits=logits,
122
+ )
123
+
124
+ @torch.no_grad()
125
+ def generate(
126
+ self,
127
+ input_ids: torch.LongTensor,
128
+ max_new_tokens: int = 50,
129
+ temperature: float = 0.5,
130
+ top_k: int = 50,
131
+ top_p: float = 0.95,
132
+ eos_token_id: int = None,
133
+ pad_token_id: int = 0,
134
+ **kwargs
135
+ ):
136
+ """
137
+ A naive token-by-token generation loop (greedy + top-k/top-p + temperature).
138
+ """
139
+ # We'll accumulate new tokens in generated_ids
140
+ generated_ids = input_ids.clone()
141
+
142
+ for _ in range(max_new_tokens):
143
+ # Forward pass to get logits for the last token
144
+ outputs = self.forward(generated_ids)
145
+ logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size)
146
+
147
+ # Scale by temperature
148
+ if temperature != 1.0:
149
+ logits = logits / temperature
150
+
151
+ # Filter
152
+ logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
153
+
154
+ # Sample next token
155
+ probs = F.softmax(logits, dim=-1)
156
+ next_token = torch.multinomial(probs, num_samples=1) # shape: (batch, 1)
157
+
158
+ # Append
159
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
160
+
161
+ # If we have an EOS token, we can break early if all sequences have ended
162
+ if eos_token_id is not None and (next_token == eos_token_id).all():
163
+ break
164
+
165
+ return generated_ids
166
+
167
+ @staticmethod
168
+ def top_k_top_p_filtering(
169
+ logits: torch.Tensor,
170
+ top_k: int = 50,
171
+ top_p: float = 0.95,
172
+ filter_value: float = float("-inf"),
173
+ ):
174
+ """
175
+ Filters logits using top-k and/or nucleus (top-p) filtering.
176
+ """
177
+ # top_k
178
+ if top_k > 0:
179
+ top_k = min(top_k, logits.size(-1))
180
+ indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
181
+ logits[indices_to_remove] = filter_value
182
+
183
+ # top_p (nucleus)
184
+ if 0 < top_p < 1.0:
185
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
186
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
187
+
188
+ # Remove tokens with cumulative probability above the threshold
189
+ sorted_indices_to_remove = cumulative_probs > top_p
190
+
191
+ # Shift right to keep also the first token above threshold
192
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
193
+ sorted_indices_to_remove[:, 0] = False
194
+
195
+ # Scatter to get back to original indexing
196
+ indices_to_remove = sorted_indices_to_remove.scatter(
197
+ dim=1, index=sorted_indices, src=sorted_indices_to_remove
198
+ )
199
+ logits[indices_to_remove] = filter_value
200
+
201
+ return logits
202
+
203
+ def _init_weights(self, module):
204
+ """
205
+ HF calls _init_weights to initialize parameters.
206
+ If you prefer Mamba’s own init approach, you can call model.mamba.init_weights().
207
+ """
208
+ # As an example, we just call Mamba2's init routine for the entire submodel,
209
+ # or do some standard PyTorch inits for linear layers, embeddings, etc.
210
+ if isinstance(module, Mamba2):
211
+ module.init_weights() # Mamba2’s internal init
212
+ elif isinstance(module, nn.Linear):
213
+ # e.g. standard xavier or normal init
214
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
215
+ if module.bias is not None:
216
+ nn.init.zeros_(module.bias)
217
+ elif isinstance(module, nn.Embedding):
218
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
219
+ # If needed, do your specialized inits for other modules
220
+
221
+ def _get_num_params(self):
222
+ # Count trainable params, subtract duplicates if tying weights, etc.
223
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
norms.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ norms.py
2
+ """Adapted from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py"""
3
+
4
+ import math
5
+
6
+ from functools import partial
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from torch.distributed._tensor import Partial, Replicate, Shard
15
+ from torch.distributed._tensor.experimental import local_map
16
+ from torch._utils import _get_available_device_type, _get_device_module
17
+
18
+
19
+ def get_device_info():
20
+ device_type = _get_available_device_type()
21
+
22
+ if device_type is None:
23
+ device_type = "cuda" # Default to CUDA
24
+
25
+ device_module = _get_device_module(device_type)
26
+ return device_type, device_module
27
+
28
+ device_type, device_module = get_device_info()
29
+
30
+ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
31
+ """
32
+ Builds the specified normalization layer based on the norm_type.
33
+
34
+ Args:
35
+ norm_type (str): The type of normalization layer to build.
36
+ Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm
37
+ dim (int): The dimension of the normalization layer.
38
+ eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
39
+
40
+ Returns:
41
+ The built normalization layer.
42
+
43
+ Raises:
44
+ NotImplementedError: If an unknown norm_type is provided.
45
+ """
46
+ norm_type = norm_type.lower() # Normalize to lowercase
47
+
48
+ if norm_type == "layernorm":
49
+ return nn.LayerNorm(dim, eps=eps, bias=False)
50
+ elif norm_type == "np_layernorm":
51
+ return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
52
+ elif norm_type == "rmsnorm":
53
+ return RMSNorm(dim, eps=eps)
54
+ elif norm_type == "fused_rmsnorm":
55
+ return FusedRMSNorm(dim, eps=eps)
56
+ else:
57
+ raise NotImplementedError(f"Unknown norm_type: '{norm_type}'")
58
+
59
+
60
+ class FusedRMSNorm(nn.Module):
61
+ """Fused RMS Norm, wraps a fused Triton Kernel"""
62
+
63
+ def __init__(
64
+ self,
65
+ dim: int,
66
+ eps: float = 1e-6,
67
+ ):
68
+ super().__init__()
69
+ self.eps = eps
70
+ self.weight = nn.Parameter(torch.ones(dim))
71
+ self.fused_rms_norm_fn = fused_rms_norm_fn
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ """leverages Triton Fused RMS Norm kernel"""
75
+ return self.fused_rms_norm_fn(
76
+ x,
77
+ self.weight,
78
+ eps=self.eps,
79
+ )
80
+
81
+ def reset_parameters(self):
82
+ torch.nn.init.ones_(self.weight) # type: ignore
83
+
84
+
85
+ class RMSNorm(torch.nn.Module):
86
+ def __init__(self, dim: int, eps: float = 1e-6):
87
+ """
88
+ Initialize the RMSNorm normalization layer.
89
+
90
+ Args:
91
+ dim (int): The dimension of the input tensor.
92
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
93
+
94
+ Attributes:
95
+ eps (float): A small value added to the denominator for numerical stability.
96
+ weight (nn.Parameter): Learnable scaling parameter.
97
+
98
+ """
99
+ super().__init__()
100
+ self.eps = eps
101
+ self.weight = nn.Parameter(torch.ones(dim))
102
+
103
+ def _norm(self, x):
104
+ """
105
+ Apply the RMSNorm normalization to the input tensor.
106
+
107
+ Args:
108
+ x (torch.Tensor): The input tensor.
109
+
110
+ Returns:
111
+ torch.Tensor: The normalized tensor.
112
+
113
+ """
114
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
115
+
116
+ def forward(self, x):
117
+ """
118
+ Forward pass through the RMSNorm layer.
119
+
120
+ Args:
121
+ x (torch.Tensor): The input tensor.
122
+
123
+ Returns:
124
+ torch.Tensor: The output tensor after applying RMSNorm.
125
+
126
+ """
127
+ output = self._norm(x.float()).type_as(x)
128
+ return output * self.weight
129
+
130
+ def reset_parameters(self):
131
+ torch.nn.init.ones_(self.weight) # type: ignore
132
+
133
+
134
+ # FusedRMSNorm in Triton
135
+
136
+ # Credit
137
+ # Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
138
+ # Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
139
+
140
+
141
+ @triton.autotune(
142
+ configs=[
143
+ triton.Config({}, num_warps=1),
144
+ triton.Config({}, num_warps=2),
145
+ triton.Config({}, num_warps=4),
146
+ triton.Config({}, num_warps=8),
147
+ triton.Config({}, num_warps=16),
148
+ triton.Config({}, num_warps=32),
149
+ ],
150
+ key=["N"],
151
+ )
152
+ @triton.jit
153
+ def _rms_norm_fwd_kernel(
154
+ X,
155
+ stride_x,
156
+ Y,
157
+ stride_y,
158
+ W,
159
+ Rstd,
160
+ eps,
161
+ M, # num rows
162
+ N, # num cols
163
+ block_N: tl.constexpr,
164
+ ):
165
+ row = tl.program_id(0)
166
+ cols = tl.arange(0, block_N)
167
+
168
+ # Load input data and weights
169
+ mask = cols < N
170
+ x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
171
+ w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
172
+
173
+ # Compute mean and variance
174
+ xbar = tl.where(cols < N, x, 0.0)
175
+ var = tl.sum(xbar * xbar, axis=0) / N
176
+ rstd = 1 / tl.sqrt(var + eps)
177
+
178
+ # Store the reciprocal standard deviation
179
+ tl.store(Rstd + row, rstd)
180
+
181
+ # Normalize and apply linear transformation
182
+ x_hat = x * rstd
183
+ y = x_hat * w
184
+
185
+ # Write output
186
+ tl.store(Y + row * stride_y + cols, y, mask=mask)
187
+
188
+
189
+ @triton.autotune(
190
+ configs=[
191
+ triton.Config({}, num_warps=1),
192
+ triton.Config({}, num_warps=2),
193
+ triton.Config({}, num_warps=4),
194
+ triton.Config({}, num_warps=8),
195
+ triton.Config({}, num_warps=16),
196
+ triton.Config({}, num_warps=32),
197
+ ],
198
+ key=["N"],
199
+ )
200
+ @triton.jit
201
+ def _rms_norm_bwd_kernel_sm(
202
+ X,
203
+ stride_x,
204
+ W,
205
+ DY,
206
+ stride_dy,
207
+ DX,
208
+ stride_dx,
209
+ Rstd,
210
+ DW,
211
+ eps,
212
+ M, # num rows
213
+ N, # num cols
214
+ rows_per_program,
215
+ block_N: tl.constexpr,
216
+ ):
217
+ row_block_id = tl.program_id(0)
218
+ row_start = row_block_id * rows_per_program
219
+ cols = tl.arange(0, block_N)
220
+ mask = cols < N
221
+
222
+ # Load weights
223
+ w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
224
+
225
+ # Accumulate gradients for weights
226
+ dw = tl.zeros((block_N,), dtype=tl.float32)
227
+
228
+ row_end = min(row_start + rows_per_program, M)
229
+ for row in range(row_start, row_end):
230
+ # Load input, output gradient, and reciprocal standard deviation
231
+ x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
232
+ dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32)
233
+ rstd = tl.load(Rstd + row)
234
+
235
+ # Compute normalized input and gradients
236
+ x_hat = x * rstd
237
+ wdy = w * dy
238
+ dw += dy * x_hat
239
+ c1 = tl.sum(x_hat * wdy, axis=0) / N
240
+ dx = (wdy - x_hat * c1) * rstd
241
+
242
+ # Store input gradient
243
+ tl.store(DX + row * stride_dx + cols, dx, mask=mask)
244
+
245
+ # Store weight gradients
246
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
247
+
248
+
249
+ class TritonFusedRMSNorm(torch.autograd.Function):
250
+ @partial(
251
+ local_map,
252
+ out_placements=[Shard(1)],
253
+ in_placements=(None, [Shard(1)], [Replicate()], None),
254
+ )
255
+ @staticmethod
256
+ def forward(ctx, x, weight, eps):
257
+ x_shape_start = x.shape
258
+
259
+ # Flatten input
260
+ x = x.view(-1, x.shape[-1])
261
+ if x.stride(-1) != 1:
262
+ x = x.contiguous()
263
+ if weight.stride(-1) != 1:
264
+ weight = weight.contiguous()
265
+
266
+ M, N = x.shape
267
+ y = torch.empty_like(x)
268
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
269
+
270
+ max_size = 65536 // x.element_size()
271
+ block_N = min(max_size, triton.next_power_of_2(N))
272
+
273
+ if N > block_N:
274
+ raise ValueError(f"N {N} must be <= {block_N=}")
275
+
276
+ grid = lambda meta: (M,)
277
+ _rms_norm_fwd_kernel[grid](
278
+ x,
279
+ x.stride(0),
280
+ y,
281
+ y.stride(0),
282
+ weight,
283
+ rstd,
284
+ eps,
285
+ M,
286
+ N,
287
+ block_N,
288
+ )
289
+
290
+ ctx.eps = eps
291
+ ctx.save_for_backward(x, weight, rstd)
292
+ ctx.x_shape_start = x_shape_start
293
+
294
+ y = y.reshape(x_shape_start)
295
+ return y
296
+
297
+ @partial(
298
+ local_map,
299
+ out_placements=([Shard(1)], [Partial()], None),
300
+ in_placements=(None, [Shard(1)]),
301
+ )
302
+ @staticmethod
303
+ def backward(ctx, dy):
304
+ x, weight, rstd = ctx.saved_tensors
305
+ eps = ctx.eps
306
+ x_shape_start = ctx.x_shape_start
307
+
308
+ # Flatten input and output gradients
309
+ dy = dy.view(-1, dy.shape[-1])
310
+ if dy.stride(-1) != 1:
311
+ dy = dy.contiguous()
312
+
313
+ M, N = dy.shape
314
+ dx = torch.empty_like(x)
315
+
316
+ sm_count = device_module.get_device_properties(x.device).multi_processor_count
317
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
318
+
319
+ max_size = 65536 // x.element_size()
320
+ block_N = min(max_size, triton.next_power_of_2(N))
321
+ rows_per_sm = math.ceil(M / sm_count)
322
+
323
+ if N > block_N:
324
+ raise ValueError(f"N {N} must be <= {block_N=}")
325
+
326
+ grid = lambda meta: (sm_count,)
327
+ _rms_norm_bwd_kernel_sm[grid](
328
+ x,
329
+ x.stride(0),
330
+ weight,
331
+ dy,
332
+ dy.stride(0),
333
+ dx,
334
+ dx.stride(0),
335
+ rstd,
336
+ _dw,
337
+ eps,
338
+ M,
339
+ N,
340
+ rows_per_sm,
341
+ block_N,
342
+ )
343
+ dw = _dw.sum(0).to(weight.dtype)
344
+ dx = dx.view(x_shape_start)
345
+ return dx, dw, None
346
+
347
+
348
+ # expose fusedRMSNorm as a function
349
+ def fused_rms_norm_fn(
350
+ x,
351
+ weight,
352
+ eps=1e-6,
353
+ ):
354
+ return TritonFusedRMSNorm.apply(
355
+ x,
356
+ weight,
357
+ eps,
358
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
ssm_compilable.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+ import torch
3
+
4
+ from mamba_ssm.ops.triton.ssd_combined import _mamba_chunk_scan_combined_fwd, _mamba_chunk_scan_combined_bwd
5
+
6
+
7
+ @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
8
+ def _compiled_mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=None):
9
+ return _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
10
+
11
+ @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
12
+ def _compiled_mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False, dt_limit=None):
13
+ return _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit)
14
+
15
+
16
+ @torch.library.custom_op(
17
+ "mamba_ssm::ssm_chunk_scan_combined_fwd",
18
+ mutates_args=(),
19
+ device_types="cuda",
20
+ )
21
+ def ssm_chunk_scan_combined_fwd(
22
+ x: torch.Tensor,
23
+ dt: torch.Tensor,
24
+ A: torch.Tensor,
25
+ B: torch.Tensor,
26
+ C: torch.Tensor,
27
+ chunk_size: int,
28
+ D: Optional[torch.Tensor] = None,
29
+ z: Optional[torch.Tensor] = None,
30
+ dt_bias: Optional[torch.Tensor] = None,
31
+ initial_states: Optional[torch.Tensor] = None,
32
+ seq_idx: Optional[torch.Tensor] = None,
33
+ cu_seqlens: Optional[torch.Tensor] = None,
34
+ dt_softplus: bool = False,
35
+ dt_limit: Optional[List[float]] = None
36
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
37
+ out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
38
+
39
+ return out, out_x if out_x is not None else out.new_empty(0), rest[0] if cu_seqlens is not None else out.new_empty(0)
40
+
41
+ @ssm_chunk_scan_combined_fwd.register_fake
42
+ def _ssm_chunk_scan_combined_fwd_fake(
43
+ x: torch.Tensor,
44
+ dt: torch.Tensor,
45
+ A: torch.Tensor,
46
+ B: torch.Tensor,
47
+ C: torch.Tensor,
48
+ chunk_size: int,
49
+ D: Optional[torch.Tensor] = None,
50
+ z: Optional[torch.Tensor] = None,
51
+ dt_bias: Optional[torch.Tensor] = None,
52
+ initial_states: Optional[torch.Tensor] = None,
53
+ seq_idx: Optional[torch.Tensor] = None,
54
+ cu_seqlens: Optional[torch.Tensor] = None,
55
+ dt_softplus: bool = False,
56
+ dt_limit: Optional[List[float]] = None
57
+ ):
58
+ _, _, n_heads, head_dim = x.shape
59
+ return (
60
+ torch.empty_like(x),
61
+ torch.empty_like(x) if z is not None else None,
62
+ x.new_empty((cu_seqlens.size(0)-1, n_heads, head_dim, B.size(0))) if cu_seqlens is not None else None,
63
+ )
64
+
65
+ @torch.library.custom_op(
66
+ "mamba_ssm::ssm_chunk_scan_combined_bwd",
67
+ mutates_args=(),
68
+ device_types="cuda",
69
+ )
70
+ def ssm_chunk_scan_combined_bwd(
71
+ dout: torch.Tensor,
72
+ x: torch.Tensor,
73
+ dt: torch.Tensor,
74
+ A: torch.Tensor,
75
+ B: torch.Tensor,
76
+ C: torch.Tensor,
77
+ out: torch.Tensor,
78
+ chunk_size: int,
79
+ D: Optional[torch.Tensor] = None,
80
+ z: Optional[torch.Tensor] = None,
81
+ dt_bias: Optional[torch.Tensor] = None,
82
+ initial_states: Optional[torch.Tensor] = None,
83
+ seq_idx: Optional[torch.Tensor] = None,
84
+ dt_softplus: bool = False,
85
+ dt_limit: Optional[List[float]] = None
86
+ )-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
87
+ dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=None, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit)
88
+ return (
89
+ dx,
90
+ ddt,
91
+ dA,
92
+ dB,
93
+ dC,
94
+ dD if dD is not None else dx.new_empty(0),
95
+ dz if dz is not None else dx.new_empty(0),
96
+ ddt_bias if ddt_bias is not None else dx.new_empty(0),
97
+ dinitial_states if dinitial_states is not None else dx.new_empty(0)
98
+ )
99
+
100
+ @ssm_chunk_scan_combined_bwd.register_fake
101
+ def _ssm_chunk_scan_combined_bwd_fake(
102
+ dout: torch.Tensor,
103
+ x: torch.Tensor,
104
+ dt: torch.Tensor,
105
+ A: torch.Tensor,
106
+ B: torch.Tensor,
107
+ C: torch.Tensor,
108
+ out: torch.Tensor,
109
+ chunk_size: int,
110
+ D: Optional[torch.Tensor] = None,
111
+ z: Optional[torch.Tensor] = None,
112
+ dt_bias: Optional[torch.Tensor] = None,
113
+ initial_states: Optional[torch.Tensor] = None,
114
+ seq_idx: Optional[torch.Tensor] = None,
115
+ dt_softplus: bool = False,
116
+ dt_limit: Optional[List[float]] = None
117
+ ):
118
+ return (
119
+ torch.empty_like(x),
120
+ torch.empty_like(dt),
121
+ torch.empty_like(A),
122
+ torch.empty_like(B),
123
+ torch.empty_like(C),
124
+ torch.empty_like(D) if D is not None else None,
125
+ torch.empty_like(z) if z is not None else None,
126
+ torch.empty_like(dt_bias) if dt_bias is not None else None,
127
+ torch.empty_like(initial_states) if initial_states is not None else None,
128
+ )
129
+
130
+
131
+ def ssm_chunk_scan_combined_setup_context(ctx, inputs, output):
132
+ x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit = inputs
133
+ out, out_x, state_varlen = output
134
+
135
+ ctx.save_for_backward(out if z is None else out_x, x, dt, A, B, C, D, z, dt_bias, initial_states, seq_idx)
136
+ ctx.dt_softplus = dt_softplus
137
+ ctx.chunk_size = chunk_size
138
+ ctx.dt_limit = dt_limit
139
+
140
+ def ssm_chunk_scan_combined_bridge(ctx, dout, dout_x, dout_state_varlen):
141
+ out, x, dt, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors
142
+
143
+ dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = ssm_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)
144
+
145
+ return (
146
+ dx,
147
+ ddt,
148
+ dA,
149
+ dB,
150
+ dC,
151
+ None,
152
+ dD if D is not None else None,
153
+ dz if z is not None else None,
154
+ ddt_bias if dt_bias is not None else None,
155
+ dinitial_states if initial_states is not None else None,
156
+ None,
157
+ None,
158
+ None,
159
+ None,
160
+ )
161
+
162
+ # Register custom autograd function
163
+ torch.library.register_autograd(
164
+ "mamba_ssm::ssm_chunk_scan_combined_fwd",
165
+ ssm_chunk_scan_combined_bridge,
166
+ setup_context=ssm_chunk_scan_combined_setup_context,
167
+ )
168
+
169
+ def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
170
+ """
171
+ Argument:
172
+ x: (batch, seqlen, nheads, headdim)
173
+ dt: (batch, seqlen, nheads)
174
+ A: (nheads)
175
+ B: (batch, seqlen, ngroups, dstate)
176
+ C: (batch, seqlen, ngroups, dstate)
177
+ chunk_size: int
178
+ D: (nheads, headdim) or (nheads,)
179
+ z: (batch, seqlen, nheads, headdim)
180
+ dt_bias: (nheads,)
181
+ initial_states: (batch, nheads, headdim, dstate)
182
+ seq_idx: (batch, seqlen)
183
+ cu_seqlens: (num_sequences + 1) or None
184
+ dt_softplus: Whether to apply softplus to dt
185
+ Return:
186
+ out: (batch, seqlen, nheads, headdim)
187
+ """
188
+
189
+ out, _, varlen_states = ssm_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
190
+ if cu_seqlens is not None:
191
+ return out, varlen_states
192
+ return out
193
+
194
+ if __name__ == "__main__":
195
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as mamba_chunk_scan_combined_ref
196
+
197
+ torch.manual_seed(0)
198
+ torch.cuda.manual_seed(0)
199
+
200
+ x = torch.randn(2, 3, 4, 5).cuda()
201
+ dt = torch.randn(2, 3, 4).cuda()
202
+ A = torch.randn(4).cuda()
203
+ B = torch.randn(2, 3, 4, 5).cuda()
204
+ C = torch.randn(2, 3, 4, 5).cuda()
205
+ chunk_size = 2
206
+ D = torch.randn(4, 5).cuda()
207
+ z = torch.randn(2, 3, 4, 5).cuda()
208
+ dt_bias = torch.randn(4).cuda()
209
+
210
+ out = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias)
211
+
212
+ print(out.min(), out.max(), out.mean(), out.std())
213
+
214
+ compiled_mamba_chunk_scan_combined = torch.compile(mamba_chunk_scan_combined)
215
+ out = compiled_mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias)
216
+
217
+ print(out.min(), out.max(), out.mean(), out.std())
218
+
219
+ out_ref = mamba_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias)
220
+
221
+ print(out_ref.min(), out_ref.max(), out_ref.mean(), out_ref.std())
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "199999": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "200018": {
13
+ "content": "<|endofprompt|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|endoftext|>",
22
+ "clean_up_tokenization_spaces": false,
23
+ "eos_token": "<|endoftext|>",
24
+ "model_max_length": 128000,
25
+ "tokenizer_class": "GPT2Tokenizer",
26
+ "unk_token": "<|endoftext|>"
27
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff