Erland commited on
Commit
af9535c
·
verified ·
1 Parent(s): 933075b

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +512 -0
  2. config.json +33 -0
  3. configs/delta_net_1B.json +29 -0
  4. configs/delta_net_340M.json +27 -0
  5. configs/gla_340M.json +24 -0
  6. configs/gla_7B.json +25 -0
  7. configs/gsa_340M.json +29 -0
  8. configs/hgrn2_340M.json +20 -0
  9. configs/rectified_transformer_340M.json +19 -0
  10. configs/scaled_softpick_transformer_120M.json +19 -0
  11. configs/scaled_vanilla_transformer_120M.json +19 -0
  12. configs/scaled_vanilla_transformer_340M.json +19 -0
  13. configs/softpick_transformer_1B.json +23 -0
  14. download_checkpoint.py +35 -0
  15. fla/__pycache__/__init__.cpython-311.pyc +0 -0
  16. fla/__pycache__/utils.cpython-311.pyc +0 -0
  17. fla/layers/__init__.py +44 -0
  18. fla/layers/__pycache__/__init__.cpython-311.pyc +0 -0
  19. fla/layers/abc.py +218 -0
  20. fla/layers/attn.py +490 -0
  21. fla/layers/based.py +96 -0
  22. fla/layers/bitattn.py +192 -0
  23. fla/layers/forgetting_attn.py +109 -0
  24. fla/layers/gated_deltaproduct.py +351 -0
  25. fla/ops/utils/cumsum.py +400 -0
  26. flame/components/__init__.py +0 -0
  27. flame/config_manager.py +940 -0
  28. flame/data.py +570 -0
  29. flame/models/__init__.py +0 -0
  30. flame/models/__pycache__/__init__.cpython-311.pyc +0 -0
  31. flame/models/__pycache__/parallelize_fla.cpython-311.pyc +0 -0
  32. flame/models/__pycache__/pipeline_fla.cpython-311.pyc +0 -0
  33. flame/tools/__pycache__/utils.cpython-311.pyc +0 -0
  34. flame/utils/__init__.py +0 -0
  35. flame/utils/__pycache__/checkpoint.cpython-311.pyc +0 -0
  36. flame/utils/__pycache__/convert_dcp_to_hf.cpython-311.pyc +0 -0
  37. flame/utils/__pycache__/hf_utils.cpython-311.pyc +0 -0
  38. flame/utils/checkpoint.py +50 -0
  39. flame/utils/hf_utils.py +77 -0
  40. generation_config.json +6 -0
  41. logs/none_xprcuk_o/attempt_0/0/stdout.log +0 -0
  42. logs/none_xprcuk_o/attempt_0/1/stderr.log +0 -0
  43. logs/none_xprcuk_o/attempt_0/2/stderr.log +0 -0
  44. logs/none_xprcuk_o/attempt_0/2/stdout.log +0 -0
  45. logs/none_xprcuk_o/attempt_0/3/stdout.log +0 -0
  46. measure_sink_rate.py +137 -0
  47. passkey_retrieval.py +158 -0
  48. pyproject.toml +42 -0
  49. register_softpick.py +72 -0
  50. setup.py +50 -0
README.md ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Linear Attention Made Easy
4
+ # This is a fork for the paper:
5
+ # Softpick: No Attention Sink, No Massive Activations with Rectified Softmax
6
+
7
+ </div>
8
+
9
+ ## Instructions for Softpick Attention
10
+
11
+ This fork can only work on an older commit of torchtitan and flame, so the setup looks like this:
12
+
13
+ ```bash
14
+ git clone https://github.com/zaydzuhri/flame.git
15
+ cd flame
16
+ git checkout softpick-attention
17
+ git submodule update --init --recursive --remote
18
+ cd 3rdparty/torchtitan
19
+ git checkout 4f532e0
20
+ cd ../../
21
+
22
+ pip install .
23
+ pip install flash-attn --no-build-isolation
24
+ ```
25
+ The flash-linear-attention submodule has been changed to link to our fork: https://github.com/zaydzuhri/flash-linear-attention/tree/softpick-attention
26
+ So no need to manually clone it.
27
+
28
+ Then prepare the fineweb-edu 100B sample the same way as described in the flame repo guide below.
29
+
30
+ These are the training commands used in the paper:
31
+ ```bash
32
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/vanilla.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/vanilla_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/vanilla-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
33
+
34
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/softpick.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/softpick_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/softpick-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
35
+ ```
36
+
37
+ And the same for the extra experiments in the appendix:
38
+ ```bash
39
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/rectified.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/rectified_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/rectified-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
40
+
41
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/softpick.scaled.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/softpick_scaled_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/softpick-scaled-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
42
+ ```
43
+
44
+ Feel free to DM @zmkzmkz on X for any questions regarding the paper or this code!
45
+
46
+ ## Flame
47
+
48
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency.
49
+
50
+ **Feature Highlights:**
51
+
52
+ - 🚀 Minimal, easy-to-use, extensible training framework
53
+ - 🤗 Seamless integration with `fla` and `transformers`
54
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
55
+ - 🔮 4D parallelism (coming soon)
56
+
57
+ ## Setup
58
+
59
+ To get started, clone the `flame` repository and install the required dependencies:
60
+
61
+ ```bash
62
+ git clone https://github.com/fla-org/flame.git
63
+ cd flame
64
+ pip install .
65
+ ```
66
+
67
+ `flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules.
68
+ After installation, initialize and update the submodules:
69
+ ```sh
70
+ git submodule update --init --recursive
71
+ ```
72
+
73
+ ## Dataset Preparation
74
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
75
+
76
+ ```py
77
+ from datasets import load_dataset
78
+
79
+ # load fineweb-edu with parallel processing
80
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
81
+
82
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
83
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
84
+ ```
85
+
86
+ ## Training Recipes
87
+
88
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode.
89
+
90
+ > [!WARNING]
91
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
92
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
93
+
94
+ ```sh
95
+ bash train.sh \
96
+ --job.config_file flame/models/fla.toml \
97
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \
98
+ --model.config configs/transformer_340M.json \
99
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
100
+ --optimizer.name AdamW \
101
+ --optimizer.eps 1e-15 \
102
+ --optimizer.lr 3e-4 \
103
+ --lr_scheduler.warmup_steps 1024 \
104
+ --lr_scheduler.lr_min 0.1 \
105
+ --lr_scheduler.decay_type cosine \
106
+ --training.batch_size 1 \
107
+ --training.seq_len 65536 \
108
+ --training.context_len 4096 \
109
+ --training.varlen \
110
+ --training.gradient_accumulation_steps 1 \
111
+ --training.steps 20480 \
112
+ --training.max_norm 1.0 \
113
+ --training.skip_nan_inf \
114
+ --training.dataset HuggingFaceFW/fineweb-edu \
115
+ --training.dataset_name sample-100BT \
116
+ --training.dataset_split train \
117
+ --training.streaming \
118
+ --training.num_workers 32 \
119
+ --training.prefetch_factor 2 \
120
+ --training.seed 42 \
121
+ --training.compile \
122
+ --checkpoint.interval 2048 \
123
+ --checkpoint.load_step -1 \
124
+ --checkpoint.keep_latest_k 2 \
125
+ --metrics.log_freq 1
126
+ ```
127
+
128
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
129
+ **For single-GPU debugging, set `NGPU=1`.**
130
+
131
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
132
+ By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
133
+
134
+ **Key parameters:**
135
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
136
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
137
+ - `--training.steps`: Total number of training steps.
138
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
139
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
140
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
141
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
142
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
143
+
144
+ > [!WARNING]
145
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
146
+ > Each step processes `global_batch_size * seq_len` tokens.
147
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
148
+
149
+ For a detailed explanation of all parameters, run:
150
+
151
+ ```sh
152
+ bash train.sh -h
153
+ ```
154
+
155
+ <details>
156
+ <summary>Usage</summary>
157
+
158
+ ```py
159
+ options:
160
+ -h, --help show this help message and exit
161
+ --job.config_file JOB.CONFIG_FILE
162
+ Job config file
163
+ --job.dump_folder JOB.DUMP_FOLDER
164
+ Folder to dump job outputs
165
+ --job.description JOB.DESCRIPTION
166
+ Description of the job
167
+ --job.use_for_integration_test
168
+ Add this config to the integration test suite
169
+ --job.print_args Print the args to terminal
170
+ --model.config MODEL.CONFIG
171
+ Path to the model config
172
+ --model.norm_type MODEL.NORM_TYPE
173
+ Type of layer normalization to use [layernorm,
174
+ np_layernorm, rmsnorm, fused_rmsnorm]
175
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
176
+ Tokenizer path
177
+ --profiling.enable_profiling
178
+ Whether to enable pytorch profiler
179
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
180
+ Trace files location
181
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
182
+ How often to collect profiler traces, in iterations
183
+ --profiling.enable_memory_snapshot
184
+ Whether to dump memory snapshot
185
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
186
+ Memeory snapshot files location
187
+ --optimizer.name OPTIMIZER.NAME
188
+ Optimizer to use
189
+ --optimizer.eps OPTIMIZER.EPS
190
+ Epsilon value for the optimizer.
191
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
192
+ --optimizer.scheduler {wsd,cosine,linear}
193
+ Scheduler to use. Currently supported: wsd, cosine,
194
+ and linear.
195
+ --optimizer.lr OPTIMIZER.LR
196
+ Learning rate to use
197
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
198
+ Min lr ratio for lr scheduler
199
+ --optimizer.early_step_in_backward
200
+ Whether to apply optimizer in the backward. Caution,
201
+ optimizer_in_backward is not compatible with gradients
202
+ clipping, users should not call
203
+ register_post_accumulate_grad_hook after the optimizer
204
+ is built.
205
+ --training.batch_size TRAINING.BATCH_SIZE
206
+ Batch size
207
+ --training.seq_len TRAINING.SEQ_LEN
208
+ Sequence length
209
+ --training.context_len TRAINING.CONTEXT_LEN
210
+ Max length allowed for each sequence
211
+ --training.varlen Whether to take sequences of variable length as input
212
+ --training.warmup_steps TRAINING.WARMUP_STEPS
213
+ Steps for lr scheduler warmup, normally 1/5 of
214
+ --training.steps
215
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
216
+ Number of steps to accumulate gradients before
217
+ updating parameters
218
+ --training.steps TRAINING.STEPS
219
+ How many train steps to run
220
+ --training.max_norm TRAINING.MAX_NORM
221
+ Max norm for gradient clipping
222
+ --training.skip_nan_inf
223
+ Skip batch updates when NaN or INF gradients are
224
+ encountered during training
225
+ --training.dataset TRAINING.DATASET
226
+ Dataset to use, with comma separated values
227
+ --training.dataset_name TRAINING.DATASET_NAME
228
+ The name of the dataset config, with comma separated
229
+ values if provided
230
+ --training.dataset_split TRAINING.DATASET_SPLIT
231
+ Dataset split to use, with comma separated values if
232
+ provided
233
+ --training.data_dir TRAINING.DATA_DIR
234
+ Data dirs to use, with comma separated values if
235
+ provided
236
+ --training.data_files TRAINING.DATA_FILES
237
+ Data files to use, with comma separated values if
238
+ provided
239
+ --training.data_probs TRAINING.DATA_PROBS
240
+ Data sampling probabilities, with comma separated
241
+ values if provided
242
+ --training.streaming Whether to load dataset in streaming mode, used for
243
+ huge dataset
244
+ --training.num_workers TRAINING.NUM_WORKERS
245
+ Number of subprocesses to use for data loading. 0
246
+ means that the data will be loaded in the main
247
+ process.
248
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
249
+ Number of batches loaded in advance by each worker.2
250
+ means there will be a total of 2 * num_workers batches
251
+ prefetched across all workers.
252
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
253
+ The `data_parallel_replicate_degree` argument
254
+ specifies the degree of data parallelism for weight
255
+ replication. When this value is greater than 1,
256
+ weights will be replicated across
257
+ `data_parallel_replicate_degree` ranks. If
258
+ `data_parallel_shard_degree` is also greater than 1,
259
+ the parallelism method used is HSDP (Hybrid Sharded
260
+ Data Parallelism). Otherwise, the parallelism method
261
+ used is DDP (Distributed Data Parallelism). 1 means
262
+ disabled.
263
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
264
+ The `data_parallel_shard_degree` argument specifies
265
+ the degree of data parallelism for weight sharding.
266
+ When this value is greater than 1, weights will be
267
+ sharded across `data_parallel_shard_degree` ranks. If
268
+ `data_parallel_replicate_degree` is also greater than
269
+ 1, the parallelism method used is HSDP (Hybrid Sharded
270
+ Data Parallelism). Otherwise, the parallelism method
271
+ used is FSDP (Fully Sharded Data Parallelism). -1
272
+ means leftover ranks will be used (After
273
+ DP_REPLICATE/SP/PP). Note that only
274
+ `data_parallel_shard_degree` can be negative. 1 means
275
+ disabled.
276
+ --training.enable_cpu_offload
277
+ Whether to apply CPU offloading of parameters,
278
+ gradients, and optimizer states in FSDP
279
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
280
+ Tensor Parallelism degree. 1 means disabled.
281
+ --training.disable_loss_parallel
282
+ Whether to apply loss parallel when sequence parallel
283
+ is enabled
284
+ --training.mixed_precision_param {bfloat16,float32}
285
+ torch dtype to use for parameters when applying mixed
286
+ precision via FSDP. This feature only takes effect
287
+ when data_parallel_shard_degree > 1
288
+ --training.mixed_precision_reduce {float32}
289
+ torch dtype to use for reductions when applying mixed
290
+ precision via FSDP. This feature only takes effect
291
+ when data_parallel_shard_degree > 1
292
+ --training.compile Whether to compile the model
293
+ --training.gc_freq TRAINING.GC_FREQ
294
+ Python garbage control scheduling interval, in steps
295
+ --training.seed TRAINING.SEED
296
+ Choose the base RNG seed used for training
297
+ --training.deterministic
298
+ Use deterministic algorithms wherever possible, may be
299
+ slower
300
+ --metrics.log_freq METRICS.LOG_FREQ
301
+ How often to log metrics to TensorBoard, in iterations
302
+ --metrics.enable_tensorboard
303
+ Whether to log metrics to TensorBoard
304
+ --metrics.disable_color_printing
305
+ Whether to disable color printing in logs
306
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
307
+ Folder to dump TensorBoard states
308
+ --metrics.rank_0_only
309
+ Whether to save TensorBoard metrics only for rank 0 or
310
+ for all ranks. When pipeline_parallel_degree is > 1,
311
+ this option uses the 0th rank of the last stage
312
+ pipeline group, which is the only stage that computes
313
+ loss metrics.
314
+ --metrics.enable_wandb
315
+ Whether to log metrics to Weights & Biases
316
+ --experimental.enable_async_tensor_parallel
317
+ Whether to apply async tensor parallel (currently only
318
+ effective when compile is enabled)
319
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
320
+ Pipeline Parallelism degree, or number of ranks. 1
321
+ means disabled. If using looped schedules, this still
322
+ specifies the number of physical ranks, not the number
323
+ of stages. Stages per rank are inferred from split
324
+ points degree, and schedule.
325
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
326
+ Specify comma-separated names of modules to use as the
327
+ beginning of a split point. e.g. "layers.0,layers.2"
328
+ will cause the model to be split into 3 stages, the
329
+ first containing all the layers up to layers.0, the
330
+ second containing layers.0 and up to layers.2, the
331
+ third containing layers.2 and all the remaining
332
+ layers. Note: fully-automated splitting may be enabled
333
+ in the future, but currently the split points must be
334
+ specified manually.
335
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
336
+ Specify the Pipeline Parallel schedule to use. The
337
+ supported schedules are: https://github.com/pytorch/py
338
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
339
+ rch/distributed/pipelining/schedules.py#L2161. The
340
+ schedule must be compatible with the split points and
341
+ stages_per_rank. Looped schedules (e.g.
342
+ Interleaved1F1B) require specifying
343
+ pipeline_parallel_degree = number of ranks, and
344
+ split_points = number of stages - 1
345
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
346
+ Specify the path to the pipeline parallel schedule csv
347
+ file to use. The pipeline_parallel_schedule argument
348
+ must be either PipelineScheduleSingle,
349
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
350
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
351
+ How many microbatches to split the global training
352
+ batch into when using pipeline parallelism. The global
353
+ training batch size must be evenly divisible by the
354
+ number of microbatches. The default value will be the
355
+ number of pipeline stages, if unspecified.
356
+ --experimental.enable_compiled_autograd
357
+ Enable CompiledAutograd to compile the backward.
358
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
359
+ Context parallelism degree. 1 means disabled.
360
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
361
+ The collective to use in context parallel SDPA for kv
362
+ shards exchange. 'allgather' means to all-gather all
363
+ kv shards on ranks after the first sub-SDPA
364
+ computation, 'alltoall' means to all-to-all shuffle
365
+ the kv shards. The default value is 'allgather'.
366
+ --checkpoint.enable_checkpoint
367
+ Whether to enable checkpoint
368
+ --checkpoint.folder CHECKPOINT.FOLDER
369
+ The folder to store the checkpoints. When
370
+ enable_checkpoint is set to true, checkpoints will be
371
+ in {--job.dump_folder}/{--checkpoint.folder}.
372
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
373
+ Checkpointing interval unit of measurement ['step',
374
+ 'seconds']
375
+ --checkpoint.interval CHECKPOINT.INTERVAL
376
+ Checkpointing interval, in steps or seconds depending
377
+ on --checkpoint.interval_type
378
+ --checkpoint.model_weights_only
379
+ When model_weights_only=True, only model weights will
380
+ be saved at the end of training. With this,
381
+ checkpoints can be loaded using `torch.load(...,
382
+ weights_only=True)` after conversion. When
383
+ model_weights_only=False, the full checkpoint will be
384
+ saved. A full checkpoint includes model, optimizer and
385
+ train_state, which can be used to resume training. The
386
+ default value is false.
387
+ --checkpoint.export_dtype {float16,bfloat16,float32}
388
+ Converts to the specified precision when training
389
+ completes and model_weights_only=true. Currently
390
+ supports float32, float16, and bfloat16. The default
391
+ value is float32.
392
+ --checkpoint.create_seed_checkpoint
393
+ Initializes the full model without applying
394
+ parallelisms, and then saves it as a seed checkpoint.
395
+ Note: requires user to call train.py without
396
+ specifying any parallelisms, e.g. NGPU=1. Could be
397
+ implemented as a separate script, but this way shares
398
+ more code.
399
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
400
+ Which async checkpoint mode to use. Currently there
401
+ are 3 different modes. 1. "disabled": synchronized
402
+ checkpointing will be used. 2. "async":
403
+ torch.distributed.checkpoint.async_save will be used.
404
+ 1. "async_with_pinned_mem": this option utilizes a
405
+ dedicated pinned memory space and creates a separate
406
+ process for faster GPU->CPU transfer performance and
407
+ eliminating GIL contention. The cost is increased CPU
408
+ memory usage. If insufficient CPU memory is available,
409
+ performance may degrade due to memory paging. For most
410
+ users, "async" should suffice as the performance
411
+ overhead is typically small (on the order of tens of
412
+ seconds) compared to checkpointing frequency. This
413
+ mode can be employed to pursue near-zero checkpointing
414
+ times (e.g., < 1 second) given appropriate hardware
415
+ support such as ample CPU memory and fast PCIe.
416
+ "disabled" is the default mode.
417
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
418
+ Keeps only the latest k checkpoints, and purging older
419
+ ones. If 0, keep all checkpoints. 0 is the default
420
+ value.
421
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
422
+ Load the checkpoint at the specified step. If -1, load
423
+ the latest checkpoint.
424
+ --float8.enable_float8_linear
425
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
426
+ This feature requires you to install 'torchao' which
427
+ can be found here: https://github.com/pytorch/ao
428
+ --float8.enable_fsdp_float8_all_gather
429
+ Whether enable float8 all-gather in FSDP
430
+ --float8.precompute_float8_dynamic_scale_for_fsdp
431
+ Whether precompute float8 scales dynamically for FSDP
432
+ --float8.scaling_type_input {dynamic,delayed}
433
+ float8 scaling for input, dynamic (default) or delayed
434
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
435
+ float8 scaling for input, dynamic (default) or delayed
436
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
437
+ float8 scaling for input, dynamic (default) or delayed
438
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
439
+ Timeout for communication operations, during
440
+ initialization and first train step.
441
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
442
+ Timeout for communication operations after the first
443
+ train step -- usually a tighter bound than during
444
+ initialization.
445
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
446
+ Flight recorder ring buffer size, >0 means recording
447
+ by default, 0 means disabled
448
+ --memory_estimation.enabled
449
+ Whether to estimate memory usage for FSDP
450
+ --memory_estimation.disable_fake_mode
451
+ Whether to estimate memory under FakeTensorMode
452
+ ```
453
+ </details>
454
+
455
+ ### Training with `torch.compile`
456
+
457
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
458
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
459
+
460
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
461
+ We are actively working on resolving these issues to make compilation transparent to users.
462
+ In the meantime, please ensure you are using the latest dependencies.
463
+
464
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
465
+
466
+ ### Training with multiple datasets
467
+
468
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
469
+ `flame` allows training with multiple datasets easily.
470
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
471
+
472
+ ```sh
473
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
474
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
475
+ ```
476
+
477
+ ### ~Finalizing training~
478
+
479
+ > [!NOTE]
480
+ > We have done this conversion automatically in the training script since our latest updates.
481
+
482
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
483
+ To facilitate this, we provide a straightforward conversion script:
484
+
485
+ ```sh
486
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
487
+ ```
488
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
489
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
490
+
491
+ ### Continual training
492
+
493
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
494
+ This allows you to seamlessly resume training with `flame`.
495
+ ```sh
496
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
497
+ ```
498
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
499
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
500
+
501
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
502
+
503
+ ## Multi-node training
504
+
505
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
506
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
507
+
508
+ To set up multi-node training:
509
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
510
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
511
+
512
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TransformerForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attn_impl": "parallel_softpick_attn",
7
+ "bos_token_id": 1,
8
+ "elementwise_affine": true,
9
+ "eos_token_id": 2,
10
+ "fuse_cross_entropy": false,
11
+ "fuse_norm": false,
12
+ "fuse_swiglu": true,
13
+ "hidden_act": "swish",
14
+ "hidden_ratio": 4,
15
+ "hidden_size": 768,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": null,
18
+ "max_position_embeddings": 4096,
19
+ "model_type": "transformer",
20
+ "norm_eps": 1e-06,
21
+ "num_heads": 12,
22
+ "num_hidden_layers": 14,
23
+ "num_kv_heads": null,
24
+ "qk_norm": false,
25
+ "qkv_bias": false,
26
+ "rope_theta": 10000.0,
27
+ "tie_word_embeddings": true,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.51.3",
30
+ "use_cache": true,
31
+ "vocab_size": 32000,
32
+ "window_size": null
33
+ }
configs/delta_net_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "conv_size": 4,
6
+ "eos_token_id": 2,
7
+ "expand_k": 1,
8
+ "expand_v": 1,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.006,
15
+ "intermediate_size": null,
16
+ "model_type": "delta_net",
17
+ "norm_eps": 1e-06,
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 2,
21
+ "qk_activation": "silu",
22
+ "qk_norm": "l2",
23
+ "tie_word_embeddings": false,
24
+ "use_beta": true,
25
+ "use_cache": true,
26
+ "use_gate": false,
27
+ "use_output_norm": true,
28
+ "use_short_conv": true
29
+ }
configs/delta_net_340M.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.006,
13
+ "intermediate_size": null,
14
+ "model_type": "delta_net",
15
+ "norm_eps": 1e-06,
16
+ "norm_first": false,
17
+ "num_heads": 8,
18
+ "num_hidden_layers": 24,
19
+ "qk_activation": "silu",
20
+ "qk_norm": "l2",
21
+ "tie_word_embeddings": false,
22
+ "use_beta": true,
23
+ "use_cache": true,
24
+ "use_gate": false,
25
+ "use_output_norm": true,
26
+ "use_short_conv": true
27
+ }
configs/gla_340M.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": null,
15
+ "model_type": "gla",
16
+ "num_heads": 4,
17
+ "num_hidden_layers": 24,
18
+ "norm_eps": 1e-06,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "vocab_size": 32000
24
+ }
configs/gla_7B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": 11008,
15
+ "model_type": "gla",
16
+ "norm_eps": 1e-06,
17
+ "num_heads": 16,
18
+ "num_hidden_layers": 32,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "use_output_gate": true,
24
+ "use_short_conv": false
25
+ }
configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
configs/hgrn2_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "expand_ratio": 128,
6
+ "fuse_cross_entropy": true,
7
+ "fuse_norm": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 1024,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "model_type": "hgrn2",
14
+ "num_heads": 8,
15
+ "num_hidden_layers": 24,
16
+ "norm_eps": 1e-06,
17
+ "tie_word_embeddings": false,
18
+ "use_cache": true,
19
+ "vocab_size": 32000
20
+ }
configs/rectified_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_rectified_attn"
19
+ }
configs/scaled_softpick_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_scaled_softpick_attn"
19
+ }
configs/scaled_vanilla_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_scaled_attn"
19
+ }
configs/scaled_vanilla_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_scaled_attn"
19
+ }
configs/softpick_transformer_1B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "attn_impl": "parallel_softpick_attn"
23
+ }
download_checkpoint.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from huggingface_hub import HfApi, HfFolder, snapshot_download
4
+
5
+ def main(args):
6
+ api = HfApi()
7
+ token = HfFolder.get_token()
8
+ experiment_checkpoint_folder = os.path.join(args.experiment_checkpoint_folder, "checkpoint")
9
+ os.makedirs(
10
+ experiment_checkpoint_folder,
11
+ exist_ok=True
12
+ )
13
+
14
+ snapshot_download(
15
+ repo_id=args.repo_id,
16
+ token=token,
17
+ local_dir=experiment_checkpoint_folder,
18
+ )
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser(description="Download a checkpoint from Hugging Face Hub.")
22
+ parser.add_argument(
23
+ "--repo_id",
24
+ type=str,
25
+ required=True,
26
+ help="The repository ID on Hugging Face Hub.",
27
+ )
28
+ parser.add_argument(
29
+ "--experiment_checkpoint_folder",
30
+ type=str,
31
+ required=True,
32
+ help="The local directory to save the downloaded checkpoint.",
33
+ )
34
+ args = parser.parse_args()
35
+ main(args)
fla/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.46 kB). View file
 
fla/__pycache__/utils.cpython-311.pyc ADDED
Binary file (13.9 kB). View file
 
fla/layers/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from .abc import ABCAttention
5
+ from .attn import Attention
6
+ from .based import BasedLinearAttention
7
+ from .bitattn import BitAttention
8
+ from .delta_net import DeltaNet
9
+ from .forgetting_attn import ForgettingAttention
10
+ from .gated_deltanet import GatedDeltaNet
11
+ from .gated_deltaproduct import GatedDeltaProduct
12
+ from .gla import GatedLinearAttention
13
+ from .gsa import GatedSlotAttention
14
+ from .hgrn import HGRNAttention
15
+ from .hgrn2 import HGRN2Attention
16
+ from .lightnet import LightNetAttention
17
+ from .linear_attn import LinearAttention
18
+ from .multiscale_retention import MultiScaleRetention
19
+ from .nsa import NativeSparseAttention
20
+ from .rebased import ReBasedLinearAttention
21
+ from .rwkv6 import RWKV6Attention
22
+ from .rwkv7 import RWKV7Attention
23
+
24
+ __all__ = [
25
+ 'ABCAttention',
26
+ 'Attention',
27
+ 'BasedLinearAttention',
28
+ 'BitAttention',
29
+ 'DeltaNet',
30
+ 'ForgettingAttention',
31
+ 'GatedDeltaNet',
32
+ 'GatedDeltaProduct',
33
+ 'GatedLinearAttention',
34
+ 'GatedSlotAttention',
35
+ 'HGRNAttention',
36
+ 'HGRN2Attention',
37
+ 'LightNetAttention',
38
+ 'LinearAttention',
39
+ 'MultiScaleRetention',
40
+ 'NativeSparseAttention',
41
+ 'ReBasedLinearAttention',
42
+ 'RWKV6Attention',
43
+ 'RWKV7Attention',
44
+ ]
fla/layers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.52 kB). View file
 
fla/layers/abc.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, RotaryEmbedding, ShortConvolution
14
+ from fla.modules.activations import swiglu, swish
15
+ from fla.ops.abc.chunk import chunk_abc
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class ABCAttention(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int = 1024,
26
+ expand_k: float = 0.5,
27
+ expand_v: float = 1.0,
28
+ num_heads: int = 4,
29
+ use_short_conv: bool = False,
30
+ conv_size: int = 4,
31
+ conv_bias: bool = False,
32
+ num_slots: Optional[int] = None,
33
+ elementwise_affine: Optional[bool] = True,
34
+ norm_eps: float = 1e-5,
35
+ gate_low_rank_dim: int = 16,
36
+ gate_logit_normalizer: int = 16,
37
+ use_rope: bool = True,
38
+ use_input_gate: bool = False,
39
+ use_output_gate: bool = True,
40
+ use_norm: bool = True,
41
+ clamp_min: Optional[float] = -32,
42
+ clamp_max: Optional[float] = 32,
43
+ layer_idx: Optional[int] = None,
44
+ **kwargs
45
+ ) -> ABCAttention:
46
+ super().__init__()
47
+
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.num_heads = num_heads
52
+ self.key_dim = int(self.hidden_size * self.expand_k)
53
+ self.value_dim = int(self.hidden_size * self.expand_v)
54
+ self.head_k_dim = self.key_dim // self.num_heads
55
+ self.head_v_dim = self.value_dim // self.num_heads
56
+
57
+ self.use_short_conv = use_short_conv
58
+ self.conv_size = conv_size
59
+ self.conv_bias = conv_bias
60
+
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.gate_logit_normalizer = gate_logit_normalizer
63
+
64
+ self.use_rope = use_rope
65
+ self.use_input_gate = use_input_gate
66
+ self.use_output_gate = use_output_gate
67
+ self.use_norm = use_norm
68
+
69
+ if num_slots is None:
70
+ num_slots = self.head_k_dim
71
+ self.num_slots = num_slots
72
+
73
+ self.norm_eps = norm_eps
74
+
75
+ self.clamp_min = clamp_min
76
+ self.clamp_max = clamp_max
77
+ self.layer_idx = layer_idx
78
+
79
+ if layer_idx is None:
80
+ warnings.warn(
81
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
82
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
83
+ "when creating this class."
84
+ )
85
+
86
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
87
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
88
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
89
+
90
+ if use_output_gate:
91
+ self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
92
+ self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
93
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
94
+
95
+ if use_short_conv:
96
+ self.conv_size = conv_size
97
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
98
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
99
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
100
+
101
+ if self.use_norm:
102
+ if self.use_output_gate:
103
+ self.g_norm = FusedRMSNormGated(
104
+ hidden_size=self.head_v_dim,
105
+ elementwise_affine=elementwise_affine,
106
+ eps=norm_eps
107
+ )
108
+ else:
109
+ self.g_norm = RMSNorm(
110
+ hidden_size=self.head_v_dim,
111
+ elementwise_affine=elementwise_affine,
112
+ eps=norm_eps
113
+ )
114
+
115
+ if self.use_rope:
116
+ self.rotary = RotaryEmbedding(self.head_k_dim)
117
+
118
+ def forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: Optional[torch.Tensor] = None,
122
+ past_key_values: Optional[Cache] = None,
123
+ use_cache: Optional[bool] = False,
124
+ output_attentions: Optional[bool] = False,
125
+ **kwargs
126
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
127
+ if attention_mask is not None:
128
+ assert len(attention_mask.shape) == 2, (
129
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
130
+ "for padding purposes (0 indicating padding). "
131
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
132
+ )
133
+
134
+ last_state = None
135
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
136
+ last_state = past_key_values[self.layer_idx]
137
+
138
+ cu_seqlens = kwargs.get('cu_seqlens', None)
139
+ if cu_seqlens is not None:
140
+ raise NotImplementedError("Training with cu_seqlens is not supported yet for ABCAttention")
141
+ if self.use_short_conv:
142
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
143
+ if last_state is not None:
144
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
145
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
146
+ q, conv_state_q = self.q_conv1d(
147
+ x=self.q_proj(hidden_states),
148
+ mask=conv_mask,
149
+ cache=conv_state_q,
150
+ output_final_state=use_cache,
151
+ cu_seqlens=cu_seqlens
152
+ )
153
+ k, conv_state_k = self.k_conv1d(
154
+ x=self.k_proj(hidden_states),
155
+ mask=conv_mask,
156
+ cache=conv_state_k,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens
159
+ )
160
+ v, conv_state_v = self.v_conv1d(
161
+ x=self.v_proj(hidden_states),
162
+ mask=conv_mask,
163
+ cache=conv_state_v,
164
+ output_final_state=use_cache,
165
+ cu_seqlens=cu_seqlens
166
+ )
167
+ else:
168
+ q = self.q_proj(hidden_states)
169
+ k = self.k_proj(hidden_states)
170
+ v = self.v_proj(hidden_states)
171
+
172
+ if self.use_input_gate:
173
+ q, k, v = map(lambda x: swish(x), (q, k, v))
174
+ # dealing with left-padding
175
+ if attention_mask is not None:
176
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
177
+
178
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
179
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
180
+ if self.use_rope:
181
+ seqlen_offset = 0
182
+ if past_key_values is not None:
183
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
184
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset)
185
+
186
+ s = rearrange(self.s_proj(hidden_states), '... (h m) -> ... h m', m=self.num_slots)
187
+ s = s.clamp_(self.clamp_min, self.clamp_max)
188
+
189
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
190
+ o, recurrent_state = chunk_abc(
191
+ q=q,
192
+ k=k,
193
+ v=v,
194
+ s=s,
195
+ initial_state=recurrent_state,
196
+ output_final_state=use_cache,
197
+ head_first=False
198
+ )
199
+ if past_key_values is not None:
200
+ past_key_values.update(
201
+ recurrent_state=recurrent_state,
202
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
203
+ layer_idx=self.layer_idx,
204
+ offset=q.shape[1]
205
+ )
206
+
207
+ if self.use_norm and not self.use_output_gate:
208
+ o = self.g_norm(o)
209
+ elif self.use_output_gate:
210
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
211
+ o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
212
+ o = rearrange(o, '... h d -> ... (h d)')
213
+ o = self.o_proj(o)
214
+
215
+ return o, None, past_key_values
216
+
217
+ def state_size(self, seq_len: int = 2048):
218
+ return 2 * self.num_slots * self.hidden_size
fla/layers/attn.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RMSNorm, RotaryEmbedding
17
+ from fla.ops import parallel_attn, parallel_rectified_attn, parallel_softpick_attn, naive_attn, naive_rectified_attn, naive_softpick_attn
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
24
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
25
+ except ImportError:
26
+ warnings.warn(
27
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class Attention(nn.Module):
36
+
37
+ def __init__(
38
+ self,
39
+ hidden_size: int = 2048,
40
+ num_heads: int = 32,
41
+ num_kv_heads: Optional[int] = None,
42
+ qkv_bias: bool = False,
43
+ qk_norm: bool = False,
44
+ window_size: Optional[int] = None,
45
+ rope_theta: Optional[float] = 10000.,
46
+ max_position_embeddings: Optional[int] = None,
47
+ layer_idx: int = None,
48
+ attn_impl: str = "flash_attn",
49
+ ):
50
+ super().__init__()
51
+
52
+ self.hidden_size = hidden_size
53
+ self.num_heads = num_heads
54
+ if num_kv_heads is None:
55
+ self.num_kv_heads = self.num_heads
56
+ else:
57
+ self.num_kv_heads = num_kv_heads
58
+ self.num_kv_groups = num_heads // self.num_kv_heads
59
+ self.head_dim = self.hidden_size // self.num_heads
60
+ self.kv_dim = self.num_kv_heads * self.head_dim
61
+ self.qkv_bias = qkv_bias
62
+ self.qk_norm = qk_norm
63
+
64
+ self.window_size = window_size
65
+ self.rope_theta = rope_theta
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.layer_idx = layer_idx
68
+ self.attn_impl = attn_impl
69
+
70
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
71
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
72
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
73
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
74
+
75
+ if "scaled" in self.attn_impl:
76
+ self.s = nn.Parameter(torch.empty(self.num_heads, 1))
77
+ self.register_buffer("logn", torch.log(torch.arange(2, self.max_position_embeddings*4+2, dtype=self.s.dtype)[:, None, None]))
78
+
79
+ if qk_norm:
80
+ self.q_norm = RMSNorm(self.head_dim)
81
+ self.k_norm = RMSNorm(self.head_dim)
82
+
83
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
84
+
85
+ def reset_parameters(self):
86
+ if "scaled" in self.attn_impl:
87
+ nn.init.constant_(self.s, 0.3)
88
+ self.logn.copy_(torch.log(torch.arange(2, self.max_position_embeddings*4+2, dtype=self.s.dtype)[:, None, None]))
89
+
90
+ def forward(
91
+ self,
92
+ hidden_states: torch.Tensor,
93
+ attention_mask: Optional[torch.LongTensor] = None,
94
+ past_key_values: Optional[Cache] = None,
95
+ output_attentions: bool = False,
96
+ use_cache: bool = False,
97
+ **kwargs,
98
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
99
+ if attention_mask is not None:
100
+ assert len(attention_mask.shape) == 2, (
101
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
102
+ "for padding purposes (0 indicating padding). "
103
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
104
+ )
105
+
106
+ batch_size, q_len, _ = hidden_states.size()
107
+
108
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
109
+
110
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
111
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
112
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
113
+
114
+ if self.qk_norm:
115
+ q, k = self.q_norm(q), self.k_norm(k)
116
+
117
+ # equivalent to cu_seqlens in `flash_attn`
118
+ cu_seqlens = kwargs.get('cu_seqlens', None)
119
+
120
+ seqlen_offset, max_seqlen = 0, q_len
121
+ if past_key_values is not None:
122
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
123
+ max_seqlen = q.shape[1] + seqlen_offset
124
+
125
+ if attention_mask is not None:
126
+ # to deliminate the offsets of padding tokens
127
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
128
+ max_seqlen = q.shape[1] + max(seqlen_offset)
129
+
130
+ if self.max_position_embeddings is not None:
131
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
132
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
133
+
134
+ if past_key_values is not None:
135
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
136
+ k_cached, v_cached = past_key_values.update(
137
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
138
+ layer_idx=self.layer_idx,
139
+ offset=q_len,
140
+ cache_kwargs=dict(window_size=self.window_size)
141
+ )['attn_state']
142
+ if cache_has_content:
143
+ k, v = k_cached, v_cached
144
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
145
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
146
+
147
+ # if flash_attn_func is None:
148
+ # raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
149
+
150
+ if "scaled" in self.attn_impl:
151
+ k_len = k.shape[1]
152
+ q = q * self.s.to(q.dtype) * self.logn[k_len-q_len:k_len].to(q.dtype)
153
+
154
+ # Contains at least one padding token in the sequence
155
+ if self.attn_impl == "flash_attn":
156
+ if attention_mask is not None:
157
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
158
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
159
+ max_seqlen_q, max_seqlen_k = max_seq_lens
160
+ o = flash_attn_varlen_func(
161
+ q, k, v,
162
+ cu_seqlens_q=cu_seqlens_q,
163
+ cu_seqlens_k=cu_seqlens_k,
164
+ max_seqlen_q=max_seqlen_q,
165
+ max_seqlen_k=max_seqlen_k,
166
+ causal=True,
167
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
168
+ )
169
+ o = pad_input(o, indices_q, batch_size, q_len)
170
+ elif cu_seqlens is not None:
171
+ o = flash_attn_varlen_func(
172
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
173
+ cu_seqlens_q=cu_seqlens,
174
+ cu_seqlens_k=cu_seqlens,
175
+ max_seqlen_q=max_seqlen,
176
+ max_seqlen_k=max_seqlen,
177
+ causal=True,
178
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
179
+ ).unsqueeze(0)
180
+ else:
181
+ o = flash_attn_func(
182
+ q, k, v,
183
+ causal=True,
184
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
185
+ )
186
+ elif self.attn_impl == "parallel_attn":
187
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
188
+ elif self.attn_impl == "parallel_scaled_attn":
189
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
190
+ elif self.attn_impl == "parallel_rectified_attn":
191
+ o = parallel_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
192
+ elif self.attn_impl == "parallel_softpick_attn":
193
+ o = parallel_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
194
+ elif self.attn_impl == "parallel_scaled_softpick_attn":
195
+ o = parallel_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
196
+ elif self.attn_impl == "naive_attn":
197
+ o, attentions = naive_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
198
+ elif self.attn_impl == "naive_scaled_attn":
199
+ o, attentions = naive_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
200
+ elif self.attn_impl == "naive_rectified_attn":
201
+ o, attentions = naive_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
202
+ elif self.attn_impl == "naive_softpick_attn":
203
+ o, attentions = naive_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
204
+ elif self.attn_impl == "naive_scaled_softpick_attn":
205
+ o, attentions = naive_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
206
+ else:
207
+ raise ValueError(f"Unknown attention implementation: {self.attn_impl}")
208
+
209
+ o = o.reshape(batch_size, q_len, -1)
210
+ o = self.o_proj(o)
211
+
212
+ if not output_attentions or "parallel" in self.attn_impl or "flash" in self.attn_impl:
213
+ attentions = None
214
+
215
+ return o, attentions, past_key_values
216
+
217
+ def _upad_input(self, q, k, v, attention_mask, q_len):
218
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
219
+ cache_mask = attention_mask[:, -seq_len:]
220
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
221
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
222
+ max_seqlen_k = seqlens.max().item()
223
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
224
+
225
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
226
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
227
+ if q_len == seq_len:
228
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
229
+ cu_seqlens_q = cu_seqlens_k
230
+ max_seqlen_q = max_seqlen_k
231
+ indices_q = indices_k
232
+ elif q_len == 1:
233
+ max_seqlen_q = 1
234
+ # There is a memcpy here, that is very bad.
235
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
236
+ indices_q = cu_seqlens_q[:-1]
237
+ q = q.squeeze(1)
238
+ else:
239
+ # The -q_len: slice assumes left padding.
240
+ attention_mask = attention_mask[:, -q_len:]
241
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
242
+
243
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
244
+
245
+ class StochasticSoftpickAttention(nn.Module):
246
+
247
+ def __init__(
248
+ self,
249
+ hidden_size: int = 2048,
250
+ num_heads: int = 32,
251
+ num_kv_heads: Optional[int] = None,
252
+ qkv_bias: bool = False,
253
+ qk_norm: bool = False,
254
+ window_size: Optional[int] = None,
255
+ rope_theta: Optional[float] = 10000.,
256
+ max_position_embeddings: Optional[int] = None,
257
+ layer_idx: int = None,
258
+ attn_impl: str = "flash_attn",
259
+ stochastic_p: float = 0.5,
260
+ ):
261
+ super().__init__()
262
+
263
+ self.hidden_size = hidden_size
264
+ self.num_heads = num_heads
265
+ if num_kv_heads is None:
266
+ self.num_kv_heads = self.num_heads
267
+ else:
268
+ self.num_kv_heads = num_kv_heads
269
+ self.num_kv_groups = num_heads // self.num_kv_heads
270
+ self.head_dim = self.hidden_size // self.num_heads
271
+ self.kv_dim = self.num_kv_heads * self.head_dim
272
+ self.qkv_bias = qkv_bias
273
+ self.qk_norm = qk_norm
274
+
275
+ self.window_size = window_size
276
+ self.rope_theta = rope_theta
277
+ self.max_position_embeddings = max_position_embeddings
278
+ self.layer_idx = layer_idx
279
+ self.attn_impl = attn_impl
280
+ self.stochastic_value = stochastic_p
281
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
282
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
283
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
284
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
285
+
286
+ if "scaled" in self.attn_impl:
287
+ self.s = nn.Parameter(torch.empty(self.num_heads, 1))
288
+ self.register_buffer("logn", torch.log(torch.arange(2, self.max_position_embeddings*4+2, dtype=self.s.dtype)[:, None, None]))
289
+
290
+ if qk_norm:
291
+ self.q_norm = RMSNorm(self.head_dim)
292
+ self.k_norm = RMSNorm(self.head_dim)
293
+
294
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
295
+
296
+ def reset_parameters(self):
297
+ if "scaled" in self.attn_impl:
298
+ nn.init.constant_(self.s, 0.3)
299
+ self.logn.copy_(torch.log(torch.arange(2, self.max_position_embeddings*4+2, dtype=self.s.dtype)[:, None, None]))
300
+
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states: torch.Tensor,
305
+ attention_mask: Optional[torch.LongTensor] = None,
306
+ past_key_values: Optional[Cache] = None,
307
+ output_attentions: bool = False,
308
+ use_cache: bool = False,
309
+ **kwargs,
310
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
311
+ if attention_mask is not None:
312
+ assert len(attention_mask.shape) == 2, (
313
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
314
+ "for padding purposes (0 indicating padding). "
315
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
316
+ )
317
+
318
+ batch_size, q_len, _ = hidden_states.size()
319
+
320
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
321
+
322
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
323
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
324
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
325
+
326
+ if self.qk_norm:
327
+ q, k = self.q_norm(q), self.k_norm(k)
328
+
329
+ # equivalent to cu_seqlens in `flash_attn`
330
+ cu_seqlens = kwargs.get('cu_seqlens', None)
331
+
332
+ seqlen_offset, max_seqlen = 0, q_len
333
+ if past_key_values is not None:
334
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
335
+ max_seqlen = q.shape[1] + seqlen_offset
336
+
337
+ if attention_mask is not None:
338
+ # to deliminate the offsets of padding tokens
339
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
340
+ max_seqlen = q.shape[1] + max(seqlen_offset)
341
+
342
+ if self.max_position_embeddings is not None:
343
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
344
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
345
+
346
+ if past_key_values is not None:
347
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
348
+ k_cached, v_cached = past_key_values.update(
349
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
350
+ layer_idx=self.layer_idx,
351
+ offset=q_len,
352
+ cache_kwargs=dict(window_size=self.window_size)
353
+ )['attn_state']
354
+ if cache_has_content:
355
+ k, v = k_cached, v_cached
356
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
357
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
358
+
359
+ # if flash_attn_func is None:
360
+ # raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
361
+
362
+ if "scaled" in self.attn_impl:
363
+ k_len = k.shape[1]
364
+ q = q * self.s.to(q.dtype) * self.logn[k_len-q_len:k_len].to(q.dtype)
365
+
366
+ # Contains at least one padding token in the sequence
367
+
368
+ p = torch.rand(1, device=q.device)
369
+ stochastic_p = torch.tensor(self.stochastic_value, dtype=torch.float32, device=q.device)
370
+ cond = torch.where(p < stochastic_p, torch.tensor(1, dtype=torch.bool, device=q.device), torch.tensor(0, dtype=torch.bool, device=q.device))
371
+ if self.attn_impl == "flash_attn":
372
+ if attention_mask is not None:
373
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
374
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
375
+ max_seqlen_q, max_seqlen_k = max_seq_lens
376
+ o = flash_attn_varlen_func(
377
+ q, k, v,
378
+ cu_seqlens_q=cu_seqlens_q,
379
+ cu_seqlens_k=cu_seqlens_k,
380
+ max_seqlen_q=max_seqlen_q,
381
+ max_seqlen_k=max_seqlen_k,
382
+ causal=True,
383
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
384
+ )
385
+ o = pad_input(o, indices_q, batch_size, q_len)
386
+ elif cu_seqlens is not None:
387
+ o = flash_attn_varlen_func(
388
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
389
+ cu_seqlens_q=cu_seqlens,
390
+ cu_seqlens_k=cu_seqlens,
391
+ max_seqlen_q=max_seqlen,
392
+ max_seqlen_k=max_seqlen,
393
+ causal=True,
394
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
395
+ ).unsqueeze(0)
396
+ else:
397
+ o = flash_attn_func(
398
+ q, k, v,
399
+ causal=True,
400
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
401
+ )
402
+
403
+ elif self.attn_impl == "parallel_attn":
404
+ if cond:
405
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
406
+ else:
407
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
408
+ elif self.attn_impl == "parallel_scaled_attn":
409
+ if cond:
410
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
411
+ else:
412
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
413
+ elif self.attn_impl == "parallel_rectified_attn":
414
+ if cond:
415
+ o = parallel_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
416
+ else:
417
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
418
+ elif self.attn_impl == "parallel_softpick_attn":
419
+ if cond:
420
+ o = parallel_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
421
+ else:
422
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
423
+ elif self.attn_impl == "parallel_scaled_softpick_attn":
424
+ if cond:
425
+ o = parallel_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
426
+ else:
427
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
428
+ elif self.attn_impl == "naive_attn":
429
+ if cond:
430
+ o, attentions = naive_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
431
+ else:
432
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
433
+ elif self.attn_impl == "naive_scaled_attn":
434
+ if cond:
435
+ o, attentions = naive_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
436
+ else:
437
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
438
+ elif self.attn_impl == "naive_rectified_attn":
439
+ if cond:
440
+ o, attentions = naive_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
441
+ else:
442
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
443
+ elif self.attn_impl == "naive_softpick_attn":
444
+ if cond:
445
+ o, attentions = naive_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
446
+ else:
447
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
448
+ elif self.attn_impl == "naive_scaled_softpick_attn":
449
+ if cond:
450
+ o, attentions = naive_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
451
+ else:
452
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
453
+ else:
454
+ raise ValueError(f"Unknown attention implementation: {self.attn_impl}")
455
+
456
+ o = o.reshape(batch_size, q_len, -1)
457
+ o = self.o_proj(o)
458
+
459
+ if not output_attentions or "parallel" in self.attn_impl or "flash" in self.attn_impl:
460
+ attentions = None
461
+
462
+ return o, attentions, past_key_values
463
+
464
+ def _upad_input(self, q, k, v, attention_mask, q_len):
465
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
466
+ cache_mask = attention_mask[:, -seq_len:]
467
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
468
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
469
+ max_seqlen_k = seqlens.max().item()
470
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
471
+
472
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
473
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
474
+ if q_len == seq_len:
475
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
476
+ cu_seqlens_q = cu_seqlens_k
477
+ max_seqlen_q = max_seqlen_k
478
+ indices_q = indices_k
479
+ elif q_len == 1:
480
+ max_seqlen_q = 1
481
+ # There is a memcpy here, that is very bad.
482
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
483
+ indices_q = cu_seqlens_q[:-1]
484
+ q = q.squeeze(1)
485
+ else:
486
+ # The -q_len: slice assumes left padding.
487
+ attention_mask = attention_mask[:, -q_len:]
488
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
489
+
490
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/based.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Linear attention in Based.
6
+ https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules.feature_map import TaylorFeatureMap
14
+ from fla.ops.based import parallel_based
15
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
16
+
17
+
18
+ class BasedLinearAttention(nn.Module):
19
+
20
+ def __init__(
21
+ self,
22
+ hidden_size: int,
23
+ feature_dim: int = 16,
24
+ num_key_value_heads: int = 12,
25
+ num_heads: int = 12,
26
+ feature_name: str = "taylor_exp",
27
+ eps: float = 1e-12,
28
+ causal: bool = True,
29
+ mode: str = "parallel",
30
+ ):
31
+ super().__init__()
32
+
33
+ self.hidden_size = hidden_size
34
+ self.mode = mode
35
+ self.feature_name = feature_name
36
+ self.feature_dim = feature_dim
37
+ self.num_key_value_heads = num_key_value_heads
38
+ self.num_heads = num_heads
39
+ self.head_dim = self.hidden_size // self.num_key_value_heads
40
+ assert self.hidden_size % self.head_dim == 0
41
+ self.causal = causal
42
+
43
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
44
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
45
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
46
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
47
+ self.dropout = nn.Identity()
48
+ self.feature_map = TaylorFeatureMap(feature_dim)
49
+ self.eps = eps
50
+
51
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
52
+ mode = self.mode
53
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
54
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
55
+ if mode == "fused_chunk":
56
+ q, k = self.feature_map(q), self.feature_map(k)
57
+ o, _ = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
58
+ elif mode == 'chunk':
59
+ q, k = self.feature_map(q), self.feature_map(k)
60
+ o, _ = chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
61
+ elif mode == 'parallel':
62
+ assert q.shape[-1] <= 128
63
+ o = parallel_based(q, k, v, scale=1, use_norm=True, head_first=False)
64
+ o = rearrange(o, 'b t h d -> b t (h d)')
65
+ o = self.o_proj(o)
66
+ o = self.dropout(o)
67
+ return o
68
+
69
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
70
+
71
+ def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
72
+ """
73
+ x (torch.Tensor): tensor of shape (b, d, t)
74
+ y (torch.Tensor): tensor of shape (b, d, t)
75
+ """
76
+ # hidden_states = hidden_states.transpose(1, 2)
77
+ b, t, _ = hidden_states.size()
78
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
79
+
80
+ q = q.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2)
81
+ k = k.view(b, t, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
82
+ v = v.view(b, t, self.num_key_value_heads, self.head_dim).transpose(1, 2)
83
+
84
+ # Linear attention
85
+ q, k = self.feature_map(q), self.feature_map(k)
86
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
87
+
88
+ # Compute attention
89
+ if self.causal:
90
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
91
+ else:
92
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
93
+ y = rearrange(y, 'b h t d -> b t (h d)')
94
+ y = self.o_proj(y.to(hidden_states.dtype))
95
+ y = self.dropout(y)
96
+ return y.to(hidden_states.dtype)
fla/layers/bitattn.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RotaryEmbedding
17
+ from fla.modules.fused_bitlinear import FusedBitLinear
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
24
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
25
+ except ImportError:
26
+ warnings.warn(
27
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class BitAttention(nn.Module):
36
+
37
+ def __init__(
38
+ self,
39
+ hidden_size: int = 2048,
40
+ num_heads: int = 32,
41
+ num_kv_heads: Optional[int] = None,
42
+ window_size: Optional[int] = None,
43
+ rope_theta: Optional[float] = 10000.,
44
+ max_position_embeddings: Optional[int] = None,
45
+ norm_eps: float = 1e-5,
46
+ layer_idx: int = None
47
+ ):
48
+ super().__init__()
49
+
50
+ self.num_heads = num_heads
51
+ if num_kv_heads is None:
52
+ self.num_kv_heads = self.num_heads
53
+ else:
54
+ self.num_kv_heads = num_kv_heads
55
+ self.num_kv_groups = num_heads // self.num_kv_heads
56
+ self.hidden_size = hidden_size
57
+ self.head_dim = self.hidden_size // self.num_heads
58
+ self.kv_dim = self.num_kv_heads * self.head_dim
59
+ self.kv_dim = self.num_kv_heads * self.head_dim
60
+ self.window_size = window_size
61
+ self.rope_theta = rope_theta
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.layer_idx = layer_idx
64
+
65
+ self.q_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
66
+ self.k_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
67
+ self.v_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
68
+ self.o_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
69
+
70
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.LongTensor] = None,
76
+ past_key_values: Optional[Cache] = None,
77
+ output_attentions: bool = False,
78
+ use_cache: bool = False,
79
+ **kwargs,
80
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
81
+ if attention_mask is not None:
82
+ assert len(attention_mask.shape) == 2, (
83
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
84
+ "for padding purposes (0 indicating padding). "
85
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
86
+ )
87
+
88
+ batch_size, q_len, _ = hidden_states.size()
89
+
90
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
91
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
92
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
93
+
94
+ # equivalent to cu_seqlens in `flash_attn`
95
+ cu_seqlens = kwargs.get('cu_seqlens', None)
96
+
97
+ seqlen_offset, max_seqlen = 0, q_len
98
+ if past_key_values is not None:
99
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
100
+ max_seqlen = q.shape[1] + seqlen_offset
101
+
102
+ if attention_mask is not None:
103
+ # to deliminate the offsets of padding tokens
104
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
105
+ max_seqlen = q.shape[1] + max(seqlen_offset)
106
+
107
+ if self.max_position_embeddings is not None:
108
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
109
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
110
+
111
+ if past_key_values is not None:
112
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
113
+ k_cached, v_cached = past_key_values.update(
114
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
115
+ layer_idx=self.layer_idx,
116
+ offset=q_len,
117
+ cache_kwargs=dict(window_size=self.window_size)
118
+ )['attn_state']
119
+ if cache_has_content:
120
+ k, v = k_cached, v_cached
121
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
122
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
123
+
124
+ if flash_attn_func is None:
125
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
126
+
127
+ # Contains at least one padding token in the sequence
128
+ if attention_mask is not None:
129
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
130
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
131
+ max_seqlen_q, max_seqlen_k = max_seq_lens
132
+ o = flash_attn_varlen_func(
133
+ q, k, v,
134
+ cu_seqlens_q=cu_seqlens_q,
135
+ cu_seqlens_k=cu_seqlens_k,
136
+ max_seqlen_q=max_seqlen_q,
137
+ max_seqlen_k=max_seqlen_k,
138
+ causal=True,
139
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
140
+ )
141
+ o = pad_input(o, indices_q, batch_size, q_len)
142
+ elif cu_seqlens is not None:
143
+ o = flash_attn_varlen_func(
144
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
145
+ cu_seqlens_q=cu_seqlens,
146
+ cu_seqlens_k=cu_seqlens,
147
+ max_seqlen_q=max_seqlen,
148
+ max_seqlen_k=max_seqlen,
149
+ causal=True,
150
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
151
+ ).unsqueeze(0)
152
+ else:
153
+ o = flash_attn_func(
154
+ q, k, v,
155
+ causal=True,
156
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
157
+ )
158
+ o = o.reshape(batch_size, q_len, -1)
159
+ o = self.o_proj(o)
160
+
161
+ if not output_attentions:
162
+ attentions = None
163
+
164
+ return o, attentions, past_key_values
165
+
166
+ def _upad_input(self, q, k, v, attention_mask, q_len):
167
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
168
+ cache_mask = attention_mask[:, -seq_len:]
169
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
170
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
171
+ max_seqlen_k = seqlens.max().item()
172
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
173
+
174
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
175
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
176
+ if q_len == seq_len:
177
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
178
+ cu_seqlens_q = cu_seqlens_k
179
+ max_seqlen_q = max_seqlen_k
180
+ indices_q = indices_k
181
+ elif q_len == 1:
182
+ max_seqlen_q = 1
183
+ # There is a memcpy here, that is very bad.
184
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
185
+ indices_q = cu_seqlens_q[:-1]
186
+ q = q.squeeze(1)
187
+ else:
188
+ # The -q_len: slice assumes left padding.
189
+ attention_mask = attention_mask[:, -q_len:]
190
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
191
+
192
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/forgetting_attn.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import rearrange
13
+ from transformers.utils import logging
14
+
15
+ from fla.modules import GroupNorm
16
+ from fla.ops.forgetting_attn.parallel import parallel_forgetting_attn
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class ForgettingAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ hidden_size: int = 2048,
30
+ num_heads: int = 32,
31
+ num_kv_heads: Optional[int] = None,
32
+ qkv_bias: bool = False,
33
+ qk_norm: bool = False,
34
+ window_size: Optional[int] = None,
35
+ use_output_gate: bool = False,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = self.hidden_size // self.num_heads
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+ self.qk_norm = qk_norm
51
+
52
+ self.window_size = window_size
53
+ self.use_output_gate = use_output_gate
54
+ self.layer_idx = layer_idx
55
+
56
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
57
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
58
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
59
+ self.f_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
60
+
61
+ if use_output_gate:
62
+ self.g_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
63
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
64
+
65
+ if qk_norm:
66
+ self.q_norm = GroupNorm(
67
+ num_groups=self.num_heads,
68
+ hidden_size=self.hidden_size,
69
+ is_rms_norm=True,
70
+ )
71
+ self.k_norm = GroupNorm(
72
+ num_groups=self.num_kv_heads,
73
+ hidden_size=self.kv_dim,
74
+ is_rms_norm=True,
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.LongTensor] = None,
81
+ past_key_values: Optional[Cache] = None,
82
+ output_attentions: bool = False,
83
+ use_cache: bool = False,
84
+ **kwargs,
85
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
86
+ if attention_mask is not None:
87
+ assert len(attention_mask.shape) == 2, (
88
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
89
+ "for padding purposes (0 indicating padding). "
90
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
91
+ )
92
+
93
+ cu_seqlens = kwargs.get('cu_seqlens', None)
94
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
95
+ f = F.logsigmoid(self.f_proj(hidden_states).float())
96
+ if self.qk_norm:
97
+ q, k = self.q_norm(q), self.k_norm(k)
98
+
99
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
100
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
101
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
102
+
103
+ o = parallel_forgetting_attn(q, k, v, f, cu_seqlens=cu_seqlens)
104
+ o = rearrange(o, '... h d -> ... (h d)')
105
+ if self.use_output_gate:
106
+ o = self.g_proj(hidden_states).sigmoid() * o
107
+ o = self.o_proj(o)
108
+
109
+ return o, None, past_key_values
fla/layers/gated_deltaproduct.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
12
+ from fla.ops.delta_rule import chunk_delta_rule
13
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule
14
+
15
+ if TYPE_CHECKING:
16
+ from transformers.processing_utils import Unpack
17
+
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ def elu_p1(x):
22
+ return (F.elu(x, 1.0, False) + 1.0).to(x)
23
+
24
+
25
+ def sum_norm(x):
26
+ return (x / x.sum(-1, keepdim=True)).to(x)
27
+
28
+
29
+ def interleave_multiple_sequences(*sequences):
30
+ """
31
+ Interleave multiple sequences together.
32
+ For example, with sequences [A1, A2], [B1, B2], [C1, C2],
33
+ returns [A1, B1, C1, A2, B2, C2]
34
+ """
35
+ if isinstance(sequences[0], (list, tuple)):
36
+ sequences = sequences[0]
37
+
38
+ if len(sequences) == 1:
39
+ return sequences[0]
40
+
41
+ # All sequences should have the same shape
42
+ assert all(s.shape == sequences[0].shape for s in sequences)
43
+
44
+ # Get the original shape
45
+ batch_size, seq_len, *rest = sequences[0].shape
46
+
47
+ # Stack sequences along a new dimension
48
+ stacked = torch.stack(sequences, dim=2)
49
+
50
+ # Reshape to interleave
51
+ reshaped = stacked.view(batch_size, seq_len * len(sequences), *rest)
52
+
53
+ return reshaped
54
+
55
+
56
+ class GatedDeltaProduct(nn.Module):
57
+ """
58
+ Generalized version of GatedDoubleDeltaNet that supports arbitrary number of householder transformations.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ hidden_size: int = 2048,
64
+ expand_v: float = 2,
65
+ head_dim: int = 256,
66
+ num_heads: int = 6,
67
+ num_householder: int = 2, # New parameter for number of householder transformations
68
+ mode: str = "chunk",
69
+ use_gate: bool = True,
70
+ use_forget_gate: bool = True, # when true Gated DeltaProduct, when false DeltaProduct
71
+ use_short_conv: bool = True,
72
+ conv_size: int = 4,
73
+ conv_bias: bool = False,
74
+ layer_idx: int | None = None,
75
+ norm_eps: float = 1e-5,
76
+ allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1]
77
+ **kwargs,
78
+ ) -> None:
79
+ super().__init__()
80
+
81
+ self.mode = mode
82
+ self.hidden_size = hidden_size
83
+ self.expand_v = expand_v
84
+ self.use_gate = use_gate
85
+ self.use_short_conv = use_short_conv
86
+ self.conv_size = conv_size
87
+ self.conv_bias = conv_bias
88
+ self.head_dim = head_dim
89
+ self.num_heads = num_heads
90
+ self.num_householder = num_householder
91
+ self.allow_neg_eigval = allow_neg_eigval
92
+ self.use_forget_gate = use_forget_gate
93
+ self.key_dim = self.num_heads * self.head_dim
94
+ self.value_dim = int(self.key_dim * self.expand_v)
95
+ self.head_qk_dim = head_dim
96
+ self.head_v_dim = int(head_dim * self.expand_v)
97
+ self.layer_idx = layer_idx
98
+ self.silu = nn.SiLU()
99
+ assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`."
100
+ # Create multiple projection layers for each householder transformation
101
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
102
+
103
+ self.k_projs = nn.ModuleList(
104
+ [
105
+ nn.Linear(hidden_size, self.key_dim, bias=False)
106
+ for _ in range(num_householder)
107
+ ]
108
+ )
109
+ self.v_projs = nn.ModuleList(
110
+ [
111
+ nn.Linear(hidden_size, self.value_dim, bias=False)
112
+ for _ in range(num_householder)
113
+ ]
114
+ )
115
+ self.b_projs = nn.ModuleList(
116
+ [
117
+ nn.Linear(hidden_size, self.num_heads, bias=False)
118
+ for _ in range(num_householder)
119
+ ]
120
+ )
121
+ if use_short_conv:
122
+ self.q_conv1ds = nn.ModuleList(
123
+ [
124
+ ShortConvolution(
125
+ hidden_size=self.key_dim,
126
+ kernel_size=conv_size,
127
+ activation="silu",
128
+ )
129
+ for _ in range(num_householder)
130
+ ]
131
+ )
132
+ self.k_conv1ds = nn.ModuleList(
133
+ [
134
+ ShortConvolution(
135
+ hidden_size=self.key_dim,
136
+ kernel_size=conv_size,
137
+ activation="silu",
138
+ )
139
+ for _ in range(num_householder)
140
+ ]
141
+ )
142
+ self.v_conv1ds = nn.ModuleList(
143
+ [
144
+ ShortConvolution(
145
+ hidden_size=self.value_dim,
146
+ kernel_size=conv_size,
147
+ activation="silu",
148
+ )
149
+ for _ in range(num_householder)
150
+ ]
151
+ )
152
+
153
+ if self.use_forget_gate:
154
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
155
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
156
+ A_log = torch.log(A)
157
+ self.A_log = nn.Parameter(A_log)
158
+ self.A_log._no_weight_decay = True
159
+
160
+ # Initialize dt parameters
161
+ dt_min = 0.001
162
+ dt_max = 0.1
163
+ dt_init_floor = 1e-4
164
+ dt = torch.exp(
165
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
166
+ + math.log(dt_min)
167
+ )
168
+ dt = torch.clamp(dt, min=dt_init_floor)
169
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
170
+ self.dt_bias = nn.Parameter(inv_dt)
171
+ self.dt_bias._no_weight_decay = True
172
+
173
+ if use_gate:
174
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
175
+ self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
176
+ else:
177
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
178
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
179
+ self.k_id = torch.nn.Identity()
180
+ self.apply(self._initialize_weights)
181
+
182
+ def _initialize_weights(self, module: nn.Module):
183
+ if getattr(module, "_is_hf_initialized", False):
184
+ return
185
+ if isinstance(module, nn.Linear):
186
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
187
+ if module.bias is not None:
188
+ nn.init.zeros_(module.bias)
189
+ module._is_hf_initialized = True
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ attention_mask: Optional[torch.Tensor] = None,
195
+ past_key_values: Optional[Cache] = None,
196
+ use_cache: Optional[bool] = False,
197
+ output_attentions: Optional[bool] = False,
198
+ **kwargs: Unpack[Dict],
199
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
200
+ if attention_mask is not None:
201
+ assert len(attention_mask.shape) == 2, (
202
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
203
+ "for padding purposes (0 indicating padding)."
204
+ )
205
+
206
+ mode = (
207
+ "chunk" # 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
208
+ )
209
+ if self.training:
210
+ assert mode == "chunk", "Only chunk mode is supported in training."
211
+
212
+ last_state = None
213
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
214
+ last_state = past_key_values[self.layer_idx]
215
+
216
+ # Process each householder transformation
217
+ ks, vs, betas = [], [], []
218
+ conv_states = []
219
+
220
+ for i in range(self.num_householder):
221
+ if self.use_short_conv:
222
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
223
+ if last_state is not None:
224
+ conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"][
225
+ i
226
+ ]
227
+ conv_mask = (
228
+ attention_mask[:, -hidden_states.shape[1]:]
229
+ if attention_mask is not None
230
+ else None
231
+ )
232
+
233
+ k, conv_state_k = self.k_conv1ds[i](
234
+ x=self.k_projs[i](hidden_states),
235
+ mask=conv_mask,
236
+ cache=conv_state_k,
237
+ output_final_state=use_cache,
238
+ )
239
+ v, conv_state_v = self.v_conv1ds[i](
240
+ x=self.v_projs[i](hidden_states),
241
+ mask=conv_mask,
242
+ cache=conv_state_v,
243
+ output_final_state=use_cache,
244
+ )
245
+ conv_states.append((conv_state_q, conv_state_k, conv_state_v))
246
+ else:
247
+ k = self.silu(self.k_projs[i](hidden_states))
248
+ v = self.silu(self.v_projs[i](hidden_states))
249
+
250
+ ks.append(k)
251
+ vs.append(v)
252
+
253
+ beta = self.b_projs[i](
254
+ hidden_states
255
+ ).sigmoid() # bs, sequence_length, num_heads
256
+ if attention_mask is not None:
257
+ beta = beta.mul(attention_mask[:, -hidden_states.shape[1]:, None])
258
+ if self.allow_neg_eigval:
259
+ beta = beta * 2
260
+ betas.append(beta)
261
+
262
+ if self.use_short_conv:
263
+ q, conv_state_q = self.q_conv1ds[0](
264
+ x=self.q_proj(hidden_states),
265
+ mask=conv_mask,
266
+ cache=conv_state_q,
267
+ output_final_state=use_cache,
268
+ )
269
+ else:
270
+ q = self.silu(self.q_proj(hidden_states))
271
+ q = interleave_multiple_sequences(
272
+ [torch.zeros_like(q)] * (self.num_householder - 1) + [q]
273
+ )
274
+ # Interleave all sequences
275
+ k = interleave_multiple_sequences(ks)
276
+ v = interleave_multiple_sequences(vs)
277
+ beta = interleave_multiple_sequences(betas)
278
+
279
+ q, k, v = (
280
+ rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in (q, k, v)
281
+ )
282
+
283
+ recurrent_state = (
284
+ last_state["recurrent_state"] if last_state is not None else None
285
+ )
286
+ offsets = kwargs.get("offsets")
287
+
288
+ if mode == "chunk":
289
+ if self.use_forget_gate:
290
+ g = -self.A_log.float().exp() * F.softplus(
291
+ self.a_proj(hidden_states).float() + self.dt_bias
292
+ )
293
+ if attention_mask is not None:
294
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
295
+
296
+ # Interleave g with zeros for non-first transformations
297
+ g = interleave_multiple_sequences(
298
+ [g] + [torch.zeros_like(g)] * (self.num_householder - 1)
299
+ )
300
+
301
+ o, recurrent_state = chunk_gated_delta_rule(
302
+ q=q,
303
+ k=k,
304
+ v=v,
305
+ g=g,
306
+ beta=beta,
307
+ initial_state=recurrent_state,
308
+ output_final_state=use_cache,
309
+ cu_seqlens=offsets,
310
+ head_first=False,
311
+ use_qk_l2norm_in_kernel=True
312
+ )
313
+ else:
314
+ o, recurrent_state = chunk_delta_rule(
315
+ q=q,
316
+ k=k,
317
+ v=v,
318
+ beta=beta,
319
+ initial_state=recurrent_state,
320
+ output_final_state=use_cache,
321
+ cu_seqlens=offsets,
322
+ head_first=False,
323
+ use_qk_l2norm_in_kernel=True
324
+ )
325
+ else:
326
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
327
+
328
+ # Take every nth element for n householder transformations
329
+ o = o[:, self.num_householder - 1:: self.num_householder, :]
330
+
331
+ if past_key_values is not None:
332
+ past_key_values.update(
333
+ recurrent_state=recurrent_state,
334
+ conv_state=conv_states if self.use_short_conv else None,
335
+ layer_idx=self.layer_idx,
336
+ offset=q.shape[2],
337
+ )
338
+
339
+ if self.use_gate:
340
+ g = rearrange(
341
+ self.g_proj(hidden_states),
342
+ "... (h d) -> ... h d",
343
+ h=self.num_heads,
344
+ )
345
+ o = self.o_norm(o, g)
346
+ else:
347
+ o = self.o_norm(o)
348
+ o = rearrange(o, "b t h d -> b t (h d)")
349
+ o = self.o_proj(o)
350
+
351
+ return o, None, past_key_values
fla/ops/utils/cumsum.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, input_guard
11
+
12
+ BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps)
21
+ for num_warps in [1, 2, 4, 8]
22
+ ],
23
+ key=['BT']
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def chunk_local_cumsum_scalar_kernel(
27
+ s,
28
+ o,
29
+ offsets,
30
+ indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ BT: tl.constexpr,
34
+ HEAD_FIRST: tl.constexpr,
35
+ USE_OFFSETS: tl.constexpr,
36
+ REVERSE: tl.constexpr
37
+ ):
38
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if USE_OFFSETS:
41
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
42
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
43
+ T = eos - bos
44
+ else:
45
+ bos, eos = i_b * T, i_b * T + T
46
+
47
+ if HEAD_FIRST:
48
+ p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
49
+ p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
50
+ else:
51
+ p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
52
+ p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
53
+ # [BT]
54
+ b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
55
+ b_o = tl.cumsum(b_s, axis=0)
56
+ if REVERSE:
57
+ b_z = tl.sum(b_s, axis=0)
58
+ b_o = -b_o + b_z[None] + b_s
59
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
60
+
61
+
62
+ @triton.heuristics({
63
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
64
+ })
65
+ @triton.autotune(
66
+ configs=[
67
+ triton.Config({'BS': BS}, num_warps=num_warps)
68
+ for BS in BS_LIST
69
+ for num_warps in [2, 4, 8]
70
+ ],
71
+ key=['S', 'BT'],
72
+ )
73
+ @triton.jit(do_not_specialize=['T'])
74
+ def chunk_local_cumsum_vector_kernel(
75
+ s,
76
+ o,
77
+ offsets,
78
+ indices,
79
+ T,
80
+ H: tl.constexpr,
81
+ S: tl.constexpr,
82
+ BT: tl.constexpr,
83
+ BS: tl.constexpr,
84
+ HEAD_FIRST: tl.constexpr,
85
+ USE_OFFSETS: tl.constexpr,
86
+ REVERSE: tl.constexpr
87
+ ):
88
+ i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
89
+ i_b, i_h = i_bh // H, i_bh % H
90
+ if USE_OFFSETS:
91
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
92
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
93
+ T = eos - bos
94
+ else:
95
+ bos, eos = i_b * T, i_b * T + T
96
+
97
+ o_i = tl.arange(0, BT)
98
+ if REVERSE:
99
+ m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
100
+ else:
101
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
102
+
103
+ if HEAD_FIRST:
104
+ p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
105
+ p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
106
+ else:
107
+ p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
108
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
109
+ # [BT, BS]
110
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
111
+ b_o = tl.dot(m_s, b_s, allow_tf32=False)
112
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
113
+
114
+
115
+ @triton.heuristics({
116
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
117
+ })
118
+ @triton.autotune(
119
+ configs=[
120
+ triton.Config({'BT': 16}, num_warps=2),
121
+ triton.Config({'BT': 32}, num_warps=4),
122
+ triton.Config({'BT': 32}, num_warps=2),
123
+ triton.Config({'BT': 64}, num_warps=8),
124
+ triton.Config({'BT': 64}, num_warps=4),
125
+ ],
126
+ key=[]
127
+ )
128
+ @triton.jit(do_not_specialize=['T'])
129
+ def chunk_global_cumsum_scalar_kernel(
130
+ s,
131
+ o,
132
+ offsets,
133
+ T,
134
+ H: tl.constexpr,
135
+ BT: tl.constexpr,
136
+ HEAD_FIRST: tl.constexpr,
137
+ USE_OFFSETS: tl.constexpr,
138
+ REVERSE: tl.constexpr
139
+ ):
140
+ i_bh = tl.program_id(0)
141
+ i_b, i_h = i_bh // H, i_bh % H
142
+ if USE_OFFSETS:
143
+ bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32)
144
+ else:
145
+ bos, eos = i_b * T, i_b * T + T
146
+ T = eos - bos
147
+
148
+ b_z = tl.zeros([], dtype=tl.float32)
149
+ NT = tl.cdiv(T, BT)
150
+ for i_c in range(NT):
151
+ i_t = NT-1-i_c if REVERSE else i_c
152
+ if HEAD_FIRST:
153
+ p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
154
+ p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
155
+ else:
156
+ p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
157
+ p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
158
+ b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
159
+ b_o = tl.cumsum(b_s, axis=0)
160
+ b_ss = tl.sum(b_s, 0)
161
+ if REVERSE:
162
+ b_o = -b_o + b_ss + b_s
163
+ b_o += b_z
164
+ if i_c >= 0:
165
+ b_z += b_ss
166
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
167
+
168
+
169
+ @triton.heuristics({
170
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
171
+ })
172
+ @triton.autotune(
173
+ configs=[
174
+ triton.Config({'BT': BT}, num_warps=num_warps)
175
+ for BT in [16, 32, 64]
176
+ for num_warps in [2, 4, 8]
177
+ ],
178
+ key=['S']
179
+ )
180
+ @triton.jit(do_not_specialize=['T'])
181
+ def chunk_global_cumsum_vector_kernel(
182
+ s,
183
+ z,
184
+ offsets,
185
+ T,
186
+ H: tl.constexpr,
187
+ S: tl.constexpr,
188
+ BT: tl.constexpr,
189
+ BS: tl.constexpr,
190
+ HEAD_FIRST: tl.constexpr,
191
+ USE_OFFSETS: tl.constexpr,
192
+ REVERSE: tl.constexpr
193
+ ):
194
+ i_s, i_bh = tl.program_id(0), tl.program_id(1)
195
+ i_b, i_h = i_bh // H, i_bh % H
196
+ if USE_OFFSETS:
197
+ bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32)
198
+ else:
199
+ bos, eos = i_b * T, i_b * T + T
200
+ T = eos - bos
201
+
202
+ o_i = tl.arange(0, BT)
203
+ if REVERSE:
204
+ m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
205
+ else:
206
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
207
+
208
+ b_z = tl.zeros([BS], dtype=tl.float32)
209
+ NT = tl.cdiv(T, BT)
210
+ for i_c in range(NT):
211
+ i_t = NT-1-i_c if REVERSE else i_c
212
+ if HEAD_FIRST:
213
+ p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
214
+ p_z = tl.make_block_ptr(z + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
215
+ else:
216
+ p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
217
+ p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
218
+ # [BT, BS]
219
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
220
+ b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)
221
+ tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))
222
+ if i_c >= 0:
223
+ b_z += tl.sum(b_s, 0)
224
+
225
+
226
+ def chunk_local_cumsum_scalar(
227
+ g: torch.Tensor,
228
+ chunk_size: int,
229
+ reverse: bool = False,
230
+ offsets: Optional[torch.Tensor] = None,
231
+ indices: Optional[torch.Tensor] = None,
232
+ head_first: bool = True,
233
+ output_dtype: Optional[torch.dtype] = torch.float
234
+ ) -> torch.Tensor:
235
+ if head_first:
236
+ B, H, T = g.shape
237
+ else:
238
+ B, T, H = g.shape
239
+ if offsets is not None:
240
+ B = len(offsets) - 1
241
+ assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
242
+ BT = chunk_size
243
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
244
+ g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
245
+ grid = (NT, B * H)
246
+ chunk_local_cumsum_scalar_kernel[grid](
247
+ g_org,
248
+ g,
249
+ offsets,
250
+ indices,
251
+ T=T,
252
+ H=H,
253
+ BT=BT,
254
+ HEAD_FIRST=head_first,
255
+ REVERSE=reverse
256
+ )
257
+ return g
258
+
259
+
260
+ def chunk_local_cumsum_vector(
261
+ g: torch.Tensor,
262
+ chunk_size: int,
263
+ reverse: bool = False,
264
+ offsets: Optional[torch.Tensor] = None,
265
+ indices: Optional[torch.Tensor] = None,
266
+ head_first: bool = True,
267
+ output_dtype: Optional[torch.dtype] = torch.float
268
+ ) -> torch.Tensor:
269
+ if head_first:
270
+ B, H, T, S = g.shape
271
+ else:
272
+ B, T, H, S = g.shape
273
+ BT = chunk_size
274
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
275
+ assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
276
+
277
+ g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
278
+ def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
279
+ # keep cummulative normalizer in fp32
280
+ # this kernel is equivalent to
281
+ # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
282
+ chunk_local_cumsum_vector_kernel[grid](
283
+ g_org,
284
+ g,
285
+ offsets,
286
+ indices,
287
+ T=T,
288
+ H=H,
289
+ S=S,
290
+ BT=BT,
291
+ HEAD_FIRST=head_first,
292
+ REVERSE=reverse
293
+ )
294
+ return g
295
+
296
+
297
+ @input_guard
298
+ def chunk_global_cumsum_scalar(
299
+ s: torch.Tensor,
300
+ dtype: Optional[torch.dtype] = None,
301
+ reverse: bool = False,
302
+ offsets: Optional[torch.Tensor] = None,
303
+ head_first: bool = True,
304
+ output_dtype: Optional[torch.dtype] = torch.float
305
+ ) -> torch.Tensor:
306
+ dtype = dtype or s.dtype
307
+ if head_first:
308
+ B, H, T = s.shape
309
+ else:
310
+ B, T, H = s.shape
311
+ if offsets is not None:
312
+ B = len(offsets) - 1
313
+ grid = (B * H,)
314
+ z = torch.empty_like(s, dtype=output_dtype or dtype)
315
+ chunk_global_cumsum_scalar_kernel[grid](
316
+ s,
317
+ z,
318
+ offsets,
319
+ T=T,
320
+ H=H,
321
+ HEAD_FIRST=head_first,
322
+ REVERSE=reverse
323
+ )
324
+ return z
325
+
326
+
327
+ @input_guard
328
+ def chunk_global_cumsum_vector(
329
+ s: torch.Tensor,
330
+ dtype: Optional[torch.dtype] = None,
331
+ reverse: bool = False,
332
+ offsets: Optional[torch.Tensor] = None,
333
+ head_first: bool = True,
334
+ output_dtype: Optional[torch.dtype] = torch.float
335
+ ) -> torch.Tensor:
336
+ dtype = dtype or s.dtype
337
+ if head_first:
338
+ B, H, T, S = s.shape
339
+ else:
340
+ B, T, H, S = s.shape
341
+ BS = min(32, triton.next_power_of_2(S))
342
+ if offsets is not None:
343
+ B = len(offsets) - 1
344
+ grid = (triton.cdiv(S, BS), B * H)
345
+ z = torch.empty_like(s, dtype=output_dtype or dtype)
346
+ chunk_global_cumsum_vector_kernel[grid](
347
+ s,
348
+ z,
349
+ offsets,
350
+ T=T,
351
+ H=H,
352
+ S=S,
353
+ BS=BS,
354
+ HEAD_FIRST=head_first,
355
+ REVERSE=reverse
356
+ )
357
+ return z
358
+
359
+
360
+ @input_guard
361
+ def chunk_global_cumsum(
362
+ s: torch.Tensor,
363
+ dtype: Optional[torch.dtype] = None,
364
+ reverse: bool = False,
365
+ offsets: Optional[torch.Tensor] = None,
366
+ head_first: bool = True,
367
+ output_dtype: Optional[torch.dtype] = torch.float
368
+ ) -> torch.Tensor:
369
+ if offsets is not None:
370
+ assert s.shape[0] == 1, "Only batch size 1 is supported when offsets are provided"
371
+ if len(s.shape) == 3:
372
+ return chunk_global_cumsum_scalar(s, dtype, reverse, offsets, head_first, output_dtype)
373
+ elif len(s.shape) == 4:
374
+ return chunk_global_cumsum_vector(s, dtype, reverse, offsets, head_first, output_dtype)
375
+ else:
376
+ raise ValueError(f"Unsupported input shape {s.shape}. "
377
+ f"which should be [B, H, T]/[B, H, T, D] if `head_first=True` "
378
+ f"or [B, T, H]/[B, T, H, D] otherwise")
379
+
380
+
381
+ @input_guard
382
+ def chunk_local_cumsum(
383
+ g: torch.Tensor,
384
+ chunk_size: int,
385
+ reverse: bool = False,
386
+ offsets: Optional[torch.Tensor] = None,
387
+ indices: Optional[torch.Tensor] = None,
388
+ head_first: bool = True,
389
+ output_dtype: Optional[torch.dtype] = torch.float
390
+ ) -> torch.Tensor:
391
+ if offsets is not None:
392
+ assert g.shape[0] == 1, "Only batch size 1 is supported when offsets are provided"
393
+ if len(g.shape) == 3:
394
+ return chunk_local_cumsum_scalar(g, chunk_size, reverse, offsets, indices, head_first, output_dtype)
395
+ elif len(g.shape) == 4:
396
+ return chunk_local_cumsum_vector(g, chunk_size, reverse, offsets, indices, head_first, output_dtype)
397
+ else:
398
+ raise ValueError(f"Unsupported input shape {g.shape}. "
399
+ f"which should be (B, H, T, dim) if `head_first=True` "
400
+ f"or (batch_size, num_heads, seq_len) otherwise")
flame/components/__init__.py ADDED
File without changes
flame/config_manager.py ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import sys
9
+ from collections import defaultdict
10
+ from typing import Tuple
11
+
12
+ import torch
13
+
14
+ try:
15
+ import tomllib
16
+ except ModuleNotFoundError:
17
+ import tomli as tomllib
18
+
19
+ from torchtitan.tools.logging import logger
20
+
21
+ TORCH_DTYPE_MAP = {
22
+ "float16": torch.float16,
23
+ "float32": torch.float32,
24
+ "bfloat16": torch.bfloat16,
25
+ }
26
+
27
+
28
+ def string_list(raw_arg):
29
+ """Comma-separated string list argument."""
30
+ return [s.strip() for s in raw_arg.split(",") if s.strip()]
31
+
32
+
33
+ def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
34
+ section, name = fullargname.split(".")
35
+ # Split string list which are still raw strings.
36
+ if (
37
+ section in args_dict
38
+ and name in args_dict[section]
39
+ and isinstance(args_dict[section][name], str)
40
+ ):
41
+ sec = args_dict[section]
42
+ sec[name] = string_list(sec[name])
43
+
44
+
45
+ class JobConfig:
46
+ """
47
+ A helper class to manage the train configuration.
48
+ Semantics:
49
+ - Default config is loaded from a toml file. If no toml file is provided,
50
+ then the default config is loaded from argparse defaults.
51
+ - if toml file has missing keys, they are filled with argparse defaults.
52
+ - if additional explicit cmd args are provided in addition to the toml
53
+ file, they will override the toml config and the argparse defaults
54
+
55
+ precedence order: cmdline > toml > argparse default
56
+
57
+ Arg parsing semantics:
58
+
59
+ Each argument starts with <prefix>_ which is the section name in the toml file
60
+ followed by name of the option in the toml file. For ex,
61
+ model.name translates to:
62
+ [model]
63
+ name
64
+ in the toml file
65
+ """
66
+
67
+ def __init__(self):
68
+ self.args_dict = None
69
+ # main parser
70
+ self.parser = argparse.ArgumentParser(description="torchtitan arg parser.")
71
+
72
+ self.parser.add_argument(
73
+ "--job.config_file",
74
+ type=str,
75
+ default=None,
76
+ help="Job config file",
77
+ )
78
+
79
+ # job level configs
80
+ self.parser.add_argument(
81
+ "--job.dump_folder",
82
+ type=str,
83
+ default="./torchtitan/outputs",
84
+ help="Folder to dump job outputs",
85
+ )
86
+ self.parser.add_argument(
87
+ "--job.description",
88
+ type=str,
89
+ default="default job",
90
+ help="Description of the job",
91
+ )
92
+ self.parser.add_argument(
93
+ "--job.use_for_integration_test",
94
+ action="store_true",
95
+ help="Add this config to the integration test suite",
96
+ )
97
+ self.parser.add_argument(
98
+ "--job.print_args",
99
+ action="store_true",
100
+ help="Print the args to terminal",
101
+ )
102
+
103
+ # model configs
104
+ self.parser.add_argument(
105
+ "--model.name",
106
+ type=str,
107
+ default="fla",
108
+ help="Which model to train",
109
+ )
110
+ self.parser.add_argument(
111
+ "--model.config",
112
+ type=str,
113
+ default="fla-hub/transformer-1.3B-100B",
114
+ help="Path to the model config",
115
+ )
116
+ self.parser.add_argument(
117
+ "--model.tokenizer_path",
118
+ type=str,
119
+ default="fla-hub/transformer-1.3B-100B",
120
+ help="Tokenizer path",
121
+ )
122
+ self.parser.add_argument(
123
+ "--model.converters",
124
+ type=string_list,
125
+ nargs="+",
126
+ default=[],
127
+ help="""
128
+ Comma separated list of converters to apply to the model.
129
+ For instance, the `float8` converter swaps `torch.nn.Linear`
130
+ with `Float8Linear`. This feature requires you to install 'torchao'
131
+ which can be found here: https://github.com/pytorch/ao
132
+ """,
133
+ )
134
+ self.parser.add_argument(
135
+ "--model.print_after_conversion",
136
+ action="store_true",
137
+ help="""
138
+ If true, model definition will be printed to stdout after all model
139
+ converters have been applied.
140
+ """,
141
+ )
142
+
143
+ # profiling configs
144
+ self.parser.add_argument(
145
+ "--profiling.enable_profiling",
146
+ action="store_true",
147
+ help="Whether to enable pytorch profiler",
148
+ )
149
+ self.parser.add_argument(
150
+ "--profiling.save_traces_folder",
151
+ type=str,
152
+ default="profile_traces",
153
+ help="Trace files location",
154
+ )
155
+ self.parser.add_argument(
156
+ "--profiling.profile_freq",
157
+ type=int,
158
+ default=10,
159
+ help="How often to collect profiler traces, in iterations",
160
+ )
161
+ self.parser.add_argument(
162
+ "--profiling.enable_memory_snapshot",
163
+ action="store_true",
164
+ help="Whether to dump memory snapshot",
165
+ )
166
+ self.parser.add_argument(
167
+ "--profiling.save_memory_snapshot_folder",
168
+ type=str,
169
+ default="memory_snapshot",
170
+ help="Memeory snapshot files location",
171
+ )
172
+
173
+ # optimizer configs
174
+ self.parser.add_argument(
175
+ "--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
176
+ )
177
+ self.parser.add_argument(
178
+ "--optimizer.eps",
179
+ type=float,
180
+ default=1e-8,
181
+ help="Epsilon value for the optimizer.",
182
+ )
183
+ self.parser.add_argument(
184
+ "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
185
+ )
186
+ self.parser.add_argument(
187
+ "--optimizer.implementation",
188
+ type=str,
189
+ default="fused",
190
+ choices=["for-loop", "foreach", "fused"],
191
+ help="""
192
+ Specify which optimizer implementation to use:
193
+ - 'fused': Use fused implementation (CUDA only) for best performance.
194
+ - 'foreach': Use some horizontal fusion of tensors for better performance.
195
+ - 'for-loop': Use the default implementation for the optimizer (slowest).
196
+ - more info: https://pytorch.org/docs/stable/optim.html
197
+ """,
198
+ )
199
+ self.parser.add_argument(
200
+ "--optimizer.early_step_in_backward",
201
+ action="store_true",
202
+ help="""
203
+ Whether to apply optimizer in the backward. Caution, optimizer_in_backward
204
+ is not compatible with gradients clipping, users should not call
205
+ register_post_accumulate_grad_hook after the optimizer is built.""",
206
+ )
207
+
208
+ # lr scheduler configs
209
+ self.parser.add_argument(
210
+ "--lr_scheduler.warmup_steps",
211
+ type=int,
212
+ default=200,
213
+ help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
214
+ )
215
+ self.parser.add_argument(
216
+ "--lr_scheduler.decay_ratio",
217
+ type=float,
218
+ default=None,
219
+ help="""
220
+ Controls the proportion of the training steps allocated to the learning rate decay phase.
221
+
222
+ If `None`, the learning rate will begin decaying immediately after the warmup period.
223
+ Otherwise, the learning rate will remain stable after the warmup period and
224
+ only start decaying during the last `decay_ratio` portion of the total training steps.
225
+
226
+ This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395.
227
+ """,
228
+ )
229
+ self.parser.add_argument(
230
+ "--lr_scheduler.decay_type",
231
+ type=str,
232
+ default="linear",
233
+ choices=["linear", "sqrt", "cosine"],
234
+ help="""
235
+ Learning rate decay type to use during training:
236
+ - 'linear': linearly decays learning rate from initial to final value
237
+ - 'sqrt': decays learning rate following a 1 minus square root curve
238
+ - 'cosine': smoothly decays learning rate following a cosine curve
239
+ """,
240
+ )
241
+ self.parser.add_argument(
242
+ "--lr_scheduler.lr_min",
243
+ type=float,
244
+ default=0.0,
245
+ help="""
246
+ Min lr ratio for lr scheduler.
247
+
248
+ If provided, the range of decay factor is scaled from 1 to `lr_min`
249
+ to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
250
+ """,
251
+ )
252
+
253
+ # training configs
254
+ self.parser.add_argument(
255
+ "--training.batch_size", type=int, default=8, help="Batch size"
256
+ )
257
+ self.parser.add_argument(
258
+ "--training.seq_len", type=int, default=2048, help="Sequence length"
259
+ )
260
+ self.parser.add_argument(
261
+ "--training.context_len",
262
+ type=int,
263
+ default=2048,
264
+ help="Max length allowed for each sequence",
265
+ )
266
+ self.parser.add_argument(
267
+ "--training.varlen",
268
+ action="store_true",
269
+ help="Whether to take sequences of variable length as input",
270
+ )
271
+ self.parser.add_argument(
272
+ "--training.gradient_accumulation_steps",
273
+ type=int,
274
+ default=1,
275
+ help="Number of steps to accumulate gradients before updating parameters",
276
+ )
277
+ self.parser.add_argument(
278
+ "--training.steps",
279
+ type=int,
280
+ default=10000,
281
+ help="How many train steps to run",
282
+ )
283
+ self.parser.add_argument(
284
+ "--training.max_norm",
285
+ type=float,
286
+ default=1.0,
287
+ help="Max norm for gradient clipping",
288
+ )
289
+ self.parser.add_argument(
290
+ "--training.skip_nan_inf",
291
+ action="store_true",
292
+ help="Skip batch updates when NaN or INF gradients are encountered during training",
293
+ )
294
+ self.parser.add_argument(
295
+ "--training.dataset",
296
+ default="HuggingFaceFW/fineweb-edu",
297
+ help="Dataset to use, with comma separated values",
298
+ )
299
+ self.parser.add_argument(
300
+ "--training.dataset_name",
301
+ default=None,
302
+ help="The name of the dataset config, with comma separated values if provided",
303
+ )
304
+ self.parser.add_argument(
305
+ "--training.dataset_split",
306
+ default=None,
307
+ help="Dataset split to use, with comma separated values if provided",
308
+ )
309
+ self.parser.add_argument(
310
+ "--training.data_dir",
311
+ default=None,
312
+ help="Data dirs to use, with comma separated values if provided",
313
+ )
314
+ self.parser.add_argument(
315
+ "--training.data_files",
316
+ default=None,
317
+ help="Data files to use, with comma separated values if provided",
318
+ )
319
+ self.parser.add_argument(
320
+ "--training.data_probs",
321
+ default=None,
322
+ help="Data sampling probabilities, with comma separated values if provided",
323
+ )
324
+ self.parser.add_argument(
325
+ "--training.streaming",
326
+ action="store_true",
327
+ help="Whether to load dataset in streaming mode, used for huge dataset",
328
+ )
329
+ self.parser.add_argument(
330
+ "--training.num_workers",
331
+ type=int,
332
+ default=32,
333
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
334
+ )
335
+ self.parser.add_argument(
336
+ "--training.prefetch_factor",
337
+ type=int,
338
+ default=2,
339
+ help="Number of batches loaded in advance by each worker."
340
+ "2 means there will be a total of 2 * num_workers batches prefetched across all workers.",
341
+ )
342
+ self.parser.add_argument(
343
+ "--training.data_parallel_replicate_degree",
344
+ type=int,
345
+ default=1,
346
+ help="""
347
+ The `data_parallel_replicate_degree` argument specifies the degree of
348
+ data parallelism for weight replication. When this value is greater
349
+ than 1, weights will be replicated across `data_parallel_replicate_degree`
350
+ ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
351
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
352
+ parallelism method used is DDP (Distributed Data Parallelism).
353
+ 1 means disabled.""",
354
+ )
355
+ self.parser.add_argument(
356
+ "--training.data_parallel_shard_degree",
357
+ type=int,
358
+ default=-1,
359
+ help="""
360
+ The `data_parallel_shard_degree` argument specifies the degree of data
361
+ parallelism for weight sharding. When this value is greater than 1, weights
362
+ will be sharded across `data_parallel_shard_degree` ranks. If
363
+ `data_parallel_replicate_degree` is also greater than 1, the parallelism
364
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
365
+ parallelism method used is FSDP (Fully Sharded Data Parallelism).
366
+
367
+ -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
368
+ only `data_parallel_shard_degree` can be negative. 1 means disabled.""",
369
+ )
370
+ self.parser.add_argument(
371
+ "--training.enable_cpu_offload",
372
+ action="store_true",
373
+ help="""
374
+ Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
375
+ )
376
+ self.parser.add_argument(
377
+ "--training.tensor_parallel_degree",
378
+ type=int,
379
+ default=1,
380
+ help="Tensor Parallelism degree. 1 means disabled.",
381
+ )
382
+ self.parser.add_argument(
383
+ "--training.disable_loss_parallel",
384
+ action="store_true",
385
+ help="Whether to apply loss parallel when sequence parallel is enabled",
386
+ )
387
+ self.parser.add_argument(
388
+ "--training.fsdp_reshard_after_forward",
389
+ type=str,
390
+ default="default",
391
+ choices=["default", "always", "never"],
392
+ help="""
393
+ `reshard_after_forward` specifies the policy for applying `reshard_after_forward`
394
+ within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
395
+ trading off memory and communication. See torch's `fully_shard` API for more documentation
396
+ on `reshard_after_forward`.
397
+ The supported policies include "default", "always" and "never":
398
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal
399
+ scenarios.
400
+ - "always" will enable `reshard_after_forward` for all forward passes.
401
+ - "never" will disable `reshard_after_forward` for all forward passes.
402
+ """,
403
+ )
404
+ self.parser.add_argument(
405
+ "--training.mixed_precision_param",
406
+ type=str,
407
+ default="bfloat16",
408
+ choices=["bfloat16", "float32"],
409
+ help="""
410
+ torch dtype to use for parameters when applying mixed precision via FSDP.
411
+ This feature only takes effect when data_parallel_shard_degree > 1
412
+ """,
413
+ )
414
+ self.parser.add_argument(
415
+ "--training.mixed_precision_reduce",
416
+ type=str,
417
+ default="float32",
418
+ choices=["float32"],
419
+ help="""
420
+ torch dtype to use for reductions when applying mixed precision via FSDP.
421
+ This feature only takes effect when data_parallel_shard_degree > 1
422
+ """,
423
+ )
424
+ self.parser.add_argument(
425
+ "--training.compile",
426
+ action="store_true",
427
+ help="Whether to compile the model",
428
+ )
429
+ self.parser.add_argument(
430
+ "--training.gc_freq",
431
+ type=int,
432
+ default=50,
433
+ help="Python garbage control scheduling interval, in steps",
434
+ )
435
+ self.parser.add_argument(
436
+ "--training.seed",
437
+ type=int,
438
+ default=42,
439
+ help="Choose the base RNG seed used for training",
440
+ )
441
+ self.parser.add_argument(
442
+ "--training.deterministic",
443
+ action="store_true",
444
+ help="Use deterministic algorithms wherever possible, may be slower",
445
+ )
446
+ # metrics configs
447
+ self.parser.add_argument(
448
+ "--metrics.log_freq",
449
+ type=int,
450
+ default=10,
451
+ help="How often to log metrics to TensorBoard, in iterations",
452
+ )
453
+ self.parser.add_argument(
454
+ "--metrics.enable_tensorboard",
455
+ action="store_true",
456
+ help="Whether to log metrics to TensorBoard",
457
+ )
458
+ self.parser.add_argument(
459
+ "--metrics.disable_color_printing",
460
+ action="store_true",
461
+ help="Whether to disable color printing in logs",
462
+ )
463
+ self.parser.add_argument(
464
+ "--metrics.save_tb_folder",
465
+ type=str,
466
+ default="tb",
467
+ help="Folder to dump TensorBoard states",
468
+ )
469
+ self.parser.add_argument(
470
+ "--metrics.save_for_all_ranks",
471
+ action="store_true",
472
+ default=False,
473
+ help="""
474
+ Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks.
475
+ When this option is False and pipeline_parallel_degree is > 1, the metrics
476
+ component uses the 0th rank of the last stage pipeline group, which is the
477
+ only stage that computes loss metrics.
478
+ """,
479
+ )
480
+ self.parser.add_argument(
481
+ "--metrics.enable_wandb",
482
+ action="store_true",
483
+ help="Whether to log metrics to Weights & Biases",
484
+ )
485
+
486
+ self.parser.add_argument(
487
+ "--experimental.enable_async_tensor_parallel",
488
+ action="store_true",
489
+ help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
490
+ )
491
+ self.parser.add_argument(
492
+ "--experimental.pipeline_parallel_degree",
493
+ type=int,
494
+ default=1,
495
+ help="""
496
+ Pipeline Parallelism degree, or number of ranks. 1 means disabled.
497
+ If using looped schedules, this still specifies the number of physical ranks, not the number
498
+ of stages. Stages per rank are inferred from split points degree, and schedule.""",
499
+ )
500
+ self.parser.add_argument(
501
+ "--experimental.pipeline_parallel_split_points",
502
+ type=string_list,
503
+ nargs="+",
504
+ default=[],
505
+ help="""
506
+ Specify comma-separated names of modules to use as the beginning of a split point.
507
+
508
+ e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
509
+ the first containing all the layers up to layers.0,
510
+ the second containing layers.0 and up to layers.2,
511
+ the third containing layers.2 and all the remaining layers.
512
+
513
+ Note: fully-automated splitting may be enabled in the future,
514
+ but currently the split points must be specified manually.""",
515
+ )
516
+ self.parser.add_argument(
517
+ "--experimental.pipeline_parallel_schedule",
518
+ type=str,
519
+ default="1F1B",
520
+ help="""
521
+ Specify the Pipeline Parallel schedule to use. The supported schedules are:
522
+ https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161.
523
+ The schedule must be compatible with the split points and stages_per_rank.
524
+
525
+ Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
526
+ and split_points = number of stages - 1
527
+ """,
528
+ )
529
+ self.parser.add_argument(
530
+ "--experimental.pipeline_parallel_schedule_csv",
531
+ type=str,
532
+ default="",
533
+ help="""
534
+ Specify the path to the pipeline parallel schedule csv file to use.
535
+ The pipeline_parallel_schedule argument must be either
536
+ PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime.
537
+ """,
538
+ )
539
+
540
+ self.parser.add_argument(
541
+ "--experimental.pipeline_parallel_microbatches",
542
+ type=int,
543
+ default=None,
544
+ help="""
545
+ How many microbatches to split the global training batch into when using pipeline parallelism.
546
+
547
+ The global training batch size must be evenly divisible by the number of microbatches.
548
+
549
+ The default value will be the number of pipeline stages, if unspecified.
550
+ """,
551
+ )
552
+ self.parser.add_argument(
553
+ "--experimental.enable_compiled_autograd",
554
+ action="store_true",
555
+ help="Enable CompiledAutograd to compile the backward.",
556
+ )
557
+ self.parser.add_argument(
558
+ "--experimental.context_parallel_degree",
559
+ type=int,
560
+ default=1,
561
+ help="Context parallelism degree. 1 means disabled.",
562
+ )
563
+ self.parser.add_argument(
564
+ "--experimental.context_parallel_rotate_method",
565
+ type=str,
566
+ default="allgather",
567
+ help="""
568
+ The collective to use in context parallel SDPA for kv shards exchange.
569
+
570
+ 'allgather' means to all-gather all kv shards on ranks after the first sub-SDPA computation,
571
+
572
+ 'alltoall' means to all-to-all shuffle the kv shards.
573
+
574
+ The default value is 'allgather'.
575
+ """,
576
+ )
577
+ # I'm not particularly fond of this. Users can choose to write their own wrapper
578
+ # module and import TorchTitan training loop and execute it, which look cleaner.
579
+ # One reason to provide this option is to allow users to use the existing run script.
580
+ # While the script is pretty trivial now, we may add more logic when integrating
581
+ # with TorchFT.
582
+ # This option is subject to change and may be deleted in the future.
583
+ self.parser.add_argument(
584
+ "--experimental.custom_model_path",
585
+ type=str,
586
+ default="",
587
+ help="""
588
+ The --custom_model_path option allows to specify a custom path to a model module
589
+ that is not natively implemented within TorchTitan.
590
+ Acceptable values are the file system path to the module (e.g., my_models/model_x)
591
+ dotted import module (e.g., some_package.model_x).
592
+ """,
593
+ )
594
+ # checkpointing configs
595
+ self.parser.add_argument(
596
+ "--checkpoint.enable_checkpoint",
597
+ action="store_true",
598
+ help="Whether to enable checkpoint",
599
+ )
600
+ self.parser.add_argument(
601
+ "--checkpoint.folder",
602
+ type=str,
603
+ default="checkpoint",
604
+ help="""
605
+ The folder to store the checkpoints.
606
+ When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
607
+ """,
608
+ )
609
+ self.parser.add_argument(
610
+ "--checkpoint.interval",
611
+ type=int,
612
+ default=500,
613
+ help="Checkpointing interval in steps.",
614
+ )
615
+ self.parser.add_argument(
616
+ "--checkpoint.model_weights_only",
617
+ action="store_true",
618
+ help="""
619
+ When model_weights_only=True, only model weights will be saved at the end of training.
620
+ With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion.
621
+ When model_weights_only=False, the full checkpoint will be saved.
622
+ A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
623
+ The default value is false.
624
+ """,
625
+ )
626
+ self.parser.add_argument(
627
+ "--checkpoint.export_dtype",
628
+ type=str,
629
+ default="float32",
630
+ choices=["float16", "bfloat16", "float32"],
631
+ help="""
632
+ Converts to the specified precision when training completes and model_weights_only=true.
633
+ Currently supports float32, float16, and bfloat16.
634
+ The default value is float32.
635
+ """,
636
+ )
637
+ self.parser.add_argument(
638
+ "--checkpoint.create_seed_checkpoint",
639
+ action="store_true",
640
+ help="""
641
+ Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint.
642
+ Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1.
643
+ Could be implemented as a separate script, but this way shares more code.
644
+ """,
645
+ )
646
+ self.parser.add_argument(
647
+ "--checkpoint.async_mode",
648
+ type=str,
649
+ default="disabled",
650
+ help="""
651
+ Which async checkpoint mode to use. Currently there are 3 different modes.
652
+ 1. "disabled": synchronized checkpointing will be used.
653
+ 2. "async": torch.distributed.checkpoint.async_save will be used.
654
+ 3. "async_with_pinned_mem": this option utilizes a dedicated pinned memory
655
+ space and creates a separate process for faster GPU->CPU transfer
656
+ performance and eliminating GIL contention. The cost is increased CPU
657
+ memory usage. If insufficient CPU memory is available, performance may
658
+ degrade due to memory paging. For most users, "async" should suffice as
659
+ the performance overhead is typically small (on the order of tens of
660
+ seconds) compared to checkpointing frequency. This mode can be employed
661
+ to pursue near-zero checkpointing times (e.g., < 1 second) given
662
+ appropriate hardware support such as ample CPU memory and fast PCIe.
663
+
664
+ "disabled" is the default mode.
665
+ """,
666
+ )
667
+ self.parser.add_argument(
668
+ "--checkpoint.keep_latest_k",
669
+ type=int,
670
+ default=0,
671
+ help="""
672
+ Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints.
673
+ 0 is the default value. k cannot be 1 as the last one may be in the process of being
674
+ saved. As a result, the metadata of the last one may not be ready yet.
675
+ """,
676
+ )
677
+ self.parser.add_argument(
678
+ "--checkpoint.load_step",
679
+ type=int,
680
+ default=-1,
681
+ help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
682
+ )
683
+ self.parser.add_argument(
684
+ "--checkpoint.exclude_from_loading",
685
+ type=string_list,
686
+ nargs="*",
687
+ default=[],
688
+ help="""
689
+ Exclude specific keys from being loaded from the checkpoint.
690
+ Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
691
+ This will load the model only, excluding the specified keys.
692
+ """,
693
+ )
694
+ self.parser.add_argument(
695
+ "--checkpoint.convert_to_hf_on_save",
696
+ action="store_true",
697
+ help="""
698
+ If true, automatically convert the saved DCP checkpoint to Hugging Face format
699
+ in a parallel directory (e.g., step-1000-hf) after each save.
700
+ """,
701
+ )
702
+ self.parser.add_argument(
703
+ "--checkpoint.hf_upload_enabled",
704
+ action="store_true",
705
+ help="Enable uploading converted Hugging Face checkpoints to the Hub.",
706
+ )
707
+ self.parser.add_argument(
708
+ "--checkpoint.hf_repo_base_name",
709
+ type=str,
710
+ default=None,
711
+ help="Hugging Face Hub repository ID to upload checkpoints to (e.g., 'username/repo').",
712
+ )
713
+ self.parser.add_argument(
714
+ "--checkpoint.hf_upload_format",
715
+ type=str,
716
+ default="dcp",
717
+ choices=["dcp", "hf"],
718
+ help="""
719
+ Format to upload to Hugging Face Hub. 'dcp' for DCP format, 'hf' for Hugging Face format.
720
+ Note: 'hf' is only supported for models with a single pipeline stage.
721
+ """,
722
+ )
723
+ # activation checkpointing configs
724
+ self.parser.add_argument(
725
+ "--activation_checkpoint.mode",
726
+ type=str,
727
+ default="selective",
728
+ help="Type of activation checkpointing to use ['none', 'full', 'selective']",
729
+ )
730
+ self.parser.add_argument(
731
+ "--activation_checkpoint.selective_ac_option",
732
+ type=str,
733
+ default="2", # 2 = checkpoint every other layer
734
+ help="""
735
+ Selective activation checkpointing options ['int', 'op'].
736
+ 'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
737
+ """,
738
+ )
739
+
740
+ self.parser.add_argument(
741
+ "--activation_offload.mode",
742
+ type=str,
743
+ default="none",
744
+ help="""
745
+ if we are using activation offload or not. Options are ['none', 'full'].
746
+ """,
747
+ )
748
+
749
+ # float8 configs
750
+ self.parser.add_argument(
751
+ "--float8.enable_fsdp_float8_all_gather",
752
+ action="store_true",
753
+ help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
754
+ )
755
+ self.parser.add_argument(
756
+ "--float8.precompute_float8_dynamic_scale_for_fsdp",
757
+ action="store_true",
758
+ help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
759
+ )
760
+ self.parser.add_argument(
761
+ "--float8.force_recompute_fp8_weight_in_bwd",
762
+ action="store_true",
763
+ help="""
764
+ Whether to force the recomputation of FP8 weights during backward pass.
765
+ When using FSDP with tensorwise scaling, it is recommended to enable
766
+ `force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
767
+ for backward computation.
768
+ """,
769
+ )
770
+ self.parser.add_argument(
771
+ "--float8.recipe_name",
772
+ type=str,
773
+ default=None,
774
+ choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
775
+ help="""
776
+ If specified, creates float8 config from recipe name, valid choices are
777
+ `tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
778
+ """,
779
+ )
780
+
781
+ # communications library settings
782
+ self.parser.add_argument(
783
+ "--comm.init_timeout_seconds",
784
+ type=int,
785
+ default=300,
786
+ help="Timeout for communication operations, during initialization and first train step.",
787
+ )
788
+ self.parser.add_argument(
789
+ "--comm.train_timeout_seconds",
790
+ type=int,
791
+ default=100,
792
+ help=(
793
+ "Timeout for communication operations after the first train step -- "
794
+ "usually a tighter bound than during initialization."
795
+ ),
796
+ )
797
+ self.parser.add_argument(
798
+ "--comm.trace_buf_size",
799
+ type=int,
800
+ default=20000,
801
+ help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
802
+ )
803
+
804
+ # memory estimation settings
805
+ self.parser.add_argument(
806
+ "--memory_estimation.enabled",
807
+ help="Whether to estimate memory usage for FSDP",
808
+ action="store_true",
809
+ )
810
+
811
+ self.parser.add_argument(
812
+ "--memory_estimation.disable_fake_mode",
813
+ help="Whether to estimate memory under FakeTensorMode",
814
+ action="store_true",
815
+ )
816
+
817
+ self.parser.add_argument(
818
+ "--fault_tolerance.enable",
819
+ action="store_true",
820
+ help="""
821
+ Enable TorchFT integration. When TorchFT is enabled, HSDP will be used.
822
+ And --fault_tolerance.data_parallel_replicate_degree should be 1 and
823
+ --fault_tolerance.group_size will be used to control the maximum
824
+ replicate group size as the replicate group size is dynamic.
825
+
826
+ Note that this is still an experimental feature.
827
+ """,
828
+ )
829
+
830
+ self.parser.add_argument(
831
+ "--fault_tolerance.replica_id",
832
+ type=int,
833
+ default=0,
834
+ help="The TorchFT replica ID of this run.",
835
+ )
836
+
837
+ self.parser.add_argument(
838
+ "--fault_tolerance.group_size",
839
+ type=int,
840
+ default=0,
841
+ help="""
842
+ The number of TorchFT replicate groups. This number will be used for
843
+ dataloader to split the dataset across the replicate groups and FSDP
844
+ dimension
845
+ """,
846
+ )
847
+
848
+ self.parser.add_argument(
849
+ "--fault_tolerance.min_replica_size",
850
+ type=int,
851
+ default=1,
852
+ help="The minimum number of FT replica for each step.",
853
+ )
854
+
855
+ def to_dict(self):
856
+ return self.args_dict
857
+
858
+ def parse_args(self, args_list: list = sys.argv[1:]):
859
+ args, cmd_args = self.parse_args_from_command_line(args_list)
860
+ config_file = getattr(args, "job.config_file", None)
861
+ # build up a two level dict
862
+ args_dict = self._args_to_two_level_dict(args)
863
+ if config_file is not None:
864
+ try:
865
+ with open(config_file, "rb") as f:
866
+ for k, v in tomllib.load(f).items():
867
+ # to prevent overwrite of non-specified keys
868
+ args_dict[k] |= v
869
+ except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
870
+ logger.exception(
871
+ f"Error while loading the configuration file: {config_file}"
872
+ )
873
+ logger.exception(f"Error details: {str(e)}")
874
+ raise e
875
+
876
+ # Checking string-list arguments are properly split into a list
877
+ # if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
878
+ string_list_argnames = self._get_string_list_argument_names()
879
+ for n in string_list_argnames:
880
+ check_string_list_argument(args_dict, n)
881
+
882
+ # override args dict with cmd_args
883
+ cmd_args_dict = self._args_to_two_level_dict(cmd_args)
884
+ for section, section_args in cmd_args_dict.items():
885
+ for k, v in section_args.items():
886
+ args_dict[section][k] = v
887
+
888
+ self.args_dict = args_dict
889
+
890
+ for k, v in args_dict.items():
891
+ class_type = type(k.title(), (), v)
892
+ setattr(self, k, class_type())
893
+ self._validate_config()
894
+
895
+ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
896
+ args_dict = defaultdict(defaultdict)
897
+ for k, v in vars(args).items():
898
+ first_level_key, second_level_key = k.split(".", 1)
899
+ args_dict[first_level_key][second_level_key] = v
900
+ return args_dict
901
+
902
+ def _validate_config(self) -> None:
903
+ # TODO: Add more mandatory validations
904
+ assert self.model.config
905
+ assert self.model.tokenizer_path
906
+
907
+ def _get_string_list_argument_names(self) -> list[str]:
908
+ """Get the parser argument names of type `string_list`."""
909
+ string_list_args = [
910
+ v.dest for v in self.parser._actions if v.type is string_list
911
+ ]
912
+ return string_list_args
913
+
914
+ def parse_args_from_command_line(
915
+ self, args_list
916
+ ) -> Tuple[argparse.Namespace, argparse.Namespace]:
917
+ """
918
+ Parse command line arguments and return the parsed args and the command line only args
919
+ """
920
+ args = self.parser.parse_args(args_list)
921
+ string_list_argnames = set(self._get_string_list_argument_names())
922
+
923
+ # aux parser to parse the command line only args, with no defaults from main parser
924
+ aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
925
+ for arg, val in vars(args).items():
926
+ if isinstance(val, bool):
927
+ aux_parser.add_argument(
928
+ "--" + arg, action="store_true" if val else "store_false"
929
+ )
930
+ elif arg in string_list_argnames:
931
+ # without this special case, type inference breaks here,
932
+ # since the inferred type is just 'list' and it ends up flattening
933
+ # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
934
+ aux_parser.add_argument("--" + arg, type=string_list)
935
+ else:
936
+ aux_parser.add_argument("--" + arg, type=type(val))
937
+
938
+ cmd_args, _ = aux_parser.parse_known_args(args_list)
939
+
940
+ return args, cmd_args
flame/data.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import pickle
7
+ from copy import deepcopy
8
+ from dataclasses import dataclass
9
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
10
+
11
+ import datasets
12
+ import numpy as np
13
+ import torch
14
+ from datasets import Dataset, IterableDataset
15
+ from datasets.iterable_dataset import ShufflingConfig
16
+ from torch.distributed.checkpoint.stateful import Stateful
17
+ from torchdata.stateful_dataloader import StatefulDataLoader
18
+ from transformers import PreTrainedTokenizer
19
+
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ class BufferShuffledIterableDataset(IterableDataset):
24
+ def __init__(
25
+ self,
26
+ dataset: Dataset,
27
+ tokenizer: PreTrainedTokenizer,
28
+ seq_len: int = 2048,
29
+ rank: int = 0,
30
+ world_size: int = 1,
31
+ buffer_size: int = 1024,
32
+ ) -> BufferShuffledIterableDataset:
33
+ self.dataset = dataset
34
+ self.tokenizer = tokenizer
35
+
36
+ self.data = dataset.shard(world_size, rank)
37
+ self.seq_len = seq_len
38
+
39
+ self.rank = rank
40
+ self.world_size = world_size
41
+ self.buffer_size = buffer_size
42
+
43
+ if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
44
+ self.dtype = torch.int16
45
+ elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
46
+ self.dtype = torch.int32
47
+ else:
48
+ self.dtype = torch.int64
49
+ self.states = None
50
+ self.buffer = torch.tensor([], dtype=self.dtype)
51
+ self.tokens = []
52
+ self.rand_id = 0
53
+ self.token_id = 0
54
+ self.rng_state = None
55
+ self._epoch = 0
56
+
57
+ def __iter__(self):
58
+ g = torch.Generator()
59
+ g.manual_seed(self._epoch + self.rank)
60
+ if self.rng_state is not None:
61
+ g.set_state(self.rng_state)
62
+
63
+ rand_it = self.randint(0, self.buffer_size, g=g)
64
+ if self.states is not None:
65
+ self.data.load_state_dict(self.states)
66
+
67
+ # max number of tokens allowed in the chunk buffer
68
+ n_tokens = self.buffer_size * self.seq_len
69
+
70
+ while True:
71
+ for sample in self.tokenize(self.data):
72
+ # keep appending the samples to the token buffer
73
+ self.tokens += sample
74
+ # if the token buffer is full, start sampling
75
+ # NOTE: we first convert the token ids to a tensor of shape [n_chunks, seq_len] for efficiency
76
+ if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
77
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
78
+ self.tokens = self.tokens[n_tokens:]
79
+ if len(self.buffer) == self.buffer_size:
80
+ yield from self.sample(rand_it)
81
+
82
+ n_chunks = len(self.tokens) // self.seq_len
83
+ # handle the left tokens in the buffer
84
+ if n_chunks > 0:
85
+ n_tokens = n_chunks * self.seq_len
86
+ indices = torch.randperm(n_chunks, generator=g).tolist()
87
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
88
+ self.tokens = self.tokens[n_tokens:]
89
+ for i in indices:
90
+ yield {'input_ids': self.buffer[i]}
91
+
92
+ def tokenize(self, data, batch_size: int = 64):
93
+ texts, states = [], []
94
+ for sample in data:
95
+ texts.append(sample['text'])
96
+ states.append(self.data.state_dict())
97
+ if len(texts) == batch_size:
98
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
99
+ self.states = s
100
+ yield tokenized
101
+ texts, states = [], []
102
+ if len(texts) > 0:
103
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
104
+ self.states = s
105
+ yield tokenized
106
+
107
+ def sample(self, indices):
108
+ n_tokens = (len(self.tokens) // self.seq_len) * self.seq_len
109
+ while self.token_id < n_tokens:
110
+ i = next(indices)
111
+ start, end = self.token_id, self.token_id + self.seq_len
112
+ self.token_id += self.seq_len
113
+ yield {'input_ids': self.buffer[i].to(torch.long)}
114
+ self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
115
+ self.token_id = 0
116
+ self.tokens = self.tokens[n_tokens:]
117
+
118
+ def randint(self, low: int, high: int, buffer_size: int = 1024, g: torch.Generator = torch.Generator()) -> Iterable[int]:
119
+ indices = torch.empty(buffer_size, dtype=torch.long)
120
+ while True:
121
+ # record the generator states before sampling
122
+ self.rng_state = g.get_state()
123
+ indices = torch.randint(low, high, (buffer_size,), out=indices, generator=g)
124
+ for i in indices[self.rand_id:].tolist():
125
+ self.rand_id += 1
126
+ yield i
127
+ self.rand_id = 0
128
+
129
+ def set_epoch(self, epoch):
130
+ self._epoch = epoch
131
+ if hasattr(self.dataset, 'set_epoch'):
132
+ self.dataset.set_epoch(epoch)
133
+
134
+ def state_dict(self):
135
+ return {
136
+ 'states': self.states,
137
+ 'buffer': self.buffer.clone(),
138
+ 'tokens': deepcopy(self.tokens),
139
+ 'rand_id': self.rand_id,
140
+ 'token_id': self.token_id,
141
+ 'rng_state': self.rng_state,
142
+ 'epoch': self._epoch,
143
+ }
144
+
145
+ def load_state_dict(self, state_dict):
146
+ self.states = state_dict['states']
147
+ self.buffer = state_dict['buffer'].clone()
148
+ self.tokens = deepcopy(state_dict['tokens'])
149
+ self.rand_id = state_dict['rand_id']
150
+ self.token_id = state_dict['token_id']
151
+ self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
152
+ self._epoch = state_dict['epoch']
153
+
154
+
155
+ class OnlineTokenizedIterableDataset(IterableDataset):
156
+ def __init__(
157
+ self, dataset: Dataset, tokenizer: PreTrainedTokenizer, seq_len: int = 2048, rank: int = 0, world_size: int = 1
158
+ ) -> OnlineTokenizedIterableDataset:
159
+ self.dataset = dataset
160
+ self.tokenizer = tokenizer
161
+
162
+ self.data = dataset.shard(world_size, rank)
163
+ self.seq_len = seq_len
164
+ self.rank = rank
165
+ self.world_size = world_size
166
+
167
+ self.states = None
168
+ self.tokens = []
169
+
170
+ def __iter__(self):
171
+ if self.states is not None:
172
+ self.data.load_state_dict(self.states)
173
+
174
+ while True:
175
+ for sample in self.tokenize(self.data):
176
+ # keep appending the samples to the token buffer
177
+ self.tokens += sample
178
+
179
+ while len(self.tokens) >= self.seq_len:
180
+ input_ids = torch.tensor(self.tokens[:self.seq_len], dtype=torch.long)
181
+ self.tokens = self.tokens[self.seq_len:]
182
+ yield {'input_ids': input_ids}
183
+
184
+ def tokenize(self, data, buffer_size: int = 64):
185
+ buffer, states = [], []
186
+ for sample in data:
187
+ if sample.get('text', None) is not None:
188
+ buffer.append(sample['text'])
189
+ elif sample.get('content', None) is not None:
190
+ buffer.append(sample['content'])
191
+ else:
192
+ raise ValueError(f"No 'text' or 'content' field found in sample:\n{sample}")
193
+ states.append(self.data.state_dict())
194
+ if len(buffer) == buffer_size:
195
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
196
+ self.states = s
197
+ yield tokenized
198
+ buffer, states = [], []
199
+ if len(buffer) > 0:
200
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
201
+ self.states = s
202
+ yield tokenized
203
+
204
+ def state_dict(self):
205
+ return {'states': self.states, 'tokens': deepcopy(self.tokens)}
206
+
207
+ def load_state_dict(self, state_dict):
208
+ self.states = state_dict['states']
209
+ self.tokens = deepcopy(state_dict['tokens'])
210
+
211
+
212
+ class BufferShuffledExamplesIterable(datasets.iterable_dataset.BufferShuffledExamplesIterable):
213
+ def __init__(self, *args, **kwargs):
214
+ super().__init__(*args, **kwargs)
215
+
216
+ def _init_state_dict(self) -> dict:
217
+ self._state_dict = self.ex_iterable._init_state_dict()
218
+ self._state_dict['mem_buffer'] = ([],)
219
+ self._state_dict['bit_generator_state'] = self.generator.bit_generator.state
220
+ self._state_dict['bit_generator_index_offset'] = 0
221
+ self._state_dict['bit_generator_index_offset_shuffle'] = 0
222
+ return self._state_dict
223
+
224
+ def __iter__(self):
225
+ buffer_size = self.buffer_size
226
+ rng = deepcopy(self.generator)
227
+ # this is the shuffle buffer that we keep in memory
228
+ mem_buffer = self._state_dict['mem_buffer'][0]
229
+ # this is an infinite iterator that randomly samples the index of the source to pick examples from
230
+ index_offset = self._state_dict['bit_generator_index_offset'] if self._state_dict else 0
231
+ if self._state_dict:
232
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
233
+ indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size=buffer_size)
234
+ # skip already consumed ones
235
+ for _ in range(index_offset):
236
+ i = next(indices_iterator)
237
+
238
+ for x in self.ex_iterable:
239
+ if len(mem_buffer) < buffer_size: # if the buffer is not full, keep filling the buffer
240
+ mem_buffer.append(x)
241
+ else: # otherwise, pick an example from it
242
+ i = next(indices_iterator)
243
+ index_offset = (index_offset + 1) % buffer_size
244
+ if self._state_dict:
245
+ self._state_dict['bit_generator_index_offset'] = index_offset
246
+ if index_offset == 0:
247
+ self._state_dict['bit_generator_state'] = rng.bit_generator.state
248
+ selected = mem_buffer[i]
249
+ mem_buffer[i] = x # replace the picked example by a new one
250
+ yield selected
251
+
252
+ index_offset = self._state_dict['bit_generator_index_offset_shuffle'] if self._state_dict else 0
253
+ if self._state_dict:
254
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
255
+
256
+ # when we run out of examples, we shuffle the remaining examples in the buffer and yield them
257
+ for i in rng.permutation(len(mem_buffer))[index_offset:].tolist():
258
+ index_offset = index_offset + 1
259
+ if self._state_dict:
260
+ self._state_dict['bit_generator_index_offset_shuffle'] = index_offset
261
+ yield mem_buffer[i]
262
+
263
+ def shuffle_data_sources(self, generator: np.random.Generator) -> BufferShuffledExamplesIterable:
264
+ """Shuffle the wrapped examples iterable as well as the shuffling buffer."""
265
+ return BufferShuffledExamplesIterable(
266
+ self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
267
+ )
268
+
269
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> BufferShuffledExamplesIterable:
270
+ """Keep only the requested shard."""
271
+ return BufferShuffledExamplesIterable(
272
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
273
+ buffer_size=self.buffer_size,
274
+ generator=self.generator,
275
+ )
276
+
277
+ def load_state_dict(self, state_dict: dict) -> dict:
278
+ def _inner_load_state_dict(state, new_state):
279
+ if new_state is not None and isinstance(state, dict):
280
+ for key in new_state:
281
+ state[key] = _inner_load_state_dict(state[key], new_state[key])
282
+ return state
283
+ elif new_state is not None and isinstance(state, list):
284
+ for i in range(len(state)):
285
+ state[i] = _inner_load_state_dict(state[i], new_state[i])
286
+ return state
287
+ return new_state
288
+
289
+ return _inner_load_state_dict(self._state_dict, state_dict)
290
+
291
+
292
+ def shuffle(
293
+ dataset: IterableDataset,
294
+ seed: int = 42,
295
+ generator: np.random.Generator = None,
296
+ buffer_size: int = 1024,
297
+ ):
298
+ generator = np.random.default_rng(seed) if generator is None else deepcopy(generator)
299
+ return IterableDataset(
300
+ ex_iterable=BufferShuffledExamplesIterable(dataset._ex_iterable, buffer_size=buffer_size, generator=generator),
301
+ info=dataset._info.copy(),
302
+ split=dataset._split,
303
+ formatting=dataset._formatting,
304
+ shuffling=ShufflingConfig(generator=generator, _original_seed=seed),
305
+ distributed=copy.deepcopy(dataset._distributed),
306
+ token_per_repo_id=dataset._token_per_repo_id,
307
+ )
308
+
309
+
310
+ @dataclass
311
+ class DataCollatorForLanguageModeling:
312
+ """
313
+ Data collator used for language modeling. Inputs are dynamically padded if `varlen=False`.
314
+ If `varlen=True`, sequences are expected to be concatenated, and labels match inputs.
315
+
316
+ Args:
317
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
318
+ The tokenizer used for encoding the data.
319
+ context_len (`int`, optional):
320
+ When `varlen=True`, sequences longer than this length within a document
321
+ (as determined by `cu_seqlens`) will be further chunked.
322
+ varlen (`bool`):
323
+ Whether to handle variable length concatenated sequences (`True`) or padded batches (`False`).
324
+
325
+ Returns:
326
+ A dictionary with the following keys:
327
+ - `input_ids`: Tensor of input IDs. Shape `[batch_size, seq_len]` if `varlen=False`, `[1, total_len]` if `varlen=True`.
328
+ - `labels`: Tensor of labels. Shape matches `input_ids`. Padding positions are masked with -100 if `varlen=False`.
329
+ - `attention_mask`: Tensor indicating non-padding tokens (only if `varlen=False`). Shape matches `input_ids`.
330
+ - `cu_seqlens`: Tensor of cumulative sequence lengths (only if `varlen=True`). Shape `[1, num_sequences + 1]`.
331
+
332
+ NOTE: When `varlen=True`, the `batch_size` must be 1.
333
+ """
334
+
335
+ tokenizer: PreTrainedTokenizer
336
+ context_len: Optional[int] = None
337
+ varlen: bool = False
338
+
339
+ def __call__(self, examples: List[Union[List[int], Dict[str, Any]]]) -> Dict[str, Any]:
340
+ if not isinstance(examples[0], Dict):
341
+ examples = [{'input_ids': example} for example in examples]
342
+
343
+ def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
344
+ tensorized = {}
345
+ for key in ['input_ids', 'cu_seqlens']:
346
+ if key not in example:
347
+ continue
348
+ if isinstance(example[key], List):
349
+ tensorized[key] = torch.tensor(example[key], dtype=torch.long)
350
+ elif isinstance(example[key], np.ndarray):
351
+ tensorized[key] = torch.from_numpy(example[key])
352
+ else:
353
+ tensorized[key] = example[key]
354
+ return tensorized
355
+
356
+ examples = list(map(tensorize, examples))
357
+
358
+ if not self.varlen:
359
+ # --- Handling for varlen=False (Batch Padding) ---
360
+ length_of_first = examples[0]['input_ids'].size(0)
361
+ needs_padding = not all(example['input_ids'].size(0) == length_of_first for example in examples)
362
+
363
+ if needs_padding:
364
+ # Check for pad token if padding is actually required
365
+ if self.tokenizer.pad_token_id is None:
366
+ raise ValueError(
367
+ f'You are attempting to pad samples but the tokenizer you are using '
368
+ f'({self.tokenizer.__class__.__name__}) does not have a pad token.'
369
+ )
370
+ # Pad using the tokenizer, ensuring attention_mask is returned
371
+ batch = self.tokenizer.pad(examples, return_tensors='pt', return_attention_mask=True)
372
+ else:
373
+ # No padding needed, stack directly and create a full attention mask
374
+ input_ids = torch.stack([example['input_ids'] for example in examples], dim=0)
375
+ batch = {
376
+ 'input_ids': input_ids,
377
+ # Create attention mask of all ones
378
+ 'attention_mask': torch.ones_like(input_ids),
379
+ }
380
+
381
+ # Create labels by cloning input_ids
382
+ labels = batch['input_ids'].clone()
383
+ # Mask labels only where attention_mask is 0 (padding positions)
384
+ if 'attention_mask' in batch:
385
+ labels[batch['attention_mask'] == 0] = -100
386
+ batch['labels'] = labels
387
+
388
+ else:
389
+ # --- Handling for varlen=True (Concatenated Sequences) ---
390
+ if len(examples) > 1:
391
+ raise ValueError('The batch size must be 1 for inputs with variable lengths (varlen=True).')
392
+
393
+ batch = {'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)}
394
+
395
+ # --- cu_seqlens calculation logic remains the same ---
396
+ if 'cu_seqlens' in examples[0]:
397
+ batch['cu_seqlens'] = (
398
+ torch.cat([example['cu_seqlens'] for example in examples], dim=0).unsqueeze(0).to(dtype=torch.int32)
399
+ ) # Ensure int32
400
+ else:
401
+ # determine boundaries by bos/eos positions
402
+ # Check for bos_token_id first
403
+ if self.tokenizer.bos_token_id is not None:
404
+ cu_seqlens = []
405
+ # Handle case where the sequence doesn't start with BOS
406
+ if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
407
+ cu_seqlens.append(torch.tensor([0], device=batch['input_ids'].device)) # Match device
408
+ # Find all BOS token positions
409
+ bos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1]
410
+ # Ensure bos_positions is on the correct device if empty
411
+ if bos_positions.numel() == 0 and len(cu_seqlens) > 0:
412
+ cu_seqlens.append(bos_positions.to(cu_seqlens[0].device))
413
+ elif bos_positions.numel() > 0:
414
+ cu_seqlens.append(bos_positions)
415
+ # Add the end of the entire batch
416
+ cu_seqlens.append(
417
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
418
+ ) # Match device and use size(1)
419
+ # Filter out empty tensors before cat
420
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
421
+ if not cu_seqlens: # Handle case where input is empty or has no BOS
422
+ batch['cu_seqlens'] = torch.tensor(
423
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
424
+ )
425
+ else:
426
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
427
+
428
+ # Else, check for eos_token_id
429
+ elif self.tokenizer.eos_token_id is not None:
430
+ cu_seqlens = [torch.tensor([0], device=batch['input_ids'].device)] # Match device
431
+ # Find positions *after* EOS tokens
432
+ eos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1
433
+ # Ensure eos_positions is on the correct device if empty
434
+ if eos_positions.numel() > 0:
435
+ cu_seqlens.append(eos_positions)
436
+ # Handle case where the sequence doesn't end with EOS
437
+ if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
438
+ # Only add the final length if the last found EOS wasn't already the end
439
+ if eos_positions.numel() == 0 or eos_positions[-1] != batch['input_ids'].size(1):
440
+ cu_seqlens.append(
441
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
442
+ ) # Match device and use size(1)
443
+ # Filter out empty tensors before cat
444
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
445
+ if not cu_seqlens: # Handle case where input is empty or has no EOS
446
+ batch['cu_seqlens'] = torch.tensor(
447
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
448
+ )
449
+ else:
450
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
451
+ # Else, neither BOS nor EOS is usable
452
+ else:
453
+ raise ValueError(
454
+ 'For varlen=True without precomputed cu_seqlens, the tokenizer must have either a bos_token_id '
455
+ 'or an eos_token_id defined to act as sequence separators.'
456
+ )
457
+
458
+ # --- cu_seqlens validation checks remain the same ---
459
+ if batch['cu_seqlens'].numel() < 2:
460
+ raise ValueError(f'Calculated cu_seqlens must have at least start and end: {batch["cu_seqlens"]}')
461
+ if not torch.all(batch['cu_seqlens'][1:] >= batch['cu_seqlens'][:-1]):
462
+ raise ValueError(f'Calculated cu_seqlens are not monotonically increasing: {batch["cu_seqlens"]}')
463
+ if batch['cu_seqlens'][0] != 0:
464
+ raise ValueError(f'Calculated cu_seqlens do not start at 0: {batch["cu_seqlens"]}')
465
+ if batch['cu_seqlens'][-1] != batch['input_ids'].size(1):
466
+ # Allow empty sequence case where cu_seqlens=[0, 0] and input_ids.size(1)=0
467
+ if not (batch['cu_seqlens'].tolist() == [0, 0] and batch['input_ids'].size(1) == 0):
468
+ raise ValueError(
469
+ f'Calculated cu_seqlens do not end at total length {batch["input_ids"].size(1)}: '
470
+ f'{batch["cu_seqlens"]}'
471
+ )
472
+
473
+ # --- context_len splitting logic remains the same ---
474
+ if self.context_len is not None:
475
+ # This logic splits sequences based on context_len *after* initial boundaries are found
476
+ bos = batch['cu_seqlens'][:-1].tolist()
477
+ eos = batch['cu_seqlens'][1:].tolist()
478
+ # Handle empty sequences between boundaries
479
+ split_boundaries = []
480
+ for i, j in zip(bos, eos):
481
+ if i < j: # Only process non-empty sequences
482
+ split_boundaries.append(torch.arange(i, j, self.context_len, device=batch['input_ids'].device))
483
+ # Add the final end point if it wasn't included by arange
484
+ final_end_point = torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
485
+ # Concatenate all boundaries
486
+ if not split_boundaries: # Handle case of completely empty input
487
+ batch['cu_seqlens'] = torch.tensor([0, 0], dtype=torch.int32, device=batch['input_ids'].device)
488
+ else:
489
+ batch['cu_seqlens'] = torch.cat(split_boundaries + [final_end_point]).to(dtype=torch.int32)
490
+ # Ensure uniqueness and sort, as arange might duplicate the endpoint
491
+ batch['cu_seqlens'] = torch.unique(batch['cu_seqlens'])
492
+
493
+ # Create labels directly from input_ids, NO padding mask needed for varlen
494
+ labels = batch['input_ids'].clone()
495
+ batch['labels'] = labels
496
+
497
+ return batch
498
+
499
+
500
+ class ParallelAwareDataLoader(StatefulDataLoader, Stateful):
501
+ """
502
+ A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
503
+ """
504
+
505
+ def __init__(
506
+ self,
507
+ rank: int,
508
+ dataset: IterableDataset,
509
+ batch_size: int,
510
+ collate_fn: Callable,
511
+ num_workers: int = 0,
512
+ pin_memory: bool = False,
513
+ prefetch_factor: int = 2,
514
+ persistent_workers: bool = False,
515
+ snapshot_every_n_steps: Optional[int] = 1,
516
+ ):
517
+ super().__init__(
518
+ dataset=dataset,
519
+ batch_size=batch_size,
520
+ collate_fn=collate_fn,
521
+ num_workers=num_workers,
522
+ pin_memory=pin_memory,
523
+ prefetch_factor=prefetch_factor,
524
+ persistent_workers=persistent_workers,
525
+ snapshot_every_n_steps=snapshot_every_n_steps,
526
+ )
527
+ self.rank = rank
528
+
529
+ def state_dict(self) -> Dict[str, Any]:
530
+ # Store state only for dp rank to avoid replicating the same state across other dimensions
531
+ return {f'rank_{self.rank}': pickle.dumps(super().state_dict())}
532
+
533
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
534
+ # State being empty is valid
535
+ if not state_dict:
536
+ return
537
+
538
+ if f'rank_{self.rank}' not in state_dict:
539
+ logger.warning(f'DataLoader state is empty for dp rank {self.rank}, expected key rank_{self.rank}')
540
+ return
541
+ super().load_state_dict(pickle.loads(state_dict[f'rank_{self.rank}']))
542
+
543
+
544
+ def build_dataloader(
545
+ dataset: IterableDataset,
546
+ tokenizer: PreTrainedTokenizer,
547
+ rank: int,
548
+ world_size: int,
549
+ batch_size: int,
550
+ seq_len: int,
551
+ context_len: Optional[int] = None,
552
+ varlen: bool = False,
553
+ num_workers: int = 0,
554
+ pin_memory: bool = False,
555
+ persistent_workers: bool = False,
556
+ snapshot_every_n_steps: Optional[int] = 1,
557
+ ):
558
+ dataset = OnlineTokenizedIterableDataset(
559
+ dataset=dataset, tokenizer=tokenizer, seq_len=seq_len, rank=rank, world_size=world_size
560
+ )
561
+ return ParallelAwareDataLoader(
562
+ rank=rank,
563
+ dataset=dataset,
564
+ batch_size=batch_size,
565
+ collate_fn=DataCollatorForLanguageModeling(tokenizer=tokenizer, context_len=context_len, varlen=varlen),
566
+ num_workers=num_workers,
567
+ pin_memory=pin_memory,
568
+ persistent_workers=persistent_workers,
569
+ snapshot_every_n_steps=snapshot_every_n_steps,
570
+ )
flame/models/__init__.py ADDED
File without changes
flame/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (172 Bytes). View file
 
flame/models/__pycache__/parallelize_fla.cpython-311.pyc ADDED
Binary file (23.6 kB). View file
 
flame/models/__pycache__/pipeline_fla.cpython-311.pyc ADDED
Binary file (6.39 kB). View file
 
flame/tools/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.38 kB). View file
 
flame/utils/__init__.py ADDED
File without changes
flame/utils/__pycache__/checkpoint.cpython-311.pyc ADDED
Binary file (5 kB). View file
 
flame/utils/__pycache__/convert_dcp_to_hf.cpython-311.pyc ADDED
Binary file (4.14 kB). View file
 
flame/utils/__pycache__/hf_utils.cpython-311.pyc ADDED
Binary file (5.13 kB). View file
 
flame/utils/checkpoint.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import re
4
+ import shutil
5
+ from torchtitan.tools.logging import logger
6
+
7
+
8
+ def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int):
9
+ """Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats."""
10
+ if keep_latest_k <= 0:
11
+ return # Keep all checkpoints
12
+
13
+ logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}")
14
+
15
+ # Cleanup DCP checkpoints (step-*)
16
+ dcp_checkpoints = sorted(
17
+ glob.glob(os.path.join(checkpoint_dir, "step-*")),
18
+ key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1,
19
+ reverse=True
20
+ )
21
+ # Filter out HF format directories
22
+ dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")]
23
+
24
+ if len(dcp_checkpoints) > keep_latest_k:
25
+ checkpoints_to_delete = dcp_checkpoints[keep_latest_k:]
26
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
27
+ for ckpt_path in checkpoints_to_delete:
28
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
29
+ try:
30
+ shutil.rmtree(ckpt_path)
31
+ except OSError as e:
32
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
33
+
34
+
35
+ # Cleanup HF checkpoints (step-*-hf)
36
+ hf_checkpoints = sorted(
37
+ glob.glob(os.path.join(checkpoint_dir, "step-*-hf")),
38
+ key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1,
39
+ reverse=True
40
+ )
41
+
42
+ if len(hf_checkpoints) > keep_latest_k:
43
+ checkpoints_to_delete = hf_checkpoints[keep_latest_k:]
44
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
45
+ for ckpt_path in checkpoints_to_delete:
46
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
47
+ try:
48
+ shutil.rmtree(ckpt_path)
49
+ except OSError as e:
50
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
flame/utils/hf_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo
4
+ from torchtitan.tools.logging import logger
5
+
6
+ def upload_checkpoint_to_hf(
7
+ local_path: str,
8
+ step: int,
9
+ hf_repo_id_for_run: str,
10
+ hf_keep_latest_k: int,
11
+ upload_format: str
12
+ ):
13
+ """Uploads a checkpoint directory to HF Hub and manages retention."""
14
+ if not os.path.isdir(local_path):
15
+ logger.error(f"Local path for upload does not exist or is not a directory: {local_path}")
16
+ return
17
+
18
+ api = HfApi()
19
+ token = HfFolder.get_token()
20
+ if not token:
21
+ logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.")
22
+ return
23
+
24
+ # --- Ensure the specific repository for this run exists ---
25
+ try:
26
+ logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...")
27
+ # Use create_repo which handles creation only if it doesn't exist
28
+ create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True)
29
+ logger.info(f"Repository {hf_repo_id_for_run} ensured.")
30
+ except Exception as e:
31
+ logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True)
32
+ return # Stop if repo interaction fails
33
+
34
+ commit_message = f"Upload {upload_format.upper()} checkpoint step {step}"
35
+ path_in_repo = f"step-{step}"
36
+
37
+ logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...")
38
+ try:
39
+ api.upload_folder(
40
+ folder_path=local_path,
41
+ path_in_repo=path_in_repo,
42
+ repo_id=hf_repo_id_for_run,
43
+ repo_type="model",
44
+ commit_message=commit_message,
45
+ token=token,
46
+ )
47
+ logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.")
48
+ except Exception as e:
49
+ logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True)
50
+ if hf_keep_latest_k > 0:
51
+ logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}")
52
+ try:
53
+ repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False)
54
+ step_folders = [
55
+ item.path for item in repo_files
56
+ if item.path.startswith("step-") and item.path[5:].isdigit()
57
+ ]
58
+
59
+ step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True)
60
+
61
+ if len(step_folders) > hf_keep_latest_k:
62
+ folders_to_delete = step_folders[hf_keep_latest_k:]
63
+ logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}")
64
+ for folder in folders_to_delete:
65
+ # Deleting requires repo_id, path_in_repo, and token
66
+ api.delete_folder(
67
+ repo_id=hf_repo_id_for_run,
68
+ path_in_repo=folder,
69
+ repo_type="model",
70
+ commit_message=f"Delete old checkpoint {folder}",
71
+ token=token
72
+ )
73
+ logger.info("Hub cleanup complete.")
74
+ else:
75
+ logger.info("No old checkpoints found on Hub to delete.")
76
+ except Exception as e:
77
+ logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.51.3"
6
+ }
logs/none_xprcuk_o/attempt_0/0/stdout.log ADDED
File without changes
logs/none_xprcuk_o/attempt_0/1/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_xprcuk_o/attempt_0/2/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_xprcuk_o/attempt_0/2/stdout.log ADDED
File without changes
logs/none_xprcuk_o/attempt_0/3/stdout.log ADDED
File without changes
measure_sink_rate.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fla
2
+ import torch
3
+ import numpy as np
4
+ import argparse
5
+ from tqdm import tqdm
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from datasets import load_dataset
8
+
9
+ def calculate_sink_rate(attention_maps, epsilon=0.3):
10
+ """
11
+ Calculate sink rate using the formula:
12
+ sink_rate = (1/(L*H))*sum_L_H(1_((1/T)*sum_T(a_l_h_1_t) > epsilon))
13
+
14
+ Where:
15
+ - L is the number of layers
16
+ - H is the number of attention heads
17
+ - T is the sequence length
18
+ - 1_() is the indicator function
19
+ - a_ is the attention score at that index
20
+ - epsilon is the threshold (default: 0.3)
21
+
22
+ Args:
23
+ attention_maps: Attention maps from the model as a list with length L of tensors with shape [batch, heads, seq_len, seq_len]
24
+ epsilon: Threshold for attention
25
+
26
+ Returns:
27
+ sink_rate: The calculated sink rate
28
+ """
29
+ sink_rate = 0
30
+ for i, attention in enumerate(attention_maps):
31
+ # Extract attention on first token (BOS) across all heads
32
+ first_token_attention = attention[:, :, :, 0] # [batch, heads, seq_len]
33
+ # print("first token attentions", first_token_attention)
34
+
35
+ # Calculate mean attention on first token across sequence length
36
+ mean_first_token_attention = first_token_attention.mean(dim=-1) # [batch, heads]
37
+ # print("mean first token attentions", mean_first_token_attention)
38
+
39
+ # Apply indicator function - whether mean attention > epsilon
40
+ indicator = (mean_first_token_attention > epsilon).float() # [batch, heads]
41
+ # print("indicator", indicator)
42
+
43
+ # Average across heads
44
+ batch_sink_rates = indicator.mean(dim=(1)) # [batch]
45
+
46
+ # Average across batch
47
+ sink_rate += batch_sink_rates.mean().item()
48
+
49
+ # Normalize by number of layers
50
+ num_layers = len(attention_maps)
51
+ sink_rate /= num_layers
52
+
53
+ return sink_rate
54
+
55
+ def main(args):
56
+ # Set device
57
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
58
+ print(f"Using device: {device}")
59
+
60
+ # Load model and tokenizer
61
+ print(f"Loading model: {args.model_name}")
62
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
63
+ model = AutoModelForCausalLM.from_pretrained(args.model_name, return_dict_in_generate=True, output_attentions=True).half()
64
+ model.to(device)
65
+ model.eval()
66
+
67
+ # Add padding token if it doesn't exist
68
+ if tokenizer.pad_token is None:
69
+ tokenizer.pad_token = tokenizer.eos_token
70
+
71
+ # Load dataset
72
+ print(f"Loading dataset: {args.dataset_name}")
73
+ dataset = load_dataset(args.dataset_name, split=args.split)
74
+
75
+ # Take n_samples from the dataset
76
+ if args.n_samples < len(dataset):
77
+ dataset = dataset.select(range(args.n_samples))
78
+ else:
79
+ args.n_samples = len(dataset)
80
+
81
+ print(f"Processing {args.n_samples} samples...")
82
+
83
+ # Process samples
84
+ sink_rate = 0
85
+
86
+ for i in tqdm(range(0, args.n_samples, args.batch_size)):
87
+ batch_samples = dataset[i:min(i + args.batch_size, args.n_samples)]
88
+
89
+ # Tokenize
90
+ encodings = tokenizer(
91
+ batch_samples["text"],
92
+ padding=True,
93
+ truncation=True,
94
+ max_length=args.max_length,
95
+ return_tensors="pt"
96
+ ).to(device)
97
+
98
+ # Forward pass
99
+ with torch.no_grad():
100
+ outputs = model(**encodings)
101
+
102
+ # Calculate sink rate for this batch
103
+ batch_sink_rate = calculate_sink_rate(outputs.attentions, args.epsilon)
104
+ attention_maps = None # Free memory
105
+ sink_rate += batch_sink_rate
106
+
107
+ # Average sink rate
108
+ sink_rate /= args.n_samples
109
+
110
+ print(f"Sink Rate (ε={args.epsilon}): {sink_rate:.4f}")
111
+
112
+ # Optional: Save sink rate results
113
+ if args.output_file:
114
+ with open(args.output_file, 'w') as f:
115
+ f.write(f"Model: {args.model_name}\n")
116
+ f.write(f"Dataset: {args.dataset_name}\n")
117
+ f.write(f"Split: {args.split}\n")
118
+ f.write(f"Samples: {args.n_samples}\n")
119
+ f.write(f"Epsilon: {args.epsilon}\n")
120
+ f.write(f"Sink Rate: {sink_rate:.4f}\n")
121
+
122
+ print(f"Results saved to {args.output_file}")
123
+
124
+ if __name__ == "__main__":
125
+ parser = argparse.ArgumentParser(description="Measure Sink Rate for Transformer Models")
126
+ parser.add_argument("--model_name", type=str, default="gpt2", help="Huggingface model name")
127
+ parser.add_argument("--dataset_name", type=str, default="wikitext", help="Huggingface dataset name")
128
+ parser.add_argument("--split", type=str, default="test", help="Dataset split to use")
129
+ parser.add_argument("--n_samples", type=int, default=100, help="Number of samples to process")
130
+ parser.add_argument("--batch_size", type=int, default=4, help="Batch size for processing")
131
+ parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length")
132
+ parser.add_argument("--epsilon", type=float, default=0.3, help="Threshold for sink rate calculation")
133
+ parser.add_argument("--output_file", type=str, default="", help="File to save results")
134
+ parser.add_argument("--cpu", action="store_true", help="Force CPU usage")
135
+
136
+ args = parser.parse_args()
137
+ main(args)
passkey_retrieval.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import random
4
+ from numpy import random
5
+ from tqdm import tqdm
6
+ import fla
7
+ import transformers
8
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
9
+
10
+ def parse_config():
11
+ parser = argparse.ArgumentParser(description='arg parser')
12
+ parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
13
+ parser.add_argument('--cache_dir', type=str, default="./cache")
14
+ parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
15
+ parser.add_argument('--num_tests', type=int, default=10, help='number of repeat testing for each length')
16
+ parser.add_argument('--test_k_tokens', type=int, default=2, help='test length')
17
+
18
+ args = parser.parse_args()
19
+ return args
20
+
21
+
22
+ def generate_prompt_landmark(tokenizer, n_garbage, seed):
23
+ """Generates a text file and inserts an passkey at a random position."""
24
+ rnd_state = random.get_state()
25
+ random.seed(seed)
26
+ n_garbage_prefix = random.randint(0, n_garbage)
27
+ n_garbage_suffix = n_garbage - n_garbage_prefix
28
+
29
+ task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
30
+ garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
31
+ garbage_inf = " ".join([garbage] * 100000)
32
+ assert len(garbage_inf) >= n_garbage
33
+ garbage_prefix = garbage_inf[:n_garbage_prefix]
34
+ garbage_suffix = garbage_inf[:n_garbage_suffix]
35
+ pass_key = random.randint(1, 50000)
36
+
37
+
38
+ information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
39
+ final_question = "What is the pass key? The pass key is"
40
+ lines = [
41
+ task_description,
42
+ garbage_prefix,
43
+ information_line,
44
+ garbage_suffix, # generate: 1k token
45
+ final_question,
46
+ ]
47
+
48
+ position = []
49
+ current_pos = 0
50
+ block = 0
51
+
52
+ for line in lines:
53
+ input_ids = tokenizer(line, return_tensors="pt").input_ids
54
+ line_length = input_ids.size(1) # Get the length of the sequence
55
+ position.append((current_pos, current_pos + line_length - 1)) # Store the start and end positions
56
+ if line.startswith("The pass key is"):
57
+ block = current_pos // 256
58
+ current_pos += line_length # Update the current position
59
+
60
+ random.set_state(rnd_state)
61
+ return "\n".join(lines), str(pass_key), position[2], block
62
+
63
+
64
+ def passkey_retrieval_test(model, tokenizer, device, use_cache=False, n_garbage=60000, seed=666, sequence=None):
65
+ prompt, answer, position, block = generate_prompt_landmark(tokenizer, n_garbage, seed+n_garbage) # 修改 seed 为 seed+n,防止重复
66
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
67
+ input_ids = input_ids.to(device)
68
+ len_token = input_ids.shape[-1]
69
+
70
+ answer_ids = tokenizer(answer, return_tensors="pt").input_ids[:, 1:] # drop BOS
71
+
72
+ generation_output = model.generate(
73
+ input_ids=input_ids, max_new_tokens=answer_ids.shape[-1], num_beams=1, use_cache=True
74
+ )
75
+
76
+ model_answer = generation_output[0, -answer_ids.shape[-1]:].cpu()
77
+
78
+ is_correct = (model_answer == answer_ids[0]).all().item()
79
+ is_split = is_number_in_range(sequence, position)
80
+
81
+
82
+ print(f"The correct answer is {tokenizer.decode(answer_ids[0].cpu())}")
83
+ print(f"The model answer is {tokenizer.decode(model_answer.cpu())}, is_correct : {is_correct}")
84
+ return is_correct, is_split, len_token, position
85
+
86
+
87
+ def generate_sequence():
88
+ sequence = [0]
89
+ increment = 256
90
+
91
+ for i in range(1, 201):
92
+ next_number = sequence[-1] + increment
93
+ sequence.append(next_number)
94
+
95
+ return sequence
96
+
97
+ def is_number_in_range(sequence, position):
98
+ for num in sequence:
99
+ if position[0] <= num and num <= position[1]:
100
+ return True
101
+ return False
102
+
103
+ def main(args):
104
+ device = "cuda:0"
105
+ torch.cuda.set_device(device)
106
+
107
+ print("base model", args.base_model)
108
+
109
+ # Set RoPE scaling factor
110
+ # config = AutoConfig.from_pretrained(
111
+ # args.base_model,
112
+ # cache_dir=args.cache_dir,
113
+ # )
114
+
115
+ model = AutoModelForCausalLM.from_pretrained(
116
+ args.base_model,
117
+ # config=config,
118
+ # cache_dir=args.cache_dir,
119
+ # torch_dtype=torch.float16,
120
+ )
121
+ model = model.to('cuda:0').half()
122
+
123
+ model.resize_token_embeddings(32001)
124
+
125
+ tokenizer = AutoTokenizer.from_pretrained(
126
+ args.base_model,
127
+ # cache_dir=args.cache_dir,
128
+ # model_max_length=args.context_size,
129
+ # padding_side="right",
130
+ # use_fast=False,
131
+ )
132
+
133
+
134
+ sequence = generate_sequence()
135
+
136
+
137
+ # This is a rough ratio to control the number of texts and tokens
138
+ n_garbage = int(3.75 * args.test_k_tokens * 1024 // 1024 * 1024)
139
+ passed_tests = 0
140
+ total_tokens = 0
141
+ for j in tqdm(range(args.num_tests)):
142
+ is_correct, is_split, len_tokens, position = passkey_retrieval_test(model, tokenizer, device, use_cache=True, n_garbage=n_garbage, seed=j, sequence=sequence)
143
+
144
+ passed_tests += is_correct
145
+ total_tokens += len_tokens
146
+ if is_correct:
147
+ print(f" Success: {position},\tis_split: {is_split}", end="", flush=True)
148
+ else:
149
+ print(f" [Fails]: {position},\tis_split: {is_split}", end="", flush=True)
150
+ avg_tokens = total_tokens//args.num_tests
151
+ accuracy = float(passed_tests)/args.num_tests
152
+ print("Accuracy on the token length %d is %f, max GPU allocate %f GB"%(avg_tokens, accuracy, torch.cuda.max_memory_allocated(0) / 1024 / 1024 / 1024), flush=True)
153
+
154
+
155
+
156
+ if __name__ == "__main__":
157
+ args = parse_config()
158
+ main(args)
pyproject.toml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "flame"
3
+ dynamic = ["version"]
4
+ description = "A minimal training framework for scaling FLA models"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Songlin Yang", email = "yangsl66@mit.edu" },
8
+ { name = "Yu Zhang", email = "yzhang.cs@outlook.com" },
9
+ ]
10
+ license = { file = "LICENSE" }
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: MIT License",
14
+ "Operating System :: OS Independent",
15
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
16
+ ]
17
+ requires-python = ">=3.10"
18
+ dependencies = [
19
+ 'torch>=2.5',
20
+ 'torchdata',
21
+ 'transformers>=4.45.0',
22
+ 'triton>=3.0',
23
+ 'datasets>=3.3.0',
24
+ 'einops',
25
+ 'ninja',
26
+ 'wandb',
27
+ 'tiktoken',
28
+ 'tensorboard',
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = ["pytest"]
33
+
34
+ [project.urls]
35
+ Homepage = "https://github.com/fla-org/flame"
36
+
37
+ [build-system]
38
+ requires = ["setuptools>=45", "wheel", "ninja", "torch"]
39
+
40
+ [tool.isort]
41
+ line_length = 127
42
+ multi_line_output = 3
register_softpick.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ from typing import Optional
3
+ import torch
4
+ from transformers import AttentionInterface
5
+ from torch.nn import functional as F
6
+
7
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
8
+ """
9
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
10
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
11
+ """
12
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
13
+ if n_rep == 1:
14
+ return hidden_states
15
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
16
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
17
+
18
+ def softpick(x, dim=-1, eps=1e-8):
19
+ # softpick function: relu(exp(x)-1) / sum(abs(exp(x)-1))
20
+ # numerically stable version
21
+ x_m = torch.max(x, dim=dim, keepdim=True).values
22
+ x_m_e_m = torch.exp(-x_m)
23
+ x_e_1 = torch.exp(x - x_m) - x_m_e_m
24
+ r_x_e_1 = F.relu(x_e_1)
25
+ a_x_e_1 = torch.where(x.isfinite(), torch.abs(x_e_1), 0)
26
+ return r_x_e_1 / (torch.sum(a_x_e_1, dim=dim, keepdim=True) + eps) # epsilon is only useful if all inputs are EXACTLY 0. we might not even need it
27
+
28
+ def naive_softpick_attn(
29
+ module: torch.nn.Module, # required arg
30
+ query: torch.Tensor, # required arg
31
+ key: torch.Tensor, # required arg
32
+ value: torch.Tensor, # required arg
33
+ attention_mask: Optional[torch.Tensor], # required arg
34
+ *args,
35
+ scale: Optional[float] = None,
36
+ cu_seqlens: Optional[torch.LongTensor] = None,
37
+ head_first: bool = False,
38
+ **kwargs
39
+ ) -> torch.Tensor:
40
+ head_dim = query.shape[-1]
41
+
42
+ # In transformers, the shape is (batch_size, num_heads, seq_len, head_dim)
43
+ num_query_heads = query.shape[1]
44
+ num_key_valye_heads = key.shape[1]
45
+
46
+
47
+ if num_query_heads != num_key_valye_heads:
48
+ # MQA or GQA
49
+ key = repeat_kv(key, num_query_heads // num_key_valye_heads)
50
+ value = repeat_kv(value, num_query_heads // num_key_valye_heads)
51
+
52
+ if scale is None:
53
+ scale = 1.0 / (head_dim ** 0.5)
54
+ if not head_first:
55
+ query, key, value = map(lambda x: rearrange(x, 'b t h d -> b h t d'), (query, key, value))
56
+ query_len = query.shape[-2]
57
+ key_len = key.shape[-2]
58
+ mask = torch.tril(torch.ones(key_len, key_len, device=query.device))
59
+ wei = torch.matmul(query, key.transpose(2, 3)) # shape: (batch_size, num_heads, query_len, key_len)
60
+ wei = wei * scale
61
+ wei = wei.masked_fill(mask[key_len-query_len:key_len, :key_len] == 0, float('-inf'))
62
+ wei = softpick(wei.float(), dim=-1).to(query.dtype)
63
+ o = torch.matmul(wei, value) # shape: (batch_size, num_heads, q_len, head_dim)
64
+ if not head_first:
65
+ o = rearrange(o, 'b h t d -> b t h d')
66
+ return o, wei
67
+
68
+ def softpick_attention(*args, **kwargs):
69
+ # print("Using softpick attention") # NOTE: Add print statement here to check whether we actually use softpick or not
70
+ return naive_softpick_attn(*args, **kwargs)
71
+
72
+ AttentionInterface.register("softpick", softpick_attention)
setup.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import ast
4
+ import os
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from setuptools import find_packages, setup
9
+
10
+ with open('README.md') as f:
11
+ long_description = f.read()
12
+
13
+
14
+ def get_package_version():
15
+ with open(Path(os.path.dirname(os.path.abspath(__file__))) / 'flame' / '__init__.py') as f:
16
+ version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
17
+ return ast.literal_eval(version_match.group(1))
18
+
19
+
20
+ setup(
21
+ name='flame',
22
+ version=get_package_version(),
23
+ description='A minimal training framework for scaling FLA models',
24
+ long_description=long_description,
25
+ long_description_content_type='text/markdown',
26
+ author='Songlin Yang, Yu Zhang',
27
+ author_email='yangsl66@mit.edu, yzhang.cs@outlook.com',
28
+ url='https://github.com/fla-org/flame',
29
+ packages=find_packages(),
30
+ license='MIT',
31
+ classifiers=[
32
+ 'Programming Language :: Python :: 3',
33
+ 'License :: OSI Approved :: MIT License',
34
+ 'Operating System :: OS Independent',
35
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence'
36
+ ],
37
+ python_requires='>=3.10',
38
+ install_requires=[
39
+ 'torch>=2.5',
40
+ 'torchdata',
41
+ 'transformers>=4.45.0',
42
+ 'triton>=3.0',
43
+ 'datasets>=3.3.0',
44
+ 'einops',
45
+ 'ninja',
46
+ 'wandb',
47
+ 'tiktoken',
48
+ 'tensorboard',
49
+ ],
50
+ )