danieldk HF Staff commited on
Commit
8e88928
·
verified ·
1 Parent(s): d033399

Build uploaded using `kernels`.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch210-cxx11-cu126-x86_64-linux/__init__.py +0 -14
  2. build/torch210-cxx11-cu126-x86_64-linux/_mamba_ssm_b2a7fd5.abi3.so +0 -3
  3. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +0 -9
  4. build/torch210-cxx11-cu126-x86_64-linux/distributed/__init__.py +0 -0
  5. build/torch210-cxx11-cu126-x86_64-linux/distributed/distributed_utils.py +0 -144
  6. build/torch210-cxx11-cu126-x86_64-linux/distributed/tensor_parallel.py +0 -296
  7. build/torch210-cxx11-cu126-x86_64-linux/mamba_ssm/__init__.py +0 -26
  8. build/torch210-cxx11-cu126-x86_64-linux/metadata.json +0 -1
  9. build/torch210-cxx11-cu126-x86_64-linux/models/__init__.py +0 -0
  10. build/torch210-cxx11-cu126-x86_64-linux/models/config_mamba.py +0 -18
  11. build/torch210-cxx11-cu126-x86_64-linux/models/mixer_seq_simple.py +0 -309
  12. build/torch210-cxx11-cu126-x86_64-linux/modules/__init__.py +0 -0
  13. build/torch210-cxx11-cu126-x86_64-linux/modules/block.py +0 -107
  14. build/torch210-cxx11-cu126-x86_64-linux/modules/mamba2.py +0 -502
  15. build/torch210-cxx11-cu126-x86_64-linux/modules/mamba2_simple.py +0 -229
  16. build/torch210-cxx11-cu126-x86_64-linux/modules/mamba_simple.py +0 -339
  17. build/torch210-cxx11-cu126-x86_64-linux/modules/mha.py +0 -294
  18. build/torch210-cxx11-cu126-x86_64-linux/modules/mlp.py +0 -34
  19. build/torch210-cxx11-cu126-x86_64-linux/modules/ssd_minimal.py +0 -111
  20. build/torch210-cxx11-cu126-x86_64-linux/ops/__init__.py +0 -0
  21. build/torch210-cxx11-cu126-x86_64-linux/ops/selective_scan_interface.py +0 -446
  22. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/__init__.py +0 -0
  23. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/k_activations.py +0 -169
  24. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/layer_norm.py +0 -1113
  25. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/layernorm_gated.py +0 -437
  26. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/selective_state_update.py +0 -285
  27. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/softplus.py +0 -15
  28. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_bmm.py +0 -262
  29. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_chunk_scan.py +0 -0
  30. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_chunk_state.py +0 -997
  31. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_combined.py +0 -998
  32. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_state_passing.py +0 -348
  33. build/torch210-cxx11-cu126-x86_64-linux/utils/__init__.py +0 -0
  34. build/torch210-cxx11-cu126-x86_64-linux/utils/generation.py +0 -390
  35. build/torch210-cxx11-cu126-x86_64-linux/utils/hf.py +0 -23
  36. build/torch210-cxx11-cu126-x86_64-linux/utils/torch.py +0 -21
  37. build/torch210-cxx11-cu128-x86_64-linux/__init__.py +0 -14
  38. build/torch210-cxx11-cu128-x86_64-linux/_mamba_ssm_b2a7fd5.abi3.so +0 -3
  39. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +0 -9
  40. build/torch210-cxx11-cu128-x86_64-linux/distributed/__init__.py +0 -0
  41. build/torch210-cxx11-cu128-x86_64-linux/distributed/distributed_utils.py +0 -144
  42. build/torch210-cxx11-cu128-x86_64-linux/distributed/tensor_parallel.py +0 -296
  43. build/torch210-cxx11-cu128-x86_64-linux/mamba_ssm/__init__.py +0 -26
  44. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +0 -1
  45. build/torch210-cxx11-cu128-x86_64-linux/models/__init__.py +0 -0
  46. build/torch210-cxx11-cu128-x86_64-linux/models/config_mamba.py +0 -18
  47. build/torch210-cxx11-cu128-x86_64-linux/models/mixer_seq_simple.py +0 -309
  48. build/torch210-cxx11-cu128-x86_64-linux/modules/__init__.py +0 -0
  49. build/torch210-cxx11-cu128-x86_64-linux/modules/block.py +0 -107
  50. build/torch210-cxx11-cu128-x86_64-linux/modules/mamba2.py +0 -502
build/torch210-cxx11-cu126-x86_64-linux/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- __version__ = "2.2.4"
2
-
3
- from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
- from .modules.mamba_simple import Mamba
5
- from .modules.mamba2 import Mamba2
6
- from .models.mixer_seq_simple import MambaLMHeadModel
7
-
8
- __all__ = [
9
- "selective_scan_fn",
10
- "mamba_inner_fn",
11
- "Mamba",
12
- "Mamba2",
13
- "MambaLMHeadModel",
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/_mamba_ssm_b2a7fd5.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:19b5ffd35a9fd55231325ac14270580c019395c0acb3e4e251518042b50b1aed
3
- size 444257200
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _mamba_ssm_b2a7fd5
3
- ops = torch.ops._mamba_ssm_b2a7fd5
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_mamba_ssm_b2a7fd5::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/distributed/__init__.py DELETED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/distributed/distributed_utils.py DELETED
@@ -1,144 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- from torch import Tensor
5
- from torch.distributed import ProcessGroup
6
-
7
- # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
8
- # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
9
- # version of PyTorch. The following 4 lines are for backward compatibility with
10
- # older PyTorch.
11
- if "all_gather_into_tensor" not in dir(torch.distributed):
12
- torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
13
- if "reduce_scatter_tensor" not in dir(torch.distributed):
14
- torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
15
-
16
-
17
- # Raw operation, does not support autograd, but does support async
18
- def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
19
- world_size = torch.distributed.get_world_size(process_group)
20
- output = torch.empty(
21
- world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
22
- )
23
- handle = torch.distributed.all_gather_into_tensor(
24
- output, input_.contiguous(), group=process_group, async_op=async_op
25
- )
26
- return output, handle
27
-
28
-
29
- # Raw operation, does not support autograd, but does support async
30
- def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
31
- world_size = torch.distributed.get_world_size(process_group)
32
- assert input_.shape[0] % world_size == 0
33
- output = torch.empty(
34
- input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
35
- )
36
- handle = torch.distributed.reduce_scatter_tensor(
37
- output, input_.contiguous(), group=process_group, async_op=async_op
38
- )
39
- return output, handle
40
-
41
-
42
- # Raw operation, does not support autograd, but does support async
43
- def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
44
- input_ = input_.contiguous()
45
- handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
46
- return input_, handle
47
-
48
-
49
- class AllGatherFunc(torch.autograd.Function):
50
- """Gather the input from sequence parallel region and concatenate."""
51
-
52
- @staticmethod
53
- def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
54
- ctx.process_group = process_group
55
- output, _ = all_gather_raw(input_, process_group)
56
- return output
57
-
58
- @staticmethod
59
- def backward(ctx, grad_output: Tensor):
60
- grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
61
- return grad_input, None
62
-
63
-
64
- # Supports autograd, but does not support async
65
- all_gather = AllGatherFunc.apply
66
-
67
-
68
- class ReduceScatterFunc(torch.autograd.Function):
69
- """Reduce scatter the input from the sequence parallel region and concatenate."""
70
-
71
- @staticmethod
72
- def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
73
- ctx.process_group = process_group
74
- output, _ = reduce_scatter_raw(input_, process_group)
75
- return output
76
-
77
- @staticmethod
78
- def backward(ctx, grad_output: Tensor):
79
- grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
80
- return grad_input, None
81
-
82
-
83
- # Supports autograd, but does not support async
84
- reduce_scatter = ReduceScatterFunc.apply
85
-
86
-
87
- class AllReduceFunc(torch.autograd.Function):
88
- """Gather the input from sequence parallel region and concatenate."""
89
-
90
- @staticmethod
91
- def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
92
- ctx.process_group = process_group
93
- output, _ = all_reduce_raw(input_, process_group)
94
- return output
95
-
96
- @staticmethod
97
- def backward(ctx, grad_output: Tensor):
98
- return grad_output, None
99
-
100
-
101
- # Supports autograd, but does not support async
102
- all_reduce = AllReduceFunc.apply
103
-
104
-
105
- def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
106
- # We want to iterate over parameters with _shared_params=True in the same order,
107
- # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
108
- pamams_shared = {
109
- name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
110
- }
111
- for _, p in sorted(pamams_shared.items()):
112
- with torch.no_grad():
113
- # Broadcast needs src to be global rank, not group rank
114
- torch.distributed.broadcast(
115
- p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
116
- )
117
-
118
-
119
- # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
120
- def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
121
- # We want to iterate over parameters with _sequence_parallel=True in the same order,
122
- # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
123
- params_seqparallel = {
124
- name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
125
- }
126
- grads = [p.grad for _, p in sorted(params_seqparallel.items())]
127
- if grads:
128
- with torch.no_grad():
129
- coalesced = torch._utils._flatten_dense_tensors(grads)
130
- torch.distributed.all_reduce(coalesced, group=process_group)
131
- for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
132
- buf.copy_(synced)
133
-
134
-
135
- def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
136
- """Get the dim for the local rank derived from splitting dim on world_size processes.
137
-
138
- The split may not be even across the world_size processes.
139
- """
140
- multiple = dim // multiple_of
141
- div = multiple // world_size
142
- mod = multiple % world_size
143
- local_multiple = div + int(local_rank < mod)
144
- return local_multiple * multiple_of
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/distributed/tensor_parallel.py DELETED
@@ -1,296 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
- from typing import Optional
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from torch import Tensor
9
- from torch.distributed import ProcessGroup
10
- from ..utils.torch import custom_bwd, custom_fwd
11
-
12
- from einops import rearrange
13
-
14
- from ..distributed.distributed_utils import (
15
- all_gather_raw,
16
- all_reduce,
17
- all_reduce_raw,
18
- reduce_scatter,
19
- reduce_scatter_raw,
20
- )
21
-
22
-
23
- class ParallelLinearFunc(torch.autograd.Function):
24
- @staticmethod
25
- @custom_fwd
26
- def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
- """
28
- If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
- with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
- """
31
- ctx.compute_weight_gradient = weight.requires_grad
32
- ctx.process_group = process_group
33
- ctx.sequence_parallel = sequence_parallel
34
-
35
- if torch.is_autocast_enabled():
36
- x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
- x = x.contiguous()
38
- if process_group is not None and sequence_parallel:
39
- # We want to kick off the all_gather early, before weight dtype conversion
40
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
- else:
42
- total_x = x
43
-
44
- if torch.is_autocast_enabled():
45
- weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
- bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
47
- weight = weight.contiguous()
48
- if process_group is not None and sequence_parallel:
49
- handle_x.wait()
50
- batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
51
- batch_dim = batch_shape.numel()
52
- # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
53
- output = F.linear(total_x, weight, bias)
54
- if ctx.compute_weight_gradient:
55
- ctx.save_for_backward(x, weight)
56
- else:
57
- ctx.save_for_backward(weight)
58
- return output
59
-
60
- @staticmethod
61
- @custom_bwd
62
- def backward(ctx, grad_output):
63
- grad_output = grad_output.contiguous()
64
- process_group = ctx.process_group
65
- sequence_parallel = ctx.sequence_parallel
66
- if ctx.compute_weight_gradient:
67
- x, weight = ctx.saved_tensors
68
- if process_group is not None and sequence_parallel:
69
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
70
- else:
71
- total_x = x
72
- else:
73
- (weight,) = ctx.saved_tensors
74
- total_x = None
75
- batch_shape = grad_output.shape[:-1]
76
- batch_dim = batch_shape.numel()
77
- grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
78
- if ctx.needs_input_grad[0]:
79
- grad_input = F.linear(grad_output, weight.t())
80
- grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
81
- if process_group is not None:
82
- reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
83
- grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
84
- else:
85
- grad_input = None
86
- if ctx.needs_input_grad[1]:
87
- assert ctx.compute_weight_gradient
88
- if process_group is not None and sequence_parallel:
89
- handle_x.wait()
90
- grad_weight = torch.einsum(
91
- "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
92
- )
93
- else:
94
- grad_weight = None
95
- grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
96
- if process_group is not None and ctx.needs_input_grad[0]:
97
- handle_grad_input.wait()
98
- return grad_input, grad_weight, grad_bias, None, None
99
-
100
-
101
- def parallel_linear_func(
102
- x: Tensor,
103
- weight: Tensor,
104
- bias: Optional[Tensor] = None,
105
- process_group: Optional[ProcessGroup] = None,
106
- sequence_parallel: bool = True,
107
- ):
108
- return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
109
-
110
-
111
- class ColumnParallelLinear(nn.Linear):
112
- def __init__(
113
- self,
114
- in_features: int,
115
- out_features: int,
116
- process_group: ProcessGroup,
117
- bias: bool = True,
118
- sequence_parallel=True,
119
- multiple_of=1,
120
- device=None,
121
- dtype=None,
122
- ) -> None:
123
- world_size = torch.distributed.get_world_size(process_group)
124
- if out_features % multiple_of:
125
- raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
126
- multiple = out_features // multiple_of
127
- # We want to split @multiple across world_size, but it could be an uneven split
128
- div = multiple // world_size
129
- mod = multiple % world_size
130
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
131
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
132
- super().__init__(
133
- in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
134
- )
135
- self.process_group = process_group
136
- self.sequence_parallel = sequence_parallel
137
-
138
- def forward(self, x):
139
- # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
140
- # we do an all_gather of x before doing the matmul.
141
- # If not, then the input is already gathered.
142
- return parallel_linear_func(
143
- x,
144
- self.weight,
145
- self.bias,
146
- process_group=self.process_group,
147
- sequence_parallel=self.sequence_parallel,
148
- )
149
-
150
-
151
- class RowParallelLinear(nn.Linear):
152
- def __init__(
153
- self,
154
- in_features: int,
155
- out_features: int,
156
- process_group: ProcessGroup,
157
- bias: bool = True,
158
- sequence_parallel=True,
159
- multiple_of=1,
160
- device=None,
161
- dtype=None,
162
- ) -> None:
163
- world_size = torch.distributed.get_world_size(process_group)
164
- rank = torch.distributed.get_rank(process_group)
165
- if in_features % multiple_of:
166
- raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
167
- multiple = in_features // multiple_of
168
- # We want to split @multiple across world_size, but it could be an uneven split
169
- div = multiple // world_size
170
- mod = multiple % world_size
171
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
172
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
173
- # Only rank 0 will have bias
174
- super().__init__(
175
- local_multiple * multiple_of,
176
- out_features,
177
- bias=bias and rank == 0,
178
- device=device,
179
- dtype=dtype,
180
- )
181
- self.process_group = process_group
182
- self.sequence_parallel = sequence_parallel
183
-
184
- def forward(self, x):
185
- """
186
- We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
187
- a reduce_scatter of the result.
188
- """
189
- out = parallel_linear_func(x, self.weight, self.bias)
190
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
191
- return reduce_fn(out, self.process_group)
192
-
193
-
194
- class VocabParallelEmbedding(nn.Embedding):
195
- def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
196
- self.process_group = process_group
197
- if process_group is not None:
198
- world_size = torch.distributed.get_world_size(process_group)
199
- if num_embeddings % world_size != 0:
200
- raise ValueError(
201
- f"num_embeddings ({num_embeddings}) must be divisible by "
202
- f"world_size ({world_size})"
203
- )
204
- if world_size > 1 and padding_idx is not None:
205
- raise RuntimeError("ParallelEmbedding does not support padding_idx")
206
- else:
207
- world_size = 1
208
- super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
209
-
210
- def forward(self, input: Tensor) -> Tensor:
211
- if self.process_group is None:
212
- return super().forward(input)
213
- else:
214
- rank = torch.distributed.get_rank(self.process_group)
215
- vocab_size = self.num_embeddings
216
- vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
217
- # Create a mask of valid vocab ids (1 means it needs to be masked).
218
- input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
219
- input = input - vocab_start_index
220
- input[input_ids_mask] = 0
221
- embeddings = super().forward(input)
222
- embeddings[input_ids_mask] = 0.0
223
- return embeddings
224
-
225
-
226
- class ColumnParallelEmbedding(nn.Embedding):
227
- def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
228
- self.process_group = process_group
229
- if process_group is not None:
230
- world_size = torch.distributed.get_world_size(process_group)
231
- if embedding_dim % world_size != 0:
232
- raise ValueError(
233
- f"embedding_dim ({embedding_dim}) must be divisible by "
234
- f"world_size ({world_size})"
235
- )
236
- else:
237
- world_size = 1
238
- super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
239
-
240
-
241
- class ParallelEmbeddings(nn.Module):
242
- def __init__(
243
- self,
244
- embed_dim,
245
- vocab_size,
246
- max_position_embeddings,
247
- process_group,
248
- padding_idx=None,
249
- sequence_parallel=True,
250
- device=None,
251
- dtype=None,
252
- ):
253
- """
254
- If max_position_embeddings <= 0, there's no position embeddings
255
- """
256
- factory_kwargs = {"device": device, "dtype": dtype}
257
- super().__init__()
258
- self.process_group = process_group
259
- self.sequence_parallel = sequence_parallel
260
- self.word_embeddings = VocabParallelEmbedding(
261
- vocab_size,
262
- embed_dim,
263
- padding_idx=padding_idx,
264
- process_group=process_group,
265
- **factory_kwargs,
266
- )
267
- self.max_position_embeddings = max_position_embeddings
268
- if self.max_position_embeddings > 0:
269
- self.position_embeddings = ColumnParallelEmbedding(
270
- max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
271
- )
272
-
273
- def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
274
- """
275
- input_ids: (batch, seqlen)
276
- position_ids: (batch, seqlen)
277
- """
278
- batch_size, seqlen = input_ids.shape
279
- world_size = torch.distributed.get_world_size(self.process_group)
280
- embeddings = self.word_embeddings(input_ids)
281
- if self.max_position_embeddings > 0:
282
- if position_ids is None:
283
- position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
284
- position_embeddings = self.position_embeddings(position_ids)
285
- if world_size <= 1:
286
- embeddings = embeddings + position_embeddings
287
- else:
288
- partition_dim = self.position_embeddings.embedding_dim
289
- rank = torch.distributed.get_rank(self.process_group)
290
- embeddings[
291
- ..., rank * partition_dim : (rank + 1) * partition_dim
292
- ] += position_embeddings
293
- if combine_batch_seqlen_dim:
294
- embeddings = rearrange(embeddings, "b s d -> (b s) d")
295
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
296
- return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/mamba_ssm/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch210-cxx11-cu126-x86_64-linux/models/__init__.py DELETED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/models/config_mamba.py DELETED
@@ -1,18 +0,0 @@
1
- from dataclasses import dataclass, field
2
-
3
-
4
- @dataclass
5
- class MambaConfig:
6
-
7
- d_model: int = 2560
8
- d_intermediate: int = 0
9
- n_layer: int = 64
10
- vocab_size: int = 50277
11
- ssm_cfg: dict = field(default_factory=dict)
12
- attn_layer_idx: list = field(default_factory=list)
13
- attn_cfg: dict = field(default_factory=dict)
14
- rms_norm: bool = True
15
- residual_in_fp32: bool = True
16
- fused_add_norm: bool = True
17
- pad_vocab_size_multiple: int = 8
18
- tie_embeddings: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/models/mixer_seq_simple.py DELETED
@@ -1,309 +0,0 @@
1
- # Copyright (c) 2023, Albert Gu, Tri Dao.
2
-
3
- import math
4
- from functools import partial
5
- import json
6
- import os
7
- import copy
8
-
9
- from collections import namedtuple
10
-
11
- import torch
12
- import torch.nn as nn
13
-
14
- from .config_mamba import MambaConfig
15
- from ..modules.mamba_simple import Mamba
16
- from ..modules.mamba2 import Mamba2
17
- from ..modules.mha import MHA
18
- from ..modules.mlp import GatedMLP
19
- from ..modules.block import Block
20
- from ..utils.generation import GenerationMixin
21
- from ..utils.hf import load_config_hf, load_state_dict_hf
22
-
23
- try:
24
- from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
- except ImportError:
26
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
-
28
-
29
- def create_block(
30
- d_model,
31
- d_intermediate,
32
- ssm_cfg=None,
33
- attn_layer_idx=None,
34
- attn_cfg=None,
35
- norm_epsilon=1e-5,
36
- rms_norm=False,
37
- residual_in_fp32=False,
38
- fused_add_norm=False,
39
- layer_idx=None,
40
- device=None,
41
- dtype=None,
42
- ):
43
- if ssm_cfg is None:
44
- ssm_cfg = {}
45
- if attn_layer_idx is None:
46
- attn_layer_idx = []
47
- if attn_cfg is None:
48
- attn_cfg = {}
49
- factory_kwargs = {"device": device, "dtype": dtype}
50
- if layer_idx not in attn_layer_idx:
51
- # Create a copy of the config to modify
52
- ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
- ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
- if ssm_layer not in ["Mamba1", "Mamba2"]:
55
- raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
56
- mixer_cls = partial(
57
- Mamba2 if ssm_layer == "Mamba2" else Mamba,
58
- layer_idx=layer_idx,
59
- **ssm_cfg,
60
- **factory_kwargs
61
- )
62
- else:
63
- mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
64
- norm_cls = partial(
65
- nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
66
- )
67
- if d_intermediate == 0:
68
- mlp_cls = nn.Identity
69
- else:
70
- mlp_cls = partial(
71
- GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
72
- )
73
- block = Block(
74
- d_model,
75
- mixer_cls,
76
- mlp_cls,
77
- norm_cls=norm_cls,
78
- fused_add_norm=fused_add_norm,
79
- residual_in_fp32=residual_in_fp32,
80
- )
81
- block.layer_idx = layer_idx
82
- return block
83
-
84
-
85
- # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
86
- def _init_weights(
87
- module,
88
- n_layer,
89
- initializer_range=0.02, # Now only used for embedding layer.
90
- rescale_prenorm_residual=True,
91
- n_residuals_per_layer=1, # Change to 2 if we have MLP
92
- ):
93
- if isinstance(module, nn.Linear):
94
- if module.bias is not None:
95
- if not getattr(module.bias, "_no_reinit", False):
96
- nn.init.zeros_(module.bias)
97
- elif isinstance(module, nn.Embedding):
98
- nn.init.normal_(module.weight, std=initializer_range)
99
-
100
- if rescale_prenorm_residual:
101
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
102
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
103
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
104
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
105
- #
106
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
107
- for name, p in module.named_parameters():
108
- if name in ["out_proj.weight", "fc2.weight"]:
109
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
110
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
111
- # We need to reinit p since this code could be called multiple times
112
- # Having just p *= scale would repeatedly scale it down
113
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
114
- with torch.no_grad():
115
- p /= math.sqrt(n_residuals_per_layer * n_layer)
116
-
117
-
118
- class MixerModel(nn.Module):
119
- def __init__(
120
- self,
121
- d_model: int,
122
- n_layer: int,
123
- d_intermediate: int,
124
- vocab_size: int,
125
- ssm_cfg=None,
126
- attn_layer_idx=None,
127
- attn_cfg=None,
128
- norm_epsilon: float = 1e-5,
129
- rms_norm: bool = False,
130
- initializer_cfg=None,
131
- fused_add_norm=False,
132
- residual_in_fp32=False,
133
- device=None,
134
- dtype=None,
135
- ) -> None:
136
- factory_kwargs = {"device": device, "dtype": dtype}
137
- super().__init__()
138
- self.residual_in_fp32 = residual_in_fp32
139
-
140
- self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
141
-
142
- # We change the order of residual and layer norm:
143
- # Instead of LN -> Attn / MLP -> Add, we do:
144
- # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
145
- # the main branch (output of MLP / Mixer). The model definition is unchanged.
146
- # This is for performance reason: we can fuse add + layer_norm.
147
- self.fused_add_norm = fused_add_norm
148
- if self.fused_add_norm:
149
- if layer_norm_fn is None or rms_norm_fn is None:
150
- raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
151
-
152
- self.layers = nn.ModuleList(
153
- [
154
- create_block(
155
- d_model,
156
- d_intermediate=d_intermediate,
157
- ssm_cfg=ssm_cfg,
158
- attn_layer_idx=attn_layer_idx,
159
- attn_cfg=attn_cfg,
160
- norm_epsilon=norm_epsilon,
161
- rms_norm=rms_norm,
162
- residual_in_fp32=residual_in_fp32,
163
- fused_add_norm=fused_add_norm,
164
- layer_idx=i,
165
- **factory_kwargs,
166
- )
167
- for i in range(n_layer)
168
- ]
169
- )
170
-
171
- self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
172
- d_model, eps=norm_epsilon, **factory_kwargs
173
- )
174
-
175
- self.apply(
176
- partial(
177
- _init_weights,
178
- n_layer=n_layer,
179
- **(initializer_cfg if initializer_cfg is not None else {}),
180
- n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
181
- )
182
- )
183
-
184
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
185
- return {
186
- i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
187
- for i, layer in enumerate(self.layers)
188
- }
189
-
190
- def forward(self, input_ids, inference_params=None, **mixer_kwargs):
191
- hidden_states = self.embedding(input_ids)
192
- residual = None
193
- for layer in self.layers:
194
- hidden_states, residual = layer(
195
- hidden_states, residual, inference_params=inference_params, **mixer_kwargs
196
- )
197
- if not self.fused_add_norm:
198
- residual = (hidden_states + residual) if residual is not None else hidden_states
199
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
200
- else:
201
- # Set prenorm=False here since we don't need the residual
202
- hidden_states = layer_norm_fn(
203
- hidden_states,
204
- self.norm_f.weight,
205
- self.norm_f.bias,
206
- eps=self.norm_f.eps,
207
- residual=residual,
208
- prenorm=False,
209
- residual_in_fp32=self.residual_in_fp32,
210
- is_rms_norm=isinstance(self.norm_f, RMSNorm)
211
- )
212
- return hidden_states
213
-
214
-
215
- class MambaLMHeadModel(nn.Module, GenerationMixin):
216
-
217
- def __init__(
218
- self,
219
- config: MambaConfig,
220
- initializer_cfg=None,
221
- device=None,
222
- dtype=None,
223
- ) -> None:
224
- self.config = config
225
- d_model = config.d_model
226
- n_layer = config.n_layer
227
- d_intermediate = config.d_intermediate
228
- vocab_size = config.vocab_size
229
- ssm_cfg = config.ssm_cfg
230
- attn_layer_idx = config.attn_layer_idx
231
- attn_cfg = config.attn_cfg
232
- rms_norm = config.rms_norm
233
- residual_in_fp32 = config.residual_in_fp32
234
- fused_add_norm = config.fused_add_norm
235
- pad_vocab_size_multiple = config.pad_vocab_size_multiple
236
- factory_kwargs = {"device": device, "dtype": dtype}
237
-
238
- super().__init__()
239
- if vocab_size % pad_vocab_size_multiple != 0:
240
- vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
241
- self.backbone = MixerModel(
242
- d_model=d_model,
243
- n_layer=n_layer,
244
- d_intermediate=d_intermediate,
245
- vocab_size=vocab_size,
246
- ssm_cfg=ssm_cfg,
247
- attn_layer_idx=attn_layer_idx,
248
- attn_cfg=attn_cfg,
249
- rms_norm=rms_norm,
250
- initializer_cfg=initializer_cfg,
251
- fused_add_norm=fused_add_norm,
252
- residual_in_fp32=residual_in_fp32,
253
- **factory_kwargs,
254
- )
255
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
256
-
257
- # Initialize weights and apply final processing
258
- self.apply(
259
- partial(
260
- _init_weights,
261
- n_layer=n_layer,
262
- **(initializer_cfg if initializer_cfg is not None else {}),
263
- )
264
- )
265
- self.tie_weights()
266
-
267
- def tie_weights(self):
268
- if self.config.tie_embeddings:
269
- self.lm_head.weight = self.backbone.embedding.weight
270
-
271
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
272
- return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
273
-
274
- def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
275
- """
276
- "position_ids" is just to be compatible with Transformer generation. We don't use it.
277
- num_last_tokens: if > 0, only return the logits for the last n tokens
278
- """
279
- hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
280
- if num_last_tokens > 0:
281
- hidden_states = hidden_states[:, -num_last_tokens:]
282
- lm_logits = self.lm_head(hidden_states)
283
- CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
284
- return CausalLMOutput(logits=lm_logits)
285
-
286
- @classmethod
287
- def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
288
- config_data = load_config_hf(pretrained_model_name)
289
- config = MambaConfig(**config_data)
290
- model = cls(config, device=device, dtype=dtype, **kwargs)
291
- model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
292
- return model
293
-
294
- def save_pretrained(self, save_directory):
295
- """
296
- Minimal implementation of save_pretrained for MambaLMHeadModel.
297
- Save the model and its configuration file to a directory.
298
- """
299
- # Ensure save_directory exists
300
- os.makedirs(save_directory, exist_ok=True)
301
-
302
- # Save the model's state_dict
303
- model_path = os.path.join(save_directory, 'pytorch_model.bin')
304
- torch.save(self.state_dict(), model_path)
305
-
306
- # Save the configuration of the model
307
- config_path = os.path.join(save_directory, 'config.json')
308
- with open(config_path, 'w') as f:
309
- json.dump(self.config.__dict__, f, indent=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/modules/__init__.py DELETED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/modules/block.py DELETED
@@ -1,107 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
- from typing import Optional
3
-
4
- import torch
5
- from torch import nn, Tensor
6
-
7
- from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
8
-
9
-
10
- class Block(nn.Module):
11
- def __init__(
12
- self,
13
- dim,
14
- mixer_cls,
15
- mlp_cls,
16
- norm_cls=nn.LayerNorm,
17
- fused_add_norm=False,
18
- residual_in_fp32=False,
19
- ):
20
- """
21
- Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
22
-
23
- This Block has a slightly different structure compared to a regular
24
- prenorm Transformer block.
25
- The standard block is: LN -> MHA/MLP -> Add.
26
- [Ref: https://arxiv.org/abs/2002.04745]
27
- Here we have: Add -> LN -> Mixer, returning both
28
- the hidden_states (output of the mixer) and the residual.
29
- This is purely for performance reasons, as we can fuse add and LayerNorm.
30
- The residual needs to be provided (except for the very first block).
31
- """
32
- super().__init__()
33
- self.residual_in_fp32 = residual_in_fp32
34
- self.fused_add_norm = fused_add_norm
35
- self.norm = norm_cls(dim)
36
- self.mixer = mixer_cls(dim)
37
- if mlp_cls is not nn.Identity:
38
- self.norm2 = norm_cls(dim)
39
- self.mlp = mlp_cls(dim)
40
- else:
41
- self.mlp = None
42
- if self.fused_add_norm:
43
- assert RMSNorm is not None, "RMSNorm import fails"
44
- assert isinstance(
45
- self.norm, (nn.LayerNorm, RMSNorm)
46
- ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
47
-
48
- def forward(
49
- self,
50
- hidden_states: Tensor,
51
- residual: Optional[Tensor] = None,
52
- inference_params=None,
53
- **mixer_kwargs
54
- ):
55
- r"""Pass the input through the encoder layer.
56
-
57
- Args:
58
- hidden_states: the sequence to the encoder layer (required).
59
- residual: hidden_states = Mixer(LN(residual))
60
- """
61
- if not self.fused_add_norm:
62
- residual = (
63
- (hidden_states + residual) if residual is not None else hidden_states
64
- )
65
- hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
66
- if self.residual_in_fp32:
67
- residual = residual.to(torch.float32)
68
- else:
69
- hidden_states, residual = layer_norm_fn(
70
- hidden_states,
71
- self.norm.weight,
72
- self.norm.bias,
73
- residual=residual,
74
- prenorm=True,
75
- residual_in_fp32=self.residual_in_fp32,
76
- eps=self.norm.eps,
77
- is_rms_norm=isinstance(self.norm, RMSNorm),
78
- )
79
- hidden_states = self.mixer(
80
- hidden_states, inference_params=inference_params, **mixer_kwargs
81
- )
82
-
83
- if self.mlp is not None:
84
- if not self.fused_add_norm:
85
- residual = hidden_states + residual
86
- hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
87
- if self.residual_in_fp32:
88
- residual = residual.to(torch.float32)
89
- else:
90
- hidden_states, residual = layer_norm_fn(
91
- hidden_states,
92
- self.norm2.weight,
93
- self.norm2.bias,
94
- residual=residual,
95
- prenorm=True,
96
- residual_in_fp32=self.residual_in_fp32,
97
- eps=self.norm2.eps,
98
- is_rms_norm=isinstance(self.norm2, RMSNorm),
99
- )
100
- hidden_states = self.mlp(hidden_states)
101
-
102
- return hidden_states, residual
103
-
104
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
105
- return self.mixer.allocate_inference_cache(
106
- batch_size, max_seqlen, dtype=dtype, **kwargs
107
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/modules/mamba2.py DELETED
@@ -1,502 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- import math
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
- from einops import rearrange, repeat
10
-
11
- try:
12
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
13
- except ImportError:
14
- causal_conv1d_fn, causal_conv1d_update = None, None
15
-
16
- try:
17
- from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18
- except ImportError:
19
- causal_conv1d_varlen_states = None
20
-
21
- try:
22
- from ..ops.triton.selective_state_update import selective_state_update
23
- except ImportError:
24
- selective_state_update = None
25
-
26
- from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
27
-
28
- from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
29
- from ..distributed.distributed_utils import all_reduce, reduce_scatter
30
-
31
- from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
32
- from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
33
-
34
- from huggingface_hub import PyTorchModelHubMixin
35
-
36
-
37
- class Mamba2(nn.Module, PyTorchModelHubMixin):
38
- def __init__(
39
- self,
40
- d_model,
41
- d_state=128,
42
- d_conv=4,
43
- conv_init=None,
44
- expand=2,
45
- headdim=64,
46
- d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
47
- ngroups=1,
48
- A_init_range=(1, 16),
49
- D_has_hdim=False,
50
- rmsnorm=True,
51
- norm_before_gate=False,
52
- dt_min=0.001,
53
- dt_max=0.1,
54
- dt_init_floor=1e-4,
55
- dt_limit=(0.0, float("inf")),
56
- bias=False,
57
- conv_bias=True,
58
- # Fused kernel and sharding options
59
- chunk_size=256,
60
- use_mem_eff_path=True,
61
- layer_idx=None, # Absorb kwarg for general module
62
- process_group=None,
63
- sequence_parallel=True,
64
- device=None,
65
- dtype=None,
66
- ):
67
- factory_kwargs = {"device": device, "dtype": dtype}
68
- super().__init__()
69
- self.d_model = d_model
70
- self.d_state = d_state
71
- self.d_conv = d_conv
72
- self.conv_init = conv_init
73
- self.expand = expand
74
- self.process_group = process_group
75
- self.sequence_parallel = sequence_parallel
76
- self.world_size = 1 if process_group is None else process_group.size()
77
- self.local_rank = 0 if process_group is None else process_group.rank()
78
- self.d_inner = (self.expand * self.d_model) // self.world_size
79
- assert self.d_inner * self.world_size == self.expand * self.d_model
80
- self.headdim = headdim
81
- self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
82
- assert ngroups % self.world_size == 0
83
- self.ngroups = ngroups // self.world_size
84
- assert self.d_ssm % self.headdim == 0
85
- self.nheads = self.d_ssm // self.headdim
86
- self.D_has_hdim = D_has_hdim
87
- self.rmsnorm = rmsnorm
88
- self.norm_before_gate = norm_before_gate
89
- self.dt_limit = dt_limit
90
- self.activation = "silu"
91
- self.chunk_size = chunk_size
92
- self.use_mem_eff_path = use_mem_eff_path
93
- self.layer_idx = layer_idx
94
-
95
- # Order: [z, x, B, C, dt]
96
- d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
97
- if self.process_group is None:
98
- self.in_proj = nn.Linear(
99
- self.d_model, d_in_proj, bias=bias, **factory_kwargs
100
- )
101
- else:
102
- self.in_proj = ColumnParallelLinear(
103
- self.d_model,
104
- d_in_proj * self.world_size,
105
- bias=bias,
106
- process_group=self.process_group,
107
- sequence_parallel=self.sequence_parallel,
108
- **factory_kwargs,
109
- )
110
-
111
- conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
112
- self.conv1d = nn.Conv1d(
113
- in_channels=conv_dim,
114
- out_channels=conv_dim,
115
- bias=conv_bias,
116
- kernel_size=d_conv,
117
- groups=conv_dim,
118
- padding=d_conv - 1,
119
- **factory_kwargs,
120
- )
121
- if self.conv_init is not None:
122
- nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
123
-
124
- self.act = nn.SiLU()
125
-
126
- # Initialize log dt bias
127
- dt = torch.exp(
128
- torch.rand(self.nheads, **factory_kwargs)
129
- * (math.log(dt_max) - math.log(dt_min))
130
- + math.log(dt_min)
131
- )
132
- dt = torch.clamp(dt, min=dt_init_floor)
133
- # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
134
- inv_dt = dt + torch.log(-torch.expm1(-dt))
135
- self.dt_bias = nn.Parameter(inv_dt)
136
- # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
137
- # name.endswith("bias") in param_grouping.py
138
- self.dt_bias._no_weight_decay = True
139
-
140
- assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
141
- A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
142
- *A_init_range
143
- )
144
- A_log = torch.log(A).to(dtype=dtype)
145
- self.A_log = nn.Parameter(A_log)
146
- self.A_log._no_weight_decay = True
147
-
148
- # D "skip" parameter
149
- self.D = nn.Parameter(
150
- torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
151
- )
152
- self.D._no_weight_decay = True
153
-
154
- if self.rmsnorm:
155
- assert RMSNormGated is not None
156
- self.norm = RMSNormGated(
157
- self.d_ssm,
158
- eps=1e-5,
159
- norm_before_gate=self.norm_before_gate,
160
- group_size=self.d_ssm // ngroups,
161
- **factory_kwargs,
162
- )
163
-
164
- if self.process_group is None:
165
- self.out_proj = nn.Linear(
166
- self.d_inner, self.d_model, bias=bias, **factory_kwargs
167
- )
168
- else:
169
- self.out_proj = RowParallelLinear(
170
- self.d_inner * self.world_size,
171
- self.d_model,
172
- bias=bias,
173
- process_group=self.process_group,
174
- sequence_parallel=self.sequence_parallel,
175
- **factory_kwargs,
176
- )
177
-
178
- def forward(
179
- self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
180
- ):
181
- """
182
- u: (batch, seqlen, hidden_dim) if seqlen=None.
183
- If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
184
- split u during sequence parallel, we split the batch * seqlen dimension
185
- (in case batch is small).
186
- Returns: same shape as u
187
- """
188
- seqlen_og = seqlen
189
- if seqlen is None:
190
- batch, seqlen, dim = u.shape
191
- else:
192
- batch_seqlen, dim = u.shape
193
- batch = batch_seqlen // seqlen
194
-
195
- conv_state, ssm_state = None, None
196
- if inference_params is not None:
197
- inference_batch = (
198
- cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
199
- )
200
- conv_state, ssm_state = self._get_states_from_cache(
201
- inference_params, inference_batch
202
- )
203
- if inference_params.seqlen_offset > 0:
204
- # The states are updated inplace
205
- out, _, _ = self.step(u, conv_state, ssm_state)
206
- return out
207
-
208
- zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
209
- if seqlen_og is not None:
210
- zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
211
- # If the model is loaded in fp16, without the .float() here, A might be -inf
212
- A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
213
- dt_limit_kwargs = (
214
- {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
215
- )
216
- if self.use_mem_eff_path and inference_params is None:
217
- out = mamba_split_conv1d_scan_combined(
218
- zxbcdt,
219
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
220
- self.conv1d.bias,
221
- self.dt_bias,
222
- A,
223
- D=(
224
- rearrange(self.D, "(h p) -> h p", p=self.headdim)
225
- if self.D_has_hdim
226
- else self.D
227
- ),
228
- chunk_size=self.chunk_size,
229
- seq_idx=seq_idx,
230
- activation=self.activation,
231
- rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
232
- rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
233
- outproj_weight=self.out_proj.weight,
234
- outproj_bias=self.out_proj.bias,
235
- headdim=None if self.D_has_hdim else self.headdim,
236
- ngroups=self.ngroups,
237
- norm_before_gate=self.norm_before_gate,
238
- **dt_limit_kwargs,
239
- )
240
- if seqlen_og is not None:
241
- out = rearrange(out, "b l d -> (b l) d")
242
- if self.process_group is not None:
243
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
244
- out = reduce_fn(out, self.process_group)
245
- else:
246
- d_mlp = (
247
- zxbcdt.shape[-1]
248
- - 2 * self.d_ssm
249
- - 2 * self.ngroups * self.d_state
250
- - self.nheads
251
- ) // 2
252
- z0, x0, z, xBC, dt = torch.split(
253
- zxbcdt,
254
- [
255
- d_mlp,
256
- d_mlp,
257
- self.d_ssm,
258
- self.d_ssm + 2 * self.ngroups * self.d_state,
259
- self.nheads,
260
- ],
261
- dim=-1,
262
- )
263
- if conv_state is not None:
264
- if cu_seqlens is None:
265
- # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
266
- # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
267
- xBC_t = rearrange(xBC, "b l d -> b d l")
268
- conv_state.copy_(
269
- F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
270
- ) # Update state (B D W)
271
- else:
272
- assert (
273
- causal_conv1d_varlen_states is not None
274
- ), "varlen inference requires causal_conv1d package"
275
- assert (
276
- batch == 1
277
- ), "varlen inference only supports batch dimension 1"
278
- conv_varlen_states = causal_conv1d_varlen_states(
279
- xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
280
- )
281
- conv_state.copy_(conv_varlen_states)
282
- assert self.activation in ["silu", "swish"]
283
- if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
284
- assert (
285
- seq_idx is None
286
- ), "varlen conv1d requires the causal_conv1d package"
287
- xBC = self.act(
288
- self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
289
- :, : -(self.d_conv - 1)
290
- ]
291
- ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
292
- else:
293
- xBC = causal_conv1d_fn(
294
- xBC.transpose(1, 2),
295
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
296
- bias=self.conv1d.bias,
297
- activation=self.activation,
298
- seq_idx=seq_idx,
299
- ).transpose(1, 2)
300
- x, B, C = torch.split(
301
- xBC,
302
- [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
303
- dim=-1,
304
- )
305
- y = mamba_chunk_scan_combined(
306
- rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
307
- dt,
308
- A,
309
- rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
310
- rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
311
- chunk_size=self.chunk_size,
312
- D=(
313
- rearrange(self.D, "(h p) -> h p", p=self.headdim)
314
- if self.D_has_hdim
315
- else self.D
316
- ),
317
- z=(
318
- rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
319
- if not self.rmsnorm
320
- else None
321
- ),
322
- dt_bias=self.dt_bias,
323
- dt_softplus=True,
324
- seq_idx=seq_idx,
325
- cu_seqlens=cu_seqlens,
326
- **dt_limit_kwargs,
327
- return_final_states=ssm_state is not None,
328
- return_varlen_states=cu_seqlens is not None
329
- and inference_params is not None,
330
- )
331
- if ssm_state is not None:
332
- y, last_state, *rest = y
333
- if cu_seqlens is None:
334
- ssm_state.copy_(last_state)
335
- else:
336
- varlen_states = rest[0]
337
- ssm_state.copy_(varlen_states)
338
- y = rearrange(y, "b l h p -> b l (h p)")
339
- if self.rmsnorm:
340
- y = self.norm(y, z)
341
- if d_mlp > 0:
342
- y = torch.cat([F.silu(z0) * x0, y], dim=-1)
343
- if seqlen_og is not None:
344
- y = rearrange(y, "b l d -> (b l) d")
345
- out = self.out_proj(y)
346
- return out
347
-
348
- def step(self, hidden_states, conv_state, ssm_state):
349
- dtype = hidden_states.dtype
350
- assert (
351
- hidden_states.shape[1] == 1
352
- ), "Only support decoding with 1 token at a time for now"
353
- zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
354
- d_mlp = (
355
- zxbcdt.shape[-1]
356
- - 2 * self.d_ssm
357
- - 2 * self.ngroups * self.d_state
358
- - self.nheads
359
- ) // 2
360
- z0, x0, z, xBC, dt = torch.split(
361
- zxbcdt,
362
- [
363
- d_mlp,
364
- d_mlp,
365
- self.d_ssm,
366
- self.d_ssm + 2 * self.ngroups * self.d_state,
367
- self.nheads,
368
- ],
369
- dim=-1,
370
- )
371
-
372
- # Conv step
373
- if causal_conv1d_update is None:
374
- conv_state.copy_(
375
- torch.roll(conv_state, shifts=-1, dims=-1)
376
- ) # Update state (B D W)
377
- conv_state[:, :, -1] = xBC
378
- xBC = torch.sum(
379
- conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
380
- ) # (B D)
381
- if self.conv1d.bias is not None:
382
- xBC = xBC + self.conv1d.bias
383
- xBC = self.act(xBC).to(dtype=dtype)
384
- else:
385
- xBC = causal_conv1d_update(
386
- xBC,
387
- conv_state,
388
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
389
- self.conv1d.bias,
390
- self.activation,
391
- )
392
-
393
- x, B, C = torch.split(
394
- xBC,
395
- [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
396
- dim=-1,
397
- )
398
- A = -torch.exp(self.A_log.float()) # (nheads,)
399
-
400
- # SSM step
401
- if selective_state_update is None:
402
- assert (
403
- self.ngroups == 1
404
- ), "Only support ngroups=1 for this inference code path"
405
- # Discretize A and B
406
- dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
407
- dA = torch.exp(dt * A) # (batch, nheads)
408
- x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
409
- dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
410
- ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
411
- y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
412
- y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
413
- y = rearrange(y, "b h p -> b (h p)")
414
- if not self.rmsnorm:
415
- y = y * self.act(z) # (B D)
416
- else:
417
- A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
418
- dtype=torch.float32
419
- )
420
- dt = repeat(dt, "b h -> b h p", p=self.headdim)
421
- dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
422
- D = repeat(self.D, "h -> h p", p=self.headdim)
423
- B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
424
- C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
425
- x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
426
- if not self.rmsnorm:
427
- z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
428
- y = selective_state_update(
429
- ssm_state,
430
- x_reshaped,
431
- dt,
432
- A,
433
- B,
434
- C,
435
- D,
436
- z=z if not self.rmsnorm else None,
437
- dt_bias=dt_bias,
438
- dt_softplus=True,
439
- )
440
- y = rearrange(y, "b h p -> b (h p)")
441
- if self.rmsnorm:
442
- y = self.norm(y, z)
443
- if d_mlp > 0:
444
- y = torch.cat([F.silu(z0) * x0, y], dim=-1)
445
- out = self.out_proj(y)
446
- return out.unsqueeze(1), conv_state, ssm_state
447
-
448
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
449
- device = self.out_proj.weight.device
450
- conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
451
- conv_state = torch.zeros(
452
- batch_size,
453
- self.d_conv,
454
- self.conv1d.weight.shape[0],
455
- device=device,
456
- dtype=conv_dtype,
457
- ).transpose(1, 2)
458
- ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
459
- ssm_state = torch.zeros(
460
- batch_size,
461
- self.nheads,
462
- self.headdim,
463
- self.d_state,
464
- device=device,
465
- dtype=ssm_dtype,
466
- )
467
- return conv_state, ssm_state
468
-
469
- def _get_states_from_cache(
470
- self, inference_params, batch_size, initialize_states=False
471
- ):
472
- assert self.layer_idx is not None
473
- if self.layer_idx not in inference_params.key_value_memory_dict:
474
- batch_shape = (batch_size,)
475
- conv_state = torch.zeros(
476
- batch_size,
477
- self.d_conv,
478
- self.conv1d.weight.shape[0],
479
- device=self.conv1d.weight.device,
480
- dtype=self.conv1d.weight.dtype,
481
- ).transpose(1, 2)
482
- ssm_state = torch.zeros(
483
- batch_size,
484
- self.nheads,
485
- self.headdim,
486
- self.d_state,
487
- device=self.in_proj.weight.device,
488
- dtype=self.in_proj.weight.dtype,
489
- )
490
- inference_params.key_value_memory_dict[self.layer_idx] = (
491
- conv_state,
492
- ssm_state,
493
- )
494
- else:
495
- conv_state, ssm_state = inference_params.key_value_memory_dict[
496
- self.layer_idx
497
- ]
498
- # TODO: What if batch size changes between generation, and we reuse the same states?
499
- if initialize_states:
500
- conv_state.zero_()
501
- ssm_state.zero_()
502
- return conv_state, ssm_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/modules/mamba2_simple.py DELETED
@@ -1,229 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- import math
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- from einops import rearrange, repeat
9
-
10
- try:
11
- from causal_conv1d import causal_conv1d_fn
12
- except ImportError:
13
- causal_conv1d_fn = None
14
-
15
- try:
16
- from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
17
- except ImportError:
18
- RMSNormGated, LayerNorm = None, None
19
-
20
- from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
21
- from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
22
-
23
-
24
- class Mamba2Simple(nn.Module):
25
- def __init__(
26
- self,
27
- d_model,
28
- d_state=64,
29
- d_conv=4,
30
- conv_init=None,
31
- expand=2,
32
- headdim=128,
33
- ngroups=1,
34
- A_init_range=(1, 16),
35
- dt_min=0.001,
36
- dt_max=0.1,
37
- dt_init_floor=1e-4,
38
- dt_limit=(0.0, float("inf")),
39
- learnable_init_states=False,
40
- activation="swish",
41
- bias=False,
42
- conv_bias=True,
43
- # Fused kernel and sharding options
44
- chunk_size=256,
45
- use_mem_eff_path=True,
46
- layer_idx=None, # Absorb kwarg for general module
47
- device=None,
48
- dtype=None,
49
- ):
50
- factory_kwargs = {"device": device, "dtype": dtype}
51
- super().__init__()
52
- self.d_model = d_model
53
- self.d_state = d_state
54
- self.d_conv = d_conv
55
- self.conv_init = conv_init
56
- self.expand = expand
57
- self.d_inner = self.expand * self.d_model
58
- self.headdim = headdim
59
- self.ngroups = ngroups
60
- assert self.d_inner % self.headdim == 0
61
- self.nheads = self.d_inner // self.headdim
62
- self.dt_limit = dt_limit
63
- self.learnable_init_states = learnable_init_states
64
- self.activation = activation
65
- self.chunk_size = chunk_size
66
- self.use_mem_eff_path = use_mem_eff_path
67
- self.layer_idx = layer_idx
68
-
69
- # Order: [z, x, B, C, dt]
70
- d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
71
- self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
72
-
73
- conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
74
- self.conv1d = nn.Conv1d(
75
- in_channels=conv_dim,
76
- out_channels=conv_dim,
77
- bias=conv_bias,
78
- kernel_size=d_conv,
79
- groups=conv_dim,
80
- padding=d_conv - 1,
81
- **factory_kwargs,
82
- )
83
- if self.conv_init is not None:
84
- nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
85
- # self.conv1d.weight._no_weight_decay = True
86
-
87
- if self.learnable_init_states:
88
- self.init_states = nn.Parameter(
89
- torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)
90
- )
91
- self.init_states._no_weight_decay = True
92
-
93
- self.act = nn.SiLU()
94
-
95
- # Initialize log dt bias
96
- dt = torch.exp(
97
- torch.rand(self.nheads, **factory_kwargs)
98
- * (math.log(dt_max) - math.log(dt_min))
99
- + math.log(dt_min)
100
- )
101
- dt = torch.clamp(dt, min=dt_init_floor)
102
- # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
103
- inv_dt = dt + torch.log(-torch.expm1(-dt))
104
- self.dt_bias = nn.Parameter(inv_dt)
105
- # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
106
- # name.endswith("bias") in param_grouping.py
107
- self.dt_bias._no_weight_decay = True
108
-
109
- # A parameter
110
- assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
111
- A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
112
- *A_init_range
113
- )
114
- A_log = torch.log(A).to(dtype=dtype)
115
- self.A_log = nn.Parameter(A_log)
116
- # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
117
- self.A_log._no_weight_decay = True
118
-
119
- # D "skip" parameter
120
- self.D = nn.Parameter(torch.ones(self.nheads, device=device))
121
- self.D._no_weight_decay = True
122
-
123
- # Extra normalization layer right before output projection
124
- assert RMSNormGated is not None
125
- self.norm = RMSNormGated(
126
- self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs
127
- )
128
-
129
- self.out_proj = nn.Linear(
130
- self.d_inner, self.d_model, bias=bias, **factory_kwargs
131
- )
132
-
133
- def forward(self, u, seq_idx=None):
134
- """
135
- u: (B, L, D)
136
- Returns: same shape as u
137
- """
138
- batch, seqlen, dim = u.shape
139
-
140
- zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
141
- A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
142
- initial_states = (
143
- repeat(self.init_states, "... -> b ...", b=batch)
144
- if self.learnable_init_states
145
- else None
146
- )
147
- dt_limit_kwargs = (
148
- {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
149
- )
150
-
151
- if self.use_mem_eff_path:
152
- # Fully fused path
153
- out = mamba_split_conv1d_scan_combined(
154
- zxbcdt,
155
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
156
- self.conv1d.bias,
157
- self.dt_bias,
158
- A,
159
- D=self.D,
160
- chunk_size=self.chunk_size,
161
- seq_idx=seq_idx,
162
- activation=self.activation,
163
- rmsnorm_weight=self.norm.weight,
164
- rmsnorm_eps=self.norm.eps,
165
- outproj_weight=self.out_proj.weight,
166
- outproj_bias=self.out_proj.bias,
167
- headdim=self.headdim,
168
- ngroups=self.ngroups,
169
- norm_before_gate=False,
170
- initial_states=initial_states,
171
- **dt_limit_kwargs,
172
- )
173
- else:
174
- z, xBC, dt = torch.split(
175
- zxbcdt,
176
- [
177
- self.d_inner,
178
- self.d_inner + 2 * self.ngroups * self.d_state,
179
- self.nheads,
180
- ],
181
- dim=-1,
182
- )
183
- dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
184
- assert self.activation in ["silu", "swish"]
185
-
186
- # 1D Convolution
187
- if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
188
- xBC = self.act(
189
- self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
190
- ) # (B, L, self.d_inner + 2 * ngroups * d_state)
191
- xBC = xBC[:, :seqlen, :]
192
- else:
193
- xBC = causal_conv1d_fn(
194
- x=xBC.transpose(1, 2),
195
- weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
196
- bias=self.conv1d.bias,
197
- activation=self.activation,
198
- ).transpose(1, 2)
199
-
200
- # Split into 3 main branches: X, B, C
201
- # These correspond to V, K, Q respectively in the SSM/attention duality
202
- x, B, C = torch.split(
203
- xBC,
204
- [
205
- self.d_inner,
206
- self.ngroups * self.d_state,
207
- self.ngroups * self.d_state,
208
- ],
209
- dim=-1,
210
- )
211
- y = mamba_chunk_scan_combined(
212
- rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
213
- dt,
214
- A,
215
- rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
216
- rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
217
- chunk_size=self.chunk_size,
218
- D=self.D,
219
- z=None,
220
- seq_idx=seq_idx,
221
- initial_states=initial_states,
222
- **dt_limit_kwargs,
223
- )
224
- y = rearrange(y, "b l h p -> b l (h p)")
225
-
226
- # Multiply "gate" branch and apply extra normalization layer
227
- y = self.norm(y, z)
228
- out = self.out_proj(y)
229
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/modules/mamba_simple.py DELETED
@@ -1,339 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao, Albert Gu.
2
-
3
- import math
4
- from typing import Optional
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from torch import Tensor
10
-
11
- from einops import rearrange, repeat
12
-
13
- from ..ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
14
-
15
- try:
16
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
- except ImportError:
18
- causal_conv1d_fn, causal_conv1d_update = None, None
19
-
20
- try:
21
- from ..ops.triton.selective_state_update import selective_state_update
22
- except ImportError:
23
- selective_state_update = None
24
-
25
- try:
26
- from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
27
- except ImportError:
28
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
-
30
-
31
- class Mamba(nn.Module):
32
- def __init__(
33
- self,
34
- d_model,
35
- d_state=16,
36
- d_conv=4,
37
- expand=2,
38
- dt_rank="auto",
39
- dt_min=0.001,
40
- dt_max=0.1,
41
- dt_init="random",
42
- dt_scale=1.0,
43
- dt_init_floor=1e-4,
44
- conv_bias=True,
45
- bias=False,
46
- use_fast_path=True, # Fused kernel options
47
- layer_idx=None,
48
- device=None,
49
- dtype=None,
50
- ):
51
- factory_kwargs = {"device": device, "dtype": dtype}
52
- super().__init__()
53
- self.d_model = d_model
54
- self.d_state = d_state
55
- self.d_conv = d_conv
56
- self.expand = expand
57
- self.d_inner = int(self.expand * self.d_model)
58
- self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
59
- self.use_fast_path = use_fast_path
60
- self.layer_idx = layer_idx
61
-
62
- self.in_proj = nn.Linear(
63
- self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
64
- )
65
-
66
- self.conv1d = nn.Conv1d(
67
- in_channels=self.d_inner,
68
- out_channels=self.d_inner,
69
- bias=conv_bias,
70
- kernel_size=d_conv,
71
- groups=self.d_inner,
72
- padding=d_conv - 1,
73
- **factory_kwargs,
74
- )
75
-
76
- self.activation = "silu"
77
- self.act = nn.SiLU()
78
-
79
- self.x_proj = nn.Linear(
80
- self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
81
- )
82
- self.dt_proj = nn.Linear(
83
- self.dt_rank, self.d_inner, bias=True, **factory_kwargs
84
- )
85
-
86
- # Initialize special dt projection to preserve variance at initialization
87
- dt_init_std = self.dt_rank**-0.5 * dt_scale
88
- if dt_init == "constant":
89
- nn.init.constant_(self.dt_proj.weight, dt_init_std)
90
- elif dt_init == "random":
91
- nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
92
- else:
93
- raise NotImplementedError
94
-
95
- # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
96
- dt = torch.exp(
97
- torch.rand(self.d_inner, **factory_kwargs)
98
- * (math.log(dt_max) - math.log(dt_min))
99
- + math.log(dt_min)
100
- ).clamp(min=dt_init_floor)
101
- # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
102
- inv_dt = dt + torch.log(-torch.expm1(-dt))
103
- with torch.no_grad():
104
- self.dt_proj.bias.copy_(inv_dt)
105
- # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
106
- self.dt_proj.bias._no_reinit = True
107
-
108
- # S4D real initialization
109
- A = repeat(
110
- torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
111
- "n -> d n",
112
- d=self.d_inner,
113
- ).contiguous()
114
- A_log = torch.log(A) # Keep A_log in fp32
115
- self.A_log = nn.Parameter(A_log)
116
- self.A_log._no_weight_decay = True
117
-
118
- # D "skip" parameter
119
- self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
120
- self.D._no_weight_decay = True
121
-
122
- self.out_proj = nn.Linear(
123
- self.d_inner, self.d_model, bias=bias, **factory_kwargs
124
- )
125
-
126
- def forward(self, hidden_states, inference_params=None):
127
- """
128
- hidden_states: (B, L, D)
129
- Returns: same shape as hidden_states
130
- """
131
- batch, seqlen, dim = hidden_states.shape
132
-
133
- conv_state, ssm_state = None, None
134
- if inference_params is not None:
135
- conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
136
- if inference_params.seqlen_offset > 0:
137
- # The states are updated inplace
138
- out, _, _ = self.step(hidden_states, conv_state, ssm_state)
139
- return out
140
-
141
- # We do matmul and transpose BLH -> HBL at the same time
142
- xz = rearrange(
143
- self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
144
- "d (b l) -> b d l",
145
- l=seqlen,
146
- )
147
- if self.in_proj.bias is not None:
148
- xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
149
-
150
- A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
151
- # In the backward pass we write dx and dz next to each other to avoid torch.cat
152
- if (
153
- self.use_fast_path
154
- and causal_conv1d_fn is not None
155
- and inference_params is None
156
- ): # Doesn't support outputting the states
157
- out = mamba_inner_fn(
158
- xz,
159
- self.conv1d.weight,
160
- self.conv1d.bias,
161
- self.x_proj.weight,
162
- self.dt_proj.weight,
163
- self.out_proj.weight,
164
- self.out_proj.bias,
165
- A,
166
- None, # input-dependent B
167
- None, # input-dependent C
168
- self.D.float(),
169
- delta_bias=self.dt_proj.bias.float(),
170
- delta_softplus=True,
171
- )
172
- else:
173
- x, z = xz.chunk(2, dim=1)
174
- # Compute short convolution
175
- if conv_state is not None:
176
- # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
177
- # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
178
- conv_state.copy_(
179
- F.pad(x, (self.d_conv - x.shape[-1], 0))
180
- ) # Update state (B D W)
181
- if causal_conv1d_fn is None:
182
- x = self.act(self.conv1d(x)[..., :seqlen])
183
- else:
184
- assert self.activation in ["silu", "swish"]
185
- x = causal_conv1d_fn(
186
- x=x,
187
- weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
188
- bias=self.conv1d.bias,
189
- activation=self.activation,
190
- )
191
-
192
- # We're careful here about the layout, to avoid extra transposes.
193
- # We want dt to have d as the slowest moving dimension
194
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
195
- x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
196
- dt, B, C = torch.split(
197
- x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
198
- )
199
- dt = self.dt_proj.weight @ dt.t()
200
- dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
201
- B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
202
- C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
203
- assert self.activation in ["silu", "swish"]
204
- y = selective_scan_fn(
205
- x,
206
- dt,
207
- A,
208
- B,
209
- C,
210
- self.D.float(),
211
- z=z,
212
- delta_bias=self.dt_proj.bias.float(),
213
- delta_softplus=True,
214
- return_last_state=ssm_state is not None,
215
- )
216
- if ssm_state is not None:
217
- y, last_state = y
218
- ssm_state.copy_(last_state)
219
- y = rearrange(y, "b d l -> b l d")
220
- out = self.out_proj(y)
221
- return out
222
-
223
- def step(self, hidden_states, conv_state, ssm_state):
224
- dtype = hidden_states.dtype
225
- assert (
226
- hidden_states.shape[1] == 1
227
- ), "Only support decoding with 1 token at a time for now"
228
- xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
229
- x, z = xz.chunk(2, dim=-1) # (B D)
230
-
231
- # Conv step
232
- if causal_conv1d_update is None:
233
- conv_state.copy_(
234
- torch.roll(conv_state, shifts=-1, dims=-1)
235
- ) # Update state (B D W)
236
- conv_state[:, :, -1] = x
237
- x = torch.sum(
238
- conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
239
- ) # (B D)
240
- if self.conv1d.bias is not None:
241
- x = x + self.conv1d.bias
242
- x = self.act(x).to(dtype=dtype)
243
- else:
244
- x = causal_conv1d_update(
245
- x,
246
- conv_state,
247
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
248
- self.conv1d.bias,
249
- self.activation,
250
- )
251
-
252
- x_db = self.x_proj(x) # (B dt_rank+2*d_state)
253
- dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
254
- # Don't add dt_bias here
255
- dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
256
- A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
257
-
258
- # SSM step
259
- if selective_state_update is None:
260
- # Discretize A and B
261
- dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
262
- dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
263
- dB = torch.einsum("bd,bn->bdn", dt, B)
264
- ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
265
- y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
266
- y = y + self.D.to(dtype) * x
267
- y = y * self.act(z) # (B D)
268
- else:
269
- y = selective_state_update(
270
- ssm_state,
271
- x,
272
- dt,
273
- A,
274
- B,
275
- C,
276
- self.D,
277
- z=z,
278
- dt_bias=self.dt_proj.bias,
279
- dt_softplus=True,
280
- )
281
-
282
- out = self.out_proj(y)
283
- return out.unsqueeze(1), conv_state, ssm_state
284
-
285
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
286
- device = self.out_proj.weight.device
287
- conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
288
- conv_state = torch.zeros(
289
- batch_size,
290
- self.d_model * self.expand,
291
- self.d_conv,
292
- device=device,
293
- dtype=conv_dtype,
294
- )
295
- ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
296
- # ssm_dtype = torch.float32
297
- ssm_state = torch.zeros(
298
- batch_size,
299
- self.d_model * self.expand,
300
- self.d_state,
301
- device=device,
302
- dtype=ssm_dtype,
303
- )
304
- return conv_state, ssm_state
305
-
306
- def _get_states_from_cache(
307
- self, inference_params, batch_size, initialize_states=False
308
- ):
309
- assert self.layer_idx is not None
310
- if self.layer_idx not in inference_params.key_value_memory_dict:
311
- batch_shape = (batch_size,)
312
- conv_state = torch.zeros(
313
- batch_size,
314
- self.d_model * self.expand,
315
- self.d_conv,
316
- device=self.conv1d.weight.device,
317
- dtype=self.conv1d.weight.dtype,
318
- )
319
- ssm_state = torch.zeros(
320
- batch_size,
321
- self.d_model * self.expand,
322
- self.d_state,
323
- device=self.dt_proj.weight.device,
324
- dtype=self.dt_proj.weight.dtype,
325
- # dtype=torch.float32,
326
- )
327
- inference_params.key_value_memory_dict[self.layer_idx] = (
328
- conv_state,
329
- ssm_state,
330
- )
331
- else:
332
- conv_state, ssm_state = inference_params.key_value_memory_dict[
333
- self.layer_idx
334
- ]
335
- # TODO: What if batch size changes between generation, and we reuse the same states?
336
- if initialize_states:
337
- conv_state.zero_()
338
- ssm_state.zero_()
339
- return conv_state, ssm_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/modules/mha.py DELETED
@@ -1,294 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- import math
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from einops import rearrange
9
-
10
- try:
11
- from flash_attn import flash_attn_with_kvcache
12
- except ImportError:
13
- flash_attn_with_kvcache = None
14
-
15
- try:
16
- from flash_attn.layers.rotary import RotaryEmbedding
17
- except ImportError:
18
- RotaryEmbedding = None
19
-
20
- try:
21
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
22
- except ImportError:
23
- causal_conv1d_fn, causal_conv1d_update = None, None
24
-
25
-
26
- def _update_kv_cache(kv, inference_params, layer_idx):
27
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
28
- # Pre-allocate memory for key-values for inference.
29
- num_heads, head_dim = kv.shape[-2:]
30
- assert layer_idx in inference_params.key_value_memory_dict
31
- kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
32
- # Adjust key and value for inference
33
- batch_start = inference_params.batch_size_offset
34
- batch_end = batch_start + kv.shape[0]
35
- sequence_start = inference_params.seqlen_offset
36
- sequence_end = sequence_start + kv.shape[1]
37
- assert batch_end <= kv_cache.shape[0]
38
- assert sequence_end <= kv_cache.shape[1]
39
- assert kv_cache is not None
40
- kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
41
- return kv_cache[batch_start:batch_end, :sequence_end, ...]
42
-
43
-
44
- class MHA(nn.Module):
45
- """Multi-head self-attention and cross-attention"""
46
-
47
- def __init__(
48
- self,
49
- embed_dim,
50
- num_heads,
51
- num_heads_kv=None,
52
- head_dim=None, # If None, use embed_dim // num_heads
53
- mlp_dim=0,
54
- qkv_proj_bias=True,
55
- out_proj_bias=True,
56
- softmax_scale=None,
57
- causal=False,
58
- layer_idx=None,
59
- d_conv=0,
60
- rotary_emb_dim=0,
61
- rotary_emb_base=10000.0,
62
- rotary_emb_interleaved=False,
63
- device=None,
64
- dtype=None,
65
- ) -> None:
66
- """
67
- num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
68
- return_residual: whether to return the input x along with the output. This is for
69
- performance reason: for post-norm architecture, returning the input allows us
70
- to fuse the backward of nn.Linear with the residual connection.
71
- """
72
- factory_kwargs = {"device": device, "dtype": dtype}
73
- super().__init__()
74
- self.embed_dim = embed_dim
75
- self.layer_idx = layer_idx
76
- self.d_conv = d_conv
77
- self.rotary_emb_dim = rotary_emb_dim
78
- self.softmax_scale = softmax_scale
79
- self.causal = causal
80
-
81
- self.num_heads = num_heads
82
- self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
83
- assert (
84
- self.num_heads % self.num_heads_kv == 0
85
- ), "num_heads must be divisible by num_heads_kv"
86
- if head_dim is None:
87
- assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
88
- self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
89
- self.mlp_dim = math.ceil(mlp_dim / 256) * 256
90
- qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
91
- out_dim = self.head_dim * self.num_heads
92
-
93
- if self.rotary_emb_dim > 0:
94
- assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
95
- self.rotary_emb = RotaryEmbedding(
96
- self.rotary_emb_dim,
97
- base=rotary_emb_base,
98
- interleaved=rotary_emb_interleaved,
99
- device=device,
100
- )
101
-
102
- self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
103
- if self.d_conv > 0:
104
- self.conv1d = nn.Conv1d(
105
- qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
106
- **factory_kwargs
107
- )
108
- self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
109
-
110
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
111
- dtype = self.out_proj.weight.dtype if dtype is None else dtype
112
- device = self.out_proj.weight.device
113
- if self.d_conv > 0:
114
- conv_state = torch.zeros(
115
- batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
116
- )
117
- else:
118
- conv_state = None
119
- kv_cache = torch.empty(
120
- batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
121
- )
122
- return kv_cache, conv_state
123
-
124
- def _update_kv_cache(self, kv, inference_params):
125
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
126
- assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
127
- return _update_kv_cache(kv, inference_params, self.layer_idx)
128
-
129
- def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
130
- """
131
- Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
132
- q: (batch_size, seqlen_q, nheads, head_dim)
133
- kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
134
- """
135
- assert inference_params is not None and inference_params.seqlen_offset > 0
136
- if self.rotary_emb_dim > 0:
137
- self.rotary_emb._update_cos_sin_cache(
138
- inference_params.max_seqlen, device=q.device, dtype=q.dtype
139
- )
140
- rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
141
- else:
142
- rotary_cos, rotary_sin = None, None
143
- batch = q.shape[0]
144
- kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
145
- kv_cache = kv_cache[:batch]
146
- cache_seqlens = (
147
- inference_params.lengths_per_sample[:batch]
148
- if inference_params.lengths_per_sample is not None
149
- else inference_params.seqlen_offset
150
- )
151
- assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
152
- context = flash_attn_with_kvcache(
153
- q,
154
- kv_cache[:, :, 0],
155
- kv_cache[:, :, 1],
156
- kv[:, :, 0],
157
- kv[:, :, 1],
158
- rotary_cos=rotary_cos,
159
- rotary_sin=rotary_sin,
160
- cache_seqlens=cache_seqlens,
161
- softmax_scale=self.softmax_scale,
162
- causal=self.causal,
163
- rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
164
- )
165
- return context
166
-
167
- def _update_kvcache_attention(self, q, kv, inference_params):
168
- """Write kv to inference_params, then do attention"""
169
- if (
170
- inference_params.seqlen_offset == 0
171
- or flash_attn_with_kvcache is None
172
- ):
173
- # TODO: this only uses seqlen_offset and not lengths_per_sample.
174
- kv = self._update_kv_cache(kv, inference_params)
175
- k, v = kv.unbind(dim=-3)
176
- k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
177
- v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
178
- return F.scaled_dot_product_attention(
179
- q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
180
- ).transpose(1, 2)
181
- else:
182
- batch = q.shape[0]
183
- kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
184
- kv_cache = kv_cache[:batch]
185
- cache_seqlens = (
186
- inference_params.lengths_per_sample[:batch]
187
- if inference_params.lengths_per_sample is not None
188
- else inference_params.seqlen_offset
189
- )
190
- return flash_attn_with_kvcache(
191
- q,
192
- kv_cache[:, :, 0],
193
- kv_cache[:, :, 1],
194
- kv[:, :, 0],
195
- kv[:, :, 1],
196
- cache_seqlens=cache_seqlens,
197
- softmax_scale=self.softmax_scale,
198
- causal=self.causal,
199
- )
200
-
201
- def forward(self, x, inference_params=None):
202
- """
203
- Arguments:
204
- x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
205
- cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
206
- is the is the sum of the sequence lengths in the batch.
207
- inference_params: for generation. Adapted from Megatron-LM (and Apex)
208
- https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
209
- """
210
- if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
211
- inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
212
- x.shape[0], inference_params.max_seqlen, dtype=x.dtype
213
- )
214
- seqlen_offset = (
215
- 0
216
- if inference_params is None
217
- else (
218
- inference_params.lengths_per_sample
219
- if inference_params.lengths_per_sample is not None
220
- else inference_params.seqlen_offset
221
- )
222
- )
223
- rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
224
- qkv = self.in_proj(x)
225
- if self.mlp_dim > 0:
226
- qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
227
- x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
228
- x_mlp = x_mlp_up * F.silu(x_mlp_gate)
229
- if self.d_conv > 0:
230
- # The inference code for conv1d is pretty messy, should clean it up
231
- if (inference_params is None or inference_params.seqlen_offset == 0):
232
- if causal_conv1d_fn is None:
233
- qkv = rearrange(
234
- self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
235
- ).contiguous()
236
- else:
237
- qkv = causal_conv1d_fn(
238
- qkv.transpose(1, 2),
239
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
240
- self.conv1d.bias
241
- ).transpose(1, 2)
242
- if inference_params is not None:
243
- _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
244
- # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
245
- # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
246
- qkv_t = rearrange(qkv, "b l d -> b d l")
247
- conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
248
- else:
249
- _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
250
- assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
251
- qkv = qkv.squeeze(1)
252
- # Conv step
253
- if causal_conv1d_update is None:
254
- conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
255
- conv_state[:, :, -1] = qkv
256
- qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
257
- if self.conv1d.bias is not None:
258
- qkv = qkv + self.conv1d.bias
259
- else:
260
- qkv = causal_conv1d_update(
261
- qkv,
262
- conv_state,
263
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
264
- self.conv1d.bias
265
- )
266
- qkv = qkv.unsqueeze(1)
267
- q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
268
- q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
269
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
270
- if (
271
- inference_params is None
272
- or inference_params.seqlen_offset == 0
273
- or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
274
- ):
275
- if self.rotary_emb_dim > 0:
276
- q, kv = self.rotary_emb(
277
- q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
278
- )
279
- if inference_params is None:
280
- k, v = kv.unbind(dim=-3)
281
- k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
282
- v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
283
- context = F.scaled_dot_product_attention(
284
- q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
285
- ).transpose(1, 2)
286
- else:
287
- context = self._update_kvcache_attention(q, kv, inference_params)
288
- else:
289
- context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
290
- context = rearrange(context, "... h d -> ... (h d)")
291
- if self.mlp_dim > 0:
292
- context = torch.cat([context, x_mlp], dim=-1)
293
- out = self.out_proj(context)
294
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/modules/mlp.py DELETED
@@ -1,34 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
- from torch import nn
3
- from torch.nn import functional as F
4
-
5
-
6
- class GatedMLP(nn.Module):
7
- def __init__(
8
- self,
9
- in_features,
10
- hidden_features=None,
11
- out_features=None,
12
- activation=F.silu,
13
- bias=False,
14
- multiple_of=128,
15
- device=None,
16
- dtype=None,
17
- ):
18
- factory_kwargs = {"device": device, "dtype": dtype}
19
- super().__init__()
20
- out_features = out_features if out_features is not None else in_features
21
- hidden_features = (
22
- hidden_features if hidden_features is not None else int(8 * in_features / 3)
23
- )
24
- hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
25
- self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
26
- self.activation = activation
27
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
28
-
29
- def forward(self, x):
30
- y = self.fc1(x)
31
- y, gate = y.chunk(2, dim=-1)
32
- y = y * self.activation(gate)
33
- y = self.fc2(y)
34
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/modules/ssd_minimal.py DELETED
@@ -1,111 +0,0 @@
1
- # Copyright (c) 2024, Albert Gu and Tri Dao.
2
- """Minimal implementation of SSD.
3
-
4
- This is the same as Listing 1 from the paper.
5
- """
6
-
7
- import torch
8
- import torch.nn.functional as F
9
- from einops import rearrange, repeat
10
-
11
- from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
12
-
13
-
14
- def segsum_unstable(x):
15
- """Naive segment sum calculation."""
16
- T = x.size(-1)
17
- x_cumsum = torch.cumsum(x, dim=-1)
18
- x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
19
- mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
20
- x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
21
- return x_segsum
22
-
23
-
24
- def segsum(x):
25
- """More stable segment sum calculation."""
26
- T = x.size(-1)
27
- x = repeat(x, "... d -> ... d e", e=T)
28
- mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
29
- x = x.masked_fill(~mask, 0)
30
- x_segsum = torch.cumsum(x, dim=-2)
31
- mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
32
- x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
33
- return x_segsum
34
-
35
-
36
- def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
37
- """
38
- Arguments:
39
- X: (batch, length, n_heads, d_head)
40
- A: (batch, length, n_heads)
41
- B: (batch, length, n_heads, d_state)
42
- C: (batch, length, n_heads, d_state)
43
- Return:
44
- Y: (batch, length, n_heads, d_head)
45
- """
46
- assert X.dtype == A.dtype == B.dtype == C.dtype
47
- assert X.shape[1] % block_len == 0
48
-
49
- # Rearrange into blocks/chunks
50
- X, A, B, C = [
51
- rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
52
- ]
53
-
54
- A = rearrange(A, "b c l h -> b h c l")
55
- A_cumsum = torch.cumsum(A, dim=-1)
56
-
57
- # 1. Compute the output for each intra-chunk (diagonal blocks)
58
- L = torch.exp(segsum(A))
59
- Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
60
-
61
- # 2. Compute the state for each intra-chunk
62
- # (right term of low-rank factorization of off-diagonal blocks; B terms)
63
- decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
64
- states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
65
-
66
- # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
67
- # (middle term of factorization of off-diag blocks; A terms)
68
- if initial_states is None:
69
- initial_states = torch.zeros_like(states[:, :1])
70
- states = torch.cat([initial_states, states], dim=1)
71
- decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
72
- new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
73
- states, final_state = new_states[:, :-1], new_states[:, -1]
74
-
75
- # 4. Compute state -> output conversion per chunk
76
- # (left term of low-rank factorization of off-diagonal blocks; C terms)
77
- state_decay_out = torch.exp(A_cumsum)
78
- Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
79
-
80
- # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
81
- Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
82
- return Y, final_state
83
-
84
-
85
- # Simple test
86
- def test_correctness():
87
- torch.manual_seed(42)
88
-
89
- ## Dimensions
90
- # Denoted (B, T, Q, D, P) in the paper
91
- batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
92
- nheads = dim // headdim # (H) in the paper
93
- ngroups = 1 # (G) in the paper
94
- dstate = 64 # (N) in the paper
95
- dtype = torch.float32
96
- device = "cuda"
97
-
98
- x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
99
- dt = F.softplus(
100
- torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4
101
- ).requires_grad_()
102
- A = (
103
- -torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))
104
- ).requires_grad_()
105
- B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
106
- C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
107
- D = torch.randn(nheads, dtype=dtype, device=device)
108
-
109
- # Comparing fused version and minimal version
110
- y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
111
- y_min, _ = ssd_minimal_discrete(x * dt.unsqueeze(-1), A * dt, B, C, chunk_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/ops/__init__.py DELETED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/ops/selective_scan_interface.py DELETED
@@ -1,446 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao, Albert Gu.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from ..utils.torch import custom_fwd, custom_bwd
6
-
7
- from einops import rearrange, repeat
8
-
9
- try:
10
- from causal_conv1d import causal_conv1d_fn
11
- from causal_conv1d.causal_conv1d_interface import causal_conv1d_cuda
12
- except ImportError:
13
- causal_conv1d_fn = None
14
- causal_conv1d_cuda = None
15
-
16
- from .triton.layer_norm import _layer_norm_fwd
17
-
18
- from .._ops import ops
19
-
20
-
21
- class SelectiveScanFn(torch.autograd.Function):
22
-
23
- @staticmethod
24
- def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
25
- return_last_state=False):
26
- if u.stride(-1) != 1:
27
- u = u.contiguous()
28
- if delta.stride(-1) != 1:
29
- delta = delta.contiguous()
30
- if D is not None:
31
- D = D.contiguous()
32
- if B.stride(-1) != 1:
33
- B = B.contiguous()
34
- if C.stride(-1) != 1:
35
- C = C.contiguous()
36
- if z is not None and z.stride(-1) != 1:
37
- z = z.contiguous()
38
- if B.dim() == 3:
39
- B = rearrange(B, "b dstate l -> b 1 dstate l")
40
- ctx.squeeze_B = True
41
- if C.dim() == 3:
42
- C = rearrange(C, "b dstate l -> b 1 dstate l")
43
- ctx.squeeze_C = True
44
- out, x, *rest = ops.selective_scan_fwd(
45
- u, delta, A, B, C, D, z, delta_bias, delta_softplus
46
- )
47
- ctx.delta_softplus = delta_softplus
48
- ctx.has_z = z is not None
49
- last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
50
- if not ctx.has_z:
51
- ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
52
- return out if not return_last_state else (out, last_state)
53
- else:
54
- ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
55
- out_z = rest[0]
56
- return out_z if not return_last_state else (out_z, last_state)
57
-
58
- @staticmethod
59
- def backward(ctx, dout, *args):
60
- if not ctx.has_z:
61
- u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
62
- z = None
63
- out = None
64
- else:
65
- u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
66
- if dout.stride(-1) != 1:
67
- dout = dout.contiguous()
68
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
69
- # backward of selective_scan_cuda with the backward of chunk).
70
- # Here we just pass in None and dz will be allocated in the C++ code.
71
- du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
72
- u,
73
- delta,
74
- A,
75
- B,
76
- C,
77
- D,
78
- z,
79
- delta_bias,
80
- dout,
81
- x,
82
- out,
83
- None,
84
- ctx.delta_softplus,
85
- False, # option to recompute out_z, not used here
86
- )
87
- dz = rest[0] if ctx.has_z else None
88
- dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
89
- dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
90
- return (du, ddelta, dA, dB, dC,
91
- dD if D is not None else None,
92
- dz,
93
- ddelta_bias if delta_bias is not None else None,
94
- None,
95
- None)
96
-
97
-
98
- def rms_norm_forward(
99
- x,
100
- weight,
101
- bias,
102
- eps=1e-6,
103
- is_rms_norm=True,
104
- ):
105
- # x (b l) d
106
- if x.stride(-1) != 1:
107
- x = x.contiguous()
108
- weight = weight.contiguous()
109
- if bias is not None:
110
- bias = bias.contiguous()
111
- y = _layer_norm_fwd(
112
- x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
113
- )[0]
114
- # y (b l) d
115
- return y
116
-
117
-
118
- def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
119
- return_last_state=False):
120
- """if return_last_state is True, returns (out, last_state)
121
- last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
122
- not considered in the backward pass.
123
- """
124
- return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
125
-
126
-
127
- def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
128
- return_last_state=False):
129
- """
130
- u: r(B D L)
131
- delta: r(B D L)
132
- A: c(D N) or r(D N)
133
- B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
134
- C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
135
- D: r(D)
136
- z: r(B D L)
137
- delta_bias: r(D), fp32
138
-
139
- out: r(B D L)
140
- last_state (optional): r(B D dstate) or c(B D dstate)
141
- """
142
- dtype_in = u.dtype
143
- u = u.float()
144
- delta = delta.float()
145
- if delta_bias is not None:
146
- delta = delta + delta_bias[..., None].float()
147
- if delta_softplus:
148
- delta = F.softplus(delta)
149
- batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
150
- is_variable_B = B.dim() >= 3
151
- is_variable_C = C.dim() >= 3
152
- if A.is_complex():
153
- if is_variable_B:
154
- B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
155
- if is_variable_C:
156
- C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
157
- else:
158
- B = B.float()
159
- C = C.float()
160
- x = A.new_zeros((batch, dim, dstate))
161
- ys = []
162
- deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
163
- if not is_variable_B:
164
- deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
165
- else:
166
- if B.dim() == 3:
167
- deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
168
- else:
169
- B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
170
- deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
171
- if is_variable_C and C.dim() == 4:
172
- C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
173
- last_state = None
174
- for i in range(u.shape[2]):
175
- x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
176
- if not is_variable_C:
177
- y = torch.einsum('bdn,dn->bd', x, C)
178
- else:
179
- if C.dim() == 3:
180
- y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
181
- else:
182
- y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
183
- if i == u.shape[2] - 1:
184
- last_state = x
185
- if y.is_complex():
186
- y = y.real * 2
187
- ys.append(y)
188
- y = torch.stack(ys, dim=2) # (batch dim L)
189
- out = y if D is None else y + u * rearrange(D, "d -> d 1")
190
- if z is not None:
191
- out = out * F.silu(z)
192
- out = out.to(dtype=dtype_in)
193
- return out if not return_last_state else (out, last_state)
194
-
195
-
196
- class MambaInnerFn(torch.autograd.Function):
197
-
198
- @staticmethod
199
- @custom_fwd
200
- def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
201
- out_proj_weight, out_proj_bias,
202
- A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
203
- C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6):
204
- """
205
- xz: (batch, dim, seqlen)
206
- """
207
- assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
208
- assert checkpoint_lvl in [0, 1]
209
- L = xz.shape[-1]
210
- delta_rank = delta_proj_weight.shape[1]
211
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
212
- if torch.is_autocast_enabled():
213
- x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
214
- delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
215
- out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
216
- out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
217
- if out_proj_bias is not None else None)
218
- if xz.stride(-1) != 1:
219
- xz = xz.contiguous()
220
- conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
221
- x, z = xz.chunk(2, dim=1)
222
- conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
223
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
224
- x, conv1d_weight, conv1d_bias, None, None, None, True
225
- )
226
- # We're being very careful here about the layout, to avoid extra transposes.
227
- # We want delta to have d as the slowest moving dimension
228
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
229
- x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
230
- delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
231
- ctx.is_variable_B = B is None
232
- ctx.is_variable_C = C is None
233
- ctx.B_proj_bias_is_None = B_proj_bias is None
234
- ctx.C_proj_bias_is_None = C_proj_bias is None
235
- if B is None: # variable B
236
- B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
237
- if B_proj_bias is not None:
238
- B = B + B_proj_bias.to(dtype=B.dtype)
239
- if not A.is_complex():
240
- # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
241
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
242
- else:
243
- B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
244
- else:
245
- if B.stride(-1) != 1:
246
- B = B.contiguous()
247
- if C is None: # variable C
248
- C = x_dbl[:, -d_state:] # (bl dstate)
249
- if C_proj_bias is not None:
250
- C = C + C_proj_bias.to(dtype=C.dtype)
251
- if not A.is_complex():
252
- # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
253
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
254
- else:
255
- C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
256
- else:
257
- if C.stride(-1) != 1:
258
- C = C.contiguous()
259
- if D is not None:
260
- D = D.contiguous()
261
-
262
- if b_rms_weight is not None:
263
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
264
- B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
265
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
266
- if c_rms_weight is not None:
267
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
268
- C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
269
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
270
- if dt_rms_weight is not None:
271
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
272
- delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
273
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
274
-
275
- out, scan_intermediates, out_z = ops.selective_scan_fwd(
276
- conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
277
- )
278
- ctx.delta_softplus = delta_softplus
279
- ctx.out_proj_bias_is_None = out_proj_bias is None
280
- ctx.checkpoint_lvl = checkpoint_lvl
281
- ctx.b_rms_weight = b_rms_weight
282
- ctx.c_rms_weight = c_rms_weight
283
- ctx.dt_rms_weight = dt_rms_weight
284
- ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
285
- if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
286
- conv1d_out, delta = None, None
287
- ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
288
- delta_proj_weight, out_proj_weight, conv1d_out, delta,
289
- A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out)
290
- return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
291
-
292
- @staticmethod
293
- @custom_bwd
294
- def backward(ctx, dout):
295
- # dout: (batch, seqlen, dim)
296
- assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
297
- (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
298
- conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out) = ctx.saved_tensors
299
- L = xz.shape[-1]
300
- delta_rank = delta_proj_weight.shape[1]
301
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
302
- x, z = xz.chunk(2, dim=1)
303
- if dout.stride(-1) != 1:
304
- dout = dout.contiguous()
305
- if ctx.checkpoint_lvl == 1:
306
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
307
- x, conv1d_weight, conv1d_bias, None, None, None, True
308
- )
309
- delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
310
- "d (b l) -> b d l", l = L)
311
- if dt_rms_weight is not None:
312
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
313
- delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
314
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
315
- if b_rms_weight is not None:
316
- # Recompute & RMSNorm B
317
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
318
- B = rms_norm_forward(
319
- B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps
320
- )
321
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
322
- if c_rms_weight is not None:
323
- # Recompute & RMSNorm C
324
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
325
- C = rms_norm_forward(
326
- C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps
327
- )
328
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
329
-
330
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
331
- # backward of selective_scan_cuda with the backward of chunk).
332
- dxz = torch.empty_like(xz) # (batch, dim, seqlen)
333
- dx, dz = dxz.chunk(2, dim=1)
334
- dout = rearrange(dout, "b l e -> e (b l)")
335
- dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
336
- dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
337
- ops.selective_scan_bwd(
338
- conv1d_out,
339
- delta,
340
- A,
341
- B,
342
- C,
343
- D,
344
- z,
345
- delta_bias,
346
- dout_y,
347
- scan_intermediates,
348
- out,
349
- dz,
350
- ctx.delta_softplus,
351
- True, # option to recompute out_z
352
- )
353
- )
354
- dout_proj_weight = torch.einsum(
355
- "eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
356
- )
357
- dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
358
- dD = dD if D is not None else None
359
- dx_dbl = torch.empty_like(x_dbl)
360
- dB_proj_bias = None
361
- if ctx.is_variable_B:
362
- if not A.is_complex():
363
- dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
364
- else:
365
- dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
366
- dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
367
- dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
368
- dB = None
369
- dC_proj_bias = None
370
- if ctx.is_variable_C:
371
- if not A.is_complex():
372
- dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
373
- else:
374
- dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
375
- dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
376
- dx_dbl[:, -d_state:] = dC # (bl d)
377
- dC = None
378
- ddelta = rearrange(ddelta, "b d l -> d (b l)")
379
- ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
380
- dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
381
- dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
382
- dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
383
- dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
384
- dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
385
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
386
- # backward of conv1d with the backward of chunk).
387
- dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
388
- x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
389
- )
390
- dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
391
- dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
392
- return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
393
- dout_proj_weight, dout_proj_bias,
394
- dA, dB, dC, dD,
395
- ddelta_bias if delta_bias is not None else None,
396
- # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
397
- dB_proj_bias, dC_proj_bias, None, None, None, None, None, None)
398
-
399
-
400
- def mamba_inner_fn(
401
- xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
402
- out_proj_weight, out_proj_bias,
403
- A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
404
- C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight= None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6
405
- ):
406
- return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
407
- out_proj_weight, out_proj_bias,
408
- A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps)
409
-
410
-
411
- def mamba_inner_ref(
412
- xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
413
- out_proj_weight, out_proj_bias,
414
- A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
415
- C_proj_bias=None, delta_softplus=True
416
- ):
417
- assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
418
- L = xz.shape[-1]
419
- delta_rank = delta_proj_weight.shape[1]
420
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
421
- x, z = xz.chunk(2, dim=1)
422
- x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
423
- # We're being very careful here about the layout, to avoid extra transposes.
424
- # We want delta to have d as the slowest moving dimension
425
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
426
- x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
427
- delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
428
- delta = rearrange(delta, "d (b l) -> b d l", l=L)
429
- if B is None: # variable B
430
- B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
431
- if B_proj_bias is not None:
432
- B = B + B_proj_bias.to(dtype=B.dtype)
433
- if not A.is_complex():
434
- B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
435
- else:
436
- B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
437
- if C is None: # variable B
438
- C = x_dbl[:, -d_state:] # (bl d)
439
- if C_proj_bias is not None:
440
- C = C + C_proj_bias.to(dtype=C.dtype)
441
- if not A.is_complex():
442
- C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
443
- else:
444
- C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
445
- y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
446
- return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/__init__.py DELETED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/k_activations.py DELETED
@@ -1,169 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- import torch
4
-
5
- import triton
6
- import triton.language as tl
7
-
8
-
9
- @triton.autotune(
10
- configs=[
11
- triton.Config({'BLOCK_N': 32}),
12
- triton.Config({'BLOCK_N': 64}),
13
- triton.Config({'BLOCK_N': 128}),
14
- triton.Config({'BLOCK_N': 256}),
15
- triton.Config({'BLOCK_N': 512}),
16
- triton.Config({'BLOCK_N': 1024}),
17
- ],
18
- key=['ncols'],
19
- )
20
- @triton.jit
21
- def _swiglu_fwd_kernel(
22
- X,
23
- Y,
24
- OUT,
25
- stride_x_row, # how much to increase the pointer when moving by 1 row
26
- stride_y_row,
27
- stride_out_row,
28
- ncols,
29
- BLOCK_N: tl.constexpr,
30
- ):
31
- # Map the program id to the row of X and Y it should compute.
32
- row = tl.program_id(0)
33
- start_col = tl.program_id(1) * BLOCK_N
34
- X += row * stride_x_row
35
- Y += row * stride_y_row
36
- OUT += row * stride_out_row
37
- cols = start_col + tl.arange(0, BLOCK_N)
38
- x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
39
- y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
40
- out = x * tl.sigmoid(x) * y
41
- tl.store(OUT + cols, out, mask=cols < ncols)
42
-
43
-
44
- def _swiglu_fwd(xy, out=None):
45
- if xy.stride(-1) != 1:
46
- xy = xy.contiguous()
47
- batch_shape = xy.shape[:-1]
48
- xy = xy.reshape(-1, xy.shape[-1])
49
- x, y = xy.chunk(2, dim=-1)
50
- if out is None:
51
- out = torch.empty_like(x)
52
- else:
53
- out = out.reshape(-1, out.shape[-1])
54
- assert out.shape == x.shape
55
- assert out.stride(-1) == 1
56
- M, N = x.shape
57
- grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
58
- with torch.cuda.device(x.device.index):
59
- _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
60
- return out.reshape(*batch_shape, out.shape[-1])
61
-
62
-
63
- @triton.autotune(
64
- configs=[
65
- triton.Config({'BLOCK_N': 32}),
66
- triton.Config({'BLOCK_N': 64}),
67
- triton.Config({'BLOCK_N': 128}),
68
- triton.Config({'BLOCK_N': 256}),
69
- triton.Config({'BLOCK_N': 512}),
70
- triton.Config({'BLOCK_N': 1024}),
71
- ],
72
- key=['ncols'],
73
- )
74
- @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
75
- @triton.jit
76
- def _swiglu_bwd_kernel(
77
- X,
78
- Y,
79
- DOUT,
80
- OUT,
81
- DX,
82
- DY,
83
- stride_x_row, # how much to increase the pointer when moving by 1 row
84
- stride_y_row,
85
- stride_dout_row,
86
- stride_out_row,
87
- stride_dx_row,
88
- stride_dy_row,
89
- ncols,
90
- BLOCK_N: tl.constexpr,
91
- RECOMPUTE_OUTPUT: tl.constexpr,
92
- ):
93
- # Map the program id to the row of X and Y it should compute.
94
- row = tl.program_id(0)
95
- start_col = tl.program_id(1) * BLOCK_N
96
- X += row * stride_x_row
97
- Y += row * stride_y_row
98
- DOUT += row * stride_dout_row
99
- if RECOMPUTE_OUTPUT:
100
- OUT += row * stride_out_row
101
- DX += row * stride_dx_row
102
- DY += row * stride_dy_row
103
- cols = start_col + tl.arange(0, BLOCK_N)
104
- x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
105
- y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
106
- dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
107
- x_sigmoid = tl.sigmoid(x)
108
- dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
109
- dy = x * x_sigmoid * dout
110
- tl.store(DX + cols, dx, mask=cols < ncols)
111
- tl.store(DY + cols, dy, mask=cols < ncols)
112
- if RECOMPUTE_OUTPUT:
113
- out = x * x_sigmoid * y
114
- tl.store(OUT + cols, out, mask=cols < ncols)
115
-
116
-
117
- def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
118
- if xy.stride(-1) != 1:
119
- xy = xy.contiguous()
120
- if dout.stride(-1) != 1:
121
- dout = dout.contiguous()
122
- batch_shape = xy.shape[:-1]
123
- xy = xy.reshape(-1, xy.shape[-1])
124
- x, y = xy.chunk(2, dim=-1)
125
- dout = dout.reshape(-1, dout.shape[-1])
126
- assert dout.shape == x.shape
127
- if dxy is None:
128
- dxy = torch.empty_like(xy)
129
- else:
130
- dxy = dxy.reshape(-1, dxy.shape[-1])
131
- assert dxy.shape == xy.shape
132
- dx, dy = dxy.chunk(2, dim=-1)
133
- assert dx.stride(-1) == 1
134
- assert dy.stride(-1) == 1
135
- if recompute_output:
136
- if out is None:
137
- out = torch.empty_like(x)
138
- else:
139
- out = out.reshape(-1, out.shape[-1])
140
- assert out.shape == x.shape
141
- assert out.stride(-1) == 1
142
- M, N = x.shape
143
- grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
144
- with torch.cuda.device(x.device.index):
145
- _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
146
- x.stride(0), y.stride(0), dout.stride(0),
147
- out.stride(0) if recompute_output else 0,
148
- dx.stride(0), dy.stride(0),
149
- N)
150
- if not recompute_output:
151
- return dxy.reshape(*batch_shape, dxy.shape[-1])
152
- else:
153
- return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
154
-
155
-
156
- class SwiGLU(torch.autograd.Function):
157
-
158
- @staticmethod
159
- def forward(ctx, xy):
160
- ctx.save_for_backward(xy)
161
- return _swiglu_fwd(xy)
162
-
163
- @staticmethod
164
- def backward(ctx, dout):
165
- xy, = ctx.saved_tensors
166
- return _swiglu_bwd(xy, dout)
167
-
168
-
169
- swiglu = SwiGLU.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/layer_norm.py DELETED
@@ -1,1113 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # Implement dropout + residual + layer_norm / rms_norm.
3
-
4
- # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
- # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
- # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
- # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
-
9
- import math
10
- import warnings
11
-
12
- import torch
13
- import torch.nn.functional as F
14
- from ...utils.torch import custom_bwd, custom_fwd
15
-
16
- import triton
17
- import triton.language as tl
18
-
19
-
20
- def layer_norm_ref(
21
- x,
22
- weight,
23
- bias,
24
- residual=None,
25
- x1=None,
26
- weight1=None,
27
- bias1=None,
28
- eps=1e-6,
29
- dropout_p=0.0,
30
- rowscale=None,
31
- prenorm=False,
32
- dropout_mask=None,
33
- dropout_mask1=None,
34
- upcast=False,
35
- ):
36
- dtype = x.dtype
37
- if upcast:
38
- x = x.float()
39
- weight = weight.float()
40
- bias = bias.float() if bias is not None else None
41
- residual = residual.float() if residual is not None else residual
42
- x1 = x1.float() if x1 is not None else None
43
- weight1 = weight1.float() if weight1 is not None else None
44
- bias1 = bias1.float() if bias1 is not None else None
45
- if x1 is not None:
46
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
47
- if rowscale is not None:
48
- x = x * rowscale[..., None]
49
- if dropout_p > 0.0:
50
- if dropout_mask is not None:
51
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
52
- else:
53
- x = F.dropout(x, p=dropout_p)
54
- if x1 is not None:
55
- if dropout_mask1 is not None:
56
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
57
- else:
58
- x1 = F.dropout(x1, p=dropout_p)
59
- if x1 is not None:
60
- x = x + x1
61
- if residual is not None:
62
- x = (x + residual).to(x.dtype)
63
- out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
64
- dtype
65
- )
66
- if weight1 is None:
67
- return out if not prenorm else (out, x)
68
- else:
69
- out1 = F.layer_norm(
70
- x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
71
- ).to(dtype)
72
- return (out, out1) if not prenorm else (out, out1, x)
73
-
74
-
75
- def rms_norm_ref(
76
- x,
77
- weight,
78
- bias,
79
- residual=None,
80
- x1=None,
81
- weight1=None,
82
- bias1=None,
83
- eps=1e-6,
84
- dropout_p=0.0,
85
- rowscale=None,
86
- prenorm=False,
87
- dropout_mask=None,
88
- dropout_mask1=None,
89
- upcast=False,
90
- ):
91
- dtype = x.dtype
92
- if upcast:
93
- x = x.float()
94
- weight = weight.float()
95
- bias = bias.float() if bias is not None else None
96
- residual = residual.float() if residual is not None else residual
97
- x1 = x1.float() if x1 is not None else None
98
- weight1 = weight1.float() if weight1 is not None else None
99
- bias1 = bias1.float() if bias1 is not None else None
100
- if x1 is not None:
101
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
102
- if rowscale is not None:
103
- x = x * rowscale[..., None]
104
- if dropout_p > 0.0:
105
- if dropout_mask is not None:
106
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
107
- else:
108
- x = F.dropout(x, p=dropout_p)
109
- if x1 is not None:
110
- if dropout_mask1 is not None:
111
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
112
- else:
113
- x1 = F.dropout(x1, p=dropout_p)
114
- if x1 is not None:
115
- x = x + x1
116
- if residual is not None:
117
- x = (x + residual).to(x.dtype)
118
- rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
119
- out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
120
- if weight1 is None:
121
- return out if not prenorm else (out, x)
122
- else:
123
- out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
124
- dtype
125
- )
126
- return (out, out1) if not prenorm else (out, out1, x)
127
-
128
- def config_prune(configs):
129
-
130
- if torch.version.hip:
131
- try:
132
- # set warp size based on gcn architecure
133
- gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
134
- if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
135
- # radeon
136
- warp_size = 32
137
- else:
138
- # instinct
139
- warp_size = 64
140
- except AttributeError as e:
141
- # fall back to crude method to set warp size
142
- device_name = torch.cuda.get_device_properties(0).name
143
- if 'instinct' in device_name.lower():
144
- warp_size = 64
145
- else:
146
- warp_size = 32
147
- warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning)
148
-
149
- else:
150
- # cuda
151
- warp_size = 32
152
-
153
- max_block_sz = 1024
154
- max_num_warps = max_block_sz // warp_size
155
- pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
156
- return pruned_configs
157
-
158
- configs_autotune = [
159
- triton.Config({}, num_warps=1),
160
- triton.Config({}, num_warps=2),
161
- triton.Config({}, num_warps=4),
162
- triton.Config({}, num_warps=8),
163
- triton.Config({}, num_warps=16),
164
- triton.Config({}, num_warps=32),
165
- ]
166
-
167
- pruned_configs_autotune = config_prune(configs_autotune)
168
-
169
- @triton.autotune(
170
- configs = pruned_configs_autotune,
171
- key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
172
- )
173
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
174
- # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
175
- @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
176
- @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
177
- @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
178
- @triton.jit
179
- def _layer_norm_fwd_1pass_kernel(
180
- X, # pointer to the input
181
- Y, # pointer to the output
182
- W, # pointer to the weights
183
- B, # pointer to the biases
184
- RESIDUAL, # pointer to the residual
185
- X1,
186
- W1,
187
- B1,
188
- Y1,
189
- RESIDUAL_OUT, # pointer to the residual
190
- ROWSCALE,
191
- SEEDS, # Dropout seeds for each row
192
- DROPOUT_MASK,
193
- Mean, # pointer to the mean
194
- Rstd, # pointer to the 1/std
195
- stride_x_row, # how much to increase the pointer when moving by 1 row
196
- stride_y_row,
197
- stride_res_row,
198
- stride_res_out_row,
199
- stride_x1_row,
200
- stride_y1_row,
201
- M, # number of rows in X
202
- N, # number of columns in X
203
- eps, # epsilon to avoid division by zero
204
- dropout_p, # Dropout probability
205
- IS_RMS_NORM: tl.constexpr,
206
- BLOCK_N: tl.constexpr,
207
- HAS_RESIDUAL: tl.constexpr,
208
- STORE_RESIDUAL_OUT: tl.constexpr,
209
- HAS_BIAS: tl.constexpr,
210
- HAS_DROPOUT: tl.constexpr,
211
- STORE_DROPOUT_MASK: tl.constexpr,
212
- HAS_ROWSCALE: tl.constexpr,
213
- HAS_X1: tl.constexpr,
214
- HAS_W1: tl.constexpr,
215
- HAS_B1: tl.constexpr,
216
- ):
217
- # Map the program id to the row of X and Y it should compute.
218
- row = tl.program_id(0)
219
- X += row * stride_x_row
220
- Y += row * stride_y_row
221
- if HAS_RESIDUAL:
222
- RESIDUAL += row * stride_res_row
223
- if STORE_RESIDUAL_OUT:
224
- RESIDUAL_OUT += row * stride_res_out_row
225
- if HAS_X1:
226
- X1 += row * stride_x1_row
227
- if HAS_W1:
228
- Y1 += row * stride_y1_row
229
- # Compute mean and variance
230
- cols = tl.arange(0, BLOCK_N)
231
- x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
232
- if HAS_ROWSCALE:
233
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
234
- x *= rowscale
235
- if HAS_DROPOUT:
236
- # Compute dropout mask
237
- # 7 rounds is good enough, and reduces register pressure
238
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
239
- x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
240
- if STORE_DROPOUT_MASK:
241
- tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
242
- if HAS_X1:
243
- x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
244
- if HAS_ROWSCALE:
245
- rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
246
- x1 *= rowscale
247
- if HAS_DROPOUT:
248
- # Compute dropout mask
249
- # 7 rounds is good enough, and reduces register pressure
250
- keep_mask = (
251
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
252
- )
253
- x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
254
- if STORE_DROPOUT_MASK:
255
- tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
256
- x += x1
257
- if HAS_RESIDUAL:
258
- residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
259
- x += residual
260
- if STORE_RESIDUAL_OUT:
261
- tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
262
- if not IS_RMS_NORM:
263
- mean = tl.sum(x, axis=0) / N
264
- tl.store(Mean + row, mean)
265
- xbar = tl.where(cols < N, x - mean, 0.0)
266
- var = tl.sum(xbar * xbar, axis=0) / N
267
- else:
268
- xbar = tl.where(cols < N, x, 0.0)
269
- var = tl.sum(xbar * xbar, axis=0) / N
270
- rstd = 1 / tl.sqrt(var + eps)
271
- tl.store(Rstd + row, rstd)
272
- # Normalize and apply linear transformation
273
- mask = cols < N
274
- w = tl.load(W + cols, mask=mask).to(tl.float32)
275
- if HAS_BIAS:
276
- b = tl.load(B + cols, mask=mask).to(tl.float32)
277
- x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
278
- y = x_hat * w + b if HAS_BIAS else x_hat * w
279
- # Write output
280
- tl.store(Y + cols, y, mask=mask)
281
- if HAS_W1:
282
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
283
- if HAS_B1:
284
- b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
285
- y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
286
- tl.store(Y1 + cols, y1, mask=mask)
287
-
288
-
289
- def _layer_norm_fwd(
290
- x,
291
- weight,
292
- bias,
293
- eps,
294
- residual=None,
295
- x1=None,
296
- weight1=None,
297
- bias1=None,
298
- dropout_p=0.0,
299
- rowscale=None,
300
- out_dtype=None,
301
- residual_dtype=None,
302
- is_rms_norm=False,
303
- return_dropout_mask=False,
304
- ):
305
- if residual is not None:
306
- residual_dtype = residual.dtype
307
- M, N = x.shape
308
- assert x.stride(-1) == 1
309
- if residual is not None:
310
- assert residual.stride(-1) == 1
311
- assert residual.shape == (M, N)
312
- assert weight.shape == (N,)
313
- assert weight.stride(-1) == 1
314
- if bias is not None:
315
- assert bias.stride(-1) == 1
316
- assert bias.shape == (N,)
317
- if x1 is not None:
318
- assert x1.shape == x.shape
319
- assert rowscale is None
320
- assert x1.stride(-1) == 1
321
- if weight1 is not None:
322
- assert weight1.shape == (N,)
323
- assert weight1.stride(-1) == 1
324
- if bias1 is not None:
325
- assert bias1.shape == (N,)
326
- assert bias1.stride(-1) == 1
327
- if rowscale is not None:
328
- assert rowscale.is_contiguous()
329
- assert rowscale.shape == (M,)
330
- # allocate output
331
- y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
332
- assert y.stride(-1) == 1
333
- if weight1 is not None:
334
- y1 = torch.empty_like(y)
335
- assert y1.stride(-1) == 1
336
- else:
337
- y1 = None
338
- if (
339
- residual is not None
340
- or (residual_dtype is not None and residual_dtype != x.dtype)
341
- or dropout_p > 0.0
342
- or rowscale is not None
343
- or x1 is not None
344
- ):
345
- residual_out = torch.empty(
346
- M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
347
- )
348
- assert residual_out.stride(-1) == 1
349
- else:
350
- residual_out = None
351
- mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
352
- rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
353
- if dropout_p > 0.0:
354
- seeds = torch.randint(
355
- 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
356
- )
357
- else:
358
- seeds = None
359
- if return_dropout_mask and dropout_p > 0.0:
360
- dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
361
- else:
362
- dropout_mask = None
363
- # Less than 64KB per feature: enqueue fused kernel
364
- MAX_FUSED_SIZE = 65536 // x.element_size()
365
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
366
- if N > BLOCK_N:
367
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
368
- with torch.cuda.device(x.device.index):
369
- _layer_norm_fwd_1pass_kernel[(M,)](
370
- x,
371
- y,
372
- weight,
373
- bias,
374
- residual,
375
- x1,
376
- weight1,
377
- bias1,
378
- y1,
379
- residual_out,
380
- rowscale,
381
- seeds,
382
- dropout_mask,
383
- mean,
384
- rstd,
385
- x.stride(0),
386
- y.stride(0),
387
- residual.stride(0) if residual is not None else 0,
388
- residual_out.stride(0) if residual_out is not None else 0,
389
- x1.stride(0) if x1 is not None else 0,
390
- y1.stride(0) if y1 is not None else 0,
391
- M,
392
- N,
393
- eps,
394
- dropout_p,
395
- is_rms_norm,
396
- BLOCK_N,
397
- residual is not None,
398
- residual_out is not None,
399
- bias is not None,
400
- dropout_p > 0.0,
401
- dropout_mask is not None,
402
- rowscale is not None,
403
- )
404
- # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
405
- if dropout_mask is not None and x1 is not None:
406
- dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
407
- else:
408
- dropout_mask1 = None
409
- return (
410
- y,
411
- y1,
412
- mean,
413
- rstd,
414
- residual_out if residual_out is not None else x,
415
- seeds,
416
- dropout_mask,
417
- dropout_mask1,
418
- )
419
-
420
-
421
- @triton.autotune(
422
- configs=pruned_configs_autotune,
423
- key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
424
- )
425
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
426
- # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
427
- # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
428
- @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
429
- @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
430
- @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
431
- @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
432
- @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
433
- @triton.jit
434
- def _layer_norm_bwd_kernel(
435
- X, # pointer to the input
436
- W, # pointer to the weights
437
- B, # pointer to the biases
438
- Y, # pointer to the output to be recomputed
439
- DY, # pointer to the output gradient
440
- DX, # pointer to the input gradient
441
- DW, # pointer to the partial sum of weights gradient
442
- DB, # pointer to the partial sum of biases gradient
443
- DRESIDUAL,
444
- W1,
445
- DY1,
446
- DX1,
447
- DW1,
448
- DB1,
449
- DRESIDUAL_IN,
450
- ROWSCALE,
451
- SEEDS,
452
- Mean, # pointer to the mean
453
- Rstd, # pointer to the 1/std
454
- stride_x_row, # how much to increase the pointer when moving by 1 row
455
- stride_y_row,
456
- stride_dy_row,
457
- stride_dx_row,
458
- stride_dres_row,
459
- stride_dy1_row,
460
- stride_dx1_row,
461
- stride_dres_in_row,
462
- M, # number of rows in X
463
- N, # number of columns in X
464
- eps, # epsilon to avoid division by zero
465
- dropout_p,
466
- rows_per_program,
467
- IS_RMS_NORM: tl.constexpr,
468
- BLOCK_N: tl.constexpr,
469
- HAS_DRESIDUAL: tl.constexpr,
470
- STORE_DRESIDUAL: tl.constexpr,
471
- HAS_BIAS: tl.constexpr,
472
- HAS_DROPOUT: tl.constexpr,
473
- HAS_ROWSCALE: tl.constexpr,
474
- HAS_DY1: tl.constexpr,
475
- HAS_DX1: tl.constexpr,
476
- HAS_B1: tl.constexpr,
477
- RECOMPUTE_OUTPUT: tl.constexpr,
478
- ):
479
- # Map the program id to the elements of X, DX, and DY it should compute.
480
- row_block_id = tl.program_id(0)
481
- row_start = row_block_id * rows_per_program
482
- # Do not early exit if row_start >= M, because we need to write DW and DB
483
- cols = tl.arange(0, BLOCK_N)
484
- mask = cols < N
485
- X += row_start * stride_x_row
486
- if HAS_DRESIDUAL:
487
- DRESIDUAL += row_start * stride_dres_row
488
- if STORE_DRESIDUAL:
489
- DRESIDUAL_IN += row_start * stride_dres_in_row
490
- DY += row_start * stride_dy_row
491
- DX += row_start * stride_dx_row
492
- if HAS_DY1:
493
- DY1 += row_start * stride_dy1_row
494
- if HAS_DX1:
495
- DX1 += row_start * stride_dx1_row
496
- if RECOMPUTE_OUTPUT:
497
- Y += row_start * stride_y_row
498
- w = tl.load(W + cols, mask=mask).to(tl.float32)
499
- if RECOMPUTE_OUTPUT and HAS_BIAS:
500
- b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
501
- if HAS_DY1:
502
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
503
- dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
504
- if HAS_BIAS:
505
- db = tl.zeros((BLOCK_N,), dtype=tl.float32)
506
- if HAS_DY1:
507
- dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
508
- if HAS_B1:
509
- db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
510
- row_end = min((row_block_id + 1) * rows_per_program, M)
511
- for row in range(row_start, row_end):
512
- # Load data to SRAM
513
- x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
514
- dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
515
- if HAS_DY1:
516
- dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
517
- if not IS_RMS_NORM:
518
- mean = tl.load(Mean + row)
519
- rstd = tl.load(Rstd + row)
520
- # Compute dx
521
- xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
522
- xhat = tl.where(mask, xhat, 0.0)
523
- if RECOMPUTE_OUTPUT:
524
- y = xhat * w + b if HAS_BIAS else xhat * w
525
- tl.store(Y + cols, y, mask=mask)
526
- wdy = w * dy
527
- dw += dy * xhat
528
- if HAS_BIAS:
529
- db += dy
530
- if HAS_DY1:
531
- wdy += w1 * dy1
532
- dw1 += dy1 * xhat
533
- if HAS_B1:
534
- db1 += dy1
535
- if not IS_RMS_NORM:
536
- c1 = tl.sum(xhat * wdy, axis=0) / N
537
- c2 = tl.sum(wdy, axis=0) / N
538
- dx = (wdy - (xhat * c1 + c2)) * rstd
539
- else:
540
- c1 = tl.sum(xhat * wdy, axis=0) / N
541
- dx = (wdy - xhat * c1) * rstd
542
- if HAS_DRESIDUAL:
543
- dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
544
- dx += dres
545
- # Write dx
546
- if STORE_DRESIDUAL:
547
- tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
548
- if HAS_DX1:
549
- if HAS_DROPOUT:
550
- keep_mask = (
551
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
552
- )
553
- dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
554
- else:
555
- dx1 = dx
556
- tl.store(DX1 + cols, dx1, mask=mask)
557
- if HAS_DROPOUT:
558
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
559
- dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
560
- if HAS_ROWSCALE:
561
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
562
- dx *= rowscale
563
- tl.store(DX + cols, dx, mask=mask)
564
-
565
- X += stride_x_row
566
- if HAS_DRESIDUAL:
567
- DRESIDUAL += stride_dres_row
568
- if STORE_DRESIDUAL:
569
- DRESIDUAL_IN += stride_dres_in_row
570
- if RECOMPUTE_OUTPUT:
571
- Y += stride_y_row
572
- DY += stride_dy_row
573
- DX += stride_dx_row
574
- if HAS_DY1:
575
- DY1 += stride_dy1_row
576
- if HAS_DX1:
577
- DX1 += stride_dx1_row
578
- tl.store(DW + row_block_id * N + cols, dw, mask=mask)
579
- if HAS_BIAS:
580
- tl.store(DB + row_block_id * N + cols, db, mask=mask)
581
- if HAS_DY1:
582
- tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
583
- if HAS_B1:
584
- tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
585
-
586
-
587
- def _layer_norm_bwd(
588
- dy,
589
- x,
590
- weight,
591
- bias,
592
- eps,
593
- mean,
594
- rstd,
595
- dresidual=None,
596
- dy1=None,
597
- weight1=None,
598
- bias1=None,
599
- seeds=None,
600
- dropout_p=0.0,
601
- rowscale=None,
602
- has_residual=False,
603
- has_x1=False,
604
- is_rms_norm=False,
605
- x_dtype=None,
606
- recompute_output=False,
607
- ):
608
- M, N = x.shape
609
- assert x.stride(-1) == 1
610
- assert dy.stride(-1) == 1
611
- assert dy.shape == (M, N)
612
- if dresidual is not None:
613
- assert dresidual.stride(-1) == 1
614
- assert dresidual.shape == (M, N)
615
- assert weight.shape == (N,)
616
- assert weight.stride(-1) == 1
617
- if bias is not None:
618
- assert bias.stride(-1) == 1
619
- assert bias.shape == (N,)
620
- if dy1 is not None:
621
- assert weight1 is not None
622
- assert dy1.shape == dy.shape
623
- assert dy1.stride(-1) == 1
624
- if weight1 is not None:
625
- assert weight1.shape == (N,)
626
- assert weight1.stride(-1) == 1
627
- if bias1 is not None:
628
- assert bias1.shape == (N,)
629
- assert bias1.stride(-1) == 1
630
- if seeds is not None:
631
- assert seeds.is_contiguous()
632
- assert seeds.shape == (M if not has_x1 else M * 2,)
633
- if rowscale is not None:
634
- assert rowscale.is_contiguous()
635
- assert rowscale.shape == (M,)
636
- # allocate output
637
- dx = (
638
- torch.empty_like(x)
639
- if x_dtype is None
640
- else torch.empty(M, N, dtype=x_dtype, device=x.device)
641
- )
642
- dresidual_in = (
643
- torch.empty_like(x)
644
- if has_residual
645
- and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
646
- else None
647
- )
648
- dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
649
- y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
650
- if recompute_output:
651
- assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
652
-
653
- # Less than 64KB per feature: enqueue fused kernel
654
- MAX_FUSED_SIZE = 65536 // x.element_size()
655
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
656
- if N > BLOCK_N:
657
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
658
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
659
- _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
660
- _db = (
661
- torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
662
- if bias is not None
663
- else None
664
- )
665
- _dw1 = torch.empty_like(_dw) if weight1 is not None else None
666
- _db1 = torch.empty_like(_db) if bias1 is not None else None
667
- rows_per_program = math.ceil(M / sm_count)
668
- grid = (sm_count,)
669
- with torch.cuda.device(x.device.index):
670
- _layer_norm_bwd_kernel[grid](
671
- x,
672
- weight,
673
- bias,
674
- y,
675
- dy,
676
- dx,
677
- _dw,
678
- _db,
679
- dresidual,
680
- weight1,
681
- dy1,
682
- dx1,
683
- _dw1,
684
- _db1,
685
- dresidual_in,
686
- rowscale,
687
- seeds,
688
- mean,
689
- rstd,
690
- x.stride(0),
691
- 0 if not recompute_output else y.stride(0),
692
- dy.stride(0),
693
- dx.stride(0),
694
- dresidual.stride(0) if dresidual is not None else 0,
695
- dy1.stride(0) if dy1 is not None else 0,
696
- dx1.stride(0) if dx1 is not None else 0,
697
- dresidual_in.stride(0) if dresidual_in is not None else 0,
698
- M,
699
- N,
700
- eps,
701
- dropout_p,
702
- rows_per_program,
703
- is_rms_norm,
704
- BLOCK_N,
705
- dresidual is not None,
706
- dresidual_in is not None,
707
- bias is not None,
708
- dropout_p > 0.0,
709
- )
710
- dw = _dw.sum(0).to(weight.dtype)
711
- db = _db.sum(0).to(bias.dtype) if bias is not None else None
712
- dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
713
- db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
714
- # Don't need to compute dresidual_in separately in this case
715
- if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
716
- dresidual_in = dx
717
- if has_x1 and dropout_p == 0.0:
718
- dx1 = dx
719
- return (
720
- (dx, dw, db, dresidual_in, dx1, dw1, db1)
721
- if not recompute_output
722
- else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
723
- )
724
-
725
-
726
- class LayerNormFn(torch.autograd.Function):
727
- @staticmethod
728
- def forward(
729
- ctx,
730
- x,
731
- weight,
732
- bias,
733
- residual=None,
734
- x1=None,
735
- weight1=None,
736
- bias1=None,
737
- eps=1e-6,
738
- dropout_p=0.0,
739
- rowscale=None,
740
- prenorm=False,
741
- residual_in_fp32=False,
742
- is_rms_norm=False,
743
- return_dropout_mask=False,
744
- ):
745
- x_shape_og = x.shape
746
- # reshape input data into 2D tensor
747
- x = x.reshape(-1, x.shape[-1])
748
- if x.stride(-1) != 1:
749
- x = x.contiguous()
750
- if residual is not None:
751
- assert residual.shape == x_shape_og
752
- residual = residual.reshape(-1, residual.shape[-1])
753
- if residual.stride(-1) != 1:
754
- residual = residual.contiguous()
755
- if x1 is not None:
756
- assert x1.shape == x_shape_og
757
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
758
- x1 = x1.reshape(-1, x1.shape[-1])
759
- if x1.stride(-1) != 1:
760
- x1 = x1.contiguous()
761
- weight = weight.contiguous()
762
- if bias is not None:
763
- bias = bias.contiguous()
764
- if weight1 is not None:
765
- weight1 = weight1.contiguous()
766
- if bias1 is not None:
767
- bias1 = bias1.contiguous()
768
- if rowscale is not None:
769
- rowscale = rowscale.reshape(-1).contiguous()
770
- residual_dtype = (
771
- residual.dtype
772
- if residual is not None
773
- else (torch.float32 if residual_in_fp32 else None)
774
- )
775
- y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
776
- x,
777
- weight,
778
- bias,
779
- eps,
780
- residual,
781
- x1,
782
- weight1,
783
- bias1,
784
- dropout_p=dropout_p,
785
- rowscale=rowscale,
786
- residual_dtype=residual_dtype,
787
- is_rms_norm=is_rms_norm,
788
- return_dropout_mask=return_dropout_mask,
789
- )
790
- ctx.save_for_backward(
791
- residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
792
- )
793
- ctx.x_shape_og = x_shape_og
794
- ctx.eps = eps
795
- ctx.dropout_p = dropout_p
796
- ctx.is_rms_norm = is_rms_norm
797
- ctx.has_residual = residual is not None
798
- ctx.has_x1 = x1 is not None
799
- ctx.prenorm = prenorm
800
- ctx.x_dtype = x.dtype
801
- y = y.reshape(x_shape_og)
802
- y1 = y1.reshape(x_shape_og) if y1 is not None else None
803
- residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
804
- dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
805
- dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
806
- if not return_dropout_mask:
807
- if weight1 is None:
808
- return y if not prenorm else (y, residual_out)
809
- else:
810
- return (y, y1) if not prenorm else (y, y1, residual_out)
811
- else:
812
- if weight1 is None:
813
- return (
814
- (y, dropout_mask, dropout_mask1)
815
- if not prenorm
816
- else (y, residual_out, dropout_mask, dropout_mask1)
817
- )
818
- else:
819
- return (
820
- (y, y1, dropout_mask, dropout_mask1)
821
- if not prenorm
822
- else (y, y1, residual_out, dropout_mask, dropout_mask1)
823
- )
824
-
825
- @staticmethod
826
- def backward(ctx, dy, *args):
827
- x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
828
- dy = dy.reshape(-1, dy.shape[-1])
829
- if dy.stride(-1) != 1:
830
- dy = dy.contiguous()
831
- assert dy.shape == x.shape
832
- if weight1 is not None:
833
- dy1, args = args[0], args[1:]
834
- dy1 = dy1.reshape(-1, dy1.shape[-1])
835
- if dy1.stride(-1) != 1:
836
- dy1 = dy1.contiguous()
837
- assert dy1.shape == x.shape
838
- else:
839
- dy1 = None
840
- if ctx.prenorm:
841
- dresidual = args[0]
842
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
843
- if dresidual.stride(-1) != 1:
844
- dresidual = dresidual.contiguous()
845
- assert dresidual.shape == x.shape
846
- else:
847
- dresidual = None
848
- dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
849
- dy,
850
- x,
851
- weight,
852
- bias,
853
- ctx.eps,
854
- mean,
855
- rstd,
856
- dresidual,
857
- dy1,
858
- weight1,
859
- bias1,
860
- seeds,
861
- ctx.dropout_p,
862
- rowscale,
863
- ctx.has_residual,
864
- ctx.has_x1,
865
- ctx.is_rms_norm,
866
- x_dtype=ctx.x_dtype,
867
- )
868
- return (
869
- dx.reshape(ctx.x_shape_og),
870
- dw,
871
- db,
872
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
873
- dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
874
- dw1,
875
- db1,
876
- None,
877
- None,
878
- None,
879
- None,
880
- None,
881
- None,
882
- None,
883
- )
884
-
885
-
886
- def layer_norm_fn(
887
- x,
888
- weight,
889
- bias,
890
- residual=None,
891
- x1=None,
892
- weight1=None,
893
- bias1=None,
894
- eps=1e-6,
895
- dropout_p=0.0,
896
- rowscale=None,
897
- prenorm=False,
898
- residual_in_fp32=False,
899
- is_rms_norm=False,
900
- return_dropout_mask=False,
901
- ):
902
- return LayerNormFn.apply(
903
- x,
904
- weight,
905
- bias,
906
- residual,
907
- x1,
908
- weight1,
909
- bias1,
910
- eps,
911
- dropout_p,
912
- rowscale,
913
- prenorm,
914
- residual_in_fp32,
915
- is_rms_norm,
916
- return_dropout_mask,
917
- )
918
-
919
-
920
- def rms_norm_fn(
921
- x,
922
- weight,
923
- bias,
924
- residual=None,
925
- x1=None,
926
- weight1=None,
927
- bias1=None,
928
- eps=1e-6,
929
- dropout_p=0.0,
930
- rowscale=None,
931
- prenorm=False,
932
- residual_in_fp32=False,
933
- return_dropout_mask=False,
934
- ):
935
- return LayerNormFn.apply(
936
- x,
937
- weight,
938
- bias,
939
- residual,
940
- x1,
941
- weight1,
942
- bias1,
943
- eps,
944
- dropout_p,
945
- rowscale,
946
- prenorm,
947
- residual_in_fp32,
948
- True,
949
- return_dropout_mask,
950
- )
951
-
952
-
953
- class RMSNorm(torch.nn.Module):
954
-
955
- def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
956
- factory_kwargs = {"device": device, "dtype": dtype}
957
- super().__init__()
958
- self.eps = eps
959
- if dropout_p > 0.0:
960
- self.drop = torch.nn.Dropout(dropout_p)
961
- else:
962
- self.drop = None
963
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
964
- self.register_parameter("bias", None)
965
- self.reset_parameters()
966
-
967
- def reset_parameters(self):
968
- torch.nn.init.ones_(self.weight)
969
-
970
- def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
971
- return rms_norm_fn(
972
- x,
973
- self.weight,
974
- self.bias,
975
- residual=residual,
976
- eps=self.eps,
977
- dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
978
- prenorm=prenorm,
979
- residual_in_fp32=residual_in_fp32,
980
- )
981
-
982
-
983
- class LayerNormLinearFn(torch.autograd.Function):
984
- @staticmethod
985
- @custom_fwd
986
- def forward(
987
- ctx,
988
- x,
989
- norm_weight,
990
- norm_bias,
991
- linear_weight,
992
- linear_bias,
993
- residual=None,
994
- eps=1e-6,
995
- prenorm=False,
996
- residual_in_fp32=False,
997
- is_rms_norm=False,
998
- ):
999
- x_shape_og = x.shape
1000
- # reshape input data into 2D tensor
1001
- x = x.reshape(-1, x.shape[-1])
1002
- if x.stride(-1) != 1:
1003
- x = x.contiguous()
1004
- if residual is not None:
1005
- assert residual.shape == x_shape_og
1006
- residual = residual.reshape(-1, residual.shape[-1])
1007
- if residual.stride(-1) != 1:
1008
- residual = residual.contiguous()
1009
- norm_weight = norm_weight.contiguous()
1010
- if norm_bias is not None:
1011
- norm_bias = norm_bias.contiguous()
1012
- residual_dtype = (
1013
- residual.dtype
1014
- if residual is not None
1015
- else (torch.float32 if residual_in_fp32 else None)
1016
- )
1017
- y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1018
- x,
1019
- norm_weight,
1020
- norm_bias,
1021
- eps,
1022
- residual,
1023
- out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
1024
- residual_dtype=residual_dtype,
1025
- is_rms_norm=is_rms_norm,
1026
- )
1027
- y = y.reshape(x_shape_og)
1028
- dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1029
- linear_weight = linear_weight.to(dtype)
1030
- linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1031
- out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1032
- # We don't store y, will be recomputed in the backward pass to save memory
1033
- ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
1034
- ctx.x_shape_og = x_shape_og
1035
- ctx.eps = eps
1036
- ctx.is_rms_norm = is_rms_norm
1037
- ctx.has_residual = residual is not None
1038
- ctx.prenorm = prenorm
1039
- ctx.x_dtype = x.dtype
1040
- ctx.linear_bias_is_none = linear_bias is None
1041
- return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1042
-
1043
- @staticmethod
1044
- @custom_bwd
1045
- def backward(ctx, dout, *args):
1046
- x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1047
- dout = dout.reshape(-1, dout.shape[-1])
1048
- dy = F.linear(dout, linear_weight.t())
1049
- dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1050
- if dy.stride(-1) != 1:
1051
- dy = dy.contiguous()
1052
- assert dy.shape == x.shape
1053
- if ctx.prenorm:
1054
- dresidual = args[0]
1055
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1056
- if dresidual.stride(-1) != 1:
1057
- dresidual = dresidual.contiguous()
1058
- assert dresidual.shape == x.shape
1059
- else:
1060
- dresidual = None
1061
- dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1062
- dy,
1063
- x,
1064
- norm_weight,
1065
- norm_bias,
1066
- ctx.eps,
1067
- mean,
1068
- rstd,
1069
- dresidual=dresidual,
1070
- has_residual=ctx.has_residual,
1071
- is_rms_norm=ctx.is_rms_norm,
1072
- x_dtype=ctx.x_dtype,
1073
- recompute_output=True,
1074
- )
1075
- dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1076
- return (
1077
- dx.reshape(ctx.x_shape_og),
1078
- dnorm_weight,
1079
- dnorm_bias,
1080
- dlinear_weight,
1081
- dlinear_bias,
1082
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1083
- None,
1084
- None,
1085
- None,
1086
- None,
1087
- )
1088
-
1089
-
1090
- def layer_norm_linear_fn(
1091
- x,
1092
- norm_weight,
1093
- norm_bias,
1094
- linear_weight,
1095
- linear_bias,
1096
- residual=None,
1097
- eps=1e-6,
1098
- prenorm=False,
1099
- residual_in_fp32=False,
1100
- is_rms_norm=False,
1101
- ):
1102
- return LayerNormLinearFn.apply(
1103
- x,
1104
- norm_weight,
1105
- norm_bias,
1106
- linear_weight,
1107
- linear_bias,
1108
- residual,
1109
- eps,
1110
- prenorm,
1111
- residual_in_fp32,
1112
- is_rms_norm,
1113
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/layernorm_gated.py DELETED
@@ -1,437 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
3
- # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
4
- # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
5
- # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
6
-
7
- import math
8
-
9
- import torch
10
- import torch.nn.functional as F
11
-
12
- import triton
13
- import triton.language as tl
14
-
15
- from einops import rearrange
16
-
17
-
18
- def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
19
- dtype = x.dtype
20
- N = x.shape[-1]
21
- weight = weight.float()
22
- bias = bias.float() if bias is not None else None
23
- if upcast:
24
- x = x.float()
25
- z = z.float() if z is not None else z
26
- if z is not None and not norm_before_gate:
27
- x = x * F.silu(z)
28
- if group_size is None:
29
- rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
30
- out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
31
- else:
32
- x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
33
- rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
34
- out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
35
- if bias is not None:
36
- out = out + bias
37
- if z is not None and norm_before_gate:
38
- out *= F.silu(z)
39
- return out.to(dtype)
40
-
41
-
42
- @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
43
- @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
44
- @triton.jit
45
- def _layer_norm_fwd_1pass_kernel(
46
- X, # pointer to the input
47
- Y, # pointer to the output
48
- W, # pointer to the weights
49
- B, # pointer to the biases
50
- Z, # pointer to the other branch
51
- Mean, # pointer to the mean
52
- Rstd, # pointer to the 1/std
53
- stride_x_row, # how much to increase the pointer when moving by 1 row
54
- stride_y_row,
55
- stride_z_row,
56
- M, # number of rows in X
57
- N, # number of columns in X
58
- eps, # epsilon to avoid division by zero
59
- BLOCK_N: tl.constexpr,
60
- HAS_BIAS: tl.constexpr,
61
- HAS_Z: tl.constexpr,
62
- NORM_BEFORE_GATE: tl.constexpr,
63
- IS_RMS_NORM: tl.constexpr,
64
- ):
65
- # Map the program id to the row of X and Y it should compute.
66
- row = tl.program_id(0)
67
- group = tl.program_id(1)
68
- X += row * stride_x_row + group * N
69
- Y += row * stride_y_row + group * N
70
- if HAS_Z:
71
- Z += row * stride_z_row + group * N
72
- if not IS_RMS_NORM:
73
- Mean += group * M
74
- Rstd += group * M
75
- W += group * N
76
- if HAS_BIAS:
77
- B += group * N
78
- # Compute mean and variance
79
- cols = tl.arange(0, BLOCK_N)
80
- x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
81
- if HAS_Z and not NORM_BEFORE_GATE:
82
- z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
83
- x *= z * tl.sigmoid(z)
84
- if not IS_RMS_NORM:
85
- mean = tl.sum(x, axis=0) / N
86
- tl.store(Mean + row, mean)
87
- xbar = tl.where(cols < N, x - mean, 0.)
88
- var = tl.sum(xbar * xbar, axis=0) / N
89
- else:
90
- xbar = tl.where(cols < N, x, 0.)
91
- var = tl.sum(xbar * xbar, axis=0) / N
92
- rstd = 1 / tl.sqrt(var + eps)
93
- tl.store(Rstd + row, rstd)
94
- # Normalize and apply linear transformation
95
- mask = cols < N
96
- w = tl.load(W + cols, mask=mask).to(tl.float32)
97
- if HAS_BIAS:
98
- b = tl.load(B + cols, mask=mask).to(tl.float32)
99
- x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
100
- y = x_hat * w + b if HAS_BIAS else x_hat * w
101
- if HAS_Z and NORM_BEFORE_GATE:
102
- z = tl.load(Z + cols, mask=mask).to(tl.float32)
103
- y *= z * tl.sigmoid(z)
104
- # Write output
105
- tl.store(Y + cols, y, mask=mask)
106
-
107
-
108
- def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):
109
- M, N = x.shape
110
- if group_size is None:
111
- group_size = N
112
- assert N % group_size == 0
113
- ngroups = N // group_size
114
- assert x.stride(-1) == 1
115
- if z is not None:
116
- assert z.stride(-1) == 1
117
- assert z.shape == (M, N)
118
- assert weight.shape == (N,)
119
- assert weight.stride(-1) == 1
120
- if bias is not None:
121
- assert bias.stride(-1) == 1
122
- assert bias.shape == (N,)
123
- # allocate output
124
- if out is not None:
125
- assert out.shape == x.shape
126
- else:
127
- out = torch.empty_like(x)
128
- assert out.stride(-1) == 1
129
- mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
130
- rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
131
- # Less than 64KB per feature: enqueue fused kernel
132
- MAX_FUSED_SIZE = 65536 // x.element_size()
133
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
134
- if group_size > BLOCK_N:
135
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
136
- # heuristics for number of warps
137
- num_warps = min(max(BLOCK_N // 256, 1), 8)
138
- grid = (M, ngroups)
139
- with torch.cuda.device(x.device.index):
140
- _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,
141
- x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,
142
- M, group_size, eps,
143
- BLOCK_N=BLOCK_N,
144
- NORM_BEFORE_GATE=norm_before_gate,
145
- IS_RMS_NORM=is_rms_norm,
146
- num_warps=num_warps)
147
- return out, mean, rstd
148
-
149
-
150
-
151
- @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
152
- @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
153
- @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
154
- @triton.jit
155
- def _layer_norm_bwd_kernel(
156
- X, # pointer to the input
157
- W, # pointer to the weights
158
- B, # pointer to the biases
159
- Z, # pointer to the other branch
160
- Y, # pointer to the output to be recomputed
161
- DY, # pointer to the output gradient
162
- DX, # pointer to the input gradient
163
- DW, # pointer to the partial sum of weights gradient
164
- DB, # pointer to the partial sum of biases gradient
165
- DZ, # pointer to the other branch
166
- Mean, # pointer to the mean
167
- Rstd, # pointer to the 1/std
168
- stride_x_row, # how much to increase the pointer when moving by 1 row
169
- stride_z_row,
170
- stride_y_row,
171
- stride_dy_row,
172
- stride_dx_row,
173
- stride_dz_row,
174
- stride_dw_row,
175
- stride_db_row,
176
- M, # number of rows in X
177
- N, # number of columns in X
178
- eps, # epsilon to avoid division by zero
179
- rows_per_program,
180
- NORM_BEFORE_GATE: tl.constexpr,
181
- IS_RMS_NORM: tl.constexpr,
182
- HAS_BIAS: tl.constexpr,
183
- HAS_Z: tl.constexpr,
184
- RECOMPUTE_OUTPUT: tl.constexpr,
185
- BLOCK_N: tl.constexpr,
186
- ):
187
- # Map the program id to the elements of X, DX, and DY it should compute.
188
- row_block_id = tl.program_id(0)
189
- group = tl.program_id(1)
190
- row_start = row_block_id * rows_per_program
191
- cols = tl.arange(0, BLOCK_N)
192
- mask = cols < N
193
- X += row_start * stride_x_row + group * N
194
- if HAS_Z:
195
- Z += row_start * stride_z_row + group * N
196
- DZ += row_start * stride_dz_row + group * N
197
- DY += row_start * stride_dy_row + group * N
198
- DX += row_start * stride_dx_row + group * N
199
- if RECOMPUTE_OUTPUT:
200
- Y += row_start * stride_y_row + group * N
201
- if not IS_RMS_NORM:
202
- Mean += group * M
203
- Rstd += group * M
204
- W += group * N
205
- w = tl.load(W + cols, mask=mask).to(tl.float32)
206
- if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
207
- B += group * N
208
- b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
209
- dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
210
- if HAS_BIAS:
211
- db = tl.zeros((BLOCK_N,), dtype=tl.float32)
212
- row_end = min((row_block_id + 1) * rows_per_program, M)
213
- for row in range(row_start, row_end):
214
- # Load data to SRAM
215
- x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
216
- dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
217
- if not IS_RMS_NORM:
218
- mean = tl.load(Mean + row)
219
- if HAS_Z and not NORM_BEFORE_GATE:
220
- z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
221
- x_og = x
222
- x = x_og * z * tl.sigmoid(z)
223
- rstd = tl.load(Rstd + row)
224
- # Compute dx
225
- xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
226
- xhat = tl.where(mask, xhat, 0.)
227
- if HAS_Z and NORM_BEFORE_GATE:
228
- z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
229
- z_sigmoid = tl.sigmoid(z)
230
- y = xhat * w + b if HAS_BIAS else xhat * w
231
- if RECOMPUTE_OUTPUT:
232
- tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
233
- dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
234
- tl.store(DZ + cols, dz, mask=mask)
235
- dy *= z * z_sigmoid
236
- else:
237
- if RECOMPUTE_OUTPUT:
238
- y = xhat * w + b if HAS_BIAS else xhat * w
239
- tl.store(Y + cols, y, mask=mask)
240
- wdy = w * dy
241
- c1 = tl.sum(xhat * wdy, axis=0) / N
242
- if not IS_RMS_NORM:
243
- c2 = tl.sum(wdy, axis=0) / N
244
- dx = (wdy - (xhat * c1 + c2)) * rstd
245
- else:
246
- dx = (wdy - xhat * c1) * rstd
247
- dw += dy * xhat
248
- if HAS_BIAS:
249
- db += dy
250
- if HAS_Z and not NORM_BEFORE_GATE:
251
- z_sigmoid = tl.sigmoid(z)
252
- dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
253
- tl.store(DZ + cols, dz, mask=mask)
254
- dx *= z * z_sigmoid
255
- # Write dx
256
- tl.store(DX + cols, dx, mask=mask)
257
-
258
- X += stride_x_row
259
- if HAS_Z:
260
- Z += stride_z_row
261
- DZ += stride_dz_row
262
- if RECOMPUTE_OUTPUT:
263
- Y += stride_y_row
264
- DY += stride_dy_row
265
- DX += stride_dx_row
266
- tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
267
- if HAS_BIAS:
268
- tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
269
-
270
-
271
- def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,
272
- norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):
273
- M, N = x.shape
274
- if group_size is None:
275
- group_size = N
276
- assert N % group_size == 0
277
- ngroups = N // group_size
278
- assert x.stride(-1) == 1
279
- assert dy.stride(-1) == 1
280
- assert dy.shape == (M, N)
281
- if z is not None:
282
- assert z.stride(-1) == 1
283
- assert z.shape == (M, N)
284
- assert weight.shape == (N,)
285
- assert weight.stride(-1) == 1
286
- if bias is not None:
287
- assert bias.stride(-1) == 1
288
- assert bias.shape == (N,)
289
- # allocate output
290
- dx = torch.empty_like(x)
291
- if dz is not None:
292
- assert z is not None
293
- assert dz.shape == z.shape
294
- assert dz.stride(-1) == 1
295
- else:
296
- dz = torch.empty_like(z) if z is not None else None
297
- if recompute_output:
298
- if out is None:
299
- out = torch.empty_like(x)
300
- assert out.shape == x.shape
301
-
302
- # Less than 64KB per feature: enqueue fused kernel
303
- MAX_FUSED_SIZE = 65536 // x.element_size()
304
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
305
- if group_size > BLOCK_N:
306
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
307
- # heuristics for number of warps
308
- num_warps = min(max(BLOCK_N // 256, 1), 8)
309
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
310
- # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
311
- # would limit the occupancy.
312
- nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
313
- _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
314
- _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
315
- rows_per_program = math.ceil(M / nrow_groups)
316
- grid = (nrow_groups, ngroups)
317
- with torch.cuda.device(x.device.index):
318
- _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,
319
- dy, dx, _dw, _db, dz, mean, rstd,
320
- x.stride(0),
321
- z.stride(0) if z is not None else 0,
322
- 0 if not recompute_output else out.stride(0),
323
- dy.stride(0), dx.stride(0),
324
- dz.stride(0) if dz is not None else 0,
325
- _dw.stride(0),
326
- _db.stride(0) if _db is not None else 0,
327
- M, group_size, eps,
328
- rows_per_program,
329
- BLOCK_N=BLOCK_N,
330
- NORM_BEFORE_GATE=norm_before_gate,
331
- IS_RMS_NORM=is_rms_norm,
332
- num_warps=num_warps)
333
- dw = _dw.sum(0).to(weight.dtype)
334
- db = _db.sum(0).to(bias.dtype) if bias is not None else None
335
- return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
336
-
337
-
338
- class LayerNormFn(torch.autograd.Function):
339
-
340
- @staticmethod
341
- def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
342
- is_rms_norm=False):
343
- """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
344
- """
345
-
346
- x_shape_og = x.shape
347
- # reshape input data into 2D tensor
348
- x = x.reshape(-1, x.shape[-1])
349
- if x.stride(-1) != 1:
350
- x = x.contiguous()
351
- if z is not None:
352
- assert z.shape == x_shape_og
353
- z = z.reshape(-1, z.shape[-1])
354
- if z.stride(-1) != 1:
355
- z = z.contiguous()
356
- weight = weight.contiguous()
357
- if bias is not None:
358
- bias = bias.contiguous()
359
- y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
360
- ctx.save_for_backward(x, weight, bias, mean, rstd, z)
361
- ctx.x_shape_og = x_shape_og
362
- ctx.eps = eps
363
- ctx.group_size = group_size
364
- ctx.norm_before_gate = norm_before_gate
365
- ctx.is_rms_norm = is_rms_norm
366
- return y.reshape(x_shape_og)
367
-
368
- @staticmethod
369
- def backward(ctx, dy):
370
- x, weight, bias, mean, rstd, z = ctx.saved_tensors
371
- dy = dy.reshape(-1, dy.shape[-1])
372
- if dy.stride(-1) != 1:
373
- dy = dy.contiguous()
374
- assert dy.shape == x.shape
375
- dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,
376
- ctx.norm_before_gate, ctx.is_rms_norm)
377
- return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None
378
-
379
-
380
- def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
381
- return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
382
-
383
-
384
- def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
385
- return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
386
-
387
-
388
- class LayerNorm(torch.nn.Module):
389
-
390
- def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
391
- """If group_size is not None, we do GroupNorm with each group having group_size elements.
392
- group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
393
- """
394
-
395
- factory_kwargs = {"device": device, "dtype": dtype}
396
- super().__init__()
397
- self.eps = eps
398
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
399
- self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
400
- self.group_size = group_size
401
- self.norm_before_gate = norm_before_gate
402
- self.reset_parameters()
403
-
404
- def reset_parameters(self):
405
- torch.nn.init.ones_(self.weight)
406
- torch.nn.init.zeros_(self.bias)
407
-
408
- def forward(self, x, z=None):
409
- """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
410
- """
411
- return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
412
- norm_before_gate=self.norm_before_gate)
413
-
414
-
415
- class RMSNorm(torch.nn.Module):
416
-
417
- def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
418
- """If group_size is not None, we do GroupNorm with each group having group_size elements.
419
- group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
420
- """
421
- factory_kwargs = {"device": device, "dtype": dtype}
422
- super().__init__()
423
- self.eps = eps
424
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
425
- self.register_parameter("bias", None)
426
- self.group_size = group_size
427
- self.norm_before_gate = norm_before_gate
428
- self.reset_parameters()
429
-
430
- def reset_parameters(self):
431
- torch.nn.init.ones_(self.weight)
432
-
433
- def forward(self, x, z=None):
434
- """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
435
- """
436
- return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
437
- norm_before_gate=self.norm_before_gate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/selective_state_update.py DELETED
@@ -1,285 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
- from .softplus import softplus
16
-
17
-
18
- @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
19
- @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
20
- @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
21
- @triton.heuristics({"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] is not None})
22
- @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
23
- @triton.jit
24
- def _selective_scan_update_kernel(
25
- # Pointers to matrices
26
- state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, state_batch_indices_ptr,
27
- # Matrix dimensions
28
- batch, nheads, dim, dstate, nheads_ngroups_ratio,
29
- # Strides
30
- stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
31
- stride_x_batch, stride_x_head, stride_x_dim,
32
- stride_dt_batch, stride_dt_head, stride_dt_dim,
33
- stride_dt_bias_head, stride_dt_bias_dim,
34
- stride_A_head, stride_A_dim, stride_A_dstate,
35
- stride_B_batch, stride_B_group, stride_B_dstate,
36
- stride_C_batch, stride_C_group, stride_C_dstate,
37
- stride_D_head, stride_D_dim,
38
- stride_z_batch, stride_z_head, stride_z_dim,
39
- stride_out_batch, stride_out_head, stride_out_dim,
40
- # Meta-parameters
41
- DT_SOFTPLUS: tl.constexpr,
42
- TIE_HDIM: tl.constexpr,
43
- BLOCK_SIZE_M: tl.constexpr,
44
- HAS_DT_BIAS: tl.constexpr,
45
- HAS_D: tl.constexpr,
46
- HAS_Z: tl.constexpr,
47
- HAS_STATE_BATCH_INDICES: tl.constexpr,
48
- BLOCK_SIZE_DSTATE: tl.constexpr,
49
- ):
50
- pid_m = tl.program_id(axis=0)
51
- pid_b = tl.program_id(axis=1)
52
- pid_h = tl.program_id(axis=2)
53
-
54
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
55
- out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
56
- out_ptrs = out_ptr + offs_m * stride_out_dim
57
-
58
- if HAS_STATE_BATCH_INDICES:
59
- state_batch_indices_ptr += pid_b
60
- state_batch_idx = tl.load(state_batch_indices_ptr)
61
- # Skip padding tokens
62
- if state_batch_idx < 0:
63
- tl.store(out_ptrs, 0.0, mask=offs_m < dim)
64
- return
65
- state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
66
- else:
67
- state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
68
-
69
- x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
70
- dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
71
- if HAS_DT_BIAS:
72
- dt_bias_ptr += pid_h * stride_dt_bias_head
73
- A_ptr += pid_h * stride_A_head
74
- B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
75
- C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
76
- if HAS_Z:
77
- z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
78
-
79
- offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
80
- state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
81
- x_ptrs = x_ptr + offs_m * stride_x_dim
82
- dt_ptrs = dt_ptr + offs_m * stride_dt_dim
83
- if HAS_DT_BIAS:
84
- dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
85
- if HAS_D:
86
- D_ptr += pid_h * stride_D_head
87
- A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
88
- B_ptrs = B_ptr + offs_n * stride_B_dstate
89
- C_ptrs = C_ptr + offs_n * stride_C_dstate
90
- if HAS_D:
91
- D_ptrs = D_ptr + offs_m * stride_D_dim
92
- if HAS_Z:
93
- z_ptrs = z_ptr + offs_m * stride_z_dim
94
-
95
- state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
96
- x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
97
- if not TIE_HDIM:
98
- dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
99
- if HAS_DT_BIAS:
100
- dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
101
- if DT_SOFTPLUS:
102
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
103
- A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
104
- dA = tl.exp(A * dt[:, None])
105
- else:
106
- dt = tl.load(dt_ptr).to(tl.float32)
107
- if HAS_DT_BIAS:
108
- dt += tl.load(dt_bias_ptr).to(tl.float32)
109
- if DT_SOFTPLUS:
110
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
111
- A = tl.load(A_ptr).to(tl.float32)
112
- dA = tl.exp(A * dt) # scalar, not a matrix
113
-
114
- B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
115
- C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
116
- if HAS_D:
117
- D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
118
- if HAS_Z:
119
- z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
120
-
121
- if not TIE_HDIM:
122
- dB = B[None, :] * dt[:, None]
123
- else:
124
- dB = B * dt # vector of size (dstate,)
125
- state = state * dA + dB * x[:, None]
126
- tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
127
- out = tl.sum(state * C[None, :], axis=1)
128
- if HAS_D:
129
- out += x * D
130
- if HAS_Z:
131
- out *= z * tl.sigmoid(z)
132
- tl.store(out_ptrs, out, mask=offs_m < dim)
133
-
134
-
135
- def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False,
136
- state_batch_indices=None):
137
- """
138
- Argument:
139
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
140
- x: (batch, dim) or (batch, nheads, dim)
141
- dt: (batch, dim) or (batch, nheads, dim)
142
- A: (dim, dstate) or (nheads, dim, dstate)
143
- B: (batch, dstate) or (batch, ngroups, dstate)
144
- C: (batch, dstate) or (batch, ngroups, dstate)
145
- D: (dim,) or (nheads, dim)
146
- z: (batch, dim) or (batch, nheads, dim)
147
- dt_bias: (dim,) or (nheads, dim)
148
- Return:
149
- out: (batch, dim) or (batch, nheads, dim)
150
- """
151
- has_heads = state.dim() > 3
152
- if state.dim() == 3:
153
- state = state.unsqueeze(1)
154
- if x.dim() == 2:
155
- x = x.unsqueeze(1)
156
- if dt.dim() == 2:
157
- dt = dt.unsqueeze(1)
158
- if A.dim() == 2:
159
- A = A.unsqueeze(0)
160
- if B.dim() == 2:
161
- B = B.unsqueeze(1)
162
- if C.dim() == 2:
163
- C = C.unsqueeze(1)
164
- if D is not None and D.dim() == 1:
165
- D = D.unsqueeze(0)
166
- if z is not None and z.dim() == 2:
167
- z = z.unsqueeze(1)
168
- if dt_bias is not None and dt_bias.dim() == 1:
169
- dt_bias = dt_bias.unsqueeze(0)
170
- _, nheads, dim, dstate = state.shape
171
- batch = x.shape[0]
172
- if x.shape != (batch, nheads, dim):
173
- print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
174
- assert x.shape == (batch, nheads, dim)
175
- assert dt.shape == x.shape
176
- assert A.shape == (nheads, dim, dstate)
177
- ngroups = B.shape[1]
178
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
179
- assert B.shape == (batch, ngroups, dstate)
180
- assert C.shape == B.shape
181
- if D is not None:
182
- assert D.shape == (nheads, dim)
183
- if z is not None:
184
- assert z.shape == x.shape
185
- if dt_bias is not None:
186
- assert dt_bias.shape == (nheads, dim)
187
- if state_batch_indices is not None:
188
- assert state_batch_indices.shape == (batch,)
189
- out = torch.empty_like(x)
190
- grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
191
- z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
192
- # We don't want autotune since it will overwrite the state
193
- # We instead tune by hand.
194
- BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
195
- else ((16, 4) if dstate <= 32 else
196
- ((8, 4) if dstate <= 64 else
197
- ((4, 4) if dstate <= 128 else
198
- ((4, 8))))))
199
- tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
200
- with torch.cuda.device(x.device.index):
201
- _selective_scan_update_kernel[grid](
202
- state, x, dt, dt_bias, A, B, C, D, z, out, state_batch_indices,
203
- batch, nheads, dim, dstate, nheads // ngroups,
204
- state.stride(0), state.stride(1), state.stride(2), state.stride(3),
205
- x.stride(0), x.stride(1), x.stride(2),
206
- dt.stride(0), dt.stride(1), dt.stride(2),
207
- *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
208
- A.stride(0), A.stride(1), A.stride(2),
209
- B.stride(0), B.stride(1), B.stride(2),
210
- C.stride(0), C.stride(1), C.stride(2),
211
- *(D.stride(0), D.stride(1)) if D is not None else 0,
212
- z_strides[0], z_strides[1], z_strides[2],
213
- out.stride(0), out.stride(1), out.stride(2),
214
- dt_softplus,
215
- tie_hdim,
216
- BLOCK_SIZE_M,
217
- num_warps=num_warps,
218
- )
219
- if not has_heads:
220
- out = out.squeeze(1)
221
- return out
222
-
223
-
224
- def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
225
- """
226
- Argument:
227
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
228
- x: (batch, dim) or (batch, nheads, dim)
229
- dt: (batch, dim) or (batch, nheads, dim)
230
- A: (dim, dstate) or (nheads, dim, dstate)
231
- B: (batch, dstate) or (batch, ngroups, dstate)
232
- C: (batch, dstate) or (batch, ngroups, dstate)
233
- D: (dim,) or (nheads, dim)
234
- z: (batch, dim) or (batch, nheads, dim)
235
- dt_bias: (dim,) or (nheads, dim)
236
- Return:
237
- out: (batch, dim) or (batch, nheads, dim)
238
- """
239
- has_heads = state.dim() > 3
240
- if state.dim() == 3:
241
- state = state.unsqueeze(1)
242
- if x.dim() == 2:
243
- x = x.unsqueeze(1)
244
- if dt.dim() == 2:
245
- dt = dt.unsqueeze(1)
246
- if A.dim() == 2:
247
- A = A.unsqueeze(0)
248
- if B.dim() == 2:
249
- B = B.unsqueeze(1)
250
- if C.dim() == 2:
251
- C = C.unsqueeze(1)
252
- if D is not None and D.dim() == 1:
253
- D = D.unsqueeze(0)
254
- if z is not None and z.dim() == 2:
255
- z = z.unsqueeze(1)
256
- if dt_bias is not None and dt_bias.dim() == 1:
257
- dt_bias = dt_bias.unsqueeze(0)
258
- batch, nheads, dim, dstate = state.shape
259
- assert x.shape == (batch, nheads, dim)
260
- assert dt.shape == x.shape
261
- assert A.shape == (nheads, dim, dstate)
262
- ngroups = B.shape[1]
263
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
264
- assert B.shape == (batch, ngroups, dstate)
265
- assert C.shape == B.shape
266
- if D is not None:
267
- assert D.shape == (nheads, dim)
268
- if z is not None:
269
- assert z.shape == x.shape
270
- if dt_bias is not None:
271
- assert dt_bias.shape == (nheads, dim)
272
- dt = dt + dt_bias
273
- dt = F.softplus(dt) if dt_softplus else dt
274
- dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
275
- B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
276
- C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
277
- dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
278
- state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
279
- out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
280
- if D is not None:
281
- out += (x * D).to(out.dtype)
282
- out = (out if z is None else out * F.silu(z)).to(x.dtype)
283
- if not has_heads:
284
- out = out.squeeze(1)
285
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/softplus.py DELETED
@@ -1,15 +0,0 @@
1
- import triton
2
- import triton.language as tl
3
- from packaging import version
4
-
5
- TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
6
-
7
-
8
- if TRITON3:
9
- @triton.jit
10
- def softplus(dt):
11
- return tl.math.log(tl.math.exp(dt) + 1)
12
- else:
13
- @triton.jit
14
- def softplus(dt):
15
- return tl.math.log1p(tl.exp(dt))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_bmm.py DELETED
@@ -1,262 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
-
16
- def init_to_zero(names):
17
- return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
18
-
19
-
20
- @triton.autotune(
21
- configs=[
22
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
23
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
24
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
25
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
26
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
27
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
28
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
29
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
30
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
31
- ],
32
- key=['chunk_size', 'K', 'IS_CAUSAL'],
33
- )
34
- @triton.jit
35
- def _bmm_chunk_fwd_kernel(
36
- # Pointers to matrices
37
- a_ptr, b_ptr, out_ptr, seq_idx_ptr,
38
- # Matrix dimensions
39
- seqlen, chunk_size, K, ngroups,
40
- stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
41
- stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
42
- stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
43
- stride_seq_idx_batch, stride_seq_idx_seqlen,
44
- # Meta-parameters
45
- IS_CAUSAL: tl.constexpr,
46
- dot_dtype: tl.constexpr,
47
- HAS_SEQ_IDX: tl.constexpr,
48
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
49
- ):
50
- pid_b = tl.program_id(axis=1)
51
- pid_ch = tl.program_id(axis=2)
52
- pid_c = pid_ch // ngroups
53
- pid_h = pid_ch - pid_c * ngroups
54
- num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
55
- pid_m = tl.program_id(axis=0) // num_pid_n
56
- pid_n = tl.program_id(axis=0) % num_pid_n
57
- if IS_CAUSAL:
58
- if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
59
- return
60
- a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
61
- b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
62
- if HAS_SEQ_IDX:
63
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
64
-
65
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
66
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
67
- offs_k = tl.arange(0, BLOCK_SIZE_K)
68
- a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
69
- b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
70
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
71
-
72
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
74
- a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
75
- b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
76
- acc += tl.dot(a, b)
77
- a_ptrs += BLOCK_SIZE_K * stride_ak
78
- b_ptrs += BLOCK_SIZE_K * stride_bk
79
-
80
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
81
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
82
- if HAS_SEQ_IDX:
83
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
84
- seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
85
- seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
86
- acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
87
- out = acc.to(out_ptr.dtype.element_ty)
88
-
89
- out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
90
- out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
91
- tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
92
-
93
-
94
- @triton.autotune(
95
- configs=[
96
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
97
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
98
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
99
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
100
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
101
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
102
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
103
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
104
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
105
- ],
106
- key=['chunk_size', 'K'],
107
- )
108
- @triton.jit
109
- def _bmm_chunk_bwd_kernel(
110
- # Pointers to matrices
111
- a_ptr, dout_ptr, db_ptr, res_ptr,
112
- # Matrix dimensions
113
- seqlen, chunk_size, K, ngroups,
114
- stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
115
- stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
116
- stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
117
- stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
118
- # Meta-parameters
119
- dot_dtype: tl.constexpr,
120
- HAS_RESIDUAL: tl.constexpr,
121
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
122
- ):
123
- pid_b = tl.program_id(axis=1)
124
- pid_ch = tl.program_id(axis=2)
125
- pid_c = pid_ch // ngroups
126
- pid_h = pid_ch - pid_c * ngroups
127
- num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
128
- pid_m = tl.program_id(axis=0) // num_pid_n
129
- pid_n = tl.program_id(axis=0) % num_pid_n
130
-
131
- a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
132
- dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head
133
-
134
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
135
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
136
- offs_cs = tl.arange(0, BLOCK_SIZE_CS)
137
- dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
138
- a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
139
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
140
-
141
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
142
- for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
143
- dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
144
- a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
145
- acc += tl.dot(dout, a)
146
- dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
147
- a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen
148
-
149
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
150
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
151
- if HAS_RESIDUAL:
152
- res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
153
- res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
154
- res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
155
- acc += res
156
- db = acc.to(db_ptr.dtype.element_ty)
157
-
158
- db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
159
- db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
160
- tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))
161
-
162
-
163
- def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
164
- """
165
- Argument:
166
- a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
167
- b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
168
- seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
169
- causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
170
- guaranteed to be correct.
171
- Return:
172
- out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
173
- """
174
- # Check constraints.
175
- has_groups = a.dim() == 4
176
- if not has_groups:
177
- batch, seqlen, k = a.shape
178
- else:
179
- batch, seqlen, ngroups, k = a.shape
180
- assert b.shape == a.shape
181
- if seq_idx is not None:
182
- assert seq_idx.shape == (batch, seqlen)
183
- if a.stride(-1) != 1 and a.stride(1) != 1:
184
- a = a.contiguous()
185
- if b.stride(-1) != 1 and b.stride(1) != 1:
186
- b = b.contiguous()
187
- nchunks = math.ceil(seqlen / chunk_size)
188
- # Allocates output.
189
- out_dtype = a.dtype if output_dtype is None else output_dtype
190
- out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
191
- device=a.device, dtype=out_dtype)
192
- dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
193
- (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
194
- grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
195
- batch, nchunks if not has_groups else nchunks * ngroups)
196
- with torch.cuda.device(a.device.index):
197
- _bmm_chunk_fwd_kernel[grid](
198
- a, b, out, seq_idx,
199
- seqlen, chunk_size, k, ngroups if has_groups else 1,
200
- a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
201
- b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
202
- out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
203
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
204
- causal,
205
- dot_dtype,
206
- HAS_SEQ_IDX=seq_idx is not None,
207
- )
208
- return out
209
-
210
-
211
- def _bmm_chunk_bwd(a, dout, residual=None, out=None):
212
- """
213
- Argument:
214
- a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
215
- dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
216
- residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
217
- Return:
218
- out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
219
-
220
- If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
221
- zeroed out before calling this function.
222
- """
223
- # Check constraints.
224
- has_groups = a.dim() == 4
225
- if not has_groups:
226
- batch, seqlen, k = a.shape
227
- else:
228
- batch, seqlen, ngroups, k = a.shape
229
- nchunks, chunk_size = dout.shape[1], dout.shape[-1]
230
- if a.stride(-1) != 1 and a.stride(-2) != 1:
231
- a = a.contiguous()
232
- if dout.stride(-1) != 1 and dout.stride(-2) != 1:
233
- dout = dout.contiguous()
234
- if residual is not None:
235
- assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
236
- if residual.stride(-1) != 1 and residual.stride(1) != 1:
237
- residual = residual.contiguous()
238
- # Allocates output.
239
- if out is not None:
240
- assert out.shape == a.shape
241
- assert out.stride(-1) == 1 or out.stride(1) == 1
242
- else:
243
- out = torch.empty_like(a)
244
- dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
245
- (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
246
- grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
247
- nchunks if not has_groups else nchunks * ngroups)
248
- residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
249
- residual.stride(-1))
250
- if residual is not None else (0, 0, 0, 0))
251
- with torch.cuda.device(a.device.index):
252
- _bmm_chunk_bwd_kernel[grid](
253
- a, dout, out, residual,
254
- seqlen, chunk_size, k, ngroups if has_groups else 1,
255
- a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
256
- dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
257
- out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
258
- residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
259
- dot_dtype,
260
- HAS_RESIDUAL=residual is not None,
261
- )
262
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_chunk_scan.py DELETED
The diff for this file is too large to render. See raw diff
 
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_chunk_state.py DELETED
@@ -1,997 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
- from .softplus import softplus
16
-
17
-
18
- def init_to_zero(names):
19
- return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
20
-
21
- @triton.autotune(
22
- configs=[
23
- triton.Config({'BLOCK_SIZE_H': 1}),
24
- triton.Config({'BLOCK_SIZE_H': 2}),
25
- triton.Config({'BLOCK_SIZE_H': 4}),
26
- triton.Config({'BLOCK_SIZE_H': 8}),
27
- triton.Config({'BLOCK_SIZE_H': 16}),
28
- triton.Config({'BLOCK_SIZE_H': 32}),
29
- triton.Config({'BLOCK_SIZE_H': 64}),
30
- ],
31
- key=['chunk_size', 'nheads'],
32
- )
33
- @triton.jit
34
- def _chunk_cumsum_fwd_kernel(
35
- # Pointers to matrices
36
- dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,
37
- # Matrix dimension
38
- batch, seqlen, nheads, chunk_size,
39
- dt_min, dt_max,
40
- # Strides
41
- stride_dt_batch, stride_dt_seqlen, stride_dt_head,
42
- stride_A_head,
43
- stride_dt_bias_head,
44
- stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,
45
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
46
- # Meta-parameters
47
- DT_SOFTPLUS: tl.constexpr,
48
- HAS_DT_BIAS: tl.constexpr,
49
- BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
50
- ):
51
- pid_b = tl.program_id(axis=0)
52
- pid_c = tl.program_id(axis=1)
53
- pid_h = tl.program_id(axis=2)
54
- dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
55
- dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
56
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
57
-
58
- offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
59
- offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
60
- dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
61
- A_ptrs = A_ptr + offs_h * stride_A_head
62
- dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)
63
- dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)
64
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
65
-
66
- dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
67
- if HAS_DT_BIAS:
68
- dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
69
- dt += dt_bias[:, None]
70
- if DT_SOFTPLUS:
71
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
72
- # As of Triton 2.2.0, tl.clamp is not available yet
73
- # dt = tl.clamp(dt, dt_min, dt_max)
74
- dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
75
- dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
76
- tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
77
- A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
78
- dA = dt * A[:, None]
79
- dA_cs = tl.cumsum(dA, axis=1)
80
- tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
81
-
82
-
83
- @triton.autotune(
84
- configs=[
85
- triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
86
- triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
87
- triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
88
- triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
89
- triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
90
- triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
91
- triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
92
- ],
93
- key=['chunk_size', 'nheads'],
94
- )
95
- @triton.jit
96
- def _chunk_cumsum_bwd_kernel(
97
- # Pointers to matrices
98
- ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,
99
- ddt_ptr, dA_ptr, ddt_bias_ptr,
100
- # Matrix dimensions
101
- batch, seqlen, nheads, chunk_size,
102
- dt_min, dt_max,
103
- # Strides
104
- stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,
105
- stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,
106
- stride_dt_batch, stride_dt_seqlen, stride_dt_head,
107
- stride_A_head,
108
- stride_dt_bias_head,
109
- stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,
110
- stride_dA_head,
111
- stride_ddt_bias_head,
112
- # Meta-parameters
113
- DT_SOFTPLUS: tl.constexpr,
114
- HAS_DT_BIAS: tl.constexpr,
115
- BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
116
- ):
117
- pid_b = tl.program_id(axis=0)
118
- pid_c = tl.program_id(axis=1)
119
- pid_h = tl.program_id(axis=2)
120
- ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
121
- ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
122
- dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
123
- ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
124
-
125
- offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
126
- offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
127
- ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize)
128
- ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize)
129
- dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
130
- ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen)
131
- A_ptrs = A_ptr + offs_h * stride_A_head
132
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
133
-
134
- ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
135
- ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
136
- A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
137
- ddt = ddA * A[:, None] + ddt_out
138
- dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
139
- if HAS_DT_BIAS:
140
- dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
141
- dt += dt_bias[:, None]
142
- if DT_SOFTPLUS:
143
- dt_presoftplus = dt
144
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
145
- clamp_mask = (dt < dt_min) | (dt > dt_max)
146
- # As of Triton 2.2.0, tl.clamp is not available yet
147
- # dt = tl.clamp(dt, dt_min, dt_max)
148
- dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
149
- dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
150
- ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0)
151
- ddt = tl.where(clamp_mask, 0.0, ddt)
152
- if DT_SOFTPLUS:
153
- ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
154
- tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit))
155
- dA = tl.sum(ddA * dt, axis=1)
156
- tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
157
- if HAS_DT_BIAS:
158
- ddt_bias = tl.sum(ddt, axis=1)
159
- tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)
160
-
161
-
162
- @triton.autotune(
163
- configs=[
164
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
165
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
166
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
167
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
168
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
169
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
170
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
171
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
172
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
173
- ],
174
- key=['hdim', 'dstate', 'chunk_size'],
175
- )
176
- @triton.jit
177
- def _chunk_state_fwd_kernel(
178
- # Pointers to matrices
179
- x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
180
- # Matrix dimensions
181
- hdim, dstate, chunk_size,
182
- batch, seqlen, nheads_ngroups_ratio,
183
- # Strides
184
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
185
- stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
186
- stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
187
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
188
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
189
- stride_seq_idx_batch, stride_seq_idx_seqlen,
190
- # Meta-parameters
191
- HAS_SEQ_IDX: tl.constexpr,
192
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
193
- ):
194
- pid_bc = tl.program_id(axis=1)
195
- pid_c = pid_bc // batch
196
- pid_b = pid_bc - pid_c * batch
197
- pid_h = tl.program_id(axis=2)
198
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
199
- pid_m = tl.program_id(axis=0) // num_pid_n
200
- pid_n = tl.program_id(axis=0) % num_pid_n
201
- b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
202
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
203
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
204
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
205
- if HAS_SEQ_IDX:
206
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
207
-
208
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
209
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
210
- offs_k = tl.arange(0, BLOCK_SIZE_K)
211
- x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
212
- b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
213
- dt_ptrs = dt_ptr + offs_k * stride_dt_csize
214
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
215
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
216
- if HAS_SEQ_IDX:
217
- seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
218
-
219
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
220
- if HAS_SEQ_IDX:
221
- seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
222
-
223
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
224
- for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
225
- x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0)
226
- b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
227
- dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
228
- if HAS_SEQ_IDX:
229
- seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
230
- dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
231
- if not HAS_SEQ_IDX:
232
- # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
233
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k
234
- else:
235
- # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
236
- scale = tl.where((seq_idx_last >= 0) & (seq_idx_k == seq_idx_last), tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
237
- b *= scale[:, None]
238
- b = b.to(x_ptr.dtype.element_ty)
239
- acc += tl.dot(x, b)
240
- x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
241
- b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
242
- dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
243
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
244
- if HAS_SEQ_IDX:
245
- seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
246
- states = acc.to(states_ptr.dtype.element_ty)
247
-
248
- states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
249
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
250
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
251
- states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
252
- c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
253
- tl.store(states_ptrs, states, mask=c_mask)
254
-
255
-
256
- @triton.autotune(
257
- configs=[
258
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
259
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
260
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
261
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
262
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
263
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
264
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
265
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
266
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
267
- ],
268
- key=['chunk_size', 'hdim', 'dstate'],
269
- )
270
- @triton.jit
271
- def _chunk_state_bwd_dx_kernel(
272
- # Pointers to matrices
273
- x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr,
274
- dx_ptr, ddt_ptr, ddA_cumsum_ptr,
275
- # Matrix dimensions
276
- chunk_size, hdim, dstate,
277
- batch, seqlen, nheads_ngroups_ratio,
278
- # Strides
279
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
280
- stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
281
- stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
282
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
283
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
284
- stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
285
- stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
286
- stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
287
- # Meta-parameters
288
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
289
- BLOCK_SIZE_DSTATE: tl.constexpr,
290
- ):
291
- pid_bc = tl.program_id(axis=1)
292
- pid_c = pid_bc // batch
293
- pid_b = pid_bc - pid_c * batch
294
- pid_h = tl.program_id(axis=2)
295
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
296
- pid_m = tl.program_id(axis=0) // num_pid_n
297
- pid_n = tl.program_id(axis=0) % num_pid_n
298
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
299
- b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
300
- dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
301
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
302
- ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
303
- ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
304
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
305
-
306
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
307
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
308
-
309
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
310
- # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
311
- offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
312
- b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
313
- dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
314
- if BLOCK_SIZE_DSTATE <= 128:
315
- b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
316
- dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
317
- dstates = dstates.to(b_ptr.dtype.element_ty)
318
- acc = tl.dot(b, dstates)
319
- else:
320
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
321
- for k in range(0, dstate, BLOCK_SIZE_K):
322
- b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
323
- dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
324
- dstates = dstates.to(b_ptr.dtype.element_ty)
325
- acc += tl.dot(b, dstates)
326
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
327
- dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
328
-
329
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
330
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
331
-
332
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
333
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
334
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
335
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
336
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
337
- # acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
338
- acc *= tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))[:, None]
339
-
340
- x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
341
- x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
342
- ddt = tl.sum(acc * x, axis=1)
343
- ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
344
- tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
345
- ddA_cs = -(ddt * dt_m)
346
- ddA_cs_last = -tl.sum(ddA_cs)
347
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
348
- tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
349
- tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
350
-
351
- dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
352
- dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
353
- dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
354
- tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
355
-
356
-
357
- @triton.autotune(
358
- configs=[
359
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
360
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
361
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
362
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
363
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
364
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
365
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
366
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
367
- ],
368
- key=['chunk_size', 'dstate', 'hdim'],
369
- )
370
- @triton.jit
371
- def _chunk_state_bwd_db_kernel(
372
- # Pointers to matrices
373
- x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
374
- db_ptr, ddA_cumsum_ptr,
375
- # Matrix dimensions
376
- chunk_size, dstate, hdim,
377
- batch, seqlen, nheads, nheads_per_program, ngroups,
378
- # Strides
379
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
380
- stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
381
- stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
382
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
383
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
384
- stride_seq_idx_batch, stride_seq_idx_seqlen,
385
- stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate,
386
- stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
387
- # Meta-parameters
388
- HAS_DDA_CS: tl.constexpr,
389
- HAS_SEQ_IDX: tl.constexpr,
390
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
391
- ):
392
- pid_bc = tl.program_id(axis=1)
393
- pid_c = pid_bc // batch
394
- pid_b = pid_bc - pid_c * batch
395
- pid_sg = tl.program_id(axis=2)
396
- pid_s = pid_sg // ngroups
397
- pid_g = pid_sg - pid_s * ngroups
398
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
399
- pid_m = tl.program_id(axis=0) // num_pid_n
400
- pid_n = tl.program_id(axis=0) % num_pid_n
401
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
402
- db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split
403
- dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head
404
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
405
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
406
- if HAS_DDA_CS:
407
- b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head
408
- ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head
409
- if HAS_SEQ_IDX:
410
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
411
-
412
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
413
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
414
- offs_k = tl.arange(0, BLOCK_SIZE_K)
415
- x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim)
416
- dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim)
417
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
418
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
419
- if HAS_DDA_CS:
420
- b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate)
421
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
422
-
423
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
424
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
425
- if HAS_DDA_CS:
426
- b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
427
- if HAS_SEQ_IDX:
428
- seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
429
- seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
430
- nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
431
- for h in range(nheads_iter):
432
- x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
433
- dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
434
- dstates = dstates.to(x_ptrs.dtype.element_ty)
435
- db = tl.dot(x, dstates)
436
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
437
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
438
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
439
- if not HAS_SEQ_IDX:
440
- # scale = tl.exp(dA_cs_last - dA_cs_m)
441
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
442
- else:
443
- # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
444
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
445
- db *= (scale * dt_m)[:, None]
446
- if HAS_DDA_CS:
447
- # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
448
- ddA_cs = tl.sum(db * b, axis=1)
449
- tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
450
- acc += db
451
- x_ptrs += stride_x_head
452
- dstates_ptrs += stride_states_head
453
- dt_ptrs += stride_dt_head
454
- dA_cumsum_ptr += stride_dA_cs_head
455
- dA_cumsum_ptrs += stride_dA_cs_head
456
- if HAS_DDA_CS:
457
- ddA_cumsum_ptrs += stride_ddA_cs_head
458
-
459
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
460
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
461
- # if HAS_SEQ_IDX:
462
- # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
463
- # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
464
- # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
465
- db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate)
466
- tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))
467
-
468
-
469
- @triton.autotune(
470
- configs=[
471
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
472
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
473
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
474
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
475
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
476
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
477
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
478
- # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
479
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
480
- triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
481
- triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
482
- triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
483
- triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
484
- triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
485
- triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
486
- triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
487
- triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
488
- ],
489
- key=['chunk_size', 'hdim', 'dstate'],
490
- )
491
- @triton.jit
492
- def _chunk_state_bwd_ddAcs_stable_kernel(
493
- # Pointers to matrices
494
- x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
495
- ddA_cumsum_ptr,
496
- # Matrix dimensions
497
- chunk_size, hdim, dstate,
498
- batch, seqlen, nheads_ngroups_ratio,
499
- # Strides
500
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
501
- stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
502
- stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
503
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
504
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
505
- stride_seq_idx_batch, stride_seq_idx_seqlen,
506
- stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
507
- # Meta-parameters
508
- HAS_SEQ_IDX: tl.constexpr,
509
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
510
- BLOCK_SIZE_DSTATE: tl.constexpr,
511
- ):
512
- pid_bc = tl.program_id(axis=1)
513
- pid_c = pid_bc // batch
514
- pid_b = pid_bc - pid_c * batch
515
- pid_h = tl.program_id(axis=2)
516
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
517
- pid_m = tl.program_id(axis=0) // num_pid_n
518
- pid_n = tl.program_id(axis=0) % num_pid_n
519
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
520
- b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
521
- dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
522
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
523
- ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
524
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
525
- if HAS_SEQ_IDX:
526
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
527
-
528
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
529
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
530
-
531
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
532
- # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
533
- offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
534
- b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
535
- dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
536
- if BLOCK_SIZE_DSTATE <= 128:
537
- b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
538
- dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
539
- dstates = dstates.to(b_ptr.dtype.element_ty)
540
- acc = tl.dot(b, dstates)
541
- else:
542
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
543
- for k in range(0, dstate, BLOCK_SIZE_K):
544
- b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
545
- dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
546
- dstates = dstates.to(b_ptr.dtype.element_ty)
547
- acc += tl.dot(b, dstates)
548
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
549
- dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
550
-
551
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
552
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
553
-
554
- dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
555
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
556
- if not HAS_SEQ_IDX:
557
- # scale = tl.exp(dA_cs_last - dA_cs_m)
558
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
559
- else:
560
- seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
561
- seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
562
- # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
563
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
564
- acc *= scale[:, None]
565
-
566
- x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
567
- x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
568
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
569
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
570
- ddt = tl.sum(acc * x, axis=1)
571
- # ddA_cs = -(ddt * dt_m)
572
- # Triton 2.2.0 errors if we have the cumsum here, so we just write it out
573
- # then call torch.cumsum outside this kernel.
574
- # ddA_cs = tl.cumsum(ddt * dt_m)
575
- ddA_cs = ddt * dt_m
576
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
577
- # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
578
- tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
579
-
580
-
581
- @triton.autotune(
582
- configs=[
583
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
584
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
585
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
586
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
587
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
588
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
589
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
590
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
591
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
592
- ],
593
- key=['hdim', 'dstate', 'chunk_size'],
594
- )
595
- @triton.jit
596
- def _chunk_state_varlen_kernel(
597
- # Pointers to matrices
598
- x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr,
599
- # Matrix dimensions
600
- hdim, dstate, chunk_size,
601
- seqlen, nheads_ngroups_ratio,
602
- # Strides
603
- stride_x_seqlen, stride_x_head, stride_x_hdim,
604
- stride_b_seqlen, stride_b_head, stride_b_dstate,
605
- stride_dt_chunk, stride_dt_head, stride_dt_csize,
606
- stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
607
- stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate,
608
- stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate,
609
- # Meta-parameters
610
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
611
- ):
612
- pid_b = tl.program_id(axis=1)
613
- pid_h = tl.program_id(axis=2)
614
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
615
- pid_m = tl.program_id(axis=0) // num_pid_n
616
- pid_n = tl.program_id(axis=0) % num_pid_n
617
- end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
618
- pid_c = (end_idx - 1) // chunk_size
619
- b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
620
- x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
621
- dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
622
- dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
623
- chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
624
-
625
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
626
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
627
- offs_k = tl.arange(0, BLOCK_SIZE_K)
628
- x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
629
- b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
630
- dt_ptrs = dt_ptr + offs_k * stride_dt_csize
631
- dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
632
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
633
-
634
- chunk_size_limit = end_idx - pid_c * chunk_size
635
- start_idx = tl.load(cu_seqlens_ptr + pid_b)
636
- start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
637
-
638
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
639
- for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
640
- x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0)
641
- b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)
642
- dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
643
- dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
644
- # scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
645
- # tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
646
- scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
647
- tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
648
- b *= scale[:, None]
649
- b = b.to(x_ptr.dtype.element_ty)
650
- acc += tl.dot(x, b)
651
- x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
652
- b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
653
- dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
654
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
655
-
656
- # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
657
- if start_idx < pid_c * chunk_size:
658
- chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate)
659
- chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
660
- # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
661
- scale = tl.exp(dA_cs_last)
662
- acc += chunk_states * scale
663
-
664
- states = acc.to(states_ptr.dtype.element_ty)
665
-
666
- states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
667
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
668
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
669
- states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
670
- c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
671
- tl.store(states_ptrs, states, mask=c_mask)
672
-
673
-
674
- def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
675
- batch, seqlen, nheads = dt.shape
676
- assert A.shape == (nheads,)
677
- if dt_bias is not None:
678
- assert dt_bias.shape == (nheads,)
679
- nchunks = math.ceil(seqlen / chunk_size)
680
- dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
681
- dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
682
- grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
683
- with torch.cuda.device(dt.device.index):
684
- _chunk_cumsum_fwd_kernel[grid_chunk_cs](
685
- dt, A, dt_bias, dt_out, dA_cumsum,
686
- batch, seqlen, nheads, chunk_size,
687
- dt_limit[0], dt_limit[1],
688
- dt.stride(0), dt.stride(1), dt.stride(2),
689
- A.stride(0),
690
- dt_bias.stride(0) if dt_bias is not None else 0,
691
- dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),
692
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
693
- dt_softplus,
694
- HAS_DT_BIAS=dt_bias is not None,
695
- BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
696
- )
697
- return dA_cumsum, dt_out
698
-
699
-
700
- def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None):
701
- batch, seqlen, nheads = dt.shape
702
- _, _, nchunks, chunk_size = ddA.shape
703
- assert ddA.shape == (batch, nheads, nchunks, chunk_size)
704
- assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
705
- assert A.shape == (nheads,)
706
- if dt_bias is not None:
707
- assert dt_bias.shape == (nheads,)
708
- ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
709
- else:
710
- ddt_bias = None
711
- if ddt is not None:
712
- assert ddt.shape == dt.shape
713
- else:
714
- ddt = torch.empty_like(dt)
715
- dA = torch.empty_like(A, dtype=torch.float32)
716
- grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
717
- with torch.cuda.device(dt.device.index):
718
- _chunk_cumsum_bwd_kernel[grid_chunk_cs](
719
- ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias,
720
- batch, seqlen, nheads, chunk_size,
721
- dt_limit[0], dt_limit[1],
722
- ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),
723
- ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),
724
- dt.stride(0), dt.stride(1), dt.stride(2),
725
- A.stride(0),
726
- dt_bias.stride(0) if dt_bias is not None else 0,
727
- ddt.stride(0), ddt.stride(1), ddt.stride(2),
728
- dA.stride(0),
729
- ddt_bias.stride(0) if ddt_bias is not None else 0,
730
- dt_softplus,
731
- HAS_DT_BIAS=dt_bias is not None,
732
- BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
733
- )
734
- return ddt, dA, ddt_bias
735
-
736
-
737
- def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True):
738
- batch, seqlen, nheads, headdim = x.shape
739
- _, _, nchunks, chunk_size = dt.shape
740
- _, _, ngroups, dstate = B.shape
741
- assert nheads % ngroups == 0
742
- assert B.shape == (batch, seqlen, ngroups, dstate)
743
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
744
- assert dA_cumsum.shape == dt.shape
745
- if seq_idx is not None:
746
- assert seq_idx.shape == (batch, seqlen)
747
- if states is not None:
748
- assert states.shape == (batch, nchunks, nheads, headdim, dstate)
749
- else:
750
- states_dtype = torch.float32 if states_in_fp32 else B.dtype
751
- states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype)
752
- grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
753
- batch * nchunks, nheads)
754
- with torch.cuda.device(x.device.index):
755
- _chunk_state_fwd_kernel[grid](
756
- x, B, states, dt, dA_cumsum, seq_idx,
757
- headdim, dstate, chunk_size,
758
- batch, seqlen, nheads // ngroups,
759
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
760
- B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
761
- states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
762
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
763
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
764
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
765
- HAS_SEQ_IDX=seq_idx is not None,
766
- )
767
- return states
768
-
769
-
770
- def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
771
- batch, seqlen, nheads, headdim = x.shape
772
- _, _, nchunks, chunk_size = dt.shape
773
- _, _, ngroups, dstate = B.shape
774
- assert nheads % ngroups == 0
775
- assert B.shape == (batch, seqlen, ngroups, dstate)
776
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
777
- assert dA_cumsum.shape == dt.shape
778
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
779
- if dx is not None:
780
- assert dx.shape == x.shape
781
- else:
782
- dx = torch.empty_like(x)
783
- ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
784
- ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32)
785
- grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
786
- batch * nchunks, nheads)
787
- with torch.cuda.device(x.device.index):
788
- _chunk_state_bwd_dx_kernel[grid_dx](
789
- x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum,
790
- chunk_size, headdim, dstate,
791
- batch, seqlen, nheads // ngroups,
792
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
793
- B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
794
- dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
795
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
796
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
797
- dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
798
- ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
799
- ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
800
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
801
- )
802
- return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
803
-
804
-
805
- def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
806
- batch, seqlen, nheads, headdim = x.shape
807
- _, _, nchunks, chunk_size = dt.shape
808
- dstate = dstates.shape[-1]
809
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
810
- assert dA_cumsum.shape == dt.shape
811
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
812
- if seq_idx is not None:
813
- assert seq_idx.shape == (batch, seqlen)
814
- if B is not None:
815
- assert B.shape == (batch, seqlen, ngroups, dstate)
816
- B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
817
- # Use torch.empty since the Triton kernel will call init_to_zero
818
- ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
819
- ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3))
820
- else:
821
- B_strides = (0, 0, 0, 0)
822
- ddA_cumsum = None
823
- ddA_cumsum_strides = (0, 0, 0, 0)
824
- nheads_ngroups_ratio = nheads // ngroups
825
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
826
- nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
827
- nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
828
- dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32)
829
- grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
830
- batch * nchunks, nsplits * ngroups)
831
- with torch.cuda.device(x.device.index):
832
- _chunk_state_bwd_db_kernel[grid_db](
833
- x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum,
834
- chunk_size, dstate, headdim,
835
- batch, seqlen, nheads, nheads_per_program, ngroups,
836
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
837
- dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
838
- *B_strides,
839
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
840
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
841
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
842
- dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4),
843
- *ddA_cumsum_strides,
844
- HAS_DDA_CS=ddA_cumsum is not None,
845
- HAS_SEQ_IDX=seq_idx is not None,
846
- BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
847
- )
848
- dB = dB.sum(2)
849
- if ddA_cumsum is not None:
850
- # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
851
- # to the state of the chunk.
852
- # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
853
- # But it's easier to just do the cumsum for all elements, the result will be the same.
854
- torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
855
- return dB if B is None else (dB, ddA_cumsum)
856
-
857
-
858
- def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
859
- batch, seqlen, nheads, headdim = x.shape
860
- _, _, nchunks, chunk_size = dt.shape
861
- _, _, ngroups, dstate = B.shape
862
- assert nheads % ngroups == 0
863
- assert B.shape == (batch, seqlen, ngroups, dstate)
864
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
865
- assert dA_cumsum.shape == dt.shape
866
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
867
- if seq_idx is not None:
868
- assert seq_idx.shape == (batch, seqlen)
869
- # Use torch.empty since the Triton kernel will call init_to_zero
870
- ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
871
- grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
872
- batch * nchunks, nheads)
873
- with torch.cuda.device(x.device.index):
874
- _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
875
- x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum,
876
- chunk_size, headdim, dstate,
877
- batch, seqlen, nheads // ngroups,
878
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
879
- B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
880
- dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
881
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
882
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
883
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
884
- ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
885
- HAS_SEQ_IDX=seq_idx is not None,
886
- BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
887
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
888
- )
889
- torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
890
- return ddA_cumsum
891
-
892
-
893
- def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
894
- total_seqlen, nheads, headdim = x.shape
895
- _, nchunks, chunk_size = dt.shape
896
- _, ngroups, dstate = B.shape
897
- batch = cu_seqlens.shape[0] - 1
898
- cu_seqlens = cu_seqlens.contiguous()
899
- assert nheads % ngroups == 0
900
- assert B.shape == (total_seqlen, ngroups, dstate)
901
- assert dt.shape == (nheads, nchunks, chunk_size)
902
- assert dA_cumsum.shape == dt.shape
903
- assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
904
- states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device)
905
- grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
906
- batch, nheads)
907
- with torch.cuda.device(x.device.index):
908
- _chunk_state_varlen_kernel[grid](
909
- x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states,
910
- headdim, dstate, chunk_size,
911
- total_seqlen, nheads // ngroups,
912
- x.stride(0), x.stride(1), x.stride(2),
913
- B.stride(0), B.stride(1), B.stride(2),
914
- dt.stride(1), dt.stride(0), dt.stride(2),
915
- dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2),
916
- chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3),
917
- states.stride(0), states.stride(1), states.stride(2), states.stride(3),
918
- )
919
- return states
920
-
921
-
922
- class ChunkStateFn(torch.autograd.Function):
923
-
924
- @staticmethod
925
- def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
926
- batch, seqlen, nheads, headdim = x.shape
927
- _, _, nchunks, chunk_size = dt.shape
928
- assert seqlen <= nchunks * chunk_size
929
- _, _, ngroups, dstate = B.shape
930
- assert B.shape == (batch, seqlen, ngroups, dstate)
931
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
932
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
933
- if B.stride(-1) != 1:
934
- B = B.contiguous()
935
- if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
936
- x = x.contiguous()
937
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
938
- ctx.save_for_backward(B, x, dt, dA_cumsum)
939
- return states
940
-
941
- @staticmethod
942
- def backward(ctx, dstates):
943
- B, x, dt, dA_cumsum = ctx.saved_tensors
944
- batch, seqlen, nheads, headdim = x.shape
945
- _, _, nchunks, chunk_size = dt.shape
946
- _, _, ngroups, dstate = B.shape
947
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
948
- if dstates.stride(-1) != 1:
949
- dstates = dstates.contiguous()
950
- dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
951
- dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
952
- dB = dB.to(B.dtype)
953
- return dB, dx, ddt, ddA_cumsum, None
954
-
955
-
956
- def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
957
- """
958
- Argument:
959
- B: (batch, seqlen, ngroups, headdim)
960
- x: (batch, seqlen, nheads, headdim)
961
- dt: (batch, nheads, nchunks, chunk_size)
962
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
963
- Return:
964
- states: (batch, nchunks, nheads, headdim, dstate)
965
- """
966
- return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
967
-
968
-
969
- def chunk_state_ref(B, x, dt, dA_cumsum):
970
- """
971
- Argument:
972
- B: (batch, seqlen, ngroups, headdim)
973
- x: (batch, seqlen, nheads, headdim)
974
- dt: (batch, nheads, nchunks, chunk_size)
975
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
976
- Return:
977
- states: (batch, nchunks, nheads, headdim, dstate)
978
- """
979
- # Check constraints.
980
- batch, seqlen, nheads, headdim = x.shape
981
- dstate = B.shape[-1]
982
- _, _, nchunks, chunk_size = dt.shape
983
- assert seqlen <= nchunks * chunk_size
984
- assert x.shape == (batch, seqlen, nheads, headdim)
985
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
986
- ngroups = B.shape[2]
987
- assert nheads % ngroups == 0
988
- assert B.shape == (batch, seqlen, ngroups, dstate)
989
- B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
990
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
991
- if seqlen < nchunks * chunk_size:
992
- x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
993
- B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
994
- x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
995
- B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
996
- decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
997
- return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_combined.py DELETED
@@ -1,998 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- from typing import Optional
7
-
8
- import math
9
- from packaging import version
10
-
11
- import torch
12
- import torch.nn.functional as F
13
- from torch import Tensor
14
- from ...utils.torch import custom_bwd, custom_fwd
15
-
16
- import triton
17
- import triton.language as tl
18
-
19
- from einops import rearrange, repeat
20
-
21
- try:
22
- from causal_conv1d import causal_conv1d_fn
23
- from causal_conv1d.causal_conv1d_interface import causal_conv1d_cuda
24
- except ImportError:
25
- causal_conv1d_fn = None
26
- causal_conv1d_cuda = None
27
-
28
- from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
29
- from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
30
- from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
31
- from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
32
- from .ssd_chunk_state import chunk_state, chunk_state_ref
33
- from .ssd_chunk_state import chunk_state_varlen
34
- from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
35
- from .ssd_state_passing import state_passing, state_passing_ref
36
- from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
37
- from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
38
- from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
39
- from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
40
- from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
41
- from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
42
- from .k_activations import _swiglu_fwd, _swiglu_bwd
43
-
44
- TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
45
-
46
-
47
- def init_to_zero(names):
48
- return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
49
-
50
-
51
- def rearrange_and_update_stride(tensor, pattern=None, dim=2):
52
- # ensure tensor.stride(dim) is a multiple of eight after rearranging according to pattern,
53
- # if not call contiguous(), rearrange only if pattern is not None
54
- tensor_rearranged = rearrange(tensor, pattern) if pattern is not None else tensor
55
- return tensor_rearranged.contiguous() if tensor_rearranged.stride(dim) % 8 != 0 else tensor_rearranged
56
-
57
-
58
- @triton.autotune(
59
- configs=[
60
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
61
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
62
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
63
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
64
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
65
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
66
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
67
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
68
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
69
- ],
70
- key=['chunk_size', 'hdim', 'dstate'],
71
- )
72
- @triton.jit
73
- def _chunk_scan_chunk_state_bwd_dx_kernel(
74
- # Pointers to matrices
75
- x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,
76
- b_ptr, dstates_ptr,
77
- dx_ptr, ddt_ptr, dD_ptr,
78
- # Matrix dimensions
79
- chunk_size, hdim, dstate,
80
- batch, seqlen, nheads_ngroups_ratio,
81
- # Strides
82
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
83
- stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
84
- stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
85
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
86
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
87
- stride_seq_idx_batch, stride_seq_idx_seqlen,
88
- stride_D_head,
89
- stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
90
- stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,
91
- stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
92
- stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
93
- stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
94
- # Meta-parameters
95
- HAS_D: tl.constexpr,
96
- D_HAS_HDIM: tl.constexpr,
97
- HAS_SEQ_IDX: tl.constexpr,
98
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
99
- BLOCK_SIZE_DSTATE: tl.constexpr,
100
- IS_TRITON_22: tl.constexpr,
101
- ):
102
- pid_bc = tl.program_id(axis=1)
103
- pid_c = pid_bc // batch
104
- pid_b = pid_bc - pid_c * batch
105
- pid_h = tl.program_id(axis=2)
106
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
107
- pid_m = tl.program_id(axis=0) // num_pid_n
108
- pid_n = tl.program_id(axis=0) % num_pid_n
109
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
110
- cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
111
- dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
112
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
113
- ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
114
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
115
- b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
116
- dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head
117
- if HAS_SEQ_IDX:
118
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
119
-
120
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
121
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
122
-
123
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
124
-
125
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
126
-
127
- dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
128
-
129
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
130
- if not HAS_SEQ_IDX:
131
- # scale = tl.exp(dA_cs_last - dA_cs_m)
132
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
133
- else:
134
- seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
135
- seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
136
- # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
137
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
138
- # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
139
- # However, we're getting error with the Triton compiler 2.1.0 for that code path:
140
- # Unexpected mma -> mma layout conversion
141
- # Triton 2.2.0 fixes this
142
- offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
143
- b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)
144
- dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)
145
- if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
146
- b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)
147
- dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
148
- dstates = dstates.to(b_ptr.dtype.element_ty)
149
- acc = tl.dot(b, dstates) * scale[:, None]
150
- else:
151
- for k in range(0, dstate, BLOCK_SIZE_K):
152
- b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)
153
- dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
154
- dstates = dstates.to(b_ptr.dtype.element_ty)
155
- acc += tl.dot(b, dstates)
156
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
157
- dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
158
- acc *= scale[:, None]
159
-
160
- # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
161
- # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
162
- # dt_ptrs = dt_ptr + offs_m * stride_dt_csize
163
- # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
164
- # ddt = tl.sum(acc * x, axis=1) * dt_m
165
- # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
166
- # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
167
-
168
- offs_k = tl.arange(0, BLOCK_SIZE_K)
169
- cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
170
- dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
171
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
172
- K_MAX = chunk_size_limit
173
- K_MIN = pid_m * BLOCK_SIZE_M
174
- cb_ptrs += K_MIN * stride_cb_csize_k
175
- dout_ptrs += K_MIN * stride_dout_seqlen
176
- dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
177
- for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
178
- k = tl.multiple_of(k, BLOCK_SIZE_K)
179
- # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
180
- cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
181
- dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
182
- dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
183
- # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
184
- cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))
185
- # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
186
- # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
187
- # Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
188
- # This will cause NaN in acc, and hence NaN in dx and ddt.
189
- mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
190
- cb = tl.where(mask, cb, 0.0)
191
- cb = cb.to(dout_ptr.dtype.element_ty)
192
- acc += tl.dot(cb, dout)
193
- cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
194
- dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
195
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
196
-
197
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
198
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
199
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
200
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
201
- dx = acc * dt_m[:, None]
202
- dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
203
- dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
204
- if HAS_D:
205
- dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
206
- dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
207
- if D_HAS_HDIM:
208
- D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
209
- else:
210
- D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
211
- dx += dout_res * D
212
- tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
213
-
214
- x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
215
- x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
216
- if HAS_D:
217
- dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
218
- if D_HAS_HDIM:
219
- dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
220
- dD = tl.sum(dout_res * x, axis=0)
221
- tl.store(dD_ptrs, dD, mask=offs_n < hdim)
222
- else:
223
- dD = tl.sum(dout_res * x)
224
- tl.store(dD_ptr, dD)
225
- ddt = tl.sum(acc * x, axis=1)
226
- ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
227
- tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
228
-
229
-
230
- def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):
231
- batch, seqlen, nheads, headdim = x.shape
232
- _, _, nchunks, chunk_size = dt.shape
233
- _, _, ngroups, dstate = B.shape
234
- assert nheads % ngroups == 0
235
- assert B.shape == (batch, seqlen, ngroups, dstate)
236
- assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
237
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
238
- assert dA_cumsum.shape == dt.shape
239
- assert dout.shape == x.shape
240
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
241
- if seq_idx is not None:
242
- assert seq_idx.shape == (batch, seqlen)
243
- if D is not None:
244
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
245
- assert D.stride(-1) == 1
246
- BLOCK_SIZE_min = 32
247
- dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
248
- headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
249
- else:
250
- dD = None
251
- dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
252
- if D is not None else (0, 0, 0, 0, 0))
253
- if dx is None:
254
- dx = torch.empty_like(x)
255
- else:
256
- assert dx.shape == x.shape
257
- ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
258
- grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
259
- batch * nchunks, nheads)
260
- with torch.cuda.device(x.device.index):
261
- _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
262
- x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,
263
- chunk_size, headdim, dstate,
264
- batch, seqlen, nheads // ngroups,
265
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
266
- CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),
267
- dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
268
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
269
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
270
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
271
- D.stride(0) if D is not None else 0,
272
- B.stride(0), B.stride(1), B.stride(2), B.stride(3),
273
- dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
274
- dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
275
- ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
276
- dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
277
- D is not None,
278
- D.dim() == 2 if D is not None else True,
279
- HAS_SEQ_IDX=seq_idx is not None,
280
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
281
- IS_TRITON_22=TRITON_22
282
- )
283
- if D is not None:
284
- BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"]
285
- n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
286
- dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
287
- if D.dim() == 1:
288
- dD = rearrange(dD, "h 1 -> h")
289
- return dx, ddt.to(dtype=dt.dtype), dD
290
-
291
-
292
- def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
293
- batch, seqlen, nheads, headdim = x.shape
294
- _, _, ngroups, dstate = B.shape
295
- assert nheads % ngroups == 0
296
- assert B.shape == (batch, seqlen, ngroups, dstate)
297
- assert x.shape == (batch, seqlen, nheads, headdim)
298
- assert dt.shape == (batch, seqlen, nheads)
299
- assert A.shape == (nheads,)
300
- assert C.shape == B.shape
301
- if z is not None:
302
- assert z.shape == x.shape
303
- if D is not None:
304
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
305
- if seq_idx is not None:
306
- assert seq_idx.shape == (batch, seqlen)
307
- if B.stride(-1) != 1:
308
- B = B.contiguous()
309
- if C.stride(-1) != 1:
310
- C = C.contiguous()
311
- if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
312
- x = x.contiguous()
313
- if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous
314
- z = z.contiguous()
315
- if D is not None and D.stride(-1) != 1:
316
- D = D.contiguous()
317
- if initial_states is not None:
318
- assert initial_states.shape == (batch, nheads, headdim, dstate)
319
- # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
320
- # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
321
- # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
322
- # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
323
- dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
324
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
325
- # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
326
- # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
327
- # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
328
- states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
329
- initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
330
- seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype)
331
- states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]]
332
- # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
333
- # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
334
- CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
335
- out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx)
336
- if cu_seqlens is None:
337
- return out, out_x, dt, dA_cumsum, states, final_states
338
- else:
339
- assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
340
- varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0),
341
- cu_seqlens, states.squeeze(0))
342
- return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
343
-
344
-
345
- def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None,
346
- dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False,
347
- dt_limit=(0.0, float("inf")),
348
- dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False):
349
- if dout.stride(-1) != 1:
350
- dout = dout.contiguous()
351
- batch, seqlen, nheads, headdim = x.shape
352
- nchunks = math.ceil(seqlen / chunk_size)
353
- _, _, ngroups, dstate = B.shape
354
- assert dout.shape == (batch, seqlen, nheads, headdim)
355
- assert dt.shape == (batch, seqlen, nheads)
356
- assert A.shape == (nheads,)
357
- assert nheads % ngroups == 0
358
- assert B.shape == (batch, seqlen, ngroups, dstate)
359
- assert C.shape == B.shape
360
- assert out.shape == x.shape
361
- if initial_states is not None:
362
- assert initial_states.shape == (batch, nheads, headdim, dstate)
363
- if seq_idx is not None:
364
- assert seq_idx.shape == (batch, seqlen)
365
- if dx is not None:
366
- assert dx.shape == x.shape
367
- if dB is not None:
368
- assert dB.shape == B.shape
369
- dB_given = dB
370
- else:
371
- dB_given = torch.empty_like(B)
372
- if dC is not None:
373
- assert dC.shape == C.shape
374
- dC_given = dC
375
- else:
376
- dC_given = torch.empty_like(C)
377
- if dz is not None:
378
- assert z is not None
379
- assert dz.shape == z.shape
380
- if ddt is not None:
381
- assert ddt.shape == dt.shape
382
- ddt_given = ddt
383
- else:
384
- ddt_given = torch.empty_like(dt)
385
- # TD: For some reason Triton (2.1.0 and 2.2.0) errors with
386
- # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
387
- dt_in = dt.clone()
388
- dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus,
389
- dt_limit=dt_limit)
390
- CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
391
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
392
- states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
393
- initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
394
- seq_idx=seq_idx, chunk_size=chunk_size)
395
- states = rearrange(states, "... (p n) -> ... p n", n=dstate)
396
- if z is not None:
397
- dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output)
398
- outz = rest[0] if recompute_output else out
399
- else:
400
- dz = None
401
- outz = out
402
- dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype)
403
- # dstates has length nchunks, containing the gradient to initial states at index 0 and
404
- # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
405
- # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
406
- # will be used in matmul in the next kernels.
407
- dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
408
- rearrange(states, "... p n -> ... (p n)"),
409
- dA_cumsum[:, :, :, -1],
410
- rearrange(dstates, "... p n -> ... (p n)"),
411
- dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None,
412
- seq_idx=seq_idx,
413
- has_initial_states=initial_states is not None,
414
- dstates_dtype=x.dtype,
415
- states_dtype=x.dtype,
416
- chunk_size=chunk_size,
417
- )
418
- # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
419
- # gradient to the final states at index (nchunks - 1)
420
- # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
421
- # The final states is not stored.
422
- states = rearrange(states, "... (p n) -> ... p n", n=dstate)
423
- dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
424
- dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None
425
- dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx)
426
- # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
427
- dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups)
428
- # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
429
- dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups)
430
- # Computing ddA with the dcb kernel is much slower, so we're not using it for now
431
- dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
432
- # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
433
- dCB = dCB.to(CB.dtype)
434
- _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
435
- _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
436
- # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
437
- # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
438
- if z is None:
439
- dD = dD_from_x
440
- # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
441
- # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
442
- # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
443
- # be a lot of underflow.
444
-
445
- # This is already done as part of bwd_dC kernel
446
- # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
447
- ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
448
- ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
449
- # This is already done as part of bwd_dB kernel
450
- # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
451
- # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
452
- ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
453
- ddA += ddA_next + ddA_prev
454
-
455
- ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given)
456
-
457
- # These 2 lines are just to test ddt and dA being computed by old code
458
- # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
459
- # ddt_given.copy_(ddt)
460
-
461
- return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states)
462
- return return_vals if not recompute_output else (*return_vals, outz)
463
-
464
-
465
- def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
466
- """
467
- Argument:
468
- dout: (batch, seqlen, nheads, headdim)
469
- x: (batch, seqlen, nheads, headdim)
470
- dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
471
- A: (nheads) or (dim, dstate)
472
- B: (batch, seqlen, ngroups, dstate)
473
- C: (batch, seqlen, ngroups, dstate)
474
- D: (nheads, headdim) or (nheads,)
475
- z: (batch, seqlen, nheads, headdim)
476
- Return:
477
- out: (batch, seqlen, nheads, headdim)
478
- """
479
- import selective_scan
480
-
481
- batch, seqlen, nheads, headdim = x.shape
482
- chunk_size = dt.shape[-1]
483
- _, _, ngroups, dstate = B.shape
484
- assert nheads % ngroups == 0
485
- x = rearrange(x, "b l h p -> b (h p) l")
486
- squeeze_dt = dt.dim() == 4
487
- if dt.dim() == 4:
488
- dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
489
- dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
490
- squeeze_A = A.dim() == 1
491
- if A.dim() == 1:
492
- A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
493
- else:
494
- A = A.to(dtype=torch.float32)
495
- B = rearrange(B, "b l g n -> b g n l")
496
- C = rearrange(C, "b l g n -> b g n l")
497
- if D is not None:
498
- if D.dim() == 2:
499
- D = rearrange(D, "h p -> (h p)")
500
- else:
501
- D = repeat(D, "h -> (h p)", p=headdim)
502
- if z is not None:
503
- z = rearrange(z, "b l h p -> b (h p) l")
504
-
505
- if x.stride(-1) != 1:
506
- x = x.contiguous()
507
- if dt.stride(-1) != 1:
508
- dt = dt.contiguous()
509
- if D is not None:
510
- D = D.contiguous()
511
- if B.stride(-1) != 1:
512
- B = B.contiguous()
513
- if C.stride(-1) != 1:
514
- C = C.contiguous()
515
- if z is not None and z.stride(-1) != 1:
516
- z = z.contiguous()
517
- _, intermediate, *rest = selective_scan.fwd(x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False)
518
- if z is not None:
519
- out = rest[0]
520
- else:
521
- out = None
522
-
523
- dout = rearrange(dout, "b l h p -> b (h p) l")
524
-
525
- if dout.stride(-1) != 1:
526
- dout = dout.contiguous()
527
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
528
- # backward of selective_scan with the backward of chunk).
529
- # Here we just pass in None and dz will be allocated in the C++ code.
530
- _, ddt, dA, *rest = selective_scan.bwd(
531
- x, dt.to(dtype=x.dtype), A, B, C, D, z, None, dout, intermediate, out, None, False,
532
- False # option to recompute out_z, not used here
533
- )
534
- ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
535
- if squeeze_dt:
536
- ddt = ddt.float().sum(dim=2)
537
- if squeeze_A:
538
- dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
539
- return ddt, dA
540
-
541
-
542
- class MambaChunkScanCombinedFn(torch.autograd.Function):
543
-
544
- @staticmethod
545
- def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False):
546
- ctx.dt_dtype = dt.dtype
547
- if not return_varlen_states:
548
- cu_seqlens = None
549
- else:
550
- assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
551
- out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
552
- ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx)
553
- ctx.dt_softplus = dt_softplus
554
- ctx.chunk_size = chunk_size
555
- ctx.dt_limit = dt_limit
556
- ctx.return_final_states = return_final_states
557
- ctx.return_varlen_states = return_varlen_states
558
- if not return_varlen_states:
559
- return out if not return_final_states else (out, final_states)
560
- else:
561
- varlen_states = rest[0]
562
- return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states)
563
-
564
- @staticmethod
565
- def backward(ctx, dout, *args):
566
- out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors
567
- assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward"
568
- dfinal_states = args[0] if ctx.return_final_states else None
569
- dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)
570
- return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None
571
-
572
-
573
- def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False):
574
- """
575
- Argument:
576
- x: (batch, seqlen, nheads, headdim)
577
- dt: (batch, seqlen, nheads)
578
- A: (nheads)
579
- B: (batch, seqlen, ngroups, dstate)
580
- C: (batch, seqlen, ngroups, dstate)
581
- chunk_size: int
582
- D: (nheads, headdim) or (nheads,)
583
- z: (batch, seqlen, nheads, headdim)
584
- dt_bias: (nheads,)
585
- initial_states: (batch, nheads, headdim, dstate)
586
- seq_idx: (batch, seqlen)
587
- cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
588
- dt_softplus: Whether to apply softplus to dt
589
- Return:
590
- out: (batch, seqlen, nheads, headdim)
591
- """
592
- return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states)
593
-
594
-
595
- def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
596
- """
597
- Argument:
598
- x: (batch, seqlen, nheads, headdim)
599
- dt: (batch, seqlen, nheads)
600
- A: (nheads)
601
- B: (batch, seqlen, ngroups, dstate)
602
- C: (batch, seqlen, ngroups, dstate)
603
- D: (nheads, headdim) or (nheads,)
604
- z: (batch, seqlen, nheads, headdim)
605
- dt_bias: (nheads,)
606
- Return:
607
- out: (batch, seqlen, nheads, headdim)
608
- """
609
- batch, seqlen, nheads, headdim = x.shape
610
- dstate = B.shape[-1]
611
- if seqlen % chunk_size != 0:
612
- dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
613
- dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
614
- dt = dt.float() # We want high precision for this before cumsum
615
- if dt_bias is not None:
616
- dt = dt + rearrange(dt_bias, "h -> h 1 1")
617
- if dt_softplus:
618
- dt = F.softplus(dt)
619
- dA = dt * rearrange(A, "h -> h 1 1")
620
- dA = dt * rearrange(A, "h -> h 1 1")
621
- dA_cumsum = torch.cumsum(dA, dim=-1)
622
- # 1. Compute the state for each chunk
623
- states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
624
- # 2. Pass the state to all the chunks by weighted cumsum.
625
- states = rearrange(state_passing(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
626
- "... (p n) -> ... p n", n=dstate)
627
- # 3. Compute the output for each chunk
628
- out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
629
- return out
630
-
631
-
632
- def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
633
- """
634
- Argument:
635
- x: (batch, seqlen, nheads, headdim)
636
- dt: (batch, seqlen, nheads)
637
- A: (nheads)
638
- B: (batch, seqlen, ngroups, dstate)
639
- C: (batch, seqlen, ngroups, dstate)
640
- D: (nheads, headdim) or (nheads,)
641
- z: (batch, seqlen, nheads, headdim)
642
- dt_bias: (nheads,)
643
- Return:
644
- out: (batch, seqlen, nheads, headdim)
645
- """
646
- batch, seqlen, nheads, headdim = x.shape
647
- dstate = B.shape[-1]
648
- if seqlen % chunk_size != 0:
649
- dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
650
- dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
651
- dt = dt.float() # We want high precision for this before cumsum
652
- if dt_bias is not None:
653
- dt = dt + rearrange(dt_bias, "h -> h 1 1")
654
- if dt_softplus:
655
- dt = F.softplus(dt)
656
- dA = dt * rearrange(A, "h -> h 1 1")
657
- dA_cumsum = torch.cumsum(dA, dim=-1)
658
- # 1. Compute the state for each chunk
659
- states = chunk_state_ref(B, x, dt, dA_cumsum)
660
- states_dtype = states.dtype
661
- if states.dtype not in [torch.float32, torch.float64]:
662
- states = states.to(torch.float32)
663
- # 2. Pass the state to all the chunks by weighted cumsum.
664
- # state_passing_ref is much less numerically stable
665
- states = rearrange(state_passing_ref(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
666
- "... (p n) -> ... p n", n=dstate)
667
- states = states.to(states_dtype)
668
- # 3. Compute the output for each chunk
669
- out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
670
- return out
671
-
672
-
673
- def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
674
- """
675
- Argument:
676
- x: (batch, seqlen, nheads, headdim)
677
- dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
678
- A: (nheads) or (dim, dstate)
679
- B: (batch, seqlen, ngroups, dstate)
680
- C: (batch, seqlen, ngroups, dstate)
681
- D: (nheads, headdim) or (nheads,)
682
- z: (batch, seqlen, nheads, headdim)
683
- dt_bias: (nheads,) or (nheads, headdim)
684
- Return:
685
- out: (batch, seqlen, nheads, headdim)
686
- """
687
- from ..selective_scan_interface import selective_scan_fn
688
-
689
- batch, seqlen, nheads, headdim = x.shape
690
- _, _, ngroups, dstate = B.shape
691
- x = rearrange(x, "b l h p -> b (h p) l")
692
- if dt.dim() == 3:
693
- dt = repeat(dt, "b l h -> b l h p", p=headdim)
694
- dt = rearrange(dt, "b l h p -> b (h p) l")
695
- if A.dim() == 1:
696
- A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
697
- else:
698
- A = A.to(dtype=torch.float32)
699
- B = rearrange(B, "b l g n -> b g n l")
700
- C = rearrange(C, "b l g n -> b g n l")
701
- if D is not None:
702
- if D.dim() == 2:
703
- D = rearrange(D, "h p -> (h p)")
704
- else:
705
- D = repeat(D, "h -> (h p)", p=headdim)
706
- if z is not None:
707
- z = rearrange(z, "b l h p -> b (h p) l")
708
- if dt_bias is not None:
709
- if dt_bias.dim() == 1:
710
- dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
711
- dt_bias = rearrange(dt_bias, "h p -> (h p)")
712
- if dt_limit != (0.0, float("inf")):
713
- if dt_bias is not None:
714
- dt = dt + rearrange(dt_bias, "d -> d 1")
715
- if dt_softplus:
716
- dt = F.softplus(dt)
717
- dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
718
- dt_bias = None
719
- dt_softplus = None
720
- out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus)
721
- return rearrange(out, "b (h p) l -> b l h p", p=headdim)
722
-
723
-
724
- def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=None, z=None,
725
- dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")),
726
- activation="silu", headdim=None, ngroups=1):
727
- """
728
- Argument:
729
- xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
730
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
731
- conv1d_bias: (dim + 2 * ngroups * dstate,)
732
- dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
733
- A: (nheads)
734
- D: (nheads, headdim) or (nheads,)
735
- z: (batch, seqlen, dim)
736
- dt_bias: (nheads) or (nheads, headdim)
737
- headdim: if D is 1D and z is None, headdim must be passed in
738
- Return:
739
- out: (batch, seqlen, dim)
740
- """
741
- batch, seqlen, nheads = dt.shape[:3]
742
- assert nheads % ngroups == 0
743
- if z is not None:
744
- dim = z.shape[-1]
745
- assert dim % nheads == 0
746
- headdim = dim // nheads
747
- else:
748
- if D.dim() == 1:
749
- assert headdim is not None
750
- else:
751
- headdim = D.shape[1]
752
- dim = nheads * headdim
753
- xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
754
- "b d s -> b s d")
755
- dstate = (xBC.shape[-1] - dim) // ngroups // 2
756
- x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
757
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
758
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
759
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
760
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
761
- out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
762
- return rearrange(out, "b s h p -> b s (h p)")
763
-
764
-
765
- class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
766
-
767
- @staticmethod
768
- @custom_fwd
769
- def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
770
- rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None,
771
- ngroups=1, norm_before_gate=True):
772
- assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
773
- assert activation in [None, "silu", "swish"]
774
- if D.dim() == 1:
775
- assert headdim is not None
776
- nheads, = D.shape
777
- else:
778
- nheads, headdim = D.shape
779
- batch, seqlen, _ = zxbcdt.shape
780
- dim = nheads * headdim
781
- assert nheads % ngroups == 0
782
- dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
783
- d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
784
- assert d_nonssm >= 0
785
- assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads)
786
- assert dt_bias.shape == (nheads,)
787
- assert A.shape == (nheads,)
788
- zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)
789
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
790
- xBC_conv = rearrange(
791
- causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"),
792
- conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
793
- "b d s -> b s d"
794
- )
795
- x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
796
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
797
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
798
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
799
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
800
- if rmsnorm_weight is None:
801
- out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
802
- out = rearrange(out, "b s h p -> b s (h p)")
803
- rstd = None
804
- if d_nonssm > 0:
805
- out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
806
- else:
807
- out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
808
- # reshape input data into 2D tensor
809
- x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
810
- z_rms = rearrange(z, "b s h p -> (b s) (h p)")
811
- rmsnorm_weight = rmsnorm_weight.contiguous()
812
- if d_nonssm == 0:
813
- out = None
814
- else:
815
- out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device)
816
- out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
817
- _swiglu_fwd(zx0, out=out01[..., :d_nonssm])
818
- out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out,
819
- group_size=dim // ngroups,
820
- norm_before_gate=norm_before_gate, is_rms_norm=True)
821
- if d_nonssm == 0:
822
- out = rearrange(out, "(b s) d -> b s d", b=batch)
823
- else:
824
- out = out01
825
- ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None
826
- if outproj_weight is not None:
827
- if torch.is_autocast_enabled():
828
- dtype = torch.get_autocast_gpu_dtype()
829
- out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
830
- outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None
831
- out = F.linear(out, outproj_weight, outproj_bias)
832
- else:
833
- assert outproj_bias is None
834
- ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias,
835
- out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias)
836
- ctx.dt_limit = dt_limit
837
- ctx.return_final_states = return_final_states
838
- ctx.activation = activation
839
- ctx.rmsnorm_eps = rmsnorm_eps
840
- ctx.norm_before_gate = norm_before_gate
841
- ctx.chunk_size = chunk_size
842
- ctx.headdim = headdim
843
- ctx.ngroups = ngroups
844
- return out if not return_final_states else (out, final_states)
845
-
846
- @staticmethod
847
- @custom_bwd
848
- def backward(ctx, dout, *args):
849
- assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
850
- zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
851
- dfinal_states = args[0] if ctx.return_final_states else None
852
- headdim = ctx.headdim
853
- nheads = D.shape[0]
854
- dim = nheads * headdim
855
- assert nheads % ctx.ngroups == 0
856
- dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
857
- d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
858
- assert d_nonssm >= 0
859
- recompute_output = outproj_weight is not None
860
- if recompute_output:
861
- out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype)
862
- out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1)
863
- zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
864
- # Recompute x, B, C
865
- xBC_conv = rearrange(
866
- causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"),
867
- conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
868
- "b d s -> b s d"
869
- )
870
- x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
871
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
872
- B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
873
- C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
874
- dzxbcdt = torch.empty_like(zxbcdt)
875
- dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
876
- dxBC = torch.empty_like(xBC)
877
- dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
878
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
879
- dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
880
- dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
881
- dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
882
- if outproj_weight is not None:
883
- dout_og = dout
884
- dout = F.linear(dout, outproj_weight.t())
885
- if d_nonssm > 0:
886
- dout0, dout = dout.split([d_nonssm, dim], dim=-1)
887
- _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
888
- dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
889
- if rmsnorm_weight is None:
890
- dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
891
- dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd(
892
- dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, dz=dz, recompute_output=recompute_output
893
- )
894
- out_for_linear = rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
895
- drmsnorm_weight = None
896
- else:
897
- batch = dout.shape[0]
898
- dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
899
- dz = rearrange(dz, "b l d -> (b l) d")
900
- x_rms = rearrange(out, "b s h p -> (b s) (h p)")
901
- z_rms = rearrange(z, "b s h p -> (b s) (h p)")
902
- out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None
903
- dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, group_size=dim//ctx.ngroups, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)
904
- out_for_linear = out_recompute if recompute_output else None
905
- dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
906
- dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
907
- dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC
908
- )
909
-
910
- if outproj_weight is not None:
911
- doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
912
- doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
913
- else:
914
- doutproj_weight, doutproj_bias = None, None
915
- dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
916
- dxBC_given_update, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
917
- rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
918
- rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, rearrange_and_update_stride(dxBC_given), False, ctx.activation in ["silu", "swish"]
919
- )
920
- if dxBC_given.stride() != dxBC_given_update.stride():
921
- dxBC_given.copy_(dxBC_given_update)
922
- else:
923
- dxBC_given = dxBC_given_update
924
- dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
925
- return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None
926
-
927
-
928
- def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
929
- """
930
- Argument:
931
- zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
932
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
933
- conv1d_bias: (dim + 2 * ngroups * dstate,)
934
- dt_bias: (nheads,)
935
- A: (nheads)
936
- D: (nheads, headdim) or (nheads,)
937
- initial_states: (batch, nheads, headdim, dstate)
938
- seq_idx: (batch, seqlen), int32
939
- rmsnorm_weight: (dim,)
940
- outproj_weight: (out_dim, dim)
941
- outproj_bias: (out_dim,)
942
- headdim: if D is 1D, headdim must be passed in
943
- norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
944
- Return:
945
- out: (batch, seqlen, dim)
946
- """
947
- return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
948
-
949
-
950
- def mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, dt_limit=(0.0, float("inf")), activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
951
- """
952
- Argument:
953
- zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
954
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
955
- conv1d_bias: (dim + 2 * ngroups * dstate,)
956
- dt_bias: (nheads,)
957
- A: (nheads)
958
- D: (nheads, headdim) or (nheads,)
959
- rmsnorm_weight: (dim,)
960
- outproj_weight: (out_dim, dim)
961
- outproj_bias: (out_dim,)
962
- headdim: if D is 1D, headdim must be passed in
963
- norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
964
- Return:
965
- out: (batch, seqlen, dim)
966
- """
967
- if D.dim() == 1:
968
- assert headdim is not None
969
- nheads, = D.shape
970
- else:
971
- nheads, headdim = D.shape
972
- assert nheads % ngroups == 0
973
- batch, seqlen, _ = zxbcdt.shape
974
- dim = nheads * headdim
975
- dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
976
- assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
977
- assert dt_bias.shape == (nheads,)
978
- assert A.shape == (nheads,)
979
- if rmsnorm_weight is not None:
980
- assert rmsnorm_weight.shape == (dim,)
981
- z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
982
- xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
983
- "b d s -> b s d")
984
- x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
985
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
986
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
987
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
988
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
989
- out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(),
990
- z=z if rmsnorm_weight is None else None, dt_bias=dt_bias, dt_softplus=True, dt_limit=dt_limit)
991
- out = rearrange(out, "b s h p -> b s (h p)")
992
- if rmsnorm_weight is not None:
993
- out = rmsnorm_fn(out, rmsnorm_weight, None, z=rearrange(z, "b l h p -> b l (h p)"), eps=rmsnorm_eps,
994
- norm_before_gate=norm_before_gate)
995
- if outproj_weight is not None:
996
- out = F.linear(out, outproj_weight, outproj_bias)
997
- return out
998
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_state_passing.py DELETED
@@ -1,348 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
-
16
- @triton.autotune(
17
- configs=[
18
- triton.Config({'BLOCK_SIZE': 64}),
19
- triton.Config({'BLOCK_SIZE': 128}),
20
- triton.Config({'BLOCK_SIZE': 256}),
21
- triton.Config({'BLOCK_SIZE': 512}),
22
- triton.Config({'BLOCK_SIZE': 1024}),
23
- triton.Config({'BLOCK_SIZE': 2048}),
24
- ],
25
- key=['dim'],
26
- )
27
- @triton.jit
28
- def _state_passing_fwd_kernel(
29
- # Pointers to matrices
30
- states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,
31
- # Matrix dimensions
32
- dim, nchunks, seqlen, chunk_size,
33
- # Strides
34
- stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,
35
- stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
36
- stride_final_states_batch, stride_final_states_head, stride_final_states_dim,
37
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
38
- stride_initstates_batch, stride_initstates_head, stride_initstates_dim,
39
- stride_seq_idx_batch, stride_seq_idx_seqlen,
40
- # Meta-parameters
41
- HAS_INITSTATES: tl.constexpr,
42
- HAS_SEQ_IDX: tl.constexpr,
43
- BLOCK_SIZE: tl.constexpr,
44
- ):
45
- pid_b = tl.program_id(axis=1)
46
- pid_h = tl.program_id(axis=2)
47
- pid_m = tl.program_id(axis=0)
48
- states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
49
- dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
50
- out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
51
- final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
52
- if HAS_INITSTATES:
53
- initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head
54
- if HAS_SEQ_IDX:
55
- seq_idx_ptr += pid_b * stride_seq_idx_batch
56
-
57
- offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
58
- states_ptrs = states_ptr + offs_m * stride_states_dim
59
- out_ptrs = out_ptr + offs_m * stride_out_dim
60
- final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
61
-
62
- if not HAS_INITSTATES:
63
- states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
64
- else:
65
- initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim
66
- states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
67
- tl.store(out_ptrs, states, mask=offs_m < dim)
68
- out_ptrs += stride_out_chunk
69
- seq_idx = 0
70
- for c in range(nchunks):
71
- new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
72
- dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
73
- scale = tl.exp(dA_cs)
74
- if HAS_SEQ_IDX:
75
- seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
76
- scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
77
- seq_idx = seq_idx_new
78
- states = scale * states + new_states
79
- if c < nchunks - 1:
80
- tl.store(out_ptrs, states, mask=offs_m < dim)
81
- else:
82
- tl.store(final_states_ptrs, states, mask=offs_m < dim)
83
- states_ptrs += stride_states_chunk
84
- dA_cs_ptr += stride_dA_cs_chunk
85
- out_ptrs += stride_out_chunk
86
-
87
-
88
- @triton.autotune(
89
- configs=[
90
- triton.Config({'BLOCK_SIZE': 64}),
91
- triton.Config({'BLOCK_SIZE': 128}),
92
- triton.Config({'BLOCK_SIZE': 256}),
93
- triton.Config({'BLOCK_SIZE': 512}),
94
- triton.Config({'BLOCK_SIZE': 1024}),
95
- triton.Config({'BLOCK_SIZE': 2048}),
96
- ],
97
- key=['dim'],
98
- )
99
- @triton.jit
100
- def _state_passing_bwd_kernel(
101
- # Pointers to matrices
102
- dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,
103
- dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,
104
- # Matrix dimensions
105
- dim, nchunks, seqlen, chunk_size,
106
- # Strides
107
- stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,
108
- stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
109
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
110
- stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,
111
- stride_seq_idx_batch, stride_seq_idx_seqlen,
112
- stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,
113
- stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,
114
- stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,
115
- # Meta-parameters
116
- CONVERT_STATES: tl.constexpr,
117
- HAS_DFINAL_STATES: tl.constexpr,
118
- HAS_DINITSTATES: tl.constexpr,
119
- HAS_SEQ_IDX: tl.constexpr,
120
- BLOCK_SIZE: tl.constexpr,
121
- ):
122
- pid_b = tl.program_id(axis=1)
123
- pid_h = tl.program_id(axis=2)
124
- pid_m = tl.program_id(axis=0)
125
- dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk
126
- dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk
127
- ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m
128
- out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
129
- dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk
130
- if CONVERT_STATES:
131
- states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
132
- if HAS_DFINAL_STATES:
133
- dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head
134
- if HAS_DINITSTATES:
135
- dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head
136
- if HAS_SEQ_IDX:
137
- seq_idx_ptr += pid_b * stride_seq_idx_batch
138
-
139
- offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
140
- dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim
141
- out_ptrs = out_ptr + offs_m * stride_out_dim
142
- dout_ptrs = dout_ptr + offs_m * stride_dout_dim
143
- if CONVERT_STATES:
144
- states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim
145
-
146
- if HAS_DFINAL_STATES:
147
- dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)
148
- else:
149
- dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
150
- tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
151
- if HAS_SEQ_IDX:
152
- seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)
153
- dstates_ptrs -= stride_dstates_chunk
154
- for c in range(nchunks - 1):
155
- dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
156
- scale = tl.exp(dA_cs)
157
- if HAS_SEQ_IDX:
158
- seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))
159
- scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
160
- seq_idx = seq_idx_new
161
- out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
162
- if CONVERT_STATES:
163
- tl.store(states_converted_ptrs, out, mask=offs_m < dim)
164
- ddA = tl.sum(out * dstates) * scale
165
- tl.store(ddA_cs_ptr, ddA)
166
- dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
167
- dstates = scale * dstates + dout
168
- tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
169
- dout_ptrs -= stride_dout_chunk
170
- dstates_ptrs -= stride_dstates_chunk
171
- dA_cs_ptr -= stride_dA_cs_chunk
172
- ddA_cs_ptr -= stride_ddA_cs_chunk
173
- out_ptrs -= stride_out_chunk
174
- if CONVERT_STATES:
175
- states_converted_ptrs -= stride_out_chunk
176
- if CONVERT_STATES:
177
- out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
178
- tl.store(states_converted_ptrs, out, mask=offs_m < dim)
179
- if not HAS_DINITSTATES:
180
- tl.store(ddA_cs_ptr, 0.0)
181
- else:
182
- dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
183
- scale = tl.exp(dA_cs)
184
- if HAS_SEQ_IDX:
185
- scale = tl.where(seq_idx == 0, scale, 0.0)
186
- out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
187
- ddA = tl.sum(out * dstates) * scale
188
- tl.store(ddA_cs_ptr, ddA)
189
- dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
190
- dstates = scale * dstates + dout
191
- tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)
192
-
193
-
194
- def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,
195
- out_dtype=None):
196
- batch, nchunks, nheads, dim = states.shape
197
- assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
198
- if initial_states is not None:
199
- assert initial_states.shape == (batch, nheads, dim)
200
- if seq_idx is not None:
201
- assert chunk_size is not None
202
- seqlen = seq_idx.shape[-1]
203
- assert seq_idx.shape == (batch, seqlen)
204
- out_dtype = states.dtype if out_dtype is None else out_dtype
205
- out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
206
- final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)
207
- grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
208
- with torch.cuda.device(states.device.index):
209
- _state_passing_fwd_kernel[grid](
210
- states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,
211
- dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
212
- states.stride(0), states.stride(1), states.stride(2), states.stride(3),
213
- out.stride(0), out.stride(1), out.stride(2), out.stride(3),
214
- final_states.stride(0), final_states.stride(1), final_states.stride(2),
215
- dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
216
- *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
217
- if initial_states is not None else (0, 0, 0)),
218
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
219
- HAS_INITSTATES=initial_states is not None,
220
- HAS_SEQ_IDX=seq_idx is not None,
221
- )
222
- return out, final_states
223
-
224
-
225
- def _state_passing_bwd(
226
- states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,
227
- dstates_dtype=None, states_dtype=None, chunk_size=None
228
- ):
229
- """
230
- states contains the initial_states at index 0. The final states are not included in states.
231
- """
232
- batch, nchunks, nheads, dim = states.shape
233
- assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
234
- assert dout.shape == (batch, nchunks, nheads, dim)
235
- if seq_idx is not None:
236
- assert chunk_size is not None
237
- seqlen = seq_idx.shape[-1]
238
- assert seq_idx.shape == (batch, seqlen)
239
- dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
240
- if states_dtype is not None and states_dtype != states.dtype:
241
- states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
242
- assert states_converted.stride() == states.stride()
243
- else:
244
- states_converted = None
245
- if has_initial_states:
246
- dinitstates = torch.empty_like(dstates[:, 0])
247
- else:
248
- dinitstates = None
249
- if dfinal_states is not None:
250
- assert dfinal_states.shape == (batch, nheads, dim)
251
- BLOCK_SIZE_min = 64
252
- n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min
253
- ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,
254
- dtype=torch.float32, device=dA_chunk_cumsum.device)
255
- grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
256
- with torch.cuda.device(dout.device.index):
257
- _state_passing_bwd_kernel[grid](
258
- dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,
259
- dstates, ddA_chunk_cumsum, dinitstates, states_converted,
260
- dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
261
- dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
262
- states.stride(0), states.stride(1), states.stride(2), states.stride(3),
263
- dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
264
- *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))
265
- if dfinal_states is not None else (0, 0, 0)),
266
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
267
- dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),
268
- ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),
269
- *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))
270
- if dinitstates is not None else (0, 0, 0)),
271
- CONVERT_STATES=states_converted is not None,
272
- HAS_DFINAL_STATES=dfinal_states is not None,
273
- HAS_DINITSTATES=dinitstates is not None,
274
- HAS_SEQ_IDX=seq_idx is not None,
275
- )
276
- BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"]
277
- n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
278
- ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)
279
- if states_dtype is not None and states_dtype == states.dtype:
280
- states_converted = states
281
- return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)
282
-
283
-
284
- class StatePassingFn(torch.autograd.Function):
285
-
286
- @staticmethod
287
- def forward(ctx, states, dA_chunk_cumsum, initial_states=None):
288
- batch, nchunks, nheads, dim = states.shape
289
- assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
290
- if states.stride(-1) != 1:
291
- states = states.contiguous()
292
- out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states)
293
- ctx.save_for_backward(out, dA_chunk_cumsum)
294
- ctx.has_initial_states = initial_states is not None
295
- return out, final_states
296
-
297
- @staticmethod
298
- def backward(ctx, dout, dfinal_states):
299
- out, dA_chunk_cumsum = ctx.saved_tensors
300
- batch, nchunks, nheads, dim = out.shape
301
- assert dout.shape == (batch, nchunks, nheads, dim)
302
- assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
303
- assert dfinal_states.shape == (batch, nheads, dim)
304
- if dout.stride(-1) != 1:
305
- dout = dout.contiguous()
306
- dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd(
307
- out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states
308
- )
309
- return dstates, ddA_chunk_cumsum, dinitstates
310
-
311
-
312
- def state_passing(states, dA_chunk_cumsum, initial_states=None):
313
- """
314
- Argument:
315
- states: (batch, nchunks, nheads, dim)
316
- dA_chunk_cumsum: (batch, nheads, nchunks)
317
- initial_states: (batch, nheads, dim)
318
- Return:
319
- out: (batch, nchunks, nheads, dim)
320
- final_states: (batch, nheads, dim)
321
- """
322
- return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states)
323
-
324
-
325
- def state_passing_ref(states, dA_chunk_cumsum, initial_states=None):
326
- """
327
- Argument:
328
- states: (batch, nchunks, nheads, dim)
329
- dA_chunk_cumsum: (batch, nheads, nchunks)
330
- initial_states: (batch, nheads, dim)
331
- Return:
332
- out: (batch, nchunks, nheads, dim)
333
- final_states: (batch, nheads, dim)
334
- """
335
- if initial_states is None:
336
- initial_states = torch.zeros_like(states[:, 0])
337
- states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1)
338
- dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))
339
- dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)
340
- nchunks = dA_chunk_cumsum.shape[-1]
341
- # (batch, nheads, nchunks, nchunks)
342
- dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]
343
- # (batch, nheads, nchunks, nchunks)
344
- decay_chunk = torch.exp(dt_chunk_segment_sum)
345
- causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)
346
- decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
347
- out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states)
348
- return out[:, :-1], out[:, -1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/utils/__init__.py DELETED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/utils/generation.py DELETED
@@ -1,390 +0,0 @@
1
- # Copyright (c) 2023, Albert Gu, Tri Dao.
2
- import gc
3
- import time
4
- from collections import namedtuple
5
- from dataclasses import dataclass, field
6
- from functools import partial
7
- from typing import Callable, Optional, Sequence, Union
8
-
9
- import torch
10
- import torch.nn.functional as F
11
- from einops import rearrange, repeat
12
- from torch import Tensor
13
- from torch.profiler import ProfilerActivity, profile, record_function
14
- from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
15
-
16
-
17
- @dataclass
18
- class InferenceParams:
19
- """Inference parameters that are passed to the main model in order
20
- to efficienly calculate and store the context during inference."""
21
-
22
- max_seqlen: int
23
- max_batch_size: int
24
- seqlen_offset: int = 0
25
- batch_size_offset: int = 0
26
- key_value_memory_dict: dict = field(default_factory=dict)
27
- lengths_per_sample: Optional[Tensor] = None
28
-
29
- def reset(self, max_seqlen, max_batch_size):
30
- self.max_seqlen = max_seqlen
31
- self.max_batch_size = max_batch_size
32
- self.seqlen_offset = 0
33
- if self.lengths_per_sample is not None:
34
- self.lengths_per_sample.zero_()
35
-
36
-
37
- def modify_logits_for_min_p_filtering(logits, min_p):
38
- """Set the logits for none min_p values to -inf. Done in-place."""
39
- if min_p <= 0.0 or min_p >= 1.0:
40
- return
41
- indices_to_remove = logits < min_p
42
- logits.masked_fill_(indices_to_remove, float("-Inf"))
43
- # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
44
- # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
45
- def modify_logits_for_top_k_filtering(logits, top_k):
46
- """Set the logits for none top-k values to -inf. Done in-place."""
47
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
48
- logits.masked_fill_(indices_to_remove, float("-Inf"))
49
-
50
-
51
- # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
52
- # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
53
- def modify_logits_for_top_p_filtering(logits, top_p):
54
- """Set the logits for none top-p values to -inf. Done in-place."""
55
- if top_p <= 0.0 or top_p >= 1.0:
56
- return
57
- # First sort and calculate cumulative sum of probabilities.
58
- sorted_logits, sorted_indices = torch.sort(logits, descending=False)
59
- cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
60
- # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
61
- sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
62
- # scatter sorted tensors to original indexing
63
- indices_to_remove = sorted_indices_to_remove.scatter(
64
- 1, sorted_indices, sorted_indices_to_remove
65
- )
66
- logits.masked_fill_(indices_to_remove, float("-inf"))
67
-
68
-
69
- def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
70
- """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
71
- logits: (batch_size, vocab_size)
72
- prev_output_tokens: (batch_size, seq_len)
73
- """
74
- if repetition_penalty == 1.0:
75
- return logits
76
- score = torch.gather(logits, 1, prev_output_tokens)
77
- # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
78
- score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
79
- logits.scatter_(1, prev_output_tokens, score)
80
- return logits
81
-
82
-
83
- def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
84
- """Sample from top-k logits.
85
- Arguments:
86
- logits: Tensor of shape (batch_size, vocab_size)
87
- """
88
- if top_k == 1: # Short-circuit for greedy decoding
89
- return logits.argmax(dim=-1)
90
- else:
91
- if top_p > 0.0:
92
- assert top_p <= 1.0, "top-p should be in (0, 1]."
93
- if top_k > 0:
94
- top_k = min(top_k, logits.size(-1)) # Safety check
95
- logits_top, indices = torch.topk(logits, top_k, dim=-1)
96
- if temperature != 1.0:
97
- logits_top /= temperature
98
- modify_logits_for_top_p_filtering(logits_top, top_p)
99
- return indices[
100
- torch.arange(indices.shape[0], device=indices.device),
101
- torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
102
- ]
103
- else:
104
- if min_p > 0.0:
105
- logits_top = logits.clone()
106
- max_prob = logits_top[..., 0].item()
107
- min_prob = max_prob * min_p
108
- modify_logits_for_min_p_filtering(logits_top, min_prob)
109
- if temperature != 1.0:
110
- logits_top /= temperature
111
- return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
112
- # Clone so that when we modify for top_p we don't change the original logits
113
- logits_top = logits / temperature if temperature != 1.0 else logits.clone()
114
- modify_logits_for_top_p_filtering(logits_top, top_p)
115
- return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
116
- dim=-1
117
- )
118
-
119
-
120
- @torch.inference_mode()
121
- def decode(
122
- input_ids,
123
- model,
124
- max_length,
125
- top_k=1,
126
- top_p=0.0,
127
- min_p=0.0,
128
- temperature=1.0,
129
- repetition_penalty=1.0,
130
- eos_token_id=None,
131
- teacher_outputs=None,
132
- vocab_size=None,
133
- cg=False,
134
- enable_timing=False,
135
- output_scores=False,
136
- streamer: Optional[TextStreamer] = None
137
- ):
138
- """Decoding, either greedy or with top-k or top-p sampling.
139
- If top-k = 0, don't limit the number of candidates (pure sampling).
140
- Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
141
- then top-p.
142
- We assume that all sequences in the same batch have the same length.
143
-
144
- Arguments:
145
- input_ids: (batch, seq_len)
146
- max_length: int
147
- teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
148
- logits, the next token is taken from the teacher_outputs. Useful for testing.
149
- Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
150
- sequences: (batch, max_length)
151
- scores: tuples of (batch, vocab_size)
152
- """
153
- if streamer is not None:
154
- streamer.put(input_ids.cpu())
155
-
156
- batch_size, seqlen_og = input_ids.shape
157
- teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
158
- if cg:
159
- if not hasattr(model, "_decoding_cache"):
160
- model._decoding_cache = None
161
- model._decoding_cache = update_graph_cache(
162
- model,
163
- model._decoding_cache,
164
- batch_size,
165
- seqlen_og,
166
- max_length,
167
- )
168
- inference_params = model._decoding_cache.inference_params
169
- inference_params.reset(max_length, batch_size)
170
- else:
171
- inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
172
-
173
- def get_logits(input_ids, inference_params):
174
- decoding = inference_params.seqlen_offset > 0
175
- if decoding:
176
- position_ids = torch.full(
177
- (batch_size, 1),
178
- inference_params.seqlen_offset,
179
- dtype=torch.long,
180
- device=input_ids.device,
181
- )
182
- else:
183
- position_ids = None
184
- if not cg or not decoding:
185
- logits = model(
186
- input_ids,
187
- position_ids=position_ids,
188
- inference_params=inference_params,
189
- num_last_tokens=1,
190
- ).logits.squeeze(dim=1)
191
- else:
192
- logits = model._decoding_cache.run(
193
- input_ids, position_ids, inference_params.seqlen_offset
194
- ).squeeze(dim=1)
195
- return logits[..., :vocab_size] if vocab_size is not None else logits
196
-
197
- def sample_tokens(logits, inference_params):
198
- if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
199
- token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
200
- else:
201
- token = teacher_outputs[:, inference_params.seqlen_offset]
202
- # return rearrange(token, "b -> b 1")
203
- return token.unsqueeze(1)
204
-
205
- def should_stop(current_token, inference_params):
206
- if inference_params.seqlen_offset == 0:
207
- return False
208
- if eos_token_id is not None and (current_token == eos_token_id).all():
209
- return True
210
- if inference_params.seqlen_offset >= max_length - 1:
211
- return True
212
- return False
213
-
214
- start = torch.cuda.Event(enable_timing=enable_timing)
215
- end = torch.cuda.Event(enable_timing=enable_timing)
216
-
217
- if enable_timing:
218
- start.record()
219
- scores, sequences = [], [input_ids]
220
- sequences_cat = input_ids
221
- while not should_stop(sequences[-1], inference_params):
222
- logits = get_logits(sequences[-1], inference_params)
223
- if output_scores:
224
- scores.append(logits.clone())
225
- inference_params.seqlen_offset += sequences[-1].shape[1]
226
- if repetition_penalty == 1.0:
227
- sampled_tokens = sample_tokens(logits, inference_params)
228
- else:
229
- logits = modify_logit_for_repetition_penalty(
230
- logits, sequences_cat, repetition_penalty
231
- )
232
- sampled_tokens = sample_tokens(logits, inference_params)
233
- sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
234
- sequences.append(sampled_tokens)
235
- if streamer is not None:
236
- streamer.put(sampled_tokens.cpu())
237
- if streamer is not None:
238
- streamer.end()
239
- if enable_timing:
240
- end.record()
241
- torch.cuda.synchronize()
242
- print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
243
- output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
244
- return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
245
-
246
-
247
- class GenerationMixin:
248
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
249
- raise NotImplementedError
250
-
251
- def generate(
252
- self,
253
- input_ids,
254
- max_length,
255
- top_k=1,
256
- top_p=0.0,
257
- min_p=0.0,
258
- temperature=1.0,
259
- return_dict_in_generate=False,
260
- output_scores=False,
261
- **kwargs,
262
- ):
263
- output = decode(
264
- input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs
265
- )
266
- if not output_scores:
267
- output.scores = None
268
- return output if return_dict_in_generate else output.sequences
269
-
270
-
271
- @dataclass
272
- class DecodingCGCache:
273
- max_batch_size: int = 0
274
- max_seqlen: int = 0
275
- device = None
276
- dtype = None
277
- callables: dict = field(default_factory=dict)
278
- mempool = None
279
- inference_params: Optional[InferenceParams] = None
280
- run: Optional[Callable] = None
281
-
282
-
283
- @torch.inference_mode()
284
- def update_graph_cache(
285
- model,
286
- cache,
287
- batch_size,
288
- seqlen_og,
289
- max_seqlen,
290
- decoding_seqlens=(1,),
291
- dtype=None,
292
- n_warmups=2,
293
- ):
294
- if cache is None:
295
- cache = DecodingCGCache()
296
- param_example = next(iter(model.parameters()))
297
- device = param_example.device
298
- if dtype is None:
299
- dtype = param_example.dtype
300
- if (
301
- (device, dtype) != (cache.device, cache.dtype)
302
- or batch_size > cache.max_batch_size
303
- or max_seqlen > cache.max_seqlen
304
- ): # Invalidate the cache
305
- cache.callables = {}
306
- cache.mempool = None
307
- cache.inference_params = None
308
- gc.collect()
309
- cache.device, cache.dtype = device, dtype
310
- cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
311
- assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
312
- inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
313
- lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
314
- cache.inference_params = InferenceParams(
315
- max_seqlen=max_seqlen,
316
- max_batch_size=batch_size,
317
- seqlen_offset=seqlen_og,
318
- key_value_memory_dict=inf_cache,
319
- lengths_per_sample=lengths_per_sample,
320
- )
321
- cache.mempool = torch.cuda.graphs.graph_pool_handle()
322
- for decoding_seqlen in decoding_seqlens:
323
- if (batch_size, decoding_seqlen) not in cache.callables:
324
- cache.callables[batch_size, decoding_seqlen] = capture_graph(
325
- model,
326
- cache.inference_params,
327
- batch_size,
328
- max_seqlen,
329
- decoding_seqlen=decoding_seqlen,
330
- mempool=cache.mempool,
331
- n_warmups=n_warmups,
332
- )
333
-
334
- def dispatch(input_ids, position_ids, seqlen):
335
- batch_size, decoding_seqlen = input_ids.shape[:2]
336
- return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
337
-
338
- cache.run = dispatch
339
- cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
340
- return cache
341
-
342
-
343
- def capture_graph(
344
- model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
345
- ):
346
- device = next(iter(model.parameters())).device
347
- input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
348
- position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
349
- seqlen_offset_og = inference_params.seqlen_offset
350
- inference_params.seqlen_offset = max_seqlen - decoding_seqlen
351
- inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
352
-
353
- # Warmup before capture
354
- s = torch.cuda.Stream()
355
- s.wait_stream(torch.cuda.current_stream())
356
- with torch.cuda.stream(s):
357
- for _ in range(n_warmups):
358
- logits = model(
359
- input_ids,
360
- position_ids=position_ids,
361
- inference_params=inference_params,
362
- num_last_tokens=decoding_seqlen,
363
- ).logits
364
- s.synchronize()
365
- # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
366
- # which requires that graph launch and non-captured launch to not overlap (I think,
367
- # that's how I interpret the documentation). I'm not sure if this is required.
368
- if torch.distributed.is_initialized():
369
- torch.distributed.barrier()
370
- torch.cuda.current_stream().wait_stream(s)
371
- # Captures the graph
372
- # To allow capture, automatically sets a side stream as the current stream in the context
373
- graph = torch.cuda.CUDAGraph()
374
- with torch.cuda.graph(graph, pool=mempool):
375
- logits = model(
376
- input_ids,
377
- position_ids=position_ids,
378
- inference_params=inference_params,
379
- num_last_tokens=decoding_seqlen,
380
- ).logits
381
-
382
- def run(new_input_ids, new_position_ids, seqlen):
383
- inference_params.lengths_per_sample[:] = seqlen
384
- input_ids.copy_(new_input_ids)
385
- position_ids.copy_(new_position_ids)
386
- graph.replay()
387
- return logits.clone()
388
-
389
- inference_params.seqlen_offset = seqlen_offset_og
390
- return run
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/utils/hf.py DELETED
@@ -1,23 +0,0 @@
1
- import json
2
-
3
- import torch
4
-
5
- from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6
- from transformers.utils.hub import cached_file
7
-
8
-
9
- def load_config_hf(model_name):
10
- resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11
- return json.load(open(resolved_archive_file))
12
-
13
-
14
- def load_state_dict_hf(model_name, device=None, dtype=None):
15
- # If not fp32, then we don't want to load directly to the GPU
16
- mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17
- resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18
- return torch.load(resolved_archive_file, map_location=mapped_device)
19
- # Convert dtype before moving to GPU to save memory
20
- if dtype is not None:
21
- state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22
- state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23
- return state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/utils/torch.py DELETED
@@ -1,21 +0,0 @@
1
- import torch
2
- from functools import partial
3
- from typing import Callable
4
-
5
- def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
6
- def decorator(*args, **kwargs):
7
- if cuda_amp_deprecated:
8
- kwargs["device_type"] = "cuda"
9
- return dec(*args, **kwargs)
10
- return decorator
11
-
12
-
13
- if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
14
- deprecated = True
15
- from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
16
- else:
17
- deprecated = False
18
- from torch.cuda.amp import custom_fwd, custom_bwd
19
-
20
- custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
21
- custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- __version__ = "2.2.4"
2
-
3
- from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
- from .modules.mamba_simple import Mamba
5
- from .modules.mamba2 import Mamba2
6
- from .models.mixer_seq_simple import MambaLMHeadModel
7
-
8
- __all__ = [
9
- "selective_scan_fn",
10
- "mamba_inner_fn",
11
- "Mamba",
12
- "Mamba2",
13
- "MambaLMHeadModel",
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/_mamba_ssm_b2a7fd5.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2cebad781003a612eea29f35ebaf4a1905057ac6e20cdd12a216e4e403b34095
3
- size 610662240
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _mamba_ssm_b2a7fd5
3
- ops = torch.ops._mamba_ssm_b2a7fd5
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_mamba_ssm_b2a7fd5::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/distributed/__init__.py DELETED
File without changes
build/torch210-cxx11-cu128-x86_64-linux/distributed/distributed_utils.py DELETED
@@ -1,144 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- from torch import Tensor
5
- from torch.distributed import ProcessGroup
6
-
7
- # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
8
- # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
9
- # version of PyTorch. The following 4 lines are for backward compatibility with
10
- # older PyTorch.
11
- if "all_gather_into_tensor" not in dir(torch.distributed):
12
- torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
13
- if "reduce_scatter_tensor" not in dir(torch.distributed):
14
- torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
15
-
16
-
17
- # Raw operation, does not support autograd, but does support async
18
- def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
19
- world_size = torch.distributed.get_world_size(process_group)
20
- output = torch.empty(
21
- world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
22
- )
23
- handle = torch.distributed.all_gather_into_tensor(
24
- output, input_.contiguous(), group=process_group, async_op=async_op
25
- )
26
- return output, handle
27
-
28
-
29
- # Raw operation, does not support autograd, but does support async
30
- def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
31
- world_size = torch.distributed.get_world_size(process_group)
32
- assert input_.shape[0] % world_size == 0
33
- output = torch.empty(
34
- input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
35
- )
36
- handle = torch.distributed.reduce_scatter_tensor(
37
- output, input_.contiguous(), group=process_group, async_op=async_op
38
- )
39
- return output, handle
40
-
41
-
42
- # Raw operation, does not support autograd, but does support async
43
- def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
44
- input_ = input_.contiguous()
45
- handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
46
- return input_, handle
47
-
48
-
49
- class AllGatherFunc(torch.autograd.Function):
50
- """Gather the input from sequence parallel region and concatenate."""
51
-
52
- @staticmethod
53
- def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
54
- ctx.process_group = process_group
55
- output, _ = all_gather_raw(input_, process_group)
56
- return output
57
-
58
- @staticmethod
59
- def backward(ctx, grad_output: Tensor):
60
- grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
61
- return grad_input, None
62
-
63
-
64
- # Supports autograd, but does not support async
65
- all_gather = AllGatherFunc.apply
66
-
67
-
68
- class ReduceScatterFunc(torch.autograd.Function):
69
- """Reduce scatter the input from the sequence parallel region and concatenate."""
70
-
71
- @staticmethod
72
- def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
73
- ctx.process_group = process_group
74
- output, _ = reduce_scatter_raw(input_, process_group)
75
- return output
76
-
77
- @staticmethod
78
- def backward(ctx, grad_output: Tensor):
79
- grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
80
- return grad_input, None
81
-
82
-
83
- # Supports autograd, but does not support async
84
- reduce_scatter = ReduceScatterFunc.apply
85
-
86
-
87
- class AllReduceFunc(torch.autograd.Function):
88
- """Gather the input from sequence parallel region and concatenate."""
89
-
90
- @staticmethod
91
- def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
92
- ctx.process_group = process_group
93
- output, _ = all_reduce_raw(input_, process_group)
94
- return output
95
-
96
- @staticmethod
97
- def backward(ctx, grad_output: Tensor):
98
- return grad_output, None
99
-
100
-
101
- # Supports autograd, but does not support async
102
- all_reduce = AllReduceFunc.apply
103
-
104
-
105
- def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
106
- # We want to iterate over parameters with _shared_params=True in the same order,
107
- # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
108
- pamams_shared = {
109
- name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
110
- }
111
- for _, p in sorted(pamams_shared.items()):
112
- with torch.no_grad():
113
- # Broadcast needs src to be global rank, not group rank
114
- torch.distributed.broadcast(
115
- p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
116
- )
117
-
118
-
119
- # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
120
- def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
121
- # We want to iterate over parameters with _sequence_parallel=True in the same order,
122
- # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
123
- params_seqparallel = {
124
- name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
125
- }
126
- grads = [p.grad for _, p in sorted(params_seqparallel.items())]
127
- if grads:
128
- with torch.no_grad():
129
- coalesced = torch._utils._flatten_dense_tensors(grads)
130
- torch.distributed.all_reduce(coalesced, group=process_group)
131
- for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
132
- buf.copy_(synced)
133
-
134
-
135
- def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
136
- """Get the dim for the local rank derived from splitting dim on world_size processes.
137
-
138
- The split may not be even across the world_size processes.
139
- """
140
- multiple = dim // multiple_of
141
- div = multiple // world_size
142
- mod = multiple % world_size
143
- local_multiple = div + int(local_rank < mod)
144
- return local_multiple * multiple_of
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/distributed/tensor_parallel.py DELETED
@@ -1,296 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
- from typing import Optional
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from torch import Tensor
9
- from torch.distributed import ProcessGroup
10
- from ..utils.torch import custom_bwd, custom_fwd
11
-
12
- from einops import rearrange
13
-
14
- from ..distributed.distributed_utils import (
15
- all_gather_raw,
16
- all_reduce,
17
- all_reduce_raw,
18
- reduce_scatter,
19
- reduce_scatter_raw,
20
- )
21
-
22
-
23
- class ParallelLinearFunc(torch.autograd.Function):
24
- @staticmethod
25
- @custom_fwd
26
- def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
- """
28
- If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
- with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
- """
31
- ctx.compute_weight_gradient = weight.requires_grad
32
- ctx.process_group = process_group
33
- ctx.sequence_parallel = sequence_parallel
34
-
35
- if torch.is_autocast_enabled():
36
- x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
- x = x.contiguous()
38
- if process_group is not None and sequence_parallel:
39
- # We want to kick off the all_gather early, before weight dtype conversion
40
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
- else:
42
- total_x = x
43
-
44
- if torch.is_autocast_enabled():
45
- weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
- bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
47
- weight = weight.contiguous()
48
- if process_group is not None and sequence_parallel:
49
- handle_x.wait()
50
- batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
51
- batch_dim = batch_shape.numel()
52
- # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
53
- output = F.linear(total_x, weight, bias)
54
- if ctx.compute_weight_gradient:
55
- ctx.save_for_backward(x, weight)
56
- else:
57
- ctx.save_for_backward(weight)
58
- return output
59
-
60
- @staticmethod
61
- @custom_bwd
62
- def backward(ctx, grad_output):
63
- grad_output = grad_output.contiguous()
64
- process_group = ctx.process_group
65
- sequence_parallel = ctx.sequence_parallel
66
- if ctx.compute_weight_gradient:
67
- x, weight = ctx.saved_tensors
68
- if process_group is not None and sequence_parallel:
69
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
70
- else:
71
- total_x = x
72
- else:
73
- (weight,) = ctx.saved_tensors
74
- total_x = None
75
- batch_shape = grad_output.shape[:-1]
76
- batch_dim = batch_shape.numel()
77
- grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
78
- if ctx.needs_input_grad[0]:
79
- grad_input = F.linear(grad_output, weight.t())
80
- grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
81
- if process_group is not None:
82
- reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
83
- grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
84
- else:
85
- grad_input = None
86
- if ctx.needs_input_grad[1]:
87
- assert ctx.compute_weight_gradient
88
- if process_group is not None and sequence_parallel:
89
- handle_x.wait()
90
- grad_weight = torch.einsum(
91
- "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
92
- )
93
- else:
94
- grad_weight = None
95
- grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
96
- if process_group is not None and ctx.needs_input_grad[0]:
97
- handle_grad_input.wait()
98
- return grad_input, grad_weight, grad_bias, None, None
99
-
100
-
101
- def parallel_linear_func(
102
- x: Tensor,
103
- weight: Tensor,
104
- bias: Optional[Tensor] = None,
105
- process_group: Optional[ProcessGroup] = None,
106
- sequence_parallel: bool = True,
107
- ):
108
- return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
109
-
110
-
111
- class ColumnParallelLinear(nn.Linear):
112
- def __init__(
113
- self,
114
- in_features: int,
115
- out_features: int,
116
- process_group: ProcessGroup,
117
- bias: bool = True,
118
- sequence_parallel=True,
119
- multiple_of=1,
120
- device=None,
121
- dtype=None,
122
- ) -> None:
123
- world_size = torch.distributed.get_world_size(process_group)
124
- if out_features % multiple_of:
125
- raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
126
- multiple = out_features // multiple_of
127
- # We want to split @multiple across world_size, but it could be an uneven split
128
- div = multiple // world_size
129
- mod = multiple % world_size
130
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
131
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
132
- super().__init__(
133
- in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
134
- )
135
- self.process_group = process_group
136
- self.sequence_parallel = sequence_parallel
137
-
138
- def forward(self, x):
139
- # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
140
- # we do an all_gather of x before doing the matmul.
141
- # If not, then the input is already gathered.
142
- return parallel_linear_func(
143
- x,
144
- self.weight,
145
- self.bias,
146
- process_group=self.process_group,
147
- sequence_parallel=self.sequence_parallel,
148
- )
149
-
150
-
151
- class RowParallelLinear(nn.Linear):
152
- def __init__(
153
- self,
154
- in_features: int,
155
- out_features: int,
156
- process_group: ProcessGroup,
157
- bias: bool = True,
158
- sequence_parallel=True,
159
- multiple_of=1,
160
- device=None,
161
- dtype=None,
162
- ) -> None:
163
- world_size = torch.distributed.get_world_size(process_group)
164
- rank = torch.distributed.get_rank(process_group)
165
- if in_features % multiple_of:
166
- raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
167
- multiple = in_features // multiple_of
168
- # We want to split @multiple across world_size, but it could be an uneven split
169
- div = multiple // world_size
170
- mod = multiple % world_size
171
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
172
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
173
- # Only rank 0 will have bias
174
- super().__init__(
175
- local_multiple * multiple_of,
176
- out_features,
177
- bias=bias and rank == 0,
178
- device=device,
179
- dtype=dtype,
180
- )
181
- self.process_group = process_group
182
- self.sequence_parallel = sequence_parallel
183
-
184
- def forward(self, x):
185
- """
186
- We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
187
- a reduce_scatter of the result.
188
- """
189
- out = parallel_linear_func(x, self.weight, self.bias)
190
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
191
- return reduce_fn(out, self.process_group)
192
-
193
-
194
- class VocabParallelEmbedding(nn.Embedding):
195
- def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
196
- self.process_group = process_group
197
- if process_group is not None:
198
- world_size = torch.distributed.get_world_size(process_group)
199
- if num_embeddings % world_size != 0:
200
- raise ValueError(
201
- f"num_embeddings ({num_embeddings}) must be divisible by "
202
- f"world_size ({world_size})"
203
- )
204
- if world_size > 1 and padding_idx is not None:
205
- raise RuntimeError("ParallelEmbedding does not support padding_idx")
206
- else:
207
- world_size = 1
208
- super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
209
-
210
- def forward(self, input: Tensor) -> Tensor:
211
- if self.process_group is None:
212
- return super().forward(input)
213
- else:
214
- rank = torch.distributed.get_rank(self.process_group)
215
- vocab_size = self.num_embeddings
216
- vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
217
- # Create a mask of valid vocab ids (1 means it needs to be masked).
218
- input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
219
- input = input - vocab_start_index
220
- input[input_ids_mask] = 0
221
- embeddings = super().forward(input)
222
- embeddings[input_ids_mask] = 0.0
223
- return embeddings
224
-
225
-
226
- class ColumnParallelEmbedding(nn.Embedding):
227
- def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
228
- self.process_group = process_group
229
- if process_group is not None:
230
- world_size = torch.distributed.get_world_size(process_group)
231
- if embedding_dim % world_size != 0:
232
- raise ValueError(
233
- f"embedding_dim ({embedding_dim}) must be divisible by "
234
- f"world_size ({world_size})"
235
- )
236
- else:
237
- world_size = 1
238
- super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
239
-
240
-
241
- class ParallelEmbeddings(nn.Module):
242
- def __init__(
243
- self,
244
- embed_dim,
245
- vocab_size,
246
- max_position_embeddings,
247
- process_group,
248
- padding_idx=None,
249
- sequence_parallel=True,
250
- device=None,
251
- dtype=None,
252
- ):
253
- """
254
- If max_position_embeddings <= 0, there's no position embeddings
255
- """
256
- factory_kwargs = {"device": device, "dtype": dtype}
257
- super().__init__()
258
- self.process_group = process_group
259
- self.sequence_parallel = sequence_parallel
260
- self.word_embeddings = VocabParallelEmbedding(
261
- vocab_size,
262
- embed_dim,
263
- padding_idx=padding_idx,
264
- process_group=process_group,
265
- **factory_kwargs,
266
- )
267
- self.max_position_embeddings = max_position_embeddings
268
- if self.max_position_embeddings > 0:
269
- self.position_embeddings = ColumnParallelEmbedding(
270
- max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
271
- )
272
-
273
- def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
274
- """
275
- input_ids: (batch, seqlen)
276
- position_ids: (batch, seqlen)
277
- """
278
- batch_size, seqlen = input_ids.shape
279
- world_size = torch.distributed.get_world_size(self.process_group)
280
- embeddings = self.word_embeddings(input_ids)
281
- if self.max_position_embeddings > 0:
282
- if position_ids is None:
283
- position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
284
- position_embeddings = self.position_embeddings(position_ids)
285
- if world_size <= 1:
286
- embeddings = embeddings + position_embeddings
287
- else:
288
- partition_dim = self.position_embeddings.embedding_dim
289
- rank = torch.distributed.get_rank(self.process_group)
290
- embeddings[
291
- ..., rank * partition_dim : (rank + 1) * partition_dim
292
- ] += position_embeddings
293
- if combine_batch_seqlen_dim:
294
- embeddings = rearrange(embeddings, "b s d -> (b s) d")
295
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
296
- return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/mamba_ssm/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch210-cxx11-cu128-x86_64-linux/models/__init__.py DELETED
File without changes
build/torch210-cxx11-cu128-x86_64-linux/models/config_mamba.py DELETED
@@ -1,18 +0,0 @@
1
- from dataclasses import dataclass, field
2
-
3
-
4
- @dataclass
5
- class MambaConfig:
6
-
7
- d_model: int = 2560
8
- d_intermediate: int = 0
9
- n_layer: int = 64
10
- vocab_size: int = 50277
11
- ssm_cfg: dict = field(default_factory=dict)
12
- attn_layer_idx: list = field(default_factory=list)
13
- attn_cfg: dict = field(default_factory=dict)
14
- rms_norm: bool = True
15
- residual_in_fp32: bool = True
16
- fused_add_norm: bool = True
17
- pad_vocab_size_multiple: int = 8
18
- tie_embeddings: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/models/mixer_seq_simple.py DELETED
@@ -1,309 +0,0 @@
1
- # Copyright (c) 2023, Albert Gu, Tri Dao.
2
-
3
- import math
4
- from functools import partial
5
- import json
6
- import os
7
- import copy
8
-
9
- from collections import namedtuple
10
-
11
- import torch
12
- import torch.nn as nn
13
-
14
- from .config_mamba import MambaConfig
15
- from ..modules.mamba_simple import Mamba
16
- from ..modules.mamba2 import Mamba2
17
- from ..modules.mha import MHA
18
- from ..modules.mlp import GatedMLP
19
- from ..modules.block import Block
20
- from ..utils.generation import GenerationMixin
21
- from ..utils.hf import load_config_hf, load_state_dict_hf
22
-
23
- try:
24
- from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
- except ImportError:
26
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
-
28
-
29
- def create_block(
30
- d_model,
31
- d_intermediate,
32
- ssm_cfg=None,
33
- attn_layer_idx=None,
34
- attn_cfg=None,
35
- norm_epsilon=1e-5,
36
- rms_norm=False,
37
- residual_in_fp32=False,
38
- fused_add_norm=False,
39
- layer_idx=None,
40
- device=None,
41
- dtype=None,
42
- ):
43
- if ssm_cfg is None:
44
- ssm_cfg = {}
45
- if attn_layer_idx is None:
46
- attn_layer_idx = []
47
- if attn_cfg is None:
48
- attn_cfg = {}
49
- factory_kwargs = {"device": device, "dtype": dtype}
50
- if layer_idx not in attn_layer_idx:
51
- # Create a copy of the config to modify
52
- ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
- ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
- if ssm_layer not in ["Mamba1", "Mamba2"]:
55
- raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
56
- mixer_cls = partial(
57
- Mamba2 if ssm_layer == "Mamba2" else Mamba,
58
- layer_idx=layer_idx,
59
- **ssm_cfg,
60
- **factory_kwargs
61
- )
62
- else:
63
- mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
64
- norm_cls = partial(
65
- nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
66
- )
67
- if d_intermediate == 0:
68
- mlp_cls = nn.Identity
69
- else:
70
- mlp_cls = partial(
71
- GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
72
- )
73
- block = Block(
74
- d_model,
75
- mixer_cls,
76
- mlp_cls,
77
- norm_cls=norm_cls,
78
- fused_add_norm=fused_add_norm,
79
- residual_in_fp32=residual_in_fp32,
80
- )
81
- block.layer_idx = layer_idx
82
- return block
83
-
84
-
85
- # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
86
- def _init_weights(
87
- module,
88
- n_layer,
89
- initializer_range=0.02, # Now only used for embedding layer.
90
- rescale_prenorm_residual=True,
91
- n_residuals_per_layer=1, # Change to 2 if we have MLP
92
- ):
93
- if isinstance(module, nn.Linear):
94
- if module.bias is not None:
95
- if not getattr(module.bias, "_no_reinit", False):
96
- nn.init.zeros_(module.bias)
97
- elif isinstance(module, nn.Embedding):
98
- nn.init.normal_(module.weight, std=initializer_range)
99
-
100
- if rescale_prenorm_residual:
101
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
102
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
103
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
104
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
105
- #
106
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
107
- for name, p in module.named_parameters():
108
- if name in ["out_proj.weight", "fc2.weight"]:
109
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
110
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
111
- # We need to reinit p since this code could be called multiple times
112
- # Having just p *= scale would repeatedly scale it down
113
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
114
- with torch.no_grad():
115
- p /= math.sqrt(n_residuals_per_layer * n_layer)
116
-
117
-
118
- class MixerModel(nn.Module):
119
- def __init__(
120
- self,
121
- d_model: int,
122
- n_layer: int,
123
- d_intermediate: int,
124
- vocab_size: int,
125
- ssm_cfg=None,
126
- attn_layer_idx=None,
127
- attn_cfg=None,
128
- norm_epsilon: float = 1e-5,
129
- rms_norm: bool = False,
130
- initializer_cfg=None,
131
- fused_add_norm=False,
132
- residual_in_fp32=False,
133
- device=None,
134
- dtype=None,
135
- ) -> None:
136
- factory_kwargs = {"device": device, "dtype": dtype}
137
- super().__init__()
138
- self.residual_in_fp32 = residual_in_fp32
139
-
140
- self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
141
-
142
- # We change the order of residual and layer norm:
143
- # Instead of LN -> Attn / MLP -> Add, we do:
144
- # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
145
- # the main branch (output of MLP / Mixer). The model definition is unchanged.
146
- # This is for performance reason: we can fuse add + layer_norm.
147
- self.fused_add_norm = fused_add_norm
148
- if self.fused_add_norm:
149
- if layer_norm_fn is None or rms_norm_fn is None:
150
- raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
151
-
152
- self.layers = nn.ModuleList(
153
- [
154
- create_block(
155
- d_model,
156
- d_intermediate=d_intermediate,
157
- ssm_cfg=ssm_cfg,
158
- attn_layer_idx=attn_layer_idx,
159
- attn_cfg=attn_cfg,
160
- norm_epsilon=norm_epsilon,
161
- rms_norm=rms_norm,
162
- residual_in_fp32=residual_in_fp32,
163
- fused_add_norm=fused_add_norm,
164
- layer_idx=i,
165
- **factory_kwargs,
166
- )
167
- for i in range(n_layer)
168
- ]
169
- )
170
-
171
- self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
172
- d_model, eps=norm_epsilon, **factory_kwargs
173
- )
174
-
175
- self.apply(
176
- partial(
177
- _init_weights,
178
- n_layer=n_layer,
179
- **(initializer_cfg if initializer_cfg is not None else {}),
180
- n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
181
- )
182
- )
183
-
184
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
185
- return {
186
- i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
187
- for i, layer in enumerate(self.layers)
188
- }
189
-
190
- def forward(self, input_ids, inference_params=None, **mixer_kwargs):
191
- hidden_states = self.embedding(input_ids)
192
- residual = None
193
- for layer in self.layers:
194
- hidden_states, residual = layer(
195
- hidden_states, residual, inference_params=inference_params, **mixer_kwargs
196
- )
197
- if not self.fused_add_norm:
198
- residual = (hidden_states + residual) if residual is not None else hidden_states
199
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
200
- else:
201
- # Set prenorm=False here since we don't need the residual
202
- hidden_states = layer_norm_fn(
203
- hidden_states,
204
- self.norm_f.weight,
205
- self.norm_f.bias,
206
- eps=self.norm_f.eps,
207
- residual=residual,
208
- prenorm=False,
209
- residual_in_fp32=self.residual_in_fp32,
210
- is_rms_norm=isinstance(self.norm_f, RMSNorm)
211
- )
212
- return hidden_states
213
-
214
-
215
- class MambaLMHeadModel(nn.Module, GenerationMixin):
216
-
217
- def __init__(
218
- self,
219
- config: MambaConfig,
220
- initializer_cfg=None,
221
- device=None,
222
- dtype=None,
223
- ) -> None:
224
- self.config = config
225
- d_model = config.d_model
226
- n_layer = config.n_layer
227
- d_intermediate = config.d_intermediate
228
- vocab_size = config.vocab_size
229
- ssm_cfg = config.ssm_cfg
230
- attn_layer_idx = config.attn_layer_idx
231
- attn_cfg = config.attn_cfg
232
- rms_norm = config.rms_norm
233
- residual_in_fp32 = config.residual_in_fp32
234
- fused_add_norm = config.fused_add_norm
235
- pad_vocab_size_multiple = config.pad_vocab_size_multiple
236
- factory_kwargs = {"device": device, "dtype": dtype}
237
-
238
- super().__init__()
239
- if vocab_size % pad_vocab_size_multiple != 0:
240
- vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
241
- self.backbone = MixerModel(
242
- d_model=d_model,
243
- n_layer=n_layer,
244
- d_intermediate=d_intermediate,
245
- vocab_size=vocab_size,
246
- ssm_cfg=ssm_cfg,
247
- attn_layer_idx=attn_layer_idx,
248
- attn_cfg=attn_cfg,
249
- rms_norm=rms_norm,
250
- initializer_cfg=initializer_cfg,
251
- fused_add_norm=fused_add_norm,
252
- residual_in_fp32=residual_in_fp32,
253
- **factory_kwargs,
254
- )
255
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
256
-
257
- # Initialize weights and apply final processing
258
- self.apply(
259
- partial(
260
- _init_weights,
261
- n_layer=n_layer,
262
- **(initializer_cfg if initializer_cfg is not None else {}),
263
- )
264
- )
265
- self.tie_weights()
266
-
267
- def tie_weights(self):
268
- if self.config.tie_embeddings:
269
- self.lm_head.weight = self.backbone.embedding.weight
270
-
271
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
272
- return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
273
-
274
- def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
275
- """
276
- "position_ids" is just to be compatible with Transformer generation. We don't use it.
277
- num_last_tokens: if > 0, only return the logits for the last n tokens
278
- """
279
- hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
280
- if num_last_tokens > 0:
281
- hidden_states = hidden_states[:, -num_last_tokens:]
282
- lm_logits = self.lm_head(hidden_states)
283
- CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
284
- return CausalLMOutput(logits=lm_logits)
285
-
286
- @classmethod
287
- def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
288
- config_data = load_config_hf(pretrained_model_name)
289
- config = MambaConfig(**config_data)
290
- model = cls(config, device=device, dtype=dtype, **kwargs)
291
- model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
292
- return model
293
-
294
- def save_pretrained(self, save_directory):
295
- """
296
- Minimal implementation of save_pretrained for MambaLMHeadModel.
297
- Save the model and its configuration file to a directory.
298
- """
299
- # Ensure save_directory exists
300
- os.makedirs(save_directory, exist_ok=True)
301
-
302
- # Save the model's state_dict
303
- model_path = os.path.join(save_directory, 'pytorch_model.bin')
304
- torch.save(self.state_dict(), model_path)
305
-
306
- # Save the configuration of the model
307
- config_path = os.path.join(save_directory, 'config.json')
308
- with open(config_path, 'w') as f:
309
- json.dump(self.config.__dict__, f, indent=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/modules/__init__.py DELETED
File without changes
build/torch210-cxx11-cu128-x86_64-linux/modules/block.py DELETED
@@ -1,107 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
- from typing import Optional
3
-
4
- import torch
5
- from torch import nn, Tensor
6
-
7
- from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
8
-
9
-
10
- class Block(nn.Module):
11
- def __init__(
12
- self,
13
- dim,
14
- mixer_cls,
15
- mlp_cls,
16
- norm_cls=nn.LayerNorm,
17
- fused_add_norm=False,
18
- residual_in_fp32=False,
19
- ):
20
- """
21
- Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
22
-
23
- This Block has a slightly different structure compared to a regular
24
- prenorm Transformer block.
25
- The standard block is: LN -> MHA/MLP -> Add.
26
- [Ref: https://arxiv.org/abs/2002.04745]
27
- Here we have: Add -> LN -> Mixer, returning both
28
- the hidden_states (output of the mixer) and the residual.
29
- This is purely for performance reasons, as we can fuse add and LayerNorm.
30
- The residual needs to be provided (except for the very first block).
31
- """
32
- super().__init__()
33
- self.residual_in_fp32 = residual_in_fp32
34
- self.fused_add_norm = fused_add_norm
35
- self.norm = norm_cls(dim)
36
- self.mixer = mixer_cls(dim)
37
- if mlp_cls is not nn.Identity:
38
- self.norm2 = norm_cls(dim)
39
- self.mlp = mlp_cls(dim)
40
- else:
41
- self.mlp = None
42
- if self.fused_add_norm:
43
- assert RMSNorm is not None, "RMSNorm import fails"
44
- assert isinstance(
45
- self.norm, (nn.LayerNorm, RMSNorm)
46
- ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
47
-
48
- def forward(
49
- self,
50
- hidden_states: Tensor,
51
- residual: Optional[Tensor] = None,
52
- inference_params=None,
53
- **mixer_kwargs
54
- ):
55
- r"""Pass the input through the encoder layer.
56
-
57
- Args:
58
- hidden_states: the sequence to the encoder layer (required).
59
- residual: hidden_states = Mixer(LN(residual))
60
- """
61
- if not self.fused_add_norm:
62
- residual = (
63
- (hidden_states + residual) if residual is not None else hidden_states
64
- )
65
- hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
66
- if self.residual_in_fp32:
67
- residual = residual.to(torch.float32)
68
- else:
69
- hidden_states, residual = layer_norm_fn(
70
- hidden_states,
71
- self.norm.weight,
72
- self.norm.bias,
73
- residual=residual,
74
- prenorm=True,
75
- residual_in_fp32=self.residual_in_fp32,
76
- eps=self.norm.eps,
77
- is_rms_norm=isinstance(self.norm, RMSNorm),
78
- )
79
- hidden_states = self.mixer(
80
- hidden_states, inference_params=inference_params, **mixer_kwargs
81
- )
82
-
83
- if self.mlp is not None:
84
- if not self.fused_add_norm:
85
- residual = hidden_states + residual
86
- hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
87
- if self.residual_in_fp32:
88
- residual = residual.to(torch.float32)
89
- else:
90
- hidden_states, residual = layer_norm_fn(
91
- hidden_states,
92
- self.norm2.weight,
93
- self.norm2.bias,
94
- residual=residual,
95
- prenorm=True,
96
- residual_in_fp32=self.residual_in_fp32,
97
- eps=self.norm2.eps,
98
- is_rms_norm=isinstance(self.norm2, RMSNorm),
99
- )
100
- hidden_states = self.mlp(hidden_states)
101
-
102
- return hidden_states, residual
103
-
104
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
105
- return self.mixer.allocate_inference_cache(
106
- batch_size, max_seqlen, dtype=dtype, **kwargs
107
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/modules/mamba2.py DELETED
@@ -1,502 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- import math
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
- from einops import rearrange, repeat
10
-
11
- try:
12
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
13
- except ImportError:
14
- causal_conv1d_fn, causal_conv1d_update = None, None
15
-
16
- try:
17
- from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18
- except ImportError:
19
- causal_conv1d_varlen_states = None
20
-
21
- try:
22
- from ..ops.triton.selective_state_update import selective_state_update
23
- except ImportError:
24
- selective_state_update = None
25
-
26
- from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
27
-
28
- from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
29
- from ..distributed.distributed_utils import all_reduce, reduce_scatter
30
-
31
- from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
32
- from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
33
-
34
- from huggingface_hub import PyTorchModelHubMixin
35
-
36
-
37
- class Mamba2(nn.Module, PyTorchModelHubMixin):
38
- def __init__(
39
- self,
40
- d_model,
41
- d_state=128,
42
- d_conv=4,
43
- conv_init=None,
44
- expand=2,
45
- headdim=64,
46
- d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
47
- ngroups=1,
48
- A_init_range=(1, 16),
49
- D_has_hdim=False,
50
- rmsnorm=True,
51
- norm_before_gate=False,
52
- dt_min=0.001,
53
- dt_max=0.1,
54
- dt_init_floor=1e-4,
55
- dt_limit=(0.0, float("inf")),
56
- bias=False,
57
- conv_bias=True,
58
- # Fused kernel and sharding options
59
- chunk_size=256,
60
- use_mem_eff_path=True,
61
- layer_idx=None, # Absorb kwarg for general module
62
- process_group=None,
63
- sequence_parallel=True,
64
- device=None,
65
- dtype=None,
66
- ):
67
- factory_kwargs = {"device": device, "dtype": dtype}
68
- super().__init__()
69
- self.d_model = d_model
70
- self.d_state = d_state
71
- self.d_conv = d_conv
72
- self.conv_init = conv_init
73
- self.expand = expand
74
- self.process_group = process_group
75
- self.sequence_parallel = sequence_parallel
76
- self.world_size = 1 if process_group is None else process_group.size()
77
- self.local_rank = 0 if process_group is None else process_group.rank()
78
- self.d_inner = (self.expand * self.d_model) // self.world_size
79
- assert self.d_inner * self.world_size == self.expand * self.d_model
80
- self.headdim = headdim
81
- self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
82
- assert ngroups % self.world_size == 0
83
- self.ngroups = ngroups // self.world_size
84
- assert self.d_ssm % self.headdim == 0
85
- self.nheads = self.d_ssm // self.headdim
86
- self.D_has_hdim = D_has_hdim
87
- self.rmsnorm = rmsnorm
88
- self.norm_before_gate = norm_before_gate
89
- self.dt_limit = dt_limit
90
- self.activation = "silu"
91
- self.chunk_size = chunk_size
92
- self.use_mem_eff_path = use_mem_eff_path
93
- self.layer_idx = layer_idx
94
-
95
- # Order: [z, x, B, C, dt]
96
- d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
97
- if self.process_group is None:
98
- self.in_proj = nn.Linear(
99
- self.d_model, d_in_proj, bias=bias, **factory_kwargs
100
- )
101
- else:
102
- self.in_proj = ColumnParallelLinear(
103
- self.d_model,
104
- d_in_proj * self.world_size,
105
- bias=bias,
106
- process_group=self.process_group,
107
- sequence_parallel=self.sequence_parallel,
108
- **factory_kwargs,
109
- )
110
-
111
- conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
112
- self.conv1d = nn.Conv1d(
113
- in_channels=conv_dim,
114
- out_channels=conv_dim,
115
- bias=conv_bias,
116
- kernel_size=d_conv,
117
- groups=conv_dim,
118
- padding=d_conv - 1,
119
- **factory_kwargs,
120
- )
121
- if self.conv_init is not None:
122
- nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
123
-
124
- self.act = nn.SiLU()
125
-
126
- # Initialize log dt bias
127
- dt = torch.exp(
128
- torch.rand(self.nheads, **factory_kwargs)
129
- * (math.log(dt_max) - math.log(dt_min))
130
- + math.log(dt_min)
131
- )
132
- dt = torch.clamp(dt, min=dt_init_floor)
133
- # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
134
- inv_dt = dt + torch.log(-torch.expm1(-dt))
135
- self.dt_bias = nn.Parameter(inv_dt)
136
- # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
137
- # name.endswith("bias") in param_grouping.py
138
- self.dt_bias._no_weight_decay = True
139
-
140
- assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
141
- A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
142
- *A_init_range
143
- )
144
- A_log = torch.log(A).to(dtype=dtype)
145
- self.A_log = nn.Parameter(A_log)
146
- self.A_log._no_weight_decay = True
147
-
148
- # D "skip" parameter
149
- self.D = nn.Parameter(
150
- torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
151
- )
152
- self.D._no_weight_decay = True
153
-
154
- if self.rmsnorm:
155
- assert RMSNormGated is not None
156
- self.norm = RMSNormGated(
157
- self.d_ssm,
158
- eps=1e-5,
159
- norm_before_gate=self.norm_before_gate,
160
- group_size=self.d_ssm // ngroups,
161
- **factory_kwargs,
162
- )
163
-
164
- if self.process_group is None:
165
- self.out_proj = nn.Linear(
166
- self.d_inner, self.d_model, bias=bias, **factory_kwargs
167
- )
168
- else:
169
- self.out_proj = RowParallelLinear(
170
- self.d_inner * self.world_size,
171
- self.d_model,
172
- bias=bias,
173
- process_group=self.process_group,
174
- sequence_parallel=self.sequence_parallel,
175
- **factory_kwargs,
176
- )
177
-
178
- def forward(
179
- self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
180
- ):
181
- """
182
- u: (batch, seqlen, hidden_dim) if seqlen=None.
183
- If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
184
- split u during sequence parallel, we split the batch * seqlen dimension
185
- (in case batch is small).
186
- Returns: same shape as u
187
- """
188
- seqlen_og = seqlen
189
- if seqlen is None:
190
- batch, seqlen, dim = u.shape
191
- else:
192
- batch_seqlen, dim = u.shape
193
- batch = batch_seqlen // seqlen
194
-
195
- conv_state, ssm_state = None, None
196
- if inference_params is not None:
197
- inference_batch = (
198
- cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
199
- )
200
- conv_state, ssm_state = self._get_states_from_cache(
201
- inference_params, inference_batch
202
- )
203
- if inference_params.seqlen_offset > 0:
204
- # The states are updated inplace
205
- out, _, _ = self.step(u, conv_state, ssm_state)
206
- return out
207
-
208
- zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
209
- if seqlen_og is not None:
210
- zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
211
- # If the model is loaded in fp16, without the .float() here, A might be -inf
212
- A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
213
- dt_limit_kwargs = (
214
- {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
215
- )
216
- if self.use_mem_eff_path and inference_params is None:
217
- out = mamba_split_conv1d_scan_combined(
218
- zxbcdt,
219
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
220
- self.conv1d.bias,
221
- self.dt_bias,
222
- A,
223
- D=(
224
- rearrange(self.D, "(h p) -> h p", p=self.headdim)
225
- if self.D_has_hdim
226
- else self.D
227
- ),
228
- chunk_size=self.chunk_size,
229
- seq_idx=seq_idx,
230
- activation=self.activation,
231
- rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
232
- rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
233
- outproj_weight=self.out_proj.weight,
234
- outproj_bias=self.out_proj.bias,
235
- headdim=None if self.D_has_hdim else self.headdim,
236
- ngroups=self.ngroups,
237
- norm_before_gate=self.norm_before_gate,
238
- **dt_limit_kwargs,
239
- )
240
- if seqlen_og is not None:
241
- out = rearrange(out, "b l d -> (b l) d")
242
- if self.process_group is not None:
243
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
244
- out = reduce_fn(out, self.process_group)
245
- else:
246
- d_mlp = (
247
- zxbcdt.shape[-1]
248
- - 2 * self.d_ssm
249
- - 2 * self.ngroups * self.d_state
250
- - self.nheads
251
- ) // 2
252
- z0, x0, z, xBC, dt = torch.split(
253
- zxbcdt,
254
- [
255
- d_mlp,
256
- d_mlp,
257
- self.d_ssm,
258
- self.d_ssm + 2 * self.ngroups * self.d_state,
259
- self.nheads,
260
- ],
261
- dim=-1,
262
- )
263
- if conv_state is not None:
264
- if cu_seqlens is None:
265
- # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
266
- # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
267
- xBC_t = rearrange(xBC, "b l d -> b d l")
268
- conv_state.copy_(
269
- F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
270
- ) # Update state (B D W)
271
- else:
272
- assert (
273
- causal_conv1d_varlen_states is not None
274
- ), "varlen inference requires causal_conv1d package"
275
- assert (
276
- batch == 1
277
- ), "varlen inference only supports batch dimension 1"
278
- conv_varlen_states = causal_conv1d_varlen_states(
279
- xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
280
- )
281
- conv_state.copy_(conv_varlen_states)
282
- assert self.activation in ["silu", "swish"]
283
- if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
284
- assert (
285
- seq_idx is None
286
- ), "varlen conv1d requires the causal_conv1d package"
287
- xBC = self.act(
288
- self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
289
- :, : -(self.d_conv - 1)
290
- ]
291
- ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
292
- else:
293
- xBC = causal_conv1d_fn(
294
- xBC.transpose(1, 2),
295
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
296
- bias=self.conv1d.bias,
297
- activation=self.activation,
298
- seq_idx=seq_idx,
299
- ).transpose(1, 2)
300
- x, B, C = torch.split(
301
- xBC,
302
- [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
303
- dim=-1,
304
- )
305
- y = mamba_chunk_scan_combined(
306
- rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
307
- dt,
308
- A,
309
- rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
310
- rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
311
- chunk_size=self.chunk_size,
312
- D=(
313
- rearrange(self.D, "(h p) -> h p", p=self.headdim)
314
- if self.D_has_hdim
315
- else self.D
316
- ),
317
- z=(
318
- rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
319
- if not self.rmsnorm
320
- else None
321
- ),
322
- dt_bias=self.dt_bias,
323
- dt_softplus=True,
324
- seq_idx=seq_idx,
325
- cu_seqlens=cu_seqlens,
326
- **dt_limit_kwargs,
327
- return_final_states=ssm_state is not None,
328
- return_varlen_states=cu_seqlens is not None
329
- and inference_params is not None,
330
- )
331
- if ssm_state is not None:
332
- y, last_state, *rest = y
333
- if cu_seqlens is None:
334
- ssm_state.copy_(last_state)
335
- else:
336
- varlen_states = rest[0]
337
- ssm_state.copy_(varlen_states)
338
- y = rearrange(y, "b l h p -> b l (h p)")
339
- if self.rmsnorm:
340
- y = self.norm(y, z)
341
- if d_mlp > 0:
342
- y = torch.cat([F.silu(z0) * x0, y], dim=-1)
343
- if seqlen_og is not None:
344
- y = rearrange(y, "b l d -> (b l) d")
345
- out = self.out_proj(y)
346
- return out
347
-
348
- def step(self, hidden_states, conv_state, ssm_state):
349
- dtype = hidden_states.dtype
350
- assert (
351
- hidden_states.shape[1] == 1
352
- ), "Only support decoding with 1 token at a time for now"
353
- zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
354
- d_mlp = (
355
- zxbcdt.shape[-1]
356
- - 2 * self.d_ssm
357
- - 2 * self.ngroups * self.d_state
358
- - self.nheads
359
- ) // 2
360
- z0, x0, z, xBC, dt = torch.split(
361
- zxbcdt,
362
- [
363
- d_mlp,
364
- d_mlp,
365
- self.d_ssm,
366
- self.d_ssm + 2 * self.ngroups * self.d_state,
367
- self.nheads,
368
- ],
369
- dim=-1,
370
- )
371
-
372
- # Conv step
373
- if causal_conv1d_update is None:
374
- conv_state.copy_(
375
- torch.roll(conv_state, shifts=-1, dims=-1)
376
- ) # Update state (B D W)
377
- conv_state[:, :, -1] = xBC
378
- xBC = torch.sum(
379
- conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
380
- ) # (B D)
381
- if self.conv1d.bias is not None:
382
- xBC = xBC + self.conv1d.bias
383
- xBC = self.act(xBC).to(dtype=dtype)
384
- else:
385
- xBC = causal_conv1d_update(
386
- xBC,
387
- conv_state,
388
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
389
- self.conv1d.bias,
390
- self.activation,
391
- )
392
-
393
- x, B, C = torch.split(
394
- xBC,
395
- [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
396
- dim=-1,
397
- )
398
- A = -torch.exp(self.A_log.float()) # (nheads,)
399
-
400
- # SSM step
401
- if selective_state_update is None:
402
- assert (
403
- self.ngroups == 1
404
- ), "Only support ngroups=1 for this inference code path"
405
- # Discretize A and B
406
- dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
407
- dA = torch.exp(dt * A) # (batch, nheads)
408
- x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
409
- dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
410
- ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
411
- y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
412
- y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
413
- y = rearrange(y, "b h p -> b (h p)")
414
- if not self.rmsnorm:
415
- y = y * self.act(z) # (B D)
416
- else:
417
- A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
418
- dtype=torch.float32
419
- )
420
- dt = repeat(dt, "b h -> b h p", p=self.headdim)
421
- dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
422
- D = repeat(self.D, "h -> h p", p=self.headdim)
423
- B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
424
- C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
425
- x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
426
- if not self.rmsnorm:
427
- z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
428
- y = selective_state_update(
429
- ssm_state,
430
- x_reshaped,
431
- dt,
432
- A,
433
- B,
434
- C,
435
- D,
436
- z=z if not self.rmsnorm else None,
437
- dt_bias=dt_bias,
438
- dt_softplus=True,
439
- )
440
- y = rearrange(y, "b h p -> b (h p)")
441
- if self.rmsnorm:
442
- y = self.norm(y, z)
443
- if d_mlp > 0:
444
- y = torch.cat([F.silu(z0) * x0, y], dim=-1)
445
- out = self.out_proj(y)
446
- return out.unsqueeze(1), conv_state, ssm_state
447
-
448
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
449
- device = self.out_proj.weight.device
450
- conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
451
- conv_state = torch.zeros(
452
- batch_size,
453
- self.d_conv,
454
- self.conv1d.weight.shape[0],
455
- device=device,
456
- dtype=conv_dtype,
457
- ).transpose(1, 2)
458
- ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
459
- ssm_state = torch.zeros(
460
- batch_size,
461
- self.nheads,
462
- self.headdim,
463
- self.d_state,
464
- device=device,
465
- dtype=ssm_dtype,
466
- )
467
- return conv_state, ssm_state
468
-
469
- def _get_states_from_cache(
470
- self, inference_params, batch_size, initialize_states=False
471
- ):
472
- assert self.layer_idx is not None
473
- if self.layer_idx not in inference_params.key_value_memory_dict:
474
- batch_shape = (batch_size,)
475
- conv_state = torch.zeros(
476
- batch_size,
477
- self.d_conv,
478
- self.conv1d.weight.shape[0],
479
- device=self.conv1d.weight.device,
480
- dtype=self.conv1d.weight.dtype,
481
- ).transpose(1, 2)
482
- ssm_state = torch.zeros(
483
- batch_size,
484
- self.nheads,
485
- self.headdim,
486
- self.d_state,
487
- device=self.in_proj.weight.device,
488
- dtype=self.in_proj.weight.dtype,
489
- )
490
- inference_params.key_value_memory_dict[self.layer_idx] = (
491
- conv_state,
492
- ssm_state,
493
- )
494
- else:
495
- conv_state, ssm_state = inference_params.key_value_memory_dict[
496
- self.layer_idx
497
- ]
498
- # TODO: What if batch size changes between generation, and we reuse the same states?
499
- if initialize_states:
500
- conv_state.zero_()
501
- ssm_state.zero_()
502
- return conv_state, ssm_state