zaydzuhri commited on
Commit
a8e4d4c
·
verified ·
1 Parent(s): d9de648

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +523 -0
  3. config.json +35 -0
  4. configs/delta_net_1B.json +29 -0
  5. configs/delta_net_340M.json +27 -0
  6. configs/dsmtp_transformer_120M.json +19 -0
  7. configs/dsmtp_transformer_1B.json +23 -0
  8. configs/dsmtp_transformer_340M.json +19 -0
  9. configs/dsmtp_transformer_7B.json +22 -0
  10. configs/gla_340M.json +24 -0
  11. configs/gla_7B.json +25 -0
  12. configs/gsa_340M.json +29 -0
  13. configs/hgrn2_340M.json +20 -0
  14. configs/mtp_transformer_120M.json +19 -0
  15. configs/mtp_transformer_1B.json +23 -0
  16. configs/mtp_transformer_340M.json +19 -0
  17. configs/mtp_transformer_7B.json +22 -0
  18. configs/top_transformer_120M.json +20 -0
  19. configs/top_transformer_1B.json +24 -0
  20. configs/top_transformer_340M.json +20 -0
  21. configs/top_transformer_7B.json +23 -0
  22. configs/transformer_120M.json +18 -0
  23. configs/transformer_1B.json +22 -0
  24. configs/transformer_340M.json +18 -0
  25. configs/transformer_7B.json +21 -0
  26. fla/__init__.py +110 -0
  27. fla/layers/__pycache__/gla.cpython-312.pyc +0 -0
  28. fla/layers/__pycache__/hgrn.cpython-312.pyc +0 -0
  29. fla/layers/__pycache__/multiscale_retention.cpython-312.pyc +0 -0
  30. fla/layers/rwkv6.py +307 -0
  31. fla/ops/__pycache__/__init__.cpython-312.pyc +0 -0
  32. fla/ops/abc/__init__.py +7 -0
  33. fla/ops/abc/chunk.py +1116 -0
  34. fla/ops/abc/naive.py +96 -0
  35. fla/ops/attn/parallel.py +629 -0
  36. fla/ops/based/__init__.py +9 -0
  37. fla/ops/based/fused_chunk.py +374 -0
  38. fla/ops/based/naive.py +72 -0
  39. fla/ops/based/parallel.py +410 -0
  40. fla/ops/common/__init__.py +1 -0
  41. fla/ops/common/chunk_h.py +422 -0
  42. fla/ops/common/chunk_h_split.py +677 -0
  43. fla/ops/common/chunk_o.py +668 -0
  44. fla/ops/common/chunk_scaled_dot_kkt.py +126 -0
  45. fla/ops/common/fused_recurrent.py +575 -0
  46. fla/ops/common/utils.py +69 -0
  47. fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
  48. fla/ops/delta_rule/chunk.py +373 -0
  49. fla/ops/delta_rule/naive.py +120 -0
  50. fla/ops/delta_rule/wy_fast.py +340 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Linear Attention Made Easy
4
+ # This is a fork for the paper:
5
+ # "Predicting the Order of Upcoming Tokens Improves Language Modeling"
6
+
7
+ </div>
8
+
9
+ ## Instructions for Token Order Prediction
10
+
11
+ This fork can only work on an older commit of flame and torchtitan (forked at https://github.com/Erland366/torchtitan, see .gitmodules), so the setup looks like this:
12
+
13
+ ```bash
14
+ git clone https://github.com/zaydzuhri/flame.git
15
+ cd flame
16
+ git checkout token-order-prediction
17
+ git submodule update --init --recursive --remote
18
+ pip install .
19
+ pip install wheel
20
+ pip install flash-attn==2.7.3 --no-build-isolation --no-cache-dir
21
+ ```
22
+ The flash-linear-attention submodule has been changed to link to our fork: https://github.com/zaydzuhri/flash-linear-attention/tree/token-order-prediction
23
+ So no need to manually clone it.
24
+
25
+ Then prepare the fineweb-edu 100B sample the same way as described in the flame repo guide below, or:
26
+ ```py
27
+ from datasets import load_dataset
28
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=32, cache_dir="~/.cache")
29
+ ```
30
+
31
+ These are the training commands used in the paper:
32
+ ```bash
33
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/vanilla.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/vanilla_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name zaydzuhri/vanilla-340M-4096-batch16-steps100000 --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
34
+
35
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/mtp.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/mtp_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name zaydzuhri/mtp-340M-4096-batch16-steps100000 --comm.init_timeout_seconds 1800 --comm.train_timeout_seconds 1800
36
+
37
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/top.340M.batch8.seqlen4096.context4096.warmup1000.update2.steps100000.lr3e-4.cosine --model.config configs/top_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 8 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 2 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name zaydzuhri/top-340M-4096-batch16-steps100000 --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
38
+
39
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/vanilla.1B.batch8.seqlen4096.context4096.warmup2000.update2.steps200000.lr2e-4.cosine --model.config configs/vanilla_transformer_1B.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 2e-4 --lr_scheduler.warmup_steps 2000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 8 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 2 --training.steps 200000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name zaydzuhri/vanilla-1B-4096-batch8x2-steps200000 --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
40
+
41
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine --model.config configs/mtp_transformer_1B.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 2e-4 --lr_scheduler.warmup_steps 2000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 200000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name zaydzuhri/mtp-1B-4096-batch16x1-steps200000 --comm.init_timeout_seconds 1800 --comm.train_timeout_seconds 1800
42
+
43
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/top.1B.batch8.seqlen4096.context4096.warmup2000.update2.steps200000.lr2e-4.cosine --model.config configs/top_transformer_1B.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 2e-4 --lr_scheduler.warmup_steps 2000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 8 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 2 --training.steps 200000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name zaydzuhri/top-1B-4096-batch8x2-steps200000 --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
44
+
45
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/vanilla.7B.batch8.seqlen4096.context4096.warmup2000.update2.steps200000.lr1.2e-4.cosine --model.config configs/transformer_7B.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 1.2e-4 --lr_scheduler.warmup_steps 2000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 8 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 2 --training.steps 200000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name zaydzuhri/vanilla-7B-4096-batch8x2-steps200000 --comm.init_timeout_seconds 1800 --comm.train_timeout_seconds 1800
46
+
47
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/mtp.7B.batch8.seqlen4096.context4096.warmup2000.update2.steps200000.lr1.2e-4.cosine --model.config configs/mtp_transformer_7B.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 1.2e-4 --lr_scheduler.warmup_steps 2000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 8 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 2 --training.steps 200000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name zaydzuhri/mtp-7B-4096-batch8x2-steps200000 --comm.init_timeout_seconds 1800 --comm.train_timeout_seconds 1800
48
+
49
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/top.7B.batch8.seqlen4096.context4096.warmup2000.update2.steps200000.lr1.2e-4.cosine --model.config configs/top_transformer_7B.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 1.2e-4 --lr_scheduler.warmup_steps 2000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 8 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 2 --training.steps 200000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name zaydzuhri/top-7B-4096-batch8x2-steps200000 --comm.init_timeout_seconds 1800 --comm.train_timeout_seconds 1800
50
+ ```
51
+
52
+ Check out the wandb for training logs (although it is very unorganized lol): https://wandb.ai/zaydzuhri/fla
53
+
54
+ Feel free to DM @zmkzmkz on X for any questions regarding the paper or this code!
55
+
56
+ ## Flame
57
+
58
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency.
59
+
60
+ **Feature Highlights:**
61
+
62
+ - 🚀 Minimal, easy-to-use, extensible training framework
63
+ - 🤗 Seamless integration with `fla` and `transformers`
64
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
65
+ - 🔮 4D parallelism (coming soon)
66
+
67
+ ## Setup
68
+
69
+ To get started, clone the `flame` repository and install the required dependencies:
70
+
71
+ ```bash
72
+ git clone https://github.com/fla-org/flame.git
73
+ cd flame
74
+ pip install .
75
+ ```
76
+
77
+ `flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules.
78
+ After installation, initialize and update the submodules:
79
+ ```sh
80
+ git submodule update --init --recursive
81
+ ```
82
+
83
+ ## Dataset Preparation
84
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
85
+
86
+ ```py
87
+ from datasets import load_dataset
88
+
89
+ # load fineweb-edu with parallel processing
90
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
91
+
92
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
93
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
94
+ ```
95
+
96
+ ## Training Recipes
97
+
98
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode.
99
+
100
+ > [!WARNING]
101
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
102
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
103
+
104
+ ```sh
105
+ bash train.sh \
106
+ --job.config_file flame/models/fla.toml \
107
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \
108
+ --model.config configs/transformer_340M.json \
109
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
110
+ --optimizer.name AdamW \
111
+ --optimizer.eps 1e-15 \
112
+ --optimizer.lr 3e-4 \
113
+ --lr_scheduler.warmup_steps 1024 \
114
+ --lr_scheduler.lr_min 0.1 \
115
+ --lr_scheduler.decay_type cosine \
116
+ --training.batch_size 1 \
117
+ --training.seq_len 65536 \
118
+ --training.context_len 4096 \
119
+ --training.varlen \
120
+ --training.gradient_accumulation_steps 1 \
121
+ --training.steps 20480 \
122
+ --training.max_norm 1.0 \
123
+ --training.skip_nan_inf \
124
+ --training.dataset HuggingFaceFW/fineweb-edu \
125
+ --training.dataset_name sample-100BT \
126
+ --training.dataset_split train \
127
+ --training.streaming \
128
+ --training.num_workers 32 \
129
+ --training.prefetch_factor 2 \
130
+ --training.seed 42 \
131
+ --training.compile \
132
+ --checkpoint.interval 2048 \
133
+ --checkpoint.load_step -1 \
134
+ --checkpoint.keep_latest_k 2 \
135
+ --metrics.log_freq 1
136
+ ```
137
+
138
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
139
+ **For single-GPU debugging, set `NGPU=1`.**
140
+
141
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
142
+ By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
143
+
144
+ **Key parameters:**
145
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
146
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
147
+ - `--training.steps`: Total number of training steps.
148
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
149
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
150
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
151
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
152
+ - `--training.dataset_mode`: Choose `pretrain` (default) to stream fixed-length chunks or `finetune` to keep per-example sequences and pad within the batch.
153
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
154
+
155
+ > [!WARNING]
156
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
157
+ > Each step processes `global_batch_size * seq_len` tokens.
158
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
159
+
160
+ For a detailed explanation of all parameters, run:
161
+
162
+ ```sh
163
+ bash train.sh -h
164
+ ```
165
+
166
+ <details>
167
+ <summary>Usage</summary>
168
+
169
+ ```py
170
+ options:
171
+ -h, --help show this help message and exit
172
+ --job.config_file JOB.CONFIG_FILE
173
+ Job config file
174
+ --job.dump_folder JOB.DUMP_FOLDER
175
+ Folder to dump job outputs
176
+ --job.description JOB.DESCRIPTION
177
+ Description of the job
178
+ --job.use_for_integration_test
179
+ Add this config to the integration test suite
180
+ --job.print_args Print the args to terminal
181
+ --model.config MODEL.CONFIG
182
+ Path to the model config
183
+ --model.norm_type MODEL.NORM_TYPE
184
+ Type of layer normalization to use [layernorm,
185
+ np_layernorm, rmsnorm, fused_rmsnorm]
186
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
187
+ Tokenizer path
188
+ --profiling.enable_profiling
189
+ Whether to enable pytorch profiler
190
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
191
+ Trace files location
192
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
193
+ How often to collect profiler traces, in iterations
194
+ --profiling.enable_memory_snapshot
195
+ Whether to dump memory snapshot
196
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
197
+ Memeory snapshot files location
198
+ --optimizer.name OPTIMIZER.NAME
199
+ Optimizer to use
200
+ --optimizer.eps OPTIMIZER.EPS
201
+ Epsilon value for the optimizer.
202
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
203
+ --optimizer.scheduler {wsd,cosine,linear}
204
+ Scheduler to use. Currently supported: wsd, cosine,
205
+ and linear.
206
+ --optimizer.lr OPTIMIZER.LR
207
+ Learning rate to use
208
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
209
+ Min lr ratio for lr scheduler
210
+ --optimizer.early_step_in_backward
211
+ Whether to apply optimizer in the backward. Caution,
212
+ optimizer_in_backward is not compatible with gradients
213
+ clipping, users should not call
214
+ register_post_accumulate_grad_hook after the optimizer
215
+ is built.
216
+ --training.batch_size TRAINING.BATCH_SIZE
217
+ Batch size
218
+ --training.seq_len TRAINING.SEQ_LEN
219
+ Sequence length
220
+ --training.context_len TRAINING.CONTEXT_LEN
221
+ Max length allowed for each sequence
222
+ --training.varlen Whether to take sequences of variable length as input
223
+ --training.warmup_steps TRAINING.WARMUP_STEPS
224
+ Steps for lr scheduler warmup, normally 1/5 of
225
+ --training.steps
226
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
227
+ Number of steps to accumulate gradients before
228
+ updating parameters
229
+ --training.steps TRAINING.STEPS
230
+ How many train steps to run
231
+ --training.max_norm TRAINING.MAX_NORM
232
+ Max norm for gradient clipping
233
+ --training.skip_nan_inf
234
+ Skip batch updates when NaN or INF gradients are
235
+ encountered during training
236
+ --training.dataset TRAINING.DATASET
237
+ Dataset to use, with comma separated values
238
+ --training.dataset_name TRAINING.DATASET_NAME
239
+ The name of the dataset config, with comma separated
240
+ values if provided
241
+ --training.dataset_split TRAINING.DATASET_SPLIT
242
+ Dataset split to use, with comma separated values if
243
+ provided
244
+ --training.data_dir TRAINING.DATA_DIR
245
+ Data dirs to use, with comma separated values if
246
+ provided
247
+ --training.data_files TRAINING.DATA_FILES
248
+ Data files to use, with comma separated values if
249
+ provided
250
+ --training.data_probs TRAINING.DATA_PROBS
251
+ Data sampling probabilities, with comma separated
252
+ values if provided
253
+ --training.streaming Whether to load dataset in streaming mode, used for
254
+ huge dataset
255
+ --training.num_workers TRAINING.NUM_WORKERS
256
+ Number of subprocesses to use for data loading. 0
257
+ means that the data will be loaded in the main
258
+ process.
259
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
260
+ Number of batches loaded in advance by each worker.2
261
+ means there will be a total of 2 * num_workers batches
262
+ prefetched across all workers.
263
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
264
+ The `data_parallel_replicate_degree` argument
265
+ specifies the degree of data parallelism for weight
266
+ replication. When this value is greater than 1,
267
+ weights will be replicated across
268
+ `data_parallel_replicate_degree` ranks. If
269
+ `data_parallel_shard_degree` is also greater than 1,
270
+ the parallelism method used is HSDP (Hybrid Sharded
271
+ Data Parallelism). Otherwise, the parallelism method
272
+ used is DDP (Distributed Data Parallelism). 1 means
273
+ disabled.
274
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
275
+ The `data_parallel_shard_degree` argument specifies
276
+ the degree of data parallelism for weight sharding.
277
+ When this value is greater than 1, weights will be
278
+ sharded across `data_parallel_shard_degree` ranks. If
279
+ `data_parallel_replicate_degree` is also greater than
280
+ 1, the parallelism method used is HSDP (Hybrid Sharded
281
+ Data Parallelism). Otherwise, the parallelism method
282
+ used is FSDP (Fully Sharded Data Parallelism). -1
283
+ means leftover ranks will be used (After
284
+ DP_REPLICATE/SP/PP). Note that only
285
+ `data_parallel_shard_degree` can be negative. 1 means
286
+ disabled.
287
+ --training.enable_cpu_offload
288
+ Whether to apply CPU offloading of parameters,
289
+ gradients, and optimizer states in FSDP
290
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
291
+ Tensor Parallelism degree. 1 means disabled.
292
+ --training.disable_loss_parallel
293
+ Whether to apply loss parallel when sequence parallel
294
+ is enabled
295
+ --training.mixed_precision_param {bfloat16,float32}
296
+ torch dtype to use for parameters when applying mixed
297
+ precision via FSDP. This feature only takes effect
298
+ when data_parallel_shard_degree > 1
299
+ --training.mixed_precision_reduce {float32}
300
+ torch dtype to use for reductions when applying mixed
301
+ precision via FSDP. This feature only takes effect
302
+ when data_parallel_shard_degree > 1
303
+ --training.compile Whether to compile the model
304
+ --training.gc_freq TRAINING.GC_FREQ
305
+ Python garbage control scheduling interval, in steps
306
+ --training.seed TRAINING.SEED
307
+ Choose the base RNG seed used for training
308
+ --training.deterministic
309
+ Use deterministic algorithms wherever possible, may be
310
+ slower
311
+ --metrics.log_freq METRICS.LOG_FREQ
312
+ How often to log metrics to TensorBoard, in iterations
313
+ --metrics.enable_tensorboard
314
+ Whether to log metrics to TensorBoard
315
+ --metrics.disable_color_printing
316
+ Whether to disable color printing in logs
317
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
318
+ Folder to dump TensorBoard states
319
+ --metrics.rank_0_only
320
+ Whether to save TensorBoard metrics only for rank 0 or
321
+ for all ranks. When pipeline_parallel_degree is > 1,
322
+ this option uses the 0th rank of the last stage
323
+ pipeline group, which is the only stage that computes
324
+ loss metrics.
325
+ --metrics.enable_wandb
326
+ Whether to log metrics to Weights & Biases
327
+ --experimental.enable_async_tensor_parallel
328
+ Whether to apply async tensor parallel (currently only
329
+ effective when compile is enabled)
330
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
331
+ Pipeline Parallelism degree, or number of ranks. 1
332
+ means disabled. If using looped schedules, this still
333
+ specifies the number of physical ranks, not the number
334
+ of stages. Stages per rank are inferred from split
335
+ points degree, and schedule.
336
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
337
+ Specify comma-separated names of modules to use as the
338
+ beginning of a split point. e.g. "layers.0,layers.2"
339
+ will cause the model to be split into 3 stages, the
340
+ first containing all the layers up to layers.0, the
341
+ second containing layers.0 and up to layers.2, the
342
+ third containing layers.2 and all the remaining
343
+ layers. Note: fully-automated splitting may be enabled
344
+ in the future, but currently the split points must be
345
+ specified manually.
346
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
347
+ Specify the Pipeline Parallel schedule to use. The
348
+ supported schedules are: https://github.com/pytorch/py
349
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
350
+ rch/distributed/pipelining/schedules.py#L2161. The
351
+ schedule must be compatible with the split points and
352
+ stages_per_rank. Looped schedules (e.g.
353
+ Interleaved1F1B) require specifying
354
+ pipeline_parallel_degree = number of ranks, and
355
+ split_points = number of stages - 1
356
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
357
+ Specify the path to the pipeline parallel schedule csv
358
+ file to use. The pipeline_parallel_schedule argument
359
+ must be either PipelineScheduleSingle,
360
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
361
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
362
+ How many microbatches to split the global training
363
+ batch into when using pipeline parallelism. The global
364
+ training batch size must be evenly divisible by the
365
+ number of microbatches. The default value will be the
366
+ number of pipeline stages, if unspecified.
367
+ --experimental.enable_compiled_autograd
368
+ Enable CompiledAutograd to compile the backward.
369
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
370
+ Context parallelism degree. 1 means disabled.
371
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
372
+ The collective to use in context parallel SDPA for kv
373
+ shards exchange. 'allgather' means to all-gather all
374
+ kv shards on ranks after the first sub-SDPA
375
+ computation, 'alltoall' means to all-to-all shuffle
376
+ the kv shards. The default value is 'allgather'.
377
+ --checkpoint.enable_checkpoint
378
+ Whether to enable checkpoint
379
+ --checkpoint.folder CHECKPOINT.FOLDER
380
+ The folder to store the checkpoints. When
381
+ enable_checkpoint is set to true, checkpoints will be
382
+ in {--job.dump_folder}/{--checkpoint.folder}.
383
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
384
+ Checkpointing interval unit of measurement ['step',
385
+ 'seconds']
386
+ --checkpoint.interval CHECKPOINT.INTERVAL
387
+ Checkpointing interval, in steps or seconds depending
388
+ on --checkpoint.interval_type
389
+ --checkpoint.model_weights_only
390
+ When model_weights_only=True, only model weights will
391
+ be saved at the end of training. With this,
392
+ checkpoints can be loaded using `torch.load(...,
393
+ weights_only=True)` after conversion. When
394
+ model_weights_only=False, the full checkpoint will be
395
+ saved. A full checkpoint includes model, optimizer and
396
+ train_state, which can be used to resume training. The
397
+ default value is false.
398
+ --checkpoint.export_dtype {float16,bfloat16,float32}
399
+ Converts to the specified precision when training
400
+ completes and model_weights_only=true. Currently
401
+ supports float32, float16, and bfloat16. The default
402
+ value is float32.
403
+ --checkpoint.create_seed_checkpoint
404
+ Initializes the full model without applying
405
+ parallelisms, and then saves it as a seed checkpoint.
406
+ Note: requires user to call train.py without
407
+ specifying any parallelisms, e.g. NGPU=1. Could be
408
+ implemented as a separate script, but this way shares
409
+ more code.
410
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
411
+ Which async checkpoint mode to use. Currently there
412
+ are 3 different modes. 1. "disabled": synchronized
413
+ checkpointing will be used. 2. "async":
414
+ torch.distributed.checkpoint.async_save will be used.
415
+ 1. "async_with_pinned_mem": this option utilizes a
416
+ dedicated pinned memory space and creates a separate
417
+ process for faster GPU->CPU transfer performance and
418
+ eliminating GIL contention. The cost is increased CPU
419
+ memory usage. If insufficient CPU memory is available,
420
+ performance may degrade due to memory paging. For most
421
+ users, "async" should suffice as the performance
422
+ overhead is typically small (on the order of tens of
423
+ seconds) compared to checkpointing frequency. This
424
+ mode can be employed to pursue near-zero checkpointing
425
+ times (e.g., < 1 second) given appropriate hardware
426
+ support such as ample CPU memory and fast PCIe.
427
+ "disabled" is the default mode.
428
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
429
+ Keeps only the latest k checkpoints, and purging older
430
+ ones. If 0, keep all checkpoints. 0 is the default
431
+ value.
432
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
433
+ Load the checkpoint at the specified step. If -1, load
434
+ the latest checkpoint.
435
+ --float8.enable_float8_linear
436
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
437
+ This feature requires you to install 'torchao' which
438
+ can be found here: https://github.com/pytorch/ao
439
+ --float8.enable_fsdp_float8_all_gather
440
+ Whether enable float8 all-gather in FSDP
441
+ --float8.precompute_float8_dynamic_scale_for_fsdp
442
+ Whether precompute float8 scales dynamically for FSDP
443
+ --float8.scaling_type_input {dynamic,delayed}
444
+ float8 scaling for input, dynamic (default) or delayed
445
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
446
+ float8 scaling for input, dynamic (default) or delayed
447
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
448
+ float8 scaling for input, dynamic (default) or delayed
449
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
450
+ Timeout for communication operations, during
451
+ initialization and first train step.
452
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
453
+ Timeout for communication operations after the first
454
+ train step -- usually a tighter bound than during
455
+ initialization.
456
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
457
+ Flight recorder ring buffer size, >0 means recording
458
+ by default, 0 means disabled
459
+ --memory_estimation.enabled
460
+ Whether to estimate memory usage for FSDP
461
+ --memory_estimation.disable_fake_mode
462
+ Whether to estimate memory under FakeTensorMode
463
+ ```
464
+ </details>
465
+
466
+ ### Training with `torch.compile`
467
+
468
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
469
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
470
+
471
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
472
+ We are actively working on resolving these issues to make compilation transparent to users.
473
+ In the meantime, please ensure you are using the latest dependencies.
474
+
475
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
476
+
477
+ ### Training with multiple datasets
478
+
479
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
480
+ `flame` allows training with multiple datasets easily.
481
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
482
+
483
+ ```sh
484
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
485
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
486
+ ```
487
+
488
+ ### ~Finalizing training~
489
+
490
+ > [!NOTE]
491
+ > We have done this conversion automatically in the training script since our latest updates.
492
+
493
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
494
+ To facilitate this, we provide a straightforward conversion script:
495
+
496
+ ```sh
497
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
498
+ ```
499
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
500
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
501
+
502
+ ### Continual training
503
+
504
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
505
+ This allows you to seamlessly resume training with `flame`.
506
+ ```sh
507
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
508
+ ```
509
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
510
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
511
+
512
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
513
+
514
+ ## Multi-node training
515
+
516
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
517
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
518
+
519
+ To set up multi-node training:
520
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
521
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
522
+
523
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TOPTransformerForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "bos_token_id": 1,
7
+ "elementwise_affine": true,
8
+ "eos_token_id": 2,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "fuse_swiglu": true,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 4096,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": 14336,
17
+ "max_position_embeddings": 2048,
18
+ "model_type": "top_transformer",
19
+ "norm_eps": 1e-06,
20
+ "num_heads": 32,
21
+ "num_hidden_layers": 30,
22
+ "num_kv_heads": 8,
23
+ "qk_norm": false,
24
+ "qkv_bias": false,
25
+ "rope_theta": 10000.0,
26
+ "tie_word_embeddings": false,
27
+ "top_loss_ratio": 0.5,
28
+ "top_window_size": 4096,
29
+ "torch_dtype": "float32",
30
+ "transformers_version": "4.51.3",
31
+ "use_cache": true,
32
+ "use_top_loss": true,
33
+ "vocab_size": 32000,
34
+ "window_size": null
35
+ }
configs/delta_net_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "conv_size": 4,
6
+ "eos_token_id": 2,
7
+ "expand_k": 1,
8
+ "expand_v": 1,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.006,
15
+ "intermediate_size": null,
16
+ "model_type": "delta_net",
17
+ "norm_eps": 1e-06,
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 2,
21
+ "qk_activation": "silu",
22
+ "qk_norm": "l2",
23
+ "tie_word_embeddings": false,
24
+ "use_beta": true,
25
+ "use_cache": true,
26
+ "use_gate": false,
27
+ "use_output_norm": true,
28
+ "use_short_conv": true
29
+ }
configs/delta_net_340M.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.006,
13
+ "intermediate_size": null,
14
+ "model_type": "delta_net",
15
+ "norm_eps": 1e-06,
16
+ "norm_first": false,
17
+ "num_heads": 8,
18
+ "num_hidden_layers": 24,
19
+ "qk_activation": "silu",
20
+ "qk_norm": "l2",
21
+ "tie_word_embeddings": false,
22
+ "use_beta": true,
23
+ "use_cache": true,
24
+ "use_gate": false,
25
+ "use_output_norm": true,
26
+ "use_short_conv": true
27
+ }
configs/dsmtp_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "dsmtp_transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "n_future_tokens": 3
19
+ }
configs/dsmtp_transformer_1B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "dsmtp_transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "n_future_tokens": 4
23
+ }
configs/dsmtp_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "dsmtp_transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "n_future_tokens": 4
19
+ }
configs/dsmtp_transformer_7B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "dsmtp_transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 30,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null,
21
+ "n_future_tokens": 4
22
+ }
configs/gla_340M.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": null,
15
+ "model_type": "gla",
16
+ "num_heads": 4,
17
+ "num_hidden_layers": 24,
18
+ "norm_eps": 1e-06,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "vocab_size": 32000
24
+ }
configs/gla_7B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": 11008,
15
+ "model_type": "gla",
16
+ "norm_eps": 1e-06,
17
+ "num_heads": 16,
18
+ "num_hidden_layers": 32,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "use_output_gate": true,
24
+ "use_short_conv": false
25
+ }
configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
configs/hgrn2_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "expand_ratio": 128,
6
+ "fuse_cross_entropy": true,
7
+ "fuse_norm": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 1024,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "model_type": "hgrn2",
14
+ "num_heads": 8,
15
+ "num_hidden_layers": 24,
16
+ "norm_eps": 1e-06,
17
+ "tie_word_embeddings": false,
18
+ "use_cache": true,
19
+ "vocab_size": 32000
20
+ }
configs/mtp_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "mtp_transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "n_future_tokens": 4
19
+ }
configs/mtp_transformer_1B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "mtp_transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "n_future_tokens": 4
23
+ }
configs/mtp_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "mtp_transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "n_future_tokens": 4
19
+ }
configs/mtp_transformer_7B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "mtp_transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 30,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null,
21
+ "n_future_tokens": 4
22
+ }
configs/top_transformer_120M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "top_transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "use_top_loss": true,
19
+ "top_window_size": 2048
20
+ }
configs/top_transformer_1B.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "top_transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "use_top_loss": true,
23
+ "top_window_size": 4096
24
+ }
configs/top_transformer_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "top_transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "use_top_loss": true,
19
+ "top_window_size": 4096
20
+ }
configs/top_transformer_7B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "top_transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 30,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null,
21
+ "use_top_loss": true,
22
+ "top_window_size": 4096
23
+ }
configs/transformer_120M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_1B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false
22
+ }
configs/transformer_340M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_7B.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 30,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null
21
+ }
fla/__init__.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.layers import (
4
+ ABCAttention,
5
+ Attention,
6
+ BasedLinearAttention,
7
+ BitAttention,
8
+ DeltaNet,
9
+ GatedDeltaNet,
10
+ GatedDeltaProduct,
11
+ GatedLinearAttention,
12
+ GatedSlotAttention,
13
+ HGRN2Attention,
14
+ HGRNAttention,
15
+ LightNetAttention,
16
+ LinearAttention,
17
+ MultiScaleRetention,
18
+ NativeSparseAttention,
19
+ ReBasedLinearAttention,
20
+ RWKV6Attention,
21
+ RWKV7Attention
22
+ )
23
+ from fla.models import (
24
+ ABCForCausalLM,
25
+ ABCModel,
26
+ BitNetForCausalLM,
27
+ BitNetModel,
28
+ DeltaNetForCausalLM,
29
+ DeltaNetModel,
30
+ GatedDeltaNetForCausalLM,
31
+ GatedDeltaNetModel,
32
+ GatedDeltaProductForCausalLM,
33
+ GatedDeltaProductModel,
34
+ GLAForCausalLM,
35
+ GLAModel,
36
+ GSAForCausalLM,
37
+ GSAModel,
38
+ HGRN2ForCausalLM,
39
+ HGRN2Model,
40
+ HGRNForCausalLM,
41
+ LightNetForCausalLM,
42
+ LightNetModel,
43
+ LinearAttentionForCausalLM,
44
+ LinearAttentionModel,
45
+ NSAForCausalLM,
46
+ NSAModel,
47
+ RetNetForCausalLM,
48
+ RetNetModel,
49
+ RWKV6ForCausalLM,
50
+ RWKV6Model,
51
+ RWKV7ForCausalLM,
52
+ RWKV7Model,
53
+ TransformerForCausalLM,
54
+ TransformerModel
55
+ )
56
+
57
+ __all__ = [
58
+ 'ABCAttention',
59
+ 'Attention',
60
+ 'BasedLinearAttention',
61
+ 'BitAttention',
62
+ 'DeltaNet',
63
+ 'GatedDeltaNet',
64
+ 'GatedDeltaProduct',
65
+ 'GatedLinearAttention',
66
+ 'GatedSlotAttention',
67
+ 'HGRNAttention',
68
+ 'HGRN2Attention',
69
+ 'LightNetAttention',
70
+ 'LinearAttention',
71
+ 'MultiScaleRetention',
72
+ 'NativeSparseAttention',
73
+ 'ReBasedLinearAttention',
74
+ 'RWKV6Attention',
75
+ 'RWKV7Attention',
76
+ 'ABCForCausalLM',
77
+ 'ABCModel',
78
+ 'BitNetForCausalLM',
79
+ 'BitNetModel',
80
+ 'DeltaNetForCausalLM',
81
+ 'DeltaNetModel',
82
+ 'GatedDeltaNetForCausalLM',
83
+ 'GatedDeltaNetModel',
84
+ 'GatedDeltaProductForCausalLM',
85
+ 'GatedDeltaProductModel',
86
+ 'GLAForCausalLM',
87
+ 'GLAModel',
88
+ 'GSAForCausalLM',
89
+ 'GSAModel',
90
+ 'HGRNForCausalLM',
91
+ 'HGRNModel',
92
+ 'HGRN2ForCausalLM',
93
+ 'HGRN2Model',
94
+ 'LightNetForCausalLM',
95
+ 'LightNetModel',
96
+ 'LinearAttentionForCausalLM',
97
+ 'LinearAttentionModel',
98
+ 'NSAForCausalLM',
99
+ 'NSAModel',
100
+ 'RetNetForCausalLM',
101
+ 'RetNetModel',
102
+ 'RWKV6ForCausalLM',
103
+ 'RWKV6Model',
104
+ 'RWKV7ForCausalLM',
105
+ 'RWKV7Model',
106
+ 'TransformerForCausalLM',
107
+ 'TransformerModel',
108
+ ]
109
+
110
+ __version__ = '0.1.2'
fla/layers/__pycache__/gla.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
fla/layers/__pycache__/hgrn.cpython-312.pyc ADDED
Binary file (6.7 kB). View file
 
fla/layers/__pycache__/multiscale_retention.cpython-312.pyc ADDED
Binary file (12.4 kB). View file
 
fla/layers/rwkv6.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV6Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ expand_k: float = 0.5,
29
+ expand_v: float = 1.0,
30
+ num_heads: int = 4,
31
+ gate_fn: str = 'swish',
32
+ proj_low_rank_dim: int = 32,
33
+ gate_low_rank_dim: int = 64,
34
+ fuse_norm: bool = True,
35
+ elementwise_affine: Optional[bool] = True,
36
+ norm_eps: float = 1e-5,
37
+ layer_idx: int = None,
38
+ **kwargs
39
+ ) -> RWKV6Attention:
40
+ super().__init__()
41
+
42
+ self.mode = mode
43
+ self.hidden_size = hidden_size
44
+ self.expand_k = expand_k
45
+ self.expand_v = expand_v
46
+ self.num_heads = num_heads
47
+ self.proj_low_rank_dim = proj_low_rank_dim
48
+ self.gate_low_rank_dim = gate_low_rank_dim
49
+
50
+ self.key_dim = int(hidden_size * expand_k)
51
+ self.value_dim = int(hidden_size * expand_v)
52
+ self.layer_idx = layer_idx
53
+
54
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
55
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
56
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
57
+
58
+ self.head_k_dim = self.key_dim // num_heads
59
+ self.head_v_dim = self.value_dim // num_heads
60
+
61
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
62
+ self.x_proj = nn.Sequential(
63
+ LerpLinear(hidden_size, proj_low_rank_dim * 5),
64
+ nn.Tanh(),
65
+ nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=False)
66
+ )
67
+ self.x_bias = nn.Parameter(torch.zeros(5, hidden_size))
68
+
69
+ self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
70
+ self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
71
+ self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
72
+ self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
73
+ self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
74
+ self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_k_dim))
75
+
76
+ # TODO: fuse GroupNorm and output gate
77
+ self.g_norm = GroupNorm(self.num_heads, self.value_dim, elementwise_affine=elementwise_affine, bias=True, eps=norm_eps)
78
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
79
+ self.gate_fn = ACT2FN[gate_fn]
80
+
81
+ self.apply(self._initialize_weights)
82
+
83
+ def _initialize_weights(self, module: nn.Module):
84
+ if getattr(module, "_is_hf_initialized", False):
85
+ return
86
+ if isinstance(module, nn.Linear):
87
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
88
+ if module.bias is not None:
89
+ nn.init.zeros_(module.bias)
90
+ if isinstance(module, nn.Parameter):
91
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
92
+ module._is_hf_initialized = True
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ past_key_values: Optional[Cache] = None,
99
+ use_cache: Optional[bool] = False,
100
+ output_attentions: Optional[bool] = False,
101
+ **kwargs
102
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
103
+ if attention_mask is not None:
104
+ assert len(attention_mask.shape) == 2, (
105
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
106
+ "for padding purposes (0 indicating padding). "
107
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
108
+ )
109
+
110
+ batch_size, seq_len, hidden_size = hidden_states.shape
111
+ # launching the triton kernel for just one token will actually be slower
112
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
113
+
114
+ last_state = None
115
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
116
+ last_state = past_key_values[self.layer_idx]
117
+
118
+ if attention_mask is not None:
119
+ hidden_states = hidden_states.mul_(attention_mask[:, -hidden_states.shape[-2]:, None])
120
+ if hidden_states.shape[1] == 1 and last_state is not None:
121
+ shifted = last_state['conv_state'].unsqueeze(1)
122
+ else:
123
+ shifted = self.time_shift(hidden_states)
124
+ if last_state is not None:
125
+ shifted[:, 0] = last_state['conv_state']
126
+
127
+ delta = shifted - hidden_states
128
+ x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
129
+ x = torch.einsum('b t n r, h n r-> b t n h', self.x_proj[1](x), self.x_proj[2].weight.view(hidden_size, 5, -1))
130
+
131
+ r, w, k, v, g = x.add_(self.x_bias).unbind(-2)
132
+ r = self.r_proj(hidden_states, r, delta)
133
+ w = self.w_proj(hidden_states, w, delta)
134
+ k = self.k_proj(hidden_states, k, delta)
135
+ v = self.v_proj(hidden_states, v, delta)
136
+ g = self.g_proj(hidden_states, g, delta)
137
+
138
+ # dealing with left-padding
139
+ if attention_mask is not None:
140
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
141
+ r, w, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (r, w, k))
142
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
143
+ w = -torch.exp(w)
144
+ u = self.bonus
145
+
146
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
147
+ cu_seqlens = kwargs.get('cu_seqlens', None)
148
+ if mode == 'fused_recurrent':
149
+ o, recurrent_state = fused_recurrent_rwkv6(
150
+ r=r,
151
+ k=k,
152
+ v=v,
153
+ w=w,
154
+ u=u,
155
+ scale=1.,
156
+ initial_state=recurrent_state,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens,
159
+ head_first=False
160
+ )
161
+ elif mode == 'chunk':
162
+ o, recurrent_state = chunk_rwkv6(
163
+ q=r,
164
+ k=k,
165
+ v=v,
166
+ g=w,
167
+ u=u,
168
+ scale=1.,
169
+ initial_state=recurrent_state,
170
+ output_final_state=use_cache,
171
+ cu_seqlens=cu_seqlens,
172
+ head_first=False
173
+ )
174
+ else:
175
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
176
+
177
+ if past_key_values is not None:
178
+ past_key_values.update(
179
+ recurrent_state=recurrent_state,
180
+ conv_state=hidden_states[:, -1],
181
+ layer_idx=self.layer_idx,
182
+ offset=r.shape[2]
183
+ )
184
+
185
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)')) * self.gate_fn(g)
186
+ o = self.o_proj(o)
187
+
188
+ return o, None, past_key_values
189
+
190
+
191
+ class LoRA(nn.Module):
192
+
193
+ def __init__(
194
+ self,
195
+ input_dim: int,
196
+ output_dim: int,
197
+ low_rank_dim: int,
198
+ bias: Optional[bool] = True,
199
+ activation: Optional[str] = 'tanh'
200
+ ):
201
+ super().__init__()
202
+
203
+ self.input_dim = input_dim
204
+ self.output_dim = output_dim
205
+ self.low_rank_dim = low_rank_dim
206
+ self.bias = bias
207
+
208
+ if activation is None:
209
+ self.activation = nn.Identity()
210
+ elif activation == 'sigmoid':
211
+ self.activation = nn.Sigmoid()
212
+ elif activation == 'tanh':
213
+ self.activation = nn.Tanh()
214
+ elif activation == 'relu':
215
+ self.activation = nn.ReLU()
216
+ else:
217
+ raise ValueError(f"Not supported activation `{activation}`.")
218
+
219
+ self.lora = nn.Sequential(
220
+ nn.Linear(input_dim, low_rank_dim, bias=False),
221
+ self.activation,
222
+ nn.Linear(low_rank_dim, output_dim, bias=bias)
223
+ )
224
+
225
+ def __repr__(self) -> str:
226
+ s = f"{self.__class__.__name__}("
227
+ s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
228
+ if not self.bias:
229
+ s += f", bias={self.bias}"
230
+ s += ")"
231
+ return s
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ return self.lora(x)
235
+
236
+
237
+ class LerpLinear(nn.Module):
238
+
239
+ def __init__(
240
+ self,
241
+ input_dim: int,
242
+ output_dim: int,
243
+ low_rank_dim: Optional[int] = None
244
+ ):
245
+ super().__init__()
246
+
247
+ self.input_dim = input_dim
248
+ self.output_dim = output_dim
249
+ self.low_rank_dim = low_rank_dim
250
+
251
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
252
+ if low_rank_dim is None:
253
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
254
+ else:
255
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
256
+ self.mu = nn.Parameter(torch.zeros(input_dim))
257
+
258
+ def __repr__(self) -> str:
259
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
260
+ if self.low_rank_dim is not None:
261
+ s += f", low_rank_dim={self.low_rank_dim}"
262
+ s += ")"
263
+ return s
264
+
265
+ def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
266
+ if delta is None:
267
+ shifted = self.time_shift(x)
268
+ if len(shifted.shape) == 2:
269
+ shifted = shifted.unsqueeze(1)
270
+ delta = shifted - x
271
+ return self.linear(x + delta * self.mu)
272
+
273
+
274
+ class DDLerpLinear(nn.Module):
275
+
276
+ def __init__(
277
+ self,
278
+ input_dim: int,
279
+ output_dim: int,
280
+ low_rank_dim: Optional[int] = None
281
+ ):
282
+ super().__init__()
283
+
284
+ self.input_dim = input_dim
285
+ self.output_dim = output_dim
286
+ self.low_rank_dim = low_rank_dim
287
+
288
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
289
+ if low_rank_dim is None:
290
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
291
+ else:
292
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
293
+
294
+ def __repr__(self) -> str:
295
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
296
+ if self.low_rank_dim is not None:
297
+ s += f", low_rank_dim={self.low_rank_dim}"
298
+ s += ")"
299
+ return s
300
+
301
+ def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
302
+ if delta is None:
303
+ shifted = self.time_shift(x)
304
+ if len(shifted.shape) == 2:
305
+ shifted = shifted.unsqueeze(1)
306
+ delta = shifted - x
307
+ return self.linear(x + delta * mu)
fla/ops/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.9 kB). View file
 
fla/ops/abc/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_abc
4
+
5
+ __all__ = [
6
+ 'chunk_abc'
7
+ ]
fla/ops/abc/chunk.py ADDED
@@ -0,0 +1,1116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils import logcumsumexp_fwd_kernel, softmax_bwd, softmax_fwd
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.jit(do_not_specialize=['T'])
16
+ def chunk_abc_fwd_kernel_h(
17
+ k,
18
+ v,
19
+ z,
20
+ h,
21
+ h0,
22
+ ht,
23
+ T,
24
+ K: tl.constexpr,
25
+ V: tl.constexpr,
26
+ BT: tl.constexpr,
27
+ BK: tl.constexpr,
28
+ BV: tl.constexpr,
29
+ NT: tl.constexpr,
30
+ NORMK: tl.constexpr,
31
+ USE_INITIAL_STATE: tl.constexpr,
32
+ STORE_FINAL_STATE: tl.constexpr
33
+ ):
34
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
35
+
36
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
37
+ if USE_INITIAL_STATE:
38
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
39
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
40
+ if NORMK:
41
+ p_z0 = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_k * BK,), (BK,), (0,))
42
+ else:
43
+ p_z0 = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_v * BV,), (BV,), (0,))
44
+ b_zp = tl.load(p_z0).to(tl.float32)
45
+ for i_t in range(NT):
46
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
47
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
48
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
49
+
50
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
51
+ # [BK, BT]
52
+ b_k = tl.load(p_k, boundary_check=(0, 1))
53
+ # [BT, BV]
54
+ b_v = tl.load(p_v, boundary_check=(0, 1))
55
+ if NORMK:
56
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
57
+ # [BK,]
58
+ b_zc = tl.load(p_zc, boundary_check=(0,))
59
+ b_r, b_zp = exp(b_zp - b_zc), b_zc
60
+ # [BK, BV]
61
+ b_h = b_h * b_r[:, None]
62
+ b_k = exp(b_k - b_zc[:, None]).to(b_k.dtype)
63
+ else:
64
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))
65
+ # [BV,]
66
+ b_zc = tl.load(p_zc, boundary_check=(0,))
67
+ b_r, b_zp = exp(b_zp - b_zc), b_zc
68
+ # [BK, BV]
69
+ b_h = b_h * b_r[None, :]
70
+ b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype)
71
+ # [BK, BV]
72
+ b_h += tl.dot(b_k, b_v, allow_tf32=False)
73
+
74
+ if STORE_FINAL_STATE:
75
+ p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
76
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+
79
+ @triton.jit(do_not_specialize=['T'])
80
+ def chunk_abc_fwd_kernel_intra_K(
81
+ v,
82
+ z,
83
+ o,
84
+ A,
85
+ T,
86
+ V: tl.constexpr,
87
+ BT: tl.constexpr,
88
+ BC: tl.constexpr,
89
+ BV: tl.constexpr,
90
+ NC: tl.constexpr
91
+ ):
92
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
93
+ i_t, i_i = i_c // NC, i_c % NC
94
+
95
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
97
+ # [BV,]
98
+ b_zn = tl.load(p_zn, boundary_check=(0,))
99
+ # [BC, BV]
100
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
101
+ for i_j in range(0, i_i):
102
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
103
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
104
+ # [BC, BV]
105
+ b_v = tl.load(p_v, boundary_check=(0, 1))
106
+ # [BC, BC]
107
+ b_A = tl.load(p_A, boundary_check=(0, 1))
108
+ b_o += tl.dot(b_A, exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False)
109
+ b_z = tl.load(p_z, boundary_check=(0, 1))
110
+ b_o *= exp(b_zn[None, :] - b_z)
111
+
112
+ o_i = tl.arange(0, BC)
113
+ o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
114
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
115
+ for j in range(0, BC):
116
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
117
+ # [BC,]
118
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
119
+ # [BV,]
120
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
121
+ # [BC, BV]
122
+ # avoid 0 * inf = inf
123
+ m_i = o_i[:, None] >= j
124
+ b_o += tl.where(m_i, b_A[:, None] * exp(b_v[None, :] - b_z), 0)
125
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
126
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
127
+
128
+
129
+ @triton.jit(do_not_specialize=['T'])
130
+ def chunk_abc_fwd_kernel_K(
131
+ q,
132
+ k,
133
+ z,
134
+ h,
135
+ o,
136
+ A,
137
+ scale,
138
+ T,
139
+ K: tl.constexpr,
140
+ V: tl.constexpr,
141
+ BT: tl.constexpr,
142
+ BK: tl.constexpr,
143
+ BV: tl.constexpr,
144
+ NT: tl.constexpr
145
+ ):
146
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
147
+ i_p = tl.maximum(i_t * BT - 1, 0)
148
+
149
+ o_i = tl.arange(0, BT)
150
+ m_s = o_i[:, None] >= o_i[None, :]
151
+
152
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
153
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
154
+ for i_k in range(tl.cdiv(K, BK)):
155
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
156
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
157
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
158
+
159
+ # [BT, BK]
160
+ b_q = tl.load(p_q, boundary_check=(0, 1))
161
+ b_q = (b_q * scale).to(b_q.dtype)
162
+ # [BK, BT]
163
+ b_k = tl.load(p_k, boundary_check=(0, 1))
164
+ # [BK, BV]
165
+ b_h = tl.load(p_h, boundary_check=(0, 1))
166
+ # [BT, BV]
167
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
168
+ # [BT, BT]
169
+ b_A += tl.dot(b_q, b_k, allow_tf32=False)
170
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
171
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
172
+ # [BT, BV]
173
+ b_z = tl.load(p_z, boundary_check=(0, 1))
174
+ # [BT, BV]
175
+ p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
176
+ b_zp = tl.load(p_zp, boundary_check=(0,))
177
+ b_o = b_o * exp(b_zp[None, :] - b_z)
178
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
179
+
180
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
181
+ # [BT, BT]
182
+ b_A = tl.where(m_s, b_A, 0.)
183
+ if i_v == 0:
184
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
185
+
186
+
187
+ @triton.jit(do_not_specialize=['T'])
188
+ def chunk_abc_fwd_kernel_intra_V(
189
+ q,
190
+ k,
191
+ z,
192
+ A,
193
+ scale,
194
+ T,
195
+ K: tl.constexpr,
196
+ BT: tl.constexpr,
197
+ BC: tl.constexpr,
198
+ BK: tl.constexpr,
199
+ NC: tl.constexpr
200
+ ):
201
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
202
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
203
+ n_bh = tl.num_programs(2)
204
+
205
+ if i_i > i_j:
206
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
207
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
208
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
209
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
210
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
211
+ # [BK,]
212
+ b_zn = tl.load(p_zn, boundary_check=(0,))
213
+ # [BC, BK]
214
+ b_q = tl.load(p_q, boundary_check=(0, 1))
215
+ b_z = tl.load(p_z, boundary_check=(0, 1))
216
+ b_q = (b_q * exp(b_zn[None, :] - b_z) * scale).to(b_q.dtype)
217
+ # [BK, BC]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ b_k = exp(b_k - b_zn[:, None]).to(b_k.dtype)
220
+ # [BC, BC]
221
+ b_A = tl.dot(b_q, b_k, allow_tf32=False)
222
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
223
+ elif i_i == i_j:
224
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
225
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
226
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
227
+ # [BC, BK]
228
+ b_q = tl.load(p_q, boundary_check=(0, 1))
229
+ b_z = tl.load(p_z, boundary_check=(0, 1))
230
+
231
+ o_i = tl.arange(0, BC)
232
+ o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
233
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
234
+ for j in range(0, BC):
235
+ # [BK,]
236
+ b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
237
+ # [BC,]
238
+ b_A = tl.sum(b_q * exp(b_k[None, :] - b_z) * scale, 1)
239
+ b_A = tl.where(o_i >= j, b_A, 0.)
240
+ tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)
241
+
242
+ p_k = tl.advance(p_k, (K,))
243
+
244
+
245
+ @triton.jit(do_not_specialize=['T'])
246
+ def chunk_abc_fwd_kernel_V(
247
+ q,
248
+ v,
249
+ z,
250
+ h,
251
+ o,
252
+ A,
253
+ scale,
254
+ T,
255
+ K: tl.constexpr,
256
+ V: tl.constexpr,
257
+ BT: tl.constexpr,
258
+ BK: tl.constexpr,
259
+ BV: tl.constexpr,
260
+ NT: tl.constexpr
261
+ ):
262
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
263
+ i_p = tl.maximum(i_t * BT - 1, 0)
264
+
265
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
266
+ for i_k in range(tl.cdiv(K, BK)):
267
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
268
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
269
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
270
+ p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
271
+
272
+ # [BT, BK]
273
+ b_q = tl.load(p_q, boundary_check=(0, 1))
274
+ b_q = (b_q * scale).to(b_q.dtype)
275
+ # [BT, BK]
276
+ b_z = tl.load(p_z, boundary_check=(0, 1))
277
+ # [BT, BK]
278
+ b_zp = tl.load(p_zp, boundary_check=(0,))
279
+ b_q = (b_q * exp(b_zp[None, :] - b_z)).to(b_q.dtype)
280
+ # [BK, BV]
281
+ b_h = tl.load(p_h, boundary_check=(0, 1))
282
+ # works but dkw, owing to divine benevolence
283
+ # [BT, BV]
284
+ if i_k >= 0:
285
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
286
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
287
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
288
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
289
+ # [BT, BV]
290
+ b_v = tl.load(p_v, boundary_check=(0, 1))
291
+ # [BT, BT]
292
+ b_A = tl.load(p_A, boundary_check=(0, 1))
293
+ b_o += tl.dot(b_A.to(b_v.dtype), b_v, allow_tf32=False)
294
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
295
+
296
+
297
+ @triton.jit(do_not_specialize=['T'])
298
+ def chunk_abc_bwd_kernel_dh(
299
+ q,
300
+ z,
301
+ do,
302
+ dh,
303
+ scale,
304
+ T,
305
+ K: tl.constexpr,
306
+ V: tl.constexpr,
307
+ BT: tl.constexpr,
308
+ BK: tl.constexpr,
309
+ BV: tl.constexpr,
310
+ NT: tl.constexpr,
311
+ NORMK: tl.constexpr
312
+ ):
313
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
314
+
315
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
316
+ b_zp = tl.full([BK if NORMK else BV], float('inf'), dtype=tl.float32)
317
+ for i_t in range(NT - 1, -1, -1):
318
+ i_p = tl.maximum(i_t * BT - 1, 0)
319
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
320
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
321
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
322
+
323
+ # [BK, BT]
324
+ b_q = tl.load(p_q, boundary_check=(0, 1))
325
+ b_q = (b_q * scale).to(b_q.dtype)
326
+ # [BT, BV]
327
+ b_do = tl.load(p_do, boundary_check=(0, 1))
328
+
329
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
330
+ if NORMK:
331
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
332
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
333
+ # [BK,]
334
+ b_zc = tl.load(p_zc, boundary_check=(0,))
335
+ b_r, b_zp = exp(b_zc - b_zp), b_zc
336
+ # [BK, BT]
337
+ b_z = tl.load(p_z, boundary_check=(0, 1))
338
+ b_q = (b_q * exp(b_zc[:, None] - b_z)).to(b_q.dtype)
339
+ # [BK, BV]
340
+ b_dh = b_dh * b_r[:, None]
341
+ else:
342
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
343
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
344
+ # [BV,]
345
+ b_zc = tl.load(p_zc, boundary_check=(0,))
346
+ b_r, b_zp = exp(b_zc - b_zp), b_zc
347
+ # [BT, BV]
348
+ b_z = tl.load(p_z, boundary_check=(0,))
349
+ b_do = (b_do * exp(b_zc[None, :] - b_z)).to(b_do.dtype)
350
+ # [BK, BV]
351
+ b_dh = b_dh * b_r[None, :]
352
+ # [BK, BV]
353
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
354
+
355
+
356
+ @triton.jit(do_not_specialize=['T'])
357
+ def chunk_abc_bwd_kernel_V(
358
+ k,
359
+ v,
360
+ z,
361
+ h,
362
+ A,
363
+ do,
364
+ dh,
365
+ dq,
366
+ dk,
367
+ dv,
368
+ dA,
369
+ scale,
370
+ T,
371
+ K: tl.constexpr,
372
+ V: tl.constexpr,
373
+ BT: tl.constexpr,
374
+ BK: tl.constexpr,
375
+ BV: tl.constexpr,
376
+ NT: tl.constexpr
377
+ ):
378
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
379
+ i_p = tl.maximum(i_t * BT - 1, 0)
380
+ n_bh = tl.num_programs(2)
381
+
382
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
383
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
384
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
385
+
386
+ # [BK,]
387
+ b_zc = tl.load(p_zc, boundary_check=(0,))
388
+ # [BT, BK]
389
+ b_k = tl.load(p_k, boundary_check=(0, 1))
390
+ b_k = exp(b_k - b_zc[None, :]).to(b_k.dtype)
391
+ # [BT, BT]
392
+ b_A = tl.load(p_A, boundary_check=(0, 1))
393
+
394
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
395
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
396
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
397
+ for i_v in range(tl.cdiv(V, BV)):
398
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
399
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * V * K, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
400
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
401
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
402
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
403
+
404
+ # [BT, BV]
405
+ b_v = tl.load(p_v, boundary_check=(0, 1))
406
+ # [BV, BK]
407
+ b_h = tl.load(p_h, boundary_check=(0, 1))
408
+ # [BT, BV]
409
+ b_do = tl.load(p_do, boundary_check=(0, 1))
410
+ # [BK, BV]
411
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
412
+
413
+ # [BT, BV]
414
+ b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
415
+ if i_k == 0:
416
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do, allow_tf32=False)
417
+ b_do = (b_do * scale).to(b_do.dtype)
418
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
419
+ # [BT, BT]
420
+ b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
421
+ # [BT, BK]
422
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
423
+ # [BT, BK]
424
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
425
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
426
+ p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
427
+ # [BK,]
428
+ b_zp = tl.load(p_zp, boundary_check=(0,))
429
+ # [BT, BK]
430
+ b_z = tl.load(p_z, boundary_check=(0, 1))
431
+ b_z = exp(b_zp[None, :] - b_z)
432
+ # [BT, BK]
433
+ b_dq = b_dq * b_z
434
+ b_dk = b_dk * b_k
435
+
436
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
437
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
438
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT,), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
439
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
440
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
441
+
442
+ o_i = tl.arange(0, BT)
443
+ m_s = o_i[:, None] >= o_i[None, :]
444
+ # [BT, BT]
445
+ b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
446
+ if i_k == 0:
447
+ tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
448
+
449
+
450
+ @triton.jit(do_not_specialize=['T'])
451
+ def chunk_abc_bwd_kernel_intra_V(
452
+ q,
453
+ k,
454
+ z,
455
+ dA,
456
+ dq,
457
+ dk,
458
+ T,
459
+ K: tl.constexpr,
460
+ BT: tl.constexpr,
461
+ BC: tl.constexpr,
462
+ BK: tl.constexpr,
463
+ NC: tl.constexpr
464
+ ):
465
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
466
+ i_t, i_i = i_c // NC, i_c % NC
467
+
468
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
469
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
470
+ # [BK,]
471
+ b_zn = tl.load(p_zn, boundary_check=(0,))
472
+ # [BC, BK]
473
+ b_z = tl.load(p_z, boundary_check=(0, 1))
474
+ b_zq = exp(b_zn[None, :] - b_z)
475
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
476
+ for i_j in range(0, i_i):
477
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
478
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
479
+ # [BC, BK]
480
+ b_k = tl.load(p_k, boundary_check=(0, 1))
481
+ b_kz = exp(b_k - b_zn[None, :]).to(b_k.dtype)
482
+ # [BC, BC]
483
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
484
+ # [BC, BK]
485
+ b_dq += tl.dot(b_dA, b_kz, allow_tf32=False)
486
+ b_dq *= b_zq
487
+
488
+ o_i = tl.arange(0, BC)
489
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
490
+ m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
491
+ for j in range(0, BC):
492
+ p_kj = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
493
+ # [BC,]
494
+ b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
495
+ # [BK,]
496
+ b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
497
+ # [BC, BK]
498
+ m_i = o_i[:, None] >= j
499
+ # [BC, BK]
500
+ b_dq += tl.where(m_i, b_dA[:, None] * exp(b_kj[None, :] - b_z), 0.)
501
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
502
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
503
+
504
+ tl.debug_barrier()
505
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
506
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T*K,), (1,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
507
+ # [BK,]
508
+ b_zn = tl.load(p_zn, boundary_check=(0,))
509
+ # [BC, BK]
510
+ b_k = tl.load(p_k, boundary_check=(0, 1))
511
+ b_kz = exp(b_k - b_zn[None, :])
512
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
513
+ for i_j in range(i_i + 1, NC):
514
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
515
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
516
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
517
+ # [BC, BK]
518
+ b_q = tl.load(p_q, boundary_check=(0, 1))
519
+ b_z = tl.load(p_z, boundary_check=(0, 1))
520
+ b_qz = (b_q * exp(b_zn[None, :] - b_z)).to(b_q.dtype)
521
+ # [BC, BC]
522
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
523
+ # [BC, BK]
524
+ b_dk += tl.dot(tl.trans(b_dA), b_qz, allow_tf32=False)
525
+ b_dk *= b_kz
526
+
527
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
528
+ for j in range(0, BC):
529
+ p_qj = tl.make_block_ptr(q + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
530
+ p_zj = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
531
+ # [BC,]
532
+ b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
533
+ # [BK,]
534
+ b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
535
+ b_zj = tl.load(p_zj, boundary_check=(0,)).to(tl.float32)
536
+ # [BC, BK]
537
+ m_i = o_i[:, None] <= j
538
+ b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_k - b_zj[None, :]), 0.)
539
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
540
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
541
+
542
+
543
+ @triton.jit(do_not_specialize=['T'])
544
+ def chunk_abc_bwd_kernel_intra_K(
545
+ v,
546
+ z,
547
+ do,
548
+ dA,
549
+ scale,
550
+ T,
551
+ V: tl.constexpr,
552
+ BT: tl.constexpr,
553
+ BC: tl.constexpr,
554
+ BV: tl.constexpr,
555
+ NC: tl.constexpr
556
+ ):
557
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
558
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
559
+ n_bh = tl.num_programs(2)
560
+
561
+ if i_i > i_j:
562
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
563
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
564
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
565
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
566
+ p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
567
+ # [BV,]
568
+ b_zn = tl.load(p_zn, boundary_check=(0,))
569
+ # [BC, BV]
570
+ b_z = tl.load(p_z, boundary_check=(0, 1))
571
+ b_do = tl.load(p_do, boundary_check=(0, 1))
572
+ b_do = (b_do * exp(b_zn[None, :] - b_z) * scale).to(b_do.dtype)
573
+ # [BV, BC]
574
+ b_v = tl.load(p_v, boundary_check=(0, 1))
575
+ b_v = exp(b_v - b_zn[:, None]).to(b_v.dtype)
576
+ # [BC, BC]
577
+ b_dA = tl.dot(b_do, b_v, allow_tf32=False)
578
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
579
+ elif i_i == i_j:
580
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,))
581
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
582
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
583
+ # [BC, BV]
584
+ b_z = tl.load(p_z, boundary_check=(0, 1))
585
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
586
+
587
+ o_i = tl.arange(0, BC)
588
+ o_A = (i_bh + i_v * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
589
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
590
+ for j in range(0, BC):
591
+ # [BV,]
592
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
593
+ # [BC,]
594
+ b_dA = tl.sum(b_do * exp(b_v[None, :] - b_z), 1)
595
+ b_dA = tl.where(o_i >= j, b_dA, 0)
596
+ tl.store(dA + o_A + j, b_dA.to(b_do.dtype), mask=m_A)
597
+
598
+ p_v = tl.advance(p_v, (V,))
599
+
600
+
601
+ @triton.jit(do_not_specialize=['T'])
602
+ def chunk_abc_bwd_kernel_K(
603
+ q,
604
+ k,
605
+ v,
606
+ z,
607
+ h,
608
+ A,
609
+ do,
610
+ dh,
611
+ dq,
612
+ dk,
613
+ dv,
614
+ dA,
615
+ scale,
616
+ T,
617
+ K: tl.constexpr,
618
+ V: tl.constexpr,
619
+ BT: tl.constexpr,
620
+ BK: tl.constexpr,
621
+ BV: tl.constexpr,
622
+ NT: tl.constexpr
623
+ ):
624
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
625
+ i_p = tl.maximum(i_t * BT - 1, 0)
626
+ n_bh = tl.num_programs(2)
627
+
628
+ o_i = tl.arange(0, BT)
629
+ m_s = o_i[:, None] >= o_i[None, :]
630
+
631
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
632
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
633
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
634
+
635
+ # [BT, BK]
636
+ b_q = tl.load(p_q, boundary_check=(0, 1))
637
+ b_k = tl.load(p_k, boundary_check=(0, 1))
638
+ # [BT, BT]
639
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False)
640
+ b_A = tl.where(m_s, b_A, 0.)
641
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
642
+
643
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
644
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
645
+ for i_v in range(tl.cdiv(V, BV)):
646
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
647
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
648
+ p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
649
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))
650
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
651
+
652
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
653
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
654
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
655
+
656
+ # [BV,]
657
+ b_zp = tl.load(p_zp, boundary_check=(0,))
658
+ b_zc = tl.load(p_zc, boundary_check=(0,))
659
+ # [BT, BV]
660
+ b_v = tl.load(p_v, boundary_check=(0, 1))
661
+ b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype)
662
+ b_z = tl.load(p_z, boundary_check=(0, 1))
663
+ b_z = exp(b_zp[None, :] - b_z)
664
+ # [BV, BK]
665
+ b_h = tl.load(p_h, boundary_check=(0, 1))
666
+ # [BT, BV]
667
+ b_do = tl.load(p_do, boundary_check=(0, 1))
668
+ b_do = (b_do * b_z * scale).to(b_do.dtype)
669
+ # [BK, BV]
670
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
671
+
672
+ # [BT, BK]
673
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
674
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
675
+ # [BT, BV]
676
+ b_dv = b_v * tl.dot(b_k, b_dh, allow_tf32=False)
677
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
678
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
679
+ # [BT, BT]
680
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
681
+ # [BT, BK]
682
+ b_dq += tl.dot(b_dA, b_k, allow_tf32=False)
683
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False)
684
+
685
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
686
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
687
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
688
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
689
+
690
+
691
+ @triton.jit(do_not_specialize=['T'])
692
+ def chunk_abc_bwd_kernel_intra_KV(
693
+ v,
694
+ z,
695
+ A,
696
+ do,
697
+ dv,
698
+ T,
699
+ V: tl.constexpr,
700
+ BT: tl.constexpr,
701
+ BC: tl.constexpr,
702
+ BV: tl.constexpr,
703
+ NC: tl.constexpr
704
+ ):
705
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
706
+ i_t, i_i = i_c // NC, i_c % NC
707
+
708
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
709
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T*V,), (1,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,))
710
+ # [BV,]
711
+ b_zn = tl.load(p_zn, boundary_check=(0,))
712
+ # [BC, BV]
713
+ b_v = tl.load(p_v, boundary_check=(0, 1))
714
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
715
+ for i_j in range(i_i + 1, NC):
716
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
717
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1))
718
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
719
+ # [BC, BV]
720
+ b_z = tl.load(p_z, boundary_check=(0, 1))
721
+ b_do = tl.load(p_do, boundary_check=(0, 1))
722
+ b_do = (b_do * exp(b_zn[None, :] - b_z)).to(b_do.dtype)
723
+ # [BC, BC]
724
+ b_A = tl.load(p_A, boundary_check=(0, 1))
725
+ b_dv += tl.dot(b_A, b_do, allow_tf32=False)
726
+ b_dv *= exp(b_v - b_zn[None, :])
727
+
728
+ o_i = tl.arange(0, BC)
729
+ for j in range(0, BC):
730
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
731
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,))
732
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
733
+ # [BC,]
734
+ b_A = tl.load(p_A, boundary_check=(0,))
735
+ # [BV,]
736
+ b_z = tl.load(p_z, boundary_check=(0,))
737
+ b_do = tl.load(p_do, boundary_check=(0,))
738
+ # [BC, BV]
739
+ m_i = o_i[:, None] <= j
740
+ b_dv += tl.where(m_i, exp(b_v - b_z[None, :]) * b_A[:, None] * b_do[None, :], 0.)
741
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
742
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
743
+
744
+
745
+ @triton.jit(do_not_specialize=['T'])
746
+ def chunk_abc_bwd_kernel_rcum_inter(
747
+ s,
748
+ z,
749
+ ss,
750
+ doo,
751
+ T,
752
+ S: tl.constexpr,
753
+ BT: tl.constexpr,
754
+ BS: tl.constexpr,
755
+ NT: tl.constexpr
756
+ ):
757
+ i_m, i_bh = tl.program_id(0), tl.program_id(1)
758
+
759
+ b_sp = tl.zeros([BS,], dtype=tl.float32)
760
+ b_zp = tl.full([BS,], float('inf'), dtype=tl.float32)
761
+ for i_t in range(NT - 1, -1, -1):
762
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
763
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
764
+ p_zc = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT) * S + i_m * BS,), (BS,), (0,))
765
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
766
+ p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
767
+ # [BS,]
768
+ b_zc = tl.load(p_zc, boundary_check=(0,))
769
+ # [BT, BS]
770
+ b_s = tl.load(p_s, boundary_check=(0, 1))
771
+ b_z = tl.load(p_z, boundary_check=(0, 1))
772
+ b_ss = tl.load(p_ss, boundary_check=(0, 1))
773
+
774
+ b_doo = exp(b_s - b_zp[None, :]) * b_sp[None, :]
775
+ tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1))
776
+ # [BS,]
777
+ b_sp = b_sp * exp(b_zc - b_zp) + tl.sum(b_ss * exp(b_zc[None, :] - b_z), 0)
778
+ b_zp = b_zc
779
+
780
+
781
+ @triton.jit(do_not_specialize=['T'])
782
+ def chunk_abc_bwd_kernel_rcum_intra(
783
+ s,
784
+ z,
785
+ ss,
786
+ doo,
787
+ T,
788
+ S: tl.constexpr,
789
+ BT: tl.constexpr,
790
+ BC: tl.constexpr,
791
+ BS: tl.constexpr,
792
+ NC: tl.constexpr
793
+ ):
794
+ i_s, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
795
+ i_t, i_i = i_c // NC, i_c % NC
796
+
797
+ o_i = tl.arange(0, BC)
798
+ m_o = tl.full([BC, BC], 1., dtype=tl.float32)
799
+
800
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0))
801
+ p_zn = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + BC - 1) * S + i_s * BS,), (BS,), (0,))
802
+ p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0))
803
+ # [BC, BS]
804
+ b_s = tl.load(p_s, boundary_check=(0, 1))
805
+ # [BS,]
806
+ b_zn = tl.load(p_zn, boundary_check=(0,))
807
+
808
+ b_doo = tl.zeros([BC, BS], dtype=tl.float32)
809
+ for i_j in range(i_i + 1, NC):
810
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0))
811
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0))
812
+ # [BC, BS]
813
+ b_z = tl.load(p_z, boundary_check=(0, 1))
814
+ b_ss = tl.load(p_ss, boundary_check=(0, 1))
815
+ # [BC, BS]
816
+ b_doo += b_ss * exp(b_zn[None, :] - b_z)
817
+ b_doo = exp(b_s - b_zn[None, :]) * tl.dot(m_o.to(b_s.dtype), b_doo.to(b_s.dtype), allow_tf32=False)
818
+
819
+ for j in range(0, BC):
820
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,))
821
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,))
822
+ # [BS,]
823
+ b_z = tl.load(p_z, boundary_check=(0,))
824
+ b_ss = tl.load(p_ss, boundary_check=(0,))
825
+ # [BC, BS]
826
+ m_i = o_i[:, None] <= j
827
+ b_doo += tl.where(m_i, exp(b_s - b_z[None, :]) * b_ss[None, :], 0.)
828
+ b_doo += tl.load(p_doo, boundary_check=(0, 1))
829
+ tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1))
830
+
831
+
832
+ class ChunkABCFunction(torch.autograd.Function):
833
+
834
+ @staticmethod
835
+ @input_guard
836
+ def forward(ctx, q, k, v, s, initial_state, output_final_state):
837
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
838
+ BT, BC = 64, 16
839
+ BK = min(64, triton.next_power_of_2(K))
840
+ BV = min(64, triton.next_power_of_2(V))
841
+ BM = min(64, triton.next_power_of_2(M))
842
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
843
+ NV, NM = triton.cdiv(V, BV), triton.cdiv(M, BM)
844
+ num_warps = 4 if BK == 64 else 2
845
+ num_stages = 1
846
+
847
+ def fwd_pre(s, B, H, T, S):
848
+ # keep cummulative normalizer in fp32
849
+ z = torch.empty_like(s, dtype=torch.float)
850
+ grid = (B * H,)
851
+ logcumsumexp_fwd_kernel[grid](
852
+ s, z,
853
+ T=T, S=S
854
+ )
855
+ return z
856
+
857
+ def fwd_inner(q, k, v, z, B, H, T, K, V, BT, BK, BV, NT, normk=False, h0=None, ht=None):
858
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
859
+ h = q.new_empty(B, H, NT * K, V)
860
+ grid = (NV, NK, B * H)
861
+ chunk_abc_fwd_kernel_h[grid](
862
+ k, v, z, h, h0, ht,
863
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
864
+ NORMK=normk,
865
+ USE_INITIAL_STATE=h0 is not None,
866
+ STORE_FINAL_STATE=ht is not None,
867
+ num_warps=num_warps,
868
+ num_stages=num_stages
869
+ )
870
+ return h
871
+
872
+ final_state = None
873
+ if output_final_state:
874
+ final_state = (q.new_empty(B, H, K, M, dtype=torch.float),
875
+ q.new_empty(B, H, M, V, dtype=torch.float))
876
+
877
+ z = fwd_pre(s, B, H, T, M)
878
+ scale = K ** -0.5
879
+ hk = fwd_inner(
880
+ q=q, k=k, v=s, z=z,
881
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
882
+ normk=False,
883
+ h0=initial_state[0] if initial_state is not None else None,
884
+ ht=final_state[0] if final_state is not None else None
885
+ )
886
+ ok1 = torch.empty_like(s)
887
+ Ak = q.new_empty(B, H, T, BT)
888
+ grid = (NM, NT, B * H)
889
+ chunk_abc_fwd_kernel_K[grid](
890
+ q, k, z, hk, ok1, Ak,
891
+ scale=scale,
892
+ T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
893
+ num_warps=num_warps,
894
+ num_stages=num_stages
895
+ )
896
+ ok0 = torch.empty_like(s)
897
+ grid = (NM, NT * NC, B * H)
898
+ chunk_abc_fwd_kernel_intra_K[grid](
899
+ s, z, ok0, Ak,
900
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
901
+ num_warps=2,
902
+ num_stages=num_stages
903
+ )
904
+ ok = ok0.add_(ok1)
905
+
906
+ scale = 1.
907
+ # p is kept in fp32 for safe softmax backward
908
+ p = softmax_fwd(ok, dtype=torch.float)
909
+ qv = p.to(q.dtype)
910
+
911
+ scale = 1.
912
+ hv = fwd_inner(
913
+ q=qv, k=s, v=v, z=z,
914
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
915
+ normk=True,
916
+ h0=initial_state[1] if initial_state is not None else None,
917
+ ht=final_state[1] if final_state is not None else None
918
+ )
919
+ Av = q.new_zeros(NM, B, H, T, BT)
920
+ grid = (NM, NT * NC * NC, B * H)
921
+ chunk_abc_fwd_kernel_intra_V[grid](
922
+ qv, s, z, Av,
923
+ scale=scale,
924
+ T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC,
925
+ num_warps=2,
926
+ num_stages=num_stages
927
+ )
928
+ Av = Av.sum(0)
929
+ ov = torch.empty_like(v)
930
+ grid = (NV, NT, B * H)
931
+ chunk_abc_fwd_kernel_V[grid](
932
+ qv, v, z, hv, ov, Av,
933
+ scale=scale,
934
+ T=T,
935
+ K=M,
936
+ V=V,
937
+ BT=BT,
938
+ BK=BM,
939
+ BV=BV,
940
+ NT=NT,
941
+ num_warps=num_warps,
942
+ num_stages=num_stages
943
+ )
944
+ ctx.save_for_backward(q, k, v, s, z, ok, p, hk, hv, Av)
945
+ ctx.BT = BT
946
+ return ov, final_state
947
+
948
+ @staticmethod
949
+ @input_guard
950
+ def backward(ctx, dov, dht=None):
951
+ q, k, v, s, z, ok, p, hk, hv, Av = ctx.saved_tensors
952
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
953
+ BT, BC = ctx.BT, 16
954
+ BK = min(64, triton.next_power_of_2(K))
955
+ BV = min(64, triton.next_power_of_2(V))
956
+ BM = min(64, triton.next_power_of_2(M))
957
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
958
+ NK, NM = triton.cdiv(K, BK), triton.cdiv(M, BM)
959
+ num_warps = 4 if BK == 64 else 2
960
+ num_stages = 1
961
+
962
+ def bwd_inner(q, z, do, B, H, T, K, V, BT, BK, BV, NT, scale, normk=False):
963
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
964
+ dh = q.new_empty(B, H, NT * K, V)
965
+ grid = (NK, NV, B * H)
966
+ chunk_abc_bwd_kernel_dh[grid](
967
+ q, z, do, dh,
968
+ scale=scale,
969
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
970
+ NORMK=normk,
971
+ num_warps=num_warps,
972
+ num_stages=num_stages
973
+ )
974
+ return dh
975
+
976
+ def bwd_post(s, z, ss, B, H, T, S, BT, BC, BS, NT, NC, NS):
977
+ doo = torch.empty_like(s)
978
+ grid = (NS, B * H)
979
+ chunk_abc_bwd_kernel_rcum_inter[grid](
980
+ s, z, ss, doo,
981
+ T=T, S=S, BT=BT, BS=BS, NT=NT,
982
+ num_warps=num_warps,
983
+ num_stages=num_stages
984
+ )
985
+ grid = (NS, NT * NC, B * H)
986
+ chunk_abc_bwd_kernel_rcum_intra[grid](
987
+ s, z, ss, doo,
988
+ T=T, S=S, BT=BT, BC=BC, BS=BS, NC=NC,
989
+ num_warps=num_warps,
990
+ num_stages=num_stages
991
+ )
992
+ return doo
993
+
994
+ scale = 1.
995
+ qv = p.to(q.dtype)
996
+ dhv = bwd_inner(
997
+ qv, z, dov,
998
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
999
+ scale=scale,
1000
+ normk=True
1001
+ )
1002
+ dp1 = torch.empty_like(p)
1003
+ dsv1 = torch.empty_like(s, dtype=torch.float)
1004
+ dv = v.new_empty(NM, *v.shape)
1005
+ dAv = q.new_zeros(B, H, T, BT)
1006
+ grid = (NM, NT, B * H)
1007
+ chunk_abc_bwd_kernel_V[grid](
1008
+ s, v, z, hv, Av, dov, dhv, dp1, dsv1, dv, dAv,
1009
+ scale=scale,
1010
+ T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
1011
+ num_warps=num_warps,
1012
+ num_stages=num_stages
1013
+ )
1014
+ dv = dv.sum(0)
1015
+ dp0 = torch.empty_like(p)
1016
+ dsv0 = s.new_zeros(s.shape, dtype=torch.float)
1017
+ grid = (NM, NT * NC, B * H)
1018
+ chunk_abc_bwd_kernel_intra_V[grid](
1019
+ qv, s, z, dAv, dp0, dsv0,
1020
+ T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC,
1021
+ num_warps=2,
1022
+ num_stages=num_stages
1023
+ )
1024
+ dp = dp1.add_(dp0)
1025
+ dsv = dsv1.add_(dsv0)
1026
+
1027
+ # softmax gradient, equivalent to:
1028
+ # dok = p * (dp - (p * dp).sum(-1, True))
1029
+ dok = softmax_bwd(p, dp, dtype=ok.dtype)
1030
+
1031
+ scale = K ** -0.5
1032
+ dhk = bwd_inner(
1033
+ q, z, dok,
1034
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
1035
+ scale=scale,
1036
+ normk=False
1037
+ )
1038
+ dAk = q.new_zeros(NM, B, H, T, BT)
1039
+ grid = (NM, NT * NC * NC, B * H)
1040
+ chunk_abc_bwd_kernel_intra_K[grid](
1041
+ s, z, dok, dAk,
1042
+ scale=scale,
1043
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
1044
+ num_warps=2,
1045
+ num_stages=num_stages
1046
+ )
1047
+ dAk = dAk.sum(0)
1048
+
1049
+ Ak = q.new_zeros(NK, B, H, T, BT)
1050
+ dq = torch.empty_like(q)
1051
+ dk = torch.empty_like(k)
1052
+ dsk1 = s.new_empty(NK, *s.shape, dtype=torch.float)
1053
+ grid = (NK, NT, B * H)
1054
+ chunk_abc_bwd_kernel_K[grid](
1055
+ q, k, s, z, hk, Ak, dok, dhk, dq, dk, dsk1, dAk,
1056
+ scale=scale,
1057
+ T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
1058
+ num_warps=num_warps,
1059
+ num_stages=num_stages
1060
+ )
1061
+ Ak = Ak.sum(0)
1062
+ dsk1 = dsk1.sum(0)
1063
+ dsk0 = torch.empty_like(s, dtype=torch.float)
1064
+ grid = (NM, NT * NC, B * H)
1065
+ chunk_abc_bwd_kernel_intra_KV[grid](
1066
+ s, z, Ak, dok, dsk0,
1067
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
1068
+ num_warps=2,
1069
+ num_stages=num_stages
1070
+ )
1071
+ ds = dsv.add_(dsk1.add_(dsk0))
1072
+ ds -= bwd_post(s, z, ok * dok + p * dp, B, H, T, M, BT, BC, BM, NT, NC, NM)
1073
+ ds = ds.to(s.dtype)
1074
+ return dq, dk, dv, ds, None, None
1075
+
1076
+
1077
+ @torch.compiler.disable
1078
+ def chunk_abc(
1079
+ q: torch.Tensor,
1080
+ k: torch.Tensor,
1081
+ v: torch.Tensor,
1082
+ s: torch.Tensor,
1083
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1084
+ output_final_state: bool = False,
1085
+ head_first: bool = True
1086
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1087
+ r"""
1088
+ Args:
1089
+ q (torch.Tensor):
1090
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
1091
+ k (torch.Tensor):
1092
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
1093
+ v (torch.Tensor):
1094
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
1095
+ s (torch.Tensor):
1096
+ slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`
1097
+ initial_state (Optional[Tuple[torch.Tensor, torch.Tensor]]):
1098
+ Initial states of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `None`.
1099
+ output_final_state (Optional[bool]):
1100
+ Whether to output the final state of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `False`.
1101
+ head_first (Optional[bool]):
1102
+ Whether the inputs are in the head-first format.
1103
+ Default: `True`.
1104
+
1105
+ Returns:
1106
+ o (torch.Tensor):
1107
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1108
+ final_state (torch.Tensor):
1109
+ Final state of shape `[B, H, K, M]` and `[B, H, M, V]` if `output_final_state=True` else `None`.
1110
+ """
1111
+ if not head_first:
1112
+ q, k, v, s = map(lambda x: x.transpose(1, 2), (q, k, v, s))
1113
+ o, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state)
1114
+ if not head_first:
1115
+ o = o.transpose(1, 2)
1116
+ return o, final_state
fla/ops/abc/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import repeat
7
+
8
+
9
+ def naive_recurrent_abc(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ s: torch.Tensor,
14
+ g: Optional[torch.Tensor] = None,
15
+ scale: Optional[int] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: Optional[bool] = False
18
+ ) -> torch.Tensor:
19
+ dtype = q.dtype
20
+
21
+ NG = q.shape[1]//k.shape[1]
22
+ # [batch_size, n_heads, seq_len, n_slots]
23
+ if g is None:
24
+ z = s.float().logcumsumexp(2)
25
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
26
+ s = torch.exp(s - z)
27
+ q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
28
+ k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g))
29
+ if initial_state is not None:
30
+ initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state))
31
+
32
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
33
+
34
+ hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
35
+ ok = torch.zeros_like(s)
36
+
37
+ if scale is None:
38
+ scale = q.shape[-1] ** -0.5
39
+
40
+ final_state = None
41
+ if initial_state is not None:
42
+ hk += initial_state[0]
43
+
44
+ for i in range(T):
45
+ q_i = q[:, :, i] * scale
46
+ k_i = k[:, :, i]
47
+ v_i = s[:, :, i]
48
+ g_i = g[:, :, i].exp()
49
+ hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
50
+ ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
51
+
52
+ qv = ok.softmax(-1)
53
+ hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
54
+ ov = torch.zeros_like(v)
55
+ if initial_state is not None:
56
+ hv += initial_state[1]
57
+
58
+ for i in range(T):
59
+ q_i = qv[:, :, i]
60
+ k_i = s[:, :, i]
61
+ v_i = v[:, :, i]
62
+ g_i = g[:, :, i].exp()
63
+ hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
64
+ ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
65
+
66
+ if output_final_state:
67
+ final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0])
68
+ return ov.to(dtype), final_state
69
+
70
+
71
+ def naive_cumsum_abc(
72
+ q: torch.Tensor,
73
+ k: torch.Tensor,
74
+ v: torch.Tensor,
75
+ s: torch.Tensor
76
+ ) -> torch.Tensor:
77
+ """
78
+ A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper.
79
+ This is just for demonstration purposes, with no numerical stabilities guaranteed.
80
+ """
81
+
82
+ dtype = q.dtype
83
+ q, k, v, s = map(lambda x: x.float(), (q, k, v, s))
84
+
85
+ scale = q.shape[-1] ** -0.5
86
+ # [batch_size, n_heads, seq_len, n_slots]
87
+ s = (s - s.max(2, True)[0]).exp()
88
+ z = s.cumsum(2)
89
+ # [batch_size, n_heads, seq_len, n_slots, d_head]
90
+ K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
91
+ V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
92
+ # [batch_size, n_heads, seq_len, n_slots]
93
+ p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1)
94
+ # [batch_size, n_heads, seq_len, d_head]
95
+ o = torch.einsum('...m,...md->...d', p, V)
96
+ return o.to(dtype), None
fla/ops/attn/parallel.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
23
+ for num_stages in [2, 3, 4, 5]
24
+ ],
25
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
26
+ )
27
+ @triton.jit
28
+ def parallel_attn_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ o,
33
+ lse,
34
+ scale,
35
+ offsets,
36
+ indices,
37
+ T,
38
+ B: tl.constexpr,
39
+ H: tl.constexpr,
40
+ HQ: tl.constexpr,
41
+ G: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BS: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ USE_OFFSETS: tl.constexpr
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
52
+ i_h = i_hq // G
53
+
54
+ if USE_OFFSETS:
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ else:
59
+ i_n = i_b
60
+ bos, eos = i_n * T, i_n * T + T
61
+
62
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
63
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
64
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
65
+
66
+ # the Q block is kept in the shared memory throughout the whole kernel
67
+ # [BT, BK]
68
+ b_q = tl.load(p_q, boundary_check=(0, 1))
69
+ b_q = (b_q * scale).to(b_q.dtype)
70
+ # [BT, BV]
71
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
72
+
73
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
74
+ b_acc = tl.zeros([BT], dtype=tl.float32)
75
+ for i_s in range(0, i_t * BT, BS):
76
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
77
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
78
+ # [BK, BS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BT, BS]
83
+ b_s = tl.dot(b_q, b_k)
84
+
85
+ # [BT, BS]
86
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
87
+ b_r = exp(b_mp - b_m)
88
+ # [BT, BS]
89
+ b_p = exp(b_s - b_m[:, None])
90
+ # [BT]
91
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
92
+ # [BT, BV]
93
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
94
+
95
+ b_mp = b_m
96
+
97
+ # [BT]
98
+ o_q = i_t * BT + tl.arange(0, BT)
99
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
100
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
101
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
102
+
103
+ # [BS]
104
+ o_k = i_s + tl.arange(0, BS)
105
+ # [BK, BS]
106
+ b_k = tl.load(p_k, boundary_check=(0, 1))
107
+ # [BS, BV]
108
+ b_v = tl.load(p_v, boundary_check=(0, 1))
109
+ # [BT, BS]
110
+ b_s = tl.dot(b_q, b_k)
111
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
112
+
113
+ # [BT]
114
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
115
+ b_r = exp(b_mp - b_m)
116
+ # [BT, BS]
117
+ b_p = exp(b_s - b_m[:, None])
118
+ # [BT]
119
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
120
+ # [BT, BV]
121
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
122
+
123
+ b_mp = b_m
124
+ b_o = b_o / b_acc[:, None]
125
+ b_m += log(b_acc)
126
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
127
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
128
+
129
+
130
+ @triton.jit
131
+ def parallel_attn_bwd_kernel_preprocess(
132
+ o,
133
+ do,
134
+ delta,
135
+ B: tl.constexpr,
136
+ V: tl.constexpr
137
+ ):
138
+ i_n = tl.program_id(0)
139
+ o_d = tl.arange(0, B)
140
+ m_d = o_d < V
141
+
142
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
143
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
144
+ b_delta = tl.sum(b_o * b_do)
145
+
146
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
147
+
148
+
149
+ @triton.heuristics({
150
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
151
+ })
152
+ @triton.autotune(
153
+ configs=[
154
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
155
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
156
+ for num_stages in [2, 3, 4, 5]
157
+ ],
158
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
159
+ )
160
+ @triton.jit(do_not_specialize=['T'])
161
+ def parallel_attn_bwd_kernel_dq(
162
+ q,
163
+ k,
164
+ v,
165
+ lse,
166
+ delta,
167
+ do,
168
+ dq,
169
+ scale,
170
+ offsets,
171
+ indices,
172
+ T,
173
+ B: tl.constexpr,
174
+ H: tl.constexpr,
175
+ HQ: tl.constexpr,
176
+ G: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BS: tl.constexpr,
181
+ BK: tl.constexpr,
182
+ BV: tl.constexpr,
183
+ USE_OFFSETS: tl.constexpr
184
+ ):
185
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
186
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
187
+ i_h = i_hq // G
188
+
189
+ if USE_OFFSETS:
190
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
191
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
192
+ T = eos - bos
193
+ else:
194
+ i_n = i_b
195
+ bos, eos = i_n * T, i_n * T + T
196
+
197
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
198
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
199
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
200
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
201
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
202
+
203
+ # [BT, BK]
204
+ b_q = tl.load(p_q, boundary_check=(0, 1))
205
+ b_q = (b_q * scale).to(b_q.dtype)
206
+ # [BT, BV]
207
+ b_do = tl.load(p_do, boundary_check=(0, 1))
208
+ # [BT]
209
+ b_lse = tl.load(p_lse, boundary_check=(0,))
210
+ b_delta = tl.load(p_delta, boundary_check=(0,))
211
+
212
+ # [BT, BK]
213
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
214
+ for i_s in range(0, i_t * BT, BS):
215
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
216
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
217
+ # [BK, BS]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ # [BV, BS]
220
+ b_v = tl.load(p_v, boundary_check=(0, 1))
221
+
222
+ # [BT, BS]
223
+ b_s = tl.dot(b_q, b_k)
224
+ b_p = exp(b_s - b_lse[:, None])
225
+
226
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
227
+ b_dp = tl.dot(b_do, b_v)
228
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
229
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
230
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
231
+
232
+ # [BT]
233
+ o_q = i_t * BT + tl.arange(0, BT)
234
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
235
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
236
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
237
+ # [BS]
238
+ o_k = i_s + tl.arange(0, BS)
239
+ # [BK, BS]
240
+ b_k = tl.load(p_k, boundary_check=(0, 1))
241
+ # [BV, BS]
242
+ b_v = tl.load(p_v, boundary_check=(0, 1))
243
+
244
+ # [BT, BS]
245
+ b_s = tl.dot(b_q, b_k)
246
+ b_p = exp(b_s - b_lse[:, None])
247
+ b_p = tl.where(o_q[:, None] >= o_k[None, :], b_p, 0)
248
+
249
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
250
+ b_dp = tl.dot(b_do, b_v)
251
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
252
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
253
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
254
+
255
+ b_dq *= scale
256
+
257
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
258
+
259
+
260
+ @triton.heuristics({
261
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
262
+ })
263
+ @triton.autotune(
264
+ configs=[
265
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
266
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
267
+ for num_stages in [2, 3, 4, 5]
268
+ ],
269
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
270
+ )
271
+ @triton.jit(do_not_specialize=['T'])
272
+ def parallel_attn_bwd_kernel_dkv(
273
+ q,
274
+ k,
275
+ v,
276
+ lse,
277
+ delta,
278
+ do,
279
+ dk,
280
+ dv,
281
+ offsets,
282
+ indices,
283
+ scale,
284
+ T,
285
+ B: tl.constexpr,
286
+ H: tl.constexpr,
287
+ HQ: tl.constexpr,
288
+ G: tl.constexpr,
289
+ K: tl.constexpr,
290
+ V: tl.constexpr,
291
+ BT: tl.constexpr,
292
+ BS: tl.constexpr,
293
+ BK: tl.constexpr,
294
+ BV: tl.constexpr,
295
+ USE_OFFSETS: tl.constexpr
296
+ ):
297
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
298
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
299
+ i_h = i_hq // G
300
+
301
+ if USE_OFFSETS:
302
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
303
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
304
+ T = eos - bos
305
+ else:
306
+ i_n = i_b
307
+ bos, eos = i_n * T, i_n * T + T
308
+
309
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
310
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
311
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
312
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
313
+
314
+ # [BT, BK]
315
+ b_k = tl.load(p_k, boundary_check=(0, 1))
316
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
317
+ # [BT, BV]
318
+ b_v = tl.load(p_v, boundary_check=(0, 1))
319
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
320
+
321
+ o_k = i_t * BT + tl.arange(0, BT)
322
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
323
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
324
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
325
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
326
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
327
+
328
+ # [BS]
329
+ o_q = i_s + tl.arange(0, BS)
330
+ # [BS, BK]
331
+ b_q = tl.load(p_q, boundary_check=(0, 1))
332
+ b_q = (b_q * scale).to(b_q.dtype)
333
+ # [BS, BV]
334
+ b_do = tl.load(p_do, boundary_check=(0, 1))
335
+ # [BS]
336
+ b_lse = tl.load(p_lse, boundary_check=(0,))
337
+ b_delta = tl.load(p_delta, boundary_check=(0,))
338
+ # [BT, BS]
339
+ b_s = tl.dot(b_k, tl.trans(b_q))
340
+ b_p = exp(b_s - b_lse[None, :])
341
+ b_p = tl.where(o_k[:, None] <= o_q[None, :], b_p, 0)
342
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
343
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
344
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
345
+ b_dp = tl.dot(b_v, tl.trans(b_do))
346
+ # [BT, BS]
347
+ b_ds = b_p * (b_dp - b_delta[None, :])
348
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
349
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
350
+
351
+ for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS):
352
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
353
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
354
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
355
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
356
+
357
+ # [BS]
358
+ o_q = i_s + tl.arange(0, BS)
359
+ # [BS, BK]
360
+ b_q = tl.load(p_q, boundary_check=(0, 1))
361
+ b_q = (b_q * scale).to(b_q.dtype)
362
+ # [BS, BV]
363
+ b_do = tl.load(p_do, boundary_check=(0, 1))
364
+ # [BS]
365
+ b_lse = tl.load(p_lse, boundary_check=(0,))
366
+ b_delta = tl.load(p_delta, boundary_check=(0,))
367
+ # [BT, BS]
368
+ b_s = tl.dot(b_k, tl.trans(b_q))
369
+ b_p = exp(b_s - b_lse[None, :])
370
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
371
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
372
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
373
+ b_dp = tl.dot(b_v, tl.trans(b_do))
374
+ # [BT, BS]
375
+ b_ds = b_p * (b_dp - b_delta[None, :])
376
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
377
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
378
+
379
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
380
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
381
+
382
+
383
+ def parallel_attn_fwd(
384
+ q: torch.Tensor,
385
+ k: torch.Tensor,
386
+ v: torch.Tensor,
387
+ scale: float,
388
+ chunk_size: int = 128,
389
+ offsets: Optional[torch.LongTensor] = None,
390
+ indices: Optional[torch.LongTensor] = None,
391
+ ):
392
+ B, T, H, K, V = *k.shape, v.shape[-1]
393
+ HQ = q.shape[2]
394
+ G = HQ // H
395
+ BT = chunk_size
396
+ if check_shared_mem('hopper', q.device.index):
397
+ BS = min(64, max(16, triton.next_power_of_2(T)))
398
+ BK = min(256, max(16, triton.next_power_of_2(K)))
399
+ BV = min(256, max(16, triton.next_power_of_2(V)))
400
+ elif check_shared_mem('ampere', q.device.index):
401
+ BS = min(32, max(16, triton.next_power_of_2(T)))
402
+ BK = min(256, max(16, triton.next_power_of_2(K)))
403
+ BV = min(128, max(16, triton.next_power_of_2(V)))
404
+ else:
405
+ BS = min(32, max(16, triton.next_power_of_2(T)))
406
+ BK = min(256, max(16, triton.next_power_of_2(K)))
407
+ BV = min(64, max(16, triton.next_power_of_2(V)))
408
+ NK = triton.cdiv(K, BK)
409
+ NV = triton.cdiv(V, BV)
410
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
411
+ assert NK == 1, "The key dimension can not be larger than 256"
412
+
413
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
414
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
415
+
416
+ grid = (NV, NT, B * HQ)
417
+ parallel_attn_fwd_kernel[grid](
418
+ q=q,
419
+ k=k,
420
+ v=v,
421
+ o=o,
422
+ lse=lse,
423
+ scale=scale,
424
+ offsets=offsets,
425
+ indices=indices,
426
+ B=B,
427
+ T=T,
428
+ H=H,
429
+ HQ=HQ,
430
+ G=G,
431
+ K=K,
432
+ V=V,
433
+ BT=BT,
434
+ BS=BS,
435
+ BK=BK,
436
+ BV=BV,
437
+ )
438
+ return o, lse
439
+
440
+
441
+ def parallel_attn_bwd_preprocess(
442
+ o: torch.Tensor,
443
+ do: torch.Tensor
444
+ ):
445
+ V = o.shape[-1]
446
+ delta = torch.empty_like(o[..., 0], dtype=torch.float32)
447
+ parallel_attn_bwd_kernel_preprocess[(delta.numel(),)](
448
+ o=o,
449
+ do=do,
450
+ delta=delta,
451
+ B=triton.next_power_of_2(V),
452
+ V=V,
453
+ )
454
+ return delta
455
+
456
+
457
+ def parallel_attn_bwd(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ o: torch.Tensor,
462
+ lse: torch.Tensor,
463
+ do: torch.Tensor,
464
+ scale: float = None,
465
+ chunk_size: int = 128,
466
+ offsets: Optional[torch.LongTensor] = None,
467
+ indices: Optional[torch.LongTensor] = None,
468
+ ):
469
+ B, T, H, K, V = *k.shape, v.shape[-1]
470
+ HQ = q.shape[2]
471
+ G = HQ // H
472
+ BT = chunk_size
473
+ BS = max(16, triton.next_power_of_2(T))
474
+ BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS)
475
+ BK = max(16, triton.next_power_of_2(K))
476
+ BV = max(16, triton.next_power_of_2(V))
477
+ NV = triton.cdiv(V, BV)
478
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
479
+
480
+ delta = parallel_attn_bwd_preprocess(o, do)
481
+
482
+ dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
483
+ dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
484
+ dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device)
485
+ grid = (NV, NT, B * HQ)
486
+ parallel_attn_bwd_kernel_dq[grid](
487
+ q=q,
488
+ k=k,
489
+ v=v,
490
+ lse=lse,
491
+ delta=delta,
492
+ do=do,
493
+ dq=dq,
494
+ offsets=offsets,
495
+ indices=indices,
496
+ scale=scale,
497
+ T=T,
498
+ B=B,
499
+ H=H,
500
+ HQ=HQ,
501
+ G=G,
502
+ K=K,
503
+ V=V,
504
+ BT=BT,
505
+ BS=BS,
506
+ BK=BK,
507
+ BV=BV
508
+ )
509
+ parallel_attn_bwd_kernel_dkv[grid](
510
+ q=q,
511
+ k=k,
512
+ v=v,
513
+ lse=lse,
514
+ delta=delta,
515
+ do=do,
516
+ dk=dk,
517
+ dv=dv,
518
+ offsets=offsets,
519
+ indices=indices,
520
+ scale=scale,
521
+ T=T,
522
+ B=B,
523
+ H=H,
524
+ HQ=HQ,
525
+ G=G,
526
+ K=K,
527
+ V=V,
528
+ BT=BT,
529
+ BS=BS,
530
+ BK=BK,
531
+ BV=BV
532
+ )
533
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
534
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
535
+ return dq, dk, dv
536
+
537
+
538
+ @torch.compile
539
+ class ParallelAttentionFunction(torch.autograd.Function):
540
+
541
+ @staticmethod
542
+ @contiguous
543
+ @autocast_custom_fwd
544
+ def forward(ctx, q, k, v, scale, offsets):
545
+ ctx.dtype = q.dtype
546
+
547
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
548
+ # 2-d indices denoting the offsets of chunks in each sequence
549
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
550
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
551
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
552
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
553
+
554
+ o, lse = parallel_attn_fwd(
555
+ q=q,
556
+ k=k,
557
+ v=v,
558
+ scale=scale,
559
+ chunk_size=chunk_size,
560
+ offsets=offsets,
561
+ indices=indices
562
+ )
563
+ ctx.save_for_backward(q, k, v, o, lse)
564
+ ctx.chunk_size = chunk_size
565
+ ctx.offsets = offsets
566
+ ctx.indices = indices
567
+ ctx.scale = scale
568
+ return o.to(q.dtype)
569
+
570
+ @staticmethod
571
+ @contiguous
572
+ @autocast_custom_bwd
573
+ def backward(ctx, do):
574
+ q, k, v, o, lse = ctx.saved_tensors
575
+ dq, dk, dv = parallel_attn_bwd(
576
+ q=q,
577
+ k=k,
578
+ v=v,
579
+ o=o,
580
+ lse=lse,
581
+ do=do,
582
+ scale=ctx.scale,
583
+ chunk_size=ctx.chunk_size,
584
+ offsets=ctx.offsets,
585
+ indices=ctx.indices
586
+ )
587
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
588
+
589
+
590
+ def parallel_attn(
591
+ q: torch.Tensor,
592
+ k: torch.Tensor,
593
+ v: torch.Tensor,
594
+ scale: Optional[float] = None,
595
+ cu_seqlens: Optional[torch.LongTensor] = None,
596
+ head_first: bool = False
597
+ ) -> torch.Tensor:
598
+ r"""
599
+ Args:
600
+ q (torch.Tensor):
601
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
602
+ k (torch.Tensor):
603
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
604
+ GQA will be applied if HQ is divisible by H.
605
+ v (torch.Tensor):
606
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
607
+ scale (Optional[int]):
608
+ Scale factor for attention scores.
609
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
610
+ cu_seqlens (torch.LongTensor):
611
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
612
+ consistent with the FlashAttention API.
613
+ head_first (Optional[bool]):
614
+ Whether the inputs are in the head-first format. Default: `False`.
615
+
616
+ Returns:
617
+ o (torch.Tensor):
618
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
619
+ """
620
+ if scale is None:
621
+ scale = k.shape[-1] ** -0.5
622
+ if cu_seqlens is not None:
623
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
624
+ if head_first:
625
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
626
+ o = ParallelAttentionFunction.apply(q, k, v, scale, cu_seqlens)
627
+ if head_first:
628
+ o = rearrange(o, 'b t h d -> b h t d')
629
+ return o
fla/ops/based/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .fused_chunk import fused_chunk_based
4
+ from .parallel import parallel_based
5
+
6
+ __all__ = [
7
+ 'fused_chunk_based',
8
+ 'parallel_based'
9
+ ]
fla/ops/based/fused_chunk.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
11
+
12
+
13
+ @triton.jit(do_not_specialize=['T'])
14
+ def fused_chunk_based_fwd_kernel(
15
+ q,
16
+ k,
17
+ v,
18
+ o,
19
+ z,
20
+ scale, # K ** -0.5
21
+ T,
22
+ B: tl.constexpr,
23
+ H: tl.constexpr,
24
+ K: tl.constexpr,
25
+ V: tl.constexpr,
26
+ BT: tl.constexpr,
27
+ BK: tl.constexpr,
28
+ BV: tl.constexpr,
29
+ ):
30
+ # indices
31
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
32
+
33
+ o_i = tl.arange(0, BT)
34
+
35
+ # [BT, BT]
36
+ m_s = o_i[:, None] >= o_i[None, :]
37
+
38
+ # [BV], zero-order taylor expansion
39
+ b_h_0o = tl.zeros([BV], dtype=tl.float32)
40
+ # [BK, BV], first-order taylor expansion
41
+ b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)
42
+ # [BK, BK, BV] second-order taylor expansion
43
+ b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
44
+
45
+ # make block pointers
46
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
47
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
48
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
49
+ p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
50
+
51
+ p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)
52
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
53
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
54
+ k_0o = 0
55
+
56
+ for i in range(0, tl.cdiv(T, BT)):
57
+ # [BK, BT]
58
+ b_k = tl.load(p_k, boundary_check=(0, 1))
59
+ # [BK*BK, BT]
60
+ b_k_2o = b_k[:, None, :] * b_k[None, :, :]
61
+ b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)
62
+ # [BT, BV]
63
+ b_v = tl.load(p_v, boundary_check=(0, 1))
64
+ # [BT, BK]
65
+ b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)
66
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
67
+ b_z = tl.zeros([BT], dtype=tl.float32)
68
+
69
+ # interchunk
70
+ # zero-order
71
+ b_o += b_h_0o
72
+ b_z += k_0o
73
+ # first-order
74
+ b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)
75
+ b_z += tl.sum(b_q * k_1o, axis=1)
76
+ # second-order
77
+ b_q_2o = b_q[:, :, None] * b_q[:, None, :]
78
+ b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)
79
+ b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5
80
+ b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5
81
+
82
+ # update running statistics
83
+ k_1o += tl.sum(b_k, axis=1)[None, :]
84
+ k_2o += tl.sum(b_k_2o, axis=1)[None, :]
85
+ k_0o += BT
86
+
87
+ # intrachunk
88
+ # [BT, BT]
89
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
90
+ b_s = 1 + b_s + 0.5 * b_s * b_s
91
+ b_s = tl.where(m_s, b_s, 0)
92
+ b_z += tl.sum(b_s, axis=1)
93
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
94
+ # [TB, BV]
95
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
96
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T)
97
+
98
+ # update hidden state
99
+ # [BK, BV]
100
+ b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)
101
+ b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)
102
+ b_h_0o = b_h_0o + tl.sum(b_v, axis=0)
103
+
104
+ p_q = tl.advance(p_q, (BT, 0))
105
+ p_k = tl.advance(p_k, (0, BT))
106
+ p_v = tl.advance(p_v, (BT, 0))
107
+ p_o = tl.advance(p_o, (BT, 0))
108
+ p_z += BT
109
+
110
+
111
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
112
+ @triton.jit
113
+ def fused_chunk_based_bwd_kernel(
114
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
115
+ q,
116
+ k,
117
+ v,
118
+ do,
119
+ dz,
120
+ dq,
121
+ dk,
122
+ dv,
123
+ scale, # K ** -0.5
124
+ T,
125
+ B: tl.constexpr,
126
+ H: tl.constexpr,
127
+ K: tl.constexpr,
128
+ V: tl.constexpr,
129
+ BT: tl.constexpr,
130
+ BK: tl.constexpr,
131
+ BV: tl.constexpr,
132
+ ):
133
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
134
+
135
+ o_i = tl.arange(0, BT)
136
+ m_s = o_i[:, None] >= o_i[None, :]
137
+
138
+ # [BV], zero-order taylor expansion
139
+ # b_h_0o = tl.zeros([BV], dtype=tl.float32)
140
+ # [BK, BV], first-order taylor expansion
141
+ b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)
142
+ # [BK, BK, BV] second-order taylor expansion
143
+ b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)
144
+
145
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
146
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
147
+
148
+ for i in range(0, tl.cdiv(T, BT)):
149
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
150
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
151
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
152
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
153
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
154
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT
155
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
156
+
157
+ # load tensors
158
+ # [BT, BK]
159
+ b_q = tl.load(p_q, boundary_check=(0, 1))
160
+ b_q = (b_q * scale).to(b_q.dtype)
161
+ b_k = tl.load(p_k, boundary_check=(0, 1))
162
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
163
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)
164
+ # [BV, BT]
165
+ b_v = tl.load(p_v, boundary_check=(0, 1))
166
+
167
+ # inter-chunk
168
+ b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)
169
+ if i_v == 0:
170
+ b_dq += b_dz[:, None] * k_1o
171
+ b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5
172
+ if i_v == 0:
173
+ b_dq_2o += (b_dz[:, None] * k_2o) * 0.5
174
+ b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])
175
+ b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)
176
+ b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)
177
+ b_dq *= scale
178
+
179
+ # intra-chunk
180
+ # [BT, BT]
181
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
182
+ if i_v == 0:
183
+ b_ds += b_dz[:, None]
184
+ b_ds = tl.where(m_s, b_ds, 0) * scale
185
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
186
+ b_s = tl.where(m_s, b_s, 0)
187
+ b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)
188
+
189
+ # store
190
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
191
+
192
+ # update hidden state
193
+ # [BT, BK*BK]
194
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
195
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
196
+ # [BV, BK*BK]
197
+ b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)
198
+ # [BV, BK]
199
+ b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)
200
+
201
+ if i_v == 0:
202
+ # update running statistics
203
+ k_1o += tl.sum(b_k, axis=0)[None, :]
204
+ k_2o += tl.sum(b_k_2o, axis=0)[None, :]
205
+
206
+ tl.debug_barrier()
207
+ b_h_1o = None
208
+ b_h_2o = None
209
+
210
+ # [BK, BV], first-order taylor expansion
211
+ b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)
212
+ # [BK, BK, BV] second-order taylor expansion
213
+ b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
214
+ b_dh_0o = tl.zeros([BV], dtype=tl.float32)
215
+ m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]
216
+
217
+ dq_1o = tl.zeros([1, BK], dtype=tl.float32)
218
+ dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)
219
+
220
+ for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):
221
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BT), (0, 1))
222
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i, i_k * BK), (BT, BK), (1, 0))
223
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0))
224
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0))
225
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (i, i_k*BK), (BT, BK), (1, 0))
226
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (i, i_v*BV), (BT, BV), (1, 0))
227
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i
228
+
229
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
230
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
231
+
232
+ b_q = tl.load(p_q, boundary_check=(0, 1))
233
+ b_k = tl.load(p_k, boundary_check=(0, 1))
234
+ b_v = tl.load(p_v, boundary_check=(0, 1))
235
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
236
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)
237
+ b_q = (b_q * scale).to(b_k.dtype)
238
+
239
+ # intra chunk
240
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
241
+ if i_v == 0:
242
+ b_ds += b_dz[None, :]
243
+ b_ds = tl.where(m_s, b_ds, 0)
244
+ b_s = tl.dot(b_k, b_q, allow_tf32=False)
245
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
246
+ b_s = tl.where(m_s, b_s, 0)
247
+ b_s2 = tl.where(m_s, b_s2, 0)
248
+ b_ds *= (1+b_s)
249
+
250
+ b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)
251
+ b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)
252
+
253
+ # inter chunk
254
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
255
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
256
+
257
+ b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)
258
+ b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)
259
+ b_dv += b_dh_0o
260
+
261
+ b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)
262
+
263
+ if i_v == 0:
264
+ b_dk += dq_1o
265
+
266
+ b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False)
267
+ if i_v == 0:
268
+ b_dk_2o += dq_2o
269
+ b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])
270
+ b_k_fp32 = tl.trans(b_k.to(tl.float32))
271
+ b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)
272
+ b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)
273
+ b_dk += tl.trans(b_dk2)
274
+
275
+ # hidden state update
276
+ b_dh_0o += tl.sum(b_do, axis=0)
277
+ b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)
278
+ b_q_2o = b_q[None, :, :] * b_q[:, None, :]
279
+ b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)
280
+ b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5
281
+
282
+ if i_v == 0:
283
+ dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]
284
+ dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]
285
+
286
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
287
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
288
+
289
+
290
+ class FusedChunkBasedFunction(torch.autograd.Function):
291
+
292
+ @staticmethod
293
+ @input_guard
294
+ @autocast_custom_fwd
295
+ def forward(ctx, q, k, v, scale=1):
296
+ B, H, T, K, V = *k.shape, v.shape[-1]
297
+
298
+ scale = scale
299
+ BT = 16
300
+ BK, BV = min(K, 16), min(V, 32)
301
+ BK, BV = max(BK, 16), max(BV, 16)
302
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
303
+
304
+ num_warps = 4
305
+
306
+ # the norm of o might explode, so we need to use float32 here
307
+ o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
308
+ z = q.new_empty(NK, B, H, T, dtype=torch.float32)
309
+
310
+ grid = (NV, NK, B * H)
311
+ fused_chunk_based_fwd_kernel[grid](
312
+ q, k, v, o, z,
313
+ scale,
314
+ T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV,
315
+ num_warps=num_warps,
316
+ )
317
+ o = o.sum(0)
318
+ z = z.sum(0)
319
+ ctx.save_for_backward(q, k, v)
320
+ ctx.scale = scale
321
+ return o.to(q.dtype), z.to(z.dtype)
322
+
323
+ @staticmethod
324
+ @input_guard
325
+ @autocast_custom_bwd
326
+ def backward(ctx, do, dz):
327
+ q, k, v = ctx.saved_tensors
328
+ B, H, T, K, V = *k.shape, v.shape[-1]
329
+ scale = ctx.scale
330
+
331
+ BT = 16
332
+ BK, BV = min(K, 16), min(V, 32)
333
+ BK, BV = max(BK, 16), max(BV, 16)
334
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
335
+ num_stages = 1
336
+ num_warps = 4
337
+
338
+ dq = q.new_empty(NV, B, H, T, K)
339
+ dk = q.new_empty(NV, B, H, T, K)
340
+ dv = q.new_empty(NK, B, H, T, V)
341
+ grid = (NV, NK, B * H)
342
+
343
+ fused_chunk_based_bwd_kernel[grid](
344
+ q, k, v, do, dz, dq, dk, dv,
345
+ scale,
346
+ T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV,
347
+ num_warps=num_warps,
348
+ num_stages=num_stages
349
+ )
350
+ dq = dq.sum(0)
351
+ dk = dk.sum(0)
352
+ dv = dv.sum(0)
353
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None
354
+
355
+
356
+ def fused_chunk_based(
357
+ q: torch.Tensor,
358
+ k: torch.Tensor,
359
+ v: torch.Tensor,
360
+ scale: Optional[float] = None,
361
+ use_norm: bool = True,
362
+ head_first: bool = True
363
+ ):
364
+ assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'
365
+ if scale is None:
366
+ scale = q.shape[-1] ** -0.5
367
+ if not head_first:
368
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
369
+ o, z = FusedChunkBasedFunction.apply(q, k, v, scale)
370
+ if use_norm:
371
+ o = o / (z[..., None] + 1e-6)
372
+ if not head_first:
373
+ o = o.transpose(1, 2)
374
+ return o.to(q.dtype)
fla/ops/based/naive.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import rearrange
7
+
8
+
9
+ def naive_parallel_based(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ scale: Optional[float] = None,
14
+ use_norm: bool = True
15
+ ):
16
+ if scale is None:
17
+ scale = q.shape[-1] ** -0.5
18
+ q = q * scale
19
+ attn = q @ k.transpose(-2, -1)
20
+ attn = 1 + attn + 1/2 * (attn ** 2)
21
+ attn.masked_fill_(~torch.tril(torch.ones(
22
+ q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
23
+ o = attn @ v
24
+ if use_norm:
25
+ z = attn.sum(-1)
26
+ return o / (z[..., None] + 1e-6)
27
+ else:
28
+ return o
29
+
30
+
31
+ def naive_chunk_based(q, k, v, chunk_size=256):
32
+ q = q * (q.shape[-1] ** -0.5)
33
+ # compute normalizer.
34
+ k_cumsum = torch.cumsum(k, dim=-2)
35
+ kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3)
36
+ # first
37
+ z = (q * k_cumsum).sum(-1)
38
+ # second order
39
+ z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5
40
+ # zero-th order
41
+ z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :]
42
+
43
+ # compute o
44
+ # constant term
45
+ _o = v.cumsum(-2)
46
+
47
+ q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size)
48
+
49
+ k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
50
+ v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
51
+
52
+ intra_chunk_attn = q @ k.transpose(-2, -1)
53
+ intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2)
54
+ intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0)
55
+ o = intra_chunk_attn @ v
56
+
57
+ # quadractic term
58
+ kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v)
59
+ kv = kv.cumsum(2)
60
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
61
+
62
+ o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q)
63
+
64
+ # linear term
65
+ kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v)
66
+ kv = kv.cumsum(2)
67
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
68
+ o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q)
69
+
70
+ o = rearrange(o, 'b h n c d -> b h (n c) d')
71
+ o = o + _o
72
+ return o / (z[..., None] + 1e-6)
fla/ops/based/parallel.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
11
+
12
+ # Based: An Educational and Effective Sequence Mixer
13
+ # https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
14
+
15
+
16
+ @triton.jit(do_not_specialize=['T'])
17
+ def parallel_based_fwd_kernel(
18
+ q,
19
+ k,
20
+ v,
21
+ o,
22
+ z,
23
+ scale,
24
+ T,
25
+ B: tl.constexpr,
26
+ H: tl.constexpr,
27
+ K: tl.constexpr,
28
+ V: tl.constexpr,
29
+ BTL: tl.constexpr,
30
+ BTS: tl.constexpr,
31
+ BK: tl.constexpr,
32
+ BV: tl.constexpr,
33
+ ):
34
+ # i_c: chunk index. used for sequence parallelism
35
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
36
+ NV = tl.cdiv(V, BV)
37
+ i_k = i_kv // (NV)
38
+ i_v = i_kv % (NV)
39
+
40
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
41
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BTS), (0, 1))
42
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BTS, BV), (1, 0))
43
+
44
+ # [BQ, BD] block Q, in the shared memory throughout the whole kernel
45
+ b_q = tl.load(p_q, boundary_check=(0, 1))
46
+ b_q = (b_q * scale).to(b_q.dtype)
47
+ b_o = tl.zeros([BTL, BV], dtype=tl.float32)
48
+ b_z = tl.zeros([BTL], dtype=tl.float32)
49
+
50
+ # Q block and K block have no overlap
51
+ # no need for mask, thereby saving flops
52
+ for _ in range(0, i_c * BTL, BTS):
53
+ # [BK, BTS]
54
+ b_k = tl.load(p_k, boundary_check=(0, 1))
55
+
56
+ # [BTS, BV]
57
+ b_v = tl.load(p_v, boundary_check=(0, 1))
58
+ # [BTL, BTS]
59
+ b_s = tl.dot(b_q, (b_k), allow_tf32=False)
60
+ b_s = 1 + b_s + 0.5 * b_s * b_s
61
+ b_z += tl.sum(b_s, axis=1)
62
+
63
+ # [BQ, BD]
64
+ b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
65
+ p_k = tl.advance(p_k, (0, BTS))
66
+ p_v = tl.advance(p_v, (BTS, 0))
67
+
68
+ # # rescale interchunk output
69
+ tl.debug_barrier()
70
+ o_q = tl.arange(0, BTL)
71
+ # # sync threads, easy for compiler to optimize
72
+ # tl.debug_barrier()
73
+
74
+ o_k = tl.arange(0, BTS)
75
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
76
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
77
+ # Q block and K block have overlap. masks required
78
+ for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
79
+ # [BK, BTS]
80
+ b_k = tl.load(p_k, boundary_check=(0, 1))
81
+ # [BTS, BV]
82
+ b_v = tl.load(p_v, boundary_check=(0, 1))
83
+ # [BTL, BTS]
84
+ m_s = o_q[:, None] >= o_k[None, :]
85
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
86
+ b_s = 1 + b_s + 0.5 * b_s * b_s
87
+ b_s = tl.where(m_s, b_s, 0)
88
+ b_z += tl.sum(b_s, axis=1)
89
+ # [BTL, BV]
90
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
91
+
92
+ p_k = tl.advance(p_k, (0, BTS))
93
+ p_v = tl.advance(p_v, (BTS, 0))
94
+ o_k += BTS
95
+
96
+ p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
97
+ p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)
98
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
99
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T))
100
+
101
+
102
+ @triton.jit
103
+ def _parallel_based_bwd_dq(
104
+ i_bh,
105
+ i_c,
106
+ i_k,
107
+ i_v,
108
+ q,
109
+ k,
110
+ v,
111
+ do,
112
+ dz,
113
+ dq,
114
+ scale,
115
+ T,
116
+ B: tl.constexpr,
117
+ H: tl.constexpr,
118
+ BTL: tl.constexpr,
119
+ BTS: tl.constexpr,
120
+ BK: tl.constexpr,
121
+ BV: tl.constexpr,
122
+ K: tl.constexpr,
123
+ V: tl.constexpr,
124
+ ):
125
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
126
+ p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
127
+ b_q = tl.load(p_q, boundary_check=(0, 1))
128
+ b_q = (b_q * scale).to(b_q.dtype)
129
+
130
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
131
+ b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
132
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BTS, BK), (1, 0))
133
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, 0), (BV, BTS), (0, 1))
134
+ p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)
135
+ b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)
136
+
137
+ for _ in range(0, i_c * BTL, BTS):
138
+ # [BTS, BK]
139
+ b_k = tl.load(p_k, boundary_check=(0, 1))
140
+ # [BV, BTS]
141
+ b_v = tl.load(p_v, boundary_check=(0, 1))
142
+ # [BTL, BTS]
143
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
144
+ if i_v == 0:
145
+ b_ds += b_dz[:, None]
146
+ else:
147
+ b_ds = b_ds
148
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
149
+ # [BQ, BD]
150
+ b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False)
151
+ p_k = tl.advance(p_k, (BTS, 0))
152
+ p_v = tl.advance(p_v, (0, BTS))
153
+
154
+ b_dq *= scale
155
+ o_q = tl.arange(0, BTL)
156
+ o_k = tl.arange(0, BTS)
157
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
158
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
159
+ # Q block and K block have overlap. masks required
160
+ for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
161
+ # [BTS, BK]
162
+ b_k = tl.load(p_k, boundary_check=(0, 1))
163
+ # [BV, BTS]
164
+ b_v = tl.load(p_v, boundary_check=(0, 1))
165
+ # [BTL, BTS]
166
+ m_s = o_q[:, None] >= o_k[None, :]
167
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
168
+ if i_v == 0:
169
+ b_ds += b_dz[:, None]
170
+ else:
171
+ b_ds = b_ds
172
+ b_ds = tl.where(m_s, b_ds, 0) * scale
173
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
174
+ b_s = tl.where(m_s, b_s, 0)
175
+ # [BTL, BK]
176
+ b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), b_k, allow_tf32=False)
177
+ p_k = tl.advance(p_k, (BTS, 0))
178
+ p_v = tl.advance(p_v, (0, BTS))
179
+ o_k += BTS
180
+ p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
181
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
182
+ return
183
+
184
+
185
+ @triton.jit
186
+ def _parallel_based_bwd_dkv(
187
+ i_bh,
188
+ i_c,
189
+ i_k,
190
+ i_v,
191
+ q,
192
+ k,
193
+ v,
194
+ do,
195
+ dz,
196
+ dk,
197
+ dv,
198
+ scale,
199
+ T,
200
+ B: tl.constexpr,
201
+ H: tl.constexpr,
202
+ BTL: tl.constexpr,
203
+ BTS: tl.constexpr,
204
+ BK: tl.constexpr,
205
+ BV: tl.constexpr,
206
+ K: tl.constexpr,
207
+ V: tl.constexpr,
208
+ ):
209
+ # compute dk dv
210
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
211
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
212
+ b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1))
213
+ b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros([BTL, BV], dtype=tl.float32)
214
+
215
+ for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
216
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1))
217
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1))
218
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
219
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
220
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS]
221
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
222
+ b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale # [BTL, BTS]
223
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
224
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
225
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
226
+ if i_v == 0:
227
+ b_ds += b_dz[None, :] * scale
228
+ else:
229
+ b_ds = b_ds
230
+ b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
231
+
232
+ tl.debug_barrier()
233
+ o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
234
+ for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
235
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1))
236
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1))
237
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
238
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
239
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
240
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
241
+ # [BK, BQ]
242
+ m_s = o_k[:, None] <= o_q[None, :]
243
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
244
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
245
+ b_s = tl.where(m_s, b_s, 0)
246
+ b_s2 = tl.where(m_s, b_s2, 0)
247
+
248
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False)
249
+ if i_v == 0:
250
+ b_ds += b_dz[None, :]
251
+ else:
252
+ b_ds = b_ds
253
+ b_ds = tl.where(m_s, b_ds, 0) * scale
254
+ # [BK, BD]
255
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
256
+ b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
257
+ o_q += BTS
258
+
259
+ p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
260
+ p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
261
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
262
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
263
+ return
264
+
265
+
266
+ @triton.jit(do_not_specialize=['T'])
267
+ def parallel_based_bwd_kernel(
268
+ q,
269
+ k,
270
+ v,
271
+ do,
272
+ dz,
273
+ dq,
274
+ dk,
275
+ dv,
276
+ scale,
277
+ T,
278
+ B: tl.constexpr,
279
+ H: tl.constexpr,
280
+ K: tl.constexpr,
281
+ V: tl.constexpr,
282
+ BTL: tl.constexpr,
283
+ BTS: tl.constexpr,
284
+ BK: tl.constexpr,
285
+ BV: tl.constexpr,
286
+ ):
287
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
288
+ NV = tl.cdiv(V, BV)
289
+ i_k = i_kv // (NV)
290
+ i_v = i_kv % NV
291
+ _parallel_based_bwd_dq(
292
+ i_bh, i_c, i_k, i_v,
293
+ q, k, v, do, dz, dq,
294
+ scale, T, B, H, BTL, BTS, BK, BV, K, V
295
+ )
296
+ tl.debug_barrier()
297
+ _parallel_based_bwd_dkv(
298
+ i_bh, i_c, i_k, i_v,
299
+ q, k, v, do, dz, dk, dv,
300
+ scale, T, B, H, BTL, BTS, BK, BV, K, V
301
+ )
302
+
303
+
304
+ class ParallelBasedFunction(torch.autograd.Function):
305
+
306
+ @staticmethod
307
+ @input_guard
308
+ @autocast_custom_fwd
309
+ def forward(ctx, q, k, v, scale):
310
+ BTL, BTS = 128, 32
311
+ assert BTL % BTS == 0
312
+ # assert q.shape[-1] % 16 == 0
313
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
314
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
315
+ BK, BV = max(BK, 16), max(BV, 16)
316
+ B, H, T, K, V = *k.shape, v.shape[-1]
317
+ num_stages = 2
318
+ num_warps = 4
319
+ NK = triton.cdiv(K, BK)
320
+ NV = triton.cdiv(V, BV)
321
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
322
+
323
+ assert NK == 1, "will encounter some synchronization issue if not."
324
+
325
+ o = torch.empty(NK, B, H, T, V, device=q.device)
326
+ z = torch.empty(NK, B, H, T, device=q.device)
327
+ parallel_based_fwd_kernel[grid](
328
+ q, k, v, o, z,
329
+ scale,
330
+ B=B,
331
+ H=H,
332
+ T=T,
333
+ K=K,
334
+ V=V,
335
+ BTL=BTL,
336
+ BTS=BTS,
337
+ BK=BK,
338
+ BV=BV,
339
+ num_warps=num_warps,
340
+ num_stages=num_stages
341
+ )
342
+ ctx.save_for_backward(q, k, v)
343
+ ctx.scale = scale
344
+ return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
345
+
346
+ @staticmethod
347
+ @input_guard
348
+ @autocast_custom_bwd
349
+ def backward(ctx, do, dz):
350
+ q, k, v = ctx.saved_tensors
351
+ scale = ctx.scale
352
+ BTL, BTS = 64, 32
353
+ assert BTL % BTS == 0
354
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
355
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
356
+ BK, BV = max(BK, 16), max(BV, 16)
357
+ B, H, T, K, V = *k.shape, v.shape[-1]
358
+ num_stages = 2
359
+ num_warps = 4
360
+ NK = triton.cdiv(K, BK)
361
+ NV = triton.cdiv(V, BV)
362
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
363
+
364
+ assert NK == 1, "will encounter some synchronization issue if not"
365
+
366
+ dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
367
+ dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
368
+ dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device)
369
+
370
+ parallel_based_bwd_kernel[grid](
371
+ q, k, v, do, dz, dq, dk, dv,
372
+ scale,
373
+ B=B,
374
+ H=H,
375
+ T=T,
376
+ K=K,
377
+ V=V,
378
+ BTL=BTL,
379
+ BTS=BTS,
380
+ BK=BK,
381
+ BV=BV,
382
+ num_warps=num_warps,
383
+ num_stages=num_stages
384
+ )
385
+
386
+ return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
387
+
388
+
389
+ triton_parallel_based = ParallelBasedFunction.apply
390
+
391
+
392
+ def parallel_based(
393
+ q: torch.Tensor,
394
+ k: torch.Tensor,
395
+ v: torch.Tensor,
396
+ scale: Optional[float] = None,
397
+ use_norm: bool = True,
398
+ head_first: bool = True
399
+ ):
400
+ assert q.shape[-1] <= 128, "only support feature dim up to 128"
401
+ if scale is None:
402
+ scale = q.shape[-1] ** -0.5
403
+ if not head_first:
404
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
405
+ o, z = triton_parallel_based(q, k, v, scale)
406
+ if use_norm:
407
+ o = o / (z[..., None] + 1e-6)
408
+ if not head_first:
409
+ o = o.transpose(1, 2)
410
+ return o.to(q.dtype)
fla/ops/common/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
fla/ops/common/chunk_h.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem
13
+
14
+ BKV_LIST = [32, 64] if check_shared_mem() else [16, 32]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in BKV_LIST
26
+ for BV in BKV_LIST
27
+ for num_warps in [1, 2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ split_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BS: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ USE_G: tl.constexpr,
53
+ USE_GK: tl.constexpr,
54
+ USE_GV: tl.constexpr,
55
+ USE_INITIAL_STATE: tl.constexpr,
56
+ STORE_FINAL_STATE: tl.constexpr,
57
+ USE_OFFSETS: tl.constexpr,
58
+ HEAD_FIRST: tl.constexpr
59
+ ):
60
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
61
+ i_n, i_h = i_nh // H, i_nh % H
62
+ if USE_OFFSETS:
63
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
64
+ T = eos - bos
65
+ NT = tl.cdiv(T, BT)
66
+ NS = tl.cdiv(T, BS)
67
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
68
+ else:
69
+ bos, eos = i_n * T, i_n * T + T
70
+ NT = tl.cdiv(T, BT)
71
+ NS = tl.cdiv(T, BS)
72
+ boh = i_n * NS
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ if USE_INITIAL_STATE:
77
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
79
+
80
+ for i_t in range(NT):
81
+ i_s = i_t // (BS // BT)
82
+ if HEAD_FIRST:
83
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
84
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
85
+
86
+ o_h = (i_nh * NS + i_s).to(tl.int64) * K*V
87
+ p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
88
+ else:
89
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
90
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
91
+
92
+ o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
93
+ p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
94
+
95
+ if i_t % (BS // BT) == 0:
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+ # [BK, BT]
98
+ b_k = tl.load(p_k, boundary_check=(0, 1))
99
+ # [BT, BV]
100
+ b_v = tl.load(p_v, boundary_check=(0, 1))
101
+ last_idx = min((i_t + 1) * BT, T) - 1
102
+
103
+ # scalar decay
104
+ if USE_G:
105
+ if HEAD_FIRST:
106
+ b_g_last = tl.load(g + i_nh * T + last_idx)
107
+ p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
108
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
109
+ else:
110
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
111
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
112
+ b_h *= exp(b_g_last)
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+ b_h *= exp(b_gk_last)[:, None]
128
+
129
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
130
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
131
+
132
+ # vector decay, h = h @ Diag(gv)
133
+ if USE_GV:
134
+ if HEAD_FIRST:
135
+ p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
136
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
137
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
138
+ else:
139
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
140
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
141
+
142
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
143
+ b_h *= exp(b_gv_last)[None, :]
144
+
145
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
146
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
147
+
148
+ b_h += tl.dot(b_k, b_v)
149
+
150
+ if STORE_FINAL_STATE:
151
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
153
+
154
+
155
+ @triton.heuristics({
156
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
157
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
158
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
159
+ })
160
+ @triton.autotune(
161
+ configs=[
162
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
163
+ for BK in BKV_LIST
164
+ for BV in BKV_LIST
165
+ for num_warps in [1, 2, 4, 8]
166
+ for num_stages in [2, 3, 4]
167
+ ],
168
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
169
+ )
170
+ @triton.jit(do_not_specialize=['T'])
171
+ def chunk_bwd_kernel_dh(
172
+ q,
173
+ g,
174
+ gk,
175
+ gv,
176
+ do,
177
+ dh,
178
+ dht,
179
+ dh0,
180
+ offsets,
181
+ split_offsets,
182
+ scale,
183
+ T,
184
+ HQ: tl.constexpr,
185
+ H: tl.constexpr,
186
+ K: tl.constexpr,
187
+ V: tl.constexpr,
188
+ BT: tl.constexpr,
189
+ BS: tl.constexpr,
190
+ BK: tl.constexpr,
191
+ BV: tl.constexpr,
192
+ NG: tl.constexpr,
193
+ USE_G: tl.constexpr,
194
+ USE_GK: tl.constexpr,
195
+ USE_GV: tl.constexpr,
196
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
197
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
198
+ USE_OFFSETS: tl.constexpr,
199
+ HEAD_FIRST: tl.constexpr
200
+ ):
201
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
202
+ i_bg = i_nh // NG
203
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
204
+ i_h = i_hq // NG
205
+ if USE_OFFSETS:
206
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
207
+ T = eos - bos
208
+ NT = tl.cdiv(T, BT)
209
+ NS = tl.cdiv(T, BS)
210
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
211
+ else:
212
+ bos, eos = i_n * T, i_n * T + T
213
+ NT = tl.cdiv(T, BT)
214
+ NS = tl.cdiv(T, BS)
215
+ boh = i_n * NS
216
+
217
+ # [BK, BV]
218
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
219
+ if USE_FINAL_STATE_GRADIENT:
220
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
221
+ b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
222
+
223
+ for i_t in range(NT - 1, -1, -1):
224
+ i_s = i_t // (BS // BT)
225
+ if HEAD_FIRST:
226
+ o_dh = (i_nh * NS + i_s).to(tl.int64) * K*V
227
+ p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
228
+ else:
229
+ o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
230
+ p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
231
+
232
+ if i_t % (BS // BT) == 0:
233
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
234
+ last_idx = min(i_t * BT + BT, T) - 1
235
+ # [BK, BT]
236
+ if HEAD_FIRST:
237
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
238
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
239
+ else:
240
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
241
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
242
+ b_q = tl.load(p_q, boundary_check=(0, 1))
243
+ b_q = (b_q * scale).to(b_q.dtype)
244
+ # [BT, BV]
245
+ b_do = tl.load(p_do, boundary_check=(0, 1))
246
+
247
+ if USE_G:
248
+ if HEAD_FIRST:
249
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
250
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
251
+ b_g_last = tl.load(g + i_bg * T + last_idx)
252
+ else:
253
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
254
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
255
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
256
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
257
+
258
+ b_dh *= exp(b_g_last)
259
+
260
+ if USE_GK:
261
+ if HEAD_FIRST:
262
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
263
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
264
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
265
+ else:
266
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
267
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
268
+
269
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
270
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
271
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
272
+ b_dh *= exp(b_gk_last)[:, None]
273
+
274
+ if USE_GV:
275
+ if HEAD_FIRST:
276
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
277
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
278
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
279
+ else:
280
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
281
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
282
+
283
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
284
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
285
+
286
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
287
+ b_dh *= exp(b_gv_last)[None, :]
288
+
289
+ b_dh += tl.dot(b_q, b_do)
290
+
291
+ if STORE_INITIAL_STATE_GRADIENT:
292
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
293
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
294
+
295
+
296
+ def chunk_fwd_h(
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ g: torch.Tensor,
300
+ gk: torch.Tensor,
301
+ gv: torch.Tensor,
302
+ h0: torch.Tensor,
303
+ output_final_state: bool,
304
+ offsets: Optional[torch.Tensor] = None,
305
+ head_first: bool = True,
306
+ chunk_size: int = 64,
307
+ split_size: Optional[int] = None,
308
+ states_in_fp32: bool = False
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ if head_first:
311
+ B, H, T, K, V = *k.shape, v.shape[-1]
312
+ else:
313
+ B, T, H, K, V = *k.shape, v.shape[-1]
314
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
315
+ BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
316
+ assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
317
+ # N: the actual number of sequences in the batch with either equal or variable lengths
318
+ if offsets is None:
319
+ split_offsets, N, NS = None, B, triton.cdiv(T, BS)
320
+ else:
321
+ split_offsets = prepare_chunk_offsets(offsets, BS)
322
+ N, NS = len(offsets) - 1, split_offsets[-1]
323
+
324
+ if head_first:
325
+ h = k.new_empty(B, H, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
326
+ else:
327
+ h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
328
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
329
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
330
+ chunk_fwd_kernel_h[grid](
331
+ k=k,
332
+ v=v,
333
+ h=h,
334
+ g=g,
335
+ gk=gk,
336
+ gv=gv,
337
+ h0=h0,
338
+ ht=ht,
339
+ offsets=offsets,
340
+ split_offsets=split_offsets,
341
+ T=T,
342
+ H=H,
343
+ K=K,
344
+ V=V,
345
+ BT=BT,
346
+ BS=BS,
347
+ USE_G=g is not None,
348
+ USE_GK=gk is not None,
349
+ USE_GV=gv is not None,
350
+ HEAD_FIRST=head_first
351
+ )
352
+ return h, ht
353
+
354
+
355
+ def chunk_bwd_dh(
356
+ q: torch.Tensor,
357
+ k: torch.Tensor,
358
+ v: torch.Tensor,
359
+ g: torch.Tensor,
360
+ gk: torch.Tensor,
361
+ gv: torch.Tensor,
362
+ do: torch.Tensor,
363
+ h0: torch.Tensor,
364
+ dht: torch.Tensor,
365
+ scale: float,
366
+ offsets: Optional[torch.Tensor] = None,
367
+ head_first: bool = True,
368
+ chunk_size: int = 64,
369
+ split_size: Optional[int] = None,
370
+ states_in_fp32: bool = False
371
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
372
+ if head_first:
373
+ B, H, T, K, V = *k.shape, v.shape[-1]
374
+ HQ = q.shape[1]
375
+ else:
376
+ B, T, H, K, V = *k.shape, v.shape[-1]
377
+ HQ = q.shape[2]
378
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
379
+ BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
380
+ assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
381
+ # N: the actual number of sequences in the batch with either equal or variable lengths
382
+ # NG: number of groups in GQA
383
+ if offsets is None:
384
+ split_offsets, N, NS = None, B, triton.cdiv(T, BS)
385
+ else:
386
+ split_offsets = prepare_chunk_offsets(offsets, BS)
387
+ N, NS = len(offsets) - 1, split_offsets[-1]
388
+ NG = HQ // H
389
+
390
+ if head_first:
391
+ dh = k.new_empty(B, HQ, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
392
+ else:
393
+ dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
394
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
395
+
396
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
397
+ chunk_bwd_kernel_dh[grid](
398
+ q=q,
399
+ g=g,
400
+ gk=gk,
401
+ gv=gv,
402
+ do=do,
403
+ dh=dh,
404
+ dht=dht,
405
+ dh0=dh0,
406
+ offsets=offsets,
407
+ split_offsets=split_offsets,
408
+ scale=scale,
409
+ T=T,
410
+ HQ=HQ,
411
+ H=H,
412
+ K=K,
413
+ V=V,
414
+ BT=BT,
415
+ BS=BS,
416
+ NG=NG,
417
+ USE_G=g is not None,
418
+ USE_GK=gk is not None,
419
+ USE_GV=gv is not None,
420
+ HEAD_FIRST=head_first
421
+ )
422
+ return dh, dh0
fla/ops/common/chunk_h_split.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
15
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BK in [32, 64]
22
+ for BV in [32, 64]
23
+ for num_warps in [2, 4, 8]
24
+ for num_stages in [2, 3]
25
+ ],
26
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def chunk_fwd_kernel_h_split(
30
+ k,
31
+ v,
32
+ g,
33
+ gk,
34
+ gv,
35
+ hs,
36
+ hr,
37
+ h0,
38
+ ht,
39
+ offsets,
40
+ split_indices,
41
+ T,
42
+ S: tl.constexpr,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ USE_G: tl.constexpr,
50
+ USE_GK: tl.constexpr,
51
+ USE_GV: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr
56
+ ):
57
+ # handle one split at a time
58
+ # i_h: head index
59
+ # i_n: sequence index
60
+ # i_s: local split index inside a sequence
61
+ i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ i_ss, i_h = i_sh // H, i_sh % H
63
+ if USE_OFFSETS:
64
+ i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32)
65
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
66
+ T = eos - bos
67
+ NS = tl.cdiv(T, S)
68
+ else:
69
+ NS = tl.cdiv(T, S)
70
+ i_n, i_s = i_ss // NS, i_ss % NS
71
+ bos, eos = i_n * T, i_n * T + T
72
+ i_nh = i_n * H + i_h
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ # for the first split, we directly store the state as the final result
77
+ if i_s == 0:
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ b_h += tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
81
+ p_hr = tl.make_block_ptr(hr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1))
83
+ for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)):
84
+ if HEAD_FIRST:
85
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
86
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
87
+ else:
88
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
89
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
90
+ # [BK, BT]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ # [BT, BV]
93
+ b_v = tl.load(p_v, boundary_check=(0, 1))
94
+ last_idx = min(i_t * BT + BT, T) - 1
95
+
96
+ # scalar decay
97
+ if USE_G:
98
+ if HEAD_FIRST:
99
+ b_g_last = tl.load(g + i_nh * T + last_idx)
100
+ p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
101
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
102
+ else:
103
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
104
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
105
+ b_h *= exp(b_g_last)
106
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
107
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
108
+
109
+ # vector decay, h = Diag(gk) @ h
110
+ if USE_GK:
111
+ if HEAD_FIRST:
112
+ p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
113
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
114
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
115
+ else:
116
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
117
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
118
+
119
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
120
+ b_h *= exp(b_gk_last)[:, None]
121
+
122
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
123
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
124
+
125
+ # vector decay, h = h @ Diag(gv)
126
+ if USE_GV:
127
+ if HEAD_FIRST:
128
+ p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
129
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
130
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
131
+ else:
132
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
133
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
134
+
135
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
136
+ b_h *= exp(b_gv_last)[None, :]
137
+
138
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
139
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
140
+
141
+ b_h += tl.dot(b_k, b_v)
142
+
143
+ # if there are more than one splits, we store the result to (unreduced) hs
144
+ # otherwise, we store the result to ht as the final state
145
+ if NS > 1:
146
+ p_hs = tl.make_block_ptr(hs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
147
+ tl.store(p_hs, b_h.to(p_hs.dtype.element_ty), boundary_check=(0, 1))
148
+ elif STORE_FINAL_STATE:
149
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
151
+
152
+
153
+ @triton.heuristics({
154
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
155
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
156
+ })
157
+ @triton.autotune(
158
+ configs=[
159
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
160
+ for BK in [32, 64]
161
+ for BV in [32, 64]
162
+ for num_warps in [2, 4, 8]
163
+ for num_stages in [2, 3, 4]
164
+ ],
165
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
166
+ )
167
+ @triton.jit(do_not_specialize=['T'])
168
+ def chunk_fwd_kernel_h_reduction(
169
+ g,
170
+ gk,
171
+ gv,
172
+ hs,
173
+ hr,
174
+ ht,
175
+ offsets,
176
+ split_offsets,
177
+ T,
178
+ S: tl.constexpr,
179
+ H: tl.constexpr,
180
+ K: tl.constexpr,
181
+ V: tl.constexpr,
182
+ BT: tl.constexpr,
183
+ BK: tl.constexpr,
184
+ BV: tl.constexpr,
185
+ USE_G: tl.constexpr,
186
+ USE_GK: tl.constexpr,
187
+ USE_GV: tl.constexpr,
188
+ STORE_FINAL_STATE: tl.constexpr,
189
+ USE_OFFSETS: tl.constexpr,
190
+ HEAD_FIRST: tl.constexpr
191
+ ):
192
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
193
+ i_n, i_h = i_nh // H, i_nh % H
194
+ if USE_OFFSETS:
195
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
196
+ T = eos - bos
197
+ NS = tl.cdiv(T, S)
198
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
199
+ else:
200
+ bos, eos = i_n * T, i_n * T + T
201
+ NS = tl.cdiv(T, S)
202
+ boh = i_n * NS
203
+
204
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
205
+ # skip the first split
206
+ for i_s in range(1, NS):
207
+ p_hs = tl.make_block_ptr(hs + ((boh + i_s-1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
208
+ p_hr = tl.make_block_ptr(hr + ((boh + i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
209
+ b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32)
210
+ tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1))
211
+
212
+ for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)):
213
+ last_idx = min(i_t * BT + BT, T) - 1
214
+ # scalar decay
215
+ if USE_G:
216
+ if HEAD_FIRST:
217
+ b_g_last = tl.load(g + i_nh * T + last_idx)
218
+ else:
219
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
220
+ b_h *= exp(b_g_last)
221
+
222
+ # vector decay, h = Diag(gk) @ h
223
+ if USE_GK:
224
+ if HEAD_FIRST:
225
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
226
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
227
+ else:
228
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
229
+
230
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
231
+ b_h *= exp(b_gk_last)[:, None]
232
+
233
+ # vector decay, h = h @ Diag(gv)
234
+ if USE_GV:
235
+ if HEAD_FIRST:
236
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
237
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
238
+ else:
239
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
240
+
241
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
242
+ b_h *= exp(b_gv_last)[None, :]
243
+
244
+ if NS > 1:
245
+ if STORE_FINAL_STATE:
246
+ p_hs = tl.make_block_ptr(hs + ((boh + NS-1) * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
247
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
248
+ b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32)
249
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
250
+
251
+
252
+ @triton.heuristics({
253
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
254
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
255
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
256
+ })
257
+ @triton.autotune(
258
+ configs=[
259
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
260
+ for BK in [32, 64]
261
+ for BV in [32, 64]
262
+ for num_warps in [2, 4, 8]
263
+ for num_stages in [2, 3]
264
+ ],
265
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
266
+ )
267
+ @triton.jit(do_not_specialize=['T'])
268
+ def chunk_bwd_kernel_dh_split(
269
+ q,
270
+ g,
271
+ gk,
272
+ gv,
273
+ do,
274
+ dht,
275
+ dhs,
276
+ dhr,
277
+ dh0,
278
+ offsets,
279
+ split_indices,
280
+ scale,
281
+ T,
282
+ S: tl.constexpr,
283
+ HQ: tl.constexpr,
284
+ H: tl.constexpr,
285
+ K: tl.constexpr,
286
+ V: tl.constexpr,
287
+ BT: tl.constexpr,
288
+ BK: tl.constexpr,
289
+ BV: tl.constexpr,
290
+ NG: tl.constexpr,
291
+ USE_G: tl.constexpr,
292
+ USE_GK: tl.constexpr,
293
+ USE_GV: tl.constexpr,
294
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
295
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
296
+ USE_OFFSETS: tl.constexpr,
297
+ HEAD_FIRST: tl.constexpr
298
+ ):
299
+ # handle one split at a time
300
+ # i_h: head index
301
+ # i_n: sequence index
302
+ # i_s: local split index inside a sequence
303
+ i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
304
+ i_ss, i_hq = i_sh // HQ, i_sh % HQ
305
+ if USE_OFFSETS:
306
+ i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32)
307
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
308
+ T = eos - bos
309
+ NS = tl.cdiv(T, S)
310
+ else:
311
+ NS = tl.cdiv(T, S)
312
+ i_n, i_s = i_ss // NS, i_ss % NS
313
+ bos, eos = i_n * T, i_n * T + T
314
+ i_nh = i_n * HQ + i_hq
315
+ i_ng, i_h = i_nh // NG, i_hq // NG
316
+
317
+ # [BK, BV]
318
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
319
+ if i_s == NS - 1:
320
+ if USE_FINAL_STATE_GRADIENT:
321
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
322
+ b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
323
+ p_dhr = tl.make_block_ptr(dhr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1))
325
+
326
+ for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1):
327
+ if HEAD_FIRST:
328
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
329
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
330
+ else:
331
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
332
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
333
+
334
+ b_q = tl.load(p_q, boundary_check=(0, 1))
335
+ b_q = (b_q * scale).to(b_q.dtype)
336
+ # [BT, BV]
337
+ b_do = tl.load(p_do, boundary_check=(0, 1))
338
+
339
+ last_idx = min(i_t * BT + BT, T) - 1
340
+ if USE_G:
341
+ if HEAD_FIRST:
342
+ p_g = g + i_ng * T + i_t * BT + tl.arange(0, BT)
343
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
344
+ b_g_last = tl.load(g + i_ng * T + last_idx)
345
+ else:
346
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
347
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
348
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
349
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
350
+ b_dh *= exp(b_g_last)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_ng * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
356
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
357
+ else:
358
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
359
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
360
+
361
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
362
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
363
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
364
+ b_dh *= exp(b_gk_last)[:, None]
365
+
366
+ if USE_GV:
367
+ if HEAD_FIRST:
368
+ p_gv = tl.make_block_ptr(gv + i_ng * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
369
+ p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
370
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
371
+ else:
372
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
373
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
374
+
375
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
376
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
377
+
378
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
379
+ b_dh *= exp(b_gv_last)[None, :]
380
+
381
+ b_dh += tl.dot(b_q, b_do)
382
+
383
+ if NS > 1:
384
+ p_dhs = tl.make_block_ptr(dhs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
385
+ tl.store(p_dhs, b_dh.to(p_dhs.dtype.element_ty), boundary_check=(0, 1))
386
+ elif STORE_INITIAL_STATE_GRADIENT:
387
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
388
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
389
+
390
+
391
+ @triton.heuristics({
392
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
393
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
394
+ })
395
+ @triton.autotune(
396
+ configs=[
397
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
398
+ for BK in [32, 64]
399
+ for BV in [32, 64]
400
+ for num_warps in [2, 4, 8]
401
+ for num_stages in [2, 3, 4]
402
+ ],
403
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
404
+ )
405
+ @triton.jit(do_not_specialize=['T'])
406
+ def chunk_bwd_kernel_dh_reduction(
407
+ g,
408
+ gk,
409
+ gv,
410
+ dhs,
411
+ dhr,
412
+ dh0,
413
+ offsets,
414
+ split_offsets,
415
+ T,
416
+ S: tl.constexpr,
417
+ H: tl.constexpr,
418
+ HQ: tl.constexpr,
419
+ K: tl.constexpr,
420
+ V: tl.constexpr,
421
+ BT: tl.constexpr,
422
+ BK: tl.constexpr,
423
+ BV: tl.constexpr,
424
+ NG: tl.constexpr,
425
+ USE_G: tl.constexpr,
426
+ USE_GK: tl.constexpr,
427
+ USE_GV: tl.constexpr,
428
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
429
+ USE_OFFSETS: tl.constexpr,
430
+ HEAD_FIRST: tl.constexpr
431
+ ):
432
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
433
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
434
+ i_ng, i_h = i_nh // NG, i_hq // NG
435
+ if USE_OFFSETS:
436
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
437
+ T = eos - bos
438
+ NS = tl.cdiv(T, S)
439
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
440
+ else:
441
+ bos, eos = i_n * T, i_n * T + T
442
+ NS = tl.cdiv(T, S)
443
+ boh = i_n * NS
444
+
445
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
446
+ for i_s in range(NS - 2, -1, -1):
447
+ p_dhs = tl.make_block_ptr(dhs + ((boh+i_s+1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
448
+ p_dhr = tl.make_block_ptr(dhr + ((boh+i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
449
+ b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32)
450
+ tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1))
451
+
452
+ for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1):
453
+ last_idx = min(i_t * BT + BT, T) - 1
454
+ # scalar decay
455
+ if USE_G:
456
+ if HEAD_FIRST:
457
+ b_g_last = tl.load(g + i_ng * T + last_idx)
458
+ else:
459
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
460
+ b_dh *= exp(b_g_last)
461
+
462
+ if USE_GK:
463
+ if HEAD_FIRST:
464
+ p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
465
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
466
+ else:
467
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
468
+
469
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
470
+ b_dh *= exp(b_gk_last)[:, None]
471
+
472
+ if USE_GV:
473
+ if HEAD_FIRST:
474
+ p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
475
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
476
+ else:
477
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
478
+
479
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
480
+ b_dh *= exp(b_gv_last)[None, :]
481
+
482
+ if NS > 1:
483
+ if STORE_INITIAL_STATE_GRADIENT:
484
+ p_dhs = tl.make_block_ptr(dhs + (boh * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
485
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
486
+ b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32)
487
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
488
+
489
+
490
+ def chunk_fwd_h(
491
+ k: torch.Tensor,
492
+ v: torch.Tensor,
493
+ g: torch.Tensor,
494
+ gk: torch.Tensor,
495
+ gv: torch.Tensor,
496
+ h0: torch.Tensor,
497
+ output_final_state: bool,
498
+ offsets: Optional[torch.LongTensor] = None,
499
+ split_offsets: Optional[torch.LongTensor] = None,
500
+ split_indices: Optional[torch.LongTensor] = None,
501
+ head_first: bool = True,
502
+ chunk_size: int = 64,
503
+ split_size: int = 256,
504
+ states_in_fp32: bool = True
505
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
506
+ if head_first:
507
+ B, H, T, K, V = *k.shape, v.shape[-1]
508
+ else:
509
+ B, T, H, K, V = *k.shape, v.shape[-1]
510
+ # B: batch size
511
+ # N: the actual number of sequences in the batch
512
+ # H: number of heads
513
+ # T: sequence length, can be variable across sequences
514
+ # S: split size, a multiple of chunk size
515
+ # BT: chunk size
516
+ S, BT = split_size, chunk_size
517
+ assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}"
518
+ if offsets is None:
519
+ N = B
520
+ NS = N * triton.cdiv(T, S)
521
+ else:
522
+ N = len(offsets) - 1
523
+ NS = split_offsets[-1]
524
+
525
+ # unreduced kv states per split
526
+ hs = k.new_empty(NS, H, K, V, dtype=torch.float)
527
+ # reduced states per split
528
+ hr = k.new_empty(NS, H, K, V, dtype=torch.float if states_in_fp32 else k.dtype)
529
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
530
+ # parallelized over splits
531
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * H)
532
+ chunk_fwd_kernel_h_split[grid](
533
+ k=k,
534
+ v=v,
535
+ g=g,
536
+ gk=gk,
537
+ gv=gv,
538
+ hs=hs,
539
+ hr=hr,
540
+ h0=h0,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ split_indices=split_indices,
544
+ T=T,
545
+ S=S,
546
+ H=H,
547
+ K=K,
548
+ V=V,
549
+ BT=BT,
550
+ USE_G=g is not None,
551
+ USE_GK=gk is not None,
552
+ USE_GV=gv is not None,
553
+ HEAD_FIRST=head_first
554
+ )
555
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
556
+ chunk_fwd_kernel_h_reduction[grid](
557
+ g=g,
558
+ gk=gk,
559
+ gv=gv,
560
+ hs=hs,
561
+ hr=hr,
562
+ ht=ht,
563
+ offsets=offsets,
564
+ split_offsets=split_offsets,
565
+ T=T,
566
+ S=S,
567
+ H=H,
568
+ K=K,
569
+ V=V,
570
+ BT=BT,
571
+ USE_G=g is not None,
572
+ USE_GK=gk is not None,
573
+ USE_GV=gv is not None,
574
+ HEAD_FIRST=head_first
575
+ )
576
+ return hr, ht
577
+
578
+
579
+ def chunk_bwd_dh(
580
+ q: torch.Tensor,
581
+ k: torch.Tensor,
582
+ v: torch.Tensor,
583
+ g: torch.Tensor,
584
+ gk: torch.Tensor,
585
+ gv: torch.Tensor,
586
+ do: torch.Tensor,
587
+ h0: torch.Tensor,
588
+ dht: torch.Tensor,
589
+ scale: float,
590
+ offsets: Optional[torch.Tensor] = None,
591
+ split_offsets: Optional[torch.Tensor] = None,
592
+ split_indices: Optional[torch.Tensor] = None,
593
+ head_first: bool = True,
594
+ chunk_size: int = 64,
595
+ split_size: int = 256,
596
+ states_in_fp32: bool = True
597
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
598
+ if head_first:
599
+ B, H, T, K, V = *k.shape, v.shape[-1]
600
+ HQ = q.shape[1]
601
+ else:
602
+ B, T, H, K, V = *k.shape, v.shape[-1]
603
+ HQ = q.shape[2]
604
+ # B: batch size
605
+ # N: the actual number of sequences in the batch
606
+ # H: number of heads
607
+ # T: sequence length, can be variable across sequences
608
+ # S: split size, a multiple of chunk size
609
+ # BT: chunk size
610
+ S, BT = max(chunk_size, min(split_size, triton.next_power_of_2(T))), chunk_size
611
+ assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}"
612
+ if offsets is None:
613
+ N = B
614
+ NS = N * triton.cdiv(T, S)
615
+ else:
616
+ N = len(offsets) - 1
617
+ NS = split_offsets[-1]
618
+ # number of groups in GQA
619
+ NG = HQ // H
620
+
621
+ dhs = q.new_empty(NS, HQ, K, V, dtype=torch.float)
622
+ dhr = q.new_empty(NS, HQ, K, V, dtype=torch.float if states_in_fp32 else k.dtype)
623
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
624
+
625
+ # parallelized over splits
626
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * HQ)
627
+ chunk_bwd_kernel_dh_split[grid](
628
+ q=q,
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ do=do,
633
+ dht=dht,
634
+ dhs=dhs,
635
+ dhr=dhr,
636
+ dh0=dh0,
637
+ offsets=offsets,
638
+ split_indices=split_indices,
639
+ scale=scale,
640
+ T=T,
641
+ S=S,
642
+ HQ=HQ,
643
+ H=H,
644
+ K=K,
645
+ V=V,
646
+ BT=BT,
647
+ NG=NG,
648
+ USE_G=g is not None,
649
+ USE_GK=gk is not None,
650
+ USE_GV=gv is not None,
651
+ HEAD_FIRST=head_first,
652
+ )
653
+
654
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
655
+ chunk_bwd_kernel_dh_reduction[grid](
656
+ g=g,
657
+ gk=gk,
658
+ gv=gv,
659
+ dhs=dhs,
660
+ dhr=dhr,
661
+ dh0=dh0,
662
+ offsets=offsets,
663
+ split_offsets=split_offsets,
664
+ T=T,
665
+ S=S,
666
+ HQ=HQ,
667
+ H=H,
668
+ K=K,
669
+ V=V,
670
+ BT=BT,
671
+ NG=NG,
672
+ USE_G=g is not None,
673
+ USE_GK=gk is not None,
674
+ USE_GV=gv is not None,
675
+ HEAD_FIRST=head_first
676
+ )
677
+ return dhr, dh0
fla/ops/common/chunk_o.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, safe_exp
11
+ from fla.utils import check_shared_mem, is_nvidia_hopper
12
+
13
+ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
20
+ })
21
+ @triton.autotune(
22
+ configs=[
23
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
24
+ for BK in BKV_LIST
25
+ for BV in BKV_LIST
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT'],
30
+ )
31
+ @triton.jit(do_not_specialize=['T'])
32
+ def chunk_fwd_kernel_o(
33
+ q,
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ o,
39
+ offsets,
40
+ indices,
41
+ scale,
42
+ T,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ USE_G: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr,
51
+ HEAD_FIRST: tl.constexpr
52
+ ):
53
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+
56
+ if USE_OFFSETS:
57
+ i_tg = i_t
58
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
60
+ T = eos - bos
61
+ NT = tl.cdiv(T, BT)
62
+ else:
63
+ NT = tl.cdiv(T, BT)
64
+ i_tg = i_b * NT + i_t
65
+ bos, eos = i_b * T, i_b * T + T
66
+
67
+ s_qk = K if HEAD_FIRST else H*K
68
+ s_vo = V if HEAD_FIRST else H*V
69
+ s_g = 1 if HEAD_FIRST else H
70
+ # offset calculation
71
+ q += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
72
+ k += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
73
+ v += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V)
74
+ o += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V)
75
+ h += ((i_bh * NT + i_t).to(tl.int64) * K*V) if HEAD_FIRST else ((i_tg * H + i_h).to(tl.int64) * K*V)
76
+
77
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
78
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
79
+
80
+ for i_k in range(tl.cdiv(K, BK)):
81
+ p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
82
+ p_k = tl.make_block_ptr(k, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
83
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
84
+ # [BT, BK]
85
+ b_q = tl.load(p_q, boundary_check=(0, 1))
86
+ # [BK, BT]
87
+ b_k = tl.load(p_k, boundary_check=(0, 1))
88
+ # [BK, BV]
89
+ b_h = tl.load(p_h, boundary_check=(0, 1))
90
+
91
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
92
+ b_o += tl.dot(b_q, b_h)
93
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
94
+ b_A += tl.dot(b_q, b_k)
95
+
96
+ if USE_G:
97
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
98
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
99
+ b_g = tl.load(p_g, boundary_check=(0,))
100
+ b_o = b_o * exp(b_g)[:, None]
101
+ b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
102
+
103
+ o_i = tl.arange(0, BT)
104
+ m_A = o_i[:, None] >= o_i[None, :]
105
+ b_A = tl.where(m_A, b_A, 0)
106
+
107
+ p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
108
+ p_o = tl.make_block_ptr(o, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
109
+ b_v = tl.load(p_v, boundary_check=(0, 1))
110
+
111
+ # to fix mma -> mma layout conversion
112
+ # already solved by triton v3.2 or higher
113
+ b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
114
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
115
+
116
+
117
+ @triton.heuristics({
118
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
119
+ 'USE_G': lambda args: args['g'] is not None,
120
+ 'USE_DW': lambda args: args['dw'] is not None
121
+ })
122
+ @triton.autotune(
123
+ configs=[
124
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
125
+ for num_warps in NUM_WARPS
126
+ for num_stages in [2, 3, 4]
127
+ ],
128
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_DW'],
129
+ )
130
+ @triton.jit(do_not_specialize=['T'])
131
+ def chunk_bwd_kernel_dqkwg(
132
+ q,
133
+ k,
134
+ v,
135
+ h,
136
+ g,
137
+ do,
138
+ dh,
139
+ dq,
140
+ dk,
141
+ dg,
142
+ w,
143
+ dv,
144
+ dw,
145
+ offsets,
146
+ indices,
147
+ scale,
148
+ B: tl.constexpr,
149
+ T,
150
+ H: tl.constexpr,
151
+ K: tl.constexpr,
152
+ V: tl.constexpr,
153
+ BT: tl.constexpr,
154
+ BK: tl.constexpr,
155
+ BV: tl.constexpr,
156
+ USE_G: tl.constexpr,
157
+ USE_DW: tl.constexpr,
158
+ USE_OFFSETS: tl.constexpr,
159
+ HEAD_FIRST: tl.constexpr
160
+ ):
161
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
162
+ i_b, i_h = i_bh // H, i_bh % H
163
+ if USE_G:
164
+ dg += i_k * B * H * T
165
+ if USE_OFFSETS:
166
+ i_tg = i_t
167
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
168
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
169
+ T = eos - bos
170
+ NT = tl.cdiv(T, BT)
171
+ else:
172
+ NT = tl.cdiv(T, BT)
173
+ i_tg = i_b * NT + i_t
174
+ bos, eos = i_b * T, i_b * T + T
175
+
176
+ # offset calculation
177
+ v += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
178
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
179
+ h += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
180
+ dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
181
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
182
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
183
+ dq += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
184
+ dk += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
185
+ s_qk = K if HEAD_FIRST else H*K
186
+ s_vo = V if HEAD_FIRST else H*V
187
+ s_g = 1 if HEAD_FIRST else H
188
+
189
+ # for delta rule only
190
+ if USE_DW:
191
+ dw += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
192
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
193
+ w += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
194
+
195
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
196
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
197
+ b_ds = tl.zeros([BT, BT], dtype=tl.float32)
198
+ b_dg_last = tl.zeros([1,], dtype=tl.float32) if USE_G else None
199
+ b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None
200
+
201
+ for i_v in range(tl.cdiv(V, BV)):
202
+ p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
203
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
204
+ p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
205
+ p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
206
+ # [BT, BV]
207
+ b_v = tl.load(p_v, boundary_check=(0, 1))
208
+ b_do = tl.load(p_do, boundary_check=(0, 1))
209
+ # [BV, BK]
210
+ b_h = tl.load(p_h, boundary_check=(0, 1))
211
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
212
+ if USE_G:
213
+ b_dg_last += (tl.sum(b_h * b_dh))
214
+ # [BT, BV] @ [BV, BT] -> [BT, BT]
215
+ b_ds += tl.dot(b_do, tl.trans(b_v))
216
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
217
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
218
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
219
+ b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
220
+ if USE_DW:
221
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
222
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
223
+ b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
224
+
225
+ if USE_DW and not USE_G:
226
+ p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
227
+ tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
228
+
229
+ tl.debug_barrier()
230
+ o_i = tl.arange(0, BT)
231
+ p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
232
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
233
+ b_q = tl.load(p_q, boundary_check=(0, 1))
234
+ b_k = tl.load(p_k, boundary_check=(0, 1))
235
+
236
+ p_dq = tl.make_block_ptr(dq, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
237
+ p_dk = tl.make_block_ptr(dk, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
238
+
239
+ if USE_G:
240
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
241
+ g += i_bh * T if HEAD_FIRST else bos * H + i_h
242
+ dg += i_bh * T if HEAD_FIRST else bos * H + i_h
243
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
244
+ b_g = tl.load(p_g, boundary_check=(0,))
245
+ b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g)
246
+ b_dg_last *= exp(b_g_last)
247
+
248
+ if USE_DW:
249
+ p_w = tl.make_block_ptr(w, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
250
+ p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
251
+ b_w = tl.load(p_w, boundary_check=(0, 1))
252
+ b_dw = b_dw * exp(b_g)[:, None]
253
+ tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
254
+ b_dg -= tl.sum(b_w * b_dw, axis=1)
255
+
256
+ b_dq = b_dq * exp(b_g)[:, None] * scale
257
+ b_dg += tl.sum(b_dq * b_q, axis=1)
258
+
259
+ b_dk = b_dk * safe_exp(-b_g + b_g_last)[:, None]
260
+ b_dg -= tl.sum(b_k * b_dk, axis=1)
261
+ b_dg_last += tl.sum(b_dk * b_k)
262
+
263
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * safe_exp(b_g[:, None] - b_g[None, :]), 0) * scale
264
+ b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k))
265
+ b_dg += tl.sum(b_ds2, axis=1)
266
+ b_dg -= tl.sum(b_ds2, axis=0)
267
+
268
+ b_ds = b_ds.to(b_k.dtype)
269
+ # [BT, BK]
270
+ b_dq += tl.dot(b_ds, b_k)
271
+ b_dk += tl.dot(tl.trans(b_ds), b_q)
272
+ p_dg = tl.make_block_ptr(dg, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
273
+ # (SY 09/21) revcumsum in a separate kernel due to strange triton compiler issue
274
+ # b_dg = tl.dot(tl.where(o_i[:, None] <= o_i[None, :], 1., 0.), b_dg, allow_tf32=False) + b_dg_last)
275
+ b_dg = tl.where(o_i < min(BT, T-i_t*BT) - 1, b_dg, b_dg + b_dg_last)
276
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
277
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
278
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
279
+ else:
280
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0)
281
+ b_ds = b_ds.to(b_k.dtype)
282
+ b_dq += tl.dot(b_ds, b_k)
283
+ b_dk += tl.dot(tl.trans(b_ds), b_q) * scale
284
+ b_dq *= scale
285
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
286
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
287
+
288
+
289
+ @triton.heuristics({
290
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
291
+ 'USE_G': lambda args: args['g'] is not None,
292
+ })
293
+ @triton.autotune(
294
+ configs=[
295
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
296
+ for num_warps in [2, 4, 8]
297
+ for num_stages in [2, 3, 4]
298
+ ],
299
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
300
+ )
301
+ @triton.jit(do_not_specialize=['T'])
302
+ def chunk_bwd_kernel_dv(
303
+ q,
304
+ k,
305
+ g,
306
+ do,
307
+ dv,
308
+ dh,
309
+ offsets,
310
+ indices,
311
+ scale,
312
+ T,
313
+ H: tl.constexpr,
314
+ K: tl.constexpr,
315
+ V: tl.constexpr,
316
+ BT: tl.constexpr,
317
+ BK: tl.constexpr,
318
+ BV: tl.constexpr,
319
+ USE_G: tl.constexpr,
320
+ USE_OFFSETS: tl.constexpr,
321
+ HEAD_FIRST: tl.constexpr
322
+ ):
323
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
324
+ i_b, i_h = i_bh // H, i_bh % H
325
+ if USE_OFFSETS:
326
+ i_tg = i_t
327
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
328
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
329
+ T = eos - bos
330
+ NT = tl.cdiv(T, BT)
331
+ else:
332
+ NT = tl.cdiv(T, BT)
333
+ i_tg = i_b * NT + i_t
334
+ bos, eos = i_b * T, i_b * T + T
335
+
336
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
337
+
338
+ # offset calculation
339
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
340
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
341
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
342
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
343
+ s_qk = K if HEAD_FIRST else H*K
344
+ s_vo = V if HEAD_FIRST else H*V
345
+ s_g = 1 if HEAD_FIRST else H
346
+ dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
347
+
348
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
349
+ for i_k in range(tl.cdiv(K, BK)):
350
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
351
+ p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
352
+ b_q = tl.load(p_q, boundary_check=(0, 1))
353
+ b_k = tl.load(p_k, boundary_check=(0, 1))
354
+ b_A += tl.dot(b_k, b_q)
355
+ p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
356
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
357
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype))
358
+
359
+ if USE_G:
360
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
361
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
362
+ b_g = tl.load(p_g, boundary_check=(0,))
363
+ b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g)
364
+ b_dv *= safe_exp(-b_g + b_g_last)[:, None]
365
+
366
+ mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :])
367
+ if USE_G:
368
+ b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
369
+ else:
370
+ b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty)
371
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
372
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
373
+ b_do = tl.load(p_do, boundary_check=(0, 1))
374
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do)
375
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
376
+
377
+
378
+ @triton.heuristics({
379
+ 'USE_G': lambda args: args['g'] is not None,
380
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
381
+ })
382
+ @triton.autotune(
383
+ configs=[
384
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
385
+ for num_warps in NUM_WARPS
386
+ for num_stages in [2, 3, 4]
387
+ ],
388
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
389
+ )
390
+ @triton.jit(do_not_specialize=['T'])
391
+ def chunk_bwd_kernel_dv_local(
392
+ q,
393
+ k,
394
+ g,
395
+ do,
396
+ dv,
397
+ offsets,
398
+ indices,
399
+ scale,
400
+ T,
401
+ H: tl.constexpr,
402
+ K: tl.constexpr,
403
+ V: tl.constexpr,
404
+ BT: tl.constexpr,
405
+ BK: tl.constexpr,
406
+ BV: tl.constexpr,
407
+ USE_G: tl.constexpr,
408
+ USE_OFFSETS: tl.constexpr,
409
+ HEAD_FIRST: tl.constexpr
410
+ ):
411
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
412
+ i_b, i_h = i_bh // H, i_bh % H
413
+ if USE_OFFSETS:
414
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
415
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
416
+ T = eos - bos
417
+ else:
418
+ bos, eos = i_b * T, i_b * T + T
419
+
420
+ # offset calculation
421
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
422
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
423
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
424
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
425
+ s_qk = K if HEAD_FIRST else H*K
426
+ s_vo = V if HEAD_FIRST else H*V
427
+ s_g = 1 if HEAD_FIRST else H
428
+
429
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
430
+ for i_k in range(tl.cdiv(K, BK)):
431
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
432
+ p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
433
+ b_q = tl.load(p_q, boundary_check=(0, 1))
434
+ b_k = tl.load(p_k, boundary_check=(0, 1))
435
+ b_A += tl.dot(b_k, b_q)
436
+
437
+ if USE_G:
438
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
439
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
440
+ b_g = tl.load(p_g, boundary_check=(0,))
441
+
442
+ mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :])
443
+ if USE_G:
444
+ b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
445
+ else:
446
+ b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty)
447
+
448
+ for i_v in range(tl.cdiv(V, BV)):
449
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
450
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
451
+ b_do = tl.load(p_do, boundary_check=(0, 1))
452
+ b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
453
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
454
+
455
+
456
+ def chunk_fwd_o(
457
+ q: torch.Tensor,
458
+ k: torch.Tensor,
459
+ v: torch.Tensor,
460
+ h: torch.Tensor,
461
+ g: Optional[torch.Tensor] = None, # cumsum of log decay
462
+ scale: Optional[float] = None,
463
+ offsets: Optional[torch.LongTensor] = None,
464
+ indices: Optional[torch.LongTensor] = None,
465
+ head_first: bool = True,
466
+ chunk_size: int = 64
467
+ ) -> torch.Tensor:
468
+ if head_first:
469
+ B, H, T, K, V = *q.shape, v.shape[-1]
470
+ else:
471
+ B, T, H, K, V = *q.shape, v.shape[-1]
472
+ if scale is None:
473
+ scale = k.shape[-1] ** -0.5
474
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
475
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
476
+
477
+ o = torch.empty_like(v)
478
+
479
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
480
+ chunk_fwd_kernel_o[grid](
481
+ q,
482
+ k,
483
+ v,
484
+ h,
485
+ g,
486
+ o,
487
+ offsets,
488
+ indices,
489
+ scale,
490
+ T=T,
491
+ H=H,
492
+ K=K,
493
+ V=V,
494
+ BT=BT,
495
+ HEAD_FIRST=head_first
496
+ )
497
+ return o
498
+
499
+
500
+ def chunk_bwd_dv(
501
+ q: torch.Tensor,
502
+ k: torch.Tensor,
503
+ g: torch.Tensor,
504
+ do: torch.Tensor,
505
+ dh: torch.Tensor,
506
+ scale: float,
507
+ offsets: Optional[torch.LongTensor] = None,
508
+ indices: Optional[torch.LongTensor] = None,
509
+ head_first: bool = True,
510
+ chunk_size: int = 64
511
+ ) -> torch.Tensor:
512
+ if head_first:
513
+ B, H, T, K, V = *k.shape, do.shape[-1]
514
+ else:
515
+ B, T, H, K, V = *k.shape, do.shape[-1]
516
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
517
+ # H100 can have larger block size
518
+ if check_shared_mem('hopper', k.device.index):
519
+ CONST_TILING = 128
520
+ elif check_shared_mem:
521
+ CONST_TILING = 64
522
+ else:
523
+ CONST_TILING = 32
524
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
525
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
526
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
527
+ NV = triton.cdiv(V, BV)
528
+
529
+ dv = torch.empty_like(do)
530
+ grid = (NV, NT, B * H)
531
+ chunk_bwd_kernel_dv[grid](
532
+ q,
533
+ k,
534
+ g,
535
+ do,
536
+ dv,
537
+ dh,
538
+ offsets,
539
+ indices,
540
+ scale,
541
+ T=T,
542
+ H=H,
543
+ K=K,
544
+ V=V,
545
+ BT=BT,
546
+ BK=BK,
547
+ BV=BV,
548
+ HEAD_FIRST=head_first
549
+ )
550
+ return dv
551
+
552
+
553
+ def chunk_bwd_dv_local(
554
+ q: torch.Tensor,
555
+ k: torch.Tensor,
556
+ g: torch.Tensor,
557
+ do: torch.Tensor,
558
+ dh: torch.Tensor,
559
+ scale: float,
560
+ offsets: Optional[torch.LongTensor] = None,
561
+ indices: Optional[torch.LongTensor] = None,
562
+ head_first: bool = True,
563
+ chunk_size: int = 64
564
+ ) -> torch.Tensor:
565
+ if head_first:
566
+ B, H, T, K, V = *k.shape, do.shape[-1]
567
+ else:
568
+ B, T, H, K, V = *k.shape, do.shape[-1]
569
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
570
+ # H100 can have larger block size
571
+ if check_shared_mem('hopper', k.device.index):
572
+ CONST_TILING = 128
573
+ elif check_shared_mem:
574
+ CONST_TILING = 64
575
+ else:
576
+ CONST_TILING = 32
577
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
578
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
579
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
580
+
581
+ dv = torch.empty_like(do)
582
+ grid = (NT, B * H)
583
+ chunk_bwd_kernel_dv_local[grid](
584
+ q,
585
+ k,
586
+ g,
587
+ do,
588
+ dv,
589
+ offsets,
590
+ indices,
591
+ scale,
592
+ T=T,
593
+ H=H,
594
+ K=K,
595
+ V=V,
596
+ BT=BT,
597
+ BK=BK,
598
+ BV=BV,
599
+ HEAD_FIRST=head_first
600
+ )
601
+ return dv
602
+
603
+
604
+ def chunk_bwd_dqkwg(
605
+ q: torch.Tensor,
606
+ k: torch.Tensor,
607
+ v: torch.Tensor,
608
+ g: torch.Tensor,
609
+ do: torch.Tensor,
610
+ h: torch.Tensor,
611
+ dh: torch.Tensor,
612
+ dv: Optional[torch.Tensor] = None,
613
+ w: Optional[torch.Tensor] = None,
614
+ offsets: Optional[torch.LongTensor] = None,
615
+ indices: Optional[torch.LongTensor] = None,
616
+ chunk_size: int = 64,
617
+ scale: float = 1.0,
618
+ head_first: bool = True,
619
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
620
+
621
+ if head_first:
622
+ B, H, T, K, V = *k.shape, v.shape[-1]
623
+ else:
624
+ B, T, H, K, V = *k.shape, v.shape[-1]
625
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
626
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
627
+
628
+ CONST_TILING = 64 if check_shared_mem() else 32
629
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
630
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
631
+ NK = triton.cdiv(K, BK)
632
+ dq = torch.empty_like(q)
633
+ dk = torch.empty_like(k)
634
+ dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
635
+ dw = torch.empty_like(w) if w is not None else None
636
+
637
+ grid = (NK, NT, B * H)
638
+ chunk_bwd_kernel_dqkwg[grid](
639
+ q=q,
640
+ k=k,
641
+ v=v,
642
+ h=h,
643
+ g=g,
644
+ do=do,
645
+ dh=dh,
646
+ dv=dv,
647
+ w=w,
648
+ dw=dw,
649
+ dq=dq,
650
+ dk=dk,
651
+ dg=dg,
652
+ offsets=offsets,
653
+ indices=indices,
654
+ scale=scale,
655
+ B=B,
656
+ T=T,
657
+ H=H,
658
+ K=K,
659
+ V=V,
660
+ BT=BT,
661
+ BK=BK,
662
+ BV=BV,
663
+ HEAD_FIRST=head_first
664
+ )
665
+
666
+ if dg is not None:
667
+ dg = dg.sum(0)
668
+ return dq, dk, dw, dg
fla/ops/common/chunk_scaled_dot_kkt.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_indices
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
15
+ })
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
19
+ for BK in [32, 64, 128]
20
+ for num_warps in [2, 4, 8]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['H', 'K', 'BT', 'USE_OFFSETS'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def chunk_scaled_dot_kkt_fwd_kernel(
27
+ k,
28
+ beta,
29
+ A,
30
+ offsets,
31
+ indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ K: tl.constexpr,
35
+ BT: tl.constexpr,
36
+ BK: tl.constexpr,
37
+ HEAD_FIRST: tl.constexpr,
38
+ USE_OFFSETS: tl.constexpr,
39
+ ):
40
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
41
+ i_b, i_h = i_bh // H, i_bh % H
42
+ if USE_OFFSETS:
43
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_b * T, i_b * T + T
48
+ o_t = tl.arange(0, BT)
49
+
50
+ if HEAD_FIRST:
51
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
52
+ else:
53
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
54
+ b_beta = tl.load(p_beta, boundary_check=(0,))
55
+
56
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
57
+ for i_k in range(tl.cdiv(K, BK)):
58
+ if HEAD_FIRST:
59
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
60
+ else:
61
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
62
+ b_k = tl.load(p_k, boundary_check=(0, 1))
63
+ b_kb = b_k * b_beta[:, None]
64
+ b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
65
+
66
+ b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
67
+ if HEAD_FIRST:
68
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
69
+ else:
70
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
72
+
73
+
74
+ def chunk_scaled_dot_kkt_fwd(
75
+ k: torch.Tensor,
76
+ beta: torch.Tensor,
77
+ cu_seqlens: Optional[torch.LongTensor],
78
+ head_first: bool = False,
79
+ chunk_size: int = 64,
80
+ output_dtype: torch.dtype = torch.float32
81
+ ) -> torch.Tensor:
82
+ r"""
83
+ Compute beta * K * K^T.
84
+
85
+ Args:
86
+ k (torch.Tensor):
87
+ The key tensor of shape `[B, T, H, K]` if not `head_first` else `[B, H, T, K]`.
88
+ beta (torch.Tensor):
89
+ The beta tensor of shape `[B, T, H]` if not `head_first` else `[B, H, T]`.
90
+ cu_seqlens (torch.LongTensor):
91
+ The cumulative sequence lengths of the input tensor.
92
+ Default: None
93
+ head_first (bool):
94
+ If False, the input/output tensor is in the shape of `[B, T, H, K]`.
95
+ If True, the input/output tensor is in the shape of `[B, H, T, K]`.
96
+ Default: False
97
+ chunk_size (int):
98
+ The chunk size. Default: 64.
99
+ output_dtype (torch.dtype):
100
+ The dtype of the output tensor. Default: `torch.float32`
101
+
102
+ Returns:
103
+ beta * K * K^T of shape `[B, T, H, BT]` if not `head_first` else `[B, H, T, BT]`,
104
+ where `BT` is the chunk size.
105
+ """
106
+ if head_first:
107
+ B, H, T, K = k.shape
108
+ else:
109
+ B, T, H, K = k.shape
110
+ BT = chunk_size
111
+ indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
112
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices)
113
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=output_dtype)
114
+ chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
115
+ k=k,
116
+ beta=beta,
117
+ A=A,
118
+ offsets=cu_seqlens,
119
+ indices=indices,
120
+ T=T,
121
+ H=H,
122
+ K=K,
123
+ BT=BT,
124
+ HEAD_FIRST=head_first
125
+ )
126
+ return A
fla/ops/common/fused_recurrent.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils import chunk_global_cumsum
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps)
23
+ for num_warps in [1, 2, 4]
24
+ ],
25
+ key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"],
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ g,
33
+ gk,
34
+ gv,
35
+ o,
36
+ h0,
37
+ ht,
38
+ offsets,
39
+ scale,
40
+ T,
41
+ B: tl.constexpr,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ V: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ REVERSE: tl.constexpr,
48
+ USE_G: tl.constexpr,
49
+ USE_GK: tl.constexpr,
50
+ USE_GV: tl.constexpr,
51
+ USE_INITIAL_STATE: tl.constexpr,
52
+ STORE_FINAL_STATE: tl.constexpr,
53
+ USE_OFFSETS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr
55
+ ):
56
+ # indices
57
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
61
+ all = T
62
+ T = eos - bos
63
+ else:
64
+ bos, eos = i_n * T, i_n * T + T
65
+ all = B * T
66
+
67
+ if HEAD_FIRST:
68
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
69
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
70
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
71
+ p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
72
+ if USE_G:
73
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
74
+ if USE_GK:
75
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
76
+ if USE_GV:
77
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
78
+ else:
79
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
80
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
81
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
82
+ p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
83
+ if USE_G:
84
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
85
+ if USE_GK:
86
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
87
+ if USE_GV:
88
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
89
+
90
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
91
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
92
+ mask_h = mask_k[None, :] & mask_v[:, None]
93
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
94
+
95
+ if USE_INITIAL_STATE:
96
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
97
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
98
+
99
+ for _ in range(0, T):
100
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
101
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
102
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
103
+ if USE_GK:
104
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
105
+ b_h = b_h * exp(b_gk[None, :])
106
+ if USE_GV:
107
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
108
+ b_h = b_h * exp(b_gv[:, None])
109
+ if USE_G:
110
+ b_g = tl.load(p_g).to(tl.float32)
111
+ b_h = b_h * exp(b_g)
112
+ b_h += b_k[None, :] * b_v[:, None]
113
+ b_o = b_h * b_q[None, :]
114
+ b_o = tl.sum(b_o, axis=1)
115
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
116
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
117
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
118
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
119
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
120
+ if USE_GK:
121
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
122
+ if USE_GV:
123
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
124
+ if USE_G:
125
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
126
+
127
+ if STORE_FINAL_STATE:
128
+ p_ht = ht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
129
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
130
+
131
+
132
+ @triton.heuristics({
133
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
134
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
135
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
136
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
137
+ })
138
+ @triton.autotune(
139
+ configs=[
140
+ triton.Config({}, num_warps=num_warps)
141
+ for num_warps in [1, 2, 4]
142
+ ],
143
+ key=['BK', 'BV', 'USE_GK', 'USE_GV', 'USE_G'],
144
+ )
145
+ @triton.jit(do_not_specialize=['T'])
146
+ def fused_recurrent_bwd_kernel(
147
+ q,
148
+ k,
149
+ v,
150
+ g,
151
+ gk,
152
+ gv,
153
+ h0,
154
+ do,
155
+ dq,
156
+ dk,
157
+ dv,
158
+ dht,
159
+ dh0,
160
+ offsets,
161
+ scale,
162
+ T,
163
+ B: tl.constexpr,
164
+ H: tl.constexpr,
165
+ K: tl.constexpr,
166
+ V: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ BV: tl.constexpr,
169
+ REVERSE: tl.constexpr,
170
+ USE_G: tl.constexpr,
171
+ USE_GK: tl.constexpr,
172
+ USE_GV: tl.constexpr,
173
+ USE_INITIAL_STATE: tl.constexpr,
174
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
175
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
176
+ USE_OFFSETS: tl.constexpr,
177
+ HEAD_FIRST: tl.constexpr
178
+ ):
179
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
180
+ i_n, i_h = i_nh // H, i_nh % H
181
+ if USE_OFFSETS:
182
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
183
+ all = T
184
+ T = eos - bos
185
+ else:
186
+ bos, eos = i_n * T, i_n * T + T
187
+ all = B * T
188
+
189
+ if HEAD_FIRST:
190
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
191
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
192
+ p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
193
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
194
+ if USE_G:
195
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
196
+ if USE_GK:
197
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
198
+ if USE_GV:
199
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
200
+ else:
201
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
202
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
203
+ p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
204
+ p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
205
+ if USE_G:
206
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
207
+ if USE_GK:
208
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
209
+ if USE_GV:
210
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
211
+
212
+ mask_k = i_k * BK + tl.arange(0, BK) < K
213
+ mask_v = i_v * BV + tl.arange(0, BV) < V
214
+ mask_h = mask_k[:, None] & mask_v[None, :]
215
+
216
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
217
+ if USE_INITIAL_STATE:
218
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
219
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
220
+
221
+ for _ in range(0, T):
222
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
223
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
224
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
225
+ if USE_G:
226
+ b_g = tl.load(p_g).to(tl.float32)
227
+ b_h = b_h * exp(b_g)
228
+ if USE_GK:
229
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
230
+ b_h = b_h * exp(b_gk[:, None])
231
+ if USE_GV:
232
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
233
+ b_h = b_h * exp(b_gv[None, :])
234
+ b_h += b_k[:, None] * b_v[None, :]
235
+ b_dq = b_h * b_do[None, :]
236
+ b_dq = tl.sum(b_dq, axis=1) * scale
237
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k)
238
+
239
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
240
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
241
+ p_do += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
242
+ p_dq += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
243
+ if USE_G:
244
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
245
+ if USE_GK:
246
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
247
+ if USE_GV:
248
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
249
+
250
+ # sync threads
251
+ tl.debug_barrier()
252
+
253
+ if HEAD_FIRST:
254
+ p_q = q + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
255
+ p_k = k + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
256
+ p_v = v + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
257
+ p_do = do + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
258
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
259
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
260
+ if USE_G:
261
+ p_g = g + i_nh * T + ((T - 1) if not REVERSE else 0)
262
+ if USE_GK:
263
+ p_gk = gk + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
264
+ if USE_GV:
265
+ p_gv = gv + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
266
+ else:
267
+ p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
268
+ p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
269
+ p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
270
+ p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
271
+ p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
272
+ p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
273
+ if USE_G:
274
+ p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h
275
+ if USE_GK:
276
+ p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
277
+ if USE_GV:
278
+ p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
279
+
280
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
281
+ if USE_FINAL_STATE_GRADIENT:
282
+ p_dht = dht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
283
+ b_dh += tl.load(p_dht, mask=mask_h, other=0).to(tl.float32)
284
+
285
+ for _ in range(T):
286
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
287
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
288
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
289
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
290
+ b_dh += b_q[:, None] * b_do[None, :]
291
+ b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
292
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
293
+ if USE_G:
294
+ b_g = tl.load(p_g).to(tl.float32)
295
+ b_dh *= exp(b_g)
296
+ if USE_GK:
297
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
298
+ b_dh *= exp(b_gk)[:, None]
299
+ if USE_GV:
300
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
301
+ b_dh *= exp(b_gv)[None, :]
302
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
303
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
304
+
305
+ p_q += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
306
+ p_k += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
307
+ p_v += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
308
+ p_do += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
309
+ p_dk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
310
+ p_dv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
311
+ if USE_G:
312
+ p_g += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H)
313
+ if USE_GK:
314
+ p_gk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
315
+ if USE_GV:
316
+ p_gv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
317
+
318
+ if STORE_INITIAL_STATE_GRADIENT:
319
+ p_dh0 = dh0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
320
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h)
321
+
322
+
323
+ def fused_recurrent_fwd(
324
+ q: torch.Tensor,
325
+ k: torch.Tensor,
326
+ v: torch.Tensor,
327
+ g: Optional[torch.Tensor] = None,
328
+ gk: Optional[torch.Tensor] = None,
329
+ gv: Optional[torch.Tensor] = None,
330
+ scale: Optional[float] = None,
331
+ initial_state: Optional[torch.Tensor] = None,
332
+ output_final_state: bool = False,
333
+ reverse: bool = False,
334
+ offsets: Optional[torch.LongTensor] = None,
335
+ head_first: bool = True
336
+ ):
337
+ if head_first:
338
+ B, H, T, K, V = *k.shape, v.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *k.shape, v.shape[-1]
341
+ N = B if offsets is None else len(offsets) - 1
342
+ BK, BV = min(K, 64), min(V, 64)
343
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
344
+
345
+ h0 = initial_state
346
+ if output_final_state:
347
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
348
+ else:
349
+ ht = None
350
+ o = q.new_empty(NK, *v.shape, dtype=torch.float32)
351
+
352
+ grid = (NV, NK, N * H)
353
+ fused_recurrent_fwd_kernel[grid](
354
+ q,
355
+ k,
356
+ v,
357
+ g,
358
+ gk,
359
+ gv,
360
+ o,
361
+ h0,
362
+ ht,
363
+ offsets,
364
+ scale,
365
+ T=T,
366
+ B=B,
367
+ H=H,
368
+ K=K,
369
+ V=V,
370
+ BK=BK,
371
+ BV=BV,
372
+ USE_G=g is not None,
373
+ USE_GK=gk is not None,
374
+ USE_GV=gv is not None,
375
+ REVERSE=reverse,
376
+ HEAD_FIRST=head_first
377
+ )
378
+ o = o.sum(0)
379
+ return o, ht
380
+
381
+
382
+ def fused_recurrent_bwd(
383
+ q: torch.Tensor,
384
+ k: torch.Tensor,
385
+ v: torch.Tensor,
386
+ g: Optional[torch.Tensor] = None,
387
+ gk: Optional[torch.Tensor] = None,
388
+ gv: Optional[torch.Tensor] = None,
389
+ o: Optional[torch.Tensor] = None,
390
+ do: Optional[torch.Tensor] = None,
391
+ dht: Optional[torch.Tensor] = None,
392
+ scale: Optional[float] = None,
393
+ initial_state: Optional[torch.Tensor] = None,
394
+ reverse: bool = False,
395
+ offsets: Optional[torch.LongTensor] = None,
396
+ head_first: bool = True
397
+ ):
398
+ if head_first:
399
+ B, H, T, K, V = *k.shape, v.shape[-1]
400
+ else:
401
+ B, T, H, K, V = *k.shape, v.shape[-1]
402
+ N = B if offsets is None else len(offsets) - 1
403
+
404
+ BK, BV = min(K, 64), min(V, 64)
405
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
406
+
407
+ dq = q.new_empty(NV, *q.shape, dtype=torch.float32)
408
+ dk = q.new_empty(NV, *k.shape, dtype=torch.float32)
409
+ dv = q.new_empty(NK, *v.shape, dtype=torch.float32)
410
+ h0 = initial_state
411
+ dh0 = torch.empty_like(initial_state) if initial_state is not None else None
412
+
413
+ grid = (NV, NK, N * H)
414
+ fused_recurrent_bwd_kernel[grid](
415
+ q,
416
+ k,
417
+ v,
418
+ g,
419
+ gk,
420
+ gv,
421
+ h0,
422
+ do,
423
+ dq,
424
+ dk,
425
+ dv,
426
+ dht,
427
+ dh0,
428
+ offsets,
429
+ scale,
430
+ B=B,
431
+ T=T,
432
+ H=H,
433
+ K=K,
434
+ V=V,
435
+ BK=BK,
436
+ BV=BV,
437
+ USE_G=g is not None,
438
+ USE_GK=gk is not None,
439
+ USE_GV=gv is not None,
440
+ REVERSE=reverse,
441
+ HEAD_FIRST=head_first
442
+ )
443
+ dq = dq.sum(0)
444
+ dk = dk.sum(0)
445
+ dv = dv.sum(0)
446
+ dg, dgk, dgv = None, None, None
447
+ if g is not None:
448
+ dg = chunk_global_cumsum(
449
+ (dq * q.float() - dk * k.float()).sum(-1),
450
+ reverse=not reverse,
451
+ offsets=offsets,
452
+ head_first=head_first
453
+ )
454
+ if gk is not None:
455
+ dgk = chunk_global_cumsum(
456
+ dq * q.float() - dk * k.float(),
457
+ reverse=not reverse,
458
+ offsets=offsets,
459
+ head_first=head_first
460
+ )
461
+ if gv is not None:
462
+ dgv = chunk_global_cumsum(
463
+ do.float() * o.float() - dv * v.float(),
464
+ reverse=not reverse,
465
+ offsets=offsets,
466
+ head_first=head_first
467
+ )
468
+
469
+ return dq, dk, dv, dg, dgk, dgv, dh0
470
+
471
+
472
+ class FusedRecurrentFunction(torch.autograd.Function):
473
+
474
+ @staticmethod
475
+ @input_guard
476
+ @autocast_custom_fwd
477
+ def forward(
478
+ ctx,
479
+ q: torch.Tensor,
480
+ k: torch.Tensor,
481
+ v: torch.Tensor,
482
+ g: Optional[torch.Tensor] = None,
483
+ gk: Optional[torch.Tensor] = None,
484
+ gv: Optional[torch.Tensor] = None,
485
+ scale: Optional[float] = None,
486
+ initial_state: Optional[torch.Tensor] = None,
487
+ output_final_state: bool = False,
488
+ reverse: bool = False,
489
+ offsets: Optional[torch.LongTensor] = None,
490
+ head_first: bool = True
491
+ ):
492
+ o, ht = fused_recurrent_fwd(
493
+ q=q,
494
+ k=k,
495
+ v=v,
496
+ g=g,
497
+ gk=gk,
498
+ gv=gv,
499
+ scale=scale,
500
+ initial_state=initial_state,
501
+ output_final_state=output_final_state,
502
+ reverse=reverse,
503
+ offsets=offsets,
504
+ head_first=head_first
505
+ )
506
+ ctx.save_for_backward(q, k, v, g, gk, gv, initial_state, o)
507
+ ctx.scale = scale
508
+ ctx.reverse = reverse
509
+ ctx.offsets = offsets
510
+ ctx.head_first = head_first
511
+ return o.to(q.dtype), ht
512
+
513
+ @staticmethod
514
+ @input_guard
515
+ @autocast_custom_bwd
516
+ def backward(ctx, do, dht):
517
+ q, k, v, g, gk, gv, initial_state, o = ctx.saved_tensors
518
+ # not supported yet.
519
+ if dht is not None:
520
+ if not dht.eq(0).all():
521
+ if g is not None:
522
+ assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
523
+ if gk is not None:
524
+ assert gk.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
525
+ if gv is not None:
526
+ assert gv.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
527
+ dq, dk, dv, dg, dgk, dgv, dh0 = fused_recurrent_bwd(
528
+ q=q,
529
+ k=k,
530
+ v=v,
531
+ g=g,
532
+ gk=gk,
533
+ gv=gv,
534
+ o=o,
535
+ do=do,
536
+ dht=dht,
537
+ scale=ctx.scale,
538
+ initial_state=initial_state,
539
+ reverse=ctx.reverse,
540
+ offsets=ctx.offsets,
541
+ head_first=ctx.head_first
542
+ )
543
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None, None
544
+
545
+
546
+ def fused_recurrent(
547
+ q: torch.Tensor,
548
+ k: torch.Tensor,
549
+ v: torch.Tensor,
550
+ g: Optional[torch.Tensor] = None,
551
+ gk: Optional[torch.Tensor] = None,
552
+ gv: Optional[torch.Tensor] = None,
553
+ scale: Optional[float] = None,
554
+ initial_state: Optional[torch.Tensor] = None,
555
+ output_final_state: bool = False,
556
+ reverse: bool = False,
557
+ cu_seqlens: Optional[torch.LongTensor] = None,
558
+ head_first: bool = True
559
+ ):
560
+ if scale is None:
561
+ scale = k.shape[-1] ** -0.5
562
+ return FusedRecurrentFunction.apply(
563
+ q,
564
+ k,
565
+ v,
566
+ g,
567
+ gk,
568
+ gv,
569
+ scale,
570
+ initial_state,
571
+ output_final_state,
572
+ reverse,
573
+ cu_seqlens,
574
+ head_first
575
+ )
fla/ops/common/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from fla.utils import tensor_cache
9
+
10
+
11
+ @triton.autotune(
12
+ configs=[
13
+ triton.Config({}, num_warps=num_warps)
14
+ for num_warps in [4, 8, 16, 32]
15
+ ],
16
+ key=['B'],
17
+ )
18
+ @triton.jit
19
+ def prepare_position_ids_kernel(
20
+ y,
21
+ offsets,
22
+ B: tl.constexpr
23
+ ):
24
+ i_n = tl.program_id(0)
25
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
26
+ T = eos - bos
27
+
28
+ o = tl.arange(0, B)
29
+ for i in range(0, tl.cdiv(T, B) * B, B):
30
+ o_i = o + i
31
+ tl.store(y + bos + o_i, o_i, o_i < T)
32
+
33
+
34
+ @tensor_cache
35
+ def prepare_lens(offsets: torch.LongTensor) -> torch.LongTensor:
36
+ return offsets[1:] - offsets[:-1]
37
+
38
+
39
+ @tensor_cache
40
+ def prepare_position_ids(offsets: torch.LongTensor) -> torch.LongTensor:
41
+ return torch.cat([torch.arange(n, dtype=offsets.dtype, device=offsets.device) for n in prepare_lens(offsets).unbind()])
42
+
43
+
44
+ @tensor_cache
45
+ def prepare_sequence_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
46
+ return position_ids.eq(0).cumsum(0) - 1
47
+
48
+
49
+ @tensor_cache
50
+ def prepare_token_indices(offsets: torch.LongTensor) -> torch.LongTensor:
51
+ position_ids = prepare_position_ids(offsets)
52
+ return torch.stack([prepare_sequence_ids(position_ids), position_ids], 1).to(offsets)
53
+
54
+
55
+ @tensor_cache
56
+ def prepare_chunk_indices(
57
+ offsets: torch.LongTensor,
58
+ chunk_size: int
59
+ ) -> torch.LongTensor:
60
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(offsets), chunk_size).tolist()])
61
+ return torch.stack([prepare_sequence_ids(indices), indices], 1).to(offsets)
62
+
63
+
64
+ @tensor_cache
65
+ def prepare_chunk_offsets(
66
+ offsets: torch.LongTensor,
67
+ chunk_size: int
68
+ ) -> torch.LongTensor:
69
+ return torch.cat([offsets.new_tensor([0]), triton.cdiv(prepare_lens(offsets), chunk_size)]).cumsum(-1)
fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (357 Bytes). View file
 
fla/ops/delta_rule/chunk.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ from einops import rearrange
9
+
10
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
11
+ from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
12
+ from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
13
+ from fla.ops.common.utils import prepare_chunk_indices
14
+ from fla.ops.delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ def chunk_delta_rule_fwd(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ beta: torch.Tensor,
23
+ scale: float,
24
+ initial_state: torch.Tensor,
25
+ output_final_state: bool,
26
+ offsets: Optional[torch.LongTensor] = None,
27
+ indices: Optional[torch.LongTensor] = None,
28
+ head_first: bool = True,
29
+ chunk_size: int = 64
30
+ ):
31
+ T = q.shape[2] if head_first else q.shape[1]
32
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
33
+ # obtain WY representation. u is actually the new v.
34
+ w, u, A = fwd_prepare_wy_repr(
35
+ k=k,
36
+ v=v,
37
+ beta=beta,
38
+ offsets=offsets,
39
+ indices=indices,
40
+ head_first=head_first,
41
+ chunk_size=BT
42
+ )
43
+
44
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
45
+ k=k,
46
+ w=w,
47
+ u=u,
48
+ g=None,
49
+ initial_state=initial_state,
50
+ output_final_state=output_final_state,
51
+ offsets=offsets,
52
+ indices=indices,
53
+ head_first=head_first,
54
+ chunk_size=BT
55
+ )
56
+ o = chunk_fwd_o(
57
+ q=q,
58
+ k=k,
59
+ v=v_new,
60
+ h=h,
61
+ g=None,
62
+ scale=scale,
63
+ offsets=offsets,
64
+ indices=indices,
65
+ head_first=head_first,
66
+ chunk_size=BT
67
+ )
68
+ return o, A, final_state
69
+
70
+
71
+ def chunk_delta_rule_bwd(
72
+ q: torch.Tensor,
73
+ k: torch.Tensor,
74
+ v: torch.Tensor,
75
+ beta: torch.Tensor,
76
+ A: torch.Tensor,
77
+ scale: float,
78
+ initial_state: torch.Tensor,
79
+ do: torch.Tensor,
80
+ dht: torch.Tensor,
81
+ offsets: Optional[torch.LongTensor] = None,
82
+ indices: Optional[torch.LongTensor] = None,
83
+ head_first: bool = True,
84
+ chunk_size: int = 64
85
+ ):
86
+ T = q.shape[2] if head_first else q.shape[1]
87
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
88
+ w, u = fwd_recompute_w_u(
89
+ k=k,
90
+ v=v,
91
+ beta=beta,
92
+ A=A,
93
+ offsets=offsets,
94
+ indices=indices,
95
+ head_first=head_first,
96
+ chunk_size=BT
97
+ )
98
+ h, v_new, _ = chunk_gated_delta_rule_fwd_h(
99
+ k=k,
100
+ w=w,
101
+ u=u,
102
+ g=None,
103
+ initial_state=initial_state,
104
+ output_final_state=False,
105
+ offsets=offsets,
106
+ indices=indices,
107
+ head_first=head_first,
108
+ chunk_size=BT
109
+ )
110
+ dv = chunk_bwd_dv_local(
111
+ q=q,
112
+ k=k,
113
+ do=do,
114
+ g=None,
115
+ dh=None,
116
+ scale=scale,
117
+ offsets=offsets,
118
+ indices=indices,
119
+ head_first=head_first,
120
+ chunk_size=BT
121
+ )
122
+ dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
123
+ q=q,
124
+ k=k,
125
+ w=w,
126
+ g=None,
127
+ h0=initial_state,
128
+ dht=dht,
129
+ do=do,
130
+ dv=dv,
131
+ scale=scale,
132
+ offsets=offsets,
133
+ indices=indices,
134
+ head_first=head_first,
135
+ chunk_size=BT
136
+ )
137
+ dq, dk, dw, _ = chunk_bwd_dqkwg(
138
+ q=q,
139
+ k=k,
140
+ v=v_new,
141
+ h=h,
142
+ w=w,
143
+ dv=dv,
144
+ do=do,
145
+ dh=dh,
146
+ g=None,
147
+ scale=scale,
148
+ offsets=offsets,
149
+ indices=indices,
150
+ head_first=head_first,
151
+ chunk_size=BT
152
+ )
153
+ dk2, dv, db = bwd_prepare_wy_repr(
154
+ k=k,
155
+ v=v,
156
+ beta=beta,
157
+ A=A,
158
+ dw=dw,
159
+ du=dv,
160
+ offsets=offsets,
161
+ indices=indices,
162
+ head_first=head_first,
163
+ chunk_size=BT
164
+ )
165
+ dk.add_(dk2)
166
+ return dq, dk, dv, db, dh0
167
+
168
+
169
+ class ChunkDeltaRuleFunction(torch.autograd.Function):
170
+
171
+ @staticmethod
172
+ @input_guard
173
+ @autocast_custom_fwd
174
+ def forward(
175
+ ctx,
176
+ q: torch.Tensor,
177
+ k: torch.Tensor,
178
+ v: torch.Tensor,
179
+ beta: torch.Tensor,
180
+ scale: float,
181
+ initial_state: torch.Tensor,
182
+ output_final_state: bool,
183
+ offsets: Optional[torch.LongTensor] = None,
184
+ head_first: bool = True,
185
+ use_qk_l2norm_in_kernel: bool = True
186
+ ):
187
+ T = q.shape[2] if head_first else q.shape[1]
188
+ chunk_size = min(64, max(triton.next_power_of_2(T), 16))
189
+
190
+ q_orig = q
191
+ k_orig = k
192
+
193
+ if use_qk_l2norm_in_kernel:
194
+ q = l2norm_fwd(q)
195
+ k = l2norm_fwd(k)
196
+
197
+ # 2-d indices denoting the offsets of chunks in each sequence
198
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
199
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
200
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
201
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
202
+
203
+ o, A, final_state = chunk_delta_rule_fwd(
204
+ q=q,
205
+ k=k,
206
+ v=v,
207
+ beta=beta,
208
+ scale=scale,
209
+ initial_state=initial_state,
210
+ output_final_state=output_final_state,
211
+ offsets=offsets,
212
+ indices=indices,
213
+ head_first=head_first,
214
+ chunk_size=chunk_size
215
+ )
216
+ ctx.save_for_backward(q_orig, k_orig, v, beta, A, initial_state)
217
+ ctx.chunk_size = chunk_size
218
+ ctx.scale = scale
219
+ ctx.offsets = offsets
220
+ ctx.indices = indices
221
+ ctx.head_first = head_first
222
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
223
+ return o.to(q.dtype), final_state
224
+
225
+ @staticmethod
226
+ @input_guard
227
+ @autocast_custom_bwd
228
+ def backward(
229
+ ctx,
230
+ do: torch.Tensor,
231
+ dht: torch.Tensor
232
+ ):
233
+ q, k, v, beta, A, initial_state = ctx.saved_tensors
234
+ use_qk_l2norm_in_kernel = ctx.use_qk_l2norm_in_kernel
235
+ if use_qk_l2norm_in_kernel:
236
+ q, q_orig = l2norm_fwd(q), q
237
+ k, k_orig = l2norm_fwd(k), k
238
+
239
+ dq, dk, dv, db, dh0 = chunk_delta_rule_bwd(
240
+ q=q,
241
+ k=k,
242
+ v=v,
243
+ beta=beta,
244
+ A=A,
245
+ scale=ctx.scale,
246
+ initial_state=initial_state,
247
+ do=do,
248
+ dht=dht,
249
+ offsets=ctx.offsets,
250
+ indices=ctx.indices,
251
+ head_first=ctx.head_first,
252
+ chunk_size=ctx.chunk_size
253
+ )
254
+ if use_qk_l2norm_in_kernel:
255
+ dq = l2norm_bwd(q_orig, dq)
256
+ dk = l2norm_bwd(k_orig, dk)
257
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), db.to(beta.dtype), None, dh0, None, None, None, None, None, None
258
+
259
+
260
+ @torch.compiler.disable
261
+ def chunk_delta_rule(
262
+ q: torch.Tensor,
263
+ k: torch.Tensor,
264
+ v: torch.Tensor,
265
+ beta: torch.Tensor,
266
+ scale: float = None,
267
+ initial_state: torch.Tensor = None,
268
+ output_final_state: bool = False,
269
+ cu_seqlens: Optional[torch.LongTensor] = None,
270
+ head_first: bool = False,
271
+ use_qk_l2norm_in_kernel: bool = False
272
+ ):
273
+ r"""
274
+ Args:
275
+ q (torch.Tensor):
276
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
277
+ k (torch.Tensor):
278
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
279
+ v (torch.Tensor):
280
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
281
+ beta (torch.Tensor):
282
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
283
+ scale (Optional[int]):
284
+ Scale factor for the RetNet attention scores.
285
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
286
+ initial_state (Optional[torch.Tensor]):
287
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
288
+ For equal-length input sequences, `N` equals the batch size `B`.
289
+ Default: `None`.
290
+ output_final_state (Optional[bool]):
291
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
292
+ cu_seqlens (torch.LongTensor):
293
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
294
+ consistent with the FlashAttention API.
295
+ head_first (Optional[bool]):
296
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
297
+ Default: `False`.
298
+ use_qk_l2norm_in_kernel (Optional[bool]):
299
+ Whether to use qk l2norm within the kernel for saving GPU memory.
300
+ Default: `False`.
301
+
302
+ Returns:
303
+ o (torch.Tensor):
304
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
305
+ final_state (torch.Tensor):
306
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
307
+
308
+ Examples::
309
+ >>> import torch
310
+ >>> import torch.nn.functional as F
311
+ >>> from einops import rearrange
312
+ >>> from fla.ops.delta_rule import chunk_delta_rule
313
+ # inputs with equal lengths
314
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
315
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
316
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
317
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
318
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
319
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
320
+ >>> o, ht = chunk_delta_rule(
321
+ q, k, v, beta,
322
+ initial_state=h0,
323
+ output_final_state=True
324
+ )
325
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
326
+ >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta))
327
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
328
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
329
+ >>> o_var, ht_var = chunk_delta_rule(
330
+ q, k, v, beta,
331
+ initial_state=h0,
332
+ output_final_state=True,
333
+ cu_seqlens=cu_seqlens
334
+ )
335
+ """
336
+ assert q.dtype == k.dtype == v.dtype
337
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
338
+ assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)."
339
+
340
+ if cu_seqlens is not None:
341
+ if q.shape[0] != 1:
342
+ raise ValueError(
343
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
344
+ f"Please flatten variable-length inputs before processing."
345
+ )
346
+ if head_first:
347
+ raise RuntimeError(
348
+ "Sequences with variable lengths are not supported for head-first mode"
349
+ )
350
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
351
+ raise ValueError(
352
+ f"The number of initial states is expected to be equal to the number of input sequences, "
353
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
354
+ )
355
+ if head_first:
356
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
357
+ beta = rearrange(beta, 'b h t -> b t h')
358
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
359
+ o, final_state = ChunkDeltaRuleFunction.apply(
360
+ q,
361
+ k,
362
+ v,
363
+ beta,
364
+ scale,
365
+ initial_state,
366
+ output_final_state,
367
+ cu_seqlens,
368
+ False,
369
+ use_qk_l2norm_in_kernel
370
+ )
371
+ if head_first:
372
+ o = rearrange(o, 'b t h v -> b h t v')
373
+ return o, final_state
fla/ops/delta_rule/naive.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def delta_rule_recurrence(q, k, v, beta, initial_state=None, output_final_state=True):
8
+ orig_dtype = q.dtype
9
+ b, h, l, d_k = q.shape
10
+ q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta])
11
+ d_v = v.shape[-1]
12
+ o = torch.zeros_like(v)
13
+ S = torch.zeros(b, h, d_k, d_v).to(v)
14
+ q = q * (d_k ** -0.5)
15
+
16
+ if beta.ndim < v.ndim:
17
+ beta = beta[..., None]
18
+
19
+ if initial_state is not None:
20
+ S += initial_state
21
+
22
+ for i in range(l):
23
+ _k = k[:, :, i]
24
+ _q = q[:, :, i]
25
+ _v = v[:, :, i].clone()
26
+ beta_i = beta[:, :, i]
27
+ _v = _v - (S.clone() * _k[..., None]).sum(-2)
28
+ _v = _v * beta_i
29
+ S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
30
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
31
+ S = None if output_final_state is False else S
32
+ return o.to(orig_dtype), S
33
+
34
+
35
+ def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
36
+ b, h, l, d_k = q.shape
37
+ d_v = v.shape[-1]
38
+ q = q * (d_k ** -0.5)
39
+ v = v * beta[..., None]
40
+ k_beta = k * beta[..., None]
41
+
42
+ assert l % chunk_size == 0
43
+
44
+ # compute (I - tri(diag(beta) KK^T))^{-1}
45
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
46
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
47
+ attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
48
+ for i in range(1, chunk_size):
49
+ attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
50
+ attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
51
+
52
+ u = attn @ v
53
+ w = attn @ k_beta
54
+ S = k.new_zeros(b, h, d_k, d_v)
55
+ o = torch.zeros_like(v)
56
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
57
+ for i in range(0, l // chunk_size):
58
+ q_i, k_i = q[:, :, i], k[:, :, i]
59
+ attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
60
+ u_i = u[:, :, i] - w[:, :, i] @ S
61
+ o_inter = q_i @ S
62
+ o[:, :, i] = o_inter + attn @ u_i
63
+ S = S + k_i.transpose(-1, -2) @ u_i
64
+
65
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
66
+
67
+
68
+ def delta_rule_parallel(q, k, v, beta, BM=128, BN=32):
69
+ b, h, l, d_k = q.shape
70
+ # d_v = v.shape[-1]
71
+ q = q * (d_k ** -0.5)
72
+ v = v * beta[..., None]
73
+ k_beta = k * beta[..., None]
74
+ # compute (I - tri(diag(beta) KK^T))^{-1}
75
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta])
76
+ mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0)
77
+ T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
78
+ for i in range(1, BN):
79
+ T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2)
80
+ T = T + torch.eye(BN, dtype=torch.float, device=q.device)
81
+
82
+ mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1)
83
+ A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T
84
+ o_intra = A_local @ v
85
+
86
+ # apply cumprod transition matrices on k to the last position within the chunk
87
+ k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta
88
+ # apply cumprod transition matrices on q to the first position within the chunk
89
+ q = q - A_local @ k_beta
90
+ o_intra = A_local @ v
91
+
92
+ A = torch.zeros(b, h, l, l, device=q.device)
93
+
94
+ q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra])
95
+ o = torch.empty_like(v)
96
+ for i in range(0, l, BM):
97
+ q_i = q[:, :, i:i+BM]
98
+ o_i = o_intra[:, :, i:i+BM]
99
+ # intra block
100
+ for j in range(i + BM - 2 * BN, i-BN, -BN):
101
+ k_j = k[:, :, j:j+BN]
102
+ A_ij = q_i @ k_j.transpose(-1, -2)
103
+ mask = torch.arange(i, i+BM) >= (j + BN)
104
+ A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0)
105
+ A[:, :, i:i+BM, j:j+BN] = A_ij
106
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
107
+ o_i += A_ij @ v[:, :, j:j+BN]
108
+ # inter block
109
+ for j in range(i - BN, -BN, -BN):
110
+ k_j = k[:, :, j:j+BN]
111
+ A_ij = q_i @ k_j.transpose(-1, -2)
112
+ A[:, :, i:i+BM, j:j+BN] = A_ij
113
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
114
+ o_i += A_ij @ v[:, :, j:j+BN]
115
+ o[:, :, i:i+BM] = o_i
116
+
117
+ for i in range(0, l//BN):
118
+ A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i]
119
+
120
+ return o, A
fla/ops/delta_rule/wy_fast.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
11
+ from fla.ops.utils.solve_tril import solve_tril
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def fwd_recompute_w_u_kernel(
30
+ k,
31
+ v,
32
+ beta,
33
+ w,
34
+ u,
35
+ A,
36
+ offsets,
37
+ indices,
38
+ T,
39
+ H: tl.constexpr,
40
+ K: tl.constexpr,
41
+ V: tl.constexpr,
42
+ BT: tl.constexpr,
43
+ BK: tl.constexpr,
44
+ BV: tl.constexpr,
45
+ HEAD_FIRST: tl.constexpr,
46
+ USE_OFFSETS: tl.constexpr
47
+ ):
48
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
49
+ i_b, i_h = i_bh // H, i_bh % H
50
+ if USE_OFFSETS:
51
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
52
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
53
+ T = eos - bos
54
+ else:
55
+ bos, eos = i_b * T, i_b * T + T
56
+
57
+ if HEAD_FIRST:
58
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
59
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
60
+ else:
61
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
62
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
63
+ b_beta = tl.load(p_beta, boundary_check=(0,))
64
+ b_A = tl.load(p_A, boundary_check=(0, 1))
65
+
66
+ for i_v in range(tl.cdiv(V, BV)):
67
+ if HEAD_FIRST:
68
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
69
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
70
+ else:
71
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
72
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
73
+ b_v = tl.load(p_v, boundary_check=(0, 1))
74
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
75
+ b_u = tl.dot(b_A.to(b_vb.dtype), b_vb, allow_tf32=False)
76
+ tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+ for i_k in range(tl.cdiv(K, BK)):
79
+ if HEAD_FIRST:
80
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
81
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
82
+ else:
83
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
84
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
85
+ b_k = tl.load(p_k, boundary_check=(0, 1))
86
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
87
+ b_w = tl.dot(b_A.to(b_kb.dtype), b_kb, allow_tf32=False)
88
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
89
+
90
+
91
+ @triton.heuristics({
92
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
93
+ })
94
+ @triton.autotune(
95
+ configs=[
96
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
97
+ for num_warps in NUM_WARPS
98
+ for num_stages in [2, 3, 4]
99
+ ],
100
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
101
+ )
102
+ @triton.jit(do_not_specialize=['T'])
103
+ def bwd_prepare_wy_repr_kernel(
104
+ k,
105
+ v,
106
+ beta,
107
+ A,
108
+ dw,
109
+ du,
110
+ dk,
111
+ dv,
112
+ dbeta,
113
+ offsets,
114
+ indices,
115
+ T,
116
+ H: tl.constexpr,
117
+ K: tl.constexpr,
118
+ V: tl.constexpr,
119
+ BT: tl.constexpr,
120
+ BK: tl.constexpr,
121
+ BV: tl.constexpr,
122
+ HEAD_FIRST: tl.constexpr,
123
+ USE_OFFSETS: tl.constexpr
124
+ ):
125
+ i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
126
+ i_b, i_h = i_bh // H, i_bh % H
127
+ if USE_OFFSETS:
128
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
129
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
130
+ T = eos - bos
131
+ else:
132
+ bos, eos = i_b * T, i_b * T + T
133
+
134
+ if HEAD_FIRST:
135
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
136
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
137
+ else:
138
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
139
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
140
+
141
+ b_beta = tl.load(p_beta, boundary_check=(0,))
142
+ b_A = tl.load(p_A, boundary_check=(0, 1))
143
+
144
+ b_dbeta = tl.zeros([BT], dtype=tl.float32)
145
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
146
+ for i_v in range(tl.cdiv(V, BV)):
147
+ if HEAD_FIRST:
148
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
149
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
150
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
151
+ else:
152
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
153
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
154
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
155
+
156
+ b_v = tl.load(p_v, boundary_check=(0, 1))
157
+ b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
158
+ b_du = tl.load(p_du, boundary_check=(0, 1))
159
+ b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
160
+ b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False)
161
+ b_dv = b_dv_beta * b_beta[:, None]
162
+ b_dbeta += tl.sum(b_dv_beta * b_v, 1)
163
+
164
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
165
+
166
+ for i_k in range(tl.cdiv(K, BK)):
167
+ if HEAD_FIRST:
168
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
169
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
170
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
171
+ else:
172
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
173
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
174
+ p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
175
+ b_k = tl.load(p_k, boundary_check=(0, 1))
176
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
177
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
178
+ b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
179
+ b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False)
180
+ b_dk = b_dk_beta * b_beta[:, None]
181
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
182
+
183
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
184
+
185
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
186
+ b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
187
+ b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
188
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
189
+
190
+ for i_k in range(tl.cdiv(K, BK)):
191
+ if HEAD_FIRST:
192
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
193
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
194
+ else:
195
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
196
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
197
+ b_k = tl.load(p_k, boundary_check=(0, 1))
198
+ b_dk = tl.load(p_dk, boundary_check=(0, 1))
199
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
200
+
201
+ b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
202
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
203
+ b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
204
+ b_dk += b_dk_beta * b_beta[:, None]
205
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
206
+
207
+ if HEAD_FIRST:
208
+ p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
209
+ else:
210
+ p_dbeta = tl.make_block_ptr(dbeta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
211
+ tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
212
+
213
+
214
+ def fwd_prepare_wy_repr(
215
+ k: torch.Tensor,
216
+ v: torch.Tensor,
217
+ beta: torch.Tensor,
218
+ offsets: Optional[torch.LongTensor],
219
+ indices: Optional[torch.LongTensor],
220
+ head_first: bool = False,
221
+ chunk_size: int = 64
222
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
223
+ A = chunk_scaled_dot_kkt_fwd(
224
+ k=k,
225
+ beta=beta,
226
+ cu_seqlens=offsets,
227
+ head_first=head_first,
228
+ chunk_size=chunk_size,
229
+ output_dtype=torch.float32
230
+ )
231
+ A = solve_tril(
232
+ A=A,
233
+ cu_seqlens=offsets,
234
+ head_first=head_first,
235
+ output_dtype=k.dtype
236
+ )
237
+
238
+ w, u = fwd_recompute_w_u(
239
+ k=k,
240
+ v=v,
241
+ beta=beta,
242
+ A=A,
243
+ offsets=offsets,
244
+ indices=indices,
245
+ head_first=head_first,
246
+ chunk_size=chunk_size
247
+ )
248
+ return w, u, A
249
+
250
+
251
+ def fwd_recompute_w_u(
252
+ k: torch.Tensor,
253
+ v: torch.Tensor,
254
+ beta: torch.Tensor,
255
+ A: torch.Tensor,
256
+ offsets: Optional[torch.LongTensor],
257
+ indices: Optional[torch.LongTensor],
258
+ head_first: bool,
259
+ chunk_size: int
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ if head_first:
262
+ B, H, T, K, V = *k.shape, v.shape[-1]
263
+ else:
264
+ B, T, H, K, V = *k.shape, v.shape[-1]
265
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
266
+ CONST_TILING = 64 if check_shared_mem() else 32
267
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
268
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
269
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
270
+
271
+ u = torch.empty_like(v)
272
+ w = torch.empty_like(k)
273
+ fwd_recompute_w_u_kernel[(NT, B*H)](
274
+ k,
275
+ v,
276
+ beta,
277
+ w,
278
+ u,
279
+ A,
280
+ offsets=offsets,
281
+ indices=indices,
282
+ T=T,
283
+ H=H,
284
+ K=K,
285
+ V=V,
286
+ BT=BT,
287
+ BK=BK,
288
+ BV=BV,
289
+ HEAD_FIRST=head_first
290
+ )
291
+ return w, u
292
+
293
+
294
+ def bwd_prepare_wy_repr(
295
+ k: torch.Tensor,
296
+ v: torch.Tensor,
297
+ beta: torch.Tensor,
298
+ A: torch.Tensor,
299
+ dw: torch.Tensor,
300
+ du: torch.Tensor,
301
+ offsets: Optional[torch.LongTensor],
302
+ indices: Optional[torch.LongTensor],
303
+ head_first: bool,
304
+ chunk_size: int
305
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
306
+ if head_first:
307
+ B, H, T, K, V = *k.shape, v.shape[-1]
308
+ else:
309
+ B, T, H, K, V = *k.shape, v.shape[-1]
310
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
311
+ CONST_TILING = 64 if check_shared_mem() else 32
312
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
313
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
314
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
315
+
316
+ dk = torch.empty_like(k)
317
+ dv = torch.empty_like(v)
318
+ dbeta = torch.empty_like(beta)
319
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
320
+ k,
321
+ v,
322
+ beta,
323
+ A,
324
+ dw,
325
+ du,
326
+ dk,
327
+ dv,
328
+ dbeta,
329
+ offsets=offsets,
330
+ indices=indices,
331
+ T=T,
332
+ H=H,
333
+ K=K,
334
+ V=V,
335
+ BT=BT,
336
+ BK=BK,
337
+ BV=BV,
338
+ HEAD_FIRST=head_first
339
+ )
340
+ return dk, dv, dbeta