| <div align="center"> | |
| # 🔥 Flame: Flash Linear Attention Made Easy | |
| </div> | |
| > [!IMPORTANT] | |
| > The `flame` project has been migrated to a new project built on torchtitan. | |
| > Please visit the [new repository](https://github.com/fla-org/flame) for details and updates. | |
| > | |
| > The code here is now **archived as legacy**, and no future updates will be synchronized here. | |
| A minimal framework for training FLA models, whether from scratch or through finetuning. | |
| Built on the robust infrastructure of 🤗, `flame` enables you to train large language models with just a few lines of code: | |
| we use `datasets` for data processing, `transformers` for model definitions, and `accelerate`[^1] for seamless distributed training. | |
| In this README, we will guide you through the process of using `flame` to train GLA models. | |
| ## Setup | |
| To get started, you'll need to install the required packages. | |
| Both `fla` and `flame` have minimal dependencies. | |
| Clone the `fla` repository and install the necessary packages as follows: | |
| ```bash | |
| git clone https://github.com/sustcsonglin/flash-linear-attention.git | |
| pip install . | |
| pip install accelerate | |
| ``` | |
| > [!CAUTION] | |
| > The 🤗 `tokenizers` have some [memory leak issues](https://github.com/huggingface/tokenizers/issues/1539) when processing very long documents. | |
| > To address this, please ensure you install `tokenizers>=0.20.4`. | |
| ## Preprocessing | |
| Before training, you need to download and pre-tokenize your dataset. | |
| We provide a straightforward script for this. | |
| For instance, to tokenize a 10B sample of the `fineweb-edu` dataset, run: | |
| ```bash | |
| python preprocess.py \ | |
| --dataset HuggingFaceFW/fineweb-edu \ | |
| --name sample-10BT \ | |
| --split train \ | |
| --context_length 2048 | |
| ``` | |
| ``` | |
| python preprocess.py \ | |
| --dataset /mnt/jfzn/msj/fineweb100B_hf/datasets--HuggingFaceFW--fineweb-edu/sample/100BT \ | |
| --name sample-100BT \ | |
| --split train \ | |
| --context_length 2048 | |
| ``` | |
| /mnt/jfzn/msj/fineweb100B_hf/datasets--HuggingFaceFW--fineweb-edu/sample/100BT | |
| This will cache the processed dataset at `data/HuggingFaceFW/fineweb-edu/sample-10BT/train`. | |
| GLA utilizes a subset of Slimpajama for pretraining [in the paper](https://proceedings.mlr.press/v235/yang24ab.html). | |
| Given the size of the dataset, the fastest way to download it is using `git lfs` (refer to [this issue](https://huggingface.co/datasets/cerebras/SlimPajama-627B/discussions/2)). | |
| ```bash | |
| git lfs install | |
| git clone https://huggingface.co/datasets/cerebras/SlimPajama-627B --depth 1 | |
| python preprocess.py \ | |
| --dataset SlimPajama-627B \ | |
| --split train \ | |
| --context_length 2048 | |
| ``` | |
| ## Training from scratch | |
| To train your 340M model from scratch, execute the following command: | |
| ```bash | |
| bash train.sh \ | |
| type=gla \ | |
| lr=3e-4 \ | |
| scheduler=cosine_with_min_lr \ | |
| batch=32 \ | |
| update=1 \ | |
| warmup=1024 \ | |
| steps=20480 \ | |
| context=2048 \ | |
| gpus=8 \ | |
| nodes=1 \ | |
| path=exp/gla-340M-10B \ | |
| project=fla \ | |
| model=configs/gla_340M.json \ | |
| data=HuggingFaceFW/fineweb-edu \ | |
| name=sample-10BT \ | |
| cache=data/HuggingFaceFW/fineweb-edu/sample-10BT/train | |
| ``` | |
| Key parameters: | |
| | | Description | Default | | |
| | :-------- | :---------------------------- | -------------------- | | |
| | lr | `learning_rate` | `3e-4` | | |
| | scheduler | `lr_scheduler_type` | `cosine_with_min_lr` | | |
| | batch | `batch_size` | `32` | | |
| | update | `gradient_accumulation_steps` | `1` | | |
| | context | `context_length` | `2048` | | |
| | gpus | `num_gpus_per_node` | `8` | | |
| | nodes | `num_nodes` | `1` | | |
| | warmup | `warmup_steps` | `1024` | | |
| | steps | `max_steps` | `20480` | | |
| The learning rate is set to `3e-4` by default, equipped with a cosine scheduler. | |
| Other scheduler types like WSD (`warmup_stable_decay`)[^2] are also supported. | |
| The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as | |
| `batch_size × gradient_accumulation_steps × context_length × num_gpus_per_node × num_nodes`. | |
| For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens). | |
| The `warmup_steps` parameter indicates the number of steps for the learning rate warmup phase, while `max_steps` represents the maximum number of training steps. | |
| Each step processes `global_batch_size` tokens. | |
| Consequently, `512` and `20480` correspond to processing 0.5B and 10B tokens, respectively. | |
| :warning: Monitor the value of `global_batch_size`, `warmup_steps`, and `max_steps` carefully when modifying any of the hyperparameters!! | |
| `flame` also supports resuming interrupted training by specifying the checkpoint path. | |
| Simply use the following command: | |
| ```bash | |
| bash train.sh \ | |
| type=gla \ | |
| lr=3e-4 \ | |
| steps=20480 \ | |
| batch=32 \ | |
| update=1 \ | |
| warmup=1024 \ | |
| context=2048 \ | |
| gpus=8 \ | |
| nodes=1 \ | |
| path=exp/gla-340M-10B \ | |
| project=fla \ | |
| model=configs/gla_340M.json \ | |
| data=HuggingFaceFW/fineweb-edu \ | |
| name=sample-10BT \ | |
| cache=data/HuggingFaceFW/fineweb-edu/sample-10BT/train \ | |
| checkpoint=exp/gla-340M-10B/checkpoint-8192 | |
| ``` | |
| You can also use `wandb` to monitor your training process effectively. | |
|  | |
| ## Continual Pretraining | |
| `flame` supports continual training from a pretrained checkpoint. | |
| Below, we provide an example of how to finetune Mistral-7B to GLA. | |
| You can follow similar steps to reproduce the results in the [GSA paper](https://arxiv.org/abs/2409.07146): | |
| 1. Initialize a brand-new GLA-7B model from the config and copy the mathced pretrained weights from Mistral-7B: | |
| ```bash | |
| cd ../utils | |
| python convert_from_llama.py \ | |
| --model mistralai/Mistral-7B-v0.1 \ | |
| --config ../training/configs/gla_7B.json \ | |
| --output ../training/converted/gla-7B | |
| cd - | |
| ``` | |
| 2. Directly launch training from the converted checkpoint: | |
| ```bash | |
| bash train.sh \ | |
| type=gla \ | |
| lr=3e-5 \ | |
| steps=10240 \ | |
| batch=4 \ | |
| update=8 \ | |
| warmup=512 \ | |
| context=2048 \ | |
| path=exp/gla-7B-20B \ | |
| project=fla \ | |
| model=converted/gla-7B \ | |
| data=SlimPajama-627B \ | |
| cache=data/SlimPajama-627B/train | |
| ``` | |
| Please be aware that finetuning on a single node may not be the most efficient approach. | |
| If available, consider leveraging multi-node GPUs for optimal performance. | |
| You can find guidance on how to launch a multi-node job in the [accelerate tutorial](https://github.com/huggingface/accelerate/blob/main/examples/slurm/submit_multinode.sh). | |
| [^1]: The `accelerate` library supports various distributed frameworks, like `deepspeed` and `megatron` for large-scale training. We use `deepspeed` in our case. | |
| [^2]: https://arxiv.org/abs/2404.06395 | |