CarlOwOs commited on
Commit
8533328
·
verified ·
1 Parent(s): 95c12ea

Add files using upload-large-folder tool

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Language Modeling Made Easy
4
+
5
+ [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/fla-org/flame)
6
+
7
+ </div>
8
+
9
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for language models with blazing efficiency.
10
+
11
+ **Feature Highlights:**
12
+
13
+ - 🚀 Minimal, easy-to-use, extensible training framework
14
+ - 🤗 Seamless integration with `fla` and `transformers`
15
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
16
+ - 🔮 4D parallelism (coming soon)
17
+
18
+ ## Setup
19
+
20
+ To get started, clone the `flame` repository and install the required dependencies:
21
+
22
+ ```bash
23
+ git clone https://github.com/fla-org/flame.git
24
+ cd flame
25
+ pip install .
26
+ ```
27
+
28
+ Install the latest version of fla
29
+ ```
30
+ pip uninstall flash-linear-attention && pip install -U --no-use-pep517 git+https://github.com/fla-org/flash-linear-attention
31
+ ```
32
+
33
+ [Important] Install specific version of torchtitan
34
+ ```
35
+ pip install git+https://github.com/pytorch/torchtitan.git@0b44d4c
36
+ ```
37
+
38
+
39
+ ## Dataset Preparation
40
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
41
+
42
+ ```py
43
+ from datasets import load_dataset
44
+
45
+ # load fineweb-edu with parallel processing
46
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
47
+
48
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
49
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
50
+ ```
51
+
52
+ ## Training Recipes
53
+
54
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus ~~in streaming mode~~. (Do not use streaming mode if you are concerned about resuming training.)
55
+
56
+ > [!WARNING]
57
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
58
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
59
+
60
+ ```sh
61
+ bash train.sh \
62
+ --job.config_file flame/models/fla.toml \
63
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr1e-3.cosine \
64
+ --model.config configs/transformer_340M.json \
65
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
66
+ --optimizer.name AdamW \
67
+ --optimizer.eps 1e-15 \
68
+ --optimizer.lr 1e-3 \
69
+ --lr_scheduler.warmup_steps 1024 \
70
+ --lr_scheduler.lr_min 0.1 \
71
+ --lr_scheduler.decay_type cosine \
72
+ --training.batch_size 1 \
73
+ --training.seq_len 65536 \
74
+ --training.context_len 4096 \
75
+ --training.varlen \
76
+ --training.gradient_accumulation_steps 1 \
77
+ --training.steps 20480 \
78
+ --training.max_norm 1.0 \
79
+ --training.skip_nan_inf \
80
+ --training.dataset HuggingFaceFW/fineweb-edu \
81
+ --training.dataset_name sample-100BT \
82
+ --training.dataset_split train \
83
+ --training.num_workers 32 \
84
+ --training.prefetch_factor 2 \
85
+ --training.seed 42 \
86
+ --training.compile \
87
+ --checkpoint.interval 2048 \
88
+ --checkpoint.load_step -1 \
89
+ --checkpoint.keep_latest_k 2 \
90
+ --metrics.log_freq 1
91
+ ```
92
+
93
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
94
+ **For single-GPU debugging, set `NGPU=1`.**
95
+
96
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
97
+ By default, the learning rate is set to 1e-3 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
98
+
99
+ **Key parameters:**
100
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
101
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
102
+ - `--training.steps`: Total number of training steps.
103
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
104
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
105
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
106
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
107
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
108
+
109
+ > [!WARNING]
110
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
111
+ > Each step processes `global_batch_size * seq_len` tokens.
112
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
113
+
114
+ For a detailed explanation of all parameters, run:
115
+
116
+ ```sh
117
+ bash train.sh -h
118
+ ```
119
+
120
+ <details>
121
+ <summary>Usage</summary>
122
+
123
+ ```py
124
+ options:
125
+ -h, --help show this help message and exit
126
+ --job.config_file JOB.CONFIG_FILE
127
+ Job config file
128
+ --job.dump_folder JOB.DUMP_FOLDER
129
+ Folder to dump job outputs
130
+ --job.description JOB.DESCRIPTION
131
+ Description of the job
132
+ --job.use_for_integration_test
133
+ Add this config to the integration test suite
134
+ --job.print_args Print the args to terminal
135
+ --model.config MODEL.CONFIG
136
+ Path to the model config
137
+ --model.norm_type MODEL.NORM_TYPE
138
+ Type of layer normalization to use [layernorm,
139
+ np_layernorm, rmsnorm, fused_rmsnorm]
140
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
141
+ Tokenizer path
142
+ --profiling.enable_profiling
143
+ Whether to enable pytorch profiler
144
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
145
+ Trace files location
146
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
147
+ How often to collect profiler traces, in iterations
148
+ --profiling.enable_memory_snapshot
149
+ Whether to dump memory snapshot
150
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
151
+ Memeory snapshot files location
152
+ --optimizer.name OPTIMIZER.NAME
153
+ Optimizer to use
154
+ --optimizer.eps OPTIMIZER.EPS
155
+ Epsilon value for the optimizer.
156
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
157
+ --optimizer.scheduler {wsd,cosine,linear}
158
+ Scheduler to use. Currently supported: wsd, cosine,
159
+ and linear.
160
+ --optimizer.lr OPTIMIZER.LR
161
+ Learning rate to use
162
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
163
+ Min lr ratio for lr scheduler
164
+ --optimizer.early_step_in_backward
165
+ Whether to apply optimizer in the backward. Caution,
166
+ optimizer_in_backward is not compatible with gradients
167
+ clipping, users should not call
168
+ register_post_accumulate_grad_hook after the optimizer
169
+ is built.
170
+ --training.batch_size TRAINING.BATCH_SIZE
171
+ Batch size
172
+ --training.seq_len TRAINING.SEQ_LEN
173
+ Sequence length
174
+ --training.context_len TRAINING.CONTEXT_LEN
175
+ Max length allowed for each sequence
176
+ --training.varlen Whether to take sequences of variable length as input
177
+ --training.warmup_steps TRAINING.WARMUP_STEPS
178
+ Steps for lr scheduler warmup, normally 1/5 of
179
+ --training.steps
180
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
181
+ Number of steps to accumulate gradients before
182
+ updating parameters
183
+ --training.steps TRAINING.STEPS
184
+ How many train steps to run
185
+ --training.max_norm TRAINING.MAX_NORM
186
+ Max norm for gradient clipping
187
+ --training.skip_nan_inf
188
+ Skip batch updates when NaN or INF gradients are
189
+ encountered during training
190
+ --training.dataset TRAINING.DATASET
191
+ Dataset to use, with comma separated values
192
+ --training.dataset_name TRAINING.DATASET_NAME
193
+ The name of the dataset config, with comma separated
194
+ values if provided
195
+ --training.dataset_split TRAINING.DATASET_SPLIT
196
+ Dataset split to use, with comma separated values if
197
+ provided
198
+ --training.data_dir TRAINING.DATA_DIR
199
+ Data dirs to use, with comma separated values if
200
+ provided
201
+ --training.data_files TRAINING.DATA_FILES
202
+ Data files to use, with comma separated values if
203
+ provided
204
+ --training.data_probs TRAINING.DATA_PROBS
205
+ Data sampling probabilities, with comma separated
206
+ values if provided
207
+ --training.streaming Whether to load dataset in streaming mode, used for
208
+ huge dataset
209
+ --training.num_workers TRAINING.NUM_WORKERS
210
+ Number of subprocesses to use for data loading. 0
211
+ means that the data will be loaded in the main
212
+ process.
213
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
214
+ Number of batches loaded in advance by each worker.2
215
+ means there will be a total of 2 * num_workers batches
216
+ prefetched across all workers.
217
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
218
+ The `data_parallel_replicate_degree` argument
219
+ specifies the degree of data parallelism for weight
220
+ replication. When this value is greater than 1,
221
+ weights will be replicated across
222
+ `data_parallel_replicate_degree` ranks. If
223
+ `data_parallel_shard_degree` is also greater than 1,
224
+ the parallelism method used is HSDP (Hybrid Sharded
225
+ Data Parallelism). Otherwise, the parallelism method
226
+ used is DDP (Distributed Data Parallelism). 1 means
227
+ disabled.
228
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
229
+ The `data_parallel_shard_degree` argument specifies
230
+ the degree of data parallelism for weight sharding.
231
+ When this value is greater than 1, weights will be
232
+ sharded across `data_parallel_shard_degree` ranks. If
233
+ `data_parallel_replicate_degree` is also greater than
234
+ 1, the parallelism method used is HSDP (Hybrid Sharded
235
+ Data Parallelism). Otherwise, the parallelism method
236
+ used is FSDP (Fully Sharded Data Parallelism). -1
237
+ means leftover ranks will be used (After
238
+ DP_REPLICATE/SP/PP). Note that only
239
+ `data_parallel_shard_degree` can be negative. 1 means
240
+ disabled.
241
+ --training.enable_cpu_offload
242
+ Whether to apply CPU offloading of parameters,
243
+ gradients, and optimizer states in FSDP
244
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
245
+ Tensor Parallelism degree. 1 means disabled.
246
+ --training.disable_loss_parallel
247
+ Whether to apply loss parallel when sequence parallel
248
+ is enabled
249
+ --training.mixed_precision_param {bfloat16,float32}
250
+ torch dtype to use for parameters when applying mixed
251
+ precision via FSDP. This feature only takes effect
252
+ when data_parallel_shard_degree > 1
253
+ --training.mixed_precision_reduce {float32}
254
+ torch dtype to use for reductions when applying mixed
255
+ precision via FSDP. This feature only takes effect
256
+ when data_parallel_shard_degree > 1
257
+ --training.compile Whether to compile the model
258
+ --training.gc_freq TRAINING.GC_FREQ
259
+ Python garbage control scheduling interval, in steps
260
+ --training.seed TRAINING.SEED
261
+ Choose the base RNG seed used for training
262
+ --training.deterministic
263
+ Use deterministic algorithms wherever possible, may be
264
+ slower
265
+ --metrics.log_freq METRICS.LOG_FREQ
266
+ How often to log metrics to TensorBoard, in iterations
267
+ --metrics.enable_tensorboard
268
+ Whether to log metrics to TensorBoard
269
+ --metrics.disable_color_printing
270
+ Whether to disable color printing in logs
271
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
272
+ Folder to dump TensorBoard states
273
+ --metrics.rank_0_only
274
+ Whether to save TensorBoard metrics only for rank 0 or
275
+ for all ranks. When pipeline_parallel_degree is > 1,
276
+ this option uses the 0th rank of the last stage
277
+ pipeline group, which is the only stage that computes
278
+ loss metrics.
279
+ --metrics.enable_wandb
280
+ Whether to log metrics to Weights & Biases
281
+ --experimental.enable_async_tensor_parallel
282
+ Whether to apply async tensor parallel (currently only
283
+ effective when compile is enabled)
284
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
285
+ Pipeline Parallelism degree, or number of ranks. 1
286
+ means disabled. If using looped schedules, this still
287
+ specifies the number of physical ranks, not the number
288
+ of stages. Stages per rank are inferred from split
289
+ points degree, and schedule.
290
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
291
+ Specify comma-separated names of modules to use as the
292
+ beginning of a split point. e.g. "layers.0,layers.2"
293
+ will cause the model to be split into 3 stages, the
294
+ first containing all the layers up to layers.0, the
295
+ second containing layers.0 and up to layers.2, the
296
+ third containing layers.2 and all the remaining
297
+ layers. Note: fully-automated splitting may be enabled
298
+ in the future, but currently the split points must be
299
+ specified manually.
300
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
301
+ Specify the Pipeline Parallel schedule to use. The
302
+ supported schedules are: https://github.com/pytorch/py
303
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
304
+ rch/distributed/pipelining/schedules.py#L2161. The
305
+ schedule must be compatible with the split points and
306
+ stages_per_rank. Looped schedules (e.g.
307
+ Interleaved1F1B) require specifying
308
+ pipeline_parallel_degree = number of ranks, and
309
+ split_points = number of stages - 1
310
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
311
+ Specify the path to the pipeline parallel schedule csv
312
+ file to use. The pipeline_parallel_schedule argument
313
+ must be either PipelineScheduleSingle,
314
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
315
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
316
+ How many microbatches to split the global training
317
+ batch into when using pipeline parallelism. The global
318
+ training batch size must be evenly divisible by the
319
+ number of microbatches. The default value will be the
320
+ number of pipeline stages, if unspecified.
321
+ --experimental.enable_compiled_autograd
322
+ Enable CompiledAutograd to compile the backward.
323
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
324
+ Context parallelism degree. 1 means disabled.
325
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
326
+ The collective to use in context parallel SDPA for kv
327
+ shards exchange. 'allgather' means to all-gather all
328
+ kv shards on ranks after the first sub-SDPA
329
+ computation, 'alltoall' means to all-to-all shuffle
330
+ the kv shards. The default value is 'allgather'.
331
+ --checkpoint.enable_checkpoint
332
+ Whether to enable checkpoint
333
+ --checkpoint.folder CHECKPOINT.FOLDER
334
+ The folder to store the checkpoints. When
335
+ enable_checkpoint is set to true, checkpoints will be
336
+ in {--job.dump_folder}/{--checkpoint.folder}.
337
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
338
+ Checkpointing interval unit of measurement ['step',
339
+ 'seconds']
340
+ --checkpoint.interval CHECKPOINT.INTERVAL
341
+ Checkpointing interval, in steps or seconds depending
342
+ on --checkpoint.interval_type
343
+ --checkpoint.model_weights_only
344
+ When model_weights_only=True, only model weights will
345
+ be saved at the end of training. With this,
346
+ checkpoints can be loaded using `torch.load(...,
347
+ weights_only=True)` after conversion. When
348
+ model_weights_only=False, the full checkpoint will be
349
+ saved. A full checkpoint includes model, optimizer and
350
+ train_state, which can be used to resume training. The
351
+ default value is false.
352
+ --checkpoint.export_dtype {float16,bfloat16,float32}
353
+ Converts to the specified precision when training
354
+ completes and model_weights_only=true. Currently
355
+ supports float32, float16, and bfloat16. The default
356
+ value is float32.
357
+ --checkpoint.create_seed_checkpoint
358
+ Initializes the full model without applying
359
+ parallelisms, and then saves it as a seed checkpoint.
360
+ Note: requires user to call train.py without
361
+ specifying any parallelisms, e.g. NGPU=1. Could be
362
+ implemented as a separate script, but this way shares
363
+ more code.
364
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
365
+ Which async checkpoint mode to use. Currently there
366
+ are 3 different modes. 1. "disabled": synchronized
367
+ checkpointing will be used. 2. "async":
368
+ torch.distributed.checkpoint.async_save will be used.
369
+ 1. "async_with_pinned_mem": this option utilizes a
370
+ dedicated pinned memory space and creates a separate
371
+ process for faster GPU->CPU transfer performance and
372
+ eliminating GIL contention. The cost is increased CPU
373
+ memory usage. If insufficient CPU memory is available,
374
+ performance may degrade due to memory paging. For most
375
+ users, "async" should suffice as the performance
376
+ overhead is typically small (on the order of tens of
377
+ seconds) compared to checkpointing frequency. This
378
+ mode can be employed to pursue near-zero checkpointing
379
+ times (e.g., < 1 second) given appropriate hardware
380
+ support such as ample CPU memory and fast PCIe.
381
+ "disabled" is the default mode.
382
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
383
+ Keeps only the latest k checkpoints, and purging older
384
+ ones. If 0, keep all checkpoints. 0 is the default
385
+ value.
386
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
387
+ Load the checkpoint at the specified step. If -1, load
388
+ the latest checkpoint.
389
+ --float8.enable_float8_linear
390
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
391
+ This feature requires you to install 'torchao' which
392
+ can be found here: https://github.com/pytorch/ao
393
+ --float8.enable_fsdp_float8_all_gather
394
+ Whether enable float8 all-gather in FSDP
395
+ --float8.precompute_float8_dynamic_scale_for_fsdp
396
+ Whether precompute float8 scales dynamically for FSDP
397
+ --float8.scaling_type_input {dynamic,delayed}
398
+ float8 scaling for input, dynamic (default) or delayed
399
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
400
+ float8 scaling for input, dynamic (default) or delayed
401
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
402
+ float8 scaling for input, dynamic (default) or delayed
403
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
404
+ Timeout for communication operations, during
405
+ initialization and first train step.
406
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
407
+ Timeout for communication operations after the first
408
+ train step -- usually a tighter bound than during
409
+ initialization.
410
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
411
+ Flight recorder ring buffer size, >0 means recording
412
+ by default, 0 means disabled
413
+ --memory_estimation.enabled
414
+ Whether to estimate memory usage for FSDP
415
+ --memory_estimation.disable_fake_mode
416
+ Whether to estimate memory under FakeTensorMode
417
+ ```
418
+ </details>
419
+
420
+ ### Training with variable-length inputs
421
+ When you set the `--training.varlen` flag, you're enabling a more efficient training method that packs multiple documents together into a single long sequence, eliminating the need for padding.
422
+ This is particularly useful when your dataset contains documents of varying lengths.
423
+ Let's break down how `--training.seq_len` and `--training.context_len` work in this mode.
424
+
425
+ * `--training.seq_len` (Packed Sequence Length): This is the total length of the final sequence fed to the model on one device. Instead of processing one document at a time, the dataloader takes multiple documents (each split to sequences no longer than `context_len`), concatenates them end-to-end, and creates a single long sequence of length `seq_len`.
426
+ * `--training.context_len` (Sample Length): This parameter defines the maximum number of tokens for a single document or sample. If a document from the dataset is longer than `context_len`, it will be truncated. For example, if `--training.context_len` is set to 4,096, a document with 5,000 tokens will be cut down to its first 4,096 tokens, leaving the left tokens as another independent sequence, while a document with 3000 tokens remains unchanged.
427
+
428
+ ### Training with `torch.compile`
429
+
430
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
431
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
432
+
433
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
434
+ We are actively working on resolving these issues to make compilation transparent to users.
435
+ In the meantime, please ensure you are using the latest dependencies.
436
+
437
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
438
+
439
+ ### Training with multiple datasets
440
+
441
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
442
+ `flame` allows training with multiple datasets easily.
443
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
444
+
445
+ ```sh
446
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
447
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
448
+ ```
449
+
450
+ ### ~Finalizing training~
451
+
452
+ > [!NOTE]
453
+ > We have done this conversion automatically in the training script since our latest updates.
454
+
455
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
456
+ To facilitate this, we provide a straightforward conversion script:
457
+
458
+ ```sh
459
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
460
+ ```
461
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
462
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
463
+
464
+ ### Continual training
465
+
466
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
467
+ This allows you to seamlessly resume training with `flame`.
468
+ ```sh
469
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
470
+ ```
471
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
472
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
473
+
474
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
475
+
476
+ ## Multi-node training
477
+
478
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
479
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
480
+
481
+ To set up multi-node training:
482
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
483
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
484
+
485
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
486
+
487
+ ## Custom models
488
+
489
+ `flame` supports custom model architectures through seamless integration with the Hugging Face `transformers` library. To add your own model:
490
+
491
+ 1. Create a new model directory under `custom_models/` (see `custom_models/sba` for a complete example)
492
+ 2. Implement your model classes and configuration:
493
+ - Define a config class inheriting from `PretrainedConfig` (see `custom_models/sba/config_sba.py` for an example)
494
+ - Create model classes inheriting from `PreTrainedModel` (see `custom_models/sba/modeling_sba.py` for an example)
495
+ 3. Register your models in `__init__.py`:
496
+ - Import your model classes and config classes
497
+ - Register your models with the `AutoModelForCausalLM`, `AutoModel` and `AutoConfig` classes (see `custom_models/sba/__init__.py` for an example)
498
+ 4. Create a config file for your custom model, just need to specify the `model_type` to the one you just named for your custom model (example: `configs/sba_340m.json`).
499
+ 5. Training is extremely simple, you can just use the `flame.train.py` script to train your custom model.
500
+
501
+
502
+
503
+
504
+
505
+
506
+
507
+ ## Citation
508
+
509
+ If you find `flame` helpful for your work, please consider citing it.
510
+
511
+ ```bib
512
+ @software{yang2025flame,
513
+ title = {Flame: Flash Language Modeling Made Easy},
514
+ author = {Zhang, Yu and Yang, Songlin},
515
+ url = {https://github.com/fla-org/flame},
516
+ month = jan,
517
+ year = {2025}
518
+ }
519
+ ```
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if message.content is string %}
27
+ {%- set content = message.content %}
28
+ {%- else %}
29
+ {%- set content = '' %}
30
+ {%- endif %}
31
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
+ {%- elif message.role == "assistant" %}
34
+ {%- set reasoning_content = '' %}
35
+ {%- if message.reasoning_content is string %}
36
+ {%- set reasoning_content = message.reasoning_content %}
37
+ {%- else %}
38
+ {%- if '</think>' in content %}
39
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if loop.index0 > ns.last_query_index %}
44
+ {%- if loop.last or (not loop.last and reasoning_content) %}
45
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
+ {%- else %}
47
+ {{- '<|im_start|>' + message.role + '\n' + content }}
48
+ {%- endif %}
49
+ {%- else %}
50
+ {{- '<|im_start|>' + message.role + '\n' + content }}
51
+ {%- endif %}
52
+ {%- if message.tool_calls %}
53
+ {%- for tool_call in message.tool_calls %}
54
+ {%- if (loop.first and content) or (not loop.first) %}
55
+ {{- '\n' }}
56
+ {%- endif %}
57
+ {%- if tool_call.function %}
58
+ {%- set tool_call = tool_call.function %}
59
+ {%- endif %}
60
+ {{- '<tool_call>\n{"name": "' }}
61
+ {{- tool_call.name }}
62
+ {{- '", "arguments": ' }}
63
+ {%- if tool_call.arguments is string %}
64
+ {{- tool_call.arguments }}
65
+ {%- else %}
66
+ {{- tool_call.arguments | tojson }}
67
+ {%- endif %}
68
+ {{- '}\n</tool_call>' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '<|im_end|>\n' }}
72
+ {%- elif message.role == "tool" %}
73
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
+ {{- '<|im_start|>user' }}
75
+ {%- endif %}
76
+ {{- '\n<tool_response>\n' }}
77
+ {{- content }}
78
+ {{- '\n</tool_response>' }}
79
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
+ {{- '<|im_end|>\n' }}
81
+ {%- endif %}
82
+ {%- endif %}
83
+ {%- endfor %}
84
+ {%- if add_generation_prompt %}
85
+ {{- '<|im_start|>assistant\n' }}
86
+ {%- if enable_thinking is defined and enable_thinking is false %}
87
+ {{- '<think>\n\n</think>\n\n' }}
88
+ {%- endif %}
89
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GSAForCausalLM"
4
+ ],
5
+ "attn": null,
6
+ "bos_token_id": 151643,
7
+ "clamp_max": null,
8
+ "clamp_min": null,
9
+ "conv_size": 4,
10
+ "dtype": "bfloat16",
11
+ "elementwise_affine": false,
12
+ "eos_token_id": 151645,
13
+ "expand_k": 1,
14
+ "expand_v": 1,
15
+ "feature_map": "swish",
16
+ "fuse_cross_entropy": true,
17
+ "fuse_linear_cross_entropy": false,
18
+ "fuse_norm": true,
19
+ "fuse_swiglu": true,
20
+ "gate_logit_normalizer": 8,
21
+ "hidden_act": "swish",
22
+ "hidden_ratio": 4,
23
+ "hidden_size": 5120,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 17408,
26
+ "max_position_embeddings": 40960,
27
+ "model_type": "gsa",
28
+ "norm_eps": 1e-06,
29
+ "num_heads": 40,
30
+ "num_hidden_layers": 40,
31
+ "num_kv_heads": 8,
32
+ "num_slots": 256,
33
+ "rope_theta": 1000000,
34
+ "share_conv_kernel": true,
35
+ "tie_word_embeddings": true,
36
+ "transformers_version": "4.57.3",
37
+ "use_cache": true,
38
+ "use_l2warp": false,
39
+ "use_norm": true,
40
+ "use_output_gate": true,
41
+ "use_rope": false,
42
+ "use_short_conv": false,
43
+ "vocab_size": 151936
44
+ }
configs/delta_net_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "conv_size": 4,
6
+ "eos_token_id": 2,
7
+ "expand_k": 1,
8
+ "expand_v": 1,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": null,
16
+ "model_type": "delta_net",
17
+ "norm_eps": 1e-06,
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 2,
21
+ "qk_activation": "silu",
22
+ "qk_norm": "l2",
23
+ "tie_word_embeddings": false,
24
+ "use_beta": true,
25
+ "use_cache": true,
26
+ "use_gate": false,
27
+ "use_output_norm": true,
28
+ "use_short_conv": true
29
+ }
configs/delta_net_340M.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": null,
14
+ "model_type": "delta_net",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 8,
17
+ "num_hidden_layers": 24,
18
+ "qk_activation": "silu",
19
+ "qk_norm": "l2",
20
+ "tie_word_embeddings": false,
21
+ "use_beta": true,
22
+ "use_cache": true,
23
+ "use_gate": false,
24
+ "use_output_norm": true,
25
+ "use_short_conv": true
26
+ }
configs/gated_deltanet_1B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_v": 2,
7
+ "fuse_cross_entropy": true,
8
+ "head_dim": 256,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 2048,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": null,
14
+ "model_type": "gated_deltanet",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 6,
17
+ "num_hidden_layers": 21,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "use_gate": true,
21
+ "use_short_conv": true
22
+ }
configs/gated_deltanet_340M.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_v": 2,
7
+ "fuse_cross_entropy": true,
8
+ "head_dim": 256,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": null,
14
+ "model_type": "gated_deltanet",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 6,
17
+ "num_hidden_layers": 21,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "use_gate": true,
21
+ "use_short_conv": true
22
+ }
configs/gay_14B.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "conv_size": 4,
4
+ "eos_token_id": 151645,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 8,
12
+ "hidden_act": "swish",
13
+ "hidden_size": 5120,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 17408,
16
+ "max_position_embeddings": 40960,
17
+ "model_type": "gsa",
18
+ "num_heads": 40,
19
+ "num_hidden_layers": 40,
20
+ "num_kv_heads": 8,
21
+ "num_slots": 256,
22
+ "norm_eps": 1e-06,
23
+ "share_conv_kernel": true,
24
+ "tie_word_embeddings": true,
25
+ "torch_dtype": "bfloat16",
26
+ "use_cache": true,
27
+ "use_norm": true,
28
+ "use_output_gate": true,
29
+ "use_rope": false,
30
+ "use_short_conv": false,
31
+ "vocab_size": 151936
32
+ }
configs/gay_1B.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "conv_size": 4,
4
+ "eos_token_id": 151645,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 8,
12
+ "hidden_act": "swish",
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 6144,
16
+ "max_position_embeddings": 40960,
17
+ "model_type": "gsa",
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 28,
20
+ "num_kv_heads": 8,
21
+ "num_slots": 256,
22
+ "norm_eps": 1e-06,
23
+ "share_conv_kernel": true,
24
+ "tie_word_embeddings": true,
25
+ "torch_dtype": "bfloat16",
26
+ "use_cache": true,
27
+ "use_norm": true,
28
+ "use_output_gate": true,
29
+ "use_rope": false,
30
+ "use_short_conv": false,
31
+ "vocab_size": 151936
32
+ }
configs/gayted_deltanet_1B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 151643,
4
+ "conv_size": 4,
5
+ "eos_token_id": 151645,
6
+ "expand_v": 1,
7
+ "fuse_cross_entropy": true,
8
+ "head_dim": 128,
9
+ "hidden_act": "swish",
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 6144,
13
+ "max_position_embeddings": 40960,
14
+ "model_type": "gated_deltanet",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 16,
17
+ "num_hidden_layers": 28,
18
+ "num_v_heads": 8,
19
+ "tie_word_embeddings": true,
20
+ "torch_dtype": "bfloat16",
21
+ "use_cache": true,
22
+ "use_gate": true,
23
+ "use_short_conv": true,
24
+ "vocab_size": 151936
25
+ }
configs/gla_340M.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": null,
15
+ "model_type": "gla",
16
+ "num_heads": 4,
17
+ "num_hidden_layers": 24,
18
+ "norm_eps": 1e-06,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "vocab_size": 32000
24
+ }
configs/gla_7B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 11008,
15
+ "model_type": "gla",
16
+ "norm_eps": 1e-06,
17
+ "num_heads": 16,
18
+ "num_hidden_layers": 32,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "use_output_gate": true,
24
+ "use_short_conv": false
25
+ }
configs/gsa_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 8,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 2048,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
configs/hgrn2_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "expand_ratio": 128,
6
+ "fuse_cross_entropy": true,
7
+ "fuse_norm": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 1024,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": null,
13
+ "model_type": "hgrn2",
14
+ "num_heads": 8,
15
+ "num_hidden_layers": 24,
16
+ "norm_eps": 1e-06,
17
+ "tie_word_embeddings": false,
18
+ "use_cache": true,
19
+ "vocab_size": 32000
20
+ }
configs/mamba2_1B.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "chunk_size": 256,
4
+ "conv_kernel": 4,
5
+ "eos_token_id": 2,
6
+ "expand": 2,
7
+ "fuse_cross_entropy": true,
8
+ "fuse_norm": true,
9
+ "head_dim": 64,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 2048,
12
+ "initializer_range": 0.02,
13
+ "norm_eps": 1e-05,
14
+ "model_type": "mamba2",
15
+ "n_groups": 1,
16
+ "num_hidden_layers": 48,
17
+ "pad_token_id": 0,
18
+ "rescale_prenorm_residual": true,
19
+ "residual_in_fp32": true,
20
+ "rms_norm": true,
21
+ "state_size": 128,
22
+ "tie_word_embeddings": false,
23
+ "time_step_floor": 0.0001,
24
+ "time_step_max": 0.1,
25
+ "time_step_min": 0.001,
26
+ "time_step_rank": 128,
27
+ "transformers_version": "4.50.1",
28
+ "use_bias": false,
29
+ "use_cache": true,
30
+ "use_conv_bias": true,
31
+ "vocab_size": 32000
32
+ }
configs/mamba2_340M.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "chunk_size": 256,
4
+ "conv_kernel": 4,
5
+ "eos_token_id": 2,
6
+ "expand": 2,
7
+ "fuse_cross_entropy": true,
8
+ "fuse_norm": true,
9
+ "head_dim": 64,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "norm_eps": 1e-05,
14
+ "model_type": "mamba2",
15
+ "n_groups": 1,
16
+ "num_hidden_layers": 48,
17
+ "pad_token_id": 0,
18
+ "rescale_prenorm_residual": true,
19
+ "residual_in_fp32": true,
20
+ "rms_norm": true,
21
+ "state_size": 128,
22
+ "tie_word_embeddings": false,
23
+ "time_step_floor": 0.0001,
24
+ "time_step_max": 0.1,
25
+ "time_step_min": 0.001,
26
+ "time_step_rank": 128,
27
+ "transformers_version": "4.50.1",
28
+ "use_bias": false,
29
+ "use_cache": true,
30
+ "use_conv_bias": true,
31
+ "vocab_size": 32000
32
+ }
configs/mamba_1B.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_kernel": 4,
4
+ "eos_token_id": 2,
5
+ "expand": 2,
6
+ "fuse_cross_entropy": true,
7
+ "fuse_norm": true,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 2048,
10
+ "initializer_range": 0.02,
11
+ "model_type": "mamba",
12
+ "norm_eps": 1e-05,
13
+ "num_hidden_layers": 48,
14
+ "pad_token_id": 0,
15
+ "rescale_prenorm_residual": false,
16
+ "residual_in_fp32": false,
17
+ "state_size": 16,
18
+ "tie_word_embeddings": false,
19
+ "time_step_floor": 0.0001,
20
+ "time_step_init_scheme": "random",
21
+ "time_step_max": 0.1,
22
+ "time_step_min": 0.001,
23
+ "time_step_rank": 128,
24
+ "time_step_scale": 1.0,
25
+ "transformers_version": "4.50.1",
26
+ "use_bias": false,
27
+ "use_cache": true,
28
+ "use_conv_bias": true,
29
+ "vocab_size": 32000
30
+ }
configs/mamba_340M.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_kernel": 4,
4
+ "eos_token_id": 2,
5
+ "expand": 2,
6
+ "fuse_cross_entropy": true,
7
+ "fuse_norm": true,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 1024,
10
+ "initializer_range": 0.02,
11
+ "model_type": "mamba",
12
+ "norm_eps": 1e-05,
13
+ "num_hidden_layers": 48,
14
+ "pad_token_id": 0,
15
+ "rescale_prenorm_residual": false,
16
+ "residual_in_fp32": false,
17
+ "state_size": 16,
18
+ "tie_word_embeddings": false,
19
+ "time_step_floor": 0.0001,
20
+ "time_step_init_scheme": "random",
21
+ "time_step_max": 0.1,
22
+ "time_step_min": 0.001,
23
+ "time_step_rank": 128,
24
+ "time_step_scale": 1.0,
25
+ "transformers_version": "4.50.1",
26
+ "use_bias": false,
27
+ "use_cache": true,
28
+ "use_conv_bias": true,
29
+ "vocab_size": 32000
30
+ }
configs/routmem_1.7B.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "conv_size": 4,
4
+ "eos_token_id": 151645,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 8,
12
+ "hidden_act": "swish",
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 6144,
16
+ "max_position_embeddings": 40960,
17
+ "model_type": "routmem",
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 28,
20
+ "num_kv_heads": 8,
21
+ "num_slots": 256,
22
+ "norm_eps": 1e-06,
23
+ "share_conv_kernel": true,
24
+ "tie_word_embeddings": true,
25
+ "torch_dtype": "bfloat16",
26
+ "use_cache": true,
27
+ "use_norm": true,
28
+ "use_output_gate": true,
29
+ "use_rope": false,
30
+ "use_short_conv": false,
31
+ "vocab_size": 151936,
32
+ "add_gumbel_noise": true,
33
+ "router_score": "sigmoid",
34
+ "router_type": "lin"
35
+ }
configs/routmem_14B.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "conv_size": 4,
4
+ "eos_token_id": 151645,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 8,
12
+ "hidden_act": "swish",
13
+ "hidden_size": 5120,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 17408,
16
+ "max_position_embeddings": 40960,
17
+ "model_type": "routmem",
18
+ "num_heads": 40,
19
+ "num_hidden_layers": 40,
20
+ "num_kv_heads": 8,
21
+ "num_slots": 256,
22
+ "norm_eps": 1e-06,
23
+ "share_conv_kernel": true,
24
+ "tie_word_embeddings": true,
25
+ "torch_dtype": "bfloat16",
26
+ "use_cache": true,
27
+ "use_norm": true,
28
+ "use_output_gate": true,
29
+ "use_rope": false,
30
+ "use_short_conv": false,
31
+ "vocab_size": 151936,
32
+ "add_gumbel_noise": true,
33
+ "router_score": "sigmoid",
34
+ "router_type": "lin"
35
+ }
configs/routmem_340M.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": null,
17
+ "model_type": "routmem",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 256,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false,
29
+ "bias_rmm": false,
30
+ "add_gumbel_noise": true,
31
+ "router_score": "sigmoid",
32
+ "router_type": "lin"
33
+ }
configs/samba_1B.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": {
3
+ "layers": [
4
+ 1,
5
+ 3,
6
+ 5,
7
+ 7,
8
+ 9,
9
+ 11,
10
+ 13,
11
+ 15,
12
+ 17
13
+ ],
14
+ "num_heads": 18,
15
+ "num_kv_heads": 18,
16
+ "qkv_bias": false,
17
+ "rope_theta": 10000.0,
18
+ "window_size": 2048
19
+ },
20
+ "bos_token_id": 1,
21
+ "conv_kernel": 4,
22
+ "eos_token_id": 2,
23
+ "expand": 2,
24
+ "fuse_cross_entropy": true,
25
+ "fuse_norm": true,
26
+ "fuse_swiglu": true,
27
+ "hidden_act": "swish",
28
+ "hidden_ratio": 4,
29
+ "hidden_size": 2304,
30
+ "initializer_range": 0.02,
31
+ "intermediate_size": 4608,
32
+ "max_position_embeddings": 2048,
33
+ "model_type": "samba",
34
+ "norm_eps": 1e-05,
35
+ "num_hidden_layers": 18,
36
+ "pad_token_id": 0,
37
+ "rescale_prenorm_residual": false,
38
+ "residual_in_fp32": false,
39
+ "state_size": 16,
40
+ "tie_word_embeddings": false,
41
+ "time_step_floor": 0.0001,
42
+ "time_step_init_scheme": "random",
43
+ "time_step_max": 0.1,
44
+ "time_step_min": 0.001,
45
+ "time_step_rank": 144,
46
+ "time_step_scale": 1.0,
47
+ "transformers_version": "4.50.1",
48
+ "use_bias": false,
49
+ "use_cache": true,
50
+ "use_conv_bias": true,
51
+ "vocab_size": 32000
52
+ }
configs/sba_340m.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "sba",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_1B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 24,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false
22
+ }
configs/transformer_340M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_7B.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 32,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null
21
+ }
flame/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
flame/components/__init__.py ADDED
File without changes
flame/components/checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+ from datetime import timedelta
9
+ from io import BytesIO
10
+ from typing import Any, Dict, List
11
+
12
+ import torch
13
+ from torch.distributed.checkpoint.stateful import Stateful
14
+
15
+
16
+ @dataclass
17
+ class TrainState(Stateful):
18
+ step: int = 0
19
+ skipped_step: int = 0
20
+ token: int = 0
21
+ elapsed: timedelta = timedelta(0)
22
+ global_avg_losses: List[float] = field(default_factory=list)
23
+ global_max_losses: List[float] = field(default_factory=list)
24
+ log_steps: List[int] = field(default_factory=list)
25
+
26
+ def state_dict(self) -> Dict[str, Any]:
27
+ # Only checkpoint global_avg_losses and global_max_losses per log frequency
28
+ # to avoid sync overhead in every iteration.
29
+ global_avg_losses_bytes = BytesIO()
30
+ torch.save(self.global_avg_losses, global_avg_losses_bytes)
31
+ global_max_losses_bytes = BytesIO()
32
+ torch.save(self.global_max_losses, global_max_losses_bytes)
33
+ log_steps_bytes = BytesIO()
34
+ torch.save(self.log_steps, log_steps_bytes)
35
+ return {
36
+ "step": torch.tensor(self.step, dtype=torch.int32),
37
+ "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
38
+ "token": torch.tensor(self.token, dtype=torch.int64),
39
+ "elapsed": self.elapsed,
40
+ "global_avg_losses": global_avg_losses_bytes,
41
+ "global_max_losses": global_max_losses_bytes,
42
+ "log_steps": log_steps_bytes,
43
+ }
44
+
45
+ def load_state_dict(self, state_dict) -> None:
46
+ self.step = state_dict["step"].item()
47
+ self.skipped_step = state_dict.get("skipped_step", 0).item()
48
+ self.token = state_dict["token"].item()
49
+ self.elapsed = state_dict["elapsed"]
50
+ state_dict["global_avg_losses"].seek(0)
51
+ self.global_avg_losses = torch.load(
52
+ state_dict["global_avg_losses"], weights_only=False
53
+ )
54
+ state_dict["global_max_losses"].seek(0)
55
+ self.global_max_losses = torch.load(
56
+ state_dict["global_max_losses"], weights_only=False
57
+ )
58
+ state_dict["log_steps"].seek(0)
59
+ self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
flame/config_manager.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import sys
9
+ from collections import defaultdict
10
+ from typing import Tuple
11
+
12
+ import torch
13
+
14
+ try:
15
+ import tomllib
16
+ except ModuleNotFoundError:
17
+ import tomli as tomllib
18
+
19
+ from torchtitan.tools.logging import logger
20
+
21
+ TORCH_DTYPE_MAP = {
22
+ "float16": torch.float16,
23
+ "float32": torch.float32,
24
+ "bfloat16": torch.bfloat16,
25
+ }
26
+
27
+
28
+ def string_list(raw_arg):
29
+ """Comma-separated string list argument."""
30
+ return [s.strip() for s in raw_arg.split(",") if s.strip()]
31
+
32
+
33
+ def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
34
+ section, name = fullargname.split(".")
35
+ # Split string list which are still raw strings.
36
+ if (
37
+ section in args_dict
38
+ and name in args_dict[section]
39
+ and isinstance(args_dict[section][name], str)
40
+ ):
41
+ sec = args_dict[section]
42
+ sec[name] = string_list(sec[name])
43
+
44
+
45
+ class JobConfig:
46
+ """
47
+ A helper class to manage the train configuration.
48
+ Semantics:
49
+ - Default config is loaded from a toml file. If no toml file is provided,
50
+ then the default config is loaded from argparse defaults.
51
+ - if toml file has missing keys, they are filled with argparse defaults.
52
+ - if additional explicit cmd args are provided in addition to the toml
53
+ file, they will override the toml config and the argparse defaults
54
+
55
+ precedence order: cmdline > toml > argparse default
56
+
57
+ Arg parsing semantics:
58
+
59
+ Each argument starts with <prefix>_ which is the section name in the toml file
60
+ followed by name of the option in the toml file. For ex,
61
+ model.name translates to:
62
+ [model]
63
+ name
64
+ in the toml file
65
+ """
66
+
67
+ def __init__(self):
68
+ self.args_dict = None
69
+ # main parser
70
+ self.parser = argparse.ArgumentParser(description="torchtitan arg parser.")
71
+
72
+ self.parser.add_argument(
73
+ "--job.config_file",
74
+ type=str,
75
+ default=None,
76
+ help="Job config file",
77
+ )
78
+
79
+ # job level configs
80
+ self.parser.add_argument(
81
+ "--job.dump_folder",
82
+ type=str,
83
+ default="./torchtitan/outputs",
84
+ help="Folder to dump job outputs",
85
+ )
86
+ self.parser.add_argument(
87
+ "--job.description",
88
+ type=str,
89
+ default="default job",
90
+ help="Description of the job",
91
+ )
92
+ self.parser.add_argument(
93
+ "--job.use_for_integration_test",
94
+ action="store_true",
95
+ help="Add this config to the integration test suite",
96
+ )
97
+ self.parser.add_argument(
98
+ "--job.print_args",
99
+ action="store_true",
100
+ help="Print the args to terminal",
101
+ )
102
+
103
+ # model configs
104
+ self.parser.add_argument(
105
+ "--model.name",
106
+ type=str,
107
+ default="fla",
108
+ help="Which model to train",
109
+ )
110
+ self.parser.add_argument(
111
+ "--model.config",
112
+ type=str,
113
+ default="fla-hub/transformer-1.3B-100B",
114
+ help="Path to the model config",
115
+ )
116
+ self.parser.add_argument(
117
+ "--model.tokenizer_path",
118
+ type=str,
119
+ default="fla-hub/transformer-1.3B-100B",
120
+ help="Tokenizer path",
121
+ )
122
+ self.parser.add_argument(
123
+ "--model.converters",
124
+ type=string_list,
125
+ nargs="+",
126
+ default=[],
127
+ help="""
128
+ Comma separated list of converters to apply to the model.
129
+ For instance, the `float8` converter swaps `torch.nn.Linear`
130
+ with `Float8Linear`. This feature requires you to install 'torchao'
131
+ which can be found here: https://github.com/pytorch/ao
132
+ """,
133
+ )
134
+ self.parser.add_argument(
135
+ "--model.print_after_conversion",
136
+ action="store_true",
137
+ help="""
138
+ If true, model definition will be printed to stdout after all model
139
+ converters have been applied.
140
+ """,
141
+ )
142
+
143
+ # profiling configs
144
+ self.parser.add_argument(
145
+ "--profiling.enable_profiling",
146
+ action="store_true",
147
+ help="Whether to enable pytorch profiler",
148
+ )
149
+ self.parser.add_argument(
150
+ "--profiling.save_traces_folder",
151
+ type=str,
152
+ default="profile_traces",
153
+ help="Trace files location",
154
+ )
155
+ self.parser.add_argument(
156
+ "--profiling.profile_freq",
157
+ type=int,
158
+ default=10,
159
+ help="How often to collect profiler traces, in iterations",
160
+ )
161
+ self.parser.add_argument(
162
+ "--profiling.enable_memory_snapshot",
163
+ action="store_true",
164
+ help="Whether to dump memory snapshot",
165
+ )
166
+ self.parser.add_argument(
167
+ "--profiling.save_memory_snapshot_folder",
168
+ type=str,
169
+ default="memory_snapshot",
170
+ help="Memeory snapshot files location",
171
+ )
172
+
173
+ # optimizer configs
174
+ self.parser.add_argument(
175
+ "--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
176
+ )
177
+ self.parser.add_argument(
178
+ "--optimizer.eps",
179
+ type=float,
180
+ default=1e-8,
181
+ help="Epsilon value for the optimizer.",
182
+ )
183
+ self.parser.add_argument(
184
+ "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
185
+ )
186
+ self.parser.add_argument(
187
+ "--optimizer.beta1", type=float, default=0.9,
188
+ help="Exponential moving average hyperparameters to use"
189
+ )
190
+ self.parser.add_argument(
191
+ "--optimizer.beta2", type=float, default=0.95,
192
+ help="Exponential moving average hyperparameters to use"
193
+ )
194
+ self.parser.add_argument(
195
+ "--optimizer.weight_decay", type=float, default=0.1,
196
+ help="Weight decay to use"
197
+ )
198
+ self.parser.add_argument(
199
+ "--optimizer.implementation",
200
+ type=str,
201
+ default="fused",
202
+ choices=["for-loop", "foreach", "fused"],
203
+ help="""
204
+ Specify which optimizer implementation to use:
205
+ - 'fused': Use fused implementation (CUDA only) for best performance.
206
+ - 'foreach': Use some horizontal fusion of tensors for better performance.
207
+ - 'for-loop': Use the default implementation for the optimizer (slowest).
208
+ - more info: https://pytorch.org/docs/stable/optim.html
209
+ """,
210
+ )
211
+ self.parser.add_argument(
212
+ "--optimizer.early_step_in_backward",
213
+ action="store_true",
214
+ help="""
215
+ Whether to apply optimizer in the backward. Caution, optimizer_in_backward
216
+ is not compatible with gradients clipping, users should not call
217
+ register_post_accumulate_grad_hook after the optimizer is built.""",
218
+ )
219
+
220
+ # lr scheduler configs
221
+ self.parser.add_argument(
222
+ "--lr_scheduler.warmup_steps",
223
+ type=int,
224
+ default=200,
225
+ help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
226
+ )
227
+ self.parser.add_argument(
228
+ "--lr_scheduler.decay_ratio",
229
+ type=float,
230
+ default=None,
231
+ help="""
232
+ Controls the proportion of the training steps allocated to the learning rate decay phase.
233
+
234
+ If `None`, the learning rate will begin decaying immediately after the warmup period.
235
+ Otherwise, the learning rate will remain stable after the warmup period and
236
+ only start decaying during the last `decay_ratio` portion of the total training steps.
237
+
238
+ This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395.
239
+ """,
240
+ )
241
+ self.parser.add_argument(
242
+ "--lr_scheduler.decay_type",
243
+ type=str,
244
+ default="linear",
245
+ choices=["linear", "sqrt", "cosine"],
246
+ help="""
247
+ Learning rate decay type to use during training:
248
+ - 'linear': linearly decays learning rate from initial to final value
249
+ - 'sqrt': decays learning rate following a 1 minus square root curve
250
+ - 'cosine': smoothly decays learning rate following a cosine curve
251
+ """,
252
+ )
253
+ self.parser.add_argument(
254
+ "--lr_scheduler.lr_min",
255
+ type=float,
256
+ default=0.0,
257
+ help="""
258
+ Min lr ratio for lr scheduler.
259
+
260
+ If provided, the range of decay factor is scaled from 1 to `lr_min`
261
+ to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
262
+ """,
263
+ )
264
+
265
+ # training configs
266
+ self.parser.add_argument(
267
+ "--training.batch_size", type=int, default=8, help="Batch size"
268
+ )
269
+ self.parser.add_argument(
270
+ "--training.seq_len", type=int, default=2048, help="Sequence length"
271
+ )
272
+ self.parser.add_argument(
273
+ "--training.context_len",
274
+ type=int,
275
+ default=2048,
276
+ help="Max length allowed for each sequence",
277
+ )
278
+ self.parser.add_argument(
279
+ "--training.varlen",
280
+ action="store_true",
281
+ help="Whether to take sequences of variable length as input",
282
+ )
283
+ self.parser.add_argument(
284
+ "--training.gradient_accumulation_steps",
285
+ type=int,
286
+ default=1,
287
+ help="Number of steps to accumulate gradients before updating parameters",
288
+ )
289
+ self.parser.add_argument(
290
+ "--training.steps",
291
+ type=int,
292
+ default=10000,
293
+ help="How many train steps to run",
294
+ )
295
+ self.parser.add_argument(
296
+ "--training.max_norm",
297
+ type=float,
298
+ default=1.0,
299
+ help="Max norm for gradient clipping",
300
+ )
301
+ self.parser.add_argument(
302
+ "--training.skip_nan_inf",
303
+ action="store_true",
304
+ help="Skip batch updates when NaN or INF gradients are encountered during training",
305
+ )
306
+ self.parser.add_argument(
307
+ "--training.dataset",
308
+ default="HuggingFaceFW/fineweb-edu",
309
+ help="Dataset to use, with comma separated values",
310
+ )
311
+ self.parser.add_argument(
312
+ "--training.dataset_name",
313
+ default=None,
314
+ help="The name of the dataset config, with comma separated values if provided",
315
+ )
316
+ self.parser.add_argument(
317
+ "--training.dataset_split",
318
+ default=None,
319
+ help="Dataset split to use, with comma separated values if provided",
320
+ )
321
+ self.parser.add_argument(
322
+ "--training.data_dir",
323
+ default=None,
324
+ help="Data dirs to use, with comma separated values if provided",
325
+ )
326
+ self.parser.add_argument(
327
+ "--training.data_files",
328
+ default=None,
329
+ help="Data files to use, with comma separated values if provided",
330
+ )
331
+ self.parser.add_argument(
332
+ "--training.data_probs",
333
+ default=None,
334
+ help="Data sampling probabilities, with comma separated values if provided",
335
+ )
336
+ self.parser.add_argument(
337
+ "--training.streaming",
338
+ action="store_true",
339
+ help="Whether to load dataset in streaming mode, used for huge dataset",
340
+ )
341
+ self.parser.add_argument(
342
+ "--training.num_workers",
343
+ type=int,
344
+ default=32,
345
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
346
+ )
347
+ self.parser.add_argument(
348
+ "--training.prefetch_factor",
349
+ type=int,
350
+ default=2,
351
+ help="Number of batches loaded in advance by each worker."
352
+ "2 means there will be a total of 2 * num_workers batches prefetched across all workers.",
353
+ )
354
+ self.parser.add_argument(
355
+ "--training.data_parallel_replicate_degree",
356
+ type=int,
357
+ default=1,
358
+ help="""
359
+ The `data_parallel_replicate_degree` argument specifies the degree of
360
+ data parallelism for weight replication. When this value is greater
361
+ than 1, weights will be replicated across `data_parallel_replicate_degree`
362
+ ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
363
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
364
+ parallelism method used is DDP (Distributed Data Parallelism).
365
+ 1 means disabled.""",
366
+ )
367
+ self.parser.add_argument(
368
+ "--training.data_parallel_shard_degree",
369
+ type=int,
370
+ default=-1,
371
+ help="""
372
+ The `data_parallel_shard_degree` argument specifies the degree of data
373
+ parallelism for weight sharding. When this value is greater than 1, weights
374
+ will be sharded across `data_parallel_shard_degree` ranks. If
375
+ `data_parallel_replicate_degree` is also greater than 1, the parallelism
376
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
377
+ parallelism method used is FSDP (Fully Sharded Data Parallelism).
378
+
379
+ -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
380
+ only `data_parallel_shard_degree` can be negative. 1 means disabled.""",
381
+ )
382
+ self.parser.add_argument(
383
+ "--training.enable_cpu_offload",
384
+ action="store_true",
385
+ help="""
386
+ Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
387
+ )
388
+ self.parser.add_argument(
389
+ "--training.tensor_parallel_degree",
390
+ type=int,
391
+ default=1,
392
+ help="Tensor Parallelism degree. 1 means disabled.",
393
+ )
394
+ self.parser.add_argument(
395
+ "--training.disable_loss_parallel",
396
+ action="store_true",
397
+ help="Whether to apply loss parallel when sequence parallel is enabled",
398
+ )
399
+ self.parser.add_argument(
400
+ "--training.fsdp_reshard_after_forward",
401
+ type=str,
402
+ default="default",
403
+ choices=["default", "always", "never"],
404
+ help="""
405
+ `reshard_after_forward` specifies the policy for applying `reshard_after_forward`
406
+ within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
407
+ trading off memory and communication. See torch's `fully_shard` API for more documentation
408
+ on `reshard_after_forward`.
409
+ The supported policies include "default", "always" and "never":
410
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal
411
+ scenarios.
412
+ - "always" will enable `reshard_after_forward` for all forward passes.
413
+ - "never" will disable `reshard_after_forward` for all forward passes.
414
+ """,
415
+ )
416
+ self.parser.add_argument(
417
+ "--training.mixed_precision_param",
418
+ type=str,
419
+ default="bfloat16",
420
+ choices=["bfloat16", "float32"],
421
+ help="""
422
+ torch dtype to use for parameters when applying mixed precision via fully_shard or torch.autocast.
423
+ This feature takes effect via fully_shard when data_parallel_shard_degree > 1 or
424
+ context_parallel_degree > 1; it takes effect via torch.autocast when data_replicate_degree >= 1
425
+ and no other parallelism is enabled, i.e. under DDP or single-device training.
426
+ """,
427
+ )
428
+ self.parser.add_argument(
429
+ "--training.mixed_precision_reduce",
430
+ type=str,
431
+ default="float32",
432
+ choices=["float32"],
433
+ help="""
434
+ torch dtype to use for reductions when applying mixed precision via FSDP.
435
+ This feature only takes effect when data_parallel_shard_degree > 1
436
+ """,
437
+ )
438
+ self.parser.add_argument(
439
+ "--training.compile",
440
+ action="store_true",
441
+ help="Whether to compile the model",
442
+ )
443
+ self.parser.add_argument(
444
+ "--training.gc_freq",
445
+ type=int,
446
+ default=50,
447
+ help="Python garbage control scheduling interval, in steps",
448
+ )
449
+ self.parser.add_argument(
450
+ "--training.seed",
451
+ type=int,
452
+ default=42,
453
+ help="Choose the base RNG seed used for training",
454
+ )
455
+ self.parser.add_argument(
456
+ "--training.deterministic",
457
+ action="store_true",
458
+ help="Use deterministic algorithms wherever possible, may be slower",
459
+ )
460
+ # metrics configs
461
+ self.parser.add_argument(
462
+ "--metrics.log_freq",
463
+ type=int,
464
+ default=10,
465
+ help="How often to log metrics to TensorBoard, in iterations",
466
+ )
467
+ self.parser.add_argument(
468
+ "--metrics.enable_tensorboard",
469
+ action="store_true",
470
+ help="Whether to log metrics to TensorBoard",
471
+ )
472
+ self.parser.add_argument(
473
+ "--metrics.disable_color_printing",
474
+ action="store_true",
475
+ help="Whether to disable color printing in logs",
476
+ )
477
+ self.parser.add_argument(
478
+ "--metrics.save_tb_folder",
479
+ type=str,
480
+ default="tb",
481
+ help="Folder to dump TensorBoard states",
482
+ )
483
+ self.parser.add_argument(
484
+ "--metrics.save_for_all_ranks",
485
+ action="store_true",
486
+ default=False,
487
+ help="""
488
+ Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks.
489
+ When this option is False and pipeline_parallel_degree is > 1, the metrics
490
+ component uses the 0th rank of the last stage pipeline group, which is the
491
+ only stage that computes loss metrics.
492
+ """,
493
+ )
494
+ self.parser.add_argument(
495
+ "--metrics.enable_wandb",
496
+ action="store_true",
497
+ help="Whether to log metrics to Weights & Biases",
498
+ )
499
+
500
+ self.parser.add_argument(
501
+ "--experimental.enable_async_tensor_parallel",
502
+ action="store_true",
503
+ help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
504
+ )
505
+ self.parser.add_argument(
506
+ "--experimental.pipeline_parallel_degree",
507
+ type=int,
508
+ default=1,
509
+ help="""
510
+ Pipeline Parallelism degree, or number of ranks. 1 means disabled.
511
+ If using looped schedules, this still specifies the number of physical ranks, not the number
512
+ of stages. Stages per rank are inferred from split points degree, and schedule.""",
513
+ )
514
+ self.parser.add_argument(
515
+ "--experimental.pipeline_parallel_split_points",
516
+ type=string_list,
517
+ nargs="+",
518
+ default=[],
519
+ help="""
520
+ Specify comma-separated names of modules to use as the beginning of a split point.
521
+
522
+ e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
523
+ the first containing all the layers up to layers.0,
524
+ the second containing layers.0 and up to layers.2,
525
+ the third containing layers.2 and all the remaining layers.
526
+
527
+ Note: fully-automated splitting may be enabled in the future,
528
+ but currently the split points must be specified manually.""",
529
+ )
530
+ self.parser.add_argument(
531
+ "--experimental.pipeline_parallel_schedule",
532
+ type=str,
533
+ default="1F1B",
534
+ help="""
535
+ Specify the Pipeline Parallel schedule to use. The supported schedules are:
536
+ https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161.
537
+ The schedule must be compatible with the split points and stages_per_rank.
538
+
539
+ Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
540
+ and split_points = number of stages - 1
541
+ """,
542
+ )
543
+ self.parser.add_argument(
544
+ "--experimental.pipeline_parallel_schedule_csv",
545
+ type=str,
546
+ default="",
547
+ help="""
548
+ Specify the path to the pipeline parallel schedule csv file to use.
549
+ The pipeline_parallel_schedule argument must be either
550
+ PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime.
551
+ """,
552
+ )
553
+
554
+ self.parser.add_argument(
555
+ "--experimental.pipeline_parallel_microbatches",
556
+ type=int,
557
+ default=None,
558
+ help="""
559
+ How many microbatches to split the global training batch into when using pipeline parallelism.
560
+
561
+ The global training batch size must be evenly divisible by the number of microbatches.
562
+
563
+ The default value will be the number of pipeline stages, if unspecified.
564
+ """,
565
+ )
566
+ self.parser.add_argument(
567
+ "--experimental.enable_compiled_autograd",
568
+ action="store_true",
569
+ help="Enable CompiledAutograd to compile the backward.",
570
+ )
571
+ self.parser.add_argument(
572
+ "--experimental.context_parallel_degree",
573
+ type=int,
574
+ default=1,
575
+ help="Context parallelism degree. 1 means disabled.",
576
+ )
577
+ self.parser.add_argument(
578
+ "--experimental.context_parallel_rotate_method",
579
+ type=str,
580
+ default="allgather",
581
+ help="""
582
+ The collective to use in context parallel SDPA for kv shards exchange.
583
+
584
+ 'allgather' means to all-gather all kv shards on ranks after the first sub-SDPA computation,
585
+
586
+ 'alltoall' means to all-to-all shuffle the kv shards.
587
+
588
+ The default value is 'allgather'.
589
+ """,
590
+ )
591
+ # I'm not particularly fond of this. Users can choose to write their own wrapper
592
+ # module and import TorchTitan training loop and execute it, which look cleaner.
593
+ # One reason to provide this option is to allow users to use the existing run script.
594
+ # While the script is pretty trivial now, we may add more logic when integrating
595
+ # with TorchFT.
596
+ # This option is subject to change and may be deleted in the future.
597
+ self.parser.add_argument(
598
+ "--experimental.custom_model_path",
599
+ type=str,
600
+ default="",
601
+ help="""
602
+ The --custom_model_path option allows to specify a custom path to a model module
603
+ that is not natively implemented within TorchTitan.
604
+ Acceptable values are the file system path to the module (e.g., my_models/model_x)
605
+ dotted import module (e.g., some_package.model_x).
606
+ """,
607
+ )
608
+ # checkpointing configs
609
+ self.parser.add_argument(
610
+ "--checkpoint.enable_checkpoint",
611
+ action="store_true",
612
+ help="Whether to enable checkpoint",
613
+ )
614
+ self.parser.add_argument(
615
+ "--checkpoint.folder",
616
+ type=str,
617
+ default="checkpoint",
618
+ help="""
619
+ The folder to store the checkpoints.
620
+ When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
621
+ """,
622
+ )
623
+ self.parser.add_argument(
624
+ "--checkpoint.initial_load_path", type=str, default=None,
625
+ help="""
626
+ This option specifies the path to the initial checkpoint to load, which is
627
+ particularly useful for resuming training from a previous run with a
628
+ different output path or when loading a checkpoint from a pre-trained model.
629
+ If the checkpoint folder for the current run is not empty,
630
+ located at {--job.dump_folder}/{--checkpoint.folder}, this option will be ignored.
631
+ This feature allows users to load an initial checkpoint from a different folder and
632
+ continue training, saving new checkpoints to the specified folder without affecting
633
+ the existing ones.
634
+
635
+ Note that the path should contain the full path to the checkpoint folder,
636
+ including the step number, if any; for example,
637
+ "//pre_train/checkpoints/llama3/llama3_8b/step_10000".
638
+ """
639
+ )
640
+ self.parser.add_argument(
641
+ "--checkpoint.initial_load_model_weights_only",
642
+ dest='checkpoint.initial_load_model_weights_only', action="store_true", default=True,
643
+ help="""
644
+ This option specifies if only the model weights should be loaded during the initial
645
+ checkpoint load. The option is only used when `initial_load_path` is specified, and
646
+ only applies to a model_weights_only checkpoint. Loading a periodic checkpoint
647
+ may lead to unexpected behavior if this option is set to True.
648
+ If False, the checkpoint at `initial_load_path` is treated as a standard training
649
+ checkpoint, including optimizer and training states.
650
+ The default setting for this option is True. Note that you will have to use
651
+ `--checkpoint.no_initial_load_model_weights_only` to override the default setting.
652
+ """
653
+ )
654
+ self.parser.add_argument(
655
+ "--checkpoint.no_initial_load_model_weights_only",
656
+ dest='checkpoint.initial_load_model_weights_only', action="store_false",
657
+ )
658
+ self.parser.add_argument(
659
+ "--checkpoint.interval",
660
+ type=int,
661
+ default=500,
662
+ help="Checkpointing interval in steps.",
663
+ )
664
+ self.parser.add_argument(
665
+ "--checkpoint.last_save_model_weights_only",
666
+ action="store_true",
667
+ help="""
668
+ When last_save_model_weights_only=True, only model weights will be saved at the end of training,
669
+ the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)`
670
+ after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved.
671
+ A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
672
+ The default value is false.
673
+ """,
674
+ )
675
+ self.parser.add_argument(
676
+ "--checkpoint.export_dtype",
677
+ type=str,
678
+ default="float32",
679
+ choices=["float16", "bfloat16", "float32"],
680
+ help="""
681
+ Converts to the specified precision when training completes and model_weights_only=true.
682
+ Currently supports float32, float16, and bfloat16.
683
+ The default value is float32.
684
+ """,
685
+ )
686
+ self.parser.add_argument(
687
+ "--checkpoint.create_seed_checkpoint",
688
+ action="store_true",
689
+ help="""
690
+ Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint.
691
+ Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1.
692
+ Could be implemented as a separate script, but this way shares more code.
693
+ """,
694
+ )
695
+ self.parser.add_argument(
696
+ "--checkpoint.async_mode",
697
+ type=str,
698
+ default="disabled",
699
+ help="""
700
+ Which async checkpoint mode to use. Currently there are 3 different modes.
701
+ 1. "disabled": synchronized checkpointing will be used.
702
+ 2. "async": torch.distributed.checkpoint.async_save will be used.
703
+ 3. "async_with_pinned_mem": this option utilizes a dedicated pinned memory
704
+ space and creates a separate process for faster GPU->CPU transfer
705
+ performance and eliminating GIL contention. The cost is increased CPU
706
+ memory usage. If insufficient CPU memory is available, performance may
707
+ degrade due to memory paging. For most users, "async" should suffice as
708
+ the performance overhead is typically small (on the order of tens of
709
+ seconds) compared to checkpointing frequency. This mode can be employed
710
+ to pursue near-zero checkpointing times (e.g., < 1 second) given
711
+ appropriate hardware support such as ample CPU memory and fast PCIe.
712
+
713
+ "disabled" is the default mode.
714
+ """,
715
+ )
716
+ self.parser.add_argument(
717
+ "--checkpoint.keep_latest_k",
718
+ type=int,
719
+ default=0,
720
+ help="""
721
+ Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints.
722
+ 0 is the default value. k cannot be 1 as the last one may be in the process of being
723
+ saved. As a result, the metadata of the last one may not be ready yet.
724
+ """,
725
+ )
726
+ self.parser.add_argument(
727
+ "--checkpoint.load_step",
728
+ type=int,
729
+ default=-1,
730
+ help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
731
+ )
732
+ self.parser.add_argument(
733
+ "--checkpoint.exclude_from_loading",
734
+ type=string_list,
735
+ nargs="*",
736
+ default=[],
737
+ help="""
738
+ Exclude specific keys from being loaded from the checkpoint.
739
+ Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
740
+ This will load the model only, excluding the specified keys.
741
+ """,
742
+ )
743
+ # activation checkpointing configs
744
+ self.parser.add_argument(
745
+ "--activation_checkpoint.mode",
746
+ type=str,
747
+ default="selective",
748
+ help="Type of activation checkpointing to use ['none', 'full', 'selective']",
749
+ )
750
+ self.parser.add_argument(
751
+ "--activation_checkpoint.selective_ac_option",
752
+ type=str,
753
+ default="2", # 2 = checkpoint every other layer
754
+ help="""
755
+ Selective activation checkpointing options ['int', 'op'].
756
+ 'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
757
+ """,
758
+ )
759
+
760
+ self.parser.add_argument(
761
+ "--activation_offload.mode",
762
+ type=str,
763
+ default="none",
764
+ help="""
765
+ if we are using activation offload or not. Options are ['none', 'full'].
766
+ """,
767
+ )
768
+
769
+ # float8 configs
770
+ self.parser.add_argument(
771
+ "--float8.enable_fsdp_float8_all_gather",
772
+ action="store_true",
773
+ help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
774
+ )
775
+ self.parser.add_argument(
776
+ "--float8.precompute_float8_dynamic_scale_for_fsdp",
777
+ action="store_true",
778
+ help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
779
+ )
780
+ self.parser.add_argument(
781
+ "--float8.force_recompute_fp8_weight_in_bwd",
782
+ action="store_true",
783
+ help="""
784
+ Whether to force the recomputation of FP8 weights during backward pass.
785
+ When using FSDP with tensorwise scaling, it is recommended to enable
786
+ `force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
787
+ for backward computation.
788
+ """,
789
+ )
790
+ self.parser.add_argument(
791
+ "--float8.recipe_name",
792
+ type=str,
793
+ default=None,
794
+ choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
795
+ help="""
796
+ If specified, creates float8 config from recipe name, valid choices are
797
+ `tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
798
+ """,
799
+ )
800
+
801
+ # communications library settings
802
+ self.parser.add_argument(
803
+ "--comm.init_timeout_seconds",
804
+ type=int,
805
+ default=300,
806
+ help="Timeout for communication operations, during initialization and first train step.",
807
+ )
808
+ self.parser.add_argument(
809
+ "--comm.train_timeout_seconds",
810
+ type=int,
811
+ default=100,
812
+ help=(
813
+ "Timeout for communication operations after the first train step -- "
814
+ "usually a tighter bound than during initialization."
815
+ ),
816
+ )
817
+ self.parser.add_argument(
818
+ "--comm.trace_buf_size",
819
+ type=int,
820
+ default=20000,
821
+ help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
822
+ )
823
+
824
+ # memory estimation settings
825
+ self.parser.add_argument(
826
+ "--memory_estimation.enabled",
827
+ help="Whether to estimate memory usage for FSDP",
828
+ action="store_true",
829
+ )
830
+
831
+ self.parser.add_argument(
832
+ "--memory_estimation.disable_fake_mode",
833
+ help="Whether to estimate memory under FakeTensorMode",
834
+ action="store_true",
835
+ )
836
+
837
+ self.parser.add_argument(
838
+ "--fault_tolerance.enable",
839
+ action="store_true",
840
+ help="""
841
+ Enable TorchFT integration. When TorchFT is enabled, HSDP will be used.
842
+ And --fault_tolerance.data_parallel_replicate_degree should be 1 and
843
+ --fault_tolerance.group_size will be used to control the maximum
844
+ replicate group size as the replicate group size is dynamic.
845
+
846
+ Note that this is still an experimental feature.
847
+ """,
848
+ )
849
+
850
+ self.parser.add_argument(
851
+ "--fault_tolerance.replica_id",
852
+ type=int,
853
+ default=0,
854
+ help="The TorchFT replica ID of this run.",
855
+ )
856
+
857
+ self.parser.add_argument(
858
+ "--fault_tolerance.group_size",
859
+ type=int,
860
+ default=0,
861
+ help="""
862
+ The number of TorchFT replicate groups. This number will be used for
863
+ dataloader to split the dataset across the replicate groups and FSDP
864
+ dimension
865
+ """,
866
+ )
867
+
868
+ self.parser.add_argument(
869
+ "--fault_tolerance.min_replica_size",
870
+ type=int,
871
+ default=1,
872
+ help="The minimum number of FT replica for each step.",
873
+ )
874
+
875
+ def to_dict(self):
876
+ return self.args_dict
877
+
878
+ def parse_args(self, args_list: list = sys.argv[1:]):
879
+ args, cmd_args = self.parse_args_from_command_line(args_list)
880
+ config_file = getattr(args, "job.config_file", None)
881
+ # build up a two level dict
882
+ args_dict = self._args_to_two_level_dict(args)
883
+ if config_file is not None:
884
+ try:
885
+ with open(config_file, "rb") as f:
886
+ for k, v in tomllib.load(f).items():
887
+ # to prevent overwrite of non-specified keys
888
+ args_dict[k] |= v
889
+ except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
890
+ logger.exception(
891
+ f"Error while loading the configuration file: {config_file}"
892
+ )
893
+ logger.exception(f"Error details: {str(e)}")
894
+ raise e
895
+
896
+ # Checking string-list arguments are properly split into a list
897
+ # if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
898
+ string_list_argnames = self._get_string_list_argument_names()
899
+ for n in string_list_argnames:
900
+ check_string_list_argument(args_dict, n)
901
+
902
+ # override args dict with cmd_args
903
+ cmd_args_dict = self._args_to_two_level_dict(cmd_args)
904
+ for section, section_args in cmd_args_dict.items():
905
+ for k, v in section_args.items():
906
+ args_dict[section][k] = v
907
+
908
+ self.args_dict = args_dict
909
+
910
+ for k, v in args_dict.items():
911
+ class_type = type(k.title(), (), v)
912
+ setattr(self, k, class_type())
913
+ self._validate_config()
914
+
915
+ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
916
+ args_dict = defaultdict(defaultdict)
917
+ for k, v in vars(args).items():
918
+ first_level_key, second_level_key = k.split(".", 1)
919
+ args_dict[first_level_key][second_level_key] = v
920
+ return args_dict
921
+
922
+ def _validate_config(self) -> None:
923
+ # TODO: Add more mandatory validations
924
+ assert self.model.config
925
+ assert self.model.tokenizer_path
926
+
927
+ def _get_string_list_argument_names(self) -> list[str]:
928
+ """Get the parser argument names of type `string_list`."""
929
+ string_list_args = [
930
+ v.dest for v in self.parser._actions if v.type is string_list
931
+ ]
932
+ return string_list_args
933
+
934
+ def parse_args_from_command_line(
935
+ self, args_list
936
+ ) -> Tuple[argparse.Namespace, argparse.Namespace]:
937
+ """
938
+ Parse command line arguments and return the parsed args and the command line only args
939
+ """
940
+ args = self.parser.parse_args(args_list)
941
+ string_list_argnames = set(self._get_string_list_argument_names())
942
+
943
+ # aux parser to parse the command line only args, with no defaults from main parser
944
+ aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
945
+ for arg, val in vars(args).items():
946
+ if isinstance(val, bool):
947
+ aux_parser.add_argument(
948
+ "--" + arg, action="store_true" if val else "store_false"
949
+ )
950
+ elif arg in string_list_argnames:
951
+ # without this special case, type inference breaks here,
952
+ # since the inferred type is just 'list' and it ends up flattening
953
+ # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
954
+ aux_parser.add_argument("--" + arg, type=string_list)
955
+ else:
956
+ aux_parser.add_argument("--" + arg, type=type(val))
957
+
958
+ cmd_args, _ = aux_parser.parse_known_args(args_list)
959
+
960
+ return args, cmd_args
flame/data.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import pickle
7
+ from copy import deepcopy
8
+ from dataclasses import dataclass
9
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
10
+
11
+ import datasets
12
+ import numpy as np
13
+ import torch
14
+ from datasets import Dataset, IterableDataset, interleave_datasets, load_dataset
15
+ from datasets.iterable_dataset import ShufflingConfig
16
+ from torch.distributed.checkpoint.stateful import Stateful
17
+ from torchdata.stateful_dataloader import StatefulDataLoader
18
+ from transformers import PreTrainedTokenizer
19
+
20
+ from torchtitan.tools import utils
21
+ from torchtitan.tools.logging import logger
22
+
23
+
24
+ class BufferShuffledIterableDataset(IterableDataset):
25
+ def __init__(
26
+ self,
27
+ dataset: Dataset,
28
+ tokenizer: PreTrainedTokenizer,
29
+ seq_len: int = 2048,
30
+ rank: int = 0,
31
+ world_size: int = 1,
32
+ buffer_size: int = 1024,
33
+ ) -> BufferShuffledIterableDataset:
34
+ self.dataset = dataset
35
+ self.tokenizer = tokenizer
36
+
37
+ self.data = dataset.shard(world_size, rank)
38
+ self.seq_len = seq_len
39
+
40
+ self.rank = rank
41
+ self.world_size = world_size
42
+ self.buffer_size = buffer_size
43
+
44
+ if tokenizer.vocab_size < torch.iinfo(torch.uint16).max:
45
+ self.dtype = torch.uint16
46
+ elif tokenizer.vocab_size < torch.iinfo(torch.uint32).max:
47
+ self.dtype = torch.uint32
48
+ else:
49
+ self.dtype = torch.uint64
50
+ self.states = None
51
+ self.buffer = torch.tensor([], dtype=self.dtype)
52
+ self.tokens = []
53
+ self.rand_id = 0
54
+ self.token_id = 0
55
+ self.rng_state = None
56
+ self._epoch = 0
57
+
58
+ def __iter__(self):
59
+ g = torch.Generator()
60
+ g.manual_seed(self._epoch + self.rank)
61
+ if self.rng_state is not None:
62
+ g.set_state(self.rng_state)
63
+
64
+ rand_it = self.randint(0, self.buffer_size, g=g)
65
+ if self.states is not None:
66
+ self.data.load_state_dict(self.states)
67
+
68
+ # max number of tokens allowed in the chunk buffer
69
+ n_tokens = self.buffer_size * self.seq_len
70
+
71
+ while True:
72
+ for sample in self.tokenize(self.data):
73
+ # keep appending the samples to the token buffer
74
+ self.tokens += sample
75
+ # if the token buffer is full, start sampling
76
+ # NOTE: we first convert the token ids to a tensor of shape [n_chunks, seq_len] for efficiency
77
+ if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
78
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
79
+ self.tokens = self.tokens[n_tokens:]
80
+ if len(self.buffer) == self.buffer_size:
81
+ yield from self.sample(rand_it)
82
+
83
+ n_chunks = len(self.tokens) // self.seq_len
84
+ # handle the left tokens in the buffer
85
+ if n_chunks > 0:
86
+ n_tokens = n_chunks * self.seq_len
87
+ indices = torch.randperm(n_chunks, generator=g).tolist()
88
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
89
+ self.tokens = self.tokens[n_tokens:]
90
+ for i in indices:
91
+ yield {'input_ids': self.buffer[i]}
92
+
93
+ def tokenize(self, data, batch_size: int = 64):
94
+ texts, states = [], []
95
+ for sample in data:
96
+ texts.append(sample['text'])
97
+ states.append(self.data.state_dict())
98
+ if len(texts) == batch_size:
99
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
100
+ self.states = s
101
+ yield tokenized
102
+ texts, states = [], []
103
+ if len(texts) > 0:
104
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
105
+ self.states = s
106
+ yield tokenized
107
+
108
+ def sample(self, indices):
109
+ n_tokens = (len(self.tokens) // self.seq_len) * self.seq_len
110
+ while self.token_id < n_tokens:
111
+ i = next(indices)
112
+ start, end = self.token_id, self.token_id + self.seq_len
113
+ self.token_id += self.seq_len
114
+ yield {'input_ids': self.buffer[i].to(torch.long)}
115
+ self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
116
+ self.token_id = 0
117
+ self.tokens = self.tokens[n_tokens:]
118
+
119
+ def randint(self, low: int, high: int, buffer_size: int = 1024, g: torch.Generator = torch.Generator()) -> Iterable[int]:
120
+ indices = torch.empty(buffer_size, dtype=torch.long)
121
+ while True:
122
+ # record the generator states before sampling
123
+ self.rng_state = g.get_state()
124
+ indices = torch.randint(low, high, (buffer_size,), out=indices, generator=g)
125
+ for i in indices[self.rand_id:].tolist():
126
+ self.rand_id += 1
127
+ yield i
128
+ self.rand_id = 0
129
+
130
+ def set_epoch(self, epoch):
131
+ self._epoch = epoch
132
+ if hasattr(self.dataset, 'set_epoch'):
133
+ self.dataset.set_epoch(epoch)
134
+
135
+ def state_dict(self):
136
+ return {
137
+ 'states': self.states,
138
+ 'buffer': self.buffer.clone(),
139
+ 'tokens': deepcopy(self.tokens),
140
+ 'rand_id': self.rand_id,
141
+ 'token_id': self.token_id,
142
+ 'rng_state': self.rng_state,
143
+ 'epoch': self._epoch,
144
+ }
145
+
146
+ def load_state_dict(self, state_dict):
147
+ self.states = state_dict['states']
148
+ self.buffer = state_dict['buffer'].clone()
149
+ self.tokens = deepcopy(state_dict['tokens'])
150
+ self.rand_id = state_dict['rand_id']
151
+ self.token_id = state_dict['token_id']
152
+ self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
153
+ self._epoch = state_dict['epoch']
154
+
155
+
156
+ class OnlineTokenizedIterableDataset(IterableDataset):
157
+ def __init__(
158
+ self, dataset: Dataset, tokenizer: PreTrainedTokenizer, seq_len: int = 2048, rank: int = 0, world_size: int = 1
159
+ ) -> OnlineTokenizedIterableDataset:
160
+ self.dataset = dataset
161
+ self.tokenizer = tokenizer
162
+
163
+ self.data = dataset.shard(world_size, rank)
164
+ self.seq_len = seq_len
165
+ self.rank = rank
166
+ self.world_size = world_size
167
+
168
+ self.states = None
169
+ self.tokens = []
170
+
171
+ def __iter__(self):
172
+ if self.states is not None:
173
+ self.data.load_state_dict(self.states)
174
+
175
+ while True:
176
+ for sample in self.tokenize(self.data):
177
+ # keep appending the samples to the token buffer
178
+ self.tokens += sample
179
+
180
+ while len(self.tokens) >= self.seq_len:
181
+ input_ids = torch.tensor(self.tokens[:self.seq_len], dtype=torch.long)
182
+ self.tokens = self.tokens[self.seq_len:]
183
+ yield {'input_ids': input_ids}
184
+
185
+ def tokenize(self, data, buffer_size: int = 64):
186
+ buffer, states = [], []
187
+ for sample in data:
188
+ if sample.get('text', None) is not None:
189
+ buffer.append(sample['text'])
190
+ elif sample.get('content', None) is not None:
191
+ buffer.append(sample['content'])
192
+ else:
193
+ raise ValueError(f"No 'text' or 'content' field found in sample:\n{sample}")
194
+ states.append(self.data.state_dict())
195
+ if len(buffer) == buffer_size:
196
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
197
+ self.states = s
198
+ yield tokenized
199
+ buffer, states = [], []
200
+ if len(buffer) > 0:
201
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
202
+ self.states = s
203
+ yield tokenized
204
+
205
+ def state_dict(self):
206
+ return {'states': self.states, 'tokens': deepcopy(self.tokens)}
207
+
208
+ def load_state_dict(self, state_dict):
209
+ self.states = state_dict['states']
210
+ self.tokens = deepcopy(state_dict['tokens'])
211
+
212
+
213
+ class BufferShuffledExamplesIterable(datasets.iterable_dataset.BufferShuffledExamplesIterable):
214
+ def __init__(self, *args, **kwargs):
215
+ super().__init__(*args, **kwargs)
216
+
217
+ def _init_state_dict(self) -> dict:
218
+ self._state_dict = self.ex_iterable._init_state_dict()
219
+ self._state_dict['mem_buffer'] = ([],)
220
+ self._state_dict['bit_generator_state'] = self.generator.bit_generator.state
221
+ self._state_dict['bit_generator_index_offset'] = 0
222
+ self._state_dict['bit_generator_index_offset_shuffle'] = 0
223
+ return self._state_dict
224
+
225
+ def __iter__(self):
226
+ buffer_size = self.buffer_size
227
+ rng = deepcopy(self.generator)
228
+ # this is the shuffle buffer that we keep in memory
229
+ mem_buffer = self._state_dict['mem_buffer'][0]
230
+ # this is an infinite iterator that randomly samples the index of the source to pick examples from
231
+ index_offset = self._state_dict['bit_generator_index_offset'] if self._state_dict else 0
232
+ if self._state_dict:
233
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
234
+ indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size=buffer_size)
235
+ # skip already consumed ones
236
+ for _ in range(index_offset):
237
+ i = next(indices_iterator)
238
+
239
+ for x in self.ex_iterable:
240
+ if len(mem_buffer) < buffer_size: # if the buffer is not full, keep filling the buffer
241
+ mem_buffer.append(x)
242
+ else: # otherwise, pick an example from it
243
+ i = next(indices_iterator)
244
+ index_offset = (index_offset + 1) % buffer_size
245
+ if self._state_dict:
246
+ self._state_dict['bit_generator_index_offset'] = index_offset
247
+ if index_offset == 0:
248
+ self._state_dict['bit_generator_state'] = rng.bit_generator.state
249
+ selected = mem_buffer[i]
250
+ mem_buffer[i] = x # replace the picked example by a new one
251
+ yield selected
252
+
253
+ index_offset = self._state_dict['bit_generator_index_offset_shuffle'] if self._state_dict else 0
254
+ if self._state_dict:
255
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
256
+
257
+ # when we run out of examples, we shuffle the remaining examples in the buffer and yield them
258
+ for i in rng.permutation(len(mem_buffer))[index_offset:].tolist():
259
+ index_offset = index_offset + 1
260
+ if self._state_dict:
261
+ self._state_dict['bit_generator_index_offset_shuffle'] = index_offset
262
+ yield mem_buffer[i]
263
+
264
+ def shuffle_data_sources(self, generator: np.random.Generator) -> BufferShuffledExamplesIterable:
265
+ """Shuffle the wrapped examples iterable as well as the shuffling buffer."""
266
+ return BufferShuffledExamplesIterable(
267
+ self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
268
+ )
269
+
270
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> BufferShuffledExamplesIterable:
271
+ """Keep only the requested shard."""
272
+ return BufferShuffledExamplesIterable(
273
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
274
+ buffer_size=self.buffer_size,
275
+ generator=self.generator,
276
+ )
277
+
278
+ def load_state_dict(self, state_dict: dict) -> dict:
279
+ def _inner_load_state_dict(state, new_state):
280
+ if new_state is not None and isinstance(state, dict):
281
+ for key in new_state:
282
+ state[key] = _inner_load_state_dict(state[key], new_state[key])
283
+ return state
284
+ elif new_state is not None and isinstance(state, list):
285
+ for i in range(len(state)):
286
+ state[i] = _inner_load_state_dict(state[i], new_state[i])
287
+ return state
288
+ return new_state
289
+
290
+ return _inner_load_state_dict(self._state_dict, state_dict)
291
+
292
+
293
+ def shuffle(
294
+ dataset: IterableDataset,
295
+ seed: int = 42,
296
+ generator: np.random.Generator = None,
297
+ buffer_size: int = 1024,
298
+ ):
299
+ generator = np.random.default_rng(seed) if generator is None else deepcopy(generator)
300
+ return IterableDataset(
301
+ ex_iterable=BufferShuffledExamplesIterable(dataset._ex_iterable, buffer_size=buffer_size, generator=generator),
302
+ info=dataset._info.copy(),
303
+ split=dataset._split,
304
+ formatting=dataset._formatting,
305
+ shuffling=ShufflingConfig(generator=generator, _original_seed=seed),
306
+ distributed=copy.deepcopy(dataset._distributed),
307
+ token_per_repo_id=dataset._token_per_repo_id,
308
+ )
309
+
310
+
311
+ @dataclass
312
+ class DataCollatorForLanguageModeling:
313
+ """
314
+ Data collator used for language modeling. Inputs are dynamically padded if `varlen=False`.
315
+ If `varlen=True`, sequences are expected to be concatenated, and labels match inputs.
316
+
317
+ Args:
318
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
319
+ The tokenizer used for encoding the data.
320
+ context_len (`int`, optional):
321
+ When `varlen=True`, sequences longer than this length within a document
322
+ (as determined by `cu_seqlens`) will be further chunked.
323
+ varlen (`bool`):
324
+ Whether to handle variable length concatenated sequences (`True`) or padded batches (`False`).
325
+
326
+ Returns:
327
+ A dictionary with the following keys:
328
+ - `input_ids`: Tensor of input IDs. Shape `[batch_size, seq_len]` if `varlen=False`, `[1, total_len]` if `varlen=True`.
329
+ - `labels`: Tensor of labels. Shape matches `input_ids`. Padding positions are masked with -100 if `varlen=False`.
330
+ - `attention_mask`: Tensor indicating non-padding tokens (only if `varlen=False`). Shape matches `input_ids`.
331
+ - `cu_seqlens`: Tensor of cumulative sequence lengths (only if `varlen=True`). Shape `[1, num_sequences + 1]`.
332
+
333
+ NOTE: When `varlen=True`, the `batch_size` must be 1.
334
+ """
335
+
336
+ tokenizer: PreTrainedTokenizer
337
+ context_len: Optional[int] = None
338
+ varlen: bool = False
339
+
340
+ def __call__(self, examples: List[Union[List[int], Dict[str, Any]]]) -> Dict[str, Any]:
341
+ if not isinstance(examples[0], Dict):
342
+ examples = [{'input_ids': example} for example in examples]
343
+
344
+ def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
345
+ tensorized = {}
346
+ for key in ['input_ids', 'cu_seqlens']:
347
+ if key not in example:
348
+ continue
349
+ if isinstance(example[key], List):
350
+ tensorized[key] = torch.tensor(example[key], dtype=torch.long)
351
+ elif isinstance(example[key], np.ndarray):
352
+ tensorized[key] = torch.from_numpy(example[key])
353
+ else:
354
+ tensorized[key] = example[key]
355
+ return tensorized
356
+
357
+ examples = list(map(tensorize, examples))
358
+
359
+ if not self.varlen:
360
+ # --- Handling for varlen=False (Batch Padding) ---
361
+ length_of_first = examples[0]['input_ids'].size(0)
362
+ needs_padding = not all(example['input_ids'].size(0) == length_of_first for example in examples)
363
+
364
+ if needs_padding:
365
+ # Check for pad token if padding is actually required
366
+ if self.tokenizer.pad_token_id is None:
367
+ raise ValueError(
368
+ f'You are attempting to pad samples but the tokenizer you are using '
369
+ f'({self.tokenizer.__class__.__name__}) does not have a pad token.'
370
+ )
371
+ # Pad using the tokenizer, ensuring attention_mask is returned
372
+ batch = self.tokenizer.pad(examples, return_tensors='pt', return_attention_mask=True)
373
+ else:
374
+ # No padding needed, stack directly and create a full attention mask
375
+ input_ids = torch.stack([example['input_ids'] for example in examples], dim=0)
376
+ batch = {
377
+ 'input_ids': input_ids,
378
+ # Create attention mask of all ones
379
+ 'attention_mask': torch.ones_like(input_ids),
380
+ }
381
+
382
+ # Create labels by cloning input_ids
383
+ labels = batch['input_ids'].clone()
384
+ # Mask labels only where attention_mask is 0 (padding positions)
385
+ if 'attention_mask' in batch:
386
+ labels[batch['attention_mask'] == 0] = -100
387
+ batch['labels'] = labels
388
+
389
+ else:
390
+ # --- Handling for varlen=True (Concatenated Sequences) ---
391
+ if len(examples) > 1:
392
+ raise ValueError('The batch size must be 1 for inputs with variable lengths (varlen=True).')
393
+
394
+ batch = {'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)}
395
+
396
+ # --- cu_seqlens calculation logic remains the same ---
397
+ if 'cu_seqlens' in examples[0]:
398
+ batch['cu_seqlens'] = (
399
+ torch.cat([example['cu_seqlens'] for example in examples], dim=0).unsqueeze(0).to(dtype=torch.int32)
400
+ ) # Ensure int32
401
+ else:
402
+ # determine boundaries by bos/eos positions
403
+ # Check for bos_token_id first
404
+ if self.tokenizer.bos_token_id is not None:
405
+ cu_seqlens = []
406
+ # Handle case where the sequence doesn't start with BOS
407
+ if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
408
+ cu_seqlens.append(torch.tensor([0], device=batch['input_ids'].device)) # Match device
409
+ # Find all BOS token positions
410
+ bos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1]
411
+ # Ensure bos_positions is on the correct device if empty
412
+ if bos_positions.numel() == 0 and len(cu_seqlens) > 0:
413
+ cu_seqlens.append(bos_positions.to(cu_seqlens[0].device))
414
+ elif bos_positions.numel() > 0:
415
+ cu_seqlens.append(bos_positions)
416
+ # Add the end of the entire batch
417
+ cu_seqlens.append(
418
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
419
+ ) # Match device and use size(1)
420
+ # Filter out empty tensors before cat
421
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
422
+ if not cu_seqlens: # Handle case where input is empty or has no BOS
423
+ batch['cu_seqlens'] = torch.tensor(
424
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
425
+ )
426
+ else:
427
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
428
+
429
+ # Else, check for eos_token_id
430
+ elif self.tokenizer.eos_token_id is not None:
431
+ cu_seqlens = [torch.tensor([0], device=batch['input_ids'].device)] # Match device
432
+ # Find positions *after* EOS tokens
433
+ eos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1
434
+ # Ensure eos_positions is on the correct device if empty
435
+ if eos_positions.numel() > 0:
436
+ cu_seqlens.append(eos_positions)
437
+ # Handle case where the sequence doesn't end with EOS
438
+ if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
439
+ # Only add the final length if the last found EOS wasn't already the end
440
+ if eos_positions.numel() == 0 or eos_positions[-1] != batch['input_ids'].size(1):
441
+ cu_seqlens.append(
442
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
443
+ ) # Match device and use size(1)
444
+ # Filter out empty tensors before cat
445
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
446
+ if not cu_seqlens: # Handle case where input is empty or has no EOS
447
+ batch['cu_seqlens'] = torch.tensor(
448
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
449
+ )
450
+ else:
451
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
452
+ # Else, neither BOS nor EOS is usable
453
+ else:
454
+ raise ValueError(
455
+ 'For varlen=True without precomputed cu_seqlens, the tokenizer must have either a bos_token_id '
456
+ 'or an eos_token_id defined to act as sequence separators.'
457
+ )
458
+
459
+ # --- cu_seqlens validation checks remain the same ---
460
+ if batch['cu_seqlens'].numel() < 2:
461
+ raise ValueError(f'Calculated cu_seqlens must have at least start and end: {batch["cu_seqlens"]}')
462
+ if not torch.all(batch['cu_seqlens'][1:] >= batch['cu_seqlens'][:-1]):
463
+ raise ValueError(f'Calculated cu_seqlens are not monotonically increasing: {batch["cu_seqlens"]}')
464
+ if batch['cu_seqlens'][0] != 0:
465
+ raise ValueError(f'Calculated cu_seqlens do not start at 0: {batch["cu_seqlens"]}')
466
+ if batch['cu_seqlens'][-1] != batch['input_ids'].size(1):
467
+ # Allow empty sequence case where cu_seqlens=[0, 0] and input_ids.size(1)=0
468
+ if not (batch['cu_seqlens'].tolist() == [0, 0] and batch['input_ids'].size(1) == 0):
469
+ raise ValueError(
470
+ f'Calculated cu_seqlens do not end at total length {batch["input_ids"].size(1)}: '
471
+ f'{batch["cu_seqlens"]}'
472
+ )
473
+
474
+ # --- context_len splitting logic remains the same ---
475
+ if self.context_len is not None:
476
+ # This logic splits sequences based on context_len *after* initial boundaries are found
477
+ bos = batch['cu_seqlens'][:-1].tolist()
478
+ eos = batch['cu_seqlens'][1:].tolist()
479
+ # Handle empty sequences between boundaries
480
+ split_boundaries = []
481
+ for i, j in zip(bos, eos):
482
+ if i < j: # Only process non-empty sequences
483
+ split_boundaries.append(torch.arange(i, j, self.context_len, device=batch['input_ids'].device))
484
+ # Add the final end point if it wasn't included by arange
485
+ final_end_point = torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
486
+ # Concatenate all boundaries
487
+ if not split_boundaries: # Handle case of completely empty input
488
+ batch['cu_seqlens'] = torch.tensor([0, 0], dtype=torch.int32, device=batch['input_ids'].device)
489
+ else:
490
+ batch['cu_seqlens'] = torch.cat(split_boundaries + [final_end_point]).to(dtype=torch.int32)
491
+ # Ensure uniqueness and sort, as arange might duplicate the endpoint
492
+ batch['cu_seqlens'] = torch.unique(batch['cu_seqlens'])
493
+
494
+ # Create labels directly from input_ids, NO padding mask needed for varlen
495
+ labels = batch['input_ids'].clone()
496
+ batch['labels'] = labels
497
+
498
+ return batch
499
+
500
+
501
+ class ParallelAwareDataLoader(StatefulDataLoader, Stateful):
502
+ """
503
+ A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
504
+ """
505
+
506
+ def __init__(
507
+ self,
508
+ rank: int,
509
+ dataset: IterableDataset,
510
+ batch_size: int,
511
+ collate_fn: Callable,
512
+ num_workers: int = 0,
513
+ pin_memory: bool = False,
514
+ prefetch_factor: int = 2,
515
+ persistent_workers: bool = False,
516
+ snapshot_every_n_steps: Optional[int] = 1,
517
+ ):
518
+ super().__init__(
519
+ dataset=dataset,
520
+ batch_size=batch_size,
521
+ collate_fn=collate_fn,
522
+ num_workers=num_workers,
523
+ pin_memory=pin_memory,
524
+ prefetch_factor=prefetch_factor,
525
+ persistent_workers=persistent_workers,
526
+ snapshot_every_n_steps=snapshot_every_n_steps,
527
+ )
528
+ self.rank = rank
529
+
530
+ def state_dict(self) -> Dict[str, Any]:
531
+ # Store state only for dp rank to avoid replicating the same state across other dimensions
532
+ return {f'rank_{self.rank}': pickle.dumps(super().state_dict())}
533
+
534
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
535
+ # State being empty is valid
536
+ if not state_dict:
537
+ return
538
+
539
+ if f'rank_{self.rank}' not in state_dict:
540
+ logger.warning(f'DataLoader state is empty for dp rank {self.rank}, expected key rank_{self.rank}')
541
+ return
542
+ super().load_state_dict(pickle.loads(state_dict[f'rank_{self.rank}']))
543
+
544
+
545
+ def build_dataset(
546
+ dataset: str,
547
+ dataset_name: str = None,
548
+ dataset_split: str = 'train',
549
+ data_dir: str = None,
550
+ data_files: str = None,
551
+ data_probs: List[float] = None,
552
+ streaming: bool = False,
553
+ dp_degree: Optional[int] = None,
554
+ num_workers: int = 32,
555
+ seed: Optional[int] = None,
556
+ ) -> IterableDataset:
557
+ color = utils.Color
558
+ min_num_shards = dp_degree * num_workers if dp_degree else None
559
+ if len(dataset.split(',')) == 1:
560
+ dataset = load_dataset(
561
+ path=dataset,
562
+ name=dataset_name,
563
+ split=dataset_split,
564
+ data_dir=data_dir,
565
+ data_files=data_files,
566
+ trust_remote_code=True,
567
+ streaming=streaming,
568
+ num_proc=num_workers if not streaming else None,
569
+ )
570
+ logger.info(f"Shuffling the dataset with seed {seed}")
571
+ if not streaming:
572
+ # the states of map-style dataset is recoverable after shuffling
573
+ if seed is not None:
574
+ dataset = dataset.shuffle(seed=seed)
575
+ if min_num_shards is not None:
576
+ dataset = dataset.to_iterable_dataset(num_shards=min_num_shards)
577
+ else:
578
+ if min_num_shards is not None and dataset.num_shards < min_num_shards:
579
+ logger.warning(
580
+ f"{color.red}"
581
+ f"Dataset {dataset} has insufficient shards ({dataset.num_shards}). "
582
+ f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
583
+ f"{num_workers} dataloader workers. "
584
+ f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards."
585
+ f"{color.reset}"
586
+ )
587
+ dataset = load_dataset(
588
+ path=dataset,
589
+ name=dataset_name,
590
+ split=dataset_split,
591
+ data_dir=data_dir,
592
+ data_files=data_files,
593
+ trust_remote_code=True,
594
+ streaming=False,
595
+ num_proc=num_workers,
596
+ )
597
+ if seed is not None:
598
+ dataset = dataset.shuffle(seed=seed)
599
+ dataset = dataset.to_iterable_dataset(num_shards=min_num_shards)
600
+ else:
601
+ if seed is not None:
602
+ dataset = shuffle(dataset, seed=seed)
603
+ else:
604
+ datasets = dataset.split(",")
605
+ if dataset_name is not None:
606
+ dataset_names = [
607
+ name or None for name in dataset_name.split(",")
608
+ ]
609
+ assert len(dataset_names) == len(datasets), (
610
+ "The number of dataset names must match the number of datasets"
611
+ )
612
+ else:
613
+ dataset_names = [None] * len(datasets)
614
+ if dataset_split is not None:
615
+ dataset_splits = [split or "train"for split in dataset_split.split(",")]
616
+ assert len(dataset_splits) == len(datasets), (
617
+ "The number of dataset splits must match the number of datasets"
618
+ )
619
+ else:
620
+ dataset_splits = ["train"] * len(datasets)
621
+ if data_dir is not None:
622
+ data_dirs = [
623
+ data_dir or None for data_dir in data_dir.split(",")
624
+ ]
625
+ assert len(data_dirs) == len(datasets), (
626
+ "The number of data dirs must match the number of datasets"
627
+ )
628
+ else:
629
+ data_dirs = [None] * len(datasets)
630
+ if data_files is not None:
631
+ data_files = data_files.split(",")
632
+ assert len(data_files) == len(datasets), (
633
+ "The number of data files must match the number of datasets"
634
+ )
635
+ else:
636
+ data_files = [None] * len(datasets)
637
+ if data_probs is not None:
638
+ data_probs = [float(p) for p in data_probs.split(",")]
639
+ assert len(data_probs) == len(datasets), (
640
+ "The number of data probabilities must match the number of datasets"
641
+ )
642
+ else:
643
+ raise ValueError(
644
+ "Data sampling probabilities are required if using multiple datasets"
645
+ )
646
+
647
+ subsets = []
648
+ for i, prob in enumerate(data_probs):
649
+ subset = load_dataset(
650
+ path=datasets[i],
651
+ name=dataset_names[i],
652
+ split=dataset_splits[i],
653
+ data_dir=data_dirs[i],
654
+ data_files=data_files[i],
655
+ trust_remote_code=True,
656
+ streaming=streaming,
657
+ num_proc=(
658
+ num_workers
659
+ if not streaming
660
+ else None
661
+ ),
662
+ )
663
+ logger.info(
664
+ f"Subset {color.cyan}{datasets[i]}"
665
+ + (f":{dataset_names[i]} " if dataset_names[i] else " ")
666
+ + f"(p = {prob:.3f}){color.reset}:\n"
667
+ + f"{subset}"
668
+ )
669
+
670
+ logger.info(f"Shuffling the dataset with seed {seed}")
671
+ if not streaming:
672
+ # the states of map-style dataset is recoverable after shuffling
673
+ if seed is not None:
674
+ subset = subset.shuffle(seed=seed)
675
+ if min_num_shards is not None:
676
+ subset = subset.to_iterable_dataset(num_shards=min_num_shards)
677
+ else:
678
+ if min_num_shards is not None and subset.num_shards < min_num_shards:
679
+ logger.warning(
680
+ f"{color.red}"
681
+ f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). "
682
+ f"Need {min_num_shards} shards minimum for desired data parallel workers × "
683
+ f"{num_workers} dataloader workers. "
684
+ f"Resharding dataset to {min_num_shards} shards and disabling streaming mode."
685
+ f"{color.reset}"
686
+ )
687
+ # again, it's ok to directly shuffle the map-style dataset
688
+ # we expect an error raised if the map-style dataset still has not enough data shards
689
+ subset = load_dataset(
690
+ path=datasets[i],
691
+ name=dataset_names[i],
692
+ split=dataset_splits[i],
693
+ data_dir=data_dirs[i],
694
+ data_files=data_files[i],
695
+ trust_remote_code=True,
696
+ streaming=False,
697
+ num_proc=num_workers,
698
+ )
699
+ if seed is not None:
700
+ subset = subset.shuffle(seed=seed)
701
+ subset = subset.to_iterable_dataset(num_shards=min_num_shards)
702
+ else:
703
+ # we set relatively small buffer size here as interleaving could provide some randomness
704
+ if seed is not None:
705
+ subset = shuffle(subset, seed=seed, buffer_size=max(128, 1024 // len(datasets)))
706
+
707
+ if "text" in subset.column_names:
708
+ subset = subset.select_columns("text")
709
+ elif "content" in subset.column_names:
710
+ subset = subset.select_columns("content")
711
+ else:
712
+ raise ValueError(
713
+ f"Subset {datasets[i]} has no 'text' or 'content' column"
714
+ )
715
+ subsets.append(subset)
716
+
717
+ logger.info(
718
+ f"Interleaving {len(subsets)} datasets with probabilities {data_probs}"
719
+ )
720
+ dataset = interleave_datasets(
721
+ datasets=subsets,
722
+ probabilities=data_probs,
723
+ stopping_strategy="all_exhausted",
724
+ seed=seed,
725
+ )
726
+ logger.info(f"{dataset}")
727
+ return dataset
728
+
729
+
730
+ def build_dataloader(
731
+ dataset: IterableDataset,
732
+ tokenizer: PreTrainedTokenizer,
733
+ rank: int,
734
+ world_size: int,
735
+ batch_size: int,
736
+ seq_len: int,
737
+ context_len: Optional[int] = None,
738
+ varlen: bool = False,
739
+ num_workers: int = 0,
740
+ pin_memory: bool = False,
741
+ persistent_workers: bool = False,
742
+ snapshot_every_n_steps: Optional[int] = 1,
743
+ ):
744
+ dataset = OnlineTokenizedIterableDataset(
745
+ dataset=dataset, tokenizer=tokenizer, seq_len=seq_len, rank=rank, world_size=world_size
746
+ )
747
+ return ParallelAwareDataLoader(
748
+ rank=rank,
749
+ dataset=dataset,
750
+ batch_size=batch_size,
751
+ collate_fn=DataCollatorForLanguageModeling(tokenizer=tokenizer, context_len=context_len, varlen=varlen),
752
+ num_workers=num_workers,
753
+ pin_memory=pin_memory,
754
+ persistent_workers=persistent_workers,
755
+ snapshot_every_n_steps=snapshot_every_n_steps,
756
+ )
flame/models/fla.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [model]
2
+ config = "fla-hub/transformer-1.3B-100B"
3
+ tokenizer_path = "fla-hub/transformer-1.3B-100B"
4
+
5
+ [job]
6
+ dump_folder = "exp"
7
+ print_args = true
8
+
9
+ [training]
10
+ batch_size = 32
11
+ seq_len = 2048
12
+ context_len = 2048
13
+ gradient_accumulation_steps = 1
14
+ steps = 20480
15
+ max_norm = 1.0
16
+ skip_nan_inf = true
17
+ data_parallel_replicate_degree = 1
18
+ data_parallel_shard_degree = -1
19
+ tensor_parallel_degree = 1
20
+ compile = false
21
+ dataset = "HuggingFaceFW/fineweb-edu"
22
+ dataset_name = "default"
23
+ num_workers = 32
24
+ pin_memory = false
25
+ persistent_workers = false
26
+ prefetch_factor = 2
27
+ seed = 42
28
+ varlen = false
29
+
30
+ [optimizer]
31
+ name = "AdamW"
32
+ eps = 1e-15
33
+ lr = 3e-4
34
+
35
+ [lr_scheduler]
36
+ warmup_steps = 1024
37
+ decay_type = "cosine"
38
+ lr_min = 0.1
39
+
40
+ [checkpoint]
41
+ enable_checkpoint = true
42
+ folder = "checkpoint"
43
+ interval_type = "steps"
44
+ interval = 2048
45
+ model_weights_only = false
46
+ export_dtype = "float32"
47
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
48
+
49
+ [profiling]
50
+ enable_profiling = true
51
+ save_traces_folder = "profile_trace"
52
+ profile_freq = 512
53
+
54
+ [metrics]
55
+ log_freq = 32
56
+ enable_wandb = true
57
+
58
+ [experimental]
59
+ context_parallel_degree = 1
60
+ pipeline_parallel_degree = 1
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+
66
+ [activation_checkpoint]
67
+ mode = "none"
flame/train.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import time
10
+ from datetime import timedelta
11
+
12
+ import fla # noqa
13
+ import fla.models.gsa
14
+ import fla.models.routmem
15
+ import torch
16
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
17
+ from fla.ops.utils import prepare_position_ids
18
+ from torch.distributed.elastic.multiprocessing.errors import record
19
+ from torchtitan.components.checkpoint import CheckpointManager
20
+ from torchtitan.components.ft import FTParallelDims, init_ft_manager
21
+ from torchtitan.components.loss import build_cross_entropy_loss
22
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
23
+ from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible
24
+ from torchtitan.components.optimizer import build_optimizers
25
+ from torchtitan.distributed import ParallelDims
26
+ from torchtitan.distributed import utils as dist_utils
27
+ from torchtitan.protocols.model_converter import build_model_converters
28
+ from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec
29
+ from torchtitan.tools import utils
30
+ from torchtitan.tools.logging import init_logger, logger
31
+ from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
32
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
33
+
34
+ import custom_models
35
+ from flame.components.checkpoint import TrainState
36
+ from flame.config_manager import JobConfig
37
+ from flame.data import build_dataloader, build_dataset
38
+ from flame.models.parallelize_fla import parallelize_fla
39
+ from flame.models.pipeline_fla import pipeline_fla
40
+ from flame.tools.utils import get_nparams_and_flops
41
+
42
+
43
+ def build_tokenizer(job_config: JobConfig) -> AutoTokenizer:
44
+ return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path)
45
+
46
+
47
+ register_train_spec(
48
+ TrainSpec(
49
+ name="fla",
50
+ cls=AutoModelForCausalLM,
51
+ config=AutoConfig,
52
+ parallelize_fn=parallelize_fla,
53
+ pipelining_fn=pipeline_fla,
54
+ build_optimizers_fn=build_optimizers,
55
+ build_lr_schedulers_fn=build_lr_schedulers,
56
+ build_dataloader_fn=build_dataloader,
57
+ build_tokenizer_fn=build_tokenizer,
58
+ build_loss_fn=build_cross_entropy_loss,
59
+ )
60
+ )
61
+
62
+
63
+ # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
64
+ @record
65
+ def main(job_config: JobConfig):
66
+ logger.info(f"Starting job: {job_config.job.description}")
67
+
68
+ if job_config.experimental.custom_model_path:
69
+ utils.import_module_from_path(job_config.experimental.custom_model_path)
70
+
71
+ # used for colorful printing
72
+ color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color
73
+
74
+ if job_config.job.print_args:
75
+ logger.info(
76
+ f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}"
77
+ )
78
+
79
+ # take control of garbage collection to avoid stragglers
80
+ gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
81
+
82
+ device_module, device_type = utils.device_module, utils.device_type
83
+ device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
84
+ # Device has to be set before creating TorchFT manager.
85
+ device_module.set_device(device)
86
+ ft_manager = init_ft_manager(job_config)
87
+
88
+ # init distributed
89
+ world_size = int(os.environ["WORLD_SIZE"])
90
+ if not ft_manager.enabled:
91
+ parallel_dims = ParallelDims(
92
+ dp_shard=job_config.training.data_parallel_shard_degree,
93
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
94
+ cp=job_config.experimental.context_parallel_degree,
95
+ tp=job_config.training.tensor_parallel_degree,
96
+ pp=job_config.experimental.pipeline_parallel_degree,
97
+ world_size=world_size,
98
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
99
+ )
100
+ else:
101
+ parallel_dims = FTParallelDims(
102
+ dp_shard=job_config.training.data_parallel_shard_degree,
103
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
104
+ cp=job_config.experimental.context_parallel_degree,
105
+ tp=job_config.training.tensor_parallel_degree,
106
+ pp=job_config.experimental.pipeline_parallel_degree,
107
+ world_size=world_size,
108
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
109
+ ft_manager=ft_manager,
110
+ )
111
+ dist_utils.init_distributed(job_config)
112
+ # initialize device memory monitor and get peak flops for MFU calculation
113
+ device_memory_monitor = build_device_memory_monitor()
114
+ gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
115
+ logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
116
+
117
+ # build meshes
118
+ world_mesh = parallel_dims.build_mesh(device_type=device_type)
119
+ if parallel_dims.dp_enabled:
120
+ dp_mesh = world_mesh["dp"]
121
+ dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
122
+ else:
123
+ dp_degree, dp_rank = 1, 0
124
+
125
+ if parallel_dims.pp_enabled:
126
+ raise NotImplementedError(
127
+ "Pipeline parallelism is not supported in this version"
128
+ )
129
+ """
130
+ ! TODO[flame]: We need to fix the pipeline parallelism for flame
131
+ [x] Match the key of models' components with the actual naming
132
+ [ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically
133
+ forces to tie if head is None, we need to handle this case
134
+ [ ]
135
+ """
136
+ pp_mesh = world_mesh["pp"]
137
+
138
+ # Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss)
139
+ dist_utils.set_determinism(
140
+ world_mesh, device, job_config.training.seed, job_config.training.deterministic
141
+ )
142
+ train_spec = get_train_spec(job_config.model.name)
143
+
144
+ logger.info("Loading tokenizer...")
145
+ tokenizer = AutoTokenizer.from_pretrained(
146
+ job_config.model.tokenizer_path,
147
+ trust_remote_code=True,
148
+ model_max_length=int(1e10),
149
+ )
150
+ logger.info(f"{tokenizer}")
151
+ logger.info(
152
+ f"Loading dataset {job_config.training.dataset}"
153
+ f":{job_config.training.dataset_name}"
154
+ if job_config.training.dataset_name is not None
155
+ else ""
156
+ )
157
+ dataset = build_dataset(
158
+ dataset=job_config.training.dataset,
159
+ dataset_name=job_config.training.dataset_name,
160
+ dataset_split=job_config.training.dataset_split,
161
+ data_dir=job_config.training.data_dir,
162
+ data_files=job_config.training.data_files,
163
+ data_probs=job_config.training.data_probs,
164
+ streaming=job_config.training.streaming,
165
+ dp_degree=dp_degree,
166
+ num_workers=job_config.training.num_workers,
167
+ seed=job_config.training.seed,
168
+ )
169
+
170
+ logger.info("Building dataloader...")
171
+ dataloader = build_dataloader(
172
+ dataset=dataset,
173
+ tokenizer=tokenizer,
174
+ rank=dp_rank,
175
+ world_size=dp_degree,
176
+ batch_size=job_config.training.batch_size,
177
+ seq_len=job_config.training.seq_len,
178
+ context_len=job_config.training.context_len,
179
+ varlen=job_config.training.varlen,
180
+ num_workers=job_config.training.num_workers,
181
+ pin_memory=job_config.training.pin_memory,
182
+ persistent_workers=job_config.training.persistent_workers,
183
+ snapshot_every_n_steps=job_config.checkpoint.interval,
184
+ )
185
+
186
+ logger.info(f"Loading model config from {job_config.model.config}")
187
+ model_config = AutoConfig.from_pretrained(job_config.model.config)
188
+ # set the model configs from training inputs:
189
+ # 1. norm type to decide which norm layer to use
190
+ # 2. disable fused norm if TP is enabled
191
+ # 3. vocab size from tokenizer
192
+ # 4. context_len base on inputs
193
+ if parallel_dims.tp_enabled:
194
+ if model_config.fuse_norm:
195
+ logger.warning(
196
+ f"{color.red}"
197
+ f"Fused norm is not compatible with tensor parallelism. "
198
+ f"Disabling it for now."
199
+ f"{color.reset}"
200
+ )
201
+ model_config.fuse_norm = False
202
+ if parallel_dims.loss_parallel_enabled:
203
+ if model_config.fuse_linear_cross_entropy:
204
+ logger.warning(
205
+ f"{color.red}"
206
+ f"Loss parallel enabled. Disabling fused cross entropy for now."
207
+ f"{color.reset}"
208
+ )
209
+ model_config.fuse_linear_cross_entropy = False
210
+ model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size)
211
+
212
+ logger.info(
213
+ f"Building model from the config\n{color.green}{model_config}{color.reset}"
214
+ )
215
+ with torch.device("meta"):
216
+ model = AutoModelForCausalLM.from_config(model_config)
217
+ if (
218
+ getattr(model_config, "fuse_linear_cross_entropy", False)
219
+ and FusedLinearCrossEntropyLoss is not None
220
+ ):
221
+ model.criterion = FusedLinearCrossEntropyLoss(
222
+ num_chunks=8 // parallel_dims.tp
223
+ )
224
+ # defer weight initialization until after parallelisms are applied
225
+ model.apply(lambda m: setattr(m, "_is_hf_initialized", False))
226
+ logger.info(f"{color.blue}\n{model}{color.reset}\n")
227
+
228
+ # Build the collection of model converters. No-op if `model.converters` empty
229
+ model_converters = build_model_converters(job_config, parallel_dims)
230
+ model_converters.convert(model)
231
+
232
+ # calculate model size and flops per token
233
+ model_param_count, num_flops_per_token = get_nparams_and_flops(
234
+ model, model_config, job_config.training.context_len
235
+ )
236
+
237
+ # move sharded model to CPU/GPU and initialize weights via DTensor
238
+ if job_config.checkpoint.create_seed_checkpoint:
239
+ init_device = "cpu"
240
+ elif job_config.training.enable_cpu_offload:
241
+ init_device = "cpu"
242
+ else:
243
+ init_device = device_type
244
+
245
+ # apply parallelisms and initialization
246
+ if parallel_dims.pp_enabled:
247
+ # apply PT-D Pipeline Parallel
248
+ (
249
+ pp_schedule,
250
+ model_parts,
251
+ has_first_stage,
252
+ has_last_stage,
253
+ ) = train_spec.pipelining_fn(
254
+ model,
255
+ pp_mesh,
256
+ parallel_dims,
257
+ job_config,
258
+ device,
259
+ model_config,
260
+ train_spec.loss_fn,
261
+ )
262
+ # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
263
+ del model
264
+
265
+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
266
+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
267
+ # optimizer, and checkpointing
268
+ for m in model_parts:
269
+ # apply SPMD-style PT-D techniques
270
+ train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
271
+ m.to_empty(device=init_device)
272
+ with torch.no_grad():
273
+ m.post_init()
274
+ m.train()
275
+
276
+ # confirm that user will be able to view loss metrics on the console
277
+ ensure_pp_loss_visible(parallel_dims, job_config, color)
278
+ else:
279
+ # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
280
+ train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
281
+ model.to_empty(device=init_device)
282
+ with torch.no_grad():
283
+ model.post_init()
284
+ model.train()
285
+
286
+ model_parts = [model]
287
+
288
+ device_mem_stats = device_memory_monitor.get_peak_stats()
289
+ logger.info(
290
+ f"{device_type.upper()} memory usage for model: "
291
+ f"{device_mem_stats.max_reserved_gib:.2f}GiB"
292
+ f"({device_mem_stats.max_reserved_pct:.2f}%)"
293
+ )
294
+
295
+ # build optimizer after applying parallelisms to the model
296
+ optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager)
297
+ lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config)
298
+ # Post optimizer step model converters hook.
299
+ # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
300
+ # where it issues a single all-reduce for all parameters at once for better performance
301
+ optimizers.register_step_post_hook(
302
+ lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts)
303
+ )
304
+
305
+ train_state = TrainState()
306
+
307
+ # load initial checkpoint
308
+ checkpoint = CheckpointManager(
309
+ dataloader=dataloader,
310
+ model_parts=model_parts,
311
+ optimizers=optimizers,
312
+ lr_schedulers=lr_schedulers,
313
+ states={"train_state": train_state},
314
+ job_config=job_config,
315
+ ft_manager=ft_manager,
316
+ )
317
+
318
+ if job_config.checkpoint.create_seed_checkpoint:
319
+ assert world_size == 1, (
320
+ "Must create seed checkpoint using a single device, to disable sharding"
321
+ )
322
+ assert job_config.checkpoint.enable_checkpoint, (
323
+ "Must enable checkpointing when creating a seed checkpoint"
324
+ )
325
+ checkpoint.save(curr_step=0, force=True)
326
+ logger.info("Created seed checkpoint")
327
+ return
328
+
329
+ checkpoint.load(step=job_config.checkpoint.load_step)
330
+ metric_logger = build_metrics_processor(job_config, parallel_dims)
331
+ # Set dependent attributes for metric_logger
332
+ metric_logger.num_flops_per_token = num_flops_per_token
333
+ metric_logger.optimizers = optimizers # Pass optimizers if needed by logger logic
334
+ metric_logger.lr_schedulers = (
335
+ lr_schedulers # Pass schedulers if needed by logger logic
336
+ )
337
+
338
+ # plot losses loaded from checkpoint (if any) to TensorBoard
339
+ # NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
340
+ # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
341
+ if train_state.step > 0 and len(metric_logger.data_loading_times) > 0:
342
+ for idx, step in enumerate(train_state.log_steps):
343
+ metric_logger.log(
344
+ step,
345
+ global_avg_loss=train_state.global_avg_losses[idx],
346
+ global_max_loss=train_state.global_max_losses[idx],
347
+ )
348
+
349
+ data_iterator = iter(dataloader)
350
+
351
+ train_context = dist_utils.get_train_context(
352
+ parallel_dims.loss_parallel_enabled,
353
+ job_config.experimental.enable_compiled_autograd,
354
+ )
355
+ maybe_enable_amp = dist_utils.maybe_enable_amp(
356
+ parallel_dims,
357
+ job_config.training.mixed_precision_param,
358
+ device_type,
359
+ )
360
+
361
+ # variables used to keep info for metrics logging
362
+ device_memory_monitor.reset_peak_stats()
363
+
364
+ global_batch_size = (
365
+ job_config.training.batch_size
366
+ * dp_degree
367
+ * job_config.training.gradient_accumulation_steps
368
+ )
369
+ num_tokens_per_step = global_batch_size * job_config.training.seq_len
370
+ # train loop
371
+ logger.info(f"{color.red}***** Running training *****{color.reset}")
372
+ logger.info(f"{color.green} Training starts at step {train_state.step + 1}")
373
+ logger.info(
374
+ f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}"
375
+ )
376
+ logger.info(
377
+ f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}"
378
+ )
379
+ logger.info(
380
+ f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}"
381
+ )
382
+ logger.info(
383
+ f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}"
384
+ f" ({num_tokens_per_step:,} tokens)"
385
+ )
386
+ logger.info(
387
+ f"{color.green} Total optimization steps = {job_config.training.steps:,} "
388
+ f"({job_config.training.steps * num_tokens_per_step:,} tokens)"
389
+ )
390
+ logger.info(
391
+ f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}"
392
+ f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)"
393
+ )
394
+ logger.info(
395
+ f"{color.green} Number of parameters = {model_param_count:,} {color.reset}"
396
+ )
397
+
398
+ with (
399
+ maybe_enable_profiling(
400
+ job_config, global_step=train_state.step
401
+ ) as torch_profiler,
402
+ maybe_enable_memory_snapshot(
403
+ job_config, global_step=train_state.step
404
+ ) as memory_profiler,
405
+ ):
406
+ while train_state.step < job_config.training.steps:
407
+ train_state.step += 1
408
+ gc_handler.run(train_state.step)
409
+
410
+ optimizers.zero_grad()
411
+
412
+ losses = []
413
+ # do gradient accumulation if enabled
414
+ for _ in range(job_config.training.gradient_accumulation_steps):
415
+ # get batch
416
+ data_load_start = time.perf_counter()
417
+ batch = next(data_iterator)
418
+ input_ids, labels = batch["input_ids"], batch["labels"]
419
+
420
+ # Update metrics processor state before forward/backward
421
+ metric_logger.ntokens_since_last_log += labels.numel()
422
+ metric_logger.data_loading_times.append(
423
+ time.perf_counter() - data_load_start
424
+ )
425
+
426
+ input_ids = input_ids.to(device_type)
427
+
428
+ """
429
+ TODO[flame]: We need to carefully handle the position_ids for TP/CP
430
+ Depending on the Models'PE, the position_ids might be different.
431
+
432
+ e.g. for TP
433
+ For RoPE, all ranks have the same position_ids. [FOR HF model]
434
+ For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model]
435
+
436
+ e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids]
437
+ Each rank has the coresponding chunked position_ids. [FOR All model]
438
+
439
+ """
440
+ labels = labels.to(device_type)
441
+ cu_seqlens = (
442
+ batch["cu_seqlens"].to(device_type)
443
+ if "cu_seqlens" in batch
444
+ else None
445
+ )
446
+ if cu_seqlens is not None:
447
+ position_ids = prepare_position_ids(cu_seqlens).to(torch.int32)
448
+ else:
449
+ position_ids = (
450
+ torch.arange(0, input_ids.shape[1], device=device_type)
451
+ .repeat(input_ids.shape[0], 1)
452
+ .to(torch.int32)
453
+ )
454
+ # apply context parallelism if cp is enabled
455
+ # ensure CP handles the separate freqs_cis buffer for each pp stage
456
+ optional_context_parallel_ctx = (
457
+ dist_utils.create_context_parallel_ctx(
458
+ cp_mesh=world_mesh["cp"],
459
+ cp_buffers=[input_ids, labels, position_ids],
460
+ cp_seq_dims=[1, 1, 1],
461
+ cp_no_restore_buffers={input_ids, labels, position_ids},
462
+ cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
463
+ )
464
+ if parallel_dims.cp_enabled
465
+ else None
466
+ )
467
+
468
+ # #! TODO[flame], we should distribute the position_ids as well with CP
469
+ if parallel_dims.pp_enabled:
470
+ raise NotImplementedError(
471
+ "Pipeline parallelism is not supported in this version"
472
+ )
473
+ # Pipeline Parallel forward / backward inside step() call
474
+ with train_context(optional_context_parallel_ctx):
475
+ targets, losses = (
476
+ (labels, []) if has_last_stage else (None, None)
477
+ )
478
+
479
+ if has_first_stage:
480
+ pp_schedule.step(input_ids, target=targets, losses=losses)
481
+ else:
482
+ pp_schedule.step(target=targets, losses=losses)
483
+
484
+ # accumulate losses across pipeline microbatches
485
+ # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
486
+ loss = (
487
+ torch.mean(torch.stack(losses)).to(device)
488
+ if has_last_stage
489
+ else torch.tensor([-1.0], device=device)
490
+ )
491
+ else:
492
+ # Non-PP forward / backward
493
+ with train_context(optional_context_parallel_ctx):
494
+ with maybe_enable_amp:
495
+ output = model(
496
+ input_ids=input_ids,
497
+ labels=labels,
498
+ position_ids=position_ids,
499
+ cu_seqlens=cu_seqlens,
500
+ )
501
+ loss = (
502
+ output.loss
503
+ / job_config.training.gradient_accumulation_steps
504
+ )
505
+ loss.backward()
506
+
507
+ losses.append(loss)
508
+ loss = sum(losses)
509
+
510
+ # clip gradients
511
+ grad_norm = dist_utils.clip_grad_norm_(
512
+ [p for m in model_parts for p in m.parameters()],
513
+ job_config.training.max_norm,
514
+ foreach=True,
515
+ pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
516
+ )
517
+
518
+ # optimizer step
519
+ checkpoint.maybe_wait_for_staging()
520
+ if job_config.training.skip_nan_inf and (
521
+ grad_norm.isnan() or grad_norm.isinf()
522
+ ):
523
+ logger.warning(
524
+ f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}"
525
+ )
526
+ optimizers.zero_grad()
527
+ train_state.skipped_step += 1
528
+ else:
529
+ optimizers.step()
530
+ lr_schedulers.step()
531
+
532
+ # log metrics - Use MetricsProcessor
533
+ if metric_logger.should_log(train_state.step):
534
+ if (
535
+ parallel_dims.dp_replicate_enabled
536
+ or parallel_dims.dp_shard_enabled
537
+ or parallel_dims.cp_enabled
538
+ ):
539
+ loss = loss.detach()
540
+ # Use dist_mean/max on the accumulated loss for the step
541
+ global_avg_loss, global_max_loss = (
542
+ dist_utils.dist_mean(
543
+ loss,
544
+ world_mesh["dp_cp"],
545
+ ),
546
+ dist_utils.dist_max(
547
+ loss,
548
+ world_mesh["dp_cp"],
549
+ ),
550
+ )
551
+ else:
552
+ # Scale back the loss before logging
553
+ global_avg_loss = global_max_loss = loss.item()
554
+
555
+ # Update train state tokens and elapsed time
556
+ time_now = time.perf_counter()
557
+ time_delta = (
558
+ time_now - metric_logger.time_last_log
559
+ ) # Use metric_logger's time
560
+ train_state.token += (
561
+ metric_logger.ntokens_since_last_log # Use tokens tracked by metric_logger
562
+ * parallel_dims.world_size
563
+ / parallel_dims.non_data_parallel_size
564
+ )
565
+ train_state.elapsed += timedelta(seconds=time_delta)
566
+ train_state.log_steps.append(train_state.step)
567
+ train_state.global_avg_losses.append(global_avg_loss)
568
+ train_state.global_max_losses.append(global_max_loss)
569
+
570
+ # Log using the metric processor
571
+ last_lr = lr_schedulers.schedulers[0].get_last_lr()[0]
572
+ eta = (
573
+ train_state.elapsed
574
+ * (job_config.training.steps - train_state.step)
575
+ / train_state.step
576
+ )
577
+ metric_logger.log(
578
+ train_state.step,
579
+ global_avg_loss,
580
+ global_max_loss,
581
+ extra_metrics={
582
+ "optimizer/lr": last_lr,
583
+ "optimizer/grad_norm": grad_norm.item(),
584
+ "optimizer/skipped_step": train_state.skipped_step,
585
+ },
586
+ )
587
+
588
+ logger.info(
589
+ f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} "
590
+ f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}"
591
+ )
592
+
593
+ checkpoint.save(
594
+ train_state.step, force=(train_state.step == job_config.training.steps)
595
+ )
596
+
597
+ # signal the profiler that the next profiling step has started
598
+ if torch_profiler:
599
+ torch_profiler.step()
600
+ if memory_profiler:
601
+ memory_profiler.step()
602
+
603
+ # reduce timeout after first train step for faster signal
604
+ # (assuming lazy init and compilation are finished)
605
+ if train_state.step == 1:
606
+ dist_utils.set_pg_timeouts(
607
+ timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
608
+ world_mesh=world_mesh,
609
+ )
610
+
611
+ if torch.distributed.get_rank() == 0:
612
+ logger.info("Sleeping 2 seconds for other ranks to complete")
613
+ time.sleep(2)
614
+
615
+ metric_logger.close()
616
+ logger.info("Training completed")
617
+
618
+
619
+ if __name__ == "__main__":
620
+ init_logger()
621
+ config = JobConfig()
622
+ config.parse_args()
623
+ main(config)
624
+ torch.distributed.destroy_process_group()
flame/utils/__init__.py ADDED
File without changes
flame/utils/convert_dcp_to_hf.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ import io
6
+ import os
7
+ import tempfile
8
+ from datetime import timedelta
9
+
10
+ import fla # noqa
11
+ import fla.models.gsa
12
+
13
+ import torch
14
+ import torch.serialization
15
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
16
+ from torchtitan.tools.logging import init_logger, logger
17
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
18
+
19
+ import custom_models
20
+
21
+
22
+ @torch.inference_mode()
23
+ def save_pretrained(
24
+ path: str,
25
+ step: int,
26
+ config: str,
27
+ tokenizer: str
28
+ ):
29
+ logger.info(f"Loading the config from {config}")
30
+ config = AutoConfig.from_pretrained(config, trust_remote_code=True)
31
+
32
+ logger.info(f"Saving the config to {path}")
33
+ config.save_pretrained(path)
34
+ logger.info(f"Loading the tokenizer from {tokenizer}")
35
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
36
+ logger.info(f"Saving the tokenizer to {path}")
37
+ tokenizer.save_pretrained(path)
38
+
39
+ with tempfile.TemporaryDirectory() as tmpdir:
40
+ checkpoint = os.path.join(path, f'checkpoint/step-{step}')
41
+ checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
42
+ logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
43
+ dcp_to_torch_save(checkpoint, checkpoint_path)
44
+
45
+ logger.info(f"Initializing the model from config\n{config}")
46
+ model = AutoModelForCausalLM.from_config(config)
47
+ logger.info(model)
48
+ logger.info("Loading state dict from the checkpoint")
49
+
50
+ # Add datetime.timedelta and io.BytesIO to safe globals
51
+ torch.serialization.add_safe_globals([timedelta, io.BytesIO])
52
+ # torch.load now with default weights_only=True will work
53
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
54
+
55
+ logger.info(f"Saving the model to {path}")
56
+ model.save_pretrained(path)
57
+
58
+
59
+ if __name__ == "__main__":
60
+ init_logger()
61
+ parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
62
+ parser.add_argument("--path", type=str, required=True)
63
+ parser.add_argument("--step", type=int, required=True)
64
+ parser.add_argument("--config", type=str, required=True)
65
+ parser.add_argument("--tokenizer", type=str, required=True)
66
+ args = parser.parse_args()
67
+ save_pretrained(args.path, args.step, args.config, args.tokenizer)
flame/utils/convert_hf_to_dcp.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.distributed.checkpoint as DCP
9
+ from transformers import AutoModelForCausalLM
10
+
11
+ import fla # noqa
12
+ import fla.models.gsa
13
+ import fla.models.routmem
14
+ from torchtitan.tools.logging import init_logger, logger
15
+
16
+
17
+ @torch.inference_mode()
18
+ def convert_hf_weights(model: str, checkpoint: str):
19
+ logger.info(f"Loading model from {model}")
20
+ model = AutoModelForCausalLM.from_pretrained(model)
21
+ state_dict = model.state_dict()
22
+
23
+ logger.info(f"Writing to DCP at '{checkpoint}'")
24
+ checkpoint.mkdir(parents=True, exist_ok=True)
25
+ storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
26
+ DCP.save(state_dict, storage_writer=storage_writer)
27
+
28
+
29
+ if __name__ == "__main__":
30
+ init_logger()
31
+ parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
32
+ parser.add_argument("--model", type=str, required=True)
33
+ parser.add_argument("--checkpoint", type=Path, required=True)
34
+ args = parser.parse_args()
35
+
36
+ convert_hf_weights(args.model, args.checkpoint)
flame/utils/preprocess.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ from typing import Any, Dict, List
6
+
7
+ from transformers import AutoTokenizer, PreTrainedTokenizer
8
+
9
+ from flame.data import build_dataset
10
+ from torchtitan.tools.logging import init_logger, logger
11
+
12
+
13
+ def tokenize(
14
+ examples: Dict[str, List[Any]],
15
+ tokenizer: PreTrainedTokenizer,
16
+ ) -> Dict:
17
+ if 'text' in examples:
18
+ samples = examples['text']
19
+ elif 'content' in examples:
20
+ samples = examples['content']
21
+ else:
22
+ raise ValueError(f'No "text" or "content" field found in examples:\n{examples}')
23
+ input_ids = tokenizer(samples)['input_ids']
24
+ bits_per_token = [len(sample.encode(encoding='utf-8')) * 8 / len(input_ids[i]) for i, sample in enumerate(samples)]
25
+ return {'input_ids': input_ids, 'bits_per_token': bits_per_token}
26
+
27
+
28
+ if __name__ == '__main__':
29
+ init_logger()
30
+ parser = argparse.ArgumentParser(description='Preprocess the dataset.')
31
+ parser.add_argument(
32
+ '--dataset',
33
+ default='HuggingFaceFW/fineweb-edu',
34
+ help='Dataset to use, with comma separated values',
35
+ )
36
+ parser.add_argument(
37
+ '--dataset_name',
38
+ default='sample-100BT',
39
+ help='The name of the dataset config, with comma separated values if provided',
40
+ )
41
+ parser.add_argument(
42
+ '--dataset_split',
43
+ default='train',
44
+ help='Dataset split to use, with comma separated values if provided',
45
+ )
46
+ parser.add_argument(
47
+ '--data_dir',
48
+ default=None,
49
+ help='Data dirs to use, with comma separated values if provided',
50
+ )
51
+ parser.add_argument(
52
+ '--data_files',
53
+ default=None,
54
+ help='Data files to use, with comma separated values if provided',
55
+ )
56
+ parser.add_argument(
57
+ '--data_probs',
58
+ default=None,
59
+ help='Data sampling probabilities, with comma separated values if provided',
60
+ )
61
+ parser.add_argument(
62
+ '--streaming',
63
+ action='store_true',
64
+ help='Whether to use streaming mode',
65
+ )
66
+ parser.add_argument(
67
+ '--num_workers',
68
+ type=int,
69
+ default=64,
70
+ help='Number of workers to use for preprocessing',
71
+ )
72
+ parser.add_argument(
73
+ '--seed',
74
+ type=int,
75
+ default=42,
76
+ help='Random seed for preprocessing',
77
+ )
78
+ parser.add_argument(
79
+ '--path',
80
+ default='data',
81
+ help='Path to save the preprocessed dataset',
82
+ )
83
+ parser.add_argument(
84
+ '--tokenizer',
85
+ default='fla-hub/transformer-1.3B-100B',
86
+ help='Tokenizer to use',
87
+ )
88
+ parser.add_argument(
89
+ "--batch_size",
90
+ type=int,
91
+ default=2048,
92
+ help="Batch size for processing"
93
+ )
94
+ args = parser.parse_args()
95
+
96
+ logger.info(f'Loading tokenizer {args.tokenizer}')
97
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
98
+ logger.info(f'{tokenizer}')
99
+ logger.info(f'Loading dataset {args.dataset} {args.dataset_name} {args.dataset_split}')
100
+ dataset = build_dataset(
101
+ dataset=args.dataset,
102
+ dataset_name=args.dataset_name,
103
+ dataset_split=args.dataset_split,
104
+ data_dir=args.data_dir,
105
+ data_files=args.data_files,
106
+ data_probs=args.data_probs,
107
+ streaming=args.streaming,
108
+ num_workers=args.num_workers,
109
+ seed=args.seed,
110
+ )
111
+ logger.info(f'Tokenizing and processing the dataset with batch size {args.batch_size}')
112
+ dataset = dataset.map(
113
+ lambda examples: tokenize(examples, tokenizer),
114
+ batched=True,
115
+ batch_size=args.batch_size,
116
+ remove_columns=list(next(iter(dataset)).keys()),
117
+ num_proc=args.num_workers,
118
+ desc="Running tokenizer on dataset"
119
+ )
120
+ logger.info(f'{dataset}')
121
+ logger.info(f'Saving tokenized dataset to {args.path}')
122
+ dataset.save_to_disk(args.path)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 151643,
4
+ "eos_token_id": 151645,
5
+ "transformers_version": "4.57.3"
6
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors.index.json ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 14409815040,
4
+ "total_size": 28819630080
5
+ },
6
+ "weight_map": {
7
+ "model.embeddings.weight": "model-00001-of-00006.safetensors",
8
+ "model.layers.0.attn.f_proj.weight": "model-00001-of-00006.safetensors",
9
+ "model.layers.0.attn.k_proj.weight": "model-00001-of-00006.safetensors",
10
+ "model.layers.0.attn.o_proj.weight": "model-00001-of-00006.safetensors",
11
+ "model.layers.0.attn.q_proj.weight": "model-00001-of-00006.safetensors",
12
+ "model.layers.0.attn.v_proj.weight": "model-00001-of-00006.safetensors",
13
+ "model.layers.0.attn_norm.weight": "model-00001-of-00006.safetensors",
14
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
15
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
16
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
17
+ "model.layers.0.mlp_norm.weight": "model-00001-of-00006.safetensors",
18
+ "model.layers.1.attn.f_proj.weight": "model-00001-of-00006.safetensors",
19
+ "model.layers.1.attn.k_proj.weight": "model-00001-of-00006.safetensors",
20
+ "model.layers.1.attn.o_proj.weight": "model-00001-of-00006.safetensors",
21
+ "model.layers.1.attn.q_proj.weight": "model-00001-of-00006.safetensors",
22
+ "model.layers.1.attn.v_proj.weight": "model-00001-of-00006.safetensors",
23
+ "model.layers.1.attn_norm.weight": "model-00001-of-00006.safetensors",
24
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
25
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
26
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
27
+ "model.layers.1.mlp_norm.weight": "model-00001-of-00006.safetensors",
28
+ "model.layers.10.attn.f_proj.weight": "model-00002-of-00006.safetensors",
29
+ "model.layers.10.attn.k_proj.weight": "model-00002-of-00006.safetensors",
30
+ "model.layers.10.attn.o_proj.weight": "model-00002-of-00006.safetensors",
31
+ "model.layers.10.attn.q_proj.weight": "model-00002-of-00006.safetensors",
32
+ "model.layers.10.attn.v_proj.weight": "model-00002-of-00006.safetensors",
33
+ "model.layers.10.attn_norm.weight": "model-00002-of-00006.safetensors",
34
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
35
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
36
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
37
+ "model.layers.10.mlp_norm.weight": "model-00002-of-00006.safetensors",
38
+ "model.layers.11.attn.f_proj.weight": "model-00002-of-00006.safetensors",
39
+ "model.layers.11.attn.k_proj.weight": "model-00002-of-00006.safetensors",
40
+ "model.layers.11.attn.o_proj.weight": "model-00002-of-00006.safetensors",
41
+ "model.layers.11.attn.q_proj.weight": "model-00002-of-00006.safetensors",
42
+ "model.layers.11.attn.v_proj.weight": "model-00002-of-00006.safetensors",
43
+ "model.layers.11.attn_norm.weight": "model-00002-of-00006.safetensors",
44
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
45
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
46
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
47
+ "model.layers.11.mlp_norm.weight": "model-00002-of-00006.safetensors",
48
+ "model.layers.12.attn.f_proj.weight": "model-00002-of-00006.safetensors",
49
+ "model.layers.12.attn.k_proj.weight": "model-00002-of-00006.safetensors",
50
+ "model.layers.12.attn.o_proj.weight": "model-00002-of-00006.safetensors",
51
+ "model.layers.12.attn.q_proj.weight": "model-00002-of-00006.safetensors",
52
+ "model.layers.12.attn.v_proj.weight": "model-00002-of-00006.safetensors",
53
+ "model.layers.12.attn_norm.weight": "model-00002-of-00006.safetensors",
54
+ "model.layers.12.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
55
+ "model.layers.12.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
56
+ "model.layers.12.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
57
+ "model.layers.12.mlp_norm.weight": "model-00002-of-00006.safetensors",
58
+ "model.layers.13.attn.f_proj.weight": "model-00003-of-00006.safetensors",
59
+ "model.layers.13.attn.k_proj.weight": "model-00003-of-00006.safetensors",
60
+ "model.layers.13.attn.o_proj.weight": "model-00003-of-00006.safetensors",
61
+ "model.layers.13.attn.q_proj.weight": "model-00003-of-00006.safetensors",
62
+ "model.layers.13.attn.v_proj.weight": "model-00003-of-00006.safetensors",
63
+ "model.layers.13.attn_norm.weight": "model-00003-of-00006.safetensors",
64
+ "model.layers.13.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
65
+ "model.layers.13.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
66
+ "model.layers.13.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
67
+ "model.layers.13.mlp_norm.weight": "model-00003-of-00006.safetensors",
68
+ "model.layers.14.attn.f_proj.weight": "model-00003-of-00006.safetensors",
69
+ "model.layers.14.attn.k_proj.weight": "model-00003-of-00006.safetensors",
70
+ "model.layers.14.attn.o_proj.weight": "model-00003-of-00006.safetensors",
71
+ "model.layers.14.attn.q_proj.weight": "model-00003-of-00006.safetensors",
72
+ "model.layers.14.attn.v_proj.weight": "model-00003-of-00006.safetensors",
73
+ "model.layers.14.attn_norm.weight": "model-00003-of-00006.safetensors",
74
+ "model.layers.14.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
75
+ "model.layers.14.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
76
+ "model.layers.14.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
77
+ "model.layers.14.mlp_norm.weight": "model-00003-of-00006.safetensors",
78
+ "model.layers.15.attn.f_proj.weight": "model-00003-of-00006.safetensors",
79
+ "model.layers.15.attn.k_proj.weight": "model-00003-of-00006.safetensors",
80
+ "model.layers.15.attn.o_proj.weight": "model-00003-of-00006.safetensors",
81
+ "model.layers.15.attn.q_proj.weight": "model-00003-of-00006.safetensors",
82
+ "model.layers.15.attn.v_proj.weight": "model-00003-of-00006.safetensors",
83
+ "model.layers.15.attn_norm.weight": "model-00003-of-00006.safetensors",
84
+ "model.layers.15.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
85
+ "model.layers.15.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
86
+ "model.layers.15.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
87
+ "model.layers.15.mlp_norm.weight": "model-00003-of-00006.safetensors",
88
+ "model.layers.16.attn.f_proj.weight": "model-00003-of-00006.safetensors",
89
+ "model.layers.16.attn.k_proj.weight": "model-00003-of-00006.safetensors",
90
+ "model.layers.16.attn.o_proj.weight": "model-00003-of-00006.safetensors",
91
+ "model.layers.16.attn.q_proj.weight": "model-00003-of-00006.safetensors",
92
+ "model.layers.16.attn.v_proj.weight": "model-00003-of-00006.safetensors",
93
+ "model.layers.16.attn_norm.weight": "model-00003-of-00006.safetensors",
94
+ "model.layers.16.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
95
+ "model.layers.16.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
96
+ "model.layers.16.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
97
+ "model.layers.16.mlp_norm.weight": "model-00003-of-00006.safetensors",
98
+ "model.layers.17.attn.f_proj.weight": "model-00003-of-00006.safetensors",
99
+ "model.layers.17.attn.k_proj.weight": "model-00003-of-00006.safetensors",
100
+ "model.layers.17.attn.o_proj.weight": "model-00003-of-00006.safetensors",
101
+ "model.layers.17.attn.q_proj.weight": "model-00003-of-00006.safetensors",
102
+ "model.layers.17.attn.v_proj.weight": "model-00003-of-00006.safetensors",
103
+ "model.layers.17.attn_norm.weight": "model-00003-of-00006.safetensors",
104
+ "model.layers.17.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
105
+ "model.layers.17.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
106
+ "model.layers.17.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
107
+ "model.layers.17.mlp_norm.weight": "model-00003-of-00006.safetensors",
108
+ "model.layers.18.attn.f_proj.weight": "model-00003-of-00006.safetensors",
109
+ "model.layers.18.attn.k_proj.weight": "model-00003-of-00006.safetensors",
110
+ "model.layers.18.attn.o_proj.weight": "model-00003-of-00006.safetensors",
111
+ "model.layers.18.attn.q_proj.weight": "model-00003-of-00006.safetensors",
112
+ "model.layers.18.attn.v_proj.weight": "model-00003-of-00006.safetensors",
113
+ "model.layers.18.attn_norm.weight": "model-00003-of-00006.safetensors",
114
+ "model.layers.18.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
115
+ "model.layers.18.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
116
+ "model.layers.18.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
117
+ "model.layers.18.mlp_norm.weight": "model-00003-of-00006.safetensors",
118
+ "model.layers.19.attn.f_proj.weight": "model-00003-of-00006.safetensors",
119
+ "model.layers.19.attn.k_proj.weight": "model-00003-of-00006.safetensors",
120
+ "model.layers.19.attn.o_proj.weight": "model-00003-of-00006.safetensors",
121
+ "model.layers.19.attn.q_proj.weight": "model-00003-of-00006.safetensors",
122
+ "model.layers.19.attn.v_proj.weight": "model-00003-of-00006.safetensors",
123
+ "model.layers.19.attn_norm.weight": "model-00003-of-00006.safetensors",
124
+ "model.layers.19.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
125
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
126
+ "model.layers.19.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
127
+ "model.layers.19.mlp_norm.weight": "model-00003-of-00006.safetensors",
128
+ "model.layers.2.attn.f_proj.weight": "model-00001-of-00006.safetensors",
129
+ "model.layers.2.attn.k_proj.weight": "model-00001-of-00006.safetensors",
130
+ "model.layers.2.attn.o_proj.weight": "model-00001-of-00006.safetensors",
131
+ "model.layers.2.attn.q_proj.weight": "model-00001-of-00006.safetensors",
132
+ "model.layers.2.attn.v_proj.weight": "model-00001-of-00006.safetensors",
133
+ "model.layers.2.attn_norm.weight": "model-00001-of-00006.safetensors",
134
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
135
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
136
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
137
+ "model.layers.2.mlp_norm.weight": "model-00001-of-00006.safetensors",
138
+ "model.layers.20.attn.f_proj.weight": "model-00004-of-00006.safetensors",
139
+ "model.layers.20.attn.k_proj.weight": "model-00004-of-00006.safetensors",
140
+ "model.layers.20.attn.o_proj.weight": "model-00004-of-00006.safetensors",
141
+ "model.layers.20.attn.q_proj.weight": "model-00004-of-00006.safetensors",
142
+ "model.layers.20.attn.v_proj.weight": "model-00004-of-00006.safetensors",
143
+ "model.layers.20.attn_norm.weight": "model-00004-of-00006.safetensors",
144
+ "model.layers.20.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
145
+ "model.layers.20.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
146
+ "model.layers.20.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
147
+ "model.layers.20.mlp_norm.weight": "model-00004-of-00006.safetensors",
148
+ "model.layers.21.attn.f_proj.weight": "model-00004-of-00006.safetensors",
149
+ "model.layers.21.attn.k_proj.weight": "model-00004-of-00006.safetensors",
150
+ "model.layers.21.attn.o_proj.weight": "model-00004-of-00006.safetensors",
151
+ "model.layers.21.attn.q_proj.weight": "model-00004-of-00006.safetensors",
152
+ "model.layers.21.attn.v_proj.weight": "model-00004-of-00006.safetensors",
153
+ "model.layers.21.attn_norm.weight": "model-00004-of-00006.safetensors",
154
+ "model.layers.21.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
155
+ "model.layers.21.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
156
+ "model.layers.21.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
157
+ "model.layers.21.mlp_norm.weight": "model-00004-of-00006.safetensors",
158
+ "model.layers.22.attn.f_proj.weight": "model-00004-of-00006.safetensors",
159
+ "model.layers.22.attn.k_proj.weight": "model-00004-of-00006.safetensors",
160
+ "model.layers.22.attn.o_proj.weight": "model-00004-of-00006.safetensors",
161
+ "model.layers.22.attn.q_proj.weight": "model-00004-of-00006.safetensors",
162
+ "model.layers.22.attn.v_proj.weight": "model-00004-of-00006.safetensors",
163
+ "model.layers.22.attn_norm.weight": "model-00004-of-00006.safetensors",
164
+ "model.layers.22.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
165
+ "model.layers.22.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
166
+ "model.layers.22.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
167
+ "model.layers.22.mlp_norm.weight": "model-00004-of-00006.safetensors",
168
+ "model.layers.23.attn.f_proj.weight": "model-00004-of-00006.safetensors",
169
+ "model.layers.23.attn.k_proj.weight": "model-00004-of-00006.safetensors",
170
+ "model.layers.23.attn.o_proj.weight": "model-00004-of-00006.safetensors",
171
+ "model.layers.23.attn.q_proj.weight": "model-00004-of-00006.safetensors",
172
+ "model.layers.23.attn.v_proj.weight": "model-00004-of-00006.safetensors",
173
+ "model.layers.23.attn_norm.weight": "model-00004-of-00006.safetensors",
174
+ "model.layers.23.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
175
+ "model.layers.23.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
176
+ "model.layers.23.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
177
+ "model.layers.23.mlp_norm.weight": "model-00004-of-00006.safetensors",
178
+ "model.layers.24.attn.f_proj.weight": "model-00004-of-00006.safetensors",
179
+ "model.layers.24.attn.k_proj.weight": "model-00004-of-00006.safetensors",
180
+ "model.layers.24.attn.o_proj.weight": "model-00004-of-00006.safetensors",
181
+ "model.layers.24.attn.q_proj.weight": "model-00004-of-00006.safetensors",
182
+ "model.layers.24.attn.v_proj.weight": "model-00004-of-00006.safetensors",
183
+ "model.layers.24.attn_norm.weight": "model-00004-of-00006.safetensors",
184
+ "model.layers.24.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
185
+ "model.layers.24.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
186
+ "model.layers.24.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
187
+ "model.layers.24.mlp_norm.weight": "model-00004-of-00006.safetensors",
188
+ "model.layers.25.attn.f_proj.weight": "model-00004-of-00006.safetensors",
189
+ "model.layers.25.attn.k_proj.weight": "model-00004-of-00006.safetensors",
190
+ "model.layers.25.attn.o_proj.weight": "model-00004-of-00006.safetensors",
191
+ "model.layers.25.attn.q_proj.weight": "model-00004-of-00006.safetensors",
192
+ "model.layers.25.attn.v_proj.weight": "model-00004-of-00006.safetensors",
193
+ "model.layers.25.attn_norm.weight": "model-00004-of-00006.safetensors",
194
+ "model.layers.25.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
195
+ "model.layers.25.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
196
+ "model.layers.25.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
197
+ "model.layers.25.mlp_norm.weight": "model-00004-of-00006.safetensors",
198
+ "model.layers.26.attn.f_proj.weight": "model-00004-of-00006.safetensors",
199
+ "model.layers.26.attn.k_proj.weight": "model-00004-of-00006.safetensors",
200
+ "model.layers.26.attn.o_proj.weight": "model-00004-of-00006.safetensors",
201
+ "model.layers.26.attn.q_proj.weight": "model-00004-of-00006.safetensors",
202
+ "model.layers.26.attn.v_proj.weight": "model-00004-of-00006.safetensors",
203
+ "model.layers.26.attn_norm.weight": "model-00004-of-00006.safetensors",
204
+ "model.layers.26.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
205
+ "model.layers.26.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
206
+ "model.layers.26.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
207
+ "model.layers.26.mlp_norm.weight": "model-00004-of-00006.safetensors",
208
+ "model.layers.27.attn.f_proj.weight": "model-00005-of-00006.safetensors",
209
+ "model.layers.27.attn.k_proj.weight": "model-00005-of-00006.safetensors",
210
+ "model.layers.27.attn.o_proj.weight": "model-00005-of-00006.safetensors",
211
+ "model.layers.27.attn.q_proj.weight": "model-00005-of-00006.safetensors",
212
+ "model.layers.27.attn.v_proj.weight": "model-00005-of-00006.safetensors",
213
+ "model.layers.27.attn_norm.weight": "model-00005-of-00006.safetensors",
214
+ "model.layers.27.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
215
+ "model.layers.27.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
216
+ "model.layers.27.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
217
+ "model.layers.27.mlp_norm.weight": "model-00005-of-00006.safetensors",
218
+ "model.layers.28.attn.f_proj.weight": "model-00005-of-00006.safetensors",
219
+ "model.layers.28.attn.k_proj.weight": "model-00005-of-00006.safetensors",
220
+ "model.layers.28.attn.o_proj.weight": "model-00005-of-00006.safetensors",
221
+ "model.layers.28.attn.q_proj.weight": "model-00005-of-00006.safetensors",
222
+ "model.layers.28.attn.v_proj.weight": "model-00005-of-00006.safetensors",
223
+ "model.layers.28.attn_norm.weight": "model-00005-of-00006.safetensors",
224
+ "model.layers.28.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
225
+ "model.layers.28.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
226
+ "model.layers.28.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
227
+ "model.layers.28.mlp_norm.weight": "model-00005-of-00006.safetensors",
228
+ "model.layers.29.attn.f_proj.weight": "model-00005-of-00006.safetensors",
229
+ "model.layers.29.attn.k_proj.weight": "model-00005-of-00006.safetensors",
230
+ "model.layers.29.attn.o_proj.weight": "model-00005-of-00006.safetensors",
231
+ "model.layers.29.attn.q_proj.weight": "model-00005-of-00006.safetensors",
232
+ "model.layers.29.attn.v_proj.weight": "model-00005-of-00006.safetensors",
233
+ "model.layers.29.attn_norm.weight": "model-00005-of-00006.safetensors",
234
+ "model.layers.29.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
235
+ "model.layers.29.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
236
+ "model.layers.29.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
237
+ "model.layers.29.mlp_norm.weight": "model-00005-of-00006.safetensors",
238
+ "model.layers.3.attn.f_proj.weight": "model-00001-of-00006.safetensors",
239
+ "model.layers.3.attn.k_proj.weight": "model-00001-of-00006.safetensors",
240
+ "model.layers.3.attn.o_proj.weight": "model-00001-of-00006.safetensors",
241
+ "model.layers.3.attn.q_proj.weight": "model-00001-of-00006.safetensors",
242
+ "model.layers.3.attn.v_proj.weight": "model-00001-of-00006.safetensors",
243
+ "model.layers.3.attn_norm.weight": "model-00001-of-00006.safetensors",
244
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
245
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
246
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
247
+ "model.layers.3.mlp_norm.weight": "model-00001-of-00006.safetensors",
248
+ "model.layers.30.attn.f_proj.weight": "model-00005-of-00006.safetensors",
249
+ "model.layers.30.attn.k_proj.weight": "model-00005-of-00006.safetensors",
250
+ "model.layers.30.attn.o_proj.weight": "model-00005-of-00006.safetensors",
251
+ "model.layers.30.attn.q_proj.weight": "model-00005-of-00006.safetensors",
252
+ "model.layers.30.attn.v_proj.weight": "model-00005-of-00006.safetensors",
253
+ "model.layers.30.attn_norm.weight": "model-00005-of-00006.safetensors",
254
+ "model.layers.30.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
255
+ "model.layers.30.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
256
+ "model.layers.30.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
257
+ "model.layers.30.mlp_norm.weight": "model-00005-of-00006.safetensors",
258
+ "model.layers.31.attn.f_proj.weight": "model-00005-of-00006.safetensors",
259
+ "model.layers.31.attn.k_proj.weight": "model-00005-of-00006.safetensors",
260
+ "model.layers.31.attn.o_proj.weight": "model-00005-of-00006.safetensors",
261
+ "model.layers.31.attn.q_proj.weight": "model-00005-of-00006.safetensors",
262
+ "model.layers.31.attn.v_proj.weight": "model-00005-of-00006.safetensors",
263
+ "model.layers.31.attn_norm.weight": "model-00005-of-00006.safetensors",
264
+ "model.layers.31.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
265
+ "model.layers.31.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
266
+ "model.layers.31.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
267
+ "model.layers.31.mlp_norm.weight": "model-00005-of-00006.safetensors",
268
+ "model.layers.32.attn.f_proj.weight": "model-00005-of-00006.safetensors",
269
+ "model.layers.32.attn.k_proj.weight": "model-00005-of-00006.safetensors",
270
+ "model.layers.32.attn.o_proj.weight": "model-00005-of-00006.safetensors",
271
+ "model.layers.32.attn.q_proj.weight": "model-00005-of-00006.safetensors",
272
+ "model.layers.32.attn.v_proj.weight": "model-00005-of-00006.safetensors",
273
+ "model.layers.32.attn_norm.weight": "model-00005-of-00006.safetensors",
274
+ "model.layers.32.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
275
+ "model.layers.32.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
276
+ "model.layers.32.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
277
+ "model.layers.32.mlp_norm.weight": "model-00005-of-00006.safetensors",
278
+ "model.layers.33.attn.f_proj.weight": "model-00005-of-00006.safetensors",
279
+ "model.layers.33.attn.k_proj.weight": "model-00005-of-00006.safetensors",
280
+ "model.layers.33.attn.o_proj.weight": "model-00005-of-00006.safetensors",
281
+ "model.layers.33.attn.q_proj.weight": "model-00005-of-00006.safetensors",
282
+ "model.layers.33.attn.v_proj.weight": "model-00005-of-00006.safetensors",
283
+ "model.layers.33.attn_norm.weight": "model-00005-of-00006.safetensors",
284
+ "model.layers.33.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
285
+ "model.layers.33.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
286
+ "model.layers.33.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
287
+ "model.layers.33.mlp_norm.weight": "model-00005-of-00006.safetensors",
288
+ "model.layers.34.attn.f_proj.weight": "model-00006-of-00006.safetensors",
289
+ "model.layers.34.attn.k_proj.weight": "model-00006-of-00006.safetensors",
290
+ "model.layers.34.attn.o_proj.weight": "model-00006-of-00006.safetensors",
291
+ "model.layers.34.attn.q_proj.weight": "model-00006-of-00006.safetensors",
292
+ "model.layers.34.attn.v_proj.weight": "model-00006-of-00006.safetensors",
293
+ "model.layers.34.attn_norm.weight": "model-00005-of-00006.safetensors",
294
+ "model.layers.34.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
295
+ "model.layers.34.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
296
+ "model.layers.34.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
297
+ "model.layers.34.mlp_norm.weight": "model-00006-of-00006.safetensors",
298
+ "model.layers.35.attn.f_proj.weight": "model-00006-of-00006.safetensors",
299
+ "model.layers.35.attn.k_proj.weight": "model-00006-of-00006.safetensors",
300
+ "model.layers.35.attn.o_proj.weight": "model-00006-of-00006.safetensors",
301
+ "model.layers.35.attn.q_proj.weight": "model-00006-of-00006.safetensors",
302
+ "model.layers.35.attn.v_proj.weight": "model-00006-of-00006.safetensors",
303
+ "model.layers.35.attn_norm.weight": "model-00006-of-00006.safetensors",
304
+ "model.layers.35.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
305
+ "model.layers.35.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
306
+ "model.layers.35.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
307
+ "model.layers.35.mlp_norm.weight": "model-00006-of-00006.safetensors",
308
+ "model.layers.36.attn.f_proj.weight": "model-00006-of-00006.safetensors",
309
+ "model.layers.36.attn.k_proj.weight": "model-00006-of-00006.safetensors",
310
+ "model.layers.36.attn.o_proj.weight": "model-00006-of-00006.safetensors",
311
+ "model.layers.36.attn.q_proj.weight": "model-00006-of-00006.safetensors",
312
+ "model.layers.36.attn.v_proj.weight": "model-00006-of-00006.safetensors",
313
+ "model.layers.36.attn_norm.weight": "model-00006-of-00006.safetensors",
314
+ "model.layers.36.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
315
+ "model.layers.36.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
316
+ "model.layers.36.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
317
+ "model.layers.36.mlp_norm.weight": "model-00006-of-00006.safetensors",
318
+ "model.layers.37.attn.f_proj.weight": "model-00006-of-00006.safetensors",
319
+ "model.layers.37.attn.k_proj.weight": "model-00006-of-00006.safetensors",
320
+ "model.layers.37.attn.o_proj.weight": "model-00006-of-00006.safetensors",
321
+ "model.layers.37.attn.q_proj.weight": "model-00006-of-00006.safetensors",
322
+ "model.layers.37.attn.v_proj.weight": "model-00006-of-00006.safetensors",
323
+ "model.layers.37.attn_norm.weight": "model-00006-of-00006.safetensors",
324
+ "model.layers.37.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
325
+ "model.layers.37.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
326
+ "model.layers.37.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
327
+ "model.layers.37.mlp_norm.weight": "model-00006-of-00006.safetensors",
328
+ "model.layers.38.attn.f_proj.weight": "model-00006-of-00006.safetensors",
329
+ "model.layers.38.attn.k_proj.weight": "model-00006-of-00006.safetensors",
330
+ "model.layers.38.attn.o_proj.weight": "model-00006-of-00006.safetensors",
331
+ "model.layers.38.attn.q_proj.weight": "model-00006-of-00006.safetensors",
332
+ "model.layers.38.attn.v_proj.weight": "model-00006-of-00006.safetensors",
333
+ "model.layers.38.attn_norm.weight": "model-00006-of-00006.safetensors",
334
+ "model.layers.38.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
335
+ "model.layers.38.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
336
+ "model.layers.38.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
337
+ "model.layers.38.mlp_norm.weight": "model-00006-of-00006.safetensors",
338
+ "model.layers.39.attn.f_proj.weight": "model-00006-of-00006.safetensors",
339
+ "model.layers.39.attn.k_proj.weight": "model-00006-of-00006.safetensors",
340
+ "model.layers.39.attn.o_proj.weight": "model-00006-of-00006.safetensors",
341
+ "model.layers.39.attn.q_proj.weight": "model-00006-of-00006.safetensors",
342
+ "model.layers.39.attn.v_proj.weight": "model-00006-of-00006.safetensors",
343
+ "model.layers.39.attn_norm.weight": "model-00006-of-00006.safetensors",
344
+ "model.layers.39.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
345
+ "model.layers.39.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
346
+ "model.layers.39.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
347
+ "model.layers.39.mlp_norm.weight": "model-00006-of-00006.safetensors",
348
+ "model.layers.4.attn.f_proj.weight": "model-00001-of-00006.safetensors",
349
+ "model.layers.4.attn.k_proj.weight": "model-00001-of-00006.safetensors",
350
+ "model.layers.4.attn.o_proj.weight": "model-00001-of-00006.safetensors",
351
+ "model.layers.4.attn.q_proj.weight": "model-00001-of-00006.safetensors",
352
+ "model.layers.4.attn.v_proj.weight": "model-00001-of-00006.safetensors",
353
+ "model.layers.4.attn_norm.weight": "model-00001-of-00006.safetensors",
354
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
355
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
356
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
357
+ "model.layers.4.mlp_norm.weight": "model-00001-of-00006.safetensors",
358
+ "model.layers.5.attn.f_proj.weight": "model-00002-of-00006.safetensors",
359
+ "model.layers.5.attn.k_proj.weight": "model-00002-of-00006.safetensors",
360
+ "model.layers.5.attn.o_proj.weight": "model-00002-of-00006.safetensors",
361
+ "model.layers.5.attn.q_proj.weight": "model-00002-of-00006.safetensors",
362
+ "model.layers.5.attn.v_proj.weight": "model-00002-of-00006.safetensors",
363
+ "model.layers.5.attn_norm.weight": "model-00001-of-00006.safetensors",
364
+ "model.layers.5.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
365
+ "model.layers.5.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
366
+ "model.layers.5.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
367
+ "model.layers.5.mlp_norm.weight": "model-00002-of-00006.safetensors",
368
+ "model.layers.6.attn.f_proj.weight": "model-00002-of-00006.safetensors",
369
+ "model.layers.6.attn.k_proj.weight": "model-00002-of-00006.safetensors",
370
+ "model.layers.6.attn.o_proj.weight": "model-00002-of-00006.safetensors",
371
+ "model.layers.6.attn.q_proj.weight": "model-00002-of-00006.safetensors",
372
+ "model.layers.6.attn.v_proj.weight": "model-00002-of-00006.safetensors",
373
+ "model.layers.6.attn_norm.weight": "model-00002-of-00006.safetensors",
374
+ "model.layers.6.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
375
+ "model.layers.6.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
376
+ "model.layers.6.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
377
+ "model.layers.6.mlp_norm.weight": "model-00002-of-00006.safetensors",
378
+ "model.layers.7.attn.f_proj.weight": "model-00002-of-00006.safetensors",
379
+ "model.layers.7.attn.k_proj.weight": "model-00002-of-00006.safetensors",
380
+ "model.layers.7.attn.o_proj.weight": "model-00002-of-00006.safetensors",
381
+ "model.layers.7.attn.q_proj.weight": "model-00002-of-00006.safetensors",
382
+ "model.layers.7.attn.v_proj.weight": "model-00002-of-00006.safetensors",
383
+ "model.layers.7.attn_norm.weight": "model-00002-of-00006.safetensors",
384
+ "model.layers.7.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
385
+ "model.layers.7.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
386
+ "model.layers.7.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
387
+ "model.layers.7.mlp_norm.weight": "model-00002-of-00006.safetensors",
388
+ "model.layers.8.attn.f_proj.weight": "model-00002-of-00006.safetensors",
389
+ "model.layers.8.attn.k_proj.weight": "model-00002-of-00006.safetensors",
390
+ "model.layers.8.attn.o_proj.weight": "model-00002-of-00006.safetensors",
391
+ "model.layers.8.attn.q_proj.weight": "model-00002-of-00006.safetensors",
392
+ "model.layers.8.attn.v_proj.weight": "model-00002-of-00006.safetensors",
393
+ "model.layers.8.attn_norm.weight": "model-00002-of-00006.safetensors",
394
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
395
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
396
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
397
+ "model.layers.8.mlp_norm.weight": "model-00002-of-00006.safetensors",
398
+ "model.layers.9.attn.f_proj.weight": "model-00002-of-00006.safetensors",
399
+ "model.layers.9.attn.k_proj.weight": "model-00002-of-00006.safetensors",
400
+ "model.layers.9.attn.o_proj.weight": "model-00002-of-00006.safetensors",
401
+ "model.layers.9.attn.q_proj.weight": "model-00002-of-00006.safetensors",
402
+ "model.layers.9.attn.v_proj.weight": "model-00002-of-00006.safetensors",
403
+ "model.layers.9.attn_norm.weight": "model-00002-of-00006.safetensors",
404
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
405
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
406
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
407
+ "model.layers.9.mlp_norm.weight": "model-00002-of-00006.safetensors",
408
+ "model.norm.weight": "model-00006-of-00006.safetensors"
409
+ }
410
+ }
pyproject.toml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "flame"
3
+ dynamic = ["version"]
4
+ description = "A minimal training framework for scaling FLA models"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Songlin Yang", email = "yangsl66@mit.edu" },
8
+ { name = "Yu Zhang", email = "yzhang.cs@outlook.com" },
9
+ ]
10
+ license = { file = "LICENSE" }
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: MIT License",
14
+ "Operating System :: OS Independent",
15
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
16
+ ]
17
+ requires-python = ">=3.10"
18
+ dependencies = [
19
+ 'flash-linear-attention',
20
+ 'torch>=2.5',
21
+ 'torchdata',
22
+ 'transformers>=4.45.0',
23
+ 'triton>=3.0',
24
+ 'datasets>=3.3.0',
25
+ 'einops',
26
+ 'ninja',
27
+ 'wandb',
28
+ 'tiktoken',
29
+ 'tensorboard',
30
+ "tyro>=1.0.3",
31
+ "torchtitan",
32
+ "psutil>=7.2.1",
33
+ "cmake>=4.2.1",
34
+ "packaging>=25.0",
35
+ "setuptools>=80.9.0",
36
+ "wheel>=0.45.1",
37
+ "flash-attn>=2.8.3",
38
+ "ipython>=8.37.0",
39
+ ]
40
+
41
+ [project.optional-dependencies]
42
+ dev = ["pytest"]
43
+
44
+ [project.urls]
45
+ Homepage = "https://github.com/fla-org/flame"
46
+
47
+ [build-system]
48
+ requires = ["setuptools>=45", "wheel", "ninja", "torch"]
49
+
50
+ [tool.isort]
51
+ line_length = 127
52
+ multi_line_output = 3
53
+
54
+ [tool.uv.sources]
55
+ torchtitan = { git = "https://github.com/pytorch/torchtitan.git", rev = "0b44d4c" }
setup.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import ast
4
+ import os
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from setuptools import find_packages, setup
9
+
10
+ with open('README.md') as f:
11
+ long_description = f.read()
12
+
13
+
14
+ def get_package_version():
15
+ with open(Path(os.path.dirname(os.path.abspath(__file__))) / 'flame' / '__init__.py') as f:
16
+ version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
17
+ return ast.literal_eval(version_match.group(1))
18
+
19
+
20
+ setup(
21
+ name='flame',
22
+ version=get_package_version(),
23
+ description='A minimal training framework for scaling FLA models',
24
+ long_description=long_description,
25
+ long_description_content_type='text/markdown',
26
+ author='Songlin Yang, Yu Zhang',
27
+ author_email='yangsl66@mit.edu, yzhang.cs@outlook.com',
28
+ url='https://github.com/fla-org/flame',
29
+ packages=find_packages(),
30
+ license='MIT',
31
+ classifiers=[
32
+ 'Programming Language :: Python :: 3',
33
+ 'License :: OSI Approved :: MIT License',
34
+ 'Operating System :: OS Independent',
35
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence'
36
+ ],
37
+ python_requires='>=3.10',
38
+ install_requires=[
39
+ 'flash-linear-attention',
40
+ 'torch>=2.5',
41
+ 'torchdata',
42
+ 'transformers>=4.45.0',
43
+ 'triton>=3.0',
44
+ 'datasets>=3.3.0',
45
+ 'einops',
46
+ 'ninja',
47
+ 'wandb',
48
+ 'tiktoken',
49
+ 'tensorboard',
50
+ ],
51
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
train.sh ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+
3
+ params=""
4
+ if [ $# -ne 0 ]; then
5
+ params="$*"
6
+ fi
7
+
8
+ # use envs as local params for convenience
9
+ # e.g.
10
+ # NNODE=1 NGPU=8 LOG_RANK=0 ./train.sh
11
+ NNODE=${NNODE:-"1"}
12
+ NGPU=${NGPU:-"8"}
13
+ LOG_RANK=${LOG_RANK:-0}
14
+
15
+ if [[ -z "${MASTER_ADDR}" ]]; then
16
+ export MASTER_ADDR="localhost"
17
+ fi
18
+ if [[ -z "${MASTER_PORT}" ]]; then
19
+ export MASTER_PORT="0"
20
+ fi
21
+
22
+ : '
23
+ Usage:
24
+
25
+ bash train.sh -h
26
+
27
+ Training a 340M model:
28
+
29
+ NNODE=1 NGPU=8 LOG_RANK=0 bash train.sh \
30
+ --job.config_file flame/models/fla.toml \
31
+ --job.dump_folder exp/transformer-340M-10B/batch32.seqlen2048.warmup1024.update1.steps20480.lr3e-4 \
32
+ --model.config configs/transformer_340M.json \
33
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
34
+ --optimizer.name AdamW \
35
+ --optimizer.eps 1e-15 \
36
+ --optimizer.lr 3e-4 \
37
+ --lr_scheduler.warmup_steps 1024 \
38
+ --lr_scheduler.lr_min 0.1 \
39
+ --lr_scheduler.decay_type cosine \
40
+ --training.batch_size 32 \
41
+ --training.seq_len 2048 \
42
+ --training.gradient_accumulation_steps 1 \
43
+ --training.steps 20480 \
44
+ --training.max_norm 1.0 \
45
+ --training.skip_nan_inf \
46
+ --training.dataset HuggingFaceFW/fineweb-edu \
47
+ --training.dataset_name default \
48
+ --training.dataset_split train \
49
+ --training.streaming \
50
+ --training.num_workers 32 \
51
+ --training.prefetch_factor 2 \
52
+ --training.seed 42 \
53
+ --training.compile \
54
+ --training.tensor_parallel_degree 1 \
55
+ --training.disable_loss_parallel \
56
+ --checkpoint.interval 2048 \
57
+ --checkpoint.load_step -1 \
58
+ --metrics.log_freq 1
59
+ '
60
+
61
+ echo "Launching training..."
62
+
63
+ set -x
64
+ path=$(grep -oP '(?<=--job.dump_folder )[^ ]+' <<< "$params")
65
+ steps=$(grep -oP '(?<=--training.steps )[^ ]+' <<< "$params")
66
+ config=$(grep -oP '(?<=--model.config )[^ ]+' <<< "$params")
67
+ tokenizer=$(grep -oP '(?<=--model.tokenizer_path )[^ ]+' <<< "$params")
68
+ echo "Using Python at: $(which python)"
69
+ model=$(
70
+ python -c "import fla, sys; import fla.models.gsa; import fla.models.routmem; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type'
71
+ )
72
+
73
+ mkdir -p $path
74
+ cp * $path
75
+ cp -r configs $path
76
+ cp -r flame $path
77
+ cp -r 3rdparty/flash-linear-attention/fla $path
78
+ cp -r 3rdparty/torchtitan/torchtitan $path
79
+
80
+ # for offline systems
81
+ # export TRANSFORMERS_OFFLINE=1
82
+ # export HF_DATASETS_OFFLINE=1
83
+ # export HF_HUB_OFFLINE=1
84
+ if [ "$date" == "" ]; then
85
+ date=$(date +%Y%m%d%H%M)
86
+ fi
87
+ RUN_NAME="$model-$(basename $path)"
88
+ RUN_ID="$RUN_NAME-$date"
89
+
90
+ export WANDB_RESUME=allow
91
+ if [[ -z "${WANDB_PROJECT}" ]]; then
92
+ export WANDB_PROJECT="fla"
93
+ fi
94
+ if [[ -z "${WANDB_NAME}" ]]; then
95
+ export WANDB_NAME="$RUN_NAME"
96
+ fi
97
+ if [[ -z "${WANDB_RUN_ID}" ]]; then
98
+ export WANDB_RUN_ID="$RUN_ID"
99
+ fi
100
+
101
+ PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
102
+ torchrun --nnodes=${NNODE} \
103
+ --nproc_per_node=${NGPU} \
104
+ --rdzv_backend c10d \
105
+ --rdzv_endpoint "${MASTER_ADDR}:${MASTER_PORT}" \
106
+ --local-ranks-filter ${LOG_RANK} \
107
+ --role rank \
108
+ --tee 3 \
109
+ --log-dir $path/logs \
110
+ -m flame.train \
111
+ $params
112
+
113
+ echo "TRAINING DONE!"
114
+ echo "Converting the DCP checkpoints to HF format..."
115
+
116
+ python -m flame.utils.convert_dcp_to_hf \
117
+ --path $path \
118
+ --step $steps \
119
+ --config $config \
120
+ --tokenizer $tokenizer
121
+
122
+ echo "RUNNING DONE!"
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
vocab.json ADDED
The diff for this file is too large to render. See raw diff