| <div align="center"> | |
| # ๐ฅ Flame: Flash Linear Attention Made Easy | |
| </div> | |
| Welcome to ๐ฅ `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency. | |
| **Feature Highlights:** | |
| - ๐ Minimal, easy-to-use, extensible training framework | |
| - ๐ค Seamless integration with `fla` and `transformers` | |
| - ๐ Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support | |
| - ๐ฎ 4D parallelism (coming soon) | |
| ## Setup | |
| To get started, clone the `flame` repository and install the required dependencies: | |
| ```bash | |
| git clone https://github.com/fla-org/flame.git | |
| cd flame | |
| pip install . | |
| ``` | |
| `flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules. | |
| After installation, initialize and update the submodules: | |
| ```sh | |
| git submodule update --init --recursive | |
| ``` | |
| ## Dataset Preparation | |
| To download the dataset to your local disk, create a new Python file with the following content and execute it: | |
| ```py | |
| from datasets import load_dataset | |
| # load fineweb-edu with parallel processing | |
| dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path") | |
| # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments | |
| dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path") | |
| ``` | |
| ## Training Recipes | |
| Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode. | |
| > [!WARNING] | |
| > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues. | |
| > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus. | |
| ```sh | |
| bash train.sh \ | |
| --job.config_file flame/models/fla.toml \ | |
| --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \ | |
| --model.config configs/transformer_340M.json \ | |
| --model.tokenizer_path fla-hub/transformer-1.3B-100B \ | |
| --optimizer.name AdamW \ | |
| --optimizer.eps 1e-15 \ | |
| --optimizer.lr 3e-4 \ | |
| --lr_scheduler.warmup_steps 1024 \ | |
| --lr_scheduler.lr_min 0.1 \ | |
| --lr_scheduler.decay_type cosine \ | |
| --training.batch_size 1 \ | |
| --training.seq_len 65536 \ | |
| --training.context_len 4096 \ | |
| --training.varlen \ | |
| --training.gradient_accumulation_steps 1 \ | |
| --training.steps 20480 \ | |
| --training.max_norm 1.0 \ | |
| --training.skip_nan_inf \ | |
| --training.dataset HuggingFaceFW/fineweb-edu \ | |
| --training.dataset_name sample-100BT \ | |
| --training.dataset_split train \ | |
| --training.streaming \ | |
| --training.num_workers 32 \ | |
| --training.prefetch_factor 2 \ | |
| --training.seed 42 \ | |
| --training.compile \ | |
| --checkpoint.interval 2048 \ | |
| --checkpoint.load_step -1 \ | |
| --checkpoint.keep_latest_k 2 \ | |
| --metrics.log_freq 1 | |
| ``` | |
| You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8. | |
| **For single-GPU debugging, set `NGPU=1`.** | |
| We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models. | |
| By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported. | |
| **Key parameters:** | |
| - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule. | |
| - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase. | |
| - `--training.steps`: Total number of training steps. | |
| - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set. | |
| - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples. | |
| - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`. | |
| - `--training.varlen`: Whether to conduct variable-length sequence training. | |
| - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps. | |
| > [!WARNING] | |
| > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size ร gradient_accumulation_steps ร num_gpus. | |
| > Each step processes `global_batch_size * seq_len` tokens. | |
| > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters! | |
| For a detailed explanation of all parameters, run: | |
| ```sh | |
| bash train.sh -h | |
| ``` | |
| <details> | |
| <summary>Usage</summary> | |
| ```py | |
| options: | |
| -h, --help show this help message and exit | |
| --job.config_file JOB.CONFIG_FILE | |
| Job config file | |
| --job.dump_folder JOB.DUMP_FOLDER | |
| Folder to dump job outputs | |
| --job.description JOB.DESCRIPTION | |
| Description of the job | |
| --job.use_for_integration_test | |
| Add this config to the integration test suite | |
| --job.print_args Print the args to terminal | |
| --model.config MODEL.CONFIG | |
| Path to the model config | |
| --model.norm_type MODEL.NORM_TYPE | |
| Type of layer normalization to use [layernorm, | |
| np_layernorm, rmsnorm, fused_rmsnorm] | |
| --model.tokenizer_path MODEL.TOKENIZER_PATH | |
| Tokenizer path | |
| --profiling.enable_profiling | |
| Whether to enable pytorch profiler | |
| --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER | |
| Trace files location | |
| --profiling.profile_freq PROFILING.PROFILE_FREQ | |
| How often to collect profiler traces, in iterations | |
| --profiling.enable_memory_snapshot | |
| Whether to dump memory snapshot | |
| --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER | |
| Memeory snapshot files location | |
| --optimizer.name OPTIMIZER.NAME | |
| Optimizer to use | |
| --optimizer.eps OPTIMIZER.EPS | |
| Epsilon value for the optimizer. | |
| --optimizer.fused Whether the fused implementation(CUDA only) is used. | |
| --optimizer.scheduler {wsd,cosine,linear} | |
| Scheduler to use. Currently supported: wsd, cosine, | |
| and linear. | |
| --optimizer.lr OPTIMIZER.LR | |
| Learning rate to use | |
| --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO | |
| Min lr ratio for lr scheduler | |
| --optimizer.early_step_in_backward | |
| Whether to apply optimizer in the backward. Caution, | |
| optimizer_in_backward is not compatible with gradients | |
| clipping, users should not call | |
| register_post_accumulate_grad_hook after the optimizer | |
| is built. | |
| --training.batch_size TRAINING.BATCH_SIZE | |
| Batch size | |
| --training.seq_len TRAINING.SEQ_LEN | |
| Sequence length | |
| --training.context_len TRAINING.CONTEXT_LEN | |
| Max length allowed for each sequence | |
| --training.varlen Whether to take sequences of variable length as input | |
| --training.warmup_steps TRAINING.WARMUP_STEPS | |
| Steps for lr scheduler warmup, normally 1/5 of | |
| --training.steps | |
| --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS | |
| Number of steps to accumulate gradients before | |
| updating parameters | |
| --training.steps TRAINING.STEPS | |
| How many train steps to run | |
| --training.max_norm TRAINING.MAX_NORM | |
| Max norm for gradient clipping | |
| --training.skip_nan_inf | |
| Skip batch updates when NaN or INF gradients are | |
| encountered during training | |
| --training.dataset TRAINING.DATASET | |
| Dataset to use, with comma separated values | |
| --training.dataset_name TRAINING.DATASET_NAME | |
| The name of the dataset config, with comma separated | |
| values if provided | |
| --training.dataset_split TRAINING.DATASET_SPLIT | |
| Dataset split to use, with comma separated values if | |
| provided | |
| --training.data_dir TRAINING.DATA_DIR | |
| Data dirs to use, with comma separated values if | |
| provided | |
| --training.data_files TRAINING.DATA_FILES | |
| Data files to use, with comma separated values if | |
| provided | |
| --training.data_probs TRAINING.DATA_PROBS | |
| Data sampling probabilities, with comma separated | |
| values if provided | |
| --training.streaming Whether to load dataset in streaming mode, used for | |
| huge dataset | |
| --training.num_workers TRAINING.NUM_WORKERS | |
| Number of subprocesses to use for data loading. 0 | |
| means that the data will be loaded in the main | |
| process. | |
| --training.prefetch_factor TRAINING.PREFETCH_FACTOR | |
| Number of batches loaded in advance by each worker.2 | |
| means there will be a total of 2 * num_workers batches | |
| prefetched across all workers. | |
| --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE | |
| The `data_parallel_replicate_degree` argument | |
| specifies the degree of data parallelism for weight | |
| replication. When this value is greater than 1, | |
| weights will be replicated across | |
| `data_parallel_replicate_degree` ranks. If | |
| `data_parallel_shard_degree` is also greater than 1, | |
| the parallelism method used is HSDP (Hybrid Sharded | |
| Data Parallelism). Otherwise, the parallelism method | |
| used is DDP (Distributed Data Parallelism). 1 means | |
| disabled. | |
| --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE | |
| The `data_parallel_shard_degree` argument specifies | |
| the degree of data parallelism for weight sharding. | |
| When this value is greater than 1, weights will be | |
| sharded across `data_parallel_shard_degree` ranks. If | |
| `data_parallel_replicate_degree` is also greater than | |
| 1, the parallelism method used is HSDP (Hybrid Sharded | |
| Data Parallelism). Otherwise, the parallelism method | |
| used is FSDP (Fully Sharded Data Parallelism). -1 | |
| means leftover ranks will be used (After | |
| DP_REPLICATE/SP/PP). Note that only | |
| `data_parallel_shard_degree` can be negative. 1 means | |
| disabled. | |
| --training.enable_cpu_offload | |
| Whether to apply CPU offloading of parameters, | |
| gradients, and optimizer states in FSDP | |
| --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE | |
| Tensor Parallelism degree. 1 means disabled. | |
| --training.disable_loss_parallel | |
| Whether to apply loss parallel when sequence parallel | |
| is enabled | |
| --training.mixed_precision_param {bfloat16,float32} | |
| torch dtype to use for parameters when applying mixed | |
| precision via FSDP. This feature only takes effect | |
| when data_parallel_shard_degree > 1 | |
| --training.mixed_precision_reduce {float32} | |
| torch dtype to use for reductions when applying mixed | |
| precision via FSDP. This feature only takes effect | |
| when data_parallel_shard_degree > 1 | |
| --training.compile Whether to compile the model | |
| --training.gc_freq TRAINING.GC_FREQ | |
| Python garbage control scheduling interval, in steps | |
| --training.seed TRAINING.SEED | |
| Choose the base RNG seed used for training | |
| --training.deterministic | |
| Use deterministic algorithms wherever possible, may be | |
| slower | |
| --metrics.log_freq METRICS.LOG_FREQ | |
| How often to log metrics to TensorBoard, in iterations | |
| --metrics.enable_tensorboard | |
| Whether to log metrics to TensorBoard | |
| --metrics.disable_color_printing | |
| Whether to disable color printing in logs | |
| --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER | |
| Folder to dump TensorBoard states | |
| --metrics.rank_0_only | |
| Whether to save TensorBoard metrics only for rank 0 or | |
| for all ranks. When pipeline_parallel_degree is > 1, | |
| this option uses the 0th rank of the last stage | |
| pipeline group, which is the only stage that computes | |
| loss metrics. | |
| --metrics.enable_wandb | |
| Whether to log metrics to Weights & Biases | |
| --experimental.enable_async_tensor_parallel | |
| Whether to apply async tensor parallel (currently only | |
| effective when compile is enabled) | |
| --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE | |
| Pipeline Parallelism degree, or number of ranks. 1 | |
| means disabled. If using looped schedules, this still | |
| specifies the number of physical ranks, not the number | |
| of stages. Stages per rank are inferred from split | |
| points degree, and schedule. | |
| --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...] | |
| Specify comma-separated names of modules to use as the | |
| beginning of a split point. e.g. "layers.0,layers.2" | |
| will cause the model to be split into 3 stages, the | |
| first containing all the layers up to layers.0, the | |
| second containing layers.0 and up to layers.2, the | |
| third containing layers.2 and all the remaining | |
| layers. Note: fully-automated splitting may be enabled | |
| in the future, but currently the split points must be | |
| specified manually. | |
| --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE | |
| Specify the Pipeline Parallel schedule to use. The | |
| supported schedules are: https://github.com/pytorch/py | |
| torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to | |
| rch/distributed/pipelining/schedules.py#L2161. The | |
| schedule must be compatible with the split points and | |
| stages_per_rank. Looped schedules (e.g. | |
| Interleaved1F1B) require specifying | |
| pipeline_parallel_degree = number of ranks, and | |
| split_points = number of stages - 1 | |
| --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV | |
| Specify the path to the pipeline parallel schedule csv | |
| file to use. The pipeline_parallel_schedule argument | |
| must be either PipelineScheduleSingle, | |
| PipelineScheduleMulti, or _PipelineScheduleRuntime. | |
| --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES | |
| How many microbatches to split the global training | |
| batch into when using pipeline parallelism. The global | |
| training batch size must be evenly divisible by the | |
| number of microbatches. The default value will be the | |
| number of pipeline stages, if unspecified. | |
| --experimental.enable_compiled_autograd | |
| Enable CompiledAutograd to compile the backward. | |
| --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE | |
| Context parallelism degree. 1 means disabled. | |
| --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD | |
| The collective to use in context parallel SDPA for kv | |
| shards exchange. 'allgather' means to all-gather all | |
| kv shards on ranks after the first sub-SDPA | |
| computation, 'alltoall' means to all-to-all shuffle | |
| the kv shards. The default value is 'allgather'. | |
| --checkpoint.enable_checkpoint | |
| Whether to enable checkpoint | |
| --checkpoint.folder CHECKPOINT.FOLDER | |
| The folder to store the checkpoints. When | |
| enable_checkpoint is set to true, checkpoints will be | |
| in {--job.dump_folder}/{--checkpoint.folder}. | |
| --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE | |
| Checkpointing interval unit of measurement ['step', | |
| 'seconds'] | |
| --checkpoint.interval CHECKPOINT.INTERVAL | |
| Checkpointing interval, in steps or seconds depending | |
| on --checkpoint.interval_type | |
| --checkpoint.model_weights_only | |
| When model_weights_only=True, only model weights will | |
| be saved at the end of training. With this, | |
| checkpoints can be loaded using `torch.load(..., | |
| weights_only=True)` after conversion. When | |
| model_weights_only=False, the full checkpoint will be | |
| saved. A full checkpoint includes model, optimizer and | |
| train_state, which can be used to resume training. The | |
| default value is false. | |
| --checkpoint.export_dtype {float16,bfloat16,float32} | |
| Converts to the specified precision when training | |
| completes and model_weights_only=true. Currently | |
| supports float32, float16, and bfloat16. The default | |
| value is float32. | |
| --checkpoint.create_seed_checkpoint | |
| Initializes the full model without applying | |
| parallelisms, and then saves it as a seed checkpoint. | |
| Note: requires user to call train.py without | |
| specifying any parallelisms, e.g. NGPU=1. Could be | |
| implemented as a separate script, but this way shares | |
| more code. | |
| --checkpoint.async_mode CHECKPOINT.ASYNC_MODE | |
| Which async checkpoint mode to use. Currently there | |
| are 3 different modes. 1. "disabled": synchronized | |
| checkpointing will be used. 2. "async": | |
| torch.distributed.checkpoint.async_save will be used. | |
| 1. "async_with_pinned_mem": this option utilizes a | |
| dedicated pinned memory space and creates a separate | |
| process for faster GPU->CPU transfer performance and | |
| eliminating GIL contention. The cost is increased CPU | |
| memory usage. If insufficient CPU memory is available, | |
| performance may degrade due to memory paging. For most | |
| users, "async" should suffice as the performance | |
| overhead is typically small (on the order of tens of | |
| seconds) compared to checkpointing frequency. This | |
| mode can be employed to pursue near-zero checkpointing | |
| times (e.g., < 1 second) given appropriate hardware | |
| support such as ample CPU memory and fast PCIe. | |
| "disabled" is the default mode. | |
| --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K | |
| Keeps only the latest k checkpoints, and purging older | |
| ones. If 0, keep all checkpoints. 0 is the default | |
| value. | |
| --checkpoint.load_step CHECKPOINT.LOAD_STEP | |
| Load the checkpoint at the specified step. If -1, load | |
| the latest checkpoint. | |
| --float8.enable_float8_linear | |
| If true, swaps `torch.nn.Linear` with `Float8Linear`. | |
| This feature requires you to install 'torchao' which | |
| can be found here: https://github.com/pytorch/ao | |
| --float8.enable_fsdp_float8_all_gather | |
| Whether enable float8 all-gather in FSDP | |
| --float8.precompute_float8_dynamic_scale_for_fsdp | |
| Whether precompute float8 scales dynamically for FSDP | |
| --float8.scaling_type_input {dynamic,delayed} | |
| float8 scaling for input, dynamic (default) or delayed | |
| --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT | |
| float8 scaling for input, dynamic (default) or delayed | |
| --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT | |
| float8 scaling for input, dynamic (default) or delayed | |
| --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS | |
| Timeout for communication operations, during | |
| initialization and first train step. | |
| --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS | |
| Timeout for communication operations after the first | |
| train step -- usually a tighter bound than during | |
| initialization. | |
| --comm.trace_buf_size COMM.TRACE_BUF_SIZE | |
| Flight recorder ring buffer size, >0 means recording | |
| by default, 0 means disabled | |
| --memory_estimation.enabled | |
| Whether to estimate memory usage for FSDP | |
| --memory_estimation.disable_fake_mode | |
| Whether to estimate memory under FakeTensorMode | |
| ``` | |
| </details> | |
| ### Training with `torch.compile` | |
| Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes. | |
| In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script. | |
| However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`. | |
| We are actively working on resolving these issues to make compilation transparent to users. | |
| In the meantime, please ensure you are using the latest dependencies. | |
| Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**. | |
| ### Training with multiple datasets | |
| If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets. | |
| `flame` allows training with multiple datasets easily. | |
| For example, you can specify the following arguments to train on 6 datasets with different proportions: | |
| ```sh | |
| --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \ | |
| --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \ | |
| ``` | |
| ### ~Finalizing training~ | |
| > [!NOTE] | |
| > We have done this conversion automatically in the training script since our latest updates. | |
| Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the ๐ค format for broader use. | |
| To facilitate this, we provide a straightforward conversion script: | |
| ```sh | |
| python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer> | |
| ``` | |
| After this, your model will be in the ๐ค format, ready to be shared or deployed. | |
| You can then easily publish your model using the `huggingface_hub` for wider accessibility. | |
| ### Continual training | |
| If you wish to build upon a strong pre-trained model (in ๐ค format) and continue training, we also offer a script to convert the ๐ค format model back into DCP format. | |
| This allows you to seamlessly resume training with `flame`. | |
| ```sh | |
| python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0> | |
| ``` | |
| Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored. | |
| The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled. | |
| Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off. | |
| ## Multi-node training | |
| If you have access to multi-node GPUs, consider leveraging them for optimal performance. | |
| This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html). | |
| To set up multi-node training: | |
| * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes. | |
| * If you're using a job scheduler like Slurm, it will handle these variables for you. | |
| `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point. | |