YongganFu commited on
Commit
3357a8e
·
verified ·
1 Parent(s): bf1039a

Upload FastSLMForCausalLM

Browse files
config.json CHANGED
@@ -1,49 +1,28 @@
1
  {
2
  "architectures": [
3
- "JambaForCausalLM"
4
  ],
5
  "attention_dropout": 0.0,
6
  "attn_hidden_size": -1,
7
- "attn_implementation": "flash_attention_2",
8
- "attn_implementation_new": "flash_attention_2",
9
- "attn_layer_offset": 4,
10
- "attn_layer_period": 8,
11
- "attn_reuse_every_i_layer": -1,
12
  "auto_map": {
13
- "AutoConfig": "configuration_jamba.JambaConfig",
14
- "AutoModelForCausalLM": "modeling_jamba.JambaForCausalLM"
15
  },
16
  "bos_token_id": 1,
17
  "calc_logits_for_entire_prompt": false,
18
- "compact_gating": false,
19
- "compute_attn_mat": false,
20
  "d_conv": 4,
21
- "dense_public_ffn_structure": false,
22
- "double_v_dim": false,
23
- "enable_mod": false,
24
  "eos_token_id": 2,
25
- "expert_layer_offset": 1,
26
- "expert_layer_period": 2,
27
  "ffn_expand_ratio": 3,
28
- "ffn_reuse_every_i_layer": -1,
29
- "ffn_sharing_config": null,
30
- "fully_parallel_jamba": false,
31
- "fused_multihead_config": null,
32
  "global_attn_idx": [],
33
- "gradient_checkpoint_layer": null,
34
- "hash_grid_config": null,
35
- "hash_grid_config_mlp": null,
36
  "hidden_act": "silu",
37
  "hidden_size": 3072,
38
- "hybrid_block_indices": [],
39
  "hybrid_decoder_layer": "mamba",
40
  "initializer_range": 0.02,
41
  "intermediate_size": 0,
42
  "kq_head_dim": -1,
43
  "kq_norm": "none",
44
- "kv_reuse_every_i_layer": -1,
45
- "kv_reuse_group": null,
46
- "kv_weight_reuse": false,
47
  "layer_type": [
48
  "m",
49
  "a",
@@ -120,89 +99,38 @@
120
  "m2",
121
  "f"
122
  ],
123
- "layerwise_memory_token": false,
124
- "local_expand_ratio": 1,
125
- "local_global_dual_branch": false,
126
- "local_global_dual_branch_merge_op": "mean",
127
- "lookback_mode": "",
128
- "macro_arch": "",
129
  "mamba2_headdim": 64,
130
- "mamba_attnaug_config": null,
131
  "mamba_conv_bias": true,
132
  "mamba_d_conv": 4,
133
  "mamba_d_state": 128,
134
  "mamba_dt_rank": 192,
135
  "mamba_expand": 2,
136
  "mamba_inner_layernorms": true,
137
- "mamba_latent_size": null,
138
- "mamba_multihead_config": null,
139
  "mamba_proj_bias": false,
140
- "mamba_reuse_every_i_layer": -1,
141
- "max_position_embeddings": 22528,
142
- "memory_tokens_interspersed_every": 0,
143
  "mlp_hidden_act": "silu",
144
- "mod_topk": 2,
145
  "model_type": "jamba",
146
- "moe_config": null,
147
- "nGPT_config": {
148
- "extra_grad": false,
149
- "gate_scaling": false,
150
- "init_norm": false,
151
- "learned_scaling": false,
152
- "norm_bc": false,
153
- "norm_gating": false,
154
- "norm_ssm_input": false,
155
- "post_norm": false,
156
- "qk_norm": false,
157
- "weight_norm": true
158
- },
159
- "nGPT_mode": null,
160
  "new_seq_length": 2048,
161
- "no_dt_bias": false,
162
  "num_attention_heads": 24,
163
- "num_attn_per_ffn": 3,
164
  "num_experts": 1,
165
  "num_experts_per_tok": 1,
166
- "num_ffn": 1,
167
  "num_hidden_layers": 36,
168
  "num_key_value_heads": 6,
169
- "num_mamba": 1,
170
  "num_memory_tokens": 256,
171
  "orig_max_position_embeddings": 4096,
172
- "other_args": null,
173
  "output_router_logits": false,
174
  "pad_token_id": 0,
175
- "public_ffn_structure": false,
176
- "pure_linear_attn": false,
177
- "reduce_attn_ratio": 0.5,
178
- "reduce_method": "mean",
179
- "repeat_ffn": null,
180
  "rms_norm_eps": 1e-06,
181
  "rope": true,
182
  "rope_theta": 10000.0,
183
  "rope_type": "ntk",
184
  "router_aux_loss_coef": 0.001,
185
- "save_input_output": false,
186
- "self_attn_type": null,
187
- "seq_length": 1024,
188
- "sequential_jamba": false,
189
- "share_kv": false,
190
- "shared_module_attn": "",
191
- "shared_module_mamba": "",
192
  "sliding_window": null,
193
- "sliding_window_size": null,
194
- "supernet_config": null,
195
- "swa_full_head": false,
196
  "tie_word_embeddings": true,
197
  "torch_dtype": "bfloat16",
198
- "transformers_version": "4.45.0",
199
  "use_cache": false,
200
- "use_mamba2": false,
201
  "use_mamba_kernels": true,
202
- "use_nGPT": true,
203
- "use_nemotron5": false,
204
  "v_head_dim": -1,
205
- "visual_attn": false,
206
- "visual_entropy": false,
207
  "vocab_size": 131072
208
  }
 
1
  {
2
  "architectures": [
3
+ "FastSLMForCausalLM"
4
  ],
5
  "attention_dropout": 0.0,
6
  "attn_hidden_size": -1,
7
+ "attn_implementation": "fused_mha",
8
+ "attn_implementation_new": "fused_mha",
 
 
 
9
  "auto_map": {
10
+ "AutoConfig": "configuration_fast_slm.FastSLMConfig",
11
+ "AutoModelForCausalLM": "modeling_fast_slm.FastSLMForCausalLM"
12
  },
13
  "bos_token_id": 1,
14
  "calc_logits_for_entire_prompt": false,
 
 
15
  "d_conv": 4,
 
 
 
16
  "eos_token_id": 2,
 
 
17
  "ffn_expand_ratio": 3,
 
 
 
 
18
  "global_attn_idx": [],
 
 
 
19
  "hidden_act": "silu",
20
  "hidden_size": 3072,
 
21
  "hybrid_decoder_layer": "mamba",
22
  "initializer_range": 0.02,
23
  "intermediate_size": 0,
24
  "kq_head_dim": -1,
25
  "kq_norm": "none",
 
 
 
26
  "layer_type": [
27
  "m",
28
  "a",
 
99
  "m2",
100
  "f"
101
  ],
 
 
 
 
 
 
102
  "mamba2_headdim": 64,
 
103
  "mamba_conv_bias": true,
104
  "mamba_d_conv": 4,
105
  "mamba_d_state": 128,
106
  "mamba_dt_rank": 192,
107
  "mamba_expand": 2,
108
  "mamba_inner_layernorms": true,
 
 
109
  "mamba_proj_bias": false,
110
+ "max_position_embeddings": 29000,
 
 
111
  "mlp_hidden_act": "silu",
 
112
  "model_type": "jamba",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  "new_seq_length": 2048,
 
114
  "num_attention_heads": 24,
 
115
  "num_experts": 1,
116
  "num_experts_per_tok": 1,
 
117
  "num_hidden_layers": 36,
118
  "num_key_value_heads": 6,
 
119
  "num_memory_tokens": 256,
120
  "orig_max_position_embeddings": 4096,
 
121
  "output_router_logits": false,
122
  "pad_token_id": 0,
 
 
 
 
 
123
  "rms_norm_eps": 1e-06,
124
  "rope": true,
125
  "rope_theta": 10000.0,
126
  "rope_type": "ntk",
127
  "router_aux_loss_coef": 0.001,
 
 
 
 
 
 
 
128
  "sliding_window": null,
 
 
 
129
  "tie_word_embeddings": true,
130
  "torch_dtype": "bfloat16",
131
+ "transformers_version": "4.48.2",
132
  "use_cache": false,
 
133
  "use_mamba_kernels": true,
 
 
134
  "v_head_dim": -1,
 
 
135
  "vocab_size": 131072
136
  }
configuration_fast_slm.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Jamba model configuration"""
16
+ import math
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class FastSLMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
28
+ Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the jamba-small architecture.
30
+
31
+ [ai21labs/jamba-small](https://huggingface.co/ai21labs/Jamba-v0.1)
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 65536):
39
+ Vocabulary size of the Jamba model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`JambaModel`]
41
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
42
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
43
+ model has a output word embedding layer.
44
+ hidden_size (`int`, *optional*, defaults to 4096):
45
+ Dimension of the hidden representations.
46
+ intermediate_size (`int`, *optional*, defaults to 14336):
47
+ Dimension of the MLP representations.
48
+ num_hidden_layers (`int`, *optional*, defaults to 32):
49
+ Number of hidden layers in the Transformer encoder.
50
+ num_attention_heads (`int`, *optional*, defaults to 32):
51
+ Number of attention heads for each attention layer in the Transformer encoder.
52
+ num_key_value_heads (`int`, *optional*, defaults to 8):
53
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57
+ by meanpooling all the original heads within that group. For more details checkout [this
58
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
+ The epsilon used by the rms normalization layers.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ calc_logits_for_entire_prompt (`bool`, *optional*, defaults to `False`):
69
+ Whether or not to calculate logits for entire prompt during generation. If `False`, only the logits of the
70
+ last prompt token will be calculated, which are the only logits needed for generation. For long sequences,
71
+ the logits for the entire sequence may use a lot of memory so setting `calc_logits_for_entire_prompt=False`
72
+ will reduce memory footprint significantly.
73
+ Note: some generation features may not be available if this is set to `False`.
74
+ output_router_logits (`bool`, *optional*, defaults to `False`):
75
+ Whether or not the router logits should be returned by the model. Enabling this will also
76
+ allow the model to output the auxiliary loss. See [here]() for more details
77
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
78
+ The aux loss factor for the total loss.
79
+ pad_token_id (`int`, *optional*, defaults to 0):
80
+ The id of the padding token.
81
+ bos_token_id (`int`, *optional*, defaults to 1):
82
+ The id of the "beginning-of-sequence" token.
83
+ eos_token_id (`int`, *optional*, defaults to 2):
84
+ The id of the "end-of-sequence" token.
85
+ sliding_window (`int`, *optional*):
86
+ Sliding window attention window size. If not specified, will default to `None`.
87
+ n_ctx (`int`, *optional*, defaults to 262144):
88
+ This value doesn't have any real effect. The maximum sequence length that this model is intended to be
89
+ used with. It can be used with longer sequences, but performance may degrade.
90
+ attention_dropout (`float`, *optional*, defaults to 0.0):
91
+ The dropout ratio for the attention probabilities.
92
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
93
+ The number of experts to root per-token, can be also interpreted as the `top-p` routing
94
+ parameter
95
+ num_experts (`int`, *optional*, defaults to 16):
96
+ Number of experts per Sparse MLP layer.
97
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
98
+ Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
99
+ `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
100
+ `True` and kernels are not available
101
+ mamba_d_state (`int`, *optional*, defaults to 16):
102
+ The dimension the mamba state space latents
103
+ mamba_d_conv (`int`, *optional*, defaults to 4):
104
+ The size of the mamba convolution kernel
105
+ mamba_expand (`int`, *optional*, defaults to 2):
106
+ Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
107
+ mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
108
+ Rank of the the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
109
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
110
+ Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
111
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
112
+ Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
113
+ mamba_inner_layernorms (`bool`, *optional*, defaults to `True`):
114
+ Flag indicating whether or not to apply layernorms to internal mamba activations
115
+
116
+ """
117
+
118
+ model_type = "jamba"
119
+ keys_to_ignore_at_inference = ["past_key_values"]
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_size=65536,
124
+ tie_word_embeddings=False,
125
+ hidden_size=4096,
126
+ intermediate_size=14336,
127
+ num_hidden_layers=32,
128
+ num_attention_heads=32,
129
+ num_key_value_heads=8,
130
+ hidden_act="silu",
131
+ initializer_range=0.02,
132
+ rms_norm_eps=1e-6,
133
+ use_cache=True,
134
+ calc_logits_for_entire_prompt=False,
135
+ output_router_logits=False,
136
+ router_aux_loss_coef=0.001,
137
+ pad_token_id=0,
138
+ bos_token_id=1,
139
+ eos_token_id=2,
140
+ sliding_window=None,
141
+ max_position_embeddings=262144,
142
+ orig_max_position_embeddings=None,
143
+ attention_dropout=0.0,
144
+ num_experts_per_tok=2,
145
+ num_experts=16,
146
+ use_mamba_kernels=True,
147
+ mamba_d_state=16,
148
+ mamba_d_conv=4,
149
+ mamba_expand=2,
150
+ mamba_dt_rank="auto",
151
+ mamba_conv_bias=True,
152
+ mamba_proj_bias=False,
153
+ mamba_inner_layernorms=True,
154
+
155
+ hybrid_decoder_layer='mamba',
156
+
157
+ global_attn_idx=None,
158
+
159
+ attn_implementation_new='flash_attention_2',
160
+
161
+ mamba2_headdim=64,
162
+
163
+ rope_type=None,
164
+
165
+ layer_types=None,
166
+
167
+ ffn_expand_ratio=None,
168
+
169
+ d_conv=4,
170
+
171
+ **kwargs,
172
+ ):
173
+ self.vocab_size = vocab_size
174
+ self.tie_word_embeddings = tie_word_embeddings
175
+ self.hidden_size = hidden_size
176
+ self.intermediate_size = intermediate_size
177
+ self.num_hidden_layers = num_hidden_layers
178
+ self.num_attention_heads = num_attention_heads
179
+ self.sliding_window = sliding_window
180
+ self.max_position_embeddings = max_position_embeddings
181
+ self.orig_max_position_embeddings = orig_max_position_embeddings
182
+ self.attention_dropout = attention_dropout
183
+
184
+ # for backward compatibility
185
+ if num_key_value_heads is None:
186
+ num_key_value_heads = num_attention_heads
187
+
188
+ self.num_key_value_heads = num_key_value_heads
189
+ self.hidden_act = hidden_act
190
+ self.initializer_range = initializer_range
191
+ self.rms_norm_eps = rms_norm_eps
192
+
193
+ self.use_cache = use_cache
194
+ self.calc_logits_for_entire_prompt = calc_logits_for_entire_prompt
195
+ self.output_router_logits = output_router_logits
196
+ self.router_aux_loss_coef = router_aux_loss_coef
197
+
198
+ self.num_experts_per_tok = num_experts_per_tok
199
+ self.num_experts = num_experts
200
+
201
+ self.use_mamba_kernels = use_mamba_kernels
202
+ self.mamba_d_state = mamba_d_state
203
+ self.mamba_d_conv = mamba_d_conv
204
+ self.mamba_expand = mamba_expand
205
+ self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
206
+ self.mamba_conv_bias = mamba_conv_bias
207
+ self.mamba_proj_bias = mamba_proj_bias
208
+ self.mamba_inner_layernorms = mamba_inner_layernorms
209
+
210
+ # added by Xin
211
+ self.kq_norm = kwargs.pop("kq_norm", None)
212
+ self.rope = kwargs.pop("rope", False)
213
+ self.rope_theta = kwargs.pop("rope_theta", 10000.0)
214
+ self.num_memory_tokens = kwargs.pop("num_memory_tokens", 0)
215
+ self.attn_hidden_size = kwargs.pop("attn_hidden_size", -1)
216
+ self.kq_head_dim = kwargs.pop("kq_head_dim", -1)
217
+ self.v_head_dim = kwargs.pop("v_head_dim", -1)
218
+
219
+ #! adhoc change
220
+ self.new_seq_length = 2048
221
+
222
+ self.hybrid_decoder_layer = hybrid_decoder_layer
223
+
224
+ self.global_attn_idx = global_attn_idx
225
+
226
+ self.attn_implementation_new = attn_implementation_new
227
+
228
+ self.mamba2_headdim = mamba2_headdim
229
+
230
+ self.rope_type = rope_type
231
+
232
+ self.layer_types = layer_types
233
+
234
+ self.ffn_expand_ratio = ffn_expand_ratio
235
+
236
+ self.d_conv = d_conv
237
+
238
+ self.mlp_hidden_act = kwargs.pop("mlp_hidden_act", "silu")
239
+
240
+ super().__init__(
241
+ pad_token_id=pad_token_id,
242
+ bos_token_id=bos_token_id,
243
+ eos_token_id=eos_token_id,
244
+ tie_word_embeddings=tie_word_embeddings,
245
+ **kwargs,
246
+ )
delta_net.py CHANGED
@@ -10,9 +10,15 @@ import torch.nn as nn
10
  from einops import rearrange
11
  from torch.nn import functional as F
12
 
 
13
  from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
14
  from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
15
 
 
 
 
 
 
16
  if TYPE_CHECKING:
17
  from transformers.processing_utils import Unpack
18
 
@@ -97,12 +103,6 @@ class DeltaNet(nn.Module):
97
 
98
  assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
99
  assert self.qk_norm in ['l2', 'sum']
100
-
101
- self.config = config
102
- if self.config is not None and self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']:
103
- self.weight_norm = True
104
- else:
105
- self.weight_norm = False
106
 
107
  if d_model is not None:
108
  hidden_size = d_model
@@ -199,7 +199,7 @@ class DeltaNet(nn.Module):
199
  last_state = None
200
  if past_key_values is not None and len(past_key_values) > self.layer_idx:
201
  last_state = past_key_values[self.layer_idx]
202
-
203
  if self.use_short_conv:
204
  conv_state_q, conv_state_k, conv_state_v = None, None, None
205
  if last_state is not None:
@@ -208,9 +208,7 @@ class DeltaNet(nn.Module):
208
  position_ids = kwargs.get('position_ids', None)
209
 
210
  q = self.q_proj(hidden_states)
211
- if self.weight_norm:
212
- q = q / self.q_proj.weight.norm(p=2, dim=1)
213
-
214
  q, conv_state_q = self.q_conv1d(x=q,
215
  mask=conv_mask,
216
  cache=conv_state_q,
@@ -218,8 +216,7 @@ class DeltaNet(nn.Module):
218
  seq_idx=position_ids)
219
 
220
  k = self.k_proj(hidden_states)
221
- if self.weight_norm:
222
- k = k / self.k_proj.weight.norm(p=2, dim=1)
223
  k, conv_state_k = self.k_conv1d(x=k,
224
  mask=conv_mask,
225
  cache=conv_state_k,
@@ -227,8 +224,7 @@ class DeltaNet(nn.Module):
227
  seq_idx=position_ids)
228
 
229
  v = self.v_proj(hidden_states)
230
- if self.weight_norm:
231
- v = v / self.v_proj.weight.norm(p=2, dim=1)
232
  v, conv_state_v = self.v_conv1d(x=v,
233
  mask=conv_mask,
234
  cache=conv_state_v,
@@ -239,11 +235,6 @@ class DeltaNet(nn.Module):
239
  k = self.k_proj(hidden_states)
240
  v = self.v_proj(hidden_states)
241
 
242
- if self.weight_norm:
243
- q = q / self.q_proj.weight.norm(p=2, dim=1)
244
- k = k / self.k_proj.weight.norm(p=2, dim=1)
245
- v = v / self.v_proj.weight.norm(p=2, dim=1)
246
-
247
  if self.qk_activation == 'silu':
248
  q, k = self.silu(q), self.silu(k)
249
 
@@ -267,10 +258,6 @@ class DeltaNet(nn.Module):
267
 
268
  if self.use_beta:
269
  beta = self.b_proj(hidden_states)
270
-
271
- if self.weight_norm:
272
- beta = beta / self.b_proj.weight.norm(p=2, dim=1)
273
-
274
  beta = beta.sigmoid()
275
  else:
276
  beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
@@ -283,6 +270,7 @@ class DeltaNet(nn.Module):
283
  beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
284
 
285
  recurrent_state = last_state['recurrent_state'] if last_state is not None else None
 
286
  cu_seqlens = kwargs.get('cu_seqlens', None)
287
  if mode == 'fused_recurrent':
288
  o, recurrent_state = fused_recurrent_delta_rule(
@@ -327,7 +315,161 @@ class DeltaNet(nn.Module):
327
  o = rearrange(o, 'b t h d -> b t (h d)')
328
  o = self.o_proj(o)
329
 
330
- if self.weight_norm:
331
- o = o / self.o_proj.weight.norm(p=2, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
- return o, None, past_key_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from einops import rearrange
11
  from torch.nn import functional as F
12
 
13
+ import fla
14
  from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
15
  from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
16
 
17
+ from typing import Any, Dict, List, Optional, Tuple
18
+
19
+ import torch
20
+ import transformers
21
+
22
  if TYPE_CHECKING:
23
  from transformers.processing_utils import Unpack
24
 
 
103
 
104
  assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
105
  assert self.qk_norm in ['l2', 'sum']
 
 
 
 
 
 
106
 
107
  if d_model is not None:
108
  hidden_size = d_model
 
199
  last_state = None
200
  if past_key_values is not None and len(past_key_values) > self.layer_idx:
201
  last_state = past_key_values[self.layer_idx]
202
+
203
  if self.use_short_conv:
204
  conv_state_q, conv_state_k, conv_state_v = None, None, None
205
  if last_state is not None:
 
208
  position_ids = kwargs.get('position_ids', None)
209
 
210
  q = self.q_proj(hidden_states)
211
+
 
 
212
  q, conv_state_q = self.q_conv1d(x=q,
213
  mask=conv_mask,
214
  cache=conv_state_q,
 
216
  seq_idx=position_ids)
217
 
218
  k = self.k_proj(hidden_states)
219
+
 
220
  k, conv_state_k = self.k_conv1d(x=k,
221
  mask=conv_mask,
222
  cache=conv_state_k,
 
224
  seq_idx=position_ids)
225
 
226
  v = self.v_proj(hidden_states)
227
+
 
228
  v, conv_state_v = self.v_conv1d(x=v,
229
  mask=conv_mask,
230
  cache=conv_state_v,
 
235
  k = self.k_proj(hidden_states)
236
  v = self.v_proj(hidden_states)
237
 
 
 
 
 
 
238
  if self.qk_activation == 'silu':
239
  q, k = self.silu(q), self.silu(k)
240
 
 
258
 
259
  if self.use_beta:
260
  beta = self.b_proj(hidden_states)
 
 
 
 
261
  beta = beta.sigmoid()
262
  else:
263
  beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
 
270
  beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
271
 
272
  recurrent_state = last_state['recurrent_state'] if last_state is not None else None
273
+
274
  cu_seqlens = kwargs.get('cu_seqlens', None)
275
  if mode == 'fused_recurrent':
276
  o, recurrent_state = fused_recurrent_delta_rule(
 
315
  o = rearrange(o, 'b t h d -> b t (h d)')
316
  o = self.o_proj(o)
317
 
318
+ return o, None, past_key_values
319
+
320
+
321
+ class Cache(transformers.cache_utils.Cache):
322
+ """
323
+ A cache used for storing hidden states produced by flash linear attention models.
324
+
325
+ It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
326
+ """
327
+
328
+ is_compileable = True
329
+
330
+ def __init__(
331
+ self,
332
+ seen_tokens: int = 0
333
+ ) -> Cache:
334
+ super().__init__()
335
+
336
+ self.states: List[Dict[str, Any]] = []
337
+
338
+ self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
339
+
340
+ def __getitem__(self, layer_idx: int) -> Dict[str, Any]:
341
+ if layer_idx < len(self):
342
+ return self.states[layer_idx]
343
+ else:
344
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
345
+
346
+ def __iter__(self):
347
+ for state in self.states:
348
+ yield state
349
+
350
+ def __len__(self):
351
+ return len(self.states)
352
+
353
+ def reset(self):
354
+ for state in self.states:
355
+ for key in state:
356
+ if state[key] is not None:
357
+ if type(state[key]) == tuple:
358
+ for subkey in state[key]:
359
+ subkey.zero_()
360
+ else:
361
+ state[key].zero_()
362
+ self._seen_tokens = 0
363
+
364
 
365
+ def update(
366
+ self,
367
+ recurrent_state: Optional[Tuple[torch.Tensor]] = None,
368
+ attn_state: Optional[Tuple[torch.Tensor]] = None,
369
+ conv_state: Optional[Tuple[torch.Tensor]] = None,
370
+ ffn_state: Optional[Tuple[torch.Tensor]] = None,
371
+ layer_idx: int = 0,
372
+ offset: Optional[int] = 1,
373
+ cache_kwargs: Optional[Dict[str, Any]] = None,
374
+ ) -> Dict[str, Any]:
375
+ """
376
+ Args:
377
+ recurrent_state (`torch.Tensor`):
378
+ The new recurrent state to cache.
379
+ attn_state (`Tuple[torch.Tensor]`):
380
+ The new attention key/value states to cache.
381
+ conv_state (`Tuple[torch.Tensor]`):
382
+ The new convolution state to cache.
383
+ ffn_state (`Tuple[torch.Tensor]`):
384
+ The new feed-forward state to cache.
385
+ layer_idx (`int`, defaults to 0):
386
+ The index of the layer to cache the states for.
387
+ offset (`int`, defaults to 1):
388
+ The number of new tokens being processed.
389
+ cache_kwargs (`Dict[str, Any]`):
390
+ Additional arguments for the cache subclass.
391
+
392
+ Return:
393
+ Dictionary of the updated state.
394
+ """
395
+
396
+ if cache_kwargs is None:
397
+ cache_kwargs = {}
398
+ if attn_state is not None:
399
+ input_size = attn_state[0].shape[1]
400
+ window_size = cache_kwargs.get('window_size', None)
401
+ if not (isinstance(attn_state, Tuple) or isinstance(attn_state, List)):
402
+ raise ValueError("`attn_state` must be a tuple of tensors for key/value states")
403
+ if len(self.states) <= layer_idx:
404
+ # update the number of seen tokens
405
+ if layer_idx == 0:
406
+ self._seen_tokens += offset
407
+ if attn_state is not None:
408
+ if window_size is not None and input_size > window_size:
409
+ attn_state = [state[:, -window_size:].contiguous() for state in attn_state]
410
+ state = dict(
411
+ recurrent_state=recurrent_state,
412
+ attn_state=attn_state,
413
+ conv_state=conv_state,
414
+ ffn_state=ffn_state
415
+ )
416
+ self.states.append(state)
417
+ else:
418
+ # update the number of seen tokens
419
+ if layer_idx == len(self.states) - 1:
420
+ self._seen_tokens += offset
421
+ state = self.states[layer_idx]
422
+ if recurrent_state is not None:
423
+ state['recurrent_state'].copy_(recurrent_state)
424
+ if attn_state is not None:
425
+ if window_size is not None and state['attn_state'][0].shape[1] == window_size:
426
+ for i, (old_state, new_state) in enumerate(zip(state['attn_state'], attn_state)):
427
+ # DO NOT allocate new memory if the cache is full
428
+ # roll the key/value states to the left by `input_size`
429
+ old_state = old_state.roll(-input_size, 1)
430
+ # replace the last `input_size` tokens with the new key/value states
431
+ old_state[:, -input_size:] = new_state
432
+ state['attn_state'][i].copy_(old_state)
433
+ else:
434
+ attn_state = [
435
+ torch.cat([old_state, new_state], 1)
436
+ for old_state, new_state in zip(state['attn_state'], attn_state)
437
+ ]
438
+ state['attn_state'].copy_(attn_state)
439
+ if conv_state is not None:
440
+ conv_state_q, conv_state_k, conv_state_v = state['conv_state']
441
+ conv_state_q.copy_(conv_state[0])
442
+ conv_state_k.copy_(conv_state[1])
443
+ conv_state_v.copy_(conv_state[2])
444
+ if ffn_state is not None:
445
+ state['ffn_state'].copy_(ffn_state)
446
+
447
+ return state
448
+
449
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
450
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
451
+ if len(self.states) <= layer_idx:
452
+ return 0
453
+ return self._seen_tokens
454
+
455
+ def get_max_length(self) -> Optional[int]:
456
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
457
+ return None
458
+
459
+ def to_legacy_cache(self) -> Tuple:
460
+ return tuple(self.states)
461
+
462
+ @classmethod
463
+ @torch.compiler.disable
464
+ def from_legacy_cache(
465
+ cls,
466
+ past_key_values: Optional[Tuple] = None,
467
+ seen_tokens: int = 0
468
+ ) -> Cache:
469
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
470
+
471
+ cache = cls(seen_tokens)
472
+ if isinstance(past_key_values, list):
473
+ for layer_idx in range(len(past_key_values)):
474
+ cache.states.append(past_key_values[layer_idx])
475
+ return cache
fused_mha_with_cache.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Tuple
3
+
4
+ from .triton_attention import (
5
+ fused_mha_with_paged_cache, fused_mha_with_cache
6
+ )
7
+
8
+ dtype_int = torch.int32
9
+
10
+ def fused_mha_interface(
11
+ query_states: torch.Tensor, # [batch, q_len, heads, head_dim]
12
+ key_states: torch.Tensor, # [batch, kv_len, heads, head_dim]
13
+ value_states: torch.Tensor, # [batch, kv_len, heads, head_dim]
14
+ k_cache: torch.Tensor, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD] or [num_pages, page_size, n, d] for paged attn
15
+ v_cache: torch.Tensor, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
16
+ position_ids: torch.Tensor=None,
17
+ page_table: torch.Tensor=None, # [b, max_num_pages_per_seq] # loc of the block page in the cache.
18
+ max_seq_len = None,
19
+ ) -> torch.Tensor:
20
+ """
21
+ Replacement for _flash_attention_forward(...) that uses
22
+ Triton’s fused_mha_with_paged_cache under the hood.
23
+ Returns: [batch, q_len, heads*head_dim]
24
+ """
25
+ # unpack shapes
26
+ b, ql, n_heads, head_dim = query_states.shape
27
+ _, kvl, n_kv_heads, _ = key_states.shape
28
+
29
+ q = query_states.reshape(b, ql, n_heads * head_dim)
30
+ k = key_states.reshape(b, kvl, n_kv_heads * head_dim)
31
+ v = value_states.reshape(b, kvl, n_kv_heads * head_dim)
32
+
33
+ if position_ids is not None:
34
+ if ql == 1: # Generate phase - single token
35
+ input_pos = position_ids[:, -1] # Use the last position for each sequence
36
+ else: # Context phase - multiple tokens
37
+ input_pos = position_ids[:, 0] # Use the starting position for each sequence
38
+ else:
39
+ # Fallback: assume starting from 0 for all sequences
40
+ input_pos = torch.zeros(b, device=q.device, dtype=torch.int32)
41
+
42
+ freqs_cis = None
43
+
44
+ if page_table is None:
45
+ y = torch.ops.attention.fused_mha_with_cache(
46
+ q, k, v,
47
+ input_pos,
48
+ k_cache, v_cache,
49
+ freqs_cis,
50
+ )
51
+
52
+
53
+ else:
54
+ batch_size = b
55
+
56
+ # cache_loc: identity mapping [0, 1, ..., b-1]
57
+ cache_loc = torch.arange(batch_size, device=q.device, dtype=dtype_int)
58
+
59
+ # input_positions: assume pure context (all start from 0)
60
+ input_positions = torch.zeros(batch_size, device=q.device, dtype=dtype_int)
61
+
62
+ # seq_len: each sequence length is kvl
63
+ seq_len = torch.full((batch_size,), kvl, device=q.device, dtype=dtype_int)
64
+
65
+ # seq_start: flattened starting index for each sequence
66
+ seq_start = (seq_len.cumsum(0) - seq_len).to(dtype=dtype_int)
67
+
68
+ assert max_seq_len is not None, "max_seq_len must be provided when using paged attention."
69
+
70
+ y = torch.ops.attention.fused_mha_with_paged_cache(
71
+ q, k, v,
72
+ input_positions, cache_loc,
73
+ seq_len, seq_start,
74
+ page_table, max_seq_len,
75
+ k_cache, v_cache,
76
+ freqs_cis,
77
+ )
78
+
79
+ y = y.view(b, ql, n_heads, head_dim)
80
+
81
+ return y
82
+
83
+
84
+
85
+ def main():
86
+ #––– Test hyperparameters –––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
87
+ batch_size = 1
88
+ q_len = 1
89
+ kv_len = 1
90
+ num_heads = 16
91
+ n_kv_heads = 16
92
+ head_dim = 128
93
+
94
+ max_batch_size = 1
95
+ max_seq_len = 1024
96
+
97
+ page_size = 256
98
+
99
+ device = "cuda"
100
+
101
+ #––– Random query, key, value tensors –––––––––––––––––––––––––––––––––––––––––––––––––––
102
+ query_states = torch.randn(batch_size, q_len, num_heads, head_dim, device=device)
103
+ key_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device)
104
+ value_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device)
105
+
106
+ k_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device)
107
+ v_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device)
108
+
109
+ attn_out = fused_mha_interface(
110
+ query_states,
111
+ key_states,
112
+ value_states,
113
+ k_cache=k_cache,
114
+ v_cache=v_cache,
115
+ )
116
+
117
+ expected_shape = (batch_size, q_len, num_heads, head_dim)
118
+ print(f"[test] output shape: {attn_out.shape} (expected {expected_shape})")
119
+
120
+ if attn_out.shape == expected_shape:
121
+ print("[test] ✅ Success: output tensor has correct shape.")
122
+ else:
123
+ print("[test] ❌ Failure: shape mismatch.")
124
+
125
+ if __name__ == "__main__":
126
+ main()
mamba2.py CHANGED
@@ -18,9 +18,6 @@ try:
18
  except ImportError:
19
  causal_conv1d_varlen_states = None
20
 
21
- import sys
22
- # sys.path.insert(0, '/lustre/fsw/portfolios/nvr/users/yongganf/TLM/')
23
-
24
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
25
  from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
26
 
@@ -124,13 +121,10 @@ class Mamba2(nn.Module):
124
  # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
125
  inv_dt = dt + torch.log(-torch.expm1(-dt))
126
 
127
- if config.no_dt_bias:
128
- self.dt_bias = None
129
- else:
130
- self.dt_bias = nn.Parameter(inv_dt)
131
- # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
132
- # name.endswith("bias") in param_grouping.py
133
- self.dt_bias._no_weight_decay = True
134
 
135
  assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
136
  A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
@@ -154,39 +148,6 @@ class Mamba2(nn.Module):
154
  process_group=self.process_group, sequence_parallel=self.sequence_parallel,
155
  **factory_kwargs)
156
 
157
- self.mamba_multihead_config = config.mamba_multihead_config
158
- if self.mamba_multihead_config is not None:
159
- assert self.mamba_multihead_config['alpha_mode'] == 'sparsity' or self.mamba_multihead_config['alpha_mode'] == 'cummax'
160
-
161
- if self.mamba_multihead_config['alpha_mode'] == 'cummax':
162
- self.learned_dt_scale = nn.Parameter(torch.ones(1, device=device))
163
-
164
- if self.mamba_multihead_config['alpha_mode'] == 'sparsity':
165
- if 'use_learned_thres' in self.mamba_multihead_config and self.mamba_multihead_config['use_learned_thres']:
166
- self.learned_thres = nn.Parameter(torch.zeros(self.nheads, device=device))
167
- self.smooth_factor = self.mamba_multihead_config['smooth_factor']
168
- self.detach_dt = self.mamba_multihead_config['detach_dt']
169
-
170
- if 'use_cummax' in self.mamba_multihead_config and self.mamba_multihead_config['use_cummax']:
171
- self.use_cummax = True
172
- self.cummax_lower_bound = self.mamba_multihead_config['cummax_lower_bound']
173
- else:
174
- self.use_cummax = False
175
-
176
- else:
177
- self.learned_thres = None
178
- self.smooth_factor = None
179
- self.detach_dt = None
180
-
181
- self.sparsity_split = self.mamba_multihead_config['sparsity_split']
182
- self.sparsity_ratio = self.mamba_multihead_config['sparsity_ratio']
183
-
184
- if self.config.layerwise_memory_token:
185
- assert self.config.num_memory_tokens > 0
186
- self.memory_tokens = nn.Parameter(torch.randn(self.config.num_memory_tokens, self.config.hidden_size))
187
- else:
188
- self.memory_tokens = None
189
-
190
 
191
  def forward(self, hidden_states, attention_mask=None, past_key_value=None, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
192
  """
@@ -198,11 +159,6 @@ class Mamba2(nn.Module):
198
  """
199
  # assert past_key_value is None, "Not implemented yet!!!"
200
 
201
- if self.memory_tokens is not None:
202
- hidden_states = hidden_states[:,self.config.num_memory_tokens:,...]
203
- mem = repeat(self.memory_tokens, 'n d -> b n d', b = hidden_states.shape[0]) # prepend the memory to every segment of m by repeating the memory tokens
204
- hidden_states, mem_packed_shape = pack((mem, hidden_states), 'b * d')
205
-
206
  seqlen_og = seqlen
207
  if seqlen is None:
208
  batch, seqlen, dim = hidden_states.shape
@@ -211,19 +167,18 @@ class Mamba2(nn.Module):
211
  batch = batch_seqlen // seqlen
212
 
213
  conv_state, ssm_state = None, None
 
214
  if inference_params is not None:
215
  inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
216
  conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
 
217
  if inference_params.seqlen_offset > 0:
218
  # The states are updated inplace
219
  out, _, _ = self.step(hidden_states, conv_state, ssm_state)
220
- return out
221
 
222
  zxbcdt = self.in_proj(hidden_states) # (B, L, d_in_proj) or (B * L, d_in_proj)
223
 
224
- if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']:
225
- zxbcdt = zxbcdt / self.in_proj.weight.norm(p=2, dim=1)
226
-
227
  if seqlen_og is not None:
228
  zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
229
  # If the model is loaded in fp16, without the .float() here, A might be -inf
@@ -261,6 +216,7 @@ class Mamba2(nn.Module):
261
  [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
262
  dim=-1
263
  )
 
264
  if conv_state is not None:
265
  if cu_seqlens is None:
266
  # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
@@ -288,27 +244,9 @@ class Mamba2(nn.Module):
288
  activation=self.activation,
289
  # seq_idx=seq_idx,
290
  ).transpose(1, 2)
 
291
  x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
292
 
293
- no_dt_bias = False
294
- if self.mamba_multihead_config is not None and self.mamba_multihead_config['alpha_mode'] == 'cummax': ### todo: implement this in the fused kernel
295
- dt = dt + self.dt_bias
296
- dt = torch.nn.functional.softmax(dt, dim=-1)
297
- dt = torch.cumsum(dt, dim=-1)
298
- dt = dt * self.learned_dt_scale
299
-
300
- no_dt_bias = True
301
-
302
- if self.mamba_multihead_config is not None and self.mamba_multihead_config['alpha_mode'] == 'sparsity':
303
- dt = dt + self.dt_bias
304
-
305
- if self.learned_thres is not None:
306
- dt = self.sparsify_learned_thres(dt)
307
- else:
308
- dt = self.split_and_sparsify(dt, self.sparsity_split, self.sparsity_ratio)
309
-
310
- no_dt_bias = True
311
-
312
 
313
  y = mamba_chunk_scan_combined(
314
  rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
@@ -317,9 +255,10 @@ class Mamba2(nn.Module):
317
  rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
318
  rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
319
  chunk_size=self.chunk_size,
320
- D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
 
321
  z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
322
- dt_bias=self.dt_bias if not no_dt_bias else None,
323
  dt_softplus=True,
324
  seq_idx=seq_idx,
325
  cu_seqlens=cu_seqlens,
@@ -336,186 +275,153 @@ class Mamba2(nn.Module):
336
  ssm_state.copy_(varlen_states)
337
  y = rearrange(y, "b l h p -> b l (h p)")
338
  if self.rmsnorm:
339
- y = self.norm(y, z)
 
 
 
340
  if d_mlp > 0:
341
  y = torch.cat([F.silu(z0) * x0, y], dim=-1)
342
  if seqlen_og is not None:
343
  y = rearrange(y, "b l d -> (b l) d")
344
 
345
- if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']:
346
- y = y / self.out_proj.weight.norm(p=2, dim=0)
347
-
348
  out = self.out_proj(y)
349
-
350
  return out, past_key_value
351
 
352
 
353
- def sparsify_learned_thres(self, dt):
354
- """
355
- Args:
356
- dt: Tensor of shape [bs, seq_len, nheads]
357
- Returns:
358
- pruned_dt: Pruned tensor with the same shape as dt
359
- """
360
- # Compute sigmoid scores
361
 
362
- if self.use_cummax:
363
- learned_thres = torch.nn.functional.softmax(self.learned_thres, dim=-1)
364
- learned_thres = torch.cumsum(learned_thres, dim=-1) - self.cummax_lower_bound ## keep the dt_normalized larger than 1 - self.cummax_lower_bound
365
-
366
- dt_normalized = (dt - dt.min(dim=-1, keepdim=True)[0]) / (dt.max(dim=-1, keepdim=True)[0] - dt.min(dim=-1, keepdim=True)[0])
367
-
368
- scores = torch.sigmoid((dt_normalized.detach() - self.learned_thres) / self.smooth_factor)
369
-
370
  else:
371
- if self.detach_dt:
372
- scores = torch.sigmoid((dt.detach() - self.learned_thres) / self.smooth_factor)
373
- else:
374
- scores = torch.sigmoid((dt - self.learned_thres) / self.smooth_factor)
375
-
376
- # Generate binary mask for pruning (forward pass)
377
- mask = (scores >= 0.5).float()
378
-
379
- # Apply mask in the forward pass and backward using sigmoid
380
- pruned_dt = (dt * mask - dt * scores).detach() + dt * scores
381
-
382
- # print(pruned_dt.mean())
383
-
384
- return pruned_dt
385
-
386
-
387
- def split_and_sparsify(self, dt, sparsity_split, sparsity_ratio):
388
- """
389
- dt: a torch.Tensor of shape [bs, seq_len, dim]
390
- sparsity_split: list of ratios (e.g., [0.4, 0.3, 0.3]) that sum to 1
391
- and define how to split dt along the last dimension
392
- sparsity_ratio: list of ratios (e.g., [0.2, 0.5, 0.3]) that sum to 1
393
- and define how many time steps (along seq_len) to keep
394
- """
395
- bs, seq_len, dim = dt.shape
396
-
397
- assert sum(sparsity_split) == 1
398
-
399
- # Compute the exact split sizes (watching out for integer rounding)
400
- split_sizes = [int(r * dim) for r in sparsity_split]
401
- # Fix potential off-by-one rounding in the last split
402
- split_sizes[-1] = dim - sum(split_sizes[:-1])
403
-
404
- # Split the original tensor along the last dimension
405
- splitted_tensors = torch.split(dt, split_sizes, dim=-1)
406
-
407
- results = []
408
- for i, sub_tensor in enumerate(splitted_tensors):
409
- # sub_tensor has shape [bs, seq_len, split_dim_i]
410
- k = int(sparsity_ratio[i] * seq_len)
411
-
412
- ### Strategy 1: keep at least one token
413
- k = max(k, 1)
414
-
415
- ### Strategy 2: the #tokens is the same as training
416
- # if self.config.orig_max_position_embeddings is not None:
417
- # k = int(self.config.orig_max_position_embeddings * self.sparsity_ratio[i])
418
- # else:
419
- # assert self.config.max_position_embeddings is not None
420
- # k = int(self.config.max_position_embeddings * self.sparsity_ratio[i])
421
-
422
- # k = min(seq_len, k)
423
-
424
- # print(self.config.max_position_embeddings, sparsity_ratio[i], seq_len, k)
425
-
426
- # 1) Average over the feature dimension (the last dim),
427
- # resulting in shape [bs, seq_len]
428
- averaged_values = sub_tensor.mean(dim=-1)
429
-
430
- # 2) Get top-k indices (along seq_len = dim=1)
431
- topk_values, _ = torch.topk(averaged_values, k=k, dim=1)
432
- # The smallest value among the top-k per batch element
433
- threshold = topk_values[:, -1].unsqueeze(-1) # shape [bs, 1]
434
-
435
- # 3) Create a mask of shape [bs, seq_len] => True if >= threshold
436
- averaged_mask = (averaged_values >= threshold)
437
-
438
- # 4) Expand that mask back to [bs, seq_len, split_dim_i]
439
- mask_3d = averaged_mask.unsqueeze(-1).expand_as(sub_tensor)
440
-
441
- # 5) Zero out everything that is not in top-k
442
- sparsified_sub = sub_tensor * mask_3d
443
 
444
- # print((sparsified_sub == 0).float().mean().item())
445
- # input()
446
-
447
- results.append(sparsified_sub)
448
-
449
- # Concatenate the results back along the last dimension
450
- output = torch.cat(results, dim=-1)
451
- return output
452
-
453
- def step(self, hidden_states, conv_state, ssm_state):
454
- dtype = hidden_states.dtype
455
- assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
456
- zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
457
  d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
458
- z0, x0, z, xBC, dt = torch.split(
459
- zxbcdt,
460
- [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
461
- dim=-1
462
- )
463
-
464
- # Conv step
465
- if causal_conv1d_update is None:
466
- conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
467
- conv_state[:, :, -1] = xBC
468
- xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
469
- if self.conv1d.bias is not None:
470
- xBC = xBC + self.conv1d.bias
471
- xBC = self.act(xBC).to(dtype=dtype)
472
  else:
473
- xBC = causal_conv1d_update(
474
- xBC,
475
- conv_state,
476
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
477
- self.conv1d.bias,
478
- self.activation,
479
  )
480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
482
  A = -torch.exp(self.A_log.float()) # (nheads,)
483
 
484
- # SSM step
485
- if selective_state_update is None:
486
- assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
487
- # Discretize A and B
488
- dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
489
- dA = torch.exp(dt * A) # (batch, nheads)
490
- x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
491
- dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
492
- ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
493
- y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
494
- y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
495
- y = rearrange(y, "b h p -> b (h p)")
496
- if not self.rmsnorm:
497
- y = y * self.act(z) # (B D)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  else:
499
- A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
500
- dt = repeat(dt, "b h -> b h p", p=self.headdim)
501
- dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
502
- D = repeat(self.D, "h -> h p", p=self.headdim)
503
- B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
504
- C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
505
- x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
506
- if not self.rmsnorm:
507
- z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
508
- y = selective_state_update(
509
- ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
510
- dt_bias=dt_bias, dt_softplus=True
 
 
 
 
511
  )
512
- y = rearrange(y, "b h p -> b (h p)")
 
 
 
 
513
  if self.rmsnorm:
514
  y = self.norm(y, z)
515
  if d_mlp > 0:
516
  y = torch.cat([F.silu(z0) * x0, y], dim=-1)
517
  out = self.out_proj(y)
518
- return out.unsqueeze(1), conv_state, ssm_state
 
 
 
 
 
 
519
 
520
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
521
  device = self.out_proj.weight.device
@@ -555,873 +461,4 @@ class Mamba2(nn.Module):
555
  if initialize_states:
556
  conv_state.zero_()
557
  ssm_state.zero_()
558
- return conv_state, ssm_state
559
-
560
-
561
- class Mamba2_Fused(nn.Module):
562
- def __init__(
563
- self,
564
- config,
565
- layer_idx=None, # Absorb kwarg for general module
566
- reuse_kv=False,
567
- conv_init=None,
568
- d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
569
- ngroups=1,
570
- A_init_range=(1, 16),
571
- D_has_hdim=False,
572
- rmsnorm=True,
573
- norm_before_gate=False,
574
- dt_min=0.001,
575
- dt_max=0.1,
576
- dt_init_floor=1e-4,
577
- dt_limit=(0.0, float("inf")),
578
- bias=False,
579
- conv_bias=True,
580
- # Fused kernel and sharding options
581
- chunk_size=256,
582
- use_mem_eff_path=False, # True,
583
- process_group=None,
584
- sequence_parallel=True,
585
- device=None,
586
- dtype=None,
587
- ):
588
- factory_kwargs = {"device": device, "dtype": dtype}
589
- super().__init__()
590
-
591
- self.config = config
592
- self.d_model = config.hidden_size
593
- self.d_state = config.mamba_d_state
594
- self.d_conv = config.mamba_d_conv
595
-
596
- self.conv_init = conv_init
597
- self.expand = config.mamba_expand
598
- self.process_group = process_group
599
- self.sequence_parallel = sequence_parallel
600
- self.world_size = 1 if process_group is None else process_group.size()
601
- self.local_rank = 0 if process_group is None else process_group.rank()
602
- self.d_inner = (self.expand * self.d_model) // self.world_size
603
- assert self.d_inner * self.world_size == self.expand * self.d_model
604
- self.headdim = config.mamba2_headdim
605
- self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
606
- assert ngroups % self.world_size == 0
607
- self.ngroups = ngroups // self.world_size
608
- assert self.d_ssm % self.headdim == 0
609
- self.nheads = self.d_ssm // self.headdim
610
- self.D_has_hdim = D_has_hdim
611
- self.rmsnorm = rmsnorm
612
- self.norm_before_gate = norm_before_gate
613
- self.dt_limit = dt_limit
614
- self.activation = "silu"
615
- self.chunk_size = chunk_size
616
- self.use_mem_eff_path = use_mem_eff_path
617
- self.layer_idx = layer_idx
618
-
619
- assert (self.d_model * self.expand / self.headdim) % 8 == 0
620
-
621
- self.fused_multihead_config = config.fused_multihead_config
622
- assert self.fused_multihead_config['expand_v'], "Only implemented Hymba for Mamba"
623
-
624
- self.reuse_kv = reuse_kv
625
-
626
- self.hidden_size = config.hidden_size
627
- self.attn_hidden_size = config.hidden_size
628
- self.num_attention_heads = config.num_attention_heads
629
- self.num_key_value_heads = config.num_key_value_heads
630
-
631
- self.k_hidden_size = int(self.num_key_value_heads/self.num_attention_heads * self.attn_hidden_size)
632
- self.v_hidden_size = int(self.num_key_value_heads/self.num_attention_heads * self.attn_hidden_size * self.expand) if self.fused_multihead_config['expand_v'] else int(self.num_key_value_heads/self.num_attention_heads * self.attn_hidden_size)
633
-
634
- if self.fused_multihead_config['expand_v']:
635
- config.v_head_dim = self.d_inner // self.num_attention_heads
636
-
637
- self.self_attn = config.attn_op(config, layer_idx, attn_only_wo_proj=True, reuse_kv=reuse_kv)
638
-
639
- if self.reuse_kv: # Order: [q, z, x, B, C, dt]
640
- d_in_proj = self.attn_hidden_size + 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
641
- else: # Order: [q, k, v, z, x, B, C, dt]
642
- d_in_proj = self.attn_hidden_size + self.k_hidden_size + self.v_hidden_size + 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
643
-
644
- if self.process_group is None:
645
- self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
646
- else:
647
- self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
648
- process_group=self.process_group, sequence_parallel=self.sequence_parallel,
649
- **factory_kwargs)
650
-
651
- self.pre_avg_layernorm1 = JambaRMSNorm(self.d_inner, eps=config.rms_norm_eps)
652
- self.pre_avg_layernorm2 = JambaRMSNorm(self.d_inner, eps=config.rms_norm_eps)
653
-
654
- conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
655
- self.conv1d = nn.Conv1d(
656
- in_channels=conv_dim,
657
- out_channels=conv_dim,
658
- bias=conv_bias,
659
- kernel_size=self.d_conv,
660
- groups=conv_dim,
661
- padding=self.d_conv - 1,
662
- **factory_kwargs,
663
- )
664
- if self.conv_init is not None:
665
- nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
666
-
667
- self.act = nn.SiLU()
668
-
669
- # Initialize log dt bias
670
- dt = torch.exp(
671
- torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
672
- + math.log(dt_min)
673
- )
674
- dt = torch.clamp(dt, min=dt_init_floor)
675
- # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
676
- inv_dt = dt + torch.log(-torch.expm1(-dt))
677
- self.dt_bias = nn.Parameter(inv_dt)
678
- # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
679
- # name.endswith("bias") in param_grouping.py
680
- self.dt_bias._no_weight_decay = True
681
-
682
- assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
683
- A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
684
- A_log = torch.log(A).to(dtype=dtype)
685
- self.A_log = nn.Parameter(A_log)
686
- self.A_log._no_weight_decay = True
687
-
688
- # D "skip" parameter
689
- self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))
690
- self.D._no_weight_decay = True
691
-
692
- if self.rmsnorm:
693
- assert RMSNormGated is not None
694
- self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
695
- group_size=self.d_ssm // ngroups, **factory_kwargs)
696
-
697
- if self.process_group is None:
698
- self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
699
- else:
700
- self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
701
- process_group=self.process_group, sequence_parallel=self.sequence_parallel,
702
- **factory_kwargs)
703
-
704
- def forward(self, hidden_states, attention_mask=None, past_key_value=None, position_ids=None, kv_last_layer=None, use_cache=False, use_swa=False, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
705
- """
706
- hidden_states: (batch, seqlen, hidden_dim) if seqlen=None.
707
- If seqlen is not None, hidden_states is (batch * seqlen, hidden_dim). This is so that when we
708
- split hidden_states during sequence parallel, we split the batch * seqlen dimension
709
- (in case batch is small).
710
- Returns: same shape as u
711
- """
712
- # assert past_key_value is None, "Not implemented yet!!!"
713
-
714
- seqlen_og = seqlen
715
- if seqlen is None:
716
- batch, seqlen, dim = hidden_states.shape
717
- else:
718
- batch_seqlen, dim = hidden_states.shape
719
- batch = batch_seqlen // seqlen
720
-
721
- conv_state, ssm_state = None, None
722
- if inference_params is not None:
723
- inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
724
- conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
725
- if inference_params.seqlen_offset > 0:
726
- # The states are updated inplace
727
- out, _, _ = self.step(hidden_states, conv_state, ssm_state)
728
- return out
729
-
730
- zxbcdt = self.in_proj(hidden_states) # (B, L, d_in_proj) or (B * L, d_in_proj)
731
-
732
- if self.reuse_kv:
733
- query_states, zxbcdt = zxbcdt.tensor_split((self.attn_hidden_size,), dim=-1)
734
- # query_states = query_states.transpose(1,2)
735
- else:
736
- query_states, key_states, value_states, zxbcdt = zxbcdt.tensor_split((self.attn_hidden_size, self.attn_hidden_size + self.k_hidden_size, self.attn_hidden_size + self.k_hidden_size + self.v_hidden_size), dim=-1)
737
-
738
- # query_states = query_states.transpose(1,2)
739
- # key_states = key_states.transpose(1,2)
740
- # value_states = value_states.transpose(1,2)
741
-
742
- if self.reuse_kv:
743
- assert kv_last_layer is not None
744
- attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, kv_last_layer=kv_last_layer, use_swa=use_swa, use_cache=use_cache, past_key_value=past_key_value)
745
- else:
746
- if 'use_linear_attn' in self.fused_multihead_config and self.fused_multihead_config['use_linear_attn'] and self.linear_attn_op == 'gla':
747
- attn_outputs, _, attn_key_value = self.self_attn(hidden_states=value_states, position_ids=position_ids, attention_mask=attention_mask, Q=query_states, K=key_states, V=value_states, past_key_value=past_key_value)
748
- else:
749
- attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, key_states=key_states, value_states=value_states, use_swa=use_swa, use_cache=use_cache, past_key_value=past_key_value)
750
-
751
-
752
- if seqlen_og is not None:
753
- zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
754
- # If the model is loaded in fp16, without the .float() here, A might be -inf
755
- A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
756
- dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
757
- if self.use_mem_eff_path and inference_params is None:
758
- out = mamba_split_conv1d_scan_combined(
759
- zxbcdt,
760
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
761
- self.conv1d.bias,
762
- self.dt_bias,
763
- A,
764
- D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
765
- chunk_size=self.chunk_size,
766
- seq_idx=seq_idx,
767
- activation=self.activation,
768
- rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
769
- rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
770
- outproj_weight=self.out_proj.weight,
771
- outproj_bias=self.out_proj.bias,
772
- headdim=None if self.D_has_hdim else self.headdim,
773
- ngroups=self.ngroups,
774
- norm_before_gate=self.norm_before_gate,
775
- **dt_limit_kwargs,
776
- )
777
- if seqlen_og is not None:
778
- out = rearrange(out, "b l d -> (b l) d")
779
- if self.process_group is not None:
780
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
781
- out = reduce_fn(out, self.process_group)
782
- else:
783
- d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
784
-
785
- z0, x0, z, xBC, dt = torch.split(
786
- zxbcdt,
787
- [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
788
- dim=-1
789
- )
790
- if conv_state is not None:
791
- if cu_seqlens is None:
792
- # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
793
- # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
794
- xBC_t = rearrange(xBC, "b l d -> b d l")
795
- conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
796
- else:
797
- assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
798
- assert batch == 1, "varlen inference only supports batch dimension 1"
799
- conv_varlen_states = causal_conv1d_varlen_states(
800
- xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
801
- )
802
- conv_state.copy_(conv_varlen_states)
803
- assert self.activation in ["silu", "swish"]
804
- if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
805
- assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
806
- xBC = self.act(
807
- self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
808
- ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
809
- else:
810
- xBC = causal_conv1d_fn(
811
- xBC.transpose(1, 2),
812
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
813
- bias=self.conv1d.bias,
814
- activation=self.activation,
815
- # seq_idx=seq_idx,
816
- ).transpose(1, 2)
817
- x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
818
-
819
- y = mamba_chunk_scan_combined(
820
- rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
821
- dt,
822
- A,
823
- rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
824
- rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
825
- chunk_size=self.chunk_size,
826
- D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
827
- z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
828
- dt_bias=self.dt_bias,
829
- dt_softplus=True,
830
- seq_idx=seq_idx,
831
- cu_seqlens=cu_seqlens,
832
- **dt_limit_kwargs,
833
- return_final_states=ssm_state is not None,
834
- return_varlen_states=cu_seqlens is not None and inference_params is not None,
835
- )
836
- if ssm_state is not None:
837
- y, last_state, *rest = y
838
- if cu_seqlens is None:
839
- ssm_state.copy_(last_state)
840
- else:
841
- varlen_states = rest[0]
842
- ssm_state.copy_(varlen_states)
843
- y = rearrange(y, "b l h p -> b l (h p)")
844
- if self.rmsnorm:
845
- y = self.norm(y, z)
846
- if d_mlp > 0:
847
- y = torch.cat([F.silu(z0) * x0, y], dim=-1)
848
- if seqlen_og is not None:
849
- y = rearrange(y, "b l d -> (b l) d")
850
-
851
- scan_outputs = y
852
- if 'repeat_v' in self.fused_multihead_config and self.fused_multihead_config['repeat_v']:
853
- num_repeat = scan_outputs.shape[-1] // attn_outputs.shape[-1]
854
- attn_outputs = attn_outputs.repeat(1, 1, num_repeat)
855
-
856
- hidden_states = (self.pre_avg_layernorm1(attn_outputs) + self.pre_avg_layernorm2(scan_outputs)) / 2
857
- out = self.out_proj(hidden_states)
858
-
859
- return out, attn_key_value, past_key_value
860
-
861
-
862
- def step(self, hidden_states, conv_state, ssm_state):
863
- dtype = hidden_states.dtype
864
- assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
865
- zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
866
- d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
867
- z0, x0, z, xBC, dt = torch.split(
868
- zxbcdt,
869
- [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
870
- dim=-1
871
- )
872
-
873
- # Conv step
874
- if causal_conv1d_update is None:
875
- conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
876
- conv_state[:, :, -1] = xBC
877
- xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
878
- if self.conv1d.bias is not None:
879
- xBC = xBC + self.conv1d.bias
880
- xBC = self.act(xBC).to(dtype=dtype)
881
- else:
882
- xBC = causal_conv1d_update(
883
- xBC,
884
- conv_state,
885
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
886
- self.conv1d.bias,
887
- self.activation,
888
- )
889
-
890
- x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
891
- A = -torch.exp(self.A_log.float()) # (nheads,)
892
-
893
- # SSM step
894
- if selective_state_update is None:
895
- assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
896
- # Discretize A and B
897
- dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
898
- dA = torch.exp(dt * A) # (batch, nheads)
899
- x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
900
- dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
901
- ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
902
- y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
903
- y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
904
- y = rearrange(y, "b h p -> b (h p)")
905
- if not self.rmsnorm:
906
- y = y * self.act(z) # (B D)
907
- else:
908
- A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
909
- dt = repeat(dt, "b h -> b h p", p=self.headdim)
910
- dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
911
- D = repeat(self.D, "h -> h p", p=self.headdim)
912
- B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
913
- C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
914
- x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
915
- if not self.rmsnorm:
916
- z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
917
- y = selective_state_update(
918
- ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
919
- dt_bias=dt_bias, dt_softplus=True
920
- )
921
- y = rearrange(y, "b h p -> b (h p)")
922
- if self.rmsnorm:
923
- y = self.norm(y, z)
924
- if d_mlp > 0:
925
- y = torch.cat([F.silu(z0) * x0, y], dim=-1)
926
- out = self.out_proj(y)
927
-
928
- print(out)
929
- input()
930
- return out.unsqueeze(1), conv_state, ssm_state
931
-
932
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
933
- device = self.out_proj.weight.device
934
- conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
935
- conv_state = torch.zeros(
936
- batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
937
- ).transpose(1, 2)
938
- ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
939
- ssm_state = torch.zeros(
940
- batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
941
- )
942
- return conv_state, ssm_state
943
-
944
- def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
945
- assert self.layer_idx is not None
946
- if self.layer_idx not in inference_params.key_value_memory_dict:
947
- batch_shape = (batch_size,)
948
- conv_state = torch.zeros(
949
- batch_size,
950
- self.d_conv,
951
- self.conv1d.weight.shape[0],
952
- device=self.conv1d.weight.device,
953
- dtype=self.conv1d.weight.dtype,
954
- ).transpose(1, 2)
955
- ssm_state = torch.zeros(
956
- batch_size,
957
- self.nheads,
958
- self.headdim,
959
- self.d_state,
960
- device=self.in_proj.weight.device,
961
- dtype=self.in_proj.weight.dtype,
962
- )
963
- inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
964
- else:
965
- conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
966
- # TODO: What if batch size changes between generation, and we reuse the same states?
967
- if initialize_states:
968
- conv_state.zero_()
969
- ssm_state.zero_()
970
- return conv_state, ssm_state
971
-
972
-
973
- class Mamba2_Multihead(nn.Module):
974
- def __init__(
975
- self,
976
- config,
977
- conv_init=None,
978
- headdim=64,
979
- d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
980
- ngroups=1,
981
- A_init_range=(1, 16),
982
- D_has_hdim=False,
983
- rmsnorm=True,
984
- norm_before_gate=False,
985
- dt_min=0.001,
986
- dt_max=0.1,
987
- dt_init_floor=1e-4,
988
- dt_limit=(0.0, float("inf")),
989
- bias=False,
990
- conv_bias=True,
991
- # Fused kernel and sharding options
992
- chunk_size=256,
993
- use_mem_eff_path=False, # True,
994
- layer_idx=None, # Absorb kwarg for general module
995
- process_group=None,
996
- sequence_parallel=True,
997
- device=None,
998
- dtype=None,
999
- ):
1000
- factory_kwargs = {"device": device, "dtype": dtype}
1001
- super().__init__()
1002
-
1003
- self.config = config
1004
- self.d_model = config.hidden_size
1005
- self.d_state = config.mamba_d_state
1006
- self.d_conv = config.mamba_d_conv
1007
-
1008
- self.conv_init = conv_init
1009
- self.expand = config.mamba_expand
1010
- self.process_group = process_group
1011
- self.sequence_parallel = sequence_parallel
1012
- self.world_size = 1 if process_group is None else process_group.size()
1013
- self.local_rank = 0 if process_group is None else process_group.rank()
1014
- self.d_inner = (self.expand * self.d_model) // self.world_size
1015
- assert self.d_inner * self.world_size == self.expand * self.d_model
1016
- self.headdim = config.mamba2_headdim
1017
- self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
1018
- assert ngroups % self.world_size == 0
1019
- self.ngroups = ngroups // self.world_size
1020
- assert self.d_ssm % self.headdim == 0
1021
- self.nheads = self.d_ssm // self.headdim
1022
- self.D_has_hdim = D_has_hdim
1023
- self.rmsnorm = rmsnorm
1024
- self.norm_before_gate = norm_before_gate
1025
- self.dt_limit = dt_limit
1026
- self.activation = "silu"
1027
- self.chunk_size = chunk_size
1028
- self.use_mem_eff_path = use_mem_eff_path
1029
- self.layer_idx = layer_idx
1030
-
1031
- assert (self.d_model * self.expand / self.headdim) % 8 == 0
1032
-
1033
- self.mamba_multihead_config = config.mamba_multihead_config
1034
- self.share_ratio = self.mamba_multihead_config['share_ratio']
1035
-
1036
- self.reuse_ssm = self.mamba_multihead_config['reuse_ssm']
1037
- self.num_ssm_param = 1 if self.reuse_ssm else self.share_ratio
1038
-
1039
- if self.reuse_ssm:
1040
- if self.mamba_multihead_config['alpha_mode'] == 'learnable':
1041
- self.alpha = nn.Parameter(torch.ones(self.share_ratio))
1042
- elif self.mamba_multihead_config['alpha_mode'] == 'manual':
1043
- manual_alpha_base = self.mamba_multihead_config['manual_alpha_base']
1044
- self.alpha = [1 / manual_alpha_base ** k for k in range(self.share_ratio)]
1045
- else:
1046
- raise ValueError(f"No such alpha_mode: {self.mamba_multihead_config['alpha_mode']}")
1047
-
1048
- # Order: [z, x, B, C, dt]
1049
- d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads * self.num_ssm_param
1050
- if self.process_group is None:
1051
- self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
1052
- else:
1053
- self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
1054
- process_group=self.process_group, sequence_parallel=self.sequence_parallel,
1055
- **factory_kwargs)
1056
-
1057
- conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
1058
- self.conv1d = nn.Conv1d(
1059
- in_channels=conv_dim,
1060
- out_channels=conv_dim,
1061
- bias=conv_bias,
1062
- kernel_size=self.d_conv,
1063
- groups=conv_dim,
1064
- padding=self.d_conv - 1,
1065
- **factory_kwargs,
1066
- )
1067
- if self.conv_init is not None:
1068
- nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
1069
-
1070
- self.act = nn.SiLU()
1071
-
1072
- # Initialize log dt bias
1073
- dt = torch.exp(
1074
- torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
1075
- + math.log(dt_min)
1076
- )
1077
- dt = torch.clamp(dt, min=dt_init_floor)
1078
- # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
1079
- inv_dt = dt + torch.log(-torch.expm1(-dt))
1080
- self.dt_bias = nn.ParameterList([nn.Parameter(inv_dt) for _ in range(self.num_ssm_param)])
1081
- # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
1082
- # name.endswith("bias") in param_grouping.py
1083
- self.dt_bias._no_weight_decay = True
1084
-
1085
- assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
1086
- A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
1087
- A_log = torch.log(A).to(dtype=dtype)
1088
- self.A_log = nn.ParameterList([nn.Parameter(A_log) for _ in range(self.num_ssm_param)])
1089
- self.A_log._no_weight_decay = True
1090
-
1091
- # D "skip" parameter
1092
- self.D = nn.ParameterList([nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)) for _ in range(self.num_ssm_param)])
1093
- self.D._no_weight_decay = True
1094
-
1095
- if self.rmsnorm:
1096
- assert RMSNormGated is not None
1097
- self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
1098
- group_size=self.d_ssm // ngroups, **factory_kwargs)
1099
-
1100
- if self.process_group is None:
1101
- self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
1102
- else:
1103
- self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
1104
- process_group=self.process_group, sequence_parallel=self.sequence_parallel,
1105
- **factory_kwargs)
1106
-
1107
-
1108
- if self.mamba_multihead_config['merge_op'] == 'norm':
1109
- self.multihead_layernorm = nn.ModuleList([JambaRMSNorm(self.d_ssm, eps=config.rms_norm_eps) for _ in range(self.share_ratio)])
1110
- elif self.mamba_multihead_config['merge_op'] == 'scalar_gate':
1111
- self.multi_head_selection_layer = nn.Linear(self.d_ssm, self.share_ratio)
1112
- elif self.mamba_multihead_config['merge_op'] == 'concat':
1113
- assert self.d_ssm % self.share_ratio == 0
1114
- self.multihead_layernorm = nn.ModuleList([JambaRMSNorm(self.d_ssm, eps=config.rms_norm_eps) for _ in range(self.share_ratio)])
1115
- self.reduction_layer = nn.Linear(self.d_ssm, self.d_ssm//self.share_ratio)
1116
-
1117
-
1118
- def forward(self, hidden_states, attention_mask=None, past_key_value=None, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
1119
- """
1120
- hidden_states: (batch, seqlen, hidden_dim) if seqlen=None.
1121
- If seqlen is not None, hidden_states is (batch * seqlen, hidden_dim). This is so that when we
1122
- split hidden_states during sequence parallel, we split the batch * seqlen dimension
1123
- (in case batch is small).
1124
- Returns: same shape as u
1125
- """
1126
- assert past_key_value is None, "Not implemented yet!!!"
1127
-
1128
- seqlen_og = seqlen
1129
- if seqlen is None:
1130
- batch, seqlen, dim = hidden_states.shape
1131
- else:
1132
- batch_seqlen, dim = hidden_states.shape
1133
- batch = batch_seqlen // seqlen
1134
-
1135
- conv_state, ssm_state = None, None
1136
- if inference_params is not None:
1137
- inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
1138
- conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
1139
- if inference_params.seqlen_offset > 0:
1140
- # The states are updated inplace
1141
- out, _, _ = self.step(hidden_states, conv_state, ssm_state)
1142
- return out
1143
-
1144
- zxbcdt = self.in_proj(hidden_states) # (B, L, d_in_proj) or (B * L, d_in_proj)
1145
- if seqlen_og is not None:
1146
- zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
1147
-
1148
- dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
1149
- if self.use_mem_eff_path and inference_params is None:
1150
- # If the model is loaded in fp16, without the .float() here, A might be -inf
1151
- A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
1152
-
1153
- out = mamba_split_conv1d_scan_combined(
1154
- zxbcdt,
1155
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
1156
- self.conv1d.bias,
1157
- self.dt_bias,
1158
- A,
1159
- D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
1160
- chunk_size=self.chunk_size,
1161
- seq_idx=seq_idx,
1162
- activation=self.activation,
1163
- rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
1164
- rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
1165
- outproj_weight=self.out_proj.weight,
1166
- outproj_bias=self.out_proj.bias,
1167
- headdim=None if self.D_has_hdim else self.headdim,
1168
- ngroups=self.ngroups,
1169
- norm_before_gate=self.norm_before_gate,
1170
- **dt_limit_kwargs,
1171
- )
1172
- if seqlen_og is not None:
1173
- out = rearrange(out, "b l d -> (b l) d")
1174
- if self.process_group is not None:
1175
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
1176
- out = reduce_fn(out, self.process_group)
1177
- else:
1178
- d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads * self.num_ssm_param) // 2
1179
- z0, x0, z, xBC, dt = torch.split(
1180
- zxbcdt,
1181
- [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads * self.num_ssm_param],
1182
- dim=-1
1183
- )
1184
-
1185
- if conv_state is not None:
1186
- if cu_seqlens is None:
1187
- # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
1188
- # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
1189
- xBC_t = rearrange(xBC, "b l d -> b d l")
1190
- conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
1191
- else:
1192
- assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
1193
- assert batch == 1, "varlen inference only supports batch dimension 1"
1194
- conv_varlen_states = causal_conv1d_varlen_states(
1195
- xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
1196
- )
1197
- conv_state.copy_(conv_varlen_states)
1198
- assert self.activation in ["silu", "swish"]
1199
- if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
1200
- assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
1201
- xBC = self.act(
1202
- self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
1203
- ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
1204
- else:
1205
- xBC = causal_conv1d_fn(
1206
- xBC.transpose(1, 2),
1207
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
1208
- bias=self.conv1d.bias,
1209
- activation=self.activation,
1210
- seq_idx=seq_idx,
1211
- ).transpose(1, 2)
1212
- x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
1213
-
1214
- x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim)
1215
- B = rearrange(B, "b l (g n) -> b l g n", g=self.ngroups)
1216
- C = rearrange(C, "b l (g n) -> b l g n", g=self.ngroups)
1217
-
1218
- outputs_list = []
1219
- dt_list = dt
1220
- for i in range(self.num_ssm_param):
1221
- dt = dt_list[..., self.nheads*i:self.nheads*(i+1)]
1222
- A = -torch.exp(self.A_log[i].float()) # (nheads) or (d_inner, d_state)
1223
- D = rearrange(self.D[i], "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D[i]
1224
- dt_bias = self.dt_bias[i]
1225
-
1226
- if self.reuse_ssm:
1227
- #### duplicate heads with different decays
1228
- if self.mamba_multihead_config['alpha_mode'] == 'learnable':
1229
- decay = self.alpha # [share_ratio]
1230
- elif self.mamba_multihead_config['alpha_mode'] == 'manual':
1231
- decay = torch.tensor(self.alpha).to(dt) # [share_ratio]
1232
-
1233
- dt = dt.repeat(1, 1, self.share_ratio) # [bs, seq_len, self.nheads * share_ratio]
1234
- decay = decay.view(-1, 1).repeat(1, self.nheads).view(-1) # [self.nheads * share_ratio]
1235
- dt = dt * decay # [bs, seq_len, nheads * share_ratio]
1236
-
1237
- dt_bias = dt_bias.repeat(self.share_ratio) * decay # [nheads * share_ratio]
1238
-
1239
- x = x.repeat(1,1,self.share_ratio,1) # [bs, seq_len, nheads * share_ratio, head_dim]
1240
- D = D.repeat(self.share_ratio,1) if self.D_has_hdim else D.repeat(self.share_ratio) # [nheads * share_ratio]
1241
- A = A.repeat(self.share_ratio) # [nheads * share_ratio]
1242
-
1243
- y = mamba_chunk_scan_combined(
1244
- x,
1245
- dt,
1246
- A,
1247
- B,
1248
- C,
1249
- chunk_size=self.chunk_size,
1250
- D=D,
1251
- z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim).repeat(1,1,self.share_ratio,1) if not self.rmsnorm else None,
1252
- dt_bias=dt_bias,
1253
- dt_softplus=True,
1254
- seq_idx=seq_idx,
1255
- cu_seqlens=cu_seqlens,
1256
- **dt_limit_kwargs,
1257
- return_final_states=ssm_state is not None,
1258
- return_varlen_states=cu_seqlens is not None and inference_params is not None,
1259
- )
1260
- if ssm_state is not None:
1261
- y, last_state, *rest = y
1262
- if cu_seqlens is None:
1263
- ssm_state.copy_(last_state)
1264
- else:
1265
- varlen_states = rest[0]
1266
- ssm_state.copy_(varlen_states)
1267
-
1268
- outputs_list.append(y)
1269
-
1270
- if len(outputs_list) > 1:
1271
- y = torch.cat(outputs_list, dim=2)
1272
-
1273
- #### merge heads
1274
- num_repeat = y.shape[2] // self.nheads
1275
- head_outputs = torch.chunk(y, num_repeat, dim=2)
1276
- head_outputs = [rearrange(item, "b l h p -> b l (h p)") for item in head_outputs]
1277
-
1278
- if self.mamba_multihead_config['merge_op'] == 'norm':
1279
- y = sum([self.multihead_layernorm[k](item) for k, item in enumerate(head_outputs)])
1280
-
1281
- elif self.mamba_multihead_config['merge_op'] == 'concat':
1282
- head_outputs = [self.reduction_layer(self.multihead_layernorm[k](item)) for k, item in enumerate(head_outputs)]
1283
- y = torch.cat(head_outputs, dim=-1)
1284
- else:
1285
- raise ValueError(f"No such merge_op: {self.mamba_multihead_config['merge_op']}")
1286
-
1287
- if self.rmsnorm:
1288
- y = self.norm(y, z)
1289
- if d_mlp > 0:
1290
- y = torch.cat([F.silu(z0) * x0, y], dim=-1)
1291
- if seqlen_og is not None:
1292
- y = rearrange(y, "b l d -> (b l) d")
1293
- out = self.out_proj(y)
1294
- return out, past_key_value
1295
-
1296
- def step(self, hidden_states, conv_state, ssm_state):
1297
- dtype = hidden_states.dtype
1298
- assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
1299
- zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
1300
- d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
1301
- z0, x0, z, xBC, dt = torch.split(
1302
- zxbcdt,
1303
- [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
1304
- dim=-1
1305
- )
1306
-
1307
- # Conv step
1308
- if causal_conv1d_update is None:
1309
- conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
1310
- conv_state[:, :, -1] = xBC
1311
- xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
1312
- if self.conv1d.bias is not None:
1313
- xBC = xBC + self.conv1d.bias
1314
- xBC = self.act(xBC).to(dtype=dtype)
1315
- else:
1316
- xBC = causal_conv1d_update(
1317
- xBC,
1318
- conv_state,
1319
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
1320
- self.conv1d.bias,
1321
- self.activation,
1322
- )
1323
-
1324
- x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
1325
- A = -torch.exp(self.A_log.float()) # (nheads,)
1326
-
1327
- # SSM step
1328
- if selective_state_update is None:
1329
- assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
1330
- # Discretize A and B
1331
- dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
1332
- dA = torch.exp(dt * A) # (batch, nheads)
1333
- x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
1334
- dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
1335
- ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
1336
- y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
1337
- y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
1338
- y = rearrange(y, "b h p -> b (h p)")
1339
- if not self.rmsnorm:
1340
- y = y * self.act(z) # (B D)
1341
- else:
1342
- A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
1343
- dt = repeat(dt, "b h -> b h p", p=self.headdim)
1344
- dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
1345
- D = repeat(self.D, "h -> h p", p=self.headdim)
1346
- B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
1347
- C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
1348
- x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
1349
- if not self.rmsnorm:
1350
- z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
1351
- y = selective_state_update(
1352
- ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
1353
- dt_bias=dt_bias, dt_softplus=True
1354
- )
1355
- y = rearrange(y, "b h p -> b (h p)")
1356
- if self.rmsnorm:
1357
- y = self.norm(y, z)
1358
- if d_mlp > 0:
1359
- y = torch.cat([F.silu(z0) * x0, y], dim=-1)
1360
- out = self.out_proj(y)
1361
- return out.unsqueeze(1), conv_state, ssm_state
1362
-
1363
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
1364
- device = self.out_proj.weight.device
1365
- conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
1366
- conv_state = torch.zeros(
1367
- batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
1368
- ).transpose(1, 2)
1369
- ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
1370
- ssm_state = torch.zeros(
1371
- batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
1372
- )
1373
- return conv_state, ssm_state
1374
-
1375
- def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
1376
- assert self.layer_idx is not None
1377
- if self.layer_idx not in inference_params.key_value_memory_dict:
1378
- batch_shape = (batch_size,)
1379
- conv_state = torch.zeros(
1380
- batch_size,
1381
- self.d_conv,
1382
- self.conv1d.weight.shape[0],
1383
- device=self.conv1d.weight.device,
1384
- dtype=self.conv1d.weight.dtype,
1385
- ).transpose(1, 2)
1386
- ssm_state = torch.zeros(
1387
- batch_size,
1388
- self.nheads,
1389
- self.headdim,
1390
- self.d_state,
1391
- device=self.in_proj.weight.device,
1392
- dtype=self.in_proj.weight.dtype,
1393
- )
1394
- inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
1395
- else:
1396
- conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
1397
- # TODO: What if batch size changes between generation, and we reuse the same states?
1398
- if initialize_states:
1399
- conv_state.zero_()
1400
- ssm_state.zero_()
1401
- return conv_state, ssm_state
1402
-
1403
-
1404
-
1405
-
1406
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Jamba
1407
- class JambaRMSNorm(nn.Module):
1408
- def __init__(self, hidden_size, eps=1e-6):
1409
- """
1410
- JambaRMSNorm is equivalent to T5LayerNorm
1411
- """
1412
- super().__init__()
1413
- self.weight = nn.Parameter(torch.ones(hidden_size))
1414
- self.variance_epsilon = eps
1415
-
1416
- def forward(self, hidden_states):
1417
- input_dtype = hidden_states.dtype
1418
- hidden_states = hidden_states.to(torch.float32)
1419
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
1420
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
1421
- return self.weight * hidden_states.to(input_dtype)
1422
-
1423
-
1424
-
1425
-
1426
-
1427
-
 
18
  except ImportError:
19
  causal_conv1d_varlen_states = None
20
 
 
 
 
21
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
22
  from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
23
 
 
121
  # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
122
  inv_dt = dt + torch.log(-torch.expm1(-dt))
123
 
124
+ self.dt_bias = nn.Parameter(inv_dt)
125
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
126
+ # name.endswith("bias") in param_grouping.py
127
+ self.dt_bias._no_weight_decay = True
 
 
 
128
 
129
  assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
130
  A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
 
148
  process_group=self.process_group, sequence_parallel=self.sequence_parallel,
149
  **factory_kwargs)
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  def forward(self, hidden_states, attention_mask=None, past_key_value=None, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
153
  """
 
159
  """
160
  # assert past_key_value is None, "Not implemented yet!!!"
161
 
 
 
 
 
 
162
  seqlen_og = seqlen
163
  if seqlen is None:
164
  batch, seqlen, dim = hidden_states.shape
 
167
  batch = batch_seqlen // seqlen
168
 
169
  conv_state, ssm_state = None, None
170
+
171
  if inference_params is not None:
172
  inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
173
  conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
174
+
175
  if inference_params.seqlen_offset > 0:
176
  # The states are updated inplace
177
  out, _, _ = self.step(hidden_states, conv_state, ssm_state)
178
+ return out, past_key_value
179
 
180
  zxbcdt = self.in_proj(hidden_states) # (B, L, d_in_proj) or (B * L, d_in_proj)
181
 
 
 
 
182
  if seqlen_og is not None:
183
  zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
184
  # If the model is loaded in fp16, without the .float() here, A might be -inf
 
216
  [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
217
  dim=-1
218
  )
219
+
220
  if conv_state is not None:
221
  if cu_seqlens is None:
222
  # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
 
244
  activation=self.activation,
245
  # seq_idx=seq_idx,
246
  ).transpose(1, 2)
247
+
248
  x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  y = mamba_chunk_scan_combined(
252
  rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
 
255
  rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
256
  rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
257
  chunk_size=self.chunk_size,
258
+ # D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
259
+ D=self.D,
260
  z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
261
+ dt_bias=self.dt_bias,
262
  dt_softplus=True,
263
  seq_idx=seq_idx,
264
  cu_seqlens=cu_seqlens,
 
275
  ssm_state.copy_(varlen_states)
276
  y = rearrange(y, "b l h p -> b l (h p)")
277
  if self.rmsnorm:
278
+ y_full = y
279
+ z_full = z
280
+
281
+ y = self.norm(y_full, z_full)
282
  if d_mlp > 0:
283
  y = torch.cat([F.silu(z0) * x0, y], dim=-1)
284
  if seqlen_og is not None:
285
  y = rearrange(y, "b l d -> (b l) d")
286
 
 
 
 
287
  out = self.out_proj(y)
288
+
289
  return out, past_key_value
290
 
291
 
292
+ def step(self, hidden_states, conv_state, ssm_state):
293
+ dtype = hidden_states.dtype
294
+ # Remove single token limitation - now supports hidden_states.shape[1] > 1
295
+ batch_size, seq_len, _ = hidden_states.shape
 
 
 
 
296
 
297
+ if seq_len == 1:
298
+ # Single token case - keep existing optimized path
299
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
 
 
 
 
 
300
  else:
301
+ # Multi-token case - process without squeezing
302
+ zxbcdt = self.in_proj(hidden_states) # (B L 2D)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
305
+
306
+ if seq_len == 1:
307
+ z0, x0, z, xBC, dt = torch.split(
308
+ zxbcdt,
309
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
310
+ dim=-1
311
+ )
 
 
 
 
 
 
 
312
  else:
313
+ z0, x0, z, xBC, dt = torch.split(
314
+ zxbcdt,
315
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
316
+ dim=-1
 
 
317
  )
318
 
319
+ # Conv step - handle both single and multi-token cases
320
+ if seq_len == 1:
321
+ # Single token optimized path
322
+ if causal_conv1d_update is None:
323
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
324
+ conv_state[:, :, -1] = xBC
325
+ xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
326
+ if self.conv1d.bias is not None:
327
+ xBC = xBC + self.conv1d.bias
328
+ xBC = self.act(xBC).to(dtype=dtype)
329
+ else:
330
+ xBC = causal_conv1d_update(
331
+ xBC,
332
+ conv_state,
333
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
334
+ self.conv1d.bias,
335
+ self.activation,
336
+ )
337
+ else:
338
+ # Multi-token case - update conv_state and process sequence
339
+ # Update conv_state with the new sequence
340
+ xBC_t = rearrange(xBC, "b l d -> b d l")
341
+ conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
342
+
343
+ # Process convolution for the full sequence
344
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
345
+ xBC = self.act(
346
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.d_conv - 1):]
347
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
348
+ else:
349
+ xBC = causal_conv1d_fn(
350
+ xBC.transpose(1, 2),
351
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
352
+ bias=self.conv1d.bias,
353
+ activation=self.activation,
354
+ ).transpose(1, 2)
355
+
356
  x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
357
  A = -torch.exp(self.A_log.float()) # (nheads,)
358
 
359
+ # SSM step - handle both single and multi-token cases
360
+ if seq_len == 1:
361
+ # Single token optimized path
362
+ if selective_state_update is None:
363
+ assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
364
+ # Discretize A and B
365
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
366
+ dA = torch.exp(dt * A) # (batch, nheads)
367
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
368
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
369
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
370
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
371
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
372
+ y = rearrange(y, "b h p -> b (h p)")
373
+ if not self.rmsnorm:
374
+ y = y * self.act(z) # (B D)
375
+ else:
376
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
377
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
378
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
379
+ D = repeat(self.D, "h -> h p", p=self.headdim)
380
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
381
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
382
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
383
+ if not self.rmsnorm:
384
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
385
+ y = selective_state_update(
386
+ ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
387
+ dt_bias=dt_bias, dt_softplus=True
388
+ )
389
+ y = rearrange(y, "b h p -> b (h p)")
390
  else:
391
+ # Multi-token case - use mamba_chunk_scan_combined similar to forward method
392
+ dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
393
+
394
+ y = mamba_chunk_scan_combined(
395
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
396
+ dt,
397
+ A,
398
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
399
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
400
+ chunk_size=self.chunk_size,
401
+ D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
402
+ z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
403
+ dt_bias=self.dt_bias,
404
+ dt_softplus=True,
405
+ **dt_limit_kwargs,
406
+ return_final_states=True,
407
  )
408
+ # Extract final state and update ssm_state
409
+ y, final_ssm_state = y
410
+ ssm_state.copy_(final_ssm_state)
411
+ y = rearrange(y, "b l h p -> b l (h p)")
412
+
413
  if self.rmsnorm:
414
  y = self.norm(y, z)
415
  if d_mlp > 0:
416
  y = torch.cat([F.silu(z0) * x0, y], dim=-1)
417
  out = self.out_proj(y)
418
+
419
+ # Ensure output shape consistency
420
+ if seq_len == 1 and out.dim() == 2:
421
+ out = out.unsqueeze(1) # (B, 1, D)
422
+
423
+ return out, conv_state, ssm_state
424
+
425
 
426
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
427
  device = self.out_proj.weight.device
 
461
  if initialize_states:
462
  conv_state.zero_()
463
  ssm_state.zero_()
464
+ return conv_state, ssm_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model-00001-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5e8a0875ed4decf5cbbf676868cbba137f3248a5a592a85597f31614080a25c6
3
  size 4987939472
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b9b0f9876dd16860790782a1f166be5173253c2ea303c4e3ab40b0b8218af2f
3
  size 4987939472
model-00002-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9f361fe0361ab0e101d95f5161ee1a724501ae664e0d27c28496d9288b71ebc3
3
  size 512102640
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:915482f874991ab97f5e2afbb2064d38f829ef6a76907ddf337a495e14bba382
3
  size 512102640
modeling_fast_slm.py ADDED
The diff for this file is too large to render. See raw diff
 
triton_attention.py ADDED
@@ -0,0 +1,2714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom ops for MHA/XQA attention."""
2
+
3
+ import math
4
+ from dataclasses import astuple
5
+ from typing import List, Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import triton
10
+
11
+ from triton import language as tl
12
+
13
+ from abc import ABC, abstractmethod
14
+ from dataclasses import dataclass, field, fields
15
+ from typing import Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union
16
+
17
+ import torch
18
+ from torch.export import Dim
19
+
20
+
21
+ @triton.jit
22
+ def update_kv_cache(
23
+ k_ptr, # [B*S, N, D]
24
+ v_ptr, # [B*S, N, D]
25
+ seq_len_ptr, # [b] # length of each sequence in a batch
26
+ seq_start_indices_ptr, # [b] # start indices of a sequence in flattened q/k/v.
27
+ k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
28
+ v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
29
+ input_pos_ptr, # Specifies the sequence index in the caches at which to write the provided kv
30
+ cache_loc_ptr, # Specifies the batch index for each of the input sequences
31
+ MAX_SEQ_LENGTH: tl.constexpr,
32
+ N_KV_HEADS: tl.constexpr,
33
+ Q_D_HEAD: tl.constexpr,
34
+ V_D_HEAD: tl.constexpr,
35
+ SEQ_BLOCK: tl.constexpr,
36
+ GENERATE_ONLY: tl.constexpr,
37
+ ):
38
+ batch_id = tl.program_id(axis=0)
39
+ head_id = tl.program_id(axis=1)
40
+ seq_block_id = tl.program_id(axis=2)
41
+
42
+ # Each program is responsible for a block of tokens in a single batch.
43
+ if GENERATE_ONLY:
44
+ seq_start_index = batch_id
45
+ seq_len: tl.constexpr = 1
46
+ else:
47
+ seq_start_index = tl.load(seq_start_indices_ptr + batch_id)
48
+ seq_len = tl.load(seq_len_ptr + batch_id)
49
+
50
+ # cache is [bsnd]
51
+ # cache_loc_ptr stores the batch index for the sequences provided to the kernel.
52
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
53
+
54
+ kv_position = tl.load(input_pos_ptr + batch_id)
55
+
56
+ K_D_HEAD: tl.constexpr = Q_D_HEAD
57
+ k_cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH * K_D_HEAD
58
+ v_cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH * V_D_HEAD
59
+
60
+ k_dhead_offsets = tl.arange(0, triton.next_power_of_2(K_D_HEAD))
61
+ k_dhead_mask = k_dhead_offsets < K_D_HEAD
62
+
63
+ v_dhead_offsets = tl.arange(0, triton.next_power_of_2(V_D_HEAD))
64
+ v_dhead_mask = v_dhead_offsets < V_D_HEAD
65
+
66
+ seq_offsets = seq_block_id * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
67
+ seq_mask = seq_offsets < seq_len
68
+
69
+ k_load_mask = seq_mask[:, None] * k_dhead_mask[None, :]
70
+ v_load_mask = seq_mask[:, None] * v_dhead_mask[None, :]
71
+
72
+ k_batch_offset = seq_start_index * N_KV_HEADS * K_D_HEAD
73
+ v_batch_offset = seq_start_index * N_KV_HEADS * V_D_HEAD
74
+ # Write back to kv-caches
75
+ ks = tl.load(
76
+ k_ptr
77
+ + k_batch_offset
78
+ + seq_offsets[:, None] * N_KV_HEADS * K_D_HEAD
79
+ + head_id * K_D_HEAD
80
+ + k_dhead_offsets[None, :],
81
+ mask=k_load_mask,
82
+ )
83
+ vs = tl.load(
84
+ v_ptr
85
+ + v_batch_offset
86
+ + seq_offsets[:, None] * N_KV_HEADS * V_D_HEAD
87
+ + head_id * V_D_HEAD
88
+ + v_dhead_offsets[None, :],
89
+ mask=v_load_mask,
90
+ )
91
+
92
+ kv_writeback_seq_offsets = seq_offsets + kv_position
93
+
94
+ k_cache_offset = (
95
+ k_cache_batch_offset
96
+ + kv_writeback_seq_offsets[:, None] * K_D_HEAD * N_KV_HEADS
97
+ + head_id * K_D_HEAD
98
+ + k_dhead_offsets[None, :]
99
+ )
100
+
101
+ v_cache_offset = (
102
+ v_cache_batch_offset
103
+ + kv_writeback_seq_offsets[:, None] * V_D_HEAD * N_KV_HEADS
104
+ + head_id * V_D_HEAD
105
+ + v_dhead_offsets[None, :]
106
+ )
107
+ tl.store(k_cache_ptr + k_cache_offset, ks, k_load_mask)
108
+ tl.store(v_cache_ptr + v_cache_offset, vs, v_load_mask)
109
+
110
+
111
+ @triton.jit
112
+ def gqa_attention_kv_stage1(
113
+ q_ptr, # [Batch, 1, N_HEADS, D_HEAD]
114
+ k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
115
+ v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
116
+ cache_loc_ptr, # [Batch] # Specifies the batch index for each of the generate tokens.
117
+ input_pos_ptr, # [Batch]
118
+ output_values_ptr, # [Batch, N_HEADS, num_blocks, D_HEAD]
119
+ output_logsumexp_ptr, # [Batch, N_HEADS, num_blocks]
120
+ num_blocks,
121
+ MAX_SEQ_LEN: tl.constexpr, # Maximum supported sequence length
122
+ N_HEADS: tl.constexpr, # Number of heads
123
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
124
+ Q_D_HEAD: tl.constexpr, # Dimension of each query head.
125
+ V_D_HEAD: tl.constexpr, # Dimension of each key/value head
126
+ SEQ_BLOCK_SIZE: tl.constexpr, # Block size used for tiling the sequence dim.
127
+ HEAD_BLOCK_SIZE: tl.constexpr, # pad to 16 if HEAD_RATIO is < 16 to invoke tensor cores.
128
+ ):
129
+ """Attention kernel to be used for generate-only batches.
130
+
131
+ Specialized for GQA.
132
+
133
+ Assumes that kv caches have been updated.
134
+
135
+ Supports non-power-of-2 D_HEAD
136
+
137
+ Uses flash decoding.
138
+ KV-cache layout is assumed to be [Batch,Seq, Head, Dim]
139
+ 1. Fetch the K-cache from 0 to input_pos
140
+ 2. Fetch the V-cache from 0 to input_pos
141
+ 3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
142
+ 4. S = softmax(A)
143
+ 5. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD]
144
+ """
145
+ # Assume KV-cache layout: [Batch, Seq, Head, Dim]
146
+ # A program is responsible for 1 batch, 1 head and a block of sequences.
147
+ batch_id = tl.program_id(axis=0)
148
+ kv_head_id = tl.program_id(axis=1)
149
+ seq_block_id = tl.program_id(axis=2)
150
+
151
+ kv_position = tl.load(input_pos_ptr + batch_id)
152
+ kv_batch_id = tl.load(cache_loc_ptr + batch_id)
153
+ K_D_HEAD: tl.constexpr = Q_D_HEAD
154
+ batch_offset = kv_batch_id * N_KV_HEADS * MAX_SEQ_LEN
155
+
156
+ # Offsets for the block of sequences this program processes.
157
+ seq_start_pos = seq_block_id * SEQ_BLOCK_SIZE
158
+
159
+ # The number of Q heads that map to each KV head.
160
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS # This needs to be a power-of-2
161
+ if seq_start_pos > kv_position:
162
+ return
163
+ seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
164
+ seq_mask = seq_offsets <= kv_position
165
+
166
+ # Need to pad the head dim to 16 if HEAD_RATIO is < 16 so that tensor cores can be invoked
167
+ #
168
+ head_offsets = kv_head_id * HEAD_RATIO + tl.arange(0, HEAD_BLOCK_SIZE)
169
+ head_mask = head_offsets < (kv_head_id * HEAD_RATIO + HEAD_RATIO)
170
+ # Assuming D_HEAD is a power of 2
171
+ q_dhead_offsets = tl.arange(0, triton.next_power_of_2(Q_D_HEAD))
172
+ q_dhead_mask = q_dhead_offsets < Q_D_HEAD
173
+
174
+ v_dhead_offsets = tl.arange(0, triton.next_power_of_2(V_D_HEAD))
175
+ v_dhead_mask = v_dhead_offsets < V_D_HEAD
176
+
177
+ sm_scale: tl.constexpr = 1.0 / (Q_D_HEAD**0.5)
178
+
179
+ # Program loads the entire Q for the head assigned to it.
180
+ # [NUM_HEADS, Q_D_HEAD]
181
+ q_batch_offset = batch_id * N_HEADS * Q_D_HEAD
182
+ q_head_offsets = head_offsets * Q_D_HEAD
183
+
184
+ # Q layout : BSND
185
+ q = tl.load(
186
+ q_ptr + q_batch_offset + q_head_offsets[:, None] + q_dhead_offsets[None, :],
187
+ mask=head_mask[:, None] * q_dhead_mask[None, :],
188
+ other=0.0,
189
+ )
190
+
191
+ # [BSND]
192
+ k_block_offsets = (
193
+ batch_offset * K_D_HEAD
194
+ + seq_offsets[:, None] * K_D_HEAD * N_KV_HEADS
195
+ + kv_head_id * K_D_HEAD
196
+ + q_dhead_offsets[None, :]
197
+ )
198
+ k_mask = seq_mask[:, None] * q_dhead_mask[None, :] # K and Q share the same head dim
199
+ k = tl.load(k_cache_ptr + k_block_offsets, mask=k_mask, other=0.0)
200
+
201
+ v_block_offsets = (
202
+ batch_offset * V_D_HEAD
203
+ + seq_offsets[:, None] * V_D_HEAD * N_KV_HEADS
204
+ + kv_head_id * V_D_HEAD
205
+ + v_dhead_offsets[None, :]
206
+ )
207
+ v_mask = seq_mask[:, None] * v_dhead_mask[None, :]
208
+
209
+ # [seq_block, V_D_HEAD]
210
+ v = tl.load(v_cache_ptr + v_block_offsets, mask=v_mask, other=0.0)
211
+
212
+ # Note: check the output precision of the sum.
213
+ # compute q*K^T
214
+ # [NUM_HEADS, Q_D_HEAD] * [seq_block, Q_D_HEAD], sum along axis 1
215
+ attn = tl.dot(q, k.trans()) # [N, seq_block]
216
+ attn = attn.to(tl.float32)
217
+ attn *= sm_scale
218
+ max_attn = tl.max(attn, axis=1) # [N, 1]
219
+ # Set to -inf attn values where mask is not set. This forces exp(attn) to 0.
220
+ attn = tl.where(head_mask[:, None] * seq_mask[None, :], attn, float("-inf"))
221
+ exp_attn = tl.exp(attn - max_attn[:, None])
222
+
223
+ sumexp = tl.sum(exp_attn, axis=1) # [N, 1]
224
+
225
+ # [NUM_HEADS, seq_len] * [seq_len, V_D_HEAD], sum along axis 0
226
+ output = tl.dot(exp_attn.to(v.dtype), v)
227
+
228
+ output = output / sumexp[:, None] # [N, D_HEAD]
229
+
230
+ # We store the log-sum-exp after removing the max.
231
+ logsumexp = tl.log(sumexp) + max_attn
232
+ # when seq_mask is all false, max_attn will be -inf and sumexp is zero
233
+
234
+ tl.store(
235
+ output_values_ptr
236
+ + batch_id * N_HEADS * V_D_HEAD * num_blocks
237
+ + head_offsets[:, None] * V_D_HEAD * num_blocks
238
+ + seq_block_id * V_D_HEAD
239
+ + v_dhead_offsets[None, :],
240
+ output,
241
+ mask=head_mask[:, None] * v_dhead_mask[None, :],
242
+ )
243
+ tl.store(
244
+ output_logsumexp_ptr
245
+ + batch_id * N_HEADS * num_blocks
246
+ + head_offsets * num_blocks
247
+ + seq_block_id,
248
+ logsumexp,
249
+ mask=head_mask,
250
+ )
251
+
252
+
253
+ @triton.jit
254
+ def attention_kv_stage1(
255
+ q_ptr, # [Batch, 1, N_HEADS, D_HEAD]
256
+ k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
257
+ v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
258
+ cache_loc_ptr, # [Batch] # Specifies the batch index for each of the generate tokens.
259
+ input_pos_ptr, # [Batch]
260
+ output_values_ptr, # [Batch, N_HEADS, num_blocks, D_HEAD]
261
+ output_logsumexp_ptr, # [Batch, N_HEADS, num_blocks]
262
+ num_blocks,
263
+ MAX_SEQ_LEN: tl.constexpr, # Maximum supported sequence length
264
+ N_HEADS: tl.constexpr, # Number of heads
265
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
266
+ D_HEAD: tl.constexpr, # Dimension of each head.
267
+ SEQ_BLOCK_SIZE: tl.constexpr, # Block size used for tiling the sequence dim.
268
+ ):
269
+ """Attention kernel to be used for generate-only batches.
270
+
271
+ Assumes that kv caches have been updated.
272
+
273
+ Uses flash decoding.
274
+ KV-cache layout is assumed to be [Batch,Seq, Head, Dim]
275
+ 1. Fetch the K-cache from 0 to input_pos
276
+ 2. Fetch the V-cache from 0 to input_pos
277
+ 3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
278
+ 4. S = softmax(A)
279
+ 5. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD]
280
+ """
281
+ # Assume KV-cache layout: [Batch, Seq, Head, Dim]
282
+ # A program is responsible for 1 batch, 1 head and a block of sequences.
283
+ batch_id = tl.program_id(axis=0)
284
+ head_id = tl.program_id(axis=1)
285
+ seq_block_id = tl.program_id(axis=2)
286
+ epsilon: tl.constexpr = 1e-38 # float32 smallest positive number
287
+
288
+ kv_position = tl.load(input_pos_ptr + batch_id)
289
+ kv_batch_id = tl.load(cache_loc_ptr + batch_id)
290
+ kv_batch_offset = kv_batch_id * N_KV_HEADS * MAX_SEQ_LEN * D_HEAD
291
+ # Offsets for the block of sequences this program processes.
292
+ seq_start_pos = seq_block_id * SEQ_BLOCK_SIZE
293
+
294
+ if seq_start_pos > kv_position:
295
+ return
296
+ seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
297
+ seq_mask = seq_offsets <= kv_position
298
+ # Assuming D_HEAD is a power of 2
299
+ dhead_offsets = tl.arange(0, triton.next_power_of_2(D_HEAD))
300
+ dhead_mask = dhead_offsets < D_HEAD
301
+
302
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
303
+ kv_head_offset = (head_id // HEAD_RATIO) * D_HEAD
304
+
305
+ sm_scale: tl.constexpr = 1.0 / (D_HEAD**0.5)
306
+
307
+ # Program loads the entire Q for the head assigned to it.
308
+ # [D_HEAD]
309
+ q_batch_offset = batch_id * N_HEADS * D_HEAD
310
+ q_head_offset = head_id * D_HEAD
311
+ q = tl.load(q_ptr + q_batch_offset + q_head_offset + dhead_offsets, mask=dhead_mask)
312
+
313
+ kv_block_offsets = (
314
+ kv_batch_offset
315
+ + seq_offsets[:, None] * D_HEAD * N_KV_HEADS
316
+ + kv_head_offset
317
+ + dhead_offsets[None, :]
318
+ ) # [BSND]
319
+ kv_mask = seq_mask[:, None] * dhead_mask[None, :]
320
+
321
+ # [seq_block, D_HEAD]
322
+ k = tl.load(k_cache_ptr + kv_block_offsets, mask=kv_mask, other=0.0)
323
+ v = tl.load(v_cache_ptr + kv_block_offsets, mask=kv_mask, other=0.0)
324
+
325
+ # Note: check the output precision of the sum.
326
+ # compute q*K^T
327
+ # [D_HEAD] * [seq_block, D_HEAD], sum along axis 1
328
+ attn = tl.sum(q[None, :].to(tl.float32) * k.to(tl.float32), axis=1) # [seq_block]
329
+
330
+ attn *= sm_scale
331
+ max_attn = tl.max(attn)
332
+ # Set to -inf attn values where mask is not set. This forces exp(attn) to 0.
333
+ attn = tl.where(seq_mask, attn, float("-inf"))
334
+ exp_attn = tl.exp(attn - max_attn)
335
+ exp_attn = tl.where(exp_attn == 0, epsilon, exp_attn)
336
+ sumexp = tl.sum(exp_attn, axis=0) # scalar.
337
+
338
+ # [seq_len] * [seq_len, D_HEAD], sum along axis 0
339
+ output = tl.sum(exp_attn[:, None] * v, axis=0) # [D_HEAD]
340
+
341
+ output = output / sumexp
342
+
343
+ # We store the log-sum-exp after removing the max.
344
+ logsumexp = tl.log(sumexp) + max_attn
345
+ # when seq_mask is all false, max_attn will be -inf and sumexp is zero
346
+
347
+ tl.store(
348
+ output_values_ptr
349
+ + batch_id * N_HEADS * D_HEAD * num_blocks
350
+ + head_id * D_HEAD * num_blocks
351
+ + seq_block_id * D_HEAD
352
+ + dhead_offsets,
353
+ output,
354
+ mask=dhead_mask,
355
+ )
356
+ tl.store(
357
+ output_logsumexp_ptr
358
+ + batch_id * N_HEADS * num_blocks
359
+ + head_id * num_blocks
360
+ + seq_block_id,
361
+ logsumexp,
362
+ )
363
+
364
+
365
+ @triton.jit
366
+ def attention_kv_stage2(
367
+ values_ptr, # [Batch, N_HEADS, num_blocks, D_HEAD]
368
+ logsumexp_ptr, # [Batch, N_HEADS, num_blocks]
369
+ output_ptr, # [Batch, N_HEADS, D_HEAD]
370
+ input_pos_ptr,
371
+ NUM_BLOCKS: tl.constexpr,
372
+ N_HEADS: tl.constexpr,
373
+ D_HEAD: tl.constexpr,
374
+ SEQ_BLOCK_SIZE: tl.constexpr, # Nearest power of 2 for num_blocks
375
+ ):
376
+ # There are batch * N_HEADS programs
377
+ batch_id = tl.program_id(axis=0)
378
+ head_id = tl.program_id(axis=1)
379
+
380
+ dhead_offsets = tl.arange(0, triton.next_power_of_2(D_HEAD))
381
+ dhead_mask = dhead_offsets < D_HEAD
382
+
383
+ kv_position = tl.load(input_pos_ptr + batch_id)
384
+ block_id = kv_position // SEQ_BLOCK_SIZE + 1
385
+
386
+ NUM_BLOCKS_POW2: tl.constexpr = triton.next_power_of_2(NUM_BLOCKS)
387
+ block_offsets = tl.arange(0, NUM_BLOCKS_POW2)
388
+
389
+ block_mask = block_offsets < block_id
390
+ logsumexp = tl.load(
391
+ logsumexp_ptr + batch_id * N_HEADS * NUM_BLOCKS + head_id * NUM_BLOCKS + block_offsets,
392
+ mask=block_mask,
393
+ other=float("-inf"),
394
+ )
395
+ max_logsumexp = tl.max(logsumexp)
396
+ sumexp = tl.exp(logsumexp - max_logsumexp) # [NUM_BLOCKS_POW2]
397
+
398
+ aggregate_sumexp = tl.sum(sumexp, axis=0)
399
+
400
+ values_offsets = block_offsets[:, None] * D_HEAD + dhead_offsets[None, :]
401
+ values_mask = block_mask[:, None] * dhead_mask[None, :]
402
+
403
+ values = tl.load(
404
+ values_ptr
405
+ + batch_id * N_HEADS * D_HEAD * NUM_BLOCKS
406
+ + head_id * D_HEAD * NUM_BLOCKS
407
+ + values_offsets,
408
+ mask=values_mask,
409
+ other=0.0,
410
+ ) # [BLOCK_SIZE, D_HEAD]
411
+ values *= sumexp[:, None]
412
+ values /= aggregate_sumexp
413
+
414
+ output = tl.sum(values, axis=0) # [DHEAD]
415
+
416
+ tl.store(
417
+ output_ptr + batch_id * N_HEADS * D_HEAD + head_id * D_HEAD + dhead_offsets,
418
+ output,
419
+ mask=dhead_mask,
420
+ )
421
+
422
+
423
+ @triton.jit
424
+ def context_attention_kv(
425
+ q_ptr, # [bsnd]
426
+ k_ptr, # [bsnd]
427
+ v_ptr, # [bsnd]
428
+ k_cache_ptr, # [bsnd]
429
+ v_cache_ptr, # [bsnd]
430
+ seq_len,
431
+ o_ptr,
432
+ softmax_scale,
433
+ N_HEADS: tl.constexpr, # Number of heads
434
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
435
+ Q_D_HEAD: tl.constexpr, # Dimension of each query head.
436
+ V_D_HEAD: tl.constexpr, # Dimension of each value head.
437
+ SEQ_BLOCK: tl.constexpr,
438
+ MAX_SEQ_LENGTH: tl.constexpr,
439
+ ):
440
+ """Kernel for context phase.
441
+
442
+ Assuming:
443
+ 1. Self-attention [seqlen(Q) == seqlen(K)]
444
+ 2. Causal attention
445
+ 3. QKV layout: [bsnd]
446
+ """
447
+ batch_id = tl.program_id(axis=0)
448
+ head_id = tl.program_id(axis=1)
449
+ seq_block_id = tl.program_id(axis=2)
450
+
451
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
452
+ K_D_HEAD: tl.constexpr = Q_D_HEAD
453
+
454
+ q_dhead_offsets = tl.arange(0, triton.next_power_of_2(Q_D_HEAD))
455
+ q_dhead_mask = q_dhead_offsets < Q_D_HEAD
456
+
457
+ v_dhead_offsets = tl.arange(0, triton.next_power_of_2(V_D_HEAD))
458
+ v_dhead_mask = v_dhead_offsets < V_D_HEAD
459
+
460
+ seq_offsets = seq_block_id * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
461
+ seq_mask = seq_offsets < seq_len
462
+
463
+ q_load_mask = seq_mask[:, None] * q_dhead_mask[None, :]
464
+
465
+ q_batch_offset = batch_id * seq_len * N_HEADS
466
+ kv_batch_offset = batch_id * seq_len * N_KV_HEADS
467
+
468
+ k_head_offset = (head_id // HEAD_RATIO) * K_D_HEAD
469
+ v_head_offset = (head_id // HEAD_RATIO) * V_D_HEAD
470
+
471
+ # Q will stay in SRAM
472
+ q = tl.load(
473
+ q_ptr
474
+ + q_batch_offset * Q_D_HEAD
475
+ + seq_offsets[:, None] * N_HEADS * Q_D_HEAD
476
+ + head_id * Q_D_HEAD
477
+ + q_dhead_offsets[None, :],
478
+ mask=q_load_mask,
479
+ )
480
+ acc = tl.zeros([SEQ_BLOCK, triton.next_power_of_2(V_D_HEAD)], dtype=tl.float32)
481
+ lse_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
482
+ m_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
483
+
484
+ for s in range(0, seq_block_id + 1, 1):
485
+ kv_seq_offsets = s * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
486
+ kv_seq_mask = kv_seq_offsets < seq_len
487
+ k_load_mask = kv_seq_mask[:, None] * q_dhead_mask[None, :]
488
+
489
+ k = tl.load(
490
+ k_ptr
491
+ + kv_batch_offset * K_D_HEAD
492
+ + kv_seq_offsets[:, None] * N_KV_HEADS * K_D_HEAD
493
+ + k_head_offset
494
+ + q_dhead_offsets[None, :],
495
+ mask=k_load_mask,
496
+ )
497
+ qk = tl.zeros([SEQ_BLOCK, SEQ_BLOCK], dtype=tl.float32)
498
+ qk += tl.dot(q, k.trans())
499
+ # causal mask
500
+ qk = tl.where(seq_offsets[:, None] >= kv_seq_offsets[None, :], qk, float("-inf"))
501
+ qk *= softmax_scale
502
+ # rowmax
503
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
504
+ p = tl.exp(qk - m_ij[:, None]) # [S,S]
505
+ v = tl.load(
506
+ v_ptr
507
+ + kv_batch_offset * V_D_HEAD
508
+ + kv_seq_offsets[:, None] * N_KV_HEADS * V_D_HEAD
509
+ + v_head_offset
510
+ + v_dhead_offsets[None, :],
511
+ mask=kv_seq_mask[:, None] * v_dhead_mask[None, :],
512
+ )
513
+
514
+ l_ij = tl.sum(p, 1)
515
+ acc_scale = tl.exp(m_i - m_ij)
516
+ acc = acc * acc_scale[:, None]
517
+ p = p.to(v.dtype)
518
+ acc += tl.dot(p, v)
519
+ m_i = m_ij
520
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
521
+ lse_i = m_ij + tl.log(l_i_new)
522
+
523
+ o_scale = tl.exp(m_i - lse_i)
524
+
525
+ acc = acc * o_scale[:, None]
526
+
527
+ tl.store(
528
+ o_ptr
529
+ + batch_id * seq_len * N_HEADS * V_D_HEAD
530
+ + seq_offsets[:, None] * N_HEADS * V_D_HEAD
531
+ + head_id * V_D_HEAD
532
+ + v_dhead_offsets[None, :],
533
+ acc,
534
+ mask=seq_mask[:, None] * v_dhead_mask[None, :],
535
+ )
536
+
537
+ # Write back to kv-caches
538
+
539
+ ks = tl.load(
540
+ k_ptr
541
+ + kv_batch_offset * K_D_HEAD
542
+ + seq_offsets[:, None] * N_KV_HEADS * K_D_HEAD
543
+ + k_head_offset
544
+ + q_dhead_offsets[None, :],
545
+ mask=seq_mask[:, None] * q_dhead_mask[None, :],
546
+ )
547
+ vs = tl.load(
548
+ v_ptr
549
+ + kv_batch_offset * V_D_HEAD
550
+ + seq_offsets[:, None] * N_KV_HEADS * V_D_HEAD
551
+ + v_head_offset
552
+ + v_dhead_offsets[None, :],
553
+ mask=seq_mask[:, None] * v_dhead_mask[None, :],
554
+ )
555
+ # cache is [bsnd]
556
+ k_cache_offset = (
557
+ batch_id * N_KV_HEADS * MAX_SEQ_LENGTH * K_D_HEAD
558
+ + seq_offsets[:, None] * K_D_HEAD * N_KV_HEADS
559
+ + k_head_offset
560
+ + q_dhead_offsets[None, :]
561
+ )
562
+
563
+ v_cache_offset = (
564
+ batch_id * N_KV_HEADS * MAX_SEQ_LENGTH * V_D_HEAD
565
+ + seq_offsets[:, None] * V_D_HEAD * N_KV_HEADS
566
+ + v_head_offset
567
+ + v_dhead_offsets[None, :]
568
+ )
569
+ tl.store(k_cache_ptr + k_cache_offset, ks, seq_mask[:, None] * q_dhead_mask[None, :])
570
+ tl.store(v_cache_ptr + v_cache_offset, vs, seq_mask[:, None] * v_dhead_mask[None, :])
571
+
572
+
573
+ @triton.jit
574
+ def context_attention_kv_flattened(
575
+ q_ptr, # [b*s,nd]
576
+ seq_len_ptr, # [b] # length of each sequence in a batch
577
+ seq_start_indices_ptr, # [b] # start indices of a sequence in flattened q/k/v.
578
+ k_cache_ptr, # [bsnd]
579
+ v_cache_ptr, # [bsnd]
580
+ input_pos_ptr, # [b] # specifies the location in the sequence where kv must be written back.
581
+ cache_loc_ptr, # [b] # location of the sequence in the cache.
582
+ o_ptr,
583
+ softmax_scale: tl.constexpr,
584
+ N_HEADS: tl.constexpr, # Number of heads
585
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
586
+ Q_D_HEAD: tl.constexpr, # Dimension of each query head.
587
+ V_D_HEAD: tl.constexpr, # Dimension of each value head.
588
+ SEQ_BLOCK: tl.constexpr,
589
+ MAX_SEQ_LENGTH: tl.constexpr,
590
+ ):
591
+ """Kernel for context phase.
592
+
593
+ Assumes that kv caches have been updated.
594
+ Assuming QKV layout: [b*s,n,d]
595
+ """
596
+ batch_id = tl.program_id(axis=0)
597
+ head_id = tl.program_id(axis=1)
598
+ seq_block_id = tl.program_id(axis=2)
599
+
600
+ # Each program is responsible for a block of tokens in a single batch.
601
+ seq_start_index = tl.load(seq_start_indices_ptr + batch_id)
602
+ seq_len = tl.load(seq_len_ptr + batch_id)
603
+ K_D_HEAD: tl.constexpr = Q_D_HEAD
604
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
605
+
606
+ # cache is [bsnd]
607
+ # cache_loc_ptr stores the batch index for the sequences provided to the kernel.
608
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
609
+
610
+ cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH
611
+ cache_head_offset = head_id // HEAD_RATIO
612
+
613
+ q_dhead_offsets = tl.arange(0, triton.next_power_of_2(Q_D_HEAD))
614
+ q_dhead_mask = q_dhead_offsets < Q_D_HEAD
615
+
616
+ v_dhead_offsets = tl.arange(0, triton.next_power_of_2(V_D_HEAD))
617
+ v_dhead_mask = v_dhead_offsets < V_D_HEAD
618
+
619
+ seq_offsets = seq_block_id * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
620
+ seq_mask = seq_offsets < seq_len
621
+
622
+ # Q will stay in SRAM
623
+ q = tl.load(
624
+ q_ptr
625
+ + seq_start_index * N_HEADS * Q_D_HEAD
626
+ + seq_offsets[:, None] * N_HEADS * Q_D_HEAD
627
+ + head_id * Q_D_HEAD
628
+ + q_dhead_offsets[None, :],
629
+ mask=seq_mask[:, None] * q_dhead_mask[None, :],
630
+ )
631
+
632
+ acc = tl.zeros([SEQ_BLOCK, triton.next_power_of_2(V_D_HEAD)], dtype=tl.float32)
633
+ lse_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
634
+ m_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
635
+
636
+ # Loop over the entire KV-history
637
+ # input_pos_ptr stores the location at which kv must be written back for the given batch.
638
+ kv_position = tl.load(input_pos_ptr + batch_id)
639
+ num_blocks = (kv_position + seq_len + SEQ_BLOCK - 1) // SEQ_BLOCK
640
+ for s in range(0, num_blocks + 1, 1):
641
+ kv_seq_offsets = s * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
642
+ kv_seq_mask = kv_seq_offsets < (kv_position + seq_len)
643
+
644
+ k = tl.load(
645
+ k_cache_ptr
646
+ + cache_batch_offset * K_D_HEAD
647
+ + kv_seq_offsets[:, None] * K_D_HEAD * N_KV_HEADS
648
+ + cache_head_offset * K_D_HEAD
649
+ + q_dhead_offsets[None, :],
650
+ mask=kv_seq_mask[:, None] * q_dhead_mask[None, :],
651
+ )
652
+ qk = tl.zeros([SEQ_BLOCK, SEQ_BLOCK], dtype=tl.float32)
653
+ qk += tl.dot(q, k.trans())
654
+ qk = tl.where(
655
+ (seq_offsets[:, None] + kv_position) >= kv_seq_offsets[None, :], qk, float("-inf")
656
+ )
657
+ qk *= softmax_scale
658
+ # rowmax
659
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
660
+ p = tl.exp(qk - m_ij[:, None])
661
+ v = tl.load(
662
+ v_cache_ptr
663
+ + cache_batch_offset * V_D_HEAD
664
+ + kv_seq_offsets[:, None] * V_D_HEAD * N_KV_HEADS
665
+ + cache_head_offset * V_D_HEAD
666
+ + v_dhead_offsets[None, :],
667
+ mask=kv_seq_mask[:, None] * v_dhead_mask[None, :],
668
+ )
669
+
670
+ l_ij = tl.sum(p, 1)
671
+ acc_scale = tl.exp(m_i - m_ij)
672
+ acc = acc * acc_scale[:, None]
673
+ p = p.to(v.dtype)
674
+ acc += tl.dot(p, v)
675
+ m_i = m_ij
676
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
677
+ lse_i = m_ij + tl.log(l_i_new)
678
+
679
+ o_scale = tl.exp(m_i - lse_i)
680
+
681
+ acc = acc * o_scale[:, None]
682
+
683
+ tl.store(
684
+ o_ptr
685
+ + seq_start_index * N_HEADS * V_D_HEAD
686
+ + seq_offsets[:, None] * N_HEADS * V_D_HEAD
687
+ + head_id * V_D_HEAD
688
+ + v_dhead_offsets[None, :],
689
+ acc,
690
+ mask=seq_mask[:, None] * v_dhead_mask[None, :],
691
+ )
692
+
693
+
694
+ @triton.jit
695
+ def update_kv_cache_rope_fusion(
696
+ q_ptr, # [B*S, N, D]
697
+ k_ptr, # [B*S, N, D]
698
+ v_ptr, # [B*S, N, D]
699
+ seq_len_ptr, # [b] # length of each sequence in a batch
700
+ seq_start_indices_ptr, # [b] # start indices of a sequence in flattened q/k/v.
701
+ q_rope_ptr, # [B*S, N, D], roped q result
702
+ k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
703
+ v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
704
+ input_pos_ptr, # Specifies the sequence index in the caches at which to write the provided kv
705
+ cache_loc_ptr, # Specifies the batch index for each of the input sequences
706
+ f_ptr, # [MAX_SEQ_LEN, D_HEAD//2, 2] # frequencies for rope embadding.
707
+ MAX_SEQ_LENGTH: tl.constexpr,
708
+ N_HEADS: tl.constexpr,
709
+ N_KV_HEADS: tl.constexpr,
710
+ D_HEAD: tl.constexpr,
711
+ SEQ_BLOCK: tl.constexpr,
712
+ HEAD_BLOCK_SIZE: tl.constexpr, # pad to 16 if HEAD_RATIO is < 16 to invoke tensor cores.
713
+ GENERATE_ONLY: tl.constexpr,
714
+ ):
715
+ """Fuse q and k rope with update_kv_cache kernel.
716
+
717
+ The input is interleaved as [2, D//2] in D_HEAD dim.
718
+ Update q_rope with the post-rope-embadding q values.
719
+ Update k_cache with the post-rope-embadding k values.
720
+ For rope computation, q and k need to load and store in tensors pair of 2 * [D//2].
721
+ Update v_cache with v.
722
+ """
723
+ batch_id = tl.program_id(axis=0)
724
+ kv_head_id = tl.program_id(axis=1)
725
+ seq_block_id = tl.program_id(axis=2)
726
+
727
+ # Each program is responsible for a block of tokens in a single batch.
728
+ if GENERATE_ONLY:
729
+ seq_start_index = batch_id
730
+ seq_len: tl.constexpr = 1
731
+ else:
732
+ seq_start_index = tl.load(seq_start_indices_ptr + batch_id)
733
+ seq_len = tl.load(seq_len_ptr + batch_id)
734
+
735
+ # cache is [bsnd]
736
+ # cache_loc_ptr stores the batch index for the sequences provided to the kernel.
737
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
738
+
739
+ kv_position = tl.load(input_pos_ptr + batch_id)
740
+
741
+ cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH * D_HEAD
742
+ cache_head_offset = kv_head_id * D_HEAD
743
+
744
+ # Assuming D_HEAD is a power of 2
745
+ dhead_offsets = tl.arange(0, D_HEAD)
746
+ dhead_mask = dhead_offsets < D_HEAD
747
+
748
+ seq_offsets = seq_block_id * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
749
+ seq_mask = seq_offsets < seq_len
750
+
751
+ load_mask = seq_mask[:, None] * dhead_mask[None, :]
752
+
753
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS # This needs to be a power-of-2
754
+ q_head_offsets = kv_head_id * HEAD_RATIO + tl.arange(0, HEAD_BLOCK_SIZE)
755
+ q_head_mask = q_head_offsets < (kv_head_id * HEAD_RATIO + HEAD_RATIO)
756
+
757
+ q_batch_offset = seq_start_index * N_HEADS * D_HEAD
758
+
759
+ kv_batch_offset = seq_start_index * N_KV_HEADS * D_HEAD
760
+ kv_head_offset = cache_head_offset
761
+
762
+ D2: tl.constexpr = D_HEAD // 2
763
+ # input is interleaved as [2, D//2] in dim [D_HEAD].
764
+ d2_offsets = tl.arange(0, D2)
765
+ dhead_offsets1 = d2_offsets
766
+ dhead_offsets2 = d2_offsets + D2
767
+ d2_mask = dhead_offsets2 < D_HEAD
768
+ d2_load_mask = seq_mask[:, None] * d2_mask[None, :]
769
+
770
+ # offsets of [bsn]
771
+ q_offsets_base = (
772
+ q_batch_offset
773
+ + seq_offsets[:, None, None] * N_HEADS * D_HEAD
774
+ + q_head_offsets[None, :, None] * D_HEAD
775
+ )
776
+ q_offsets1 = q_offsets_base + dhead_offsets1[None, None, :]
777
+ q_offsets2 = q_offsets_base + dhead_offsets2[None, None, :]
778
+ q_mask = d2_load_mask[:, None, :] * q_head_mask[None, :, None]
779
+
780
+ q1 = tl.load(q_ptr + q_offsets1, mask=q_mask).to(tl.float32)
781
+ q2 = tl.load(q_ptr + q_offsets2, mask=q_mask).to(tl.float32)
782
+
783
+ k_offsets_base = kv_batch_offset + seq_offsets[:, None] * N_KV_HEADS * D_HEAD + kv_head_offset
784
+ k_offsets1 = k_offsets_base + dhead_offsets1[None, :]
785
+ k_offsets2 = k_offsets_base + dhead_offsets2[None, :]
786
+
787
+ k1 = tl.load(k_ptr + k_offsets1, mask=d2_load_mask).to(tl.float32)
788
+ k2 = tl.load(k_ptr + k_offsets2, mask=d2_load_mask).to(tl.float32)
789
+
790
+ # -----------------------------------
791
+ # torch version sin/cos
792
+ # cos and sin values are interleaved in frequencies tensor.
793
+ f_offsets = seq_offsets[:, None] * D2 + d2_offsets[None, :]
794
+ cos_ref = tl.load(f_ptr + kv_position * D_HEAD + f_offsets * 2, mask=d2_load_mask).to(
795
+ dtype=tl.float32
796
+ )
797
+ sin_ref = tl.load(f_ptr + kv_position * D_HEAD + f_offsets * 2 + 1, mask=d2_load_mask).to(
798
+ dtype=tl.float32
799
+ )
800
+
801
+ qs1 = cos_ref[:, None, :] * q1 - sin_ref[:, None, :] * q2
802
+ qs2 = sin_ref[:, None, :] * q1 + cos_ref[:, None, :] * q2
803
+
804
+ tl.store(q_rope_ptr + q_offsets1, qs1, mask=q_mask)
805
+ tl.store(q_rope_ptr + q_offsets2, qs2, mask=q_mask)
806
+
807
+ ks1 = cos_ref * k1 - sin_ref * k2
808
+ ks2 = sin_ref * k1 + cos_ref * k2
809
+
810
+ # Write back to kv-caches
811
+ vs = tl.load(
812
+ v_ptr
813
+ + kv_batch_offset
814
+ + seq_offsets[:, None] * N_KV_HEADS * D_HEAD
815
+ + kv_head_offset
816
+ + dhead_offsets[None, :],
817
+ mask=load_mask,
818
+ )
819
+
820
+ kv_writeback_seq_offsets = seq_offsets + kv_position
821
+
822
+ cache_offset_base = (
823
+ cache_batch_offset
824
+ + kv_writeback_seq_offsets[:, None] * D_HEAD * N_KV_HEADS
825
+ + cache_head_offset
826
+ )
827
+
828
+ k_cache_offset1 = cache_offset_base + dhead_offsets1[None, :]
829
+ k_cache_offset2 = cache_offset_base + dhead_offsets2[None, :]
830
+ tl.store(k_cache_ptr + k_cache_offset1, ks1, mask=d2_load_mask)
831
+ tl.store(k_cache_ptr + k_cache_offset2, ks2, mask=d2_load_mask)
832
+
833
+ v_cache_offset = cache_offset_base + dhead_offsets[None, :]
834
+ tl.store(v_cache_ptr + v_cache_offset, vs, load_mask)
835
+
836
+
837
+
838
+ """
839
+ Kernels based on paged KV Cache.
840
+ Parameter infos:
841
+ tensors:
842
+ - q: [b*s, n, d], flattened queries.
843
+ - k/v: [b*s, n, d], flattened key/value.
844
+ - seq_len: [b], length of each sequence in the batch.
845
+ `seq_len` can be 1 (generate) or larger (context).
846
+ - seq_start: [b], start index of each sequence in b*s dim of q/k/v.
847
+ - k_cache/v_cache: [num_pages, PAGE_SIZE, n, d], paged KV Cache.
848
+ New-coming k/v is split into small group of PAGE_SIZE, and then
849
+ mapped to incontinuous memory in KV Cache.
850
+ - page_table: [b, max_num_pages_per_seq], mapping logic of each sequence.
851
+ - cache_loc: [b], mapping logic of `batch_id` in q/k/v to index in `page_table`.
852
+ - cache_len: [b], existing cached k/v length of each sequence.
853
+
854
+ constexpr:
855
+ - N_HEADS/N_KV_HEADS: shape of dim [n] in q or k/v.
856
+ - D_HEAD: shape of dim [d] in q/k/v.
857
+ Assuming power of 2.
858
+ - SEQ_BLOCK: block size to split dim [s].
859
+ Assuming power of 2.
860
+ Split k/v in update kernel and split q in context/generate kernel.
861
+ - MAX_SEQ_LENGTH: seq_len <= MAX_SEQ_LENGTH.
862
+ - PAGE_SIZE: shape of each kv cache page,
863
+ Assuming power of 2 and SEQ_BLOCK % PAGE_SIZE = 0.
864
+ - PAGE_TABLE_STIDE: stride of dim [b] in `page_table`.
865
+
866
+ KV Cache access logic in update kernel:
867
+ 1. batch_id i access k[seq_start[i] : seq_start[i] + seq_len[i]]
868
+ and can be split into pages [a:b] in the sequence.
869
+ 2. Look up cache_len[i] to find if the sequence has cached k/v.
870
+ 3. Look up page_table[cache_loc[i], cache_len[i] + a : cache_len[i] + b]
871
+ to get the corresponding pages in the k_cache, with result [c:d].
872
+ 4. Then update k_cache[c:d] with the k value.
873
+
874
+ """
875
+
876
+
877
+ @triton.jit
878
+ def update_paged_kv_cache(
879
+ k_ptr, # [B*S, N, D]
880
+ v_ptr, # [B*S, N, D]
881
+ seq_len_ptr, # [b] # length of each sequence in a batch
882
+ seq_start_indices_ptr, # [b] # start indices of a sequence in flattened q/k/v.
883
+ k_cache_ptr, # [num_pages, page_size, n, d]
884
+ v_cache_ptr, # [num_pages, page_size, n, d]
885
+ cache_loc_ptr, # [b] # index of the sequence in the page table.
886
+ cache_len_ptr, # [b] # length of the sequence already in kv cache.
887
+ page_table_ptr, # [b, max_num_pages_per_seq] # loc of the block page in the cache.
888
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
889
+ D_HEAD: tl.constexpr, # Dimension of each head.
890
+ SEQ_BLOCK: tl.constexpr,
891
+ MAX_SEQ_LENGTH: tl.constexpr,
892
+ PAGE_SIZE: tl.constexpr,
893
+ PAGE_TABLE_STRIDE: tl.constexpr,
894
+ GENERATE_ONLY: tl.constexpr,
895
+ ):
896
+ batch_id = tl.program_id(axis=0)
897
+ head_id = tl.program_id(axis=1)
898
+ seq_block_id = tl.program_id(axis=2)
899
+
900
+ # Each program is responsible for a block of tokens in a single batch.
901
+ if GENERATE_ONLY:
902
+ seq_start_index = batch_id
903
+ seq_len: tl.constexpr = 1
904
+ else:
905
+ seq_start_index = tl.load(seq_start_indices_ptr + batch_id)
906
+ seq_len = tl.load(seq_len_ptr + batch_id)
907
+
908
+ cache_len = tl.load(cache_len_ptr + batch_id)
909
+
910
+ # cache is [num_pages, page_size, n, d]
911
+ # cache_loc_ptr stores the batch index for the sequences provided to the kernel.
912
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
913
+ cache_head_offset = head_id * D_HEAD
914
+
915
+ # Assuming D_HEAD is a power of 2
916
+ dhead_offsets = tl.arange(0, D_HEAD)
917
+ dhead_mask = dhead_offsets < D_HEAD
918
+
919
+ seq_offsets = seq_block_id * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
920
+ seq_mask = seq_offsets < seq_len
921
+
922
+ load_mask = seq_mask[:, None] * dhead_mask[None, :]
923
+
924
+ kv_batch_offset = seq_start_index * N_KV_HEADS * D_HEAD
925
+ kv_head_offset = cache_head_offset
926
+
927
+ # Write back to kv-caches
928
+ ks = tl.load(
929
+ k_ptr
930
+ + kv_batch_offset
931
+ + seq_offsets[:, None] * N_KV_HEADS * D_HEAD
932
+ + kv_head_offset
933
+ + dhead_offsets[None, :],
934
+ mask=load_mask,
935
+ )
936
+ vs = tl.load(
937
+ v_ptr
938
+ + kv_batch_offset
939
+ + seq_offsets[:, None] * N_KV_HEADS * D_HEAD
940
+ + kv_head_offset
941
+ + dhead_offsets[None, :],
942
+ mask=load_mask,
943
+ )
944
+
945
+ # assuming SEQ_BLOCK can be divided by PAGE_SIZE and PAGE_SIZE is a power of 2.
946
+ SEQ_BLOCK_PAGE: tl.constexpr = SEQ_BLOCK // PAGE_SIZE
947
+ MAX_NUM_PAGES: tl.constexpr = (MAX_SEQ_LENGTH + PAGE_SIZE - 1) // PAGE_SIZE
948
+ # cache_len // PAGE_SIZE means history pages
949
+ # if decode sequence, then seq_len = 1 and only seq_block_id = 0 works,
950
+ kv_pages = seq_block_id * SEQ_BLOCK_PAGE + tl.arange(0, SEQ_BLOCK_PAGE) + cache_len // PAGE_SIZE
951
+ cache_pages = tl.load(
952
+ page_table_ptr + cache_loc * PAGE_TABLE_STRIDE + kv_pages, mask=kv_pages < MAX_NUM_PAGES
953
+ )
954
+
955
+ page_offsets = tl.arange(0, PAGE_SIZE)
956
+ # shape [SEQ_BLOCK], means [cache_pages, page_offsets]
957
+ cache_seq_offset = tl.reshape(
958
+ cache_pages[:, None] * PAGE_SIZE + page_offsets[None, :], [SEQ_BLOCK]
959
+ )
960
+ # write offset inside the page
961
+ cache_seq_offset += cache_len % PAGE_SIZE
962
+
963
+ cache_offsets = (
964
+ cache_seq_offset[:, None] * N_KV_HEADS * D_HEAD + kv_head_offset + dhead_offsets[None, :]
965
+ )
966
+ tl.store(k_cache_ptr + cache_offsets, ks, load_mask)
967
+ tl.store(v_cache_ptr + cache_offsets, vs, load_mask)
968
+
969
+
970
+ # TODO: Write a doc describing the 2 stage algorithm
971
+ @triton.jit
972
+ def attention_kv_paged_stage1(
973
+ q_ptr, # [Batch, 1, N_HEADS, D_HEAD]
974
+ k_cache_ptr, # [NUM_PAGES, PAGE_SIZE, N_HEADS, D_HEAD]
975
+ v_cache_ptr, # [NUM_PAGES, PAGE_SIZE, N_HEADS, D_HEAD]
976
+ cache_loc_ptr, # [Batch] # Specifies the batch index for each of the generate tokens.
977
+ page_table_ptr, # [Batch, num_pages_per_seq]
978
+ cache_len_ptr, # [Batch] # Number of tokens in kv cache.
979
+ output_values_ptr, # [Batch, N_HEADS, num_blocks, D_HEAD]
980
+ output_logsumexp_ptr, # [Batch, N_HEADS, num_blocks]
981
+ num_blocks,
982
+ MAX_SEQ_LEN: tl.constexpr, # Maximum supported sequence length
983
+ N_HEADS: tl.constexpr, # Number of heads
984
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
985
+ D_HEAD: tl.constexpr, # Dimension of each head.
986
+ # Block size used for tiling the sequence dim.
987
+ SEQ_BLOCK_SIZE: tl.constexpr,
988
+ PAGE_SIZE: tl.constexpr,
989
+ PAGE_TABLE_STRIDE: tl.constexpr,
990
+ ):
991
+ """Attention kernel to be used during the generate phase.
992
+
993
+ Uses flash decoding.
994
+ KV-cache layout is assumed to be [Batch, Head, Seq, Dim]
995
+ 1. Fetch the K-cache from 0 to input_pos
996
+ 2. Fetch the V-cache from 0 to input_pos
997
+ 3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
998
+ 4. S = softmax(A)
999
+ 5. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD]
1000
+ """
1001
+ # Assume KV-cache layout: [Batch, Head, Seq, Dim]
1002
+ # A program is responsible for 1 batch, 1 head and a block of sequences.
1003
+ batch_id = tl.program_id(axis=0)
1004
+ head_id = tl.program_id(axis=1)
1005
+ seq_block_id = tl.program_id(axis=2)
1006
+
1007
+ SEQ_BLOCK_PAGE: tl.constexpr = SEQ_BLOCK_SIZE // PAGE_SIZE
1008
+ MAX_NUM_PAGES: tl.constexpr = MAX_SEQ_LEN // PAGE_SIZE
1009
+
1010
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
1011
+ seq_len = tl.load(cache_len_ptr + batch_id)
1012
+ # Offsets for the block of sequences this program processes.
1013
+ seq_start_pos = seq_block_id * SEQ_BLOCK_SIZE
1014
+
1015
+ if seq_start_pos > seq_len:
1016
+ return
1017
+ seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
1018
+ seq_mask = seq_offsets <= seq_len
1019
+ # Assuming D_HEAD is a power of 2
1020
+ dhead_offsets = tl.arange(0, D_HEAD)
1021
+ dhead_mask = dhead_offsets < D_HEAD
1022
+
1023
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
1024
+ cache_head_offset = (head_id // HEAD_RATIO) * D_HEAD
1025
+
1026
+ sm_scale: tl.constexpr = 1 / (D_HEAD**0.5)
1027
+
1028
+ # Program loads the entire Q for the head assigned to it.
1029
+ # [D_HEAD]
1030
+ q_batch_offset = batch_id * N_HEADS * D_HEAD
1031
+ q_head_offset = head_id * D_HEAD
1032
+ q = tl.load(q_ptr + q_batch_offset + q_head_offset + dhead_offsets)
1033
+
1034
+ kv_mask = seq_mask[:, None] * dhead_mask[None, :]
1035
+
1036
+ kv_pages = seq_block_id * SEQ_BLOCK_PAGE + tl.arange(0, SEQ_BLOCK_PAGE)
1037
+ cache_pages = tl.load(
1038
+ page_table_ptr + cache_loc * PAGE_TABLE_STRIDE + kv_pages, mask=kv_pages < MAX_NUM_PAGES
1039
+ )
1040
+
1041
+ page_offsets = tl.arange(0, PAGE_SIZE)
1042
+ # shape [SEQ_BLOCK], means [cache_pages, page_offsets]
1043
+ # token offsets in the paged kv cache
1044
+ cache_seq_offset = tl.reshape(
1045
+ cache_pages[:, None] * PAGE_SIZE + page_offsets[None, :], [SEQ_BLOCK_SIZE]
1046
+ )
1047
+
1048
+ cache_offsets = (
1049
+ cache_seq_offset[:, None] * N_KV_HEADS * D_HEAD + cache_head_offset + dhead_offsets[None, :]
1050
+ )
1051
+
1052
+ k = tl.load(k_cache_ptr + cache_offsets, mask=kv_mask)
1053
+ v = tl.load(v_cache_ptr + cache_offsets, mask=kv_mask)
1054
+
1055
+ # Note: check the output precision of the sum.
1056
+ # compute q*K^T
1057
+ # [D_HEAD] * [seq_block, D_HEAD], sum along axis 1
1058
+ attn = tl.sum(q[None, :] * k, axis=1) # [seq_block]
1059
+ attn = attn.to(tl.float32)
1060
+ attn *= sm_scale
1061
+ max_attn = tl.max(attn)
1062
+ # Set to -inf attn values where mask is not set. This forces exp(attn) to 0.
1063
+ attn = tl.where(seq_mask, attn, float("-inf"))
1064
+ exp_attn = tl.exp(attn - max_attn)
1065
+
1066
+ sumexp = tl.sum(exp_attn, axis=0) # scalar.
1067
+
1068
+ # [seq_len] * [seq_len, D_HEAD], sum along axis 0
1069
+ output = tl.sum(exp_attn[:, None] * v, axis=0) # [D_HEAD]
1070
+
1071
+ output = output / sumexp
1072
+
1073
+ # We store the log-sum-exp after removing the max.
1074
+ logsumexp = tl.log(sumexp) + max_attn
1075
+ # when seq_mask is all false, max_attn will be -inf and sumexp is zero
1076
+
1077
+ tl.store(
1078
+ output_values_ptr
1079
+ + batch_id * N_HEADS * D_HEAD * num_blocks
1080
+ + head_id * D_HEAD * num_blocks
1081
+ + seq_block_id * D_HEAD
1082
+ + dhead_offsets,
1083
+ output,
1084
+ )
1085
+ tl.store(
1086
+ output_logsumexp_ptr
1087
+ + batch_id * N_HEADS * num_blocks
1088
+ + head_id * num_blocks
1089
+ + seq_block_id,
1090
+ logsumexp,
1091
+ )
1092
+
1093
+
1094
+ @triton.jit
1095
+ def context_attention_kv_paged(
1096
+ q_ptr, # [b*s,nd]
1097
+ seq_len_ptr, # [b] # length of each sequence in a batch
1098
+ seq_start_ptr, # [b] # start indices of a sequence in flattened q/k/v.
1099
+ k_cache_ptr, # [num_pages, page_size, n, d]
1100
+ v_cache_ptr, # [num_pages, page_size, n, d]
1101
+ cache_loc_ptr, # [b] # index of the sequence in the page table.
1102
+ cache_len_ptr, # [Batch] # Number of tokens in kv cache.
1103
+ page_table_ptr, # [b, max_num_pages_per_seq] # loc of the block page in the cache.
1104
+ softmax_scale,
1105
+ o_ptr,
1106
+ N_HEADS: tl.constexpr, # Number of heads
1107
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
1108
+ D_HEAD: tl.constexpr, # Dimension of each head.
1109
+ SEQ_BLOCK: tl.constexpr,
1110
+ MAX_SEQ_LENGTH: tl.constexpr,
1111
+ PAGE_SIZE: tl.constexpr,
1112
+ PAGE_TABLE_STRIDE: tl.constexpr,
1113
+ ):
1114
+ """Kernel for context phase.
1115
+
1116
+ Fuses rope
1117
+ Assuming:
1118
+ 1. Self-attention [seqlen(Q) == seqlen(K)]
1119
+ 2. Causal attention
1120
+ 3. QKV layout: [b*s,n,d]
1121
+ """
1122
+ batch_id = tl.program_id(axis=0)
1123
+ head_id = tl.program_id(axis=1)
1124
+ seq_block_id = tl.program_id(axis=2)
1125
+
1126
+ # Each program is responsible for a block of tokens in a single batch.
1127
+ seq_start_index = tl.load(seq_start_ptr + batch_id)
1128
+ seq_len = tl.load(seq_len_ptr + batch_id)
1129
+
1130
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
1131
+
1132
+ # assuming SEQ_BLOCK can be divided by PAGE_SIZE and PAGE_SIZE is a power of 2.
1133
+ SEQ_BLOCK_PAGE: tl.constexpr = SEQ_BLOCK // PAGE_SIZE
1134
+ MAX_NUM_PAGES: tl.constexpr = (MAX_SEQ_LENGTH + PAGE_SIZE - 1) // PAGE_SIZE
1135
+
1136
+ # cache is [num_pages, page_size, n, d]
1137
+ # cache_loc_ptr stores the batch index for the sequences provided to the kernel.
1138
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
1139
+ table_batch_offset = cache_loc * PAGE_TABLE_STRIDE
1140
+
1141
+ # Assuming D_HEAD is a power of 2
1142
+ dhead_offsets = tl.arange(0, D_HEAD)
1143
+ dhead_mask = dhead_offsets < D_HEAD
1144
+
1145
+ seq_offsets = tl.arange(0, SEQ_BLOCK)
1146
+ q_seq_offsets = seq_block_id * SEQ_BLOCK + seq_offsets
1147
+ seq_mask = q_seq_offsets < seq_len
1148
+
1149
+ load_mask = seq_mask[:, None] * dhead_mask[None, :]
1150
+
1151
+ q_batch_offset = seq_start_index * N_HEADS * D_HEAD
1152
+ q_head_offset = head_id * D_HEAD
1153
+ cache_head_offset = (head_id // HEAD_RATIO) * D_HEAD
1154
+
1155
+ # Q will stay in SRAM
1156
+ q = tl.load(
1157
+ q_ptr
1158
+ + q_batch_offset
1159
+ + q_seq_offsets[:, None] * N_HEADS * D_HEAD
1160
+ + q_head_offset
1161
+ + dhead_offsets[None, :],
1162
+ mask=load_mask,
1163
+ )
1164
+ acc = tl.zeros([SEQ_BLOCK, D_HEAD], dtype=tl.float32)
1165
+ lse_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
1166
+ m_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
1167
+
1168
+ cache_len = tl.load(cache_len_ptr + batch_id)
1169
+ total_len = cache_len + seq_len
1170
+ num_blocks = (total_len + SEQ_BLOCK - 1) // SEQ_BLOCK
1171
+ for s in range(0, num_blocks + 1, 1):
1172
+ kv_pages = s * SEQ_BLOCK_PAGE + tl.arange(0, SEQ_BLOCK_PAGE)
1173
+ cache_pages = tl.load(
1174
+ page_table_ptr + table_batch_offset + kv_pages, mask=kv_pages < MAX_NUM_PAGES
1175
+ )
1176
+
1177
+ page_offsets = tl.arange(0, PAGE_SIZE)
1178
+ # shape [SEQ_BLOCK], means [cache_pages, page_offsets]
1179
+ # physical token offsets in the paged kv cache
1180
+ cache_seq_offset = tl.reshape(
1181
+ cache_pages[:, None] * PAGE_SIZE + page_offsets[None, :], [SEQ_BLOCK]
1182
+ )
1183
+ cache_offsets = (
1184
+ cache_seq_offset[:, None] * N_KV_HEADS * D_HEAD
1185
+ + cache_head_offset
1186
+ + dhead_offsets[None, :]
1187
+ )
1188
+
1189
+ # logical kv tokens offsets
1190
+ kv_seq_offsets = s * SEQ_BLOCK + seq_offsets
1191
+ kv_seq_mask = kv_seq_offsets < total_len
1192
+ kv_load_mask = kv_seq_mask[:, None] * dhead_mask[None, :]
1193
+
1194
+ k = tl.load(k_cache_ptr + cache_offsets, mask=kv_load_mask)
1195
+ qk = tl.zeros([SEQ_BLOCK, SEQ_BLOCK], dtype=tl.float32)
1196
+ qk += tl.dot(q, k.trans())
1197
+ # causal mask, need to use kv_seq_offsets
1198
+ qk = tl.where(
1199
+ (q_seq_offsets[:, None] + cache_len) >= kv_seq_offsets[None, :], qk, float("-inf")
1200
+ )
1201
+
1202
+ qk *= softmax_scale
1203
+ # rowmax
1204
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
1205
+ p = tl.exp(qk - m_ij[:, None])
1206
+ v = tl.load(v_cache_ptr + cache_offsets, mask=kv_load_mask)
1207
+
1208
+ l_ij = tl.sum(p, 1)
1209
+ acc_scale = tl.exp(m_i - m_ij)
1210
+ acc = acc * acc_scale[:, None]
1211
+ p = p.to(v.dtype)
1212
+ acc += tl.dot(p, v)
1213
+ m_i = m_ij
1214
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
1215
+ lse_i = m_ij + tl.log(l_i_new)
1216
+
1217
+ o_scale = tl.exp(m_i - lse_i)
1218
+
1219
+ acc = acc * o_scale[:, None]
1220
+
1221
+ tl.store(
1222
+ o_ptr
1223
+ + q_batch_offset
1224
+ + q_seq_offsets[:, None] * N_HEADS * D_HEAD
1225
+ + q_head_offset
1226
+ + dhead_offsets[None, :],
1227
+ acc,
1228
+ mask=load_mask,
1229
+ )
1230
+
1231
+
1232
+
1233
+ @dataclass
1234
+ class PositionalEmbeddingConfig:
1235
+ """A dataclass to hold positional embedding information."""
1236
+
1237
+ mode: Optional[Literal["rope"]] = None
1238
+ rope_theta: float = 10000.0
1239
+ rope_scale: float = 1.0
1240
+
1241
+ def __post_init__(self):
1242
+ assert self.mode in [None, "rope"], f"Invalid mode: {self.mode}."
1243
+ if self.mode == "rope":
1244
+ assert self.rope_theta > 0, f"Invalid rope theta: {self.rope_theta}."
1245
+
1246
+
1247
+ @dataclass
1248
+ class CacheConfig:
1249
+ """A dataclass to hold information how to configure the cache."""
1250
+
1251
+ dtype: Optional[torch.dtype] = None
1252
+
1253
+
1254
+ @dataclass
1255
+ class AttentionInfo:
1256
+ """Information about the attention op.
1257
+
1258
+ This is the dataclass collected by the kvcache transformation and passed in to the
1259
+ AttentionDescriptor methods to inform the attention op about the attention configuration.
1260
+ """
1261
+
1262
+ num_heads: int
1263
+ num_kv_heads: int
1264
+ head_dim: int # embedding size of each head
1265
+ dtype: torch.dtype
1266
+
1267
+ cache_config: CacheConfig
1268
+ pos_embd_config: PositionalEmbeddingConfig
1269
+ # rope_dim represents embedding size of decoupled q/k that carry rope information
1270
+ # when rope_dim != 0 the decoupled q/k tensor carrying rope information is the last part of the tensor [-rope_dim: ]
1271
+ rope_dim: Optional[int] = 0
1272
+
1273
+
1274
+ @dataclass
1275
+ class SequenceInfo:
1276
+ """A dataclass to hold information about how the sequence is laid out and stored in cache.
1277
+
1278
+ We assume the sequence + cache is laid out in the following way:
1279
+
1280
+ - input_ids: [id_0, ..., id_{s_total-1}]
1281
+ flattened sequence of [b, 1] or [1, s_total]. We use [b, 1] to denote generate-only batches.
1282
+ - seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i)
1283
+ Describes how long each sequence is. For example,
1284
+ input_ids[:s_0] will correspond to sequence 0 in the batch and input_ids[s_0:s_1] will
1285
+ correspond to sequence 1 in the batch.
1286
+ - input_pos: [pos_0, ..., pos_{b-1}]
1287
+ Corresponds to the total number of tokens that has been already been cached for each sequence
1288
+ in the batch.
1289
+ - cache_loc: [c0, ...., c_{np-1}] where np is total number of pages allocated to describe all
1290
+ sequences in the batch.
1291
+ - pages_per_seq: [ps_0, ps_1, ..., ps_{b-1}] where ps_i is the number of pages allocated for
1292
+ sequence i. Note that, for example, cache_loc[p_0:p_1] will correspond to the pages associated
1293
+ with sequence 1 in the batch.
1294
+
1295
+ Here are a couple of notes to emphasize this notation:
1296
+
1297
+ - The total number of allocated token space for sequence i is given by ps_i * page_size. This is
1298
+ the total number of tokens that can be cached for each sequence.
1299
+
1300
+ - NOTE: It must hold that pos_i + s_i <= ps_i * page_size for all i in [0, b-1]. Moreover, it is
1301
+ the responsibility of the cache manager and/or runtime to ensure sufficient page allocation
1302
+ for each sequence.
1303
+
1304
+ """
1305
+
1306
+ ## USE TO INITIALIZE DATA CLASS ###############################################################
1307
+ # max_seq_len corresponds the maximum number of tokens in any sequence. It includes the tokens in the
1308
+ # input sequence and the tokens generated by the model.
1309
+ max_seq_len: int = 1
1310
+ # max_batch_size corresponds to the maximum number of sequences (or requests) that the model can process.
1311
+ max_batch_size: int = 1
1312
+ # page_size is the granularity with which the cache pages are allocated for a paged kv cache.
1313
+ # For an unpaged cache, the page size should be set to max_seq_len.
1314
+ # Also note that two sequences in a batch can not share a page.
1315
+ page_size: int = 0
1316
+ # max_num_tokens is the maximum number of tokens that the model can process across all sequences in the batch.
1317
+ # If a batch is composed of context-only requests of input sequence length ISL,
1318
+ # then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens // ISL).
1319
+ # Similarly, if a batch is composed of generate-only requests,
1320
+ # then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens).
1321
+ max_num_tokens: int = 0
1322
+
1323
+ ## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP #################
1324
+ # input_ids MUST ALWAYS BE THE FIRST FIELD
1325
+ input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.int))
1326
+ seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
1327
+ input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
1328
+ cache_loc: torch.Tensor = field(default_factory=lambda: torch.arange(1, dtype=torch.int))
1329
+ pages_per_seq: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
1330
+ ################################################################################################
1331
+
1332
+ ## PRIVATE FIELDS ##############################################################################
1333
+ _sequence_lengths: List[int] = field(default_factory=list)
1334
+ _num_pages: int = 1
1335
+
1336
+ def __post_init__(self):
1337
+ if self.page_size < 1:
1338
+ self.page_size = self.max_seq_len
1339
+ if self.max_num_tokens < 1:
1340
+ self.max_num_tokens = self.max_batch_size * self.max_seq_len
1341
+ # if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
1342
+ # we use the provided max_num_tokens to calculate the number of pages
1343
+ total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len)
1344
+ self._num_pages = (total_tokens) // self.page_size + (total_tokens % self.page_size > 0)
1345
+ self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
1346
+ self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
1347
+ self.input_pos = torch.empty_like(self.seq_len)
1348
+ self.cache_loc = torch.empty(self.num_pages, dtype=torch.int)
1349
+ self.pages_per_seq = torch.empty_like(self.seq_len)
1350
+
1351
+ # dynamic shape descriptors for tensor args
1352
+ self._dynamic_shapes: Optional[Tuple[Dict[str, Dim]]] = None
1353
+
1354
+ # keep a list-like object of sequence lengths for simplicity as well
1355
+ self._sequence_lengths = [0] * self.max_batch_size
1356
+
1357
+ # call reset once to initialize the tensors
1358
+ self.reset()
1359
+
1360
+ @property
1361
+ def device(self) -> torch.device:
1362
+ return self.input_pos.device
1363
+
1364
+ @property
1365
+ def args(self) -> List[torch.Tensor]:
1366
+ args = []
1367
+ for f in fields(self):
1368
+ val = getattr(self, f.name)
1369
+ if isinstance(val, torch.Tensor):
1370
+ args.append(val)
1371
+ return args
1372
+
1373
+ @property
1374
+ def extra_arg_names(self) -> List[str]:
1375
+ """Return extra arg names for the prepare_metadata op beyond input_ids."""
1376
+ return [f.name for f in fields(self) if isinstance(getattr(self, f.name), torch.Tensor)][1:]
1377
+
1378
+ @property
1379
+ def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]:
1380
+ """Return dynamic shapes of sequence info tensors.
1381
+
1382
+ NOTE: will be lazily initialized since the Dim object is not picklable for multi-processing.
1383
+ """
1384
+ if self._dynamic_shapes is None:
1385
+ dynamic_shapes = ({},)
1386
+ if self.max_batch_size > 1:
1387
+ dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size)
1388
+ dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len)
1389
+ dynamic_shapes += ({},) * len(self.extra_arg_names)
1390
+ self._dynamic_shapes = dynamic_shapes
1391
+ return self._dynamic_shapes
1392
+
1393
+ @property
1394
+ def num_sequences(self) -> int:
1395
+ return len(self._sequence_lengths)
1396
+
1397
+ @property
1398
+ def sequence_lengths(self) -> List[int]:
1399
+ return self._sequence_lengths
1400
+
1401
+ @property
1402
+ def input_positions(self) -> List[int]:
1403
+ return self.input_pos[: self.num_sequences].tolist()
1404
+
1405
+ @property
1406
+ def is_generate(self) -> bool:
1407
+ return all(sl == 1 for sl in self.sequence_lengths)
1408
+
1409
+ @property
1410
+ def num_pages(self) -> int:
1411
+ return self._num_pages
1412
+
1413
+ @num_pages.setter
1414
+ def num_pages(self, value):
1415
+ self._num_pages = value
1416
+ # update the cache_loc tensor
1417
+ self.cache_loc.resize_(value)
1418
+
1419
+ @property
1420
+ def is_paged(self) -> bool:
1421
+ return self.page_size < self.max_seq_len
1422
+
1423
+ @property
1424
+ def page_assignments(self) -> List[List[int]]:
1425
+ """Return the page assignments for each sequence."""
1426
+ pages_per_seq = self.pages_per_seq[: self.num_sequences].tolist()
1427
+ return [
1428
+ c_loc_one_seq.tolist()
1429
+ for c_loc_one_seq in torch.split(self.cache_loc[: sum(pages_per_seq)], pages_per_seq)
1430
+ ]
1431
+
1432
+ @classmethod
1433
+ def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
1434
+ """Sanitize sequence lengths.
1435
+
1436
+ We want to cover the following scenarios with this function:
1437
+
1438
+ 1. Pre-fill:
1439
+ input_ids: [1, s_total, ...]
1440
+ seq_len: [s_0, s_1, ..., s_{b-1}, 0, 0, ..., 0]
1441
+ ---> returns [s_0, s_1, ..., s_{b-1}]
1442
+ 2. Decode:
1443
+ input_ids: [b, 1, ...]
1444
+ seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
1445
+ |---- b ----|--- (max_batch_size - b) ---|
1446
+ --> returns [1,] * b
1447
+ 3. Decode in Cudagraph:
1448
+ input_ids: [b_cudagraph, 1, ...]
1449
+ seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
1450
+ |---- b ----|--- (max_batch_size - b) ---|
1451
+
1452
+ --> returns [1,] * b_cudagraph
1453
+ Here b <= b_cudagraph. We want to make sure that the seq_len is one-padded to
1454
+ b_cudagraph.
1455
+
1456
+ # TODO: I could see one possible issue with this approach in the future.
1457
+ # If we have b < b_cudagraph we now one-pad. However, we don't pad the cache location
1458
+ # information. What could happen is that the for the padded sequences the cache location
1459
+ # tensors point to allocated pages. This could lead to a situation where we write into
1460
+ # allocated cache pages polluting the cache of other sequences. Now this is not an issue
1461
+ # if we write the dummy sequences into unallocated cache pages... One fix could be to
1462
+ # pad not only the seq len but also pad the cache locations by just repeating the last
1463
+ # valid cache location in the batch. This would ensure that the dummy sequences just
1464
+ # repeats valid computation...
1465
+ """
1466
+ _, s = input_ids.shape[:2]
1467
+ num_seq = cls._get_sanitized_num_sequences(input_ids, seq_len)
1468
+ if s > 1:
1469
+ return seq_len[:num_seq].detach().clone()
1470
+ else:
1471
+ return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)
1472
+
1473
+ @staticmethod
1474
+ def _get_sanitized_num_sequences(input_ids: torch.Tensor, seq_len: torch.Tensor) -> int:
1475
+ """Get number of sequences.
1476
+
1477
+ We makes sure that this function is compatible with both torch graph capture and cudagraph.
1478
+ Both can be a bit temparamental when trying to extract the number of sequences from a tensor
1479
+ with max_batch_size or max_batch_size*max_seq_len.
1480
+ """
1481
+ b, s = input_ids.shape[:2]
1482
+ if s > 1:
1483
+ num_seq = torch.sum(seq_len > 0)
1484
+ assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
1485
+ else:
1486
+ num_seq = b
1487
+ return num_seq
1488
+
1489
+ def to(self, *args, **kwargs) -> None:
1490
+ for f in fields(self):
1491
+ val = getattr(self, f.name)
1492
+ if isinstance(val, torch.Tensor):
1493
+ setattr(self, f.name, val.to(*args, **kwargs))
1494
+
1495
+ def sync(self, other: "SequenceInfo") -> None:
1496
+ for f in fields(self):
1497
+ val = getattr(self, f.name)
1498
+ val_other = getattr(other, f.name)
1499
+ if f.name == "input_ids":
1500
+ setattr(self, f.name, val_other.to(self.device))
1501
+ elif f.name == "_sequence_lengths":
1502
+ self._sequence_lengths = val_other
1503
+ elif isinstance(val, torch.Tensor):
1504
+ val[: len(val_other)] = val_other.to(self.device)
1505
+ else:
1506
+ assert val == val_other, f"Field {f.name} mismatch: {val} != {val_other}."
1507
+
1508
+ def reset(self) -> None:
1509
+ """Reset the sequence information.
1510
+
1511
+ After reset the sequence information should correspond to a "generate-only" batch of
1512
+ sequences (b, s==1) without cache history.
1513
+ """
1514
+ # set a dummy sequence corresponding to a generate-only batch
1515
+ self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int))
1516
+
1517
+ # reset cache information
1518
+ self.input_pos.zero_()
1519
+ self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device)
1520
+ self.pages_per_seq.fill_(1)
1521
+
1522
+ def _set_example_sequence(self) -> None:
1523
+ """Set an example sequence for export purposes."""
1524
+ self.reset()
1525
+ input_ids = torch.ones(
1526
+ min(2, self.max_batch_size),
1527
+ min(4, self.max_seq_len),
1528
+ dtype=torch.int,
1529
+ device=self.device,
1530
+ )
1531
+ self.nest_sequences(input_ids)
1532
+ self.input_ids = input_ids
1533
+
1534
+ def _set_max_num_tokens_sample(self) -> None:
1535
+ """Set an example sequence with max_num_tokens."""
1536
+ self.reset()
1537
+ seq_len = self.max_num_tokens // self.max_batch_size
1538
+ input_ids = torch.ones(
1539
+ self.max_batch_size,
1540
+ seq_len,
1541
+ dtype=torch.int,
1542
+ device=self.device,
1543
+ )
1544
+ self.pages_per_seq.fill_(seq_len // self.page_size)
1545
+ self.nest_sequences(input_ids)
1546
+
1547
+ def _set_generate_only_batch(self) -> None:
1548
+ """Set an example sequence for generate-only batch."""
1549
+ self.reset()
1550
+ self.nest_sequences([[1]] * self.max_batch_size)
1551
+
1552
+ def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
1553
+ """Create and store a flattened list of input_ids from the provided list of sequences.
1554
+
1555
+ This i/f will also update any relevant sequence information.
1556
+ """
1557
+ # set new sequence lengths
1558
+ seq_lens = [len(ids) for ids in input_ids]
1559
+ self.seq_len.zero_()
1560
+ self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True)
1561
+
1562
+ # set new input_ids as new tensor from flattened input_ids
1563
+ ids_tnsr_list = [
1564
+ lst.detach() if isinstance(lst, torch.Tensor) else torch.tensor(lst, dtype=torch.int)
1565
+ for lst in input_ids
1566
+ ]
1567
+ self.input_ids = torch.cat(ids_tnsr_list, dim=0).to(self.device)
1568
+
1569
+ # set derivative properties
1570
+ self._sequence_lengths = seq_lens
1571
+
1572
+ # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
1573
+ if self.is_generate:
1574
+ self.input_ids = self.input_ids.view(-1, 1, *self.input_ids.shape[1:])
1575
+ else:
1576
+ self.input_ids = self.input_ids.view(1, -1, *self.input_ids.shape[1:])
1577
+
1578
+ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
1579
+ t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
1580
+ return list(torch.split(t_squeezed, self.sequence_lengths))
1581
+
1582
+ def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None:
1583
+ """Update the starting position for each sequence in the cache.
1584
+
1585
+ If ``reset=True`, ``input_pos`` will be reset to zero before updating.
1586
+ """
1587
+ if not isinstance(seq_len, torch.Tensor):
1588
+ seq_len = torch.tensor(seq_len, dtype=torch.int)
1589
+ bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size
1590
+
1591
+ if reset:
1592
+ self.input_pos[:bs] = seq_len.to(self.device)
1593
+ else:
1594
+ self.input_pos[:bs] += seq_len.to(self.device)
1595
+
1596
+ def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None:
1597
+ """Set the cache location and pages_per_seq tensors from page assignments."""
1598
+ cache_loc_flat = torch.tensor(
1599
+ [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int
1600
+ )
1601
+ self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True)
1602
+
1603
+ pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int)
1604
+ self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True)
1605
+
1606
+
1607
+ Constant = Union[int, float, str, None]
1608
+
1609
+
1610
+ class MHACallable(Protocol):
1611
+ def __call__(
1612
+ self,
1613
+ *qkv_metadata_and_caches: Union[torch.Tensor, Constant],
1614
+ ) -> torch.Tensor: ...
1615
+
1616
+
1617
+ class PrepareMetadataCallable(Protocol):
1618
+ def __call__(
1619
+ self,
1620
+ input_ids: torch.Tensor,
1621
+ seq_len: torch.Tensor,
1622
+ input_pos: torch.Tensor,
1623
+ cache_loc: torch.Tensor,
1624
+ pages_per_seq: torch.Tensor,
1625
+ page_size: int,
1626
+ ) -> List[torch.Tensor]: ...
1627
+
1628
+
1629
+ class GetCacheCallable(Protocol):
1630
+ def __call__(self, sequence_info: SequenceInfo) -> torch.Tensor: ...
1631
+
1632
+
1633
+ class GetBufferCallable(GetCacheCallable):
1634
+ pass
1635
+
1636
+
1637
+ class GetAttentionInfo(Protocol):
1638
+ def __call__() -> AttentionInfo: ...
1639
+
1640
+
1641
+ CacheInitializerDict = Dict[str, GetCacheCallable]
1642
+ BufferInitializerDict = Dict[str, GetBufferCallable]
1643
+
1644
+
1645
+ class AttentionDescriptor(ABC):
1646
+ """An interface to define a functional attention operator.
1647
+
1648
+ The main logic is contained with the actual attention op as well as the prepare_metadata op. The
1649
+ prepare_metadata op is responsible for converting the standardized sequence info into metadata
1650
+ specific to the attention op.
1651
+ """
1652
+
1653
+ @classmethod
1654
+ @abstractmethod
1655
+ def is_paged(cls) -> bool:
1656
+ """Return if the attention op is paged or not."""
1657
+
1658
+ @classmethod
1659
+ def get_attention_op(cls) -> Tuple[MHACallable, int]:
1660
+ """Get the attention op and the number of arguments corresponding to qkv.
1661
+
1662
+ The attention_op should follow the below signature:
1663
+
1664
+ ```
1665
+ def attention_op(
1666
+ *qkv, # list of tensors corresponding to Q, K, V as in original op
1667
+ *metadata, # global info about the sequences as returned by the prepare_metadata op
1668
+ *caches, # contains layer-specific caches per provided cache initializers
1669
+ *buffers, # global buffers used by the attention op as provided by buffer initializers
1670
+ *constants, # basic arguments (int, float, str, None) added as CONSTANTS in the graph
1671
+ ) -> torch.Tensor: ...
1672
+ ```
1673
+
1674
+ **Note that the attention op should be a valid torch custom op, which comes with
1675
+ restrictions on the supported types in the signature.**
1676
+
1677
+ **Note that the `qkv` tuple should be consistent across both the cached attention
1678
+ op and the op that it is replacing.**
1679
+
1680
+ """
1681
+ raise NotImplementedError
1682
+
1683
+ @classmethod
1684
+ @abstractmethod
1685
+ def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
1686
+ """Get the prepare_metadata op.
1687
+
1688
+ The prepare_metadata op should follow the below signature:
1689
+
1690
+ ```
1691
+ def prepare_metadata(
1692
+ input_ids: torch.Tensor,
1693
+ seq_len: torch.Tensor,
1694
+ input_pos: torch.Tensor,
1695
+ cache_loc: torch.Tensor,
1696
+ ) -> List[torch.Tensor]: ...
1697
+ ```
1698
+ The metadata should contain all necessary global information required for the underlying
1699
+ attention op to process the input sequence and the returned list of tensors will be passed
1700
+ on to each invocation of the attention op in the graph.
1701
+
1702
+ prepare_metadata is called once at the beginning of the forward pass.
1703
+
1704
+ **Note that the prepare_metadata op should be a valid torch custom op, which comes with
1705
+ restrictions on the supported types in the signature.**
1706
+ """
1707
+ return NotImplementedError
1708
+
1709
+ @classmethod
1710
+ @abstractmethod
1711
+ def get_cache_initializers(cls, get_info: GetAttentionInfo) -> CacheInitializerDict:
1712
+ """Provide a dictionary of function pointers that can be used to initialize the caches.
1713
+
1714
+ The key corresponds to the argument name used in the attention op signature. The function
1715
+ key doesn't need to be unique across multiple attention nodes in the graph. The key used to
1716
+ describe the cache in the graph will be patched with the attention node index to ensure
1717
+ uniqueness.
1718
+
1719
+ ``get_cache_initializers`` will be called *once* after the model initialization and before
1720
+ the initial forward pass for each attention op detected in the graph. The caches will be
1721
+ managed by the global CacheManager and passed back to the attention op during the forward
1722
+ pass.
1723
+
1724
+ If the cache initializer requires information about the attention op, the ``get_info``
1725
+ function can be called **inside** the cache initializer to retrieve the necessary
1726
+ information.
1727
+ """
1728
+ raise NotImplementedError
1729
+
1730
+ @classmethod
1731
+ def get_global_buffer_initializers(cls, get_info: GetAttentionInfo) -> BufferInitializerDict:
1732
+ """Provide a dictionary of function pointers that can be used to initialize buffers.
1733
+
1734
+ The key corresponds to the buffer name used in the graph module and will **not**
1735
+ be patched unlike a cache key. Hence, it is a **global** key that is shared across all
1736
+ attention ops in the model much like a regular buffer in an nn.Module. That means if this
1737
+ i/f is called for multiple attention ops, the same buffer will be shared across all of them
1738
+ if this function provides the same key multiple times.
1739
+
1740
+ Buffers are initialize *once* after the model initialization and before the initial forward
1741
+ pass for each attention op detected in the graph. The buffer will be managed by the global
1742
+ CacheManager and passed back to the attention op during the forward pass.
1743
+
1744
+ If the buffer initializer requires information about the attention op, the ``get_info``
1745
+ function can be called **inside** the buffer initializer to retrieve the necessary
1746
+ information.
1747
+ """
1748
+ return {}
1749
+
1750
+ @classmethod
1751
+ def get_constants(cls, attention_info: AttentionInfo) -> List[Constant]:
1752
+ """Provide a list of constant arguments to be passed to the attention op.
1753
+
1754
+ The constant arguments are passed to the attention op as additional arguments after the
1755
+ caches and buffers. The constants are expected to be of type int, float, str, or None.
1756
+ """
1757
+ return []
1758
+
1759
+
1760
+ class AttentionRegistry:
1761
+ """A simple registry to look up different attention implementations."""
1762
+
1763
+ _attention_registry: Dict[str, Type["AttentionDescriptor"]] = {}
1764
+
1765
+ @classmethod
1766
+ def register(cls, kernel_source: str) -> Type["AttentionDescriptor"]:
1767
+ def decorator(attention_cls: Type["AttentionDescriptor"]):
1768
+ assert kernel_source not in cls._attention_registry, (
1769
+ f"Attention source {kernel_source} already registered."
1770
+ )
1771
+ cls._attention_registry[kernel_source] = attention_cls
1772
+ return attention_cls
1773
+
1774
+ return decorator
1775
+
1776
+ @classmethod
1777
+ def get(cls, kernel_source: str) -> Type["AttentionDescriptor"]:
1778
+ assert cls.has(kernel_source), f"Attention source {kernel_source} not registered."
1779
+ return cls._attention_registry[kernel_source]
1780
+
1781
+ @classmethod
1782
+ def has(cls, kernel_source: str) -> bool:
1783
+ return kernel_source in cls._attention_registry
1784
+
1785
+
1786
+
1787
+ @torch.library.custom_op("attention::scaled_dot_product_attention", mutates_args=())
1788
+ def scaled_dot_product_attention(
1789
+ query: torch.Tensor,
1790
+ key: torch.Tensor,
1791
+ value: torch.Tensor,
1792
+ attn_mask: Optional[torch.Tensor] = None,
1793
+ dropout_p: float = 0.0,
1794
+ is_causal: bool = False,
1795
+ scale: Optional[float] = None,
1796
+ ) -> torch.Tensor:
1797
+ """A carbon copy of torch.nn.functional.scaled_dot_product_attention as custom op.
1798
+
1799
+ Using this custom op instead of using the functional directly ensures consistent representation
1800
+ of the vanilla sdpa in a graph.
1801
+ """
1802
+ return F.scaled_dot_product_attention(
1803
+ query,
1804
+ key,
1805
+ value,
1806
+ attn_mask=attn_mask,
1807
+ dropout_p=dropout_p,
1808
+ is_causal=is_causal,
1809
+ scale=scale,
1810
+ )
1811
+
1812
+
1813
+ @scaled_dot_product_attention.register_fake
1814
+ def scaled_dot_product_attention_fake(
1815
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
1816
+ ):
1817
+ """Fake implementation of scaled_dot_product_attention."""
1818
+ return torch.empty_like(query)
1819
+
1820
+
1821
+ def _generate_mha(
1822
+ q: torch.Tensor,
1823
+ k: torch.Tensor,
1824
+ v: torch.Tensor,
1825
+ k_cache: torch.Tensor,
1826
+ v_cache: torch.Tensor,
1827
+ cache_locs: torch.Tensor,
1828
+ input_pos: torch.Tensor,
1829
+ out: torch.Tensor,
1830
+ ):
1831
+ b, (n_heads, q_d_head) = q.shape[0], q.shape[-2:]
1832
+ max_seq_len, n_kv_heads = k_cache.shape[1:3]
1833
+ v_d_head = v.shape[-1]
1834
+ device = q.device
1835
+
1836
+ HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
1837
+ SEQ_BLOCK_SIZE = 256
1838
+ num_blocks = (max_seq_len + SEQ_BLOCK_SIZE - 1) // SEQ_BLOCK_SIZE
1839
+
1840
+ stage1_output_values = torch.empty(
1841
+ b, n_heads, num_blocks, v_d_head, device=device, dtype=torch.float32
1842
+ )
1843
+ stage1_output_logsumexp = torch.empty(
1844
+ b, n_heads, num_blocks, device=device, dtype=torch.float32
1845
+ ) - float("inf")
1846
+
1847
+ (
1848
+ update_kv_cache[(b, n_kv_heads, 1)](
1849
+ k,
1850
+ v,
1851
+ None,
1852
+ None,
1853
+ k_cache,
1854
+ v_cache,
1855
+ input_pos,
1856
+ cache_locs,
1857
+ max_seq_len,
1858
+ n_kv_heads,
1859
+ q_d_head,
1860
+ v_d_head,
1861
+ 1,
1862
+ GENERATE_ONLY=True,
1863
+ ),
1864
+ )
1865
+
1866
+ gqa_attention_kv_stage1[
1867
+ (
1868
+ b,
1869
+ n_kv_heads,
1870
+ num_blocks,
1871
+ )
1872
+ ](
1873
+ q,
1874
+ k_cache,
1875
+ v_cache,
1876
+ cache_locs,
1877
+ input_pos,
1878
+ stage1_output_values,
1879
+ stage1_output_logsumexp,
1880
+ num_blocks,
1881
+ max_seq_len,
1882
+ n_heads,
1883
+ n_kv_heads,
1884
+ q_d_head,
1885
+ v_d_head,
1886
+ SEQ_BLOCK_SIZE,
1887
+ HEAD_BLOCK_SIZE,
1888
+ )
1889
+ attention_kv_stage2[(b, n_heads, 1)](
1890
+ stage1_output_values,
1891
+ stage1_output_logsumexp,
1892
+ out,
1893
+ input_pos,
1894
+ num_blocks,
1895
+ n_heads,
1896
+ v_d_head,
1897
+ SEQ_BLOCK_SIZE,
1898
+ )
1899
+
1900
+
1901
+ def _context_mha(
1902
+ q: torch.Tensor,
1903
+ k: torch.Tensor,
1904
+ v: torch.Tensor,
1905
+ k_cache: torch.Tensor,
1906
+ v_cache: torch.Tensor,
1907
+ out: torch.Tensor,
1908
+ ):
1909
+ b, s, n_heads, q_d_head = q.shape
1910
+ max_seq_len, n_kv_heads = k_cache.shape[1:3]
1911
+ v_d_head = v.shape[-1]
1912
+
1913
+ SEQ_BLOCK = 128
1914
+ softmax_scale = 1.0 / math.sqrt(q_d_head)
1915
+ grid = (b, n_heads, (s + SEQ_BLOCK - 1) // SEQ_BLOCK)
1916
+ context_attention_kv[grid](
1917
+ q,
1918
+ k,
1919
+ v,
1920
+ k_cache,
1921
+ v_cache,
1922
+ s,
1923
+ out,
1924
+ softmax_scale,
1925
+ n_heads,
1926
+ n_kv_heads,
1927
+ q_d_head,
1928
+ v_d_head,
1929
+ SEQ_BLOCK,
1930
+ max_seq_len,
1931
+ num_stages=2,
1932
+ )
1933
+
1934
+
1935
+ @torch.library.custom_op("attention::fused_mha_with_cache", mutates_args=())
1936
+ def fused_mha_with_cache(
1937
+ q: torch.Tensor,
1938
+ k: torch.Tensor,
1939
+ v: torch.Tensor,
1940
+ input_pos: torch.Tensor,
1941
+ k_cache: torch.Tensor,
1942
+ v_cache: torch.Tensor,
1943
+ freqs_cis: Optional[torch.Tensor],
1944
+ ) -> torch.Tensor:
1945
+ """Fused MHA with cache that takes raw input from q, k, v GEMMs."""
1946
+ # b, s info
1947
+ b, s = q.shape[:2]
1948
+ head_dim = k_cache.shape[-1]
1949
+
1950
+ # reshapes with num_heads and head_dim
1951
+ q = q.view(b, s, -1, head_dim)
1952
+ k = k.view(b, s, -1, head_dim)
1953
+ v = v.view(b, s, -1, head_dim)
1954
+
1955
+ # rope embedding
1956
+ if freqs_cis is not None:
1957
+ q = torch.ops.rope.apply_rope_with_input_pos(q, freqs_cis, input_pos, "bsnd")
1958
+ k = torch.ops.rope.apply_rope_with_input_pos(k, freqs_cis, input_pos, "bsnd")
1959
+
1960
+ # attention (assumed layout is bsnd)
1961
+ y = torch.empty_like(q)
1962
+ if s > 1:
1963
+ # context phase
1964
+ _context_mha(q, k, v, k_cache, v_cache, y)
1965
+ else:
1966
+ # generate phase
1967
+ cache_locs = torch.arange(0, b, device=q.device, dtype=torch.int32)
1968
+ _generate_mha(q, k, v, k_cache, v_cache, cache_locs, input_pos, y)
1969
+
1970
+ return y.view(b, s, -1) # [b,s,n*h_d]
1971
+
1972
+
1973
+ @fused_mha_with_cache.register_fake
1974
+ def fused_mha_fake(
1975
+ q: torch.Tensor,
1976
+ k: torch.Tensor,
1977
+ v: torch.Tensor,
1978
+ input_pos: torch.Tensor,
1979
+ k_cache: torch.Tensor,
1980
+ v_cache: torch.Tensor,
1981
+ freqs_cis: torch.Tensor,
1982
+ ):
1983
+ return torch.empty_like(q.contiguous())
1984
+
1985
+
1986
+ def _flattened_context_mha(
1987
+ q: torch.Tensor,
1988
+ k: torch.Tensor,
1989
+ v: torch.Tensor,
1990
+ input_pos: torch.Tensor,
1991
+ cache_loc: torch.Tensor,
1992
+ k_cache: torch.Tensor,
1993
+ v_cache: torch.Tensor,
1994
+ seq_len: torch.Tensor,
1995
+ seq_start: torch.Tensor,
1996
+ out: torch.Tensor,
1997
+ ) -> None:
1998
+ # NOTE: s_total == sum(seq_len)
1999
+ s_total, n_heads, q_d_head = q.shape
2000
+ max_cache_seq_len, n_kv_heads = k_cache.shape[1:3]
2001
+ v_d_head = v.shape[-1]
2002
+ BATCH_SIZE: int = len(input_pos)
2003
+ SEQ_BLOCK = 32
2004
+ (
2005
+ update_kv_cache[(BATCH_SIZE, n_kv_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)](
2006
+ k,
2007
+ v,
2008
+ seq_len,
2009
+ seq_start,
2010
+ k_cache,
2011
+ v_cache,
2012
+ input_pos,
2013
+ cache_loc,
2014
+ max_cache_seq_len,
2015
+ n_kv_heads,
2016
+ q_d_head,
2017
+ v_d_head,
2018
+ 32,
2019
+ GENERATE_ONLY=False,
2020
+ ),
2021
+ )
2022
+ # TODO: use input_pos to get the correct cache locations
2023
+ softmax_scale = 1.0 / math.sqrt(q_d_head)
2024
+ grid = (BATCH_SIZE, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
2025
+ context_attention_kv_flattened[grid](
2026
+ q,
2027
+ seq_len,
2028
+ seq_start,
2029
+ k_cache,
2030
+ v_cache,
2031
+ input_pos,
2032
+ cache_loc,
2033
+ out,
2034
+ softmax_scale,
2035
+ n_heads,
2036
+ n_kv_heads,
2037
+ q_d_head,
2038
+ v_d_head,
2039
+ SEQ_BLOCK,
2040
+ max_cache_seq_len,
2041
+ num_stages=2,
2042
+ )
2043
+
2044
+
2045
+ @torch.library.custom_op("attention::fused_flattened_mha_with_cache", mutates_args=())
2046
+ def fused_flattened_mha_with_cache(
2047
+ # Q, K, V
2048
+ q: torch.Tensor,
2049
+ k: torch.Tensor,
2050
+ v: torch.Tensor,
2051
+ # METADATA
2052
+ seq_len: torch.Tensor,
2053
+ input_pos: torch.Tensor,
2054
+ cache_loc: torch.Tensor,
2055
+ seq_start: torch.Tensor,
2056
+ # CACHES
2057
+ k_cache: torch.Tensor,
2058
+ v_cache: torch.Tensor,
2059
+ # BUFFERS
2060
+ freqs_cis: torch.Tensor,
2061
+ # CONSTANTS
2062
+ # <none>
2063
+ ) -> torch.Tensor:
2064
+ """Flattened & fused MHA with cache that takes raw input from q, k, v GEMMs.
2065
+
2066
+ NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
2067
+ """
2068
+ # b, s info
2069
+ # NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
2070
+ # Generally speaking, we expect one of two cases here:
2071
+ # 1. b > 0, s==1: this indicates a generate-only batch of tokens.
2072
+ # 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
2073
+ # and number of tokens per sequence are encoded in seq_len and seq_start.
2074
+ head_dim = k_cache.shape[-1]
2075
+ b, s, d = q.shape
2076
+
2077
+ # reshapes with num_heads and head_dim
2078
+ if s == 1:
2079
+ bs_view = (b, s)
2080
+ else:
2081
+ bs_view = (b * s,)
2082
+ q = q.view(*bs_view, q.shape[2] // head_dim, head_dim)
2083
+ k = k.view(*bs_view, k.shape[2] // head_dim, head_dim)
2084
+ v = v.view(*bs_view, v.shape[2] // head_dim, head_dim)
2085
+
2086
+ # rope embedding for generate-only or mixed
2087
+ if freqs_cis is not None and freqs_cis.numel() > 0:
2088
+ if s == 1:
2089
+ rope_args = (freqs_cis, input_pos, "bsnd")
2090
+ fn_rope = torch.ops.rope.apply_rope_with_input_pos
2091
+ else:
2092
+ rope_args = (freqs_cis, input_pos, seq_len, seq_start)
2093
+ fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs
2094
+ q = fn_rope(q, *rope_args)
2095
+ k = fn_rope(k, *rope_args)
2096
+
2097
+ # run attention
2098
+ y = torch.empty_like(q)
2099
+ if s == 1:
2100
+ # generate-only phase
2101
+ _generate_mha(q, k, v, k_cache, v_cache, cache_loc, input_pos, y)
2102
+ else:
2103
+ # mixed context + generate phase
2104
+ _flattened_context_mha(
2105
+ q,
2106
+ k,
2107
+ v,
2108
+ input_pos,
2109
+ cache_loc,
2110
+ k_cache,
2111
+ v_cache,
2112
+ seq_len,
2113
+ seq_start,
2114
+ y,
2115
+ )
2116
+
2117
+ return y.view(b, s, d) # [b,s,n*h_d]
2118
+
2119
+
2120
+ @fused_flattened_mha_with_cache.register_fake
2121
+ def fused_flattened_mha_fake(
2122
+ q: torch.Tensor,
2123
+ k: torch.Tensor,
2124
+ v: torch.Tensor,
2125
+ seq_len: torch.Tensor,
2126
+ input_pos: torch.Tensor,
2127
+ cache_loc: torch.Tensor,
2128
+ seq_start: torch.Tensor,
2129
+ k_cache: torch.Tensor,
2130
+ v_cache: torch.Tensor,
2131
+ freqs_cis: torch.Tensor,
2132
+ ):
2133
+ return torch.empty_like(q.contiguous())
2134
+
2135
+
2136
+ def _generate_mha_rope_fusion(
2137
+ q: torch.Tensor,
2138
+ k: torch.Tensor,
2139
+ v: torch.Tensor,
2140
+ freqs_cis: torch.Tensor,
2141
+ k_cache: torch.Tensor,
2142
+ v_cache: torch.Tensor,
2143
+ cache_locs: torch.Tensor,
2144
+ input_pos: torch.Tensor,
2145
+ out: torch.Tensor,
2146
+ ):
2147
+ b, (n_heads, d_head) = q.shape[0], q.shape[-2:]
2148
+ max_seq_len, n_kv_heads = k_cache.shape[1:3]
2149
+ device = q.device
2150
+
2151
+ SEQ_BLOCK_SIZE = 64
2152
+ num_blocks = (max_seq_len + SEQ_BLOCK_SIZE - 1) // SEQ_BLOCK_SIZE
2153
+ stage1_output_values = torch.empty(
2154
+ b, n_heads, num_blocks, d_head, device=device, dtype=torch.float32
2155
+ )
2156
+ stage1_output_logsumexp = torch.empty(
2157
+ b, n_heads, num_blocks, device=device, dtype=torch.float32
2158
+ ) - float("inf")
2159
+ q_rope = torch.empty_like(q)
2160
+ HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
2161
+
2162
+ (
2163
+ update_kv_cache_rope_fusion[(b, n_kv_heads, 1)](
2164
+ q,
2165
+ k,
2166
+ v,
2167
+ None,
2168
+ None,
2169
+ q_rope,
2170
+ k_cache,
2171
+ v_cache,
2172
+ input_pos,
2173
+ cache_locs,
2174
+ freqs_cis,
2175
+ max_seq_len,
2176
+ n_heads,
2177
+ n_kv_heads,
2178
+ d_head,
2179
+ 1,
2180
+ HEAD_BLOCK_SIZE,
2181
+ GENERATE_ONLY=True,
2182
+ ),
2183
+ )
2184
+
2185
+ HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
2186
+ gqa_attention_kv_stage1[
2187
+ (
2188
+ b,
2189
+ n_kv_heads,
2190
+ num_blocks,
2191
+ )
2192
+ ](
2193
+ q_rope,
2194
+ k_cache,
2195
+ v_cache,
2196
+ cache_locs,
2197
+ input_pos,
2198
+ stage1_output_values,
2199
+ stage1_output_logsumexp,
2200
+ num_blocks,
2201
+ max_seq_len,
2202
+ n_heads,
2203
+ n_kv_heads,
2204
+ d_head,
2205
+ d_head,
2206
+ SEQ_BLOCK_SIZE,
2207
+ HEAD_BLOCK_SIZE,
2208
+ )
2209
+ attention_kv_stage2[(b, n_heads, 1)](
2210
+ stage1_output_values,
2211
+ stage1_output_logsumexp,
2212
+ out,
2213
+ input_pos,
2214
+ num_blocks,
2215
+ n_heads,
2216
+ d_head,
2217
+ SEQ_BLOCK_SIZE,
2218
+ )
2219
+
2220
+
2221
+ def _flattened_context_mha_rope_fusion(
2222
+ q: torch.Tensor,
2223
+ k: torch.Tensor,
2224
+ v: torch.Tensor,
2225
+ freqs_cis: torch.Tensor,
2226
+ input_pos: torch.Tensor,
2227
+ cache_loc: torch.Tensor,
2228
+ k_cache: torch.Tensor,
2229
+ v_cache: torch.Tensor,
2230
+ seq_len: torch.Tensor,
2231
+ seq_start: torch.Tensor,
2232
+ out: torch.Tensor,
2233
+ ) -> None:
2234
+ # NOTE: s_total == sum(seq_len)
2235
+ s_total, n_heads, d_head = q.shape
2236
+ max_cache_seq_len, n_kv_heads = k_cache.shape[1:3]
2237
+ BATCH_SIZE: int = len(input_pos)
2238
+ SEQ_BLOCK = 32
2239
+ q_rope = torch.empty_like(q)
2240
+ HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
2241
+ (
2242
+ update_kv_cache_rope_fusion[
2243
+ (BATCH_SIZE, n_kv_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
2244
+ ](
2245
+ q,
2246
+ k,
2247
+ v,
2248
+ seq_len,
2249
+ seq_start,
2250
+ q_rope,
2251
+ k_cache,
2252
+ v_cache,
2253
+ input_pos,
2254
+ cache_loc,
2255
+ freqs_cis,
2256
+ max_cache_seq_len,
2257
+ n_heads,
2258
+ n_kv_heads,
2259
+ d_head,
2260
+ 32,
2261
+ HEAD_BLOCK_SIZE,
2262
+ GENERATE_ONLY=False,
2263
+ ),
2264
+ )
2265
+ # TODO: use input_pos to get the correct cache locations
2266
+ softmax_scale = 1.0 / math.sqrt(d_head)
2267
+ grid = (BATCH_SIZE, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
2268
+ context_attention_kv_flattened[grid](
2269
+ q_rope,
2270
+ seq_len,
2271
+ seq_start,
2272
+ k_cache,
2273
+ v_cache,
2274
+ input_pos,
2275
+ cache_loc,
2276
+ out,
2277
+ softmax_scale,
2278
+ n_heads,
2279
+ n_kv_heads,
2280
+ d_head,
2281
+ d_head,
2282
+ SEQ_BLOCK,
2283
+ max_cache_seq_len,
2284
+ num_stages=2,
2285
+ )
2286
+
2287
+
2288
+ @torch.library.custom_op("attention::fused_flattened_mha_with_cache_rope_fusion", mutates_args=())
2289
+ def fused_flattened_mha_with_cache_rope_fusion(
2290
+ q: torch.Tensor,
2291
+ k: torch.Tensor,
2292
+ v: torch.Tensor,
2293
+ input_pos: torch.Tensor,
2294
+ cache_loc: torch.Tensor,
2295
+ seq_len: torch.Tensor,
2296
+ seq_start: torch.Tensor,
2297
+ k_cache: torch.Tensor,
2298
+ v_cache: torch.Tensor,
2299
+ freqs_cis: Optional[torch.Tensor],
2300
+ ) -> torch.Tensor:
2301
+ """Flattened & fused MHA with cache that takes raw input from q, k, v GEMMs.
2302
+
2303
+ Fuse k rope in update_kv_cache and q rope in attention.
2304
+ NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
2305
+ """
2306
+ # this function only handle requests with rope embadding.
2307
+ if freqs_cis is None:
2308
+ return fused_flattened_mha_with_cache(
2309
+ q,
2310
+ k,
2311
+ v,
2312
+ input_pos,
2313
+ cache_loc,
2314
+ seq_len,
2315
+ seq_start,
2316
+ k_cache,
2317
+ v_cache,
2318
+ freqs_cis,
2319
+ )
2320
+
2321
+ # b, s info
2322
+ # NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
2323
+ # Generally speaking, we expect one of two cases here:
2324
+ # 1. b > 0, s==1: this indicates a generate-only batch of tokens.
2325
+ # 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
2326
+ # and number of tokens per sequence are encoded in seq_len and seq_start.
2327
+ b, s, d = q.shape
2328
+ head_dim = k_cache.shape[-1]
2329
+
2330
+ # reshapes with num_heads and head_dim
2331
+ if s == 1:
2332
+ bs_view = (b, s)
2333
+ else:
2334
+ bs_view = (b * s,)
2335
+ q = q.view(*bs_view, q.shape[2] // head_dim, head_dim)
2336
+ k = k.view(*bs_view, k.shape[2] // head_dim, head_dim)
2337
+ v = v.view(*bs_view, v.shape[2] // head_dim, head_dim)
2338
+
2339
+ # run attention
2340
+ y = torch.empty_like(q)
2341
+ if s == 1:
2342
+ # generate-only phase
2343
+ _generate_mha_rope_fusion(q, k, v, freqs_cis, k_cache, v_cache, cache_loc, input_pos, y)
2344
+ else:
2345
+ # mixed context + generate phase
2346
+ _flattened_context_mha_rope_fusion(
2347
+ q,
2348
+ k,
2349
+ v,
2350
+ freqs_cis,
2351
+ input_pos,
2352
+ cache_loc,
2353
+ k_cache,
2354
+ v_cache,
2355
+ seq_len,
2356
+ seq_start,
2357
+ y,
2358
+ )
2359
+
2360
+ return y.view(b, s, d) # [b,s,n*h_d]
2361
+
2362
+
2363
+ @fused_flattened_mha_with_cache_rope_fusion.register_fake
2364
+ def fused_flattened_mha_with_cache_rope_fusion_fake(
2365
+ q: torch.Tensor,
2366
+ k: torch.Tensor,
2367
+ v: torch.Tensor,
2368
+ input_pos: torch.Tensor,
2369
+ cache_loc: torch.Tensor,
2370
+ seq_len: torch.Tensor,
2371
+ seq_start: torch.Tensor,
2372
+ k_cache: torch.Tensor,
2373
+ v_cache: torch.Tensor,
2374
+ freqs_cis: torch.Tensor,
2375
+ ):
2376
+ return torch.empty_like(q.contiguous())
2377
+
2378
+
2379
+ def _paged_generate_mha(
2380
+ q: torch.Tensor,
2381
+ k: torch.Tensor,
2382
+ v: torch.Tensor,
2383
+ page_table: torch.Tensor,
2384
+ k_cache: torch.Tensor,
2385
+ v_cache: torch.Tensor,
2386
+ cache_loc: torch.Tensor,
2387
+ input_pos: torch.Tensor,
2388
+ out: torch.Tensor,
2389
+ max_seq_len: int,
2390
+ ):
2391
+ b, (n_heads, d_head) = q.shape[0], q.shape[-2:]
2392
+ PAGE_SIZE, n_kv_heads = k_cache.shape[1:3]
2393
+ device = q.device
2394
+
2395
+ SEQ_BLOCK_SIZE = PAGE_SIZE # 256
2396
+ num_blocks = (max_seq_len + SEQ_BLOCK_SIZE - 1) // SEQ_BLOCK_SIZE
2397
+ stage1_output_values = torch.empty(
2398
+ b, n_heads, num_blocks, d_head, device=device, dtype=torch.float32
2399
+ )
2400
+ stage1_output_logsumexp = torch.empty(
2401
+ b, n_heads, num_blocks, device=device, dtype=torch.float32
2402
+ ) - float("inf")
2403
+
2404
+ (
2405
+ update_paged_kv_cache[(b, n_kv_heads, 1)](
2406
+ k,
2407
+ v,
2408
+ None,
2409
+ None,
2410
+ k_cache,
2411
+ v_cache,
2412
+ cache_loc,
2413
+ input_pos,
2414
+ page_table,
2415
+ n_kv_heads,
2416
+ d_head,
2417
+ SEQ_BLOCK_SIZE,
2418
+ max_seq_len,
2419
+ PAGE_SIZE,
2420
+ page_table.stride(0),
2421
+ GENERATE_ONLY=True,
2422
+ ),
2423
+ )
2424
+
2425
+ attention_kv_paged_stage1[
2426
+ (
2427
+ b,
2428
+ n_heads,
2429
+ num_blocks,
2430
+ )
2431
+ ](
2432
+ q,
2433
+ k_cache,
2434
+ v_cache,
2435
+ cache_loc,
2436
+ page_table,
2437
+ input_pos,
2438
+ stage1_output_values,
2439
+ stage1_output_logsumexp,
2440
+ num_blocks,
2441
+ max_seq_len,
2442
+ n_heads,
2443
+ n_kv_heads,
2444
+ d_head,
2445
+ SEQ_BLOCK_SIZE,
2446
+ PAGE_SIZE,
2447
+ page_table.stride(0),
2448
+ )
2449
+ attention_kv_stage2[(b, n_heads, 1)](
2450
+ stage1_output_values,
2451
+ stage1_output_logsumexp,
2452
+ out,
2453
+ input_pos,
2454
+ num_blocks,
2455
+ n_heads,
2456
+ d_head,
2457
+ SEQ_BLOCK_SIZE,
2458
+ )
2459
+
2460
+
2461
+ def _paged_context_mha(
2462
+ q: torch.Tensor,
2463
+ k: torch.Tensor,
2464
+ v: torch.Tensor,
2465
+ input_pos: torch.Tensor,
2466
+ cache_loc: torch.Tensor,
2467
+ page_table: torch.Tensor,
2468
+ k_cache: torch.Tensor,
2469
+ v_cache: torch.Tensor,
2470
+ seq_len: torch.Tensor,
2471
+ seq_start: torch.Tensor,
2472
+ out: torch.Tensor,
2473
+ max_seq_len: int, # max cache length of sequence, kv_cache shape don't provide this info.
2474
+ ) -> None:
2475
+ # NOTE: s_total == sum(seq_len)
2476
+ s_total, n_heads, d_head = q.shape
2477
+ PAGE_SIZE, n_kv_heads = k_cache.shape[1:3]
2478
+ BATCH_SIZE = len(input_pos)
2479
+ SEQ_BLOCK = PAGE_SIZE # 32
2480
+ (
2481
+ update_paged_kv_cache[
2482
+ (BATCH_SIZE, n_kv_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
2483
+ ](
2484
+ k,
2485
+ v,
2486
+ seq_len,
2487
+ seq_start,
2488
+ k_cache,
2489
+ v_cache,
2490
+ cache_loc,
2491
+ input_pos,
2492
+ page_table,
2493
+ n_kv_heads,
2494
+ d_head,
2495
+ SEQ_BLOCK,
2496
+ max_seq_len,
2497
+ PAGE_SIZE,
2498
+ page_table.stride(0),
2499
+ GENERATE_ONLY=False,
2500
+ ),
2501
+ )
2502
+ softmax_scale = 1.0 / math.sqrt(d_head)
2503
+ grid = (BATCH_SIZE, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
2504
+ context_attention_kv_paged[grid](
2505
+ q,
2506
+ seq_len,
2507
+ seq_start,
2508
+ k_cache,
2509
+ v_cache,
2510
+ cache_loc,
2511
+ input_pos,
2512
+ page_table,
2513
+ softmax_scale,
2514
+ out,
2515
+ n_heads,
2516
+ n_kv_heads,
2517
+ d_head,
2518
+ SEQ_BLOCK,
2519
+ max_seq_len,
2520
+ PAGE_SIZE,
2521
+ page_table.stride(0),
2522
+ num_stages=2,
2523
+ )
2524
+
2525
+
2526
+ @torch.library.custom_op("attention::fused_mha_with_paged_cache", mutates_args=())
2527
+ def fused_mha_with_paged_cache(
2528
+ q: torch.Tensor,
2529
+ k: torch.Tensor,
2530
+ v: torch.Tensor,
2531
+ input_pos: torch.Tensor,
2532
+ cache_loc: torch.Tensor,
2533
+ seq_len: torch.Tensor,
2534
+ seq_start: torch.Tensor,
2535
+ page_table: torch.Tensor,
2536
+ max_seq_len: int,
2537
+ k_cache: torch.Tensor,
2538
+ v_cache: torch.Tensor,
2539
+ freqs_cis: Optional[torch.Tensor],
2540
+ ) -> torch.Tensor:
2541
+ """Fused MHA with paged cache that takes raw input from q, k, v GEMMs.
2542
+
2543
+ NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
2544
+ """
2545
+ # b, s info
2546
+ # NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
2547
+ # Generally speaking, we expect one of two cases here:
2548
+ # 1. b > 0, s==1: this indicates a generate-only batch of tokens.
2549
+ # 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
2550
+ # and number of tokens per sequence are encoded in seq_len and seq_start.
2551
+ # Assuming that context seq_len always > 0.
2552
+ b, s, d = q.shape
2553
+ head_dim = k_cache.shape[-1]
2554
+
2555
+ # reshapes with num_heads and head_dim
2556
+ if s == 1:
2557
+ bs_view = (b, s)
2558
+ else:
2559
+ bs_view = (b * s,)
2560
+ q = q.view(*bs_view, q.shape[2] // head_dim, head_dim)
2561
+ k = k.view(*bs_view, k.shape[2] // head_dim, head_dim)
2562
+ v = v.view(*bs_view, v.shape[2] // head_dim, head_dim)
2563
+
2564
+ # rope embedding for generate-only or mixed
2565
+ if freqs_cis is not None:
2566
+ if s == 1:
2567
+ rope_args = (freqs_cis, input_pos, "bsnd")
2568
+ fn_rope = torch.ops.rope.apply_rope_with_input_pos
2569
+ else:
2570
+ rope_args = (freqs_cis, input_pos, seq_len, seq_start)
2571
+ fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs
2572
+ q = fn_rope(q, *rope_args)
2573
+ k = fn_rope(k, *rope_args)
2574
+
2575
+ # run attention
2576
+ y = torch.empty_like(q)
2577
+ if s == 1:
2578
+ # generate-only phase
2579
+ _paged_generate_mha(
2580
+ q, k, v, page_table, k_cache, v_cache, cache_loc, input_pos, y, max_seq_len
2581
+ )
2582
+ else:
2583
+ # mixed context + generate phase
2584
+ _paged_context_mha(
2585
+ q,
2586
+ k,
2587
+ v,
2588
+ input_pos,
2589
+ cache_loc,
2590
+ page_table,
2591
+ k_cache,
2592
+ v_cache,
2593
+ seq_len,
2594
+ seq_start,
2595
+ y,
2596
+ max_seq_len,
2597
+ )
2598
+
2599
+ return y.view(b, s, d) # [b,s,n*h_d]
2600
+
2601
+
2602
+ @fused_mha_with_paged_cache.register_fake
2603
+ def fused_mha_with_paged_cache_fake(
2604
+ q: torch.Tensor,
2605
+ k: torch.Tensor,
2606
+ v: torch.Tensor,
2607
+ input_pos: torch.Tensor,
2608
+ cache_loc: torch.Tensor,
2609
+ seq_len: torch.Tensor,
2610
+ seq_start: torch.Tensor,
2611
+ page_table: torch.Tensor,
2612
+ max_seq_len: int,
2613
+ k_cache: torch.Tensor,
2614
+ v_cache: torch.Tensor,
2615
+ freqs_cis: Optional[torch.Tensor],
2616
+ ) -> torch.Tensor:
2617
+ return torch.empty_like(q.contiguous())
2618
+
2619
+
2620
+ @torch.library.custom_op("attention::prepare_fused_mha_metadata", mutates_args=())
2621
+ def prepare_fused_mha_metadata(
2622
+ input_ids: torch.Tensor,
2623
+ seq_len: torch.Tensor,
2624
+ input_pos: torch.Tensor,
2625
+ cache_loc: torch.Tensor,
2626
+ pages_per_seq: torch.Tensor,
2627
+ page_size: int,
2628
+ ) -> List[torch.Tensor]:
2629
+ num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
2630
+ seq_start = torch.zeros_like(seq_len[:num_seq])
2631
+ seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0)
2632
+ return (
2633
+ seq_len[:num_seq].clone(),
2634
+ input_pos[:num_seq].clone(),
2635
+ cache_loc[:num_seq].clone(),
2636
+ seq_start,
2637
+ )
2638
+
2639
+
2640
+ @prepare_fused_mha_metadata.register_fake
2641
+ def prepare_fused_mha_metadata_fake(
2642
+ input_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
2643
+ ):
2644
+ return (
2645
+ torch.empty_like(seq_len),
2646
+ torch.empty_like(input_pos),
2647
+ torch.empty_like(cache_loc),
2648
+ torch.empty_like(seq_len),
2649
+ )
2650
+
2651
+
2652
+ @AttentionRegistry.register("TritonWithFlattenedInputs")
2653
+ class TritonWithFlattenedInputs(AttentionDescriptor):
2654
+ @classmethod
2655
+ def is_paged(cls):
2656
+ """Return if the attention op is paged or not."""
2657
+ return False
2658
+
2659
+ @classmethod
2660
+ def get_attention_op(cls):
2661
+ return torch.ops.attention.fused_flattened_mha_with_cache, 3
2662
+
2663
+ @classmethod
2664
+ def get_prepare_metadata_op(cls):
2665
+ return torch.ops.attention.prepare_fused_mha_metadata, 4
2666
+
2667
+ @classmethod
2668
+ def get_cache_initializers(cls, get_info):
2669
+ def _get_cache(si: SequenceInfo):
2670
+ assert not si.is_paged, "Paged cache not supported for TritonWithFlattenedInputs"
2671
+ attention_info = get_info()
2672
+ return torch.empty(
2673
+ si.num_pages,
2674
+ si.page_size,
2675
+ attention_info.num_kv_heads,
2676
+ attention_info.head_dim,
2677
+ device=si.device,
2678
+ dtype=attention_info.cache_config.dtype or attention_info.dtype,
2679
+ )
2680
+
2681
+ return {"k_cache": _get_cache, "v_cache": _get_cache}
2682
+
2683
+ @classmethod
2684
+ def get_global_buffer_initializers(cls, get_info):
2685
+ attention_info = get_info()
2686
+ head_dim = attention_info.head_dim
2687
+ pos_embd_config = attention_info.pos_embd_config
2688
+
2689
+ def _get_freqs_cis(si: SequenceInfo):
2690
+ if pos_embd_config.mode is None:
2691
+ return torch.empty(0, device=si.device)
2692
+ assert pos_embd_config.mode == "rope", f"Mode {pos_embd_config.mode=} not supported"
2693
+ assert pos_embd_config.rope_scale == 1.0, f"{pos_embd_config.rope_scale=} not supported"
2694
+ rope_theta = pos_embd_config.rope_theta
2695
+ return cls._precompute_freqs_cis(2 * si.max_seq_len, head_dim, rope_theta).to(si.device)
2696
+
2697
+ k_full = "_".join(map(str, ["freqs_cis", *astuple(pos_embd_config)])).replace(".", "_")
2698
+ return {k_full: _get_freqs_cis}
2699
+
2700
+ @staticmethod
2701
+ def _precompute_freqs_cis(
2702
+ seq_len: int, head_dim: int, rope_theta: Optional[float] = None
2703
+ ) -> torch.Tensor:
2704
+ if rope_theta is None:
2705
+ rope_theta = 1e4
2706
+ freqs = 1.0 / (
2707
+ rope_theta ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim)
2708
+ )
2709
+ t = torch.arange(seq_len)
2710
+ freqs = torch.outer(t, freqs)
2711
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
2712
+ # cos and sin (real and img) are packed
2713
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
2714
+ return cache.to(dtype=torch.float16)