Upload 10 files
Browse files- README.md +519 -3
- config.json +40 -0
- generation_config.json +6 -0
- pyproject.toml +43 -0
- setup.py +51 -0
- special_tokens_map.json +23 -0
- tokenizer.json +0 -0
- tokenizer_config.json +44 -0
- train.sh +130 -0
- train_restart.sh +130 -0
README.md
CHANGED
|
@@ -1,3 +1,519 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# 🔥 Flame: Flash Language Modeling Made Easy
|
| 4 |
+
|
| 5 |
+
[](https://deepwiki.com/fla-org/flame)
|
| 6 |
+
|
| 7 |
+
</div>
|
| 8 |
+
|
| 9 |
+
Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for language models with blazing efficiency.
|
| 10 |
+
|
| 11 |
+
**Feature Highlights:**
|
| 12 |
+
|
| 13 |
+
- 🚀 Minimal, easy-to-use, extensible training framework
|
| 14 |
+
- 🤗 Seamless integration with `fla` and `transformers`
|
| 15 |
+
- 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
|
| 16 |
+
- 🔮 4D parallelism (coming soon)
|
| 17 |
+
|
| 18 |
+
## Setup
|
| 19 |
+
|
| 20 |
+
To get started, clone the `flame` repository and install the required dependencies:
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
git clone https://github.com/fla-org/flame.git
|
| 24 |
+
cd flame
|
| 25 |
+
pip install .
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Install the latest version of fla
|
| 29 |
+
```
|
| 30 |
+
pip uninstall flash-linear-attention && pip install -U --no-use-pep517 git+https://github.com/fla-org/flash-linear-attention
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
[Important] Install specific version of torchtitan
|
| 34 |
+
```
|
| 35 |
+
pip install git+https://github.com/pytorch/torchtitan.git@0b44d4c
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
## Dataset Preparation
|
| 40 |
+
To download the dataset to your local disk, create a new Python file with the following content and execute it:
|
| 41 |
+
|
| 42 |
+
```py
|
| 43 |
+
from datasets import load_dataset
|
| 44 |
+
|
| 45 |
+
# load fineweb-edu with parallel processing
|
| 46 |
+
dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
|
| 47 |
+
|
| 48 |
+
# or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
|
| 49 |
+
dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Training Recipes
|
| 53 |
+
|
| 54 |
+
Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus ~~in streaming mode~~. (Do not use streaming mode if you are concerned about resuming training.)
|
| 55 |
+
|
| 56 |
+
> [!WARNING]
|
| 57 |
+
> If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
|
| 58 |
+
> For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
|
| 59 |
+
|
| 60 |
+
```sh
|
| 61 |
+
bash train.sh \
|
| 62 |
+
--job.config_file flame/models/fla.toml \
|
| 63 |
+
--job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr1e-3.cosine \
|
| 64 |
+
--model.config configs/transformer_340M.json \
|
| 65 |
+
--model.tokenizer_path fla-hub/transformer-1.3B-100B \
|
| 66 |
+
--optimizer.name AdamW \
|
| 67 |
+
--optimizer.eps 1e-15 \
|
| 68 |
+
--optimizer.lr 1e-3 \
|
| 69 |
+
--lr_scheduler.warmup_steps 1024 \
|
| 70 |
+
--lr_scheduler.lr_min 0.1 \
|
| 71 |
+
--lr_scheduler.decay_type cosine \
|
| 72 |
+
--training.batch_size 1 \
|
| 73 |
+
--training.seq_len 65536 \
|
| 74 |
+
--training.context_len 4096 \
|
| 75 |
+
--training.varlen \
|
| 76 |
+
--training.gradient_accumulation_steps 1 \
|
| 77 |
+
--training.steps 20480 \
|
| 78 |
+
--training.max_norm 1.0 \
|
| 79 |
+
--training.skip_nan_inf \
|
| 80 |
+
--training.dataset HuggingFaceFW/fineweb-edu \
|
| 81 |
+
--training.dataset_name sample-100BT \
|
| 82 |
+
--training.dataset_split train \
|
| 83 |
+
--training.num_workers 32 \
|
| 84 |
+
--training.prefetch_factor 2 \
|
| 85 |
+
--training.seed 42 \
|
| 86 |
+
--training.compile \
|
| 87 |
+
--checkpoint.interval 2048 \
|
| 88 |
+
--checkpoint.load_step -1 \
|
| 89 |
+
--checkpoint.keep_latest_k 2 \
|
| 90 |
+
--metrics.log_freq 1
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
|
| 94 |
+
**For single-GPU debugging, set `NGPU=1`.**
|
| 95 |
+
|
| 96 |
+
We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
|
| 97 |
+
By default, the learning rate is set to 1e-3 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
|
| 98 |
+
|
| 99 |
+
**Key parameters:**
|
| 100 |
+
- `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
|
| 101 |
+
- `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
|
| 102 |
+
- `--training.steps`: Total number of training steps.
|
| 103 |
+
- `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
|
| 104 |
+
- `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
|
| 105 |
+
- `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
|
| 106 |
+
- `--training.varlen`: Whether to conduct variable-length sequence training.
|
| 107 |
+
- `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
|
| 108 |
+
|
| 109 |
+
> [!WARNING]
|
| 110 |
+
> The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
|
| 111 |
+
> Each step processes `global_batch_size * seq_len` tokens.
|
| 112 |
+
> Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
|
| 113 |
+
|
| 114 |
+
For a detailed explanation of all parameters, run:
|
| 115 |
+
|
| 116 |
+
```sh
|
| 117 |
+
bash train.sh -h
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
<details>
|
| 121 |
+
<summary>Usage</summary>
|
| 122 |
+
|
| 123 |
+
```py
|
| 124 |
+
options:
|
| 125 |
+
-h, --help show this help message and exit
|
| 126 |
+
--job.config_file JOB.CONFIG_FILE
|
| 127 |
+
Job config file
|
| 128 |
+
--job.dump_folder JOB.DUMP_FOLDER
|
| 129 |
+
Folder to dump job outputs
|
| 130 |
+
--job.description JOB.DESCRIPTION
|
| 131 |
+
Description of the job
|
| 132 |
+
--job.use_for_integration_test
|
| 133 |
+
Add this config to the integration test suite
|
| 134 |
+
--job.print_args Print the args to terminal
|
| 135 |
+
--model.config MODEL.CONFIG
|
| 136 |
+
Path to the model config
|
| 137 |
+
--model.norm_type MODEL.NORM_TYPE
|
| 138 |
+
Type of layer normalization to use [layernorm,
|
| 139 |
+
np_layernorm, rmsnorm, fused_rmsnorm]
|
| 140 |
+
--model.tokenizer_path MODEL.TOKENIZER_PATH
|
| 141 |
+
Tokenizer path
|
| 142 |
+
--profiling.enable_profiling
|
| 143 |
+
Whether to enable pytorch profiler
|
| 144 |
+
--profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
|
| 145 |
+
Trace files location
|
| 146 |
+
--profiling.profile_freq PROFILING.PROFILE_FREQ
|
| 147 |
+
How often to collect profiler traces, in iterations
|
| 148 |
+
--profiling.enable_memory_snapshot
|
| 149 |
+
Whether to dump memory snapshot
|
| 150 |
+
--profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
|
| 151 |
+
Memeory snapshot files location
|
| 152 |
+
--optimizer.name OPTIMIZER.NAME
|
| 153 |
+
Optimizer to use
|
| 154 |
+
--optimizer.eps OPTIMIZER.EPS
|
| 155 |
+
Epsilon value for the optimizer.
|
| 156 |
+
--optimizer.fused Whether the fused implementation(CUDA only) is used.
|
| 157 |
+
--optimizer.scheduler {wsd,cosine,linear}
|
| 158 |
+
Scheduler to use. Currently supported: wsd, cosine,
|
| 159 |
+
and linear.
|
| 160 |
+
--optimizer.lr OPTIMIZER.LR
|
| 161 |
+
Learning rate to use
|
| 162 |
+
--optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
|
| 163 |
+
Min lr ratio for lr scheduler
|
| 164 |
+
--optimizer.early_step_in_backward
|
| 165 |
+
Whether to apply optimizer in the backward. Caution,
|
| 166 |
+
optimizer_in_backward is not compatible with gradients
|
| 167 |
+
clipping, users should not call
|
| 168 |
+
register_post_accumulate_grad_hook after the optimizer
|
| 169 |
+
is built.
|
| 170 |
+
--training.batch_size TRAINING.BATCH_SIZE
|
| 171 |
+
Batch size
|
| 172 |
+
--training.seq_len TRAINING.SEQ_LEN
|
| 173 |
+
Sequence length
|
| 174 |
+
--training.context_len TRAINING.CONTEXT_LEN
|
| 175 |
+
Max length allowed for each sequence
|
| 176 |
+
--training.varlen Whether to take sequences of variable length as input
|
| 177 |
+
--training.warmup_steps TRAINING.WARMUP_STEPS
|
| 178 |
+
Steps for lr scheduler warmup, normally 1/5 of
|
| 179 |
+
--training.steps
|
| 180 |
+
--training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
|
| 181 |
+
Number of steps to accumulate gradients before
|
| 182 |
+
updating parameters
|
| 183 |
+
--training.steps TRAINING.STEPS
|
| 184 |
+
How many train steps to run
|
| 185 |
+
--training.max_norm TRAINING.MAX_NORM
|
| 186 |
+
Max norm for gradient clipping
|
| 187 |
+
--training.skip_nan_inf
|
| 188 |
+
Skip batch updates when NaN or INF gradients are
|
| 189 |
+
encountered during training
|
| 190 |
+
--training.dataset TRAINING.DATASET
|
| 191 |
+
Dataset to use, with comma separated values
|
| 192 |
+
--training.dataset_name TRAINING.DATASET_NAME
|
| 193 |
+
The name of the dataset config, with comma separated
|
| 194 |
+
values if provided
|
| 195 |
+
--training.dataset_split TRAINING.DATASET_SPLIT
|
| 196 |
+
Dataset split to use, with comma separated values if
|
| 197 |
+
provided
|
| 198 |
+
--training.data_dir TRAINING.DATA_DIR
|
| 199 |
+
Data dirs to use, with comma separated values if
|
| 200 |
+
provided
|
| 201 |
+
--training.data_files TRAINING.DATA_FILES
|
| 202 |
+
Data files to use, with comma separated values if
|
| 203 |
+
provided
|
| 204 |
+
--training.data_probs TRAINING.DATA_PROBS
|
| 205 |
+
Data sampling probabilities, with comma separated
|
| 206 |
+
values if provided
|
| 207 |
+
--training.streaming Whether to load dataset in streaming mode, used for
|
| 208 |
+
huge dataset
|
| 209 |
+
--training.num_workers TRAINING.NUM_WORKERS
|
| 210 |
+
Number of subprocesses to use for data loading. 0
|
| 211 |
+
means that the data will be loaded in the main
|
| 212 |
+
process.
|
| 213 |
+
--training.prefetch_factor TRAINING.PREFETCH_FACTOR
|
| 214 |
+
Number of batches loaded in advance by each worker.2
|
| 215 |
+
means there will be a total of 2 * num_workers batches
|
| 216 |
+
prefetched across all workers.
|
| 217 |
+
--training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
|
| 218 |
+
The `data_parallel_replicate_degree` argument
|
| 219 |
+
specifies the degree of data parallelism for weight
|
| 220 |
+
replication. When this value is greater than 1,
|
| 221 |
+
weights will be replicated across
|
| 222 |
+
`data_parallel_replicate_degree` ranks. If
|
| 223 |
+
`data_parallel_shard_degree` is also greater than 1,
|
| 224 |
+
the parallelism method used is HSDP (Hybrid Sharded
|
| 225 |
+
Data Parallelism). Otherwise, the parallelism method
|
| 226 |
+
used is DDP (Distributed Data Parallelism). 1 means
|
| 227 |
+
disabled.
|
| 228 |
+
--training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
|
| 229 |
+
The `data_parallel_shard_degree` argument specifies
|
| 230 |
+
the degree of data parallelism for weight sharding.
|
| 231 |
+
When this value is greater than 1, weights will be
|
| 232 |
+
sharded across `data_parallel_shard_degree` ranks. If
|
| 233 |
+
`data_parallel_replicate_degree` is also greater than
|
| 234 |
+
1, the parallelism method used is HSDP (Hybrid Sharded
|
| 235 |
+
Data Parallelism). Otherwise, the parallelism method
|
| 236 |
+
used is FSDP (Fully Sharded Data Parallelism). -1
|
| 237 |
+
means leftover ranks will be used (After
|
| 238 |
+
DP_REPLICATE/SP/PP). Note that only
|
| 239 |
+
`data_parallel_shard_degree` can be negative. 1 means
|
| 240 |
+
disabled.
|
| 241 |
+
--training.enable_cpu_offload
|
| 242 |
+
Whether to apply CPU offloading of parameters,
|
| 243 |
+
gradients, and optimizer states in FSDP
|
| 244 |
+
--training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
|
| 245 |
+
Tensor Parallelism degree. 1 means disabled.
|
| 246 |
+
--training.disable_loss_parallel
|
| 247 |
+
Whether to apply loss parallel when sequence parallel
|
| 248 |
+
is enabled
|
| 249 |
+
--training.mixed_precision_param {bfloat16,float32}
|
| 250 |
+
torch dtype to use for parameters when applying mixed
|
| 251 |
+
precision via FSDP. This feature only takes effect
|
| 252 |
+
when data_parallel_shard_degree > 1
|
| 253 |
+
--training.mixed_precision_reduce {float32}
|
| 254 |
+
torch dtype to use for reductions when applying mixed
|
| 255 |
+
precision via FSDP. This feature only takes effect
|
| 256 |
+
when data_parallel_shard_degree > 1
|
| 257 |
+
--training.compile Whether to compile the model
|
| 258 |
+
--training.gc_freq TRAINING.GC_FREQ
|
| 259 |
+
Python garbage control scheduling interval, in steps
|
| 260 |
+
--training.seed TRAINING.SEED
|
| 261 |
+
Choose the base RNG seed used for training
|
| 262 |
+
--training.deterministic
|
| 263 |
+
Use deterministic algorithms wherever possible, may be
|
| 264 |
+
slower
|
| 265 |
+
--metrics.log_freq METRICS.LOG_FREQ
|
| 266 |
+
How often to log metrics to TensorBoard, in iterations
|
| 267 |
+
--metrics.enable_tensorboard
|
| 268 |
+
Whether to log metrics to TensorBoard
|
| 269 |
+
--metrics.disable_color_printing
|
| 270 |
+
Whether to disable color printing in logs
|
| 271 |
+
--metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
|
| 272 |
+
Folder to dump TensorBoard states
|
| 273 |
+
--metrics.rank_0_only
|
| 274 |
+
Whether to save TensorBoard metrics only for rank 0 or
|
| 275 |
+
for all ranks. When pipeline_parallel_degree is > 1,
|
| 276 |
+
this option uses the 0th rank of the last stage
|
| 277 |
+
pipeline group, which is the only stage that computes
|
| 278 |
+
loss metrics.
|
| 279 |
+
--metrics.enable_wandb
|
| 280 |
+
Whether to log metrics to Weights & Biases
|
| 281 |
+
--experimental.enable_async_tensor_parallel
|
| 282 |
+
Whether to apply async tensor parallel (currently only
|
| 283 |
+
effective when compile is enabled)
|
| 284 |
+
--experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
|
| 285 |
+
Pipeline Parallelism degree, or number of ranks. 1
|
| 286 |
+
means disabled. If using looped schedules, this still
|
| 287 |
+
specifies the number of physical ranks, not the number
|
| 288 |
+
of stages. Stages per rank are inferred from split
|
| 289 |
+
points degree, and schedule.
|
| 290 |
+
--experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
|
| 291 |
+
Specify comma-separated names of modules to use as the
|
| 292 |
+
beginning of a split point. e.g. "layers.0,layers.2"
|
| 293 |
+
will cause the model to be split into 3 stages, the
|
| 294 |
+
first containing all the layers up to layers.0, the
|
| 295 |
+
second containing layers.0 and up to layers.2, the
|
| 296 |
+
third containing layers.2 and all the remaining
|
| 297 |
+
layers. Note: fully-automated splitting may be enabled
|
| 298 |
+
in the future, but currently the split points must be
|
| 299 |
+
specified manually.
|
| 300 |
+
--experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
|
| 301 |
+
Specify the Pipeline Parallel schedule to use. The
|
| 302 |
+
supported schedules are: https://github.com/pytorch/py
|
| 303 |
+
torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
|
| 304 |
+
rch/distributed/pipelining/schedules.py#L2161. The
|
| 305 |
+
schedule must be compatible with the split points and
|
| 306 |
+
stages_per_rank. Looped schedules (e.g.
|
| 307 |
+
Interleaved1F1B) require specifying
|
| 308 |
+
pipeline_parallel_degree = number of ranks, and
|
| 309 |
+
split_points = number of stages - 1
|
| 310 |
+
--experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
|
| 311 |
+
Specify the path to the pipeline parallel schedule csv
|
| 312 |
+
file to use. The pipeline_parallel_schedule argument
|
| 313 |
+
must be either PipelineScheduleSingle,
|
| 314 |
+
PipelineScheduleMulti, or _PipelineScheduleRuntime.
|
| 315 |
+
--experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
|
| 316 |
+
How many microbatches to split the global training
|
| 317 |
+
batch into when using pipeline parallelism. The global
|
| 318 |
+
training batch size must be evenly divisible by the
|
| 319 |
+
number of microbatches. The default value will be the
|
| 320 |
+
number of pipeline stages, if unspecified.
|
| 321 |
+
--experimental.enable_compiled_autograd
|
| 322 |
+
Enable CompiledAutograd to compile the backward.
|
| 323 |
+
--experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
|
| 324 |
+
Context parallelism degree. 1 means disabled.
|
| 325 |
+
--experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
|
| 326 |
+
The collective to use in context parallel SDPA for kv
|
| 327 |
+
shards exchange. 'allgather' means to all-gather all
|
| 328 |
+
kv shards on ranks after the first sub-SDPA
|
| 329 |
+
computation, 'alltoall' means to all-to-all shuffle
|
| 330 |
+
the kv shards. The default value is 'allgather'.
|
| 331 |
+
--checkpoint.enable_checkpoint
|
| 332 |
+
Whether to enable checkpoint
|
| 333 |
+
--checkpoint.folder CHECKPOINT.FOLDER
|
| 334 |
+
The folder to store the checkpoints. When
|
| 335 |
+
enable_checkpoint is set to true, checkpoints will be
|
| 336 |
+
in {--job.dump_folder}/{--checkpoint.folder}.
|
| 337 |
+
--checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
|
| 338 |
+
Checkpointing interval unit of measurement ['step',
|
| 339 |
+
'seconds']
|
| 340 |
+
--checkpoint.interval CHECKPOINT.INTERVAL
|
| 341 |
+
Checkpointing interval, in steps or seconds depending
|
| 342 |
+
on --checkpoint.interval_type
|
| 343 |
+
--checkpoint.model_weights_only
|
| 344 |
+
When model_weights_only=True, only model weights will
|
| 345 |
+
be saved at the end of training. With this,
|
| 346 |
+
checkpoints can be loaded using `torch.load(...,
|
| 347 |
+
weights_only=True)` after conversion. When
|
| 348 |
+
model_weights_only=False, the full checkpoint will be
|
| 349 |
+
saved. A full checkpoint includes model, optimizer and
|
| 350 |
+
train_state, which can be used to resume training. The
|
| 351 |
+
default value is false.
|
| 352 |
+
--checkpoint.export_dtype {float16,bfloat16,float32}
|
| 353 |
+
Converts to the specified precision when training
|
| 354 |
+
completes and model_weights_only=true. Currently
|
| 355 |
+
supports float32, float16, and bfloat16. The default
|
| 356 |
+
value is float32.
|
| 357 |
+
--checkpoint.create_seed_checkpoint
|
| 358 |
+
Initializes the full model without applying
|
| 359 |
+
parallelisms, and then saves it as a seed checkpoint.
|
| 360 |
+
Note: requires user to call train.py without
|
| 361 |
+
specifying any parallelisms, e.g. NGPU=1. Could be
|
| 362 |
+
implemented as a separate script, but this way shares
|
| 363 |
+
more code.
|
| 364 |
+
--checkpoint.async_mode CHECKPOINT.ASYNC_MODE
|
| 365 |
+
Which async checkpoint mode to use. Currently there
|
| 366 |
+
are 3 different modes. 1. "disabled": synchronized
|
| 367 |
+
checkpointing will be used. 2. "async":
|
| 368 |
+
torch.distributed.checkpoint.async_save will be used.
|
| 369 |
+
1. "async_with_pinned_mem": this option utilizes a
|
| 370 |
+
dedicated pinned memory space and creates a separate
|
| 371 |
+
process for faster GPU->CPU transfer performance and
|
| 372 |
+
eliminating GIL contention. The cost is increased CPU
|
| 373 |
+
memory usage. If insufficient CPU memory is available,
|
| 374 |
+
performance may degrade due to memory paging. For most
|
| 375 |
+
users, "async" should suffice as the performance
|
| 376 |
+
overhead is typically small (on the order of tens of
|
| 377 |
+
seconds) compared to checkpointing frequency. This
|
| 378 |
+
mode can be employed to pursue near-zero checkpointing
|
| 379 |
+
times (e.g., < 1 second) given appropriate hardware
|
| 380 |
+
support such as ample CPU memory and fast PCIe.
|
| 381 |
+
"disabled" is the default mode.
|
| 382 |
+
--checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
|
| 383 |
+
Keeps only the latest k checkpoints, and purging older
|
| 384 |
+
ones. If 0, keep all checkpoints. 0 is the default
|
| 385 |
+
value.
|
| 386 |
+
--checkpoint.load_step CHECKPOINT.LOAD_STEP
|
| 387 |
+
Load the checkpoint at the specified step. If -1, load
|
| 388 |
+
the latest checkpoint.
|
| 389 |
+
--float8.enable_float8_linear
|
| 390 |
+
If true, swaps `torch.nn.Linear` with `Float8Linear`.
|
| 391 |
+
This feature requires you to install 'torchao' which
|
| 392 |
+
can be found here: https://github.com/pytorch/ao
|
| 393 |
+
--float8.enable_fsdp_float8_all_gather
|
| 394 |
+
Whether enable float8 all-gather in FSDP
|
| 395 |
+
--float8.precompute_float8_dynamic_scale_for_fsdp
|
| 396 |
+
Whether precompute float8 scales dynamically for FSDP
|
| 397 |
+
--float8.scaling_type_input {dynamic,delayed}
|
| 398 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 399 |
+
--float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
|
| 400 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 401 |
+
--float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
|
| 402 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 403 |
+
--comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
|
| 404 |
+
Timeout for communication operations, during
|
| 405 |
+
initialization and first train step.
|
| 406 |
+
--comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
|
| 407 |
+
Timeout for communication operations after the first
|
| 408 |
+
train step -- usually a tighter bound than during
|
| 409 |
+
initialization.
|
| 410 |
+
--comm.trace_buf_size COMM.TRACE_BUF_SIZE
|
| 411 |
+
Flight recorder ring buffer size, >0 means recording
|
| 412 |
+
by default, 0 means disabled
|
| 413 |
+
--memory_estimation.enabled
|
| 414 |
+
Whether to estimate memory usage for FSDP
|
| 415 |
+
--memory_estimation.disable_fake_mode
|
| 416 |
+
Whether to estimate memory under FakeTensorMode
|
| 417 |
+
```
|
| 418 |
+
</details>
|
| 419 |
+
|
| 420 |
+
### Training with variable-length inputs
|
| 421 |
+
When you set the `--training.varlen` flag, you're enabling a more efficient training method that packs multiple documents together into a single long sequence, eliminating the need for padding.
|
| 422 |
+
This is particularly useful when your dataset contains documents of varying lengths.
|
| 423 |
+
Let's break down how `--training.seq_len` and `--training.context_len` work in this mode.
|
| 424 |
+
|
| 425 |
+
* `--training.seq_len` (Packed Sequence Length): This is the total length of the final sequence fed to the model on one device. Instead of processing one document at a time, the dataloader takes multiple documents (each split to sequences no longer than `context_len`), concatenates them end-to-end, and creates a single long sequence of length `seq_len`.
|
| 426 |
+
* `--training.context_len` (Sample Length): This parameter defines the maximum number of tokens for a single document or sample. If a document from the dataset is longer than `context_len`, it will be truncated. For example, if `--training.context_len` is set to 4,096, a document with 5,000 tokens will be cut down to its first 4,096 tokens, leaving the left tokens as another independent sequence, while a document with 3000 tokens remains unchanged.
|
| 427 |
+
|
| 428 |
+
### Training with `torch.compile`
|
| 429 |
+
|
| 430 |
+
Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
|
| 431 |
+
In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
|
| 432 |
+
|
| 433 |
+
However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
|
| 434 |
+
We are actively working on resolving these issues to make compilation transparent to users.
|
| 435 |
+
In the meantime, please ensure you are using the latest dependencies.
|
| 436 |
+
|
| 437 |
+
Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
|
| 438 |
+
|
| 439 |
+
### Training with multiple datasets
|
| 440 |
+
|
| 441 |
+
If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
|
| 442 |
+
`flame` allows training with multiple datasets easily.
|
| 443 |
+
For example, you can specify the following arguments to train on 6 datasets with different proportions:
|
| 444 |
+
|
| 445 |
+
```sh
|
| 446 |
+
--training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
|
| 447 |
+
--training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
|
| 448 |
+
```
|
| 449 |
+
|
| 450 |
+
### ~Finalizing training~
|
| 451 |
+
|
| 452 |
+
> [!NOTE]
|
| 453 |
+
> We have done this conversion automatically in the training script since our latest updates.
|
| 454 |
+
|
| 455 |
+
Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
|
| 456 |
+
To facilitate this, we provide a straightforward conversion script:
|
| 457 |
+
|
| 458 |
+
```sh
|
| 459 |
+
python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
|
| 460 |
+
```
|
| 461 |
+
After this, your model will be in the 🤗 format, ready to be shared or deployed.
|
| 462 |
+
You can then easily publish your model using the `huggingface_hub` for wider accessibility.
|
| 463 |
+
|
| 464 |
+
### Continual training
|
| 465 |
+
|
| 466 |
+
If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
|
| 467 |
+
This allows you to seamlessly resume training with `flame`.
|
| 468 |
+
```sh
|
| 469 |
+
python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
|
| 470 |
+
```
|
| 471 |
+
Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
|
| 472 |
+
The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
|
| 473 |
+
|
| 474 |
+
Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
|
| 475 |
+
|
| 476 |
+
## Multi-node training
|
| 477 |
+
|
| 478 |
+
If you have access to multi-node GPUs, consider leveraging them for optimal performance.
|
| 479 |
+
This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
|
| 480 |
+
|
| 481 |
+
To set up multi-node training:
|
| 482 |
+
* Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
|
| 483 |
+
* If you're using a job scheduler like Slurm, it will handle these variables for you.
|
| 484 |
+
|
| 485 |
+
`torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
|
| 486 |
+
|
| 487 |
+
## Custom models
|
| 488 |
+
|
| 489 |
+
`flame` supports custom model architectures through seamless integration with the Hugging Face `transformers` library. To add your own model:
|
| 490 |
+
|
| 491 |
+
1. Create a new model directory under `custom_models/` (see `custom_models/sba` for a complete example)
|
| 492 |
+
2. Implement your model classes and configuration:
|
| 493 |
+
- Define a config class inheriting from `PretrainedConfig` (see `custom_models/sba/config_sba.py` for an example)
|
| 494 |
+
- Create model classes inheriting from `PreTrainedModel` (see `custom_models/sba/modeling_sba.py` for an example)
|
| 495 |
+
3. Register your models in `__init__.py`:
|
| 496 |
+
- Import your model classes and config classes
|
| 497 |
+
- Register your models with the `AutoModelForCausalLM`, `AutoModel` and `AutoConfig` classes (see `custom_models/sba/__init__.py` for an example)
|
| 498 |
+
4. Create a config file for your custom model, just need to specify the `model_type` to the one you just named for your custom model (example: `configs/sba_340m.json`).
|
| 499 |
+
5. Training is extremely simple, you can just use the `flame.train.py` script to train your custom model.
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
## Citation
|
| 508 |
+
|
| 509 |
+
If you find `flame` helpful for your work, please consider citing it.
|
| 510 |
+
|
| 511 |
+
```bib
|
| 512 |
+
@software{yang2025flame,
|
| 513 |
+
title = {Flame: Flash Language Modeling Made Easy},
|
| 514 |
+
author = {Zhang, Yu and Yang, Songlin},
|
| 515 |
+
url = {https://github.com/fla-org/flame},
|
| 516 |
+
month = jan,
|
| 517 |
+
year = {2025}
|
| 518 |
+
}
|
| 519 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"P_rank": 1,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"HamiltonForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"attn": null,
|
| 7 |
+
"attn_mode": "chunk",
|
| 8 |
+
"bos_token_id": 1,
|
| 9 |
+
"conv_bias_spatial": true,
|
| 10 |
+
"conv_bias_temporal": true,
|
| 11 |
+
"conv_size_spatial": 4,
|
| 12 |
+
"conv_size_temporal": 16,
|
| 13 |
+
"decay_weight": 4.0,
|
| 14 |
+
"dropout": 0.0,
|
| 15 |
+
"dtype": "float32",
|
| 16 |
+
"eos_token_id": 2,
|
| 17 |
+
"expand_spatial": 4,
|
| 18 |
+
"fuse_cross_entropy": true,
|
| 19 |
+
"fuse_norm": true,
|
| 20 |
+
"fuse_swiglu": true,
|
| 21 |
+
"head_dim": 64,
|
| 22 |
+
"hidden_act": "swish",
|
| 23 |
+
"hidden_ratio": 3,
|
| 24 |
+
"hidden_size": 1024,
|
| 25 |
+
"initializer_range": 0.02,
|
| 26 |
+
"intermediate_size": null,
|
| 27 |
+
"max_position_embeddings": 8192,
|
| 28 |
+
"mnu_bias": true,
|
| 29 |
+
"model_type": "hamilton",
|
| 30 |
+
"norm_eps": 1e-06,
|
| 31 |
+
"num_heads": 16,
|
| 32 |
+
"num_hidden_layers": 24,
|
| 33 |
+
"task": "text",
|
| 34 |
+
"tie_word_embeddings": false,
|
| 35 |
+
"transformers_version": "4.57.3",
|
| 36 |
+
"use_cache": true,
|
| 37 |
+
"use_l2warp": false,
|
| 38 |
+
"use_mlp": true,
|
| 39 |
+
"vocab_size": 32000
|
| 40 |
+
}
|
generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"transformers_version": "4.57.3"
|
| 6 |
+
}
|
pyproject.toml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "flame"
|
| 3 |
+
dynamic = ["version"]
|
| 4 |
+
description = "A minimal training framework for scaling FLA models"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
authors = [
|
| 7 |
+
{ name = "Songlin Yang", email = "yangsl66@mit.edu" },
|
| 8 |
+
{ name = "Yu Zhang", email = "yzhang.cs@outlook.com" },
|
| 9 |
+
]
|
| 10 |
+
license = { file = "LICENSE" }
|
| 11 |
+
classifiers = [
|
| 12 |
+
"Programming Language :: Python :: 3",
|
| 13 |
+
"License :: OSI Approved :: MIT License",
|
| 14 |
+
"Operating System :: OS Independent",
|
| 15 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 16 |
+
]
|
| 17 |
+
requires-python = ">=3.10"
|
| 18 |
+
dependencies = [
|
| 19 |
+
'flash-linear-attention',
|
| 20 |
+
'torch>=2.5',
|
| 21 |
+
'torchdata',
|
| 22 |
+
'transformers>=4.45.0',
|
| 23 |
+
'triton>=3.0',
|
| 24 |
+
'datasets>=3.3.0',
|
| 25 |
+
'einops',
|
| 26 |
+
'ninja',
|
| 27 |
+
'wandb',
|
| 28 |
+
'tiktoken',
|
| 29 |
+
'tensorboard',
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
[project.optional-dependencies]
|
| 33 |
+
dev = ["pytest"]
|
| 34 |
+
|
| 35 |
+
[project.urls]
|
| 36 |
+
Homepage = "https://github.com/fla-org/flame"
|
| 37 |
+
|
| 38 |
+
[build-system]
|
| 39 |
+
requires = ["setuptools>=45", "wheel", "ninja", "torch"]
|
| 40 |
+
|
| 41 |
+
[tool.isort]
|
| 42 |
+
line_length = 127
|
| 43 |
+
multi_line_output = 3
|
setup.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import ast
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from setuptools import find_packages, setup
|
| 9 |
+
|
| 10 |
+
with open('README.md') as f:
|
| 11 |
+
long_description = f.read()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_package_version():
|
| 15 |
+
with open(Path(os.path.dirname(os.path.abspath(__file__))) / 'flame' / '__init__.py') as f:
|
| 16 |
+
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
|
| 17 |
+
return ast.literal_eval(version_match.group(1))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
setup(
|
| 21 |
+
name='flame',
|
| 22 |
+
version=get_package_version(),
|
| 23 |
+
description='A minimal training framework for scaling FLA models',
|
| 24 |
+
long_description=long_description,
|
| 25 |
+
long_description_content_type='text/markdown',
|
| 26 |
+
author='Songlin Yang, Yu Zhang',
|
| 27 |
+
author_email='yangsl66@mit.edu, yzhang.cs@outlook.com',
|
| 28 |
+
url='https://github.com/fla-org/flame',
|
| 29 |
+
packages=find_packages(),
|
| 30 |
+
license='MIT',
|
| 31 |
+
classifiers=[
|
| 32 |
+
'Programming Language :: Python :: 3',
|
| 33 |
+
'License :: OSI Approved :: MIT License',
|
| 34 |
+
'Operating System :: OS Independent',
|
| 35 |
+
'Topic :: Scientific/Engineering :: Artificial Intelligence'
|
| 36 |
+
],
|
| 37 |
+
python_requires='>=3.10',
|
| 38 |
+
install_requires=[
|
| 39 |
+
'flash-linear-attention',
|
| 40 |
+
'torch>=2.5',
|
| 41 |
+
'torchdata',
|
| 42 |
+
'transformers>=4.45.0',
|
| 43 |
+
'triton>=3.0',
|
| 44 |
+
'datasets>=3.3.0',
|
| 45 |
+
'einops',
|
| 46 |
+
'ninja',
|
| 47 |
+
'wandb',
|
| 48 |
+
'tiktoken',
|
| 49 |
+
'tensorboard',
|
| 50 |
+
],
|
| 51 |
+
)
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"unk_token": {
|
| 17 |
+
"content": "<unk>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
}
|
| 23 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"additional_special_tokens": [],
|
| 32 |
+
"bos_token": "<s>",
|
| 33 |
+
"clean_up_tokenization_spaces": false,
|
| 34 |
+
"eos_token": "</s>",
|
| 35 |
+
"extra_special_tokens": {},
|
| 36 |
+
"legacy": false,
|
| 37 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 38 |
+
"pad_token": null,
|
| 39 |
+
"sp_model_kwargs": {},
|
| 40 |
+
"spaces_between_special_tokens": false,
|
| 41 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 42 |
+
"unk_token": "<unk>",
|
| 43 |
+
"use_default_system_prompt": false
|
| 44 |
+
}
|
train.sh
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/bash
|
| 2 |
+
|
| 3 |
+
export HF_HOME="/root/workspace/huggingface_cache"
|
| 4 |
+
|
| 5 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 6 |
+
# export HF_HOME="../../autodl-fs/hf_cache"
|
| 7 |
+
|
| 8 |
+
params=""
|
| 9 |
+
if [ $# -ne 0 ]; then
|
| 10 |
+
params="$*"
|
| 11 |
+
fi
|
| 12 |
+
|
| 13 |
+
# use envs as local params for convenience
|
| 14 |
+
# e.g.
|
| 15 |
+
# NNODE=1 NGPU=8 LOG_RANK=0 ./train.sh
|
| 16 |
+
NNODE=${NNODE:-"1"}
|
| 17 |
+
NGPU=${NGPU:-"4"}
|
| 18 |
+
DEVICES=${DEVICES:-"0,1,2,3"}
|
| 19 |
+
|
| 20 |
+
LOG_RANK=${LOG_RANK:-0}
|
| 21 |
+
|
| 22 |
+
if [[ -z "${MASTER_ADDR}" ]]; then
|
| 23 |
+
export MASTER_ADDR="localhost"
|
| 24 |
+
fi
|
| 25 |
+
if [[ -z "${MASTER_PORT}" ]]; then
|
| 26 |
+
export MASTER_PORT="0"
|
| 27 |
+
fi
|
| 28 |
+
|
| 29 |
+
: '
|
| 30 |
+
Usage:
|
| 31 |
+
|
| 32 |
+
bash train.sh -h
|
| 33 |
+
|
| 34 |
+
Training a 340M model:
|
| 35 |
+
|
| 36 |
+
NNODE=1 NGPU=8 LOG_RANK=0 bash train.sh \
|
| 37 |
+
--job.config_file flame/models/fla.toml \
|
| 38 |
+
--job.dump_folder exp/transformer-340M-10B/batch32.seqlen2048.warmup1024.update1.steps20480.lr3e-4 \
|
| 39 |
+
--model.config configs/transformer_340M.json \
|
| 40 |
+
--model.tokenizer_path fla-hub/transformer-1.3B-100B \
|
| 41 |
+
--optimizer.name AdamW \
|
| 42 |
+
--optimizer.eps 1e-15 \
|
| 43 |
+
--optimizer.lr 3e-4 \
|
| 44 |
+
--lr_scheduler.warmup_steps 1024 \
|
| 45 |
+
--lr_scheduler.lr_min 0.1 \
|
| 46 |
+
--lr_scheduler.decay_type cosine \
|
| 47 |
+
--training.batch_size 32 \
|
| 48 |
+
--training.seq_len 2048 \
|
| 49 |
+
--training.gradient_accumulation_steps 1 \
|
| 50 |
+
--training.steps 20480 \
|
| 51 |
+
--training.max_norm 1.0 \
|
| 52 |
+
--training.skip_nan_inf \
|
| 53 |
+
--training.dataset HuggingFaceFW/fineweb-edu \
|
| 54 |
+
--training.dataset_name default \
|
| 55 |
+
--training.dataset_split train \
|
| 56 |
+
--training.streaming \
|
| 57 |
+
--training.num_workers 32 \
|
| 58 |
+
--training.prefetch_factor 2 \
|
| 59 |
+
--training.seed 42 \
|
| 60 |
+
--training.compile \
|
| 61 |
+
--training.tensor_parallel_degree 1 \
|
| 62 |
+
--training.disable_loss_parallel \
|
| 63 |
+
--checkpoint.interval 2048 \
|
| 64 |
+
--checkpoint.load_step -1 \
|
| 65 |
+
--metrics.log_freq 1
|
| 66 |
+
'
|
| 67 |
+
|
| 68 |
+
echo "Launching training..."
|
| 69 |
+
|
| 70 |
+
set -x
|
| 71 |
+
path=$(grep -oP '(?<=--job.dump_folder )[^ ]+' <<< "$params")
|
| 72 |
+
steps=$(grep -oP '(?<=--training.steps )[^ ]+' <<< "$params")
|
| 73 |
+
config=$(grep -oP '(?<=--model.config )[^ ]+' <<< "$params")
|
| 74 |
+
tokenizer=$(grep -oP '(?<=--model.tokenizer_path )[^ ]+' <<< "$params")
|
| 75 |
+
model=$(
|
| 76 |
+
python -c "import fla, sys; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type'
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
mkdir -p $path
|
| 80 |
+
cp * $path
|
| 81 |
+
cp -r configs $path
|
| 82 |
+
cp -r flame $path
|
| 83 |
+
cp -r 3rdparty/flash-linear-attention/fla $path
|
| 84 |
+
cp -r 3rdparty/torchtitan/torchtitan $path
|
| 85 |
+
|
| 86 |
+
# for offline systems
|
| 87 |
+
# export TRANSFORMERS_OFFLINE=1
|
| 88 |
+
# export HF_DATASETS_OFFLINE=1
|
| 89 |
+
# export HF_HUB_OFFLINE=1
|
| 90 |
+
if [ "$date" == "" ]; then
|
| 91 |
+
date=$(date +%Y%m%d%H%M)
|
| 92 |
+
fi
|
| 93 |
+
RUN_NAME="$model-$(basename $path)"
|
| 94 |
+
RUN_ID="$RUN_NAME-$date"
|
| 95 |
+
|
| 96 |
+
export WANDB_RESUME=allow
|
| 97 |
+
if [[ -z "${WANDB_PROJECT}" ]]; then
|
| 98 |
+
export WANDB_PROJECT="fla"
|
| 99 |
+
fi
|
| 100 |
+
if [[ -z "${WANDB_NAME}" ]]; then
|
| 101 |
+
export WANDB_NAME="$RUN_NAME"
|
| 102 |
+
fi
|
| 103 |
+
if [[ -z "${WANDB_RUN_ID}" ]]; then
|
| 104 |
+
export WANDB_RUN_ID="$RUN_ID"
|
| 105 |
+
fi
|
| 106 |
+
|
| 107 |
+
CUDA_VISIBLE_DEVICES=${DEVICES} \
|
| 108 |
+
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
|
| 109 |
+
# systemd-run --scope --user -p MemoryHigh=80G \
|
| 110 |
+
torchrun --nnodes=${NNODE} \
|
| 111 |
+
--nproc_per_node=${NGPU} \
|
| 112 |
+
--rdzv_backend c10d \
|
| 113 |
+
--rdzv_endpoint "${MASTER_ADDR}:${MASTER_PORT}" \
|
| 114 |
+
--local-ranks-filter ${LOG_RANK} \
|
| 115 |
+
--role rank \
|
| 116 |
+
--tee 3 \
|
| 117 |
+
--log-dir $path/logs \
|
| 118 |
+
-m flame.train \
|
| 119 |
+
$params
|
| 120 |
+
|
| 121 |
+
echo "TRAINING DONE!"
|
| 122 |
+
echo "Converting the DCP checkpoints to HF format..."
|
| 123 |
+
|
| 124 |
+
python -m flame.utils.convert_dcp_to_hf \
|
| 125 |
+
--path $path \
|
| 126 |
+
--step $steps \
|
| 127 |
+
--config $config \
|
| 128 |
+
--tokenizer $tokenizer
|
| 129 |
+
|
| 130 |
+
echo "RUNNING DONE!"
|
train_restart.sh
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/bash
|
| 2 |
+
|
| 3 |
+
export HF_HOME="/root/workspace/huggingface_cache"
|
| 4 |
+
|
| 5 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 6 |
+
# export HF_HOME="../../autodl-fs/hf_cache"
|
| 7 |
+
|
| 8 |
+
params=""
|
| 9 |
+
if [ $# -ne 0 ]; then
|
| 10 |
+
params="$*"
|
| 11 |
+
fi
|
| 12 |
+
|
| 13 |
+
# use envs as local params for convenience
|
| 14 |
+
# e.g.
|
| 15 |
+
# NNODE=1 NGPU=8 LOG_RANK=0 ./train.sh
|
| 16 |
+
NNODE=${NNODE:-"1"}
|
| 17 |
+
NGPU=${NGPU:-"3"}
|
| 18 |
+
DEVICES=${DEVICES:-"0,1,2"}
|
| 19 |
+
|
| 20 |
+
LOG_RANK=${LOG_RANK:-0}
|
| 21 |
+
|
| 22 |
+
if [[ -z "${MASTER_ADDR}" ]]; then
|
| 23 |
+
export MASTER_ADDR="localhost"
|
| 24 |
+
fi
|
| 25 |
+
if [[ -z "${MASTER_PORT}" ]]; then
|
| 26 |
+
export MASTER_PORT="0"
|
| 27 |
+
fi
|
| 28 |
+
|
| 29 |
+
: '
|
| 30 |
+
Usage:
|
| 31 |
+
|
| 32 |
+
bash train.sh -h
|
| 33 |
+
|
| 34 |
+
Training a 340M model:
|
| 35 |
+
|
| 36 |
+
NNODE=1 NGPU=8 LOG_RANK=0 bash train.sh \
|
| 37 |
+
--job.config_file flame/models/fla.toml \
|
| 38 |
+
--job.dump_folder exp/transformer-340M-10B/batch32.seqlen2048.warmup1024.update1.steps20480.lr3e-4 \
|
| 39 |
+
--model.config configs/transformer_340M.json \
|
| 40 |
+
--model.tokenizer_path fla-hub/transformer-1.3B-100B \
|
| 41 |
+
--optimizer.name AdamW \
|
| 42 |
+
--optimizer.eps 1e-15 \
|
| 43 |
+
--optimizer.lr 3e-4 \
|
| 44 |
+
--lr_scheduler.warmup_steps 1024 \
|
| 45 |
+
--lr_scheduler.lr_min 0.1 \
|
| 46 |
+
--lr_scheduler.decay_type cosine \
|
| 47 |
+
--training.batch_size 32 \
|
| 48 |
+
--training.seq_len 2048 \
|
| 49 |
+
--training.gradient_accumulation_steps 1 \
|
| 50 |
+
--training.steps 20480 \
|
| 51 |
+
--training.max_norm 1.0 \
|
| 52 |
+
--training.skip_nan_inf \
|
| 53 |
+
--training.dataset HuggingFaceFW/fineweb-edu \
|
| 54 |
+
--training.dataset_name default \
|
| 55 |
+
--training.dataset_split train \
|
| 56 |
+
--training.streaming \
|
| 57 |
+
--training.num_workers 32 \
|
| 58 |
+
--training.prefetch_factor 2 \
|
| 59 |
+
--training.seed 42 \
|
| 60 |
+
--training.compile \
|
| 61 |
+
--training.tensor_parallel_degree 1 \
|
| 62 |
+
--training.disable_loss_parallel \
|
| 63 |
+
--checkpoint.interval 2048 \
|
| 64 |
+
--checkpoint.load_step -1 \
|
| 65 |
+
--metrics.log_freq 1
|
| 66 |
+
'
|
| 67 |
+
|
| 68 |
+
echo "Launching training..."
|
| 69 |
+
|
| 70 |
+
set -x
|
| 71 |
+
path=$(grep -oP '(?<=--job.dump_folder )[^ ]+' <<< "$params")
|
| 72 |
+
steps=$(grep -oP '(?<=--training.steps )[^ ]+' <<< "$params")
|
| 73 |
+
config=$(grep -oP '(?<=--model.config )[^ ]+' <<< "$params")
|
| 74 |
+
tokenizer=$(grep -oP '(?<=--model.tokenizer_path )[^ ]+' <<< "$params")
|
| 75 |
+
model=$(
|
| 76 |
+
python -c "import fla, sys; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type'
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
mkdir -p $path
|
| 80 |
+
cp * $path
|
| 81 |
+
cp -r configs $path
|
| 82 |
+
cp -r flame $path
|
| 83 |
+
cp -r 3rdparty/flash-linear-attention/fla $path
|
| 84 |
+
cp -r 3rdparty/torchtitan/torchtitan $path
|
| 85 |
+
|
| 86 |
+
# for offline systems
|
| 87 |
+
# export TRANSFORMERS_OFFLINE=1
|
| 88 |
+
# export HF_DATASETS_OFFLINE=1
|
| 89 |
+
# export HF_HUB_OFFLINE=1
|
| 90 |
+
if [ "$date" == "" ]; then
|
| 91 |
+
date=$(date +%Y%m%d%H%M)
|
| 92 |
+
fi
|
| 93 |
+
RUN_NAME="$model-$(basename $path)"
|
| 94 |
+
RUN_ID="$RUN_NAME-$date"
|
| 95 |
+
|
| 96 |
+
export WANDB_RESUME=allow
|
| 97 |
+
if [[ -z "${WANDB_PROJECT}" ]]; then
|
| 98 |
+
export WANDB_PROJECT="fla"
|
| 99 |
+
fi
|
| 100 |
+
if [[ -z "${WANDB_NAME}" ]]; then
|
| 101 |
+
export WANDB_NAME="$RUN_NAME"
|
| 102 |
+
fi
|
| 103 |
+
if [[ -z "${WANDB_RUN_ID}" ]]; then
|
| 104 |
+
export WANDB_RUN_ID="$RUN_ID"
|
| 105 |
+
fi
|
| 106 |
+
|
| 107 |
+
CUDA_VISIBLE_DEVICES=${DEVICES} \
|
| 108 |
+
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
|
| 109 |
+
# systemd-run --scope --user -p MemoryHigh=80G \
|
| 110 |
+
torchrun --nnodes=${NNODE} \
|
| 111 |
+
--nproc_per_node=${NGPU} \
|
| 112 |
+
--rdzv_backend c10d \
|
| 113 |
+
--rdzv_endpoint "${MASTER_ADDR}:${MASTER_PORT}" \
|
| 114 |
+
--local-ranks-filter ${LOG_RANK} \
|
| 115 |
+
--role rank \
|
| 116 |
+
--tee 3 \
|
| 117 |
+
--log-dir $path/logs \
|
| 118 |
+
-m flame.train_restart \
|
| 119 |
+
$params
|
| 120 |
+
|
| 121 |
+
echo "TRAINING DONE!"
|
| 122 |
+
echo "Converting the DCP checkpoints to HF format..."
|
| 123 |
+
|
| 124 |
+
python -m flame.utils.convert_dcp_to_hf \
|
| 125 |
+
--path $path \
|
| 126 |
+
--step $steps \
|
| 127 |
+
--config $config \
|
| 128 |
+
--tokenizer $tokenizer
|
| 129 |
+
|
| 130 |
+
echo "RUNNING DONE!"
|