msj19 commited on
Commit
50776d2
·
verified ·
1 Parent(s): 6d0130a

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +188 -0
  2. fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dkv.cpython-310.pyc +0 -0
  3. fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc +0 -0
  4. fla3/ops/retention/__pycache__/parallel.cpython-312.pyc +0 -0
  5. fla3/ops/rwkv7/__pycache__/chunk.cpython-312.pyc +0 -0
  6. fla3/ops/rwkv7/__pycache__/fused_addcmul.cpython-310.pyc +0 -0
  7. fla3/ops/rwkv7/__pycache__/fused_k_update.cpython-310.pyc +0 -0
  8. fla3/ops/rwkv7/fused_recurrent.py +328 -0
  9. fla3/ops/simple_gla/__pycache__/chunk.cpython-312.pyc +0 -0
  10. fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  11. fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc +0 -0
  12. fla3/ops/simple_gla/__pycache__/parallel.cpython-312.pyc +0 -0
  13. fla3/ops/simple_gla/fused_recurrent.py +108 -0
  14. fla3/ops/simple_gla/naive.py +54 -0
  15. fla3/ops/ttt/naive.py +126 -0
  16. fla3/ops/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  17. fla3/ops/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  18. fla3/ops/utils/__pycache__/asm.cpython-310.pyc +0 -0
  19. fla3/ops/utils/__pycache__/asm.cpython-312.pyc +0 -0
  20. fla3/ops/utils/__pycache__/cumsum.cpython-310.pyc +0 -0
  21. fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc +0 -0
  22. fla3/ops/utils/__pycache__/index.cpython-310.pyc +0 -0
  23. fla3/ops/utils/__pycache__/index.cpython-312.pyc +0 -0
  24. fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc +0 -0
  25. fla3/ops/utils/__pycache__/logsumexp.cpython-310.pyc +0 -0
  26. fla3/ops/utils/__pycache__/logsumexp.cpython-312.pyc +0 -0
  27. fla3/ops/utils/__pycache__/matmul.cpython-310.pyc +0 -0
  28. fla3/ops/utils/__pycache__/op.cpython-312.pyc +0 -0
  29. fla3/ops/utils/__pycache__/pooling.cpython-310.pyc +0 -0
  30. fla3/ops/utils/cumsum.py +414 -0
  31. fla3/ops/utils/index.py +83 -0
  32. fla3/ops/utils/logcumsumexp.py +52 -0
  33. fla3/ops/utils/matmul.py +245 -0
  34. fla3/ops/utils/op.py +39 -0
  35. fla3/ops/utils/pack.py +208 -0
  36. fla3/ops/utils/pooling.py +207 -0
  37. fla3/ops/utils/softmax.py +111 -0
  38. fla3/ops/utils/solve_tril.py +276 -0
  39. flame/__init__.py +0 -0
  40. flame/__pycache__/__init__.cpython-310.pyc +0 -0
  41. flame/__pycache__/__init__.cpython-312.pyc +0 -0
  42. flame/__pycache__/data.cpython-310.pyc +0 -0
  43. flame/__pycache__/data.cpython-312.pyc +0 -0
  44. flame/__pycache__/logging.cpython-310.pyc +0 -0
  45. flame/__pycache__/logging.cpython-312.pyc +0 -0
  46. flame/__pycache__/parser.cpython-310.pyc +0 -0
  47. flame/__pycache__/parser.cpython-312.pyc +0 -0
  48. flame/data.py +246 -0
  49. flame/logging.py +118 -0
  50. flame/parser.py +94 -0
README.md ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Linear Attention Made Easy
4
+
5
+ </div>
6
+
7
+ > [!IMPORTANT]
8
+ > The `flame` project has been migrated to a new project built on torchtitan.
9
+ > Please visit the [new repository](https://github.com/fla-org/flame) for details and updates.
10
+ >
11
+ > The code here is now **archived as legacy**, and no future updates will be synchronized here.
12
+
13
+ A minimal framework for training FLA models, whether from scratch or through finetuning.
14
+
15
+ Built on the robust infrastructure of 🤗, `flame` enables you to train large language models with just a few lines of code:
16
+ we use `datasets` for data processing, `transformers` for model definitions, and `accelerate`[^1] for seamless distributed training.
17
+
18
+ In this README, we will guide you through the process of using `flame` to train GLA models.
19
+
20
+ ## Setup
21
+
22
+ To get started, you'll need to install the required packages.
23
+ Both `fla` and `flame` have minimal dependencies.
24
+ Clone the `fla` repository and install the necessary packages as follows:
25
+
26
+ ```bash
27
+ git clone https://github.com/sustcsonglin/flash-linear-attention.git
28
+ pip install .
29
+ pip install accelerate
30
+ ```
31
+
32
+ > [!CAUTION]
33
+ > The 🤗 `tokenizers` have some [memory leak issues](https://github.com/huggingface/tokenizers/issues/1539) when processing very long documents.
34
+ > To address this, please ensure you install `tokenizers>=0.20.4`.
35
+
36
+ ## Preprocessing
37
+
38
+ Before training, you need to download and pre-tokenize your dataset.
39
+ We provide a straightforward script for this.
40
+ For instance, to tokenize a 10B sample of the `fineweb-edu` dataset, run:
41
+
42
+ ```bash
43
+ python preprocess.py \
44
+ --dataset HuggingFaceFW/fineweb-edu \
45
+ --name sample-10BT \
46
+ --split train \
47
+ --context_length 2048
48
+ ```
49
+ ```
50
+ python preprocess.py \
51
+ --dataset /mnt/jfzn/msj/fineweb100B_hf/datasets--HuggingFaceFW--fineweb-edu/sample/100BT \
52
+ --name sample-100BT \
53
+ --split train \
54
+ --context_length 2048
55
+ ```
56
+ /mnt/jfzn/msj/fineweb100B_hf/datasets--HuggingFaceFW--fineweb-edu/sample/100BT
57
+
58
+
59
+ This will cache the processed dataset at `data/HuggingFaceFW/fineweb-edu/sample-10BT/train`.
60
+
61
+ GLA utilizes a subset of Slimpajama for pretraining [in the paper](https://proceedings.mlr.press/v235/yang24ab.html).
62
+ Given the size of the dataset, the fastest way to download it is using `git lfs` (refer to [this issue](https://huggingface.co/datasets/cerebras/SlimPajama-627B/discussions/2)).
63
+ ```bash
64
+ git lfs install
65
+ git clone https://huggingface.co/datasets/cerebras/SlimPajama-627B --depth 1
66
+ python preprocess.py \
67
+ --dataset SlimPajama-627B \
68
+ --split train \
69
+ --context_length 2048
70
+ ```
71
+
72
+ ## Training from scratch
73
+
74
+ To train your 340M model from scratch, execute the following command:
75
+
76
+ ```bash
77
+ bash train.sh \
78
+ type=gla \
79
+ lr=3e-4 \
80
+ scheduler=cosine_with_min_lr \
81
+ batch=32 \
82
+ update=1 \
83
+ warmup=1024 \
84
+ steps=20480 \
85
+ context=2048 \
86
+ gpus=8 \
87
+ nodes=1 \
88
+ path=exp/gla-340M-10B \
89
+ project=fla \
90
+ model=configs/gla_340M.json \
91
+ data=HuggingFaceFW/fineweb-edu \
92
+ name=sample-10BT \
93
+ cache=data/HuggingFaceFW/fineweb-edu/sample-10BT/train
94
+ ```
95
+
96
+ Key parameters:
97
+
98
+ | | Description | Default |
99
+ | :-------- | :---------------------------- | -------------------- |
100
+ | lr | `learning_rate` | `3e-4` |
101
+ | scheduler | `lr_scheduler_type` | `cosine_with_min_lr` |
102
+ | batch | `batch_size` | `32` |
103
+ | update | `gradient_accumulation_steps` | `1` |
104
+ | context | `context_length` | `2048` |
105
+ | gpus | `num_gpus_per_node` | `8` |
106
+ | nodes | `num_nodes` | `1` |
107
+ | warmup | `warmup_steps` | `1024` |
108
+ | steps | `max_steps` | `20480` |
109
+
110
+ The learning rate is set to `3e-4` by default, equipped with a cosine scheduler.
111
+ Other scheduler types like WSD (`warmup_stable_decay`)[^2] are also supported.
112
+
113
+ The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as
114
+ `batch_size × gradient_accumulation_steps × context_length × num_gpus_per_node × num_nodes`.
115
+ For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens).
116
+
117
+ The `warmup_steps` parameter indicates the number of steps for the learning rate warmup phase, while `max_steps` represents the maximum number of training steps.
118
+ Each step processes `global_batch_size` tokens.
119
+ Consequently, `512` and `20480` correspond to processing 0.5B and 10B tokens, respectively.
120
+
121
+ :warning: Monitor the value of `global_batch_size`, `warmup_steps`, and `max_steps` carefully when modifying any of the hyperparameters!!
122
+
123
+ `flame` also supports resuming interrupted training by specifying the checkpoint path.
124
+ Simply use the following command:
125
+
126
+ ```bash
127
+ bash train.sh \
128
+ type=gla \
129
+ lr=3e-4 \
130
+ steps=20480 \
131
+ batch=32 \
132
+ update=1 \
133
+ warmup=1024 \
134
+ context=2048 \
135
+ gpus=8 \
136
+ nodes=1 \
137
+ path=exp/gla-340M-10B \
138
+ project=fla \
139
+ model=configs/gla_340M.json \
140
+ data=HuggingFaceFW/fineweb-edu \
141
+ name=sample-10BT \
142
+ cache=data/HuggingFaceFW/fineweb-edu/sample-10BT/train \
143
+ checkpoint=exp/gla-340M-10B/checkpoint-8192
144
+ ```
145
+
146
+ You can also use `wandb` to monitor your training process effectively.
147
+
148
+ ![wandb](https://github.com/user-attachments/assets/05ca031c-1cae-41c9-bfcb-5b6b6d0df729)
149
+
150
+ ## Continual Pretraining
151
+
152
+ `flame` supports continual training from a pretrained checkpoint.
153
+ Below, we provide an example of how to finetune Mistral-7B to GLA.
154
+ You can follow similar steps to reproduce the results in the [GSA paper](https://arxiv.org/abs/2409.07146):
155
+
156
+ 1. Initialize a brand-new GLA-7B model from the config and copy the mathced pretrained weights from Mistral-7B:
157
+ ```bash
158
+ cd ../utils
159
+ python convert_from_llama.py \
160
+ --model mistralai/Mistral-7B-v0.1 \
161
+ --config ../training/configs/gla_7B.json \
162
+ --output ../training/converted/gla-7B
163
+ cd -
164
+ ```
165
+
166
+ 2. Directly launch training from the converted checkpoint:
167
+ ```bash
168
+ bash train.sh \
169
+ type=gla \
170
+ lr=3e-5 \
171
+ steps=10240 \
172
+ batch=4 \
173
+ update=8 \
174
+ warmup=512 \
175
+ context=2048 \
176
+ path=exp/gla-7B-20B \
177
+ project=fla \
178
+ model=converted/gla-7B \
179
+ data=SlimPajama-627B \
180
+ cache=data/SlimPajama-627B/train
181
+ ```
182
+
183
+ Please be aware that finetuning on a single node may not be the most efficient approach.
184
+ If available, consider leveraging multi-node GPUs for optimal performance.
185
+ You can find guidance on how to launch a multi-node job in the [accelerate tutorial](https://github.com/huggingface/accelerate/blob/main/examples/slurm/submit_multinode.sh).
186
+
187
+ [^1]: The `accelerate` library supports various distributed frameworks, like `deepspeed` and `megatron` for large-scale training. We use `deepspeed` in our case.
188
+ [^2]: https://arxiv.org/abs/2404.06395
fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dkv.cpython-310.pyc ADDED
Binary file (5.47 kB). View file
 
fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc ADDED
Binary file (2.26 kB). View file
 
fla3/ops/retention/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (3.75 kB). View file
 
fla3/ops/rwkv7/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (2.56 kB). View file
 
fla3/ops/rwkv7/__pycache__/fused_addcmul.cpython-310.pyc ADDED
Binary file (6.37 kB). View file
 
fla3/ops/rwkv7/__pycache__/fused_k_update.cpython-310.pyc ADDED
Binary file (3.93 kB). View file
 
fla3/ops/rwkv7/fused_recurrent.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange
11
+
12
+ from fla.ops.generalized_delta_rule import fused_recurrent_dplr_delta_rule
13
+ from fla.ops.utils.op import exp
14
+ from fla.utils import input_guard, use_cuda_graph
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BV in [16, 32, 64]
26
+ for num_warps in [2, 4, 8, 16, 32]
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['BK'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def fused_recurrent_rwkv7_fwd_kernel(
34
+ r,
35
+ w,
36
+ k,
37
+ v,
38
+ kk,
39
+ a,
40
+ o,
41
+ h0,
42
+ ht,
43
+ cu_seqlens,
44
+ scale,
45
+ T,
46
+ B: tl.constexpr,
47
+ H: tl.constexpr,
48
+ K: tl.constexpr,
49
+ V: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ REVERSE: tl.constexpr,
53
+ USE_INITIAL_STATE: tl.constexpr,
54
+ STORE_FINAL_STATE: tl.constexpr,
55
+ IS_VARLEN: tl.constexpr,
56
+ IS_DECODE: tl.constexpr,
57
+ ):
58
+ i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
59
+ i_n, i_h = i_nh // H, i_nh % H
60
+
61
+ if IS_VARLEN:
62
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
63
+ T = eos - bos
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+
67
+ o_k = tl.arange(0, BK)
68
+ o_v = i_v * BV + tl.arange(0, BV)
69
+ p_r = r + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
70
+ p_w = w + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
71
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
72
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
73
+ p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
74
+ p_kk = kk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
75
+
76
+ p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
77
+
78
+ mask_k = o_k < K
79
+ mask_v = o_v < V
80
+ mask_h = mask_k[None, :] & mask_v[:, None]
81
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
82
+
83
+ if USE_INITIAL_STATE:
84
+ p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
85
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
86
+
87
+ if IS_DECODE:
88
+ b_r = tl.load(p_r, mask=mask_k, other=0).to(tl.float32) * scale
89
+ b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32)
90
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
91
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
92
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
93
+ b_kk = tl.load(p_kk, mask=mask_k, other=0).to(tl.float32)
94
+ b_act_a = -b_kk
95
+ b_b = b_kk * b_a
96
+
97
+ tmp = tl.sum(b_h * b_act_a[None, :], axis=1)
98
+ b_h = exp(b_w)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
99
+ b_o = tl.sum(b_h * b_r[None, :], axis=1)
100
+
101
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
102
+ else:
103
+ for _ in range(0, T):
104
+ b_r = tl.load(p_r, mask=mask_k, other=0).to(tl.float32) * scale
105
+ b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32)
106
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
107
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
108
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
109
+ b_kk = tl.load(p_kk, mask=mask_k, other=0).to(tl.float32)
110
+ b_act_a = -b_kk
111
+ b_b = b_kk * b_a
112
+
113
+ tmp = tl.sum(b_h * b_act_a[None, :], axis=1)
114
+ b_h = exp(b_w)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
115
+ b_o = tl.sum(b_h * b_r[None, :], axis=1)
116
+
117
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
118
+ p_r += (-1 if REVERSE else 1) * H*K
119
+ p_w += (-1 if REVERSE else 1) * H*K
120
+ p_k += (-1 if REVERSE else 1) * H*K
121
+ p_v += (-1 if REVERSE else 1) * H*V
122
+ p_a += (-1 if REVERSE else 1) * H*K
123
+ p_kk += (-1 if REVERSE else 1) * H*K
124
+ p_o += (-1 if REVERSE else 1) * H*V
125
+
126
+ if STORE_FINAL_STATE:
127
+ p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
128
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
129
+
130
+
131
+ @input_guard
132
+ def fused_recurrent_rwkv7_fwd(
133
+ r: torch.Tensor,
134
+ w: torch.Tensor,
135
+ k: torch.Tensor,
136
+ v: torch.Tensor,
137
+ kk: torch.Tensor,
138
+ a: torch.Tensor,
139
+ scale: Optional[float] = 1.0,
140
+ initial_state: Optional[torch.Tensor] = None,
141
+ output_final_state: bool = False,
142
+ reverse: bool = False,
143
+ cu_seqlens: Optional[torch.LongTensor] = None,
144
+ ):
145
+ B, T, H, K, V = *k.shape, v.shape[-1]
146
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
147
+ BK = triton.next_power_of_2(K)
148
+ IS_DECODE = (T == 1)
149
+
150
+ h0 = initial_state
151
+ if not output_final_state:
152
+ ht = None
153
+ else:
154
+ ht = r.new_empty(N, H, K, V, dtype=torch.float32)
155
+ o = torch.empty_like(v)
156
+
157
+ def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
158
+ fused_recurrent_rwkv7_fwd_kernel[grid](
159
+ r,
160
+ w,
161
+ k,
162
+ v,
163
+ kk,
164
+ a,
165
+ o,
166
+ h0,
167
+ ht,
168
+ cu_seqlens,
169
+ scale,
170
+ T=T,
171
+ B=B,
172
+ H=H,
173
+ K=K,
174
+ V=V,
175
+ BK=BK,
176
+ REVERSE=reverse,
177
+ IS_DECODE=IS_DECODE
178
+ )
179
+ return o, ht
180
+
181
+
182
+ def fused_recurrent_rwkv7(
183
+ r: torch.Tensor,
184
+ w: torch.Tensor,
185
+ k: torch.Tensor,
186
+ v: torch.Tensor,
187
+ a: torch.Tensor,
188
+ b: torch.Tensor,
189
+ scale: float = 1.0,
190
+ initial_state: torch.Tensor = None,
191
+ output_final_state: bool = True,
192
+ cu_seqlens: Optional[torch.LongTensor] = None,
193
+ head_first: bool = False,
194
+ ):
195
+ """
196
+ Args:
197
+ r (torch.Tensor):
198
+ r of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
199
+ w (torch.Tensor):
200
+ log decay of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
201
+ k (torch.Tensor):
202
+ k of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
203
+ v (torch.Tensor):
204
+ v of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
205
+ a (torch.Tensor):
206
+ a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
207
+ b (torch.Tensor):
208
+ b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
209
+ scale (float):
210
+ scale of the attention.
211
+ initial_state (torch.Tensor):
212
+ initial state of shape `[B, H, K, V]` if cu_seqlens is None else `[N, H, K, V]` where N = len(cu_seqlens) - 1.
213
+ output_final_state (bool):
214
+ whether to output the final state.
215
+ cu_seqlens (torch.LongTensor):
216
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
217
+ consistent with the FlashAttention API.
218
+ head_first (bool):
219
+ whether to use head first. Recommended to be False to avoid extra transposes.
220
+ Default: `False`.
221
+ """
222
+ return fused_recurrent_dplr_delta_rule(
223
+ q=r,
224
+ k=k,
225
+ v=v,
226
+ a=a,
227
+ b=b,
228
+ gk=w,
229
+ scale=scale,
230
+ initial_state=initial_state,
231
+ output_final_state=output_final_state,
232
+ cu_seqlens=cu_seqlens,
233
+ head_first=head_first,
234
+ )
235
+
236
+
237
+ def fused_mul_recurrent_rwkv7(
238
+ r: torch.Tensor,
239
+ w: torch.Tensor,
240
+ k: torch.Tensor,
241
+ v: torch.Tensor,
242
+ kk: torch.Tensor,
243
+ a: torch.Tensor,
244
+ scale: Optional[float] = 1.0,
245
+ initial_state: Optional[torch.Tensor] = None,
246
+ output_final_state: bool = False,
247
+ reverse: bool = False,
248
+ cu_seqlens: Optional[torch.Tensor] = None,
249
+ head_first: bool = False,
250
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
251
+ r"""
252
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
253
+
254
+ Args:
255
+ r (torch.Tensor):
256
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
257
+ w (torch.Tensor):
258
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
259
+ k (torch.Tensor):
260
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
261
+ v (torch.Tensor):
262
+ a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
263
+ kk (torch.Tensor):
264
+ b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
265
+ a (torch.Tensor):
266
+ gk of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. decay term in log space!
267
+ scale (Optional[int]):
268
+ Scale factor for the RetNet attention scores.
269
+ If not provided, it will default to `1 / sqrt(K)`. Default: 1.
270
+ initial_state (Optional[torch.Tensor]):
271
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
272
+ For equal-length input sequences, `N` equals the batch size `B`.
273
+ Default: `None`.
274
+ output_final_state (Optional[bool]):
275
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
276
+ reverse (Optional[bool]):
277
+ If `True`, process the state passing in reverse order. Default: `False`.
278
+ cu_seqlens (Optional[torch.Tensor]):
279
+ Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
280
+ consistent with the FlashAttention API.
281
+ head_first (Optional[bool]):
282
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
283
+ Default: `False`.
284
+ """
285
+ if head_first:
286
+ raise DeprecationWarning(
287
+ "head_first is deprecated and will be removed in a future version. "
288
+ "Please use head_first=False for now instead."
289
+ )
290
+ r, w, k, v, kk, a = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (r, w, k, v, kk, a))
291
+ if not head_first and r.shape[1] < r.shape[2]:
292
+ warnings.warn(
293
+ f"Input tensor shape suggests potential format mismatch: seq_len ({r.shape[1]}) < num_heads ({r.shape[2]}). "
294
+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
295
+ "when head_first=False was specified. "
296
+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
297
+ )
298
+ if cu_seqlens is not None:
299
+ if r.shape[0] != 1:
300
+ raise ValueError(
301
+ f"The batch size is expected to be 1 rather than {r.shape[0]} when using `cu_seqlens`."
302
+ f"Please flatten variable-length inputs before processing."
303
+ )
304
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
305
+ raise ValueError(
306
+ f"The number of initial states is expected to be equal to the number of input sequences, "
307
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
308
+ )
309
+ if scale is None:
310
+ scale = r.shape[-1] ** -0.5
311
+ else:
312
+ assert scale > 0, "scale must be positive"
313
+ o, final_state = fused_recurrent_rwkv7_fwd(
314
+ r,
315
+ w,
316
+ k,
317
+ v,
318
+ kk,
319
+ a,
320
+ scale,
321
+ initial_state,
322
+ output_final_state,
323
+ reverse,
324
+ cu_seqlens,
325
+ )
326
+ if head_first:
327
+ o = rearrange(o, 'b t h ... -> b h t ...')
328
+ return o, final_state
fla3/ops/simple_gla/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (4.13 kB). View file
 
fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc ADDED
Binary file (17 kB). View file
 
fla3/ops/simple_gla/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (35.6 kB). View file
 
fla3/ops/simple_gla/fused_recurrent.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.common.fused_recurrent import fused_recurrent
9
+
10
+
11
+ def fused_recurrent_simple_gla(
12
+ q: torch.Tensor,
13
+ k: torch.Tensor,
14
+ v: torch.Tensor,
15
+ g: torch.Tensor,
16
+ scale: Optional[float] = None,
17
+ initial_state: Optional[torch.Tensor] = None,
18
+ output_final_state: bool = False,
19
+ reverse: bool = False,
20
+ cu_seqlens: Optional[torch.LongTensor] = None,
21
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
22
+ r"""
23
+ Args:
24
+ q (torch.Tensor):
25
+ queries of shape `[B, T, H, K]`.
26
+ k (torch.Tensor):
27
+ keys of shape `[B, T, H, K]`.
28
+ v (torch.Tensor):
29
+ values of shape `[B, T, H, V]`.
30
+ g (torch.Tensor):
31
+ Forget gates of shape `[B, T, H]`.
32
+ Compared to GLA, the gating is head-wise instead of elementwise.
33
+ scale (Optional[int]):
34
+ Scale factor for the attention scores.
35
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
36
+ initial_state (Optional[torch.Tensor]):
37
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
38
+ For equal-length input sequences, `N` equals the batch size `B`.
39
+ Default: `None`.
40
+ output_final_state (Optional[bool]):
41
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
42
+ reverse (Optional[bool]):
43
+ If `True`, process the state passing in reverse order. Default: `False`.
44
+ cu_seqlens (torch.LongTensor):
45
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
46
+ consistent with the FlashAttention API.
47
+
48
+ Returns:
49
+ o (torch.Tensor):
50
+ Outputs of shape `[B, T, H, V]`.
51
+ final_state (torch.Tensor):
52
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
53
+
54
+ Examples::
55
+ >>> import torch
56
+ >>> import torch.nn.functional as F
57
+ >>> from einops import rearrange
58
+ >>> from fla.ops.simple_gla import fused_recurrent_simple_gla
59
+ # inputs with equal lengths
60
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
61
+ >>> q = torch.randn(B, T, H, K, device='cuda')
62
+ >>> k = torch.randn(B, T, H, K, device='cuda')
63
+ >>> v = torch.randn(B, T, H, V, device='cuda')
64
+ >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
65
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
66
+ >>> o, ht = fused_recurrent_simple_gla(
67
+ q, k, v, g,
68
+ initial_state=h0,
69
+ output_final_state=True
70
+ )
71
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
72
+ >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
73
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
74
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
75
+ >>> o_var, ht_var = fused_recurrent_simple_gla(
76
+ q, k, v, g,
77
+ initial_state=h0,
78
+ output_final_state=True,
79
+ cu_seqlens=cu_seqlens
80
+ )
81
+ >>> assert o.allclose(o_var.view(o.shape))
82
+ >>> assert ht.allclose(ht_var)
83
+ """
84
+ if cu_seqlens is not None:
85
+ if q.shape[0] != 1:
86
+ raise ValueError(
87
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
88
+ f"Please flatten variable-length inputs before processing."
89
+ )
90
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
91
+ raise ValueError(
92
+ f"The number of initial states is expected to be equal to the number of input sequences, "
93
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
94
+ )
95
+ if scale is None:
96
+ scale = k.shape[-1] ** -0.5
97
+ o, final_state = fused_recurrent(
98
+ q=q,
99
+ k=k,
100
+ v=v,
101
+ g=g,
102
+ scale=scale,
103
+ initial_state=initial_state,
104
+ output_final_state=output_final_state,
105
+ reverse=reverse,
106
+ cu_seqlens=cu_seqlens
107
+ )
108
+ return o, final_state
fla3/ops/simple_gla/naive.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def torch_simple_gla(q, k, v, g, chunk_size=64, scale=None):
8
+ if scale is None:
9
+ scale = (q.shape[-1] ** -0.5)
10
+ q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale
11
+ k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
12
+ v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
13
+ g = rearrange(g, 'b h (n c) -> b h n c', c=chunk_size)
14
+ g = g.cumsum(-1)
15
+ kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None])
16
+ S = torch.zeros_like(kv)
17
+
18
+ for i in range(1, g.shape[-2]):
19
+ S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1]
20
+
21
+ inter = (q * g[..., None].exp()) @ S
22
+ attn = q @ k.transpose(-1, -2)
23
+ attn = attn * (g[..., None] - g[..., None, :]).exp()
24
+ attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
25
+ intra = attn @ v
26
+ o = inter + intra
27
+ return rearrange(o, 'b h n c d -> b h (n c) d')
28
+
29
+
30
+ def torch_simple_gla_recurrent(q, k, v, g, scale=None, initial_state=None, output_final_state=True):
31
+ B, H, T, DK = q.shape
32
+ original_dtype = q.dtype
33
+ q, k, v, g = q.float(), k.float(), v.float(), g.float()
34
+ if scale is None:
35
+ scale = DK ** -0.5
36
+ q = q * scale
37
+ _, _, _, DV = v.shape
38
+ if initial_state is None:
39
+ S = torch.zeros(B, H, DK, DV)
40
+ else:
41
+ S = initial_state
42
+ o = torch.zeros(B, H, T, DV).to(q)
43
+ for i in range(T):
44
+ gate = g[:, :, i].exp()
45
+ key = k[:, :, i]
46
+ value = v[:, :, i]
47
+ kv = key.unsqueeze(-1) * value.unsqueeze(-2)
48
+ S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv
49
+ q_i = q[:, :, i, :]
50
+ o_i = (q_i.unsqueeze(-1) * S).sum(-2)
51
+ o[:, :, i] = o_i
52
+ if not output_final_state:
53
+ S = None
54
+ return o.to(original_dtype), S
fla3/ops/ttt/naive.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def ttt_linear(
9
+ q: torch.Tensor,
10
+ k: torch.Tensor,
11
+ v: torch.Tensor,
12
+ w: torch.Tensor,
13
+ b: torch.Tensor,
14
+ eta: torch.Tensor,
15
+ scale: float,
16
+ eps: float,
17
+ mini_batch_size: int,
18
+ initial_state: torch.Tensor,
19
+ initial_state_bias: torch.Tensor,
20
+ output_final_state: bool
21
+ ):
22
+ B, H, T, D = q.shape
23
+ BT = mini_batch_size
24
+ NT = T // BT
25
+ # [NT, B, H, mini_batch_size, D]
26
+ _q = q.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
27
+ _k = k.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
28
+ _v = v.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
29
+ # [NT, B, H, BT, 1]
30
+ _eta = eta.reshape(B, H, NT, BT, 1).permute(2, 0, 1, 3, 4)
31
+ # [H, 1, D]
32
+ w = w.reshape(H, 1, D).to(torch.float32)
33
+ b = b.reshape(H, 1, D).to(torch.float32)
34
+
35
+ h = torch.zeros((B, H, D, D), device=v.device, dtype=torch.float32) if initial_state is None else initial_state
36
+ hb = torch.zeros((B, H, 1, D), device=v.device, dtype=torch.float32) if initial_state_bias is None else initial_state_bias
37
+ q *= scale
38
+ # [NT, B, H, BT, D]
39
+ o = torch.empty_like(_v)
40
+
41
+ for i in range(NT):
42
+ q_i, k_i, v_i, eta_i = [x[i] for x in [_q, _k, _v, _eta]]
43
+ kh = k_i @ h + hb
44
+ reconstruction_target = v_i - k_i
45
+
46
+ mean = kh.mean(-1, True)
47
+ var = kh.var(-1, unbiased=False, keepdim=True).to(torch.float32)
48
+ rstd = torch.sqrt(var + eps).to(torch.float32)
49
+ kh_hat = (kh - mean) / rstd
50
+
51
+ g = w * kh_hat + b - reconstruction_target
52
+ g *= w
53
+ v_new = (D * g - g.sum(-1, True) - kh_hat * (g * kh_hat).sum(-1, True)) / (rstd * D)
54
+
55
+ Attn = torch.tril(q_i @ k_i.transpose(-2, -1))
56
+ o_i = q_i @ h - (eta_i * Attn) @ v_new + hb - torch.tril(eta_i.expand_as(Attn)) @ v_new
57
+ h = h - (eta_i[:, :, -1, :, None] * k_i).transpose(-1, -2) @ v_new
58
+ hb = hb - torch.sum(eta_i[:, :, -1, :, None] * v_new, dim=-2, keepdim=True)
59
+ # layer norm with residuals
60
+
61
+ mean = o_i.mean(dim=-1, keepdim=True)
62
+ var = o_i.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
63
+ rstd = torch.sqrt(var + eps).to(torch.float32)
64
+ o[i] = o_i + (o_i - mean) / rstd * w + b
65
+
66
+ # [B, H, T, D]
67
+ o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D)
68
+ h = h if output_final_state else None
69
+ hb = hb if output_final_state else None
70
+ return o, h, hb
71
+
72
+
73
+ def chunk_ttt_linear_ref(
74
+ q: torch.Tensor,
75
+ k: torch.Tensor,
76
+ v: torch.Tensor,
77
+ w: torch.Tensor,
78
+ b: torch.Tensor,
79
+ eta: torch.Tensor,
80
+ scale: float = None,
81
+ eps: float = 1e-6,
82
+ mini_batch_size: int = 16,
83
+ initial_state: torch.Tensor = None,
84
+ initial_state_bias: torch.Tensor = None,
85
+ output_final_state: bool = False,
86
+ head_first: bool = False,
87
+ ):
88
+ assert q.dtype == k.dtype == v.dtype
89
+ assert k.shape[-1] == v.shape[-1], "The key and value dimension must be the same."
90
+ if isinstance(eta, float):
91
+ eta = torch.full_like(q[:, :, :, :1], eta)
92
+ if scale is None:
93
+ scale = k.shape[-1] ** -0.5
94
+ if not head_first:
95
+ q = q.transpose(1, 2)
96
+ k = k.transpose(1, 2)
97
+ v = v.transpose(1, 2)
98
+ eta = eta.transpose(1, 2)
99
+ T = q.shape[-2]
100
+ padded = (mini_batch_size - (T % mini_batch_size)) % mini_batch_size
101
+ if padded > 0:
102
+ q = F.pad(q, (0, 0, 0, padded))
103
+ k = F.pad(k, (0, 0, 0, padded))
104
+ v = F.pad(v, (0, 0, 0, padded))
105
+ eta = F.pad(eta, (0, 0, 0, padded))
106
+ eta[:, :, -1, :] = eta[:, :, -(padded+1), :]
107
+ assert q.shape[-2] % mini_batch_size == 0, "Sequence length should be a multiple of mini_batch_size."
108
+ q, k, v, eta, w, b = map(lambda x: x.to(torch.float32), [q, k, v, eta, w, b])
109
+ o, final_state, final_state_bias = ttt_linear(
110
+ q,
111
+ k,
112
+ v,
113
+ w,
114
+ b,
115
+ eta,
116
+ scale,
117
+ eps,
118
+ mini_batch_size,
119
+ initial_state,
120
+ initial_state_bias,
121
+ output_final_state,
122
+ )
123
+ o = o[:, :, :T, :].contiguous()
124
+ if not head_first:
125
+ o = o.transpose(1, 2)
126
+ return o, final_state, final_state_bias
fla3/ops/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.16 kB). View file
 
fla3/ops/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.2 kB). View file
 
fla3/ops/utils/__pycache__/asm.cpython-310.pyc ADDED
Binary file (482 Bytes). View file
 
fla3/ops/utils/__pycache__/asm.cpython-312.pyc ADDED
Binary file (543 Bytes). View file
 
fla3/ops/utils/__pycache__/cumsum.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc ADDED
Binary file (21.4 kB). View file
 
fla3/ops/utils/__pycache__/index.cpython-310.pyc ADDED
Binary file (3.12 kB). View file
 
fla3/ops/utils/__pycache__/index.cpython-312.pyc ADDED
Binary file (5.48 kB). View file
 
fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc ADDED
Binary file (1.54 kB). View file
 
fla3/ops/utils/__pycache__/logsumexp.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
fla3/ops/utils/__pycache__/logsumexp.cpython-312.pyc ADDED
Binary file (3.66 kB). View file
 
fla3/ops/utils/__pycache__/matmul.cpython-310.pyc ADDED
Binary file (5.29 kB). View file
 
fla3/ops/utils/__pycache__/op.cpython-312.pyc ADDED
Binary file (1.56 kB). View file
 
fla3/ops/utils/__pycache__/pooling.cpython-310.pyc ADDED
Binary file (5.61 kB). View file
 
fla3/ops/utils/cumsum.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from ...ops.utils.index import prepare_chunk_indices
12
+ from ...utils import check_shared_mem, input_guard
13
+
14
+ BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps)
23
+ for num_warps in [1, 2, 4, 8]
24
+ ],
25
+ key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE']
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_local_cumsum_scalar_kernel(
29
+ s,
30
+ o,
31
+ cu_seqlens,
32
+ chunk_indices,
33
+ T,
34
+ B: tl.constexpr,
35
+ H: tl.constexpr,
36
+ BT: tl.constexpr,
37
+ REVERSE: tl.constexpr,
38
+ IS_VARLEN: tl.constexpr,
39
+ HEAD_FIRST: tl.constexpr,
40
+ ):
41
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
42
+ i_b, i_h = i_bh // H, i_bh % H
43
+ if IS_VARLEN:
44
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
45
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
46
+ T = eos - bos
47
+ else:
48
+ bos, eos = i_b * T, i_b * T + T
49
+
50
+ if HEAD_FIRST:
51
+ p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
52
+ p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
53
+ else:
54
+ p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
55
+ p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
56
+ # [BT]
57
+ b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
58
+ b_o = tl.cumsum(b_s, axis=0)
59
+ if REVERSE:
60
+ b_z = tl.sum(b_s, axis=0)
61
+ b_o = -b_o + b_z[None] + b_s
62
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
63
+
64
+
65
+ @triton.heuristics({
66
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
67
+ })
68
+ @triton.autotune(
69
+ configs=[
70
+ triton.Config({'BS': BS}, num_warps=num_warps)
71
+ for BS in BS_LIST
72
+ for num_warps in [2, 4, 8]
73
+ ],
74
+ key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE']
75
+ )
76
+ @triton.jit(do_not_specialize=['T'])
77
+ def chunk_local_cumsum_vector_kernel(
78
+ s,
79
+ o,
80
+ cu_seqlens,
81
+ chunk_indices,
82
+ T,
83
+ B: tl.constexpr,
84
+ H: tl.constexpr,
85
+ S: tl.constexpr,
86
+ BT: tl.constexpr,
87
+ BS: tl.constexpr,
88
+ REVERSE: tl.constexpr,
89
+ IS_VARLEN: tl.constexpr,
90
+ HEAD_FIRST: tl.constexpr,
91
+ ):
92
+ i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
93
+ i_b, i_h = i_bh // H, i_bh % H
94
+ if IS_VARLEN:
95
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
96
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
97
+ T = eos - bos
98
+ else:
99
+ bos, eos = i_b * T, i_b * T + T
100
+
101
+ o_i = tl.arange(0, BT)
102
+ if REVERSE:
103
+ m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
104
+ else:
105
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
106
+
107
+ if HEAD_FIRST:
108
+ p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
109
+ p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
110
+ else:
111
+ p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
112
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
113
+ # [BT, BS]
114
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
115
+ b_o = tl.dot(m_s, b_s, allow_tf32=False)
116
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
117
+
118
+
119
+ @triton.heuristics({
120
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
121
+ })
122
+ @triton.autotune(
123
+ configs=[
124
+ triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages)
125
+ for BT in [32, 64, 128, 256]
126
+ for num_warps in [2, 4, 8]
127
+ for num_stages in [1, 2, 3, 4]
128
+ ],
129
+ key=['B', 'H', 'IS_VARLEN', 'REVERSE']
130
+ )
131
+ @triton.jit(do_not_specialize=['T'])
132
+ def chunk_global_cumsum_scalar_kernel(
133
+ s,
134
+ o,
135
+ cu_seqlens,
136
+ T,
137
+ B: tl.constexpr,
138
+ H: tl.constexpr,
139
+ BT: tl.constexpr,
140
+ REVERSE: tl.constexpr,
141
+ IS_VARLEN: tl.constexpr,
142
+ HEAD_FIRST: tl.constexpr,
143
+ ):
144
+ i_nh = tl.program_id(0)
145
+ i_n, i_h = i_nh // H, i_nh % H
146
+ if IS_VARLEN:
147
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
148
+ else:
149
+ bos, eos = i_n * T, i_n * T + T
150
+ T = eos - bos
151
+
152
+ b_z = tl.zeros([], dtype=tl.float32)
153
+ NT = tl.cdiv(T, BT)
154
+ for i_c in range(NT):
155
+ i_t = NT-1-i_c if REVERSE else i_c
156
+ if HEAD_FIRST:
157
+ p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
158
+ p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
159
+ else:
160
+ p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
161
+ p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
162
+ b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
163
+ b_o = tl.cumsum(b_s, axis=0)
164
+ b_ss = tl.sum(b_s, 0)
165
+ if REVERSE:
166
+ b_o = -b_o + b_ss + b_s
167
+ b_o += b_z
168
+ if i_c >= 0:
169
+ b_z += b_ss
170
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
171
+
172
+
173
+ @triton.heuristics({
174
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
175
+ })
176
+ @triton.autotune(
177
+ configs=[
178
+ triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages)
179
+ for BT in [16, 32, 64, 128]
180
+ for num_warps in [2, 4, 8]
181
+ for num_stages in [1, 2, 3, 4]
182
+ ],
183
+ key=['B', 'H', 'S', 'IS_VARLEN', 'REVERSE']
184
+ )
185
+ @triton.jit(do_not_specialize=['T'])
186
+ def chunk_global_cumsum_vector_kernel(
187
+ s,
188
+ z,
189
+ cu_seqlens,
190
+ T,
191
+ B: tl.constexpr,
192
+ H: tl.constexpr,
193
+ S: tl.constexpr,
194
+ BT: tl.constexpr,
195
+ BS: tl.constexpr,
196
+ REVERSE: tl.constexpr,
197
+ IS_VARLEN: tl.constexpr,
198
+ HEAD_FIRST: tl.constexpr,
199
+ ):
200
+ i_s, i_nh = tl.program_id(0), tl.program_id(1)
201
+ i_n, i_h = i_nh // H, i_nh % H
202
+ if IS_VARLEN:
203
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
204
+ else:
205
+ bos, eos = i_n * T, i_n * T + T
206
+ T = eos - bos
207
+
208
+ o_i = tl.arange(0, BT)
209
+ if REVERSE:
210
+ m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
211
+ else:
212
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
213
+
214
+ b_z = tl.zeros([BS], dtype=tl.float32)
215
+ NT = tl.cdiv(T, BT)
216
+ for i_c in range(NT):
217
+ i_t = NT-1-i_c if REVERSE else i_c
218
+ if HEAD_FIRST:
219
+ p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
220
+ p_z = tl.make_block_ptr(z + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
221
+ else:
222
+ p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
223
+ p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
224
+ # [BT, BS]
225
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
226
+ b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)
227
+ tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))
228
+ if i_c >= 0:
229
+ b_z += tl.sum(b_s, 0)
230
+
231
+
232
+ def chunk_local_cumsum_scalar(
233
+ g: torch.Tensor,
234
+ chunk_size: int,
235
+ reverse: bool = False,
236
+ cu_seqlens: Optional[torch.Tensor] = None,
237
+ head_first: bool = False,
238
+ output_dtype: Optional[torch.dtype] = torch.float
239
+ ) -> torch.Tensor:
240
+ if head_first:
241
+ B, H, T = g.shape
242
+ else:
243
+ B, T, H = g.shape
244
+ assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
245
+ BT = chunk_size
246
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
247
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
248
+ g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
249
+ grid = (NT, B * H)
250
+ chunk_local_cumsum_scalar_kernel[grid](
251
+ g_org,
252
+ g,
253
+ cu_seqlens,
254
+ chunk_indices,
255
+ T=T,
256
+ B=B,
257
+ H=H,
258
+ BT=BT,
259
+ HEAD_FIRST=head_first,
260
+ REVERSE=reverse
261
+ )
262
+ return g
263
+
264
+
265
+ def chunk_local_cumsum_vector(
266
+ g: torch.Tensor,
267
+ chunk_size: int,
268
+ reverse: bool = False,
269
+ cu_seqlens: Optional[torch.Tensor] = None,
270
+ head_first: bool = False,
271
+ output_dtype: Optional[torch.dtype] = torch.float
272
+ ) -> torch.Tensor:
273
+ if head_first:
274
+ B, H, T, S = g.shape
275
+ else:
276
+ B, T, H, S = g.shape
277
+ BT = chunk_size
278
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
279
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
280
+ assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
281
+
282
+ g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
283
+ def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
284
+ # keep cummulative normalizer in fp32
285
+ # this kernel is equivalent to
286
+ # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
287
+ chunk_local_cumsum_vector_kernel[grid](
288
+ g_org,
289
+ g,
290
+ cu_seqlens,
291
+ chunk_indices,
292
+ T=T,
293
+ B=B,
294
+ H=H,
295
+ S=S,
296
+ BT=BT,
297
+ HEAD_FIRST=head_first,
298
+ REVERSE=reverse
299
+ )
300
+ return g
301
+
302
+
303
+ @input_guard
304
+ def chunk_global_cumsum_scalar(
305
+ s: torch.Tensor,
306
+ reverse: bool = False,
307
+ cu_seqlens: Optional[torch.Tensor] = None,
308
+ head_first: bool = False,
309
+ output_dtype: Optional[torch.dtype] = torch.float
310
+ ) -> torch.Tensor:
311
+ if head_first:
312
+ B, H, T = s.shape
313
+ else:
314
+ B, T, H = s.shape
315
+ N = len(cu_seqlens) - 1 if cu_seqlens is not None else B
316
+
317
+ z = torch.empty_like(s, dtype=output_dtype or s.dtype)
318
+ grid = (N * H,)
319
+ chunk_global_cumsum_scalar_kernel[grid](
320
+ s,
321
+ z,
322
+ cu_seqlens,
323
+ T=T,
324
+ B=B,
325
+ H=H,
326
+ HEAD_FIRST=head_first,
327
+ REVERSE=reverse
328
+ )
329
+ return z
330
+
331
+
332
+ @input_guard
333
+ def chunk_global_cumsum_vector(
334
+ s: torch.Tensor,
335
+ reverse: bool = False,
336
+ cu_seqlens: Optional[torch.Tensor] = None,
337
+ head_first: bool = False,
338
+ output_dtype: Optional[torch.dtype] = torch.float
339
+ ) -> torch.Tensor:
340
+ if head_first:
341
+ B, H, T, S = s.shape
342
+ else:
343
+ B, T, H, S = s.shape
344
+ N = len(cu_seqlens) - 1 if cu_seqlens is not None else B
345
+ BS = min(32, triton.next_power_of_2(S))
346
+
347
+ z = torch.empty_like(s, dtype=output_dtype or s.dtype)
348
+ grid = (triton.cdiv(S, BS), N * H)
349
+ chunk_global_cumsum_vector_kernel[grid](
350
+ s,
351
+ z,
352
+ cu_seqlens,
353
+ T=T,
354
+ B=B,
355
+ H=H,
356
+ S=S,
357
+ BS=BS,
358
+ HEAD_FIRST=head_first,
359
+ REVERSE=reverse
360
+ )
361
+ return z
362
+
363
+
364
+ @input_guard
365
+ def chunk_global_cumsum(
366
+ s: torch.Tensor,
367
+ reverse: bool = False,
368
+ cu_seqlens: Optional[torch.Tensor] = None,
369
+ head_first: bool = False,
370
+ output_dtype: Optional[torch.dtype] = torch.float
371
+ ) -> torch.Tensor:
372
+ if cu_seqlens is not None:
373
+ assert s.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
374
+ if len(s.shape) == 3:
375
+ return chunk_global_cumsum_scalar(s, reverse, cu_seqlens, head_first, output_dtype)
376
+ elif len(s.shape) == 4:
377
+ return chunk_global_cumsum_vector(s, reverse, cu_seqlens, head_first, output_dtype)
378
+ else:
379
+ raise ValueError(
380
+ f"Unsupported input shape {s.shape}. "
381
+ f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` "
382
+ f"or [B, H, T]/[B, H, T, D] otherwise"
383
+ )
384
+
385
+
386
+ @input_guard
387
+ def chunk_local_cumsum(
388
+ g: torch.Tensor,
389
+ chunk_size: int,
390
+ reverse: bool = False,
391
+ cu_seqlens: Optional[torch.Tensor] = None,
392
+ head_first: bool = False,
393
+ output_dtype: Optional[torch.dtype] = torch.float,
394
+ **kwargs
395
+ ) -> torch.Tensor:
396
+ if not head_first and g.shape[1] < g.shape[2]:
397
+ warnings.warn(
398
+ f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
399
+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
400
+ "when head_first=False was specified. "
401
+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
402
+ )
403
+ if cu_seqlens is not None:
404
+ assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
405
+ if len(g.shape) == 3:
406
+ return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype)
407
+ elif len(g.shape) == 4:
408
+ return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype)
409
+ else:
410
+ raise ValueError(
411
+ f"Unsupported input shape {g.shape}. "
412
+ f"which should be (B, T, H, D) if `head_first=False` "
413
+ f"or (B, H, T, D) otherwise"
414
+ )
fla3/ops/utils/index.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from ...utils import tensor_cache
10
+
11
+
12
+ @triton.autotune(
13
+ configs=[
14
+ triton.Config({}, num_warps=num_warps)
15
+ for num_warps in [4, 8, 16, 32]
16
+ ],
17
+ key=['B'],
18
+ )
19
+ @triton.jit
20
+ def prepare_position_ids_kernel(
21
+ y,
22
+ cu_seqlens,
23
+ B: tl.constexpr
24
+ ):
25
+ i_n = tl.program_id(0)
26
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
27
+ T = eos - bos
28
+
29
+ o = tl.arange(0, B)
30
+ for i in range(0, tl.cdiv(T, B) * B, B):
31
+ o_i = o + i
32
+ tl.store(y + bos + o_i, o_i, o_i < T)
33
+
34
+
35
+ @tensor_cache
36
+ def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
37
+ return cu_seqlens[1:] - cu_seqlens[:-1]
38
+
39
+
40
+ @tensor_cache
41
+ def prepare_lens_from_mask(mask: torch.BoolTensor) -> torch.LongTensor:
42
+ return mask.sum(dim=-1, dtype=torch.int32)
43
+
44
+
45
+ @tensor_cache
46
+ def prepare_cu_seqlens_from_mask(mask: torch.BoolTensor, out_dtype: torch.dtype = torch.int32) -> torch.LongTensor:
47
+ return F.pad(prepare_lens_from_mask(mask).cumsum(dim=0, dtype=out_dtype), (1, 0))
48
+
49
+
50
+ @tensor_cache
51
+ def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
52
+ return torch.cat([
53
+ torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device)
54
+ for n in prepare_lens(cu_seqlens).unbind()
55
+ ])
56
+
57
+
58
+ @tensor_cache
59
+ def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
60
+ return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1
61
+
62
+
63
+ @tensor_cache
64
+ def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
65
+ position_ids = prepare_position_ids(cu_seqlens)
66
+ return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens)
67
+
68
+
69
+ @tensor_cache
70
+ def prepare_chunk_indices(
71
+ cu_seqlens: torch.LongTensor,
72
+ chunk_size: int
73
+ ) -> torch.LongTensor:
74
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
75
+ return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
76
+
77
+
78
+ @tensor_cache
79
+ def prepare_chunk_offsets(
80
+ cu_seqlens: torch.LongTensor,
81
+ chunk_size: int
82
+ ) -> torch.LongTensor:
83
+ return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
fla3/ops/utils/logcumsumexp.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from ...ops.utils.op import exp, log
8
+
9
+
10
+ @triton.autotune(
11
+ configs=[
12
+ triton.Config({'BT': BT}, num_warps=num_warps)
13
+ for BT in [16, 32, 64]
14
+ for num_warps in [2, 4, 8]
15
+ ],
16
+ key=['S']
17
+ )
18
+ @triton.jit(do_not_specialize=['T'])
19
+ def logcumsumexp_fwd_kernel(
20
+ s,
21
+ z,
22
+ T,
23
+ S: tl.constexpr,
24
+ BT: tl.constexpr
25
+ ):
26
+ i_bh = tl.program_id(0)
27
+ o_i = tl.arange(0, BT)
28
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
29
+
30
+ b_mp = tl.full([S,], float('-inf'), dtype=tl.float32)
31
+ b_zp = tl.zeros([S,], dtype=tl.float32)
32
+ for i_t in range(tl.cdiv(T, BT)):
33
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
34
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
35
+
36
+ # [BT, S]
37
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
38
+ # [S,]
39
+ b_mc = tl.max(b_s, 0)
40
+ b_mc = tl.maximum(b_mp, b_mc)
41
+ b_zp = b_zp * exp(b_mp - b_mc)
42
+ # [BT, S]
43
+ b_s = exp(b_s - b_mc)
44
+ b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp
45
+ # [S,]
46
+ b_zc = tl.max(b_z, 0)
47
+ b_mp = b_mc
48
+ b_zp = b_zc
49
+ # [BT, BS]
50
+ # small eps to prevent underflows
51
+ b_z = log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc
52
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1))
fla3/ops/utils/matmul.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # code adapted from
5
+ # https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from ...ops.utils.op import exp
14
+ from ...utils import input_guard
15
+
16
+
17
+ # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
18
+ # - A list of `triton.Config` objects that define different configurations of
19
+ # meta-parameters (e.g., `BM`) and compilation options (e.g., `num_warps`) to try
20
+ # - An auto-tuning *key* whose change in values will trigger evaluation of all the
21
+ # provided configs
22
+ @triton.heuristics({
23
+ 'HAS_ALPHA': lambda args: args['alpha'] is not None,
24
+ 'HAS_BETA': lambda args: args['beta'] is not None
25
+ })
26
+ @triton.autotune(
27
+ configs=[
28
+ triton.Config({'BM': 128, 'BK': 64, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
29
+ triton.Config({'BM': 64, 'BK': 32, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
30
+ triton.Config({'BM': 128, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
31
+ triton.Config({'BM': 128, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
32
+ triton.Config({'BM': 64, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
33
+ triton.Config({'BM': 128, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4),
34
+ triton.Config({'BM': 64, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=5, num_warps=2),
35
+ triton.Config({'BM': 32, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=5, num_warps=2),
36
+ # Good config for fp8 inputs.
37
+ # triton.Config({'BM': 128, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
38
+ # triton.Config({'BM': 256, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=3, num_warps=8),
39
+ # triton.Config({'BM': 256, 'BK': 128, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
40
+ # triton.Config({'BM': 64, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
41
+ # triton.Config({'BM': 128, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
42
+ # triton.Config({'BM': 128, 'BK': 64, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
43
+ # triton.Config({'BM': 64, 'BK': 64, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
44
+ # triton.Config({'BM': 128, 'BK': 64, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4)
45
+ ],
46
+ key=['M', 'N', 'K']
47
+ )
48
+ @triton.jit
49
+ def matmul_kernel(
50
+ # Pointers to matrices
51
+ a,
52
+ b,
53
+ c,
54
+ input,
55
+ alpha,
56
+ beta,
57
+ # Matrix dimensions
58
+ M,
59
+ N,
60
+ K,
61
+ # The stride variables represent how much to increase the ptr by when moving by 1
62
+ # element in a particular dimension. E.g. `s_am` is how much to increase `a`
63
+ # by to get the element one row down (A has M rows).
64
+ stride_ab, stride_am, stride_ak, # a: batch, M, K
65
+ stride_bk, stride_bn, # b: K, N
66
+ stride_cb, stride_cm, stride_cn, # c: batch, M, N
67
+ # Meta-parameters
68
+ BM: tl.constexpr,
69
+ BK: tl.constexpr,
70
+ BN: tl.constexpr,
71
+ G: tl.constexpr,
72
+ ACTIVATION: tl.constexpr,
73
+ HAS_INPUT: tl.constexpr,
74
+ HAS_ALPHA: tl.constexpr,
75
+ HAS_BETA: tl.constexpr,
76
+ ALLOW_TF32: tl.constexpr,
77
+ X_DIM: tl.constexpr = 1,
78
+ ):
79
+ """Kernel for computing the matmul C = A x B.
80
+ A has shape (M, K), B has shape (K, N) and C has shape (M, N)
81
+ """
82
+ # -----------------------------------------------------------
83
+ # Map program ids `pid` to the block of C it should compute.
84
+ # This is done in a grouped ordering to promote L2 data reuse.
85
+ # See above `L2 Cache Optimizations` section for details.
86
+ i_b, i_m, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
87
+
88
+ NM, NN = tl.num_programs(1), tl.num_programs(2)
89
+ i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)
90
+
91
+ # ----------------------------------------------------------
92
+ # Create pointers for the first blocks of A and B.
93
+ # We will advance this pointer as we move in the K direction
94
+ # and accumulate
95
+ # `p_a` is a block of [BM, BK] pointers
96
+ # `p_b` is a block of [BK, BN] pointers
97
+ # See above `Pointer Arithmetic` section for details
98
+ a_batch_ptr = a + i_b * stride_ab
99
+ o_am = (i_m * BM + tl.arange(0, BM)) % M
100
+ o_bn = (i_n * BN + tl.arange(0, BN)) % N
101
+ o_k = tl.arange(0, BK)
102
+
103
+ p_a = a_batch_ptr + (o_am[:, None] * stride_am + o_k[None, :] * stride_ak)
104
+ p_b = b + (o_k[:, None] * stride_bk + o_bn[None, :] * stride_bn)
105
+
106
+ b_acc = tl.zeros((BM, BN), dtype=tl.float32)
107
+ for k in range(0, tl.cdiv(K, BK)):
108
+ # Load the next block of A and B, generate a mask by checking the K dimension.
109
+ # If it is out of bounds, set it to 0.
110
+ b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)
111
+ b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)
112
+ # We accumulate along the K dimension.
113
+ b_acc = tl.dot(b_a, b_b, acc=b_acc, allow_tf32=ALLOW_TF32)
114
+ # Advance the ptrs to the next K block.
115
+ p_a += BK * stride_ak
116
+ p_b += BK * stride_bk
117
+
118
+ o_cm = i_m * BM + tl.arange(0, BM)
119
+ o_cn = i_n * BN + tl.arange(0, BN)
120
+ mask = (o_cm[:, None] < M) & (o_cn[None, :] < N)
121
+
122
+ b_c = b_acc
123
+ # You can fuse arbitrary activation functions here
124
+ # while the b_acc is still in FP32!
125
+ if ACTIVATION == "leaky_relu":
126
+ b_c = leaky_relu(b_c)
127
+ elif ACTIVATION == "relu":
128
+ b_c = relu(b_c)
129
+ elif ACTIVATION == "sigmoid":
130
+ b_c = sigmoid(b_c)
131
+ elif ACTIVATION == "tanh":
132
+ b_c = tanh(b_c)
133
+
134
+ if HAS_ALPHA:
135
+ b_c *= tl.load(alpha)
136
+
137
+ if HAS_INPUT:
138
+ p_i = input + (stride_cm * o_cm[:, None] if X_DIM == 2 else 0) + stride_cn * o_cn[None, :]
139
+ mask_p = (o_cn[None, :] < N) if X_DIM == 1 else mask
140
+ b_i = tl.load(p_i, mask=mask_p, other=0.0).to(tl.float32)
141
+ if HAS_BETA:
142
+ b_i *= tl.load(beta)
143
+ b_c += b_i
144
+
145
+ # -----------------------------------------------------------
146
+ # Write back the block of the output matrix C with masks.
147
+ c_batch_ptr = c + i_b * stride_cb
148
+ p_c = c_batch_ptr + stride_cm * o_cm[:, None] + stride_cn * o_cn[None, :]
149
+ tl.store(p_c, b_c.to(c.dtype.element_ty), mask=mask)
150
+
151
+
152
+ # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
153
+ @triton.jit
154
+ def leaky_relu(x):
155
+ return tl.where(x >= 0, x, 0.01 * x)
156
+
157
+
158
+ @triton.jit
159
+ def sigmoid(x):
160
+ # σ(x) = 1 / (1 + exp(-x))
161
+ return 1.0 / (1.0 + exp(-x))
162
+
163
+
164
+ @triton.jit
165
+ def tanh(x):
166
+ # tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
167
+ # 2 * sigmoid(2x) - 1
168
+ return (exp(x) - exp(-x)) / (exp(x) + exp(-x))
169
+
170
+
171
+ @triton.jit
172
+ def relu(x):
173
+ # ReLU(x) = max(0, x)
174
+ return tl.maximum(x, 0.0)
175
+
176
+
177
+ @input_guard
178
+ def matmul(a, b, activation=''):
179
+ assert a.dim() in [2, 3], "a must be 2D or 3D"
180
+ assert b.dim() == 2, "b must be 2D"
181
+ assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
182
+
183
+ if a.dim() == 2:
184
+ a_dim = 2
185
+ a = a.unsqueeze(0).contiguous() # (1, M, K)
186
+ else:
187
+ a_dim = 3
188
+ allow_tf32 = False if a.dtype == torch.float32 else True
189
+
190
+ B, M, K = a.shape[0], a.shape[1], a.shape[2]
191
+ K_b, N = b.shape
192
+ assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
193
+ c = a.new_empty(B, M, N)
194
+
195
+ def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
196
+ matmul_kernel[grid](
197
+ a, b, c, None, None, None,
198
+ M, N, K,
199
+ a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
200
+ b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
201
+ c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
202
+ ACTIVATION=activation,
203
+ ALLOW_TF32=allow_tf32,
204
+ HAS_INPUT=False,
205
+ )
206
+ return c.squeeze(0) if a_dim == 2 else c
207
+
208
+
209
+ @input_guard
210
+ def addmm(
211
+ x: torch.Tensor,
212
+ a: torch.Tensor,
213
+ b: torch.Tensor,
214
+ alpha: Optional[float] = None,
215
+ beta: Optional[float] = None,
216
+ ) -> torch.Tensor:
217
+ assert a.dim() in [2, 3], "a must be 2D or 3D"
218
+ assert b.dim() == 2, "b must be 2D"
219
+ assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
220
+
221
+ if a.dim() == 2:
222
+ a_dim = 2
223
+ a = a.unsqueeze(0).contiguous() # (1, M, K)
224
+ else:
225
+ a_dim = 3
226
+ allow_tf32 = False if a.dtype == torch.float32 else True
227
+
228
+ B, M, K = a.shape[0], a.shape[1], a.shape[2]
229
+ K_b, N = b.shape
230
+ assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
231
+ c = a.new_empty(B, M, N)
232
+
233
+ def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
234
+ matmul_kernel[grid](
235
+ a, b, c, x, alpha, beta,
236
+ M, N, K,
237
+ a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
238
+ b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
239
+ c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
240
+ ACTIVATION=None,
241
+ ALLOW_TF32=allow_tf32,
242
+ HAS_INPUT=True,
243
+ X_DIM=x.dim(),
244
+ )
245
+ return c.squeeze(0) if a_dim == 2 else c
fla3/ops/utils/op.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ import os
5
+
6
+ import triton
7
+ import triton.language as tl
8
+ import triton.language.extra.libdevice as tldevice
9
+
10
+ from ...utils import is_gather_supported
11
+
12
+ if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
13
+ div = tldevice.fast_dividef
14
+ exp = tldevice.fast_expf
15
+ log = tldevice.fast_logf
16
+ log2 = tldevice.fast_log2f
17
+ else:
18
+ @triton.jit
19
+ def div_normal(x, y):
20
+ return x / y
21
+ div = div_normal
22
+ exp = tl.exp
23
+ log = tl.log
24
+ log2 = tl.log2
25
+
26
+
27
+ @triton.jit
28
+ def safe_exp(x):
29
+ return exp(tl.where(x <= 0, x, float('-inf')))
30
+
31
+
32
+ if not is_gather_supported:
33
+ @triton.jit
34
+ def gather(src, index, axis, _builder=None):
35
+ # This is a fallback implementation when tl.gather is not supported
36
+ # In order to pass triton compiler, there is no actual gather operation
37
+ return src
38
+ else:
39
+ gather = tl.gather
fla3/ops/utils/pack.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # Code adapted from https://github.com/mayank31398/cute-kernels
5
+
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ from ...ops.utils.index import prepare_lens
13
+ from ...utils import input_guard
14
+
15
+
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({}, num_warps=num_warps)
19
+ for num_warps in [4, 8, 16, 32]
20
+ ],
21
+ key=['D', 'PADDING_SIDE', 'PACK']
22
+ )
23
+ @triton.jit
24
+ def packunpack_sequence_kernel(
25
+ x,
26
+ y,
27
+ cu_seqlens,
28
+ S,
29
+ D,
30
+ BD: tl.constexpr,
31
+ PADDING_SIDE: tl.constexpr,
32
+ PACK: tl.constexpr,
33
+ ):
34
+ i_d, i_s, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
35
+ bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1)
36
+
37
+ T = eos - bos
38
+ if PADDING_SIDE == 'left':
39
+ NP = S - T
40
+ if i_s < NP:
41
+ return
42
+ i_t = bos + (i_s - NP)
43
+ else:
44
+ if i_s >= T:
45
+ return
46
+ i_t = bos + i_s
47
+
48
+ o_d = i_d * BD + tl.arange(0, BD)
49
+ mask = o_d < D
50
+
51
+ if PACK:
52
+ b_x = tl.load(x + (i_b * S + i_s) * D + o_d, mask=mask)
53
+ tl.store(y + i_t * D + o_d, b_x, mask=mask)
54
+ else:
55
+ b_x = tl.load(x + i_t * D + o_d, mask=mask)
56
+ tl.store(y + (i_b * S + i_s) * D + o_d, b_x, mask=mask)
57
+
58
+
59
+ def pack_sequence_fwdbwd(
60
+ x: torch.Tensor,
61
+ cu_seqlens: torch.Tensor,
62
+ padding_side: str,
63
+ ) -> torch.Tensor:
64
+ B, S = x.shape[:2]
65
+ D = x.numel() // (B * S)
66
+ BD = min(triton.next_power_of_2(D), 4096)
67
+ ND = triton.cdiv(D, BD)
68
+
69
+ y = torch.empty(cu_seqlens[-1].item(), *x.shape[2:], device=x.device, dtype=x.dtype)
70
+ packunpack_sequence_kernel[ND, S, B](
71
+ x=x,
72
+ y=y,
73
+ cu_seqlens=cu_seqlens,
74
+ S=S,
75
+ D=D,
76
+ BD=BD,
77
+ PADDING_SIDE=padding_side,
78
+ PACK=True,
79
+ )
80
+ return y
81
+
82
+
83
+ def unpack_sequence_fwdbwd(
84
+ x: torch.Tensor,
85
+ cu_seqlens: torch.Tensor,
86
+ padding_side: str,
87
+ desired_shape: torch.Size,
88
+ ) -> torch.Tensor:
89
+ if desired_shape is None:
90
+ desired_shape = (len(cu_seqlens) - 1, prepare_lens(cu_seqlens).max().item(), *x.shape[1:])
91
+ y = torch.zeros(desired_shape, device=x.device, dtype=x.dtype)
92
+ B, S = y.shape[:2]
93
+ D = y.numel() // (B * S)
94
+ BD = min(triton.next_power_of_2(D), 4096)
95
+ ND = triton.cdiv(D, BD)
96
+
97
+ packunpack_sequence_kernel[ND, S, B](
98
+ x=x,
99
+ y=y,
100
+ cu_seqlens=cu_seqlens,
101
+ S=S,
102
+ D=D,
103
+ BD=BD,
104
+ PADDING_SIDE=padding_side,
105
+ PACK=False,
106
+ )
107
+ return y
108
+
109
+
110
+ class PackSequenceFunction(torch.autograd.Function):
111
+
112
+ @staticmethod
113
+ @input_guard
114
+ def forward(
115
+ ctx,
116
+ x: torch.Tensor,
117
+ cu_seqlens: torch.Tensor,
118
+ padding_side: str,
119
+ ) -> torch.Tensor:
120
+ assert padding_side in ['left', 'right']
121
+ assert x.ndim >= 2
122
+
123
+ ctx.cu_seqlens = cu_seqlens
124
+ ctx.padding_side = padding_side
125
+ ctx.desired_shape = x.shape
126
+
127
+ y = pack_sequence_fwdbwd(
128
+ x=x,
129
+ cu_seqlens=cu_seqlens,
130
+ padding_side=padding_side,
131
+ )
132
+ return y
133
+
134
+ @staticmethod
135
+ @input_guard
136
+ def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
137
+ dx = unpack_sequence_fwdbwd(
138
+ x=dy,
139
+ cu_seqlens=ctx.cu_seqlens,
140
+ padding_side=ctx.padding_side,
141
+ desired_shape=ctx.desired_shape,
142
+ )
143
+ return dx, *[None] * 10
144
+
145
+
146
+ class UnpackSequenceFunction(torch.autograd.Function):
147
+
148
+ @staticmethod
149
+ @input_guard
150
+ def forward(
151
+ ctx,
152
+ x: torch.Tensor,
153
+ cu_seqlens: torch.Tensor,
154
+ padding_side: str,
155
+ desired_shape: Optional[torch.Size] = None,
156
+ ) -> torch.Tensor:
157
+ assert padding_side in ['left', 'right']
158
+ assert x.ndim >= 2
159
+ if desired_shape is not None:
160
+ assert desired_shape[0] == cu_seqlens.shape[0] - 1
161
+ assert desired_shape[2:] == x.shape[1:]
162
+
163
+ ctx.cu_seqlens = cu_seqlens
164
+ ctx.padding_side = padding_side
165
+
166
+ y = unpack_sequence_fwdbwd(
167
+ x=x,
168
+ cu_seqlens=cu_seqlens,
169
+ padding_side=padding_side,
170
+ desired_shape=desired_shape,
171
+ )
172
+ return y
173
+
174
+ @staticmethod
175
+ @input_guard
176
+ def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
177
+ dx = pack_sequence_fwdbwd(
178
+ x=dy,
179
+ cu_seqlens=ctx.cu_seqlens,
180
+ padding_side=ctx.padding_side,
181
+ )
182
+ return dx, None, None, None
183
+
184
+
185
+ def pack_sequence(
186
+ x: torch.Tensor,
187
+ cu_seqlens: torch.Tensor,
188
+ padding_side: str = 'left'
189
+ ) -> torch.Tensor:
190
+ return PackSequenceFunction.apply(
191
+ x,
192
+ cu_seqlens,
193
+ padding_side,
194
+ )
195
+
196
+
197
+ def unpack_sequence(
198
+ x: torch.Tensor,
199
+ cu_seqlens: torch.Tensor,
200
+ padding_side: str = 'left',
201
+ desired_shape: Optional[torch.Size] = None,
202
+ ) -> torch.Tensor:
203
+ return UnpackSequenceFunction.apply(
204
+ x,
205
+ cu_seqlens,
206
+ padding_side,
207
+ desired_shape,
208
+ )
fla3/ops/utils/pooling.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ...ops.utils.index import prepare_chunk_indices
11
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({'BD': BD}, num_warps=num_warps)
20
+ for BD in [16, 32, 64, 128]
21
+ for num_warps in [1, 2, 4, 8]
22
+ ],
23
+ key=['BT']
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def mean_pooling_fwd_kernel(
27
+ x,
28
+ o,
29
+ cu_seqlens,
30
+ chunk_indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ D: tl.constexpr,
34
+ BT: tl.constexpr,
35
+ BD: tl.constexpr,
36
+ IS_VARLEN: tl.constexpr
37
+ ):
38
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if IS_VARLEN:
41
+ i_tg = i_t
42
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
43
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
44
+ T = eos - bos
45
+ NT = tl.cdiv(T, BT)
46
+ else:
47
+ NT = tl.cdiv(T, BT)
48
+ i_tg = i_b * NT + i_t
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
52
+ p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
53
+ # [BT, BD]
54
+ b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
55
+ # [BD]
56
+ b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT)
57
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
58
+
59
+
60
+ @triton.heuristics({
61
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
62
+ })
63
+ @triton.autotune(
64
+ configs=[
65
+ triton.Config({'BD': BD}, num_warps=num_warps)
66
+ for BD in [16, 32, 64, 128]
67
+ for num_warps in [1, 2, 4, 8]
68
+ ],
69
+ key=['BT']
70
+ )
71
+ @triton.jit(do_not_specialize=['T'])
72
+ def mean_pooling_bwd_kernel(
73
+ do,
74
+ dx,
75
+ cu_seqlens,
76
+ chunk_indices,
77
+ T,
78
+ H: tl.constexpr,
79
+ D: tl.constexpr,
80
+ BT: tl.constexpr,
81
+ BD: tl.constexpr,
82
+ IS_VARLEN: tl.constexpr
83
+ ):
84
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
85
+ i_b, i_h = i_bh // H, i_bh % H
86
+ if IS_VARLEN:
87
+ i_tg = i_t
88
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
89
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
90
+ T = eos - bos
91
+ NT = tl.cdiv(T, BT)
92
+ else:
93
+ NT = tl.cdiv(T, BT)
94
+ i_tg = i_b * NT + i_t
95
+ bos, eos = i_b * T, i_b * T + T
96
+
97
+ p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
98
+ p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
99
+ # [BD]
100
+ b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32)
101
+ # [BT, BD]
102
+ b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None]
103
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
104
+
105
+
106
+ def mean_pooling_fwd(
107
+ x: torch.Tensor,
108
+ chunk_size: int,
109
+ cu_seqlens: Optional[torch.LongTensor] = None
110
+ ) -> torch.Tensor:
111
+ B, T, H, D = x.shape
112
+ BT = chunk_size
113
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
114
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
115
+
116
+ o = x.new_empty(B, NT, H, D)
117
+ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
118
+ mean_pooling_fwd_kernel[grid](
119
+ x,
120
+ o,
121
+ cu_seqlens,
122
+ chunk_indices,
123
+ T=T,
124
+ H=H,
125
+ D=D,
126
+ BT=BT,
127
+ )
128
+ return o
129
+
130
+
131
+ def mean_pooling_bwd(
132
+ do: torch.Tensor,
133
+ batch_size: int,
134
+ seq_len: int,
135
+ chunk_size: int,
136
+ cu_seqlens: Optional[torch.LongTensor] = None
137
+ ) -> torch.Tensor:
138
+ B, T, H, D = batch_size, seq_len, *do.shape[-2:]
139
+ BT = chunk_size
140
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
141
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
142
+
143
+ dx = do.new_empty(B, T, H, D)
144
+ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
145
+ mean_pooling_bwd_kernel[grid](
146
+ do,
147
+ dx,
148
+ cu_seqlens,
149
+ chunk_indices,
150
+ T=T,
151
+ H=H,
152
+ D=D,
153
+ BT=BT,
154
+ )
155
+ return dx
156
+
157
+
158
+ class MeanPoolingFunction(torch.autograd.Function):
159
+
160
+ @staticmethod
161
+ @input_guard
162
+ @autocast_custom_fwd
163
+ def forward(
164
+ ctx,
165
+ x: torch.Tensor,
166
+ chunk_size: int,
167
+ cu_seqlens: Optional[torch.LongTensor] = None
168
+ ) -> torch.Tensor:
169
+ o = mean_pooling_fwd(x, chunk_size, cu_seqlens)
170
+ ctx.batch_size = x.shape[0]
171
+ ctx.seq_len = x.shape[1]
172
+ ctx.chunk_size = chunk_size
173
+ ctx.cu_seqlens = cu_seqlens
174
+ return o
175
+
176
+ @staticmethod
177
+ @input_guard
178
+ @autocast_custom_bwd
179
+ def backward(
180
+ ctx, do
181
+ ) -> Tuple[torch.Tensor, None, None]:
182
+ batch_size = ctx.batch_size
183
+ seq_len = ctx.seq_len
184
+ chunk_size = ctx.chunk_size
185
+ cu_seqlens = ctx.cu_seqlens
186
+ dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, cu_seqlens)
187
+ return dx, None, None
188
+
189
+
190
+ def mean_pooling(
191
+ x: torch.Tensor,
192
+ chunk_size: int,
193
+ cu_seqlens: Optional[torch.LongTensor] = None,
194
+ head_first: bool = False
195
+ ) -> torch.Tensor:
196
+ if head_first:
197
+ x = x.transpose(1, 2)
198
+ if cu_seqlens is not None:
199
+ if x.shape[0] != 1:
200
+ raise ValueError(
201
+ f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`."
202
+ f"Please ..tten variable-length inputs before processing."
203
+ )
204
+ o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens)
205
+ if head_first:
206
+ o = o.transpose(1, 2)
207
+ return o
fla3/ops/utils/softmax.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ...ops.utils.op import exp
11
+
12
+
13
+ @triton.autotune(
14
+ configs=[
15
+ triton.Config({}, num_warps=1),
16
+ triton.Config({}, num_warps=2),
17
+ triton.Config({}, num_warps=4),
18
+ triton.Config({}, num_warps=8),
19
+ triton.Config({}, num_warps=16),
20
+ triton.Config({}, num_warps=32)
21
+ ],
22
+ key=['D']
23
+ )
24
+ @triton.jit
25
+ def softmax_fwd_kernel(
26
+ x,
27
+ p,
28
+ D: tl.constexpr,
29
+ B: tl.constexpr
30
+ ):
31
+ i_n = tl.program_id(0)
32
+ o_d = tl.arange(0, B)
33
+ m_d = o_d < D
34
+
35
+ b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf'))
36
+ b_m = tl.max(b_x, 0)
37
+ b_x = exp(b_x - b_m)
38
+ b_p = b_x / tl.sum(b_x, 0)
39
+
40
+ tl.store(p + i_n * D + o_d, b_p.to(p.dtype.element_ty), mask=m_d)
41
+
42
+
43
+ @triton.autotune(
44
+ configs=[
45
+ triton.Config({}, num_warps=1),
46
+ triton.Config({}, num_warps=2),
47
+ triton.Config({}, num_warps=4),
48
+ triton.Config({}, num_warps=8),
49
+ triton.Config({}, num_warps=16),
50
+ triton.Config({}, num_warps=32)
51
+ ],
52
+ key=['D']
53
+ )
54
+ @triton.jit
55
+ def softmax_bwd_kernel(
56
+ p,
57
+ dp,
58
+ ds,
59
+ D: tl.constexpr,
60
+ B: tl.constexpr
61
+ ):
62
+ i_n = tl.program_id(0)
63
+ o_d = tl.arange(0, B)
64
+ m_d = o_d < D
65
+
66
+ b_p = tl.load(p + i_n * D + o_d, mask=m_d, other=0.)
67
+ b_dp = tl.load(dp + i_n * D + o_d, mask=m_d, other=0.)
68
+ b_pp = tl.sum(b_p * b_dp, 0)
69
+ b_ds = b_p * b_dp - b_p * b_pp
70
+ tl.store(ds + i_n * D + o_d, b_ds.to(ds.dtype.element_ty), mask=m_d)
71
+
72
+
73
+ def softmax_fwd(
74
+ x: torch.Tensor,
75
+ dtype: Optional[torch.dtype] = torch.float
76
+ ) -> torch.Tensor:
77
+ shape = x.shape
78
+ x = x.view(-1, x.shape[-1])
79
+
80
+ N, D = x.shape
81
+ B = triton.next_power_of_2(D)
82
+
83
+ p = torch.empty_like(x, dtype=dtype)
84
+ softmax_fwd_kernel[(N,)](
85
+ x=x,
86
+ p=p,
87
+ D=D,
88
+ B=B
89
+ )
90
+ return p.view(*shape)
91
+
92
+
93
+ def softmax_bwd(
94
+ p: torch.Tensor,
95
+ dp: torch.Tensor,
96
+ dtype: Optional[torch.dtype] = torch.float
97
+ ) -> torch.Tensor:
98
+ shape = p.shape
99
+ p = p.view(-1, p.shape[-1])
100
+ ds = torch.empty_like(p, dtype=dtype)
101
+
102
+ N, D = p.shape
103
+ B = triton.next_power_of_2(D)
104
+ softmax_bwd_kernel[(N,)](
105
+ p=p,
106
+ dp=dp,
107
+ ds=ds,
108
+ D=D,
109
+ B=B
110
+ )
111
+ return ds.view(*shape)
fla3/ops/utils/solve_tril.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ...ops.utils.index import prepare_chunk_indices
11
+ from ...utils import input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [1, 2, 4, 8]
21
+ for num_stages in [2, 3, 4, 5]
22
+ ],
23
+ key=['BT'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def solve_tril_16x16_kernel(
27
+ A,
28
+ Ad,
29
+ cu_seqlens,
30
+ chunk_indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ BT: tl.constexpr,
34
+ IS_VARLEN: tl.constexpr,
35
+ ):
36
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
37
+ i_b, i_h = i_bh // H, i_bh % H
38
+ if IS_VARLEN:
39
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
40
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
41
+ T = eos - bos
42
+ else:
43
+ bos, eos = i_b * T, i_b * T + T
44
+
45
+ A = A + (bos*H + i_h) * BT
46
+ Ad = Ad + (bos*H + i_h) * 16
47
+
48
+ offset = (i_t * 16) % BT
49
+ p_A = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * 16, offset), (16, 16), (1, 0))
50
+ p_Ai = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 16, 0), (16, 16), (1, 0))
51
+ b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
52
+ b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
53
+
54
+ o_i = tl.arange(0, 16)
55
+ for i in range(1, min(16, T-i_t*16)):
56
+ b_a = -tl.load(A + (i_t * 16 + i) * H*BT + o_i + offset)
57
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
58
+ mask = o_i == i
59
+ b_A = tl.where(mask[:, None], b_a, b_A)
60
+ b_A += o_i[:, None] == o_i[None, :]
61
+ tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
62
+
63
+
64
+ @triton.heuristics({
65
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
66
+ })
67
+ @triton.autotune(
68
+ configs=[
69
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
70
+ for num_warps in [1, 2, 4, 8]
71
+ for num_stages in [2, 3, 4, 5]
72
+ ],
73
+ key=['H', 'BT', 'IS_VARLEN'],
74
+ )
75
+ @triton.jit(do_not_specialize=['T'])
76
+ def merge_16x16_to_32x32_inverse_kernel(
77
+ A,
78
+ Ad,
79
+ Ai,
80
+ cu_seqlens,
81
+ chunk_indices,
82
+ T,
83
+ H: tl.constexpr,
84
+ BT: tl.constexpr,
85
+ IS_VARLEN: tl.constexpr
86
+ ):
87
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
88
+ i_b, i_h = i_bh // H, i_bh % H
89
+ if IS_VARLEN:
90
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
91
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
92
+ T = eos - bos
93
+ else:
94
+ bos, eos = i_b * T, i_b * T + T
95
+
96
+ A += (bos*H + i_h) * 32
97
+ Ad += (bos*H + i_h) * 16
98
+ Ai += (bos*H + i_h) * 32
99
+
100
+ p_A_21 = tl.make_block_ptr(A, (T, 32), (H*32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
101
+ p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 32, 0), (16, 16), (1, 0))
102
+ p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
103
+ p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32, 0), (16, 16), (1, 0))
104
+ p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0))
105
+ p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
106
+
107
+ A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
108
+ Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
109
+ Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
110
+ Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee')
111
+ tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
112
+ tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
113
+ tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
114
+
115
+
116
+ @triton.heuristics({
117
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
118
+ })
119
+ @triton.autotune(
120
+ configs=[
121
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
122
+ for num_warps in [2, 4, 8]
123
+ for num_stages in [2, 3, 4, 5]
124
+ ],
125
+ key=['H', 'BT', 'IS_VARLEN'],
126
+ )
127
+ @triton.jit(do_not_specialize=['T'])
128
+ def merge_16x16_to_64x64_inverse_kernel(
129
+ A,
130
+ Ad,
131
+ Ai,
132
+ cu_seqlens,
133
+ chunk_indices,
134
+ T,
135
+ H: tl.constexpr,
136
+ BT: tl.constexpr,
137
+ IS_VARLEN: tl.constexpr
138
+ ):
139
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
140
+ i_b, i_h = i_bh // H, i_bh % H
141
+ if IS_VARLEN:
142
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
143
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
144
+ T = eos - bos
145
+ else:
146
+ bos, eos = i_b * T, i_b * T + T
147
+
148
+ A += (bos*H + i_h) * 64
149
+ Ad += (bos*H + i_h) * 16
150
+ Ai += (bos*H + i_h) * 64
151
+
152
+ p_A_21 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
153
+ p_A_32 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0))
154
+ p_A_31 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
155
+ p_A_43 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0))
156
+ p_A_42 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0))
157
+ p_A_41 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
158
+ p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64, 0), (16, 16), (1, 0))
159
+ p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
160
+ p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
161
+ p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
162
+
163
+ A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
164
+ A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
165
+ A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
166
+ A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
167
+ A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
168
+ A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
169
+
170
+ Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
171
+ Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
172
+ Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32)
173
+ Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32)
174
+
175
+ Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee')
176
+ Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'), Ai_22, input_precision='ieee')
177
+ Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'), Ai_33, input_precision='ieee')
178
+
179
+ Ai_31 = -tl.dot(
180
+ Ai_33,
181
+ tl.dot(A_31, Ai_11, input_precision='ieee') +
182
+ tl.dot(A_32, Ai_21, input_precision='ieee'),
183
+ input_precision='ieee'
184
+ )
185
+ Ai_42 = -tl.dot(
186
+ Ai_44,
187
+ tl.dot(A_42, Ai_22, input_precision='ieee') +
188
+ tl.dot(A_43, Ai_32, input_precision='ieee'),
189
+ input_precision='ieee'
190
+ )
191
+ Ai_41 = -tl.dot(
192
+ Ai_44,
193
+ tl.dot(A_41, Ai_11, input_precision='ieee') +
194
+ tl.dot(A_42, Ai_21, input_precision='ieee') +
195
+ tl.dot(A_43, Ai_31, input_precision='ieee'),
196
+ input_precision='ieee'
197
+ )
198
+
199
+ p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64, 0), (16, 16), (1, 0))
200
+ p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0))
201
+ p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0))
202
+ p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0))
203
+ p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
204
+ p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
205
+ p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0))
206
+ p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
207
+ p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0))
208
+ p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0))
209
+ tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
210
+ tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
211
+ tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
212
+ tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
213
+ tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
214
+ tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
215
+ tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
216
+ tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
217
+ tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
218
+ tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
219
+
220
+
221
+ @input_guard
222
+ def solve_tril(
223
+ A: torch.Tensor,
224
+ cu_seqlens: Optional[torch.Tensor] = None,
225
+ output_dtype: torch.dtype = torch.float
226
+ ) -> torch.Tensor:
227
+ """
228
+ Compute the inverse of the lower triangular matrix
229
+ A should be strictly lower triangular, i.e., A.triu() == 0.
230
+
231
+ Args:
232
+ A (torch.Tensor):
233
+ [B, T, H, K]
234
+ cu_seqlens (torch.Tensor):
235
+ The cumulative sequence lengths of the input tensor.
236
+ Default: None.
237
+ output_dtype (torch.dtype):
238
+ The dtype of the output tensor. Default: `torch.float`
239
+
240
+ Returns:
241
+ (I + A)^-1 with the same shape as A
242
+ """
243
+ assert A.shape[-1] in [16, 32, 64]
244
+
245
+ B, T, H, BT = A.shape
246
+ Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype)
247
+
248
+ chunk_indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None
249
+ NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
250
+ solve_tril_16x16_kernel[NT, B * H](
251
+ A=A,
252
+ Ad=Ad,
253
+ cu_seqlens=cu_seqlens,
254
+ chunk_indices=chunk_indices,
255
+ T=T,
256
+ H=H,
257
+ BT=BT,
258
+ )
259
+ if BT == 16:
260
+ return Ad
261
+
262
+ Ai = torch.zeros(B, T, H, BT, device=A.device, dtype=output_dtype)
263
+ merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
264
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
265
+ NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
266
+ merge_fn[NT, B * H](
267
+ A=A,
268
+ Ad=Ad,
269
+ Ai=Ai,
270
+ cu_seqlens=cu_seqlens,
271
+ chunk_indices=chunk_indices,
272
+ T=T,
273
+ H=H,
274
+ BT=BT,
275
+ )
276
+ return Ai
flame/__init__.py ADDED
File without changes
flame/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (167 Bytes). View file
 
flame/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (167 Bytes). View file
 
flame/__pycache__/data.cpython-310.pyc ADDED
Binary file (8.17 kB). View file
 
flame/__pycache__/data.cpython-312.pyc ADDED
Binary file (14.9 kB). View file
 
flame/__pycache__/logging.cpython-310.pyc ADDED
Binary file (3.56 kB). View file
 
flame/__pycache__/logging.cpython-312.pyc ADDED
Binary file (6.44 kB). View file
 
flame/__pycache__/parser.cpython-310.pyc ADDED
Binary file (2.89 kB). View file
 
flame/__pycache__/parser.cpython-312.pyc ADDED
Binary file (4.07 kB). View file
 
flame/data.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from copy import deepcopy
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, Iterable, List, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from datasets import Dataset, IterableDataset
12
+ from flame.logging import get_logger
13
+ from transformers import PreTrainedTokenizer
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class HuggingfaceDataset(IterableDataset):
19
+
20
+ def __init__(
21
+ self,
22
+ dataset: Dataset,
23
+ tokenizer: PreTrainedTokenizer,
24
+ context_len: int = 2048,
25
+ rank: int = 0,
26
+ world_size: int = 1,
27
+ buffer_size: int = 1024
28
+ ) -> HuggingfaceDataset:
29
+
30
+ self.dataset = dataset
31
+ self.tokenizer = tokenizer
32
+
33
+ self.data = dataset.shard(world_size, rank)
34
+ self.context_len = context_len
35
+ self.rank = rank
36
+ self.world_size = world_size
37
+ self.buffer_size = buffer_size
38
+
39
+ if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
40
+ self.dtype = torch.int16
41
+ elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
42
+ self.dtype = torch.int32
43
+ else:
44
+ self.dtype = torch.int64
45
+ self.states = None
46
+ self.buffer = torch.tensor([], dtype=self.dtype)
47
+ self.tokens = []
48
+ self.rand_id = 0
49
+ self.token_id = 0
50
+ self.rng_state = None
51
+ self._epoch = 0
52
+
53
+ def __iter__(self):
54
+ g = torch.Generator()
55
+ g.manual_seed(self._epoch + self.rank)
56
+ if self.rng_state is not None:
57
+ g.set_state(self.rng_state)
58
+
59
+ rand_it = self.randint(0, self.buffer_size, g=g)
60
+ if self.states is not None:
61
+ self.data.load_state_dict(self.states)
62
+
63
+ # max number of tokens allowed in the chunk buffer
64
+ n_tokens = self.buffer_size * self.context_len
65
+
66
+ while True:
67
+ for sample in self.tokenize(self.data):
68
+ # keep appending the samples to the token buffer
69
+ self.tokens += sample
70
+ # if the token buffer is full, start sampling
71
+ # NOTE: we first convert the token ids to a tensor of shape [n_chunks, context_len] for efficiency
72
+ if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
73
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
74
+ self.tokens = self.tokens[n_tokens:]
75
+ if len(self.buffer) == self.buffer_size:
76
+ yield from self.sample(rand_it)
77
+
78
+ n_chunks = len(self.tokens) // self.context_len
79
+ # handle the left tokens in the buffer
80
+ if n_chunks > 0:
81
+ n_tokens = n_chunks * self.context_len
82
+ indices = torch.randperm(n_chunks, generator=g).tolist()
83
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
84
+ self.tokens = self.tokens[n_tokens:]
85
+ for i in indices:
86
+ yield {'input_ids': self.buffer[i]}
87
+
88
+ def tokenize(self, data, batch_size: int = 64):
89
+ texts, states = [], []
90
+ for sample in data:
91
+ texts.append(sample['text'])
92
+ states.append(self.data.state_dict())
93
+ if len(texts) == batch_size:
94
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
95
+ self.states = s
96
+ yield tokenized
97
+ texts, states = [], []
98
+ if len(texts) > 0:
99
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
100
+ self.states = s
101
+ yield tokenized
102
+
103
+ def sample(self, indices):
104
+ n_tokens = (len(self.tokens) // self.context_len) * self.context_len
105
+ while self.token_id < n_tokens:
106
+ i = next(indices)
107
+ start, end = self.token_id, self.token_id + self.context_len
108
+ self.token_id += self.context_len
109
+ yield {'input_ids': self.buffer[i].to(torch.long)}
110
+ self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
111
+ self.token_id = 0
112
+ self.tokens = self.tokens[n_tokens:]
113
+
114
+ def randint(
115
+ self,
116
+ low: int,
117
+ high: int,
118
+ batch_size: int = 1024,
119
+ g: torch.Generator = torch.Generator()
120
+ ) -> Iterable[int]:
121
+ indices = torch.empty(batch_size, dtype=torch.long)
122
+ while True:
123
+ # record the generator states before sampling
124
+ self.rng_state = g.get_state()
125
+ indices = torch.randint(low, high, (batch_size,), out=indices, generator=g)
126
+ for i in indices[self.rand_id:].tolist():
127
+ self.rand_id += 1
128
+ yield i
129
+ self.rand_id = 0
130
+
131
+ def set_epoch(self, epoch):
132
+ self._epoch = epoch
133
+ if hasattr(self.dataset, "set_epoch"):
134
+ self.dataset.set_epoch(epoch)
135
+
136
+ def state_dict(self):
137
+ return {
138
+ 'states': self.states,
139
+ 'buffer': self.buffer.clone(),
140
+ 'tokens': deepcopy(self.tokens),
141
+ 'rand_id': self.rand_id,
142
+ 'token_id': self.token_id,
143
+ 'rng_state': self.rng_state,
144
+ 'epoch': self._epoch
145
+ }
146
+
147
+ def load_state_dict(self, state_dict):
148
+ self.states = state_dict['states']
149
+ self.buffer = state_dict['buffer'].clone()
150
+ self.tokens = deepcopy(state_dict['tokens'])
151
+ self.rand_id = state_dict['rand_id']
152
+ self.token_id = state_dict['token_id']
153
+ self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
154
+ self._epoch = state_dict['epoch']
155
+
156
+
157
+ @dataclass
158
+ class DataCollatorForLanguageModeling:
159
+ """
160
+ Data collator used for language modeling.
161
+
162
+ Args:
163
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
164
+ The tokenizer used for encoding the data.
165
+ varlen (`bool`):
166
+ Whether to return sequences with variable lengths.
167
+ If `True`, the offsets indicating the start and end of each sequence will be returned.
168
+ For example, if the sequence lengths are `[4, 8, 12]`,
169
+ the returned `input_ids` will be a long flattened tensor of shape `[1, 24]`, with `offsets` being `[0, 4, 12, 24]`.
170
+ If `False`, the `input_ids` with shape `[batch_size, seq_len]` will be returned directly.
171
+ return_tensors (`str`):
172
+ The type of Tensor to return. Allowable values are "pt".
173
+ """
174
+
175
+ tokenizer: PreTrainedTokenizer
176
+ varlen: bool = False
177
+ return_tensors: str = "pt"
178
+
179
+ def __call__(
180
+ self,
181
+ examples: List[Union[List[int], Dict[str, Any]]]
182
+ ) -> Dict[str, Any]:
183
+ if not isinstance(examples[0], Dict):
184
+ examples = [{'input_ids': example} for example in examples]
185
+
186
+ def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
187
+ tensorized = {}
188
+ for key in ['input_ids', 'offsets']:
189
+ if key not in example:
190
+ continue
191
+ if isinstance(example[key], List):
192
+ tensorized[key] = torch.tensor(example[key], dtype=torch.long)
193
+ elif isinstance(example[key], np.ndarray):
194
+ tensorized[key] = torch.from_numpy(example[key])
195
+ else:
196
+ tensorized[key] = example[key]
197
+ return tensorized
198
+
199
+ examples = list(map(tensorize, examples))
200
+
201
+ if not self.varlen:
202
+ length_of_first = examples[0]['input_ids'].size(0)
203
+ # Check if padding is necessary.
204
+ if all(example['input_ids'].size(0) == length_of_first for example in examples):
205
+ batch = {
206
+ 'input_ids': torch.stack([example['input_ids'] for example in examples], dim=0),
207
+ }
208
+ else:
209
+ # If yes, check if we have a `pad_token`.
210
+ if self.tokenizer._pad_token is None:
211
+ raise ValueError(
212
+ f"You are attempting to pad samples but the tokenizer you are using "
213
+ f"({self.tokenizer.__class__.__name__}) does not have a pad token."
214
+ )
215
+ batch = self.tokenizer.pad(examples, return_tensors=self.return_tensors, return_attention_mask=False)
216
+ else:
217
+ if len(examples) > 1:
218
+ raise ValueError("The batch size must be 1 for variable length inputs.")
219
+ batch = {
220
+ 'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)
221
+ }
222
+ if 'offsets' in examples[0]:
223
+ batch['offsets'] = torch.cat([example['offsets'] for example in examples], dim=0).unsqueeze(0)
224
+ else:
225
+ # determine boundaries by bos/eos positions
226
+ if self.tokenizer.add_bos_token:
227
+ offsets = []
228
+ if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
229
+ offsets.append(torch.tensor([0], dtype=torch.long))
230
+ offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1])
231
+ offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
232
+ batch['offsets'] = torch.cat(offsets, dim=0)
233
+ elif self.tokenizer.add_eos_token:
234
+ offsets = [torch.tensor([0], dtype=torch.long)]
235
+ offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1)
236
+ if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
237
+ offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
238
+ batch['offsets'] = torch.cat(offsets, dim=0)
239
+ else:
240
+ raise ValueError("You must allow the tokenizer to add either a bos or eos token as separators.")
241
+
242
+ labels = batch['input_ids'].clone()
243
+ if self.tokenizer.pad_token_id is not None:
244
+ labels[labels == self.tokenizer.pad_token_id] = -100
245
+ batch["labels"] = labels
246
+ return batch
flame/logging.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import sys
7
+ import time
8
+
9
+ from transformers.trainer_callback import (ExportableState, TrainerCallback,
10
+ TrainerControl, TrainerState)
11
+ from transformers.training_args import TrainingArguments
12
+
13
+
14
+ def get_logger(name: str = None) -> logging.Logger:
15
+ formatter = logging.Formatter(
16
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
17
+ )
18
+ handler = logging.StreamHandler(sys.stdout)
19
+ handler.setFormatter(formatter)
20
+
21
+ logger = logging.getLogger(name)
22
+ if 'RANK' in os.environ and int(os.environ['RANK']) == 0:
23
+ logger.setLevel(logging.INFO)
24
+ logger.addHandler(handler)
25
+
26
+ return logger
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+ LOG_FILE_NAME = "trainer_log.jsonl"
32
+
33
+
34
+ class LogCallback(TrainerCallback, ExportableState):
35
+ def __init__(self, start_time: float = None, elapsed_time: float = None):
36
+
37
+ self.start_time = time.time() if start_time is None else start_time
38
+ self.elapsed_time = 0 if elapsed_time is None else elapsed_time
39
+ self.last_time = self.start_time
40
+
41
+ def on_train_begin(
42
+ self,
43
+ args: TrainingArguments,
44
+ state: TrainerState,
45
+ control: TrainerControl,
46
+ **kwargs
47
+ ):
48
+ r"""
49
+ Event called at the beginning of training.
50
+ """
51
+ if state.is_local_process_zero:
52
+ if not args.resume_from_checkpoint:
53
+ self.start_time = time.time()
54
+ self.elapsed_time = 0
55
+ else:
56
+ self.start_time = state.stateful_callbacks['LogCallback']['start_time']
57
+ self.elapsed_time = state.stateful_callbacks['LogCallback']['elapsed_time']
58
+
59
+ if args.save_on_each_node:
60
+ if not state.is_local_process_zero:
61
+ return
62
+ else:
63
+ if not state.is_world_process_zero:
64
+ return
65
+
66
+ self.last_time = time.time()
67
+ if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
68
+ logger.warning("Previous log file in this folder will be deleted.")
69
+ os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
70
+
71
+ def on_log(
72
+ self,
73
+ args: TrainingArguments,
74
+ state: TrainerState,
75
+ control: TrainerControl,
76
+ logs,
77
+ **kwargs
78
+ ):
79
+ if args.save_on_each_node:
80
+ if not state.is_local_process_zero:
81
+ return
82
+ else:
83
+ if not state.is_world_process_zero:
84
+ return
85
+
86
+ self.elapsed_time += time.time() - self.last_time
87
+ self.last_time = time.time()
88
+ if 'num_input_tokens_seen' in logs:
89
+ logs['num_tokens'] = logs.pop('num_input_tokens_seen')
90
+ state.log_history[-1].pop('num_input_tokens_seen')
91
+ throughput = logs['num_tokens'] / args.world_size / self.elapsed_time
92
+ state.log_history[-1]['throughput'] = logs['throughput'] = throughput
93
+ state.stateful_callbacks["LogCallback"] = self.state()
94
+
95
+ logs = dict(
96
+ current_steps=state.global_step,
97
+ total_steps=state.max_steps,
98
+ loss=state.log_history[-1].get("loss", None),
99
+ eval_loss=state.log_history[-1].get("eval_loss", None),
100
+ predict_loss=state.log_history[-1].get("predict_loss", None),
101
+ learning_rate=state.log_history[-1].get("learning_rate", None),
102
+ epoch=state.log_history[-1].get("epoch", None),
103
+ percentage=round(state.global_step / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
104
+ )
105
+
106
+ os.makedirs(args.output_dir, exist_ok=True)
107
+ with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
108
+ f.write(json.dumps(logs) + "\n")
109
+
110
+ def state(self) -> dict:
111
+ return {
112
+ 'start_time': self.start_time,
113
+ 'elapsed_time': self.elapsed_time
114
+ }
115
+
116
+ @classmethod
117
+ def from_state(cls, state):
118
+ return cls(state['start_time'], state['elapsed_time'])
flame/parser.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+
8
+ import transformers
9
+ from transformers import HfArgumentParser, TrainingArguments
10
+
11
+ from flame.logging import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class TrainingArguments(TrainingArguments):
18
+
19
+ model_name_or_path: str = field(
20
+ default=None,
21
+ metadata={
22
+ "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
23
+ },
24
+ )
25
+ tokenizer: str = field(
26
+ default="fla-hub/gla-1.3B-100B",
27
+ metadata={"help": "Name of the tokenizer to use."}
28
+ )
29
+ use_fast_tokenizer: bool = field(
30
+ default=False,
31
+ metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
32
+ )
33
+ from_config: bool = field(
34
+ default=True,
35
+ metadata={"help": "Whether to initialize models from scratch."},
36
+ )
37
+ dataset: Optional[str] = field(
38
+ default=None,
39
+ metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."},
40
+ )
41
+ dataset_name: Optional[str] = field(
42
+ default=None,
43
+ metadata={"help": "The name of provided dataset(s) to use."},
44
+ )
45
+ cache_dir: str = field(
46
+ default=None,
47
+ metadata={"help": "Path to the cached tokenized dataset."},
48
+ )
49
+ split: str = field(
50
+ default="train",
51
+ metadata={"help": "Which dataset split to use for training and evaluation."},
52
+ )
53
+ streaming: bool = field(
54
+ default=False,
55
+ metadata={"help": "Enable dataset streaming."},
56
+ )
57
+ hf_hub_token: Optional[str] = field(
58
+ default=None,
59
+ metadata={"help": "Auth token to log in with Hugging Face Hub."},
60
+ )
61
+ preprocessing_num_workers: Optional[int] = field(
62
+ default=None,
63
+ metadata={"help": "The number of processes to use for the pre-processing."},
64
+ )
65
+ buffer_size: int = field(
66
+ default=2048,
67
+ metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
68
+ )
69
+ context_length: int = field(
70
+ default=2048,
71
+ metadata={"help": "The context length of the tokenized inputs in the dataset."},
72
+ )
73
+ varlen: bool = field(
74
+ default=False,
75
+ metadata={"help": "Enable training with variable length inputs."},
76
+ )
77
+
78
+
79
+ def get_train_args():
80
+ parser = HfArgumentParser(TrainingArguments)
81
+ args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
82
+
83
+ if unknown_args:
84
+ print(parser.format_help())
85
+ print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
86
+ raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
87
+
88
+ if args.should_log:
89
+ transformers.utils.logging.set_verbosity(args.get_process_log_level())
90
+ transformers.utils.logging.enable_default_handler()
91
+ transformers.utils.logging.enable_explicit_format()
92
+ # set seeds manually
93
+ transformers.set_seed(args.seed)
94
+ return args