Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- LICENSE +21 -0
- README.md +471 -0
- config.json +34 -0
- configs/delta_net_1B.json +29 -0
- configs/delta_net_340M.json +27 -0
- configs/gla_340M.json +24 -0
- configs/gla_7B.json +25 -0
- configs/gsa_340M.json +29 -0
- configs/mtp_transformer_340M.json +19 -0
- configs/top_transformer_1B.json +24 -0
- configs/top_transformer_340M.json +20 -0
- configs/transformer_120M.json +18 -0
- configs/transformer_7B.json +21 -0
- fla/utils.py +223 -0
- generation_config.json +6 -0
- logs/none_ro0qpaac/attempt_0/0/stderr.log +0 -0
- logs/none_ro0qpaac/attempt_0/1/stderr.log +0 -0
- logs/none_ro0qpaac/attempt_0/3/stderr.log +0 -0
- logs/none_ro0qpaac/attempt_0/5/stderr.log +0 -0
- logs/none_ro0qpaac/attempt_0/6/stderr.log +0 -0
- logs/none_ro0qpaac/attempt_0/7/stderr.log +0 -0
- pyproject.toml +43 -0
- setup.py +51 -0
- special_tokens_map.json +23 -0
- tokenizer.json +0 -0
- tokenizer_config.json +44 -0
- torchtitan/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/dataloader.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/float8.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/metrics.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/optimizer.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/tokenizer.cpython-312.pyc +0 -0
- torchtitan/distributed/utils.py +311 -0
- torchtitan/experiments/deepseek_v3/README.md +40 -0
- torchtitan/experiments/deepseek_v3/indices.py +195 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py +11 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
- torchtitan/experiments/deepseek_v3/train.py +142 -0
- torchtitan/experiments/flux/README.md +23 -0
- torchtitan/experiments/flux/dataset/flux_dataset.py +267 -0
- torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
- torchtitan/experiments/flux/model/hf_embedder.py +40 -0
- torchtitan/experiments/flux/model/layers.py +286 -0
- torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py +299 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py +1304 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py +126 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py +240 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py +82 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# 🔥 Flame: Flash Linear Attention Made Easy
|
| 4 |
+
|
| 5 |
+
</div>
|
| 6 |
+
|
| 7 |
+
Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency.
|
| 8 |
+
|
| 9 |
+
**Feature Highlights:**
|
| 10 |
+
|
| 11 |
+
- 🚀 Minimal, easy-to-use, extensible training framework
|
| 12 |
+
- 🤗 Seamless integration with `fla` and `transformers`
|
| 13 |
+
- 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
|
| 14 |
+
- 🔮 4D parallelism (coming soon)
|
| 15 |
+
|
| 16 |
+
## Setup
|
| 17 |
+
|
| 18 |
+
To get started, clone the `flame` repository and install the required dependencies:
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
git clone https://github.com/fla-org/flame.git
|
| 22 |
+
cd flame
|
| 23 |
+
pip install .
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
`flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules.
|
| 27 |
+
After installation, initialize and update the submodules:
|
| 28 |
+
```sh
|
| 29 |
+
git submodule update --init --recursive
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## Dataset Preparation
|
| 33 |
+
To download the dataset to your local disk, create a new Python file with the following content and execute it:
|
| 34 |
+
|
| 35 |
+
```py
|
| 36 |
+
from datasets import load_dataset
|
| 37 |
+
|
| 38 |
+
# load fineweb-edu with parallel processing
|
| 39 |
+
dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
|
| 40 |
+
|
| 41 |
+
# or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
|
| 42 |
+
dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Training Recipes
|
| 46 |
+
|
| 47 |
+
Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode.
|
| 48 |
+
|
| 49 |
+
> [!WARNING]
|
| 50 |
+
> If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
|
| 51 |
+
> For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
|
| 52 |
+
|
| 53 |
+
```sh
|
| 54 |
+
bash train.sh \
|
| 55 |
+
--job.config_file flame/models/fla.toml \
|
| 56 |
+
--job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \
|
| 57 |
+
--model.config configs/transformer_340M.json \
|
| 58 |
+
--model.tokenizer_path fla-hub/transformer-1.3B-100B \
|
| 59 |
+
--optimizer.name AdamW \
|
| 60 |
+
--optimizer.eps 1e-15 \
|
| 61 |
+
--optimizer.lr 3e-4 \
|
| 62 |
+
--lr_scheduler.warmup_steps 1024 \
|
| 63 |
+
--lr_scheduler.lr_min 0.1 \
|
| 64 |
+
--lr_scheduler.decay_type cosine \
|
| 65 |
+
--training.batch_size 1 \
|
| 66 |
+
--training.seq_len 65536 \
|
| 67 |
+
--training.context_len 4096 \
|
| 68 |
+
--training.varlen \
|
| 69 |
+
--training.gradient_accumulation_steps 1 \
|
| 70 |
+
--training.steps 20480 \
|
| 71 |
+
--training.max_norm 1.0 \
|
| 72 |
+
--training.skip_nan_inf \
|
| 73 |
+
--training.dataset HuggingFaceFW/fineweb-edu \
|
| 74 |
+
--training.dataset_name sample-100BT \
|
| 75 |
+
--training.dataset_split train \
|
| 76 |
+
--training.streaming \
|
| 77 |
+
--training.num_workers 32 \
|
| 78 |
+
--training.prefetch_factor 2 \
|
| 79 |
+
--training.seed 42 \
|
| 80 |
+
--training.compile \
|
| 81 |
+
--checkpoint.interval 2048 \
|
| 82 |
+
--checkpoint.load_step -1 \
|
| 83 |
+
--checkpoint.keep_latest_k 2 \
|
| 84 |
+
--metrics.log_freq 1
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
|
| 88 |
+
**For single-GPU debugging, set `NGPU=1`.**
|
| 89 |
+
|
| 90 |
+
We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
|
| 91 |
+
By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
|
| 92 |
+
|
| 93 |
+
**Key parameters:**
|
| 94 |
+
- `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
|
| 95 |
+
- `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
|
| 96 |
+
- `--training.steps`: Total number of training steps.
|
| 97 |
+
- `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
|
| 98 |
+
- `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
|
| 99 |
+
- `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
|
| 100 |
+
- `--training.varlen`: Whether to conduct variable-length sequence training.
|
| 101 |
+
- `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
|
| 102 |
+
|
| 103 |
+
> [!WARNING]
|
| 104 |
+
> The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
|
| 105 |
+
> Each step processes `global_batch_size * seq_len` tokens.
|
| 106 |
+
> Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
|
| 107 |
+
|
| 108 |
+
For a detailed explanation of all parameters, run:
|
| 109 |
+
|
| 110 |
+
```sh
|
| 111 |
+
bash train.sh -h
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
<details>
|
| 115 |
+
<summary>Usage</summary>
|
| 116 |
+
|
| 117 |
+
```py
|
| 118 |
+
options:
|
| 119 |
+
-h, --help show this help message and exit
|
| 120 |
+
--job.config_file JOB.CONFIG_FILE
|
| 121 |
+
Job config file
|
| 122 |
+
--job.dump_folder JOB.DUMP_FOLDER
|
| 123 |
+
Folder to dump job outputs
|
| 124 |
+
--job.description JOB.DESCRIPTION
|
| 125 |
+
Description of the job
|
| 126 |
+
--job.use_for_integration_test
|
| 127 |
+
Add this config to the integration test suite
|
| 128 |
+
--job.print_args Print the args to terminal
|
| 129 |
+
--model.config MODEL.CONFIG
|
| 130 |
+
Path to the model config
|
| 131 |
+
--model.norm_type MODEL.NORM_TYPE
|
| 132 |
+
Type of layer normalization to use [layernorm,
|
| 133 |
+
np_layernorm, rmsnorm, fused_rmsnorm]
|
| 134 |
+
--model.tokenizer_path MODEL.TOKENIZER_PATH
|
| 135 |
+
Tokenizer path
|
| 136 |
+
--profiling.enable_profiling
|
| 137 |
+
Whether to enable pytorch profiler
|
| 138 |
+
--profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
|
| 139 |
+
Trace files location
|
| 140 |
+
--profiling.profile_freq PROFILING.PROFILE_FREQ
|
| 141 |
+
How often to collect profiler traces, in iterations
|
| 142 |
+
--profiling.enable_memory_snapshot
|
| 143 |
+
Whether to dump memory snapshot
|
| 144 |
+
--profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
|
| 145 |
+
Memeory snapshot files location
|
| 146 |
+
--optimizer.name OPTIMIZER.NAME
|
| 147 |
+
Optimizer to use
|
| 148 |
+
--optimizer.eps OPTIMIZER.EPS
|
| 149 |
+
Epsilon value for the optimizer.
|
| 150 |
+
--optimizer.fused Whether the fused implementation(CUDA only) is used.
|
| 151 |
+
--optimizer.scheduler {wsd,cosine,linear}
|
| 152 |
+
Scheduler to use. Currently supported: wsd, cosine,
|
| 153 |
+
and linear.
|
| 154 |
+
--optimizer.lr OPTIMIZER.LR
|
| 155 |
+
Learning rate to use
|
| 156 |
+
--optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
|
| 157 |
+
Min lr ratio for lr scheduler
|
| 158 |
+
--optimizer.early_step_in_backward
|
| 159 |
+
Whether to apply optimizer in the backward. Caution,
|
| 160 |
+
optimizer_in_backward is not compatible with gradients
|
| 161 |
+
clipping, users should not call
|
| 162 |
+
register_post_accumulate_grad_hook after the optimizer
|
| 163 |
+
is built.
|
| 164 |
+
--training.batch_size TRAINING.BATCH_SIZE
|
| 165 |
+
Batch size
|
| 166 |
+
--training.seq_len TRAINING.SEQ_LEN
|
| 167 |
+
Sequence length
|
| 168 |
+
--training.context_len TRAINING.CONTEXT_LEN
|
| 169 |
+
Max length allowed for each sequence
|
| 170 |
+
--training.varlen Whether to take sequences of variable length as input
|
| 171 |
+
--training.warmup_steps TRAINING.WARMUP_STEPS
|
| 172 |
+
Steps for lr scheduler warmup, normally 1/5 of
|
| 173 |
+
--training.steps
|
| 174 |
+
--training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
|
| 175 |
+
Number of steps to accumulate gradients before
|
| 176 |
+
updating parameters
|
| 177 |
+
--training.steps TRAINING.STEPS
|
| 178 |
+
How many train steps to run
|
| 179 |
+
--training.max_norm TRAINING.MAX_NORM
|
| 180 |
+
Max norm for gradient clipping
|
| 181 |
+
--training.skip_nan_inf
|
| 182 |
+
Skip batch updates when NaN or INF gradients are
|
| 183 |
+
encountered during training
|
| 184 |
+
--training.dataset TRAINING.DATASET
|
| 185 |
+
Dataset to use, with comma separated values
|
| 186 |
+
--training.dataset_name TRAINING.DATASET_NAME
|
| 187 |
+
The name of the dataset config, with comma separated
|
| 188 |
+
values if provided
|
| 189 |
+
--training.dataset_split TRAINING.DATASET_SPLIT
|
| 190 |
+
Dataset split to use, with comma separated values if
|
| 191 |
+
provided
|
| 192 |
+
--training.data_dir TRAINING.DATA_DIR
|
| 193 |
+
Data dirs to use, with comma separated values if
|
| 194 |
+
provided
|
| 195 |
+
--training.data_files TRAINING.DATA_FILES
|
| 196 |
+
Data files to use, with comma separated values if
|
| 197 |
+
provided
|
| 198 |
+
--training.data_probs TRAINING.DATA_PROBS
|
| 199 |
+
Data sampling probabilities, with comma separated
|
| 200 |
+
values if provided
|
| 201 |
+
--training.streaming Whether to load dataset in streaming mode, used for
|
| 202 |
+
huge dataset
|
| 203 |
+
--training.num_workers TRAINING.NUM_WORKERS
|
| 204 |
+
Number of subprocesses to use for data loading. 0
|
| 205 |
+
means that the data will be loaded in the main
|
| 206 |
+
process.
|
| 207 |
+
--training.prefetch_factor TRAINING.PREFETCH_FACTOR
|
| 208 |
+
Number of batches loaded in advance by each worker.2
|
| 209 |
+
means there will be a total of 2 * num_workers batches
|
| 210 |
+
prefetched across all workers.
|
| 211 |
+
--training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
|
| 212 |
+
The `data_parallel_replicate_degree` argument
|
| 213 |
+
specifies the degree of data parallelism for weight
|
| 214 |
+
replication. When this value is greater than 1,
|
| 215 |
+
weights will be replicated across
|
| 216 |
+
`data_parallel_replicate_degree` ranks. If
|
| 217 |
+
`data_parallel_shard_degree` is also greater than 1,
|
| 218 |
+
the parallelism method used is HSDP (Hybrid Sharded
|
| 219 |
+
Data Parallelism). Otherwise, the parallelism method
|
| 220 |
+
used is DDP (Distributed Data Parallelism). 1 means
|
| 221 |
+
disabled.
|
| 222 |
+
--training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
|
| 223 |
+
The `data_parallel_shard_degree` argument specifies
|
| 224 |
+
the degree of data parallelism for weight sharding.
|
| 225 |
+
When this value is greater than 1, weights will be
|
| 226 |
+
sharded across `data_parallel_shard_degree` ranks. If
|
| 227 |
+
`data_parallel_replicate_degree` is also greater than
|
| 228 |
+
1, the parallelism method used is HSDP (Hybrid Sharded
|
| 229 |
+
Data Parallelism). Otherwise, the parallelism method
|
| 230 |
+
used is FSDP (Fully Sharded Data Parallelism). -1
|
| 231 |
+
means leftover ranks will be used (After
|
| 232 |
+
DP_REPLICATE/SP/PP). Note that only
|
| 233 |
+
`data_parallel_shard_degree` can be negative. 1 means
|
| 234 |
+
disabled.
|
| 235 |
+
--training.enable_cpu_offload
|
| 236 |
+
Whether to apply CPU offloading of parameters,
|
| 237 |
+
gradients, and optimizer states in FSDP
|
| 238 |
+
--training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
|
| 239 |
+
Tensor Parallelism degree. 1 means disabled.
|
| 240 |
+
--training.disable_loss_parallel
|
| 241 |
+
Whether to apply loss parallel when sequence parallel
|
| 242 |
+
is enabled
|
| 243 |
+
--training.mixed_precision_param {bfloat16,float32}
|
| 244 |
+
torch dtype to use for parameters when applying mixed
|
| 245 |
+
precision via FSDP. This feature only takes effect
|
| 246 |
+
when data_parallel_shard_degree > 1
|
| 247 |
+
--training.mixed_precision_reduce {float32}
|
| 248 |
+
torch dtype to use for reductions when applying mixed
|
| 249 |
+
precision via FSDP. This feature only takes effect
|
| 250 |
+
when data_parallel_shard_degree > 1
|
| 251 |
+
--training.compile Whether to compile the model
|
| 252 |
+
--training.gc_freq TRAINING.GC_FREQ
|
| 253 |
+
Python garbage control scheduling interval, in steps
|
| 254 |
+
--training.seed TRAINING.SEED
|
| 255 |
+
Choose the base RNG seed used for training
|
| 256 |
+
--training.deterministic
|
| 257 |
+
Use deterministic algorithms wherever possible, may be
|
| 258 |
+
slower
|
| 259 |
+
--metrics.log_freq METRICS.LOG_FREQ
|
| 260 |
+
How often to log metrics to TensorBoard, in iterations
|
| 261 |
+
--metrics.enable_tensorboard
|
| 262 |
+
Whether to log metrics to TensorBoard
|
| 263 |
+
--metrics.disable_color_printing
|
| 264 |
+
Whether to disable color printing in logs
|
| 265 |
+
--metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
|
| 266 |
+
Folder to dump TensorBoard states
|
| 267 |
+
--metrics.rank_0_only
|
| 268 |
+
Whether to save TensorBoard metrics only for rank 0 or
|
| 269 |
+
for all ranks. When pipeline_parallel_degree is > 1,
|
| 270 |
+
this option uses the 0th rank of the last stage
|
| 271 |
+
pipeline group, which is the only stage that computes
|
| 272 |
+
loss metrics.
|
| 273 |
+
--metrics.enable_wandb
|
| 274 |
+
Whether to log metrics to Weights & Biases
|
| 275 |
+
--experimental.enable_async_tensor_parallel
|
| 276 |
+
Whether to apply async tensor parallel (currently only
|
| 277 |
+
effective when compile is enabled)
|
| 278 |
+
--experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
|
| 279 |
+
Pipeline Parallelism degree, or number of ranks. 1
|
| 280 |
+
means disabled. If using looped schedules, this still
|
| 281 |
+
specifies the number of physical ranks, not the number
|
| 282 |
+
of stages. Stages per rank are inferred from split
|
| 283 |
+
points degree, and schedule.
|
| 284 |
+
--experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
|
| 285 |
+
Specify comma-separated names of modules to use as the
|
| 286 |
+
beginning of a split point. e.g. "layers.0,layers.2"
|
| 287 |
+
will cause the model to be split into 3 stages, the
|
| 288 |
+
first containing all the layers up to layers.0, the
|
| 289 |
+
second containing layers.0 and up to layers.2, the
|
| 290 |
+
third containing layers.2 and all the remaining
|
| 291 |
+
layers. Note: fully-automated splitting may be enabled
|
| 292 |
+
in the future, but currently the split points must be
|
| 293 |
+
specified manually.
|
| 294 |
+
--experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
|
| 295 |
+
Specify the Pipeline Parallel schedule to use. The
|
| 296 |
+
supported schedules are: https://github.com/pytorch/py
|
| 297 |
+
torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
|
| 298 |
+
rch/distributed/pipelining/schedules.py#L2161. The
|
| 299 |
+
schedule must be compatible with the split points and
|
| 300 |
+
stages_per_rank. Looped schedules (e.g.
|
| 301 |
+
Interleaved1F1B) require specifying
|
| 302 |
+
pipeline_parallel_degree = number of ranks, and
|
| 303 |
+
split_points = number of stages - 1
|
| 304 |
+
--experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
|
| 305 |
+
Specify the path to the pipeline parallel schedule csv
|
| 306 |
+
file to use. The pipeline_parallel_schedule argument
|
| 307 |
+
must be either PipelineScheduleSingle,
|
| 308 |
+
PipelineScheduleMulti, or _PipelineScheduleRuntime.
|
| 309 |
+
--experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
|
| 310 |
+
How many microbatches to split the global training
|
| 311 |
+
batch into when using pipeline parallelism. The global
|
| 312 |
+
training batch size must be evenly divisible by the
|
| 313 |
+
number of microbatches. The default value will be the
|
| 314 |
+
number of pipeline stages, if unspecified.
|
| 315 |
+
--experimental.enable_compiled_autograd
|
| 316 |
+
Enable CompiledAutograd to compile the backward.
|
| 317 |
+
--experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
|
| 318 |
+
Context parallelism degree. 1 means disabled.
|
| 319 |
+
--experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
|
| 320 |
+
The collective to use in context parallel SDPA for kv
|
| 321 |
+
shards exchange. 'allgather' means to all-gather all
|
| 322 |
+
kv shards on ranks after the first sub-SDPA
|
| 323 |
+
computation, 'alltoall' means to all-to-all shuffle
|
| 324 |
+
the kv shards. The default value is 'allgather'.
|
| 325 |
+
--checkpoint.enable_checkpoint
|
| 326 |
+
Whether to enable checkpoint
|
| 327 |
+
--checkpoint.folder CHECKPOINT.FOLDER
|
| 328 |
+
The folder to store the checkpoints. When
|
| 329 |
+
enable_checkpoint is set to true, checkpoints will be
|
| 330 |
+
in {--job.dump_folder}/{--checkpoint.folder}.
|
| 331 |
+
--checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
|
| 332 |
+
Checkpointing interval unit of measurement ['step',
|
| 333 |
+
'seconds']
|
| 334 |
+
--checkpoint.interval CHECKPOINT.INTERVAL
|
| 335 |
+
Checkpointing interval, in steps or seconds depending
|
| 336 |
+
on --checkpoint.interval_type
|
| 337 |
+
--checkpoint.model_weights_only
|
| 338 |
+
When model_weights_only=True, only model weights will
|
| 339 |
+
be saved at the end of training. With this,
|
| 340 |
+
checkpoints can be loaded using `torch.load(...,
|
| 341 |
+
weights_only=True)` after conversion. When
|
| 342 |
+
model_weights_only=False, the full checkpoint will be
|
| 343 |
+
saved. A full checkpoint includes model, optimizer and
|
| 344 |
+
train_state, which can be used to resume training. The
|
| 345 |
+
default value is false.
|
| 346 |
+
--checkpoint.export_dtype {float16,bfloat16,float32}
|
| 347 |
+
Converts to the specified precision when training
|
| 348 |
+
completes and model_weights_only=true. Currently
|
| 349 |
+
supports float32, float16, and bfloat16. The default
|
| 350 |
+
value is float32.
|
| 351 |
+
--checkpoint.create_seed_checkpoint
|
| 352 |
+
Initializes the full model without applying
|
| 353 |
+
parallelisms, and then saves it as a seed checkpoint.
|
| 354 |
+
Note: requires user to call train.py without
|
| 355 |
+
specifying any parallelisms, e.g. NGPU=1. Could be
|
| 356 |
+
implemented as a separate script, but this way shares
|
| 357 |
+
more code.
|
| 358 |
+
--checkpoint.async_mode CHECKPOINT.ASYNC_MODE
|
| 359 |
+
Which async checkpoint mode to use. Currently there
|
| 360 |
+
are 3 different modes. 1. "disabled": synchronized
|
| 361 |
+
checkpointing will be used. 2. "async":
|
| 362 |
+
torch.distributed.checkpoint.async_save will be used.
|
| 363 |
+
1. "async_with_pinned_mem": this option utilizes a
|
| 364 |
+
dedicated pinned memory space and creates a separate
|
| 365 |
+
process for faster GPU->CPU transfer performance and
|
| 366 |
+
eliminating GIL contention. The cost is increased CPU
|
| 367 |
+
memory usage. If insufficient CPU memory is available,
|
| 368 |
+
performance may degrade due to memory paging. For most
|
| 369 |
+
users, "async" should suffice as the performance
|
| 370 |
+
overhead is typically small (on the order of tens of
|
| 371 |
+
seconds) compared to checkpointing frequency. This
|
| 372 |
+
mode can be employed to pursue near-zero checkpointing
|
| 373 |
+
times (e.g., < 1 second) given appropriate hardware
|
| 374 |
+
support such as ample CPU memory and fast PCIe.
|
| 375 |
+
"disabled" is the default mode.
|
| 376 |
+
--checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
|
| 377 |
+
Keeps only the latest k checkpoints, and purging older
|
| 378 |
+
ones. If 0, keep all checkpoints. 0 is the default
|
| 379 |
+
value.
|
| 380 |
+
--checkpoint.load_step CHECKPOINT.LOAD_STEP
|
| 381 |
+
Load the checkpoint at the specified step. If -1, load
|
| 382 |
+
the latest checkpoint.
|
| 383 |
+
--float8.enable_float8_linear
|
| 384 |
+
If true, swaps `torch.nn.Linear` with `Float8Linear`.
|
| 385 |
+
This feature requires you to install 'torchao' which
|
| 386 |
+
can be found here: https://github.com/pytorch/ao
|
| 387 |
+
--float8.enable_fsdp_float8_all_gather
|
| 388 |
+
Whether enable float8 all-gather in FSDP
|
| 389 |
+
--float8.precompute_float8_dynamic_scale_for_fsdp
|
| 390 |
+
Whether precompute float8 scales dynamically for FSDP
|
| 391 |
+
--float8.scaling_type_input {dynamic,delayed}
|
| 392 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 393 |
+
--float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
|
| 394 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 395 |
+
--float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
|
| 396 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 397 |
+
--comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
|
| 398 |
+
Timeout for communication operations, during
|
| 399 |
+
initialization and first train step.
|
| 400 |
+
--comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
|
| 401 |
+
Timeout for communication operations after the first
|
| 402 |
+
train step -- usually a tighter bound than during
|
| 403 |
+
initialization.
|
| 404 |
+
--comm.trace_buf_size COMM.TRACE_BUF_SIZE
|
| 405 |
+
Flight recorder ring buffer size, >0 means recording
|
| 406 |
+
by default, 0 means disabled
|
| 407 |
+
--memory_estimation.enabled
|
| 408 |
+
Whether to estimate memory usage for FSDP
|
| 409 |
+
--memory_estimation.disable_fake_mode
|
| 410 |
+
Whether to estimate memory under FakeTensorMode
|
| 411 |
+
```
|
| 412 |
+
</details>
|
| 413 |
+
|
| 414 |
+
### Training with `torch.compile`
|
| 415 |
+
|
| 416 |
+
Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
|
| 417 |
+
In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
|
| 418 |
+
|
| 419 |
+
However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
|
| 420 |
+
We are actively working on resolving these issues to make compilation transparent to users.
|
| 421 |
+
In the meantime, please ensure you are using the latest dependencies.
|
| 422 |
+
|
| 423 |
+
Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
|
| 424 |
+
|
| 425 |
+
### Training with multiple datasets
|
| 426 |
+
|
| 427 |
+
If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
|
| 428 |
+
`flame` allows training with multiple datasets easily.
|
| 429 |
+
For example, you can specify the following arguments to train on 6 datasets with different proportions:
|
| 430 |
+
|
| 431 |
+
```sh
|
| 432 |
+
--training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
|
| 433 |
+
--training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
|
| 434 |
+
```
|
| 435 |
+
|
| 436 |
+
### ~Finalizing training~
|
| 437 |
+
|
| 438 |
+
> [!NOTE]
|
| 439 |
+
> We have done this conversion automatically in the training script since our latest updates.
|
| 440 |
+
|
| 441 |
+
Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
|
| 442 |
+
To facilitate this, we provide a straightforward conversion script:
|
| 443 |
+
|
| 444 |
+
```sh
|
| 445 |
+
python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
|
| 446 |
+
```
|
| 447 |
+
After this, your model will be in the 🤗 format, ready to be shared or deployed.
|
| 448 |
+
You can then easily publish your model using the `huggingface_hub` for wider accessibility.
|
| 449 |
+
|
| 450 |
+
### Continual training
|
| 451 |
+
|
| 452 |
+
If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
|
| 453 |
+
This allows you to seamlessly resume training with `flame`.
|
| 454 |
+
```sh
|
| 455 |
+
python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
|
| 456 |
+
```
|
| 457 |
+
Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
|
| 458 |
+
The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
|
| 459 |
+
|
| 460 |
+
Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
|
| 461 |
+
|
| 462 |
+
## Multi-node training
|
| 463 |
+
|
| 464 |
+
If you have access to multi-node GPUs, consider leveraging them for optimal performance.
|
| 465 |
+
This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
|
| 466 |
+
|
| 467 |
+
To set up multi-node training:
|
| 468 |
+
* Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
|
| 469 |
+
* If you're using a job scheduler like Slurm, it will handle these variables for you.
|
| 470 |
+
|
| 471 |
+
`torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
|
config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MTPTransformerForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"bos_token_id": 1,
|
| 7 |
+
"elementwise_affine": true,
|
| 8 |
+
"eos_token_id": 2,
|
| 9 |
+
"fuse_cross_entropy": true,
|
| 10 |
+
"fuse_norm": true,
|
| 11 |
+
"fuse_swiglu": true,
|
| 12 |
+
"hidden_act": "swish",
|
| 13 |
+
"hidden_ratio": 4,
|
| 14 |
+
"hidden_size": 1024,
|
| 15 |
+
"initializer_range": 0.006,
|
| 16 |
+
"intermediate_size": null,
|
| 17 |
+
"max_position_embeddings": 8192,
|
| 18 |
+
"model_type": "mtp_transformer",
|
| 19 |
+
"n_future_tokens": 3,
|
| 20 |
+
"norm_eps": 1e-06,
|
| 21 |
+
"num_heads": 16,
|
| 22 |
+
"num_hidden_layers": 24,
|
| 23 |
+
"num_kv_heads": null,
|
| 24 |
+
"qk_norm": false,
|
| 25 |
+
"qkv_bias": false,
|
| 26 |
+
"rope_theta": 10000.0,
|
| 27 |
+
"tie_word_embeddings": false,
|
| 28 |
+
"torch_dtype": "float32",
|
| 29 |
+
"transformers_version": "4.51.3",
|
| 30 |
+
"use_cache": true,
|
| 31 |
+
"use_custom_backward": false,
|
| 32 |
+
"vocab_size": 32000,
|
| 33 |
+
"window_size": null
|
| 34 |
+
}
|
configs/delta_net_1B.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn": null,
|
| 3 |
+
"attn_mode": "chunk",
|
| 4 |
+
"bos_token_id": 1,
|
| 5 |
+
"conv_size": 4,
|
| 6 |
+
"eos_token_id": 2,
|
| 7 |
+
"expand_k": 1,
|
| 8 |
+
"expand_v": 1,
|
| 9 |
+
"fuse_cross_entropy": true,
|
| 10 |
+
"fuse_norm": true,
|
| 11 |
+
"hidden_act": "swish",
|
| 12 |
+
"hidden_ratio": 4,
|
| 13 |
+
"hidden_size": 2048,
|
| 14 |
+
"initializer_range": 0.006,
|
| 15 |
+
"intermediate_size": null,
|
| 16 |
+
"model_type": "delta_net",
|
| 17 |
+
"norm_eps": 1e-06,
|
| 18 |
+
"num_heads": 16,
|
| 19 |
+
"num_hidden_layers": 24,
|
| 20 |
+
"pad_token_id": 2,
|
| 21 |
+
"qk_activation": "silu",
|
| 22 |
+
"qk_norm": "l2",
|
| 23 |
+
"tie_word_embeddings": false,
|
| 24 |
+
"use_beta": true,
|
| 25 |
+
"use_cache": true,
|
| 26 |
+
"use_gate": false,
|
| 27 |
+
"use_output_norm": true,
|
| 28 |
+
"use_short_conv": true
|
| 29 |
+
}
|
configs/delta_net_340M.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_mode": "chunk",
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"conv_size": 4,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_k": 1,
|
| 7 |
+
"expand_v": 1,
|
| 8 |
+
"fuse_cross_entropy": true,
|
| 9 |
+
"hidden_act": "swish",
|
| 10 |
+
"hidden_ratio": 4,
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"initializer_range": 0.006,
|
| 13 |
+
"intermediate_size": null,
|
| 14 |
+
"model_type": "delta_net",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"norm_first": false,
|
| 17 |
+
"num_heads": 8,
|
| 18 |
+
"num_hidden_layers": 24,
|
| 19 |
+
"qk_activation": "silu",
|
| 20 |
+
"qk_norm": "l2",
|
| 21 |
+
"tie_word_embeddings": false,
|
| 22 |
+
"use_beta": true,
|
| 23 |
+
"use_cache": true,
|
| 24 |
+
"use_gate": false,
|
| 25 |
+
"use_output_norm": true,
|
| 26 |
+
"use_short_conv": true
|
| 27 |
+
}
|
configs/gla_340M.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_mode": "chunk",
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"clamp_min": null,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_k": 0.5,
|
| 7 |
+
"expand_v": 1,
|
| 8 |
+
"fuse_cross_entropy": true,
|
| 9 |
+
"fuse_norm": true,
|
| 10 |
+
"hidden_act": "swish",
|
| 11 |
+
"hidden_ratio": 4,
|
| 12 |
+
"hidden_size": 1024,
|
| 13 |
+
"initializer_range": 0.006,
|
| 14 |
+
"intermediate_size": null,
|
| 15 |
+
"model_type": "gla",
|
| 16 |
+
"num_heads": 4,
|
| 17 |
+
"num_hidden_layers": 24,
|
| 18 |
+
"norm_eps": 1e-06,
|
| 19 |
+
"tie_word_embeddings": false,
|
| 20 |
+
"use_cache": true,
|
| 21 |
+
"use_gk": true,
|
| 22 |
+
"use_gv": false,
|
| 23 |
+
"vocab_size": 32000
|
| 24 |
+
}
|
configs/gla_7B.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn": null,
|
| 3 |
+
"attn_mode": "chunk",
|
| 4 |
+
"bos_token_id": 1,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_k": 0.5,
|
| 7 |
+
"expand_v": 1,
|
| 8 |
+
"fuse_cross_entropy": true,
|
| 9 |
+
"fuse_norm": true,
|
| 10 |
+
"hidden_act": "swish",
|
| 11 |
+
"hidden_ratio": 4,
|
| 12 |
+
"hidden_size": 4096,
|
| 13 |
+
"initializer_range": 0.006,
|
| 14 |
+
"intermediate_size": 11008,
|
| 15 |
+
"model_type": "gla",
|
| 16 |
+
"norm_eps": 1e-06,
|
| 17 |
+
"num_heads": 16,
|
| 18 |
+
"num_hidden_layers": 32,
|
| 19 |
+
"tie_word_embeddings": false,
|
| 20 |
+
"use_cache": true,
|
| 21 |
+
"use_gk": true,
|
| 22 |
+
"use_gv": false,
|
| 23 |
+
"use_output_gate": true,
|
| 24 |
+
"use_short_conv": false
|
| 25 |
+
}
|
configs/gsa_340M.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 1,
|
| 3 |
+
"conv_size": 4,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"expand_k": 1,
|
| 6 |
+
"expand_v": 1,
|
| 7 |
+
"elementwise_affine": false,
|
| 8 |
+
"feature_map": "swish",
|
| 9 |
+
"fuse_cross_entropy": true,
|
| 10 |
+
"fuse_norm": true,
|
| 11 |
+
"gate_logit_normalizer": 4,
|
| 12 |
+
"hidden_act": "swish",
|
| 13 |
+
"hidden_ratio": 4,
|
| 14 |
+
"hidden_size": 1024,
|
| 15 |
+
"initializer_range": 0.006,
|
| 16 |
+
"intermediate_size": null,
|
| 17 |
+
"model_type": "gsa",
|
| 18 |
+
"num_heads": 4,
|
| 19 |
+
"num_hidden_layers": 24,
|
| 20 |
+
"num_slots": 64,
|
| 21 |
+
"norm_eps": 1e-06,
|
| 22 |
+
"share_conv_kernel": true,
|
| 23 |
+
"tie_word_embeddings": false,
|
| 24 |
+
"use_cache": true,
|
| 25 |
+
"use_norm": true,
|
| 26 |
+
"use_output_gate": true,
|
| 27 |
+
"use_rope": false,
|
| 28 |
+
"use_short_conv": false
|
| 29 |
+
}
|
configs/mtp_transformer_340M.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attention_bias": false,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"fuse_cross_entropy": true,
|
| 6 |
+
"fuse_norm": true,
|
| 7 |
+
"hidden_act": "swish",
|
| 8 |
+
"hidden_size": 1024,
|
| 9 |
+
"initializer_range": 0.006,
|
| 10 |
+
"max_position_embeddings": 8192,
|
| 11 |
+
"model_type": "mtp_transformer",
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"num_hidden_layers": 24,
|
| 14 |
+
"norm_eps": 1e-06,
|
| 15 |
+
"tie_word_embeddings": false,
|
| 16 |
+
"use_cache": true,
|
| 17 |
+
"vocab_size": 32000,
|
| 18 |
+
"n_future_tokens": 3
|
| 19 |
+
}
|
configs/top_transformer_1B.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 1,
|
| 3 |
+
"elementwise_affine": true,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"fuse_cross_entropy": true,
|
| 6 |
+
"fuse_norm": true,
|
| 7 |
+
"fuse_swiglu": true,
|
| 8 |
+
"hidden_act": "swish",
|
| 9 |
+
"hidden_ratio": 4,
|
| 10 |
+
"hidden_size": 2048,
|
| 11 |
+
"initializer_range": 0.006,
|
| 12 |
+
"intermediate_size": null,
|
| 13 |
+
"max_position_embeddings": 8192,
|
| 14 |
+
"model_type": "top_transformer",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"num_heads": 32,
|
| 17 |
+
"num_hidden_layers": 32,
|
| 18 |
+
"num_kv_heads": null,
|
| 19 |
+
"pad_token_id": 2,
|
| 20 |
+
"rope_theta": 10000.0,
|
| 21 |
+
"tie_word_embeddings": false,
|
| 22 |
+
"use_top_loss": true,
|
| 23 |
+
"top_window_size": 4096
|
| 24 |
+
}
|
configs/top_transformer_340M.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attention_bias": false,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"fuse_cross_entropy": true,
|
| 6 |
+
"fuse_norm": true,
|
| 7 |
+
"hidden_act": "swish",
|
| 8 |
+
"hidden_size": 1024,
|
| 9 |
+
"initializer_range": 0.006,
|
| 10 |
+
"max_position_embeddings": 8192,
|
| 11 |
+
"model_type": "top_transformer",
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"num_hidden_layers": 24,
|
| 14 |
+
"norm_eps": 1e-06,
|
| 15 |
+
"tie_word_embeddings": false,
|
| 16 |
+
"use_cache": true,
|
| 17 |
+
"vocab_size": 32000,
|
| 18 |
+
"use_top_loss": true,
|
| 19 |
+
"top_window_size": 4096
|
| 20 |
+
}
|
configs/transformer_120M.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attention_bias": false,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"fuse_cross_entropy": true,
|
| 6 |
+
"fuse_norm": false,
|
| 7 |
+
"hidden_act": "swish",
|
| 8 |
+
"hidden_size": 768,
|
| 9 |
+
"initializer_range": 0.02,
|
| 10 |
+
"max_position_embeddings": 4096,
|
| 11 |
+
"model_type": "transformer",
|
| 12 |
+
"num_heads": 12,
|
| 13 |
+
"num_hidden_layers": 14,
|
| 14 |
+
"norm_eps": 1e-06,
|
| 15 |
+
"tie_word_embeddings": true,
|
| 16 |
+
"use_cache": true,
|
| 17 |
+
"vocab_size": 32000
|
| 18 |
+
}
|
configs/transformer_7B.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attention_bias": false,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"fuse_cross_entropy": true,
|
| 6 |
+
"fuse_norm": true,
|
| 7 |
+
"hidden_act": "swish",
|
| 8 |
+
"hidden_ratio": 4,
|
| 9 |
+
"hidden_size": 4096,
|
| 10 |
+
"initializer_range": 0.006,
|
| 11 |
+
"intermediate_size": 14336,
|
| 12 |
+
"model_type": "transformer",
|
| 13 |
+
"norm_eps": 1e-06,
|
| 14 |
+
"num_heads": 32,
|
| 15 |
+
"num_hidden_layers": 30,
|
| 16 |
+
"num_kv_heads": 8,
|
| 17 |
+
"rope_theta": 10000.0,
|
| 18 |
+
"tie_word_embeddings": false,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"window_size": null
|
| 21 |
+
}
|
fla/utils.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import functools
|
| 5 |
+
import os
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
from typing import Any, Callable, Dict, Literal, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import triton
|
| 12 |
+
from packaging import version
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def tensor_cache(
|
| 16 |
+
fn: Callable[..., torch.Tensor]
|
| 17 |
+
) -> Callable[..., torch.Tensor]:
|
| 18 |
+
"""
|
| 19 |
+
A decorator that caches the most recent result of a function with tensor inputs.
|
| 20 |
+
|
| 21 |
+
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
| 22 |
+
If the function is called again with the same input tensors, it will return the cached result.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
fn (Callable[..., torch.Tensor]):
|
| 27 |
+
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Callable[..., torch.Tensor]:
|
| 31 |
+
A wrapped version of the input function with single-entry caching.
|
| 32 |
+
"""
|
| 33 |
+
last_args: Optional[Tuple] = None
|
| 34 |
+
last_kwargs: Optional[Dict] = None
|
| 35 |
+
last_result: Any = None
|
| 36 |
+
|
| 37 |
+
@functools.wraps(fn)
|
| 38 |
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
| 39 |
+
nonlocal last_args, last_kwargs, last_result
|
| 40 |
+
|
| 41 |
+
if last_args is not None and last_kwargs is not None:
|
| 42 |
+
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
|
| 43 |
+
if all(a is b for a, b in zip(args, last_args)) and \
|
| 44 |
+
all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
|
| 45 |
+
return last_result
|
| 46 |
+
|
| 47 |
+
result = fn(*args, **kwargs)
|
| 48 |
+
last_args, last_kwargs, last_result = args, kwargs, result
|
| 49 |
+
return result
|
| 50 |
+
|
| 51 |
+
return wrapper
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def input_guard(
|
| 55 |
+
fn: Callable[..., torch.Tensor]
|
| 56 |
+
) -> Callable[..., torch.Tensor]:
|
| 57 |
+
"""
|
| 58 |
+
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
@functools.wraps(fn)
|
| 62 |
+
def wrapper(*args, **kwargs):
|
| 63 |
+
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
|
| 64 |
+
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
|
| 65 |
+
|
| 66 |
+
tensor = None
|
| 67 |
+
for arg in args:
|
| 68 |
+
if isinstance(arg, torch.Tensor):
|
| 69 |
+
tensor = arg
|
| 70 |
+
break
|
| 71 |
+
if tensor is None:
|
| 72 |
+
for value in kwargs.values():
|
| 73 |
+
if isinstance(value, torch.Tensor):
|
| 74 |
+
tensor = value
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
if tensor is not None:
|
| 78 |
+
ctx = custom_device_ctx(tensor.device.index)
|
| 79 |
+
else:
|
| 80 |
+
ctx = contextlib.nullcontext()
|
| 81 |
+
|
| 82 |
+
with ctx:
|
| 83 |
+
return fn(*contiguous_args, **contiguous_kwargs)
|
| 84 |
+
|
| 85 |
+
return wrapper
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
contiguous = input_guard
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def require_version(version, hint):
|
| 92 |
+
"""
|
| 93 |
+
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
|
| 94 |
+
"""
|
| 95 |
+
def decorator(fn):
|
| 96 |
+
@functools.wraps(fn)
|
| 97 |
+
def wrapper(ctx, *args, **kwargs):
|
| 98 |
+
from transformers.utils.versions import require_version
|
| 99 |
+
require_version(version, hint)
|
| 100 |
+
return fn(ctx,
|
| 101 |
+
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
|
| 102 |
+
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
|
| 103 |
+
return wrapper
|
| 104 |
+
return decorator
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def checkpoint(fn):
|
| 108 |
+
def wrapper(*args, **kwargs):
|
| 109 |
+
return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
|
| 110 |
+
return wrapper
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@lru_cache(maxsize=None)
|
| 114 |
+
def check_pytorch_version(version_s: str = '2.4') -> bool:
|
| 115 |
+
return version.parse(torch.__version__) >= version.parse(version_s)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _cpu_device_warning():
|
| 119 |
+
import warnings
|
| 120 |
+
warnings.warn(('Triton is not supported on current platform, roll back to CPU.'), stacklevel=1)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@lru_cache(maxsize=None)
|
| 124 |
+
def get_multiprocessor_count(tensor_idx: int = 0) -> int:
|
| 125 |
+
try:
|
| 126 |
+
# Only works if Homogeneous hardware
|
| 127 |
+
# TEMPORARY FIX since old version introduce graph break
|
| 128 |
+
return torch.cuda.get_device_properties().multi_processor_count
|
| 129 |
+
except BaseException:
|
| 130 |
+
_cpu_device_warning()
|
| 131 |
+
return -1
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@lru_cache(maxsize=None)
|
| 135 |
+
def get_available_device() -> str:
|
| 136 |
+
try:
|
| 137 |
+
return triton.runtime.driver.active.get_current_target().backend
|
| 138 |
+
except BaseException:
|
| 139 |
+
_cpu_device_warning()
|
| 140 |
+
return 'cpu'
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@lru_cache(maxsize=None)
|
| 144 |
+
def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
|
| 145 |
+
device = get_available_device()
|
| 146 |
+
if device == 'cuda':
|
| 147 |
+
return 'nvidia'
|
| 148 |
+
elif device == 'hip':
|
| 149 |
+
return 'amd'
|
| 150 |
+
elif device == 'xpu':
|
| 151 |
+
return 'intel'
|
| 152 |
+
else:
|
| 153 |
+
return device
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
| 157 |
+
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
| 158 |
+
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
| 159 |
+
device = get_available_device() if get_available_device() != 'hip' else 'cuda'
|
| 160 |
+
device_torch_lib = getattr(torch, device)
|
| 161 |
+
device_platform = _check_platform()
|
| 162 |
+
|
| 163 |
+
is_amd = (device_platform == 'amd')
|
| 164 |
+
is_intel = (device_platform == 'intel')
|
| 165 |
+
is_nvidia = (device_platform == 'nvidia')
|
| 166 |
+
is_intel_alchemist = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
|
| 167 |
+
is_nvidia_hopper = (is_nvidia and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
|
| 168 |
+
use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
|
| 169 |
+
|
| 170 |
+
# Nvidia Ampere or newer, haven't check AMD and intel yet.
|
| 171 |
+
is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8)
|
| 172 |
+
is_gather_supported = hasattr(triton.language, 'gather')
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_all_max_shared_mem():
|
| 176 |
+
try:
|
| 177 |
+
return [
|
| 178 |
+
triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem']
|
| 179 |
+
for i in range(device_torch_lib.device_count())
|
| 180 |
+
]
|
| 181 |
+
except BaseException:
|
| 182 |
+
_cpu_device_warning()
|
| 183 |
+
return [-1]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class Backend(Enum):
|
| 187 |
+
ADA = 101376 # RTX 4090
|
| 188 |
+
AMPERE = 166912 # A100
|
| 189 |
+
HOPPER = 232448 # H100
|
| 190 |
+
DEFAULT = 102400 # Default
|
| 191 |
+
|
| 192 |
+
@classmethod
|
| 193 |
+
def get_shared_memory(cls, arch: str) -> int:
|
| 194 |
+
try:
|
| 195 |
+
return cls[arch.upper()].value
|
| 196 |
+
except KeyError:
|
| 197 |
+
return cls.DEFAULT.value
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@lru_cache(maxsize=None)
|
| 201 |
+
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
| 202 |
+
try:
|
| 203 |
+
device_shared_mem_list = get_all_max_shared_mem()
|
| 204 |
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
| 205 |
+
return max_shared_memory >= Backend.get_shared_memory(arch)
|
| 206 |
+
except Exception:
|
| 207 |
+
return False
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if check_pytorch_version('2.4'):
|
| 211 |
+
device = 'cuda' if device == 'cpu' else device
|
| 212 |
+
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
|
| 213 |
+
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
|
| 214 |
+
|
| 215 |
+
def custom_device_ctx(index: int):
|
| 216 |
+
return device_torch_lib.device(index)
|
| 217 |
+
else:
|
| 218 |
+
assert device == 'cuda', 'Only cuda device is supported for PyTorch version < 2.4.0.'
|
| 219 |
+
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
|
| 220 |
+
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
|
| 221 |
+
|
| 222 |
+
def custom_device_ctx(index: int):
|
| 223 |
+
return torch.cuda.device(index)
|
generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"transformers_version": "4.51.3"
|
| 6 |
+
}
|
logs/none_ro0qpaac/attempt_0/0/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/none_ro0qpaac/attempt_0/1/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/none_ro0qpaac/attempt_0/3/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/none_ro0qpaac/attempt_0/5/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/none_ro0qpaac/attempt_0/6/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/none_ro0qpaac/attempt_0/7/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "flame"
|
| 3 |
+
dynamic = ["version"]
|
| 4 |
+
description = "A minimal training framework for scaling FLA models"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
authors = [
|
| 7 |
+
{ name = "Songlin Yang", email = "yangsl66@mit.edu" },
|
| 8 |
+
{ name = "Yu Zhang", email = "yzhang.cs@outlook.com" },
|
| 9 |
+
]
|
| 10 |
+
license = { file = "LICENSE" }
|
| 11 |
+
classifiers = [
|
| 12 |
+
"Programming Language :: Python :: 3",
|
| 13 |
+
"License :: OSI Approved :: MIT License",
|
| 14 |
+
"Operating System :: OS Independent",
|
| 15 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 16 |
+
]
|
| 17 |
+
requires-python = ">=3.10"
|
| 18 |
+
dependencies = [
|
| 19 |
+
'torch==2.6',
|
| 20 |
+
'torchdata',
|
| 21 |
+
'transformers==4.51.3',
|
| 22 |
+
'triton>=3.0',
|
| 23 |
+
'datasets>=3.3.0',
|
| 24 |
+
'einops',
|
| 25 |
+
'ninja',
|
| 26 |
+
'wandb',
|
| 27 |
+
'tiktoken',
|
| 28 |
+
'tensorboard',
|
| 29 |
+
'python-dotenv'
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
[project.optional-dependencies]
|
| 33 |
+
dev = ["pytest"]
|
| 34 |
+
|
| 35 |
+
[project.urls]
|
| 36 |
+
Homepage = "https://github.com/fla-org/flame"
|
| 37 |
+
|
| 38 |
+
[build-system]
|
| 39 |
+
requires = ["setuptools>=45", "wheel", "ninja", "torch"]
|
| 40 |
+
|
| 41 |
+
[tool.isort]
|
| 42 |
+
line_length = 127
|
| 43 |
+
multi_line_output = 3
|
setup.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import ast
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from setuptools import find_packages, setup
|
| 9 |
+
|
| 10 |
+
with open('README.md') as f:
|
| 11 |
+
long_description = f.read()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_package_version():
|
| 15 |
+
with open(Path(os.path.dirname(os.path.abspath(__file__))) / 'flame' / '__init__.py') as f:
|
| 16 |
+
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
|
| 17 |
+
return ast.literal_eval(version_match.group(1))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
setup(
|
| 21 |
+
name='flame',
|
| 22 |
+
version=get_package_version(),
|
| 23 |
+
description='A minimal training framework for scaling FLA models',
|
| 24 |
+
long_description=long_description,
|
| 25 |
+
long_description_content_type='text/markdown',
|
| 26 |
+
author='Songlin Yang, Yu Zhang',
|
| 27 |
+
author_email='yangsl66@mit.edu, yzhang.cs@outlook.com',
|
| 28 |
+
url='https://github.com/fla-org/flame',
|
| 29 |
+
packages=find_packages(),
|
| 30 |
+
license='MIT',
|
| 31 |
+
classifiers=[
|
| 32 |
+
'Programming Language :: Python :: 3',
|
| 33 |
+
'License :: OSI Approved :: MIT License',
|
| 34 |
+
'Operating System :: OS Independent',
|
| 35 |
+
'Topic :: Scientific/Engineering :: Artificial Intelligence'
|
| 36 |
+
],
|
| 37 |
+
python_requires='>=3.10',
|
| 38 |
+
install_requires=[
|
| 39 |
+
'torch==2.6',
|
| 40 |
+
'torchdata',
|
| 41 |
+
'transformers==4.51.3',
|
| 42 |
+
'triton>=3.0',
|
| 43 |
+
'datasets>=3.3.0',
|
| 44 |
+
'einops',
|
| 45 |
+
'ninja',
|
| 46 |
+
'wandb',
|
| 47 |
+
'tiktoken',
|
| 48 |
+
'tensorboard',
|
| 49 |
+
'python-dotenv'
|
| 50 |
+
],
|
| 51 |
+
)
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"unk_token": {
|
| 17 |
+
"content": "<unk>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
}
|
| 23 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"additional_special_tokens": [],
|
| 32 |
+
"bos_token": "<s>",
|
| 33 |
+
"clean_up_tokenization_spaces": false,
|
| 34 |
+
"eos_token": "</s>",
|
| 35 |
+
"extra_special_tokens": {},
|
| 36 |
+
"legacy": true,
|
| 37 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 38 |
+
"pad_token": null,
|
| 39 |
+
"sp_model_kwargs": {},
|
| 40 |
+
"spaces_between_special_tokens": false,
|
| 41 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 42 |
+
"unk_token": "<unk>",
|
| 43 |
+
"use_default_system_prompt": false
|
| 44 |
+
}
|
torchtitan/components/__pycache__/checkpoint.cpython-312.pyc
ADDED
|
Binary file (33.1 kB). View file
|
|
|
torchtitan/components/__pycache__/dataloader.cpython-312.pyc
ADDED
|
Binary file (3.79 kB). View file
|
|
|
torchtitan/components/__pycache__/float8.cpython-312.pyc
ADDED
|
Binary file (6.2 kB). View file
|
|
|
torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc
ADDED
|
Binary file (7.71 kB). View file
|
|
|
torchtitan/components/__pycache__/metrics.cpython-312.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
torchtitan/components/__pycache__/optimizer.cpython-312.pyc
ADDED
|
Binary file (14.5 kB). View file
|
|
|
torchtitan/components/__pycache__/tokenizer.cpython-312.pyc
ADDED
|
Binary file (1.09 kB). View file
|
|
|
torchtitan/distributed/utils.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import contextlib
|
| 8 |
+
import math
|
| 9 |
+
import os
|
| 10 |
+
from collections.abc import Generator, Iterable
|
| 11 |
+
from datetime import timedelta
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.distributed._functional_collectives as funcol
|
| 15 |
+
import torch.distributed.distributed_c10d as c10d
|
| 16 |
+
from torch import distributed as dist
|
| 17 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 18 |
+
from torch.distributed.tensor import DTensor
|
| 19 |
+
|
| 20 |
+
from torchtitan.components.ft import ft_clip_grad_norm_util, ft_dist_reduce
|
| 21 |
+
from torchtitan.tools.logging import logger
|
| 22 |
+
from torchtitan.tools.utils import device_module, device_type
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float:
|
| 26 |
+
# Remove FT replicate dimension if it exists.
|
| 27 |
+
x, reduceOp, mesh = ft_dist_reduce(x, reduceOp, mesh)
|
| 28 |
+
|
| 29 |
+
if isinstance(x, DTensor):
|
| 30 |
+
# functional collectives do not support DTensor inputs
|
| 31 |
+
x = x.full_tensor()
|
| 32 |
+
assert x.numel() == 1 # required by `.item()`
|
| 33 |
+
return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def dist_max(x: torch.Tensor, mesh: DeviceMesh) -> float:
|
| 37 |
+
return _dist_reduce(x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def dist_mean(x: torch.Tensor, mesh: DeviceMesh) -> float:
|
| 41 |
+
return _dist_reduce(x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def set_determinism(
|
| 45 |
+
world_mesh: DeviceMesh | None,
|
| 46 |
+
device: torch.device,
|
| 47 |
+
seed: int | None = None,
|
| 48 |
+
deterministic: bool = False,
|
| 49 |
+
) -> None:
|
| 50 |
+
"""
|
| 51 |
+
Set the same DTensor manual seed for all ranks within the same DTensor SPMD group, but different
|
| 52 |
+
seeds across PP groups (if applicable).
|
| 53 |
+
|
| 54 |
+
Currently, does not set seeds for the CUDA RNG since TorchTitan always uses DTensor for SPMD parallelisms,
|
| 55 |
+
and DTensor manages its own RNG tracker, but we could extend to support both if needed.
|
| 56 |
+
|
| 57 |
+
Set Determinism flags for increased reproducibility with loss of performance.
|
| 58 |
+
"""
|
| 59 |
+
if deterministic:
|
| 60 |
+
logger.info("Deterministic algorithm enabled (expect perf degradation).")
|
| 61 |
+
torch.use_deterministic_algorithms(True)
|
| 62 |
+
torch.backends.cudnn.deterministic = True
|
| 63 |
+
torch.backends.cudnn.benchmark = False
|
| 64 |
+
# env var for deterministic CuBLAS
|
| 65 |
+
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
|
| 66 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
| 67 |
+
|
| 68 |
+
if not world_mesh:
|
| 69 |
+
if seed is not None:
|
| 70 |
+
torch.manual_seed(seed)
|
| 71 |
+
os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
|
| 72 |
+
logger.debug(f"Single-process job using seed: {seed}")
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
# to ensure we can control which ranks have same or different seeds, all ranks agree on a starting seed.
|
| 76 |
+
# if user provides one, we use this. Otherwise rank 0 rolls the dice and everyone else uses that.
|
| 77 |
+
if seed is None:
|
| 78 |
+
# Extract the seed for torch's main generator on rank 0 and standardizes on using that to build
|
| 79 |
+
# seeds for unique SPMD groups
|
| 80 |
+
seed_tensor = torch.get_rng_state()[:8].to(device)
|
| 81 |
+
torch.distributed.broadcast(seed_tensor, src=0)
|
| 82 |
+
seed = seed_tensor.to("cpu").view(torch.uint64).item()
|
| 83 |
+
|
| 84 |
+
# For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh,
|
| 85 |
+
# and choose a unique seed for each rank on the PP mesh.
|
| 86 |
+
if c10d.get_world_size() > 1 and "pp" in world_mesh.mesh_dim_names:
|
| 87 |
+
pp_mesh = world_mesh["pp"]
|
| 88 |
+
seed += pp_mesh.get_local_rank()
|
| 89 |
+
seed %= 2**64
|
| 90 |
+
|
| 91 |
+
logger.debug(
|
| 92 |
+
f"PP rank {pp_mesh.get_local_rank()}, Global rank {c10d.get_rank()} using seed: {seed}"
|
| 93 |
+
)
|
| 94 |
+
spmd_mesh_dims = list(
|
| 95 |
+
filter(lambda name: name != "pp", world_mesh.mesh_dim_names)
|
| 96 |
+
)
|
| 97 |
+
spmd_mesh = world_mesh[spmd_mesh_dims] if len(spmd_mesh_dims) else None
|
| 98 |
+
else:
|
| 99 |
+
spmd_mesh = world_mesh
|
| 100 |
+
logger.debug(f"Global Rank {c10d.get_rank()} using seed: {seed}")
|
| 101 |
+
|
| 102 |
+
# The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency.
|
| 103 |
+
torch.manual_seed(seed)
|
| 104 |
+
# PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1]
|
| 105 |
+
os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
|
| 106 |
+
|
| 107 |
+
# As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh.
|
| 108 |
+
# IF PP is also used, this seed is unique per PP rank.
|
| 109 |
+
if spmd_mesh and spmd_mesh.get_coordinate() is not None:
|
| 110 |
+
torch.distributed.tensor._random.manual_seed(seed, spmd_mesh)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def create_context_parallel_ctx(
|
| 114 |
+
cp_mesh: DeviceMesh,
|
| 115 |
+
cp_buffers: list[torch.Tensor],
|
| 116 |
+
cp_seq_dims: list[int],
|
| 117 |
+
cp_no_restore_buffers: set[torch.Tensor],
|
| 118 |
+
cp_rotate_method: str,
|
| 119 |
+
):
|
| 120 |
+
try:
|
| 121 |
+
from torch.distributed.tensor.experimental import context_parallel
|
| 122 |
+
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
| 123 |
+
except ImportError:
|
| 124 |
+
print(
|
| 125 |
+
f"PyTorch version {torch.__version__} does not include the experimental "
|
| 126 |
+
"Context Parallel API. Please update to a newer version."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
set_rotate_method(cp_rotate_method)
|
| 130 |
+
return context_parallel(
|
| 131 |
+
cp_mesh,
|
| 132 |
+
buffers=cp_buffers,
|
| 133 |
+
buffer_seq_dims=cp_seq_dims,
|
| 134 |
+
no_restore_buffers=cp_no_restore_buffers,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_train_context(
|
| 139 |
+
enable_loss_parallel: bool, enable_compiled_autograd: bool
|
| 140 |
+
) -> Generator[None, None, None]:
|
| 141 |
+
@contextlib.contextmanager
|
| 142 |
+
def context(cp_context: Generator[None, None, None] | None = None):
|
| 143 |
+
with contextlib.ExitStack() as stack:
|
| 144 |
+
if enable_loss_parallel:
|
| 145 |
+
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
|
| 146 |
+
|
| 147 |
+
if enable_compiled_autograd:
|
| 148 |
+
stack.enter_context(
|
| 149 |
+
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if cp_context is not None:
|
| 153 |
+
from torch.nn.attention import sdpa_kernel, SDPBackend
|
| 154 |
+
|
| 155 |
+
stack.enter_context(
|
| 156 |
+
sdpa_kernel(
|
| 157 |
+
[
|
| 158 |
+
SDPBackend.FLASH_ATTENTION,
|
| 159 |
+
SDPBackend.EFFICIENT_ATTENTION,
|
| 160 |
+
SDPBackend.CUDNN_ATTENTION,
|
| 161 |
+
]
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
stack.enter_context(cp_context)
|
| 165 |
+
|
| 166 |
+
yield
|
| 167 |
+
|
| 168 |
+
return context
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def init_distributed(job_config):
|
| 172 |
+
def _warn_overwrite_env(env, val):
|
| 173 |
+
if env in os.environ:
|
| 174 |
+
logger.warning(
|
| 175 |
+
f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config"
|
| 176 |
+
)
|
| 177 |
+
os.environ[env] = val
|
| 178 |
+
|
| 179 |
+
def _get_distributed_backend(job_config):
|
| 180 |
+
backend = "nccl"
|
| 181 |
+
if device_type in torch.distributed.Backend.default_device_backend_map:
|
| 182 |
+
backend = torch.distributed.Backend.default_device_backend_map.get(
|
| 183 |
+
device_type
|
| 184 |
+
)
|
| 185 |
+
if job_config.training.enable_cpu_offload:
|
| 186 |
+
backend = f"{device_type}:{backend},cpu:gloo"
|
| 187 |
+
return backend
|
| 188 |
+
|
| 189 |
+
TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
|
| 190 |
+
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
|
| 191 |
+
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"
|
| 192 |
+
ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
| 193 |
+
SKIP_CLEANUP = "3"
|
| 194 |
+
|
| 195 |
+
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
|
| 196 |
+
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
|
| 197 |
+
# This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle
|
| 198 |
+
# behavior differences
|
| 199 |
+
_warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP)
|
| 200 |
+
|
| 201 |
+
# enable torch nccl flight recorder in the mode that would dump files if timeout is detected
|
| 202 |
+
_warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size))
|
| 203 |
+
if job_config.comm.trace_buf_size > 0:
|
| 204 |
+
# dump on timeout by default if trace buffer is enabled
|
| 205 |
+
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1")
|
| 206 |
+
dump_dir = f"{job_config.job.dump_folder}/comm_trace"
|
| 207 |
+
os.makedirs(dump_dir, exist_ok=True)
|
| 208 |
+
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_")
|
| 209 |
+
|
| 210 |
+
# to mitigate the memory issue that collectives using
|
| 211 |
+
# async_op=True hold memory longer than they should
|
| 212 |
+
# such as those in tensor parallelism
|
| 213 |
+
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
| 214 |
+
|
| 215 |
+
torch.distributed.init_process_group(
|
| 216 |
+
backend=_get_distributed_backend(job_config),
|
| 217 |
+
timeout=timedelta(seconds=job_config.comm.init_timeout_seconds),
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def set_pg_timeouts(timeout, world_mesh):
|
| 222 |
+
"""
|
| 223 |
+
Sets the timeout for all PGs in the provided mesh, and the default (world) group.
|
| 224 |
+
|
| 225 |
+
Note: synchronizes via a barrier, before changing the timeouts. This is important, because
|
| 226 |
+
otherwise you may face a race where the slow rank has not reached the timeout reduction point
|
| 227 |
+
yet due to slow operations permitted under the old timeout value, but other faster ranks may
|
| 228 |
+
start issuing collectives under the new shorter timeout and then immediately timeout.
|
| 229 |
+
"""
|
| 230 |
+
logger.info(
|
| 231 |
+
f"Synchronizing and adjusting timeout for all ProcessGroups to {timeout}"
|
| 232 |
+
)
|
| 233 |
+
# Ensure that all the ranks have reached the point of setting the new timeout-
|
| 234 |
+
# otherwise, some ranks may issue collectives with the new/shorter timeout and
|
| 235 |
+
# those may time out, before other ranks have finished with initialization done
|
| 236 |
+
# under the old/slow timeout.
|
| 237 |
+
torch.distributed.barrier(device_ids=[device_module.current_device()])
|
| 238 |
+
device_module.synchronize()
|
| 239 |
+
|
| 240 |
+
groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)]
|
| 241 |
+
|
| 242 |
+
# None represents the 'default' PG, not part of the mesh
|
| 243 |
+
groups.append(None)
|
| 244 |
+
for group in groups:
|
| 245 |
+
torch.distributed.distributed_c10d._set_pg_timeout(timeout, group)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@torch.no_grad()
|
| 249 |
+
def clip_grad_norm_(
|
| 250 |
+
parameters: torch.Tensor | Iterable[torch.Tensor],
|
| 251 |
+
max_norm: float,
|
| 252 |
+
norm_type: float = 2.0,
|
| 253 |
+
error_if_nonfinite: bool = False,
|
| 254 |
+
foreach: bool | None = None,
|
| 255 |
+
pp_mesh: DeviceMesh | None = None,
|
| 256 |
+
) -> torch.Tensor:
|
| 257 |
+
"""
|
| 258 |
+
Clip the gradient norm of an iterable of parameters.
|
| 259 |
+
|
| 260 |
+
Gradient norm clipping requires computing the gradient norm over the entire model.
|
| 261 |
+
`torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions.
|
| 262 |
+
We need to manually reduce the gradient norm across PP stages.
|
| 263 |
+
See https://github.com/pytorch/torchtitan/issues/596 for details.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
parameters: an iterable of Tensors or a single Tensor that will have gradients normalized
|
| 267 |
+
max_norm (float): max norm of the gradients
|
| 268 |
+
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
| 269 |
+
infinity norm.
|
| 270 |
+
error_if_nonfinite (bool): if True, an error is thrown if the total
|
| 271 |
+
norm of the gradients from :attr:`parameters` is ``nan``,
|
| 272 |
+
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
| 273 |
+
foreach (bool): use the faster foreach-based implementation.
|
| 274 |
+
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
| 275 |
+
fall back to the slow implementation for other device types.
|
| 276 |
+
Default: ``None``
|
| 277 |
+
pp_mesh: pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
Total norm of the parameter gradients (viewed as a single vector).
|
| 281 |
+
|
| 282 |
+
"""
|
| 283 |
+
grads = [p.grad for p in parameters if p.grad is not None]
|
| 284 |
+
total_norm = torch.nn.utils.get_total_norm(
|
| 285 |
+
grads, norm_type, error_if_nonfinite, foreach
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
|
| 289 |
+
# We can simply reduce the DTensor to get the total norm in this tensor's process group
|
| 290 |
+
# and then convert it to a local tensor.
|
| 291 |
+
# NOTE: It has two purposes:
|
| 292 |
+
# 1. to make sure the total norm is computed correctly when PP is used (see below)
|
| 293 |
+
# 2. to return a reduced total_norm tensor whose .item() would return the correct value
|
| 294 |
+
if isinstance(total_norm, DTensor):
|
| 295 |
+
# Will reach here if any non-PP parallelism is used.
|
| 296 |
+
# If only using PP, total_norm will be a local tensor.
|
| 297 |
+
|
| 298 |
+
# Remove FT replicate dimension if it exists.
|
| 299 |
+
total_norm = ft_clip_grad_norm_util(total_norm)
|
| 300 |
+
total_norm = total_norm.full_tensor()
|
| 301 |
+
|
| 302 |
+
if pp_mesh is not None:
|
| 303 |
+
if math.isinf(norm_type):
|
| 304 |
+
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group())
|
| 305 |
+
else:
|
| 306 |
+
total_norm **= norm_type
|
| 307 |
+
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
|
| 308 |
+
total_norm **= 1.0 / norm_type
|
| 309 |
+
|
| 310 |
+
torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
|
| 311 |
+
return total_norm
|
torchtitan/experiments/deepseek_v3/README.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Running DeepSeek in Titan (experimental)
|
| 2 |
+
|
| 3 |
+
This folder contains a DeepSeek model supporting v2 and v3 as well as kernels
|
| 4 |
+
and scripts needed to run it.
|
| 5 |
+
|
| 6 |
+
## Inference
|
| 7 |
+
|
| 8 |
+
### Prerequisites:
|
| 9 |
+
|
| 10 |
+
You will need to download a DeepSeek model's weights if you want to run a
|
| 11 |
+
pre-trained checkpoint. We provided a script to download the weights from
|
| 12 |
+
HuggingFace Model Hub:
|
| 13 |
+
```bash
|
| 14 |
+
python download.py [vX]
|
| 15 |
+
```
|
| 16 |
+
where `vX` can be v2 or v3, both are supported. You may be required to create a
|
| 17 |
+
HuggingFace account and log in first.
|
| 18 |
+
|
| 19 |
+
### Running inference:
|
| 20 |
+
|
| 21 |
+
The inference script is in `generate.py`. You can run it with the following
|
| 22 |
+
command:
|
| 23 |
+
```bash
|
| 24 |
+
torchrun --standalone --nproc-per-node 4 generate.py
|
| 25 |
+
```
|
| 26 |
+
This will run inference on the `DeepSeek-V2-Lite-Chat` model using 4 GPUs by
|
| 27 |
+
default.
|
| 28 |
+
|
| 29 |
+
Alternatively, you can run inference by using `bash inference.sh`, optionally
|
| 30 |
+
followed by your prompt.
|
| 31 |
+
|
| 32 |
+
## Training
|
| 33 |
+
|
| 34 |
+
The training script is in `train.py`. You can run it by the following command:
|
| 35 |
+
```bash
|
| 36 |
+
torchrun --standalone --nproc-per-node 8 train.py
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
This will run training on the `DeepSeek-V2-Lite-Chat` model using 8 GPUs by
|
| 40 |
+
default, with pipeline parallel, expert parallel, and data parallel enabled.
|
torchtitan/experiments/deepseek_v3/indices.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__all__ = ["generate_permute_indices"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@triton.jit
|
| 16 |
+
def fill_indices_kernel(
|
| 17 |
+
tokens_per_expert_group_ptr, # *Pointer* to first input vector.
|
| 18 |
+
start_index_values_ptr, # *Pointer* to second input vector.
|
| 19 |
+
write_offsets_ptr, # *Pointer* to third input vector.
|
| 20 |
+
output_ptr, # *Pointer* to output vector.
|
| 21 |
+
experts_per_rank, # Number of experts per rank.
|
| 22 |
+
num_ranks, # Number of expert ranks.
|
| 23 |
+
):
|
| 24 |
+
# There are multiple 'programs' processing different data. We identify which program
|
| 25 |
+
# we are here:
|
| 26 |
+
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
|
| 27 |
+
# The total number of programs in the launch grid.
|
| 28 |
+
num_programs = tl.num_programs(axis=0)
|
| 29 |
+
# We map the programs (blocks) to the experts.
|
| 30 |
+
for expert_id in tl.range(pid, experts_per_rank, step=num_programs):
|
| 31 |
+
# Read this expert's write offset.
|
| 32 |
+
write_offset = tl.load(write_offsets_ptr + expert_id)
|
| 33 |
+
# Loop over the ranks.
|
| 34 |
+
for r in tl.range(num_ranks):
|
| 35 |
+
# Slot in the tokens_per_expert_group array.
|
| 36 |
+
i = r * experts_per_rank + expert_id
|
| 37 |
+
start_index = tl.load(start_index_values_ptr + i)
|
| 38 |
+
length = tl.load(tokens_per_expert_group_ptr + i)
|
| 39 |
+
# Write the indices.
|
| 40 |
+
for l in tl.range(length):
|
| 41 |
+
val = start_index + l
|
| 42 |
+
tl.store(output_ptr + write_offset + l, val)
|
| 43 |
+
write_offset += length
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def fill_indices(
|
| 47 |
+
tokens_per_expert_group: torch.Tensor,
|
| 48 |
+
start_index_values: torch.Tensor,
|
| 49 |
+
write_offsets: torch.Tensor,
|
| 50 |
+
experts_per_rank: int,
|
| 51 |
+
num_ranks: int,
|
| 52 |
+
max_len: int,
|
| 53 |
+
):
|
| 54 |
+
# We need to preallocate the output.
|
| 55 |
+
permuted_indices = torch.full(
|
| 56 |
+
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
|
| 57 |
+
)
|
| 58 |
+
# Analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
|
| 59 |
+
# In this case, we use a 1D grid where the size is the number of blocks (TODO: bump this value).
|
| 60 |
+
grid = lambda meta: (1,)
|
| 61 |
+
# Each torch.tensor object is implicitly converted into a pointer to its first element.
|
| 62 |
+
fill_indices_kernel[grid](
|
| 63 |
+
tokens_per_expert_group,
|
| 64 |
+
start_index_values,
|
| 65 |
+
write_offsets,
|
| 66 |
+
permuted_indices,
|
| 67 |
+
experts_per_rank,
|
| 68 |
+
num_ranks,
|
| 69 |
+
)
|
| 70 |
+
return permuted_indices
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def fill_indices_cpu(
|
| 74 |
+
tokens_per_expert_group: torch.Tensor,
|
| 75 |
+
start_index_values: torch.Tensor,
|
| 76 |
+
write_offsets: torch.Tensor,
|
| 77 |
+
experts_per_rank: int,
|
| 78 |
+
num_ranks: int,
|
| 79 |
+
max_len: int,
|
| 80 |
+
):
|
| 81 |
+
# We need to preallocate the output.
|
| 82 |
+
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
|
| 83 |
+
# Fill the permuted indices
|
| 84 |
+
# For each local expert
|
| 85 |
+
for e in range(experts_per_rank):
|
| 86 |
+
write_start = write_offsets[e]
|
| 87 |
+
# For each remote rank
|
| 88 |
+
for r in range(num_ranks):
|
| 89 |
+
i = r * experts_per_rank + e
|
| 90 |
+
start_index = start_index_values[i]
|
| 91 |
+
length = tokens_per_expert_group[i]
|
| 92 |
+
# Fill in the indices
|
| 93 |
+
permuted_indices[write_start : write_start + length] = torch.arange(
|
| 94 |
+
start_index, start_index + length
|
| 95 |
+
)
|
| 96 |
+
write_start += length
|
| 97 |
+
return permuted_indices
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def generate_permute_indices(
|
| 101 |
+
tokens_per_expert_group: torch.Tensor,
|
| 102 |
+
experts_per_rank: int,
|
| 103 |
+
num_ranks: int,
|
| 104 |
+
max_len: int,
|
| 105 |
+
alignment: int,
|
| 106 |
+
use_cpu: bool = False,
|
| 107 |
+
):
|
| 108 |
+
# Prepare permutation indices and the number of tokens for each expert. The
|
| 109 |
+
# permutation indices are the indices of the tokens for each expert. The
|
| 110 |
+
# number of tokens for each expert is the sum of the number of tokens for
|
| 111 |
+
# such experts from all ranks. This number is aligned to the provided
|
| 112 |
+
# alignment requirement (usually comes from group gemm).
|
| 113 |
+
|
| 114 |
+
# Args:
|
| 115 |
+
# tokens_per_expert_group: number of tokens for each expert from all ranks.
|
| 116 |
+
# experts_per_rank: number of experts per rank.
|
| 117 |
+
# num_ranks: number of ranks.
|
| 118 |
+
# max_len: maximum length of the output index vector. If greater than
|
| 119 |
+
# total number of tokens, the remaining indices are set to -1.
|
| 120 |
+
# alignment: alignment for each returned element in `m_sizes`.
|
| 121 |
+
# use_cpu: whether to use cpu or gpu.
|
| 122 |
+
# Returns:
|
| 123 |
+
# permuted_indices: permutation indices.
|
| 124 |
+
# m_sizes: number of tokens for each expert.
|
| 125 |
+
|
| 126 |
+
# `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example:
|
| 127 |
+
# From: | rank 0 | rank 1 |
|
| 128 |
+
# To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 |
|
| 129 |
+
# | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
|
| 130 |
+
|
| 131 |
+
# Prefix sum to get the start index value of each expert
|
| 132 |
+
start_index_values = (
|
| 133 |
+
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
|
| 134 |
+
)
|
| 135 |
+
# Chunk sizes for each expert
|
| 136 |
+
chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
|
| 137 |
+
# Align the chunk sizes to the given alignment
|
| 138 |
+
m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to(
|
| 139 |
+
torch.int32
|
| 140 |
+
)
|
| 141 |
+
# Perform another prefix sum to get the write offset of each expert in `permuted_indices`
|
| 142 |
+
write_offsets = torch.cumsum(m_sizes, 0) - m_sizes
|
| 143 |
+
# Select the method to fill the permuted indices
|
| 144 |
+
fill_fn = fill_indices_cpu if use_cpu else fill_indices
|
| 145 |
+
# Fill the permuted indices
|
| 146 |
+
permuted_indices = fill_fn(
|
| 147 |
+
tokens_per_expert_group,
|
| 148 |
+
start_index_values,
|
| 149 |
+
write_offsets,
|
| 150 |
+
experts_per_rank,
|
| 151 |
+
num_ranks,
|
| 152 |
+
max_len,
|
| 153 |
+
)
|
| 154 |
+
return permuted_indices, m_sizes
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Below is for testing only
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def test():
|
| 161 |
+
device = torch.device("cuda", 0)
|
| 162 |
+
experts_per_rank = 4
|
| 163 |
+
num_ranks = 4
|
| 164 |
+
tokens_per_expert_group = torch.full(
|
| 165 |
+
(num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device
|
| 166 |
+
)
|
| 167 |
+
max_len = 128
|
| 168 |
+
alignment = 32
|
| 169 |
+
# Use the GPU kernel
|
| 170 |
+
permuted_indices_gpu, m_sizes = generate_permute_indices(
|
| 171 |
+
tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
|
| 172 |
+
)
|
| 173 |
+
# Use the CPU method
|
| 174 |
+
permuted_indices_cpu, _ = generate_permute_indices(
|
| 175 |
+
tokens_per_expert_group,
|
| 176 |
+
experts_per_rank,
|
| 177 |
+
num_ranks,
|
| 178 |
+
max_len,
|
| 179 |
+
alignment,
|
| 180 |
+
use_cpu=True,
|
| 181 |
+
)
|
| 182 |
+
# Check that the results are the same
|
| 183 |
+
assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu)
|
| 184 |
+
assert torch.equal(
|
| 185 |
+
torch.remainder(m_sizes, alignment),
|
| 186 |
+
torch.zeros(experts_per_rank, device=device),
|
| 187 |
+
)
|
| 188 |
+
# Print the results
|
| 189 |
+
print(permuted_indices_gpu)
|
| 190 |
+
print(m_sizes)
|
| 191 |
+
print("Success")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
test()
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .triton_on_device_all_to_all_v import OnDeviceAllToAllV
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"OnDeviceAllToAllV",
|
| 11 |
+
]
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
import torch.distributed._symmetric_memory as symm_mem
|
| 10 |
+
import triton
|
| 11 |
+
import triton.language as tl
|
| 12 |
+
|
| 13 |
+
from .triton_barrier import blockwise_barrier
|
| 14 |
+
from .triton_utils import sync_threads
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.jit
|
| 18 |
+
def _exchange_row_offsets(
|
| 19 |
+
split_sizes_ptrs,
|
| 20 |
+
rank: tl.constexpr,
|
| 21 |
+
world_size: tl.constexpr,
|
| 22 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
| 23 |
+
):
|
| 24 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
| 25 |
+
|
| 26 |
+
# split_sizes_ptr for all ranks
|
| 27 |
+
# All these vector stacks into split_sizes_matrix
|
| 28 |
+
split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
|
| 29 |
+
|
| 30 |
+
# split_sizes_matrix[remote_rank, :]
|
| 31 |
+
input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
|
| 32 |
+
tl.pointer_type(tl.int64)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
offsets_ = tl.arange(0, world_size)
|
| 36 |
+
input_split_sizes = tl.load(
|
| 37 |
+
input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
num_rows = tl.load(input_split_sizes_ptr + rank)
|
| 41 |
+
input_row_offset = tl.sum(input_split_sizes) - num_rows
|
| 42 |
+
|
| 43 |
+
# split_sizes_matrix[:, rank]
|
| 44 |
+
output_split_sizes_ptrs = (
|
| 45 |
+
tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
|
| 46 |
+
)
|
| 47 |
+
output_split_sizes = tl.load(
|
| 48 |
+
output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
|
| 49 |
+
)
|
| 50 |
+
output_row_offset = tl.sum(output_split_sizes) - num_rows
|
| 51 |
+
|
| 52 |
+
return input_row_offset, output_row_offset, num_rows
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@triton.jit
|
| 56 |
+
def on_device_all_to_all_v_kernel(
|
| 57 |
+
output_ptr,
|
| 58 |
+
output_splits_ptr,
|
| 59 |
+
input_ptrs,
|
| 60 |
+
input_splits_ptr,
|
| 61 |
+
signal_pad_ptrs,
|
| 62 |
+
dim: tl.constexpr, # Separate dim for easier vectorization
|
| 63 |
+
rank: tl.constexpr,
|
| 64 |
+
world_size: tl.constexpr,
|
| 65 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
| 66 |
+
UNROLL_FACTOR: tl.constexpr,
|
| 67 |
+
BLOCK_SIZE: tl.constexpr,
|
| 68 |
+
):
|
| 69 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
| 70 |
+
sync_threads()
|
| 71 |
+
|
| 72 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
| 73 |
+
block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
|
| 74 |
+
|
| 75 |
+
input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
|
| 76 |
+
input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
|
| 80 |
+
if block_offset == 0:
|
| 81 |
+
# Update output_splits
|
| 82 |
+
tl.store(output_splits_ptr + remote_rank, num_rows)
|
| 83 |
+
|
| 84 |
+
input_ptr = (
|
| 85 |
+
tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
|
| 86 |
+
tl.pointer_type(tl.bfloat16)
|
| 87 |
+
)
|
| 88 |
+
+ input_row_offset * dim
|
| 89 |
+
)
|
| 90 |
+
output_ptr = output_ptr + output_row_offset * dim
|
| 91 |
+
|
| 92 |
+
outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
|
| 93 |
+
outer_loop_iters_per_rank = tl.cdiv(
|
| 94 |
+
tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
|
| 95 |
+
)
|
| 96 |
+
numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
|
| 97 |
+
offset = numel_per_rank * block_offset
|
| 98 |
+
end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
|
| 99 |
+
|
| 100 |
+
unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
|
| 101 |
+
for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
|
| 102 |
+
datas = []
|
| 103 |
+
for j in tl.range(
|
| 104 |
+
i,
|
| 105 |
+
i + outer_loop_step,
|
| 106 |
+
BLOCK_SIZE,
|
| 107 |
+
loop_unroll_factor=UNROLL_FACTOR,
|
| 108 |
+
):
|
| 109 |
+
offsets = j + tl.arange(0, BLOCK_SIZE)
|
| 110 |
+
data = tl.load(input_ptr + offsets)
|
| 111 |
+
tl.store(output_ptr + offsets, data)
|
| 112 |
+
|
| 113 |
+
offset += unroll_region_size
|
| 114 |
+
while offset < end:
|
| 115 |
+
offsets = offset + tl.arange(0, BLOCK_SIZE)
|
| 116 |
+
mask = offsets < num_rows * dim
|
| 117 |
+
data = tl.load(input_ptr + offsets, mask=mask)
|
| 118 |
+
tl.store(output_ptr + offsets, data, mask=mask)
|
| 119 |
+
offset += BLOCK_SIZE
|
| 120 |
+
|
| 121 |
+
sync_threads()
|
| 122 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _on_device_all_to_all_v(
|
| 127 |
+
output: torch.Tensor,
|
| 128 |
+
output_splits: torch.Tensor,
|
| 129 |
+
input: torch.Tensor,
|
| 130 |
+
input_splits: torch.Tensor,
|
| 131 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
| 132 |
+
BLOCKS_PER_REMOTE_RANK=8,
|
| 133 |
+
UNROLL_FACTOR: int = 8,
|
| 134 |
+
BLOCK_SIZE: int = 16384,
|
| 135 |
+
):
|
| 136 |
+
assert output.dim() == 2, f"{output.shape}"
|
| 137 |
+
assert input.dim() == 2, f"{input.shape}"
|
| 138 |
+
assert output.shape[1] == input.shape[1]
|
| 139 |
+
|
| 140 |
+
dim = output.shape[1]
|
| 141 |
+
input_hdl = symm_mem.rendezvous(input, group=group)
|
| 142 |
+
input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
|
| 143 |
+
|
| 144 |
+
num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
|
| 145 |
+
kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
|
| 146 |
+
output,
|
| 147 |
+
output_splits,
|
| 148 |
+
input_hdl.buffer_ptrs_dev,
|
| 149 |
+
input_splits_hdl.buffer_ptrs_dev,
|
| 150 |
+
input_hdl.signal_pad_ptrs_dev,
|
| 151 |
+
dim=dim,
|
| 152 |
+
rank=input_hdl.rank,
|
| 153 |
+
world_size=input_hdl.world_size,
|
| 154 |
+
BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
|
| 155 |
+
UNROLL_FACTOR=UNROLL_FACTOR,
|
| 156 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 157 |
+
num_warps=16,
|
| 158 |
+
)
|
| 159 |
+
# log_triton_kernel(kernel)
|
| 160 |
+
return output
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class OnDeviceAllToAllV(torch.autograd.Function):
|
| 164 |
+
# A symmetric memory holding the grad_output during backward
|
| 165 |
+
grad_output_buf = None
|
| 166 |
+
# A symmetric memory for exchanges split sizes during both forward and backward
|
| 167 |
+
splits_buf = None
|
| 168 |
+
# Maximum output length (need to be set before use of OnDeviceAllToAllV)
|
| 169 |
+
max_output_len = None
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def forward(
|
| 173 |
+
ctx,
|
| 174 |
+
input: torch.Tensor,
|
| 175 |
+
input_splits: torch.Tensor,
|
| 176 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
| 177 |
+
):
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
input: input tensor with data for all ranks concatenated.
|
| 181 |
+
input_splits: input splits of shape (group.world_size,)
|
| 182 |
+
group: process group to scope the collective.
|
| 183 |
+
"""
|
| 184 |
+
# Initialize input splits buffer (one time only)
|
| 185 |
+
if OnDeviceAllToAllV.splits_buf is None:
|
| 186 |
+
OnDeviceAllToAllV.splits_buf = symm_mem.empty(
|
| 187 |
+
*input_splits.shape,
|
| 188 |
+
dtype=input_splits.dtype,
|
| 189 |
+
device=input_splits.device,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if OnDeviceAllToAllV.max_output_len is None:
|
| 193 |
+
raise RuntimeError(
|
| 194 |
+
"Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Allocate output buffer
|
| 198 |
+
output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
|
| 199 |
+
# Allocate output splits tensor
|
| 200 |
+
output_splits = torch.empty_like(input_splits)
|
| 201 |
+
# Copy input splits to the buffer
|
| 202 |
+
OnDeviceAllToAllV.splits_buf.copy_(input_splits)
|
| 203 |
+
|
| 204 |
+
# Shuffle input to output
|
| 205 |
+
_on_device_all_to_all_v(
|
| 206 |
+
output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Output splits in forward is the input splits in backward
|
| 210 |
+
ctx.save_for_backward(output_splits)
|
| 211 |
+
ctx.group = group
|
| 212 |
+
ctx.input_shape = input.shape
|
| 213 |
+
return output, output_splits
|
| 214 |
+
|
| 215 |
+
@staticmethod
|
| 216 |
+
def backward(ctx, grad_output, grad_splits):
|
| 217 |
+
"""
|
| 218 |
+
Backward is implemented as a shuffle of the output's gradients to the input.
|
| 219 |
+
Args:
|
| 220 |
+
`grad_output`: output's gradients passed from the downstream.
|
| 221 |
+
`grad_splits`: unused.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
# Initialize grad_output buffer (one time only)
|
| 225 |
+
if OnDeviceAllToAllV.grad_output_buf is None:
|
| 226 |
+
assert (
|
| 227 |
+
OnDeviceAllToAllV.max_output_len is not None
|
| 228 |
+
), "`max_output_len` not set"
|
| 229 |
+
OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
|
| 230 |
+
OnDeviceAllToAllV.max_output_len,
|
| 231 |
+
*grad_output.shape[1:],
|
| 232 |
+
dtype=grad_output.dtype,
|
| 233 |
+
device=grad_output.device,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# TODO: is there a way to tell autograd to feed grad_output directly to
|
| 237 |
+
# our symm_mem buffer?
|
| 238 |
+
OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
|
| 239 |
+
grad_output
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Size info
|
| 243 |
+
(grad_output_splits,) = ctx.saved_tensors
|
| 244 |
+
OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
|
| 245 |
+
grad_input_splits = torch.empty_like(grad_output_splits) # unused
|
| 246 |
+
grad_input = grad_output.new_empty(*ctx.input_shape)
|
| 247 |
+
|
| 248 |
+
# Shuffle gradients back to the input
|
| 249 |
+
_on_device_all_to_all_v(
|
| 250 |
+
grad_input,
|
| 251 |
+
grad_input_splits,
|
| 252 |
+
OnDeviceAllToAllV.grad_output_buf,
|
| 253 |
+
OnDeviceAllToAllV.splits_buf,
|
| 254 |
+
group=ctx.group,
|
| 255 |
+
)
|
| 256 |
+
return grad_input, None, None
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# Alias
|
| 260 |
+
on_device_all_to_all_v = OnDeviceAllToAllV.apply
|
torchtitan/experiments/deepseek_v3/train.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# torchrun --standalone --nproc-per-node 8 run.py
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from checkpoint import load_weights_from_hf
|
| 11 |
+
from model import DeepseekForCausalLM
|
| 12 |
+
from model_config import deepseek_config_registry
|
| 13 |
+
|
| 14 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 15 |
+
from torch.distributed.fsdp import fully_shard
|
| 16 |
+
from torch.distributed.pipelining import PipelineStage, Schedule1F1B
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Use DeepSeek-V2-Lite as a proxy
|
| 20 |
+
model_id = "deepseek-ai/DeepSeek-V2-Lite"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Run full model
|
| 24 |
+
def run_full_model(
|
| 25 |
+
mesh: DeviceMesh,
|
| 26 |
+
):
|
| 27 |
+
rank = dist.get_rank()
|
| 28 |
+
device_count = torch.cuda.device_count()
|
| 29 |
+
device = torch.device("cuda", rank % device_count)
|
| 30 |
+
|
| 31 |
+
pp_mesh = mesh["pp"]
|
| 32 |
+
ep_mesh = mesh["ep"]
|
| 33 |
+
pp_rank = pp_mesh.get_local_rank()
|
| 34 |
+
ep_rank = ep_mesh.get_local_rank()
|
| 35 |
+
pp_size = pp_mesh.size()
|
| 36 |
+
ep_size = ep_mesh.size()
|
| 37 |
+
|
| 38 |
+
# Get model configs
|
| 39 |
+
model_args = deepseek_config_registry[model_id]
|
| 40 |
+
# [Note]: I am making the model smaller for testing / avoiding OOM. If you
|
| 41 |
+
# have sufficient GPUs for model parallelism, you can remove this line.
|
| 42 |
+
model_args.num_hidden_layers = 16
|
| 43 |
+
|
| 44 |
+
# Apply model parallelism
|
| 45 |
+
model_args.ep_size = ep_size
|
| 46 |
+
model_args.num_stages = pp_size
|
| 47 |
+
model_args.stage_idx = pp_rank
|
| 48 |
+
print(model_args)
|
| 49 |
+
|
| 50 |
+
# Instantiate model
|
| 51 |
+
with device, mesh:
|
| 52 |
+
model = DeepseekForCausalLM(model_args)
|
| 53 |
+
|
| 54 |
+
# Load weights
|
| 55 |
+
load_weights_from_hf(model, model_id, device)
|
| 56 |
+
model.train()
|
| 57 |
+
|
| 58 |
+
# Apply data parallelism
|
| 59 |
+
fsdp_mesh = mesh["fsdp"]
|
| 60 |
+
hsdp_mesh = mesh["ep", "fsdp"]
|
| 61 |
+
# Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the
|
| 62 |
+
# optimizer (Zero-1) and gradients (Zero-2), but not the model weights.
|
| 63 |
+
# Reason: the MoE is "sparsely activated" compared to the dense model, thus
|
| 64 |
+
# it will be ineconomical re-gather the weights.
|
| 65 |
+
for layer in model.model.layers.values():
|
| 66 |
+
# Apply FSDP to experts
|
| 67 |
+
if hasattr(layer.mlp, "experts"):
|
| 68 |
+
for expert in layer.mlp.experts.values():
|
| 69 |
+
fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False)
|
| 70 |
+
# Apply HSDP to other parts such as attention, layernorm, because they
|
| 71 |
+
# are doing DDP on EP dimension
|
| 72 |
+
fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False)
|
| 73 |
+
|
| 74 |
+
# Apply HSDP on root model (lm_head, embeddings, etc)
|
| 75 |
+
fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)
|
| 76 |
+
|
| 77 |
+
# Synthetic setting
|
| 78 |
+
microbatches = pp_size * 2
|
| 79 |
+
|
| 80 |
+
# Use Symmetric Memory for MoE token shuffle.
|
| 81 |
+
# TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is
|
| 82 |
+
# currently supported for forward only. See `generate.py`.
|
| 83 |
+
# model.setup_symm_mem(torch.bfloat16, device)
|
| 84 |
+
|
| 85 |
+
# Example inputs
|
| 86 |
+
torch.manual_seed(ep_rank)
|
| 87 |
+
bs = 4
|
| 88 |
+
seqlen = 128
|
| 89 |
+
x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device)
|
| 90 |
+
label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device)
|
| 91 |
+
|
| 92 |
+
# Create loss function
|
| 93 |
+
loss_fn = torch.nn.functional.cross_entropy
|
| 94 |
+
|
| 95 |
+
# Run forward and backward
|
| 96 |
+
steps = 2
|
| 97 |
+
for _ in range(steps):
|
| 98 |
+
if pp_size > 1:
|
| 99 |
+
# Create pipeline stage
|
| 100 |
+
stage = PipelineStage(
|
| 101 |
+
model,
|
| 102 |
+
pp_rank,
|
| 103 |
+
pp_size,
|
| 104 |
+
device,
|
| 105 |
+
group=pp_mesh.get_group(),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Create pipeline schedule
|
| 109 |
+
losses = []
|
| 110 |
+
pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn)
|
| 111 |
+
|
| 112 |
+
if pp_rank == 0:
|
| 113 |
+
y = pp_schedule.step(x)
|
| 114 |
+
elif pp_rank == pp_size - 1:
|
| 115 |
+
y = pp_schedule.step(target=label, losses=losses)
|
| 116 |
+
loss = torch.mean(torch.stack(losses))
|
| 117 |
+
else:
|
| 118 |
+
pp_schedule.step()
|
| 119 |
+
else:
|
| 120 |
+
y = model(x)
|
| 121 |
+
loss = loss_fn(y, label)
|
| 122 |
+
loss.backward()
|
| 123 |
+
|
| 124 |
+
if pp_rank == pp_size - 1:
|
| 125 |
+
print(f"logits: {y.shape}")
|
| 126 |
+
print(f"{loss=}")
|
| 127 |
+
|
| 128 |
+
if pp_rank == 0:
|
| 129 |
+
param = model.get_parameter("model.layers.0.self_attn.q_proj.weight")
|
| 130 |
+
print(f"{torch.linalg.norm(param.grad)=}")
|
| 131 |
+
|
| 132 |
+
model.zero_grad()
|
| 133 |
+
|
| 134 |
+
print("Backward done")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp"))
|
| 139 |
+
|
| 140 |
+
run_full_model(mesh)
|
| 141 |
+
|
| 142 |
+
dist.destroy_process_group()
|
torchtitan/experiments/flux/README.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FLUX model in torchtitan
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
## Usage
|
| 6 |
+
First, download the autoencoder model from HuggingFace with your own access token:
|
| 7 |
+
```bash
|
| 8 |
+
python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
|
| 9 |
+
```
|
| 10 |
+
This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
|
| 11 |
+
|
| 12 |
+
Run the following command to train the model on a single GPU:
|
| 13 |
+
```bash
|
| 14 |
+
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## TODO
|
| 18 |
+
- [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
|
| 19 |
+
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
|
| 20 |
+
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
|
| 21 |
+
- [ ] Support for distributed checkpointing and loading
|
| 22 |
+
- [ ] Implement init_weights() function to initialize the model weights
|
| 23 |
+
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
|
torchtitan/experiments/flux/dataset/flux_dataset.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import random
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Callable, Optional
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from datasets import Dataset, load_dataset
|
| 17 |
+
from datasets.distributed import split_dataset_by_node
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 21 |
+
|
| 22 |
+
from torch.utils.data import IterableDataset
|
| 23 |
+
from torchtitan.components.dataloader import ParallelAwareDataloader
|
| 24 |
+
|
| 25 |
+
from torchtitan.config_manager import JobConfig
|
| 26 |
+
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
|
| 27 |
+
from torchtitan.tools.logging import logger
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _process_cc12m_image(
|
| 31 |
+
img: Image.Image,
|
| 32 |
+
output_size: int = 256,
|
| 33 |
+
) -> Optional[torch.Tensor]:
|
| 34 |
+
"""Process CC12M image to the desired size."""
|
| 35 |
+
|
| 36 |
+
width, height = img.size
|
| 37 |
+
# Skip low resolution images
|
| 38 |
+
if width < output_size or height < output_size:
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
if width >= height:
|
| 42 |
+
# resize height to be equal to output_size, then crop
|
| 43 |
+
new_width, new_height = math.ceil(output_size / height * width), output_size
|
| 44 |
+
img = img.resize((new_width, new_height))
|
| 45 |
+
left = random.randint(0, new_width - output_size)
|
| 46 |
+
resized_img = img.crop((left, 0, left + output_size, output_size))
|
| 47 |
+
else:
|
| 48 |
+
# resize width to be equal to output_size, the crop
|
| 49 |
+
new_width, new_height = (
|
| 50 |
+
output_size,
|
| 51 |
+
math.ceil(output_size / width * height),
|
| 52 |
+
)
|
| 53 |
+
img = img.resize((new_width, new_height))
|
| 54 |
+
lower = random.randint(0, new_width - output_size)
|
| 55 |
+
resized_img = img.crop((0, lower, output_size, lower + output_size))
|
| 56 |
+
|
| 57 |
+
assert resized_img.size[0] == resized_img.size[1] == output_size
|
| 58 |
+
|
| 59 |
+
# Skip grayscale images
|
| 60 |
+
if resized_img.mode == "L":
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
np_img = np.array(resized_img).transpose((2, 0, 1))
|
| 64 |
+
tensor_img = torch.tensor(np_img).float() / 255.0
|
| 65 |
+
|
| 66 |
+
# NOTE: The following commented code is an alternative way
|
| 67 |
+
# img_transform = transforms.Compose(
|
| 68 |
+
# [
|
| 69 |
+
# transforms.Resize(max(output_size, output_size)),
|
| 70 |
+
# transforms.CenterCrop((output_size, output_size)),
|
| 71 |
+
# transforms.ToTensor(),
|
| 72 |
+
# ]
|
| 73 |
+
# )
|
| 74 |
+
# tensor_img = img_transform(img)
|
| 75 |
+
|
| 76 |
+
return tensor_img
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _flux_data_processor(
|
| 80 |
+
sample: dict[str, Any],
|
| 81 |
+
t5_tokenizer: FluxTokenizer,
|
| 82 |
+
clip_tokenizer: FluxTokenizer,
|
| 83 |
+
output_size: int = 256,
|
| 84 |
+
) -> dict[str, Any]:
|
| 85 |
+
"""
|
| 86 |
+
Preprocess CC12M dataset sample image and text for Flux model.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
sample: A sample from dataset
|
| 90 |
+
t5_encoder: T5 encoder
|
| 91 |
+
clip_encoder: CLIP encoder
|
| 92 |
+
output_size: The output image size
|
| 93 |
+
|
| 94 |
+
"""
|
| 95 |
+
img = _process_cc12m_image(sample["jpg"], output_size=output_size)
|
| 96 |
+
t5_tokens = t5_tokenizer.encode(sample["txt"])
|
| 97 |
+
clip_tokens = clip_tokenizer.encode(sample["txt"])
|
| 98 |
+
|
| 99 |
+
return {
|
| 100 |
+
"image": img,
|
| 101 |
+
"clip_tokens": clip_tokens, # type: List[int]
|
| 102 |
+
"t5_tokens": t5_tokens, # type: List[int]
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class TextToImageDatasetConfig:
|
| 108 |
+
path: str
|
| 109 |
+
loader: Callable
|
| 110 |
+
data_processor: Callable
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
DATASETS = {
|
| 114 |
+
"cc12m": TextToImageDatasetConfig(
|
| 115 |
+
path="pixparse/cc12m-wds",
|
| 116 |
+
loader=lambda path: load_dataset(path, split="train", streaming=True),
|
| 117 |
+
data_processor=_flux_data_processor,
|
| 118 |
+
),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _validate_dataset(
|
| 123 |
+
dataset_name: str, dataset_path: Optional[str] = None
|
| 124 |
+
) -> tuple[str, Callable, Callable]:
|
| 125 |
+
"""Validate dataset name and path."""
|
| 126 |
+
if dataset_name not in DATASETS:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Dataset {dataset_name} is not supported. "
|
| 129 |
+
f"Supported datasets are: {list(DATASETS.keys())}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
config = DATASETS[dataset_name]
|
| 133 |
+
path = dataset_path or config.path
|
| 134 |
+
logger.info(f"Preparing {dataset_name} dataset from {path}")
|
| 135 |
+
return path, config.loader, config.data_processor
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class FluxDataset(IterableDataset, Stateful):
|
| 139 |
+
"""Dataset for FLUX text-to-image model.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
dataset_name (str): Name of the dataset.
|
| 143 |
+
dataset_path (str): Path to the dataset.
|
| 144 |
+
model_transform (Transform): Callable that applies model-specific preprocessing to the sample.
|
| 145 |
+
dp_rank (int): Data parallel rank.
|
| 146 |
+
dp_world_size (int): Data parallel world size.
|
| 147 |
+
infinite (bool): Whether to loop over the dataset infinitely.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
dataset_name: str,
|
| 153 |
+
dataset_path: Optional[str],
|
| 154 |
+
t5_tokenizer: FluxTokenizer,
|
| 155 |
+
clip_tokenizer: FluxTokenizer,
|
| 156 |
+
job_config: Optional[JobConfig] = None,
|
| 157 |
+
dp_rank: int = 0,
|
| 158 |
+
dp_world_size: int = 1,
|
| 159 |
+
infinite: bool = False,
|
| 160 |
+
) -> None:
|
| 161 |
+
|
| 162 |
+
# Force lowercase for consistent comparison
|
| 163 |
+
dataset_name = dataset_name.lower()
|
| 164 |
+
|
| 165 |
+
path, dataset_loader, data_processor = _validate_dataset(
|
| 166 |
+
dataset_name, dataset_path
|
| 167 |
+
)
|
| 168 |
+
ds = dataset_loader(path)
|
| 169 |
+
|
| 170 |
+
self.dataset_name = dataset_name
|
| 171 |
+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
|
| 172 |
+
|
| 173 |
+
self._t5_tokenizer = t5_tokenizer
|
| 174 |
+
self._clip_tokenizer = clip_tokenizer
|
| 175 |
+
self._data_processor = data_processor
|
| 176 |
+
self.job_config = job_config
|
| 177 |
+
|
| 178 |
+
self.infinite = infinite
|
| 179 |
+
|
| 180 |
+
# Variables for checkpointing
|
| 181 |
+
self._sample_idx = 0
|
| 182 |
+
self._all_samples: list[dict[str, Any]] = []
|
| 183 |
+
|
| 184 |
+
def _get_data_iter(self):
|
| 185 |
+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
|
| 186 |
+
return iter([])
|
| 187 |
+
|
| 188 |
+
it = iter(self._data)
|
| 189 |
+
for _ in range(self._sample_idx):
|
| 190 |
+
next(it)
|
| 191 |
+
return it
|
| 192 |
+
|
| 193 |
+
def __iter__(self):
|
| 194 |
+
while True:
|
| 195 |
+
for sample in self._get_data_iter():
|
| 196 |
+
# Use the dataset-specific preprocessor
|
| 197 |
+
sample_dict = self._data_processor(
|
| 198 |
+
sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# skip low quality image or image with color channel = 1
|
| 202 |
+
if sample_dict["image"] is None:
|
| 203 |
+
logger.warning(
|
| 204 |
+
f"Low quality image {sample['__key__']} is skipped in Flux Dataloader"
|
| 205 |
+
)
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
self._all_samples.extend(sample_dict)
|
| 209 |
+
self._sample_idx += 1
|
| 210 |
+
|
| 211 |
+
labels = sample_dict.pop("image")
|
| 212 |
+
yield sample_dict, labels
|
| 213 |
+
|
| 214 |
+
if not self.infinite:
|
| 215 |
+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
|
| 216 |
+
break
|
| 217 |
+
else:
|
| 218 |
+
# Reset offset for the next iteration
|
| 219 |
+
self._sample_idx = 0
|
| 220 |
+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
|
| 221 |
+
|
| 222 |
+
def load_state_dict(self, state_dict):
|
| 223 |
+
self._sample_idx = state_dict["sample_idx"]
|
| 224 |
+
self._all_samples = state_dict["all_samples"]
|
| 225 |
+
|
| 226 |
+
def state_dict(self):
|
| 227 |
+
return {
|
| 228 |
+
"all_samples": self._all_samples,
|
| 229 |
+
"sample_idx": self._sample_idx,
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def build_flux_dataloader(
|
| 234 |
+
dp_world_size: int,
|
| 235 |
+
dp_rank: int,
|
| 236 |
+
job_config: JobConfig,
|
| 237 |
+
# This parameter is not used, keep it for compatibility
|
| 238 |
+
tokenizer: FluxTokenizer | None,
|
| 239 |
+
infinite: bool = True,
|
| 240 |
+
) -> ParallelAwareDataloader:
|
| 241 |
+
"""Build a data loader for HuggingFace datasets."""
|
| 242 |
+
dataset_name = job_config.training.dataset
|
| 243 |
+
dataset_path = job_config.training.dataset_path
|
| 244 |
+
batch_size = job_config.training.batch_size
|
| 245 |
+
|
| 246 |
+
t5_encoder_name = job_config.encoder.t5_encoder
|
| 247 |
+
clip_encoder_name = job_config.encoder.clip_encoder
|
| 248 |
+
max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
|
| 249 |
+
|
| 250 |
+
ds = FluxDataset(
|
| 251 |
+
dataset_name=dataset_name,
|
| 252 |
+
dataset_path=dataset_path,
|
| 253 |
+
t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len),
|
| 254 |
+
clip_tokenizer=FluxTokenizer(
|
| 255 |
+
clip_encoder_name, max_length=77
|
| 256 |
+
), # fix max_length for CLIP
|
| 257 |
+
dp_rank=dp_rank,
|
| 258 |
+
dp_world_size=dp_world_size,
|
| 259 |
+
infinite=infinite,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return ParallelAwareDataloader(
|
| 263 |
+
dataset=ds,
|
| 264 |
+
dp_rank=dp_rank,
|
| 265 |
+
dp_world_size=dp_world_size,
|
| 266 |
+
batch_size=batch_size,
|
| 267 |
+
)
|
torchtitan/experiments/flux/dataset/tokenizer.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 8 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import List
|
| 12 |
+
|
| 13 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 14 |
+
from transformers import CLIPTokenizer, T5Tokenizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FluxTokenizer(Tokenizer):
|
| 18 |
+
"""
|
| 19 |
+
Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model_path (str): Path to the tokenzier from hugging face.
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, model_path: str = "t5-small", max_length: int = 77):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self._n_words = 8 # TODO(jianiw): check
|
| 29 |
+
self._max_length = max_length
|
| 30 |
+
|
| 31 |
+
self.is_clip = model_path.startswith("openai")
|
| 32 |
+
|
| 33 |
+
if self.is_clip:
|
| 34 |
+
self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
|
| 35 |
+
model_path, max_length=max_length
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
|
| 39 |
+
model_path, max_length=max_length
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def encode(
|
| 43 |
+
self,
|
| 44 |
+
s: str,
|
| 45 |
+
) -> List[int]:
|
| 46 |
+
"""
|
| 47 |
+
Encode the prompt text into tokens.
|
| 48 |
+
"""
|
| 49 |
+
tokens = self._tokenizer(
|
| 50 |
+
s,
|
| 51 |
+
truncation=True,
|
| 52 |
+
max_length=self._max_length,
|
| 53 |
+
return_length=False,
|
| 54 |
+
return_overflowing_tokens=False,
|
| 55 |
+
padding="max_length",
|
| 56 |
+
return_tensors="pt", # return pytorch tensors, default return List[int]
|
| 57 |
+
)["input_ids"]
|
| 58 |
+
return tokens
|
| 59 |
+
|
| 60 |
+
def decode(self, t: List[int]) -> str:
|
| 61 |
+
"""
|
| 62 |
+
Decode function. This function will not be called.
|
| 63 |
+
"""
|
| 64 |
+
return self._tokenizer.decode(t)
|
torchtitan/experiments/flux/model/hf_embedder.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torch import nn, Tensor
|
| 8 |
+
from transformers import CLIPTextModel, T5EncoderModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FluxEmbedder(nn.Module):
|
| 12 |
+
def __init__(self, version: str, **hf_kwargs):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.is_clip = version.startswith("openai")
|
| 15 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
| 16 |
+
|
| 17 |
+
if self.is_clip:
|
| 18 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
|
| 19 |
+
version, **hf_kwargs
|
| 20 |
+
)
|
| 21 |
+
else:
|
| 22 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
|
| 23 |
+
version, **hf_kwargs
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
| 27 |
+
|
| 28 |
+
def forward(self, batch_tokens: Tensor) -> Tensor:
|
| 29 |
+
"""
|
| 30 |
+
batch_tokens: [bsz, embedding_length]
|
| 31 |
+
|
| 32 |
+
For T5 Encoder, embeding_length is 768
|
| 33 |
+
For CLIP, embedding_length is 256
|
| 34 |
+
"""
|
| 35 |
+
outputs = self.hf_module(
|
| 36 |
+
input_ids=batch_tokens.to(self.hf_module.device),
|
| 37 |
+
attention_mask=None,
|
| 38 |
+
output_hidden_states=False,
|
| 39 |
+
)
|
| 40 |
+
return outputs[self.output_key]
|
torchtitan/experiments/flux/model/layers.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# imported from black-forest-labs/FLUX
|
| 8 |
+
import math
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from torch import nn, Tensor
|
| 14 |
+
|
| 15 |
+
from torchtitan.experiments.flux.model.math import attention, rope
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EmbedND(nn.Module):
|
| 19 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.dim = dim
|
| 22 |
+
self.theta = theta
|
| 23 |
+
self.axes_dim = axes_dim
|
| 24 |
+
|
| 25 |
+
def forward(self, ids: Tensor) -> Tensor:
|
| 26 |
+
n_axes = ids.shape[-1]
|
| 27 |
+
emb = torch.cat(
|
| 28 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
| 29 |
+
dim=-3,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return emb.unsqueeze(1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
| 36 |
+
"""
|
| 37 |
+
Create sinusoidal timestep embeddings.
|
| 38 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 39 |
+
These may be fractional.
|
| 40 |
+
:param dim: the dimension of the output.
|
| 41 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 42 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 43 |
+
"""
|
| 44 |
+
t = time_factor * t
|
| 45 |
+
half = dim // 2
|
| 46 |
+
freqs = torch.exp(
|
| 47 |
+
-math.log(max_period)
|
| 48 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 49 |
+
/ half
|
| 50 |
+
).to(t.device)
|
| 51 |
+
|
| 52 |
+
args = t[:, None].float() * freqs[None]
|
| 53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 54 |
+
if dim % 2:
|
| 55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 56 |
+
if torch.is_floating_point(t):
|
| 57 |
+
embedding = embedding.to(t)
|
| 58 |
+
return embedding
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MLPEmbedder(nn.Module):
|
| 62 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
| 65 |
+
self.silu = nn.SiLU()
|
| 66 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RMSNorm(torch.nn.Module):
|
| 73 |
+
def __init__(self, dim: int):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
| 76 |
+
|
| 77 |
+
def forward(self, x: Tensor):
|
| 78 |
+
x_dtype = x.dtype
|
| 79 |
+
x = x.float()
|
| 80 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
| 81 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class QKNorm(torch.nn.Module):
|
| 85 |
+
def __init__(self, dim: int):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.query_norm = RMSNorm(dim) # TODO(jianiw): switch to pytorch nn.RMSNorm
|
| 88 |
+
self.key_norm = RMSNorm(dim)
|
| 89 |
+
|
| 90 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
| 91 |
+
q = self.query_norm(q)
|
| 92 |
+
k = self.key_norm(k)
|
| 93 |
+
return q.to(v), k.to(v)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SelfAttention(nn.Module):
|
| 97 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.num_heads = num_heads
|
| 100 |
+
head_dim = dim // num_heads
|
| 101 |
+
|
| 102 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 103 |
+
self.norm = QKNorm(head_dim)
|
| 104 |
+
self.proj = nn.Linear(dim, dim)
|
| 105 |
+
|
| 106 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
| 107 |
+
qkv = self.qkv(x)
|
| 108 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 109 |
+
q, k = self.norm(q, k, v)
|
| 110 |
+
x = attention(q, k, v, pe=pe)
|
| 111 |
+
x = self.proj(x)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dataclass
|
| 116 |
+
class ModulationOut:
|
| 117 |
+
shift: Tensor
|
| 118 |
+
scale: Tensor
|
| 119 |
+
gate: Tensor
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Modulation(nn.Module):
|
| 123 |
+
def __init__(self, dim: int, double: bool):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.is_double = double
|
| 126 |
+
self.multiplier = 6 if double else 3
|
| 127 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 128 |
+
|
| 129 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
| 130 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
|
| 131 |
+
self.multiplier, dim=-1
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return (
|
| 135 |
+
ModulationOut(*out[:3]),
|
| 136 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class DoubleStreamBlock(nn.Module):
|
| 141 |
+
def __init__(
|
| 142 |
+
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
| 143 |
+
):
|
| 144 |
+
super().__init__()
|
| 145 |
+
|
| 146 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 147 |
+
self.num_heads = num_heads
|
| 148 |
+
self.hidden_size = hidden_size
|
| 149 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
| 150 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 151 |
+
self.img_attn = SelfAttention(
|
| 152 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 156 |
+
self.img_mlp = nn.Sequential(
|
| 157 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 158 |
+
nn.GELU(approximate="tanh"),
|
| 159 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
| 163 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 164 |
+
self.txt_attn = SelfAttention(
|
| 165 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 169 |
+
self.txt_mlp = nn.Sequential(
|
| 170 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 171 |
+
nn.GELU(approximate="tanh"),
|
| 172 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def forward(
|
| 176 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
|
| 177 |
+
) -> tuple[Tensor, Tensor]:
|
| 178 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
| 179 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
| 180 |
+
|
| 181 |
+
# prepare image for attention
|
| 182 |
+
img_modulated = self.img_norm1(img)
|
| 183 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 184 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
| 185 |
+
img_q, img_k, img_v = rearrange(
|
| 186 |
+
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
| 187 |
+
)
|
| 188 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
| 189 |
+
|
| 190 |
+
# prepare txt for attention
|
| 191 |
+
txt_modulated = self.txt_norm1(txt)
|
| 192 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 193 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
| 194 |
+
txt_q, txt_k, txt_v = rearrange(
|
| 195 |
+
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
| 196 |
+
)
|
| 197 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 198 |
+
|
| 199 |
+
# run actual attention
|
| 200 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
| 201 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
| 202 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
| 203 |
+
|
| 204 |
+
attn = attention(q, k, v, pe=pe)
|
| 205 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
| 206 |
+
|
| 207 |
+
# calculate the img bloks
|
| 208 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
| 209 |
+
img = img + img_mod2.gate * self.img_mlp(
|
| 210 |
+
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# calculate the txt bloks
|
| 214 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
| 215 |
+
txt = txt + txt_mod2.gate * self.txt_mlp(
|
| 216 |
+
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
| 217 |
+
)
|
| 218 |
+
return img, txt
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class SingleStreamBlock(nn.Module):
|
| 222 |
+
"""
|
| 223 |
+
A DiT block with parallel linear layers as described in
|
| 224 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
def __init__(
|
| 228 |
+
self,
|
| 229 |
+
hidden_size: int,
|
| 230 |
+
num_heads: int,
|
| 231 |
+
mlp_ratio: float = 4.0,
|
| 232 |
+
qk_scale: float | None = None,
|
| 233 |
+
):
|
| 234 |
+
super().__init__()
|
| 235 |
+
self.hidden_dim = hidden_size
|
| 236 |
+
self.num_heads = num_heads
|
| 237 |
+
head_dim = hidden_size // num_heads
|
| 238 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 239 |
+
|
| 240 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 241 |
+
# qkv and mlp_in
|
| 242 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
| 243 |
+
# proj and mlp_out
|
| 244 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
| 245 |
+
|
| 246 |
+
self.norm = QKNorm(head_dim)
|
| 247 |
+
|
| 248 |
+
self.hidden_size = hidden_size
|
| 249 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 250 |
+
|
| 251 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
| 252 |
+
self.modulation = Modulation(hidden_size, double=False)
|
| 253 |
+
|
| 254 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
| 255 |
+
mod, _ = self.modulation(vec)
|
| 256 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
| 257 |
+
qkv, mlp = torch.split(
|
| 258 |
+
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 262 |
+
q, k = self.norm(q, k, v)
|
| 263 |
+
|
| 264 |
+
# compute attention
|
| 265 |
+
attn = attention(q, k, v, pe=pe)
|
| 266 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 267 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
| 268 |
+
return x + mod.gate * output
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class LastLayer(nn.Module):
|
| 272 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 275 |
+
self.linear = nn.Linear(
|
| 276 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True
|
| 277 |
+
)
|
| 278 |
+
self.adaLN_modulation = nn.Sequential(
|
| 279 |
+
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
| 283 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
| 284 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
| 285 |
+
x = self.linear(x)
|
| 286 |
+
return x
|
torchtitan/experiments/flux/tests/test_flux_dataloader.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from torchtitan.config_manager import JobConfig
|
| 10 |
+
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
|
| 11 |
+
from torchtitan.tools.profiling import (
|
| 12 |
+
maybe_enable_memory_snapshot,
|
| 13 |
+
maybe_enable_profiling,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestFluxDataLoader:
|
| 18 |
+
def test_flux_dataloader(self):
|
| 19 |
+
dataset_name = "cc12m"
|
| 20 |
+
batch_size = 32
|
| 21 |
+
world_size = 4
|
| 22 |
+
rank = 0
|
| 23 |
+
|
| 24 |
+
num_steps = 10
|
| 25 |
+
|
| 26 |
+
path = "torchtitan.experiments.flux.flux_argparser"
|
| 27 |
+
sys.argv.append(f"--experimental.custom_args_module={path}")
|
| 28 |
+
config = JobConfig()
|
| 29 |
+
config.maybe_add_custom_args()
|
| 30 |
+
config.parse_args(
|
| 31 |
+
[
|
| 32 |
+
# Profiling options
|
| 33 |
+
# "--profiling.enable_profiling",
|
| 34 |
+
# "--profiling.profile_freq",
|
| 35 |
+
# "5",
|
| 36 |
+
# "--profiling.enable_memory_snapshot",
|
| 37 |
+
# "--profiling.save_memory_snapshot_folder",
|
| 38 |
+
# "memory_snapshot_flux",
|
| 39 |
+
"--training.dataset",
|
| 40 |
+
dataset_name,
|
| 41 |
+
"--training.batch_size",
|
| 42 |
+
str(batch_size),
|
| 43 |
+
"--encoder.t5_encoder",
|
| 44 |
+
"google/t5-v1_1-small",
|
| 45 |
+
"--encoder.clip_encoder",
|
| 46 |
+
"openai/clip-vit-large-patch14",
|
| 47 |
+
"--encoder.max_t5_encoding_len",
|
| 48 |
+
"512",
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
with maybe_enable_profiling(
|
| 53 |
+
config, global_step=0
|
| 54 |
+
) as torch_profiler, maybe_enable_memory_snapshot(
|
| 55 |
+
config, global_step=0
|
| 56 |
+
) as memory_profiler:
|
| 57 |
+
dl = self._build_dataloader(
|
| 58 |
+
config,
|
| 59 |
+
world_size,
|
| 60 |
+
rank,
|
| 61 |
+
)
|
| 62 |
+
dl = iter(dl)
|
| 63 |
+
|
| 64 |
+
for i in range(0, num_steps):
|
| 65 |
+
input_data, labels = next(dl)
|
| 66 |
+
print(f"Step {i} image size: {labels.shape}")
|
| 67 |
+
if torch_profiler:
|
| 68 |
+
torch_profiler.step()
|
| 69 |
+
if memory_profiler:
|
| 70 |
+
memory_profiler.step()
|
| 71 |
+
|
| 72 |
+
print(len(input_data["clip_tokens"]))
|
| 73 |
+
for k, v in input_data.items():
|
| 74 |
+
print(f"Step {i} {k} value: {type(v), v.shape}")
|
| 75 |
+
|
| 76 |
+
assert len(input_data) == 2 # (clip_encodings, t5_encodings)
|
| 77 |
+
assert labels.shape == (batch_size, 3, 256, 256)
|
| 78 |
+
# assert input_data["clip_tokens"].shape[0] == batch_size
|
| 79 |
+
# assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
|
| 80 |
+
|
| 81 |
+
if torch_profiler:
|
| 82 |
+
torch_profiler.step()
|
| 83 |
+
if memory_profiler:
|
| 84 |
+
memory_profiler.step(exit_ctx=True)
|
| 85 |
+
|
| 86 |
+
def test_preprocess(self):
|
| 87 |
+
# TODO
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
def _build_dataloader(
|
| 91 |
+
self,
|
| 92 |
+
job_config,
|
| 93 |
+
world_size,
|
| 94 |
+
rank,
|
| 95 |
+
):
|
| 96 |
+
|
| 97 |
+
return build_flux_dataloader(
|
| 98 |
+
dp_world_size=world_size,
|
| 99 |
+
dp_rank=rank,
|
| 100 |
+
job_config=job_config,
|
| 101 |
+
tokenizer=None,
|
| 102 |
+
infinite=False,
|
| 103 |
+
)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# pyre-unsafe
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from reference_utils import (
|
| 14 |
+
analyze_tensor_differences,
|
| 15 |
+
compute_reference_backward,
|
| 16 |
+
compute_reference_forward,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Configure logging
|
| 20 |
+
logging.basicConfig(
|
| 21 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Import grouped GEMM implementations
|
| 25 |
+
try:
|
| 26 |
+
from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward
|
| 27 |
+
|
| 28 |
+
except ImportError:
|
| 29 |
+
logging.error(
|
| 30 |
+
"Error importing grouped GEMM modules. Make sure the implementation files are in the correct path."
|
| 31 |
+
)
|
| 32 |
+
raise
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_forward_pass():
|
| 36 |
+
"""
|
| 37 |
+
A simple test for the M*G grouped GEMM forward pass with detailed error handling.
|
| 38 |
+
|
| 39 |
+
In M*G grouping:
|
| 40 |
+
- M dimension is partitioned into G groups (M_total = sum(M_sizes))
|
| 41 |
+
- N dimension is the same for all groups
|
| 42 |
+
"""
|
| 43 |
+
try:
|
| 44 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
+
|
| 46 |
+
# Test parameters for DeepSeek-like models
|
| 47 |
+
G = 1 # Number of groups
|
| 48 |
+
M_sizes = [
|
| 49 |
+
2048,
|
| 50 |
+
] # 2048, 2048, 2048] # Group sizes (will be adjusted)
|
| 51 |
+
M_total = sum(M_sizes) # Total M dimension
|
| 52 |
+
N = 4096 # Output dimension (same for all groups)
|
| 53 |
+
K = 7168 # Hidden dimension
|
| 54 |
+
|
| 55 |
+
# Create group sizes tensor
|
| 56 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 57 |
+
|
| 58 |
+
# Create input and weight tensors - using float16 for higher precision
|
| 59 |
+
x = torch.randn(M_total, K, dtype=torch.float16, device=device)
|
| 60 |
+
w = torch.randn(N, K, dtype=torch.float16, device=device)
|
| 61 |
+
|
| 62 |
+
# Log the setup
|
| 63 |
+
logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
|
| 64 |
+
logging.info(f"Group sizes: {m_sizes}")
|
| 65 |
+
logging.info(f"Input x shape: {x.shape}")
|
| 66 |
+
logging.info(f"Weight w shape: {w.shape}")
|
| 67 |
+
|
| 68 |
+
# Run forward pass
|
| 69 |
+
logging.info("Running forward pass with grouped GEMM")
|
| 70 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
| 71 |
+
logging.info(f"Forward result shape: {result.shape}")
|
| 72 |
+
|
| 73 |
+
# Compute reference result
|
| 74 |
+
logging.info("Computing reference result with PyTorch")
|
| 75 |
+
reference_result = compute_reference_forward(x, w, m_sizes)
|
| 76 |
+
|
| 77 |
+
# Compare results
|
| 78 |
+
logging.info("Comparing with PyTorch reference")
|
| 79 |
+
forward_close = analyze_tensor_differences(
|
| 80 |
+
result, reference_result, "Forward output"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return forward_close
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logging.error(f"Test failed with error: {e}")
|
| 87 |
+
import traceback
|
| 88 |
+
|
| 89 |
+
logging.error(traceback.format_exc())
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def test_backward_pass():
|
| 94 |
+
"""
|
| 95 |
+
A simple test for the M*G grouped GEMM backward pass with detailed error handling.
|
| 96 |
+
|
| 97 |
+
In M*G grouping:
|
| 98 |
+
- M dimension is partitioned into G groups (M_total = sum(M_sizes))
|
| 99 |
+
- N dimension is the same for all groups
|
| 100 |
+
"""
|
| 101 |
+
try:
|
| 102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 103 |
+
|
| 104 |
+
# Test parameters for DeepSeek-like models
|
| 105 |
+
G = 4 # Number of groups
|
| 106 |
+
M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted)
|
| 107 |
+
M_total = sum(M_sizes) # Total M dimension
|
| 108 |
+
N = 4096 # Output dimension (same for all groups)
|
| 109 |
+
K = 7168 # Hidden dimension
|
| 110 |
+
|
| 111 |
+
# Create group sizes tensor
|
| 112 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 113 |
+
|
| 114 |
+
# Create input and weight tensors - using float16 for higher precision
|
| 115 |
+
x = torch.randn(
|
| 116 |
+
M_total, K, dtype=torch.float16, device=device, requires_grad=True
|
| 117 |
+
)
|
| 118 |
+
w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True)
|
| 119 |
+
|
| 120 |
+
# Log the setup
|
| 121 |
+
logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
|
| 122 |
+
logging.info(f"Group sizes: {m_sizes}")
|
| 123 |
+
logging.info(f"Input x shape: {x.shape}")
|
| 124 |
+
logging.info(f"Weight w shape: {w.shape}")
|
| 125 |
+
|
| 126 |
+
# Step 1: Run forward pass
|
| 127 |
+
logging.info("Running forward pass")
|
| 128 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
| 129 |
+
logging.info(f"Forward result shape: {result.shape}")
|
| 130 |
+
|
| 131 |
+
# Create a gradient for backpropagation
|
| 132 |
+
grad_output = torch.randn_like(result)
|
| 133 |
+
logging.info(f"Created gradient with shape: {grad_output.shape}")
|
| 134 |
+
|
| 135 |
+
# Step 2: Run backward pass directly
|
| 136 |
+
logging.info("Running backward pass directly")
|
| 137 |
+
grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
|
| 138 |
+
|
| 139 |
+
# Verify gradient shapes
|
| 140 |
+
logging.info(
|
| 141 |
+
f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Step 3: Verify gradient computation using PyTorch's autograd
|
| 145 |
+
logging.info("Running PyTorch reference implementation")
|
| 146 |
+
|
| 147 |
+
# Compute reference gradients
|
| 148 |
+
x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output)
|
| 149 |
+
|
| 150 |
+
# Compare gradients
|
| 151 |
+
logging.info("Comparing gradients with PyTorch reference")
|
| 152 |
+
grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
|
| 153 |
+
grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
|
| 154 |
+
|
| 155 |
+
# Log overall result
|
| 156 |
+
if grad_x_close and grad_w_close:
|
| 157 |
+
logging.info("✓ SUCCESS: Gradients match the PyTorch reference")
|
| 158 |
+
else:
|
| 159 |
+
logging.error("✗ FAILURE: Gradient mismatch detected")
|
| 160 |
+
|
| 161 |
+
return grad_x_close and grad_w_close
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logging.error(f"Test failed with error: {e}")
|
| 165 |
+
import traceback
|
| 166 |
+
|
| 167 |
+
logging.error(traceback.format_exc())
|
| 168 |
+
return False
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def test_multiple_deepseek_configs():
|
| 172 |
+
"""
|
| 173 |
+
Test multiple DeepSeek model configurations with both forward and backward pass verification.
|
| 174 |
+
"""
|
| 175 |
+
# DeepSeek configurations: (G, M, K, N)
|
| 176 |
+
configs = [
|
| 177 |
+
(4, 8192, 7168, 4096), # Config 1
|
| 178 |
+
(4, 8192, 2048, 7168), # Config 2
|
| 179 |
+
(8, 4096, 7168, 4096), # Config 3
|
| 180 |
+
(8, 4096, 2048, 7168), # Config 4
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
results = []
|
| 184 |
+
|
| 185 |
+
for config_idx, (G, M, K, N) in enumerate(configs):
|
| 186 |
+
logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====")
|
| 187 |
+
logging.info(f"G={G}, M={M}, K={K}, N={N}")
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 191 |
+
|
| 192 |
+
# Create even group sizes
|
| 193 |
+
base_size = M // G
|
| 194 |
+
remainder = M % G
|
| 195 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
| 196 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 197 |
+
|
| 198 |
+
# Create input and weight tensors using float16 for higher precision
|
| 199 |
+
x = torch.randn(
|
| 200 |
+
M, K, dtype=torch.float16, device=device, requires_grad=True
|
| 201 |
+
)
|
| 202 |
+
w = torch.randn(
|
| 203 |
+
N, K, dtype=torch.float16, device=device, requires_grad=True
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}")
|
| 207 |
+
|
| 208 |
+
# Run forward pass
|
| 209 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
| 210 |
+
logging.info(f"Forward result shape: {result.shape}")
|
| 211 |
+
|
| 212 |
+
# ===== FORWARD PASS VERIFICATION =====
|
| 213 |
+
# Compute reference forward result
|
| 214 |
+
reference_result = compute_reference_forward(x, w, m_sizes)
|
| 215 |
+
|
| 216 |
+
# Compare forward results
|
| 217 |
+
forward_close = analyze_tensor_differences(
|
| 218 |
+
result, reference_result, "Forward output"
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# ===== BACKWARD PASS VERIFICATION =====
|
| 222 |
+
# Create gradient for backpropagation
|
| 223 |
+
grad_output = torch.randn_like(result)
|
| 224 |
+
|
| 225 |
+
# Run backward pass
|
| 226 |
+
grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
|
| 227 |
+
|
| 228 |
+
# Compute reference gradients
|
| 229 |
+
x_ref_grad, w_ref_grad = compute_reference_backward(
|
| 230 |
+
x, w, m_sizes, grad_output
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Compare backward results
|
| 234 |
+
grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
|
| 235 |
+
grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
|
| 236 |
+
|
| 237 |
+
# Overall config result
|
| 238 |
+
backward_close = grad_x_close and grad_w_close
|
| 239 |
+
config_success = forward_close and backward_close
|
| 240 |
+
results.append(
|
| 241 |
+
(config_idx + 1, config_success, forward_close, backward_close)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Log overall config result
|
| 245 |
+
if config_success:
|
| 246 |
+
logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!")
|
| 247 |
+
else:
|
| 248 |
+
logging.error(
|
| 249 |
+
f"✗ FAILURE: Config {config_idx+1} failed one or more tests"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logging.error(f"Config {config_idx+1} test failed with error: {e}")
|
| 254 |
+
import traceback
|
| 255 |
+
|
| 256 |
+
logging.error(traceback.format_exc())
|
| 257 |
+
results.append((config_idx + 1, False, False, False))
|
| 258 |
+
|
| 259 |
+
# Summary
|
| 260 |
+
logging.info("\n===== Test Results Summary =====")
|
| 261 |
+
for config_idx, overall_success, forward_success, backward_success in results:
|
| 262 |
+
overall_status = "✓ PASSED" if overall_success else "✗ FAILED"
|
| 263 |
+
forward_status = "✓ PASSED" if forward_success else "✗ FAILED"
|
| 264 |
+
backward_status = "✓ PASSED" if backward_success else "✗ FAILED"
|
| 265 |
+
|
| 266 |
+
logging.info(f"Config {config_idx}: {overall_status}")
|
| 267 |
+
logging.info(f" - Forward pass: {forward_status}")
|
| 268 |
+
logging.info(f" - Backward pass: {backward_status}")
|
| 269 |
+
|
| 270 |
+
return all(overall_success for _, overall_success, _, _ in results)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
logging.info(
|
| 275 |
+
"Running verification for both forward and backward pass of M*G grouped GEMM"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Run basic forward pass test
|
| 279 |
+
logging.info("\n===== Running basic forward pass test =====")
|
| 280 |
+
success_forward = test_forward_pass()
|
| 281 |
+
logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}")
|
| 282 |
+
|
| 283 |
+
# Run basic backward pass test
|
| 284 |
+
logging.info("\n===== Running basic backward pass test =====")
|
| 285 |
+
success_backward = test_backward_pass()
|
| 286 |
+
logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}")
|
| 287 |
+
|
| 288 |
+
# Run multiple DeepSeek configs with forward and backward verification
|
| 289 |
+
logging.info("\n===== Running tests for all DeepSeek configs =====")
|
| 290 |
+
success_configs = test_multiple_deepseek_configs()
|
| 291 |
+
logging.info(
|
| 292 |
+
f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Overall result
|
| 296 |
+
overall_success = success_forward and success_backward and success_configs
|
| 297 |
+
logging.info(
|
| 298 |
+
f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}"
|
| 299 |
+
)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py
ADDED
|
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# credit - flat index forward kernel is derived from FBGemm:
|
| 8 |
+
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
| 9 |
+
|
| 10 |
+
# pyre-unsafe
|
| 11 |
+
import functools
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
from typing import Any, Dict, Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
import triton
|
| 21 |
+
import triton.language as tl
|
| 22 |
+
from triton import Config as TConfig
|
| 23 |
+
|
| 24 |
+
from triton.runtime import driver # @manual
|
| 25 |
+
|
| 26 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 27 |
+
|
| 28 |
+
from tma_autotuning import (
|
| 29 |
+
ALIGN_SIZE_M,
|
| 30 |
+
_NV_CONFIGS,
|
| 31 |
+
CudaUtils,
|
| 32 |
+
early_config_prune,
|
| 33 |
+
TmaDescriptorHelper,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Configure logging
|
| 38 |
+
logging.basicConfig(
|
| 39 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# ============== Start Triton Kernels ===============
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@triton.autotune(
|
| 46 |
+
configs=_NV_CONFIGS,
|
| 47 |
+
key=["G", "M_BUCKET", "N", "K"],
|
| 48 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
| 49 |
+
)
|
| 50 |
+
@triton.jit
|
| 51 |
+
def _kernel_mg_forward_hopper(
|
| 52 |
+
a_desc_ptr,
|
| 53 |
+
b_desc_ptr,
|
| 54 |
+
c_ptr,
|
| 55 |
+
workspace,
|
| 56 |
+
m_sizes,
|
| 57 |
+
# problem sizes
|
| 58 |
+
G: tl.constexpr,
|
| 59 |
+
M_BUCKET: tl.constexpr,
|
| 60 |
+
N: tl.constexpr,
|
| 61 |
+
K: tl.constexpr,
|
| 62 |
+
# config
|
| 63 |
+
NUM_SMS: tl.constexpr,
|
| 64 |
+
TMA_SIZE: tl.constexpr,
|
| 65 |
+
USE_EPILOGUE_SUBTILING: tl.constexpr,
|
| 66 |
+
# tiles
|
| 67 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 68 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 69 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 70 |
+
) -> None:
|
| 71 |
+
"""
|
| 72 |
+
Flat index style forward kernel for Hopper.
|
| 73 |
+
For simplicity, we always use TMA Load and TMA Store
|
| 74 |
+
"""
|
| 75 |
+
tbidx = tl.program_id(0) # thread block index
|
| 76 |
+
|
| 77 |
+
c_dtype = c_ptr.dtype.element_ty # output dtype
|
| 78 |
+
|
| 79 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store
|
| 80 |
+
|
| 81 |
+
M_end = 0
|
| 82 |
+
M_start = 0
|
| 83 |
+
processed_tiles = 0
|
| 84 |
+
# Size of individual weight matrix
|
| 85 |
+
n_size = N // G
|
| 86 |
+
n_start = 0
|
| 87 |
+
|
| 88 |
+
for g in range(G):
|
| 89 |
+
# Move down along groups
|
| 90 |
+
# reset to new M offset
|
| 91 |
+
M_start = M_end
|
| 92 |
+
m_size = tl.load(m_sizes + g)
|
| 93 |
+
M_end = M_start + m_size
|
| 94 |
+
n_start = n_size * g
|
| 95 |
+
|
| 96 |
+
if m_size > 0:
|
| 97 |
+
# Process this group
|
| 98 |
+
|
| 99 |
+
# Acquire hold on c_desc_ptr for TMA Store
|
| 100 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
| 101 |
+
desc_ptr=c_desc_ptr,
|
| 102 |
+
global_address=c_ptr + M_start * n_size,
|
| 103 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
| 104 |
+
global_size=[m_size, n_size],
|
| 105 |
+
element_ty=c_dtype,
|
| 106 |
+
)
|
| 107 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
| 108 |
+
|
| 109 |
+
# tiles for this group
|
| 110 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
| 111 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
| 112 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
| 113 |
+
|
| 114 |
+
while tbidx >= processed_tiles and tbidx < (
|
| 115 |
+
processed_tiles + group_num_tiles
|
| 116 |
+
):
|
| 117 |
+
group_index = tbidx - processed_tiles
|
| 118 |
+
|
| 119 |
+
# columnwise
|
| 120 |
+
tile_m_index = group_index % num_m_tiles
|
| 121 |
+
tile_n_index = group_index // num_m_tiles
|
| 122 |
+
|
| 123 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 124 |
+
|
| 125 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
| 126 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
| 127 |
+
global_n_offset = (n_start + n_offset).to(tl.int32)
|
| 128 |
+
|
| 129 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
| 130 |
+
# input block [M,K]
|
| 131 |
+
a = tl._experimental_descriptor_load(
|
| 132 |
+
a_desc_ptr,
|
| 133 |
+
[m_offset, k_offset],
|
| 134 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
| 135 |
+
c_dtype,
|
| 136 |
+
)
|
| 137 |
+
# weight block [N, K]
|
| 138 |
+
b = tl._experimental_descriptor_load(
|
| 139 |
+
b_desc_ptr,
|
| 140 |
+
[global_n_offset, k_offset],
|
| 141 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
| 142 |
+
c_dtype,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
accumulator += tl.dot(a, b.T)
|
| 146 |
+
|
| 147 |
+
# Store using TMA
|
| 148 |
+
|
| 149 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
| 150 |
+
|
| 151 |
+
if USE_EPILOGUE_SUBTILING:
|
| 152 |
+
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
| 153 |
+
acc = tl.permute(acc, (0, 2, 1))
|
| 154 |
+
acc0, acc1 = tl.split(acc)
|
| 155 |
+
c0 = acc0.to(c_dtype)
|
| 156 |
+
tl._experimental_descriptor_store(
|
| 157 |
+
c_desc_ptr, c0, [m_offset, n_offset]
|
| 158 |
+
)
|
| 159 |
+
c1 = acc1.to(c_dtype)
|
| 160 |
+
tl._experimental_descriptor_store(
|
| 161 |
+
c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
tl._experimental_descriptor_store(
|
| 165 |
+
c_desc_ptr,
|
| 166 |
+
accumulator.to(c_dtype),
|
| 167 |
+
[m_offset, n_offset],
|
| 168 |
+
)
|
| 169 |
+
# move to next tile in group
|
| 170 |
+
tbidx += NUM_SMS
|
| 171 |
+
# Update the total tiles count for the next group
|
| 172 |
+
processed_tiles += group_num_tiles
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@triton.autotune(
|
| 176 |
+
configs=_NV_CONFIGS,
|
| 177 |
+
key=["G", "M_BUCKET", "N", "K"],
|
| 178 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
| 179 |
+
)
|
| 180 |
+
@triton.jit
|
| 181 |
+
def _kernel_mg_forward_tma(
|
| 182 |
+
a_desc_ptr,
|
| 183 |
+
b_desc_ptr,
|
| 184 |
+
c_ptr,
|
| 185 |
+
workspace,
|
| 186 |
+
m_sizes,
|
| 187 |
+
a_scale_ptr,
|
| 188 |
+
b_scale_ptr,
|
| 189 |
+
# problem sizes
|
| 190 |
+
G: tl.constexpr,
|
| 191 |
+
M_BUCKET: tl.constexpr,
|
| 192 |
+
N: tl.constexpr,
|
| 193 |
+
K: tl.constexpr,
|
| 194 |
+
# config
|
| 195 |
+
NUM_SMS: tl.constexpr,
|
| 196 |
+
USE_TMA_LOAD: tl.constexpr,
|
| 197 |
+
USE_TMA_STORE: tl.constexpr,
|
| 198 |
+
TMA_SIZE: tl.constexpr,
|
| 199 |
+
USE_FP8: tl.constexpr,
|
| 200 |
+
# tiles
|
| 201 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 202 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 203 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 204 |
+
) -> None:
|
| 205 |
+
"""
|
| 206 |
+
Flat index style forward kernel.
|
| 207 |
+
For simplicity, we always use TMA Load and TMA Store
|
| 208 |
+
"""
|
| 209 |
+
tbidx = tl.program_id(0) # thread block index
|
| 210 |
+
|
| 211 |
+
c_dtype = c_ptr.dtype.element_ty
|
| 212 |
+
|
| 213 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE)
|
| 214 |
+
|
| 215 |
+
M_end = 0
|
| 216 |
+
processed_tiles = 0
|
| 217 |
+
|
| 218 |
+
for g in range(G):
|
| 219 |
+
# Move down along groups
|
| 220 |
+
# reset to new M offset
|
| 221 |
+
M_start = M_end
|
| 222 |
+
m_size = tl.load(m_sizes + g)
|
| 223 |
+
M_end = M_start + m_size
|
| 224 |
+
|
| 225 |
+
if m_size > 0:
|
| 226 |
+
# Process this group
|
| 227 |
+
n_size = N
|
| 228 |
+
|
| 229 |
+
# TMA Store prep
|
| 230 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
| 231 |
+
desc_ptr=c_desc_ptr,
|
| 232 |
+
global_address=c_ptr + M_start * N,
|
| 233 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
| 234 |
+
global_size=[m_size, n_size],
|
| 235 |
+
element_ty=c_dtype,
|
| 236 |
+
)
|
| 237 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
| 238 |
+
|
| 239 |
+
# tiles for this group
|
| 240 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
| 241 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
| 242 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
| 243 |
+
|
| 244 |
+
while tbidx >= processed_tiles and tbidx < (
|
| 245 |
+
processed_tiles + group_num_tiles
|
| 246 |
+
):
|
| 247 |
+
group_index = tbidx - processed_tiles
|
| 248 |
+
|
| 249 |
+
tile_m_index = group_index % num_m_tiles
|
| 250 |
+
tile_n_index = group_index // num_m_tiles
|
| 251 |
+
|
| 252 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 253 |
+
|
| 254 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
| 255 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
| 256 |
+
|
| 257 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
| 258 |
+
# input block [M,K]
|
| 259 |
+
a = tl._experimental_descriptor_load(
|
| 260 |
+
a_desc_ptr,
|
| 261 |
+
[m_offset, k_offset],
|
| 262 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
| 263 |
+
c_dtype,
|
| 264 |
+
)
|
| 265 |
+
# weight block [N, K]
|
| 266 |
+
b = tl._experimental_descriptor_load(
|
| 267 |
+
b_desc_ptr,
|
| 268 |
+
[n_offset, k_offset],
|
| 269 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
| 270 |
+
c_dtype,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
accumulator += tl.dot(a, b.T)
|
| 274 |
+
|
| 275 |
+
# Store using TMA
|
| 276 |
+
|
| 277 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
| 278 |
+
# n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
| 279 |
+
|
| 280 |
+
tl._experimental_descriptor_store(
|
| 281 |
+
c_desc_ptr,
|
| 282 |
+
accumulator.to(c_dtype),
|
| 283 |
+
[m_offset, n_offset],
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Move to the next tile
|
| 287 |
+
tbidx += NUM_SMS
|
| 288 |
+
# Update the total tiles count for the next group
|
| 289 |
+
processed_tiles += group_num_tiles
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
@triton.autotune(
|
| 293 |
+
configs=_NV_CONFIGS,
|
| 294 |
+
key=["G", "M_BUCKET", "N", "K"],
|
| 295 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
| 296 |
+
)
|
| 297 |
+
@triton.jit
|
| 298 |
+
def _kernel_mg_forward_no_tma(
|
| 299 |
+
a_ptr,
|
| 300 |
+
b_ptr,
|
| 301 |
+
c_ptr,
|
| 302 |
+
workspace,
|
| 303 |
+
m_sizes,
|
| 304 |
+
# problem sizes
|
| 305 |
+
G: tl.constexpr,
|
| 306 |
+
M_BUCKET: tl.constexpr,
|
| 307 |
+
N: tl.constexpr,
|
| 308 |
+
K: tl.constexpr,
|
| 309 |
+
# config
|
| 310 |
+
NUM_SMS: tl.constexpr,
|
| 311 |
+
USE_TMA_LOAD: tl.constexpr,
|
| 312 |
+
USE_TMA_STORE: tl.constexpr,
|
| 313 |
+
TMA_SIZE: tl.constexpr,
|
| 314 |
+
# tiles
|
| 315 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 316 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 317 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 318 |
+
) -> None:
|
| 319 |
+
"""
|
| 320 |
+
Flat index style forward kernel.
|
| 321 |
+
For bc and Ampere, we never use TMA Load and TMA Store
|
| 322 |
+
"""
|
| 323 |
+
tbidx = tl.program_id(0) # thread block index
|
| 324 |
+
|
| 325 |
+
c_dtype = c_ptr.dtype.element_ty
|
| 326 |
+
c_desc_ptr = None
|
| 327 |
+
|
| 328 |
+
M_end = 0
|
| 329 |
+
processed_tiles = 0
|
| 330 |
+
|
| 331 |
+
for g in range(G):
|
| 332 |
+
# Move down along groups
|
| 333 |
+
# reset to new M offset
|
| 334 |
+
M_start = M_end
|
| 335 |
+
m_size = tl.load(m_sizes + g)
|
| 336 |
+
M_end = M_start + m_size
|
| 337 |
+
|
| 338 |
+
if m_size > 0:
|
| 339 |
+
# Process this group
|
| 340 |
+
n_size = N
|
| 341 |
+
|
| 342 |
+
# tiles for this group
|
| 343 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
| 344 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
| 345 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
| 346 |
+
|
| 347 |
+
while tbidx >= processed_tiles and tbidx < (
|
| 348 |
+
processed_tiles + group_num_tiles
|
| 349 |
+
):
|
| 350 |
+
group_index = tbidx - processed_tiles
|
| 351 |
+
|
| 352 |
+
tile_m_index = group_index % num_m_tiles
|
| 353 |
+
tile_n_index = group_index // num_m_tiles
|
| 354 |
+
|
| 355 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 356 |
+
|
| 357 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
| 358 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
| 359 |
+
|
| 360 |
+
offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 361 |
+
offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 362 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 363 |
+
|
| 364 |
+
a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
|
| 365 |
+
b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]
|
| 366 |
+
|
| 367 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
| 368 |
+
# Load with bounds checking
|
| 369 |
+
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
|
| 370 |
+
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
|
| 371 |
+
|
| 372 |
+
# Main matmul
|
| 373 |
+
accumulator += tl.dot(a, b.T)
|
| 374 |
+
|
| 375 |
+
# Update pointers for next block
|
| 376 |
+
a_ptrs += BLOCK_SIZE_K
|
| 377 |
+
b_ptrs += BLOCK_SIZE_K
|
| 378 |
+
|
| 379 |
+
# Store without TMA
|
| 380 |
+
offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 381 |
+
offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 382 |
+
|
| 383 |
+
c = accumulator.to(c_dtype)
|
| 384 |
+
|
| 385 |
+
tl.store(
|
| 386 |
+
c_ptr
|
| 387 |
+
+ (M_start + offs_am[:, None]) * N # Row stride is N
|
| 388 |
+
+ offs_bn[None, :], # Column offset
|
| 389 |
+
c,
|
| 390 |
+
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
|
| 391 |
+
)
|
| 392 |
+
# Move to the next tile
|
| 393 |
+
tbidx += NUM_SMS
|
| 394 |
+
# Update the total tiles count for the next group
|
| 395 |
+
processed_tiles += group_num_tiles
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
"""
|
| 399 |
+
Backward pass for grouped GEMM with Triton, where grouping is M*G
|
| 400 |
+
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# ---- dx flat linear indexed ----
|
| 405 |
+
@triton.autotune(
|
| 406 |
+
configs=_NV_CONFIGS,
|
| 407 |
+
key=["G", "M_BUCKET", "N", "K"],
|
| 408 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
| 409 |
+
)
|
| 410 |
+
@triton.jit
|
| 411 |
+
def _kernel_mg_dx_tma(
|
| 412 |
+
grad_output_desc_ptr, # [MG, N]
|
| 413 |
+
w_desc_ptr, # [N, K]
|
| 414 |
+
grad_input_ptr, # output grad_x [MG, K]
|
| 415 |
+
workspace, # for TMA store
|
| 416 |
+
m_sizes, # group sizes [G]
|
| 417 |
+
# problem sizes
|
| 418 |
+
G: tl.constexpr,
|
| 419 |
+
M_BUCKET: tl.constexpr,
|
| 420 |
+
N: tl.constexpr,
|
| 421 |
+
K: tl.constexpr,
|
| 422 |
+
# config
|
| 423 |
+
NUM_SMS: tl.constexpr,
|
| 424 |
+
USE_TMA_LOAD: tl.constexpr,
|
| 425 |
+
USE_TMA_STORE: tl.constexpr,
|
| 426 |
+
TMA_SIZE: tl.constexpr,
|
| 427 |
+
# tiles
|
| 428 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 429 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 430 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 431 |
+
) -> None:
|
| 432 |
+
"""
|
| 433 |
+
TMA-optimized kernel for computing gradients with respect to input (dx).
|
| 434 |
+
For the forward pass Y = X @ W.T, the backward for input is:
|
| 435 |
+
grad_X = grad_Y @ W
|
| 436 |
+
|
| 437 |
+
This maps to [MG, N] @ [N, K] -> [MG, K]
|
| 438 |
+
|
| 439 |
+
Key differences from forward:
|
| 440 |
+
1. W is used directly and not transposed
|
| 441 |
+
2. The reduction dimension is now N (not K)
|
| 442 |
+
3. Output is [M, K] instead of [M, N]
|
| 443 |
+
"""
|
| 444 |
+
tbidx = tl.program_id(0) # thread block index
|
| 445 |
+
|
| 446 |
+
c_dtype = grad_input_ptr.dtype.element_ty
|
| 447 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE)
|
| 448 |
+
|
| 449 |
+
M_end = 0
|
| 450 |
+
processed_tiles = 0
|
| 451 |
+
|
| 452 |
+
for g in range(G):
|
| 453 |
+
# Move down along groups - same as forward
|
| 454 |
+
M_start = M_end
|
| 455 |
+
m_size = tl.load(m_sizes + g)
|
| 456 |
+
M_end = M_start + m_size
|
| 457 |
+
|
| 458 |
+
if m_size > 0:
|
| 459 |
+
# Process this group
|
| 460 |
+
# tiles for this group - now producing [M, K] output
|
| 461 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
| 462 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
| 463 |
+
group_num_tiles = num_m_tiles * num_k_tiles
|
| 464 |
+
|
| 465 |
+
# TMA Store prep for [M, K] output
|
| 466 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
| 467 |
+
desc_ptr=c_desc_ptr,
|
| 468 |
+
global_address=grad_input_ptr + M_start * K,
|
| 469 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
| 470 |
+
global_size=[m_size, K],
|
| 471 |
+
element_ty=c_dtype,
|
| 472 |
+
)
|
| 473 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
| 474 |
+
|
| 475 |
+
while tbidx >= processed_tiles and tbidx < (
|
| 476 |
+
processed_tiles + group_num_tiles
|
| 477 |
+
):
|
| 478 |
+
group_index = tbidx - processed_tiles
|
| 479 |
+
|
| 480 |
+
# Different tiling scheme for [M, K] output
|
| 481 |
+
tile_m_index = group_index % num_m_tiles
|
| 482 |
+
tile_k_index = group_index // num_m_tiles
|
| 483 |
+
|
| 484 |
+
# for grad_input block [M, K]
|
| 485 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
| 486 |
+
|
| 487 |
+
# Position in full matrix
|
| 488 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
| 489 |
+
k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
|
| 490 |
+
|
| 491 |
+
# reduce along N dimension (instead of K in forward)
|
| 492 |
+
for n_offset in range(0, N, BLOCK_SIZE_N):
|
| 493 |
+
# grad_output block [M, N]
|
| 494 |
+
grad_output = tl._experimental_descriptor_load(
|
| 495 |
+
grad_output_desc_ptr,
|
| 496 |
+
[m_offset, n_offset],
|
| 497 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
| 498 |
+
c_dtype,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
# weight block [N, K] - no transpose needed
|
| 502 |
+
w = tl._experimental_descriptor_load(
|
| 503 |
+
w_desc_ptr,
|
| 504 |
+
[n_offset, k_offset],
|
| 505 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
| 506 |
+
c_dtype,
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
# grad_x = grad_output @ w
|
| 510 |
+
# reducing along N dimension
|
| 511 |
+
accumulator += tl.dot(grad_output, w)
|
| 512 |
+
|
| 513 |
+
# Store using TMA
|
| 514 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
| 515 |
+
# k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
|
| 516 |
+
|
| 517 |
+
tl._experimental_descriptor_store(
|
| 518 |
+
c_desc_ptr,
|
| 519 |
+
accumulator.to(c_dtype),
|
| 520 |
+
[m_offset, k_offset],
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# Move to the next tile
|
| 524 |
+
tbidx += NUM_SMS
|
| 525 |
+
|
| 526 |
+
# Update the total tiles count for the next group
|
| 527 |
+
processed_tiles += group_num_tiles
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
# ---- dw flat linear indexed ----
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
@triton.autotune(
|
| 534 |
+
configs=_NV_CONFIGS,
|
| 535 |
+
key=["G", "M_BUCKET", "N", "K"],
|
| 536 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
| 537 |
+
)
|
| 538 |
+
@triton.jit
|
| 539 |
+
def _kernel_mg_dw_tma(
|
| 540 |
+
x_desc_ptr, # input descriptor [M_total, K]
|
| 541 |
+
grad_output_desc_ptr, # grad_output descriptor [M_total, N]
|
| 542 |
+
grad_weight_ptr, # output grad_w [N, K]
|
| 543 |
+
workspace, # workspace for TMA store
|
| 544 |
+
m_sizes, # group sizes [G]
|
| 545 |
+
# problem sizes
|
| 546 |
+
G: tl.constexpr,
|
| 547 |
+
M_BUCKET: tl.constexpr,
|
| 548 |
+
N: tl.constexpr,
|
| 549 |
+
K: tl.constexpr,
|
| 550 |
+
# config
|
| 551 |
+
NUM_SMS: tl.constexpr,
|
| 552 |
+
USE_TMA_LOAD: tl.constexpr,
|
| 553 |
+
USE_TMA_STORE: tl.constexpr,
|
| 554 |
+
TMA_SIZE: tl.constexpr,
|
| 555 |
+
# tiles
|
| 556 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 557 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 558 |
+
BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension
|
| 559 |
+
) -> None:
|
| 560 |
+
"""
|
| 561 |
+
Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
|
| 562 |
+
Uses flat index structure similar to forward.
|
| 563 |
+
|
| 564 |
+
For the forward pass Y = X @ W.T,
|
| 565 |
+
the backward for weights is:
|
| 566 |
+
grad_W = grad_Y.T @ X
|
| 567 |
+
|
| 568 |
+
Where:
|
| 569 |
+
- grad_Y is [MG, N]
|
| 570 |
+
- X is [MG, K]
|
| 571 |
+
- grad_W is [N, K]
|
| 572 |
+
- we return [N,K]
|
| 573 |
+
"""
|
| 574 |
+
# Get thread block index l
|
| 575 |
+
tbidx = tl.program_id(0)
|
| 576 |
+
|
| 577 |
+
# Get output data type
|
| 578 |
+
c_dtype = grad_weight_ptr.dtype.element_ty
|
| 579 |
+
|
| 580 |
+
# Calculate number of output tiles
|
| 581 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
| 582 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
| 583 |
+
total_output_tiles = num_n_tiles * num_k_tiles
|
| 584 |
+
|
| 585 |
+
# Process tiles in strided manner across SMs
|
| 586 |
+
for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
|
| 587 |
+
# Calculate tile indices
|
| 588 |
+
tile_n_idx = tile_idx % num_n_tiles
|
| 589 |
+
tile_k_idx = tile_idx // num_n_tiles
|
| 590 |
+
|
| 591 |
+
# Calculate global offsets
|
| 592 |
+
n_offset = tile_n_idx * BLOCK_SIZE_N
|
| 593 |
+
k_offset = tile_k_idx * BLOCK_SIZE_K
|
| 594 |
+
|
| 595 |
+
# Initialize accumulator for this output tile [N, K]
|
| 596 |
+
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
|
| 597 |
+
|
| 598 |
+
# Process each group
|
| 599 |
+
M_end = 0
|
| 600 |
+
for g in range(G):
|
| 601 |
+
# Get group boundaries
|
| 602 |
+
M_start = M_end
|
| 603 |
+
m_size = tl.load(m_sizes + g)
|
| 604 |
+
M_end = M_start + m_size
|
| 605 |
+
|
| 606 |
+
# Only process if group is non-empty
|
| 607 |
+
if m_size > 0:
|
| 608 |
+
# Process this group in chunks along the M dimension
|
| 609 |
+
for m_offset in range(0, m_size, BLOCK_SIZE_M):
|
| 610 |
+
# Calculate actual block size (handling boundary)
|
| 611 |
+
m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)
|
| 612 |
+
|
| 613 |
+
# Only process if we have actual work to do
|
| 614 |
+
if m_block_size > 0:
|
| 615 |
+
# Global offset for this chunk
|
| 616 |
+
m_global_offset = M_start + m_offset
|
| 617 |
+
|
| 618 |
+
if USE_TMA_LOAD:
|
| 619 |
+
# Load input chunk [M_chunk, K] using TMA
|
| 620 |
+
x_block = tl._experimental_descriptor_load(
|
| 621 |
+
x_desc_ptr,
|
| 622 |
+
[m_global_offset, k_offset],
|
| 623 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
| 624 |
+
c_dtype,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# Load grad_output chunk [M_chunk, N] using TMA
|
| 628 |
+
grad_output_block = tl._experimental_descriptor_load(
|
| 629 |
+
grad_output_desc_ptr,
|
| 630 |
+
[m_global_offset, n_offset],
|
| 631 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
| 632 |
+
c_dtype,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
# Apply masks for valid regions
|
| 636 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
| 637 |
+
m_mask = offs_m < m_block_size
|
| 638 |
+
|
| 639 |
+
# Zero out invalid elements
|
| 640 |
+
x_block = tl.where(m_mask[:, None], x_block, 0.0)
|
| 641 |
+
grad_output_block = tl.where(
|
| 642 |
+
m_mask[:, None], grad_output_block, 0.0
|
| 643 |
+
)
|
| 644 |
+
else:
|
| 645 |
+
# Manual load with bounds checking
|
| 646 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
| 647 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
| 648 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 649 |
+
|
| 650 |
+
# Create masks
|
| 651 |
+
m_mask = offs_m < m_block_size
|
| 652 |
+
n_mask = offs_n < N - n_offset
|
| 653 |
+
k_mask = offs_k < K - k_offset
|
| 654 |
+
|
| 655 |
+
# Combined masks
|
| 656 |
+
mk_mask = m_mask[:, None] & k_mask[None, :]
|
| 657 |
+
mn_mask = m_mask[:, None] & n_mask[None, :]
|
| 658 |
+
|
| 659 |
+
# Global offsets for loading
|
| 660 |
+
m_global_offs = m_global_offset + offs_m
|
| 661 |
+
|
| 662 |
+
# Load x block [M_chunk, K]
|
| 663 |
+
x_block = tl.load(
|
| 664 |
+
x_desc_ptr
|
| 665 |
+
+ m_global_offs[:, None] * K
|
| 666 |
+
+ (k_offset + offs_k)[None, :],
|
| 667 |
+
mask=mk_mask,
|
| 668 |
+
other=0.0,
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# Load grad_output block [M_chunk, N]
|
| 672 |
+
grad_output_block = tl.load(
|
| 673 |
+
grad_output_desc_ptr
|
| 674 |
+
+ m_global_offs[:, None] * N
|
| 675 |
+
+ (n_offset + offs_n)[None, :],
|
| 676 |
+
mask=mn_mask,
|
| 677 |
+
other=0.0,
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# Compute partial contribution: grad_W += grad_Y.T @ X
|
| 681 |
+
# transpose grad_output for the matmul
|
| 682 |
+
contribution = tl.dot(
|
| 683 |
+
grad_output_block.to(tl.float32).T, # [N, M_chunk]
|
| 684 |
+
x_block.to(tl.float32), # [M_chunk, K]
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# Accumulate
|
| 688 |
+
accumulator += contribution
|
| 689 |
+
|
| 690 |
+
# Store the result
|
| 691 |
+
if USE_TMA_STORE:
|
| 692 |
+
# Store using TMA
|
| 693 |
+
tl._experimental_descriptor_store(
|
| 694 |
+
workspace, # TMA store descriptor
|
| 695 |
+
accumulator.to(c_dtype),
|
| 696 |
+
[n_offset, k_offset],
|
| 697 |
+
)
|
| 698 |
+
else:
|
| 699 |
+
# Manual store with bounds checking
|
| 700 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
| 701 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 702 |
+
|
| 703 |
+
# Create masks for bounds checking
|
| 704 |
+
n_mask = offs_n < N - n_offset
|
| 705 |
+
k_mask = offs_k < K - k_offset
|
| 706 |
+
output_mask = n_mask[:, None] & k_mask[None, :]
|
| 707 |
+
|
| 708 |
+
# Store the result
|
| 709 |
+
tl.store(
|
| 710 |
+
grad_weight_ptr
|
| 711 |
+
+ (n_offset + offs_n)[:, None] * K
|
| 712 |
+
+ (k_offset + offs_k)[None, :],
|
| 713 |
+
accumulator.to(c_dtype),
|
| 714 |
+
mask=output_mask,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
# ======== End Triton kernels ========
|
| 719 |
+
|
| 720 |
+
# ======== Triton wrapper functions ========
|
| 721 |
+
|
| 722 |
+
# ----- main forward pass wrapper -----
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def grouped_gemm_forward(
|
| 726 |
+
x: torch.Tensor,
|
| 727 |
+
w: torch.Tensor,
|
| 728 |
+
m_sizes: torch.Tensor,
|
| 729 |
+
tma_size: int = 128,
|
| 730 |
+
) -> torch.Tensor:
|
| 731 |
+
"""
|
| 732 |
+
M*G style grouped GEMM with TMA and Float8 support.
|
| 733 |
+
# Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.
|
| 734 |
+
|
| 735 |
+
"""
|
| 736 |
+
if not CudaUtils.verify_tma():
|
| 737 |
+
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
|
| 738 |
+
|
| 739 |
+
G = m_sizes.shape[0]
|
| 740 |
+
|
| 741 |
+
assert x.is_contiguous()
|
| 742 |
+
assert w.is_contiguous()
|
| 743 |
+
assert m_sizes.is_contiguous()
|
| 744 |
+
|
| 745 |
+
# Total input size is now [M_total, K] where M_total is the sum of all group sizes
|
| 746 |
+
M_total, K = x.shape
|
| 747 |
+
N = w.shape[0] # N is now the same for all groups
|
| 748 |
+
|
| 749 |
+
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
|
| 750 |
+
|
| 751 |
+
# Verify that all group sizes are multiples of ALIGN_SIZE_M
|
| 752 |
+
# This check is commented out because it will involve a GPU-CPU sync
|
| 753 |
+
# assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"
|
| 754 |
+
|
| 755 |
+
# Create output tensor with correct shape [M_total, N]
|
| 756 |
+
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
|
| 757 |
+
|
| 758 |
+
if M_total == 0:
|
| 759 |
+
return y
|
| 760 |
+
|
| 761 |
+
NUM_SMS = CudaUtils.get_num_sms()
|
| 762 |
+
USE_TMA_LOAD = True
|
| 763 |
+
USE_TMA_STORE = True
|
| 764 |
+
USE_EPILOGUE_SUBTILING = False
|
| 765 |
+
|
| 766 |
+
# TMA descriptor helper
|
| 767 |
+
desc_helper = None
|
| 768 |
+
desc_x = x
|
| 769 |
+
desc_w = w
|
| 770 |
+
workspace = None
|
| 771 |
+
|
| 772 |
+
if USE_TMA_LOAD:
|
| 773 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
| 774 |
+
desc_helper.init_tma_descriptor("x")
|
| 775 |
+
desc_helper.init_tma_descriptor("w")
|
| 776 |
+
desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
|
| 777 |
+
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
| 778 |
+
|
| 779 |
+
if USE_TMA_STORE:
|
| 780 |
+
workspace = torch.empty(
|
| 781 |
+
NUM_SMS * desc_helper.tma_size,
|
| 782 |
+
device=x.device,
|
| 783 |
+
dtype=torch.uint8,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
def grid(META):
|
| 787 |
+
if USE_TMA_LOAD:
|
| 788 |
+
nonlocal desc_helper
|
| 789 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 790 |
+
"x",
|
| 791 |
+
x.data_ptr(),
|
| 792 |
+
M_total,
|
| 793 |
+
K,
|
| 794 |
+
META["BLOCK_SIZE_M"],
|
| 795 |
+
META["BLOCK_SIZE_K"],
|
| 796 |
+
x.element_size(),
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 800 |
+
"w",
|
| 801 |
+
w.data_ptr(),
|
| 802 |
+
N,
|
| 803 |
+
K,
|
| 804 |
+
META["BLOCK_SIZE_N"],
|
| 805 |
+
META["BLOCK_SIZE_K"],
|
| 806 |
+
w.element_size(),
|
| 807 |
+
)
|
| 808 |
+
return (NUM_SMS,)
|
| 809 |
+
|
| 810 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
| 811 |
+
|
| 812 |
+
_kernel_mg_forward_hopper[grid](
|
| 813 |
+
desc_x,
|
| 814 |
+
desc_w,
|
| 815 |
+
y,
|
| 816 |
+
workspace,
|
| 817 |
+
m_sizes,
|
| 818 |
+
G,
|
| 819 |
+
M_BUCKET,
|
| 820 |
+
N,
|
| 821 |
+
K,
|
| 822 |
+
NUM_SMS,
|
| 823 |
+
TMA_SIZE=tma_size,
|
| 824 |
+
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
return y
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
# ======== Improved Backward =============
|
| 831 |
+
def grouped_gemm_backward(
|
| 832 |
+
grad_output: torch.Tensor,
|
| 833 |
+
x: torch.Tensor,
|
| 834 |
+
w: torch.Tensor,
|
| 835 |
+
m_sizes: torch.Tensor,
|
| 836 |
+
use_tma: bool = True,
|
| 837 |
+
tma_size: int = 128,
|
| 838 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 839 |
+
"""
|
| 840 |
+
Unified backward pass for grouped GeMM with M*G grouping.
|
| 841 |
+
Uses optimized TMA-based implementations for both dx and dw when available.
|
| 842 |
+
|
| 843 |
+
Args:
|
| 844 |
+
grad_output: Gradient of output, shape [M_total, N]
|
| 845 |
+
x: Input tensor from forward pass, shape [M_total, K]
|
| 846 |
+
w: Weight tensor from forward pass, shape [N, K]
|
| 847 |
+
m_sizes: Group sizes tensor, shape [G]
|
| 848 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
| 849 |
+
tma_size: Size of TMA descriptor in bytes
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
Returns:
|
| 853 |
+
Tuple of gradients with respect to x and w: (grad_x, grad_w)
|
| 854 |
+
"""
|
| 855 |
+
logging.info("Starting unified grouped_gemm_backward")
|
| 856 |
+
|
| 857 |
+
# do this once, seems expensive
|
| 858 |
+
NUM_SMS = CudaUtils.get_num_sms()
|
| 859 |
+
|
| 860 |
+
# Basic validation
|
| 861 |
+
G = m_sizes.shape[0]
|
| 862 |
+
M_total, K_x = x.shape
|
| 863 |
+
M_grad, N = grad_output.shape
|
| 864 |
+
N_w, K_w = w.shape
|
| 865 |
+
|
| 866 |
+
# Check dimensions
|
| 867 |
+
if K_x != K_w:
|
| 868 |
+
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
|
| 869 |
+
if M_total != M_grad:
|
| 870 |
+
raise ValueError(
|
| 871 |
+
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
# Check total M matches sum of group sizes
|
| 875 |
+
sum_m_sizes = m_sizes.sum().item()
|
| 876 |
+
if M_total != sum_m_sizes:
|
| 877 |
+
raise ValueError(
|
| 878 |
+
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
# Make sure inputs are contiguous
|
| 882 |
+
grad_output = grad_output.contiguous()
|
| 883 |
+
x = x.contiguous()
|
| 884 |
+
w = w.contiguous()
|
| 885 |
+
m_sizes = m_sizes.contiguous()
|
| 886 |
+
|
| 887 |
+
# Check TMA support
|
| 888 |
+
can_use_tma = use_tma and CudaUtils.verify_tma()
|
| 889 |
+
if use_tma and not can_use_tma:
|
| 890 |
+
logging.info("TMA requested but not supported on this device")
|
| 891 |
+
use_tma = False
|
| 892 |
+
|
| 893 |
+
# Compute grad_x using flat linear implementation
|
| 894 |
+
try:
|
| 895 |
+
logging.info(f"Computing grad_x with flat linear kernel")
|
| 896 |
+
|
| 897 |
+
# Use TMA-optimized implementation
|
| 898 |
+
grad_x = grouped_gemm_dx_tma(
|
| 899 |
+
grad_output=grad_output,
|
| 900 |
+
w=w,
|
| 901 |
+
m_sizes=m_sizes,
|
| 902 |
+
num_sms=NUM_SMS,
|
| 903 |
+
tma_size=tma_size,
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
except Exception as e:
|
| 907 |
+
logging.error(f"Error in grad_x computation: {e}")
|
| 908 |
+
raise
|
| 909 |
+
|
| 910 |
+
# Compute grad_w using flat linear style implementation
|
| 911 |
+
try:
|
| 912 |
+
logging.info(f"Computing grad_w with flat linear kernel")
|
| 913 |
+
|
| 914 |
+
grad_w = grouped_gemm_dw_tma(
|
| 915 |
+
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
|
| 916 |
+
)
|
| 917 |
+
except Exception as e:
|
| 918 |
+
logging.error(f"Error in grad_w computation: {e}")
|
| 919 |
+
raise
|
| 920 |
+
|
| 921 |
+
return grad_x, grad_w
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
# ----- dx backward pass wrapper -----
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
def grouped_gemm_dx_tma(
|
| 928 |
+
grad_output: torch.Tensor,
|
| 929 |
+
w: torch.Tensor,
|
| 930 |
+
m_sizes: torch.Tensor,
|
| 931 |
+
num_sms: int = 132,
|
| 932 |
+
tma_size: int = 128,
|
| 933 |
+
) -> torch.Tensor:
|
| 934 |
+
"""
|
| 935 |
+
Optimized backward pass wrapper for computing gradient with respect to input (dx)
|
| 936 |
+
using TMA patterns similar to the forward pass.
|
| 937 |
+
|
| 938 |
+
Args:
|
| 939 |
+
grad_output: Gradient of output, shape [M_total, N]
|
| 940 |
+
w: Weight tensor, shape [N, K]
|
| 941 |
+
m_sizes: Group sizes tensor, shape [G]
|
| 942 |
+
tma_size: Size of TMA descriptor
|
| 943 |
+
# using_fp8: Whether to use FP8 quantization
|
| 944 |
+
# grad_output_scale: Scale for grad_output in FP8 mode
|
| 945 |
+
# w_scale: Scale for w in FP8 mode
|
| 946 |
+
|
| 947 |
+
Returns:
|
| 948 |
+
grad_x: Gradient with respect to x, shape [M_total, K]
|
| 949 |
+
"""
|
| 950 |
+
"""
|
| 951 |
+
Optimized backward pass for computing gradient with respect to input (dx)
|
| 952 |
+
using TMA patterns similar to the forward pass.
|
| 953 |
+
|
| 954 |
+
Args:
|
| 955 |
+
grad_output: Gradient of output, shape [M_total, N]
|
| 956 |
+
w: Weight tensor, shape [N, K]
|
| 957 |
+
m_sizes: Group sizes tensor, shape [G]
|
| 958 |
+
tma_size: Size of TMA descriptor
|
| 959 |
+
using_fp8: Whether to use FP8 quantization
|
| 960 |
+
# grad_output_scale: Scale for grad_output in FP8 mode
|
| 961 |
+
# w_scale: Scale for w in FP8 mode
|
| 962 |
+
|
| 963 |
+
Returns:
|
| 964 |
+
grad_x: Gradient with respect to x, shape [M_total, K]
|
| 965 |
+
"""
|
| 966 |
+
if not CudaUtils.verify_tma():
|
| 967 |
+
raise NotImplementedError("Optimized dx computation requires TMA support")
|
| 968 |
+
|
| 969 |
+
G = m_sizes.shape[0]
|
| 970 |
+
|
| 971 |
+
assert grad_output.is_contiguous()
|
| 972 |
+
assert w.is_contiguous()
|
| 973 |
+
assert m_sizes.is_contiguous()
|
| 974 |
+
|
| 975 |
+
M_total, N_grad = grad_output.shape
|
| 976 |
+
N_w, K = w.shape
|
| 977 |
+
|
| 978 |
+
# Check dimensions
|
| 979 |
+
assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"
|
| 980 |
+
|
| 981 |
+
# Verify that the sum of m_sizes matches M_total
|
| 982 |
+
sum_m_sizes = m_sizes.sum().item()
|
| 983 |
+
assert (
|
| 984 |
+
M_total == sum_m_sizes
|
| 985 |
+
), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
| 986 |
+
|
| 987 |
+
# Create output tensor (grad_x) with shape [M_total, K]
|
| 988 |
+
grad_x = torch.empty(
|
| 989 |
+
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
NUM_SMS = num_sms # CudaUtils.get_num_sms()
|
| 993 |
+
USE_TMA_LOAD = True
|
| 994 |
+
USE_TMA_STORE = True
|
| 995 |
+
|
| 996 |
+
# Set up TMA descriptors
|
| 997 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
| 998 |
+
desc_helper.init_tma_descriptor("grad_output")
|
| 999 |
+
desc_helper.init_tma_descriptor("w")
|
| 1000 |
+
desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
|
| 1001 |
+
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
| 1002 |
+
|
| 1003 |
+
# Allocate workspace for TMA store
|
| 1004 |
+
workspace = torch.empty(
|
| 1005 |
+
NUM_SMS * desc_helper.tma_size,
|
| 1006 |
+
device=grad_output.device,
|
| 1007 |
+
dtype=torch.uint8,
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
def grid(META):
|
| 1011 |
+
# Fill TMA descriptors with appropriate dimensions
|
| 1012 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 1013 |
+
"grad_output",
|
| 1014 |
+
grad_output.data_ptr(),
|
| 1015 |
+
M_total,
|
| 1016 |
+
N_grad,
|
| 1017 |
+
META["BLOCK_SIZE_M"],
|
| 1018 |
+
META["BLOCK_SIZE_N"],
|
| 1019 |
+
grad_output.element_size(),
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 1023 |
+
"w",
|
| 1024 |
+
w.data_ptr(),
|
| 1025 |
+
N_w,
|
| 1026 |
+
K,
|
| 1027 |
+
META["BLOCK_SIZE_N"],
|
| 1028 |
+
META["BLOCK_SIZE_K"],
|
| 1029 |
+
w.element_size(),
|
| 1030 |
+
)
|
| 1031 |
+
return (NUM_SMS,)
|
| 1032 |
+
|
| 1033 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
| 1034 |
+
|
| 1035 |
+
# Launch the flat linear kernel for computing grad_x
|
| 1036 |
+
_kernel_mg_dx_tma[grid](
|
| 1037 |
+
desc_grad_output,
|
| 1038 |
+
desc_w,
|
| 1039 |
+
grad_x,
|
| 1040 |
+
workspace,
|
| 1041 |
+
m_sizes,
|
| 1042 |
+
G,
|
| 1043 |
+
M_BUCKET,
|
| 1044 |
+
N_grad, # N dimension is now the reduction dimension
|
| 1045 |
+
K,
|
| 1046 |
+
NUM_SMS,
|
| 1047 |
+
USE_TMA_LOAD,
|
| 1048 |
+
USE_TMA_STORE,
|
| 1049 |
+
TMA_SIZE=tma_size,
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
return grad_x
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
# ======== dw wrapper function ==========
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
def grouped_gemm_dw_tma(
|
| 1059 |
+
x: torch.Tensor,
|
| 1060 |
+
grad_output: torch.Tensor,
|
| 1061 |
+
m_sizes: torch.Tensor,
|
| 1062 |
+
num_sms: int = 132,
|
| 1063 |
+
tma_size: int = 128,
|
| 1064 |
+
) -> torch.Tensor:
|
| 1065 |
+
"""
|
| 1066 |
+
Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
|
| 1067 |
+
For the forward pass Y = X @ W.T, the backward for weights is:
|
| 1068 |
+
grad_W = grad_Y.T @ X
|
| 1069 |
+
|
| 1070 |
+
Args:
|
| 1071 |
+
x: Input tensor, shape [M_total, K]
|
| 1072 |
+
grad_output: Gradient of output, shape [M_total, N]
|
| 1073 |
+
m_sizes: Group sizes tensor, shape [G]
|
| 1074 |
+
tma_size: Size of TMA descriptor in bytes
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
Returns:
|
| 1078 |
+
grad_w: Gradient with respect to weights, shape [N, K]
|
| 1079 |
+
"""
|
| 1080 |
+
# Check TMA support
|
| 1081 |
+
has_tma_support = CudaUtils.verify_tma()
|
| 1082 |
+
|
| 1083 |
+
# Get group count
|
| 1084 |
+
G = m_sizes.shape[0]
|
| 1085 |
+
|
| 1086 |
+
# Ensure contiguous tensors
|
| 1087 |
+
x = x.contiguous()
|
| 1088 |
+
grad_output = grad_output.contiguous()
|
| 1089 |
+
m_sizes = m_sizes.contiguous()
|
| 1090 |
+
|
| 1091 |
+
# Get dimensions
|
| 1092 |
+
M_total, K_x = x.shape
|
| 1093 |
+
M_grad, N = grad_output.shape
|
| 1094 |
+
|
| 1095 |
+
# Check dimensions
|
| 1096 |
+
assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"
|
| 1097 |
+
|
| 1098 |
+
# Verify that the sum of m_sizes matches M_total
|
| 1099 |
+
sum_m_sizes = m_sizes.sum().item()
|
| 1100 |
+
assert (
|
| 1101 |
+
sum_m_sizes == M_total
|
| 1102 |
+
), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
| 1103 |
+
|
| 1104 |
+
# Create output tensor (grad_w) with shape [N, K]
|
| 1105 |
+
grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)
|
| 1106 |
+
|
| 1107 |
+
NUM_SMS = num_sms
|
| 1108 |
+
|
| 1109 |
+
# TODO - hardcoded for now...but should set TMA flags based on hardware support
|
| 1110 |
+
USE_TMA_LOAD = True # has_tma_support
|
| 1111 |
+
USE_TMA_STORE = True # has_tma_support
|
| 1112 |
+
|
| 1113 |
+
# Set up TMA descriptors or direct pointers
|
| 1114 |
+
if USE_TMA_LOAD or USE_TMA_STORE:
|
| 1115 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
| 1116 |
+
|
| 1117 |
+
if USE_TMA_LOAD:
|
| 1118 |
+
desc_helper.init_tma_descriptor("x")
|
| 1119 |
+
desc_helper.init_tma_descriptor("grad_output")
|
| 1120 |
+
x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
|
| 1121 |
+
grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
|
| 1122 |
+
"grad_output"
|
| 1123 |
+
)
|
| 1124 |
+
else:
|
| 1125 |
+
x_desc = x
|
| 1126 |
+
grad_output_desc = grad_output
|
| 1127 |
+
|
| 1128 |
+
if USE_TMA_STORE:
|
| 1129 |
+
desc_helper.init_tma_descriptor("grad_w")
|
| 1130 |
+
workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
|
| 1131 |
+
else:
|
| 1132 |
+
workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
|
| 1133 |
+
else:
|
| 1134 |
+
# If not using TMA, just use the tensors directly
|
| 1135 |
+
x_desc = x
|
| 1136 |
+
grad_output_desc = grad_output
|
| 1137 |
+
workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
|
| 1138 |
+
|
| 1139 |
+
# M_BUCKET for grid size
|
| 1140 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
| 1141 |
+
|
| 1142 |
+
# Define grid for kernel launch
|
| 1143 |
+
def grid(META):
|
| 1144 |
+
if USE_TMA_LOAD or USE_TMA_STORE:
|
| 1145 |
+
|
| 1146 |
+
if USE_TMA_LOAD:
|
| 1147 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 1148 |
+
"x",
|
| 1149 |
+
x.data_ptr(),
|
| 1150 |
+
M_total,
|
| 1151 |
+
K_x,
|
| 1152 |
+
META["BLOCK_SIZE_M"],
|
| 1153 |
+
META["BLOCK_SIZE_K"],
|
| 1154 |
+
x.element_size(),
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 1158 |
+
"grad_output",
|
| 1159 |
+
grad_output.data_ptr(),
|
| 1160 |
+
M_total,
|
| 1161 |
+
N,
|
| 1162 |
+
META["BLOCK_SIZE_M"],
|
| 1163 |
+
META["BLOCK_SIZE_N"],
|
| 1164 |
+
grad_output.element_size(),
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
if USE_TMA_STORE:
|
| 1168 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 1169 |
+
"grad_w",
|
| 1170 |
+
grad_w.data_ptr(),
|
| 1171 |
+
N,
|
| 1172 |
+
K_x,
|
| 1173 |
+
META["BLOCK_SIZE_N"],
|
| 1174 |
+
META["BLOCK_SIZE_K"],
|
| 1175 |
+
grad_w.element_size(),
|
| 1176 |
+
)
|
| 1177 |
+
|
| 1178 |
+
# Return grid size - one block per SM for balanced work distribution
|
| 1179 |
+
return (NUM_SMS,)
|
| 1180 |
+
|
| 1181 |
+
# Launch the optimized kernel
|
| 1182 |
+
_kernel_mg_dw_tma[grid](
|
| 1183 |
+
x_desc,
|
| 1184 |
+
grad_output_desc,
|
| 1185 |
+
grad_w,
|
| 1186 |
+
workspace,
|
| 1187 |
+
m_sizes,
|
| 1188 |
+
G,
|
| 1189 |
+
M_BUCKET,
|
| 1190 |
+
N,
|
| 1191 |
+
K_x,
|
| 1192 |
+
NUM_SMS,
|
| 1193 |
+
USE_TMA_LOAD,
|
| 1194 |
+
USE_TMA_STORE,
|
| 1195 |
+
TMA_SIZE=tma_size,
|
| 1196 |
+
)
|
| 1197 |
+
|
| 1198 |
+
return grad_w
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
# ======== End Backwards Wrapper Functions =============
|
| 1202 |
+
|
| 1203 |
+
# ======== PyTorch wrapper functions ========
|
| 1204 |
+
|
| 1205 |
+
|
| 1206 |
+
class GroupedGEMM_mg(torch.autograd.Function):
|
| 1207 |
+
"""
|
| 1208 |
+
Autograd function for GroupedGEMM with M*G grouping.
|
| 1209 |
+
Supports both standard and FP8 quantized operations.
|
| 1210 |
+
"""
|
| 1211 |
+
|
| 1212 |
+
@staticmethod
|
| 1213 |
+
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
|
| 1214 |
+
"""
|
| 1215 |
+
Forward pass of GroupedGEMM.
|
| 1216 |
+
|
| 1217 |
+
Args:
|
| 1218 |
+
x: Input tensor, shape [M_total, K]
|
| 1219 |
+
w: Weight tensor, shape [N, K]
|
| 1220 |
+
m_sizes: Tensor of shape [G] containing the size of each group
|
| 1221 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
| 1222 |
+
tma_size: Size of TMA descriptor in bytes
|
| 1223 |
+
using_fp8: Whether to use FP8 quantization
|
| 1224 |
+
|
| 1225 |
+
Returns:
|
| 1226 |
+
Output tensor, shape [M_total, N]
|
| 1227 |
+
"""
|
| 1228 |
+
|
| 1229 |
+
# Use regular forward without quantization
|
| 1230 |
+
output = grouped_gemm_forward(
|
| 1231 |
+
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
+
# Save inputs and parameters for backward pass
|
| 1235 |
+
ctx.save_for_backward(x, w, m_sizes)
|
| 1236 |
+
ctx.use_tma = use_tma
|
| 1237 |
+
ctx.tma_size = tma_size
|
| 1238 |
+
|
| 1239 |
+
ctx.save_for_backward(x, w, m_sizes)
|
| 1240 |
+
|
| 1241 |
+
return output
|
| 1242 |
+
|
| 1243 |
+
@staticmethod
|
| 1244 |
+
def backward(ctx, grad_output):
|
| 1245 |
+
"""
|
| 1246 |
+
Backward pass of M*G GroupedGEMM.
|
| 1247 |
+
|
| 1248 |
+
Args:
|
| 1249 |
+
grad_output: Gradient of output, shape [M_total, N]
|
| 1250 |
+
|
| 1251 |
+
Returns:
|
| 1252 |
+
Tuple of gradients:
|
| 1253 |
+
- grad_x: Gradient with respect to x, shape [M_total, K]
|
| 1254 |
+
- grad_w: Gradient with respect to w, shape [N, K]
|
| 1255 |
+
- None: Gradient with respect to m_sizes (not differentiable)
|
| 1256 |
+
- None: Gradient with respect to use_tma (not differentiable)
|
| 1257 |
+
- None: Gradient with respect to tma_size (not differentiable)
|
| 1258 |
+
|
| 1259 |
+
"""
|
| 1260 |
+
# Retrieve saved tensors and parameters
|
| 1261 |
+
|
| 1262 |
+
x, w, m_sizes = ctx.saved_tensors
|
| 1263 |
+
|
| 1264 |
+
use_tma = ctx.use_tma
|
| 1265 |
+
tma_size = ctx.tma_size
|
| 1266 |
+
|
| 1267 |
+
# Compute gradients using the unified implementation
|
| 1268 |
+
grad_x, grad_w = grouped_gemm_backward(
|
| 1269 |
+
grad_output=grad_output,
|
| 1270 |
+
x=x,
|
| 1271 |
+
w=w,
|
| 1272 |
+
m_sizes=m_sizes,
|
| 1273 |
+
use_tma=use_tma,
|
| 1274 |
+
tma_size=tma_size,
|
| 1275 |
+
)
|
| 1276 |
+
|
| 1277 |
+
# Return gradients for all inputs (None for non-differentiable parameters)
|
| 1278 |
+
return grad_x, grad_w, None, None
|
| 1279 |
+
|
| 1280 |
+
|
| 1281 |
+
def mg_grouped_gemm(
|
| 1282 |
+
x: torch.Tensor,
|
| 1283 |
+
w: torch.Tensor,
|
| 1284 |
+
m_sizes: torch.Tensor,
|
| 1285 |
+
use_tma: bool = True,
|
| 1286 |
+
tma_size: int = 128,
|
| 1287 |
+
using_fp8: bool = False,
|
| 1288 |
+
) -> torch.Tensor:
|
| 1289 |
+
"""
|
| 1290 |
+
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
|
| 1291 |
+
Supports both standard precision and FP8 quantized operations.
|
| 1292 |
+
|
| 1293 |
+
Args:
|
| 1294 |
+
x: Input tensor, shape [M_total, K]
|
| 1295 |
+
w: Weight tensor, shape [N, K]
|
| 1296 |
+
m_sizes: Tensor of shape [G] containing the size of each group
|
| 1297 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
| 1298 |
+
tma_size: Size of TMA descriptor in bytes
|
| 1299 |
+
using_fp8: Whether to use FP8 quantization
|
| 1300 |
+
|
| 1301 |
+
Returns:
|
| 1302 |
+
Output tensor, shape [M_total, N]
|
| 1303 |
+
"""
|
| 1304 |
+
return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# pyre-unsafe
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
# Configure logging
|
| 14 |
+
logging.basicConfig(
|
| 15 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def compute_reference_forward(x, w, m_sizes):
|
| 20 |
+
"""
|
| 21 |
+
Compute reference forward pass using PyTorch operations.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
x (torch.Tensor): Input tensor of shape (M, K)
|
| 25 |
+
w (torch.Tensor): Weight tensor of shape (N, K)
|
| 26 |
+
m_sizes (torch.Tensor): Group sizes tensor of shape (G)
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
torch.Tensor: Reference output tensor of shape (M, N)
|
| 30 |
+
"""
|
| 31 |
+
result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
|
| 32 |
+
|
| 33 |
+
m_start = 0
|
| 34 |
+
for g in range(len(m_sizes)):
|
| 35 |
+
m_size = m_sizes[g].item()
|
| 36 |
+
if m_size > 0:
|
| 37 |
+
m_end = m_start + m_size
|
| 38 |
+
|
| 39 |
+
# Extract group input
|
| 40 |
+
x_g = x[m_start:m_end]
|
| 41 |
+
|
| 42 |
+
# Compute group output: y_g = x_g @ w.T
|
| 43 |
+
y_g = torch.matmul(x_g, w.T)
|
| 44 |
+
|
| 45 |
+
# Store result
|
| 46 |
+
result[m_start:m_end] = y_g
|
| 47 |
+
|
| 48 |
+
# Update start index
|
| 49 |
+
m_start = m_end
|
| 50 |
+
|
| 51 |
+
return result
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def compute_reference_backward(x, w, m_sizes, grad_output):
|
| 55 |
+
"""
|
| 56 |
+
Compute reference backward pass using PyTorch autograd.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
x (torch.Tensor): Input tensor of shape (M, K)
|
| 60 |
+
w (torch.Tensor): Weight tensor of shape (N, K)
|
| 61 |
+
m_sizes (torch.Tensor): Group sizes tensor of shape (G)
|
| 62 |
+
grad_output (torch.Tensor): Gradient tensor of shape (M, N)
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
tuple: (grad_x, grad_w) gradient tensors
|
| 66 |
+
"""
|
| 67 |
+
# Create autograd-enabled copies
|
| 68 |
+
x_autograd = x.detach().clone().requires_grad_(True)
|
| 69 |
+
w_autograd = w.detach().clone().requires_grad_(True)
|
| 70 |
+
|
| 71 |
+
# Compute forward pass
|
| 72 |
+
output = compute_reference_forward(x_autograd, w_autograd, m_sizes)
|
| 73 |
+
|
| 74 |
+
# Backpropagate
|
| 75 |
+
output.backward(grad_output)
|
| 76 |
+
|
| 77 |
+
return x_autograd.grad, w_autograd.grad
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def analyze_tensor_differences(actual, expected, name):
|
| 81 |
+
"""
|
| 82 |
+
Analyze differences between actual and expected tensors.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
actual (torch.Tensor): Actual tensor
|
| 86 |
+
expected (torch.Tensor): Expected tensor
|
| 87 |
+
name (str): Name of the tensor for logging
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
bool: True if tensors are close enough
|
| 91 |
+
"""
|
| 92 |
+
rtol = 0.5 # Relative tolerance for float16
|
| 93 |
+
atol = 0.5 # Absolute tolerance for float16
|
| 94 |
+
|
| 95 |
+
# Analyze differences
|
| 96 |
+
diff = (actual - expected).abs()
|
| 97 |
+
max_idx = diff.argmax().item()
|
| 98 |
+
idx = np.unravel_index(max_idx, actual.shape)
|
| 99 |
+
max_diff = diff.max().item()
|
| 100 |
+
|
| 101 |
+
logging.info(f"Largest {name} difference: {max_diff} at {idx}")
|
| 102 |
+
logging.info(f"Values: {actual[idx].item()} vs {expected[idx].item()}")
|
| 103 |
+
|
| 104 |
+
is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol)
|
| 105 |
+
|
| 106 |
+
if is_close:
|
| 107 |
+
logging.info(f"✓ SUCCESS: {name} matches PyTorch reference")
|
| 108 |
+
else:
|
| 109 |
+
logging.error(f"✗ FAILURE: {name} mismatch detected")
|
| 110 |
+
|
| 111 |
+
# Count zeros
|
| 112 |
+
zeros_actual = (actual == 0).sum().item()
|
| 113 |
+
zeros_expected = (expected == 0).sum().item()
|
| 114 |
+
logging.info(
|
| 115 |
+
f"Zeros in {name} (actual): {zeros_actual}/{actual.numel()} ({zeros_actual/actual.numel()*100:.2f}%)"
|
| 116 |
+
)
|
| 117 |
+
logging.info(
|
| 118 |
+
f"Zeros in {name} (expected): {zeros_expected}/{expected.numel()} ({zeros_expected/expected.numel()*100:.2f}%)"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Check for NaNs
|
| 122 |
+
nan_actual = torch.isnan(actual).sum().item()
|
| 123 |
+
if nan_actual > 0:
|
| 124 |
+
logging.error(f"NaN values detected in {name}: {nan_actual}")
|
| 125 |
+
|
| 126 |
+
return is_close
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
|
| 8 |
+
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
| 9 |
+
|
| 10 |
+
# pyre-unsafe
|
| 11 |
+
import functools
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from typing import Any, Dict, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
import triton
|
| 20 |
+
import triton.language as tl
|
| 21 |
+
from triton import Config as TConfig
|
| 22 |
+
|
| 23 |
+
from triton.runtime import driver # @manual
|
| 24 |
+
|
| 25 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ===== Supporting utils, CUDA and TMA =====
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CudaUtils:
|
| 32 |
+
@staticmethod
|
| 33 |
+
def is_cuda() -> bool:
|
| 34 |
+
"""Check if Triton is running on CUDA backend."""
|
| 35 |
+
return driver.active.get_current_target().backend == "cuda"
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def verify_tma() -> bool:
|
| 39 |
+
"""Check if TMA is supported on the current device."""
|
| 40 |
+
return (
|
| 41 |
+
CudaUtils.is_cuda()
|
| 42 |
+
and torch.cuda.is_available()
|
| 43 |
+
and torch.cuda.get_device_capability()[0] >= 9
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_num_sms() -> int:
|
| 48 |
+
"""Get the number of streaming multiprocessors on the current device."""
|
| 49 |
+
if not CudaUtils.is_cuda():
|
| 50 |
+
raise RuntimeError("Triton is not running on CUDA backend")
|
| 51 |
+
if not torch.cuda.is_available():
|
| 52 |
+
raise RuntimeError("CUDA is not available")
|
| 53 |
+
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TmaDescriptorHelper:
|
| 57 |
+
"""Helper class for managing TMA descriptors in Triton kernels."""
|
| 58 |
+
|
| 59 |
+
class KernelParamWrapper:
|
| 60 |
+
"""Wrapper to implement the TmaDescKernelParam interface."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, desc: torch.Tensor):
|
| 63 |
+
self.desc = desc
|
| 64 |
+
|
| 65 |
+
def tma_desc_cpu_ptr(self) -> int:
|
| 66 |
+
"""Return the CPU pointer to the TMA descriptor."""
|
| 67 |
+
return self.desc.data_ptr()
|
| 68 |
+
|
| 69 |
+
def __init__(self, tma_size: int = 128):
|
| 70 |
+
"""Initialize the TMA descriptor helper.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
tma_size: Size of the TMA descriptor in bytes
|
| 74 |
+
"""
|
| 75 |
+
if not CudaUtils.verify_tma():
|
| 76 |
+
raise RuntimeError(
|
| 77 |
+
"TMA not supported on this device (requires Hopper or newer)"
|
| 78 |
+
)
|
| 79 |
+
if "nv_tma_desc_type" not in dir(tl):
|
| 80 |
+
raise RuntimeError(
|
| 81 |
+
"TMA grid constant descriptors not supported in your Triton version"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.tma_size = tma_size
|
| 85 |
+
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
| 86 |
+
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
| 87 |
+
self.descriptors: Dict[str, torch.Tensor] = {}
|
| 88 |
+
|
| 89 |
+
def init_tma_descriptor(self, name: str) -> None:
|
| 90 |
+
"""Initialize a TMA descriptor with the given name.
|
| 91 |
+
|
| 92 |
+
Call this method outside of the lambda function for grid size.
|
| 93 |
+
"""
|
| 94 |
+
self.descriptors[name] = torch.empty(
|
| 95 |
+
self.tma_size, device="cpu", dtype=torch.int8
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def fill_1d_tma_descriptor(
|
| 99 |
+
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
| 100 |
+
) -> None:
|
| 101 |
+
"""Fill a 1D TMA descriptor.
|
| 102 |
+
|
| 103 |
+
Call this method inside the lambda function for grid size.
|
| 104 |
+
"""
|
| 105 |
+
if name not in self.descriptors:
|
| 106 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 107 |
+
|
| 108 |
+
desc_x = self.descriptors[name]
|
| 109 |
+
if desc_x.data_ptr() % 64 != 0:
|
| 110 |
+
raise ValueError("TMA descriptor must be 64-byte aligned")
|
| 111 |
+
self.fill_1d_tma_descriptor_inner(
|
| 112 |
+
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def fill_2d_tma_descriptor(
|
| 116 |
+
self,
|
| 117 |
+
name: str,
|
| 118 |
+
ptr: int,
|
| 119 |
+
dim1: int,
|
| 120 |
+
dim0: int,
|
| 121 |
+
block_dim1: int,
|
| 122 |
+
block_dim0: int,
|
| 123 |
+
element_size: int,
|
| 124 |
+
) -> None:
|
| 125 |
+
"""Fill a 2D TMA descriptor.
|
| 126 |
+
|
| 127 |
+
Call this method inside the lambda function for grid size.
|
| 128 |
+
"""
|
| 129 |
+
if name not in self.descriptors:
|
| 130 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 131 |
+
|
| 132 |
+
desc_x = self.descriptors[name]
|
| 133 |
+
if desc_x.data_ptr() % 64 != 0:
|
| 134 |
+
raise ValueError("TMA descriptor must be 64-byte aligned")
|
| 135 |
+
self.fill_2d_tma_descriptor_inner(
|
| 136 |
+
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
|
| 140 |
+
"""Get the TMA descriptor kernel parameter for the given name."""
|
| 141 |
+
if name not in self.descriptors or self.descriptors[name] is None:
|
| 142 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 143 |
+
return self.KernelParamWrapper(self.descriptors[name])
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ====== Autotuning utilities ======
|
| 147 |
+
ALIGN_SIZE_M = 128
|
| 148 |
+
|
| 149 |
+
_NV_CONFIGS = [
|
| 150 |
+
triton.Config(
|
| 151 |
+
{
|
| 152 |
+
"BLOCK_SIZE_M": block_size_m,
|
| 153 |
+
"BLOCK_SIZE_N": block_size_n,
|
| 154 |
+
"BLOCK_SIZE_K": block_size_k,
|
| 155 |
+
},
|
| 156 |
+
num_stages=num_stages,
|
| 157 |
+
num_warps=num_warps,
|
| 158 |
+
num_ctas=num_ctas,
|
| 159 |
+
)
|
| 160 |
+
for block_size_m in [ALIGN_SIZE_M, ]
|
| 161 |
+
for block_size_n in [64, 128, 256]
|
| 162 |
+
for block_size_k in [64, 128, 256]
|
| 163 |
+
for num_stages in [3, 4]
|
| 164 |
+
for num_warps in [4, 8]
|
| 165 |
+
for num_ctas in [1]
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
| 170 |
+
device = torch.cuda.current_device()
|
| 171 |
+
# Check for all possible pointer parameter names
|
| 172 |
+
if "grad_input_ptr" in named_args:
|
| 173 |
+
ptr_name = "grad_input_ptr"
|
| 174 |
+
elif "c_ptr" in named_args:
|
| 175 |
+
ptr_name = "c_ptr"
|
| 176 |
+
elif "grad_weight_ptr" in named_args:
|
| 177 |
+
ptr_name = "grad_weight_ptr"
|
| 178 |
+
else:
|
| 179 |
+
raise KeyError("No recognized pointer parameter found in kernel arguments")
|
| 180 |
+
|
| 181 |
+
if dtsize is None:
|
| 182 |
+
dtsize = named_args[ptr_name].element_size()
|
| 183 |
+
if dtype is None:
|
| 184 |
+
dtype = named_args[ptr_name].dtype
|
| 185 |
+
|
| 186 |
+
pruned_configs = []
|
| 187 |
+
for config in configs:
|
| 188 |
+
kw = config.kwargs
|
| 189 |
+
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
|
| 190 |
+
kw["BLOCK_SIZE_M"],
|
| 191 |
+
kw["BLOCK_SIZE_N"],
|
| 192 |
+
kw["BLOCK_SIZE_K"],
|
| 193 |
+
config.num_stages,
|
| 194 |
+
)
|
| 195 |
+
G, M, N, K = (
|
| 196 |
+
named_args["G"],
|
| 197 |
+
named_args["M_BUCKET"],
|
| 198 |
+
named_args["N"],
|
| 199 |
+
named_args["K"],
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# 1. make sure we have enough smem
|
| 203 |
+
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
| 204 |
+
"max_shared_mem"
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
| 208 |
+
if required_shared_memory > max_shared_memory:
|
| 209 |
+
continue
|
| 210 |
+
|
| 211 |
+
M_PER_GROUP = M // G
|
| 212 |
+
MIN_M_TILES = 64
|
| 213 |
+
# 2. make sure we don't load M tiles that are too big
|
| 214 |
+
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
| 215 |
+
continue
|
| 216 |
+
# 3. make sure we don't load N tiles that are too small
|
| 217 |
+
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
+
num_sm = driver.active.utils.get_device_properties(device)[
|
| 221 |
+
"multiprocessor_count"
|
| 222 |
+
]
|
| 223 |
+
N_TILES = N // BLOCK_N
|
| 224 |
+
MIN_N_TILES = 64
|
| 225 |
+
# 4. make sure we don't load N tiles that are too big
|
| 226 |
+
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
| 227 |
+
continue
|
| 228 |
+
# 5. make sure we don't load N tiles that are too small
|
| 229 |
+
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
| 230 |
+
continue
|
| 231 |
+
# 6. make sure K can be evenly divided
|
| 232 |
+
if K % BLOCK_K != 0:
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
pruned_configs.append(config)
|
| 236 |
+
|
| 237 |
+
return pruned_configs
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ======== End Autotuning utilities ========
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# pyre-unsafe
|
| 8 |
+
import logging
|
| 9 |
+
import unittest
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
from mg_grouped_gemm import grouped_gemm_forward
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestMG_GroupedGEMM(unittest.TestCase):
|
| 19 |
+
def setUp(self) -> None:
|
| 20 |
+
torch.manual_seed(2020)
|
| 21 |
+
|
| 22 |
+
def _run_grouped_gemm_test(
|
| 23 |
+
self,
|
| 24 |
+
shape: Tuple[int, int, int, int],
|
| 25 |
+
device: torch.device,
|
| 26 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 27 |
+
atol: float = 1e-5,
|
| 28 |
+
rtol: float = 1.6e-2,
|
| 29 |
+
) -> None:
|
| 30 |
+
G, M, N, K = shape
|
| 31 |
+
# In M*G grouping, input is [M*G, K] and weights are [N*G, K]
|
| 32 |
+
a = torch.randn(M * G, K, dtype=dtype, device=device)
|
| 33 |
+
b = torch.randn(N * G, K, dtype=dtype, device=device)
|
| 34 |
+
|
| 35 |
+
# Create equal-sized groups for simplicity
|
| 36 |
+
m_size = M
|
| 37 |
+
m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
|
| 38 |
+
|
| 39 |
+
result = grouped_gemm_forward(a, b, m_sizes)
|
| 40 |
+
self.assertTrue(result.shape == (M * G, N))
|
| 41 |
+
|
| 42 |
+
expected_result = torch.zeros(M * G, N, dtype=dtype, device=device)
|
| 43 |
+
m_start = 0
|
| 44 |
+
for g in range(G):
|
| 45 |
+
m_end = m_start + m_sizes[g]
|
| 46 |
+
b_slice = b[N * g : N * (g+1), :]
|
| 47 |
+
expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T
|
| 48 |
+
m_start = m_end
|
| 49 |
+
|
| 50 |
+
# Convert result to match input dtype if needed
|
| 51 |
+
result = result.to(dtype)
|
| 52 |
+
torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol)
|
| 53 |
+
|
| 54 |
+
def test_MG_grouped_gemm_bf16(self) -> None:
|
| 55 |
+
for G in (1, 4, 16):
|
| 56 |
+
for M in (128, 512, 1024):
|
| 57 |
+
print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}")
|
| 58 |
+
self._run_grouped_gemm_test(
|
| 59 |
+
(G, M, 1024, 1024),
|
| 60 |
+
torch.device("cuda"),
|
| 61 |
+
dtype=torch.bfloat16,
|
| 62 |
+
atol=1e-5,
|
| 63 |
+
rtol=1.6e-2,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def test_MG_grouped_gemm_deepseek_shapes(self) -> None:
|
| 67 |
+
"""Test with shapes from Deepseek model."""
|
| 68 |
+
deepseek_shapes = [
|
| 69 |
+
(4, 2048, 4096, 7168), # G, M, N, K
|
| 70 |
+
(4, 2048, 7168, 2048),
|
| 71 |
+
(8, 512, 4096, 7168),
|
| 72 |
+
(8, 512, 7168, 2048),
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
device = torch.device("cuda")
|
| 76 |
+
|
| 77 |
+
for shape in deepseek_shapes:
|
| 78 |
+
G, M, N, K = shape
|
| 79 |
+
print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}")
|
| 80 |
+
self._run_grouped_gemm_test(
|
| 81 |
+
shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2
|
| 82 |
+
)
|