ZhenbinWang commited on
Commit
0caebb9
·
verified ·
1 Parent(s): d3adf8d

Upload 10 files

Browse files
Files changed (10) hide show
  1. README.md +519 -3
  2. config.json +40 -0
  3. generation_config.json +6 -0
  4. pyproject.toml +43 -0
  5. setup.py +51 -0
  6. special_tokens_map.json +23 -0
  7. tokenizer.json +0 -0
  8. tokenizer_config.json +44 -0
  9. train.sh +130 -0
  10. train_restart.sh +130 -0
README.md CHANGED
@@ -1,3 +1,519 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Language Modeling Made Easy
4
+
5
+ [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/fla-org/flame)
6
+
7
+ </div>
8
+
9
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for language models with blazing efficiency.
10
+
11
+ **Feature Highlights:**
12
+
13
+ - 🚀 Minimal, easy-to-use, extensible training framework
14
+ - 🤗 Seamless integration with `fla` and `transformers`
15
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
16
+ - 🔮 4D parallelism (coming soon)
17
+
18
+ ## Setup
19
+
20
+ To get started, clone the `flame` repository and install the required dependencies:
21
+
22
+ ```bash
23
+ git clone https://github.com/fla-org/flame.git
24
+ cd flame
25
+ pip install .
26
+ ```
27
+
28
+ Install the latest version of fla
29
+ ```
30
+ pip uninstall flash-linear-attention && pip install -U --no-use-pep517 git+https://github.com/fla-org/flash-linear-attention
31
+ ```
32
+
33
+ [Important] Install specific version of torchtitan
34
+ ```
35
+ pip install git+https://github.com/pytorch/torchtitan.git@0b44d4c
36
+ ```
37
+
38
+
39
+ ## Dataset Preparation
40
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
41
+
42
+ ```py
43
+ from datasets import load_dataset
44
+
45
+ # load fineweb-edu with parallel processing
46
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
47
+
48
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
49
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
50
+ ```
51
+
52
+ ## Training Recipes
53
+
54
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus ~~in streaming mode~~. (Do not use streaming mode if you are concerned about resuming training.)
55
+
56
+ > [!WARNING]
57
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
58
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
59
+
60
+ ```sh
61
+ bash train.sh \
62
+ --job.config_file flame/models/fla.toml \
63
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr1e-3.cosine \
64
+ --model.config configs/transformer_340M.json \
65
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
66
+ --optimizer.name AdamW \
67
+ --optimizer.eps 1e-15 \
68
+ --optimizer.lr 1e-3 \
69
+ --lr_scheduler.warmup_steps 1024 \
70
+ --lr_scheduler.lr_min 0.1 \
71
+ --lr_scheduler.decay_type cosine \
72
+ --training.batch_size 1 \
73
+ --training.seq_len 65536 \
74
+ --training.context_len 4096 \
75
+ --training.varlen \
76
+ --training.gradient_accumulation_steps 1 \
77
+ --training.steps 20480 \
78
+ --training.max_norm 1.0 \
79
+ --training.skip_nan_inf \
80
+ --training.dataset HuggingFaceFW/fineweb-edu \
81
+ --training.dataset_name sample-100BT \
82
+ --training.dataset_split train \
83
+ --training.num_workers 32 \
84
+ --training.prefetch_factor 2 \
85
+ --training.seed 42 \
86
+ --training.compile \
87
+ --checkpoint.interval 2048 \
88
+ --checkpoint.load_step -1 \
89
+ --checkpoint.keep_latest_k 2 \
90
+ --metrics.log_freq 1
91
+ ```
92
+
93
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
94
+ **For single-GPU debugging, set `NGPU=1`.**
95
+
96
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
97
+ By default, the learning rate is set to 1e-3 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
98
+
99
+ **Key parameters:**
100
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
101
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
102
+ - `--training.steps`: Total number of training steps.
103
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
104
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
105
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
106
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
107
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
108
+
109
+ > [!WARNING]
110
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
111
+ > Each step processes `global_batch_size * seq_len` tokens.
112
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
113
+
114
+ For a detailed explanation of all parameters, run:
115
+
116
+ ```sh
117
+ bash train.sh -h
118
+ ```
119
+
120
+ <details>
121
+ <summary>Usage</summary>
122
+
123
+ ```py
124
+ options:
125
+ -h, --help show this help message and exit
126
+ --job.config_file JOB.CONFIG_FILE
127
+ Job config file
128
+ --job.dump_folder JOB.DUMP_FOLDER
129
+ Folder to dump job outputs
130
+ --job.description JOB.DESCRIPTION
131
+ Description of the job
132
+ --job.use_for_integration_test
133
+ Add this config to the integration test suite
134
+ --job.print_args Print the args to terminal
135
+ --model.config MODEL.CONFIG
136
+ Path to the model config
137
+ --model.norm_type MODEL.NORM_TYPE
138
+ Type of layer normalization to use [layernorm,
139
+ np_layernorm, rmsnorm, fused_rmsnorm]
140
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
141
+ Tokenizer path
142
+ --profiling.enable_profiling
143
+ Whether to enable pytorch profiler
144
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
145
+ Trace files location
146
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
147
+ How often to collect profiler traces, in iterations
148
+ --profiling.enable_memory_snapshot
149
+ Whether to dump memory snapshot
150
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
151
+ Memeory snapshot files location
152
+ --optimizer.name OPTIMIZER.NAME
153
+ Optimizer to use
154
+ --optimizer.eps OPTIMIZER.EPS
155
+ Epsilon value for the optimizer.
156
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
157
+ --optimizer.scheduler {wsd,cosine,linear}
158
+ Scheduler to use. Currently supported: wsd, cosine,
159
+ and linear.
160
+ --optimizer.lr OPTIMIZER.LR
161
+ Learning rate to use
162
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
163
+ Min lr ratio for lr scheduler
164
+ --optimizer.early_step_in_backward
165
+ Whether to apply optimizer in the backward. Caution,
166
+ optimizer_in_backward is not compatible with gradients
167
+ clipping, users should not call
168
+ register_post_accumulate_grad_hook after the optimizer
169
+ is built.
170
+ --training.batch_size TRAINING.BATCH_SIZE
171
+ Batch size
172
+ --training.seq_len TRAINING.SEQ_LEN
173
+ Sequence length
174
+ --training.context_len TRAINING.CONTEXT_LEN
175
+ Max length allowed for each sequence
176
+ --training.varlen Whether to take sequences of variable length as input
177
+ --training.warmup_steps TRAINING.WARMUP_STEPS
178
+ Steps for lr scheduler warmup, normally 1/5 of
179
+ --training.steps
180
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
181
+ Number of steps to accumulate gradients before
182
+ updating parameters
183
+ --training.steps TRAINING.STEPS
184
+ How many train steps to run
185
+ --training.max_norm TRAINING.MAX_NORM
186
+ Max norm for gradient clipping
187
+ --training.skip_nan_inf
188
+ Skip batch updates when NaN or INF gradients are
189
+ encountered during training
190
+ --training.dataset TRAINING.DATASET
191
+ Dataset to use, with comma separated values
192
+ --training.dataset_name TRAINING.DATASET_NAME
193
+ The name of the dataset config, with comma separated
194
+ values if provided
195
+ --training.dataset_split TRAINING.DATASET_SPLIT
196
+ Dataset split to use, with comma separated values if
197
+ provided
198
+ --training.data_dir TRAINING.DATA_DIR
199
+ Data dirs to use, with comma separated values if
200
+ provided
201
+ --training.data_files TRAINING.DATA_FILES
202
+ Data files to use, with comma separated values if
203
+ provided
204
+ --training.data_probs TRAINING.DATA_PROBS
205
+ Data sampling probabilities, with comma separated
206
+ values if provided
207
+ --training.streaming Whether to load dataset in streaming mode, used for
208
+ huge dataset
209
+ --training.num_workers TRAINING.NUM_WORKERS
210
+ Number of subprocesses to use for data loading. 0
211
+ means that the data will be loaded in the main
212
+ process.
213
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
214
+ Number of batches loaded in advance by each worker.2
215
+ means there will be a total of 2 * num_workers batches
216
+ prefetched across all workers.
217
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
218
+ The `data_parallel_replicate_degree` argument
219
+ specifies the degree of data parallelism for weight
220
+ replication. When this value is greater than 1,
221
+ weights will be replicated across
222
+ `data_parallel_replicate_degree` ranks. If
223
+ `data_parallel_shard_degree` is also greater than 1,
224
+ the parallelism method used is HSDP (Hybrid Sharded
225
+ Data Parallelism). Otherwise, the parallelism method
226
+ used is DDP (Distributed Data Parallelism). 1 means
227
+ disabled.
228
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
229
+ The `data_parallel_shard_degree` argument specifies
230
+ the degree of data parallelism for weight sharding.
231
+ When this value is greater than 1, weights will be
232
+ sharded across `data_parallel_shard_degree` ranks. If
233
+ `data_parallel_replicate_degree` is also greater than
234
+ 1, the parallelism method used is HSDP (Hybrid Sharded
235
+ Data Parallelism). Otherwise, the parallelism method
236
+ used is FSDP (Fully Sharded Data Parallelism). -1
237
+ means leftover ranks will be used (After
238
+ DP_REPLICATE/SP/PP). Note that only
239
+ `data_parallel_shard_degree` can be negative. 1 means
240
+ disabled.
241
+ --training.enable_cpu_offload
242
+ Whether to apply CPU offloading of parameters,
243
+ gradients, and optimizer states in FSDP
244
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
245
+ Tensor Parallelism degree. 1 means disabled.
246
+ --training.disable_loss_parallel
247
+ Whether to apply loss parallel when sequence parallel
248
+ is enabled
249
+ --training.mixed_precision_param {bfloat16,float32}
250
+ torch dtype to use for parameters when applying mixed
251
+ precision via FSDP. This feature only takes effect
252
+ when data_parallel_shard_degree > 1
253
+ --training.mixed_precision_reduce {float32}
254
+ torch dtype to use for reductions when applying mixed
255
+ precision via FSDP. This feature only takes effect
256
+ when data_parallel_shard_degree > 1
257
+ --training.compile Whether to compile the model
258
+ --training.gc_freq TRAINING.GC_FREQ
259
+ Python garbage control scheduling interval, in steps
260
+ --training.seed TRAINING.SEED
261
+ Choose the base RNG seed used for training
262
+ --training.deterministic
263
+ Use deterministic algorithms wherever possible, may be
264
+ slower
265
+ --metrics.log_freq METRICS.LOG_FREQ
266
+ How often to log metrics to TensorBoard, in iterations
267
+ --metrics.enable_tensorboard
268
+ Whether to log metrics to TensorBoard
269
+ --metrics.disable_color_printing
270
+ Whether to disable color printing in logs
271
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
272
+ Folder to dump TensorBoard states
273
+ --metrics.rank_0_only
274
+ Whether to save TensorBoard metrics only for rank 0 or
275
+ for all ranks. When pipeline_parallel_degree is > 1,
276
+ this option uses the 0th rank of the last stage
277
+ pipeline group, which is the only stage that computes
278
+ loss metrics.
279
+ --metrics.enable_wandb
280
+ Whether to log metrics to Weights & Biases
281
+ --experimental.enable_async_tensor_parallel
282
+ Whether to apply async tensor parallel (currently only
283
+ effective when compile is enabled)
284
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
285
+ Pipeline Parallelism degree, or number of ranks. 1
286
+ means disabled. If using looped schedules, this still
287
+ specifies the number of physical ranks, not the number
288
+ of stages. Stages per rank are inferred from split
289
+ points degree, and schedule.
290
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
291
+ Specify comma-separated names of modules to use as the
292
+ beginning of a split point. e.g. "layers.0,layers.2"
293
+ will cause the model to be split into 3 stages, the
294
+ first containing all the layers up to layers.0, the
295
+ second containing layers.0 and up to layers.2, the
296
+ third containing layers.2 and all the remaining
297
+ layers. Note: fully-automated splitting may be enabled
298
+ in the future, but currently the split points must be
299
+ specified manually.
300
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
301
+ Specify the Pipeline Parallel schedule to use. The
302
+ supported schedules are: https://github.com/pytorch/py
303
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
304
+ rch/distributed/pipelining/schedules.py#L2161. The
305
+ schedule must be compatible with the split points and
306
+ stages_per_rank. Looped schedules (e.g.
307
+ Interleaved1F1B) require specifying
308
+ pipeline_parallel_degree = number of ranks, and
309
+ split_points = number of stages - 1
310
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
311
+ Specify the path to the pipeline parallel schedule csv
312
+ file to use. The pipeline_parallel_schedule argument
313
+ must be either PipelineScheduleSingle,
314
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
315
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
316
+ How many microbatches to split the global training
317
+ batch into when using pipeline parallelism. The global
318
+ training batch size must be evenly divisible by the
319
+ number of microbatches. The default value will be the
320
+ number of pipeline stages, if unspecified.
321
+ --experimental.enable_compiled_autograd
322
+ Enable CompiledAutograd to compile the backward.
323
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
324
+ Context parallelism degree. 1 means disabled.
325
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
326
+ The collective to use in context parallel SDPA for kv
327
+ shards exchange. 'allgather' means to all-gather all
328
+ kv shards on ranks after the first sub-SDPA
329
+ computation, 'alltoall' means to all-to-all shuffle
330
+ the kv shards. The default value is 'allgather'.
331
+ --checkpoint.enable_checkpoint
332
+ Whether to enable checkpoint
333
+ --checkpoint.folder CHECKPOINT.FOLDER
334
+ The folder to store the checkpoints. When
335
+ enable_checkpoint is set to true, checkpoints will be
336
+ in {--job.dump_folder}/{--checkpoint.folder}.
337
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
338
+ Checkpointing interval unit of measurement ['step',
339
+ 'seconds']
340
+ --checkpoint.interval CHECKPOINT.INTERVAL
341
+ Checkpointing interval, in steps or seconds depending
342
+ on --checkpoint.interval_type
343
+ --checkpoint.model_weights_only
344
+ When model_weights_only=True, only model weights will
345
+ be saved at the end of training. With this,
346
+ checkpoints can be loaded using `torch.load(...,
347
+ weights_only=True)` after conversion. When
348
+ model_weights_only=False, the full checkpoint will be
349
+ saved. A full checkpoint includes model, optimizer and
350
+ train_state, which can be used to resume training. The
351
+ default value is false.
352
+ --checkpoint.export_dtype {float16,bfloat16,float32}
353
+ Converts to the specified precision when training
354
+ completes and model_weights_only=true. Currently
355
+ supports float32, float16, and bfloat16. The default
356
+ value is float32.
357
+ --checkpoint.create_seed_checkpoint
358
+ Initializes the full model without applying
359
+ parallelisms, and then saves it as a seed checkpoint.
360
+ Note: requires user to call train.py without
361
+ specifying any parallelisms, e.g. NGPU=1. Could be
362
+ implemented as a separate script, but this way shares
363
+ more code.
364
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
365
+ Which async checkpoint mode to use. Currently there
366
+ are 3 different modes. 1. "disabled": synchronized
367
+ checkpointing will be used. 2. "async":
368
+ torch.distributed.checkpoint.async_save will be used.
369
+ 1. "async_with_pinned_mem": this option utilizes a
370
+ dedicated pinned memory space and creates a separate
371
+ process for faster GPU->CPU transfer performance and
372
+ eliminating GIL contention. The cost is increased CPU
373
+ memory usage. If insufficient CPU memory is available,
374
+ performance may degrade due to memory paging. For most
375
+ users, "async" should suffice as the performance
376
+ overhead is typically small (on the order of tens of
377
+ seconds) compared to checkpointing frequency. This
378
+ mode can be employed to pursue near-zero checkpointing
379
+ times (e.g., < 1 second) given appropriate hardware
380
+ support such as ample CPU memory and fast PCIe.
381
+ "disabled" is the default mode.
382
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
383
+ Keeps only the latest k checkpoints, and purging older
384
+ ones. If 0, keep all checkpoints. 0 is the default
385
+ value.
386
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
387
+ Load the checkpoint at the specified step. If -1, load
388
+ the latest checkpoint.
389
+ --float8.enable_float8_linear
390
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
391
+ This feature requires you to install 'torchao' which
392
+ can be found here: https://github.com/pytorch/ao
393
+ --float8.enable_fsdp_float8_all_gather
394
+ Whether enable float8 all-gather in FSDP
395
+ --float8.precompute_float8_dynamic_scale_for_fsdp
396
+ Whether precompute float8 scales dynamically for FSDP
397
+ --float8.scaling_type_input {dynamic,delayed}
398
+ float8 scaling for input, dynamic (default) or delayed
399
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
400
+ float8 scaling for input, dynamic (default) or delayed
401
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
402
+ float8 scaling for input, dynamic (default) or delayed
403
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
404
+ Timeout for communication operations, during
405
+ initialization and first train step.
406
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
407
+ Timeout for communication operations after the first
408
+ train step -- usually a tighter bound than during
409
+ initialization.
410
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
411
+ Flight recorder ring buffer size, >0 means recording
412
+ by default, 0 means disabled
413
+ --memory_estimation.enabled
414
+ Whether to estimate memory usage for FSDP
415
+ --memory_estimation.disable_fake_mode
416
+ Whether to estimate memory under FakeTensorMode
417
+ ```
418
+ </details>
419
+
420
+ ### Training with variable-length inputs
421
+ When you set the `--training.varlen` flag, you're enabling a more efficient training method that packs multiple documents together into a single long sequence, eliminating the need for padding.
422
+ This is particularly useful when your dataset contains documents of varying lengths.
423
+ Let's break down how `--training.seq_len` and `--training.context_len` work in this mode.
424
+
425
+ * `--training.seq_len` (Packed Sequence Length): This is the total length of the final sequence fed to the model on one device. Instead of processing one document at a time, the dataloader takes multiple documents (each split to sequences no longer than `context_len`), concatenates them end-to-end, and creates a single long sequence of length `seq_len`.
426
+ * `--training.context_len` (Sample Length): This parameter defines the maximum number of tokens for a single document or sample. If a document from the dataset is longer than `context_len`, it will be truncated. For example, if `--training.context_len` is set to 4,096, a document with 5,000 tokens will be cut down to its first 4,096 tokens, leaving the left tokens as another independent sequence, while a document with 3000 tokens remains unchanged.
427
+
428
+ ### Training with `torch.compile`
429
+
430
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
431
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
432
+
433
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
434
+ We are actively working on resolving these issues to make compilation transparent to users.
435
+ In the meantime, please ensure you are using the latest dependencies.
436
+
437
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
438
+
439
+ ### Training with multiple datasets
440
+
441
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
442
+ `flame` allows training with multiple datasets easily.
443
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
444
+
445
+ ```sh
446
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
447
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
448
+ ```
449
+
450
+ ### ~Finalizing training~
451
+
452
+ > [!NOTE]
453
+ > We have done this conversion automatically in the training script since our latest updates.
454
+
455
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
456
+ To facilitate this, we provide a straightforward conversion script:
457
+
458
+ ```sh
459
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
460
+ ```
461
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
462
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
463
+
464
+ ### Continual training
465
+
466
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
467
+ This allows you to seamlessly resume training with `flame`.
468
+ ```sh
469
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
470
+ ```
471
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
472
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
473
+
474
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
475
+
476
+ ## Multi-node training
477
+
478
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
479
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
480
+
481
+ To set up multi-node training:
482
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
483
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
484
+
485
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
486
+
487
+ ## Custom models
488
+
489
+ `flame` supports custom model architectures through seamless integration with the Hugging Face `transformers` library. To add your own model:
490
+
491
+ 1. Create a new model directory under `custom_models/` (see `custom_models/sba` for a complete example)
492
+ 2. Implement your model classes and configuration:
493
+ - Define a config class inheriting from `PretrainedConfig` (see `custom_models/sba/config_sba.py` for an example)
494
+ - Create model classes inheriting from `PreTrainedModel` (see `custom_models/sba/modeling_sba.py` for an example)
495
+ 3. Register your models in `__init__.py`:
496
+ - Import your model classes and config classes
497
+ - Register your models with the `AutoModelForCausalLM`, `AutoModel` and `AutoConfig` classes (see `custom_models/sba/__init__.py` for an example)
498
+ 4. Create a config file for your custom model, just need to specify the `model_type` to the one you just named for your custom model (example: `configs/sba_340m.json`).
499
+ 5. Training is extremely simple, you can just use the `flame.train.py` script to train your custom model.
500
+
501
+
502
+
503
+
504
+
505
+
506
+
507
+ ## Citation
508
+
509
+ If you find `flame` helpful for your work, please consider citing it.
510
+
511
+ ```bib
512
+ @software{yang2025flame,
513
+ title = {Flame: Flash Language Modeling Made Easy},
514
+ author = {Zhang, Yu and Yang, Songlin},
515
+ url = {https://github.com/fla-org/flame},
516
+ month = jan,
517
+ year = {2025}
518
+ }
519
+ ```
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "P_rank": 1,
3
+ "architectures": [
4
+ "HamiltonForCausalLM"
5
+ ],
6
+ "attn": null,
7
+ "attn_mode": "chunk",
8
+ "bos_token_id": 1,
9
+ "conv_bias_spatial": true,
10
+ "conv_bias_temporal": true,
11
+ "conv_size_spatial": 4,
12
+ "conv_size_temporal": 16,
13
+ "decay_weight": 4.0,
14
+ "dropout": 0.0,
15
+ "dtype": "float32",
16
+ "eos_token_id": 2,
17
+ "expand_spatial": 4,
18
+ "fuse_cross_entropy": true,
19
+ "fuse_norm": true,
20
+ "fuse_swiglu": true,
21
+ "head_dim": 64,
22
+ "hidden_act": "swish",
23
+ "hidden_ratio": 3,
24
+ "hidden_size": 1024,
25
+ "initializer_range": 0.02,
26
+ "intermediate_size": null,
27
+ "max_position_embeddings": 8192,
28
+ "mnu_bias": true,
29
+ "model_type": "hamilton",
30
+ "norm_eps": 1e-06,
31
+ "num_heads": 16,
32
+ "num_hidden_layers": 24,
33
+ "task": "text",
34
+ "tie_word_embeddings": false,
35
+ "transformers_version": "4.57.3",
36
+ "use_cache": true,
37
+ "use_l2warp": false,
38
+ "use_mlp": true,
39
+ "vocab_size": 32000
40
+ }
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.57.3"
6
+ }
pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "flame"
3
+ dynamic = ["version"]
4
+ description = "A minimal training framework for scaling FLA models"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Songlin Yang", email = "yangsl66@mit.edu" },
8
+ { name = "Yu Zhang", email = "yzhang.cs@outlook.com" },
9
+ ]
10
+ license = { file = "LICENSE" }
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: MIT License",
14
+ "Operating System :: OS Independent",
15
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
16
+ ]
17
+ requires-python = ">=3.10"
18
+ dependencies = [
19
+ 'flash-linear-attention',
20
+ 'torch>=2.5',
21
+ 'torchdata',
22
+ 'transformers>=4.45.0',
23
+ 'triton>=3.0',
24
+ 'datasets>=3.3.0',
25
+ 'einops',
26
+ 'ninja',
27
+ 'wandb',
28
+ 'tiktoken',
29
+ 'tensorboard',
30
+ ]
31
+
32
+ [project.optional-dependencies]
33
+ dev = ["pytest"]
34
+
35
+ [project.urls]
36
+ Homepage = "https://github.com/fla-org/flame"
37
+
38
+ [build-system]
39
+ requires = ["setuptools>=45", "wheel", "ninja", "torch"]
40
+
41
+ [tool.isort]
42
+ line_length = 127
43
+ multi_line_output = 3
setup.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import ast
4
+ import os
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from setuptools import find_packages, setup
9
+
10
+ with open('README.md') as f:
11
+ long_description = f.read()
12
+
13
+
14
+ def get_package_version():
15
+ with open(Path(os.path.dirname(os.path.abspath(__file__))) / 'flame' / '__init__.py') as f:
16
+ version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
17
+ return ast.literal_eval(version_match.group(1))
18
+
19
+
20
+ setup(
21
+ name='flame',
22
+ version=get_package_version(),
23
+ description='A minimal training framework for scaling FLA models',
24
+ long_description=long_description,
25
+ long_description_content_type='text/markdown',
26
+ author='Songlin Yang, Yu Zhang',
27
+ author_email='yangsl66@mit.edu, yzhang.cs@outlook.com',
28
+ url='https://github.com/fla-org/flame',
29
+ packages=find_packages(),
30
+ license='MIT',
31
+ classifiers=[
32
+ 'Programming Language :: Python :: 3',
33
+ 'License :: OSI Approved :: MIT License',
34
+ 'Operating System :: OS Independent',
35
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence'
36
+ ],
37
+ python_requires='>=3.10',
38
+ install_requires=[
39
+ 'flash-linear-attention',
40
+ 'torch>=2.5',
41
+ 'torchdata',
42
+ 'transformers>=4.45.0',
43
+ 'triton>=3.0',
44
+ 'datasets>=3.3.0',
45
+ 'einops',
46
+ 'ninja',
47
+ 'wandb',
48
+ 'tiktoken',
49
+ 'tensorboard',
50
+ ],
51
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "additional_special_tokens": [],
32
+ "bos_token": "<s>",
33
+ "clean_up_tokenization_spaces": false,
34
+ "eos_token": "</s>",
35
+ "extra_special_tokens": {},
36
+ "legacy": false,
37
+ "model_max_length": 1000000000000000019884624838656,
38
+ "pad_token": null,
39
+ "sp_model_kwargs": {},
40
+ "spaces_between_special_tokens": false,
41
+ "tokenizer_class": "LlamaTokenizer",
42
+ "unk_token": "<unk>",
43
+ "use_default_system_prompt": false
44
+ }
train.sh ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+
3
+ export HF_HOME="/root/workspace/huggingface_cache"
4
+
5
+ export HF_ENDPOINT=https://hf-mirror.com
6
+ # export HF_HOME="../../autodl-fs/hf_cache"
7
+
8
+ params=""
9
+ if [ $# -ne 0 ]; then
10
+ params="$*"
11
+ fi
12
+
13
+ # use envs as local params for convenience
14
+ # e.g.
15
+ # NNODE=1 NGPU=8 LOG_RANK=0 ./train.sh
16
+ NNODE=${NNODE:-"1"}
17
+ NGPU=${NGPU:-"4"}
18
+ DEVICES=${DEVICES:-"0,1,2,3"}
19
+
20
+ LOG_RANK=${LOG_RANK:-0}
21
+
22
+ if [[ -z "${MASTER_ADDR}" ]]; then
23
+ export MASTER_ADDR="localhost"
24
+ fi
25
+ if [[ -z "${MASTER_PORT}" ]]; then
26
+ export MASTER_PORT="0"
27
+ fi
28
+
29
+ : '
30
+ Usage:
31
+
32
+ bash train.sh -h
33
+
34
+ Training a 340M model:
35
+
36
+ NNODE=1 NGPU=8 LOG_RANK=0 bash train.sh \
37
+ --job.config_file flame/models/fla.toml \
38
+ --job.dump_folder exp/transformer-340M-10B/batch32.seqlen2048.warmup1024.update1.steps20480.lr3e-4 \
39
+ --model.config configs/transformer_340M.json \
40
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
41
+ --optimizer.name AdamW \
42
+ --optimizer.eps 1e-15 \
43
+ --optimizer.lr 3e-4 \
44
+ --lr_scheduler.warmup_steps 1024 \
45
+ --lr_scheduler.lr_min 0.1 \
46
+ --lr_scheduler.decay_type cosine \
47
+ --training.batch_size 32 \
48
+ --training.seq_len 2048 \
49
+ --training.gradient_accumulation_steps 1 \
50
+ --training.steps 20480 \
51
+ --training.max_norm 1.0 \
52
+ --training.skip_nan_inf \
53
+ --training.dataset HuggingFaceFW/fineweb-edu \
54
+ --training.dataset_name default \
55
+ --training.dataset_split train \
56
+ --training.streaming \
57
+ --training.num_workers 32 \
58
+ --training.prefetch_factor 2 \
59
+ --training.seed 42 \
60
+ --training.compile \
61
+ --training.tensor_parallel_degree 1 \
62
+ --training.disable_loss_parallel \
63
+ --checkpoint.interval 2048 \
64
+ --checkpoint.load_step -1 \
65
+ --metrics.log_freq 1
66
+ '
67
+
68
+ echo "Launching training..."
69
+
70
+ set -x
71
+ path=$(grep -oP '(?<=--job.dump_folder )[^ ]+' <<< "$params")
72
+ steps=$(grep -oP '(?<=--training.steps )[^ ]+' <<< "$params")
73
+ config=$(grep -oP '(?<=--model.config )[^ ]+' <<< "$params")
74
+ tokenizer=$(grep -oP '(?<=--model.tokenizer_path )[^ ]+' <<< "$params")
75
+ model=$(
76
+ python -c "import fla, sys; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type'
77
+ )
78
+
79
+ mkdir -p $path
80
+ cp * $path
81
+ cp -r configs $path
82
+ cp -r flame $path
83
+ cp -r 3rdparty/flash-linear-attention/fla $path
84
+ cp -r 3rdparty/torchtitan/torchtitan $path
85
+
86
+ # for offline systems
87
+ # export TRANSFORMERS_OFFLINE=1
88
+ # export HF_DATASETS_OFFLINE=1
89
+ # export HF_HUB_OFFLINE=1
90
+ if [ "$date" == "" ]; then
91
+ date=$(date +%Y%m%d%H%M)
92
+ fi
93
+ RUN_NAME="$model-$(basename $path)"
94
+ RUN_ID="$RUN_NAME-$date"
95
+
96
+ export WANDB_RESUME=allow
97
+ if [[ -z "${WANDB_PROJECT}" ]]; then
98
+ export WANDB_PROJECT="fla"
99
+ fi
100
+ if [[ -z "${WANDB_NAME}" ]]; then
101
+ export WANDB_NAME="$RUN_NAME"
102
+ fi
103
+ if [[ -z "${WANDB_RUN_ID}" ]]; then
104
+ export WANDB_RUN_ID="$RUN_ID"
105
+ fi
106
+
107
+ CUDA_VISIBLE_DEVICES=${DEVICES} \
108
+ PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
109
+ # systemd-run --scope --user -p MemoryHigh=80G \
110
+ torchrun --nnodes=${NNODE} \
111
+ --nproc_per_node=${NGPU} \
112
+ --rdzv_backend c10d \
113
+ --rdzv_endpoint "${MASTER_ADDR}:${MASTER_PORT}" \
114
+ --local-ranks-filter ${LOG_RANK} \
115
+ --role rank \
116
+ --tee 3 \
117
+ --log-dir $path/logs \
118
+ -m flame.train \
119
+ $params
120
+
121
+ echo "TRAINING DONE!"
122
+ echo "Converting the DCP checkpoints to HF format..."
123
+
124
+ python -m flame.utils.convert_dcp_to_hf \
125
+ --path $path \
126
+ --step $steps \
127
+ --config $config \
128
+ --tokenizer $tokenizer
129
+
130
+ echo "RUNNING DONE!"
train_restart.sh ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+
3
+ export HF_HOME="/root/workspace/huggingface_cache"
4
+
5
+ export HF_ENDPOINT=https://hf-mirror.com
6
+ # export HF_HOME="../../autodl-fs/hf_cache"
7
+
8
+ params=""
9
+ if [ $# -ne 0 ]; then
10
+ params="$*"
11
+ fi
12
+
13
+ # use envs as local params for convenience
14
+ # e.g.
15
+ # NNODE=1 NGPU=8 LOG_RANK=0 ./train.sh
16
+ NNODE=${NNODE:-"1"}
17
+ NGPU=${NGPU:-"3"}
18
+ DEVICES=${DEVICES:-"0,1,2"}
19
+
20
+ LOG_RANK=${LOG_RANK:-0}
21
+
22
+ if [[ -z "${MASTER_ADDR}" ]]; then
23
+ export MASTER_ADDR="localhost"
24
+ fi
25
+ if [[ -z "${MASTER_PORT}" ]]; then
26
+ export MASTER_PORT="0"
27
+ fi
28
+
29
+ : '
30
+ Usage:
31
+
32
+ bash train.sh -h
33
+
34
+ Training a 340M model:
35
+
36
+ NNODE=1 NGPU=8 LOG_RANK=0 bash train.sh \
37
+ --job.config_file flame/models/fla.toml \
38
+ --job.dump_folder exp/transformer-340M-10B/batch32.seqlen2048.warmup1024.update1.steps20480.lr3e-4 \
39
+ --model.config configs/transformer_340M.json \
40
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
41
+ --optimizer.name AdamW \
42
+ --optimizer.eps 1e-15 \
43
+ --optimizer.lr 3e-4 \
44
+ --lr_scheduler.warmup_steps 1024 \
45
+ --lr_scheduler.lr_min 0.1 \
46
+ --lr_scheduler.decay_type cosine \
47
+ --training.batch_size 32 \
48
+ --training.seq_len 2048 \
49
+ --training.gradient_accumulation_steps 1 \
50
+ --training.steps 20480 \
51
+ --training.max_norm 1.0 \
52
+ --training.skip_nan_inf \
53
+ --training.dataset HuggingFaceFW/fineweb-edu \
54
+ --training.dataset_name default \
55
+ --training.dataset_split train \
56
+ --training.streaming \
57
+ --training.num_workers 32 \
58
+ --training.prefetch_factor 2 \
59
+ --training.seed 42 \
60
+ --training.compile \
61
+ --training.tensor_parallel_degree 1 \
62
+ --training.disable_loss_parallel \
63
+ --checkpoint.interval 2048 \
64
+ --checkpoint.load_step -1 \
65
+ --metrics.log_freq 1
66
+ '
67
+
68
+ echo "Launching training..."
69
+
70
+ set -x
71
+ path=$(grep -oP '(?<=--job.dump_folder )[^ ]+' <<< "$params")
72
+ steps=$(grep -oP '(?<=--training.steps )[^ ]+' <<< "$params")
73
+ config=$(grep -oP '(?<=--model.config )[^ ]+' <<< "$params")
74
+ tokenizer=$(grep -oP '(?<=--model.tokenizer_path )[^ ]+' <<< "$params")
75
+ model=$(
76
+ python -c "import fla, sys; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type'
77
+ )
78
+
79
+ mkdir -p $path
80
+ cp * $path
81
+ cp -r configs $path
82
+ cp -r flame $path
83
+ cp -r 3rdparty/flash-linear-attention/fla $path
84
+ cp -r 3rdparty/torchtitan/torchtitan $path
85
+
86
+ # for offline systems
87
+ # export TRANSFORMERS_OFFLINE=1
88
+ # export HF_DATASETS_OFFLINE=1
89
+ # export HF_HUB_OFFLINE=1
90
+ if [ "$date" == "" ]; then
91
+ date=$(date +%Y%m%d%H%M)
92
+ fi
93
+ RUN_NAME="$model-$(basename $path)"
94
+ RUN_ID="$RUN_NAME-$date"
95
+
96
+ export WANDB_RESUME=allow
97
+ if [[ -z "${WANDB_PROJECT}" ]]; then
98
+ export WANDB_PROJECT="fla"
99
+ fi
100
+ if [[ -z "${WANDB_NAME}" ]]; then
101
+ export WANDB_NAME="$RUN_NAME"
102
+ fi
103
+ if [[ -z "${WANDB_RUN_ID}" ]]; then
104
+ export WANDB_RUN_ID="$RUN_ID"
105
+ fi
106
+
107
+ CUDA_VISIBLE_DEVICES=${DEVICES} \
108
+ PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
109
+ # systemd-run --scope --user -p MemoryHigh=80G \
110
+ torchrun --nnodes=${NNODE} \
111
+ --nproc_per_node=${NGPU} \
112
+ --rdzv_backend c10d \
113
+ --rdzv_endpoint "${MASTER_ADDR}:${MASTER_PORT}" \
114
+ --local-ranks-filter ${LOG_RANK} \
115
+ --role rank \
116
+ --tee 3 \
117
+ --log-dir $path/logs \
118
+ -m flame.train_restart \
119
+ $params
120
+
121
+ echo "TRAINING DONE!"
122
+ echo "Converting the DCP checkpoints to HF format..."
123
+
124
+ python -m flame.utils.convert_dcp_to_hf \
125
+ --path $path \
126
+ --step $steps \
127
+ --config $config \
128
+ --tokenizer $tokenizer
129
+
130
+ echo "RUNNING DONE!"