zaydzuhri commited on
Commit
9b1e3d4
·
verified ·
1 Parent(s): 4ffecd9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/models/abc/modeling_abc.py +418 -0
  2. fla/models/gated_deltanet/__init__.py +12 -0
  3. fla/models/mamba/__init__.py +13 -0
  4. fla/models/mamba2/__init__.py +13 -0
  5. fla/models/rwkv7/__init__.py +13 -0
  6. fla/models/rwkv7/modeling_rwkv7.py +505 -0
  7. fla/ops/abc/__pycache__/chunk.cpython-312.pyc +0 -0
  8. fla/ops/attn/__pycache__/parallel.cpython-312.pyc +0 -0
  9. fla/ops/based/__pycache__/parallel.cpython-312.pyc +0 -0
  10. fla/ops/based/fused_chunk.py +374 -0
  11. fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc +0 -0
  12. fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc +0 -0
  13. fla/ops/common/__pycache__/utils.cpython-312.pyc +0 -0
  14. fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
  15. fla/ops/delta_rule/wy_fast.py +340 -0
  16. fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
  17. fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc +0 -0
  18. fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-312.pyc +0 -0
  19. fla/ops/gated_delta_rule/wy_fast.py +620 -0
  20. fla/ops/generalized_delta_rule/dplr/__init__.py +7 -0
  21. fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc +0 -0
  22. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc +0 -0
  23. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc +0 -0
  24. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc +0 -0
  25. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc +0 -0
  26. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc +0 -0
  27. fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  28. fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc +0 -0
  29. fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py +324 -0
  30. fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +197 -0
  31. fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py +464 -0
  32. fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +138 -0
  33. fla/ops/generalized_delta_rule/dplr/fused_recurrent.py +292 -0
  34. fla/ops/generalized_delta_rule/dplr/naive.py +96 -0
  35. fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +184 -0
  36. fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +318 -0
  37. fla/ops/generalized_delta_rule/iplr/__init__.py +7 -0
  38. fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc +0 -0
  39. fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  40. fla/ops/generalized_delta_rule/iplr/chunk.py +528 -0
  41. fla/ops/generalized_delta_rule/iplr/wy_fast.py +338 -0
  42. fla/ops/gla/__pycache__/__init__.cpython-312.pyc +0 -0
  43. fla/ops/gla/__pycache__/chunk.cpython-312.pyc +0 -0
  44. fla/ops/gla/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  45. fla/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  46. fla/ops/hgrn/__pycache__/__init__.cpython-312.pyc +0 -0
  47. fla/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  48. fla/ops/linear_attn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  49. fla/ops/linear_attn/__pycache__/utils.cpython-312.pyc +0 -0
  50. fla/ops/nsa/__pycache__/__init__.cpython-312.pyc +0 -0
fla/models/abc/modeling_abc.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.abc import ABCAttention
19
+ from fla.layers.attn import Attention
20
+ from fla.models.abc.configuration_abc import ABCConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as ABCMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers.processing_utils import Unpack
30
+
31
+
32
+ class ABCBlock(nn.Module):
33
+ def __init__(self, config: ABCConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = ABCAttention(
53
+ hidden_size=config.hidden_size,
54
+ expand_k=config.expand_k,
55
+ expand_v=config.expand_v,
56
+ num_heads=config.num_heads,
57
+ num_slots=config.num_slots,
58
+ use_short_conv=config.use_short_conv,
59
+ conv_size=config.conv_size,
60
+ gate_fn=config.hidden_act,
61
+ elementwise_affine=config.elementwise_affine,
62
+ norm_eps=config.norm_eps,
63
+ use_rope=config.use_rope,
64
+ clamp_min=config.clamp_min,
65
+ clamp_max=config.clamp_max,
66
+ fuse_norm=config.fuse_norm,
67
+ layer_idx=layer_idx
68
+ )
69
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
70
+ self.mlp = ABCMLP(
71
+ hidden_size=config.hidden_size,
72
+ hidden_ratio=config.hidden_ratio,
73
+ intermediate_size=config.intermediate_size,
74
+ hidden_act=config.hidden_act,
75
+ fuse_swiglu=config.fuse_swiglu
76
+ )
77
+
78
+ def forward(
79
+ self,
80
+ hidden_states: torch.Tensor,
81
+ attention_mask: Optional[torch.Tensor] = None,
82
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
83
+ use_cache: Optional[bool] = False,
84
+ output_attentions: Optional[bool] = False,
85
+ **kwargs: Unpack[Dict]
86
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
87
+
88
+ residual = hidden_states
89
+
90
+ hidden_states = self.attn_norm(hidden_states)
91
+ hidden_states, attentions, past_key_values = self.attn(
92
+ hidden_states=hidden_states,
93
+ attention_mask=attention_mask,
94
+ past_key_values=past_key_values,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ **kwargs
98
+ )
99
+ if self.config.fuse_norm:
100
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
101
+ else:
102
+ hidden_states = residual + hidden_states
103
+ residual = hidden_states
104
+ hidden_states = self.mlp_norm(hidden_states)
105
+ hidden_states = self.mlp(hidden_states)
106
+ hidden_states = residual + hidden_states
107
+
108
+ outputs = (hidden_states, attentions, past_key_values)
109
+
110
+ return outputs
111
+
112
+
113
+ class ABCPreTrainedModel(PreTrainedModel):
114
+
115
+ config_class = ABCConfig
116
+ base_model_prefix = 'model'
117
+ supports_gradient_checkpointing = True
118
+ _no_split_modules = ['ABCBlock']
119
+ _supports_cache_class = True
120
+
121
+ def __init__(self, *inputs, **kwargs):
122
+ super().__init__(*inputs, **kwargs)
123
+
124
+ def _init_weights(
125
+ self,
126
+ module: nn.Module,
127
+ prenorm_residual_strategy: Optional[str] = 'rescale',
128
+ num_residuals_per_layer: int = 2,
129
+ ):
130
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
131
+ # Slightly different from the TF version which uses truncated_normal for initialization
132
+ # cf https://github.com/pytorch/pytorch/pull/5617
133
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
134
+ if module.bias is not None:
135
+ nn.init.zeros_(module.bias)
136
+ elif isinstance(module, nn.Embedding):
137
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
138
+ elif hasattr(module, 'reset_parameters'):
139
+ module.reset_parameters()
140
+
141
+ if prenorm_residual_strategy is not None:
142
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
143
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
144
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
145
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
146
+ #
147
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
148
+ p = None
149
+ if hasattr(module, 'o_proj'):
150
+ p = module.o_proj.weight
151
+ elif hasattr(module, 'down_proj'):
152
+ p = module.down_proj.weight
153
+ if p is not None:
154
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
155
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
156
+ # We need to reinit p since this code could be called multiple times
157
+ # Having just p *= scale would repeatedly scale it down
158
+ if prenorm_residual_strategy == 'rescale':
159
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
160
+ with torch.no_grad():
161
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
162
+ elif prenorm_residual_strategy == 'zero':
163
+ nn.init.zeros_(p)
164
+ else:
165
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
166
+
167
+
168
+ class ABCModel(ABCPreTrainedModel):
169
+
170
+ def __init__(self, config: ABCConfig):
171
+ super().__init__(config)
172
+ self.padding_idx = config.pad_token_id
173
+ self.vocab_size = config.vocab_size
174
+
175
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
176
+ self.layers = nn.ModuleList([ABCBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
177
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
178
+
179
+ self.gradient_checkpointing = False
180
+
181
+ self.post_init()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.embeddings
185
+
186
+ def set_input_embeddings(self, value):
187
+ self.embeddings = value
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: Optional[torch.LongTensor] = None,
192
+ attention_mask: Optional[torch.Tensor] = None, # noqa
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
195
+ use_cache: Optional[bool] = None,
196
+ output_attentions: Optional[bool] = None,
197
+ output_hidden_states: Optional[bool] = None,
198
+ return_dict: Optional[bool] = None,
199
+ **kwargs: Unpack[Dict]
200
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
201
+ if output_attentions:
202
+ warnings.warn("`ABCModel` does not `output_attentions` now, setting it to `False`.")
203
+ output_attentions = False
204
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
205
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
206
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
207
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
208
+
209
+ # retrieve input_ids and inputs_embeds
210
+ if input_ids is not None and inputs_embeds is not None:
211
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
212
+ if input_ids is None and inputs_embeds is None:
213
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
214
+
215
+ if inputs_embeds is None:
216
+ inputs_embeds = self.embeddings(input_ids)
217
+ hidden_states = inputs_embeds
218
+
219
+ if use_cache and not isinstance(past_key_values, Cache):
220
+ past_key_values = Cache.from_legacy_cache(past_key_values)
221
+
222
+ if self.gradient_checkpointing and self.training and use_cache:
223
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
224
+ use_cache = False
225
+
226
+ all_hidden_states = () if output_hidden_states else None
227
+ all_attns = () if output_attentions else None
228
+ for layer in self.layers:
229
+ if output_hidden_states:
230
+ all_hidden_states += (hidden_states,)
231
+
232
+ if self.gradient_checkpointing and self.training:
233
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
234
+ layer.__call__,
235
+ hidden_states,
236
+ attention_mask,
237
+ past_key_values,
238
+ use_cache,
239
+ output_attentions,
240
+ **kwargs
241
+ )
242
+ else:
243
+ hidden_states, attentions, past_key_values = layer(
244
+ hidden_states,
245
+ attention_mask,
246
+ past_key_values=past_key_values,
247
+ use_cache=use_cache,
248
+ output_attentions=output_attentions,
249
+ **kwargs
250
+ )
251
+
252
+ if output_attentions:
253
+ all_attns += (attentions,)
254
+
255
+ hidden_states = self.norm(hidden_states)
256
+
257
+ # add hidden states from the last decoder layer
258
+ if output_hidden_states:
259
+ all_hidden_states += (hidden_states,)
260
+
261
+ if not return_dict:
262
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
263
+ return BaseModelOutputWithPast(
264
+ last_hidden_state=hidden_states,
265
+ past_key_values=past_key_values,
266
+ hidden_states=all_hidden_states,
267
+ attentions=all_attns
268
+ )
269
+
270
+
271
+ class ABCForCausalLM(ABCPreTrainedModel, GenerationMixin):
272
+
273
+ _tied_weights_keys = ["lm_head.weight"]
274
+
275
+ def __init__(self, config):
276
+ super().__init__(config)
277
+ self.model = ABCModel(config)
278
+ self.vocab_size = config.vocab_size
279
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
280
+ self.criterion = None
281
+
282
+ # Initialize weights and apply final processing
283
+ self.post_init()
284
+
285
+ def get_input_embeddings(self):
286
+ return self.model.embeddings
287
+
288
+ def set_input_embeddings(self, value):
289
+ self.model.embeddings = value
290
+
291
+ def get_output_embeddings(self):
292
+ return self.lm_head
293
+
294
+ def set_output_embeddings(self, new_embeddings):
295
+ self.lm_head = new_embeddings
296
+
297
+ def set_decoder(self, decoder):
298
+ self.model = decoder
299
+
300
+ def get_decoder(self):
301
+ return self.model
302
+
303
+ def generate(self, *args, **kwargs):
304
+ try:
305
+ return super().generate(*args, **kwargs)
306
+ except AttributeError as exception:
307
+ if 'past_key_values' in str(exception):
308
+ raise AttributeError(
309
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
310
+ f"which is not supported for {self.__class__.__name__}. "
311
+ f"Try another generation strategy instead. "
312
+ f"For the available generation strategies, check this doc: "
313
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
314
+ )
315
+ else:
316
+ raise exception
317
+
318
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
319
+ def prepare_inputs_for_generation(
320
+ self,
321
+ input_ids: torch.LongTensor = None,
322
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
323
+ attention_mask: Optional[torch.Tensor] = None,
324
+ inputs_embeds: Optional[torch.Tensor] = None,
325
+ use_cache: bool = True,
326
+ logits_to_keep: Optional[int] = None,
327
+ **kwargs
328
+ ):
329
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
330
+ if past_key_values is not None and len(past_key_values) > 0:
331
+ input_ids = input_ids[:, -1:]
332
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
333
+ if inputs_embeds is not None and len(past_key_values) == 0:
334
+ model_inputs = {'inputs_embeds': inputs_embeds}
335
+ else:
336
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
337
+ # recompiles graphs as the stride of the inputs is a guard.
338
+ # Ref: https://github.com/huggingface/transformers/pull/29114
339
+ # TODO: use `next_tokens` directly instead.
340
+ model_inputs = {'input_ids': input_ids.contiguous()}
341
+
342
+ if logits_to_keep is not None:
343
+ model_inputs['logits_to_keep'] = logits_to_keep
344
+
345
+ model_inputs.update({
346
+ 'past_key_values': past_key_values,
347
+ 'use_cache': use_cache,
348
+ 'attention_mask': attention_mask,
349
+ })
350
+ return model_inputs
351
+
352
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
353
+ def forward(
354
+ self,
355
+ input_ids: torch.LongTensor = None,
356
+ attention_mask: Optional[torch.Tensor] = None,
357
+ inputs_embeds: Optional[torch.Tensor] = None,
358
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
359
+ labels: Optional[torch.LongTensor] = None,
360
+ use_cache: Optional[bool] = None,
361
+ output_attentions: Optional[bool] = None,
362
+ output_hidden_states: Optional[bool] = None,
363
+ return_dict: Optional[bool] = None,
364
+ logits_to_keep: Optional[int] = 0,
365
+ **kwargs: Unpack[Dict]
366
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
367
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
368
+ output_hidden_states = (
369
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
370
+ )
371
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
372
+
373
+ outputs = self.model(
374
+ input_ids=input_ids,
375
+ attention_mask=attention_mask,
376
+ inputs_embeds=inputs_embeds,
377
+ past_key_values=past_key_values,
378
+ use_cache=use_cache,
379
+ output_attentions=output_attentions,
380
+ output_hidden_states=output_hidden_states,
381
+ return_dict=return_dict,
382
+ **kwargs
383
+ )
384
+
385
+ hidden_states = outputs[0]
386
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
387
+
388
+ loss, logits = None, None
389
+ if not fuse_linear_and_cross_entropy or labels is None:
390
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
391
+ if labels is not None:
392
+ if getattr(self, 'criterion', None) is None:
393
+ if fuse_linear_and_cross_entropy:
394
+ criterion = FusedLinearCrossEntropyLoss()
395
+ elif self.config.fuse_cross_entropy:
396
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
397
+ else:
398
+ criterion = nn.CrossEntropyLoss()
399
+ else:
400
+ criterion = self.criterion
401
+ labels = labels.to(hidden_states.device)
402
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
403
+ if fuse_linear_and_cross_entropy:
404
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
405
+ else:
406
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
407
+
408
+ if not return_dict:
409
+ output = (logits,) + outputs[1:]
410
+ return (loss,) + output if loss is not None else output
411
+
412
+ return CausalLMOutputWithPast(
413
+ loss=loss,
414
+ logits=logits,
415
+ past_key_values=outputs.past_key_values,
416
+ hidden_states=outputs.hidden_states,
417
+ attentions=outputs.attentions,
418
+ )
fla/models/gated_deltanet/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
6
+ from fla.models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForCausalLM, GatedDeltaNetModel
7
+
8
+ AutoConfig.register(GatedDeltaNetConfig.model_type, GatedDeltaNetConfig)
9
+ AutoModel.register(GatedDeltaNetConfig, GatedDeltaNetModel)
10
+ AutoModelForCausalLM.register(GatedDeltaNetConfig, GatedDeltaNetForCausalLM)
11
+
12
+ __all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel']
fla/models/mamba/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.mamba.configuration_mamba import MambaConfig
6
+ from fla.models.mamba.modeling_mamba import MambaBlock, MambaForCausalLM, MambaModel
7
+
8
+ AutoConfig.register(MambaConfig.model_type, MambaConfig, True)
9
+ AutoModel.register(MambaConfig, MambaModel, True)
10
+ AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True)
11
+
12
+
13
+ __all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock']
fla/models/mamba2/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.mamba2.configuration_mamba2 import Mamba2Config
6
+ from fla.models.mamba2.modeling_mamba2 import Mamba2ForCausalLM, Mamba2Model
7
+
8
+ AutoConfig.register(Mamba2Config.model_type, Mamba2Config, True)
9
+ AutoModel.register(Mamba2Config, Mamba2Model, True)
10
+ AutoModelForCausalLM.register(Mamba2Config, Mamba2ForCausalLM, True)
11
+
12
+
13
+ __all__ = ['Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model']
fla/models/rwkv7/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config
6
+ from fla.models.rwkv7.modeling_rwkv7 import RWKV7ForCausalLM, RWKV7Model
7
+
8
+ AutoConfig.register(RWKV7Config.model_type, RWKV7Config, True)
9
+ AutoModel.register(RWKV7Config, RWKV7Model, True)
10
+ AutoModelForCausalLM.register(RWKV7Config, RWKV7ForCausalLM, True)
11
+
12
+
13
+ __all__ = ['RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model']
fla/models/rwkv7/modeling_rwkv7.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.rwkv7 import RWKV7Attention
20
+ from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, LayerNorm
23
+ from fla.modules.activations import ACT2FN
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class RWKV7FeedForward(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'sqrelu',
39
+ layer_idx: int = None
40
+ ) -> RWKV7FeedForward:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ if hidden_ratio is None:
45
+ hidden_ratio = 4
46
+ if intermediate_size is None:
47
+ intermediate_size = int(hidden_size * hidden_ratio)
48
+ intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32)
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+
52
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
53
+
54
+ self.x_k = nn.Parameter(torch.zeros(hidden_size))
55
+
56
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
57
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
58
+ self.act_fn = ACT2FN[hidden_act]
59
+
60
+ self.layer_idx = layer_idx
61
+
62
+ def forward(
63
+ self,
64
+ x: torch.Tensor,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ state: Optional[Cache] = None
67
+ ) -> torch.Tensor:
68
+ if attention_mask is not None:
69
+ x = x.mul(attention_mask[:, -x.shape[-2]:, None])
70
+ if x.shape[1] == 1 and state is not None and state[self.layer_idx]['ffn_state'] is not None:
71
+ shifted = state[self.layer_idx]['ffn_state'].unsqueeze(1)
72
+ else:
73
+ shifted = self.time_shift(x)
74
+ if state is not None and state[self.layer_idx]['ffn_state'] is not None:
75
+ shifted[:, 0] = state[self.layer_idx]['ffn_state'][-1]
76
+ if state is not None:
77
+ # no need to update the offset twice
78
+ state.update(ffn_state=x[:, -1], layer_idx=self.layer_idx, offset=0)
79
+ return self.value(self.act_fn(self.key(x.addcmul(shifted - x, self.x_k)))), state
80
+
81
+
82
+ class RWKV7Block(nn.Module):
83
+
84
+ def __init__(
85
+ self,
86
+ config: RWKV7Config,
87
+ layer_idx: int
88
+ ) -> RWKV7Block:
89
+ super().__init__()
90
+
91
+ self.config = config
92
+ self.layer_idx = layer_idx
93
+
94
+ if config.norm_first and layer_idx == 0:
95
+ self.pre_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
96
+ config.hidden_size,
97
+ bias=config.norm_bias,
98
+ eps=config.norm_eps
99
+ )
100
+ self.attn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
101
+ config.hidden_size,
102
+ bias=config.norm_bias,
103
+ eps=config.norm_eps
104
+ )
105
+ if config.attn is not None and layer_idx in config.attn['layers']:
106
+ self.attn = Attention(
107
+ hidden_size=config.hidden_size,
108
+ num_heads=config.attn['num_heads'],
109
+ num_kv_heads=config.attn['num_kv_heads'],
110
+ qkv_bias=config.attn['qkv_bias'],
111
+ window_size=config.attn['window_size'],
112
+ rope_theta=config.attn['rope_theta'],
113
+ max_position_embeddings=config.max_position_embeddings,
114
+ layer_idx=layer_idx
115
+ )
116
+ else:
117
+ self.attn = RWKV7Attention(
118
+ mode=config.attn_mode,
119
+ hidden_size=config.hidden_size,
120
+ head_dim=config.head_dim,
121
+ num_heads=config.num_heads,
122
+ decay_low_rank_dim=config.decay_low_rank_dim,
123
+ gate_low_rank_dim=config.gate_low_rank_dim,
124
+ a_low_rank_dim=config.a_low_rank_dim,
125
+ v_low_rank_dim=config.v_low_rank_dim,
126
+ norm_eps=config.norm_eps,
127
+ fuse_norm=config.fuse_norm,
128
+ layer_idx=layer_idx,
129
+ value_dim=config.value_dim[layer_idx]
130
+ )
131
+ self.ffn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
132
+ config.hidden_size,
133
+ bias=config.norm_bias,
134
+ eps=config.norm_eps
135
+ )
136
+ self.ffn = RWKV7FeedForward(
137
+ hidden_size=config.hidden_size,
138
+ hidden_ratio=config.hidden_ratio,
139
+ intermediate_size=config.intermediate_size,
140
+ hidden_act=config.hidden_act,
141
+ layer_idx=layer_idx
142
+ )
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states: torch.Tensor,
147
+ attention_mask: Optional[torch.Tensor] = None,
148
+ past_key_values: Optional[Cache] = None,
149
+ use_cache: Optional[bool] = False,
150
+ output_attentions: Optional[bool] = False,
151
+ v_first: torch.Tensor = None,
152
+ **kwargs,
153
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
154
+ residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states
155
+ hidden_states = self.attn_norm(residual)
156
+ hidden_states, attentions, past_key_values, v_first = self.attn(
157
+ hidden_states=hidden_states,
158
+ attention_mask=attention_mask,
159
+ past_key_values=past_key_values,
160
+ use_cache=use_cache,
161
+ output_attentions=output_attentions,
162
+ v_first=v_first,
163
+ **kwargs
164
+ )
165
+ if self.config.fuse_norm:
166
+ hidden_states, residual = self.ffn_norm(hidden_states, residual, True)
167
+ else:
168
+ hidden_states = residual + hidden_states
169
+ residual = hidden_states
170
+ hidden_states = self.ffn_norm(hidden_states)
171
+ hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values)
172
+ hidden_states = residual + hidden_states
173
+
174
+ outputs = (hidden_states, attentions, past_key_values, v_first)
175
+
176
+ return outputs
177
+
178
+
179
+ class RWKV7PreTrainedModel(PreTrainedModel):
180
+
181
+ config_class = RWKV7Config
182
+ base_model_prefix = 'model'
183
+ supports_gradient_checkpointing = True
184
+ _no_split_modules = ['RWKV7Block']
185
+ _supports_cache_class = True
186
+ _skip_keys_device_placement = ["past_key_values"]
187
+
188
+ def __init__(self, *inputs, **kwargs):
189
+ super().__init__(*inputs, **kwargs)
190
+
191
+ def _init_weights(
192
+ self,
193
+ module: nn.Module,
194
+ rescale_prenorm_residual: bool = True,
195
+ num_residuals_per_layer: int = 2,
196
+ ):
197
+ warnings.warn(
198
+ "RWKV-7 employs a carefully designed initialization strategy tailored to its architecture. "
199
+ "The detailed initialization scheme is currently not implemented here but can be found in the "
200
+ "official code repository. We emphasize that using the recommended initialization is essential "
201
+ "for replicating the results in RWKV-7 paper. Deviations from the prescribed initialization "
202
+ "may lead to performance degradation.\n"
203
+ "Alternatively, please generate initial weights from the official RWKV code repository, and "
204
+ "convert the PyTorch checkpoint into FLA supported format."
205
+ )
206
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
207
+ # Slightly different from the TF version which uses truncated_normal for initialization
208
+ # cf https://github.com/pytorch/pytorch/pull/5617
209
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
210
+ if module.bias is not None:
211
+ nn.init.zeros_(module.bias)
212
+ elif isinstance(module, nn.Parameter):
213
+ nn.init.normal_(module, mean=0.0, std=self.config.initializer_range)
214
+ elif isinstance(module, nn.Embedding):
215
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
216
+ elif hasattr(module, 'reset_parameters'):
217
+ module.reset_parameters()
218
+
219
+ if rescale_prenorm_residual:
220
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
221
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
222
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
223
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
224
+ #
225
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
226
+ p = None
227
+ if hasattr(module, 'o_proj'):
228
+ p = module.o_proj.weight
229
+ elif hasattr(module, 'down_proj'):
230
+ p = module.down_proj.weight
231
+ if p is not None:
232
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
233
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
234
+ # We need to reinit p since this code could be called multiple times
235
+ # Having just p *= scale would repeatedly scale it down
236
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
237
+ with torch.no_grad():
238
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
239
+
240
+
241
+ class RWKV7Model(RWKV7PreTrainedModel):
242
+
243
+ def __init__(self, config: RWKV7Config):
244
+ super().__init__(config)
245
+ self.padding_idx = config.pad_token_id
246
+ self.vocab_size = config.vocab_size
247
+
248
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
249
+ self.layers = nn.ModuleList([RWKV7Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
250
+ self.norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
251
+ config.hidden_size,
252
+ bias=config.norm_bias,
253
+ eps=config.norm_eps
254
+ )
255
+
256
+ self.gradient_checkpointing = False
257
+
258
+ self.post_init()
259
+
260
+ def get_input_embeddings(self):
261
+ return self.embeddings
262
+
263
+ def set_input_embeddings(self, value):
264
+ self.embeddings = value
265
+
266
+ def forward(
267
+ self,
268
+ input_ids: Optional[torch.LongTensor] = None,
269
+ attention_mask: Optional[torch.Tensor] = None, # noqa
270
+ inputs_embeds: Optional[torch.FloatTensor] = None,
271
+ past_key_values: Optional[Cache] = None,
272
+ use_cache: Optional[bool] = None,
273
+ output_attentions: Optional[bool] = None,
274
+ output_hidden_states: Optional[bool] = None,
275
+ return_dict: Optional[bool] = None,
276
+ **kwargs: Unpack[Dict]
277
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
278
+ if output_attentions:
279
+ warnings.warn("`RWKV7Model` does not `output_attentions` now, setting it to `False`.")
280
+ output_attentions = False
281
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
282
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
283
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
284
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
285
+
286
+ # retrieve input_ids and inputs_embeds
287
+ if input_ids is not None and inputs_embeds is not None:
288
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
289
+ if input_ids is None and inputs_embeds is None:
290
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
291
+
292
+ if inputs_embeds is None:
293
+ inputs_embeds = self.embeddings(input_ids)
294
+ hidden_states = inputs_embeds
295
+
296
+ if use_cache and not isinstance(past_key_values, Cache):
297
+ past_key_values = Cache.from_legacy_cache(past_key_values)
298
+
299
+ if self.gradient_checkpointing and self.training and use_cache:
300
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
301
+ use_cache = False
302
+
303
+ all_hidden_states = () if output_hidden_states else None
304
+ all_attns = () if output_attentions else None
305
+
306
+ v_first = torch.zeros_like(hidden_states)
307
+ for layer in self.layers:
308
+ if output_hidden_states:
309
+ all_hidden_states += (hidden_states,)
310
+
311
+ if self.gradient_checkpointing and self.training:
312
+ hidden_states, attentions, past_key_values, v_first = self._gradient_checkpointing_func(
313
+ layer.__call__,
314
+ hidden_states,
315
+ attention_mask,
316
+ past_key_values,
317
+ use_cache,
318
+ output_attentions,
319
+ v_first,
320
+ **kwargs
321
+ )
322
+ else:
323
+ hidden_states, attentions, past_key_values, v_first = layer(
324
+ hidden_states,
325
+ attention_mask=attention_mask,
326
+ past_key_values=past_key_values,
327
+ use_cache=use_cache,
328
+ output_attentions=output_attentions,
329
+ v_first=v_first,
330
+ **kwargs
331
+ )
332
+
333
+ if output_attentions:
334
+ all_attns += (attentions,)
335
+
336
+ hidden_states = self.norm(hidden_states)
337
+
338
+ # add hidden states from the last decoder layer
339
+ if output_hidden_states:
340
+ all_hidden_states += (hidden_states,)
341
+
342
+ if not return_dict:
343
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
344
+ return BaseModelOutputWithPast(
345
+ last_hidden_state=hidden_states,
346
+ past_key_values=past_key_values,
347
+ hidden_states=all_hidden_states,
348
+ attentions=all_attns
349
+ )
350
+
351
+
352
+ class RWKV7ForCausalLM(RWKV7PreTrainedModel, GenerationMixin):
353
+
354
+ _tied_weights_keys = ["lm_head.weight"]
355
+
356
+ def __init__(self, config):
357
+ super().__init__(config)
358
+ self.model = RWKV7Model(config)
359
+ self.vocab_size = config.vocab_size
360
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
361
+ self.criterion = None
362
+
363
+ # Initialize weights and apply final processing
364
+ self.post_init()
365
+
366
+ def get_input_embeddings(self):
367
+ return self.model.embeddings
368
+
369
+ def set_input_embeddings(self, value):
370
+ self.model.embeddings = value
371
+
372
+ def get_output_embeddings(self):
373
+ return self.lm_head
374
+
375
+ def set_output_embeddings(self, new_embeddings):
376
+ self.lm_head = new_embeddings
377
+
378
+ def set_decoder(self, decoder):
379
+ self.model = decoder
380
+
381
+ def get_decoder(self):
382
+ return self.model
383
+
384
+ def generate(self, *args, **kwargs):
385
+ try:
386
+ return super().generate(*args, **kwargs)
387
+ except AttributeError as exception:
388
+ if 'past_key_values' in str(exception):
389
+ raise AttributeError(
390
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
391
+ f"which is not supported for {self.__class__.__name__}. "
392
+ f"Try another generation strategy instead. "
393
+ f"For the available generation strategies, check this doc: "
394
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
395
+ )
396
+ else:
397
+ raise exception
398
+
399
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
400
+ def prepare_inputs_for_generation(
401
+ self,
402
+ input_ids: torch.LongTensor = None,
403
+ past_key_values: Optional[Cache] = None,
404
+ attention_mask: Optional[torch.Tensor] = None,
405
+ inputs_embeds: Optional[torch.Tensor] = None,
406
+ use_cache: bool = True,
407
+ logits_to_keep: Optional[int] = None,
408
+ **kwargs
409
+ ):
410
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
411
+ if past_key_values is not None and len(past_key_values) > 0:
412
+ input_ids = input_ids[:, -1:]
413
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
414
+ if inputs_embeds is not None and len(past_key_values) == 0:
415
+ model_inputs = {'inputs_embeds': inputs_embeds}
416
+ else:
417
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
418
+ # recompiles graphs as the stride of the inputs is a guard.
419
+ # Ref: https://github.com/huggingface/transformers/pull/29114
420
+ # TODO: use `next_tokens` directly instead.
421
+ model_inputs = {'input_ids': input_ids.contiguous()}
422
+
423
+ if logits_to_keep is not None:
424
+ model_inputs['logits_to_keep'] = logits_to_keep
425
+
426
+ model_inputs.update({
427
+ 'past_key_values': past_key_values,
428
+ 'use_cache': use_cache,
429
+ 'attention_mask': attention_mask,
430
+ })
431
+ return model_inputs
432
+
433
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
434
+ def forward(
435
+ self,
436
+ input_ids: torch.LongTensor = None,
437
+ attention_mask: Optional[torch.Tensor] = None,
438
+ inputs_embeds: Optional[torch.Tensor] = None,
439
+ past_key_values: Optional[Cache] = None,
440
+ labels: Optional[torch.LongTensor] = None,
441
+ shift_labels: Optional[torch.LongTensor] = None,
442
+ use_cache: Optional[bool] = None,
443
+ output_attentions: Optional[bool] = None,
444
+ output_hidden_states: Optional[bool] = None,
445
+ return_dict: Optional[bool] = None,
446
+ logits_to_keep: Optional[int] = 0,
447
+ **kwargs: Unpack[Dict]
448
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
449
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
450
+ output_hidden_states = (
451
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
452
+ )
453
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
454
+
455
+ outputs = self.model(
456
+ input_ids=input_ids,
457
+ attention_mask=attention_mask,
458
+ inputs_embeds=inputs_embeds,
459
+ past_key_values=past_key_values,
460
+ use_cache=use_cache,
461
+ output_attentions=output_attentions,
462
+ output_hidden_states=output_hidden_states,
463
+ return_dict=return_dict,
464
+ **kwargs
465
+ )
466
+
467
+ hidden_states = outputs[0]
468
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
469
+
470
+ loss, logits = None, None
471
+ has_labels = (labels is not None) or (shift_labels is not None)
472
+ if not (fuse_linear_and_cross_entropy and has_labels):
473
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
474
+ if has_labels:
475
+ if getattr(self, 'criterion', None) is None:
476
+ if fuse_linear_and_cross_entropy:
477
+ criterion = FusedLinearCrossEntropyLoss()
478
+ elif self.config.fuse_cross_entropy:
479
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
480
+ else:
481
+ criterion = nn.CrossEntropyLoss()
482
+ else:
483
+ criterion = self.criterion
484
+
485
+ # shift_labels: See https://github.com/huggingface/transformers/pull/36607/files.
486
+ if shift_labels is None:
487
+ shift_labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
488
+ shift_labels = shift_labels.to(hidden_states.device)
489
+
490
+ if fuse_linear_and_cross_entropy:
491
+ loss = criterion(hidden_states, shift_labels, self.lm_head.weight, self.lm_head.bias)
492
+ else:
493
+ loss = criterion(logits.view(shift_labels.numel(), -1), shift_labels.view(-1))
494
+
495
+ if not return_dict:
496
+ output = (logits,) + outputs[1:]
497
+ return (loss,) + output if loss is not None else output
498
+
499
+ return CausalLMOutputWithPast(
500
+ loss=loss,
501
+ logits=logits,
502
+ past_key_values=outputs.past_key_values,
503
+ hidden_states=outputs.hidden_states,
504
+ attentions=outputs.attentions,
505
+ )
fla/ops/abc/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (72 kB). View file
 
fla/ops/attn/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (33.2 kB). View file
 
fla/ops/based/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (22.6 kB). View file
 
fla/ops/based/fused_chunk.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
11
+
12
+
13
+ @triton.jit(do_not_specialize=['T'])
14
+ def fused_chunk_based_fwd_kernel(
15
+ q,
16
+ k,
17
+ v,
18
+ o,
19
+ z,
20
+ scale, # K ** -0.5
21
+ T,
22
+ B: tl.constexpr,
23
+ H: tl.constexpr,
24
+ K: tl.constexpr,
25
+ V: tl.constexpr,
26
+ BT: tl.constexpr,
27
+ BK: tl.constexpr,
28
+ BV: tl.constexpr,
29
+ ):
30
+ # indices
31
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
32
+
33
+ o_i = tl.arange(0, BT)
34
+
35
+ # [BT, BT]
36
+ m_s = o_i[:, None] >= o_i[None, :]
37
+
38
+ # [BV], zero-order taylor expansion
39
+ b_h_0o = tl.zeros([BV], dtype=tl.float32)
40
+ # [BK, BV], first-order taylor expansion
41
+ b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)
42
+ # [BK, BK, BV] second-order taylor expansion
43
+ b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
44
+
45
+ # make block pointers
46
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
47
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
48
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
49
+ p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
50
+
51
+ p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)
52
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
53
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
54
+ k_0o = 0
55
+
56
+ for i in range(0, tl.cdiv(T, BT)):
57
+ # [BK, BT]
58
+ b_k = tl.load(p_k, boundary_check=(0, 1))
59
+ # [BK*BK, BT]
60
+ b_k_2o = b_k[:, None, :] * b_k[None, :, :]
61
+ b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)
62
+ # [BT, BV]
63
+ b_v = tl.load(p_v, boundary_check=(0, 1))
64
+ # [BT, BK]
65
+ b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)
66
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
67
+ b_z = tl.zeros([BT], dtype=tl.float32)
68
+
69
+ # interchunk
70
+ # zero-order
71
+ b_o += b_h_0o
72
+ b_z += k_0o
73
+ # first-order
74
+ b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)
75
+ b_z += tl.sum(b_q * k_1o, axis=1)
76
+ # second-order
77
+ b_q_2o = b_q[:, :, None] * b_q[:, None, :]
78
+ b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)
79
+ b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5
80
+ b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5
81
+
82
+ # update running statistics
83
+ k_1o += tl.sum(b_k, axis=1)[None, :]
84
+ k_2o += tl.sum(b_k_2o, axis=1)[None, :]
85
+ k_0o += BT
86
+
87
+ # intrachunk
88
+ # [BT, BT]
89
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
90
+ b_s = 1 + b_s + 0.5 * b_s * b_s
91
+ b_s = tl.where(m_s, b_s, 0)
92
+ b_z += tl.sum(b_s, axis=1)
93
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
94
+ # [TB, BV]
95
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
96
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T)
97
+
98
+ # update hidden state
99
+ # [BK, BV]
100
+ b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)
101
+ b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)
102
+ b_h_0o = b_h_0o + tl.sum(b_v, axis=0)
103
+
104
+ p_q = tl.advance(p_q, (BT, 0))
105
+ p_k = tl.advance(p_k, (0, BT))
106
+ p_v = tl.advance(p_v, (BT, 0))
107
+ p_o = tl.advance(p_o, (BT, 0))
108
+ p_z += BT
109
+
110
+
111
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
112
+ @triton.jit
113
+ def fused_chunk_based_bwd_kernel(
114
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
115
+ q,
116
+ k,
117
+ v,
118
+ do,
119
+ dz,
120
+ dq,
121
+ dk,
122
+ dv,
123
+ scale, # K ** -0.5
124
+ T,
125
+ B: tl.constexpr,
126
+ H: tl.constexpr,
127
+ K: tl.constexpr,
128
+ V: tl.constexpr,
129
+ BT: tl.constexpr,
130
+ BK: tl.constexpr,
131
+ BV: tl.constexpr,
132
+ ):
133
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
134
+
135
+ o_i = tl.arange(0, BT)
136
+ m_s = o_i[:, None] >= o_i[None, :]
137
+
138
+ # [BV], zero-order taylor expansion
139
+ # b_h_0o = tl.zeros([BV], dtype=tl.float32)
140
+ # [BK, BV], first-order taylor expansion
141
+ b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)
142
+ # [BK, BK, BV] second-order taylor expansion
143
+ b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)
144
+
145
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
146
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
147
+
148
+ for i in range(0, tl.cdiv(T, BT)):
149
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
150
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
151
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
152
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
153
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
154
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT
155
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
156
+
157
+ # load tensors
158
+ # [BT, BK]
159
+ b_q = tl.load(p_q, boundary_check=(0, 1))
160
+ b_q = (b_q * scale).to(b_q.dtype)
161
+ b_k = tl.load(p_k, boundary_check=(0, 1))
162
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
163
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)
164
+ # [BV, BT]
165
+ b_v = tl.load(p_v, boundary_check=(0, 1))
166
+
167
+ # inter-chunk
168
+ b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)
169
+ if i_v == 0:
170
+ b_dq += b_dz[:, None] * k_1o
171
+ b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5
172
+ if i_v == 0:
173
+ b_dq_2o += (b_dz[:, None] * k_2o) * 0.5
174
+ b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])
175
+ b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)
176
+ b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)
177
+ b_dq *= scale
178
+
179
+ # intra-chunk
180
+ # [BT, BT]
181
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
182
+ if i_v == 0:
183
+ b_ds += b_dz[:, None]
184
+ b_ds = tl.where(m_s, b_ds, 0) * scale
185
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
186
+ b_s = tl.where(m_s, b_s, 0)
187
+ b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)
188
+
189
+ # store
190
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
191
+
192
+ # update hidden state
193
+ # [BT, BK*BK]
194
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
195
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
196
+ # [BV, BK*BK]
197
+ b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)
198
+ # [BV, BK]
199
+ b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)
200
+
201
+ if i_v == 0:
202
+ # update running statistics
203
+ k_1o += tl.sum(b_k, axis=0)[None, :]
204
+ k_2o += tl.sum(b_k_2o, axis=0)[None, :]
205
+
206
+ tl.debug_barrier()
207
+ b_h_1o = None
208
+ b_h_2o = None
209
+
210
+ # [BK, BV], first-order taylor expansion
211
+ b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)
212
+ # [BK, BK, BV] second-order taylor expansion
213
+ b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
214
+ b_dh_0o = tl.zeros([BV], dtype=tl.float32)
215
+ m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]
216
+
217
+ dq_1o = tl.zeros([1, BK], dtype=tl.float32)
218
+ dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)
219
+
220
+ for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):
221
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BT), (0, 1))
222
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i, i_k * BK), (BT, BK), (1, 0))
223
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0))
224
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0))
225
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (i, i_k*BK), (BT, BK), (1, 0))
226
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (i, i_v*BV), (BT, BV), (1, 0))
227
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i
228
+
229
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
230
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
231
+
232
+ b_q = tl.load(p_q, boundary_check=(0, 1))
233
+ b_k = tl.load(p_k, boundary_check=(0, 1))
234
+ b_v = tl.load(p_v, boundary_check=(0, 1))
235
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
236
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)
237
+ b_q = (b_q * scale).to(b_k.dtype)
238
+
239
+ # intra chunk
240
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
241
+ if i_v == 0:
242
+ b_ds += b_dz[None, :]
243
+ b_ds = tl.where(m_s, b_ds, 0)
244
+ b_s = tl.dot(b_k, b_q, allow_tf32=False)
245
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
246
+ b_s = tl.where(m_s, b_s, 0)
247
+ b_s2 = tl.where(m_s, b_s2, 0)
248
+ b_ds *= (1+b_s)
249
+
250
+ b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)
251
+ b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)
252
+
253
+ # inter chunk
254
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
255
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
256
+
257
+ b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)
258
+ b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)
259
+ b_dv += b_dh_0o
260
+
261
+ b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)
262
+
263
+ if i_v == 0:
264
+ b_dk += dq_1o
265
+
266
+ b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False)
267
+ if i_v == 0:
268
+ b_dk_2o += dq_2o
269
+ b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])
270
+ b_k_fp32 = tl.trans(b_k.to(tl.float32))
271
+ b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)
272
+ b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)
273
+ b_dk += tl.trans(b_dk2)
274
+
275
+ # hidden state update
276
+ b_dh_0o += tl.sum(b_do, axis=0)
277
+ b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)
278
+ b_q_2o = b_q[None, :, :] * b_q[:, None, :]
279
+ b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)
280
+ b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5
281
+
282
+ if i_v == 0:
283
+ dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]
284
+ dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]
285
+
286
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
287
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
288
+
289
+
290
+ class FusedChunkBasedFunction(torch.autograd.Function):
291
+
292
+ @staticmethod
293
+ @input_guard
294
+ @autocast_custom_fwd
295
+ def forward(ctx, q, k, v, scale=1):
296
+ B, H, T, K, V = *k.shape, v.shape[-1]
297
+
298
+ scale = scale
299
+ BT = 16
300
+ BK, BV = min(K, 16), min(V, 32)
301
+ BK, BV = max(BK, 16), max(BV, 16)
302
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
303
+
304
+ num_warps = 4
305
+
306
+ # the norm of o might explode, so we need to use float32 here
307
+ o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
308
+ z = q.new_empty(NK, B, H, T, dtype=torch.float32)
309
+
310
+ grid = (NV, NK, B * H)
311
+ fused_chunk_based_fwd_kernel[grid](
312
+ q, k, v, o, z,
313
+ scale,
314
+ T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV,
315
+ num_warps=num_warps,
316
+ )
317
+ o = o.sum(0)
318
+ z = z.sum(0)
319
+ ctx.save_for_backward(q, k, v)
320
+ ctx.scale = scale
321
+ return o.to(q.dtype), z.to(z.dtype)
322
+
323
+ @staticmethod
324
+ @input_guard
325
+ @autocast_custom_bwd
326
+ def backward(ctx, do, dz):
327
+ q, k, v = ctx.saved_tensors
328
+ B, H, T, K, V = *k.shape, v.shape[-1]
329
+ scale = ctx.scale
330
+
331
+ BT = 16
332
+ BK, BV = min(K, 16), min(V, 32)
333
+ BK, BV = max(BK, 16), max(BV, 16)
334
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
335
+ num_stages = 1
336
+ num_warps = 4
337
+
338
+ dq = q.new_empty(NV, B, H, T, K)
339
+ dk = q.new_empty(NV, B, H, T, K)
340
+ dv = q.new_empty(NK, B, H, T, V)
341
+ grid = (NV, NK, B * H)
342
+
343
+ fused_chunk_based_bwd_kernel[grid](
344
+ q, k, v, do, dz, dq, dk, dv,
345
+ scale,
346
+ T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV,
347
+ num_warps=num_warps,
348
+ num_stages=num_stages
349
+ )
350
+ dq = dq.sum(0)
351
+ dk = dk.sum(0)
352
+ dv = dv.sum(0)
353
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None
354
+
355
+
356
+ def fused_chunk_based(
357
+ q: torch.Tensor,
358
+ k: torch.Tensor,
359
+ v: torch.Tensor,
360
+ scale: Optional[float] = None,
361
+ use_norm: bool = True,
362
+ head_first: bool = True
363
+ ):
364
+ assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'
365
+ if scale is None:
366
+ scale = q.shape[-1] ** -0.5
367
+ if not head_first:
368
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
369
+ o, z = FusedChunkBasedFunction.apply(q, k, v, scale)
370
+ if use_norm:
371
+ o = o / (z[..., None] + 1e-6)
372
+ if not head_first:
373
+ o = o.transpose(1, 2)
374
+ return o.to(q.dtype)
fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc ADDED
Binary file (24 kB). View file
 
fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc ADDED
Binary file (6.77 kB). View file
 
fla/ops/common/__pycache__/utils.cpython-312.pyc ADDED
Binary file (4.45 kB). View file
 
fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (393 Bytes). View file
 
fla/ops/delta_rule/wy_fast.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
11
+ from fla.ops.utils.solve_tril import solve_tril
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def fwd_recompute_w_u_kernel(
30
+ k,
31
+ v,
32
+ beta,
33
+ w,
34
+ u,
35
+ A,
36
+ offsets,
37
+ indices,
38
+ T,
39
+ H: tl.constexpr,
40
+ K: tl.constexpr,
41
+ V: tl.constexpr,
42
+ BT: tl.constexpr,
43
+ BK: tl.constexpr,
44
+ BV: tl.constexpr,
45
+ HEAD_FIRST: tl.constexpr,
46
+ USE_OFFSETS: tl.constexpr
47
+ ):
48
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
49
+ i_b, i_h = i_bh // H, i_bh % H
50
+ if USE_OFFSETS:
51
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
52
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
53
+ T = eos - bos
54
+ else:
55
+ bos, eos = i_b * T, i_b * T + T
56
+
57
+ if HEAD_FIRST:
58
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
59
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
60
+ else:
61
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
62
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
63
+ b_beta = tl.load(p_beta, boundary_check=(0,))
64
+ b_A = tl.load(p_A, boundary_check=(0, 1))
65
+
66
+ for i_v in range(tl.cdiv(V, BV)):
67
+ if HEAD_FIRST:
68
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
69
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
70
+ else:
71
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
72
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
73
+ b_v = tl.load(p_v, boundary_check=(0, 1))
74
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
75
+ b_u = tl.dot(b_A.to(b_vb.dtype), b_vb, allow_tf32=False)
76
+ tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+ for i_k in range(tl.cdiv(K, BK)):
79
+ if HEAD_FIRST:
80
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
81
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
82
+ else:
83
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
84
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
85
+ b_k = tl.load(p_k, boundary_check=(0, 1))
86
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
87
+ b_w = tl.dot(b_A.to(b_kb.dtype), b_kb, allow_tf32=False)
88
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
89
+
90
+
91
+ @triton.heuristics({
92
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
93
+ })
94
+ @triton.autotune(
95
+ configs=[
96
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
97
+ for num_warps in NUM_WARPS
98
+ for num_stages in [2, 3, 4]
99
+ ],
100
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
101
+ )
102
+ @triton.jit(do_not_specialize=['T'])
103
+ def bwd_prepare_wy_repr_kernel(
104
+ k,
105
+ v,
106
+ beta,
107
+ A,
108
+ dw,
109
+ du,
110
+ dk,
111
+ dv,
112
+ dbeta,
113
+ offsets,
114
+ indices,
115
+ T,
116
+ H: tl.constexpr,
117
+ K: tl.constexpr,
118
+ V: tl.constexpr,
119
+ BT: tl.constexpr,
120
+ BK: tl.constexpr,
121
+ BV: tl.constexpr,
122
+ HEAD_FIRST: tl.constexpr,
123
+ USE_OFFSETS: tl.constexpr
124
+ ):
125
+ i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
126
+ i_b, i_h = i_bh // H, i_bh % H
127
+ if USE_OFFSETS:
128
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
129
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
130
+ T = eos - bos
131
+ else:
132
+ bos, eos = i_b * T, i_b * T + T
133
+
134
+ if HEAD_FIRST:
135
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
136
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
137
+ else:
138
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
139
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
140
+
141
+ b_beta = tl.load(p_beta, boundary_check=(0,))
142
+ b_A = tl.load(p_A, boundary_check=(0, 1))
143
+
144
+ b_dbeta = tl.zeros([BT], dtype=tl.float32)
145
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
146
+ for i_v in range(tl.cdiv(V, BV)):
147
+ if HEAD_FIRST:
148
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
149
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
150
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
151
+ else:
152
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
153
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
154
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
155
+
156
+ b_v = tl.load(p_v, boundary_check=(0, 1))
157
+ b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
158
+ b_du = tl.load(p_du, boundary_check=(0, 1))
159
+ b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
160
+ b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False)
161
+ b_dv = b_dv_beta * b_beta[:, None]
162
+ b_dbeta += tl.sum(b_dv_beta * b_v, 1)
163
+
164
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
165
+
166
+ for i_k in range(tl.cdiv(K, BK)):
167
+ if HEAD_FIRST:
168
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
169
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
170
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
171
+ else:
172
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
173
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
174
+ p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
175
+ b_k = tl.load(p_k, boundary_check=(0, 1))
176
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
177
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
178
+ b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
179
+ b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False)
180
+ b_dk = b_dk_beta * b_beta[:, None]
181
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
182
+
183
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
184
+
185
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
186
+ b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
187
+ b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
188
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
189
+
190
+ for i_k in range(tl.cdiv(K, BK)):
191
+ if HEAD_FIRST:
192
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
193
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
194
+ else:
195
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
196
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
197
+ b_k = tl.load(p_k, boundary_check=(0, 1))
198
+ b_dk = tl.load(p_dk, boundary_check=(0, 1))
199
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
200
+
201
+ b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
202
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
203
+ b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
204
+ b_dk += b_dk_beta * b_beta[:, None]
205
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
206
+
207
+ if HEAD_FIRST:
208
+ p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
209
+ else:
210
+ p_dbeta = tl.make_block_ptr(dbeta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
211
+ tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
212
+
213
+
214
+ def fwd_prepare_wy_repr(
215
+ k: torch.Tensor,
216
+ v: torch.Tensor,
217
+ beta: torch.Tensor,
218
+ offsets: Optional[torch.LongTensor],
219
+ indices: Optional[torch.LongTensor],
220
+ head_first: bool = False,
221
+ chunk_size: int = 64
222
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
223
+ A = chunk_scaled_dot_kkt_fwd(
224
+ k=k,
225
+ beta=beta,
226
+ cu_seqlens=offsets,
227
+ head_first=head_first,
228
+ chunk_size=chunk_size,
229
+ output_dtype=torch.float32
230
+ )
231
+ A = solve_tril(
232
+ A=A,
233
+ cu_seqlens=offsets,
234
+ head_first=head_first,
235
+ output_dtype=k.dtype
236
+ )
237
+
238
+ w, u = fwd_recompute_w_u(
239
+ k=k,
240
+ v=v,
241
+ beta=beta,
242
+ A=A,
243
+ offsets=offsets,
244
+ indices=indices,
245
+ head_first=head_first,
246
+ chunk_size=chunk_size
247
+ )
248
+ return w, u, A
249
+
250
+
251
+ def fwd_recompute_w_u(
252
+ k: torch.Tensor,
253
+ v: torch.Tensor,
254
+ beta: torch.Tensor,
255
+ A: torch.Tensor,
256
+ offsets: Optional[torch.LongTensor],
257
+ indices: Optional[torch.LongTensor],
258
+ head_first: bool,
259
+ chunk_size: int
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ if head_first:
262
+ B, H, T, K, V = *k.shape, v.shape[-1]
263
+ else:
264
+ B, T, H, K, V = *k.shape, v.shape[-1]
265
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
266
+ CONST_TILING = 64 if check_shared_mem() else 32
267
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
268
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
269
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
270
+
271
+ u = torch.empty_like(v)
272
+ w = torch.empty_like(k)
273
+ fwd_recompute_w_u_kernel[(NT, B*H)](
274
+ k,
275
+ v,
276
+ beta,
277
+ w,
278
+ u,
279
+ A,
280
+ offsets=offsets,
281
+ indices=indices,
282
+ T=T,
283
+ H=H,
284
+ K=K,
285
+ V=V,
286
+ BT=BT,
287
+ BK=BK,
288
+ BV=BV,
289
+ HEAD_FIRST=head_first
290
+ )
291
+ return w, u
292
+
293
+
294
+ def bwd_prepare_wy_repr(
295
+ k: torch.Tensor,
296
+ v: torch.Tensor,
297
+ beta: torch.Tensor,
298
+ A: torch.Tensor,
299
+ dw: torch.Tensor,
300
+ du: torch.Tensor,
301
+ offsets: Optional[torch.LongTensor],
302
+ indices: Optional[torch.LongTensor],
303
+ head_first: bool,
304
+ chunk_size: int
305
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
306
+ if head_first:
307
+ B, H, T, K, V = *k.shape, v.shape[-1]
308
+ else:
309
+ B, T, H, K, V = *k.shape, v.shape[-1]
310
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
311
+ CONST_TILING = 64 if check_shared_mem() else 32
312
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
313
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
314
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
315
+
316
+ dk = torch.empty_like(k)
317
+ dv = torch.empty_like(v)
318
+ dbeta = torch.empty_like(beta)
319
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
320
+ k,
321
+ v,
322
+ beta,
323
+ A,
324
+ dw,
325
+ du,
326
+ dk,
327
+ dv,
328
+ dbeta,
329
+ offsets=offsets,
330
+ indices=indices,
331
+ T=T,
332
+ H=H,
333
+ K=K,
334
+ V=V,
335
+ BT=BT,
336
+ BK=BK,
337
+ BV=BV,
338
+ HEAD_FIRST=head_first
339
+ )
340
+ return dk, dv, dbeta
fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (351 Bytes). View file
 
fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-312.pyc ADDED
Binary file (45.1 kB). View file
 
fla/ops/gated_delta_rule/wy_fast.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import safe_exp
11
+ from fla.utils import check_shared_mem
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [2, 4, 8]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['H', 'K', 'BT', 'BK', 'BC', 'HEAD_FIRST', 'USE_OFFSETS'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def fwd_prepare_wy_repr_kernel_chunk32(
27
+ k,
28
+ g,
29
+ beta,
30
+ Aw,
31
+ Au,
32
+ offsets,
33
+ indices,
34
+ T,
35
+ H: tl.constexpr,
36
+ K: tl.constexpr,
37
+ BT: tl.constexpr,
38
+ BK: tl.constexpr,
39
+ BC: tl.constexpr,
40
+ HEAD_FIRST: tl.constexpr,
41
+ USE_OFFSETS: tl.constexpr
42
+ ):
43
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
44
+ i_b, i_h = i_bh // H, i_bh % H
45
+ if USE_OFFSETS:
46
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
47
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
48
+ T = eos - bos
49
+ else:
50
+ bos, eos = i_b * T, i_b * T + T
51
+
52
+ b_Aw = tl.zeros([BC, BC], dtype=tl.float32)
53
+ if HEAD_FIRST:
54
+ p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
55
+ else:
56
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
57
+
58
+ b_beta = tl.load(p_beta, boundary_check=(0,))
59
+
60
+ for i_k in range(tl.cdiv(K, BK)):
61
+ if HEAD_FIRST:
62
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
63
+ else:
64
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
65
+ b_k = tl.load(p_k, boundary_check=(0, 1))
66
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
67
+ b_Aw += tl.dot(b_kb, tl.trans(b_k))
68
+
69
+ b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0)
70
+
71
+ if HEAD_FIRST:
72
+ p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
73
+ else:
74
+ p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
75
+
76
+ b_g = tl.load(p_g, boundary_check=(0,))
77
+ b_Au = b_Aw * safe_exp(b_g[:, None] - b_g[None, :])
78
+
79
+ for i in range(1, BC):
80
+ mask = tl.arange(0, BC) == i
81
+ b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0)
82
+ b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0)
83
+ b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i)
84
+ b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i)
85
+ b_Aw = tl.where(mask[:, None], b_aw, b_Aw)
86
+ b_Au = tl.where(mask[:, None], b_au, b_Au)
87
+
88
+ # blockwise computation of lower triangular matrix's inverse
89
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
90
+ b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
91
+ b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
92
+ if HEAD_FIRST:
93
+ p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
94
+ p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
95
+ else:
96
+ p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
97
+ p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
98
+ tl.store(p_Aw, b_Aw.to(p_Aw.dtype.element_ty), boundary_check=(0, 1))
99
+ tl.store(p_Au, b_Au.to(p_Au.dtype.element_ty), boundary_check=(0, 1))
100
+
101
+
102
+ @triton.heuristics({
103
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
104
+ })
105
+ @triton.autotune(
106
+ configs=[
107
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
108
+ for num_warps in [2, 4, 8]
109
+ for num_stages in [2, 3, 4]
110
+ ],
111
+ key=['H', 'K', 'BT', 'BK', 'BC', 'USE_OFFSETS', 'HEAD_FIRST'],
112
+ )
113
+ @triton.jit(do_not_specialize=['T'])
114
+ def fwd_prepare_wy_repr_kernel_chunk64(
115
+ k,
116
+ g,
117
+ beta,
118
+ Aw,
119
+ Au,
120
+ offsets,
121
+ indices,
122
+ T,
123
+ H: tl.constexpr,
124
+ K: tl.constexpr,
125
+ BT: tl.constexpr,
126
+ BK: tl.constexpr,
127
+ BC: tl.constexpr,
128
+ USE_OFFSETS: tl.constexpr,
129
+ HEAD_FIRST: tl.constexpr
130
+ ):
131
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
132
+ i_b, i_h = i_bh // H, i_bh % H
133
+ if USE_OFFSETS:
134
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
135
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
136
+ T = eos - bos
137
+ else:
138
+ bos, eos = i_b * T, i_b * T + T
139
+
140
+ b_Aw = tl.zeros([BC, BC], dtype=tl.float32)
141
+ b_Aw2 = tl.zeros([BC, BC], dtype=tl.float32)
142
+ b_Aw3 = tl.zeros([BC, BC], dtype=tl.float32)
143
+ if HEAD_FIRST:
144
+ p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,))
145
+ p_beta2 = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,))
146
+ else:
147
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,))
148
+ p_beta2 = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,))
149
+
150
+ b_beta = tl.load(p_beta, boundary_check=(0,))
151
+ b_beta2 = tl.load(p_beta2, boundary_check=(0,))
152
+
153
+ for i_k in range(tl.cdiv(K, BK)):
154
+ if HEAD_FIRST:
155
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
156
+ p_k2 = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
157
+ else:
158
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
159
+ p_k2 = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
160
+ b_k = tl.load(p_k, boundary_check=(0, 1))
161
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
162
+ b_k2 = tl.load(p_k2, boundary_check=(0, 1))
163
+ b_kb2 = (b_k2 * b_beta2[:, None]).to(b_k2.dtype)
164
+ b_Aw += tl.dot(b_kb, tl.trans(b_k))
165
+ b_Aw2 += tl.dot(b_kb2, tl.trans(b_k2))
166
+ b_Aw3 += tl.dot(b_kb2, tl.trans(b_k))
167
+
168
+ b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0)
169
+ b_Aw2 = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw2, 0)
170
+
171
+ if HEAD_FIRST:
172
+ p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,))
173
+ p_g2 = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,))
174
+ else:
175
+ p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,))
176
+ p_g2 = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,))
177
+ b_g = tl.load(p_g, boundary_check=(0,))
178
+ b_g2 = tl.load(p_g2, boundary_check=(0,))
179
+
180
+ mask_c = tl.arange(0, BC)[:, None] >= tl.arange(0, BC)[None, :]
181
+ mask_g = i_t * BT + tl.arange(0, BC) < T
182
+ mask_g2 = i_t * BT + BC + tl.arange(0, BC) < T
183
+
184
+ b_Au = tl.where(mask_g[None, :] & mask_c, b_Aw * safe_exp(b_g[:, None] - b_g[None, :]), 0)
185
+ b_Au2 = tl.where(mask_g2[None, :] & mask_c, b_Aw2 * safe_exp(b_g2[:, None] - b_g2[None, :]), 0)
186
+ b_Au3 = tl.where(mask_g[None, :], b_Aw3 * safe_exp(b_g2[:, None] - b_g[None, :]), 0)
187
+
188
+ for i in range(1, BC):
189
+ mask = tl.arange(0, BC) == i
190
+ b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0)
191
+ b_aw2 = tl.sum(tl.where(mask[:, None], b_Aw2, 0), 0)
192
+ b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0)
193
+ b_au2 = tl.sum(tl.where(mask[:, None], b_Au2, 0), 0)
194
+ b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i)
195
+ b_aw2 = b_aw2 + tl.sum(b_aw2[:, None] * b_Aw2, 0) * (tl.arange(0, BC) < i)
196
+ b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i)
197
+ b_au2 = b_au2 + tl.sum(b_au2[:, None] * b_Au2, 0) * (tl.arange(0, BC) < i)
198
+ b_Aw = tl.where(mask[:, None], b_aw, b_Aw)
199
+ b_Aw2 = tl.where(mask[:, None], b_aw2, b_Aw2)
200
+ b_Au = tl.where(mask[:, None], b_au, b_Au)
201
+ b_Au2 = tl.where(mask[:, None], b_au2, b_Au2)
202
+ # blockwise computation of lower triangular matrix's inverse
203
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
204
+ b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
205
+ b_Aw2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
206
+ # improve precision by disallowing tf32.
207
+ b_Aw3 = -tl.dot(tl.dot(b_Aw2, b_Aw3, allow_tf32=False), b_Aw, allow_tf32=False)
208
+ b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
209
+ b_Au2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
210
+ b_Au3 = -tl.dot(tl.dot(b_Au2, b_Au3, allow_tf32=False), b_Au, allow_tf32=False)
211
+
212
+ if HEAD_FIRST:
213
+ p_Aw1 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
214
+ p_Aw2 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
215
+ p_Aw3 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
216
+ p_Aw4 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
217
+ p_Au1 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
218
+ p_Au2 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
219
+ p_Au3 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
220
+ p_Au4 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
221
+ else:
222
+ p_Aw1 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
223
+ p_Aw2 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
224
+ p_Aw3 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
225
+ p_Aw4 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
226
+ p_Au1 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
227
+ p_Au2 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
228
+ p_Au3 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
229
+ p_Au4 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
230
+
231
+ tl.store(p_Aw1, b_Aw.to(p_Aw1.dtype.element_ty), boundary_check=(0, 1))
232
+ tl.store(p_Aw2, b_Aw2.to(p_Aw2.dtype.element_ty), boundary_check=(0, 1))
233
+ tl.store(p_Aw3, b_Aw3.to(p_Aw3.dtype.element_ty), boundary_check=(0, 1))
234
+ tl.store(p_Aw4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Aw4.dtype.element_ty), boundary_check=(0, 1))
235
+ tl.store(p_Au1, b_Au.to(p_Au1.dtype.element_ty), boundary_check=(0, 1))
236
+ tl.store(p_Au2, b_Au2.to(p_Au2.dtype.element_ty), boundary_check=(0, 1))
237
+ tl.store(p_Au3, b_Au3.to(p_Au3.dtype.element_ty), boundary_check=(0, 1))
238
+ tl.store(p_Au4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Au4.dtype.element_ty), boundary_check=(0, 1))
239
+
240
+
241
+ @triton.heuristics({
242
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
243
+ })
244
+ @triton.autotune(
245
+ configs=[
246
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
247
+ for num_warps in [2, 4, 8]
248
+ for num_stages in [2, 3, 4]
249
+ ],
250
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
251
+ )
252
+ @triton.jit(do_not_specialize=['T'])
253
+ def fwd_recompute_w_u_kernel(
254
+ k,
255
+ v,
256
+ beta,
257
+ w,
258
+ u,
259
+ Aw,
260
+ Au,
261
+ offsets,
262
+ indices,
263
+ T,
264
+ H: tl.constexpr,
265
+ K: tl.constexpr,
266
+ V: tl.constexpr,
267
+ BT: tl.constexpr,
268
+ BK: tl.constexpr,
269
+ BV: tl.constexpr,
270
+ HEAD_FIRST: tl.constexpr,
271
+ USE_OFFSETS: tl.constexpr
272
+ ):
273
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
274
+ i_b, i_h = i_bh // H, i_bh % H
275
+ if USE_OFFSETS:
276
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
277
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
278
+ T = eos - bos
279
+ else:
280
+ bos, eos = i_b * T, i_b * T + T
281
+ if HEAD_FIRST:
282
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
283
+ p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
284
+ else:
285
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
286
+ p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
287
+ b_beta = tl.load(p_beta, boundary_check=(0,))
288
+ b_Au = tl.load(p_Au, boundary_check=(0, 1))
289
+
290
+ for i_v in range(tl.cdiv(V, BV)):
291
+ if HEAD_FIRST:
292
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
293
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
294
+ else:
295
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
296
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
297
+ b_v = tl.load(p_v, boundary_check=(0, 1))
298
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
299
+ b_u = tl.dot(b_Au, b_vb, allow_tf32=False)
300
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
301
+
302
+ tl.debug_barrier()
303
+ b_Au = None
304
+ if HEAD_FIRST:
305
+ p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
306
+ else:
307
+ p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
308
+ b_Aw = tl.load(p_Aw, boundary_check=(0, 1))
309
+
310
+ for i_k in range(tl.cdiv(K, BK)):
311
+ if HEAD_FIRST:
312
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
313
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
314
+ else:
315
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
316
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
317
+ b_k = tl.load(p_k, boundary_check=(0, 1))
318
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
319
+ b_w = tl.dot(b_Aw, b_kb)
320
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
321
+
322
+
323
+ def fwd_prepare_wy_repr(
324
+ k: torch.Tensor,
325
+ v: torch.Tensor,
326
+ g: torch.Tensor,
327
+ beta: torch.Tensor,
328
+ offsets: Optional[torch.LongTensor],
329
+ indices: Optional[torch.LongTensor],
330
+ head_first: bool = True,
331
+ chunk_size: int = 64
332
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
333
+ if head_first:
334
+ B, H, T, K = k.shape
335
+ else:
336
+ B, T, H, K = k.shape
337
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
338
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
339
+ BC = min(BT, 32)
340
+ BK = min(triton.next_power_of_2(K), 64)
341
+ # bf16 should be good enough.
342
+ Aw = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype)
343
+ Au = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype)
344
+
345
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
346
+ fwd_fn[(NT, B*H)](
347
+ k=k,
348
+ g=g,
349
+ beta=beta,
350
+ Aw=Aw,
351
+ Au=Au,
352
+ offsets=offsets,
353
+ indices=indices,
354
+ T=T,
355
+ H=H,
356
+ K=K,
357
+ BT=BT,
358
+ BK=BK,
359
+ BC=BC,
360
+ HEAD_FIRST=head_first
361
+ )
362
+ w, u = fwd_recompute_w_u(
363
+ k=k,
364
+ v=v,
365
+ beta=beta,
366
+ Aw=Aw,
367
+ Au=Au,
368
+ offsets=offsets,
369
+ indices=indices,
370
+ head_first=head_first,
371
+ chunk_size=chunk_size
372
+ )
373
+ return w, u, Aw, Au
374
+
375
+
376
+ def fwd_recompute_w_u(
377
+ k: torch.Tensor,
378
+ v: torch.Tensor,
379
+ beta: torch.Tensor,
380
+ Aw: torch.Tensor,
381
+ Au: torch.Tensor,
382
+ offsets: Optional[torch.LongTensor],
383
+ indices: Optional[torch.LongTensor],
384
+ head_first: bool,
385
+ chunk_size: int
386
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
387
+ if head_first:
388
+ B, H, T, K, V = *k.shape, v.shape[-1]
389
+ else:
390
+ B, T, H, K, V = *k.shape, v.shape[-1]
391
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
392
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
393
+ BK = min(triton.next_power_of_2(K), 64)
394
+ BV = min(triton.next_power_of_2(V), 64)
395
+
396
+ u = torch.empty_like(v)
397
+ w = torch.empty_like(k)
398
+ fwd_recompute_w_u_kernel[(NT, B*H)](
399
+ k=k,
400
+ v=v,
401
+ beta=beta,
402
+ w=w,
403
+ u=u,
404
+ Aw=Aw,
405
+ Au=Au,
406
+ offsets=offsets,
407
+ indices=indices,
408
+ T=T,
409
+ H=H,
410
+ K=K,
411
+ V=V,
412
+ BT=BT,
413
+ BK=BK,
414
+ BV=BV,
415
+ HEAD_FIRST=head_first
416
+ )
417
+ return w, u
418
+
419
+
420
+ @triton.heuristics({
421
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
422
+ })
423
+ @triton.autotune(
424
+ configs=[
425
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
426
+ for num_warps in [2, 4]
427
+ for num_stages in [2, 3, 4]
428
+ ],
429
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS']
430
+ )
431
+ @triton.jit(do_not_specialize=['T'])
432
+ def bwd_prepare_wy_repr_kernel(
433
+ k,
434
+ v,
435
+ beta,
436
+ g,
437
+ Aw,
438
+ Au,
439
+ dw,
440
+ du,
441
+ dk,
442
+ dv,
443
+ dbeta,
444
+ dg,
445
+ offsets,
446
+ indices,
447
+ T,
448
+ H: tl.constexpr,
449
+ K: tl.constexpr,
450
+ V: tl.constexpr,
451
+ BT: tl.constexpr,
452
+ BK: tl.constexpr,
453
+ BV: tl.constexpr,
454
+ HEAD_FIRST: tl.constexpr,
455
+ USE_OFFSETS: tl.constexpr
456
+ ):
457
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
458
+ i_b, i_h = i_bh // H, i_bh % H
459
+ if USE_OFFSETS:
460
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
461
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
462
+ T = eos - bos
463
+ else:
464
+ bos, eos = i_b * T, i_b * T + T
465
+
466
+ b_dbeta = tl.zeros([BT], dtype=tl.float32)
467
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
468
+ if HEAD_FIRST:
469
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
470
+ p_A = tl.make_block_ptr(Aw + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
471
+ else:
472
+ p_beta = tl.make_block_ptr(beta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
473
+ p_A = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
474
+
475
+ b_A = tl.load(p_A, boundary_check=(0, 1))
476
+ b_beta = tl.load(p_beta, boundary_check=(0,))
477
+
478
+ for i_k in range(tl.cdiv(K, BK)):
479
+ if HEAD_FIRST:
480
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
481
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
482
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
483
+ else:
484
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
485
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
486
+ p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
487
+ b_k = tl.load(p_k, boundary_check=(0, 1))
488
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
489
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
490
+ b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
491
+ b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False)
492
+ b_dk = b_dk_beta * b_beta[:, None]
493
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
494
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
495
+
496
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
497
+ b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
498
+ b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
499
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
500
+
501
+ if HEAD_FIRST:
502
+ p_A = tl.make_block_ptr(Au + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
503
+ else:
504
+ p_A = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
505
+ b_A = tl.load(p_A, boundary_check=(0, 1))
506
+ b_dA2 = tl.zeros([BT, BT], dtype=tl.float32)
507
+
508
+ for i_v in range(tl.cdiv(V, BV)):
509
+ if HEAD_FIRST:
510
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
511
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
512
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
513
+ else:
514
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
515
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
516
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
517
+ b_v = tl.load(p_v, boundary_check=(0, 1))
518
+ b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
519
+ b_du = tl.load(p_du, boundary_check=(0, 1))
520
+ b_dA2 += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
521
+ b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False)
522
+ b_dv = b_dv_beta * b_beta[:, None]
523
+ b_dbeta += tl.sum(b_dv_beta * b_v, 1)
524
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
525
+
526
+ b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA2, 0)
527
+ b_dA2 = tl.dot(b_dA2.to(b_A.dtype), b_A)
528
+ b_dA2 = tl.dot(b_A, b_dA2.to(b_A.dtype))
529
+ b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA2, 0).to(k.dtype.element_ty)
530
+ if HEAD_FIRST:
531
+ p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
532
+ else:
533
+ p_g = tl.make_block_ptr(g + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
534
+ b_g = tl.load(p_g, boundary_check=(0,))
535
+ b_dA2 *= safe_exp(b_g[:, None] - b_g[None, :])
536
+ b_dA += b_dA2
537
+ b_dA = b_dA.to(k.dtype.element_ty)
538
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
539
+
540
+ for i_k in range(tl.cdiv(K, BK)):
541
+ if HEAD_FIRST:
542
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
543
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
544
+ else:
545
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
546
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
547
+ b_k = tl.load(p_k, boundary_check=(0, 1))
548
+ b_dk = tl.load(p_dk, boundary_check=(0, 1))
549
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
550
+ b_A += tl.dot(b_k_beta, tl.trans(b_k))
551
+ b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
552
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
553
+ b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
554
+ b_dk += b_dk_beta * b_beta[:, None]
555
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
556
+ b_dA2 *= b_A
557
+ b_dg = tl.sum(b_dA2, axis=1) - tl.sum(b_dA2, axis=0)
558
+ if HEAD_FIRST:
559
+ p_dg = tl.make_block_ptr(dg + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
560
+ p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
561
+ else:
562
+ p_dg = tl.make_block_ptr(dg + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
563
+ p_dbeta = tl.make_block_ptr(dbeta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
564
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
565
+ tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
566
+
567
+
568
+ def bwd_prepare_wy_repr(
569
+ k: torch.Tensor,
570
+ v: torch.Tensor,
571
+ g: torch.Tensor,
572
+ beta: torch.Tensor,
573
+ Aw: torch.Tensor,
574
+ Au: torch.Tensor,
575
+ dw: torch.Tensor,
576
+ du: torch.Tensor,
577
+ offsets: Optional[torch.LongTensor],
578
+ indices: Optional[torch.LongTensor],
579
+ head_first: bool,
580
+ chunk_size: int
581
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
582
+ if head_first:
583
+ B, H, T, K, V = *k.shape, v.shape[-1]
584
+ else:
585
+ B, T, H, K, V = *k.shape, v.shape[-1]
586
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
587
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
588
+ CONST_TILING = 64 if check_shared_mem() else 32
589
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
590
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
591
+
592
+ dk = torch.empty_like(k)
593
+ dv = torch.empty_like(v)
594
+ dbeta = torch.empty_like(beta)
595
+ dg = torch.empty_like(g)
596
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
597
+ k=k,
598
+ v=v,
599
+ beta=beta,
600
+ g=g,
601
+ Aw=Aw,
602
+ Au=Au,
603
+ dw=dw,
604
+ du=du,
605
+ dk=dk,
606
+ dv=dv,
607
+ dbeta=dbeta,
608
+ dg=dg,
609
+ offsets=offsets,
610
+ indices=indices,
611
+ T=T,
612
+ H=H,
613
+ K=K,
614
+ V=V,
615
+ BT=BT,
616
+ BK=BK,
617
+ BV=BV,
618
+ HEAD_FIRST=head_first
619
+ )
620
+ return dk, dv, dbeta, dg
fla/ops/generalized_delta_rule/dplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_dplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_dplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_dplr_delta_rule',
6
+ 'fused_recurrent_dplr_delta_rule'
7
+ ]
fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (360 Bytes). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (11.6 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc ADDED
Binary file (30.6 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc ADDED
Binary file (25.5 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc ADDED
Binary file (8.94 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc ADDED
Binary file (13.2 kB). View file
 
fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, gather
11
+ from fla.utils import is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
20
+ for BK in [32, 64]
21
+ for num_warps in [2, 4, 8, 16]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['BC', 'K'],
25
+ use_cuda_graph=use_cuda_graph,
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_dplr_fwd_A_kernel_intra_sub_inter(
29
+ q,
30
+ k,
31
+ a,
32
+ b,
33
+ gi, # cumsum
34
+ ge, # before cumsum
35
+ Aqk,
36
+ Aqb,
37
+ Aab,
38
+ Aak,
39
+ offsets,
40
+ indices,
41
+ scale: tl.constexpr,
42
+ T,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ BT: tl.constexpr,
46
+ BC: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ NC: tl.constexpr,
49
+ USE_OFFSETS: tl.constexpr,
50
+ HEAD_FIRST: tl.constexpr,
51
+ ):
52
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_h = i_bh // H, i_bh % H
54
+ i_i, i_j = i_c // NC, i_c % NC
55
+ if USE_OFFSETS:
56
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ if i_t * BT + i_i * BC >= T:
63
+ return
64
+ if i_i <= i_j:
65
+ return
66
+
67
+ b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
68
+ b_Aqb = tl.zeros([BC, BC], dtype=tl.float32)
69
+ b_Aab = tl.zeros([BC, BC], dtype=tl.float32)
70
+ b_Aak = tl.zeros([BC, BC], dtype=tl.float32)
71
+ for i_k in range(tl.cdiv(K, BK)):
72
+ o_k = i_k * BK + tl.arange(0, BK)
73
+ m_k = o_k < K
74
+
75
+ if HEAD_FIRST:
76
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
77
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
78
+ p_gq_i = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
79
+ p_gq_e = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
80
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
81
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
82
+ p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
83
+ p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK)
84
+ else:
85
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
86
+ p_a = tl.make_block_ptr(a + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
87
+ p_gq_i = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
88
+ p_gq_e = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
90
+ p_b = tl.make_block_ptr(b + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
91
+ p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
92
+ p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k
93
+ # [BK,]
94
+ b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)
95
+ # [BC, BK]
96
+ b_q = tl.load(p_q, boundary_check=(0, 1))
97
+ b_a = tl.load(p_a, boundary_check=(0, 1))
98
+ b_gq_i = tl.load(p_gq_i, boundary_check=(0, 1))
99
+ b_gq_e = tl.load(p_gq_e, boundary_check=(0, 1))
100
+ b_ag = b_a * exp(b_gq_e - b_gn[None, :])
101
+ b_qg = b_q * exp(b_gq_i - b_gn[None, :]) * scale
102
+ # [BK, BC]
103
+ b_k = tl.load(p_k, boundary_check=(0, 1))
104
+ b_b = tl.load(p_b, boundary_check=(0, 1))
105
+ b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32)
106
+ tmp = exp(b_gn[:, None] - b_gk)
107
+ b_kg = b_k * tmp
108
+ b_bg = b_b * tmp
109
+ # [BC, BC] using tf32 to improve precision here.
110
+ b_Aab += tl.dot(b_ag, b_bg)
111
+ b_Aak += tl.dot(b_ag, b_kg)
112
+ b_Aqk += tl.dot(b_qg, b_kg)
113
+ b_Aqb += tl.dot(b_qg, b_bg)
114
+
115
+ if HEAD_FIRST:
116
+ p_Aqk = tl.make_block_ptr(Aqk + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
117
+ p_Aqb = tl.make_block_ptr(Aqb + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
118
+ p_Aab = tl.make_block_ptr(Aab + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
119
+ p_Aak = tl.make_block_ptr(Aak + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
120
+ else:
121
+ p_Aqk = tl.make_block_ptr(Aqk + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
122
+ p_Aqb = tl.make_block_ptr(Aqb + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
123
+ p_Aab = tl.make_block_ptr(Aab + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
124
+ p_Aak = tl.make_block_ptr(Aak + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
125
+ tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
126
+ tl.store(p_Aqb, b_Aqb.to(Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
127
+ tl.store(p_Aab, b_Aab.to(Aab.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
128
+ tl.store(p_Aak, b_Aak.to(Aak.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
129
+
130
+
131
+ @triton.heuristics({
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in [2, 4, 8, 16, 32]
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BK', 'BT'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_dplr_fwd_A_kernel_intra_sub_intra(
145
+ q,
146
+ k,
147
+ a,
148
+ b,
149
+ gi,
150
+ ge,
151
+ qg,
152
+ kg,
153
+ ag,
154
+ bg,
155
+ Aqk,
156
+ Aqb,
157
+ Aab,
158
+ Aak,
159
+ offsets,
160
+ indices,
161
+ scale: tl.constexpr,
162
+ T,
163
+ H: tl.constexpr,
164
+ K: tl.constexpr,
165
+ BT: tl.constexpr,
166
+ BC: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ NC: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr,
171
+ GATHER_SUPPORTED: tl.constexpr
172
+ ):
173
+ i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
174
+ i_b, i_h = i_bh // H, i_bh % H
175
+ i_j = i_i
176
+ if USE_OFFSETS:
177
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
178
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
179
+ T = eos - bos
180
+ else:
181
+ bos, eos = i_b * T, i_b * T + T
182
+
183
+ if i_t * BT + i_i * BC >= T:
184
+ return
185
+
186
+ o_i = tl.arange(0, BC)
187
+ o_k = tl.arange(0, BK)
188
+ m_k = o_k < K
189
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
190
+ last_idx = min((i_t+1) * BT, T) - 1
191
+ if HEAD_FIRST:
192
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
193
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
194
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
195
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
196
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
197
+ p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
198
+ p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
199
+ p_g_last = gi + i_bh * T*K + last_idx * K + tl.arange(0, BK)
200
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
201
+
202
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
203
+ p_kg = tl.make_block_ptr(kg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
204
+ p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
205
+ p_bg = tl.make_block_ptr(bg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
206
+ else:
207
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC
208
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
209
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
210
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
211
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
212
+ p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
213
+ p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
214
+ p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK)
215
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
216
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
217
+ p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
218
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
219
+ p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
220
+
221
+ b_q = tl.load(p_q, boundary_check=(0, 1))
222
+ b_q = b_q * scale
223
+ b_k = tl.load(p_k, boundary_check=(0, 1))
224
+ b_a = tl.load(p_a, boundary_check=(0, 1))
225
+ b_b = tl.load(p_b, boundary_check=(0, 1))
226
+ b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32)
227
+ b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32)
228
+
229
+ # deal with decay term.
230
+ g_exp = exp(b_gi)
231
+ g_exp_inv = exp(-b_gi + b_g_last[None, :])
232
+ b_qg = b_q * g_exp
233
+ b_kg = b_k * g_exp_inv
234
+ b_bg = b_b * g_exp_inv
235
+ b_ag = b_a * exp(b_ge)
236
+ tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
237
+ tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
238
+ tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
239
+ tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
240
+ # tl.debug_barrier()
241
+
242
+ b_q = b_q.to(b_k.dtype)
243
+ # inner attn
244
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
245
+ # a trick to index the j-th row of b_k, b_g, b_b
246
+ if GATHER_SUPPORTED:
247
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
248
+ # [1, BK]
249
+ b_k_j = gather(b_k, row_idx, axis=0)
250
+ b_gk_j = gather(b_gi, row_idx, axis=0)
251
+ b_b_j = gather(b_b, row_idx, axis=0)
252
+ else:
253
+ mask = tl.arange(0, BC) == j
254
+ b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :]
255
+ b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :]
256
+ b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :]
257
+ mask = tl.arange(0, BC) == j
258
+ tmp = exp(b_gi - b_gk_j)
259
+ b_A_qk = tl.sum(b_q * b_k_j * tmp, 1)
260
+ b_A_qk = tl.where(o_i >= j, b_A_qk, 0.)
261
+ b_A_qb = tl.sum(b_q * b_b_j * tmp, 1)
262
+ b_A_qb = tl.where(o_i >= j, b_A_qb, 0.)
263
+ tmp2 = exp(b_ge - b_gk_j)
264
+ b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1)
265
+ b_A_ak = tl.where(o_i > j, b_A_ak, 0.)
266
+ b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1)
267
+ b_A_ab = tl.where(o_i > j, b_A_ab, 0.)
268
+ tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
269
+ tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
270
+ tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
271
+ tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
272
+
273
+
274
+ def chunk_fwd_intra_dplr_fn(
275
+ q: torch.Tensor,
276
+ k: torch.Tensor,
277
+ a: torch.Tensor,
278
+ b: torch.Tensor,
279
+ gi: torch.Tensor,
280
+ ge: torch.Tensor,
281
+ scale: float,
282
+ chunk_size: int,
283
+ offsets: Optional[torch.LongTensor] = None,
284
+ indices: Optional[torch.LongTensor] = None,
285
+ head_first: bool = True,
286
+ ):
287
+ if head_first:
288
+ B, H, T, K = k.shape
289
+ else:
290
+ B, T, H, K = k.shape
291
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
292
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
293
+ BC = min(16, BT)
294
+ NC = triton.cdiv(BT, BC)
295
+
296
+ Aqk = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
297
+ Aqb = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
298
+ # involving matrix inverse and it'd be better to use float here.
299
+ Aab = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
300
+ Aak = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
301
+ grid = (NT, NC * NC, B * H)
302
+
303
+ chunk_dplr_fwd_A_kernel_intra_sub_inter[grid](
304
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
305
+ offsets=offsets, indices=indices,
306
+ scale=scale,
307
+ T=T, H=H, K=K, BT=BT, BC=BC, NC=NC,
308
+ HEAD_FIRST=head_first
309
+ )
310
+ grid = (NT, NC, B * H)
311
+ BK = triton.next_power_of_2(K)
312
+ qg = torch.empty_like(q)
313
+ kg = torch.empty_like(k, dtype=q.dtype)
314
+ ag = torch.empty_like(a, dtype=q.dtype)
315
+ bg = torch.empty_like(b, dtype=q.dtype)
316
+ chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
317
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
318
+ qg=qg, kg=kg, ag=ag, bg=bg,
319
+ offsets=offsets, indices=indices,
320
+ scale=scale,
321
+ T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, HEAD_FIRST=head_first, NC=NC,
322
+ GATHER_SUPPORTED=is_gather_supported
323
+ )
324
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_h(
31
+ kg,
32
+ v,
33
+ w,
34
+ bg,
35
+ u,
36
+ v_new,
37
+ gk,
38
+ h,
39
+ h0,
40
+ ht,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ NT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ if HEAD_FIRST:
77
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
81
+
82
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
83
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
84
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
85
+ if HEAD_FIRST:
86
+ p_kg = tl.make_block_ptr(kg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_bg = tl.make_block_ptr(bg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
88
+ p_w = tl.make_block_ptr(w + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
91
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
92
+ else:
93
+ p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
95
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
96
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
98
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
99
+ # [BK, BC]
100
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+ b_w = tl.load(p_w, boundary_check=(0, 1))
103
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
104
+ b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1))
105
+ b_hc += tl.dot(b_kg, b_v)
106
+ b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2)
107
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
108
+
109
+ last_idx = min((i_t + 1) * BT, T) - 1
110
+ if HEAD_FIRST:
111
+ b_g_last = tl.load(gk + i_nh * T * K + last_idx * K + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
112
+ else:
113
+ b_g_last = tl.load(gk + (bos + last_idx) * H * K + i_h * K +
114
+ tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
115
+ b_h *= exp(b_g_last[:, None])
116
+ b_h += b_hc
117
+
118
+ if STORE_FINAL_STATE:
119
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
120
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
121
+
122
+
123
+ def chunk_dplr_fwd_h(
124
+ kg: torch.Tensor,
125
+ v: torch.Tensor,
126
+ w: torch.Tensor,
127
+ u: torch.Tensor,
128
+ bg: torch.Tensor,
129
+ gk: torch.Tensor,
130
+ initial_state: Optional[torch.Tensor] = None,
131
+ output_final_state: bool = False,
132
+ offsets: Optional[torch.LongTensor] = None,
133
+ indices: Optional[torch.LongTensor] = None,
134
+ head_first: bool = True,
135
+ chunk_size: int = 64
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ if head_first:
138
+ B, H, T, K, V = *kg.shape, u.shape[-1]
139
+ else:
140
+ B, T, H, K, V = *kg.shape, u.shape[-1]
141
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
142
+ # N: the actual number of sequences in the batch with either equal or variable lengths
143
+ if offsets is None:
144
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
145
+ else:
146
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
147
+ BK = triton.next_power_of_2(K)
148
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
149
+ # H100 can have larger block size
150
+
151
+ if check_shared_mem('hopper', kg.device.index):
152
+ BV = 64
153
+ BC = 64 if K <= 128 else 32
154
+ elif check_shared_mem('ampere', kg.device.index): # A100
155
+ BV = 32
156
+ BC = 32
157
+ else:
158
+ BV = 16
159
+ BC = 16
160
+
161
+ BC = min(BT, BC)
162
+ NK = triton.cdiv(K, BK)
163
+ NV = triton.cdiv(V, BV)
164
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
165
+
166
+ if head_first:
167
+ h = kg.new_empty(B, H, NT, K, V)
168
+ else:
169
+ h = kg.new_empty(B, NT, H, K, V)
170
+ final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
171
+ v_new = torch.empty_like(u)
172
+ grid = (NK, NV, N * H)
173
+ chunk_dplr_fwd_kernel_h[grid](
174
+ kg=kg,
175
+ v=v,
176
+ w=w,
177
+ bg=bg,
178
+ u=u,
179
+ v_new=v_new,
180
+ h=h,
181
+ gk=gk,
182
+ h0=initial_state,
183
+ ht=final_state,
184
+ offsets=offsets,
185
+ chunk_offsets=chunk_offsets,
186
+ T=T,
187
+ H=H,
188
+ K=K,
189
+ V=V,
190
+ BT=BT,
191
+ BC=BC,
192
+ BK=BK,
193
+ BV=BV,
194
+ NT=NT,
195
+ HEAD_FIRST=head_first
196
+ )
197
+ return h, v_new, final_state
fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import check_shared_mem, use_cuda_graph
12
+
13
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [2, 4, 8, 16, 32]
23
+ for num_stages in [2, 3, 4]
24
+ ],
25
+ key=['BV', 'BT'],
26
+ use_cuda_graph=use_cuda_graph,
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def chunk_dplr_bwd_kernel_dAu(
30
+ v,
31
+ do,
32
+ v_new,
33
+ A_qb,
34
+ dA_qk,
35
+ dA_qb,
36
+ dv_new,
37
+ offsets,
38
+ indices,
39
+ scale: tl.constexpr,
40
+ T,
41
+ H: tl.constexpr,
42
+ V: tl.constexpr,
43
+ BT: tl.constexpr,
44
+ BV: tl.constexpr,
45
+ USE_OFFSETS: tl.constexpr,
46
+ HEAD_FIRST: tl.constexpr
47
+ ):
48
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
49
+ i_b, i_h = i_bh // H, i_bh % H
50
+ if USE_OFFSETS:
51
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
52
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
53
+ else:
54
+ bos, eos = i_b * T, i_b * T + T
55
+ T = eos - bos
56
+
57
+ b_dA_qk = tl.zeros([BT, BT], dtype=tl.float32)
58
+ b_dA_qb = tl.zeros([BT, BT], dtype=tl.float32)
59
+
60
+ if HEAD_FIRST:
61
+ p_A_qb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
62
+ else:
63
+ p_A_qb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
64
+
65
+ b_A_qb = tl.load(p_A_qb, boundary_check=(0, 1))
66
+ # causal mask
67
+ b_A_qb = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_A_qb, 0.).to(b_A_qb.dtype)
68
+
69
+ for i_v in range(tl.cdiv(V, BV)):
70
+ if HEAD_FIRST:
71
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
72
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
73
+ p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
74
+ p_dv_new = tl.make_block_ptr(dv_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
75
+ else:
76
+ p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
77
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
78
+ p_v_new = tl.make_block_ptr(v_new + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
79
+ p_dv_new = tl.make_block_ptr(dv_new + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
80
+ b_v = tl.load(p_v, boundary_check=(0, 1))
81
+ b_do = tl.load(p_do, boundary_check=(0, 1))
82
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
83
+ b_dA_qk += tl.dot(b_do, b_v)
84
+ b_dA_qb += tl.dot(b_do, b_v_new)
85
+ b_dv_new = tl.dot(tl.trans(b_A_qb), b_do)
86
+ # for recurrent
87
+ tl.store(p_dv_new, b_dv_new.to(p_dv_new.dtype.element_ty), boundary_check=(0, 1))
88
+
89
+ if HEAD_FIRST:
90
+ p_dA_qk = tl.make_block_ptr(dA_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
91
+ p_dA_qb = tl.make_block_ptr(dA_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
92
+ else:
93
+ p_dA_qk = tl.make_block_ptr(dA_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
94
+ p_dA_qb = tl.make_block_ptr(dA_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
95
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
96
+ b_dA_qk = tl.where(m_s, b_dA_qk * scale, 0.)
97
+ tl.store(p_dA_qk, b_dA_qk.to(p_dA_qk.dtype.element_ty), boundary_check=(0, 1))
98
+ b_dA_qb = tl.where(m_s, b_dA_qb * scale, 0.)
99
+ tl.store(p_dA_qb, b_dA_qb.to(p_dA_qb.dtype.element_ty), boundary_check=(0, 1))
100
+
101
+
102
+ @triton.heuristics({
103
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
104
+ })
105
+ @triton.autotune(
106
+ configs=[
107
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
108
+ for num_warps in [2, 4, 8, 16, 32]
109
+ for num_stages in [2, 3, 4]
110
+ ],
111
+ key=['BT', 'BK', 'BV'],
112
+ use_cuda_graph=use_cuda_graph,
113
+ )
114
+ @triton.jit
115
+ def chunk_dplr_bwd_o_kernel(
116
+ v,
117
+ v_new,
118
+ h,
119
+ do,
120
+ dh,
121
+ dk,
122
+ db,
123
+ w,
124
+ dq,
125
+ dv,
126
+ dw,
127
+ gk,
128
+ dgk_last,
129
+ k,
130
+ b,
131
+ offsets,
132
+ indices,
133
+ T,
134
+ H: tl.constexpr,
135
+ K: tl.constexpr,
136
+ V: tl.constexpr,
137
+ BT: tl.constexpr,
138
+ BK: tl.constexpr,
139
+ BV: tl.constexpr,
140
+ USE_OFFSETS: tl.constexpr,
141
+ HEAD_FIRST: tl.constexpr,
142
+ ):
143
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
144
+ i_b, i_h = i_bh // H, i_bh % H
145
+
146
+ if USE_OFFSETS:
147
+ i_tg = i_t
148
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
149
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
150
+ T = eos - bos
151
+ NT = tl.cdiv(T, BT)
152
+ else:
153
+ NT = tl.cdiv(T, BT)
154
+ i_tg = i_b * NT + i_t
155
+ bos, eos = i_b * T, i_b * T + T
156
+
157
+ # offset calculation
158
+ v += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
159
+ v_new += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
160
+ do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
161
+ h += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V
162
+ dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V
163
+ dk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
164
+ k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
165
+ db += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
166
+ b += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
167
+ dw += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
168
+ dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
169
+ dq += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
170
+ w += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
171
+ # CHECK HEAD_FIRST is FALSE
172
+ dgk_last += (i_bh * NT + i_t) * K if HEAD_FIRST else (i_tg * H + i_h) * K
173
+ gk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
174
+
175
+ stride_qk = K if HEAD_FIRST else H*K
176
+ stride_vo = V if HEAD_FIRST else H*V
177
+
178
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
179
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
180
+ b_dw = tl.zeros([BT, BK], dtype=tl.float32)
181
+ b_db = tl.zeros([BT, BK], dtype=tl.float32)
182
+ b_dgk_last = tl.zeros([BK], dtype=tl.float32)
183
+
184
+ for i_v in range(tl.cdiv(V, BV)):
185
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
186
+ p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
187
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
188
+ p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
189
+ p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
190
+ # [BT, BV]
191
+ b_v = tl.load(p_v, boundary_check=(0, 1))
192
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
193
+ b_do = tl.load(p_do, boundary_check=(0, 1))
194
+ # [BV, BK]
195
+ b_h = tl.load(p_h, boundary_check=(0, 1))
196
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
197
+ b_dgk_last += tl.sum((b_h * b_dh).to(tl.float32), axis=0)
198
+
199
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
200
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
201
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
202
+ b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
203
+ b_db += tl.dot(b_v_new, b_dh.to(b_v_new.dtype))
204
+ p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
205
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
206
+ b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
207
+
208
+ m_k = (i_k*BK+tl.arange(0, BK)) < K
209
+ last_idx = min(i_t * BT + BT, T) - 1
210
+ b_gk_last = tl.load(gk + last_idx * stride_qk + i_k*BK + tl.arange(0, BK), mask=m_k, other=float('-inf'))
211
+ b_dgk_last *= exp(b_gk_last)
212
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
213
+ p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
214
+ b_k = tl.load(p_k, boundary_check=(0, 1))
215
+ b_b = tl.load(p_b, boundary_check=(0, 1))
216
+ b_dgk_last += tl.sum(b_k * b_dk, axis=0)
217
+ b_dgk_last += tl.sum(b_b * b_db, axis=0)
218
+ tl.store(dgk_last + tl.arange(0, BK) + i_k * BK, b_dgk_last, mask=m_k)
219
+
220
+ p_dw = tl.make_block_ptr(dw, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
221
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
222
+ p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
223
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
224
+ tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
225
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
226
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
227
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
228
+
229
+
230
+ @triton.heuristics({
231
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
232
+ })
233
+ @triton.autotune(
234
+ configs=[
235
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
236
+ for num_warps in [2, 4, 8, 16, 32]
237
+ for num_stages in [2, 3, 4]
238
+ for BK in BK_LIST
239
+ for BV in BK_LIST
240
+ ],
241
+ key=['BT', 'BK', 'BV'],
242
+ use_cuda_graph=use_cuda_graph,
243
+ )
244
+ @triton.jit
245
+ def chunk_dplr_bwd_kernel_dv(
246
+ A_qk,
247
+ kg,
248
+ do,
249
+ dv,
250
+ dh,
251
+ offsets,
252
+ indices,
253
+ T,
254
+ H: tl.constexpr,
255
+ K: tl.constexpr,
256
+ V: tl.constexpr,
257
+ BT: tl.constexpr,
258
+ BK: tl.constexpr,
259
+ BV: tl.constexpr,
260
+ USE_OFFSETS: tl.constexpr,
261
+ HEAD_FIRST: tl.constexpr,
262
+ ):
263
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
264
+ i_b, i_h = i_bh // H, i_bh % H
265
+ if USE_OFFSETS:
266
+ i_tg = i_t
267
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
268
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
269
+ T = eos - bos
270
+ NT = tl.cdiv(T, BT)
271
+ else:
272
+ NT = tl.cdiv(T, BT)
273
+ i_tg = i_b * NT + i_t
274
+ bos, eos = i_b * T, i_b * T + T
275
+
276
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
277
+
278
+ # offset calculation
279
+ A_qk += i_bh * T * BT if HEAD_FIRST else (bos * H + i_h) * BT
280
+ do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
281
+ dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
282
+ kg += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
283
+ dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K*V
284
+
285
+ stride_qk = K if HEAD_FIRST else H*K
286
+ stride_vo = V if HEAD_FIRST else H*V
287
+ stride_A = BT if HEAD_FIRST else H*BT
288
+
289
+ for i_k in range(tl.cdiv(K, BK)):
290
+ p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
291
+ p_kg = tl.make_block_ptr(kg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
292
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
293
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
294
+ b_dv += tl.dot(b_kg, b_dh.to(b_kg.dtype))
295
+
296
+ p_Aqk = tl.make_block_ptr(A_qk, (BT, T), (1, stride_A), (0, i_t * BT), (BT, BT), (0, 1))
297
+ b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], tl.load(p_Aqk, boundary_check=(0, 1)), 0)
298
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
299
+ p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
300
+ b_do = tl.load(p_do, boundary_check=(0, 1))
301
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do)
302
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
303
+
304
+
305
+ def chunk_dplr_bwd_dv(
306
+ A_qk: torch.Tensor,
307
+ kg: torch.Tensor,
308
+ do: torch.Tensor,
309
+ dh: torch.Tensor,
310
+ offsets: Optional[torch.LongTensor] = None,
311
+ indices: Optional[torch.LongTensor] = None,
312
+ head_first: bool = True,
313
+ chunk_size: int = 64
314
+ ) -> torch.Tensor:
315
+ if head_first:
316
+ B, H, T, K, V = *kg.shape, do.shape[-1]
317
+ else:
318
+ B, T, H, K, V = *kg.shape, do.shape[-1]
319
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
320
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
321
+
322
+ dv = torch.empty_like(do)
323
+
324
+ def grid(meta): return (
325
+ triton.cdiv(V, meta['BV']),
326
+ NT,
327
+ B * H
328
+ )
329
+ chunk_dplr_bwd_kernel_dv[grid](
330
+ A_qk=A_qk,
331
+ kg=kg,
332
+ do=do,
333
+ dv=dv,
334
+ dh=dh,
335
+ offsets=offsets,
336
+ indices=indices,
337
+ T=T,
338
+ H=H,
339
+ K=K,
340
+ V=V,
341
+ BT=BT,
342
+ HEAD_FIRST=head_first
343
+ )
344
+ return dv
345
+
346
+
347
+ def chunk_dplr_bwd_o(
348
+ k: torch.Tensor,
349
+ b: torch.Tensor,
350
+ v: torch.Tensor,
351
+ v_new: torch.Tensor,
352
+ gk: torch.Tensor,
353
+ do: torch.Tensor,
354
+ h: torch.Tensor,
355
+ dh: torch.Tensor,
356
+ dv: torch.Tensor,
357
+ w: torch.Tensor,
358
+ offsets: Optional[torch.LongTensor] = None,
359
+ indices: Optional[torch.LongTensor] = None,
360
+ chunk_size: int = 64,
361
+ scale: float = 1.0,
362
+ head_first: bool = True,
363
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
364
+
365
+ if head_first:
366
+ B, H, T, K, V = *w.shape, v.shape[-1]
367
+ else:
368
+ B, T, H, K, V = *w.shape, v.shape[-1]
369
+
370
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
371
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
372
+
373
+ BK = min(triton.next_power_of_2(K), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32)
374
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32)
375
+ NK = triton.cdiv(K, BK)
376
+ dq = torch.empty_like(k)
377
+ dk = torch.empty_like(k)
378
+ dw = torch.empty_like(w)
379
+ db = torch.empty_like(b)
380
+ grid = (NK, NT, B * H)
381
+
382
+ dgk_last = torch.empty(B, H, NT, K, dtype=torch.float, device=w.device) if head_first \
383
+ else torch.empty(B, NT, H, K, dtype=torch.float, device=w.device)
384
+
385
+ chunk_dplr_bwd_o_kernel[grid](
386
+ k=k,
387
+ b=b,
388
+ v=v,
389
+ v_new=v_new,
390
+ h=h,
391
+ do=do,
392
+ dh=dh,
393
+ dq=dq,
394
+ dk=dk,
395
+ db=db,
396
+ dgk_last=dgk_last,
397
+ w=w,
398
+ dv=dv,
399
+ dw=dw,
400
+ gk=gk,
401
+ offsets=offsets,
402
+ indices=indices,
403
+ T=T,
404
+ H=H,
405
+ K=K,
406
+ V=V,
407
+ BT=BT,
408
+ BK=BK,
409
+ BV=BV,
410
+ HEAD_FIRST=head_first,
411
+ )
412
+ return dq, dk, dw, db, dgk_last
413
+
414
+
415
+ def chunk_dplr_bwd_dAu(
416
+ v: torch.Tensor,
417
+ v_new: torch.Tensor,
418
+ do: torch.Tensor,
419
+ A_qb: torch.Tensor,
420
+ scale: float,
421
+ offsets: Optional[torch.LongTensor] = None,
422
+ indices: Optional[torch.LongTensor] = None,
423
+ head_first: bool = True,
424
+ chunk_size: int = 64
425
+ ) -> torch.Tensor:
426
+ if head_first:
427
+ B, H, T, V = v.shape
428
+ else:
429
+ B, T, H, V = v.shape
430
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
431
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
432
+
433
+ if check_shared_mem('ampere'): # A100
434
+ BV = min(triton.next_power_of_2(V), 128)
435
+ elif check_shared_mem('ada'): # 4090
436
+ BV = min(triton.next_power_of_2(V), 64)
437
+ else:
438
+ BV = min(triton.next_power_of_2(V), 32)
439
+
440
+ grid = (NT, B * H)
441
+ dA_qk = torch.empty(B, H, T, BT, dtype=torch.float, device=v.device) if head_first \
442
+ else torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
443
+ dA_qb = torch.empty(B, H, T, BT, dtype=torch.float, device=v.device) if head_first \
444
+ else torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
445
+ dv_new = torch.empty_like(v_new)
446
+ chunk_dplr_bwd_kernel_dAu[grid](
447
+ v=v,
448
+ do=do,
449
+ v_new=v_new,
450
+ A_qb=A_qb,
451
+ dA_qk=dA_qk,
452
+ dA_qb=dA_qb,
453
+ dv_new=dv_new,
454
+ offsets=offsets,
455
+ indices=indices,
456
+ scale=scale,
457
+ T=T,
458
+ H=H,
459
+ V=V,
460
+ BT=BT,
461
+ BV=BV,
462
+ HEAD_FIRST=head_first
463
+ )
464
+ return dv_new, dA_qk, dA_qb
fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, use_cuda_graph
11
+
12
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BK in BK_LIST
22
+ for BV in BK_LIST
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_o(
31
+ qg,
32
+ v,
33
+ v_new,
34
+ A_qk,
35
+ A_qb,
36
+ h,
37
+ o,
38
+ offsets,
39
+ indices,
40
+ T,
41
+ H: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ USE_OFFSETS: tl.constexpr,
48
+ HEAD_FIRST: tl.constexpr,
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_h = i_bh // H, i_bh % H
52
+
53
+ if USE_OFFSETS:
54
+ i_tg = i_t
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ NT = tl.cdiv(T, BT)
59
+ else:
60
+ NT = tl.cdiv(T, BT)
61
+ i_tg = i_b * NT + i_t
62
+ bos, eos = i_b * T, i_b * T + T
63
+
64
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
65
+ for i_k in range(tl.cdiv(K, BK)):
66
+ if HEAD_FIRST:
67
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
68
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
69
+ else:
70
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
71
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
73
+ b_h = tl.load(p_h, boundary_check=(0, 1))
74
+ b_o += tl.dot(b_qg, b_h)
75
+
76
+ if HEAD_FIRST:
77
+ p_Aqk = tl.make_block_ptr(A_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
78
+ p_Aqb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
79
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
80
+ p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
81
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
82
+ else:
83
+ p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
84
+ p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
86
+ p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
87
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+
89
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
90
+ b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
91
+ b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1))
92
+ b_Aqk = tl.where(m_s, b_Aqk, 0)
93
+ b_Aqb = tl.where(m_s, b_Aqb, 0)
94
+ b_v = tl.load(p_v, boundary_check=(0, 1))
95
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
96
+ b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new)
97
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
98
+
99
+
100
+ def chunk_dplr_fwd_o(
101
+ qg: torch.Tensor,
102
+ v: torch.Tensor,
103
+ v_new: torch.Tensor,
104
+ A_qk: torch.Tensor,
105
+ A_qb: torch.Tensor,
106
+ h: torch.Tensor,
107
+ offsets: Optional[torch.LongTensor] = None,
108
+ indices: Optional[torch.LongTensor] = None,
109
+ head_first: bool = True,
110
+ chunk_size: int = 64
111
+ ) -> torch.Tensor:
112
+ if head_first:
113
+ B, H, T, K, V = *qg.shape, v.shape[-1]
114
+ else:
115
+ B, T, H, K, V = *qg.shape, v.shape[-1]
116
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
117
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
118
+
119
+ o = torch.empty_like(v)
120
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
121
+ chunk_dplr_fwd_kernel_o[grid](
122
+ qg=qg,
123
+ v=v,
124
+ v_new=v_new,
125
+ A_qk=A_qk,
126
+ A_qb=A_qb,
127
+ h=h,
128
+ o=o,
129
+ offsets=offsets,
130
+ indices=indices,
131
+ T=T,
132
+ H=H,
133
+ K=K,
134
+ V=V,
135
+ BT=BT,
136
+ HEAD_FIRST=head_first
137
+ )
138
+ return o
fla/ops/generalized_delta_rule/dplr/fused_recurrent.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BV in [16, 32, 64]
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BK'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def fused_recurrent_dplr_delta_rule_fwd_kernel(
31
+ q,
32
+ k,
33
+ v,
34
+ a,
35
+ b,
36
+ gk,
37
+ o,
38
+ h0,
39
+ ht,
40
+ offsets,
41
+ scale,
42
+ T,
43
+ B: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ REVERSE: tl.constexpr,
50
+ USE_INITIAL_STATE: tl.constexpr,
51
+ STORE_FINAL_STATE: tl.constexpr,
52
+ USE_OFFSETS: tl.constexpr,
53
+ HEAD_FIRST: tl.constexpr
54
+ ):
55
+ i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
56
+ i_n, i_h = i_nh // H, i_nh % H
57
+
58
+ if USE_OFFSETS:
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
60
+ T = eos - bos
61
+ else:
62
+ bos, eos = i_n * T, i_n * T + T
63
+
64
+ o_k = tl.arange(0, BK)
65
+ o_v = i_v * BV + tl.arange(0, BV)
66
+ if HEAD_FIRST:
67
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
68
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
69
+ p_a = a + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
70
+ p_b = b + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
71
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
72
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
73
+ p_o = o + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
74
+
75
+ else:
76
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
77
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
78
+ p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
79
+ p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
80
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
81
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
82
+ p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
83
+
84
+ mask_k = o_k < K
85
+ mask_v = o_v < V
86
+ mask_h = mask_k[None, :] & mask_v[:, None]
87
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
88
+
89
+ if USE_INITIAL_STATE:
90
+ p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
91
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
92
+
93
+ for _ in range(0, T):
94
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
95
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
96
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
97
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
98
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
99
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
100
+
101
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
102
+ b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
103
+ b_o = tl.sum(b_h * b_q[None, :], axis=1)
104
+
105
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
106
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
107
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
108
+ p_a += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
109
+ p_b += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
110
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
111
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
112
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
113
+
114
+ if STORE_FINAL_STATE:
115
+ p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
116
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
117
+
118
+
119
+ def fused_recurrent_dplr_delta_rule_fwd(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ a: torch.Tensor,
124
+ b: torch.Tensor,
125
+ gk: torch.Tensor,
126
+ scale: Optional[float] = 1.0,
127
+ initial_state: Optional[torch.Tensor] = None,
128
+ output_final_state: bool = False,
129
+ reverse: bool = False,
130
+ offsets: Optional[torch.LongTensor] = None,
131
+ head_first: bool = True
132
+ ):
133
+ if head_first:
134
+ B, H, T, K, V = *k.shape, v.shape[-1]
135
+ else:
136
+ B, T, H, K, V = *k.shape, v.shape[-1]
137
+ N = B if offsets is None else len(offsets) - 1
138
+ BK = triton.next_power_of_2(K)
139
+
140
+ h0 = initial_state
141
+ if output_final_state:
142
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
143
+ else:
144
+ ht = None
145
+ o = torch.empty_like(v)
146
+
147
+ def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
148
+ fused_recurrent_dplr_delta_rule_fwd_kernel[grid](
149
+ q,
150
+ k,
151
+ v,
152
+ a,
153
+ b,
154
+ gk,
155
+ o,
156
+ h0,
157
+ ht,
158
+ offsets,
159
+ scale,
160
+ T=T,
161
+ B=B,
162
+ H=H,
163
+ K=K,
164
+ V=V,
165
+ BK=BK,
166
+ REVERSE=reverse,
167
+ HEAD_FIRST=head_first
168
+ )
169
+ return o, ht
170
+
171
+
172
+ class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function):
173
+
174
+ @staticmethod
175
+ @input_guard
176
+ @autocast_custom_fwd
177
+ def forward(
178
+ ctx,
179
+ q: torch.Tensor,
180
+ k: torch.Tensor,
181
+ v: torch.Tensor,
182
+ a: torch.Tensor,
183
+ b: torch.Tensor,
184
+ gk: torch.Tensor,
185
+ scale: Optional[float] = 1.0,
186
+ initial_state: Optional[torch.Tensor] = None,
187
+ output_final_state: bool = False,
188
+ reverse: bool = False,
189
+ offsets: Optional[torch.LongTensor] = None,
190
+ head_first: bool = False
191
+ ):
192
+ o, ht = fused_recurrent_dplr_delta_rule_fwd(
193
+ q=q,
194
+ k=k,
195
+ v=v,
196
+ a=a,
197
+ b=b,
198
+ gk=gk,
199
+ scale=scale,
200
+ initial_state=initial_state,
201
+ output_final_state=output_final_state,
202
+ reverse=reverse,
203
+ offsets=offsets,
204
+ head_first=head_first
205
+ )
206
+ return o, ht
207
+
208
+ @staticmethod
209
+ @input_guard
210
+ @autocast_custom_bwd
211
+ def backward(ctx, do, dht):
212
+ raise NotImplementedError(
213
+ "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. "
214
+ "This kernel is only for inference. "
215
+ "For training, please use `chunk_dplr_delta_rule`."
216
+ )
217
+
218
+
219
+ def fused_recurrent_dplr_delta_rule(
220
+ q: torch.Tensor,
221
+ k: torch.Tensor,
222
+ v: torch.Tensor,
223
+ a: torch.Tensor,
224
+ b: torch.Tensor,
225
+ gk: torch.Tensor,
226
+ scale: Optional[float] = 1.0,
227
+ initial_state: Optional[torch.Tensor] = None,
228
+ output_final_state: bool = False,
229
+ reverse: bool = False,
230
+ cu_seqlens: Optional[torch.Tensor] = None,
231
+ head_first: bool = False
232
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
233
+ r"""
234
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
235
+
236
+ Args:
237
+ q (torch.Tensor):
238
+ queries of shape `[B, H, T, K]`
239
+ k (torch.Tensor):
240
+ keys of shape `[B, H, T, K]`
241
+ v (torch.Tensor):
242
+ values of shape `[B, H, T, V]`
243
+ a (torch.Tensor):
244
+ as of shape `[B, H, T, K]`
245
+ b (torch.Tensor):
246
+ bs of shape `[B, H, T, K]`
247
+ gk (torch.Tensor):
248
+ gk of shape `[B, H, T, K]`
249
+ scale (Optional[int]):
250
+ Scale factor for the RetNet attention scores.
251
+ If None, it will default to `1 / sqrt(K)`. Default: `1.0`.
252
+ initial_state (Optional[torch.Tensor]):
253
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
254
+ output_final_state (Optional[bool]):
255
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
256
+ reverse (Optional[bool]):
257
+ If `True`, process the state passing in reverse order. Default: `False`.
258
+ cu_seqlens (Optional[torch.Tensor]):
259
+ Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
260
+ consistent with the FlashAttention API.
261
+ head_first (Optional[bool]):
262
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
263
+ Default: `False`.
264
+ """
265
+ if cu_seqlens is not None:
266
+ if q.shape[0] != 1:
267
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
268
+ f"Please flatten variable-length inputs before processing.")
269
+ if head_first:
270
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
271
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
272
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
273
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
274
+ if scale is None:
275
+ scale = q.shape[-1] ** -0.5
276
+ else:
277
+ assert scale > 0, "scale must be positive"
278
+ o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply(
279
+ q,
280
+ k,
281
+ v,
282
+ a,
283
+ b,
284
+ gk,
285
+ scale,
286
+ initial_state,
287
+ output_final_state,
288
+ reverse,
289
+ cu_seqlens,
290
+ head_first
291
+ )
292
+ return o, final_state
fla/ops/generalized_delta_rule/dplr/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
7
+ # q, k, alpha, beta [B, H, L, D_K]
8
+ # v [B, H, L, D_V]
9
+
10
+
11
+ def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
12
+ orig_dtype = q.dtype
13
+ b, h, l, d_k = q.shape
14
+ q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk])
15
+ d_v = v.shape[-1]
16
+ o = torch.zeros_like(v)
17
+ S = torch.zeros(b, h, d_k, d_v).to(v)
18
+ q = q * (d_k ** -0.5)
19
+
20
+ if initial_state is not None:
21
+ S += initial_state
22
+
23
+ for i in range(l):
24
+ _k = k[:, :, i]
25
+ _q = q[:, :, i]
26
+ _v = v[:, :, i]
27
+ _alpha = alpha[:, :, i].clone()
28
+ _beta = beta[:, :, i].clone()
29
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
30
+ S = S.clone() * gk[:, :, i].exp()[..., None] + _kv
31
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
32
+ S = None if output_final_state is False else S
33
+ return o.to(orig_dtype), S
34
+
35
+
36
+ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32):
37
+ b, h, l, d_k = q.shape
38
+ d_v = v.shape[-1]
39
+ q = q * (d_k ** -0.5)
40
+ v = v
41
+ assert l % chunk_size == 0
42
+
43
+ S = k.new_zeros(b, h, d_k, d_v).to(q)
44
+ if initial_state is not None:
45
+ S += initial_state
46
+
47
+ # note that diagonal is masked.
48
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
49
+ q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
50
+ c=chunk_size).float(), [q, k, v, alpha, beta, gk])
51
+
52
+ gk_cumsum = gk.cumsum(-2)
53
+
54
+ # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
55
+ A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
56
+ A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
57
+ A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
58
+ A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
59
+
60
+ for i in range(chunk_size):
61
+ alpha_i = alpha[:, :, :, i, None]
62
+ q_i = q[:, :, :, i, None]
63
+ gk_i = gk_cumsum[:, :, :, i, None]
64
+ mask = (torch.arange(chunk_size) <= i).to(q.device)
65
+ attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
66
+ A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone()
67
+ A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
68
+ mask = (torch.arange(chunk_size) < i).to(q.device)
69
+ # shift by one.
70
+ attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
71
+ A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
72
+ A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()
73
+
74
+ A_ab = A_ab
75
+ for i in range(1, chunk_size):
76
+ A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2)
77
+
78
+ A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device)
79
+ u = A_ab @ (A_ak @ v)
80
+ w = A_ab @ ((gk_cumsum-gk).exp() * alpha)
81
+
82
+ o = torch.zeros_like(v)
83
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
84
+ for i in range(0, l // chunk_size):
85
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
86
+ v2_i = u_i + w_i @ S
87
+
88
+ o_1 = A_qk[:, :, i] @ v_i
89
+ o_2 = A_qb[:, :, i] @ v2_i
90
+ o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
91
+ o[:, :, i] = o_1 + o_2 + o_3
92
+ decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
93
+ S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
94
+ (beta_i * decay).transpose(-1, -2) @ v2_i
95
+ S = None if output_final_state is False else S
96
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, is_intel_alchemist, use_cuda_graph
11
+
12
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
13
+ triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [2, 4, 8, 16, 32]
23
+ for num_stages in [2, 3, 4]
24
+ ],
25
+ key=['BT', 'BK', 'BV'],
26
+ use_cuda_graph=use_cuda_graph,
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def bwd_prepare_wy_repr_kernel(
30
+ A_ab_inv,
31
+ A_ak,
32
+ ag,
33
+ v,
34
+ dw,
35
+ du,
36
+ dv,
37
+ dv0,
38
+ dag,
39
+ dAak,
40
+ dAab,
41
+ offsets,
42
+ indices,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr,
51
+ HEAD_FIRST: tl.constexpr
52
+ ):
53
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+ if USE_OFFSETS:
56
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ if HEAD_FIRST:
63
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
64
+ p_Aak_t = tl.make_block_ptr(A_ak + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
65
+ p_dAak = tl.make_block_ptr(dAak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
66
+ p_dAab = tl.make_block_ptr(dAab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
67
+ else:
68
+ p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
69
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
70
+ p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
72
+
73
+ b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
74
+ b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
75
+ b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0)
76
+ b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0)
77
+ b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
78
+ b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)
79
+
80
+ for i_v in range(tl.cdiv(V, BV)):
81
+ if HEAD_FIRST:
82
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
83
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_dv0 = tl.make_block_ptr(dv0 + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
85
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
86
+ else:
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
89
+ p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
90
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
91
+ b_v = tl.load(p_v, boundary_check=(0, 1))
92
+ b_du = tl.load(p_du, boundary_check=(0, 1))
93
+ b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
94
+ b_dv0 = tl.load(p_dv0, boundary_check=(0, 1))
95
+ b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du)
96
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ b_dA_tmp = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_tmp, 0)
99
+ b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
100
+ b_dA_ak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ak, 0)
101
+ tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
102
+ b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)
103
+
104
+ for i_k in range(tl.cdiv(K, BK)):
105
+ if HEAD_FIRST:
106
+ p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
107
+ p_dag = tl.make_block_ptr(dag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
108
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
109
+ else:
110
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
111
+ p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
112
+ p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
113
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
114
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
115
+ b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
116
+ b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
117
+ tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))
118
+
119
+ # if we know dL/dA^(-1), for dL/dA, we can use the following formula:
120
+ # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
121
+ # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
122
+ # denote A = I - lower(A_ab), B = A^-1
123
+ # in the backward pass.
124
+ # dL/dA = -(B)^T @ (dL/dB) @ B^T
125
+ # dL/dA_ab = lower(B^T @ dL/dB @ B^T)
126
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
127
+ b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
128
+ b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
129
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
130
+ tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))
131
+
132
+
133
+ def chunk_dplr_bwd_wy(
134
+ A_ab_inv: torch.Tensor,
135
+ A_ak: torch.Tensor,
136
+ v: torch.Tensor,
137
+ ag: torch.Tensor,
138
+ dw: torch.Tensor,
139
+ du: torch.Tensor,
140
+ dv0: torch.Tensor,
141
+ offsets: Optional[torch.LongTensor],
142
+ indices: Optional[torch.LongTensor],
143
+ head_first: bool,
144
+ chunk_size: int,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
146
+ A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du])
147
+ if head_first:
148
+ B, H, T, K, V = *dw.shape, du.shape[-1]
149
+ else:
150
+ B, T, H, K, V = *dw.shape, du.shape[-1]
151
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
152
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
153
+ BK = min(triton.next_power_of_2(K), 64)
154
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32)
155
+
156
+ dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
157
+ dA_ak = torch.empty_like(A_ak, dtype=torch.float)
158
+ dv = torch.empty_like(v)
159
+ dag = torch.empty_like(ag)
160
+
161
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
162
+ A_ab_inv=A_ab_inv,
163
+ A_ak=A_ak,
164
+ ag=ag,
165
+ v=v,
166
+ dw=dw,
167
+ du=du,
168
+ dv=dv,
169
+ dv0=dv0,
170
+ dag=dag,
171
+ dAak=dA_ak,
172
+ dAab=dA_ab,
173
+ offsets=offsets,
174
+ indices=indices,
175
+ T=T,
176
+ H=H,
177
+ K=K,
178
+ V=V,
179
+ BT=BT,
180
+ BK=BK,
181
+ BV=BV,
182
+ HEAD_FIRST=head_first
183
+ )
184
+ return dA_ab, dA_ak, dv, dag
fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import gather
11
+ from fla.utils import is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps)
20
+ for num_warps in [1, 2, 4, 8, 16]
21
+ ],
22
+ key=['BT'],
23
+ use_cuda_graph=use_cuda_graph,
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def fwd_prepare_wy_repr_kernel_chunk32(
27
+ A_ab,
28
+ A_ab_inv,
29
+ offsets,
30
+ indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ BT: tl.constexpr,
34
+ BC: tl.constexpr, # placeholder, do not delete
35
+ USE_OFFSETS: tl.constexpr,
36
+ HEAD_FIRST: tl.constexpr
37
+ ):
38
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if USE_OFFSETS:
41
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
42
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
43
+ T = eos - bos
44
+ else:
45
+ bos, eos = i_b * T, i_b * T + T
46
+ if HEAD_FIRST:
47
+ p_Aab = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
48
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
49
+ else:
50
+ p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
51
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
52
+ b_A_ab = tl.load(p_Aab, boundary_check=(0, 1))
53
+ b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0)
54
+ for i in range(1, BT):
55
+ mask = tl.arange(0, BT) == i
56
+ b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0)
57
+ b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i)
58
+ b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
59
+ b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
60
+ tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))
61
+
62
+
63
+ @triton.heuristics({
64
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
65
+ })
66
+ @triton.autotune(
67
+ configs=[
68
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
69
+ for num_warps in [2, 4, 8]
70
+ for num_stages in [2, 3, 4]
71
+ ],
72
+ key=['BC'],
73
+ use_cuda_graph=use_cuda_graph,
74
+ )
75
+ @triton.jit(do_not_specialize=['T'])
76
+ def fwd_prepare_wy_repr_kernel_chunk64(
77
+ A_ab,
78
+ A_ab_inv,
79
+ offsets,
80
+ indices,
81
+ T,
82
+ H: tl.constexpr,
83
+ BT: tl.constexpr,
84
+ BC: tl.constexpr,
85
+ USE_OFFSETS: tl.constexpr,
86
+ HEAD_FIRST: tl.constexpr,
87
+ GATHER_SUPPORTED: tl.constexpr = is_gather_supported
88
+ ):
89
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
90
+ i_b, i_h = i_bh // H, i_bh % H
91
+ if USE_OFFSETS:
92
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
93
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
94
+ T = eos - bos
95
+ else:
96
+ bos, eos = i_b * T, i_b * T + T
97
+
98
+ if HEAD_FIRST:
99
+
100
+ p_A1 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
101
+ p_A2 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
102
+ p_A3 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
103
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
104
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
105
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
106
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
107
+ else:
108
+ p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
109
+ p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
110
+ p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
111
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
112
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
113
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
114
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
115
+
116
+ b_A = tl.load(p_A1, boundary_check=(0, 1))
117
+ b_A2 = tl.load(p_A2, boundary_check=(0, 1))
118
+ b_A3 = tl.load(p_A3, boundary_check=(0, 1))
119
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
120
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
121
+
122
+ for i in range(1, BC):
123
+ if GATHER_SUPPORTED:
124
+ row_idx = tl.full([1, BC], i, dtype=tl.int16)
125
+ # [1, BK] -> [BK]
126
+ b_a = tl.sum(gather(b_A, row_idx, axis=0), 0)
127
+ b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0)
128
+ else:
129
+ mask = tl.arange(0, BC) == i
130
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
131
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
132
+ mask = tl.arange(0, BC) == i
133
+ # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
134
+ # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
135
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
136
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
137
+ b_A = tl.where(mask[:, None], b_a, b_A)
138
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
139
+
140
+ # blockwise computation of lower triangular matrix's inverse
141
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
142
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
143
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
144
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A)
145
+ # tl.debug_barrier()
146
+ tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
147
+ tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
148
+ tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
149
+ # causal mask
150
+ tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1))
151
+
152
+
153
+ @triton.heuristics({
154
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
155
+ })
156
+ @triton.autotune(
157
+ configs=[
158
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
159
+ for num_warps in [2, 4, 8, 16, 32]
160
+ for num_stages in [2, 3, 4]
161
+ ],
162
+ key=['BT', 'BK', 'BV'],
163
+ use_cuda_graph=use_cuda_graph,
164
+ )
165
+ @triton.jit(do_not_specialize=['T'])
166
+ def fwd_wu_kernel(
167
+ u,
168
+ w,
169
+ ag,
170
+ v,
171
+ A_ab_inv,
172
+ A_ak,
173
+ offsets,
174
+ indices,
175
+ T,
176
+ H: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BK: tl.constexpr,
181
+ BV: tl.constexpr,
182
+ USE_OFFSETS: tl.constexpr,
183
+ HEAD_FIRST: tl.constexpr,
184
+ ):
185
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
186
+ i_b, i_h = i_bh // H, i_bh % H
187
+ if USE_OFFSETS:
188
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
189
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
190
+ T = eos - bos
191
+ else:
192
+ bos, eos = i_b * T, i_b * T + T
193
+
194
+ if HEAD_FIRST:
195
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
196
+ p_A_ak = tl.make_block_ptr(A_ak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
197
+ else:
198
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
199
+ p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
200
+ b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1))
201
+ b_Aak = tl.load(p_A_ak, boundary_check=(0, 1))
202
+ o_s = tl.arange(0, BT)
203
+ b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0)
204
+ b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0)
205
+ # let's use tf32 here
206
+ b_Aak = tl.dot(b_Aab_inv, b_Aak)
207
+ # (SY 01/04) should be bf16 or tf32? To verify.
208
+ b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne")
209
+ b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne")
210
+
211
+ for i_k in range(tl.cdiv(K, BK)):
212
+ if HEAD_FIRST:
213
+ p_ag = tl.make_block_ptr(ag + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
214
+ p_w = tl.make_block_ptr(w + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
215
+ else:
216
+ p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
217
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
218
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
219
+ b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16
220
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
221
+
222
+ for i_v in range(tl.cdiv(V, BV)):
223
+ if HEAD_FIRST:
224
+ p_v = tl.make_block_ptr(v + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
225
+ p_u = tl.make_block_ptr(u + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
226
+ else:
227
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
228
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
229
+ b_v = tl.load(p_v, boundary_check=(0, 1))
230
+ b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16
231
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
232
+
233
+
234
+ def fwd_prepare_wy_repr(
235
+ ag: torch.Tensor,
236
+ v: torch.Tensor,
237
+ A_ak: torch.Tensor,
238
+ A_ab: torch.Tensor,
239
+ offsets: Optional[torch.LongTensor],
240
+ indices: Optional[torch.LongTensor],
241
+ head_first: bool = True,
242
+ chunk_size: int = 64
243
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
244
+ if head_first:
245
+ B, H, T, K = ag.shape
246
+ else:
247
+ B, T, H, K = ag.shape
248
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
249
+
250
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
251
+ BC = min(BT, 32)
252
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
253
+ A_ab_inv = torch.empty_like(A_ab)
254
+ fwd_fn[(NT, B * H)](
255
+ A_ab=A_ab,
256
+ A_ab_inv=A_ab_inv,
257
+ offsets=offsets,
258
+ indices=indices,
259
+ T=T,
260
+ H=H,
261
+ BT=BT,
262
+ BC=BC,
263
+ HEAD_FIRST=head_first
264
+ )
265
+ w, u = fwd_wu(
266
+ ag=ag,
267
+ v=v,
268
+ A_ak=A_ak,
269
+ A_ab_inv=A_ab_inv,
270
+ offsets=offsets,
271
+ indices=indices,
272
+ head_first=head_first,
273
+ chunk_size=BT
274
+ )
275
+ return w, u, A_ab_inv
276
+
277
+
278
+ def fwd_wu(
279
+ ag: torch.Tensor,
280
+ v: torch.Tensor,
281
+ A_ak: torch.Tensor,
282
+ A_ab_inv: torch.Tensor,
283
+ offsets: Optional[torch.LongTensor],
284
+ indices: Optional[torch.LongTensor],
285
+ head_first: bool,
286
+ chunk_size: int
287
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
288
+ if head_first:
289
+ B, H, T, K, V = *ag.shape, v.shape[-1]
290
+ else:
291
+ B, T, H, K, V = *ag.shape, v.shape[-1]
292
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
293
+
294
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
295
+ BK = min(triton.next_power_of_2(K), 64)
296
+ BV = min(triton.next_power_of_2(V), 64)
297
+
298
+ u = torch.empty_like(v)
299
+ w = torch.empty_like(ag)
300
+ fwd_wu_kernel[(NT, B*H)](
301
+ ag=ag,
302
+ v=v,
303
+ A_ak=A_ak,
304
+ A_ab_inv=A_ab_inv,
305
+ w=w,
306
+ u=u,
307
+ offsets=offsets,
308
+ indices=indices,
309
+ T=T,
310
+ H=H,
311
+ K=K,
312
+ V=V,
313
+ BT=BT,
314
+ BK=BK,
315
+ BV=BV,
316
+ HEAD_FIRST=head_first
317
+ )
318
+ return w, u
fla/ops/generalized_delta_rule/iplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_iplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_iplr_delta_rule',
6
+ 'fused_recurrent_iplr_delta_rule'
7
+ ]
fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (360 Bytes). View file
 
fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (27.4 kB). View file
 
fla/ops/generalized_delta_rule/iplr/chunk.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_delta_h import prepare_chunk_offsets
11
+ from fla.ops.generalized_delta_rule.iplr.wy_fast import fwd_prepare_wy_repr
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph
13
+
14
+ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({}, num_warps=num_warps)
25
+ for num_warps in [2, 4, 8, 16]
26
+ ],
27
+ key=['BT', 'BK', 'BV'],
28
+ use_cuda_graph=use_cuda_graph,
29
+ )
30
+ @triton.jit(do_not_specialize=['T'])
31
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_h(
32
+ k,
33
+ v,
34
+ d,
35
+ b,
36
+ u,
37
+ v_new,
38
+ h,
39
+ h0,
40
+ ht,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ NT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ if HEAD_FIRST:
77
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
81
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
82
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
83
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
84
+ if HEAD_FIRST:
85
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
86
+ p_b = tl.make_block_ptr(b + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
88
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
89
+ p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
91
+ else:
92
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
93
+ p_b = tl.make_block_ptr(b+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
95
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
98
+ # [BK, BC]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ b_v = tl.load(p_v, boundary_check=(0, 1))
101
+ b_d = tl.load(p_d, boundary_check=(0, 1))
102
+ b_b = tl.load(p_b, boundary_check=(0, 1))
103
+ b_v2 = tl.dot(b_d, b_h.to(b_d.dtype)) + tl.load(p_u, boundary_check=(0, 1))
104
+ b_hc += tl.dot(b_k, b_v)
105
+ b_hc += tl.dot(b_b, b_v2.to(b_k.dtype))
106
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
107
+ b_h += b_hc
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
116
+ })
117
+ @triton.autotune(
118
+ configs=[
119
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
120
+ for BK in BKV_LIST
121
+ for BV in BKV_LIST
122
+ for num_warps in [2, 4, 8]
123
+ for num_stages in [2, 3]
124
+ ],
125
+ key=['BT'],
126
+ use_cuda_graph=use_cuda_graph,
127
+ )
128
+ @triton.jit(do_not_specialize=['T'])
129
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_o(
130
+ q,
131
+ k,
132
+ v,
133
+ u,
134
+ b,
135
+ h,
136
+ o,
137
+ offsets,
138
+ indices,
139
+ scale,
140
+ T,
141
+ H: tl.constexpr,
142
+ K: tl.constexpr,
143
+ V: tl.constexpr,
144
+ BT: tl.constexpr,
145
+ BK: tl.constexpr,
146
+ BV: tl.constexpr,
147
+ USE_OFFSETS: tl.constexpr,
148
+ HEAD_FIRST: tl.constexpr,
149
+ ):
150
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
151
+ i_b, i_h = i_bh // H, i_bh % H
152
+
153
+ if USE_OFFSETS:
154
+ i_tg = i_t
155
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
156
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
157
+ T = eos - bos
158
+ NT = tl.cdiv(T, BT)
159
+ else:
160
+ NT = tl.cdiv(T, BT)
161
+ i_tg = i_b * NT + i_t
162
+ bos, eos = i_b * T, i_b * T + T
163
+
164
+ # offset calculation
165
+ q += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
166
+ k += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
167
+ b += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
168
+ v += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
169
+ u += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
170
+ o += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
171
+ h += ((i_bh * NT + i_t) * K * V) if HEAD_FIRST else ((i_tg * H + i_h) * K * V)
172
+ stride_qk = K if HEAD_FIRST else H*K
173
+ stride_vo = V if HEAD_FIRST else H*V
174
+
175
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
176
+ b_Aqk = tl.zeros([BT, BT], dtype=tl.float32)
177
+ b_Aqb = tl.zeros([BT, BT], dtype=tl.float32)
178
+
179
+ for i_k in range(tl.cdiv(K, BK)):
180
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
181
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
182
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
183
+ p_b = tl.make_block_ptr(b, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
184
+ # [BT, BK]
185
+ b_q = tl.load(p_q, boundary_check=(0, 1))
186
+ # [BK, BT]
187
+ b_k = tl.load(p_k, boundary_check=(0, 1))
188
+ b_b = tl.load(p_b, boundary_check=(0, 1))
189
+ # [BK, BV]
190
+ b_h = tl.load(p_h, boundary_check=(0, 1))
191
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
192
+ b_o += tl.dot(b_q, b_h)
193
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
194
+ b_Aqk += tl.dot(b_q, b_k)
195
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
196
+ b_Aqb += tl.dot(b_q, b_b)
197
+
198
+ o_i = tl.arange(0, BT)
199
+ m_A = o_i[:, None] >= o_i[None, :]
200
+ b_Aqk = tl.where(m_A, b_Aqk, 0)
201
+ b_Aqb = tl.where(m_A, b_Aqb, 0)
202
+
203
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
204
+ p_u = tl.make_block_ptr(u, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
205
+ p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
206
+ b_v = tl.load(p_v, boundary_check=(0, 1))
207
+ b_u = tl.load(p_u, boundary_check=(0, 1))
208
+ b_o = (b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_u.dtype), b_u)) * scale
209
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
210
+
211
+
212
+ def chunk_generalized_iplr_delta_rule_fwd_o(
213
+ q: torch.Tensor,
214
+ k: torch.Tensor,
215
+ v: torch.Tensor,
216
+ v_new: torch.Tensor,
217
+ b: torch.Tensor,
218
+ h: torch.Tensor,
219
+ scale: Optional[float] = None,
220
+ offsets: Optional[torch.LongTensor] = None,
221
+ indices: Optional[torch.LongTensor] = None,
222
+ head_first: bool = True,
223
+ chunk_size: int = 64
224
+ ) -> torch.Tensor:
225
+ if head_first:
226
+ B, H, T, K, V = *q.shape, v.shape[-1]
227
+ else:
228
+ B, T, H, K, V = *q.shape, v.shape[-1]
229
+ if scale is None:
230
+ scale = k.shape[-1] ** -0.5
231
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
232
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
233
+
234
+ o = torch.empty_like(v)
235
+
236
+ def grid(meta): return (
237
+ triton.cdiv(V, meta['BV']),
238
+ NT,
239
+ B * H
240
+ )
241
+ chunk_generalized_iplr_delta_rule_fwd_kernel_o[grid](
242
+ q=q,
243
+ k=k,
244
+ v=v,
245
+ u=v_new,
246
+ b=b,
247
+ h=h,
248
+ o=o,
249
+ offsets=offsets,
250
+ indices=indices,
251
+ scale=scale,
252
+ T=T,
253
+ H=H,
254
+ K=K,
255
+ V=V,
256
+ BT=BT,
257
+ HEAD_FIRST=head_first
258
+ )
259
+ return o
260
+
261
+
262
+ def chunk_generalized_iplr_delta_rule_fwd_h(
263
+ k: torch.Tensor,
264
+ v: torch.Tensor,
265
+ w: torch.Tensor,
266
+ u: torch.Tensor,
267
+ b: torch.Tensor,
268
+ initial_state: Optional[torch.Tensor] = None,
269
+ output_final_state: bool = False,
270
+ offsets: Optional[torch.LongTensor] = None,
271
+ indices: Optional[torch.LongTensor] = None,
272
+ head_first: bool = True,
273
+ chunk_size: int = 64
274
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
275
+ if head_first:
276
+ B, H, T, K, V = *k.shape, u.shape[-1]
277
+ else:
278
+ B, T, H, K, V = *k.shape, u.shape[-1]
279
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
280
+ # N: the actual number of sequences in the batch with either equal or variable lengths
281
+ if offsets is None:
282
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
283
+ else:
284
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
285
+
286
+ BK = triton.next_power_of_2(K)
287
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
288
+ # H100 can have larger block size
289
+
290
+ if check_shared_mem('hopper', k.device.index):
291
+ BV = 64
292
+ BC = 64 if K <= 128 else 32
293
+ elif check_shared_mem('ampere', k.device.index): # A100
294
+ BV = 32
295
+ BC = 32
296
+ else:
297
+ BV = 16
298
+ BC = 16
299
+
300
+ BC = min(BT, BC)
301
+ NK = triton.cdiv(K, BK)
302
+ NV = triton.cdiv(V, BV)
303
+
304
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
305
+
306
+ if head_first:
307
+ h = k.new_empty(B, H, NT, K, V)
308
+ else:
309
+ h = k.new_empty(B, NT, H, K, V)
310
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
311
+
312
+ v_new = torch.empty_like(u)
313
+ grid = (NK, NV, N * H)
314
+
315
+ chunk_generalized_iplr_delta_rule_fwd_kernel_h[grid](
316
+ k=k,
317
+ v=v,
318
+ d=w,
319
+ b=b,
320
+ u=u,
321
+ v_new=v_new,
322
+ h=h,
323
+ h0=initial_state,
324
+ ht=final_state,
325
+ offsets=offsets,
326
+ chunk_offsets=chunk_offsets,
327
+ T=T,
328
+ H=H,
329
+ K=K,
330
+ V=V,
331
+ BT=BT,
332
+ BC=BC,
333
+ BK=BK,
334
+ BV=BV,
335
+ NT=NT,
336
+ HEAD_FIRST=head_first
337
+ )
338
+ return h, v_new, final_state
339
+
340
+
341
+ def chunk_generalized_iplr_delta_rule_fwd(
342
+ q: torch.Tensor,
343
+ k: torch.Tensor,
344
+ v: torch.Tensor,
345
+ a: torch.Tensor,
346
+ b: torch.Tensor,
347
+ scale: float,
348
+ initial_state: torch.Tensor,
349
+ output_final_state: bool,
350
+ offsets: Optional[torch.LongTensor] = None,
351
+ indices: Optional[torch.LongTensor] = None,
352
+ head_first: bool = True,
353
+ chunk_size: int = 64
354
+ ):
355
+ T = q.shape[2] if head_first else q.shape[1]
356
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
357
+ w, u, _ = fwd_prepare_wy_repr(
358
+ a=a,
359
+ b=b,
360
+ k=k,
361
+ v=v,
362
+ offsets=offsets,
363
+ indices=indices,
364
+ head_first=head_first,
365
+ chunk_size=BT
366
+ )
367
+
368
+ h, v_new, final_state = chunk_generalized_iplr_delta_rule_fwd_h(
369
+ k=k,
370
+ v=v,
371
+ b=b,
372
+ w=w,
373
+ u=u,
374
+ initial_state=initial_state,
375
+ output_final_state=output_final_state,
376
+ offsets=offsets,
377
+ indices=indices,
378
+ head_first=head_first,
379
+ chunk_size=BT
380
+ )
381
+ o = chunk_generalized_iplr_delta_rule_fwd_o(
382
+ q=q,
383
+ k=k,
384
+ v=v,
385
+ v_new=v_new,
386
+ b=b,
387
+ h=h,
388
+ scale=scale,
389
+ offsets=offsets,
390
+ indices=indices,
391
+ head_first=head_first,
392
+ chunk_size=BT
393
+ )
394
+ return o, final_state
395
+
396
+
397
+ class ChunkGeneralizedIPLRDeltaRuleFunction(torch.autograd.Function):
398
+
399
+ @staticmethod
400
+ @input_guard
401
+ @autocast_custom_fwd
402
+ def forward(
403
+ ctx,
404
+ q: torch.Tensor,
405
+ k: torch.Tensor,
406
+ v: torch.Tensor,
407
+ a: torch.Tensor,
408
+ b: torch.Tensor,
409
+ scale: float,
410
+ initial_state: torch.Tensor,
411
+ output_final_state: bool,
412
+ offsets: Optional[torch.LongTensor] = None,
413
+ head_first: bool = True
414
+ ):
415
+ chunk_size = 64
416
+
417
+ # 2-d indices denoting the offsets of chunks in each sequence
418
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
419
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
420
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
421
+ indices = None
422
+ if offsets is not None:
423
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
424
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
425
+
426
+ o, final_state = chunk_generalized_iplr_delta_rule_fwd(
427
+ q=q,
428
+ k=k,
429
+ v=v,
430
+ a=a,
431
+ b=b,
432
+ scale=scale,
433
+ initial_state=initial_state,
434
+ output_final_state=output_final_state,
435
+ offsets=offsets,
436
+ indices=indices,
437
+ head_first=head_first,
438
+ chunk_size=chunk_size
439
+ )
440
+ return o.to(q.dtype), final_state
441
+
442
+ @staticmethod
443
+ @input_guard
444
+ @autocast_custom_bwd
445
+ def backward(
446
+ ctx,
447
+ do: torch.Tensor,
448
+ dht: torch.Tensor
449
+ ):
450
+ raise NotImplementedError(
451
+ "Backward pass for ChunkGeneralizedIPLRDeltaRuleFunction is not implemented yet. "
452
+ "Stay tuned!"
453
+ )
454
+
455
+
456
+ @torch.compiler.disable
457
+ def chunk_iplr_delta_rule(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ a: torch.Tensor,
462
+ b: torch.Tensor,
463
+ scale: float = None,
464
+ initial_state: torch.Tensor = None,
465
+ output_final_state: bool = False,
466
+ cu_seqlens: Optional[torch.LongTensor] = None,
467
+ head_first: bool = True
468
+ ):
469
+ r"""
470
+ Args:
471
+ q (torch.Tensor):
472
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
473
+ k (torch.Tensor):
474
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
475
+ v (torch.Tensor):
476
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
477
+ a (torch.Tensor):
478
+ activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
479
+ b (torch.Tensor):
480
+ betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
481
+ scale (Optional[int]):
482
+ Scale factor for the RetNet attention scores.
483
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
484
+ initial_state (Optional[torch.Tensor]):
485
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
486
+ For equal-length input sequences, `N` equals the batch size `B`.
487
+ Default: `None`.
488
+ output_final_state (Optional[bool]):
489
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
490
+ cu_seqlens (torch.LongTensor):
491
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
492
+ consistent with the FlashAttention API.
493
+ head_first (Optional[bool]):
494
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
495
+ Default: `True`.
496
+
497
+ Returns:
498
+ o (torch.Tensor):
499
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
500
+ final_state (torch.Tensor):
501
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
502
+ """
503
+ assert q.dtype == k.dtype == v.dtype
504
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
505
+
506
+ if cu_seqlens is not None:
507
+ if q.shape[0] != 1:
508
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
509
+ f"Please flatten variable-length inputs before processing.")
510
+ if head_first:
511
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
512
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
513
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
514
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
515
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
516
+ o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply(
517
+ q,
518
+ k,
519
+ v,
520
+ a,
521
+ b,
522
+ scale,
523
+ initial_state,
524
+ output_final_state,
525
+ cu_seqlens,
526
+ head_first
527
+ )
528
+ return o, final_state
fla/ops/generalized_delta_rule/iplr/wy_fast.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.utils import check_shared_mem, is_nvidia_hopper
12
+
13
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps)
22
+ for num_warps in [1, 2, 4, 8, 16]
23
+ ],
24
+ key=['BK']
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def fwd_prepare_wy_repr_kernel_chunk32(
28
+ a,
29
+ b,
30
+ A,
31
+ offsets,
32
+ indices,
33
+ T,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ BT: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BC: tl.constexpr, # dummy placeholder
39
+ USE_OFFSETS: tl.constexpr,
40
+ HEAD_FIRST: tl.constexpr,
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+ i_b, i_h = i_bh // H, i_bh % H
44
+ if USE_OFFSETS:
45
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
46
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
47
+ T = eos - bos
48
+ else:
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
52
+ for i_k in range(tl.cdiv(K, BK)):
53
+ if HEAD_FIRST:
54
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
55
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
56
+ else:
57
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
58
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
59
+ b_a = tl.load(p_a, boundary_check=(0, 1))
60
+ b_b = tl.load(p_b, boundary_check=(0, 1))
61
+ b_A += tl.dot(b_a, b_b)
62
+
63
+ b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
64
+ for i in range(1, BT):
65
+ mask = tl.arange(0, BT) == i
66
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
67
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
68
+ b_A = tl.where(mask[:, None], b_a, b_A)
69
+ b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
70
+
71
+ if HEAD_FIRST:
72
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
73
+ else:
74
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
75
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
76
+
77
+
78
+ @triton.heuristics({
79
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
80
+ })
81
+ @triton.autotune(
82
+ configs=[
83
+ triton.Config({}, num_warps=num_warps)
84
+ for num_warps in [1, 2, 4, 8, 16]
85
+ ],
86
+ key=['BK']
87
+ )
88
+ @triton.jit(do_not_specialize=['T'])
89
+ def fwd_prepare_wy_repr_kernel_chunk64(
90
+ a,
91
+ b,
92
+ A,
93
+ offsets,
94
+ indices,
95
+ T,
96
+ H: tl.constexpr,
97
+ K: tl.constexpr,
98
+ BT: tl.constexpr,
99
+ BK: tl.constexpr,
100
+ BC: tl.constexpr,
101
+ USE_OFFSETS: tl.constexpr,
102
+ HEAD_FIRST: tl.constexpr
103
+ ):
104
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
105
+ i_b, i_h = i_bh // H, i_bh % H
106
+ if USE_OFFSETS:
107
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
108
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
109
+ T = eos - bos
110
+ else:
111
+ bos, eos = i_b * T, i_b * T + T
112
+
113
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
114
+ b_A2 = tl.zeros([BC, BC], dtype=tl.float32)
115
+ b_A3 = tl.zeros([BC, BC], dtype=tl.float32)
116
+
117
+ for i_k in range(tl.cdiv(K, BK)):
118
+ if HEAD_FIRST:
119
+ p_a1 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
120
+ p_a2 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
121
+ p_b1 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
122
+ p_b2 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
123
+ else:
124
+ p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
125
+ p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
126
+ p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
127
+ p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
128
+ b_a1 = tl.load(p_a1, boundary_check=(0, 1))
129
+ b_a2 = tl.load(p_a2, boundary_check=(0, 1))
130
+ b_b1 = tl.load(p_b1, boundary_check=(0, 1))
131
+ b_b2 = tl.load(p_b2, boundary_check=(0, 1))
132
+ b_A += tl.dot(b_a1, b_b1, allow_tf32=False)
133
+ b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False)
134
+ b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False)
135
+
136
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
137
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
138
+
139
+ for i in range(1, BC):
140
+ mask = tl.arange(0, BC) == i
141
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
142
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
143
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
144
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
145
+ b_A = tl.where(mask[:, None], b_a, b_A)
146
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
147
+
148
+ # blockwise computation of lower triangular matrix's inverse
149
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
150
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
151
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
152
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False)
153
+
154
+ if HEAD_FIRST:
155
+ p_A1 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
156
+ p_A2 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
157
+ p_A3 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
158
+ p_A4 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
159
+ else:
160
+ p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
161
+ p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
162
+ p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
163
+ p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
164
+ tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1))
165
+ tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1))
166
+ tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1))
167
+ # causal mask
168
+ tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1))
169
+
170
+
171
+ @triton.heuristics({
172
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
173
+ })
174
+ @triton.autotune(
175
+ configs=[
176
+ triton.Config({}, num_warps=num_warps)
177
+ for num_warps in NUM_WARPS
178
+ ],
179
+ key=['BT', 'BK', 'BV']
180
+ )
181
+ @triton.jit(do_not_specialize=['T'])
182
+ def fwd_wu_kernel(
183
+ w,
184
+ u,
185
+ a,
186
+ k,
187
+ v,
188
+ A,
189
+ offsets,
190
+ indices,
191
+ T,
192
+ H: tl.constexpr,
193
+ K: tl.constexpr,
194
+ V: tl.constexpr,
195
+ BT: tl.constexpr,
196
+ BK: tl.constexpr,
197
+ BV: tl.constexpr,
198
+ USE_OFFSETS: tl.constexpr,
199
+ HEAD_FIRST: tl.constexpr
200
+ ):
201
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
202
+ i_b, i_h = i_bh // H, i_bh % H
203
+ if USE_OFFSETS:
204
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
205
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
206
+ T = eos - bos
207
+ else:
208
+ bos, eos = i_b * T, i_b * T + T
209
+
210
+ if HEAD_FIRST:
211
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
212
+ else:
213
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
214
+
215
+ b_A = tl.load(p_A, boundary_check=(0, 1))
216
+ b_Aak = tl.zeros([BT, BT], dtype=tl.float32)
217
+
218
+ for i_k in range(tl.cdiv(K, BK)):
219
+ if HEAD_FIRST:
220
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
221
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
222
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
223
+ else:
224
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
225
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
226
+ p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
227
+ b_k = tl.load(p_k, boundary_check=(0, 1))
228
+ b_a = tl.load(p_a, boundary_check=(0, 1))
229
+ b_w = tl.dot(b_A, b_a)
230
+ b_Aak += tl.dot(b_a, tl.trans(b_k))
231
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
232
+
233
+ b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0)
234
+ b_Aak = b_Aak.to(k.dtype.element_ty)
235
+
236
+ for i_v in range(tl.cdiv(V, BV)):
237
+ if HEAD_FIRST:
238
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
239
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
240
+ else:
241
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
242
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
243
+ b_v = tl.load(p_v, boundary_check=(0, 1))
244
+ b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty)
245
+ b_u = tl.dot(b_A, b_v)
246
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
247
+
248
+
249
+ def fwd_prepare_wy_repr(
250
+ a: torch.Tensor,
251
+ b: torch.Tensor,
252
+ v: torch.Tensor,
253
+ k: torch.Tensor,
254
+ offsets: Optional[torch.LongTensor],
255
+ indices: Optional[torch.LongTensor],
256
+ head_first: bool = True,
257
+ chunk_size: int = 64
258
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
259
+ if head_first:
260
+ B, H, T, K = a.shape
261
+ else:
262
+ B, T, H, K = a.shape
263
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
264
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
265
+ BC = min(BT, 32)
266
+ BK = min(triton.next_power_of_2(K), 64)
267
+
268
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=a.device, dtype=a.dtype)
269
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
270
+
271
+ fwd_fn[(NT, B * H)](
272
+ a=a,
273
+ b=b,
274
+ A=A,
275
+ offsets=offsets,
276
+ indices=indices,
277
+ T=T,
278
+ H=H,
279
+ K=K,
280
+ BT=BT,
281
+ BK=BK,
282
+ BC=BC,
283
+ HEAD_FIRST=head_first
284
+ )
285
+ w, u = fwd_wu(
286
+ a=a,
287
+ v=v,
288
+ k=k,
289
+ A=A,
290
+ offsets=offsets,
291
+ indices=indices,
292
+ head_first=head_first,
293
+ chunk_size=chunk_size
294
+ )
295
+ return w, u, A
296
+
297
+
298
+ def fwd_wu(
299
+ a: torch.Tensor,
300
+ v: torch.Tensor,
301
+ k: torch.Tensor,
302
+ A: torch.Tensor,
303
+ offsets: Optional[torch.LongTensor],
304
+ indices: Optional[torch.LongTensor],
305
+ head_first: bool,
306
+ chunk_size: int
307
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
308
+ if head_first:
309
+ B, H, T, K, V = *a.shape, v.shape[-1]
310
+ else:
311
+ B, T, H, K, V = *a.shape, v.shape[-1]
312
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
313
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
314
+ CONST_TILING = 64 if check_shared_mem() else 32
315
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
316
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
317
+
318
+ u = torch.empty_like(v)
319
+ w = torch.empty_like(a)
320
+ fwd_wu_kernel[(NT, B*H)](
321
+ a=a,
322
+ v=v,
323
+ w=w,
324
+ u=u,
325
+ A=A,
326
+ k=k,
327
+ offsets=offsets,
328
+ indices=indices,
329
+ T=T,
330
+ H=H,
331
+ K=K,
332
+ V=V,
333
+ BT=BT,
334
+ BK=BK,
335
+ BV=BV,
336
+ HEAD_FIRST=head_first
337
+ )
338
+ return w, u
fla/ops/gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (365 Bytes). View file
 
fla/ops/gla/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (81.8 kB). View file
 
fla/ops/gla/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (35.3 kB). View file
 
fla/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (25.8 kB). View file
 
fla/ops/hgrn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (317 Bytes). View file
 
fla/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (347 Bytes). View file
 
fla/ops/linear_attn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (13.9 kB). View file
 
fla/ops/linear_attn/__pycache__/utils.cpython-312.pyc ADDED
Binary file (586 Bytes). View file
 
fla/ops/nsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (300 Bytes). View file