zaydzuhri commited on
Commit
7c7a501
·
verified ·
1 Parent(s): 801b04c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +471 -0
  3. config.json +34 -0
  4. configs/delta_net_1B.json +29 -0
  5. configs/delta_net_340M.json +27 -0
  6. configs/gla_340M.json +24 -0
  7. configs/gla_7B.json +25 -0
  8. configs/gsa_340M.json +29 -0
  9. configs/mtp_transformer_340M.json +19 -0
  10. configs/top_transformer_1B.json +24 -0
  11. configs/top_transformer_340M.json +20 -0
  12. configs/transformer_120M.json +18 -0
  13. configs/transformer_7B.json +21 -0
  14. fla/utils.py +223 -0
  15. generation_config.json +6 -0
  16. logs/none_ro0qpaac/attempt_0/0/stderr.log +0 -0
  17. logs/none_ro0qpaac/attempt_0/1/stderr.log +0 -0
  18. logs/none_ro0qpaac/attempt_0/3/stderr.log +0 -0
  19. logs/none_ro0qpaac/attempt_0/5/stderr.log +0 -0
  20. logs/none_ro0qpaac/attempt_0/6/stderr.log +0 -0
  21. logs/none_ro0qpaac/attempt_0/7/stderr.log +0 -0
  22. pyproject.toml +43 -0
  23. setup.py +51 -0
  24. special_tokens_map.json +23 -0
  25. tokenizer.json +0 -0
  26. tokenizer_config.json +44 -0
  27. torchtitan/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
  28. torchtitan/components/__pycache__/dataloader.cpython-312.pyc +0 -0
  29. torchtitan/components/__pycache__/float8.cpython-312.pyc +0 -0
  30. torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
  31. torchtitan/components/__pycache__/metrics.cpython-312.pyc +0 -0
  32. torchtitan/components/__pycache__/optimizer.cpython-312.pyc +0 -0
  33. torchtitan/components/__pycache__/tokenizer.cpython-312.pyc +0 -0
  34. torchtitan/distributed/utils.py +311 -0
  35. torchtitan/experiments/deepseek_v3/README.md +40 -0
  36. torchtitan/experiments/deepseek_v3/indices.py +195 -0
  37. torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py +11 -0
  38. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
  39. torchtitan/experiments/deepseek_v3/train.py +142 -0
  40. torchtitan/experiments/flux/README.md +23 -0
  41. torchtitan/experiments/flux/dataset/flux_dataset.py +267 -0
  42. torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
  43. torchtitan/experiments/flux/model/hf_embedder.py +40 -0
  44. torchtitan/experiments/flux/model/layers.py +286 -0
  45. torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
  46. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py +299 -0
  47. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py +1304 -0
  48. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py +126 -0
  49. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py +240 -0
  50. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py +82 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Linear Attention Made Easy
4
+
5
+ </div>
6
+
7
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency.
8
+
9
+ **Feature Highlights:**
10
+
11
+ - 🚀 Minimal, easy-to-use, extensible training framework
12
+ - 🤗 Seamless integration with `fla` and `transformers`
13
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
14
+ - 🔮 4D parallelism (coming soon)
15
+
16
+ ## Setup
17
+
18
+ To get started, clone the `flame` repository and install the required dependencies:
19
+
20
+ ```bash
21
+ git clone https://github.com/fla-org/flame.git
22
+ cd flame
23
+ pip install .
24
+ ```
25
+
26
+ `flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules.
27
+ After installation, initialize and update the submodules:
28
+ ```sh
29
+ git submodule update --init --recursive
30
+ ```
31
+
32
+ ## Dataset Preparation
33
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
34
+
35
+ ```py
36
+ from datasets import load_dataset
37
+
38
+ # load fineweb-edu with parallel processing
39
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
40
+
41
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
42
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
43
+ ```
44
+
45
+ ## Training Recipes
46
+
47
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode.
48
+
49
+ > [!WARNING]
50
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
51
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
52
+
53
+ ```sh
54
+ bash train.sh \
55
+ --job.config_file flame/models/fla.toml \
56
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \
57
+ --model.config configs/transformer_340M.json \
58
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
59
+ --optimizer.name AdamW \
60
+ --optimizer.eps 1e-15 \
61
+ --optimizer.lr 3e-4 \
62
+ --lr_scheduler.warmup_steps 1024 \
63
+ --lr_scheduler.lr_min 0.1 \
64
+ --lr_scheduler.decay_type cosine \
65
+ --training.batch_size 1 \
66
+ --training.seq_len 65536 \
67
+ --training.context_len 4096 \
68
+ --training.varlen \
69
+ --training.gradient_accumulation_steps 1 \
70
+ --training.steps 20480 \
71
+ --training.max_norm 1.0 \
72
+ --training.skip_nan_inf \
73
+ --training.dataset HuggingFaceFW/fineweb-edu \
74
+ --training.dataset_name sample-100BT \
75
+ --training.dataset_split train \
76
+ --training.streaming \
77
+ --training.num_workers 32 \
78
+ --training.prefetch_factor 2 \
79
+ --training.seed 42 \
80
+ --training.compile \
81
+ --checkpoint.interval 2048 \
82
+ --checkpoint.load_step -1 \
83
+ --checkpoint.keep_latest_k 2 \
84
+ --metrics.log_freq 1
85
+ ```
86
+
87
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
88
+ **For single-GPU debugging, set `NGPU=1`.**
89
+
90
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
91
+ By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
92
+
93
+ **Key parameters:**
94
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
95
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
96
+ - `--training.steps`: Total number of training steps.
97
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
98
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
99
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
100
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
101
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
102
+
103
+ > [!WARNING]
104
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
105
+ > Each step processes `global_batch_size * seq_len` tokens.
106
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
107
+
108
+ For a detailed explanation of all parameters, run:
109
+
110
+ ```sh
111
+ bash train.sh -h
112
+ ```
113
+
114
+ <details>
115
+ <summary>Usage</summary>
116
+
117
+ ```py
118
+ options:
119
+ -h, --help show this help message and exit
120
+ --job.config_file JOB.CONFIG_FILE
121
+ Job config file
122
+ --job.dump_folder JOB.DUMP_FOLDER
123
+ Folder to dump job outputs
124
+ --job.description JOB.DESCRIPTION
125
+ Description of the job
126
+ --job.use_for_integration_test
127
+ Add this config to the integration test suite
128
+ --job.print_args Print the args to terminal
129
+ --model.config MODEL.CONFIG
130
+ Path to the model config
131
+ --model.norm_type MODEL.NORM_TYPE
132
+ Type of layer normalization to use [layernorm,
133
+ np_layernorm, rmsnorm, fused_rmsnorm]
134
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
135
+ Tokenizer path
136
+ --profiling.enable_profiling
137
+ Whether to enable pytorch profiler
138
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
139
+ Trace files location
140
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
141
+ How often to collect profiler traces, in iterations
142
+ --profiling.enable_memory_snapshot
143
+ Whether to dump memory snapshot
144
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
145
+ Memeory snapshot files location
146
+ --optimizer.name OPTIMIZER.NAME
147
+ Optimizer to use
148
+ --optimizer.eps OPTIMIZER.EPS
149
+ Epsilon value for the optimizer.
150
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
151
+ --optimizer.scheduler {wsd,cosine,linear}
152
+ Scheduler to use. Currently supported: wsd, cosine,
153
+ and linear.
154
+ --optimizer.lr OPTIMIZER.LR
155
+ Learning rate to use
156
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
157
+ Min lr ratio for lr scheduler
158
+ --optimizer.early_step_in_backward
159
+ Whether to apply optimizer in the backward. Caution,
160
+ optimizer_in_backward is not compatible with gradients
161
+ clipping, users should not call
162
+ register_post_accumulate_grad_hook after the optimizer
163
+ is built.
164
+ --training.batch_size TRAINING.BATCH_SIZE
165
+ Batch size
166
+ --training.seq_len TRAINING.SEQ_LEN
167
+ Sequence length
168
+ --training.context_len TRAINING.CONTEXT_LEN
169
+ Max length allowed for each sequence
170
+ --training.varlen Whether to take sequences of variable length as input
171
+ --training.warmup_steps TRAINING.WARMUP_STEPS
172
+ Steps for lr scheduler warmup, normally 1/5 of
173
+ --training.steps
174
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
175
+ Number of steps to accumulate gradients before
176
+ updating parameters
177
+ --training.steps TRAINING.STEPS
178
+ How many train steps to run
179
+ --training.max_norm TRAINING.MAX_NORM
180
+ Max norm for gradient clipping
181
+ --training.skip_nan_inf
182
+ Skip batch updates when NaN or INF gradients are
183
+ encountered during training
184
+ --training.dataset TRAINING.DATASET
185
+ Dataset to use, with comma separated values
186
+ --training.dataset_name TRAINING.DATASET_NAME
187
+ The name of the dataset config, with comma separated
188
+ values if provided
189
+ --training.dataset_split TRAINING.DATASET_SPLIT
190
+ Dataset split to use, with comma separated values if
191
+ provided
192
+ --training.data_dir TRAINING.DATA_DIR
193
+ Data dirs to use, with comma separated values if
194
+ provided
195
+ --training.data_files TRAINING.DATA_FILES
196
+ Data files to use, with comma separated values if
197
+ provided
198
+ --training.data_probs TRAINING.DATA_PROBS
199
+ Data sampling probabilities, with comma separated
200
+ values if provided
201
+ --training.streaming Whether to load dataset in streaming mode, used for
202
+ huge dataset
203
+ --training.num_workers TRAINING.NUM_WORKERS
204
+ Number of subprocesses to use for data loading. 0
205
+ means that the data will be loaded in the main
206
+ process.
207
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
208
+ Number of batches loaded in advance by each worker.2
209
+ means there will be a total of 2 * num_workers batches
210
+ prefetched across all workers.
211
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
212
+ The `data_parallel_replicate_degree` argument
213
+ specifies the degree of data parallelism for weight
214
+ replication. When this value is greater than 1,
215
+ weights will be replicated across
216
+ `data_parallel_replicate_degree` ranks. If
217
+ `data_parallel_shard_degree` is also greater than 1,
218
+ the parallelism method used is HSDP (Hybrid Sharded
219
+ Data Parallelism). Otherwise, the parallelism method
220
+ used is DDP (Distributed Data Parallelism). 1 means
221
+ disabled.
222
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
223
+ The `data_parallel_shard_degree` argument specifies
224
+ the degree of data parallelism for weight sharding.
225
+ When this value is greater than 1, weights will be
226
+ sharded across `data_parallel_shard_degree` ranks. If
227
+ `data_parallel_replicate_degree` is also greater than
228
+ 1, the parallelism method used is HSDP (Hybrid Sharded
229
+ Data Parallelism). Otherwise, the parallelism method
230
+ used is FSDP (Fully Sharded Data Parallelism). -1
231
+ means leftover ranks will be used (After
232
+ DP_REPLICATE/SP/PP). Note that only
233
+ `data_parallel_shard_degree` can be negative. 1 means
234
+ disabled.
235
+ --training.enable_cpu_offload
236
+ Whether to apply CPU offloading of parameters,
237
+ gradients, and optimizer states in FSDP
238
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
239
+ Tensor Parallelism degree. 1 means disabled.
240
+ --training.disable_loss_parallel
241
+ Whether to apply loss parallel when sequence parallel
242
+ is enabled
243
+ --training.mixed_precision_param {bfloat16,float32}
244
+ torch dtype to use for parameters when applying mixed
245
+ precision via FSDP. This feature only takes effect
246
+ when data_parallel_shard_degree > 1
247
+ --training.mixed_precision_reduce {float32}
248
+ torch dtype to use for reductions when applying mixed
249
+ precision via FSDP. This feature only takes effect
250
+ when data_parallel_shard_degree > 1
251
+ --training.compile Whether to compile the model
252
+ --training.gc_freq TRAINING.GC_FREQ
253
+ Python garbage control scheduling interval, in steps
254
+ --training.seed TRAINING.SEED
255
+ Choose the base RNG seed used for training
256
+ --training.deterministic
257
+ Use deterministic algorithms wherever possible, may be
258
+ slower
259
+ --metrics.log_freq METRICS.LOG_FREQ
260
+ How often to log metrics to TensorBoard, in iterations
261
+ --metrics.enable_tensorboard
262
+ Whether to log metrics to TensorBoard
263
+ --metrics.disable_color_printing
264
+ Whether to disable color printing in logs
265
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
266
+ Folder to dump TensorBoard states
267
+ --metrics.rank_0_only
268
+ Whether to save TensorBoard metrics only for rank 0 or
269
+ for all ranks. When pipeline_parallel_degree is > 1,
270
+ this option uses the 0th rank of the last stage
271
+ pipeline group, which is the only stage that computes
272
+ loss metrics.
273
+ --metrics.enable_wandb
274
+ Whether to log metrics to Weights & Biases
275
+ --experimental.enable_async_tensor_parallel
276
+ Whether to apply async tensor parallel (currently only
277
+ effective when compile is enabled)
278
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
279
+ Pipeline Parallelism degree, or number of ranks. 1
280
+ means disabled. If using looped schedules, this still
281
+ specifies the number of physical ranks, not the number
282
+ of stages. Stages per rank are inferred from split
283
+ points degree, and schedule.
284
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
285
+ Specify comma-separated names of modules to use as the
286
+ beginning of a split point. e.g. "layers.0,layers.2"
287
+ will cause the model to be split into 3 stages, the
288
+ first containing all the layers up to layers.0, the
289
+ second containing layers.0 and up to layers.2, the
290
+ third containing layers.2 and all the remaining
291
+ layers. Note: fully-automated splitting may be enabled
292
+ in the future, but currently the split points must be
293
+ specified manually.
294
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
295
+ Specify the Pipeline Parallel schedule to use. The
296
+ supported schedules are: https://github.com/pytorch/py
297
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
298
+ rch/distributed/pipelining/schedules.py#L2161. The
299
+ schedule must be compatible with the split points and
300
+ stages_per_rank. Looped schedules (e.g.
301
+ Interleaved1F1B) require specifying
302
+ pipeline_parallel_degree = number of ranks, and
303
+ split_points = number of stages - 1
304
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
305
+ Specify the path to the pipeline parallel schedule csv
306
+ file to use. The pipeline_parallel_schedule argument
307
+ must be either PipelineScheduleSingle,
308
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
309
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
310
+ How many microbatches to split the global training
311
+ batch into when using pipeline parallelism. The global
312
+ training batch size must be evenly divisible by the
313
+ number of microbatches. The default value will be the
314
+ number of pipeline stages, if unspecified.
315
+ --experimental.enable_compiled_autograd
316
+ Enable CompiledAutograd to compile the backward.
317
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
318
+ Context parallelism degree. 1 means disabled.
319
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
320
+ The collective to use in context parallel SDPA for kv
321
+ shards exchange. 'allgather' means to all-gather all
322
+ kv shards on ranks after the first sub-SDPA
323
+ computation, 'alltoall' means to all-to-all shuffle
324
+ the kv shards. The default value is 'allgather'.
325
+ --checkpoint.enable_checkpoint
326
+ Whether to enable checkpoint
327
+ --checkpoint.folder CHECKPOINT.FOLDER
328
+ The folder to store the checkpoints. When
329
+ enable_checkpoint is set to true, checkpoints will be
330
+ in {--job.dump_folder}/{--checkpoint.folder}.
331
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
332
+ Checkpointing interval unit of measurement ['step',
333
+ 'seconds']
334
+ --checkpoint.interval CHECKPOINT.INTERVAL
335
+ Checkpointing interval, in steps or seconds depending
336
+ on --checkpoint.interval_type
337
+ --checkpoint.model_weights_only
338
+ When model_weights_only=True, only model weights will
339
+ be saved at the end of training. With this,
340
+ checkpoints can be loaded using `torch.load(...,
341
+ weights_only=True)` after conversion. When
342
+ model_weights_only=False, the full checkpoint will be
343
+ saved. A full checkpoint includes model, optimizer and
344
+ train_state, which can be used to resume training. The
345
+ default value is false.
346
+ --checkpoint.export_dtype {float16,bfloat16,float32}
347
+ Converts to the specified precision when training
348
+ completes and model_weights_only=true. Currently
349
+ supports float32, float16, and bfloat16. The default
350
+ value is float32.
351
+ --checkpoint.create_seed_checkpoint
352
+ Initializes the full model without applying
353
+ parallelisms, and then saves it as a seed checkpoint.
354
+ Note: requires user to call train.py without
355
+ specifying any parallelisms, e.g. NGPU=1. Could be
356
+ implemented as a separate script, but this way shares
357
+ more code.
358
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
359
+ Which async checkpoint mode to use. Currently there
360
+ are 3 different modes. 1. "disabled": synchronized
361
+ checkpointing will be used. 2. "async":
362
+ torch.distributed.checkpoint.async_save will be used.
363
+ 1. "async_with_pinned_mem": this option utilizes a
364
+ dedicated pinned memory space and creates a separate
365
+ process for faster GPU->CPU transfer performance and
366
+ eliminating GIL contention. The cost is increased CPU
367
+ memory usage. If insufficient CPU memory is available,
368
+ performance may degrade due to memory paging. For most
369
+ users, "async" should suffice as the performance
370
+ overhead is typically small (on the order of tens of
371
+ seconds) compared to checkpointing frequency. This
372
+ mode can be employed to pursue near-zero checkpointing
373
+ times (e.g., < 1 second) given appropriate hardware
374
+ support such as ample CPU memory and fast PCIe.
375
+ "disabled" is the default mode.
376
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
377
+ Keeps only the latest k checkpoints, and purging older
378
+ ones. If 0, keep all checkpoints. 0 is the default
379
+ value.
380
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
381
+ Load the checkpoint at the specified step. If -1, load
382
+ the latest checkpoint.
383
+ --float8.enable_float8_linear
384
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
385
+ This feature requires you to install 'torchao' which
386
+ can be found here: https://github.com/pytorch/ao
387
+ --float8.enable_fsdp_float8_all_gather
388
+ Whether enable float8 all-gather in FSDP
389
+ --float8.precompute_float8_dynamic_scale_for_fsdp
390
+ Whether precompute float8 scales dynamically for FSDP
391
+ --float8.scaling_type_input {dynamic,delayed}
392
+ float8 scaling for input, dynamic (default) or delayed
393
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
394
+ float8 scaling for input, dynamic (default) or delayed
395
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
396
+ float8 scaling for input, dynamic (default) or delayed
397
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
398
+ Timeout for communication operations, during
399
+ initialization and first train step.
400
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
401
+ Timeout for communication operations after the first
402
+ train step -- usually a tighter bound than during
403
+ initialization.
404
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
405
+ Flight recorder ring buffer size, >0 means recording
406
+ by default, 0 means disabled
407
+ --memory_estimation.enabled
408
+ Whether to estimate memory usage for FSDP
409
+ --memory_estimation.disable_fake_mode
410
+ Whether to estimate memory under FakeTensorMode
411
+ ```
412
+ </details>
413
+
414
+ ### Training with `torch.compile`
415
+
416
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
417
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
418
+
419
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
420
+ We are actively working on resolving these issues to make compilation transparent to users.
421
+ In the meantime, please ensure you are using the latest dependencies.
422
+
423
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
424
+
425
+ ### Training with multiple datasets
426
+
427
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
428
+ `flame` allows training with multiple datasets easily.
429
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
430
+
431
+ ```sh
432
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
433
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
434
+ ```
435
+
436
+ ### ~Finalizing training~
437
+
438
+ > [!NOTE]
439
+ > We have done this conversion automatically in the training script since our latest updates.
440
+
441
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
442
+ To facilitate this, we provide a straightforward conversion script:
443
+
444
+ ```sh
445
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
446
+ ```
447
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
448
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
449
+
450
+ ### Continual training
451
+
452
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
453
+ This allows you to seamlessly resume training with `flame`.
454
+ ```sh
455
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
456
+ ```
457
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
458
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
459
+
460
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
461
+
462
+ ## Multi-node training
463
+
464
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
465
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
466
+
467
+ To set up multi-node training:
468
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
469
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
470
+
471
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MTPTransformerForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "bos_token_id": 1,
7
+ "elementwise_affine": true,
8
+ "eos_token_id": 2,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "fuse_swiglu": true,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "max_position_embeddings": 8192,
18
+ "model_type": "mtp_transformer",
19
+ "n_future_tokens": 3,
20
+ "norm_eps": 1e-06,
21
+ "num_heads": 16,
22
+ "num_hidden_layers": 24,
23
+ "num_kv_heads": null,
24
+ "qk_norm": false,
25
+ "qkv_bias": false,
26
+ "rope_theta": 10000.0,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.51.3",
30
+ "use_cache": true,
31
+ "use_custom_backward": false,
32
+ "vocab_size": 32000,
33
+ "window_size": null
34
+ }
configs/delta_net_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "conv_size": 4,
6
+ "eos_token_id": 2,
7
+ "expand_k": 1,
8
+ "expand_v": 1,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.006,
15
+ "intermediate_size": null,
16
+ "model_type": "delta_net",
17
+ "norm_eps": 1e-06,
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 2,
21
+ "qk_activation": "silu",
22
+ "qk_norm": "l2",
23
+ "tie_word_embeddings": false,
24
+ "use_beta": true,
25
+ "use_cache": true,
26
+ "use_gate": false,
27
+ "use_output_norm": true,
28
+ "use_short_conv": true
29
+ }
configs/delta_net_340M.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.006,
13
+ "intermediate_size": null,
14
+ "model_type": "delta_net",
15
+ "norm_eps": 1e-06,
16
+ "norm_first": false,
17
+ "num_heads": 8,
18
+ "num_hidden_layers": 24,
19
+ "qk_activation": "silu",
20
+ "qk_norm": "l2",
21
+ "tie_word_embeddings": false,
22
+ "use_beta": true,
23
+ "use_cache": true,
24
+ "use_gate": false,
25
+ "use_output_norm": true,
26
+ "use_short_conv": true
27
+ }
configs/gla_340M.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": null,
15
+ "model_type": "gla",
16
+ "num_heads": 4,
17
+ "num_hidden_layers": 24,
18
+ "norm_eps": 1e-06,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "vocab_size": 32000
24
+ }
configs/gla_7B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": 11008,
15
+ "model_type": "gla",
16
+ "norm_eps": 1e-06,
17
+ "num_heads": 16,
18
+ "num_hidden_layers": 32,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "use_output_gate": true,
24
+ "use_short_conv": false
25
+ }
configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
configs/mtp_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "mtp_transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "n_future_tokens": 3
19
+ }
configs/top_transformer_1B.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "top_transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "use_top_loss": true,
23
+ "top_window_size": 4096
24
+ }
configs/top_transformer_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "top_transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "use_top_loss": true,
19
+ "top_window_size": 4096
20
+ }
configs/transformer_120M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_7B.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 30,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null
21
+ }
fla/utils.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import contextlib
4
+ import functools
5
+ import os
6
+ from enum import Enum
7
+ from functools import lru_cache
8
+ from typing import Any, Callable, Dict, Literal, Optional, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ from packaging import version
13
+
14
+
15
+ def tensor_cache(
16
+ fn: Callable[..., torch.Tensor]
17
+ ) -> Callable[..., torch.Tensor]:
18
+ """
19
+ A decorator that caches the most recent result of a function with tensor inputs.
20
+
21
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
22
+ If the function is called again with the same input tensors, it will return the cached result.
23
+
24
+
25
+ Args:
26
+ fn (Callable[..., torch.Tensor]):
27
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
28
+
29
+ Returns:
30
+ Callable[..., torch.Tensor]:
31
+ A wrapped version of the input function with single-entry caching.
32
+ """
33
+ last_args: Optional[Tuple] = None
34
+ last_kwargs: Optional[Dict] = None
35
+ last_result: Any = None
36
+
37
+ @functools.wraps(fn)
38
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
39
+ nonlocal last_args, last_kwargs, last_result
40
+
41
+ if last_args is not None and last_kwargs is not None:
42
+ if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
43
+ if all(a is b for a, b in zip(args, last_args)) and \
44
+ all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
45
+ return last_result
46
+
47
+ result = fn(*args, **kwargs)
48
+ last_args, last_kwargs, last_result = args, kwargs, result
49
+ return result
50
+
51
+ return wrapper
52
+
53
+
54
+ def input_guard(
55
+ fn: Callable[..., torch.Tensor]
56
+ ) -> Callable[..., torch.Tensor]:
57
+ """
58
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
59
+ """
60
+
61
+ @functools.wraps(fn)
62
+ def wrapper(*args, **kwargs):
63
+ contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
64
+ contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
65
+
66
+ tensor = None
67
+ for arg in args:
68
+ if isinstance(arg, torch.Tensor):
69
+ tensor = arg
70
+ break
71
+ if tensor is None:
72
+ for value in kwargs.values():
73
+ if isinstance(value, torch.Tensor):
74
+ tensor = value
75
+ break
76
+
77
+ if tensor is not None:
78
+ ctx = custom_device_ctx(tensor.device.index)
79
+ else:
80
+ ctx = contextlib.nullcontext()
81
+
82
+ with ctx:
83
+ return fn(*contiguous_args, **contiguous_kwargs)
84
+
85
+ return wrapper
86
+
87
+
88
+ contiguous = input_guard
89
+
90
+
91
+ def require_version(version, hint):
92
+ """
93
+ Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
94
+ """
95
+ def decorator(fn):
96
+ @functools.wraps(fn)
97
+ def wrapper(ctx, *args, **kwargs):
98
+ from transformers.utils.versions import require_version
99
+ require_version(version, hint)
100
+ return fn(ctx,
101
+ *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
102
+ **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
103
+ return wrapper
104
+ return decorator
105
+
106
+
107
+ def checkpoint(fn):
108
+ def wrapper(*args, **kwargs):
109
+ return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
110
+ return wrapper
111
+
112
+
113
+ @lru_cache(maxsize=None)
114
+ def check_pytorch_version(version_s: str = '2.4') -> bool:
115
+ return version.parse(torch.__version__) >= version.parse(version_s)
116
+
117
+
118
+ def _cpu_device_warning():
119
+ import warnings
120
+ warnings.warn(('Triton is not supported on current platform, roll back to CPU.'), stacklevel=1)
121
+
122
+
123
+ @lru_cache(maxsize=None)
124
+ def get_multiprocessor_count(tensor_idx: int = 0) -> int:
125
+ try:
126
+ # Only works if Homogeneous hardware
127
+ # TEMPORARY FIX since old version introduce graph break
128
+ return torch.cuda.get_device_properties().multi_processor_count
129
+ except BaseException:
130
+ _cpu_device_warning()
131
+ return -1
132
+
133
+
134
+ @lru_cache(maxsize=None)
135
+ def get_available_device() -> str:
136
+ try:
137
+ return triton.runtime.driver.active.get_current_target().backend
138
+ except BaseException:
139
+ _cpu_device_warning()
140
+ return 'cpu'
141
+
142
+
143
+ @lru_cache(maxsize=None)
144
+ def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
145
+ device = get_available_device()
146
+ if device == 'cuda':
147
+ return 'nvidia'
148
+ elif device == 'hip':
149
+ return 'amd'
150
+ elif device == 'xpu':
151
+ return 'intel'
152
+ else:
153
+ return device
154
+
155
+
156
+ # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
157
+ # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
158
+ # Therefore, we need to check the triton backend to determine the actual GPU vendor.
159
+ device = get_available_device() if get_available_device() != 'hip' else 'cuda'
160
+ device_torch_lib = getattr(torch, device)
161
+ device_platform = _check_platform()
162
+
163
+ is_amd = (device_platform == 'amd')
164
+ is_intel = (device_platform == 'intel')
165
+ is_nvidia = (device_platform == 'nvidia')
166
+ is_intel_alchemist = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
167
+ is_nvidia_hopper = (is_nvidia and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
168
+ use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
169
+
170
+ # Nvidia Ampere or newer, haven't check AMD and intel yet.
171
+ is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8)
172
+ is_gather_supported = hasattr(triton.language, 'gather')
173
+
174
+
175
+ def get_all_max_shared_mem():
176
+ try:
177
+ return [
178
+ triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem']
179
+ for i in range(device_torch_lib.device_count())
180
+ ]
181
+ except BaseException:
182
+ _cpu_device_warning()
183
+ return [-1]
184
+
185
+
186
+ class Backend(Enum):
187
+ ADA = 101376 # RTX 4090
188
+ AMPERE = 166912 # A100
189
+ HOPPER = 232448 # H100
190
+ DEFAULT = 102400 # Default
191
+
192
+ @classmethod
193
+ def get_shared_memory(cls, arch: str) -> int:
194
+ try:
195
+ return cls[arch.upper()].value
196
+ except KeyError:
197
+ return cls.DEFAULT.value
198
+
199
+
200
+ @lru_cache(maxsize=None)
201
+ def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
202
+ try:
203
+ device_shared_mem_list = get_all_max_shared_mem()
204
+ max_shared_memory = device_shared_mem_list[tensor_idx]
205
+ return max_shared_memory >= Backend.get_shared_memory(arch)
206
+ except Exception:
207
+ return False
208
+
209
+
210
+ if check_pytorch_version('2.4'):
211
+ device = 'cuda' if device == 'cpu' else device
212
+ autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
213
+ autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
214
+
215
+ def custom_device_ctx(index: int):
216
+ return device_torch_lib.device(index)
217
+ else:
218
+ assert device == 'cuda', 'Only cuda device is supported for PyTorch version < 2.4.0.'
219
+ autocast_custom_fwd = device_torch_lib.amp.custom_fwd
220
+ autocast_custom_bwd = device_torch_lib.amp.custom_bwd
221
+
222
+ def custom_device_ctx(index: int):
223
+ return torch.cuda.device(index)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.51.3"
6
+ }
logs/none_ro0qpaac/attempt_0/0/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_ro0qpaac/attempt_0/1/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_ro0qpaac/attempt_0/3/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_ro0qpaac/attempt_0/5/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_ro0qpaac/attempt_0/6/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_ro0qpaac/attempt_0/7/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "flame"
3
+ dynamic = ["version"]
4
+ description = "A minimal training framework for scaling FLA models"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Songlin Yang", email = "yangsl66@mit.edu" },
8
+ { name = "Yu Zhang", email = "yzhang.cs@outlook.com" },
9
+ ]
10
+ license = { file = "LICENSE" }
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: MIT License",
14
+ "Operating System :: OS Independent",
15
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
16
+ ]
17
+ requires-python = ">=3.10"
18
+ dependencies = [
19
+ 'torch==2.6',
20
+ 'torchdata',
21
+ 'transformers==4.51.3',
22
+ 'triton>=3.0',
23
+ 'datasets>=3.3.0',
24
+ 'einops',
25
+ 'ninja',
26
+ 'wandb',
27
+ 'tiktoken',
28
+ 'tensorboard',
29
+ 'python-dotenv'
30
+ ]
31
+
32
+ [project.optional-dependencies]
33
+ dev = ["pytest"]
34
+
35
+ [project.urls]
36
+ Homepage = "https://github.com/fla-org/flame"
37
+
38
+ [build-system]
39
+ requires = ["setuptools>=45", "wheel", "ninja", "torch"]
40
+
41
+ [tool.isort]
42
+ line_length = 127
43
+ multi_line_output = 3
setup.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import ast
4
+ import os
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from setuptools import find_packages, setup
9
+
10
+ with open('README.md') as f:
11
+ long_description = f.read()
12
+
13
+
14
+ def get_package_version():
15
+ with open(Path(os.path.dirname(os.path.abspath(__file__))) / 'flame' / '__init__.py') as f:
16
+ version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
17
+ return ast.literal_eval(version_match.group(1))
18
+
19
+
20
+ setup(
21
+ name='flame',
22
+ version=get_package_version(),
23
+ description='A minimal training framework for scaling FLA models',
24
+ long_description=long_description,
25
+ long_description_content_type='text/markdown',
26
+ author='Songlin Yang, Yu Zhang',
27
+ author_email='yangsl66@mit.edu, yzhang.cs@outlook.com',
28
+ url='https://github.com/fla-org/flame',
29
+ packages=find_packages(),
30
+ license='MIT',
31
+ classifiers=[
32
+ 'Programming Language :: Python :: 3',
33
+ 'License :: OSI Approved :: MIT License',
34
+ 'Operating System :: OS Independent',
35
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence'
36
+ ],
37
+ python_requires='>=3.10',
38
+ install_requires=[
39
+ 'torch==2.6',
40
+ 'torchdata',
41
+ 'transformers==4.51.3',
42
+ 'triton>=3.0',
43
+ 'datasets>=3.3.0',
44
+ 'einops',
45
+ 'ninja',
46
+ 'wandb',
47
+ 'tiktoken',
48
+ 'tensorboard',
49
+ 'python-dotenv'
50
+ ],
51
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "additional_special_tokens": [],
32
+ "bos_token": "<s>",
33
+ "clean_up_tokenization_spaces": false,
34
+ "eos_token": "</s>",
35
+ "extra_special_tokens": {},
36
+ "legacy": true,
37
+ "model_max_length": 1000000000000000019884624838656,
38
+ "pad_token": null,
39
+ "sp_model_kwargs": {},
40
+ "spaces_between_special_tokens": false,
41
+ "tokenizer_class": "LlamaTokenizer",
42
+ "unk_token": "<unk>",
43
+ "use_default_system_prompt": false
44
+ }
torchtitan/components/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (33.1 kB). View file
 
torchtitan/components/__pycache__/dataloader.cpython-312.pyc ADDED
Binary file (3.79 kB). View file
 
torchtitan/components/__pycache__/float8.cpython-312.pyc ADDED
Binary file (6.2 kB). View file
 
torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc ADDED
Binary file (7.71 kB). View file
 
torchtitan/components/__pycache__/metrics.cpython-312.pyc ADDED
Binary file (19.6 kB). View file
 
torchtitan/components/__pycache__/optimizer.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
torchtitan/components/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (1.09 kB). View file
 
torchtitan/distributed/utils.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import contextlib
8
+ import math
9
+ import os
10
+ from collections.abc import Generator, Iterable
11
+ from datetime import timedelta
12
+
13
+ import torch
14
+ import torch.distributed._functional_collectives as funcol
15
+ import torch.distributed.distributed_c10d as c10d
16
+ from torch import distributed as dist
17
+ from torch.distributed.device_mesh import DeviceMesh
18
+ from torch.distributed.tensor import DTensor
19
+
20
+ from torchtitan.components.ft import ft_clip_grad_norm_util, ft_dist_reduce
21
+ from torchtitan.tools.logging import logger
22
+ from torchtitan.tools.utils import device_module, device_type
23
+
24
+
25
+ def _dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float:
26
+ # Remove FT replicate dimension if it exists.
27
+ x, reduceOp, mesh = ft_dist_reduce(x, reduceOp, mesh)
28
+
29
+ if isinstance(x, DTensor):
30
+ # functional collectives do not support DTensor inputs
31
+ x = x.full_tensor()
32
+ assert x.numel() == 1 # required by `.item()`
33
+ return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()
34
+
35
+
36
+ def dist_max(x: torch.Tensor, mesh: DeviceMesh) -> float:
37
+ return _dist_reduce(x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh)
38
+
39
+
40
+ def dist_mean(x: torch.Tensor, mesh: DeviceMesh) -> float:
41
+ return _dist_reduce(x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh)
42
+
43
+
44
+ def set_determinism(
45
+ world_mesh: DeviceMesh | None,
46
+ device: torch.device,
47
+ seed: int | None = None,
48
+ deterministic: bool = False,
49
+ ) -> None:
50
+ """
51
+ Set the same DTensor manual seed for all ranks within the same DTensor SPMD group, but different
52
+ seeds across PP groups (if applicable).
53
+
54
+ Currently, does not set seeds for the CUDA RNG since TorchTitan always uses DTensor for SPMD parallelisms,
55
+ and DTensor manages its own RNG tracker, but we could extend to support both if needed.
56
+
57
+ Set Determinism flags for increased reproducibility with loss of performance.
58
+ """
59
+ if deterministic:
60
+ logger.info("Deterministic algorithm enabled (expect perf degradation).")
61
+ torch.use_deterministic_algorithms(True)
62
+ torch.backends.cudnn.deterministic = True
63
+ torch.backends.cudnn.benchmark = False
64
+ # env var for deterministic CuBLAS
65
+ # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
66
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
67
+
68
+ if not world_mesh:
69
+ if seed is not None:
70
+ torch.manual_seed(seed)
71
+ os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
72
+ logger.debug(f"Single-process job using seed: {seed}")
73
+ return
74
+
75
+ # to ensure we can control which ranks have same or different seeds, all ranks agree on a starting seed.
76
+ # if user provides one, we use this. Otherwise rank 0 rolls the dice and everyone else uses that.
77
+ if seed is None:
78
+ # Extract the seed for torch's main generator on rank 0 and standardizes on using that to build
79
+ # seeds for unique SPMD groups
80
+ seed_tensor = torch.get_rng_state()[:8].to(device)
81
+ torch.distributed.broadcast(seed_tensor, src=0)
82
+ seed = seed_tensor.to("cpu").view(torch.uint64).item()
83
+
84
+ # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh,
85
+ # and choose a unique seed for each rank on the PP mesh.
86
+ if c10d.get_world_size() > 1 and "pp" in world_mesh.mesh_dim_names:
87
+ pp_mesh = world_mesh["pp"]
88
+ seed += pp_mesh.get_local_rank()
89
+ seed %= 2**64
90
+
91
+ logger.debug(
92
+ f"PP rank {pp_mesh.get_local_rank()}, Global rank {c10d.get_rank()} using seed: {seed}"
93
+ )
94
+ spmd_mesh_dims = list(
95
+ filter(lambda name: name != "pp", world_mesh.mesh_dim_names)
96
+ )
97
+ spmd_mesh = world_mesh[spmd_mesh_dims] if len(spmd_mesh_dims) else None
98
+ else:
99
+ spmd_mesh = world_mesh
100
+ logger.debug(f"Global Rank {c10d.get_rank()} using seed: {seed}")
101
+
102
+ # The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency.
103
+ torch.manual_seed(seed)
104
+ # PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1]
105
+ os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
106
+
107
+ # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh.
108
+ # IF PP is also used, this seed is unique per PP rank.
109
+ if spmd_mesh and spmd_mesh.get_coordinate() is not None:
110
+ torch.distributed.tensor._random.manual_seed(seed, spmd_mesh)
111
+
112
+
113
+ def create_context_parallel_ctx(
114
+ cp_mesh: DeviceMesh,
115
+ cp_buffers: list[torch.Tensor],
116
+ cp_seq_dims: list[int],
117
+ cp_no_restore_buffers: set[torch.Tensor],
118
+ cp_rotate_method: str,
119
+ ):
120
+ try:
121
+ from torch.distributed.tensor.experimental import context_parallel
122
+ from torch.distributed.tensor.experimental._attention import set_rotate_method
123
+ except ImportError:
124
+ print(
125
+ f"PyTorch version {torch.__version__} does not include the experimental "
126
+ "Context Parallel API. Please update to a newer version."
127
+ )
128
+
129
+ set_rotate_method(cp_rotate_method)
130
+ return context_parallel(
131
+ cp_mesh,
132
+ buffers=cp_buffers,
133
+ buffer_seq_dims=cp_seq_dims,
134
+ no_restore_buffers=cp_no_restore_buffers,
135
+ )
136
+
137
+
138
+ def get_train_context(
139
+ enable_loss_parallel: bool, enable_compiled_autograd: bool
140
+ ) -> Generator[None, None, None]:
141
+ @contextlib.contextmanager
142
+ def context(cp_context: Generator[None, None, None] | None = None):
143
+ with contextlib.ExitStack() as stack:
144
+ if enable_loss_parallel:
145
+ stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
146
+
147
+ if enable_compiled_autograd:
148
+ stack.enter_context(
149
+ torch._dynamo.utils.maybe_enable_compiled_autograd(True)
150
+ )
151
+
152
+ if cp_context is not None:
153
+ from torch.nn.attention import sdpa_kernel, SDPBackend
154
+
155
+ stack.enter_context(
156
+ sdpa_kernel(
157
+ [
158
+ SDPBackend.FLASH_ATTENTION,
159
+ SDPBackend.EFFICIENT_ATTENTION,
160
+ SDPBackend.CUDNN_ATTENTION,
161
+ ]
162
+ )
163
+ )
164
+ stack.enter_context(cp_context)
165
+
166
+ yield
167
+
168
+ return context
169
+
170
+
171
+ def init_distributed(job_config):
172
+ def _warn_overwrite_env(env, val):
173
+ if env in os.environ:
174
+ logger.warning(
175
+ f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config"
176
+ )
177
+ os.environ[env] = val
178
+
179
+ def _get_distributed_backend(job_config):
180
+ backend = "nccl"
181
+ if device_type in torch.distributed.Backend.default_device_backend_map:
182
+ backend = torch.distributed.Backend.default_device_backend_map.get(
183
+ device_type
184
+ )
185
+ if job_config.training.enable_cpu_offload:
186
+ backend = f"{device_type}:{backend},cpu:gloo"
187
+ return backend
188
+
189
+ TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
190
+ TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
191
+ DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"
192
+ ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING"
193
+ SKIP_CLEANUP = "3"
194
+
195
+ # FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
196
+ # to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
197
+ # This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle
198
+ # behavior differences
199
+ _warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP)
200
+
201
+ # enable torch nccl flight recorder in the mode that would dump files if timeout is detected
202
+ _warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size))
203
+ if job_config.comm.trace_buf_size > 0:
204
+ # dump on timeout by default if trace buffer is enabled
205
+ _warn_overwrite_env(DUMP_ON_TIMEOUT, "1")
206
+ dump_dir = f"{job_config.job.dump_folder}/comm_trace"
207
+ os.makedirs(dump_dir, exist_ok=True)
208
+ _warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_")
209
+
210
+ # to mitigate the memory issue that collectives using
211
+ # async_op=True hold memory longer than they should
212
+ # such as those in tensor parallelism
213
+ os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
214
+
215
+ torch.distributed.init_process_group(
216
+ backend=_get_distributed_backend(job_config),
217
+ timeout=timedelta(seconds=job_config.comm.init_timeout_seconds),
218
+ )
219
+
220
+
221
+ def set_pg_timeouts(timeout, world_mesh):
222
+ """
223
+ Sets the timeout for all PGs in the provided mesh, and the default (world) group.
224
+
225
+ Note: synchronizes via a barrier, before changing the timeouts. This is important, because
226
+ otherwise you may face a race where the slow rank has not reached the timeout reduction point
227
+ yet due to slow operations permitted under the old timeout value, but other faster ranks may
228
+ start issuing collectives under the new shorter timeout and then immediately timeout.
229
+ """
230
+ logger.info(
231
+ f"Synchronizing and adjusting timeout for all ProcessGroups to {timeout}"
232
+ )
233
+ # Ensure that all the ranks have reached the point of setting the new timeout-
234
+ # otherwise, some ranks may issue collectives with the new/shorter timeout and
235
+ # those may time out, before other ranks have finished with initialization done
236
+ # under the old/slow timeout.
237
+ torch.distributed.barrier(device_ids=[device_module.current_device()])
238
+ device_module.synchronize()
239
+
240
+ groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)]
241
+
242
+ # None represents the 'default' PG, not part of the mesh
243
+ groups.append(None)
244
+ for group in groups:
245
+ torch.distributed.distributed_c10d._set_pg_timeout(timeout, group)
246
+
247
+
248
+ @torch.no_grad()
249
+ def clip_grad_norm_(
250
+ parameters: torch.Tensor | Iterable[torch.Tensor],
251
+ max_norm: float,
252
+ norm_type: float = 2.0,
253
+ error_if_nonfinite: bool = False,
254
+ foreach: bool | None = None,
255
+ pp_mesh: DeviceMesh | None = None,
256
+ ) -> torch.Tensor:
257
+ """
258
+ Clip the gradient norm of an iterable of parameters.
259
+
260
+ Gradient norm clipping requires computing the gradient norm over the entire model.
261
+ `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions.
262
+ We need to manually reduce the gradient norm across PP stages.
263
+ See https://github.com/pytorch/torchtitan/issues/596 for details.
264
+
265
+ Args:
266
+ parameters: an iterable of Tensors or a single Tensor that will have gradients normalized
267
+ max_norm (float): max norm of the gradients
268
+ norm_type (float): type of the used p-norm. Can be ``'inf'`` for
269
+ infinity norm.
270
+ error_if_nonfinite (bool): if True, an error is thrown if the total
271
+ norm of the gradients from :attr:`parameters` is ``nan``,
272
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
273
+ foreach (bool): use the faster foreach-based implementation.
274
+ If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
275
+ fall back to the slow implementation for other device types.
276
+ Default: ``None``
277
+ pp_mesh: pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages.
278
+
279
+ Returns:
280
+ Total norm of the parameter gradients (viewed as a single vector).
281
+
282
+ """
283
+ grads = [p.grad for p in parameters if p.grad is not None]
284
+ total_norm = torch.nn.utils.get_total_norm(
285
+ grads, norm_type, error_if_nonfinite, foreach
286
+ )
287
+
288
+ # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
289
+ # We can simply reduce the DTensor to get the total norm in this tensor's process group
290
+ # and then convert it to a local tensor.
291
+ # NOTE: It has two purposes:
292
+ # 1. to make sure the total norm is computed correctly when PP is used (see below)
293
+ # 2. to return a reduced total_norm tensor whose .item() would return the correct value
294
+ if isinstance(total_norm, DTensor):
295
+ # Will reach here if any non-PP parallelism is used.
296
+ # If only using PP, total_norm will be a local tensor.
297
+
298
+ # Remove FT replicate dimension if it exists.
299
+ total_norm = ft_clip_grad_norm_util(total_norm)
300
+ total_norm = total_norm.full_tensor()
301
+
302
+ if pp_mesh is not None:
303
+ if math.isinf(norm_type):
304
+ dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group())
305
+ else:
306
+ total_norm **= norm_type
307
+ dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
308
+ total_norm **= 1.0 / norm_type
309
+
310
+ torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
311
+ return total_norm
torchtitan/experiments/deepseek_v3/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running DeepSeek in Titan (experimental)
2
+
3
+ This folder contains a DeepSeek model supporting v2 and v3 as well as kernels
4
+ and scripts needed to run it.
5
+
6
+ ## Inference
7
+
8
+ ### Prerequisites:
9
+
10
+ You will need to download a DeepSeek model's weights if you want to run a
11
+ pre-trained checkpoint. We provided a script to download the weights from
12
+ HuggingFace Model Hub:
13
+ ```bash
14
+ python download.py [vX]
15
+ ```
16
+ where `vX` can be v2 or v3, both are supported. You may be required to create a
17
+ HuggingFace account and log in first.
18
+
19
+ ### Running inference:
20
+
21
+ The inference script is in `generate.py`. You can run it with the following
22
+ command:
23
+ ```bash
24
+ torchrun --standalone --nproc-per-node 4 generate.py
25
+ ```
26
+ This will run inference on the `DeepSeek-V2-Lite-Chat` model using 4 GPUs by
27
+ default.
28
+
29
+ Alternatively, you can run inference by using `bash inference.sh`, optionally
30
+ followed by your prompt.
31
+
32
+ ## Training
33
+
34
+ The training script is in `train.py`. You can run it by the following command:
35
+ ```bash
36
+ torchrun --standalone --nproc-per-node 8 train.py
37
+ ```
38
+
39
+ This will run training on the `DeepSeek-V2-Lite-Chat` model using 8 GPUs by
40
+ default, with pipeline parallel, expert parallel, and data parallel enabled.
torchtitan/experiments/deepseek_v3/indices.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+
12
+ __all__ = ["generate_permute_indices"]
13
+
14
+
15
+ @triton.jit
16
+ def fill_indices_kernel(
17
+ tokens_per_expert_group_ptr, # *Pointer* to first input vector.
18
+ start_index_values_ptr, # *Pointer* to second input vector.
19
+ write_offsets_ptr, # *Pointer* to third input vector.
20
+ output_ptr, # *Pointer* to output vector.
21
+ experts_per_rank, # Number of experts per rank.
22
+ num_ranks, # Number of expert ranks.
23
+ ):
24
+ # There are multiple 'programs' processing different data. We identify which program
25
+ # we are here:
26
+ pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
27
+ # The total number of programs in the launch grid.
28
+ num_programs = tl.num_programs(axis=0)
29
+ # We map the programs (blocks) to the experts.
30
+ for expert_id in tl.range(pid, experts_per_rank, step=num_programs):
31
+ # Read this expert's write offset.
32
+ write_offset = tl.load(write_offsets_ptr + expert_id)
33
+ # Loop over the ranks.
34
+ for r in tl.range(num_ranks):
35
+ # Slot in the tokens_per_expert_group array.
36
+ i = r * experts_per_rank + expert_id
37
+ start_index = tl.load(start_index_values_ptr + i)
38
+ length = tl.load(tokens_per_expert_group_ptr + i)
39
+ # Write the indices.
40
+ for l in tl.range(length):
41
+ val = start_index + l
42
+ tl.store(output_ptr + write_offset + l, val)
43
+ write_offset += length
44
+
45
+
46
+ def fill_indices(
47
+ tokens_per_expert_group: torch.Tensor,
48
+ start_index_values: torch.Tensor,
49
+ write_offsets: torch.Tensor,
50
+ experts_per_rank: int,
51
+ num_ranks: int,
52
+ max_len: int,
53
+ ):
54
+ # We need to preallocate the output.
55
+ permuted_indices = torch.full(
56
+ (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
57
+ )
58
+ # Analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
59
+ # In this case, we use a 1D grid where the size is the number of blocks (TODO: bump this value).
60
+ grid = lambda meta: (1,)
61
+ # Each torch.tensor object is implicitly converted into a pointer to its first element.
62
+ fill_indices_kernel[grid](
63
+ tokens_per_expert_group,
64
+ start_index_values,
65
+ write_offsets,
66
+ permuted_indices,
67
+ experts_per_rank,
68
+ num_ranks,
69
+ )
70
+ return permuted_indices
71
+
72
+
73
+ def fill_indices_cpu(
74
+ tokens_per_expert_group: torch.Tensor,
75
+ start_index_values: torch.Tensor,
76
+ write_offsets: torch.Tensor,
77
+ experts_per_rank: int,
78
+ num_ranks: int,
79
+ max_len: int,
80
+ ):
81
+ # We need to preallocate the output.
82
+ permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
83
+ # Fill the permuted indices
84
+ # For each local expert
85
+ for e in range(experts_per_rank):
86
+ write_start = write_offsets[e]
87
+ # For each remote rank
88
+ for r in range(num_ranks):
89
+ i = r * experts_per_rank + e
90
+ start_index = start_index_values[i]
91
+ length = tokens_per_expert_group[i]
92
+ # Fill in the indices
93
+ permuted_indices[write_start : write_start + length] = torch.arange(
94
+ start_index, start_index + length
95
+ )
96
+ write_start += length
97
+ return permuted_indices
98
+
99
+
100
+ def generate_permute_indices(
101
+ tokens_per_expert_group: torch.Tensor,
102
+ experts_per_rank: int,
103
+ num_ranks: int,
104
+ max_len: int,
105
+ alignment: int,
106
+ use_cpu: bool = False,
107
+ ):
108
+ # Prepare permutation indices and the number of tokens for each expert. The
109
+ # permutation indices are the indices of the tokens for each expert. The
110
+ # number of tokens for each expert is the sum of the number of tokens for
111
+ # such experts from all ranks. This number is aligned to the provided
112
+ # alignment requirement (usually comes from group gemm).
113
+
114
+ # Args:
115
+ # tokens_per_expert_group: number of tokens for each expert from all ranks.
116
+ # experts_per_rank: number of experts per rank.
117
+ # num_ranks: number of ranks.
118
+ # max_len: maximum length of the output index vector. If greater than
119
+ # total number of tokens, the remaining indices are set to -1.
120
+ # alignment: alignment for each returned element in `m_sizes`.
121
+ # use_cpu: whether to use cpu or gpu.
122
+ # Returns:
123
+ # permuted_indices: permutation indices.
124
+ # m_sizes: number of tokens for each expert.
125
+
126
+ # `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example:
127
+ # From: | rank 0 | rank 1 |
128
+ # To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 |
129
+ # | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
130
+
131
+ # Prefix sum to get the start index value of each expert
132
+ start_index_values = (
133
+ torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
134
+ )
135
+ # Chunk sizes for each expert
136
+ chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
137
+ # Align the chunk sizes to the given alignment
138
+ m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to(
139
+ torch.int32
140
+ )
141
+ # Perform another prefix sum to get the write offset of each expert in `permuted_indices`
142
+ write_offsets = torch.cumsum(m_sizes, 0) - m_sizes
143
+ # Select the method to fill the permuted indices
144
+ fill_fn = fill_indices_cpu if use_cpu else fill_indices
145
+ # Fill the permuted indices
146
+ permuted_indices = fill_fn(
147
+ tokens_per_expert_group,
148
+ start_index_values,
149
+ write_offsets,
150
+ experts_per_rank,
151
+ num_ranks,
152
+ max_len,
153
+ )
154
+ return permuted_indices, m_sizes
155
+
156
+
157
+ # Below is for testing only
158
+
159
+
160
+ def test():
161
+ device = torch.device("cuda", 0)
162
+ experts_per_rank = 4
163
+ num_ranks = 4
164
+ tokens_per_expert_group = torch.full(
165
+ (num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device
166
+ )
167
+ max_len = 128
168
+ alignment = 32
169
+ # Use the GPU kernel
170
+ permuted_indices_gpu, m_sizes = generate_permute_indices(
171
+ tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
172
+ )
173
+ # Use the CPU method
174
+ permuted_indices_cpu, _ = generate_permute_indices(
175
+ tokens_per_expert_group,
176
+ experts_per_rank,
177
+ num_ranks,
178
+ max_len,
179
+ alignment,
180
+ use_cpu=True,
181
+ )
182
+ # Check that the results are the same
183
+ assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu)
184
+ assert torch.equal(
185
+ torch.remainder(m_sizes, alignment),
186
+ torch.zeros(experts_per_rank, device=device),
187
+ )
188
+ # Print the results
189
+ print(permuted_indices_gpu)
190
+ print(m_sizes)
191
+ print("Success")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ test()
torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .triton_on_device_all_to_all_v import OnDeviceAllToAllV
8
+
9
+ __all__ = [
10
+ "OnDeviceAllToAllV",
11
+ ]
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.distributed._symmetric_memory as symm_mem
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from .triton_barrier import blockwise_barrier
14
+ from .triton_utils import sync_threads
15
+
16
+
17
+ @triton.jit
18
+ def _exchange_row_offsets(
19
+ split_sizes_ptrs,
20
+ rank: tl.constexpr,
21
+ world_size: tl.constexpr,
22
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
23
+ ):
24
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
25
+
26
+ # split_sizes_ptr for all ranks
27
+ # All these vector stacks into split_sizes_matrix
28
+ split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
29
+
30
+ # split_sizes_matrix[remote_rank, :]
31
+ input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
32
+ tl.pointer_type(tl.int64)
33
+ )
34
+
35
+ offsets_ = tl.arange(0, world_size)
36
+ input_split_sizes = tl.load(
37
+ input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
38
+ )
39
+
40
+ num_rows = tl.load(input_split_sizes_ptr + rank)
41
+ input_row_offset = tl.sum(input_split_sizes) - num_rows
42
+
43
+ # split_sizes_matrix[:, rank]
44
+ output_split_sizes_ptrs = (
45
+ tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
46
+ )
47
+ output_split_sizes = tl.load(
48
+ output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
49
+ )
50
+ output_row_offset = tl.sum(output_split_sizes) - num_rows
51
+
52
+ return input_row_offset, output_row_offset, num_rows
53
+
54
+
55
+ @triton.jit
56
+ def on_device_all_to_all_v_kernel(
57
+ output_ptr,
58
+ output_splits_ptr,
59
+ input_ptrs,
60
+ input_splits_ptr,
61
+ signal_pad_ptrs,
62
+ dim: tl.constexpr, # Separate dim for easier vectorization
63
+ rank: tl.constexpr,
64
+ world_size: tl.constexpr,
65
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
66
+ UNROLL_FACTOR: tl.constexpr,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ ):
69
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
70
+ sync_threads()
71
+
72
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
73
+ block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
74
+
75
+ input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
76
+ input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
77
+ )
78
+
79
+ output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
80
+ if block_offset == 0:
81
+ # Update output_splits
82
+ tl.store(output_splits_ptr + remote_rank, num_rows)
83
+
84
+ input_ptr = (
85
+ tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
86
+ tl.pointer_type(tl.bfloat16)
87
+ )
88
+ + input_row_offset * dim
89
+ )
90
+ output_ptr = output_ptr + output_row_offset * dim
91
+
92
+ outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
93
+ outer_loop_iters_per_rank = tl.cdiv(
94
+ tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
95
+ )
96
+ numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
97
+ offset = numel_per_rank * block_offset
98
+ end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
99
+
100
+ unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
101
+ for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
102
+ datas = []
103
+ for j in tl.range(
104
+ i,
105
+ i + outer_loop_step,
106
+ BLOCK_SIZE,
107
+ loop_unroll_factor=UNROLL_FACTOR,
108
+ ):
109
+ offsets = j + tl.arange(0, BLOCK_SIZE)
110
+ data = tl.load(input_ptr + offsets)
111
+ tl.store(output_ptr + offsets, data)
112
+
113
+ offset += unroll_region_size
114
+ while offset < end:
115
+ offsets = offset + tl.arange(0, BLOCK_SIZE)
116
+ mask = offsets < num_rows * dim
117
+ data = tl.load(input_ptr + offsets, mask=mask)
118
+ tl.store(output_ptr + offsets, data, mask=mask)
119
+ offset += BLOCK_SIZE
120
+
121
+ sync_threads()
122
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
123
+ return
124
+
125
+
126
+ def _on_device_all_to_all_v(
127
+ output: torch.Tensor,
128
+ output_splits: torch.Tensor,
129
+ input: torch.Tensor,
130
+ input_splits: torch.Tensor,
131
+ group: dist.ProcessGroup = dist.group.WORLD,
132
+ BLOCKS_PER_REMOTE_RANK=8,
133
+ UNROLL_FACTOR: int = 8,
134
+ BLOCK_SIZE: int = 16384,
135
+ ):
136
+ assert output.dim() == 2, f"{output.shape}"
137
+ assert input.dim() == 2, f"{input.shape}"
138
+ assert output.shape[1] == input.shape[1]
139
+
140
+ dim = output.shape[1]
141
+ input_hdl = symm_mem.rendezvous(input, group=group)
142
+ input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
143
+
144
+ num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
145
+ kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
146
+ output,
147
+ output_splits,
148
+ input_hdl.buffer_ptrs_dev,
149
+ input_splits_hdl.buffer_ptrs_dev,
150
+ input_hdl.signal_pad_ptrs_dev,
151
+ dim=dim,
152
+ rank=input_hdl.rank,
153
+ world_size=input_hdl.world_size,
154
+ BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
155
+ UNROLL_FACTOR=UNROLL_FACTOR,
156
+ BLOCK_SIZE=BLOCK_SIZE,
157
+ num_warps=16,
158
+ )
159
+ # log_triton_kernel(kernel)
160
+ return output
161
+
162
+
163
+ class OnDeviceAllToAllV(torch.autograd.Function):
164
+ # A symmetric memory holding the grad_output during backward
165
+ grad_output_buf = None
166
+ # A symmetric memory for exchanges split sizes during both forward and backward
167
+ splits_buf = None
168
+ # Maximum output length (need to be set before use of OnDeviceAllToAllV)
169
+ max_output_len = None
170
+
171
+ @staticmethod
172
+ def forward(
173
+ ctx,
174
+ input: torch.Tensor,
175
+ input_splits: torch.Tensor,
176
+ group: dist.ProcessGroup = dist.group.WORLD,
177
+ ):
178
+ """
179
+ Args:
180
+ input: input tensor with data for all ranks concatenated.
181
+ input_splits: input splits of shape (group.world_size,)
182
+ group: process group to scope the collective.
183
+ """
184
+ # Initialize input splits buffer (one time only)
185
+ if OnDeviceAllToAllV.splits_buf is None:
186
+ OnDeviceAllToAllV.splits_buf = symm_mem.empty(
187
+ *input_splits.shape,
188
+ dtype=input_splits.dtype,
189
+ device=input_splits.device,
190
+ )
191
+
192
+ if OnDeviceAllToAllV.max_output_len is None:
193
+ raise RuntimeError(
194
+ "Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
195
+ )
196
+
197
+ # Allocate output buffer
198
+ output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
199
+ # Allocate output splits tensor
200
+ output_splits = torch.empty_like(input_splits)
201
+ # Copy input splits to the buffer
202
+ OnDeviceAllToAllV.splits_buf.copy_(input_splits)
203
+
204
+ # Shuffle input to output
205
+ _on_device_all_to_all_v(
206
+ output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
207
+ )
208
+
209
+ # Output splits in forward is the input splits in backward
210
+ ctx.save_for_backward(output_splits)
211
+ ctx.group = group
212
+ ctx.input_shape = input.shape
213
+ return output, output_splits
214
+
215
+ @staticmethod
216
+ def backward(ctx, grad_output, grad_splits):
217
+ """
218
+ Backward is implemented as a shuffle of the output's gradients to the input.
219
+ Args:
220
+ `grad_output`: output's gradients passed from the downstream.
221
+ `grad_splits`: unused.
222
+ """
223
+
224
+ # Initialize grad_output buffer (one time only)
225
+ if OnDeviceAllToAllV.grad_output_buf is None:
226
+ assert (
227
+ OnDeviceAllToAllV.max_output_len is not None
228
+ ), "`max_output_len` not set"
229
+ OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
230
+ OnDeviceAllToAllV.max_output_len,
231
+ *grad_output.shape[1:],
232
+ dtype=grad_output.dtype,
233
+ device=grad_output.device,
234
+ )
235
+
236
+ # TODO: is there a way to tell autograd to feed grad_output directly to
237
+ # our symm_mem buffer?
238
+ OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
239
+ grad_output
240
+ )
241
+
242
+ # Size info
243
+ (grad_output_splits,) = ctx.saved_tensors
244
+ OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
245
+ grad_input_splits = torch.empty_like(grad_output_splits) # unused
246
+ grad_input = grad_output.new_empty(*ctx.input_shape)
247
+
248
+ # Shuffle gradients back to the input
249
+ _on_device_all_to_all_v(
250
+ grad_input,
251
+ grad_input_splits,
252
+ OnDeviceAllToAllV.grad_output_buf,
253
+ OnDeviceAllToAllV.splits_buf,
254
+ group=ctx.group,
255
+ )
256
+ return grad_input, None, None
257
+
258
+
259
+ # Alias
260
+ on_device_all_to_all_v = OnDeviceAllToAllV.apply
torchtitan/experiments/deepseek_v3/train.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # torchrun --standalone --nproc-per-node 8 run.py
8
+ import torch
9
+ import torch.distributed as dist
10
+ from checkpoint import load_weights_from_hf
11
+ from model import DeepseekForCausalLM
12
+ from model_config import deepseek_config_registry
13
+
14
+ from torch.distributed.device_mesh import DeviceMesh
15
+ from torch.distributed.fsdp import fully_shard
16
+ from torch.distributed.pipelining import PipelineStage, Schedule1F1B
17
+
18
+
19
+ # Use DeepSeek-V2-Lite as a proxy
20
+ model_id = "deepseek-ai/DeepSeek-V2-Lite"
21
+
22
+
23
+ # Run full model
24
+ def run_full_model(
25
+ mesh: DeviceMesh,
26
+ ):
27
+ rank = dist.get_rank()
28
+ device_count = torch.cuda.device_count()
29
+ device = torch.device("cuda", rank % device_count)
30
+
31
+ pp_mesh = mesh["pp"]
32
+ ep_mesh = mesh["ep"]
33
+ pp_rank = pp_mesh.get_local_rank()
34
+ ep_rank = ep_mesh.get_local_rank()
35
+ pp_size = pp_mesh.size()
36
+ ep_size = ep_mesh.size()
37
+
38
+ # Get model configs
39
+ model_args = deepseek_config_registry[model_id]
40
+ # [Note]: I am making the model smaller for testing / avoiding OOM. If you
41
+ # have sufficient GPUs for model parallelism, you can remove this line.
42
+ model_args.num_hidden_layers = 16
43
+
44
+ # Apply model parallelism
45
+ model_args.ep_size = ep_size
46
+ model_args.num_stages = pp_size
47
+ model_args.stage_idx = pp_rank
48
+ print(model_args)
49
+
50
+ # Instantiate model
51
+ with device, mesh:
52
+ model = DeepseekForCausalLM(model_args)
53
+
54
+ # Load weights
55
+ load_weights_from_hf(model, model_id, device)
56
+ model.train()
57
+
58
+ # Apply data parallelism
59
+ fsdp_mesh = mesh["fsdp"]
60
+ hsdp_mesh = mesh["ep", "fsdp"]
61
+ # Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the
62
+ # optimizer (Zero-1) and gradients (Zero-2), but not the model weights.
63
+ # Reason: the MoE is "sparsely activated" compared to the dense model, thus
64
+ # it will be ineconomical re-gather the weights.
65
+ for layer in model.model.layers.values():
66
+ # Apply FSDP to experts
67
+ if hasattr(layer.mlp, "experts"):
68
+ for expert in layer.mlp.experts.values():
69
+ fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False)
70
+ # Apply HSDP to other parts such as attention, layernorm, because they
71
+ # are doing DDP on EP dimension
72
+ fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False)
73
+
74
+ # Apply HSDP on root model (lm_head, embeddings, etc)
75
+ fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)
76
+
77
+ # Synthetic setting
78
+ microbatches = pp_size * 2
79
+
80
+ # Use Symmetric Memory for MoE token shuffle.
81
+ # TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is
82
+ # currently supported for forward only. See `generate.py`.
83
+ # model.setup_symm_mem(torch.bfloat16, device)
84
+
85
+ # Example inputs
86
+ torch.manual_seed(ep_rank)
87
+ bs = 4
88
+ seqlen = 128
89
+ x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device)
90
+ label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device)
91
+
92
+ # Create loss function
93
+ loss_fn = torch.nn.functional.cross_entropy
94
+
95
+ # Run forward and backward
96
+ steps = 2
97
+ for _ in range(steps):
98
+ if pp_size > 1:
99
+ # Create pipeline stage
100
+ stage = PipelineStage(
101
+ model,
102
+ pp_rank,
103
+ pp_size,
104
+ device,
105
+ group=pp_mesh.get_group(),
106
+ )
107
+
108
+ # Create pipeline schedule
109
+ losses = []
110
+ pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn)
111
+
112
+ if pp_rank == 0:
113
+ y = pp_schedule.step(x)
114
+ elif pp_rank == pp_size - 1:
115
+ y = pp_schedule.step(target=label, losses=losses)
116
+ loss = torch.mean(torch.stack(losses))
117
+ else:
118
+ pp_schedule.step()
119
+ else:
120
+ y = model(x)
121
+ loss = loss_fn(y, label)
122
+ loss.backward()
123
+
124
+ if pp_rank == pp_size - 1:
125
+ print(f"logits: {y.shape}")
126
+ print(f"{loss=}")
127
+
128
+ if pp_rank == 0:
129
+ param = model.get_parameter("model.layers.0.self_attn.q_proj.weight")
130
+ print(f"{torch.linalg.norm(param.grad)=}")
131
+
132
+ model.zero_grad()
133
+
134
+ print("Backward done")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp"))
139
+
140
+ run_full_model(mesh)
141
+
142
+ dist.destroy_process_group()
torchtitan/experiments/flux/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX model in torchtitan
2
+
3
+ ## Overview
4
+
5
+ ## Usage
6
+ First, download the autoencoder model from HuggingFace with your own access token:
7
+ ```bash
8
+ python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
9
+ ```
10
+ This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
11
+
12
+ Run the following command to train the model on a single GPU:
13
+ ```bash
14
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
15
+ ```
16
+
17
+ ## TODO
18
+ - [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
19
+ - [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
20
+ - [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
21
+ - [ ] Support for distributed checkpointing and loading
22
+ - [ ] Implement init_weights() function to initialize the model weights
23
+ - [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
torchtitan/experiments/flux/dataset/flux_dataset.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import random
9
+ from dataclasses import dataclass
10
+ from typing import Any, Callable, Optional
11
+
12
+ import numpy as np
13
+
14
+ import torch
15
+
16
+ from datasets import Dataset, load_dataset
17
+ from datasets.distributed import split_dataset_by_node
18
+ from PIL import Image
19
+
20
+ from torch.distributed.checkpoint.stateful import Stateful
21
+
22
+ from torch.utils.data import IterableDataset
23
+ from torchtitan.components.dataloader import ParallelAwareDataloader
24
+
25
+ from torchtitan.config_manager import JobConfig
26
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
27
+ from torchtitan.tools.logging import logger
28
+
29
+
30
+ def _process_cc12m_image(
31
+ img: Image.Image,
32
+ output_size: int = 256,
33
+ ) -> Optional[torch.Tensor]:
34
+ """Process CC12M image to the desired size."""
35
+
36
+ width, height = img.size
37
+ # Skip low resolution images
38
+ if width < output_size or height < output_size:
39
+ return None
40
+
41
+ if width >= height:
42
+ # resize height to be equal to output_size, then crop
43
+ new_width, new_height = math.ceil(output_size / height * width), output_size
44
+ img = img.resize((new_width, new_height))
45
+ left = random.randint(0, new_width - output_size)
46
+ resized_img = img.crop((left, 0, left + output_size, output_size))
47
+ else:
48
+ # resize width to be equal to output_size, the crop
49
+ new_width, new_height = (
50
+ output_size,
51
+ math.ceil(output_size / width * height),
52
+ )
53
+ img = img.resize((new_width, new_height))
54
+ lower = random.randint(0, new_width - output_size)
55
+ resized_img = img.crop((0, lower, output_size, lower + output_size))
56
+
57
+ assert resized_img.size[0] == resized_img.size[1] == output_size
58
+
59
+ # Skip grayscale images
60
+ if resized_img.mode == "L":
61
+ return None
62
+
63
+ np_img = np.array(resized_img).transpose((2, 0, 1))
64
+ tensor_img = torch.tensor(np_img).float() / 255.0
65
+
66
+ # NOTE: The following commented code is an alternative way
67
+ # img_transform = transforms.Compose(
68
+ # [
69
+ # transforms.Resize(max(output_size, output_size)),
70
+ # transforms.CenterCrop((output_size, output_size)),
71
+ # transforms.ToTensor(),
72
+ # ]
73
+ # )
74
+ # tensor_img = img_transform(img)
75
+
76
+ return tensor_img
77
+
78
+
79
+ def _flux_data_processor(
80
+ sample: dict[str, Any],
81
+ t5_tokenizer: FluxTokenizer,
82
+ clip_tokenizer: FluxTokenizer,
83
+ output_size: int = 256,
84
+ ) -> dict[str, Any]:
85
+ """
86
+ Preprocess CC12M dataset sample image and text for Flux model.
87
+
88
+ Args:
89
+ sample: A sample from dataset
90
+ t5_encoder: T5 encoder
91
+ clip_encoder: CLIP encoder
92
+ output_size: The output image size
93
+
94
+ """
95
+ img = _process_cc12m_image(sample["jpg"], output_size=output_size)
96
+ t5_tokens = t5_tokenizer.encode(sample["txt"])
97
+ clip_tokens = clip_tokenizer.encode(sample["txt"])
98
+
99
+ return {
100
+ "image": img,
101
+ "clip_tokens": clip_tokens, # type: List[int]
102
+ "t5_tokens": t5_tokens, # type: List[int]
103
+ }
104
+
105
+
106
+ @dataclass
107
+ class TextToImageDatasetConfig:
108
+ path: str
109
+ loader: Callable
110
+ data_processor: Callable
111
+
112
+
113
+ DATASETS = {
114
+ "cc12m": TextToImageDatasetConfig(
115
+ path="pixparse/cc12m-wds",
116
+ loader=lambda path: load_dataset(path, split="train", streaming=True),
117
+ data_processor=_flux_data_processor,
118
+ ),
119
+ }
120
+
121
+
122
+ def _validate_dataset(
123
+ dataset_name: str, dataset_path: Optional[str] = None
124
+ ) -> tuple[str, Callable, Callable]:
125
+ """Validate dataset name and path."""
126
+ if dataset_name not in DATASETS:
127
+ raise ValueError(
128
+ f"Dataset {dataset_name} is not supported. "
129
+ f"Supported datasets are: {list(DATASETS.keys())}"
130
+ )
131
+
132
+ config = DATASETS[dataset_name]
133
+ path = dataset_path or config.path
134
+ logger.info(f"Preparing {dataset_name} dataset from {path}")
135
+ return path, config.loader, config.data_processor
136
+
137
+
138
+ class FluxDataset(IterableDataset, Stateful):
139
+ """Dataset for FLUX text-to-image model.
140
+
141
+ Args:
142
+ dataset_name (str): Name of the dataset.
143
+ dataset_path (str): Path to the dataset.
144
+ model_transform (Transform): Callable that applies model-specific preprocessing to the sample.
145
+ dp_rank (int): Data parallel rank.
146
+ dp_world_size (int): Data parallel world size.
147
+ infinite (bool): Whether to loop over the dataset infinitely.
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ dataset_name: str,
153
+ dataset_path: Optional[str],
154
+ t5_tokenizer: FluxTokenizer,
155
+ clip_tokenizer: FluxTokenizer,
156
+ job_config: Optional[JobConfig] = None,
157
+ dp_rank: int = 0,
158
+ dp_world_size: int = 1,
159
+ infinite: bool = False,
160
+ ) -> None:
161
+
162
+ # Force lowercase for consistent comparison
163
+ dataset_name = dataset_name.lower()
164
+
165
+ path, dataset_loader, data_processor = _validate_dataset(
166
+ dataset_name, dataset_path
167
+ )
168
+ ds = dataset_loader(path)
169
+
170
+ self.dataset_name = dataset_name
171
+ self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
172
+
173
+ self._t5_tokenizer = t5_tokenizer
174
+ self._clip_tokenizer = clip_tokenizer
175
+ self._data_processor = data_processor
176
+ self.job_config = job_config
177
+
178
+ self.infinite = infinite
179
+
180
+ # Variables for checkpointing
181
+ self._sample_idx = 0
182
+ self._all_samples: list[dict[str, Any]] = []
183
+
184
+ def _get_data_iter(self):
185
+ if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
186
+ return iter([])
187
+
188
+ it = iter(self._data)
189
+ for _ in range(self._sample_idx):
190
+ next(it)
191
+ return it
192
+
193
+ def __iter__(self):
194
+ while True:
195
+ for sample in self._get_data_iter():
196
+ # Use the dataset-specific preprocessor
197
+ sample_dict = self._data_processor(
198
+ sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256
199
+ )
200
+
201
+ # skip low quality image or image with color channel = 1
202
+ if sample_dict["image"] is None:
203
+ logger.warning(
204
+ f"Low quality image {sample['__key__']} is skipped in Flux Dataloader"
205
+ )
206
+ continue
207
+
208
+ self._all_samples.extend(sample_dict)
209
+ self._sample_idx += 1
210
+
211
+ labels = sample_dict.pop("image")
212
+ yield sample_dict, labels
213
+
214
+ if not self.infinite:
215
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
216
+ break
217
+ else:
218
+ # Reset offset for the next iteration
219
+ self._sample_idx = 0
220
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
221
+
222
+ def load_state_dict(self, state_dict):
223
+ self._sample_idx = state_dict["sample_idx"]
224
+ self._all_samples = state_dict["all_samples"]
225
+
226
+ def state_dict(self):
227
+ return {
228
+ "all_samples": self._all_samples,
229
+ "sample_idx": self._sample_idx,
230
+ }
231
+
232
+
233
+ def build_flux_dataloader(
234
+ dp_world_size: int,
235
+ dp_rank: int,
236
+ job_config: JobConfig,
237
+ # This parameter is not used, keep it for compatibility
238
+ tokenizer: FluxTokenizer | None,
239
+ infinite: bool = True,
240
+ ) -> ParallelAwareDataloader:
241
+ """Build a data loader for HuggingFace datasets."""
242
+ dataset_name = job_config.training.dataset
243
+ dataset_path = job_config.training.dataset_path
244
+ batch_size = job_config.training.batch_size
245
+
246
+ t5_encoder_name = job_config.encoder.t5_encoder
247
+ clip_encoder_name = job_config.encoder.clip_encoder
248
+ max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
249
+
250
+ ds = FluxDataset(
251
+ dataset_name=dataset_name,
252
+ dataset_path=dataset_path,
253
+ t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len),
254
+ clip_tokenizer=FluxTokenizer(
255
+ clip_encoder_name, max_length=77
256
+ ), # fix max_length for CLIP
257
+ dp_rank=dp_rank,
258
+ dp_world_size=dp_world_size,
259
+ infinite=infinite,
260
+ )
261
+
262
+ return ParallelAwareDataloader(
263
+ dataset=ds,
264
+ dp_rank=dp_rank,
265
+ dp_world_size=dp_world_size,
266
+ batch_size=batch_size,
267
+ )
torchtitan/experiments/flux/dataset/tokenizer.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+
11
+ from typing import List
12
+
13
+ from torchtitan.components.tokenizer import Tokenizer
14
+ from transformers import CLIPTokenizer, T5Tokenizer
15
+
16
+
17
+ class FluxTokenizer(Tokenizer):
18
+ """
19
+ Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
20
+
21
+ Args:
22
+ model_path (str): Path to the tokenzier from hugging face.
23
+
24
+ """
25
+
26
+ def __init__(self, model_path: str = "t5-small", max_length: int = 77):
27
+ super().__init__()
28
+ self._n_words = 8 # TODO(jianiw): check
29
+ self._max_length = max_length
30
+
31
+ self.is_clip = model_path.startswith("openai")
32
+
33
+ if self.is_clip:
34
+ self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
35
+ model_path, max_length=max_length
36
+ )
37
+ else:
38
+ self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
39
+ model_path, max_length=max_length
40
+ )
41
+
42
+ def encode(
43
+ self,
44
+ s: str,
45
+ ) -> List[int]:
46
+ """
47
+ Encode the prompt text into tokens.
48
+ """
49
+ tokens = self._tokenizer(
50
+ s,
51
+ truncation=True,
52
+ max_length=self._max_length,
53
+ return_length=False,
54
+ return_overflowing_tokens=False,
55
+ padding="max_length",
56
+ return_tensors="pt", # return pytorch tensors, default return List[int]
57
+ )["input_ids"]
58
+ return tokens
59
+
60
+ def decode(self, t: List[int]) -> str:
61
+ """
62
+ Decode function. This function will not be called.
63
+ """
64
+ return self._tokenizer.decode(t)
torchtitan/experiments/flux/model/hf_embedder.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn, Tensor
8
+ from transformers import CLIPTextModel, T5EncoderModel
9
+
10
+
11
+ class FluxEmbedder(nn.Module):
12
+ def __init__(self, version: str, **hf_kwargs):
13
+ super().__init__()
14
+ self.is_clip = version.startswith("openai")
15
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
16
+
17
+ if self.is_clip:
18
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
19
+ version, **hf_kwargs
20
+ )
21
+ else:
22
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
23
+ version, **hf_kwargs
24
+ )
25
+
26
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
27
+
28
+ def forward(self, batch_tokens: Tensor) -> Tensor:
29
+ """
30
+ batch_tokens: [bsz, embedding_length]
31
+
32
+ For T5 Encoder, embeding_length is 768
33
+ For CLIP, embedding_length is 256
34
+ """
35
+ outputs = self.hf_module(
36
+ input_ids=batch_tokens.to(self.hf_module.device),
37
+ attention_mask=None,
38
+ output_hidden_states=False,
39
+ )
40
+ return outputs[self.output_key]
torchtitan/experiments/flux/model/layers.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # imported from black-forest-labs/FLUX
8
+ import math
9
+ from dataclasses import dataclass
10
+
11
+ import torch
12
+ from einops import rearrange
13
+ from torch import nn, Tensor
14
+
15
+ from torchtitan.experiments.flux.model.math import attention, rope
16
+
17
+
18
+ class EmbedND(nn.Module):
19
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
20
+ super().__init__()
21
+ self.dim = dim
22
+ self.theta = theta
23
+ self.axes_dim = axes_dim
24
+
25
+ def forward(self, ids: Tensor) -> Tensor:
26
+ n_axes = ids.shape[-1]
27
+ emb = torch.cat(
28
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
29
+ dim=-3,
30
+ )
31
+
32
+ return emb.unsqueeze(1)
33
+
34
+
35
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
36
+ """
37
+ Create sinusoidal timestep embeddings.
38
+ :param t: a 1-D Tensor of N indices, one per batch element.
39
+ These may be fractional.
40
+ :param dim: the dimension of the output.
41
+ :param max_period: controls the minimum frequency of the embeddings.
42
+ :return: an (N, D) Tensor of positional embeddings.
43
+ """
44
+ t = time_factor * t
45
+ half = dim // 2
46
+ freqs = torch.exp(
47
+ -math.log(max_period)
48
+ * torch.arange(start=0, end=half, dtype=torch.float32)
49
+ / half
50
+ ).to(t.device)
51
+
52
+ args = t[:, None].float() * freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ if torch.is_floating_point(t):
57
+ embedding = embedding.to(t)
58
+ return embedding
59
+
60
+
61
+ class MLPEmbedder(nn.Module):
62
+ def __init__(self, in_dim: int, hidden_dim: int):
63
+ super().__init__()
64
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
65
+ self.silu = nn.SiLU()
66
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ return self.out_layer(self.silu(self.in_layer(x)))
70
+
71
+
72
+ class RMSNorm(torch.nn.Module):
73
+ def __init__(self, dim: int):
74
+ super().__init__()
75
+ self.scale = nn.Parameter(torch.ones(dim))
76
+
77
+ def forward(self, x: Tensor):
78
+ x_dtype = x.dtype
79
+ x = x.float()
80
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
81
+ return (x * rrms).to(dtype=x_dtype) * self.scale
82
+
83
+
84
+ class QKNorm(torch.nn.Module):
85
+ def __init__(self, dim: int):
86
+ super().__init__()
87
+ self.query_norm = RMSNorm(dim) # TODO(jianiw): switch to pytorch nn.RMSNorm
88
+ self.key_norm = RMSNorm(dim)
89
+
90
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
91
+ q = self.query_norm(q)
92
+ k = self.key_norm(k)
93
+ return q.to(v), k.to(v)
94
+
95
+
96
+ class SelfAttention(nn.Module):
97
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
98
+ super().__init__()
99
+ self.num_heads = num_heads
100
+ head_dim = dim // num_heads
101
+
102
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
103
+ self.norm = QKNorm(head_dim)
104
+ self.proj = nn.Linear(dim, dim)
105
+
106
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
107
+ qkv = self.qkv(x)
108
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
109
+ q, k = self.norm(q, k, v)
110
+ x = attention(q, k, v, pe=pe)
111
+ x = self.proj(x)
112
+ return x
113
+
114
+
115
+ @dataclass
116
+ class ModulationOut:
117
+ shift: Tensor
118
+ scale: Tensor
119
+ gate: Tensor
120
+
121
+
122
+ class Modulation(nn.Module):
123
+ def __init__(self, dim: int, double: bool):
124
+ super().__init__()
125
+ self.is_double = double
126
+ self.multiplier = 6 if double else 3
127
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
128
+
129
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
130
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
131
+ self.multiplier, dim=-1
132
+ )
133
+
134
+ return (
135
+ ModulationOut(*out[:3]),
136
+ ModulationOut(*out[3:]) if self.is_double else None,
137
+ )
138
+
139
+
140
+ class DoubleStreamBlock(nn.Module):
141
+ def __init__(
142
+ self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
143
+ ):
144
+ super().__init__()
145
+
146
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
147
+ self.num_heads = num_heads
148
+ self.hidden_size = hidden_size
149
+ self.img_mod = Modulation(hidden_size, double=True)
150
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
151
+ self.img_attn = SelfAttention(
152
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
153
+ )
154
+
155
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
156
+ self.img_mlp = nn.Sequential(
157
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
158
+ nn.GELU(approximate="tanh"),
159
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
160
+ )
161
+
162
+ self.txt_mod = Modulation(hidden_size, double=True)
163
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
164
+ self.txt_attn = SelfAttention(
165
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
166
+ )
167
+
168
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
169
+ self.txt_mlp = nn.Sequential(
170
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
171
+ nn.GELU(approximate="tanh"),
172
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
173
+ )
174
+
175
+ def forward(
176
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
177
+ ) -> tuple[Tensor, Tensor]:
178
+ img_mod1, img_mod2 = self.img_mod(vec)
179
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
180
+
181
+ # prepare image for attention
182
+ img_modulated = self.img_norm1(img)
183
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
184
+ img_qkv = self.img_attn.qkv(img_modulated)
185
+ img_q, img_k, img_v = rearrange(
186
+ img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
187
+ )
188
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
189
+
190
+ # prepare txt for attention
191
+ txt_modulated = self.txt_norm1(txt)
192
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
193
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
194
+ txt_q, txt_k, txt_v = rearrange(
195
+ txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
196
+ )
197
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
198
+
199
+ # run actual attention
200
+ q = torch.cat((txt_q, img_q), dim=2)
201
+ k = torch.cat((txt_k, img_k), dim=2)
202
+ v = torch.cat((txt_v, img_v), dim=2)
203
+
204
+ attn = attention(q, k, v, pe=pe)
205
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
206
+
207
+ # calculate the img bloks
208
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
209
+ img = img + img_mod2.gate * self.img_mlp(
210
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
211
+ )
212
+
213
+ # calculate the txt bloks
214
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
215
+ txt = txt + txt_mod2.gate * self.txt_mlp(
216
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
217
+ )
218
+ return img, txt
219
+
220
+
221
+ class SingleStreamBlock(nn.Module):
222
+ """
223
+ A DiT block with parallel linear layers as described in
224
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ hidden_size: int,
230
+ num_heads: int,
231
+ mlp_ratio: float = 4.0,
232
+ qk_scale: float | None = None,
233
+ ):
234
+ super().__init__()
235
+ self.hidden_dim = hidden_size
236
+ self.num_heads = num_heads
237
+ head_dim = hidden_size // num_heads
238
+ self.scale = qk_scale or head_dim**-0.5
239
+
240
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
241
+ # qkv and mlp_in
242
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
243
+ # proj and mlp_out
244
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
245
+
246
+ self.norm = QKNorm(head_dim)
247
+
248
+ self.hidden_size = hidden_size
249
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
250
+
251
+ self.mlp_act = nn.GELU(approximate="tanh")
252
+ self.modulation = Modulation(hidden_size, double=False)
253
+
254
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
255
+ mod, _ = self.modulation(vec)
256
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
257
+ qkv, mlp = torch.split(
258
+ self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
259
+ )
260
+
261
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
262
+ q, k = self.norm(q, k, v)
263
+
264
+ # compute attention
265
+ attn = attention(q, k, v, pe=pe)
266
+ # compute activation in mlp stream, cat again and run second linear layer
267
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
268
+ return x + mod.gate * output
269
+
270
+
271
+ class LastLayer(nn.Module):
272
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
273
+ super().__init__()
274
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
275
+ self.linear = nn.Linear(
276
+ hidden_size, patch_size * patch_size * out_channels, bias=True
277
+ )
278
+ self.adaLN_modulation = nn.Sequential(
279
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
280
+ )
281
+
282
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
283
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
284
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
285
+ x = self.linear(x)
286
+ return x
torchtitan/experiments/flux/tests/test_flux_dataloader.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+
9
+ from torchtitan.config_manager import JobConfig
10
+ from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
11
+ from torchtitan.tools.profiling import (
12
+ maybe_enable_memory_snapshot,
13
+ maybe_enable_profiling,
14
+ )
15
+
16
+
17
+ class TestFluxDataLoader:
18
+ def test_flux_dataloader(self):
19
+ dataset_name = "cc12m"
20
+ batch_size = 32
21
+ world_size = 4
22
+ rank = 0
23
+
24
+ num_steps = 10
25
+
26
+ path = "torchtitan.experiments.flux.flux_argparser"
27
+ sys.argv.append(f"--experimental.custom_args_module={path}")
28
+ config = JobConfig()
29
+ config.maybe_add_custom_args()
30
+ config.parse_args(
31
+ [
32
+ # Profiling options
33
+ # "--profiling.enable_profiling",
34
+ # "--profiling.profile_freq",
35
+ # "5",
36
+ # "--profiling.enable_memory_snapshot",
37
+ # "--profiling.save_memory_snapshot_folder",
38
+ # "memory_snapshot_flux",
39
+ "--training.dataset",
40
+ dataset_name,
41
+ "--training.batch_size",
42
+ str(batch_size),
43
+ "--encoder.t5_encoder",
44
+ "google/t5-v1_1-small",
45
+ "--encoder.clip_encoder",
46
+ "openai/clip-vit-large-patch14",
47
+ "--encoder.max_t5_encoding_len",
48
+ "512",
49
+ ]
50
+ )
51
+
52
+ with maybe_enable_profiling(
53
+ config, global_step=0
54
+ ) as torch_profiler, maybe_enable_memory_snapshot(
55
+ config, global_step=0
56
+ ) as memory_profiler:
57
+ dl = self._build_dataloader(
58
+ config,
59
+ world_size,
60
+ rank,
61
+ )
62
+ dl = iter(dl)
63
+
64
+ for i in range(0, num_steps):
65
+ input_data, labels = next(dl)
66
+ print(f"Step {i} image size: {labels.shape}")
67
+ if torch_profiler:
68
+ torch_profiler.step()
69
+ if memory_profiler:
70
+ memory_profiler.step()
71
+
72
+ print(len(input_data["clip_tokens"]))
73
+ for k, v in input_data.items():
74
+ print(f"Step {i} {k} value: {type(v), v.shape}")
75
+
76
+ assert len(input_data) == 2 # (clip_encodings, t5_encodings)
77
+ assert labels.shape == (batch_size, 3, 256, 256)
78
+ # assert input_data["clip_tokens"].shape[0] == batch_size
79
+ # assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
80
+
81
+ if torch_profiler:
82
+ torch_profiler.step()
83
+ if memory_profiler:
84
+ memory_profiler.step(exit_ctx=True)
85
+
86
+ def test_preprocess(self):
87
+ # TODO
88
+ pass
89
+
90
+ def _build_dataloader(
91
+ self,
92
+ job_config,
93
+ world_size,
94
+ rank,
95
+ ):
96
+
97
+ return build_flux_dataloader(
98
+ dp_world_size=world_size,
99
+ dp_rank=rank,
100
+ job_config=job_config,
101
+ tokenizer=None,
102
+ infinite=False,
103
+ )
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from reference_utils import (
14
+ analyze_tensor_differences,
15
+ compute_reference_backward,
16
+ compute_reference_forward,
17
+ )
18
+
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
22
+ )
23
+
24
+ # Import grouped GEMM implementations
25
+ try:
26
+ from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward
27
+
28
+ except ImportError:
29
+ logging.error(
30
+ "Error importing grouped GEMM modules. Make sure the implementation files are in the correct path."
31
+ )
32
+ raise
33
+
34
+
35
+ def test_forward_pass():
36
+ """
37
+ A simple test for the M*G grouped GEMM forward pass with detailed error handling.
38
+
39
+ In M*G grouping:
40
+ - M dimension is partitioned into G groups (M_total = sum(M_sizes))
41
+ - N dimension is the same for all groups
42
+ """
43
+ try:
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+
46
+ # Test parameters for DeepSeek-like models
47
+ G = 1 # Number of groups
48
+ M_sizes = [
49
+ 2048,
50
+ ] # 2048, 2048, 2048] # Group sizes (will be adjusted)
51
+ M_total = sum(M_sizes) # Total M dimension
52
+ N = 4096 # Output dimension (same for all groups)
53
+ K = 7168 # Hidden dimension
54
+
55
+ # Create group sizes tensor
56
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
57
+
58
+ # Create input and weight tensors - using float16 for higher precision
59
+ x = torch.randn(M_total, K, dtype=torch.float16, device=device)
60
+ w = torch.randn(N, K, dtype=torch.float16, device=device)
61
+
62
+ # Log the setup
63
+ logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
64
+ logging.info(f"Group sizes: {m_sizes}")
65
+ logging.info(f"Input x shape: {x.shape}")
66
+ logging.info(f"Weight w shape: {w.shape}")
67
+
68
+ # Run forward pass
69
+ logging.info("Running forward pass with grouped GEMM")
70
+ result = grouped_gemm_forward(x, w, m_sizes)
71
+ logging.info(f"Forward result shape: {result.shape}")
72
+
73
+ # Compute reference result
74
+ logging.info("Computing reference result with PyTorch")
75
+ reference_result = compute_reference_forward(x, w, m_sizes)
76
+
77
+ # Compare results
78
+ logging.info("Comparing with PyTorch reference")
79
+ forward_close = analyze_tensor_differences(
80
+ result, reference_result, "Forward output"
81
+ )
82
+
83
+ return forward_close
84
+
85
+ except Exception as e:
86
+ logging.error(f"Test failed with error: {e}")
87
+ import traceback
88
+
89
+ logging.error(traceback.format_exc())
90
+ return False
91
+
92
+
93
+ def test_backward_pass():
94
+ """
95
+ A simple test for the M*G grouped GEMM backward pass with detailed error handling.
96
+
97
+ In M*G grouping:
98
+ - M dimension is partitioned into G groups (M_total = sum(M_sizes))
99
+ - N dimension is the same for all groups
100
+ """
101
+ try:
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+
104
+ # Test parameters for DeepSeek-like models
105
+ G = 4 # Number of groups
106
+ M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted)
107
+ M_total = sum(M_sizes) # Total M dimension
108
+ N = 4096 # Output dimension (same for all groups)
109
+ K = 7168 # Hidden dimension
110
+
111
+ # Create group sizes tensor
112
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
113
+
114
+ # Create input and weight tensors - using float16 for higher precision
115
+ x = torch.randn(
116
+ M_total, K, dtype=torch.float16, device=device, requires_grad=True
117
+ )
118
+ w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True)
119
+
120
+ # Log the setup
121
+ logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
122
+ logging.info(f"Group sizes: {m_sizes}")
123
+ logging.info(f"Input x shape: {x.shape}")
124
+ logging.info(f"Weight w shape: {w.shape}")
125
+
126
+ # Step 1: Run forward pass
127
+ logging.info("Running forward pass")
128
+ result = grouped_gemm_forward(x, w, m_sizes)
129
+ logging.info(f"Forward result shape: {result.shape}")
130
+
131
+ # Create a gradient for backpropagation
132
+ grad_output = torch.randn_like(result)
133
+ logging.info(f"Created gradient with shape: {grad_output.shape}")
134
+
135
+ # Step 2: Run backward pass directly
136
+ logging.info("Running backward pass directly")
137
+ grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
138
+
139
+ # Verify gradient shapes
140
+ logging.info(
141
+ f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}"
142
+ )
143
+
144
+ # Step 3: Verify gradient computation using PyTorch's autograd
145
+ logging.info("Running PyTorch reference implementation")
146
+
147
+ # Compute reference gradients
148
+ x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output)
149
+
150
+ # Compare gradients
151
+ logging.info("Comparing gradients with PyTorch reference")
152
+ grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
153
+ grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
154
+
155
+ # Log overall result
156
+ if grad_x_close and grad_w_close:
157
+ logging.info("✓ SUCCESS: Gradients match the PyTorch reference")
158
+ else:
159
+ logging.error("✗ FAILURE: Gradient mismatch detected")
160
+
161
+ return grad_x_close and grad_w_close
162
+
163
+ except Exception as e:
164
+ logging.error(f"Test failed with error: {e}")
165
+ import traceback
166
+
167
+ logging.error(traceback.format_exc())
168
+ return False
169
+
170
+
171
+ def test_multiple_deepseek_configs():
172
+ """
173
+ Test multiple DeepSeek model configurations with both forward and backward pass verification.
174
+ """
175
+ # DeepSeek configurations: (G, M, K, N)
176
+ configs = [
177
+ (4, 8192, 7168, 4096), # Config 1
178
+ (4, 8192, 2048, 7168), # Config 2
179
+ (8, 4096, 7168, 4096), # Config 3
180
+ (8, 4096, 2048, 7168), # Config 4
181
+ ]
182
+
183
+ results = []
184
+
185
+ for config_idx, (G, M, K, N) in enumerate(configs):
186
+ logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====")
187
+ logging.info(f"G={G}, M={M}, K={K}, N={N}")
188
+
189
+ try:
190
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
+
192
+ # Create even group sizes
193
+ base_size = M // G
194
+ remainder = M % G
195
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
196
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
197
+
198
+ # Create input and weight tensors using float16 for higher precision
199
+ x = torch.randn(
200
+ M, K, dtype=torch.float16, device=device, requires_grad=True
201
+ )
202
+ w = torch.randn(
203
+ N, K, dtype=torch.float16, device=device, requires_grad=True
204
+ )
205
+
206
+ logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}")
207
+
208
+ # Run forward pass
209
+ result = grouped_gemm_forward(x, w, m_sizes)
210
+ logging.info(f"Forward result shape: {result.shape}")
211
+
212
+ # ===== FORWARD PASS VERIFICATION =====
213
+ # Compute reference forward result
214
+ reference_result = compute_reference_forward(x, w, m_sizes)
215
+
216
+ # Compare forward results
217
+ forward_close = analyze_tensor_differences(
218
+ result, reference_result, "Forward output"
219
+ )
220
+
221
+ # ===== BACKWARD PASS VERIFICATION =====
222
+ # Create gradient for backpropagation
223
+ grad_output = torch.randn_like(result)
224
+
225
+ # Run backward pass
226
+ grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
227
+
228
+ # Compute reference gradients
229
+ x_ref_grad, w_ref_grad = compute_reference_backward(
230
+ x, w, m_sizes, grad_output
231
+ )
232
+
233
+ # Compare backward results
234
+ grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
235
+ grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
236
+
237
+ # Overall config result
238
+ backward_close = grad_x_close and grad_w_close
239
+ config_success = forward_close and backward_close
240
+ results.append(
241
+ (config_idx + 1, config_success, forward_close, backward_close)
242
+ )
243
+
244
+ # Log overall config result
245
+ if config_success:
246
+ logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!")
247
+ else:
248
+ logging.error(
249
+ f"✗ FAILURE: Config {config_idx+1} failed one or more tests"
250
+ )
251
+
252
+ except Exception as e:
253
+ logging.error(f"Config {config_idx+1} test failed with error: {e}")
254
+ import traceback
255
+
256
+ logging.error(traceback.format_exc())
257
+ results.append((config_idx + 1, False, False, False))
258
+
259
+ # Summary
260
+ logging.info("\n===== Test Results Summary =====")
261
+ for config_idx, overall_success, forward_success, backward_success in results:
262
+ overall_status = "✓ PASSED" if overall_success else "✗ FAILED"
263
+ forward_status = "✓ PASSED" if forward_success else "✗ FAILED"
264
+ backward_status = "✓ PASSED" if backward_success else "✗ FAILED"
265
+
266
+ logging.info(f"Config {config_idx}: {overall_status}")
267
+ logging.info(f" - Forward pass: {forward_status}")
268
+ logging.info(f" - Backward pass: {backward_status}")
269
+
270
+ return all(overall_success for _, overall_success, _, _ in results)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ logging.info(
275
+ "Running verification for both forward and backward pass of M*G grouped GEMM"
276
+ )
277
+
278
+ # Run basic forward pass test
279
+ logging.info("\n===== Running basic forward pass test =====")
280
+ success_forward = test_forward_pass()
281
+ logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}")
282
+
283
+ # Run basic backward pass test
284
+ logging.info("\n===== Running basic backward pass test =====")
285
+ success_backward = test_backward_pass()
286
+ logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}")
287
+
288
+ # Run multiple DeepSeek configs with forward and backward verification
289
+ logging.info("\n===== Running tests for all DeepSeek configs =====")
290
+ success_configs = test_multiple_deepseek_configs()
291
+ logging.info(
292
+ f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}"
293
+ )
294
+
295
+ # Overall result
296
+ overall_success = success_forward and success_backward and success_configs
297
+ logging.info(
298
+ f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}"
299
+ )
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # credit - flat index forward kernel is derived from FBGemm:
8
+ # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
9
+
10
+ # pyre-unsafe
11
+ import functools
12
+ import logging
13
+
14
+ import os
15
+ import sys
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import torch
19
+
20
+ import triton
21
+ import triton.language as tl
22
+ from triton import Config as TConfig
23
+
24
+ from triton.runtime import driver # @manual
25
+
26
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
27
+
28
+ from tma_autotuning import (
29
+ ALIGN_SIZE_M,
30
+ _NV_CONFIGS,
31
+ CudaUtils,
32
+ early_config_prune,
33
+ TmaDescriptorHelper,
34
+ )
35
+
36
+
37
+ # Configure logging
38
+ logging.basicConfig(
39
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
40
+ )
41
+
42
+ # ============== Start Triton Kernels ===============
43
+
44
+
45
+ @triton.autotune(
46
+ configs=_NV_CONFIGS,
47
+ key=["G", "M_BUCKET", "N", "K"],
48
+ prune_configs_by={"early_config_prune": early_config_prune},
49
+ )
50
+ @triton.jit
51
+ def _kernel_mg_forward_hopper(
52
+ a_desc_ptr,
53
+ b_desc_ptr,
54
+ c_ptr,
55
+ workspace,
56
+ m_sizes,
57
+ # problem sizes
58
+ G: tl.constexpr,
59
+ M_BUCKET: tl.constexpr,
60
+ N: tl.constexpr,
61
+ K: tl.constexpr,
62
+ # config
63
+ NUM_SMS: tl.constexpr,
64
+ TMA_SIZE: tl.constexpr,
65
+ USE_EPILOGUE_SUBTILING: tl.constexpr,
66
+ # tiles
67
+ BLOCK_SIZE_M: tl.constexpr,
68
+ BLOCK_SIZE_N: tl.constexpr,
69
+ BLOCK_SIZE_K: tl.constexpr,
70
+ ) -> None:
71
+ """
72
+ Flat index style forward kernel for Hopper.
73
+ For simplicity, we always use TMA Load and TMA Store
74
+ """
75
+ tbidx = tl.program_id(0) # thread block index
76
+
77
+ c_dtype = c_ptr.dtype.element_ty # output dtype
78
+
79
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store
80
+
81
+ M_end = 0
82
+ M_start = 0
83
+ processed_tiles = 0
84
+ # Size of individual weight matrix
85
+ n_size = N // G
86
+ n_start = 0
87
+
88
+ for g in range(G):
89
+ # Move down along groups
90
+ # reset to new M offset
91
+ M_start = M_end
92
+ m_size = tl.load(m_sizes + g)
93
+ M_end = M_start + m_size
94
+ n_start = n_size * g
95
+
96
+ if m_size > 0:
97
+ # Process this group
98
+
99
+ # Acquire hold on c_desc_ptr for TMA Store
100
+ tl.extra.cuda.experimental_device_tensormap_create2d(
101
+ desc_ptr=c_desc_ptr,
102
+ global_address=c_ptr + M_start * n_size,
103
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
104
+ global_size=[m_size, n_size],
105
+ element_ty=c_dtype,
106
+ )
107
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
108
+
109
+ # tiles for this group
110
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
111
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
112
+ group_num_tiles = num_m_tiles * num_n_tiles
113
+
114
+ while tbidx >= processed_tiles and tbidx < (
115
+ processed_tiles + group_num_tiles
116
+ ):
117
+ group_index = tbidx - processed_tiles
118
+
119
+ # columnwise
120
+ tile_m_index = group_index % num_m_tiles
121
+ tile_n_index = group_index // num_m_tiles
122
+
123
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
124
+
125
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
126
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
127
+ global_n_offset = (n_start + n_offset).to(tl.int32)
128
+
129
+ for k_offset in range(0, K, BLOCK_SIZE_K):
130
+ # input block [M,K]
131
+ a = tl._experimental_descriptor_load(
132
+ a_desc_ptr,
133
+ [m_offset, k_offset],
134
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
135
+ c_dtype,
136
+ )
137
+ # weight block [N, K]
138
+ b = tl._experimental_descriptor_load(
139
+ b_desc_ptr,
140
+ [global_n_offset, k_offset],
141
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
142
+ c_dtype,
143
+ )
144
+
145
+ accumulator += tl.dot(a, b.T)
146
+
147
+ # Store using TMA
148
+
149
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
150
+
151
+ if USE_EPILOGUE_SUBTILING:
152
+ acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
153
+ acc = tl.permute(acc, (0, 2, 1))
154
+ acc0, acc1 = tl.split(acc)
155
+ c0 = acc0.to(c_dtype)
156
+ tl._experimental_descriptor_store(
157
+ c_desc_ptr, c0, [m_offset, n_offset]
158
+ )
159
+ c1 = acc1.to(c_dtype)
160
+ tl._experimental_descriptor_store(
161
+ c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
162
+ )
163
+ else:
164
+ tl._experimental_descriptor_store(
165
+ c_desc_ptr,
166
+ accumulator.to(c_dtype),
167
+ [m_offset, n_offset],
168
+ )
169
+ # move to next tile in group
170
+ tbidx += NUM_SMS
171
+ # Update the total tiles count for the next group
172
+ processed_tiles += group_num_tiles
173
+
174
+
175
+ @triton.autotune(
176
+ configs=_NV_CONFIGS,
177
+ key=["G", "M_BUCKET", "N", "K"],
178
+ prune_configs_by={"early_config_prune": early_config_prune},
179
+ )
180
+ @triton.jit
181
+ def _kernel_mg_forward_tma(
182
+ a_desc_ptr,
183
+ b_desc_ptr,
184
+ c_ptr,
185
+ workspace,
186
+ m_sizes,
187
+ a_scale_ptr,
188
+ b_scale_ptr,
189
+ # problem sizes
190
+ G: tl.constexpr,
191
+ M_BUCKET: tl.constexpr,
192
+ N: tl.constexpr,
193
+ K: tl.constexpr,
194
+ # config
195
+ NUM_SMS: tl.constexpr,
196
+ USE_TMA_LOAD: tl.constexpr,
197
+ USE_TMA_STORE: tl.constexpr,
198
+ TMA_SIZE: tl.constexpr,
199
+ USE_FP8: tl.constexpr,
200
+ # tiles
201
+ BLOCK_SIZE_M: tl.constexpr,
202
+ BLOCK_SIZE_N: tl.constexpr,
203
+ BLOCK_SIZE_K: tl.constexpr,
204
+ ) -> None:
205
+ """
206
+ Flat index style forward kernel.
207
+ For simplicity, we always use TMA Load and TMA Store
208
+ """
209
+ tbidx = tl.program_id(0) # thread block index
210
+
211
+ c_dtype = c_ptr.dtype.element_ty
212
+
213
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
214
+
215
+ M_end = 0
216
+ processed_tiles = 0
217
+
218
+ for g in range(G):
219
+ # Move down along groups
220
+ # reset to new M offset
221
+ M_start = M_end
222
+ m_size = tl.load(m_sizes + g)
223
+ M_end = M_start + m_size
224
+
225
+ if m_size > 0:
226
+ # Process this group
227
+ n_size = N
228
+
229
+ # TMA Store prep
230
+ tl.extra.cuda.experimental_device_tensormap_create2d(
231
+ desc_ptr=c_desc_ptr,
232
+ global_address=c_ptr + M_start * N,
233
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
234
+ global_size=[m_size, n_size],
235
+ element_ty=c_dtype,
236
+ )
237
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
238
+
239
+ # tiles for this group
240
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
241
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
242
+ group_num_tiles = num_m_tiles * num_n_tiles
243
+
244
+ while tbidx >= processed_tiles and tbidx < (
245
+ processed_tiles + group_num_tiles
246
+ ):
247
+ group_index = tbidx - processed_tiles
248
+
249
+ tile_m_index = group_index % num_m_tiles
250
+ tile_n_index = group_index // num_m_tiles
251
+
252
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
253
+
254
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
255
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
256
+
257
+ for k_offset in range(0, K, BLOCK_SIZE_K):
258
+ # input block [M,K]
259
+ a = tl._experimental_descriptor_load(
260
+ a_desc_ptr,
261
+ [m_offset, k_offset],
262
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
263
+ c_dtype,
264
+ )
265
+ # weight block [N, K]
266
+ b = tl._experimental_descriptor_load(
267
+ b_desc_ptr,
268
+ [n_offset, k_offset],
269
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
270
+ c_dtype,
271
+ )
272
+
273
+ accumulator += tl.dot(a, b.T)
274
+
275
+ # Store using TMA
276
+
277
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
278
+ # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
279
+
280
+ tl._experimental_descriptor_store(
281
+ c_desc_ptr,
282
+ accumulator.to(c_dtype),
283
+ [m_offset, n_offset],
284
+ )
285
+
286
+ # Move to the next tile
287
+ tbidx += NUM_SMS
288
+ # Update the total tiles count for the next group
289
+ processed_tiles += group_num_tiles
290
+
291
+
292
+ @triton.autotune(
293
+ configs=_NV_CONFIGS,
294
+ key=["G", "M_BUCKET", "N", "K"],
295
+ prune_configs_by={"early_config_prune": early_config_prune},
296
+ )
297
+ @triton.jit
298
+ def _kernel_mg_forward_no_tma(
299
+ a_ptr,
300
+ b_ptr,
301
+ c_ptr,
302
+ workspace,
303
+ m_sizes,
304
+ # problem sizes
305
+ G: tl.constexpr,
306
+ M_BUCKET: tl.constexpr,
307
+ N: tl.constexpr,
308
+ K: tl.constexpr,
309
+ # config
310
+ NUM_SMS: tl.constexpr,
311
+ USE_TMA_LOAD: tl.constexpr,
312
+ USE_TMA_STORE: tl.constexpr,
313
+ TMA_SIZE: tl.constexpr,
314
+ # tiles
315
+ BLOCK_SIZE_M: tl.constexpr,
316
+ BLOCK_SIZE_N: tl.constexpr,
317
+ BLOCK_SIZE_K: tl.constexpr,
318
+ ) -> None:
319
+ """
320
+ Flat index style forward kernel.
321
+ For bc and Ampere, we never use TMA Load and TMA Store
322
+ """
323
+ tbidx = tl.program_id(0) # thread block index
324
+
325
+ c_dtype = c_ptr.dtype.element_ty
326
+ c_desc_ptr = None
327
+
328
+ M_end = 0
329
+ processed_tiles = 0
330
+
331
+ for g in range(G):
332
+ # Move down along groups
333
+ # reset to new M offset
334
+ M_start = M_end
335
+ m_size = tl.load(m_sizes + g)
336
+ M_end = M_start + m_size
337
+
338
+ if m_size > 0:
339
+ # Process this group
340
+ n_size = N
341
+
342
+ # tiles for this group
343
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
344
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
345
+ group_num_tiles = num_m_tiles * num_n_tiles
346
+
347
+ while tbidx >= processed_tiles and tbidx < (
348
+ processed_tiles + group_num_tiles
349
+ ):
350
+ group_index = tbidx - processed_tiles
351
+
352
+ tile_m_index = group_index % num_m_tiles
353
+ tile_n_index = group_index // num_m_tiles
354
+
355
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
356
+
357
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
358
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
359
+
360
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
361
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
362
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
363
+
364
+ a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
365
+ b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]
366
+
367
+ for k_offset in range(0, K, BLOCK_SIZE_K):
368
+ # Load with bounds checking
369
+ a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
370
+ b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
371
+
372
+ # Main matmul
373
+ accumulator += tl.dot(a, b.T)
374
+
375
+ # Update pointers for next block
376
+ a_ptrs += BLOCK_SIZE_K
377
+ b_ptrs += BLOCK_SIZE_K
378
+
379
+ # Store without TMA
380
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
381
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
382
+
383
+ c = accumulator.to(c_dtype)
384
+
385
+ tl.store(
386
+ c_ptr
387
+ + (M_start + offs_am[:, None]) * N # Row stride is N
388
+ + offs_bn[None, :], # Column offset
389
+ c,
390
+ mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
391
+ )
392
+ # Move to the next tile
393
+ tbidx += NUM_SMS
394
+ # Update the total tiles count for the next group
395
+ processed_tiles += group_num_tiles
396
+
397
+
398
+ """
399
+ Backward pass for grouped GEMM with Triton, where grouping is M*G
400
+ We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
401
+ """
402
+
403
+
404
+ # ---- dx flat linear indexed ----
405
+ @triton.autotune(
406
+ configs=_NV_CONFIGS,
407
+ key=["G", "M_BUCKET", "N", "K"],
408
+ prune_configs_by={"early_config_prune": early_config_prune},
409
+ )
410
+ @triton.jit
411
+ def _kernel_mg_dx_tma(
412
+ grad_output_desc_ptr, # [MG, N]
413
+ w_desc_ptr, # [N, K]
414
+ grad_input_ptr, # output grad_x [MG, K]
415
+ workspace, # for TMA store
416
+ m_sizes, # group sizes [G]
417
+ # problem sizes
418
+ G: tl.constexpr,
419
+ M_BUCKET: tl.constexpr,
420
+ N: tl.constexpr,
421
+ K: tl.constexpr,
422
+ # config
423
+ NUM_SMS: tl.constexpr,
424
+ USE_TMA_LOAD: tl.constexpr,
425
+ USE_TMA_STORE: tl.constexpr,
426
+ TMA_SIZE: tl.constexpr,
427
+ # tiles
428
+ BLOCK_SIZE_M: tl.constexpr,
429
+ BLOCK_SIZE_N: tl.constexpr,
430
+ BLOCK_SIZE_K: tl.constexpr,
431
+ ) -> None:
432
+ """
433
+ TMA-optimized kernel for computing gradients with respect to input (dx).
434
+ For the forward pass Y = X @ W.T, the backward for input is:
435
+ grad_X = grad_Y @ W
436
+
437
+ This maps to [MG, N] @ [N, K] -> [MG, K]
438
+
439
+ Key differences from forward:
440
+ 1. W is used directly and not transposed
441
+ 2. The reduction dimension is now N (not K)
442
+ 3. Output is [M, K] instead of [M, N]
443
+ """
444
+ tbidx = tl.program_id(0) # thread block index
445
+
446
+ c_dtype = grad_input_ptr.dtype.element_ty
447
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
448
+
449
+ M_end = 0
450
+ processed_tiles = 0
451
+
452
+ for g in range(G):
453
+ # Move down along groups - same as forward
454
+ M_start = M_end
455
+ m_size = tl.load(m_sizes + g)
456
+ M_end = M_start + m_size
457
+
458
+ if m_size > 0:
459
+ # Process this group
460
+ # tiles for this group - now producing [M, K] output
461
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
462
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
463
+ group_num_tiles = num_m_tiles * num_k_tiles
464
+
465
+ # TMA Store prep for [M, K] output
466
+ tl.extra.cuda.experimental_device_tensormap_create2d(
467
+ desc_ptr=c_desc_ptr,
468
+ global_address=grad_input_ptr + M_start * K,
469
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
470
+ global_size=[m_size, K],
471
+ element_ty=c_dtype,
472
+ )
473
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
474
+
475
+ while tbidx >= processed_tiles and tbidx < (
476
+ processed_tiles + group_num_tiles
477
+ ):
478
+ group_index = tbidx - processed_tiles
479
+
480
+ # Different tiling scheme for [M, K] output
481
+ tile_m_index = group_index % num_m_tiles
482
+ tile_k_index = group_index // num_m_tiles
483
+
484
+ # for grad_input block [M, K]
485
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
486
+
487
+ # Position in full matrix
488
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
489
+ k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
490
+
491
+ # reduce along N dimension (instead of K in forward)
492
+ for n_offset in range(0, N, BLOCK_SIZE_N):
493
+ # grad_output block [M, N]
494
+ grad_output = tl._experimental_descriptor_load(
495
+ grad_output_desc_ptr,
496
+ [m_offset, n_offset],
497
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
498
+ c_dtype,
499
+ )
500
+
501
+ # weight block [N, K] - no transpose needed
502
+ w = tl._experimental_descriptor_load(
503
+ w_desc_ptr,
504
+ [n_offset, k_offset],
505
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
506
+ c_dtype,
507
+ )
508
+
509
+ # grad_x = grad_output @ w
510
+ # reducing along N dimension
511
+ accumulator += tl.dot(grad_output, w)
512
+
513
+ # Store using TMA
514
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
515
+ # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
516
+
517
+ tl._experimental_descriptor_store(
518
+ c_desc_ptr,
519
+ accumulator.to(c_dtype),
520
+ [m_offset, k_offset],
521
+ )
522
+
523
+ # Move to the next tile
524
+ tbidx += NUM_SMS
525
+
526
+ # Update the total tiles count for the next group
527
+ processed_tiles += group_num_tiles
528
+
529
+
530
+ # ---- dw flat linear indexed ----
531
+
532
+
533
+ @triton.autotune(
534
+ configs=_NV_CONFIGS,
535
+ key=["G", "M_BUCKET", "N", "K"],
536
+ prune_configs_by={"early_config_prune": early_config_prune},
537
+ )
538
+ @triton.jit
539
+ def _kernel_mg_dw_tma(
540
+ x_desc_ptr, # input descriptor [M_total, K]
541
+ grad_output_desc_ptr, # grad_output descriptor [M_total, N]
542
+ grad_weight_ptr, # output grad_w [N, K]
543
+ workspace, # workspace for TMA store
544
+ m_sizes, # group sizes [G]
545
+ # problem sizes
546
+ G: tl.constexpr,
547
+ M_BUCKET: tl.constexpr,
548
+ N: tl.constexpr,
549
+ K: tl.constexpr,
550
+ # config
551
+ NUM_SMS: tl.constexpr,
552
+ USE_TMA_LOAD: tl.constexpr,
553
+ USE_TMA_STORE: tl.constexpr,
554
+ TMA_SIZE: tl.constexpr,
555
+ # tiles
556
+ BLOCK_SIZE_N: tl.constexpr,
557
+ BLOCK_SIZE_K: tl.constexpr,
558
+ BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension
559
+ ) -> None:
560
+ """
561
+ Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
562
+ Uses flat index structure similar to forward.
563
+
564
+ For the forward pass Y = X @ W.T,
565
+ the backward for weights is:
566
+ grad_W = grad_Y.T @ X
567
+
568
+ Where:
569
+ - grad_Y is [MG, N]
570
+ - X is [MG, K]
571
+ - grad_W is [N, K]
572
+ - we return [N,K]
573
+ """
574
+ # Get thread block index l
575
+ tbidx = tl.program_id(0)
576
+
577
+ # Get output data type
578
+ c_dtype = grad_weight_ptr.dtype.element_ty
579
+
580
+ # Calculate number of output tiles
581
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
582
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
583
+ total_output_tiles = num_n_tiles * num_k_tiles
584
+
585
+ # Process tiles in strided manner across SMs
586
+ for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
587
+ # Calculate tile indices
588
+ tile_n_idx = tile_idx % num_n_tiles
589
+ tile_k_idx = tile_idx // num_n_tiles
590
+
591
+ # Calculate global offsets
592
+ n_offset = tile_n_idx * BLOCK_SIZE_N
593
+ k_offset = tile_k_idx * BLOCK_SIZE_K
594
+
595
+ # Initialize accumulator for this output tile [N, K]
596
+ accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
597
+
598
+ # Process each group
599
+ M_end = 0
600
+ for g in range(G):
601
+ # Get group boundaries
602
+ M_start = M_end
603
+ m_size = tl.load(m_sizes + g)
604
+ M_end = M_start + m_size
605
+
606
+ # Only process if group is non-empty
607
+ if m_size > 0:
608
+ # Process this group in chunks along the M dimension
609
+ for m_offset in range(0, m_size, BLOCK_SIZE_M):
610
+ # Calculate actual block size (handling boundary)
611
+ m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)
612
+
613
+ # Only process if we have actual work to do
614
+ if m_block_size > 0:
615
+ # Global offset for this chunk
616
+ m_global_offset = M_start + m_offset
617
+
618
+ if USE_TMA_LOAD:
619
+ # Load input chunk [M_chunk, K] using TMA
620
+ x_block = tl._experimental_descriptor_load(
621
+ x_desc_ptr,
622
+ [m_global_offset, k_offset],
623
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
624
+ c_dtype,
625
+ )
626
+
627
+ # Load grad_output chunk [M_chunk, N] using TMA
628
+ grad_output_block = tl._experimental_descriptor_load(
629
+ grad_output_desc_ptr,
630
+ [m_global_offset, n_offset],
631
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
632
+ c_dtype,
633
+ )
634
+
635
+ # Apply masks for valid regions
636
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
637
+ m_mask = offs_m < m_block_size
638
+
639
+ # Zero out invalid elements
640
+ x_block = tl.where(m_mask[:, None], x_block, 0.0)
641
+ grad_output_block = tl.where(
642
+ m_mask[:, None], grad_output_block, 0.0
643
+ )
644
+ else:
645
+ # Manual load with bounds checking
646
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
647
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
648
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
649
+
650
+ # Create masks
651
+ m_mask = offs_m < m_block_size
652
+ n_mask = offs_n < N - n_offset
653
+ k_mask = offs_k < K - k_offset
654
+
655
+ # Combined masks
656
+ mk_mask = m_mask[:, None] & k_mask[None, :]
657
+ mn_mask = m_mask[:, None] & n_mask[None, :]
658
+
659
+ # Global offsets for loading
660
+ m_global_offs = m_global_offset + offs_m
661
+
662
+ # Load x block [M_chunk, K]
663
+ x_block = tl.load(
664
+ x_desc_ptr
665
+ + m_global_offs[:, None] * K
666
+ + (k_offset + offs_k)[None, :],
667
+ mask=mk_mask,
668
+ other=0.0,
669
+ )
670
+
671
+ # Load grad_output block [M_chunk, N]
672
+ grad_output_block = tl.load(
673
+ grad_output_desc_ptr
674
+ + m_global_offs[:, None] * N
675
+ + (n_offset + offs_n)[None, :],
676
+ mask=mn_mask,
677
+ other=0.0,
678
+ )
679
+
680
+ # Compute partial contribution: grad_W += grad_Y.T @ X
681
+ # transpose grad_output for the matmul
682
+ contribution = tl.dot(
683
+ grad_output_block.to(tl.float32).T, # [N, M_chunk]
684
+ x_block.to(tl.float32), # [M_chunk, K]
685
+ )
686
+
687
+ # Accumulate
688
+ accumulator += contribution
689
+
690
+ # Store the result
691
+ if USE_TMA_STORE:
692
+ # Store using TMA
693
+ tl._experimental_descriptor_store(
694
+ workspace, # TMA store descriptor
695
+ accumulator.to(c_dtype),
696
+ [n_offset, k_offset],
697
+ )
698
+ else:
699
+ # Manual store with bounds checking
700
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
701
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
702
+
703
+ # Create masks for bounds checking
704
+ n_mask = offs_n < N - n_offset
705
+ k_mask = offs_k < K - k_offset
706
+ output_mask = n_mask[:, None] & k_mask[None, :]
707
+
708
+ # Store the result
709
+ tl.store(
710
+ grad_weight_ptr
711
+ + (n_offset + offs_n)[:, None] * K
712
+ + (k_offset + offs_k)[None, :],
713
+ accumulator.to(c_dtype),
714
+ mask=output_mask,
715
+ )
716
+
717
+
718
+ # ======== End Triton kernels ========
719
+
720
+ # ======== Triton wrapper functions ========
721
+
722
+ # ----- main forward pass wrapper -----
723
+
724
+
725
+ def grouped_gemm_forward(
726
+ x: torch.Tensor,
727
+ w: torch.Tensor,
728
+ m_sizes: torch.Tensor,
729
+ tma_size: int = 128,
730
+ ) -> torch.Tensor:
731
+ """
732
+ M*G style grouped GEMM with TMA and Float8 support.
733
+ # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.
734
+
735
+ """
736
+ if not CudaUtils.verify_tma():
737
+ raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
738
+
739
+ G = m_sizes.shape[0]
740
+
741
+ assert x.is_contiguous()
742
+ assert w.is_contiguous()
743
+ assert m_sizes.is_contiguous()
744
+
745
+ # Total input size is now [M_total, K] where M_total is the sum of all group sizes
746
+ M_total, K = x.shape
747
+ N = w.shape[0] # N is now the same for all groups
748
+
749
+ assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
750
+
751
+ # Verify that all group sizes are multiples of ALIGN_SIZE_M
752
+ # This check is commented out because it will involve a GPU-CPU sync
753
+ # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"
754
+
755
+ # Create output tensor with correct shape [M_total, N]
756
+ y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
757
+
758
+ if M_total == 0:
759
+ return y
760
+
761
+ NUM_SMS = CudaUtils.get_num_sms()
762
+ USE_TMA_LOAD = True
763
+ USE_TMA_STORE = True
764
+ USE_EPILOGUE_SUBTILING = False
765
+
766
+ # TMA descriptor helper
767
+ desc_helper = None
768
+ desc_x = x
769
+ desc_w = w
770
+ workspace = None
771
+
772
+ if USE_TMA_LOAD:
773
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
774
+ desc_helper.init_tma_descriptor("x")
775
+ desc_helper.init_tma_descriptor("w")
776
+ desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
777
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
778
+
779
+ if USE_TMA_STORE:
780
+ workspace = torch.empty(
781
+ NUM_SMS * desc_helper.tma_size,
782
+ device=x.device,
783
+ dtype=torch.uint8,
784
+ )
785
+
786
+ def grid(META):
787
+ if USE_TMA_LOAD:
788
+ nonlocal desc_helper
789
+ desc_helper.fill_2d_tma_descriptor(
790
+ "x",
791
+ x.data_ptr(),
792
+ M_total,
793
+ K,
794
+ META["BLOCK_SIZE_M"],
795
+ META["BLOCK_SIZE_K"],
796
+ x.element_size(),
797
+ )
798
+
799
+ desc_helper.fill_2d_tma_descriptor(
800
+ "w",
801
+ w.data_ptr(),
802
+ N,
803
+ K,
804
+ META["BLOCK_SIZE_N"],
805
+ META["BLOCK_SIZE_K"],
806
+ w.element_size(),
807
+ )
808
+ return (NUM_SMS,)
809
+
810
+ M_BUCKET = triton.next_power_of_2(M_total)
811
+
812
+ _kernel_mg_forward_hopper[grid](
813
+ desc_x,
814
+ desc_w,
815
+ y,
816
+ workspace,
817
+ m_sizes,
818
+ G,
819
+ M_BUCKET,
820
+ N,
821
+ K,
822
+ NUM_SMS,
823
+ TMA_SIZE=tma_size,
824
+ USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
825
+ )
826
+
827
+ return y
828
+
829
+
830
+ # ======== Improved Backward =============
831
+ def grouped_gemm_backward(
832
+ grad_output: torch.Tensor,
833
+ x: torch.Tensor,
834
+ w: torch.Tensor,
835
+ m_sizes: torch.Tensor,
836
+ use_tma: bool = True,
837
+ tma_size: int = 128,
838
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
839
+ """
840
+ Unified backward pass for grouped GeMM with M*G grouping.
841
+ Uses optimized TMA-based implementations for both dx and dw when available.
842
+
843
+ Args:
844
+ grad_output: Gradient of output, shape [M_total, N]
845
+ x: Input tensor from forward pass, shape [M_total, K]
846
+ w: Weight tensor from forward pass, shape [N, K]
847
+ m_sizes: Group sizes tensor, shape [G]
848
+ use_tma: Whether to try using TMA acceleration (if available)
849
+ tma_size: Size of TMA descriptor in bytes
850
+
851
+
852
+ Returns:
853
+ Tuple of gradients with respect to x and w: (grad_x, grad_w)
854
+ """
855
+ logging.info("Starting unified grouped_gemm_backward")
856
+
857
+ # do this once, seems expensive
858
+ NUM_SMS = CudaUtils.get_num_sms()
859
+
860
+ # Basic validation
861
+ G = m_sizes.shape[0]
862
+ M_total, K_x = x.shape
863
+ M_grad, N = grad_output.shape
864
+ N_w, K_w = w.shape
865
+
866
+ # Check dimensions
867
+ if K_x != K_w:
868
+ raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
869
+ if M_total != M_grad:
870
+ raise ValueError(
871
+ f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
872
+ )
873
+
874
+ # Check total M matches sum of group sizes
875
+ sum_m_sizes = m_sizes.sum().item()
876
+ if M_total != sum_m_sizes:
877
+ raise ValueError(
878
+ f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
879
+ )
880
+
881
+ # Make sure inputs are contiguous
882
+ grad_output = grad_output.contiguous()
883
+ x = x.contiguous()
884
+ w = w.contiguous()
885
+ m_sizes = m_sizes.contiguous()
886
+
887
+ # Check TMA support
888
+ can_use_tma = use_tma and CudaUtils.verify_tma()
889
+ if use_tma and not can_use_tma:
890
+ logging.info("TMA requested but not supported on this device")
891
+ use_tma = False
892
+
893
+ # Compute grad_x using flat linear implementation
894
+ try:
895
+ logging.info(f"Computing grad_x with flat linear kernel")
896
+
897
+ # Use TMA-optimized implementation
898
+ grad_x = grouped_gemm_dx_tma(
899
+ grad_output=grad_output,
900
+ w=w,
901
+ m_sizes=m_sizes,
902
+ num_sms=NUM_SMS,
903
+ tma_size=tma_size,
904
+ )
905
+
906
+ except Exception as e:
907
+ logging.error(f"Error in grad_x computation: {e}")
908
+ raise
909
+
910
+ # Compute grad_w using flat linear style implementation
911
+ try:
912
+ logging.info(f"Computing grad_w with flat linear kernel")
913
+
914
+ grad_w = grouped_gemm_dw_tma(
915
+ x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
916
+ )
917
+ except Exception as e:
918
+ logging.error(f"Error in grad_w computation: {e}")
919
+ raise
920
+
921
+ return grad_x, grad_w
922
+
923
+
924
+ # ----- dx backward pass wrapper -----
925
+
926
+
927
+ def grouped_gemm_dx_tma(
928
+ grad_output: torch.Tensor,
929
+ w: torch.Tensor,
930
+ m_sizes: torch.Tensor,
931
+ num_sms: int = 132,
932
+ tma_size: int = 128,
933
+ ) -> torch.Tensor:
934
+ """
935
+ Optimized backward pass wrapper for computing gradient with respect to input (dx)
936
+ using TMA patterns similar to the forward pass.
937
+
938
+ Args:
939
+ grad_output: Gradient of output, shape [M_total, N]
940
+ w: Weight tensor, shape [N, K]
941
+ m_sizes: Group sizes tensor, shape [G]
942
+ tma_size: Size of TMA descriptor
943
+ # using_fp8: Whether to use FP8 quantization
944
+ # grad_output_scale: Scale for grad_output in FP8 mode
945
+ # w_scale: Scale for w in FP8 mode
946
+
947
+ Returns:
948
+ grad_x: Gradient with respect to x, shape [M_total, K]
949
+ """
950
+ """
951
+ Optimized backward pass for computing gradient with respect to input (dx)
952
+ using TMA patterns similar to the forward pass.
953
+
954
+ Args:
955
+ grad_output: Gradient of output, shape [M_total, N]
956
+ w: Weight tensor, shape [N, K]
957
+ m_sizes: Group sizes tensor, shape [G]
958
+ tma_size: Size of TMA descriptor
959
+ using_fp8: Whether to use FP8 quantization
960
+ # grad_output_scale: Scale for grad_output in FP8 mode
961
+ # w_scale: Scale for w in FP8 mode
962
+
963
+ Returns:
964
+ grad_x: Gradient with respect to x, shape [M_total, K]
965
+ """
966
+ if not CudaUtils.verify_tma():
967
+ raise NotImplementedError("Optimized dx computation requires TMA support")
968
+
969
+ G = m_sizes.shape[0]
970
+
971
+ assert grad_output.is_contiguous()
972
+ assert w.is_contiguous()
973
+ assert m_sizes.is_contiguous()
974
+
975
+ M_total, N_grad = grad_output.shape
976
+ N_w, K = w.shape
977
+
978
+ # Check dimensions
979
+ assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"
980
+
981
+ # Verify that the sum of m_sizes matches M_total
982
+ sum_m_sizes = m_sizes.sum().item()
983
+ assert (
984
+ M_total == sum_m_sizes
985
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
986
+
987
+ # Create output tensor (grad_x) with shape [M_total, K]
988
+ grad_x = torch.empty(
989
+ (M_total, K), device=grad_output.device, dtype=grad_output.dtype
990
+ )
991
+
992
+ NUM_SMS = num_sms # CudaUtils.get_num_sms()
993
+ USE_TMA_LOAD = True
994
+ USE_TMA_STORE = True
995
+
996
+ # Set up TMA descriptors
997
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
998
+ desc_helper.init_tma_descriptor("grad_output")
999
+ desc_helper.init_tma_descriptor("w")
1000
+ desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
1001
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
1002
+
1003
+ # Allocate workspace for TMA store
1004
+ workspace = torch.empty(
1005
+ NUM_SMS * desc_helper.tma_size,
1006
+ device=grad_output.device,
1007
+ dtype=torch.uint8,
1008
+ )
1009
+
1010
+ def grid(META):
1011
+ # Fill TMA descriptors with appropriate dimensions
1012
+ desc_helper.fill_2d_tma_descriptor(
1013
+ "grad_output",
1014
+ grad_output.data_ptr(),
1015
+ M_total,
1016
+ N_grad,
1017
+ META["BLOCK_SIZE_M"],
1018
+ META["BLOCK_SIZE_N"],
1019
+ grad_output.element_size(),
1020
+ )
1021
+
1022
+ desc_helper.fill_2d_tma_descriptor(
1023
+ "w",
1024
+ w.data_ptr(),
1025
+ N_w,
1026
+ K,
1027
+ META["BLOCK_SIZE_N"],
1028
+ META["BLOCK_SIZE_K"],
1029
+ w.element_size(),
1030
+ )
1031
+ return (NUM_SMS,)
1032
+
1033
+ M_BUCKET = triton.next_power_of_2(M_total)
1034
+
1035
+ # Launch the flat linear kernel for computing grad_x
1036
+ _kernel_mg_dx_tma[grid](
1037
+ desc_grad_output,
1038
+ desc_w,
1039
+ grad_x,
1040
+ workspace,
1041
+ m_sizes,
1042
+ G,
1043
+ M_BUCKET,
1044
+ N_grad, # N dimension is now the reduction dimension
1045
+ K,
1046
+ NUM_SMS,
1047
+ USE_TMA_LOAD,
1048
+ USE_TMA_STORE,
1049
+ TMA_SIZE=tma_size,
1050
+ )
1051
+
1052
+ return grad_x
1053
+
1054
+
1055
+ # ======== dw wrapper function ==========
1056
+
1057
+
1058
+ def grouped_gemm_dw_tma(
1059
+ x: torch.Tensor,
1060
+ grad_output: torch.Tensor,
1061
+ m_sizes: torch.Tensor,
1062
+ num_sms: int = 132,
1063
+ tma_size: int = 128,
1064
+ ) -> torch.Tensor:
1065
+ """
1066
+ Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
1067
+ For the forward pass Y = X @ W.T, the backward for weights is:
1068
+ grad_W = grad_Y.T @ X
1069
+
1070
+ Args:
1071
+ x: Input tensor, shape [M_total, K]
1072
+ grad_output: Gradient of output, shape [M_total, N]
1073
+ m_sizes: Group sizes tensor, shape [G]
1074
+ tma_size: Size of TMA descriptor in bytes
1075
+
1076
+
1077
+ Returns:
1078
+ grad_w: Gradient with respect to weights, shape [N, K]
1079
+ """
1080
+ # Check TMA support
1081
+ has_tma_support = CudaUtils.verify_tma()
1082
+
1083
+ # Get group count
1084
+ G = m_sizes.shape[0]
1085
+
1086
+ # Ensure contiguous tensors
1087
+ x = x.contiguous()
1088
+ grad_output = grad_output.contiguous()
1089
+ m_sizes = m_sizes.contiguous()
1090
+
1091
+ # Get dimensions
1092
+ M_total, K_x = x.shape
1093
+ M_grad, N = grad_output.shape
1094
+
1095
+ # Check dimensions
1096
+ assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"
1097
+
1098
+ # Verify that the sum of m_sizes matches M_total
1099
+ sum_m_sizes = m_sizes.sum().item()
1100
+ assert (
1101
+ sum_m_sizes == M_total
1102
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
1103
+
1104
+ # Create output tensor (grad_w) with shape [N, K]
1105
+ grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)
1106
+
1107
+ NUM_SMS = num_sms
1108
+
1109
+ # TODO - hardcoded for now...but should set TMA flags based on hardware support
1110
+ USE_TMA_LOAD = True # has_tma_support
1111
+ USE_TMA_STORE = True # has_tma_support
1112
+
1113
+ # Set up TMA descriptors or direct pointers
1114
+ if USE_TMA_LOAD or USE_TMA_STORE:
1115
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
1116
+
1117
+ if USE_TMA_LOAD:
1118
+ desc_helper.init_tma_descriptor("x")
1119
+ desc_helper.init_tma_descriptor("grad_output")
1120
+ x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
1121
+ grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
1122
+ "grad_output"
1123
+ )
1124
+ else:
1125
+ x_desc = x
1126
+ grad_output_desc = grad_output
1127
+
1128
+ if USE_TMA_STORE:
1129
+ desc_helper.init_tma_descriptor("grad_w")
1130
+ workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
1131
+ else:
1132
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1133
+ else:
1134
+ # If not using TMA, just use the tensors directly
1135
+ x_desc = x
1136
+ grad_output_desc = grad_output
1137
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1138
+
1139
+ # M_BUCKET for grid size
1140
+ M_BUCKET = triton.next_power_of_2(M_total)
1141
+
1142
+ # Define grid for kernel launch
1143
+ def grid(META):
1144
+ if USE_TMA_LOAD or USE_TMA_STORE:
1145
+
1146
+ if USE_TMA_LOAD:
1147
+ desc_helper.fill_2d_tma_descriptor(
1148
+ "x",
1149
+ x.data_ptr(),
1150
+ M_total,
1151
+ K_x,
1152
+ META["BLOCK_SIZE_M"],
1153
+ META["BLOCK_SIZE_K"],
1154
+ x.element_size(),
1155
+ )
1156
+
1157
+ desc_helper.fill_2d_tma_descriptor(
1158
+ "grad_output",
1159
+ grad_output.data_ptr(),
1160
+ M_total,
1161
+ N,
1162
+ META["BLOCK_SIZE_M"],
1163
+ META["BLOCK_SIZE_N"],
1164
+ grad_output.element_size(),
1165
+ )
1166
+
1167
+ if USE_TMA_STORE:
1168
+ desc_helper.fill_2d_tma_descriptor(
1169
+ "grad_w",
1170
+ grad_w.data_ptr(),
1171
+ N,
1172
+ K_x,
1173
+ META["BLOCK_SIZE_N"],
1174
+ META["BLOCK_SIZE_K"],
1175
+ grad_w.element_size(),
1176
+ )
1177
+
1178
+ # Return grid size - one block per SM for balanced work distribution
1179
+ return (NUM_SMS,)
1180
+
1181
+ # Launch the optimized kernel
1182
+ _kernel_mg_dw_tma[grid](
1183
+ x_desc,
1184
+ grad_output_desc,
1185
+ grad_w,
1186
+ workspace,
1187
+ m_sizes,
1188
+ G,
1189
+ M_BUCKET,
1190
+ N,
1191
+ K_x,
1192
+ NUM_SMS,
1193
+ USE_TMA_LOAD,
1194
+ USE_TMA_STORE,
1195
+ TMA_SIZE=tma_size,
1196
+ )
1197
+
1198
+ return grad_w
1199
+
1200
+
1201
+ # ======== End Backwards Wrapper Functions =============
1202
+
1203
+ # ======== PyTorch wrapper functions ========
1204
+
1205
+
1206
+ class GroupedGEMM_mg(torch.autograd.Function):
1207
+ """
1208
+ Autograd function for GroupedGEMM with M*G grouping.
1209
+ Supports both standard and FP8 quantized operations.
1210
+ """
1211
+
1212
+ @staticmethod
1213
+ def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
1214
+ """
1215
+ Forward pass of GroupedGEMM.
1216
+
1217
+ Args:
1218
+ x: Input tensor, shape [M_total, K]
1219
+ w: Weight tensor, shape [N, K]
1220
+ m_sizes: Tensor of shape [G] containing the size of each group
1221
+ use_tma: Whether to try using TMA acceleration (if available)
1222
+ tma_size: Size of TMA descriptor in bytes
1223
+ using_fp8: Whether to use FP8 quantization
1224
+
1225
+ Returns:
1226
+ Output tensor, shape [M_total, N]
1227
+ """
1228
+
1229
+ # Use regular forward without quantization
1230
+ output = grouped_gemm_forward(
1231
+ x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
1232
+ )
1233
+
1234
+ # Save inputs and parameters for backward pass
1235
+ ctx.save_for_backward(x, w, m_sizes)
1236
+ ctx.use_tma = use_tma
1237
+ ctx.tma_size = tma_size
1238
+
1239
+ ctx.save_for_backward(x, w, m_sizes)
1240
+
1241
+ return output
1242
+
1243
+ @staticmethod
1244
+ def backward(ctx, grad_output):
1245
+ """
1246
+ Backward pass of M*G GroupedGEMM.
1247
+
1248
+ Args:
1249
+ grad_output: Gradient of output, shape [M_total, N]
1250
+
1251
+ Returns:
1252
+ Tuple of gradients:
1253
+ - grad_x: Gradient with respect to x, shape [M_total, K]
1254
+ - grad_w: Gradient with respect to w, shape [N, K]
1255
+ - None: Gradient with respect to m_sizes (not differentiable)
1256
+ - None: Gradient with respect to use_tma (not differentiable)
1257
+ - None: Gradient with respect to tma_size (not differentiable)
1258
+
1259
+ """
1260
+ # Retrieve saved tensors and parameters
1261
+
1262
+ x, w, m_sizes = ctx.saved_tensors
1263
+
1264
+ use_tma = ctx.use_tma
1265
+ tma_size = ctx.tma_size
1266
+
1267
+ # Compute gradients using the unified implementation
1268
+ grad_x, grad_w = grouped_gemm_backward(
1269
+ grad_output=grad_output,
1270
+ x=x,
1271
+ w=w,
1272
+ m_sizes=m_sizes,
1273
+ use_tma=use_tma,
1274
+ tma_size=tma_size,
1275
+ )
1276
+
1277
+ # Return gradients for all inputs (None for non-differentiable parameters)
1278
+ return grad_x, grad_w, None, None
1279
+
1280
+
1281
+ def mg_grouped_gemm(
1282
+ x: torch.Tensor,
1283
+ w: torch.Tensor,
1284
+ m_sizes: torch.Tensor,
1285
+ use_tma: bool = True,
1286
+ tma_size: int = 128,
1287
+ using_fp8: bool = False,
1288
+ ) -> torch.Tensor:
1289
+ """
1290
+ Unified differentiable grouped GEMM operation for M*G grouped GEMM.
1291
+ Supports both standard precision and FP8 quantized operations.
1292
+
1293
+ Args:
1294
+ x: Input tensor, shape [M_total, K]
1295
+ w: Weight tensor, shape [N, K]
1296
+ m_sizes: Tensor of shape [G] containing the size of each group
1297
+ use_tma: Whether to try using TMA acceleration (if available)
1298
+ tma_size: Size of TMA descriptor in bytes
1299
+ using_fp8: Whether to use FP8 quantization
1300
+
1301
+ Returns:
1302
+ Output tensor, shape [M_total, N]
1303
+ """
1304
+ return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
16
+ )
17
+
18
+
19
+ def compute_reference_forward(x, w, m_sizes):
20
+ """
21
+ Compute reference forward pass using PyTorch operations.
22
+
23
+ Args:
24
+ x (torch.Tensor): Input tensor of shape (M, K)
25
+ w (torch.Tensor): Weight tensor of shape (N, K)
26
+ m_sizes (torch.Tensor): Group sizes tensor of shape (G)
27
+
28
+ Returns:
29
+ torch.Tensor: Reference output tensor of shape (M, N)
30
+ """
31
+ result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
32
+
33
+ m_start = 0
34
+ for g in range(len(m_sizes)):
35
+ m_size = m_sizes[g].item()
36
+ if m_size > 0:
37
+ m_end = m_start + m_size
38
+
39
+ # Extract group input
40
+ x_g = x[m_start:m_end]
41
+
42
+ # Compute group output: y_g = x_g @ w.T
43
+ y_g = torch.matmul(x_g, w.T)
44
+
45
+ # Store result
46
+ result[m_start:m_end] = y_g
47
+
48
+ # Update start index
49
+ m_start = m_end
50
+
51
+ return result
52
+
53
+
54
+ def compute_reference_backward(x, w, m_sizes, grad_output):
55
+ """
56
+ Compute reference backward pass using PyTorch autograd.
57
+
58
+ Args:
59
+ x (torch.Tensor): Input tensor of shape (M, K)
60
+ w (torch.Tensor): Weight tensor of shape (N, K)
61
+ m_sizes (torch.Tensor): Group sizes tensor of shape (G)
62
+ grad_output (torch.Tensor): Gradient tensor of shape (M, N)
63
+
64
+ Returns:
65
+ tuple: (grad_x, grad_w) gradient tensors
66
+ """
67
+ # Create autograd-enabled copies
68
+ x_autograd = x.detach().clone().requires_grad_(True)
69
+ w_autograd = w.detach().clone().requires_grad_(True)
70
+
71
+ # Compute forward pass
72
+ output = compute_reference_forward(x_autograd, w_autograd, m_sizes)
73
+
74
+ # Backpropagate
75
+ output.backward(grad_output)
76
+
77
+ return x_autograd.grad, w_autograd.grad
78
+
79
+
80
+ def analyze_tensor_differences(actual, expected, name):
81
+ """
82
+ Analyze differences between actual and expected tensors.
83
+
84
+ Args:
85
+ actual (torch.Tensor): Actual tensor
86
+ expected (torch.Tensor): Expected tensor
87
+ name (str): Name of the tensor for logging
88
+
89
+ Returns:
90
+ bool: True if tensors are close enough
91
+ """
92
+ rtol = 0.5 # Relative tolerance for float16
93
+ atol = 0.5 # Absolute tolerance for float16
94
+
95
+ # Analyze differences
96
+ diff = (actual - expected).abs()
97
+ max_idx = diff.argmax().item()
98
+ idx = np.unravel_index(max_idx, actual.shape)
99
+ max_diff = diff.max().item()
100
+
101
+ logging.info(f"Largest {name} difference: {max_diff} at {idx}")
102
+ logging.info(f"Values: {actual[idx].item()} vs {expected[idx].item()}")
103
+
104
+ is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol)
105
+
106
+ if is_close:
107
+ logging.info(f"✓ SUCCESS: {name} matches PyTorch reference")
108
+ else:
109
+ logging.error(f"✗ FAILURE: {name} mismatch detected")
110
+
111
+ # Count zeros
112
+ zeros_actual = (actual == 0).sum().item()
113
+ zeros_expected = (expected == 0).sum().item()
114
+ logging.info(
115
+ f"Zeros in {name} (actual): {zeros_actual}/{actual.numel()} ({zeros_actual/actual.numel()*100:.2f}%)"
116
+ )
117
+ logging.info(
118
+ f"Zeros in {name} (expected): {zeros_expected}/{expected.numel()} ({zeros_expected/expected.numel()*100:.2f}%)"
119
+ )
120
+
121
+ # Check for NaNs
122
+ nan_actual = torch.isnan(actual).sum().item()
123
+ if nan_actual > 0:
124
+ logging.error(f"NaN values detected in {name}: {nan_actual}")
125
+
126
+ return is_close
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # credit - TMAHelper class, AutoTuning are derived from FBGemm:
8
+ # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
9
+
10
+ # pyre-unsafe
11
+ import functools
12
+
13
+ import os
14
+ import sys
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ import torch
18
+
19
+ import triton
20
+ import triton.language as tl
21
+ from triton import Config as TConfig
22
+
23
+ from triton.runtime import driver # @manual
24
+
25
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
26
+
27
+
28
+ # ===== Supporting utils, CUDA and TMA =====
29
+
30
+
31
+ class CudaUtils:
32
+ @staticmethod
33
+ def is_cuda() -> bool:
34
+ """Check if Triton is running on CUDA backend."""
35
+ return driver.active.get_current_target().backend == "cuda"
36
+
37
+ @staticmethod
38
+ def verify_tma() -> bool:
39
+ """Check if TMA is supported on the current device."""
40
+ return (
41
+ CudaUtils.is_cuda()
42
+ and torch.cuda.is_available()
43
+ and torch.cuda.get_device_capability()[0] >= 9
44
+ )
45
+
46
+ @staticmethod
47
+ def get_num_sms() -> int:
48
+ """Get the number of streaming multiprocessors on the current device."""
49
+ if not CudaUtils.is_cuda():
50
+ raise RuntimeError("Triton is not running on CUDA backend")
51
+ if not torch.cuda.is_available():
52
+ raise RuntimeError("CUDA is not available")
53
+ return torch.cuda.get_device_properties("cuda").multi_processor_count
54
+
55
+
56
+ class TmaDescriptorHelper:
57
+ """Helper class for managing TMA descriptors in Triton kernels."""
58
+
59
+ class KernelParamWrapper:
60
+ """Wrapper to implement the TmaDescKernelParam interface."""
61
+
62
+ def __init__(self, desc: torch.Tensor):
63
+ self.desc = desc
64
+
65
+ def tma_desc_cpu_ptr(self) -> int:
66
+ """Return the CPU pointer to the TMA descriptor."""
67
+ return self.desc.data_ptr()
68
+
69
+ def __init__(self, tma_size: int = 128):
70
+ """Initialize the TMA descriptor helper.
71
+
72
+ Args:
73
+ tma_size: Size of the TMA descriptor in bytes
74
+ """
75
+ if not CudaUtils.verify_tma():
76
+ raise RuntimeError(
77
+ "TMA not supported on this device (requires Hopper or newer)"
78
+ )
79
+ if "nv_tma_desc_type" not in dir(tl):
80
+ raise RuntimeError(
81
+ "TMA grid constant descriptors not supported in your Triton version"
82
+ )
83
+
84
+ self.tma_size = tma_size
85
+ self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
86
+ self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
87
+ self.descriptors: Dict[str, torch.Tensor] = {}
88
+
89
+ def init_tma_descriptor(self, name: str) -> None:
90
+ """Initialize a TMA descriptor with the given name.
91
+
92
+ Call this method outside of the lambda function for grid size.
93
+ """
94
+ self.descriptors[name] = torch.empty(
95
+ self.tma_size, device="cpu", dtype=torch.int8
96
+ )
97
+
98
+ def fill_1d_tma_descriptor(
99
+ self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
100
+ ) -> None:
101
+ """Fill a 1D TMA descriptor.
102
+
103
+ Call this method inside the lambda function for grid size.
104
+ """
105
+ if name not in self.descriptors:
106
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
107
+
108
+ desc_x = self.descriptors[name]
109
+ if desc_x.data_ptr() % 64 != 0:
110
+ raise ValueError("TMA descriptor must be 64-byte aligned")
111
+ self.fill_1d_tma_descriptor_inner(
112
+ ptr, dim, block_dim, element_size, desc_x.data_ptr()
113
+ )
114
+
115
+ def fill_2d_tma_descriptor(
116
+ self,
117
+ name: str,
118
+ ptr: int,
119
+ dim1: int,
120
+ dim0: int,
121
+ block_dim1: int,
122
+ block_dim0: int,
123
+ element_size: int,
124
+ ) -> None:
125
+ """Fill a 2D TMA descriptor.
126
+
127
+ Call this method inside the lambda function for grid size.
128
+ """
129
+ if name not in self.descriptors:
130
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
131
+
132
+ desc_x = self.descriptors[name]
133
+ if desc_x.data_ptr() % 64 != 0:
134
+ raise ValueError("TMA descriptor must be 64-byte aligned")
135
+ self.fill_2d_tma_descriptor_inner(
136
+ ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
137
+ )
138
+
139
+ def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
140
+ """Get the TMA descriptor kernel parameter for the given name."""
141
+ if name not in self.descriptors or self.descriptors[name] is None:
142
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
143
+ return self.KernelParamWrapper(self.descriptors[name])
144
+
145
+
146
+ # ====== Autotuning utilities ======
147
+ ALIGN_SIZE_M = 128
148
+
149
+ _NV_CONFIGS = [
150
+ triton.Config(
151
+ {
152
+ "BLOCK_SIZE_M": block_size_m,
153
+ "BLOCK_SIZE_N": block_size_n,
154
+ "BLOCK_SIZE_K": block_size_k,
155
+ },
156
+ num_stages=num_stages,
157
+ num_warps=num_warps,
158
+ num_ctas=num_ctas,
159
+ )
160
+ for block_size_m in [ALIGN_SIZE_M, ]
161
+ for block_size_n in [64, 128, 256]
162
+ for block_size_k in [64, 128, 256]
163
+ for num_stages in [3, 4]
164
+ for num_warps in [4, 8]
165
+ for num_ctas in [1]
166
+ ]
167
+
168
+
169
+ def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
170
+ device = torch.cuda.current_device()
171
+ # Check for all possible pointer parameter names
172
+ if "grad_input_ptr" in named_args:
173
+ ptr_name = "grad_input_ptr"
174
+ elif "c_ptr" in named_args:
175
+ ptr_name = "c_ptr"
176
+ elif "grad_weight_ptr" in named_args:
177
+ ptr_name = "grad_weight_ptr"
178
+ else:
179
+ raise KeyError("No recognized pointer parameter found in kernel arguments")
180
+
181
+ if dtsize is None:
182
+ dtsize = named_args[ptr_name].element_size()
183
+ if dtype is None:
184
+ dtype = named_args[ptr_name].dtype
185
+
186
+ pruned_configs = []
187
+ for config in configs:
188
+ kw = config.kwargs
189
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
190
+ kw["BLOCK_SIZE_M"],
191
+ kw["BLOCK_SIZE_N"],
192
+ kw["BLOCK_SIZE_K"],
193
+ config.num_stages,
194
+ )
195
+ G, M, N, K = (
196
+ named_args["G"],
197
+ named_args["M_BUCKET"],
198
+ named_args["N"],
199
+ named_args["K"],
200
+ )
201
+
202
+ # 1. make sure we have enough smem
203
+ max_shared_memory = driver.active.utils.get_device_properties(device)[
204
+ "max_shared_mem"
205
+ ]
206
+
207
+ required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
208
+ if required_shared_memory > max_shared_memory:
209
+ continue
210
+
211
+ M_PER_GROUP = M // G
212
+ MIN_M_TILES = 64
213
+ # 2. make sure we don't load M tiles that are too big
214
+ if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
215
+ continue
216
+ # 3. make sure we don't load N tiles that are too small
217
+ if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
218
+ continue
219
+
220
+ num_sm = driver.active.utils.get_device_properties(device)[
221
+ "multiprocessor_count"
222
+ ]
223
+ N_TILES = N // BLOCK_N
224
+ MIN_N_TILES = 64
225
+ # 4. make sure we don't load N tiles that are too big
226
+ if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
227
+ continue
228
+ # 5. make sure we don't load N tiles that are too small
229
+ if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
230
+ continue
231
+ # 6. make sure K can be evenly divided
232
+ if K % BLOCK_K != 0:
233
+ continue
234
+
235
+ pruned_configs.append(config)
236
+
237
+ return pruned_configs
238
+
239
+
240
+ # ======== End Autotuning utilities ========
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import unittest
10
+ from typing import Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from mg_grouped_gemm import grouped_gemm_forward
16
+
17
+
18
+ class TestMG_GroupedGEMM(unittest.TestCase):
19
+ def setUp(self) -> None:
20
+ torch.manual_seed(2020)
21
+
22
+ def _run_grouped_gemm_test(
23
+ self,
24
+ shape: Tuple[int, int, int, int],
25
+ device: torch.device,
26
+ dtype: torch.dtype = torch.bfloat16,
27
+ atol: float = 1e-5,
28
+ rtol: float = 1.6e-2,
29
+ ) -> None:
30
+ G, M, N, K = shape
31
+ # In M*G grouping, input is [M*G, K] and weights are [N*G, K]
32
+ a = torch.randn(M * G, K, dtype=dtype, device=device)
33
+ b = torch.randn(N * G, K, dtype=dtype, device=device)
34
+
35
+ # Create equal-sized groups for simplicity
36
+ m_size = M
37
+ m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
38
+
39
+ result = grouped_gemm_forward(a, b, m_sizes)
40
+ self.assertTrue(result.shape == (M * G, N))
41
+
42
+ expected_result = torch.zeros(M * G, N, dtype=dtype, device=device)
43
+ m_start = 0
44
+ for g in range(G):
45
+ m_end = m_start + m_sizes[g]
46
+ b_slice = b[N * g : N * (g+1), :]
47
+ expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T
48
+ m_start = m_end
49
+
50
+ # Convert result to match input dtype if needed
51
+ result = result.to(dtype)
52
+ torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol)
53
+
54
+ def test_MG_grouped_gemm_bf16(self) -> None:
55
+ for G in (1, 4, 16):
56
+ for M in (128, 512, 1024):
57
+ print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}")
58
+ self._run_grouped_gemm_test(
59
+ (G, M, 1024, 1024),
60
+ torch.device("cuda"),
61
+ dtype=torch.bfloat16,
62
+ atol=1e-5,
63
+ rtol=1.6e-2,
64
+ )
65
+
66
+ def test_MG_grouped_gemm_deepseek_shapes(self) -> None:
67
+ """Test with shapes from Deepseek model."""
68
+ deepseek_shapes = [
69
+ (4, 2048, 4096, 7168), # G, M, N, K
70
+ (4, 2048, 7168, 2048),
71
+ (8, 512, 4096, 7168),
72
+ (8, 512, 7168, 2048),
73
+ ]
74
+
75
+ device = torch.device("cuda")
76
+
77
+ for shape in deepseek_shapes:
78
+ G, M, N, K = shape
79
+ print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}")
80
+ self._run_grouped_gemm_test(
81
+ shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2
82
+ )