Add files using upload-large-folder tool
Browse files- README.md +188 -0
- fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dkv.cpython-310.pyc +0 -0
- fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc +0 -0
- fla3/ops/retention/__pycache__/parallel.cpython-312.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/chunk.cpython-312.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/fused_addcmul.cpython-310.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/fused_k_update.cpython-310.pyc +0 -0
- fla3/ops/rwkv7/fused_recurrent.py +328 -0
- fla3/ops/simple_gla/__pycache__/chunk.cpython-312.pyc +0 -0
- fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
- fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc +0 -0
- fla3/ops/simple_gla/__pycache__/parallel.cpython-312.pyc +0 -0
- fla3/ops/simple_gla/fused_recurrent.py +108 -0
- fla3/ops/simple_gla/naive.py +54 -0
- fla3/ops/ttt/naive.py +126 -0
- fla3/ops/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/asm.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/asm.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/cumsum.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/index.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/index.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/logsumexp.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/logsumexp.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/matmul.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/op.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/pooling.cpython-310.pyc +0 -0
- fla3/ops/utils/cumsum.py +414 -0
- fla3/ops/utils/index.py +83 -0
- fla3/ops/utils/logcumsumexp.py +52 -0
- fla3/ops/utils/matmul.py +245 -0
- fla3/ops/utils/op.py +39 -0
- fla3/ops/utils/pack.py +208 -0
- fla3/ops/utils/pooling.py +207 -0
- fla3/ops/utils/softmax.py +111 -0
- fla3/ops/utils/solve_tril.py +276 -0
- flame/__init__.py +0 -0
- flame/__pycache__/__init__.cpython-310.pyc +0 -0
- flame/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/__pycache__/data.cpython-310.pyc +0 -0
- flame/__pycache__/data.cpython-312.pyc +0 -0
- flame/__pycache__/logging.cpython-310.pyc +0 -0
- flame/__pycache__/logging.cpython-312.pyc +0 -0
- flame/__pycache__/parser.cpython-310.pyc +0 -0
- flame/__pycache__/parser.cpython-312.pyc +0 -0
- flame/data.py +246 -0
- flame/logging.py +118 -0
- flame/parser.py +94 -0
README.md
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# 🔥 Flame: Flash Linear Attention Made Easy
|
| 4 |
+
|
| 5 |
+
</div>
|
| 6 |
+
|
| 7 |
+
> [!IMPORTANT]
|
| 8 |
+
> The `flame` project has been migrated to a new project built on torchtitan.
|
| 9 |
+
> Please visit the [new repository](https://github.com/fla-org/flame) for details and updates.
|
| 10 |
+
>
|
| 11 |
+
> The code here is now **archived as legacy**, and no future updates will be synchronized here.
|
| 12 |
+
|
| 13 |
+
A minimal framework for training FLA models, whether from scratch or through finetuning.
|
| 14 |
+
|
| 15 |
+
Built on the robust infrastructure of 🤗, `flame` enables you to train large language models with just a few lines of code:
|
| 16 |
+
we use `datasets` for data processing, `transformers` for model definitions, and `accelerate`[^1] for seamless distributed training.
|
| 17 |
+
|
| 18 |
+
In this README, we will guide you through the process of using `flame` to train GLA models.
|
| 19 |
+
|
| 20 |
+
## Setup
|
| 21 |
+
|
| 22 |
+
To get started, you'll need to install the required packages.
|
| 23 |
+
Both `fla` and `flame` have minimal dependencies.
|
| 24 |
+
Clone the `fla` repository and install the necessary packages as follows:
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
git clone https://github.com/sustcsonglin/flash-linear-attention.git
|
| 28 |
+
pip install .
|
| 29 |
+
pip install accelerate
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
> [!CAUTION]
|
| 33 |
+
> The 🤗 `tokenizers` have some [memory leak issues](https://github.com/huggingface/tokenizers/issues/1539) when processing very long documents.
|
| 34 |
+
> To address this, please ensure you install `tokenizers>=0.20.4`.
|
| 35 |
+
|
| 36 |
+
## Preprocessing
|
| 37 |
+
|
| 38 |
+
Before training, you need to download and pre-tokenize your dataset.
|
| 39 |
+
We provide a straightforward script for this.
|
| 40 |
+
For instance, to tokenize a 10B sample of the `fineweb-edu` dataset, run:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
python preprocess.py \
|
| 44 |
+
--dataset HuggingFaceFW/fineweb-edu \
|
| 45 |
+
--name sample-10BT \
|
| 46 |
+
--split train \
|
| 47 |
+
--context_length 2048
|
| 48 |
+
```
|
| 49 |
+
```
|
| 50 |
+
python preprocess.py \
|
| 51 |
+
--dataset /mnt/jfzn/msj/fineweb100B_hf/datasets--HuggingFaceFW--fineweb-edu/sample/100BT \
|
| 52 |
+
--name sample-100BT \
|
| 53 |
+
--split train \
|
| 54 |
+
--context_length 2048
|
| 55 |
+
```
|
| 56 |
+
/mnt/jfzn/msj/fineweb100B_hf/datasets--HuggingFaceFW--fineweb-edu/sample/100BT
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
This will cache the processed dataset at `data/HuggingFaceFW/fineweb-edu/sample-10BT/train`.
|
| 60 |
+
|
| 61 |
+
GLA utilizes a subset of Slimpajama for pretraining [in the paper](https://proceedings.mlr.press/v235/yang24ab.html).
|
| 62 |
+
Given the size of the dataset, the fastest way to download it is using `git lfs` (refer to [this issue](https://huggingface.co/datasets/cerebras/SlimPajama-627B/discussions/2)).
|
| 63 |
+
```bash
|
| 64 |
+
git lfs install
|
| 65 |
+
git clone https://huggingface.co/datasets/cerebras/SlimPajama-627B --depth 1
|
| 66 |
+
python preprocess.py \
|
| 67 |
+
--dataset SlimPajama-627B \
|
| 68 |
+
--split train \
|
| 69 |
+
--context_length 2048
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
## Training from scratch
|
| 73 |
+
|
| 74 |
+
To train your 340M model from scratch, execute the following command:
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
bash train.sh \
|
| 78 |
+
type=gla \
|
| 79 |
+
lr=3e-4 \
|
| 80 |
+
scheduler=cosine_with_min_lr \
|
| 81 |
+
batch=32 \
|
| 82 |
+
update=1 \
|
| 83 |
+
warmup=1024 \
|
| 84 |
+
steps=20480 \
|
| 85 |
+
context=2048 \
|
| 86 |
+
gpus=8 \
|
| 87 |
+
nodes=1 \
|
| 88 |
+
path=exp/gla-340M-10B \
|
| 89 |
+
project=fla \
|
| 90 |
+
model=configs/gla_340M.json \
|
| 91 |
+
data=HuggingFaceFW/fineweb-edu \
|
| 92 |
+
name=sample-10BT \
|
| 93 |
+
cache=data/HuggingFaceFW/fineweb-edu/sample-10BT/train
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
Key parameters:
|
| 97 |
+
|
| 98 |
+
| | Description | Default |
|
| 99 |
+
| :-------- | :---------------------------- | -------------------- |
|
| 100 |
+
| lr | `learning_rate` | `3e-4` |
|
| 101 |
+
| scheduler | `lr_scheduler_type` | `cosine_with_min_lr` |
|
| 102 |
+
| batch | `batch_size` | `32` |
|
| 103 |
+
| update | `gradient_accumulation_steps` | `1` |
|
| 104 |
+
| context | `context_length` | `2048` |
|
| 105 |
+
| gpus | `num_gpus_per_node` | `8` |
|
| 106 |
+
| nodes | `num_nodes` | `1` |
|
| 107 |
+
| warmup | `warmup_steps` | `1024` |
|
| 108 |
+
| steps | `max_steps` | `20480` |
|
| 109 |
+
|
| 110 |
+
The learning rate is set to `3e-4` by default, equipped with a cosine scheduler.
|
| 111 |
+
Other scheduler types like WSD (`warmup_stable_decay`)[^2] are also supported.
|
| 112 |
+
|
| 113 |
+
The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as
|
| 114 |
+
`batch_size × gradient_accumulation_steps × context_length × num_gpus_per_node × num_nodes`.
|
| 115 |
+
For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens).
|
| 116 |
+
|
| 117 |
+
The `warmup_steps` parameter indicates the number of steps for the learning rate warmup phase, while `max_steps` represents the maximum number of training steps.
|
| 118 |
+
Each step processes `global_batch_size` tokens.
|
| 119 |
+
Consequently, `512` and `20480` correspond to processing 0.5B and 10B tokens, respectively.
|
| 120 |
+
|
| 121 |
+
:warning: Monitor the value of `global_batch_size`, `warmup_steps`, and `max_steps` carefully when modifying any of the hyperparameters!!
|
| 122 |
+
|
| 123 |
+
`flame` also supports resuming interrupted training by specifying the checkpoint path.
|
| 124 |
+
Simply use the following command:
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
bash train.sh \
|
| 128 |
+
type=gla \
|
| 129 |
+
lr=3e-4 \
|
| 130 |
+
steps=20480 \
|
| 131 |
+
batch=32 \
|
| 132 |
+
update=1 \
|
| 133 |
+
warmup=1024 \
|
| 134 |
+
context=2048 \
|
| 135 |
+
gpus=8 \
|
| 136 |
+
nodes=1 \
|
| 137 |
+
path=exp/gla-340M-10B \
|
| 138 |
+
project=fla \
|
| 139 |
+
model=configs/gla_340M.json \
|
| 140 |
+
data=HuggingFaceFW/fineweb-edu \
|
| 141 |
+
name=sample-10BT \
|
| 142 |
+
cache=data/HuggingFaceFW/fineweb-edu/sample-10BT/train \
|
| 143 |
+
checkpoint=exp/gla-340M-10B/checkpoint-8192
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
You can also use `wandb` to monitor your training process effectively.
|
| 147 |
+
|
| 148 |
+

|
| 149 |
+
|
| 150 |
+
## Continual Pretraining
|
| 151 |
+
|
| 152 |
+
`flame` supports continual training from a pretrained checkpoint.
|
| 153 |
+
Below, we provide an example of how to finetune Mistral-7B to GLA.
|
| 154 |
+
You can follow similar steps to reproduce the results in the [GSA paper](https://arxiv.org/abs/2409.07146):
|
| 155 |
+
|
| 156 |
+
1. Initialize a brand-new GLA-7B model from the config and copy the mathced pretrained weights from Mistral-7B:
|
| 157 |
+
```bash
|
| 158 |
+
cd ../utils
|
| 159 |
+
python convert_from_llama.py \
|
| 160 |
+
--model mistralai/Mistral-7B-v0.1 \
|
| 161 |
+
--config ../training/configs/gla_7B.json \
|
| 162 |
+
--output ../training/converted/gla-7B
|
| 163 |
+
cd -
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
2. Directly launch training from the converted checkpoint:
|
| 167 |
+
```bash
|
| 168 |
+
bash train.sh \
|
| 169 |
+
type=gla \
|
| 170 |
+
lr=3e-5 \
|
| 171 |
+
steps=10240 \
|
| 172 |
+
batch=4 \
|
| 173 |
+
update=8 \
|
| 174 |
+
warmup=512 \
|
| 175 |
+
context=2048 \
|
| 176 |
+
path=exp/gla-7B-20B \
|
| 177 |
+
project=fla \
|
| 178 |
+
model=converted/gla-7B \
|
| 179 |
+
data=SlimPajama-627B \
|
| 180 |
+
cache=data/SlimPajama-627B/train
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
Please be aware that finetuning on a single node may not be the most efficient approach.
|
| 184 |
+
If available, consider leveraging multi-node GPUs for optimal performance.
|
| 185 |
+
You can find guidance on how to launch a multi-node job in the [accelerate tutorial](https://github.com/huggingface/accelerate/blob/main/examples/slurm/submit_multinode.sh).
|
| 186 |
+
|
| 187 |
+
[^1]: The `accelerate` library supports various distributed frameworks, like `deepspeed` and `megatron` for large-scale training. We use `deepspeed` in our case.
|
| 188 |
+
[^2]: https://arxiv.org/abs/2404.06395
|
fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dkv.cpython-310.pyc
ADDED
|
Binary file (5.47 kB). View file
|
|
|
fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc
ADDED
|
Binary file (2.26 kB). View file
|
|
|
fla3/ops/retention/__pycache__/parallel.cpython-312.pyc
ADDED
|
Binary file (3.75 kB). View file
|
|
|
fla3/ops/rwkv7/__pycache__/chunk.cpython-312.pyc
ADDED
|
Binary file (2.56 kB). View file
|
|
|
fla3/ops/rwkv7/__pycache__/fused_addcmul.cpython-310.pyc
ADDED
|
Binary file (6.37 kB). View file
|
|
|
fla3/ops/rwkv7/__pycache__/fused_k_update.cpython-310.pyc
ADDED
|
Binary file (3.93 kB). View file
|
|
|
fla3/ops/rwkv7/fused_recurrent.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
from fla.ops.generalized_delta_rule import fused_recurrent_dplr_delta_rule
|
| 13 |
+
from fla.ops.utils.op import exp
|
| 14 |
+
from fla.utils import input_guard, use_cuda_graph
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.heuristics({
|
| 18 |
+
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
|
| 19 |
+
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
|
| 20 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 21 |
+
})
|
| 22 |
+
@triton.autotune(
|
| 23 |
+
configs=[
|
| 24 |
+
triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
|
| 25 |
+
for BV in [16, 32, 64]
|
| 26 |
+
for num_warps in [2, 4, 8, 16, 32]
|
| 27 |
+
for num_stages in [2, 3, 4]
|
| 28 |
+
],
|
| 29 |
+
key=['BK'],
|
| 30 |
+
use_cuda_graph=use_cuda_graph,
|
| 31 |
+
)
|
| 32 |
+
@triton.jit(do_not_specialize=['T'])
|
| 33 |
+
def fused_recurrent_rwkv7_fwd_kernel(
|
| 34 |
+
r,
|
| 35 |
+
w,
|
| 36 |
+
k,
|
| 37 |
+
v,
|
| 38 |
+
kk,
|
| 39 |
+
a,
|
| 40 |
+
o,
|
| 41 |
+
h0,
|
| 42 |
+
ht,
|
| 43 |
+
cu_seqlens,
|
| 44 |
+
scale,
|
| 45 |
+
T,
|
| 46 |
+
B: tl.constexpr,
|
| 47 |
+
H: tl.constexpr,
|
| 48 |
+
K: tl.constexpr,
|
| 49 |
+
V: tl.constexpr,
|
| 50 |
+
BK: tl.constexpr,
|
| 51 |
+
BV: tl.constexpr,
|
| 52 |
+
REVERSE: tl.constexpr,
|
| 53 |
+
USE_INITIAL_STATE: tl.constexpr,
|
| 54 |
+
STORE_FINAL_STATE: tl.constexpr,
|
| 55 |
+
IS_VARLEN: tl.constexpr,
|
| 56 |
+
IS_DECODE: tl.constexpr,
|
| 57 |
+
):
|
| 58 |
+
i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
|
| 59 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 60 |
+
|
| 61 |
+
if IS_VARLEN:
|
| 62 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
| 63 |
+
T = eos - bos
|
| 64 |
+
else:
|
| 65 |
+
bos, eos = i_n * T, i_n * T + T
|
| 66 |
+
|
| 67 |
+
o_k = tl.arange(0, BK)
|
| 68 |
+
o_v = i_v * BV + tl.arange(0, BV)
|
| 69 |
+
p_r = r + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 70 |
+
p_w = w + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 71 |
+
p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 72 |
+
p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
|
| 73 |
+
p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 74 |
+
p_kk = kk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 75 |
+
|
| 76 |
+
p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
|
| 77 |
+
|
| 78 |
+
mask_k = o_k < K
|
| 79 |
+
mask_v = o_v < V
|
| 80 |
+
mask_h = mask_k[None, :] & mask_v[:, None]
|
| 81 |
+
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
| 82 |
+
|
| 83 |
+
if USE_INITIAL_STATE:
|
| 84 |
+
p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
|
| 85 |
+
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
| 86 |
+
|
| 87 |
+
if IS_DECODE:
|
| 88 |
+
b_r = tl.load(p_r, mask=mask_k, other=0).to(tl.float32) * scale
|
| 89 |
+
b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32)
|
| 90 |
+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
| 91 |
+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
| 92 |
+
b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
|
| 93 |
+
b_kk = tl.load(p_kk, mask=mask_k, other=0).to(tl.float32)
|
| 94 |
+
b_act_a = -b_kk
|
| 95 |
+
b_b = b_kk * b_a
|
| 96 |
+
|
| 97 |
+
tmp = tl.sum(b_h * b_act_a[None, :], axis=1)
|
| 98 |
+
b_h = exp(b_w)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
|
| 99 |
+
b_o = tl.sum(b_h * b_r[None, :], axis=1)
|
| 100 |
+
|
| 101 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
| 102 |
+
else:
|
| 103 |
+
for _ in range(0, T):
|
| 104 |
+
b_r = tl.load(p_r, mask=mask_k, other=0).to(tl.float32) * scale
|
| 105 |
+
b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32)
|
| 106 |
+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
| 107 |
+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
| 108 |
+
b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
|
| 109 |
+
b_kk = tl.load(p_kk, mask=mask_k, other=0).to(tl.float32)
|
| 110 |
+
b_act_a = -b_kk
|
| 111 |
+
b_b = b_kk * b_a
|
| 112 |
+
|
| 113 |
+
tmp = tl.sum(b_h * b_act_a[None, :], axis=1)
|
| 114 |
+
b_h = exp(b_w)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
|
| 115 |
+
b_o = tl.sum(b_h * b_r[None, :], axis=1)
|
| 116 |
+
|
| 117 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
| 118 |
+
p_r += (-1 if REVERSE else 1) * H*K
|
| 119 |
+
p_w += (-1 if REVERSE else 1) * H*K
|
| 120 |
+
p_k += (-1 if REVERSE else 1) * H*K
|
| 121 |
+
p_v += (-1 if REVERSE else 1) * H*V
|
| 122 |
+
p_a += (-1 if REVERSE else 1) * H*K
|
| 123 |
+
p_kk += (-1 if REVERSE else 1) * H*K
|
| 124 |
+
p_o += (-1 if REVERSE else 1) * H*V
|
| 125 |
+
|
| 126 |
+
if STORE_FINAL_STATE:
|
| 127 |
+
p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
|
| 128 |
+
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@input_guard
|
| 132 |
+
def fused_recurrent_rwkv7_fwd(
|
| 133 |
+
r: torch.Tensor,
|
| 134 |
+
w: torch.Tensor,
|
| 135 |
+
k: torch.Tensor,
|
| 136 |
+
v: torch.Tensor,
|
| 137 |
+
kk: torch.Tensor,
|
| 138 |
+
a: torch.Tensor,
|
| 139 |
+
scale: Optional[float] = 1.0,
|
| 140 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 141 |
+
output_final_state: bool = False,
|
| 142 |
+
reverse: bool = False,
|
| 143 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 144 |
+
):
|
| 145 |
+
B, T, H, K, V = *k.shape, v.shape[-1]
|
| 146 |
+
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
| 147 |
+
BK = triton.next_power_of_2(K)
|
| 148 |
+
IS_DECODE = (T == 1)
|
| 149 |
+
|
| 150 |
+
h0 = initial_state
|
| 151 |
+
if not output_final_state:
|
| 152 |
+
ht = None
|
| 153 |
+
else:
|
| 154 |
+
ht = r.new_empty(N, H, K, V, dtype=torch.float32)
|
| 155 |
+
o = torch.empty_like(v)
|
| 156 |
+
|
| 157 |
+
def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
|
| 158 |
+
fused_recurrent_rwkv7_fwd_kernel[grid](
|
| 159 |
+
r,
|
| 160 |
+
w,
|
| 161 |
+
k,
|
| 162 |
+
v,
|
| 163 |
+
kk,
|
| 164 |
+
a,
|
| 165 |
+
o,
|
| 166 |
+
h0,
|
| 167 |
+
ht,
|
| 168 |
+
cu_seqlens,
|
| 169 |
+
scale,
|
| 170 |
+
T=T,
|
| 171 |
+
B=B,
|
| 172 |
+
H=H,
|
| 173 |
+
K=K,
|
| 174 |
+
V=V,
|
| 175 |
+
BK=BK,
|
| 176 |
+
REVERSE=reverse,
|
| 177 |
+
IS_DECODE=IS_DECODE
|
| 178 |
+
)
|
| 179 |
+
return o, ht
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def fused_recurrent_rwkv7(
|
| 183 |
+
r: torch.Tensor,
|
| 184 |
+
w: torch.Tensor,
|
| 185 |
+
k: torch.Tensor,
|
| 186 |
+
v: torch.Tensor,
|
| 187 |
+
a: torch.Tensor,
|
| 188 |
+
b: torch.Tensor,
|
| 189 |
+
scale: float = 1.0,
|
| 190 |
+
initial_state: torch.Tensor = None,
|
| 191 |
+
output_final_state: bool = True,
|
| 192 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 193 |
+
head_first: bool = False,
|
| 194 |
+
):
|
| 195 |
+
"""
|
| 196 |
+
Args:
|
| 197 |
+
r (torch.Tensor):
|
| 198 |
+
r of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 199 |
+
w (torch.Tensor):
|
| 200 |
+
log decay of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 201 |
+
k (torch.Tensor):
|
| 202 |
+
k of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 203 |
+
v (torch.Tensor):
|
| 204 |
+
v of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
| 205 |
+
a (torch.Tensor):
|
| 206 |
+
a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 207 |
+
b (torch.Tensor):
|
| 208 |
+
b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 209 |
+
scale (float):
|
| 210 |
+
scale of the attention.
|
| 211 |
+
initial_state (torch.Tensor):
|
| 212 |
+
initial state of shape `[B, H, K, V]` if cu_seqlens is None else `[N, H, K, V]` where N = len(cu_seqlens) - 1.
|
| 213 |
+
output_final_state (bool):
|
| 214 |
+
whether to output the final state.
|
| 215 |
+
cu_seqlens (torch.LongTensor):
|
| 216 |
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
| 217 |
+
consistent with the FlashAttention API.
|
| 218 |
+
head_first (bool):
|
| 219 |
+
whether to use head first. Recommended to be False to avoid extra transposes.
|
| 220 |
+
Default: `False`.
|
| 221 |
+
"""
|
| 222 |
+
return fused_recurrent_dplr_delta_rule(
|
| 223 |
+
q=r,
|
| 224 |
+
k=k,
|
| 225 |
+
v=v,
|
| 226 |
+
a=a,
|
| 227 |
+
b=b,
|
| 228 |
+
gk=w,
|
| 229 |
+
scale=scale,
|
| 230 |
+
initial_state=initial_state,
|
| 231 |
+
output_final_state=output_final_state,
|
| 232 |
+
cu_seqlens=cu_seqlens,
|
| 233 |
+
head_first=head_first,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def fused_mul_recurrent_rwkv7(
|
| 238 |
+
r: torch.Tensor,
|
| 239 |
+
w: torch.Tensor,
|
| 240 |
+
k: torch.Tensor,
|
| 241 |
+
v: torch.Tensor,
|
| 242 |
+
kk: torch.Tensor,
|
| 243 |
+
a: torch.Tensor,
|
| 244 |
+
scale: Optional[float] = 1.0,
|
| 245 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 246 |
+
output_final_state: bool = False,
|
| 247 |
+
reverse: bool = False,
|
| 248 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 249 |
+
head_first: bool = False,
|
| 250 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 251 |
+
r"""
|
| 252 |
+
This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
r (torch.Tensor):
|
| 256 |
+
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 257 |
+
w (torch.Tensor):
|
| 258 |
+
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 259 |
+
k (torch.Tensor):
|
| 260 |
+
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
| 261 |
+
v (torch.Tensor):
|
| 262 |
+
a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 263 |
+
kk (torch.Tensor):
|
| 264 |
+
b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 265 |
+
a (torch.Tensor):
|
| 266 |
+
gk of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. decay term in log space!
|
| 267 |
+
scale (Optional[int]):
|
| 268 |
+
Scale factor for the RetNet attention scores.
|
| 269 |
+
If not provided, it will default to `1 / sqrt(K)`. Default: 1.
|
| 270 |
+
initial_state (Optional[torch.Tensor]):
|
| 271 |
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
| 272 |
+
For equal-length input sequences, `N` equals the batch size `B`.
|
| 273 |
+
Default: `None`.
|
| 274 |
+
output_final_state (Optional[bool]):
|
| 275 |
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
| 276 |
+
reverse (Optional[bool]):
|
| 277 |
+
If `True`, process the state passing in reverse order. Default: `False`.
|
| 278 |
+
cu_seqlens (Optional[torch.Tensor]):
|
| 279 |
+
Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
|
| 280 |
+
consistent with the FlashAttention API.
|
| 281 |
+
head_first (Optional[bool]):
|
| 282 |
+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
| 283 |
+
Default: `False`.
|
| 284 |
+
"""
|
| 285 |
+
if head_first:
|
| 286 |
+
raise DeprecationWarning(
|
| 287 |
+
"head_first is deprecated and will be removed in a future version. "
|
| 288 |
+
"Please use head_first=False for now instead."
|
| 289 |
+
)
|
| 290 |
+
r, w, k, v, kk, a = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (r, w, k, v, kk, a))
|
| 291 |
+
if not head_first and r.shape[1] < r.shape[2]:
|
| 292 |
+
warnings.warn(
|
| 293 |
+
f"Input tensor shape suggests potential format mismatch: seq_len ({r.shape[1]}) < num_heads ({r.shape[2]}). "
|
| 294 |
+
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
| 295 |
+
"when head_first=False was specified. "
|
| 296 |
+
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
| 297 |
+
)
|
| 298 |
+
if cu_seqlens is not None:
|
| 299 |
+
if r.shape[0] != 1:
|
| 300 |
+
raise ValueError(
|
| 301 |
+
f"The batch size is expected to be 1 rather than {r.shape[0]} when using `cu_seqlens`."
|
| 302 |
+
f"Please flatten variable-length inputs before processing."
|
| 303 |
+
)
|
| 304 |
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
| 305 |
+
raise ValueError(
|
| 306 |
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
| 307 |
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
| 308 |
+
)
|
| 309 |
+
if scale is None:
|
| 310 |
+
scale = r.shape[-1] ** -0.5
|
| 311 |
+
else:
|
| 312 |
+
assert scale > 0, "scale must be positive"
|
| 313 |
+
o, final_state = fused_recurrent_rwkv7_fwd(
|
| 314 |
+
r,
|
| 315 |
+
w,
|
| 316 |
+
k,
|
| 317 |
+
v,
|
| 318 |
+
kk,
|
| 319 |
+
a,
|
| 320 |
+
scale,
|
| 321 |
+
initial_state,
|
| 322 |
+
output_final_state,
|
| 323 |
+
reverse,
|
| 324 |
+
cu_seqlens,
|
| 325 |
+
)
|
| 326 |
+
if head_first:
|
| 327 |
+
o = rearrange(o, 'b t h ... -> b h t ...')
|
| 328 |
+
return o, final_state
|
fla3/ops/simple_gla/__pycache__/chunk.cpython-312.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-310.pyc
ADDED
|
Binary file (4.13 kB). View file
|
|
|
fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
fla3/ops/simple_gla/__pycache__/parallel.cpython-312.pyc
ADDED
|
Binary file (35.6 kB). View file
|
|
|
fla3/ops/simple_gla/fused_recurrent.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from fla.ops.common.fused_recurrent import fused_recurrent
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def fused_recurrent_simple_gla(
|
| 12 |
+
q: torch.Tensor,
|
| 13 |
+
k: torch.Tensor,
|
| 14 |
+
v: torch.Tensor,
|
| 15 |
+
g: torch.Tensor,
|
| 16 |
+
scale: Optional[float] = None,
|
| 17 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 18 |
+
output_final_state: bool = False,
|
| 19 |
+
reverse: bool = False,
|
| 20 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 21 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 22 |
+
r"""
|
| 23 |
+
Args:
|
| 24 |
+
q (torch.Tensor):
|
| 25 |
+
queries of shape `[B, T, H, K]`.
|
| 26 |
+
k (torch.Tensor):
|
| 27 |
+
keys of shape `[B, T, H, K]`.
|
| 28 |
+
v (torch.Tensor):
|
| 29 |
+
values of shape `[B, T, H, V]`.
|
| 30 |
+
g (torch.Tensor):
|
| 31 |
+
Forget gates of shape `[B, T, H]`.
|
| 32 |
+
Compared to GLA, the gating is head-wise instead of elementwise.
|
| 33 |
+
scale (Optional[int]):
|
| 34 |
+
Scale factor for the attention scores.
|
| 35 |
+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
| 36 |
+
initial_state (Optional[torch.Tensor]):
|
| 37 |
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
| 38 |
+
For equal-length input sequences, `N` equals the batch size `B`.
|
| 39 |
+
Default: `None`.
|
| 40 |
+
output_final_state (Optional[bool]):
|
| 41 |
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
| 42 |
+
reverse (Optional[bool]):
|
| 43 |
+
If `True`, process the state passing in reverse order. Default: `False`.
|
| 44 |
+
cu_seqlens (torch.LongTensor):
|
| 45 |
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
| 46 |
+
consistent with the FlashAttention API.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
o (torch.Tensor):
|
| 50 |
+
Outputs of shape `[B, T, H, V]`.
|
| 51 |
+
final_state (torch.Tensor):
|
| 52 |
+
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
| 53 |
+
|
| 54 |
+
Examples::
|
| 55 |
+
>>> import torch
|
| 56 |
+
>>> import torch.nn.functional as F
|
| 57 |
+
>>> from einops import rearrange
|
| 58 |
+
>>> from fla.ops.simple_gla import fused_recurrent_simple_gla
|
| 59 |
+
# inputs with equal lengths
|
| 60 |
+
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
| 61 |
+
>>> q = torch.randn(B, T, H, K, device='cuda')
|
| 62 |
+
>>> k = torch.randn(B, T, H, K, device='cuda')
|
| 63 |
+
>>> v = torch.randn(B, T, H, V, device='cuda')
|
| 64 |
+
>>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
|
| 65 |
+
>>> h0 = torch.randn(B, H, K, V, device='cuda')
|
| 66 |
+
>>> o, ht = fused_recurrent_simple_gla(
|
| 67 |
+
q, k, v, g,
|
| 68 |
+
initial_state=h0,
|
| 69 |
+
output_final_state=True
|
| 70 |
+
)
|
| 71 |
+
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
| 72 |
+
>>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
|
| 73 |
+
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
| 74 |
+
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
| 75 |
+
>>> o_var, ht_var = fused_recurrent_simple_gla(
|
| 76 |
+
q, k, v, g,
|
| 77 |
+
initial_state=h0,
|
| 78 |
+
output_final_state=True,
|
| 79 |
+
cu_seqlens=cu_seqlens
|
| 80 |
+
)
|
| 81 |
+
>>> assert o.allclose(o_var.view(o.shape))
|
| 82 |
+
>>> assert ht.allclose(ht_var)
|
| 83 |
+
"""
|
| 84 |
+
if cu_seqlens is not None:
|
| 85 |
+
if q.shape[0] != 1:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
| 88 |
+
f"Please flatten variable-length inputs before processing."
|
| 89 |
+
)
|
| 90 |
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
| 93 |
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
| 94 |
+
)
|
| 95 |
+
if scale is None:
|
| 96 |
+
scale = k.shape[-1] ** -0.5
|
| 97 |
+
o, final_state = fused_recurrent(
|
| 98 |
+
q=q,
|
| 99 |
+
k=k,
|
| 100 |
+
v=v,
|
| 101 |
+
g=g,
|
| 102 |
+
scale=scale,
|
| 103 |
+
initial_state=initial_state,
|
| 104 |
+
output_final_state=output_final_state,
|
| 105 |
+
reverse=reverse,
|
| 106 |
+
cu_seqlens=cu_seqlens
|
| 107 |
+
)
|
| 108 |
+
return o, final_state
|
fla3/ops/simple_gla/naive.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def torch_simple_gla(q, k, v, g, chunk_size=64, scale=None):
|
| 8 |
+
if scale is None:
|
| 9 |
+
scale = (q.shape[-1] ** -0.5)
|
| 10 |
+
q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale
|
| 11 |
+
k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
|
| 12 |
+
v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
|
| 13 |
+
g = rearrange(g, 'b h (n c) -> b h n c', c=chunk_size)
|
| 14 |
+
g = g.cumsum(-1)
|
| 15 |
+
kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None])
|
| 16 |
+
S = torch.zeros_like(kv)
|
| 17 |
+
|
| 18 |
+
for i in range(1, g.shape[-2]):
|
| 19 |
+
S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1]
|
| 20 |
+
|
| 21 |
+
inter = (q * g[..., None].exp()) @ S
|
| 22 |
+
attn = q @ k.transpose(-1, -2)
|
| 23 |
+
attn = attn * (g[..., None] - g[..., None, :]).exp()
|
| 24 |
+
attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
|
| 25 |
+
intra = attn @ v
|
| 26 |
+
o = inter + intra
|
| 27 |
+
return rearrange(o, 'b h n c d -> b h (n c) d')
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def torch_simple_gla_recurrent(q, k, v, g, scale=None, initial_state=None, output_final_state=True):
|
| 31 |
+
B, H, T, DK = q.shape
|
| 32 |
+
original_dtype = q.dtype
|
| 33 |
+
q, k, v, g = q.float(), k.float(), v.float(), g.float()
|
| 34 |
+
if scale is None:
|
| 35 |
+
scale = DK ** -0.5
|
| 36 |
+
q = q * scale
|
| 37 |
+
_, _, _, DV = v.shape
|
| 38 |
+
if initial_state is None:
|
| 39 |
+
S = torch.zeros(B, H, DK, DV)
|
| 40 |
+
else:
|
| 41 |
+
S = initial_state
|
| 42 |
+
o = torch.zeros(B, H, T, DV).to(q)
|
| 43 |
+
for i in range(T):
|
| 44 |
+
gate = g[:, :, i].exp()
|
| 45 |
+
key = k[:, :, i]
|
| 46 |
+
value = v[:, :, i]
|
| 47 |
+
kv = key.unsqueeze(-1) * value.unsqueeze(-2)
|
| 48 |
+
S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv
|
| 49 |
+
q_i = q[:, :, i, :]
|
| 50 |
+
o_i = (q_i.unsqueeze(-1) * S).sum(-2)
|
| 51 |
+
o[:, :, i] = o_i
|
| 52 |
+
if not output_final_state:
|
| 53 |
+
S = None
|
| 54 |
+
return o.to(original_dtype), S
|
fla3/ops/ttt/naive.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def ttt_linear(
|
| 9 |
+
q: torch.Tensor,
|
| 10 |
+
k: torch.Tensor,
|
| 11 |
+
v: torch.Tensor,
|
| 12 |
+
w: torch.Tensor,
|
| 13 |
+
b: torch.Tensor,
|
| 14 |
+
eta: torch.Tensor,
|
| 15 |
+
scale: float,
|
| 16 |
+
eps: float,
|
| 17 |
+
mini_batch_size: int,
|
| 18 |
+
initial_state: torch.Tensor,
|
| 19 |
+
initial_state_bias: torch.Tensor,
|
| 20 |
+
output_final_state: bool
|
| 21 |
+
):
|
| 22 |
+
B, H, T, D = q.shape
|
| 23 |
+
BT = mini_batch_size
|
| 24 |
+
NT = T // BT
|
| 25 |
+
# [NT, B, H, mini_batch_size, D]
|
| 26 |
+
_q = q.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
|
| 27 |
+
_k = k.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
|
| 28 |
+
_v = v.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
|
| 29 |
+
# [NT, B, H, BT, 1]
|
| 30 |
+
_eta = eta.reshape(B, H, NT, BT, 1).permute(2, 0, 1, 3, 4)
|
| 31 |
+
# [H, 1, D]
|
| 32 |
+
w = w.reshape(H, 1, D).to(torch.float32)
|
| 33 |
+
b = b.reshape(H, 1, D).to(torch.float32)
|
| 34 |
+
|
| 35 |
+
h = torch.zeros((B, H, D, D), device=v.device, dtype=torch.float32) if initial_state is None else initial_state
|
| 36 |
+
hb = torch.zeros((B, H, 1, D), device=v.device, dtype=torch.float32) if initial_state_bias is None else initial_state_bias
|
| 37 |
+
q *= scale
|
| 38 |
+
# [NT, B, H, BT, D]
|
| 39 |
+
o = torch.empty_like(_v)
|
| 40 |
+
|
| 41 |
+
for i in range(NT):
|
| 42 |
+
q_i, k_i, v_i, eta_i = [x[i] for x in [_q, _k, _v, _eta]]
|
| 43 |
+
kh = k_i @ h + hb
|
| 44 |
+
reconstruction_target = v_i - k_i
|
| 45 |
+
|
| 46 |
+
mean = kh.mean(-1, True)
|
| 47 |
+
var = kh.var(-1, unbiased=False, keepdim=True).to(torch.float32)
|
| 48 |
+
rstd = torch.sqrt(var + eps).to(torch.float32)
|
| 49 |
+
kh_hat = (kh - mean) / rstd
|
| 50 |
+
|
| 51 |
+
g = w * kh_hat + b - reconstruction_target
|
| 52 |
+
g *= w
|
| 53 |
+
v_new = (D * g - g.sum(-1, True) - kh_hat * (g * kh_hat).sum(-1, True)) / (rstd * D)
|
| 54 |
+
|
| 55 |
+
Attn = torch.tril(q_i @ k_i.transpose(-2, -1))
|
| 56 |
+
o_i = q_i @ h - (eta_i * Attn) @ v_new + hb - torch.tril(eta_i.expand_as(Attn)) @ v_new
|
| 57 |
+
h = h - (eta_i[:, :, -1, :, None] * k_i).transpose(-1, -2) @ v_new
|
| 58 |
+
hb = hb - torch.sum(eta_i[:, :, -1, :, None] * v_new, dim=-2, keepdim=True)
|
| 59 |
+
# layer norm with residuals
|
| 60 |
+
|
| 61 |
+
mean = o_i.mean(dim=-1, keepdim=True)
|
| 62 |
+
var = o_i.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
|
| 63 |
+
rstd = torch.sqrt(var + eps).to(torch.float32)
|
| 64 |
+
o[i] = o_i + (o_i - mean) / rstd * w + b
|
| 65 |
+
|
| 66 |
+
# [B, H, T, D]
|
| 67 |
+
o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D)
|
| 68 |
+
h = h if output_final_state else None
|
| 69 |
+
hb = hb if output_final_state else None
|
| 70 |
+
return o, h, hb
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def chunk_ttt_linear_ref(
|
| 74 |
+
q: torch.Tensor,
|
| 75 |
+
k: torch.Tensor,
|
| 76 |
+
v: torch.Tensor,
|
| 77 |
+
w: torch.Tensor,
|
| 78 |
+
b: torch.Tensor,
|
| 79 |
+
eta: torch.Tensor,
|
| 80 |
+
scale: float = None,
|
| 81 |
+
eps: float = 1e-6,
|
| 82 |
+
mini_batch_size: int = 16,
|
| 83 |
+
initial_state: torch.Tensor = None,
|
| 84 |
+
initial_state_bias: torch.Tensor = None,
|
| 85 |
+
output_final_state: bool = False,
|
| 86 |
+
head_first: bool = False,
|
| 87 |
+
):
|
| 88 |
+
assert q.dtype == k.dtype == v.dtype
|
| 89 |
+
assert k.shape[-1] == v.shape[-1], "The key and value dimension must be the same."
|
| 90 |
+
if isinstance(eta, float):
|
| 91 |
+
eta = torch.full_like(q[:, :, :, :1], eta)
|
| 92 |
+
if scale is None:
|
| 93 |
+
scale = k.shape[-1] ** -0.5
|
| 94 |
+
if not head_first:
|
| 95 |
+
q = q.transpose(1, 2)
|
| 96 |
+
k = k.transpose(1, 2)
|
| 97 |
+
v = v.transpose(1, 2)
|
| 98 |
+
eta = eta.transpose(1, 2)
|
| 99 |
+
T = q.shape[-2]
|
| 100 |
+
padded = (mini_batch_size - (T % mini_batch_size)) % mini_batch_size
|
| 101 |
+
if padded > 0:
|
| 102 |
+
q = F.pad(q, (0, 0, 0, padded))
|
| 103 |
+
k = F.pad(k, (0, 0, 0, padded))
|
| 104 |
+
v = F.pad(v, (0, 0, 0, padded))
|
| 105 |
+
eta = F.pad(eta, (0, 0, 0, padded))
|
| 106 |
+
eta[:, :, -1, :] = eta[:, :, -(padded+1), :]
|
| 107 |
+
assert q.shape[-2] % mini_batch_size == 0, "Sequence length should be a multiple of mini_batch_size."
|
| 108 |
+
q, k, v, eta, w, b = map(lambda x: x.to(torch.float32), [q, k, v, eta, w, b])
|
| 109 |
+
o, final_state, final_state_bias = ttt_linear(
|
| 110 |
+
q,
|
| 111 |
+
k,
|
| 112 |
+
v,
|
| 113 |
+
w,
|
| 114 |
+
b,
|
| 115 |
+
eta,
|
| 116 |
+
scale,
|
| 117 |
+
eps,
|
| 118 |
+
mini_batch_size,
|
| 119 |
+
initial_state,
|
| 120 |
+
initial_state_bias,
|
| 121 |
+
output_final_state,
|
| 122 |
+
)
|
| 123 |
+
o = o[:, :, :T, :].contiguous()
|
| 124 |
+
if not head_first:
|
| 125 |
+
o = o.transpose(1, 2)
|
| 126 |
+
return o, final_state, final_state_bias
|
fla3/ops/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|
fla3/ops/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.2 kB). View file
|
|
|
fla3/ops/utils/__pycache__/asm.cpython-310.pyc
ADDED
|
Binary file (482 Bytes). View file
|
|
|
fla3/ops/utils/__pycache__/asm.cpython-312.pyc
ADDED
|
Binary file (543 Bytes). View file
|
|
|
fla3/ops/utils/__pycache__/cumsum.cpython-310.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc
ADDED
|
Binary file (21.4 kB). View file
|
|
|
fla3/ops/utils/__pycache__/index.cpython-310.pyc
ADDED
|
Binary file (3.12 kB). View file
|
|
|
fla3/ops/utils/__pycache__/index.cpython-312.pyc
ADDED
|
Binary file (5.48 kB). View file
|
|
|
fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
fla3/ops/utils/__pycache__/logsumexp.cpython-310.pyc
ADDED
|
Binary file (2.25 kB). View file
|
|
|
fla3/ops/utils/__pycache__/logsumexp.cpython-312.pyc
ADDED
|
Binary file (3.66 kB). View file
|
|
|
fla3/ops/utils/__pycache__/matmul.cpython-310.pyc
ADDED
|
Binary file (5.29 kB). View file
|
|
|
fla3/ops/utils/__pycache__/op.cpython-312.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
fla3/ops/utils/__pycache__/pooling.cpython-310.pyc
ADDED
|
Binary file (5.61 kB). View file
|
|
|
fla3/ops/utils/cumsum.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
|
| 11 |
+
from ...ops.utils.index import prepare_chunk_indices
|
| 12 |
+
from ...utils import check_shared_mem, input_guard
|
| 13 |
+
|
| 14 |
+
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.heuristics({
|
| 18 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 19 |
+
})
|
| 20 |
+
@triton.autotune(
|
| 21 |
+
configs=[
|
| 22 |
+
triton.Config({}, num_warps=num_warps)
|
| 23 |
+
for num_warps in [1, 2, 4, 8]
|
| 24 |
+
],
|
| 25 |
+
key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE']
|
| 26 |
+
)
|
| 27 |
+
@triton.jit(do_not_specialize=['T'])
|
| 28 |
+
def chunk_local_cumsum_scalar_kernel(
|
| 29 |
+
s,
|
| 30 |
+
o,
|
| 31 |
+
cu_seqlens,
|
| 32 |
+
chunk_indices,
|
| 33 |
+
T,
|
| 34 |
+
B: tl.constexpr,
|
| 35 |
+
H: tl.constexpr,
|
| 36 |
+
BT: tl.constexpr,
|
| 37 |
+
REVERSE: tl.constexpr,
|
| 38 |
+
IS_VARLEN: tl.constexpr,
|
| 39 |
+
HEAD_FIRST: tl.constexpr,
|
| 40 |
+
):
|
| 41 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 42 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 43 |
+
if IS_VARLEN:
|
| 44 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 45 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 46 |
+
T = eos - bos
|
| 47 |
+
else:
|
| 48 |
+
bos, eos = i_b * T, i_b * T + T
|
| 49 |
+
|
| 50 |
+
if HEAD_FIRST:
|
| 51 |
+
p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
| 52 |
+
p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
| 53 |
+
else:
|
| 54 |
+
p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
| 55 |
+
p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
| 56 |
+
# [BT]
|
| 57 |
+
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
|
| 58 |
+
b_o = tl.cumsum(b_s, axis=0)
|
| 59 |
+
if REVERSE:
|
| 60 |
+
b_z = tl.sum(b_s, axis=0)
|
| 61 |
+
b_o = -b_o + b_z[None] + b_s
|
| 62 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@triton.heuristics({
|
| 66 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 67 |
+
})
|
| 68 |
+
@triton.autotune(
|
| 69 |
+
configs=[
|
| 70 |
+
triton.Config({'BS': BS}, num_warps=num_warps)
|
| 71 |
+
for BS in BS_LIST
|
| 72 |
+
for num_warps in [2, 4, 8]
|
| 73 |
+
],
|
| 74 |
+
key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE']
|
| 75 |
+
)
|
| 76 |
+
@triton.jit(do_not_specialize=['T'])
|
| 77 |
+
def chunk_local_cumsum_vector_kernel(
|
| 78 |
+
s,
|
| 79 |
+
o,
|
| 80 |
+
cu_seqlens,
|
| 81 |
+
chunk_indices,
|
| 82 |
+
T,
|
| 83 |
+
B: tl.constexpr,
|
| 84 |
+
H: tl.constexpr,
|
| 85 |
+
S: tl.constexpr,
|
| 86 |
+
BT: tl.constexpr,
|
| 87 |
+
BS: tl.constexpr,
|
| 88 |
+
REVERSE: tl.constexpr,
|
| 89 |
+
IS_VARLEN: tl.constexpr,
|
| 90 |
+
HEAD_FIRST: tl.constexpr,
|
| 91 |
+
):
|
| 92 |
+
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 93 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 94 |
+
if IS_VARLEN:
|
| 95 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 96 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 97 |
+
T = eos - bos
|
| 98 |
+
else:
|
| 99 |
+
bos, eos = i_b * T, i_b * T + T
|
| 100 |
+
|
| 101 |
+
o_i = tl.arange(0, BT)
|
| 102 |
+
if REVERSE:
|
| 103 |
+
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
|
| 104 |
+
else:
|
| 105 |
+
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
| 106 |
+
|
| 107 |
+
if HEAD_FIRST:
|
| 108 |
+
p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 109 |
+
p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 110 |
+
else:
|
| 111 |
+
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 112 |
+
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 113 |
+
# [BT, BS]
|
| 114 |
+
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
| 115 |
+
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
| 116 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@triton.heuristics({
|
| 120 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 121 |
+
})
|
| 122 |
+
@triton.autotune(
|
| 123 |
+
configs=[
|
| 124 |
+
triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages)
|
| 125 |
+
for BT in [32, 64, 128, 256]
|
| 126 |
+
for num_warps in [2, 4, 8]
|
| 127 |
+
for num_stages in [1, 2, 3, 4]
|
| 128 |
+
],
|
| 129 |
+
key=['B', 'H', 'IS_VARLEN', 'REVERSE']
|
| 130 |
+
)
|
| 131 |
+
@triton.jit(do_not_specialize=['T'])
|
| 132 |
+
def chunk_global_cumsum_scalar_kernel(
|
| 133 |
+
s,
|
| 134 |
+
o,
|
| 135 |
+
cu_seqlens,
|
| 136 |
+
T,
|
| 137 |
+
B: tl.constexpr,
|
| 138 |
+
H: tl.constexpr,
|
| 139 |
+
BT: tl.constexpr,
|
| 140 |
+
REVERSE: tl.constexpr,
|
| 141 |
+
IS_VARLEN: tl.constexpr,
|
| 142 |
+
HEAD_FIRST: tl.constexpr,
|
| 143 |
+
):
|
| 144 |
+
i_nh = tl.program_id(0)
|
| 145 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 146 |
+
if IS_VARLEN:
|
| 147 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 148 |
+
else:
|
| 149 |
+
bos, eos = i_n * T, i_n * T + T
|
| 150 |
+
T = eos - bos
|
| 151 |
+
|
| 152 |
+
b_z = tl.zeros([], dtype=tl.float32)
|
| 153 |
+
NT = tl.cdiv(T, BT)
|
| 154 |
+
for i_c in range(NT):
|
| 155 |
+
i_t = NT-1-i_c if REVERSE else i_c
|
| 156 |
+
if HEAD_FIRST:
|
| 157 |
+
p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
| 158 |
+
p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
| 159 |
+
else:
|
| 160 |
+
p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
| 161 |
+
p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
| 162 |
+
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
|
| 163 |
+
b_o = tl.cumsum(b_s, axis=0)
|
| 164 |
+
b_ss = tl.sum(b_s, 0)
|
| 165 |
+
if REVERSE:
|
| 166 |
+
b_o = -b_o + b_ss + b_s
|
| 167 |
+
b_o += b_z
|
| 168 |
+
if i_c >= 0:
|
| 169 |
+
b_z += b_ss
|
| 170 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@triton.heuristics({
|
| 174 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 175 |
+
})
|
| 176 |
+
@triton.autotune(
|
| 177 |
+
configs=[
|
| 178 |
+
triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages)
|
| 179 |
+
for BT in [16, 32, 64, 128]
|
| 180 |
+
for num_warps in [2, 4, 8]
|
| 181 |
+
for num_stages in [1, 2, 3, 4]
|
| 182 |
+
],
|
| 183 |
+
key=['B', 'H', 'S', 'IS_VARLEN', 'REVERSE']
|
| 184 |
+
)
|
| 185 |
+
@triton.jit(do_not_specialize=['T'])
|
| 186 |
+
def chunk_global_cumsum_vector_kernel(
|
| 187 |
+
s,
|
| 188 |
+
z,
|
| 189 |
+
cu_seqlens,
|
| 190 |
+
T,
|
| 191 |
+
B: tl.constexpr,
|
| 192 |
+
H: tl.constexpr,
|
| 193 |
+
S: tl.constexpr,
|
| 194 |
+
BT: tl.constexpr,
|
| 195 |
+
BS: tl.constexpr,
|
| 196 |
+
REVERSE: tl.constexpr,
|
| 197 |
+
IS_VARLEN: tl.constexpr,
|
| 198 |
+
HEAD_FIRST: tl.constexpr,
|
| 199 |
+
):
|
| 200 |
+
i_s, i_nh = tl.program_id(0), tl.program_id(1)
|
| 201 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 202 |
+
if IS_VARLEN:
|
| 203 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 204 |
+
else:
|
| 205 |
+
bos, eos = i_n * T, i_n * T + T
|
| 206 |
+
T = eos - bos
|
| 207 |
+
|
| 208 |
+
o_i = tl.arange(0, BT)
|
| 209 |
+
if REVERSE:
|
| 210 |
+
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
|
| 211 |
+
else:
|
| 212 |
+
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
| 213 |
+
|
| 214 |
+
b_z = tl.zeros([BS], dtype=tl.float32)
|
| 215 |
+
NT = tl.cdiv(T, BT)
|
| 216 |
+
for i_c in range(NT):
|
| 217 |
+
i_t = NT-1-i_c if REVERSE else i_c
|
| 218 |
+
if HEAD_FIRST:
|
| 219 |
+
p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 220 |
+
p_z = tl.make_block_ptr(z + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 221 |
+
else:
|
| 222 |
+
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 223 |
+
p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 224 |
+
# [BT, BS]
|
| 225 |
+
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
| 226 |
+
b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)
|
| 227 |
+
tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))
|
| 228 |
+
if i_c >= 0:
|
| 229 |
+
b_z += tl.sum(b_s, 0)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def chunk_local_cumsum_scalar(
|
| 233 |
+
g: torch.Tensor,
|
| 234 |
+
chunk_size: int,
|
| 235 |
+
reverse: bool = False,
|
| 236 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 237 |
+
head_first: bool = False,
|
| 238 |
+
output_dtype: Optional[torch.dtype] = torch.float
|
| 239 |
+
) -> torch.Tensor:
|
| 240 |
+
if head_first:
|
| 241 |
+
B, H, T = g.shape
|
| 242 |
+
else:
|
| 243 |
+
B, T, H = g.shape
|
| 244 |
+
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
|
| 245 |
+
BT = chunk_size
|
| 246 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 247 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 248 |
+
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
| 249 |
+
grid = (NT, B * H)
|
| 250 |
+
chunk_local_cumsum_scalar_kernel[grid](
|
| 251 |
+
g_org,
|
| 252 |
+
g,
|
| 253 |
+
cu_seqlens,
|
| 254 |
+
chunk_indices,
|
| 255 |
+
T=T,
|
| 256 |
+
B=B,
|
| 257 |
+
H=H,
|
| 258 |
+
BT=BT,
|
| 259 |
+
HEAD_FIRST=head_first,
|
| 260 |
+
REVERSE=reverse
|
| 261 |
+
)
|
| 262 |
+
return g
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def chunk_local_cumsum_vector(
|
| 266 |
+
g: torch.Tensor,
|
| 267 |
+
chunk_size: int,
|
| 268 |
+
reverse: bool = False,
|
| 269 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 270 |
+
head_first: bool = False,
|
| 271 |
+
output_dtype: Optional[torch.dtype] = torch.float
|
| 272 |
+
) -> torch.Tensor:
|
| 273 |
+
if head_first:
|
| 274 |
+
B, H, T, S = g.shape
|
| 275 |
+
else:
|
| 276 |
+
B, T, H, S = g.shape
|
| 277 |
+
BT = chunk_size
|
| 278 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
| 279 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 280 |
+
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
|
| 281 |
+
|
| 282 |
+
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
| 283 |
+
def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
|
| 284 |
+
# keep cummulative normalizer in fp32
|
| 285 |
+
# this kernel is equivalent to
|
| 286 |
+
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
| 287 |
+
chunk_local_cumsum_vector_kernel[grid](
|
| 288 |
+
g_org,
|
| 289 |
+
g,
|
| 290 |
+
cu_seqlens,
|
| 291 |
+
chunk_indices,
|
| 292 |
+
T=T,
|
| 293 |
+
B=B,
|
| 294 |
+
H=H,
|
| 295 |
+
S=S,
|
| 296 |
+
BT=BT,
|
| 297 |
+
HEAD_FIRST=head_first,
|
| 298 |
+
REVERSE=reverse
|
| 299 |
+
)
|
| 300 |
+
return g
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@input_guard
|
| 304 |
+
def chunk_global_cumsum_scalar(
|
| 305 |
+
s: torch.Tensor,
|
| 306 |
+
reverse: bool = False,
|
| 307 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 308 |
+
head_first: bool = False,
|
| 309 |
+
output_dtype: Optional[torch.dtype] = torch.float
|
| 310 |
+
) -> torch.Tensor:
|
| 311 |
+
if head_first:
|
| 312 |
+
B, H, T = s.shape
|
| 313 |
+
else:
|
| 314 |
+
B, T, H = s.shape
|
| 315 |
+
N = len(cu_seqlens) - 1 if cu_seqlens is not None else B
|
| 316 |
+
|
| 317 |
+
z = torch.empty_like(s, dtype=output_dtype or s.dtype)
|
| 318 |
+
grid = (N * H,)
|
| 319 |
+
chunk_global_cumsum_scalar_kernel[grid](
|
| 320 |
+
s,
|
| 321 |
+
z,
|
| 322 |
+
cu_seqlens,
|
| 323 |
+
T=T,
|
| 324 |
+
B=B,
|
| 325 |
+
H=H,
|
| 326 |
+
HEAD_FIRST=head_first,
|
| 327 |
+
REVERSE=reverse
|
| 328 |
+
)
|
| 329 |
+
return z
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@input_guard
|
| 333 |
+
def chunk_global_cumsum_vector(
|
| 334 |
+
s: torch.Tensor,
|
| 335 |
+
reverse: bool = False,
|
| 336 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 337 |
+
head_first: bool = False,
|
| 338 |
+
output_dtype: Optional[torch.dtype] = torch.float
|
| 339 |
+
) -> torch.Tensor:
|
| 340 |
+
if head_first:
|
| 341 |
+
B, H, T, S = s.shape
|
| 342 |
+
else:
|
| 343 |
+
B, T, H, S = s.shape
|
| 344 |
+
N = len(cu_seqlens) - 1 if cu_seqlens is not None else B
|
| 345 |
+
BS = min(32, triton.next_power_of_2(S))
|
| 346 |
+
|
| 347 |
+
z = torch.empty_like(s, dtype=output_dtype or s.dtype)
|
| 348 |
+
grid = (triton.cdiv(S, BS), N * H)
|
| 349 |
+
chunk_global_cumsum_vector_kernel[grid](
|
| 350 |
+
s,
|
| 351 |
+
z,
|
| 352 |
+
cu_seqlens,
|
| 353 |
+
T=T,
|
| 354 |
+
B=B,
|
| 355 |
+
H=H,
|
| 356 |
+
S=S,
|
| 357 |
+
BS=BS,
|
| 358 |
+
HEAD_FIRST=head_first,
|
| 359 |
+
REVERSE=reverse
|
| 360 |
+
)
|
| 361 |
+
return z
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
@input_guard
|
| 365 |
+
def chunk_global_cumsum(
|
| 366 |
+
s: torch.Tensor,
|
| 367 |
+
reverse: bool = False,
|
| 368 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 369 |
+
head_first: bool = False,
|
| 370 |
+
output_dtype: Optional[torch.dtype] = torch.float
|
| 371 |
+
) -> torch.Tensor:
|
| 372 |
+
if cu_seqlens is not None:
|
| 373 |
+
assert s.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
|
| 374 |
+
if len(s.shape) == 3:
|
| 375 |
+
return chunk_global_cumsum_scalar(s, reverse, cu_seqlens, head_first, output_dtype)
|
| 376 |
+
elif len(s.shape) == 4:
|
| 377 |
+
return chunk_global_cumsum_vector(s, reverse, cu_seqlens, head_first, output_dtype)
|
| 378 |
+
else:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Unsupported input shape {s.shape}. "
|
| 381 |
+
f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` "
|
| 382 |
+
f"or [B, H, T]/[B, H, T, D] otherwise"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
@input_guard
|
| 387 |
+
def chunk_local_cumsum(
|
| 388 |
+
g: torch.Tensor,
|
| 389 |
+
chunk_size: int,
|
| 390 |
+
reverse: bool = False,
|
| 391 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 392 |
+
head_first: bool = False,
|
| 393 |
+
output_dtype: Optional[torch.dtype] = torch.float,
|
| 394 |
+
**kwargs
|
| 395 |
+
) -> torch.Tensor:
|
| 396 |
+
if not head_first and g.shape[1] < g.shape[2]:
|
| 397 |
+
warnings.warn(
|
| 398 |
+
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
|
| 399 |
+
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
| 400 |
+
"when head_first=False was specified. "
|
| 401 |
+
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
| 402 |
+
)
|
| 403 |
+
if cu_seqlens is not None:
|
| 404 |
+
assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
|
| 405 |
+
if len(g.shape) == 3:
|
| 406 |
+
return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype)
|
| 407 |
+
elif len(g.shape) == 4:
|
| 408 |
+
return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype)
|
| 409 |
+
else:
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"Unsupported input shape {g.shape}. "
|
| 412 |
+
f"which should be (B, T, H, D) if `head_first=False` "
|
| 413 |
+
f"or (B, H, T, D) otherwise"
|
| 414 |
+
)
|
fla3/ops/utils/index.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import triton
|
| 7 |
+
import triton.language as tl
|
| 8 |
+
|
| 9 |
+
from ...utils import tensor_cache
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@triton.autotune(
|
| 13 |
+
configs=[
|
| 14 |
+
triton.Config({}, num_warps=num_warps)
|
| 15 |
+
for num_warps in [4, 8, 16, 32]
|
| 16 |
+
],
|
| 17 |
+
key=['B'],
|
| 18 |
+
)
|
| 19 |
+
@triton.jit
|
| 20 |
+
def prepare_position_ids_kernel(
|
| 21 |
+
y,
|
| 22 |
+
cu_seqlens,
|
| 23 |
+
B: tl.constexpr
|
| 24 |
+
):
|
| 25 |
+
i_n = tl.program_id(0)
|
| 26 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 27 |
+
T = eos - bos
|
| 28 |
+
|
| 29 |
+
o = tl.arange(0, B)
|
| 30 |
+
for i in range(0, tl.cdiv(T, B) * B, B):
|
| 31 |
+
o_i = o + i
|
| 32 |
+
tl.store(y + bos + o_i, o_i, o_i < T)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@tensor_cache
|
| 36 |
+
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
| 37 |
+
return cu_seqlens[1:] - cu_seqlens[:-1]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@tensor_cache
|
| 41 |
+
def prepare_lens_from_mask(mask: torch.BoolTensor) -> torch.LongTensor:
|
| 42 |
+
return mask.sum(dim=-1, dtype=torch.int32)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@tensor_cache
|
| 46 |
+
def prepare_cu_seqlens_from_mask(mask: torch.BoolTensor, out_dtype: torch.dtype = torch.int32) -> torch.LongTensor:
|
| 47 |
+
return F.pad(prepare_lens_from_mask(mask).cumsum(dim=0, dtype=out_dtype), (1, 0))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@tensor_cache
|
| 51 |
+
def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
| 52 |
+
return torch.cat([
|
| 53 |
+
torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device)
|
| 54 |
+
for n in prepare_lens(cu_seqlens).unbind()
|
| 55 |
+
])
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@tensor_cache
|
| 59 |
+
def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
| 60 |
+
return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@tensor_cache
|
| 64 |
+
def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
| 65 |
+
position_ids = prepare_position_ids(cu_seqlens)
|
| 66 |
+
return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@tensor_cache
|
| 70 |
+
def prepare_chunk_indices(
|
| 71 |
+
cu_seqlens: torch.LongTensor,
|
| 72 |
+
chunk_size: int
|
| 73 |
+
) -> torch.LongTensor:
|
| 74 |
+
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
|
| 75 |
+
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@tensor_cache
|
| 79 |
+
def prepare_chunk_offsets(
|
| 80 |
+
cu_seqlens: torch.LongTensor,
|
| 81 |
+
chunk_size: int
|
| 82 |
+
) -> torch.LongTensor:
|
| 83 |
+
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
|
fla3/ops/utils/logcumsumexp.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
from ...ops.utils.op import exp, log
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@triton.autotune(
|
| 11 |
+
configs=[
|
| 12 |
+
triton.Config({'BT': BT}, num_warps=num_warps)
|
| 13 |
+
for BT in [16, 32, 64]
|
| 14 |
+
for num_warps in [2, 4, 8]
|
| 15 |
+
],
|
| 16 |
+
key=['S']
|
| 17 |
+
)
|
| 18 |
+
@triton.jit(do_not_specialize=['T'])
|
| 19 |
+
def logcumsumexp_fwd_kernel(
|
| 20 |
+
s,
|
| 21 |
+
z,
|
| 22 |
+
T,
|
| 23 |
+
S: tl.constexpr,
|
| 24 |
+
BT: tl.constexpr
|
| 25 |
+
):
|
| 26 |
+
i_bh = tl.program_id(0)
|
| 27 |
+
o_i = tl.arange(0, BT)
|
| 28 |
+
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
| 29 |
+
|
| 30 |
+
b_mp = tl.full([S,], float('-inf'), dtype=tl.float32)
|
| 31 |
+
b_zp = tl.zeros([S,], dtype=tl.float32)
|
| 32 |
+
for i_t in range(tl.cdiv(T, BT)):
|
| 33 |
+
p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
|
| 34 |
+
p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
|
| 35 |
+
|
| 36 |
+
# [BT, S]
|
| 37 |
+
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
| 38 |
+
# [S,]
|
| 39 |
+
b_mc = tl.max(b_s, 0)
|
| 40 |
+
b_mc = tl.maximum(b_mp, b_mc)
|
| 41 |
+
b_zp = b_zp * exp(b_mp - b_mc)
|
| 42 |
+
# [BT, S]
|
| 43 |
+
b_s = exp(b_s - b_mc)
|
| 44 |
+
b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp
|
| 45 |
+
# [S,]
|
| 46 |
+
b_zc = tl.max(b_z, 0)
|
| 47 |
+
b_mp = b_mc
|
| 48 |
+
b_zp = b_zc
|
| 49 |
+
# [BT, BS]
|
| 50 |
+
# small eps to prevent underflows
|
| 51 |
+
b_z = log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc
|
| 52 |
+
tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1))
|
fla3/ops/utils/matmul.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
# code adapted from
|
| 5 |
+
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import triton
|
| 11 |
+
import triton.language as tl
|
| 12 |
+
|
| 13 |
+
from ...ops.utils.op import exp
|
| 14 |
+
from ...utils import input_guard
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
| 18 |
+
# - A list of `triton.Config` objects that define different configurations of
|
| 19 |
+
# meta-parameters (e.g., `BM`) and compilation options (e.g., `num_warps`) to try
|
| 20 |
+
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
|
| 21 |
+
# provided configs
|
| 22 |
+
@triton.heuristics({
|
| 23 |
+
'HAS_ALPHA': lambda args: args['alpha'] is not None,
|
| 24 |
+
'HAS_BETA': lambda args: args['beta'] is not None
|
| 25 |
+
})
|
| 26 |
+
@triton.autotune(
|
| 27 |
+
configs=[
|
| 28 |
+
triton.Config({'BM': 128, 'BK': 64, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
|
| 29 |
+
triton.Config({'BM': 64, 'BK': 32, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
|
| 30 |
+
triton.Config({'BM': 128, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 31 |
+
triton.Config({'BM': 128, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
|
| 32 |
+
triton.Config({'BM': 64, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 33 |
+
triton.Config({'BM': 128, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4),
|
| 34 |
+
triton.Config({'BM': 64, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=5, num_warps=2),
|
| 35 |
+
triton.Config({'BM': 32, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=5, num_warps=2),
|
| 36 |
+
# Good config for fp8 inputs.
|
| 37 |
+
# triton.Config({'BM': 128, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
|
| 38 |
+
# triton.Config({'BM': 256, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=3, num_warps=8),
|
| 39 |
+
# triton.Config({'BM': 256, 'BK': 128, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
|
| 40 |
+
# triton.Config({'BM': 64, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
|
| 41 |
+
# triton.Config({'BM': 128, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 42 |
+
# triton.Config({'BM': 128, 'BK': 64, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
|
| 43 |
+
# triton.Config({'BM': 64, 'BK': 64, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 44 |
+
# triton.Config({'BM': 128, 'BK': 64, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4)
|
| 45 |
+
],
|
| 46 |
+
key=['M', 'N', 'K']
|
| 47 |
+
)
|
| 48 |
+
@triton.jit
|
| 49 |
+
def matmul_kernel(
|
| 50 |
+
# Pointers to matrices
|
| 51 |
+
a,
|
| 52 |
+
b,
|
| 53 |
+
c,
|
| 54 |
+
input,
|
| 55 |
+
alpha,
|
| 56 |
+
beta,
|
| 57 |
+
# Matrix dimensions
|
| 58 |
+
M,
|
| 59 |
+
N,
|
| 60 |
+
K,
|
| 61 |
+
# The stride variables represent how much to increase the ptr by when moving by 1
|
| 62 |
+
# element in a particular dimension. E.g. `s_am` is how much to increase `a`
|
| 63 |
+
# by to get the element one row down (A has M rows).
|
| 64 |
+
stride_ab, stride_am, stride_ak, # a: batch, M, K
|
| 65 |
+
stride_bk, stride_bn, # b: K, N
|
| 66 |
+
stride_cb, stride_cm, stride_cn, # c: batch, M, N
|
| 67 |
+
# Meta-parameters
|
| 68 |
+
BM: tl.constexpr,
|
| 69 |
+
BK: tl.constexpr,
|
| 70 |
+
BN: tl.constexpr,
|
| 71 |
+
G: tl.constexpr,
|
| 72 |
+
ACTIVATION: tl.constexpr,
|
| 73 |
+
HAS_INPUT: tl.constexpr,
|
| 74 |
+
HAS_ALPHA: tl.constexpr,
|
| 75 |
+
HAS_BETA: tl.constexpr,
|
| 76 |
+
ALLOW_TF32: tl.constexpr,
|
| 77 |
+
X_DIM: tl.constexpr = 1,
|
| 78 |
+
):
|
| 79 |
+
"""Kernel for computing the matmul C = A x B.
|
| 80 |
+
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
| 81 |
+
"""
|
| 82 |
+
# -----------------------------------------------------------
|
| 83 |
+
# Map program ids `pid` to the block of C it should compute.
|
| 84 |
+
# This is done in a grouped ordering to promote L2 data reuse.
|
| 85 |
+
# See above `L2 Cache Optimizations` section for details.
|
| 86 |
+
i_b, i_m, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 87 |
+
|
| 88 |
+
NM, NN = tl.num_programs(1), tl.num_programs(2)
|
| 89 |
+
i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)
|
| 90 |
+
|
| 91 |
+
# ----------------------------------------------------------
|
| 92 |
+
# Create pointers for the first blocks of A and B.
|
| 93 |
+
# We will advance this pointer as we move in the K direction
|
| 94 |
+
# and accumulate
|
| 95 |
+
# `p_a` is a block of [BM, BK] pointers
|
| 96 |
+
# `p_b` is a block of [BK, BN] pointers
|
| 97 |
+
# See above `Pointer Arithmetic` section for details
|
| 98 |
+
a_batch_ptr = a + i_b * stride_ab
|
| 99 |
+
o_am = (i_m * BM + tl.arange(0, BM)) % M
|
| 100 |
+
o_bn = (i_n * BN + tl.arange(0, BN)) % N
|
| 101 |
+
o_k = tl.arange(0, BK)
|
| 102 |
+
|
| 103 |
+
p_a = a_batch_ptr + (o_am[:, None] * stride_am + o_k[None, :] * stride_ak)
|
| 104 |
+
p_b = b + (o_k[:, None] * stride_bk + o_bn[None, :] * stride_bn)
|
| 105 |
+
|
| 106 |
+
b_acc = tl.zeros((BM, BN), dtype=tl.float32)
|
| 107 |
+
for k in range(0, tl.cdiv(K, BK)):
|
| 108 |
+
# Load the next block of A and B, generate a mask by checking the K dimension.
|
| 109 |
+
# If it is out of bounds, set it to 0.
|
| 110 |
+
b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)
|
| 111 |
+
b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)
|
| 112 |
+
# We accumulate along the K dimension.
|
| 113 |
+
b_acc = tl.dot(b_a, b_b, acc=b_acc, allow_tf32=ALLOW_TF32)
|
| 114 |
+
# Advance the ptrs to the next K block.
|
| 115 |
+
p_a += BK * stride_ak
|
| 116 |
+
p_b += BK * stride_bk
|
| 117 |
+
|
| 118 |
+
o_cm = i_m * BM + tl.arange(0, BM)
|
| 119 |
+
o_cn = i_n * BN + tl.arange(0, BN)
|
| 120 |
+
mask = (o_cm[:, None] < M) & (o_cn[None, :] < N)
|
| 121 |
+
|
| 122 |
+
b_c = b_acc
|
| 123 |
+
# You can fuse arbitrary activation functions here
|
| 124 |
+
# while the b_acc is still in FP32!
|
| 125 |
+
if ACTIVATION == "leaky_relu":
|
| 126 |
+
b_c = leaky_relu(b_c)
|
| 127 |
+
elif ACTIVATION == "relu":
|
| 128 |
+
b_c = relu(b_c)
|
| 129 |
+
elif ACTIVATION == "sigmoid":
|
| 130 |
+
b_c = sigmoid(b_c)
|
| 131 |
+
elif ACTIVATION == "tanh":
|
| 132 |
+
b_c = tanh(b_c)
|
| 133 |
+
|
| 134 |
+
if HAS_ALPHA:
|
| 135 |
+
b_c *= tl.load(alpha)
|
| 136 |
+
|
| 137 |
+
if HAS_INPUT:
|
| 138 |
+
p_i = input + (stride_cm * o_cm[:, None] if X_DIM == 2 else 0) + stride_cn * o_cn[None, :]
|
| 139 |
+
mask_p = (o_cn[None, :] < N) if X_DIM == 1 else mask
|
| 140 |
+
b_i = tl.load(p_i, mask=mask_p, other=0.0).to(tl.float32)
|
| 141 |
+
if HAS_BETA:
|
| 142 |
+
b_i *= tl.load(beta)
|
| 143 |
+
b_c += b_i
|
| 144 |
+
|
| 145 |
+
# -----------------------------------------------------------
|
| 146 |
+
# Write back the block of the output matrix C with masks.
|
| 147 |
+
c_batch_ptr = c + i_b * stride_cb
|
| 148 |
+
p_c = c_batch_ptr + stride_cm * o_cm[:, None] + stride_cn * o_cn[None, :]
|
| 149 |
+
tl.store(p_c, b_c.to(c.dtype.element_ty), mask=mask)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
|
| 153 |
+
@triton.jit
|
| 154 |
+
def leaky_relu(x):
|
| 155 |
+
return tl.where(x >= 0, x, 0.01 * x)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@triton.jit
|
| 159 |
+
def sigmoid(x):
|
| 160 |
+
# σ(x) = 1 / (1 + exp(-x))
|
| 161 |
+
return 1.0 / (1.0 + exp(-x))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@triton.jit
|
| 165 |
+
def tanh(x):
|
| 166 |
+
# tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
| 167 |
+
# 2 * sigmoid(2x) - 1
|
| 168 |
+
return (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@triton.jit
|
| 172 |
+
def relu(x):
|
| 173 |
+
# ReLU(x) = max(0, x)
|
| 174 |
+
return tl.maximum(x, 0.0)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@input_guard
|
| 178 |
+
def matmul(a, b, activation=''):
|
| 179 |
+
assert a.dim() in [2, 3], "a must be 2D or 3D"
|
| 180 |
+
assert b.dim() == 2, "b must be 2D"
|
| 181 |
+
assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
|
| 182 |
+
|
| 183 |
+
if a.dim() == 2:
|
| 184 |
+
a_dim = 2
|
| 185 |
+
a = a.unsqueeze(0).contiguous() # (1, M, K)
|
| 186 |
+
else:
|
| 187 |
+
a_dim = 3
|
| 188 |
+
allow_tf32 = False if a.dtype == torch.float32 else True
|
| 189 |
+
|
| 190 |
+
B, M, K = a.shape[0], a.shape[1], a.shape[2]
|
| 191 |
+
K_b, N = b.shape
|
| 192 |
+
assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
|
| 193 |
+
c = a.new_empty(B, M, N)
|
| 194 |
+
|
| 195 |
+
def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
|
| 196 |
+
matmul_kernel[grid](
|
| 197 |
+
a, b, c, None, None, None,
|
| 198 |
+
M, N, K,
|
| 199 |
+
a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
|
| 200 |
+
b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
|
| 201 |
+
c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
|
| 202 |
+
ACTIVATION=activation,
|
| 203 |
+
ALLOW_TF32=allow_tf32,
|
| 204 |
+
HAS_INPUT=False,
|
| 205 |
+
)
|
| 206 |
+
return c.squeeze(0) if a_dim == 2 else c
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@input_guard
|
| 210 |
+
def addmm(
|
| 211 |
+
x: torch.Tensor,
|
| 212 |
+
a: torch.Tensor,
|
| 213 |
+
b: torch.Tensor,
|
| 214 |
+
alpha: Optional[float] = None,
|
| 215 |
+
beta: Optional[float] = None,
|
| 216 |
+
) -> torch.Tensor:
|
| 217 |
+
assert a.dim() in [2, 3], "a must be 2D or 3D"
|
| 218 |
+
assert b.dim() == 2, "b must be 2D"
|
| 219 |
+
assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
|
| 220 |
+
|
| 221 |
+
if a.dim() == 2:
|
| 222 |
+
a_dim = 2
|
| 223 |
+
a = a.unsqueeze(0).contiguous() # (1, M, K)
|
| 224 |
+
else:
|
| 225 |
+
a_dim = 3
|
| 226 |
+
allow_tf32 = False if a.dtype == torch.float32 else True
|
| 227 |
+
|
| 228 |
+
B, M, K = a.shape[0], a.shape[1], a.shape[2]
|
| 229 |
+
K_b, N = b.shape
|
| 230 |
+
assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
|
| 231 |
+
c = a.new_empty(B, M, N)
|
| 232 |
+
|
| 233 |
+
def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
|
| 234 |
+
matmul_kernel[grid](
|
| 235 |
+
a, b, c, x, alpha, beta,
|
| 236 |
+
M, N, K,
|
| 237 |
+
a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
|
| 238 |
+
b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
|
| 239 |
+
c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
|
| 240 |
+
ACTIVATION=None,
|
| 241 |
+
ALLOW_TF32=allow_tf32,
|
| 242 |
+
HAS_INPUT=True,
|
| 243 |
+
X_DIM=x.dim(),
|
| 244 |
+
)
|
| 245 |
+
return c.squeeze(0) if a_dim == 2 else c
|
fla3/ops/utils/op.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2024, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import triton
|
| 7 |
+
import triton.language as tl
|
| 8 |
+
import triton.language.extra.libdevice as tldevice
|
| 9 |
+
|
| 10 |
+
from ...utils import is_gather_supported
|
| 11 |
+
|
| 12 |
+
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
| 13 |
+
div = tldevice.fast_dividef
|
| 14 |
+
exp = tldevice.fast_expf
|
| 15 |
+
log = tldevice.fast_logf
|
| 16 |
+
log2 = tldevice.fast_log2f
|
| 17 |
+
else:
|
| 18 |
+
@triton.jit
|
| 19 |
+
def div_normal(x, y):
|
| 20 |
+
return x / y
|
| 21 |
+
div = div_normal
|
| 22 |
+
exp = tl.exp
|
| 23 |
+
log = tl.log
|
| 24 |
+
log2 = tl.log2
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@triton.jit
|
| 28 |
+
def safe_exp(x):
|
| 29 |
+
return exp(tl.where(x <= 0, x, float('-inf')))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if not is_gather_supported:
|
| 33 |
+
@triton.jit
|
| 34 |
+
def gather(src, index, axis, _builder=None):
|
| 35 |
+
# This is a fallback implementation when tl.gather is not supported
|
| 36 |
+
# In order to pass triton compiler, there is no actual gather operation
|
| 37 |
+
return src
|
| 38 |
+
else:
|
| 39 |
+
gather = tl.gather
|
fla3/ops/utils/pack.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
# Code adapted from https://github.com/mayank31398/cute-kernels
|
| 5 |
+
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import triton
|
| 10 |
+
import triton.language as tl
|
| 11 |
+
|
| 12 |
+
from ...ops.utils.index import prepare_lens
|
| 13 |
+
from ...utils import input_guard
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@triton.autotune(
|
| 17 |
+
configs=[
|
| 18 |
+
triton.Config({}, num_warps=num_warps)
|
| 19 |
+
for num_warps in [4, 8, 16, 32]
|
| 20 |
+
],
|
| 21 |
+
key=['D', 'PADDING_SIDE', 'PACK']
|
| 22 |
+
)
|
| 23 |
+
@triton.jit
|
| 24 |
+
def packunpack_sequence_kernel(
|
| 25 |
+
x,
|
| 26 |
+
y,
|
| 27 |
+
cu_seqlens,
|
| 28 |
+
S,
|
| 29 |
+
D,
|
| 30 |
+
BD: tl.constexpr,
|
| 31 |
+
PADDING_SIDE: tl.constexpr,
|
| 32 |
+
PACK: tl.constexpr,
|
| 33 |
+
):
|
| 34 |
+
i_d, i_s, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 35 |
+
bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1)
|
| 36 |
+
|
| 37 |
+
T = eos - bos
|
| 38 |
+
if PADDING_SIDE == 'left':
|
| 39 |
+
NP = S - T
|
| 40 |
+
if i_s < NP:
|
| 41 |
+
return
|
| 42 |
+
i_t = bos + (i_s - NP)
|
| 43 |
+
else:
|
| 44 |
+
if i_s >= T:
|
| 45 |
+
return
|
| 46 |
+
i_t = bos + i_s
|
| 47 |
+
|
| 48 |
+
o_d = i_d * BD + tl.arange(0, BD)
|
| 49 |
+
mask = o_d < D
|
| 50 |
+
|
| 51 |
+
if PACK:
|
| 52 |
+
b_x = tl.load(x + (i_b * S + i_s) * D + o_d, mask=mask)
|
| 53 |
+
tl.store(y + i_t * D + o_d, b_x, mask=mask)
|
| 54 |
+
else:
|
| 55 |
+
b_x = tl.load(x + i_t * D + o_d, mask=mask)
|
| 56 |
+
tl.store(y + (i_b * S + i_s) * D + o_d, b_x, mask=mask)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def pack_sequence_fwdbwd(
|
| 60 |
+
x: torch.Tensor,
|
| 61 |
+
cu_seqlens: torch.Tensor,
|
| 62 |
+
padding_side: str,
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
B, S = x.shape[:2]
|
| 65 |
+
D = x.numel() // (B * S)
|
| 66 |
+
BD = min(triton.next_power_of_2(D), 4096)
|
| 67 |
+
ND = triton.cdiv(D, BD)
|
| 68 |
+
|
| 69 |
+
y = torch.empty(cu_seqlens[-1].item(), *x.shape[2:], device=x.device, dtype=x.dtype)
|
| 70 |
+
packunpack_sequence_kernel[ND, S, B](
|
| 71 |
+
x=x,
|
| 72 |
+
y=y,
|
| 73 |
+
cu_seqlens=cu_seqlens,
|
| 74 |
+
S=S,
|
| 75 |
+
D=D,
|
| 76 |
+
BD=BD,
|
| 77 |
+
PADDING_SIDE=padding_side,
|
| 78 |
+
PACK=True,
|
| 79 |
+
)
|
| 80 |
+
return y
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def unpack_sequence_fwdbwd(
|
| 84 |
+
x: torch.Tensor,
|
| 85 |
+
cu_seqlens: torch.Tensor,
|
| 86 |
+
padding_side: str,
|
| 87 |
+
desired_shape: torch.Size,
|
| 88 |
+
) -> torch.Tensor:
|
| 89 |
+
if desired_shape is None:
|
| 90 |
+
desired_shape = (len(cu_seqlens) - 1, prepare_lens(cu_seqlens).max().item(), *x.shape[1:])
|
| 91 |
+
y = torch.zeros(desired_shape, device=x.device, dtype=x.dtype)
|
| 92 |
+
B, S = y.shape[:2]
|
| 93 |
+
D = y.numel() // (B * S)
|
| 94 |
+
BD = min(triton.next_power_of_2(D), 4096)
|
| 95 |
+
ND = triton.cdiv(D, BD)
|
| 96 |
+
|
| 97 |
+
packunpack_sequence_kernel[ND, S, B](
|
| 98 |
+
x=x,
|
| 99 |
+
y=y,
|
| 100 |
+
cu_seqlens=cu_seqlens,
|
| 101 |
+
S=S,
|
| 102 |
+
D=D,
|
| 103 |
+
BD=BD,
|
| 104 |
+
PADDING_SIDE=padding_side,
|
| 105 |
+
PACK=False,
|
| 106 |
+
)
|
| 107 |
+
return y
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class PackSequenceFunction(torch.autograd.Function):
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
@input_guard
|
| 114 |
+
def forward(
|
| 115 |
+
ctx,
|
| 116 |
+
x: torch.Tensor,
|
| 117 |
+
cu_seqlens: torch.Tensor,
|
| 118 |
+
padding_side: str,
|
| 119 |
+
) -> torch.Tensor:
|
| 120 |
+
assert padding_side in ['left', 'right']
|
| 121 |
+
assert x.ndim >= 2
|
| 122 |
+
|
| 123 |
+
ctx.cu_seqlens = cu_seqlens
|
| 124 |
+
ctx.padding_side = padding_side
|
| 125 |
+
ctx.desired_shape = x.shape
|
| 126 |
+
|
| 127 |
+
y = pack_sequence_fwdbwd(
|
| 128 |
+
x=x,
|
| 129 |
+
cu_seqlens=cu_seqlens,
|
| 130 |
+
padding_side=padding_side,
|
| 131 |
+
)
|
| 132 |
+
return y
|
| 133 |
+
|
| 134 |
+
@staticmethod
|
| 135 |
+
@input_guard
|
| 136 |
+
def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
|
| 137 |
+
dx = unpack_sequence_fwdbwd(
|
| 138 |
+
x=dy,
|
| 139 |
+
cu_seqlens=ctx.cu_seqlens,
|
| 140 |
+
padding_side=ctx.padding_side,
|
| 141 |
+
desired_shape=ctx.desired_shape,
|
| 142 |
+
)
|
| 143 |
+
return dx, *[None] * 10
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class UnpackSequenceFunction(torch.autograd.Function):
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
@input_guard
|
| 150 |
+
def forward(
|
| 151 |
+
ctx,
|
| 152 |
+
x: torch.Tensor,
|
| 153 |
+
cu_seqlens: torch.Tensor,
|
| 154 |
+
padding_side: str,
|
| 155 |
+
desired_shape: Optional[torch.Size] = None,
|
| 156 |
+
) -> torch.Tensor:
|
| 157 |
+
assert padding_side in ['left', 'right']
|
| 158 |
+
assert x.ndim >= 2
|
| 159 |
+
if desired_shape is not None:
|
| 160 |
+
assert desired_shape[0] == cu_seqlens.shape[0] - 1
|
| 161 |
+
assert desired_shape[2:] == x.shape[1:]
|
| 162 |
+
|
| 163 |
+
ctx.cu_seqlens = cu_seqlens
|
| 164 |
+
ctx.padding_side = padding_side
|
| 165 |
+
|
| 166 |
+
y = unpack_sequence_fwdbwd(
|
| 167 |
+
x=x,
|
| 168 |
+
cu_seqlens=cu_seqlens,
|
| 169 |
+
padding_side=padding_side,
|
| 170 |
+
desired_shape=desired_shape,
|
| 171 |
+
)
|
| 172 |
+
return y
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
@input_guard
|
| 176 |
+
def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
|
| 177 |
+
dx = pack_sequence_fwdbwd(
|
| 178 |
+
x=dy,
|
| 179 |
+
cu_seqlens=ctx.cu_seqlens,
|
| 180 |
+
padding_side=ctx.padding_side,
|
| 181 |
+
)
|
| 182 |
+
return dx, None, None, None
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def pack_sequence(
|
| 186 |
+
x: torch.Tensor,
|
| 187 |
+
cu_seqlens: torch.Tensor,
|
| 188 |
+
padding_side: str = 'left'
|
| 189 |
+
) -> torch.Tensor:
|
| 190 |
+
return PackSequenceFunction.apply(
|
| 191 |
+
x,
|
| 192 |
+
cu_seqlens,
|
| 193 |
+
padding_side,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def unpack_sequence(
|
| 198 |
+
x: torch.Tensor,
|
| 199 |
+
cu_seqlens: torch.Tensor,
|
| 200 |
+
padding_side: str = 'left',
|
| 201 |
+
desired_shape: Optional[torch.Size] = None,
|
| 202 |
+
) -> torch.Tensor:
|
| 203 |
+
return UnpackSequenceFunction.apply(
|
| 204 |
+
x,
|
| 205 |
+
cu_seqlens,
|
| 206 |
+
padding_side,
|
| 207 |
+
desired_shape,
|
| 208 |
+
)
|
fla3/ops/utils/pooling.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ...ops.utils.index import prepare_chunk_indices
|
| 11 |
+
from ...utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@triton.heuristics({
|
| 15 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 16 |
+
})
|
| 17 |
+
@triton.autotune(
|
| 18 |
+
configs=[
|
| 19 |
+
triton.Config({'BD': BD}, num_warps=num_warps)
|
| 20 |
+
for BD in [16, 32, 64, 128]
|
| 21 |
+
for num_warps in [1, 2, 4, 8]
|
| 22 |
+
],
|
| 23 |
+
key=['BT']
|
| 24 |
+
)
|
| 25 |
+
@triton.jit(do_not_specialize=['T'])
|
| 26 |
+
def mean_pooling_fwd_kernel(
|
| 27 |
+
x,
|
| 28 |
+
o,
|
| 29 |
+
cu_seqlens,
|
| 30 |
+
chunk_indices,
|
| 31 |
+
T,
|
| 32 |
+
H: tl.constexpr,
|
| 33 |
+
D: tl.constexpr,
|
| 34 |
+
BT: tl.constexpr,
|
| 35 |
+
BD: tl.constexpr,
|
| 36 |
+
IS_VARLEN: tl.constexpr
|
| 37 |
+
):
|
| 38 |
+
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 39 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 40 |
+
if IS_VARLEN:
|
| 41 |
+
i_tg = i_t
|
| 42 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 43 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 44 |
+
T = eos - bos
|
| 45 |
+
NT = tl.cdiv(T, BT)
|
| 46 |
+
else:
|
| 47 |
+
NT = tl.cdiv(T, BT)
|
| 48 |
+
i_tg = i_b * NT + i_t
|
| 49 |
+
bos, eos = i_b * T, i_b * T + T
|
| 50 |
+
|
| 51 |
+
p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
| 52 |
+
p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
|
| 53 |
+
# [BT, BD]
|
| 54 |
+
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
|
| 55 |
+
# [BD]
|
| 56 |
+
b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT)
|
| 57 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@triton.heuristics({
|
| 61 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 62 |
+
})
|
| 63 |
+
@triton.autotune(
|
| 64 |
+
configs=[
|
| 65 |
+
triton.Config({'BD': BD}, num_warps=num_warps)
|
| 66 |
+
for BD in [16, 32, 64, 128]
|
| 67 |
+
for num_warps in [1, 2, 4, 8]
|
| 68 |
+
],
|
| 69 |
+
key=['BT']
|
| 70 |
+
)
|
| 71 |
+
@triton.jit(do_not_specialize=['T'])
|
| 72 |
+
def mean_pooling_bwd_kernel(
|
| 73 |
+
do,
|
| 74 |
+
dx,
|
| 75 |
+
cu_seqlens,
|
| 76 |
+
chunk_indices,
|
| 77 |
+
T,
|
| 78 |
+
H: tl.constexpr,
|
| 79 |
+
D: tl.constexpr,
|
| 80 |
+
BT: tl.constexpr,
|
| 81 |
+
BD: tl.constexpr,
|
| 82 |
+
IS_VARLEN: tl.constexpr
|
| 83 |
+
):
|
| 84 |
+
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 85 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 86 |
+
if IS_VARLEN:
|
| 87 |
+
i_tg = i_t
|
| 88 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 89 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 90 |
+
T = eos - bos
|
| 91 |
+
NT = tl.cdiv(T, BT)
|
| 92 |
+
else:
|
| 93 |
+
NT = tl.cdiv(T, BT)
|
| 94 |
+
i_tg = i_b * NT + i_t
|
| 95 |
+
bos, eos = i_b * T, i_b * T + T
|
| 96 |
+
|
| 97 |
+
p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
| 98 |
+
p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
|
| 99 |
+
# [BD]
|
| 100 |
+
b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32)
|
| 101 |
+
# [BT, BD]
|
| 102 |
+
b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None]
|
| 103 |
+
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def mean_pooling_fwd(
|
| 107 |
+
x: torch.Tensor,
|
| 108 |
+
chunk_size: int,
|
| 109 |
+
cu_seqlens: Optional[torch.LongTensor] = None
|
| 110 |
+
) -> torch.Tensor:
|
| 111 |
+
B, T, H, D = x.shape
|
| 112 |
+
BT = chunk_size
|
| 113 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
| 114 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 115 |
+
|
| 116 |
+
o = x.new_empty(B, NT, H, D)
|
| 117 |
+
def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
|
| 118 |
+
mean_pooling_fwd_kernel[grid](
|
| 119 |
+
x,
|
| 120 |
+
o,
|
| 121 |
+
cu_seqlens,
|
| 122 |
+
chunk_indices,
|
| 123 |
+
T=T,
|
| 124 |
+
H=H,
|
| 125 |
+
D=D,
|
| 126 |
+
BT=BT,
|
| 127 |
+
)
|
| 128 |
+
return o
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def mean_pooling_bwd(
|
| 132 |
+
do: torch.Tensor,
|
| 133 |
+
batch_size: int,
|
| 134 |
+
seq_len: int,
|
| 135 |
+
chunk_size: int,
|
| 136 |
+
cu_seqlens: Optional[torch.LongTensor] = None
|
| 137 |
+
) -> torch.Tensor:
|
| 138 |
+
B, T, H, D = batch_size, seq_len, *do.shape[-2:]
|
| 139 |
+
BT = chunk_size
|
| 140 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
| 141 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 142 |
+
|
| 143 |
+
dx = do.new_empty(B, T, H, D)
|
| 144 |
+
def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
|
| 145 |
+
mean_pooling_bwd_kernel[grid](
|
| 146 |
+
do,
|
| 147 |
+
dx,
|
| 148 |
+
cu_seqlens,
|
| 149 |
+
chunk_indices,
|
| 150 |
+
T=T,
|
| 151 |
+
H=H,
|
| 152 |
+
D=D,
|
| 153 |
+
BT=BT,
|
| 154 |
+
)
|
| 155 |
+
return dx
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class MeanPoolingFunction(torch.autograd.Function):
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
@input_guard
|
| 162 |
+
@autocast_custom_fwd
|
| 163 |
+
def forward(
|
| 164 |
+
ctx,
|
| 165 |
+
x: torch.Tensor,
|
| 166 |
+
chunk_size: int,
|
| 167 |
+
cu_seqlens: Optional[torch.LongTensor] = None
|
| 168 |
+
) -> torch.Tensor:
|
| 169 |
+
o = mean_pooling_fwd(x, chunk_size, cu_seqlens)
|
| 170 |
+
ctx.batch_size = x.shape[0]
|
| 171 |
+
ctx.seq_len = x.shape[1]
|
| 172 |
+
ctx.chunk_size = chunk_size
|
| 173 |
+
ctx.cu_seqlens = cu_seqlens
|
| 174 |
+
return o
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
@input_guard
|
| 178 |
+
@autocast_custom_bwd
|
| 179 |
+
def backward(
|
| 180 |
+
ctx, do
|
| 181 |
+
) -> Tuple[torch.Tensor, None, None]:
|
| 182 |
+
batch_size = ctx.batch_size
|
| 183 |
+
seq_len = ctx.seq_len
|
| 184 |
+
chunk_size = ctx.chunk_size
|
| 185 |
+
cu_seqlens = ctx.cu_seqlens
|
| 186 |
+
dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, cu_seqlens)
|
| 187 |
+
return dx, None, None
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def mean_pooling(
|
| 191 |
+
x: torch.Tensor,
|
| 192 |
+
chunk_size: int,
|
| 193 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 194 |
+
head_first: bool = False
|
| 195 |
+
) -> torch.Tensor:
|
| 196 |
+
if head_first:
|
| 197 |
+
x = x.transpose(1, 2)
|
| 198 |
+
if cu_seqlens is not None:
|
| 199 |
+
if x.shape[0] != 1:
|
| 200 |
+
raise ValueError(
|
| 201 |
+
f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`."
|
| 202 |
+
f"Please ..tten variable-length inputs before processing."
|
| 203 |
+
)
|
| 204 |
+
o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens)
|
| 205 |
+
if head_first:
|
| 206 |
+
o = o.transpose(1, 2)
|
| 207 |
+
return o
|
fla3/ops/utils/softmax.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2024, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ...ops.utils.op import exp
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@triton.autotune(
|
| 14 |
+
configs=[
|
| 15 |
+
triton.Config({}, num_warps=1),
|
| 16 |
+
triton.Config({}, num_warps=2),
|
| 17 |
+
triton.Config({}, num_warps=4),
|
| 18 |
+
triton.Config({}, num_warps=8),
|
| 19 |
+
triton.Config({}, num_warps=16),
|
| 20 |
+
triton.Config({}, num_warps=32)
|
| 21 |
+
],
|
| 22 |
+
key=['D']
|
| 23 |
+
)
|
| 24 |
+
@triton.jit
|
| 25 |
+
def softmax_fwd_kernel(
|
| 26 |
+
x,
|
| 27 |
+
p,
|
| 28 |
+
D: tl.constexpr,
|
| 29 |
+
B: tl.constexpr
|
| 30 |
+
):
|
| 31 |
+
i_n = tl.program_id(0)
|
| 32 |
+
o_d = tl.arange(0, B)
|
| 33 |
+
m_d = o_d < D
|
| 34 |
+
|
| 35 |
+
b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf'))
|
| 36 |
+
b_m = tl.max(b_x, 0)
|
| 37 |
+
b_x = exp(b_x - b_m)
|
| 38 |
+
b_p = b_x / tl.sum(b_x, 0)
|
| 39 |
+
|
| 40 |
+
tl.store(p + i_n * D + o_d, b_p.to(p.dtype.element_ty), mask=m_d)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@triton.autotune(
|
| 44 |
+
configs=[
|
| 45 |
+
triton.Config({}, num_warps=1),
|
| 46 |
+
triton.Config({}, num_warps=2),
|
| 47 |
+
triton.Config({}, num_warps=4),
|
| 48 |
+
triton.Config({}, num_warps=8),
|
| 49 |
+
triton.Config({}, num_warps=16),
|
| 50 |
+
triton.Config({}, num_warps=32)
|
| 51 |
+
],
|
| 52 |
+
key=['D']
|
| 53 |
+
)
|
| 54 |
+
@triton.jit
|
| 55 |
+
def softmax_bwd_kernel(
|
| 56 |
+
p,
|
| 57 |
+
dp,
|
| 58 |
+
ds,
|
| 59 |
+
D: tl.constexpr,
|
| 60 |
+
B: tl.constexpr
|
| 61 |
+
):
|
| 62 |
+
i_n = tl.program_id(0)
|
| 63 |
+
o_d = tl.arange(0, B)
|
| 64 |
+
m_d = o_d < D
|
| 65 |
+
|
| 66 |
+
b_p = tl.load(p + i_n * D + o_d, mask=m_d, other=0.)
|
| 67 |
+
b_dp = tl.load(dp + i_n * D + o_d, mask=m_d, other=0.)
|
| 68 |
+
b_pp = tl.sum(b_p * b_dp, 0)
|
| 69 |
+
b_ds = b_p * b_dp - b_p * b_pp
|
| 70 |
+
tl.store(ds + i_n * D + o_d, b_ds.to(ds.dtype.element_ty), mask=m_d)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def softmax_fwd(
|
| 74 |
+
x: torch.Tensor,
|
| 75 |
+
dtype: Optional[torch.dtype] = torch.float
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
shape = x.shape
|
| 78 |
+
x = x.view(-1, x.shape[-1])
|
| 79 |
+
|
| 80 |
+
N, D = x.shape
|
| 81 |
+
B = triton.next_power_of_2(D)
|
| 82 |
+
|
| 83 |
+
p = torch.empty_like(x, dtype=dtype)
|
| 84 |
+
softmax_fwd_kernel[(N,)](
|
| 85 |
+
x=x,
|
| 86 |
+
p=p,
|
| 87 |
+
D=D,
|
| 88 |
+
B=B
|
| 89 |
+
)
|
| 90 |
+
return p.view(*shape)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def softmax_bwd(
|
| 94 |
+
p: torch.Tensor,
|
| 95 |
+
dp: torch.Tensor,
|
| 96 |
+
dtype: Optional[torch.dtype] = torch.float
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
+
shape = p.shape
|
| 99 |
+
p = p.view(-1, p.shape[-1])
|
| 100 |
+
ds = torch.empty_like(p, dtype=dtype)
|
| 101 |
+
|
| 102 |
+
N, D = p.shape
|
| 103 |
+
B = triton.next_power_of_2(D)
|
| 104 |
+
softmax_bwd_kernel[(N,)](
|
| 105 |
+
p=p,
|
| 106 |
+
dp=dp,
|
| 107 |
+
ds=ds,
|
| 108 |
+
D=D,
|
| 109 |
+
B=B
|
| 110 |
+
)
|
| 111 |
+
return ds.view(*shape)
|
fla3/ops/utils/solve_tril.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ...ops.utils.index import prepare_chunk_indices
|
| 11 |
+
from ...utils import input_guard
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@triton.heuristics({
|
| 15 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 16 |
+
})
|
| 17 |
+
@triton.autotune(
|
| 18 |
+
configs=[
|
| 19 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 20 |
+
for num_warps in [1, 2, 4, 8]
|
| 21 |
+
for num_stages in [2, 3, 4, 5]
|
| 22 |
+
],
|
| 23 |
+
key=['BT'],
|
| 24 |
+
)
|
| 25 |
+
@triton.jit(do_not_specialize=['T'])
|
| 26 |
+
def solve_tril_16x16_kernel(
|
| 27 |
+
A,
|
| 28 |
+
Ad,
|
| 29 |
+
cu_seqlens,
|
| 30 |
+
chunk_indices,
|
| 31 |
+
T,
|
| 32 |
+
H: tl.constexpr,
|
| 33 |
+
BT: tl.constexpr,
|
| 34 |
+
IS_VARLEN: tl.constexpr,
|
| 35 |
+
):
|
| 36 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 37 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 38 |
+
if IS_VARLEN:
|
| 39 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 40 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 41 |
+
T = eos - bos
|
| 42 |
+
else:
|
| 43 |
+
bos, eos = i_b * T, i_b * T + T
|
| 44 |
+
|
| 45 |
+
A = A + (bos*H + i_h) * BT
|
| 46 |
+
Ad = Ad + (bos*H + i_h) * 16
|
| 47 |
+
|
| 48 |
+
offset = (i_t * 16) % BT
|
| 49 |
+
p_A = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * 16, offset), (16, 16), (1, 0))
|
| 50 |
+
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 16, 0), (16, 16), (1, 0))
|
| 51 |
+
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
| 52 |
+
b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
|
| 53 |
+
|
| 54 |
+
o_i = tl.arange(0, 16)
|
| 55 |
+
for i in range(1, min(16, T-i_t*16)):
|
| 56 |
+
b_a = -tl.load(A + (i_t * 16 + i) * H*BT + o_i + offset)
|
| 57 |
+
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
| 58 |
+
mask = o_i == i
|
| 59 |
+
b_A = tl.where(mask[:, None], b_a, b_A)
|
| 60 |
+
b_A += o_i[:, None] == o_i[None, :]
|
| 61 |
+
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@triton.heuristics({
|
| 65 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 66 |
+
})
|
| 67 |
+
@triton.autotune(
|
| 68 |
+
configs=[
|
| 69 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 70 |
+
for num_warps in [1, 2, 4, 8]
|
| 71 |
+
for num_stages in [2, 3, 4, 5]
|
| 72 |
+
],
|
| 73 |
+
key=['H', 'BT', 'IS_VARLEN'],
|
| 74 |
+
)
|
| 75 |
+
@triton.jit(do_not_specialize=['T'])
|
| 76 |
+
def merge_16x16_to_32x32_inverse_kernel(
|
| 77 |
+
A,
|
| 78 |
+
Ad,
|
| 79 |
+
Ai,
|
| 80 |
+
cu_seqlens,
|
| 81 |
+
chunk_indices,
|
| 82 |
+
T,
|
| 83 |
+
H: tl.constexpr,
|
| 84 |
+
BT: tl.constexpr,
|
| 85 |
+
IS_VARLEN: tl.constexpr
|
| 86 |
+
):
|
| 87 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 88 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 89 |
+
if IS_VARLEN:
|
| 90 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 91 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 92 |
+
T = eos - bos
|
| 93 |
+
else:
|
| 94 |
+
bos, eos = i_b * T, i_b * T + T
|
| 95 |
+
|
| 96 |
+
A += (bos*H + i_h) * 32
|
| 97 |
+
Ad += (bos*H + i_h) * 16
|
| 98 |
+
Ai += (bos*H + i_h) * 32
|
| 99 |
+
|
| 100 |
+
p_A_21 = tl.make_block_ptr(A, (T, 32), (H*32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
| 101 |
+
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 32, 0), (16, 16), (1, 0))
|
| 102 |
+
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
| 103 |
+
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32, 0), (16, 16), (1, 0))
|
| 104 |
+
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0))
|
| 105 |
+
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
| 106 |
+
|
| 107 |
+
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
| 108 |
+
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
| 109 |
+
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
| 110 |
+
Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee')
|
| 111 |
+
tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 112 |
+
tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 113 |
+
tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@triton.heuristics({
|
| 117 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 118 |
+
})
|
| 119 |
+
@triton.autotune(
|
| 120 |
+
configs=[
|
| 121 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 122 |
+
for num_warps in [2, 4, 8]
|
| 123 |
+
for num_stages in [2, 3, 4, 5]
|
| 124 |
+
],
|
| 125 |
+
key=['H', 'BT', 'IS_VARLEN'],
|
| 126 |
+
)
|
| 127 |
+
@triton.jit(do_not_specialize=['T'])
|
| 128 |
+
def merge_16x16_to_64x64_inverse_kernel(
|
| 129 |
+
A,
|
| 130 |
+
Ad,
|
| 131 |
+
Ai,
|
| 132 |
+
cu_seqlens,
|
| 133 |
+
chunk_indices,
|
| 134 |
+
T,
|
| 135 |
+
H: tl.constexpr,
|
| 136 |
+
BT: tl.constexpr,
|
| 137 |
+
IS_VARLEN: tl.constexpr
|
| 138 |
+
):
|
| 139 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 140 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 141 |
+
if IS_VARLEN:
|
| 142 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 143 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 144 |
+
T = eos - bos
|
| 145 |
+
else:
|
| 146 |
+
bos, eos = i_b * T, i_b * T + T
|
| 147 |
+
|
| 148 |
+
A += (bos*H + i_h) * 64
|
| 149 |
+
Ad += (bos*H + i_h) * 16
|
| 150 |
+
Ai += (bos*H + i_h) * 64
|
| 151 |
+
|
| 152 |
+
p_A_21 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
|
| 153 |
+
p_A_32 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0))
|
| 154 |
+
p_A_31 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
|
| 155 |
+
p_A_43 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0))
|
| 156 |
+
p_A_42 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0))
|
| 157 |
+
p_A_41 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
|
| 158 |
+
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64, 0), (16, 16), (1, 0))
|
| 159 |
+
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
|
| 160 |
+
p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
|
| 161 |
+
p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
|
| 162 |
+
|
| 163 |
+
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
| 164 |
+
A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
|
| 165 |
+
A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
|
| 166 |
+
A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
|
| 167 |
+
A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
|
| 168 |
+
A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
|
| 169 |
+
|
| 170 |
+
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
| 171 |
+
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
| 172 |
+
Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32)
|
| 173 |
+
Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32)
|
| 174 |
+
|
| 175 |
+
Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee')
|
| 176 |
+
Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'), Ai_22, input_precision='ieee')
|
| 177 |
+
Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'), Ai_33, input_precision='ieee')
|
| 178 |
+
|
| 179 |
+
Ai_31 = -tl.dot(
|
| 180 |
+
Ai_33,
|
| 181 |
+
tl.dot(A_31, Ai_11, input_precision='ieee') +
|
| 182 |
+
tl.dot(A_32, Ai_21, input_precision='ieee'),
|
| 183 |
+
input_precision='ieee'
|
| 184 |
+
)
|
| 185 |
+
Ai_42 = -tl.dot(
|
| 186 |
+
Ai_44,
|
| 187 |
+
tl.dot(A_42, Ai_22, input_precision='ieee') +
|
| 188 |
+
tl.dot(A_43, Ai_32, input_precision='ieee'),
|
| 189 |
+
input_precision='ieee'
|
| 190 |
+
)
|
| 191 |
+
Ai_41 = -tl.dot(
|
| 192 |
+
Ai_44,
|
| 193 |
+
tl.dot(A_41, Ai_11, input_precision='ieee') +
|
| 194 |
+
tl.dot(A_42, Ai_21, input_precision='ieee') +
|
| 195 |
+
tl.dot(A_43, Ai_31, input_precision='ieee'),
|
| 196 |
+
input_precision='ieee'
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64, 0), (16, 16), (1, 0))
|
| 200 |
+
p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0))
|
| 201 |
+
p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0))
|
| 202 |
+
p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0))
|
| 203 |
+
p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
|
| 204 |
+
p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
|
| 205 |
+
p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0))
|
| 206 |
+
p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
|
| 207 |
+
p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0))
|
| 208 |
+
p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0))
|
| 209 |
+
tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 210 |
+
tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 211 |
+
tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 212 |
+
tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 213 |
+
tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 214 |
+
tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 215 |
+
tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 216 |
+
tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 217 |
+
tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 218 |
+
tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@input_guard
|
| 222 |
+
def solve_tril(
|
| 223 |
+
A: torch.Tensor,
|
| 224 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 225 |
+
output_dtype: torch.dtype = torch.float
|
| 226 |
+
) -> torch.Tensor:
|
| 227 |
+
"""
|
| 228 |
+
Compute the inverse of the lower triangular matrix
|
| 229 |
+
A should be strictly lower triangular, i.e., A.triu() == 0.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
A (torch.Tensor):
|
| 233 |
+
[B, T, H, K]
|
| 234 |
+
cu_seqlens (torch.Tensor):
|
| 235 |
+
The cumulative sequence lengths of the input tensor.
|
| 236 |
+
Default: None.
|
| 237 |
+
output_dtype (torch.dtype):
|
| 238 |
+
The dtype of the output tensor. Default: `torch.float`
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
(I + A)^-1 with the same shape as A
|
| 242 |
+
"""
|
| 243 |
+
assert A.shape[-1] in [16, 32, 64]
|
| 244 |
+
|
| 245 |
+
B, T, H, BT = A.shape
|
| 246 |
+
Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype)
|
| 247 |
+
|
| 248 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None
|
| 249 |
+
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
|
| 250 |
+
solve_tril_16x16_kernel[NT, B * H](
|
| 251 |
+
A=A,
|
| 252 |
+
Ad=Ad,
|
| 253 |
+
cu_seqlens=cu_seqlens,
|
| 254 |
+
chunk_indices=chunk_indices,
|
| 255 |
+
T=T,
|
| 256 |
+
H=H,
|
| 257 |
+
BT=BT,
|
| 258 |
+
)
|
| 259 |
+
if BT == 16:
|
| 260 |
+
return Ad
|
| 261 |
+
|
| 262 |
+
Ai = torch.zeros(B, T, H, BT, device=A.device, dtype=output_dtype)
|
| 263 |
+
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
|
| 264 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 265 |
+
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
| 266 |
+
merge_fn[NT, B * H](
|
| 267 |
+
A=A,
|
| 268 |
+
Ad=Ad,
|
| 269 |
+
Ai=Ai,
|
| 270 |
+
cu_seqlens=cu_seqlens,
|
| 271 |
+
chunk_indices=chunk_indices,
|
| 272 |
+
T=T,
|
| 273 |
+
H=H,
|
| 274 |
+
BT=BT,
|
| 275 |
+
)
|
| 276 |
+
return Ai
|
flame/__init__.py
ADDED
|
File without changes
|
flame/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (167 Bytes). View file
|
|
|
flame/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (167 Bytes). View file
|
|
|
flame/__pycache__/data.cpython-310.pyc
ADDED
|
Binary file (8.17 kB). View file
|
|
|
flame/__pycache__/data.cpython-312.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
flame/__pycache__/logging.cpython-310.pyc
ADDED
|
Binary file (3.56 kB). View file
|
|
|
flame/__pycache__/logging.cpython-312.pyc
ADDED
|
Binary file (6.44 kB). View file
|
|
|
flame/__pycache__/parser.cpython-310.pyc
ADDED
|
Binary file (2.89 kB). View file
|
|
|
flame/__pycache__/parser.cpython-312.pyc
ADDED
|
Binary file (4.07 kB). View file
|
|
|
flame/data.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Dict, Iterable, List, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from datasets import Dataset, IterableDataset
|
| 12 |
+
from flame.logging import get_logger
|
| 13 |
+
from transformers import PreTrainedTokenizer
|
| 14 |
+
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class HuggingfaceDataset(IterableDataset):
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
dataset: Dataset,
|
| 23 |
+
tokenizer: PreTrainedTokenizer,
|
| 24 |
+
context_len: int = 2048,
|
| 25 |
+
rank: int = 0,
|
| 26 |
+
world_size: int = 1,
|
| 27 |
+
buffer_size: int = 1024
|
| 28 |
+
) -> HuggingfaceDataset:
|
| 29 |
+
|
| 30 |
+
self.dataset = dataset
|
| 31 |
+
self.tokenizer = tokenizer
|
| 32 |
+
|
| 33 |
+
self.data = dataset.shard(world_size, rank)
|
| 34 |
+
self.context_len = context_len
|
| 35 |
+
self.rank = rank
|
| 36 |
+
self.world_size = world_size
|
| 37 |
+
self.buffer_size = buffer_size
|
| 38 |
+
|
| 39 |
+
if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
|
| 40 |
+
self.dtype = torch.int16
|
| 41 |
+
elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
|
| 42 |
+
self.dtype = torch.int32
|
| 43 |
+
else:
|
| 44 |
+
self.dtype = torch.int64
|
| 45 |
+
self.states = None
|
| 46 |
+
self.buffer = torch.tensor([], dtype=self.dtype)
|
| 47 |
+
self.tokens = []
|
| 48 |
+
self.rand_id = 0
|
| 49 |
+
self.token_id = 0
|
| 50 |
+
self.rng_state = None
|
| 51 |
+
self._epoch = 0
|
| 52 |
+
|
| 53 |
+
def __iter__(self):
|
| 54 |
+
g = torch.Generator()
|
| 55 |
+
g.manual_seed(self._epoch + self.rank)
|
| 56 |
+
if self.rng_state is not None:
|
| 57 |
+
g.set_state(self.rng_state)
|
| 58 |
+
|
| 59 |
+
rand_it = self.randint(0, self.buffer_size, g=g)
|
| 60 |
+
if self.states is not None:
|
| 61 |
+
self.data.load_state_dict(self.states)
|
| 62 |
+
|
| 63 |
+
# max number of tokens allowed in the chunk buffer
|
| 64 |
+
n_tokens = self.buffer_size * self.context_len
|
| 65 |
+
|
| 66 |
+
while True:
|
| 67 |
+
for sample in self.tokenize(self.data):
|
| 68 |
+
# keep appending the samples to the token buffer
|
| 69 |
+
self.tokens += sample
|
| 70 |
+
# if the token buffer is full, start sampling
|
| 71 |
+
# NOTE: we first convert the token ids to a tensor of shape [n_chunks, context_len] for efficiency
|
| 72 |
+
if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
|
| 73 |
+
self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
|
| 74 |
+
self.tokens = self.tokens[n_tokens:]
|
| 75 |
+
if len(self.buffer) == self.buffer_size:
|
| 76 |
+
yield from self.sample(rand_it)
|
| 77 |
+
|
| 78 |
+
n_chunks = len(self.tokens) // self.context_len
|
| 79 |
+
# handle the left tokens in the buffer
|
| 80 |
+
if n_chunks > 0:
|
| 81 |
+
n_tokens = n_chunks * self.context_len
|
| 82 |
+
indices = torch.randperm(n_chunks, generator=g).tolist()
|
| 83 |
+
self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
|
| 84 |
+
self.tokens = self.tokens[n_tokens:]
|
| 85 |
+
for i in indices:
|
| 86 |
+
yield {'input_ids': self.buffer[i]}
|
| 87 |
+
|
| 88 |
+
def tokenize(self, data, batch_size: int = 64):
|
| 89 |
+
texts, states = [], []
|
| 90 |
+
for sample in data:
|
| 91 |
+
texts.append(sample['text'])
|
| 92 |
+
states.append(self.data.state_dict())
|
| 93 |
+
if len(texts) == batch_size:
|
| 94 |
+
for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
|
| 95 |
+
self.states = s
|
| 96 |
+
yield tokenized
|
| 97 |
+
texts, states = [], []
|
| 98 |
+
if len(texts) > 0:
|
| 99 |
+
for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
|
| 100 |
+
self.states = s
|
| 101 |
+
yield tokenized
|
| 102 |
+
|
| 103 |
+
def sample(self, indices):
|
| 104 |
+
n_tokens = (len(self.tokens) // self.context_len) * self.context_len
|
| 105 |
+
while self.token_id < n_tokens:
|
| 106 |
+
i = next(indices)
|
| 107 |
+
start, end = self.token_id, self.token_id + self.context_len
|
| 108 |
+
self.token_id += self.context_len
|
| 109 |
+
yield {'input_ids': self.buffer[i].to(torch.long)}
|
| 110 |
+
self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
|
| 111 |
+
self.token_id = 0
|
| 112 |
+
self.tokens = self.tokens[n_tokens:]
|
| 113 |
+
|
| 114 |
+
def randint(
|
| 115 |
+
self,
|
| 116 |
+
low: int,
|
| 117 |
+
high: int,
|
| 118 |
+
batch_size: int = 1024,
|
| 119 |
+
g: torch.Generator = torch.Generator()
|
| 120 |
+
) -> Iterable[int]:
|
| 121 |
+
indices = torch.empty(batch_size, dtype=torch.long)
|
| 122 |
+
while True:
|
| 123 |
+
# record the generator states before sampling
|
| 124 |
+
self.rng_state = g.get_state()
|
| 125 |
+
indices = torch.randint(low, high, (batch_size,), out=indices, generator=g)
|
| 126 |
+
for i in indices[self.rand_id:].tolist():
|
| 127 |
+
self.rand_id += 1
|
| 128 |
+
yield i
|
| 129 |
+
self.rand_id = 0
|
| 130 |
+
|
| 131 |
+
def set_epoch(self, epoch):
|
| 132 |
+
self._epoch = epoch
|
| 133 |
+
if hasattr(self.dataset, "set_epoch"):
|
| 134 |
+
self.dataset.set_epoch(epoch)
|
| 135 |
+
|
| 136 |
+
def state_dict(self):
|
| 137 |
+
return {
|
| 138 |
+
'states': self.states,
|
| 139 |
+
'buffer': self.buffer.clone(),
|
| 140 |
+
'tokens': deepcopy(self.tokens),
|
| 141 |
+
'rand_id': self.rand_id,
|
| 142 |
+
'token_id': self.token_id,
|
| 143 |
+
'rng_state': self.rng_state,
|
| 144 |
+
'epoch': self._epoch
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
def load_state_dict(self, state_dict):
|
| 148 |
+
self.states = state_dict['states']
|
| 149 |
+
self.buffer = state_dict['buffer'].clone()
|
| 150 |
+
self.tokens = deepcopy(state_dict['tokens'])
|
| 151 |
+
self.rand_id = state_dict['rand_id']
|
| 152 |
+
self.token_id = state_dict['token_id']
|
| 153 |
+
self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
|
| 154 |
+
self._epoch = state_dict['epoch']
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@dataclass
|
| 158 |
+
class DataCollatorForLanguageModeling:
|
| 159 |
+
"""
|
| 160 |
+
Data collator used for language modeling.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 164 |
+
The tokenizer used for encoding the data.
|
| 165 |
+
varlen (`bool`):
|
| 166 |
+
Whether to return sequences with variable lengths.
|
| 167 |
+
If `True`, the offsets indicating the start and end of each sequence will be returned.
|
| 168 |
+
For example, if the sequence lengths are `[4, 8, 12]`,
|
| 169 |
+
the returned `input_ids` will be a long flattened tensor of shape `[1, 24]`, with `offsets` being `[0, 4, 12, 24]`.
|
| 170 |
+
If `False`, the `input_ids` with shape `[batch_size, seq_len]` will be returned directly.
|
| 171 |
+
return_tensors (`str`):
|
| 172 |
+
The type of Tensor to return. Allowable values are "pt".
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
tokenizer: PreTrainedTokenizer
|
| 176 |
+
varlen: bool = False
|
| 177 |
+
return_tensors: str = "pt"
|
| 178 |
+
|
| 179 |
+
def __call__(
|
| 180 |
+
self,
|
| 181 |
+
examples: List[Union[List[int], Dict[str, Any]]]
|
| 182 |
+
) -> Dict[str, Any]:
|
| 183 |
+
if not isinstance(examples[0], Dict):
|
| 184 |
+
examples = [{'input_ids': example} for example in examples]
|
| 185 |
+
|
| 186 |
+
def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
|
| 187 |
+
tensorized = {}
|
| 188 |
+
for key in ['input_ids', 'offsets']:
|
| 189 |
+
if key not in example:
|
| 190 |
+
continue
|
| 191 |
+
if isinstance(example[key], List):
|
| 192 |
+
tensorized[key] = torch.tensor(example[key], dtype=torch.long)
|
| 193 |
+
elif isinstance(example[key], np.ndarray):
|
| 194 |
+
tensorized[key] = torch.from_numpy(example[key])
|
| 195 |
+
else:
|
| 196 |
+
tensorized[key] = example[key]
|
| 197 |
+
return tensorized
|
| 198 |
+
|
| 199 |
+
examples = list(map(tensorize, examples))
|
| 200 |
+
|
| 201 |
+
if not self.varlen:
|
| 202 |
+
length_of_first = examples[0]['input_ids'].size(0)
|
| 203 |
+
# Check if padding is necessary.
|
| 204 |
+
if all(example['input_ids'].size(0) == length_of_first for example in examples):
|
| 205 |
+
batch = {
|
| 206 |
+
'input_ids': torch.stack([example['input_ids'] for example in examples], dim=0),
|
| 207 |
+
}
|
| 208 |
+
else:
|
| 209 |
+
# If yes, check if we have a `pad_token`.
|
| 210 |
+
if self.tokenizer._pad_token is None:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"You are attempting to pad samples but the tokenizer you are using "
|
| 213 |
+
f"({self.tokenizer.__class__.__name__}) does not have a pad token."
|
| 214 |
+
)
|
| 215 |
+
batch = self.tokenizer.pad(examples, return_tensors=self.return_tensors, return_attention_mask=False)
|
| 216 |
+
else:
|
| 217 |
+
if len(examples) > 1:
|
| 218 |
+
raise ValueError("The batch size must be 1 for variable length inputs.")
|
| 219 |
+
batch = {
|
| 220 |
+
'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)
|
| 221 |
+
}
|
| 222 |
+
if 'offsets' in examples[0]:
|
| 223 |
+
batch['offsets'] = torch.cat([example['offsets'] for example in examples], dim=0).unsqueeze(0)
|
| 224 |
+
else:
|
| 225 |
+
# determine boundaries by bos/eos positions
|
| 226 |
+
if self.tokenizer.add_bos_token:
|
| 227 |
+
offsets = []
|
| 228 |
+
if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
|
| 229 |
+
offsets.append(torch.tensor([0], dtype=torch.long))
|
| 230 |
+
offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1])
|
| 231 |
+
offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
|
| 232 |
+
batch['offsets'] = torch.cat(offsets, dim=0)
|
| 233 |
+
elif self.tokenizer.add_eos_token:
|
| 234 |
+
offsets = [torch.tensor([0], dtype=torch.long)]
|
| 235 |
+
offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1)
|
| 236 |
+
if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
|
| 237 |
+
offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
|
| 238 |
+
batch['offsets'] = torch.cat(offsets, dim=0)
|
| 239 |
+
else:
|
| 240 |
+
raise ValueError("You must allow the tokenizer to add either a bos or eos token as separators.")
|
| 241 |
+
|
| 242 |
+
labels = batch['input_ids'].clone()
|
| 243 |
+
if self.tokenizer.pad_token_id is not None:
|
| 244 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 245 |
+
batch["labels"] = labels
|
| 246 |
+
return batch
|
flame/logging.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
from transformers.trainer_callback import (ExportableState, TrainerCallback,
|
| 10 |
+
TrainerControl, TrainerState)
|
| 11 |
+
from transformers.training_args import TrainingArguments
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_logger(name: str = None) -> logging.Logger:
|
| 15 |
+
formatter = logging.Formatter(
|
| 16 |
+
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
| 17 |
+
)
|
| 18 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 19 |
+
handler.setFormatter(formatter)
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(name)
|
| 22 |
+
if 'RANK' in os.environ and int(os.environ['RANK']) == 0:
|
| 23 |
+
logger.setLevel(logging.INFO)
|
| 24 |
+
logger.addHandler(handler)
|
| 25 |
+
|
| 26 |
+
return logger
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
LOG_FILE_NAME = "trainer_log.jsonl"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class LogCallback(TrainerCallback, ExportableState):
|
| 35 |
+
def __init__(self, start_time: float = None, elapsed_time: float = None):
|
| 36 |
+
|
| 37 |
+
self.start_time = time.time() if start_time is None else start_time
|
| 38 |
+
self.elapsed_time = 0 if elapsed_time is None else elapsed_time
|
| 39 |
+
self.last_time = self.start_time
|
| 40 |
+
|
| 41 |
+
def on_train_begin(
|
| 42 |
+
self,
|
| 43 |
+
args: TrainingArguments,
|
| 44 |
+
state: TrainerState,
|
| 45 |
+
control: TrainerControl,
|
| 46 |
+
**kwargs
|
| 47 |
+
):
|
| 48 |
+
r"""
|
| 49 |
+
Event called at the beginning of training.
|
| 50 |
+
"""
|
| 51 |
+
if state.is_local_process_zero:
|
| 52 |
+
if not args.resume_from_checkpoint:
|
| 53 |
+
self.start_time = time.time()
|
| 54 |
+
self.elapsed_time = 0
|
| 55 |
+
else:
|
| 56 |
+
self.start_time = state.stateful_callbacks['LogCallback']['start_time']
|
| 57 |
+
self.elapsed_time = state.stateful_callbacks['LogCallback']['elapsed_time']
|
| 58 |
+
|
| 59 |
+
if args.save_on_each_node:
|
| 60 |
+
if not state.is_local_process_zero:
|
| 61 |
+
return
|
| 62 |
+
else:
|
| 63 |
+
if not state.is_world_process_zero:
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
self.last_time = time.time()
|
| 67 |
+
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
| 68 |
+
logger.warning("Previous log file in this folder will be deleted.")
|
| 69 |
+
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
| 70 |
+
|
| 71 |
+
def on_log(
|
| 72 |
+
self,
|
| 73 |
+
args: TrainingArguments,
|
| 74 |
+
state: TrainerState,
|
| 75 |
+
control: TrainerControl,
|
| 76 |
+
logs,
|
| 77 |
+
**kwargs
|
| 78 |
+
):
|
| 79 |
+
if args.save_on_each_node:
|
| 80 |
+
if not state.is_local_process_zero:
|
| 81 |
+
return
|
| 82 |
+
else:
|
| 83 |
+
if not state.is_world_process_zero:
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
self.elapsed_time += time.time() - self.last_time
|
| 87 |
+
self.last_time = time.time()
|
| 88 |
+
if 'num_input_tokens_seen' in logs:
|
| 89 |
+
logs['num_tokens'] = logs.pop('num_input_tokens_seen')
|
| 90 |
+
state.log_history[-1].pop('num_input_tokens_seen')
|
| 91 |
+
throughput = logs['num_tokens'] / args.world_size / self.elapsed_time
|
| 92 |
+
state.log_history[-1]['throughput'] = logs['throughput'] = throughput
|
| 93 |
+
state.stateful_callbacks["LogCallback"] = self.state()
|
| 94 |
+
|
| 95 |
+
logs = dict(
|
| 96 |
+
current_steps=state.global_step,
|
| 97 |
+
total_steps=state.max_steps,
|
| 98 |
+
loss=state.log_history[-1].get("loss", None),
|
| 99 |
+
eval_loss=state.log_history[-1].get("eval_loss", None),
|
| 100 |
+
predict_loss=state.log_history[-1].get("predict_loss", None),
|
| 101 |
+
learning_rate=state.log_history[-1].get("learning_rate", None),
|
| 102 |
+
epoch=state.log_history[-1].get("epoch", None),
|
| 103 |
+
percentage=round(state.global_step / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 107 |
+
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
| 108 |
+
f.write(json.dumps(logs) + "\n")
|
| 109 |
+
|
| 110 |
+
def state(self) -> dict:
|
| 111 |
+
return {
|
| 112 |
+
'start_time': self.start_time,
|
| 113 |
+
'elapsed_time': self.elapsed_time
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def from_state(cls, state):
|
| 118 |
+
return cls(state['start_time'], state['elapsed_time'])
|
flame/parser.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import transformers
|
| 9 |
+
from transformers import HfArgumentParser, TrainingArguments
|
| 10 |
+
|
| 11 |
+
from flame.logging import get_logger
|
| 12 |
+
|
| 13 |
+
logger = get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TrainingArguments(TrainingArguments):
|
| 18 |
+
|
| 19 |
+
model_name_or_path: str = field(
|
| 20 |
+
default=None,
|
| 21 |
+
metadata={
|
| 22 |
+
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
| 23 |
+
},
|
| 24 |
+
)
|
| 25 |
+
tokenizer: str = field(
|
| 26 |
+
default="fla-hub/gla-1.3B-100B",
|
| 27 |
+
metadata={"help": "Name of the tokenizer to use."}
|
| 28 |
+
)
|
| 29 |
+
use_fast_tokenizer: bool = field(
|
| 30 |
+
default=False,
|
| 31 |
+
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
| 32 |
+
)
|
| 33 |
+
from_config: bool = field(
|
| 34 |
+
default=True,
|
| 35 |
+
metadata={"help": "Whether to initialize models from scratch."},
|
| 36 |
+
)
|
| 37 |
+
dataset: Optional[str] = field(
|
| 38 |
+
default=None,
|
| 39 |
+
metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."},
|
| 40 |
+
)
|
| 41 |
+
dataset_name: Optional[str] = field(
|
| 42 |
+
default=None,
|
| 43 |
+
metadata={"help": "The name of provided dataset(s) to use."},
|
| 44 |
+
)
|
| 45 |
+
cache_dir: str = field(
|
| 46 |
+
default=None,
|
| 47 |
+
metadata={"help": "Path to the cached tokenized dataset."},
|
| 48 |
+
)
|
| 49 |
+
split: str = field(
|
| 50 |
+
default="train",
|
| 51 |
+
metadata={"help": "Which dataset split to use for training and evaluation."},
|
| 52 |
+
)
|
| 53 |
+
streaming: bool = field(
|
| 54 |
+
default=False,
|
| 55 |
+
metadata={"help": "Enable dataset streaming."},
|
| 56 |
+
)
|
| 57 |
+
hf_hub_token: Optional[str] = field(
|
| 58 |
+
default=None,
|
| 59 |
+
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
| 60 |
+
)
|
| 61 |
+
preprocessing_num_workers: Optional[int] = field(
|
| 62 |
+
default=None,
|
| 63 |
+
metadata={"help": "The number of processes to use for the pre-processing."},
|
| 64 |
+
)
|
| 65 |
+
buffer_size: int = field(
|
| 66 |
+
default=2048,
|
| 67 |
+
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
| 68 |
+
)
|
| 69 |
+
context_length: int = field(
|
| 70 |
+
default=2048,
|
| 71 |
+
metadata={"help": "The context length of the tokenized inputs in the dataset."},
|
| 72 |
+
)
|
| 73 |
+
varlen: bool = field(
|
| 74 |
+
default=False,
|
| 75 |
+
metadata={"help": "Enable training with variable length inputs."},
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_train_args():
|
| 80 |
+
parser = HfArgumentParser(TrainingArguments)
|
| 81 |
+
args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
| 82 |
+
|
| 83 |
+
if unknown_args:
|
| 84 |
+
print(parser.format_help())
|
| 85 |
+
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
|
| 86 |
+
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
|
| 87 |
+
|
| 88 |
+
if args.should_log:
|
| 89 |
+
transformers.utils.logging.set_verbosity(args.get_process_log_level())
|
| 90 |
+
transformers.utils.logging.enable_default_handler()
|
| 91 |
+
transformers.utils.logging.enable_explicit_format()
|
| 92 |
+
# set seeds manually
|
| 93 |
+
transformers.set_seed(args.seed)
|
| 94 |
+
return args
|