Erland commited on
Commit
1f3d7c4
·
verified ·
1 Parent(s): e64f7a1

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. LICENSE +21 -0
  2. torchtitan/components/__pycache__/checkpoint.cpython-311.pyc +0 -0
  3. torchtitan/components/__pycache__/float8.cpython-311.pyc +0 -0
  4. torchtitan/components/__pycache__/metrics.cpython-311.pyc +0 -0
  5. torchtitan/distributed/__pycache__/utils.cpython-311.pyc +0 -0
  6. torchtitan/distributed/pipeline.py +201 -0
  7. torchtitan/experiments/deepseek_v3/LICENSE-CODE +21 -0
  8. torchtitan/experiments/deepseek_v3/README.md +40 -0
  9. torchtitan/experiments/deepseek_v3/attn_mask_utils.py +397 -0
  10. torchtitan/experiments/deepseek_v3/download.py +70 -0
  11. torchtitan/experiments/deepseek_v3/generate.py +308 -0
  12. torchtitan/experiments/deepseek_v3/indices.py +195 -0
  13. torchtitan/experiments/deepseek_v3/inference.sh +15 -0
  14. torchtitan/experiments/deepseek_v3/model_config.py +204 -0
  15. torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py +11 -0
  16. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
  17. torchtitan/experiments/deepseek_v3/train.py +142 -0
  18. torchtitan/experiments/flux/dataset/flux_dataset.py +267 -0
  19. torchtitan/experiments/flux/model/layers.py +286 -0
  20. torchtitan/experiments/flux/model/math.py +38 -0
  21. torchtitan/experiments/flux/model/model.py +177 -0
  22. torchtitan/experiments/flux/parallelize_flux.py +26 -0
  23. torchtitan/experiments/flux/requirements.txt +2 -0
  24. torchtitan/experiments/flux/scripts/download_autoencoder.py +61 -0
  25. torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
  26. torchtitan/experiments/flux/train.py +224 -0
  27. torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
  28. torchtitan/experiments/flux/utils.py +203 -0
  29. torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py +885 -0
  30. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py +1304 -0
  31. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py +240 -0
  32. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py +82 -0
  33. torchtitan/experiments/llama4/__init__.py +70 -0
  34. torchtitan/experiments/llama4/infra/expert_parallel.py +145 -0
  35. torchtitan/experiments/llama4/infra/parallelize_llama.py +159 -0
  36. torchtitan/experiments/llama4/model/__pycache__/args.cpython-311.pyc +0 -0
  37. torchtitan/experiments/llama4/model/args.py +109 -0
  38. torchtitan/experiments/llama4/model/moe.py +228 -0
  39. torchtitan/experiments/llama4/scripts/REAME.md +17 -0
  40. torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh +25 -0
  41. torchtitan/experiments/llama4/train_configs/debug_model.toml +74 -0
  42. torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +65 -0
  43. torchtitan/experiments/multimodal/requirements.txt +1 -0
  44. torchtitan/experiments/multimodal/tests/__init__.py +5 -0
  45. torchtitan/experiments/multimodal/tests/test_utils.py +58 -0
  46. torchtitan/experiments/multimodal/tokenizer/tiktoken.py +232 -0
  47. torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-311.pyc +0 -0
  48. torchtitan/experiments/simple_fsdp/tests/__init__.py +5 -0
  49. torchtitan/models/llama3/__pycache__/pipeline_llama.cpython-311.pyc +0 -0
  50. torchtitan/models/llama3/train_configs/debug_model.toml +74 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
torchtitan/components/__pycache__/checkpoint.cpython-311.pyc ADDED
Binary file (35.4 kB). View file
 
torchtitan/components/__pycache__/float8.cpython-311.pyc ADDED
Binary file (6.54 kB). View file
 
torchtitan/components/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (20.2 kB). View file
 
torchtitan/distributed/__pycache__/utils.cpython-311.pyc ADDED
Binary file (15.9 kB). View file
 
torchtitan/distributed/pipeline.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import math
7
+ import os
8
+ from typing import Callable, Optional
9
+
10
+ from torch.distributed.pipelining.schedules import (
11
+ _PipelineSchedule,
12
+ _PipelineScheduleRuntime,
13
+ get_schedule_class,
14
+ PipelineScheduleMulti,
15
+ PipelineScheduleSingle,
16
+ )
17
+ from torch.distributed.pipelining.stage import PipelineStage
18
+
19
+ from torchtitan.config_manager import JobConfig
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ __all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"]
24
+
25
+
26
+ # TODO: It's unclear if this API is general enough to be used by other models.
27
+ # If not, we should move it to a Transformer-specific directory.
28
+ def generate_split_points(
29
+ schedule_str: str,
30
+ layers_per_stage: Optional[int],
31
+ pp_dim: int,
32
+ num_layers: int,
33
+ input_weight: int = 1,
34
+ output_weight: int = 1,
35
+ ) -> list[str]:
36
+ """
37
+ Generate a list of split points based on the number of layers and
38
+ pipeline parallel dimension, ensuring the first and last stages have the least layers.
39
+
40
+ Args:
41
+ schedule_str (str): The string of the schedule name.
42
+ layers_per_stage (int): The number of layers per stage.
43
+ pp_dim (int): The pipeline parallel dimension.
44
+ num_layers (int): The number of layers in the model.
45
+ input_output_weight (int): The number of layers to consider the input/output modules in the layer calculation.
46
+
47
+ Returns:
48
+ list[str]: A list of split point FQNs.
49
+ """
50
+
51
+ schedule_class = get_schedule_class(schedule_str)
52
+ is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
53
+ num_stages_per_rank = 1 if is_single_stage_schedule else 2
54
+
55
+ if layers_per_stage is not None:
56
+ total_stages = math.ceil(num_layers / layers_per_stage)
57
+ if total_stages % pp_dim != 0:
58
+ raise ValueError(
59
+ f"Number of stages ({total_stages}) must be divisible by the pipeline parallel dimension ({pp_dim})."
60
+ f"Each rank should have the same number of stages. "
61
+ )
62
+ num_stages_per_rank = total_stages // pp_dim
63
+
64
+ if is_single_stage_schedule and num_stages_per_rank != 1:
65
+ raise ValueError(
66
+ f"Number of stages per rank ({num_stages_per_rank}) must be 1 for single stage schedules."
67
+ )
68
+ elif not is_single_stage_schedule and num_stages_per_rank < 2:
69
+ raise ValueError(
70
+ f"Number of stages per rank ({num_stages_per_rank}) must be >= 2 for multi stage schedules."
71
+ )
72
+ else:
73
+ total_stages = pp_dim * num_stages_per_rank
74
+ if total_stages > num_layers:
75
+ raise ValueError("Total stages cannot be greater than the number of layers")
76
+
77
+ # Calculate effective number of layers including input and output weights
78
+ effective_num_layers = num_layers + input_weight + output_weight
79
+ base_layers_per_stage = effective_num_layers // total_stages
80
+
81
+ splits = [""] * (total_stages - 1)
82
+ current_layer_index = 0
83
+
84
+ # First stage
85
+ layers_on_first_stage = max(0, base_layers_per_stage - input_weight)
86
+ current_layer_index += layers_on_first_stage
87
+ splits[0] = "layers." + str(current_layer_index)
88
+
89
+ # Last stage
90
+ layers_on_last_stage = max(0, base_layers_per_stage - output_weight)
91
+ splits[-1] = "layers." + str(num_layers - layers_on_last_stage)
92
+
93
+ # Middle stages
94
+ remaining_layers = num_layers - layers_on_first_stage - layers_on_last_stage - 1
95
+ middle_stages = len(splits) - 2
96
+ layers_per_middle_stage = remaining_layers // middle_stages
97
+ # split remainder evenly across middle stages
98
+ remainder = remaining_layers % middle_stages
99
+
100
+ for i in range(1, middle_stages + 1):
101
+ current_layer_index += layers_per_middle_stage
102
+ if remainder > 0:
103
+ current_layer_index += 1
104
+ remainder -= 1
105
+ splits[i] = "layers." + str(current_layer_index)
106
+
107
+ logger.info(
108
+ f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} "
109
+ "This may be sub-optimal as the number of layers per stage may be unbalanced."
110
+ )
111
+ return splits
112
+
113
+
114
+ def build_pipeline_schedule(
115
+ job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable
116
+ ) -> _PipelineSchedule:
117
+ """Builds a pipeline schedule for the given job configuration and stages.
118
+
119
+ Args:
120
+ job_config (JobConfig): The job configuration.
121
+ stages (list[PipelineStage]): The stages to be scheduled.
122
+ loss_fn (Callable): The loss function.
123
+
124
+ Returns:
125
+ _PipelineSchedule: The pipeline schedule for the given stages.
126
+ """
127
+ pp_schedule_csv = job_config.parallelism.pipeline_parallel_schedule_csv
128
+
129
+ # Validate that pp_schedule_csv is a valid path
130
+ if pp_schedule_csv:
131
+ if not os.path.isfile(pp_schedule_csv):
132
+ raise FileNotFoundError(
133
+ f"The specified path {pp_schedule_csv} does not exist or is not a file."
134
+ )
135
+ schedule_class = _PipelineScheduleRuntime
136
+ else:
137
+ schedule_class = get_schedule_class(
138
+ job_config.parallelism.pipeline_parallel_schedule
139
+ )
140
+
141
+ looped_schedule = issubclass(schedule_class, PipelineScheduleMulti)
142
+ microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size
143
+ batch_size = job_config.training.batch_size
144
+ # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training
145
+ if batch_size % microbatch_size != 0:
146
+ raise ValueError(
147
+ f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. "
148
+ "Update the config arguments for either batch_size or pipeline_parallel_microbatch_size."
149
+ )
150
+ n_microbatches = batch_size // microbatch_size
151
+ # We expect that the number of local stages (`len(stages)`) is the same across all ranks
152
+ num_total_stages = job_config.parallelism.pipeline_parallel_degree * len(stages)
153
+ if n_microbatches < num_total_stages:
154
+ logger.warning(
155
+ f"Number of microbatches ({n_microbatches}) is less than the total number "
156
+ f"of stages ({num_total_stages}) which may result in a bubble in the pipeline."
157
+ )
158
+
159
+ schedule = schedule_class(
160
+ stages if looped_schedule else stages[0],
161
+ n_microbatches=n_microbatches,
162
+ loss_fn=loss_fn,
163
+ )
164
+ logger.info(
165
+ f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} "
166
+ f"with {n_microbatches} microbatches and {num_total_stages} stages."
167
+ )
168
+
169
+ if pp_schedule_csv:
170
+ assert schedule_class in [
171
+ PipelineScheduleSingle,
172
+ PipelineScheduleMulti,
173
+ _PipelineScheduleRuntime,
174
+ ], (
175
+ "Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), "
176
+ "and _PipelineScheduleRuntime support csv schedules"
177
+ )
178
+ schedule._load_csv(pp_schedule_csv)
179
+
180
+ return schedule
181
+
182
+
183
+ # TODO(whc) should this be a utility inside torch.pipelining?
184
+ def stage_ids_this_rank(
185
+ pp_rank: int, pp_size: int, num_stages: int, style: str = "loop"
186
+ ) -> tuple[int]:
187
+ """Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule"""
188
+ assert (
189
+ num_stages % pp_size == 0
190
+ ), f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}"
191
+ stages_per_rank = num_stages // pp_size
192
+ if style == "loop":
193
+ return tuple(pp_rank + s * pp_size for s in range(stages_per_rank))
194
+ elif style == "v":
195
+ assert (
196
+ stages_per_rank == 2
197
+ ), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
198
+ stage_v_pairs = list(
199
+ zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))
200
+ )
201
+ return stage_v_pairs[pp_rank]
torchtitan/experiments/deepseek_v3/LICENSE-CODE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 DeepSeek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
torchtitan/experiments/deepseek_v3/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running DeepSeek in Titan (experimental)
2
+
3
+ This folder contains a DeepSeek model supporting v2 and v3 as well as kernels
4
+ and scripts needed to run it.
5
+
6
+ ## Inference
7
+
8
+ ### Prerequisites:
9
+
10
+ You will need to download a DeepSeek model's weights if you want to run a
11
+ pre-trained checkpoint. We provided a script to download the weights from
12
+ HuggingFace Model Hub:
13
+ ```bash
14
+ python download.py [vX]
15
+ ```
16
+ where `vX` can be v2 or v3, both are supported. You may be required to create a
17
+ HuggingFace account and log in first.
18
+
19
+ ### Running inference:
20
+
21
+ The inference script is in `generate.py`. You can run it with the following
22
+ command:
23
+ ```bash
24
+ torchrun --standalone --nproc-per-node 4 generate.py
25
+ ```
26
+ This will run inference on the `DeepSeek-V2-Lite-Chat` model using 4 GPUs by
27
+ default.
28
+
29
+ Alternatively, you can run inference by using `bash inference.sh`, optionally
30
+ followed by your prompt.
31
+
32
+ ## Training
33
+
34
+ The training script is in `train.py`. You can run it by the following command:
35
+ ```bash
36
+ torchrun --standalone --nproc-per-node 8 train.py
37
+ ```
38
+
39
+ This will run training on the `DeepSeek-V2-Lite-Chat` model using 8 GPUs by
40
+ default, with pipeline parallel, expert parallel, and data parallel enabled.
torchtitan/experiments/deepseek_v3/attn_mask_utils.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This code is based on src/transformers/modeling_attn_mask_utils.py of
8
+ # huggingface/transformers. It has been modified from its original forms to
9
+ # contain only the necessary utilities.
10
+
11
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ from dataclasses import dataclass
25
+ from typing import List, Optional, Tuple, Union
26
+
27
+ import torch
28
+
29
+
30
+ @dataclass
31
+ class AttentionMaskConverter:
32
+ """
33
+ A utility attention mask class that allows one to:
34
+ - Create a causal 4d mask
35
+ - Create a causal 4d mask with slided window
36
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
37
+ key_value_length) that can be multiplied with attention scores
38
+
39
+ Examples:
40
+
41
+ ```python
42
+ >>> import torch
43
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
44
+
45
+ >>> converter = AttentionMaskConverter(True)
46
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
47
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
48
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
49
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
50
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
51
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
52
+ ```
53
+
54
+ Parameters:
55
+ is_causal (`bool`):
56
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
57
+
58
+ sliding_window (`int`, *optional*):
59
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
60
+ """
61
+
62
+ is_causal: bool
63
+ sliding_window: int
64
+
65
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
66
+ self.is_causal = is_causal
67
+ self.sliding_window = sliding_window
68
+
69
+ if self.sliding_window is not None and self.sliding_window <= 0:
70
+ raise ValueError(
71
+ "Make sure that when passing `sliding_window` that its value is a strictly positive integer, "
72
+ f"not `{self.sliding_window}`"
73
+ )
74
+
75
+ def to_causal_4d(
76
+ self,
77
+ batch_size: int,
78
+ query_length: int,
79
+ key_value_length: int,
80
+ dtype: torch.dtype,
81
+ device: Union[torch.device, "str"] = "cpu",
82
+ ) -> Optional[torch.Tensor]:
83
+ """
84
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
85
+ bias to upper right hand triangular matrix (causal mask).
86
+ """
87
+ if not self.is_causal:
88
+ raise ValueError(
89
+ f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True."
90
+ )
91
+
92
+ # If shape is not cached, create a new causal mask and cache it
93
+ input_shape = (batch_size, query_length)
94
+ past_key_values_length = key_value_length - query_length
95
+
96
+ # create causal mask
97
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
98
+ causal_4d_mask = None
99
+ if input_shape[-1] > 1 or self.sliding_window is not None:
100
+ causal_4d_mask = self._make_causal_mask(
101
+ input_shape,
102
+ dtype,
103
+ device=device,
104
+ past_key_values_length=past_key_values_length,
105
+ sliding_window=self.sliding_window,
106
+ )
107
+
108
+ return causal_4d_mask
109
+
110
+ def to_4d(
111
+ self,
112
+ attention_mask_2d: torch.Tensor,
113
+ query_length: int,
114
+ dtype: torch.dtype,
115
+ key_value_length: Optional[int] = None,
116
+ ) -> torch.Tensor:
117
+ """
118
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
119
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
120
+ causal, a causal mask will be added.
121
+ """
122
+ input_shape = (attention_mask_2d.shape[0], query_length)
123
+
124
+ # create causal mask
125
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
126
+ causal_4d_mask = None
127
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
128
+ if key_value_length is None:
129
+ raise ValueError(
130
+ "This attention mask converter is causal. Make sure to pass "
131
+ "`key_value_length` to correctly create a causal mask."
132
+ )
133
+
134
+ past_key_values_length = key_value_length - query_length
135
+ causal_4d_mask = self._make_causal_mask(
136
+ input_shape,
137
+ dtype,
138
+ device=attention_mask_2d.device,
139
+ past_key_values_length=past_key_values_length,
140
+ sliding_window=self.sliding_window,
141
+ )
142
+ elif self.sliding_window is not None:
143
+ raise NotImplementedError(
144
+ "Sliding window is currently only implemented for causal masking"
145
+ )
146
+
147
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
148
+ expanded_attn_mask = self._expand_mask(
149
+ attention_mask_2d, dtype, tgt_len=input_shape[-1]
150
+ ).to(attention_mask_2d.device)
151
+
152
+ if causal_4d_mask is not None:
153
+ expanded_attn_mask = causal_4d_mask.masked_fill(
154
+ expanded_attn_mask.bool(), torch.finfo(dtype).min
155
+ )
156
+
157
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
158
+ expanded_4d_mask = expanded_attn_mask
159
+
160
+ return expanded_4d_mask
161
+
162
+ @staticmethod
163
+ def _make_causal_mask(
164
+ input_ids_shape: torch.Size,
165
+ dtype: torch.dtype,
166
+ device: torch.device,
167
+ past_key_values_length: int = 0,
168
+ sliding_window: Optional[int] = None,
169
+ ):
170
+ """
171
+ Make causal mask used for bi-directional self-attention.
172
+ """
173
+ bsz, tgt_len = input_ids_shape
174
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
175
+ mask_cond = torch.arange(mask.size(-1), device=device)
176
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
177
+
178
+ mask = mask.to(dtype)
179
+
180
+ if past_key_values_length > 0:
181
+ mask = torch.cat(
182
+ [
183
+ torch.zeros(
184
+ tgt_len, past_key_values_length, dtype=dtype, device=device
185
+ ),
186
+ mask,
187
+ ],
188
+ dim=-1,
189
+ )
190
+
191
+ # add lower triangular sliding window mask if necessary
192
+ if sliding_window is not None:
193
+ diagonal = past_key_values_length - sliding_window - 1
194
+
195
+ context_mask = torch.tril(
196
+ torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
197
+ )
198
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
199
+
200
+ return mask[None, None, :, :].expand(
201
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
202
+ )
203
+
204
+ @staticmethod
205
+ def _expand_mask(
206
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
207
+ ):
208
+ """
209
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
210
+ """
211
+ bsz, src_len = mask.size()
212
+ tgt_len = tgt_len if tgt_len is not None else src_len
213
+
214
+ expanded_mask = (
215
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
216
+ )
217
+
218
+ inverted_mask = 1.0 - expanded_mask
219
+
220
+ return inverted_mask.masked_fill(
221
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
222
+ )
223
+
224
+ @staticmethod
225
+ def _unmask_unattended(
226
+ expanded_mask: torch.FloatTensor,
227
+ min_dtype: float,
228
+ ):
229
+ # fmt: off
230
+ """
231
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
232
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
233
+ Details: https://github.com/pytorch/pytorch/issues/110213
234
+
235
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
236
+ `attention_mask` is [bsz, src_seq_len].
237
+
238
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case
239
+ of alibi attention bias.
240
+
241
+ For example, if `expanded_mask` is (e.g. here left-padding case)
242
+ ```
243
+ [[[[0, 0, 0],
244
+ [0, 0, 0],
245
+ [0, 0, 1]]],
246
+ [[[1, 0, 0],
247
+ [1, 1, 0],
248
+ [1, 1, 1]]],
249
+ [[[0, 0, 0],
250
+ [0, 1, 0],
251
+ [0, 1, 1]]]]
252
+ ```
253
+ then the modified `expanded_mask` will be
254
+ ```
255
+ [[[[1, 1, 1], <-- modified
256
+ [1, 1, 1], <-- modified
257
+ [0, 0, 1]]],
258
+ [[[1, 0, 0],
259
+ [1, 1, 0],
260
+ [1, 1, 1]]],
261
+ [[[1, 1, 1], <-- modified
262
+ [0, 1, 0],
263
+ [0, 1, 1]]]]
264
+ ```
265
+ """
266
+ # fmt: on
267
+ if expanded_mask.dtype == torch.bool:
268
+ raise ValueError(
269
+ "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
270
+ )
271
+
272
+ return expanded_mask.mul(
273
+ ~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)
274
+ )
275
+
276
+ @staticmethod
277
+ def _ignore_causal_mask_sdpa(
278
+ attention_mask: Optional[torch.Tensor],
279
+ inputs_embeds: torch.Tensor,
280
+ past_key_values_length: int,
281
+ sliding_window: Optional[int] = None,
282
+ is_training: bool = False,
283
+ ) -> bool:
284
+ """
285
+ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
286
+ ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
287
+
288
+ In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
289
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
290
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
291
+ passed).
292
+ """
293
+
294
+ _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
295
+ key_value_length = query_length + past_key_values_length
296
+
297
+ is_tracing = (
298
+ torch.jit.is_tracing()
299
+ or isinstance(inputs_embeds, torch.fx.Proxy)
300
+ or is_torchdynamo_compiling()
301
+ )
302
+
303
+ ignore_causal_mask = False
304
+
305
+ if attention_mask is None:
306
+ # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
307
+ # shape, thus SDPA's `is_causal` argument is rightfully updated
308
+ # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
309
+ # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
310
+ # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
311
+ # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
312
+ # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
313
+ #
314
+ # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
315
+ # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
316
+ if (
317
+ (is_training or not is_tracing)
318
+ and (query_length == 1 or key_value_length == query_length)
319
+ and (sliding_window is None or key_value_length < sliding_window)
320
+ ):
321
+ ignore_causal_mask = True
322
+ elif sliding_window is None or key_value_length < sliding_window:
323
+ if len(attention_mask.shape) == 4:
324
+ return False
325
+ elif not is_tracing and torch.all(attention_mask == 1):
326
+ if query_length == 1 or key_value_length == query_length:
327
+ # For query_length == 1, causal attention and bi-directional attention are the same.
328
+ ignore_causal_mask = True
329
+
330
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
331
+ # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
332
+ # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
333
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
334
+ # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
335
+
336
+ return ignore_causal_mask
337
+
338
+
339
+ def _prepare_4d_causal_attention_mask(
340
+ attention_mask: Optional[torch.Tensor],
341
+ input_shape: Union[torch.Size, Tuple, List],
342
+ inputs_embeds: torch.Tensor,
343
+ past_key_values_length: int,
344
+ sliding_window: Optional[int] = None,
345
+ ):
346
+ """
347
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
348
+ `(batch_size, key_value_length)`
349
+
350
+ Args:
351
+ attention_mask (`torch.Tensor` or `None`):
352
+ A 2D attention mask of shape `(batch_size, key_value_length)`
353
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
354
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
355
+ inputs_embeds (`torch.Tensor`):
356
+ The embedded inputs as a torch Tensor.
357
+ past_key_values_length (`int`):
358
+ The length of the key value cache.
359
+ sliding_window (`int`, *optional*):
360
+ If the model uses windowed attention, a sliding window should be passed.
361
+ """
362
+ attn_mask_converter = AttentionMaskConverter(
363
+ is_causal=True, sliding_window=sliding_window
364
+ )
365
+
366
+ key_value_length = input_shape[-1] + past_key_values_length
367
+
368
+ # 4d mask is passed through the layers
369
+ if attention_mask is not None and len(attention_mask.shape) == 2:
370
+ attention_mask = attn_mask_converter.to_4d(
371
+ attention_mask,
372
+ input_shape[-1],
373
+ key_value_length=key_value_length,
374
+ dtype=inputs_embeds.dtype,
375
+ )
376
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
377
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
378
+ if tuple(attention_mask.shape) != expected_shape:
379
+ raise ValueError(
380
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
381
+ )
382
+ else:
383
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
384
+ inverted_mask = 1.0 - attention_mask
385
+ attention_mask = inverted_mask.masked_fill(
386
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
387
+ )
388
+ else:
389
+ attention_mask = attn_mask_converter.to_causal_4d(
390
+ input_shape[0],
391
+ input_shape[-1],
392
+ key_value_length,
393
+ dtype=inputs_embeds.dtype,
394
+ device=inputs_embeds.device,
395
+ )
396
+
397
+ return attention_mask
torchtitan/experiments/deepseek_v3/download.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Usage:
8
+ # Downloads a given model to the HF Cache. Pass in a listed option ala "v3" or your own custom model path.
9
+ # python download.py {model_id} [custom_model_path]
10
+ # Examples:
11
+ # python download.py v2 # Use predefined model: deepseek-ai/DeepSeek-V2
12
+ # python download.py custom "deepseek-ai/new-model" # Download a custom model path
13
+
14
+ # Available models:
15
+ # "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
16
+ # "v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
17
+ # "v2": "deepseek-ai/DeepSeek-V2",
18
+ # "v3": "deepseek-ai/deepseek-v3",
19
+ # "v3-0324": "deepseek-ai/DeepSeek-V3-0324",
20
+ # "custom": None, # Placeholder for custom models
21
+
22
+
23
+ import sys
24
+
25
+ from transformers import AutoModelForCausalLM
26
+
27
+
28
+ MODELS = {
29
+ "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
30
+ "v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
31
+ "v2": "deepseek-ai/DeepSeek-V2",
32
+ "v3": "deepseek-ai/deepseek-v3",
33
+ "v3-0324": "deepseek-ai/DeepSeek-V3-0324",
34
+ "custom": None, # For custom (any) models
35
+ }
36
+
37
+
38
+ def print_usage():
39
+ print("Usage:")
40
+ print(" python download.py [model_version]")
41
+ print(" python download.py custom [custom_model_path]")
42
+ print("\nAvailable predefined models:")
43
+ for key, model in MODELS.items():
44
+ if key != "custom": # Skip the custom placeholder
45
+ print(f" {key}: {model}")
46
+ print("\nFor custom models:")
47
+ print(" custom: Specify your own model path")
48
+ print(' Example: python download.py custom "organization/model-name"')
49
+ sys.exit(1)
50
+
51
+
52
+ # Process command line arguments
53
+ if len(sys.argv) < 2 or sys.argv[1] not in MODELS:
54
+ print_usage()
55
+
56
+ if sys.argv[1] == "custom":
57
+ if len(sys.argv) != 3:
58
+ print("Error: Custom model requires a model path")
59
+ print_usage()
60
+ model_id = sys.argv[2]
61
+ print(f"Using custom model: {model_id}")
62
+ else:
63
+ model_id = MODELS[sys.argv[1]]
64
+ print(f"Downloading model: {model_id}")
65
+
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ model_id,
68
+ device_map="auto",
69
+ trust_remote_code=True,
70
+ )
torchtitan/experiments/deepseek_v3/generate.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # torchrun --standalone --nproc-per-node 4 generate.py
8
+
9
+ # use inference.sh "Your Question Here?" to run inference with a single prompt.
10
+
11
+ import sys
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+
17
+ from checkpoint import load_weights_from_hf
18
+ from model import DeepseekForCausalLM
19
+ from model_config import deepseek_config_registry
20
+ from torch.distributed.device_mesh import DeviceMesh
21
+ from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
22
+ from torchtitan.tools.utils import Color
23
+ from transformers import AutoTokenizer
24
+
25
+ # Uncomment the model you want to run.
26
+ model_id, mesh_shape = "deepseek-ai/DeepSeek-V2-Lite-Chat", (1, 4)
27
+ # model_id, mesh_shape = "deepseek-ai/deepseek-v3", (8, 4)
28
+
29
+
30
+ def colorize_chat(text, user_color=None, assistant_color=None, output_color=None):
31
+ """Parse and colorize chat output with optional colors for each role."""
32
+ lines = text.split("\n")
33
+ result = []
34
+
35
+ current_role = None
36
+ current_content = []
37
+
38
+ def _process_current_content():
39
+ if not current_role or not current_content:
40
+ return None
41
+
42
+ content = "\n".join(current_content)
43
+ if current_role == "output":
44
+ return (
45
+ f"Output: {output_color}{content}{color.reset}"
46
+ if output_color
47
+ else f"Output: {content}"
48
+ )
49
+ else:
50
+ try:
51
+ prefix, rest = current_content[0].split(":", 1)
52
+ role_color = user_color if current_role == "user" else assistant_color
53
+ if role_color:
54
+ formatted = f"{prefix}:{role_color}{rest}{color.reset}"
55
+ if len(current_content) > 1:
56
+ formatted += (
57
+ f"{role_color}\n"
58
+ + "\n".join(current_content[1:])
59
+ + f"{color.reset}"
60
+ )
61
+ return formatted
62
+ except ValueError:
63
+ pass
64
+ return content
65
+
66
+ for line in lines:
67
+ if line.startswith("Output:"):
68
+ if processed := _process_current_content():
69
+ result.append(processed)
70
+ current_role = "output"
71
+ content = line[len("Output:") :].strip()
72
+ if output_color:
73
+ content = f"Output: {output_color}{content}{color.reset}"
74
+ else:
75
+ content = f"Output: {content}"
76
+ result.append(content)
77
+ current_content = []
78
+
79
+ elif line.startswith("User:"):
80
+ if processed := _process_current_content():
81
+ result.append(processed)
82
+ current_role = "user"
83
+ current_content = [line]
84
+
85
+ elif line.startswith("Assistant:"):
86
+ if processed := _process_current_content():
87
+ result.append(processed)
88
+ current_role = "assistant"
89
+ current_content = [line]
90
+
91
+ else:
92
+ if current_content:
93
+ current_content.append(line)
94
+ elif line.strip() and current_role is None:
95
+ # Handle system message at the beginning
96
+ current_role = "output"
97
+ if output_color:
98
+ result.append(f"Output: {output_color}{line.strip()}{color.reset}")
99
+ else:
100
+ result.append(f"Output: {line.strip()}")
101
+
102
+ # Process the last segment
103
+ if processed := _process_current_content():
104
+ result.append(processed)
105
+
106
+ return "\n".join(result)
107
+
108
+
109
+ color = Color()
110
+
111
+
112
+ @dataclass
113
+ class DistConfig:
114
+ mesh: DeviceMesh
115
+ pp_mesh: DeviceMesh
116
+ ep_mesh: DeviceMesh
117
+ pp_size: int
118
+ ep_size: int
119
+ ep_rank: int
120
+ pp_rank: int
121
+ device: torch.device
122
+
123
+
124
+ def create_model(dist_config: DistConfig):
125
+ model_args = deepseek_config_registry[model_id]
126
+ model_args.ep_size = dist_config.ep_size
127
+ model_args.num_stages = dist_config.pp_size
128
+ model_args.stage_idx = dist_config.pp_rank
129
+ model_args.max_seq_len = 16384
130
+
131
+ with dist_config.device, dist_config.mesh:
132
+ model = DeepseekForCausalLM(model_args)
133
+ load_weights_from_hf(model, model_id, dist_config.device)
134
+ model.eval()
135
+ model.setup_symm_mem(torch.bfloat16, dist_config.device)
136
+
137
+ stage = PipelineStage(
138
+ model,
139
+ dist_config.pp_rank,
140
+ dist_config.pp_size,
141
+ dist_config.device,
142
+ group=dist_config.pp_mesh.get_group(),
143
+ )
144
+ pp_schedule = ScheduleGPipe(stage, dist_config.pp_size)
145
+ return model, pp_schedule
146
+
147
+
148
+ def create_dist_config(mesh: DeviceMesh):
149
+ rank = dist.get_rank()
150
+ device_count = torch.cuda.device_count()
151
+ device = torch.device("cuda", rank % device_count)
152
+
153
+ dist_config = DistConfig(
154
+ mesh=mesh,
155
+ pp_mesh=mesh["pp"],
156
+ ep_mesh=mesh["ep"],
157
+ pp_rank=mesh["pp"].get_local_rank(),
158
+ pp_size=mesh["pp"].size(),
159
+ ep_size=mesh["ep"].size(),
160
+ ep_rank=mesh["ep"].get_local_rank(),
161
+ device=device,
162
+ )
163
+ return dist_config
164
+
165
+
166
+ def decode(tokenizer, x):
167
+ output = tokenizer.decode(x[0])
168
+ # Clean up the output by removing special tokens
169
+ bos = tokenizer.bos_token
170
+ output = output.replace(bos, "")
171
+ # Truncate at end of sentence token
172
+ eos_token = tokenizer.eos_token
173
+ if eos_token and eos_token in output:
174
+ output = output.split(eos_token)[0]
175
+ colored_output = colorize_chat(
176
+ output,
177
+ user_color=color.green,
178
+ assistant_color=color.cyan,
179
+ output_color=color.blue,
180
+ )
181
+ return colored_output
182
+
183
+
184
+ @torch.inference_mode()
185
+ def generate(
186
+ model,
187
+ pp_schedule,
188
+ tokenizer,
189
+ dist_config,
190
+ messages: list[dict],
191
+ n_tokens: int = 50,
192
+ ):
193
+ rank = dist.get_rank()
194
+ device = dist_config.device
195
+ x = tokenizer.apply_chat_template(
196
+ [messages] * dist_config.pp_size,
197
+ add_generation_prompt=True,
198
+ return_tensors="pt",
199
+ )
200
+ next_idx = x.shape[-1]
201
+ x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
202
+ x = x.to(device)
203
+
204
+ for _ in range(n_tokens):
205
+ if dist_config.pp_size > 1:
206
+ if dist_config.pp_rank == 0:
207
+ pp_schedule.step(x)
208
+ torch.distributed.broadcast(
209
+ x,
210
+ group=dist_config.pp_mesh.get_group(),
211
+ group_src=dist_config.pp_size - 1,
212
+ )
213
+ elif dist_config.pp_rank == dist_config.pp_size - 1:
214
+ preds = pp_schedule.step()
215
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
216
+ x[:, next_idx] = next_token
217
+ torch.distributed.broadcast(
218
+ x,
219
+ group=dist_config.pp_mesh.get_group(),
220
+ group_src=dist_config.pp_size - 1,
221
+ )
222
+ else:
223
+ pp_schedule.step()
224
+ torch.distributed.broadcast(
225
+ x,
226
+ group=dist_config.pp_mesh.get_group(),
227
+ group_src=dist_config.pp_size - 1,
228
+ )
229
+
230
+ next_idx += 1
231
+ else:
232
+ preds = model(x)
233
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
234
+ x[:, next_idx] = next_token
235
+ next_idx += 1
236
+
237
+ if rank == 0:
238
+ colored_output = decode(tokenizer, x)
239
+ print(f"Without CUDA Graph:\n{colored_output}")
240
+
241
+
242
+ @torch.inference_mode()
243
+ def generate_with_cuda_graph(
244
+ model,
245
+ tokenizer,
246
+ dist_config,
247
+ messages: list[dict],
248
+ n_tokens: int = 10,
249
+ ):
250
+ rank = dist.get_rank()
251
+ device = dist_config.device
252
+ x = tokenizer.apply_chat_template(
253
+ [messages] * dist_config.pp_size,
254
+ add_generation_prompt=True,
255
+ return_tensors="pt",
256
+ )
257
+ next_idx = x.shape[-1]
258
+ x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
259
+ x = x.to(device)
260
+
261
+ torch.cuda.synchronize()
262
+
263
+ # Create CUDA graph
264
+ g = torch.cuda.CUDAGraph()
265
+ with torch.cuda.graph(g):
266
+ preds = model(x)
267
+
268
+ # Run CUDA graph
269
+ for _ in range(n_tokens):
270
+ g.replay()
271
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
272
+ x[:, next_idx] = next_token
273
+ next_idx += 1
274
+
275
+ if rank == 0:
276
+ colored_output = decode(tokenizer, x)
277
+ print(f"With CUDA Graph:\n{colored_output}")
278
+
279
+
280
+ if __name__ == "__main__":
281
+ # Get user prompt from command line arguments
282
+ user_prompt = "What is 2+2?" # Default prompt
283
+ if len(sys.argv) > 1:
284
+ user_prompt = sys.argv[1]
285
+
286
+ mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=("pp", "ep"))
287
+ rank = dist.get_rank()
288
+ if rank == 0:
289
+ print(
290
+ f"{color.yellow}Running inference with {model_id} on {mesh_shape} mesh{color.reset}"
291
+ )
292
+
293
+ dist_config = create_dist_config(mesh)
294
+ model, pp_schedule = create_model(dist_config)
295
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
296
+
297
+ messages = [
298
+ {"role": "system", "content": "You are a helpful assistant."},
299
+ {"role": "user", "content": user_prompt},
300
+ ]
301
+
302
+ generate(model, pp_schedule, tokenizer, dist_config, messages)
303
+ generate_with_cuda_graph(model, tokenizer, dist_config, messages)
304
+
305
+ if rank == 0:
306
+ print(f"\n{color.yellow}Closing inference mesh...{color.reset}")
307
+
308
+ dist.destroy_process_group()
torchtitan/experiments/deepseek_v3/indices.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+
12
+ __all__ = ["generate_permute_indices"]
13
+
14
+
15
+ @triton.jit
16
+ def fill_indices_kernel(
17
+ tokens_per_expert_group_ptr, # *Pointer* to first input vector.
18
+ start_index_values_ptr, # *Pointer* to second input vector.
19
+ write_offsets_ptr, # *Pointer* to third input vector.
20
+ output_ptr, # *Pointer* to output vector.
21
+ experts_per_rank, # Number of experts per rank.
22
+ num_ranks, # Number of expert ranks.
23
+ ):
24
+ # There are multiple 'programs' processing different data. We identify which program
25
+ # we are here:
26
+ pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
27
+ # The total number of programs in the launch grid.
28
+ num_programs = tl.num_programs(axis=0)
29
+ # We map the programs (blocks) to the experts.
30
+ for expert_id in tl.range(pid, experts_per_rank, step=num_programs):
31
+ # Read this expert's write offset.
32
+ write_offset = tl.load(write_offsets_ptr + expert_id)
33
+ # Loop over the ranks.
34
+ for r in tl.range(num_ranks):
35
+ # Slot in the tokens_per_expert_group array.
36
+ i = r * experts_per_rank + expert_id
37
+ start_index = tl.load(start_index_values_ptr + i)
38
+ length = tl.load(tokens_per_expert_group_ptr + i)
39
+ # Write the indices.
40
+ for l in tl.range(length):
41
+ val = start_index + l
42
+ tl.store(output_ptr + write_offset + l, val)
43
+ write_offset += length
44
+
45
+
46
+ def fill_indices(
47
+ tokens_per_expert_group: torch.Tensor,
48
+ start_index_values: torch.Tensor,
49
+ write_offsets: torch.Tensor,
50
+ experts_per_rank: int,
51
+ num_ranks: int,
52
+ max_len: int,
53
+ ):
54
+ # We need to preallocate the output.
55
+ permuted_indices = torch.full(
56
+ (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
57
+ )
58
+ # Analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
59
+ # In this case, we use a 1D grid where the size is the number of blocks (TODO: bump this value).
60
+ grid = lambda meta: (1,)
61
+ # Each torch.tensor object is implicitly converted into a pointer to its first element.
62
+ fill_indices_kernel[grid](
63
+ tokens_per_expert_group,
64
+ start_index_values,
65
+ write_offsets,
66
+ permuted_indices,
67
+ experts_per_rank,
68
+ num_ranks,
69
+ )
70
+ return permuted_indices
71
+
72
+
73
+ def fill_indices_cpu(
74
+ tokens_per_expert_group: torch.Tensor,
75
+ start_index_values: torch.Tensor,
76
+ write_offsets: torch.Tensor,
77
+ experts_per_rank: int,
78
+ num_ranks: int,
79
+ max_len: int,
80
+ ):
81
+ # We need to preallocate the output.
82
+ permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
83
+ # Fill the permuted indices
84
+ # For each local expert
85
+ for e in range(experts_per_rank):
86
+ write_start = write_offsets[e]
87
+ # For each remote rank
88
+ for r in range(num_ranks):
89
+ i = r * experts_per_rank + e
90
+ start_index = start_index_values[i]
91
+ length = tokens_per_expert_group[i]
92
+ # Fill in the indices
93
+ permuted_indices[write_start : write_start + length] = torch.arange(
94
+ start_index, start_index + length
95
+ )
96
+ write_start += length
97
+ return permuted_indices
98
+
99
+
100
+ def generate_permute_indices(
101
+ tokens_per_expert_group: torch.Tensor,
102
+ experts_per_rank: int,
103
+ num_ranks: int,
104
+ max_len: int,
105
+ alignment: int,
106
+ use_cpu: bool = False,
107
+ ):
108
+ # Prepare permutation indices and the number of tokens for each expert. The
109
+ # permutation indices are the indices of the tokens for each expert. The
110
+ # number of tokens for each expert is the sum of the number of tokens for
111
+ # such experts from all ranks. This number is aligned to the provided
112
+ # alignment requirement (usually comes from group gemm).
113
+
114
+ # Args:
115
+ # tokens_per_expert_group: number of tokens for each expert from all ranks.
116
+ # experts_per_rank: number of experts per rank.
117
+ # num_ranks: number of ranks.
118
+ # max_len: maximum length of the output index vector. If greater than
119
+ # total number of tokens, the remaining indices are set to -1.
120
+ # alignment: alignment for each returned element in `m_sizes`.
121
+ # use_cpu: whether to use cpu or gpu.
122
+ # Returns:
123
+ # permuted_indices: permutation indices.
124
+ # m_sizes: number of tokens for each expert.
125
+
126
+ # `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example:
127
+ # From: | rank 0 | rank 1 |
128
+ # To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 |
129
+ # | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
130
+
131
+ # Prefix sum to get the start index value of each expert
132
+ start_index_values = (
133
+ torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
134
+ )
135
+ # Chunk sizes for each expert
136
+ chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
137
+ # Align the chunk sizes to the given alignment
138
+ m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to(
139
+ torch.int32
140
+ )
141
+ # Perform another prefix sum to get the write offset of each expert in `permuted_indices`
142
+ write_offsets = torch.cumsum(m_sizes, 0) - m_sizes
143
+ # Select the method to fill the permuted indices
144
+ fill_fn = fill_indices_cpu if use_cpu else fill_indices
145
+ # Fill the permuted indices
146
+ permuted_indices = fill_fn(
147
+ tokens_per_expert_group,
148
+ start_index_values,
149
+ write_offsets,
150
+ experts_per_rank,
151
+ num_ranks,
152
+ max_len,
153
+ )
154
+ return permuted_indices, m_sizes
155
+
156
+
157
+ # Below is for testing only
158
+
159
+
160
+ def test():
161
+ device = torch.device("cuda", 0)
162
+ experts_per_rank = 4
163
+ num_ranks = 4
164
+ tokens_per_expert_group = torch.full(
165
+ (num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device
166
+ )
167
+ max_len = 128
168
+ alignment = 32
169
+ # Use the GPU kernel
170
+ permuted_indices_gpu, m_sizes = generate_permute_indices(
171
+ tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
172
+ )
173
+ # Use the CPU method
174
+ permuted_indices_cpu, _ = generate_permute_indices(
175
+ tokens_per_expert_group,
176
+ experts_per_rank,
177
+ num_ranks,
178
+ max_len,
179
+ alignment,
180
+ use_cpu=True,
181
+ )
182
+ # Check that the results are the same
183
+ assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu)
184
+ assert torch.equal(
185
+ torch.remainder(m_sizes, alignment),
186
+ torch.zeros(experts_per_rank, device=device),
187
+ )
188
+ # Print the results
189
+ print(permuted_indices_gpu)
190
+ print(m_sizes)
191
+ print("Success")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ test()
torchtitan/experiments/deepseek_v3/inference.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/bash
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ NGPU=${NGPU:-"4"}
10
+
11
+ # Get the prompt from command line argument or use a default
12
+ prompt="${1:-What is 2+2?}"
13
+
14
+ # Run the model with the prompt
15
+ torchrun --standalone --nproc-per-node ${NGPU} generate.py "$prompt"
torchtitan/experiments/deepseek_v3/model_config.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+
10
+ @dataclass
11
+ class ModelArgs:
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
14
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
15
+ defaults will yield a similar configuration to that of the DeepSeek-V3.
16
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17
+ documentation from [`PretrainedConfig`] for more information.
18
+ Args:
19
+ vocab_size (`int`, *optional*, defaults to 129280):
20
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
21
+ `inputs_ids` passed when calling [`DeepseekV3Model`]
22
+ hidden_size (`int`, *optional*, defaults to 4096):
23
+ Dimension of the hidden representations.
24
+ intermediate_size (`int`, *optional*, defaults to 11008):
25
+ Dimension of the MLP representations.
26
+ moe_intermediate_size (`int`, *optional*, defaults to 1407):
27
+ Dimension of the MoE representations.
28
+ num_hidden_layers (`int`, *optional*, defaults to 32):
29
+ Number of hidden layers in the Transformer decoder.
30
+ num_nextn_predict_layers (`int`, *optional*, defaults to 1):
31
+ Number of nextn predict layers in the DeepSeekV3 Model.
32
+ num_attention_heads (`int`, *optional*, defaults to 32):
33
+ Number of attention heads for each attention layer in the Transformer decoder.
34
+ n_shared_experts (`int`, *optional*, defaults to None):
35
+ Number of shared experts, None means dense model.
36
+ n_routed_experts (`int`, *optional*, defaults to None):
37
+ Number of routed experts, None means dense model.
38
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
39
+ Scaling factor or routed experts.
40
+ topk_method (`str`, *optional*, defaults to `gready`):
41
+ Topk method used in routed gate.
42
+ n_group (`int`, *optional*, defaults to None):
43
+ Number of groups for routed experts.
44
+ topk_group (`int`, *optional*, defaults to None):
45
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within
46
+ `topk_group` groups).
47
+ num_experts_per_tok (`int`, *optional*, defaults to None):
48
+ Number of selected experts, None means dense model.
49
+ moe_layer_freq (`int`, *optional*, defaults to 1):
50
+ The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
51
+ first_k_dense_replace (`int`, *optional*, defaults to 0):
52
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
53
+ \--k dense layers--/
54
+ norm_topk_prob (`bool`, *optional*, defaults to False):
55
+ Whether to normalize the weights of the routed experts.
56
+ scoring_func (`str`, *optional*, defaults to 'softmax'):
57
+ Method of computing expert weights.
58
+ aux_loss_alpha (`float`, *optional*, defaults to 0.001):
59
+ Auxiliary loss weight coefficient.
60
+ seq_aux = (`bool`, *optional*, defaults to True):
61
+ Whether to compute the auxiliary loss for each individual sample.
62
+ num_key_value_heads (`int`, *optional*):
63
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
64
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
65
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
66
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
67
+ by meanpooling all the original heads within that group. For more details checkout [this
68
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
69
+ `num_attention_heads`.
70
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
71
+ The non-linear activation function (function or string) in the decoder.
72
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
73
+ The maximum sequence length that this model might ever be used with.
74
+ initializer_range (`float`, *optional*, defaults to 0.02):
75
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
76
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
77
+ The epsilon used by the rms normalization layers.
78
+ use_cache (`bool`, *optional*, defaults to `True`):
79
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
80
+ relevant if `config.is_decoder=True`.
81
+ pad_token_id (`int`, *optional*):
82
+ Padding token id.
83
+ bos_token_id (`int`, *optional*, defaults to 1):
84
+ Beginning of stream token id.
85
+ eos_token_id (`int`, *optional*, defaults to 2):
86
+ End of stream token id.
87
+ pretraining_tp (`int`, *optional*, defaults to 1):
88
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
89
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
90
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
91
+ issue](https://github.com/pytorch/pytorch/issues/76232).
92
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
93
+ Whether to tie weight embeddings
94
+ rope_theta (`float`, *optional*, defaults to 10000.0):
95
+ The base period of the RoPE embeddings.
96
+ rope_scaling (`Dict`, *optional*):
97
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
98
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
99
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
100
+ `max_position_embeddings` to the expected new maximum.
101
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
102
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
103
+ attention_dropout (`float`, *optional*, defaults to 0.0):
104
+ The dropout ratio for the attention probabilities.
105
+ """
106
+
107
+ vocab_size: int = 129280
108
+ hidden_size: int = 7168
109
+ intermediate_size: int = 18432
110
+ moe_intermediate_size: int = 2048
111
+ num_hidden_layers: int = 61
112
+ num_nextn_predict_layers: int = 1
113
+ num_attention_heads: int = 128
114
+ num_key_value_heads: int = 128
115
+ n_shared_experts: int = 1
116
+ n_routed_experts: int = 256
117
+ ep_size: int = 1
118
+ routed_scaling_factor: float = 2.5
119
+ kv_lora_rank: int = 512
120
+ q_lora_rank: int = 1536
121
+ qk_rope_head_dim: int = 64
122
+ v_head_dim: int = 128
123
+ qk_nope_head_dim: int = 128
124
+ topk_method: str = "noaux_tc"
125
+ n_group: int = 8
126
+ topk_group: int = 4
127
+ num_experts_per_tok: int = 8
128
+ moe_layer_freq: int = 1
129
+ first_k_dense_replace: int = 3
130
+ norm_topk_prob: bool = True
131
+ scoring_func: str = "sigmoid"
132
+ aux_loss_alpha: float = 0.001
133
+ seq_aux: bool = True
134
+ hidden_act: str = "silu"
135
+ max_position_embeddings: int = 163840
136
+ initializer_range: float = 0.02
137
+ rms_norm_eps: float = 1e-6
138
+ rope_theta: float = 10000.0
139
+ rope_scaling: dict = field(
140
+ default_factory=lambda: {
141
+ "beta_fast": 32,
142
+ "beta_slow": 1,
143
+ "factor": 40,
144
+ "mscale": 1.0,
145
+ "mscale_all_dim": 1.0,
146
+ "original_max_position_embeddings": 4096,
147
+ "type": "yarn",
148
+ }
149
+ )
150
+ attention_bias: bool = False
151
+ attention_dropout: float = 0.0
152
+ pad_token_id = None
153
+ # Added for symmetric memory
154
+ max_seq_len: int = 4096
155
+ dtype: str = "bfloat16"
156
+ # Added for pipeline parallel
157
+ num_stages: int = 1
158
+ stage_idx: int = 0
159
+
160
+
161
+ # This is the configuration for deepseek-ai/DeepSeek-V2-Lite.
162
+ deepseek_v2_lite_config = ModelArgs(
163
+ vocab_size=102400,
164
+ hidden_size=2048,
165
+ intermediate_size=10944,
166
+ moe_intermediate_size=1408,
167
+ num_hidden_layers=27,
168
+ num_attention_heads=16,
169
+ num_key_value_heads=16,
170
+ n_shared_experts=2,
171
+ n_routed_experts=64,
172
+ routed_scaling_factor=1.0,
173
+ kv_lora_rank=512,
174
+ q_lora_rank=None,
175
+ qk_rope_head_dim=64,
176
+ v_head_dim=128,
177
+ qk_nope_head_dim=128,
178
+ topk_method="greedy",
179
+ n_group=1,
180
+ topk_group=1,
181
+ num_experts_per_tok=6,
182
+ first_k_dense_replace=1,
183
+ norm_topk_prob=False,
184
+ scoring_func="softmax",
185
+ max_position_embeddings=4096,
186
+ rope_scaling={
187
+ "beta_fast": 32,
188
+ "beta_slow": 1,
189
+ "factor": 40,
190
+ "mscale": 0.707,
191
+ "mscale_all_dim": 0.707,
192
+ "original_max_position_embeddings": 4096,
193
+ "type": "yarn",
194
+ },
195
+ )
196
+
197
+
198
+ # Model configuration registry
199
+ # Key is the model distribution ID on HuggingFace Hub
200
+ deepseek_config_registry = {
201
+ "deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config,
202
+ "deepseek-ai/DeepSeek-V2-Lite-Chat": deepseek_v2_lite_config,
203
+ "deepseek-ai/deepseek-v3": ModelArgs(),
204
+ }
torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .triton_on_device_all_to_all_v import OnDeviceAllToAllV
8
+
9
+ __all__ = [
10
+ "OnDeviceAllToAllV",
11
+ ]
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from .triton_utils import get_flat_bid, get_flat_tid
11
+
12
+
13
+ @triton.jit
14
+ def send_signal(addrs, sem: tl.constexpr):
15
+ if sem == "relaxed":
16
+ tl.inline_asm_elementwise(
17
+ """
18
+ {
19
+ .reg .u32 %tmp32_<1>;
20
+ .reg .pred %p<1>;
21
+
22
+ send_signal:
23
+ atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
24
+ setp.eq.u32 %p0, %tmp32_0, 0;
25
+ @!%p0 bra send_signal;
26
+ }
27
+ """,
28
+ "=r, l",
29
+ [addrs],
30
+ dtype=tl.int32,
31
+ is_pure=False,
32
+ pack=1,
33
+ )
34
+ elif sem == "acq_rel":
35
+ tl.inline_asm_elementwise(
36
+ """
37
+ {
38
+ .reg .u32 %tmp32_<1>;
39
+ .reg .pred %p<1>;
40
+
41
+ send_signal:
42
+ atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
43
+ setp.eq.u32 %p0, %tmp32_0, 0;
44
+ @!%p0 bra send_signal;
45
+ }
46
+ """,
47
+ "=r, l",
48
+ [addrs],
49
+ dtype=tl.int32,
50
+ is_pure=False,
51
+ pack=1,
52
+ )
53
+ else:
54
+ raise RuntimeError(f"Unrecognized sem: {sem}")
55
+
56
+
57
+ @triton.jit
58
+ def wait_signal(addrs, sem: tl.constexpr):
59
+ if sem == "relaxed":
60
+ tl.inline_asm_elementwise(
61
+ """
62
+ {
63
+ .reg .u32 %tmp32_<1>;
64
+ .reg .pred %p<1>;
65
+
66
+ wait_signal:
67
+ atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
68
+ setp.eq.u32 %p0, %tmp32_0, 1;
69
+ @!%p0 bra wait_signal;
70
+ }
71
+ """,
72
+ "=r, l",
73
+ [addrs],
74
+ dtype=tl.int32,
75
+ is_pure=False,
76
+ pack=1,
77
+ )
78
+ elif sem == "acq_rel":
79
+ tl.inline_asm_elementwise(
80
+ """
81
+ {
82
+ .reg .u32 %tmp32_<1>;
83
+ .reg .pred %p<1>;
84
+
85
+ wait_signal:
86
+ atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
87
+ setp.eq.u32 %p0, %tmp32_0, 1;
88
+ @!%p0 bra wait_signal;
89
+ }
90
+ """,
91
+ "=r, l",
92
+ [addrs],
93
+ dtype=tl.int32,
94
+ is_pure=False,
95
+ pack=1,
96
+ )
97
+ else:
98
+ raise RuntimeError(f"Unrecognized sem: {sem}")
99
+
100
+
101
+ @triton.jit
102
+ def blockwise_barrier(
103
+ signal_pad_ptrs,
104
+ block_id,
105
+ rank: tl.constexpr,
106
+ world_size: tl.constexpr,
107
+ sem: tl.constexpr,
108
+ ):
109
+ """
110
+ Synchronizes blocks with matching block_id across participating devices.
111
+
112
+ Note: the function itself is not a system level barrier/fence. It is a
113
+ building block for expressing different synchronization patterns.
114
+
115
+ Pattern 0: Ensures that all writes to symm_mem buffers from previous
116
+ kernels across all devices are visible to the current kernel:
117
+
118
+ blockwise_barrier(..., sem="relaxed")
119
+ sync_threads()
120
+
121
+ Pattern 1: Ensures that all writes to symm_mem buffers from the current
122
+ block are visible to all remote blocks with matching blockIdx:
123
+
124
+ sync_threads()
125
+ blockwise_barrier(..., sem="acq_rel")
126
+ sync_threads()
127
+
128
+ Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
129
+ for writing by subsequent kernels across all devices.
130
+
131
+ sync_threads()
132
+ blockwise_barrier(..., sem="relaxed")
133
+
134
+ CUDA graph friendliness:
135
+
136
+ This barrier operates through atomic operations on a zero-filled signal
137
+ pad, which resets to a zero-filled state after each successful
138
+ synchronization. This design eliminates the need for incrementing a
139
+ flag from host.
140
+ """
141
+ if block_id is None:
142
+ block_id = get_flat_bid()
143
+ flat_tid = get_flat_tid()
144
+
145
+ remote_ranks = tl.arange(0, world_size)
146
+ signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
147
+ remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
148
+ tl.pointer_type(tl.uint32)
149
+ )
150
+ send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
151
+
152
+ local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
153
+ tl.pointer_type(tl.uint32)
154
+ )
155
+ wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
156
+
157
+ if flat_tid < world_size:
158
+ send_signal(send_addrs, sem)
159
+ wait_signal(wait_addrs, sem)
torchtitan/experiments/deepseek_v3/train.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # torchrun --standalone --nproc-per-node 8 run.py
8
+ import torch
9
+ import torch.distributed as dist
10
+ from checkpoint import load_weights_from_hf
11
+ from model import DeepseekForCausalLM
12
+ from model_config import deepseek_config_registry
13
+
14
+ from torch.distributed.device_mesh import DeviceMesh
15
+ from torch.distributed.fsdp import fully_shard
16
+ from torch.distributed.pipelining import PipelineStage, Schedule1F1B
17
+
18
+
19
+ # Use DeepSeek-V2-Lite as a proxy
20
+ model_id = "deepseek-ai/DeepSeek-V2-Lite"
21
+
22
+
23
+ # Run full model
24
+ def run_full_model(
25
+ mesh: DeviceMesh,
26
+ ):
27
+ rank = dist.get_rank()
28
+ device_count = torch.cuda.device_count()
29
+ device = torch.device("cuda", rank % device_count)
30
+
31
+ pp_mesh = mesh["pp"]
32
+ ep_mesh = mesh["ep"]
33
+ pp_rank = pp_mesh.get_local_rank()
34
+ ep_rank = ep_mesh.get_local_rank()
35
+ pp_size = pp_mesh.size()
36
+ ep_size = ep_mesh.size()
37
+
38
+ # Get model configs
39
+ model_args = deepseek_config_registry[model_id]
40
+ # [Note]: I am making the model smaller for testing / avoiding OOM. If you
41
+ # have sufficient GPUs for model parallelism, you can remove this line.
42
+ model_args.num_hidden_layers = 16
43
+
44
+ # Apply model parallelism
45
+ model_args.ep_size = ep_size
46
+ model_args.num_stages = pp_size
47
+ model_args.stage_idx = pp_rank
48
+ print(model_args)
49
+
50
+ # Instantiate model
51
+ with device, mesh:
52
+ model = DeepseekForCausalLM(model_args)
53
+
54
+ # Load weights
55
+ load_weights_from_hf(model, model_id, device)
56
+ model.train()
57
+
58
+ # Apply data parallelism
59
+ fsdp_mesh = mesh["fsdp"]
60
+ hsdp_mesh = mesh["ep", "fsdp"]
61
+ # Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the
62
+ # optimizer (Zero-1) and gradients (Zero-2), but not the model weights.
63
+ # Reason: the MoE is "sparsely activated" compared to the dense model, thus
64
+ # it will be ineconomical re-gather the weights.
65
+ for layer in model.model.layers.values():
66
+ # Apply FSDP to experts
67
+ if hasattr(layer.mlp, "experts"):
68
+ for expert in layer.mlp.experts.values():
69
+ fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False)
70
+ # Apply HSDP to other parts such as attention, layernorm, because they
71
+ # are doing DDP on EP dimension
72
+ fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False)
73
+
74
+ # Apply HSDP on root model (lm_head, embeddings, etc)
75
+ fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)
76
+
77
+ # Synthetic setting
78
+ microbatches = pp_size * 2
79
+
80
+ # Use Symmetric Memory for MoE token shuffle.
81
+ # TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is
82
+ # currently supported for forward only. See `generate.py`.
83
+ # model.setup_symm_mem(torch.bfloat16, device)
84
+
85
+ # Example inputs
86
+ torch.manual_seed(ep_rank)
87
+ bs = 4
88
+ seqlen = 128
89
+ x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device)
90
+ label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device)
91
+
92
+ # Create loss function
93
+ loss_fn = torch.nn.functional.cross_entropy
94
+
95
+ # Run forward and backward
96
+ steps = 2
97
+ for _ in range(steps):
98
+ if pp_size > 1:
99
+ # Create pipeline stage
100
+ stage = PipelineStage(
101
+ model,
102
+ pp_rank,
103
+ pp_size,
104
+ device,
105
+ group=pp_mesh.get_group(),
106
+ )
107
+
108
+ # Create pipeline schedule
109
+ losses = []
110
+ pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn)
111
+
112
+ if pp_rank == 0:
113
+ y = pp_schedule.step(x)
114
+ elif pp_rank == pp_size - 1:
115
+ y = pp_schedule.step(target=label, losses=losses)
116
+ loss = torch.mean(torch.stack(losses))
117
+ else:
118
+ pp_schedule.step()
119
+ else:
120
+ y = model(x)
121
+ loss = loss_fn(y, label)
122
+ loss.backward()
123
+
124
+ if pp_rank == pp_size - 1:
125
+ print(f"logits: {y.shape}")
126
+ print(f"{loss=}")
127
+
128
+ if pp_rank == 0:
129
+ param = model.get_parameter("model.layers.0.self_attn.q_proj.weight")
130
+ print(f"{torch.linalg.norm(param.grad)=}")
131
+
132
+ model.zero_grad()
133
+
134
+ print("Backward done")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp"))
139
+
140
+ run_full_model(mesh)
141
+
142
+ dist.destroy_process_group()
torchtitan/experiments/flux/dataset/flux_dataset.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import random
9
+ from dataclasses import dataclass
10
+ from typing import Any, Callable, Optional
11
+
12
+ import numpy as np
13
+
14
+ import torch
15
+
16
+ from datasets import Dataset, load_dataset
17
+ from datasets.distributed import split_dataset_by_node
18
+ from PIL import Image
19
+
20
+ from torch.distributed.checkpoint.stateful import Stateful
21
+
22
+ from torch.utils.data import IterableDataset
23
+ from torchtitan.components.dataloader import ParallelAwareDataloader
24
+
25
+ from torchtitan.config_manager import JobConfig
26
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
27
+ from torchtitan.tools.logging import logger
28
+
29
+
30
+ def _process_cc12m_image(
31
+ img: Image.Image,
32
+ output_size: int = 256,
33
+ ) -> Optional[torch.Tensor]:
34
+ """Process CC12M image to the desired size."""
35
+
36
+ width, height = img.size
37
+ # Skip low resolution images
38
+ if width < output_size or height < output_size:
39
+ return None
40
+
41
+ if width >= height:
42
+ # resize height to be equal to output_size, then crop
43
+ new_width, new_height = math.ceil(output_size / height * width), output_size
44
+ img = img.resize((new_width, new_height))
45
+ left = random.randint(0, new_width - output_size)
46
+ resized_img = img.crop((left, 0, left + output_size, output_size))
47
+ else:
48
+ # resize width to be equal to output_size, the crop
49
+ new_width, new_height = (
50
+ output_size,
51
+ math.ceil(output_size / width * height),
52
+ )
53
+ img = img.resize((new_width, new_height))
54
+ lower = random.randint(0, new_width - output_size)
55
+ resized_img = img.crop((0, lower, output_size, lower + output_size))
56
+
57
+ assert resized_img.size[0] == resized_img.size[1] == output_size
58
+
59
+ # Skip grayscale images
60
+ if resized_img.mode == "L":
61
+ return None
62
+
63
+ np_img = np.array(resized_img).transpose((2, 0, 1))
64
+ tensor_img = torch.tensor(np_img).float() / 255.0
65
+
66
+ # NOTE: The following commented code is an alternative way
67
+ # img_transform = transforms.Compose(
68
+ # [
69
+ # transforms.Resize(max(output_size, output_size)),
70
+ # transforms.CenterCrop((output_size, output_size)),
71
+ # transforms.ToTensor(),
72
+ # ]
73
+ # )
74
+ # tensor_img = img_transform(img)
75
+
76
+ return tensor_img
77
+
78
+
79
+ def _flux_data_processor(
80
+ sample: dict[str, Any],
81
+ t5_tokenizer: FluxTokenizer,
82
+ clip_tokenizer: FluxTokenizer,
83
+ output_size: int = 256,
84
+ ) -> dict[str, Any]:
85
+ """
86
+ Preprocess CC12M dataset sample image and text for Flux model.
87
+
88
+ Args:
89
+ sample: A sample from dataset
90
+ t5_encoder: T5 encoder
91
+ clip_encoder: CLIP encoder
92
+ output_size: The output image size
93
+
94
+ """
95
+ img = _process_cc12m_image(sample["jpg"], output_size=output_size)
96
+ t5_tokens = t5_tokenizer.encode(sample["txt"])
97
+ clip_tokens = clip_tokenizer.encode(sample["txt"])
98
+
99
+ return {
100
+ "image": img,
101
+ "clip_tokens": clip_tokens, # type: List[int]
102
+ "t5_tokens": t5_tokens, # type: List[int]
103
+ }
104
+
105
+
106
+ @dataclass
107
+ class TextToImageDatasetConfig:
108
+ path: str
109
+ loader: Callable
110
+ data_processor: Callable
111
+
112
+
113
+ DATASETS = {
114
+ "cc12m": TextToImageDatasetConfig(
115
+ path="pixparse/cc12m-wds",
116
+ loader=lambda path: load_dataset(path, split="train", streaming=True),
117
+ data_processor=_flux_data_processor,
118
+ ),
119
+ }
120
+
121
+
122
+ def _validate_dataset(
123
+ dataset_name: str, dataset_path: Optional[str] = None
124
+ ) -> tuple[str, Callable, Callable]:
125
+ """Validate dataset name and path."""
126
+ if dataset_name not in DATASETS:
127
+ raise ValueError(
128
+ f"Dataset {dataset_name} is not supported. "
129
+ f"Supported datasets are: {list(DATASETS.keys())}"
130
+ )
131
+
132
+ config = DATASETS[dataset_name]
133
+ path = dataset_path or config.path
134
+ logger.info(f"Preparing {dataset_name} dataset from {path}")
135
+ return path, config.loader, config.data_processor
136
+
137
+
138
+ class FluxDataset(IterableDataset, Stateful):
139
+ """Dataset for FLUX text-to-image model.
140
+
141
+ Args:
142
+ dataset_name (str): Name of the dataset.
143
+ dataset_path (str): Path to the dataset.
144
+ model_transform (Transform): Callable that applies model-specific preprocessing to the sample.
145
+ dp_rank (int): Data parallel rank.
146
+ dp_world_size (int): Data parallel world size.
147
+ infinite (bool): Whether to loop over the dataset infinitely.
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ dataset_name: str,
153
+ dataset_path: Optional[str],
154
+ t5_tokenizer: FluxTokenizer,
155
+ clip_tokenizer: FluxTokenizer,
156
+ job_config: Optional[JobConfig] = None,
157
+ dp_rank: int = 0,
158
+ dp_world_size: int = 1,
159
+ infinite: bool = False,
160
+ ) -> None:
161
+
162
+ # Force lowercase for consistent comparison
163
+ dataset_name = dataset_name.lower()
164
+
165
+ path, dataset_loader, data_processor = _validate_dataset(
166
+ dataset_name, dataset_path
167
+ )
168
+ ds = dataset_loader(path)
169
+
170
+ self.dataset_name = dataset_name
171
+ self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
172
+
173
+ self._t5_tokenizer = t5_tokenizer
174
+ self._clip_tokenizer = clip_tokenizer
175
+ self._data_processor = data_processor
176
+ self.job_config = job_config
177
+
178
+ self.infinite = infinite
179
+
180
+ # Variables for checkpointing
181
+ self._sample_idx = 0
182
+ self._all_samples: list[dict[str, Any]] = []
183
+
184
+ def _get_data_iter(self):
185
+ if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
186
+ return iter([])
187
+
188
+ it = iter(self._data)
189
+ for _ in range(self._sample_idx):
190
+ next(it)
191
+ return it
192
+
193
+ def __iter__(self):
194
+ while True:
195
+ for sample in self._get_data_iter():
196
+ # Use the dataset-specific preprocessor
197
+ sample_dict = self._data_processor(
198
+ sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256
199
+ )
200
+
201
+ # skip low quality image or image with color channel = 1
202
+ if sample_dict["image"] is None:
203
+ logger.warning(
204
+ f"Low quality image {sample['__key__']} is skipped in Flux Dataloader"
205
+ )
206
+ continue
207
+
208
+ self._all_samples.extend(sample_dict)
209
+ self._sample_idx += 1
210
+
211
+ labels = sample_dict.pop("image")
212
+ yield sample_dict, labels
213
+
214
+ if not self.infinite:
215
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
216
+ break
217
+ else:
218
+ # Reset offset for the next iteration
219
+ self._sample_idx = 0
220
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
221
+
222
+ def load_state_dict(self, state_dict):
223
+ self._sample_idx = state_dict["sample_idx"]
224
+ self._all_samples = state_dict["all_samples"]
225
+
226
+ def state_dict(self):
227
+ return {
228
+ "all_samples": self._all_samples,
229
+ "sample_idx": self._sample_idx,
230
+ }
231
+
232
+
233
+ def build_flux_dataloader(
234
+ dp_world_size: int,
235
+ dp_rank: int,
236
+ job_config: JobConfig,
237
+ # This parameter is not used, keep it for compatibility
238
+ tokenizer: FluxTokenizer | None,
239
+ infinite: bool = True,
240
+ ) -> ParallelAwareDataloader:
241
+ """Build a data loader for HuggingFace datasets."""
242
+ dataset_name = job_config.training.dataset
243
+ dataset_path = job_config.training.dataset_path
244
+ batch_size = job_config.training.batch_size
245
+
246
+ t5_encoder_name = job_config.encoder.t5_encoder
247
+ clip_encoder_name = job_config.encoder.clip_encoder
248
+ max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
249
+
250
+ ds = FluxDataset(
251
+ dataset_name=dataset_name,
252
+ dataset_path=dataset_path,
253
+ t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len),
254
+ clip_tokenizer=FluxTokenizer(
255
+ clip_encoder_name, max_length=77
256
+ ), # fix max_length for CLIP
257
+ dp_rank=dp_rank,
258
+ dp_world_size=dp_world_size,
259
+ infinite=infinite,
260
+ )
261
+
262
+ return ParallelAwareDataloader(
263
+ dataset=ds,
264
+ dp_rank=dp_rank,
265
+ dp_world_size=dp_world_size,
266
+ batch_size=batch_size,
267
+ )
torchtitan/experiments/flux/model/layers.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # imported from black-forest-labs/FLUX
8
+ import math
9
+ from dataclasses import dataclass
10
+
11
+ import torch
12
+ from einops import rearrange
13
+ from torch import nn, Tensor
14
+
15
+ from torchtitan.experiments.flux.model.math import attention, rope
16
+
17
+
18
+ class EmbedND(nn.Module):
19
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
20
+ super().__init__()
21
+ self.dim = dim
22
+ self.theta = theta
23
+ self.axes_dim = axes_dim
24
+
25
+ def forward(self, ids: Tensor) -> Tensor:
26
+ n_axes = ids.shape[-1]
27
+ emb = torch.cat(
28
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
29
+ dim=-3,
30
+ )
31
+
32
+ return emb.unsqueeze(1)
33
+
34
+
35
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
36
+ """
37
+ Create sinusoidal timestep embeddings.
38
+ :param t: a 1-D Tensor of N indices, one per batch element.
39
+ These may be fractional.
40
+ :param dim: the dimension of the output.
41
+ :param max_period: controls the minimum frequency of the embeddings.
42
+ :return: an (N, D) Tensor of positional embeddings.
43
+ """
44
+ t = time_factor * t
45
+ half = dim // 2
46
+ freqs = torch.exp(
47
+ -math.log(max_period)
48
+ * torch.arange(start=0, end=half, dtype=torch.float32)
49
+ / half
50
+ ).to(t.device)
51
+
52
+ args = t[:, None].float() * freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ if torch.is_floating_point(t):
57
+ embedding = embedding.to(t)
58
+ return embedding
59
+
60
+
61
+ class MLPEmbedder(nn.Module):
62
+ def __init__(self, in_dim: int, hidden_dim: int):
63
+ super().__init__()
64
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
65
+ self.silu = nn.SiLU()
66
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ return self.out_layer(self.silu(self.in_layer(x)))
70
+
71
+
72
+ class RMSNorm(torch.nn.Module):
73
+ def __init__(self, dim: int):
74
+ super().__init__()
75
+ self.scale = nn.Parameter(torch.ones(dim))
76
+
77
+ def forward(self, x: Tensor):
78
+ x_dtype = x.dtype
79
+ x = x.float()
80
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
81
+ return (x * rrms).to(dtype=x_dtype) * self.scale
82
+
83
+
84
+ class QKNorm(torch.nn.Module):
85
+ def __init__(self, dim: int):
86
+ super().__init__()
87
+ self.query_norm = RMSNorm(dim) # TODO(jianiw): switch to pytorch nn.RMSNorm
88
+ self.key_norm = RMSNorm(dim)
89
+
90
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
91
+ q = self.query_norm(q)
92
+ k = self.key_norm(k)
93
+ return q.to(v), k.to(v)
94
+
95
+
96
+ class SelfAttention(nn.Module):
97
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
98
+ super().__init__()
99
+ self.num_heads = num_heads
100
+ head_dim = dim // num_heads
101
+
102
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
103
+ self.norm = QKNorm(head_dim)
104
+ self.proj = nn.Linear(dim, dim)
105
+
106
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
107
+ qkv = self.qkv(x)
108
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
109
+ q, k = self.norm(q, k, v)
110
+ x = attention(q, k, v, pe=pe)
111
+ x = self.proj(x)
112
+ return x
113
+
114
+
115
+ @dataclass
116
+ class ModulationOut:
117
+ shift: Tensor
118
+ scale: Tensor
119
+ gate: Tensor
120
+
121
+
122
+ class Modulation(nn.Module):
123
+ def __init__(self, dim: int, double: bool):
124
+ super().__init__()
125
+ self.is_double = double
126
+ self.multiplier = 6 if double else 3
127
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
128
+
129
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
130
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
131
+ self.multiplier, dim=-1
132
+ )
133
+
134
+ return (
135
+ ModulationOut(*out[:3]),
136
+ ModulationOut(*out[3:]) if self.is_double else None,
137
+ )
138
+
139
+
140
+ class DoubleStreamBlock(nn.Module):
141
+ def __init__(
142
+ self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
143
+ ):
144
+ super().__init__()
145
+
146
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
147
+ self.num_heads = num_heads
148
+ self.hidden_size = hidden_size
149
+ self.img_mod = Modulation(hidden_size, double=True)
150
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
151
+ self.img_attn = SelfAttention(
152
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
153
+ )
154
+
155
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
156
+ self.img_mlp = nn.Sequential(
157
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
158
+ nn.GELU(approximate="tanh"),
159
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
160
+ )
161
+
162
+ self.txt_mod = Modulation(hidden_size, double=True)
163
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
164
+ self.txt_attn = SelfAttention(
165
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
166
+ )
167
+
168
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
169
+ self.txt_mlp = nn.Sequential(
170
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
171
+ nn.GELU(approximate="tanh"),
172
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
173
+ )
174
+
175
+ def forward(
176
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
177
+ ) -> tuple[Tensor, Tensor]:
178
+ img_mod1, img_mod2 = self.img_mod(vec)
179
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
180
+
181
+ # prepare image for attention
182
+ img_modulated = self.img_norm1(img)
183
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
184
+ img_qkv = self.img_attn.qkv(img_modulated)
185
+ img_q, img_k, img_v = rearrange(
186
+ img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
187
+ )
188
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
189
+
190
+ # prepare txt for attention
191
+ txt_modulated = self.txt_norm1(txt)
192
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
193
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
194
+ txt_q, txt_k, txt_v = rearrange(
195
+ txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
196
+ )
197
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
198
+
199
+ # run actual attention
200
+ q = torch.cat((txt_q, img_q), dim=2)
201
+ k = torch.cat((txt_k, img_k), dim=2)
202
+ v = torch.cat((txt_v, img_v), dim=2)
203
+
204
+ attn = attention(q, k, v, pe=pe)
205
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
206
+
207
+ # calculate the img bloks
208
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
209
+ img = img + img_mod2.gate * self.img_mlp(
210
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
211
+ )
212
+
213
+ # calculate the txt bloks
214
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
215
+ txt = txt + txt_mod2.gate * self.txt_mlp(
216
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
217
+ )
218
+ return img, txt
219
+
220
+
221
+ class SingleStreamBlock(nn.Module):
222
+ """
223
+ A DiT block with parallel linear layers as described in
224
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ hidden_size: int,
230
+ num_heads: int,
231
+ mlp_ratio: float = 4.0,
232
+ qk_scale: float | None = None,
233
+ ):
234
+ super().__init__()
235
+ self.hidden_dim = hidden_size
236
+ self.num_heads = num_heads
237
+ head_dim = hidden_size // num_heads
238
+ self.scale = qk_scale or head_dim**-0.5
239
+
240
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
241
+ # qkv and mlp_in
242
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
243
+ # proj and mlp_out
244
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
245
+
246
+ self.norm = QKNorm(head_dim)
247
+
248
+ self.hidden_size = hidden_size
249
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
250
+
251
+ self.mlp_act = nn.GELU(approximate="tanh")
252
+ self.modulation = Modulation(hidden_size, double=False)
253
+
254
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
255
+ mod, _ = self.modulation(vec)
256
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
257
+ qkv, mlp = torch.split(
258
+ self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
259
+ )
260
+
261
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
262
+ q, k = self.norm(q, k, v)
263
+
264
+ # compute attention
265
+ attn = attention(q, k, v, pe=pe)
266
+ # compute activation in mlp stream, cat again and run second linear layer
267
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
268
+ return x + mod.gate * output
269
+
270
+
271
+ class LastLayer(nn.Module):
272
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
273
+ super().__init__()
274
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
275
+ self.linear = nn.Linear(
276
+ hidden_size, patch_size * patch_size * out_channels, bias=True
277
+ )
278
+ self.adaLN_modulation = nn.Sequential(
279
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
280
+ )
281
+
282
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
283
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
284
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
285
+ x = self.linear(x)
286
+ return x
torchtitan/experiments/flux/model/math.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+
12
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
13
+ q, k = apply_rope(q, k, pe)
14
+
15
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
16
+ x = rearrange(x, "B H L D -> B L (H D)")
17
+
18
+ return x
19
+
20
+
21
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
22
+ assert dim % 2 == 0
23
+ scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
24
+ omega = 1.0 / (theta**scale)
25
+ out = torch.einsum("...n,d->...nd", pos, omega)
26
+ out = torch.stack(
27
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
28
+ )
29
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
30
+ return out.float()
31
+
32
+
33
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
34
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
35
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
36
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
37
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
38
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
torchtitan/experiments/flux/model/model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+ import torch
10
+
11
+ from torch import nn, Tensor
12
+ from torchtitan.components.tokenizer import Tokenizer
13
+ from torchtitan.config_manager import JobConfig
14
+
15
+ from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
16
+ from torchtitan.experiments.flux.model.layers import (
17
+ DoubleStreamBlock,
18
+ EmbedND,
19
+ LastLayer,
20
+ MLPEmbedder,
21
+ SingleStreamBlock,
22
+ timestep_embedding,
23
+ )
24
+
25
+ from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
26
+ from torchtitan.tools.logging import logger
27
+
28
+
29
+ @dataclass
30
+ class FluxModelArgs(BaseModelArgs):
31
+ in_channels: int = 64
32
+ out_channels: int = 64
33
+ vec_in_dim: int = 768
34
+ context_in_dim: int = 512
35
+ hidden_size: int = 3072
36
+ mlp_ratio: float = 4.0
37
+ num_heads: int = 24
38
+ depth: int = 19
39
+ depth_single_blocks: int = 38
40
+ axes_dim: tuple = (16, 56, 56)
41
+ theta: int = 10_000
42
+ qkv_bias: bool = True
43
+ guidance_embed: bool = True
44
+ autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
45
+
46
+ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
47
+ # context_in_dim is the same as the T5 embedding dimension
48
+ self.context_in_dim = job_config.encoder.max_t5_encoding_len
49
+
50
+ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
51
+ # TODO(jianiw): Add the number of flops for the autoencoder
52
+ nparams = sum(p.numel() for p in model.parameters())
53
+ logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
54
+ return nparams, 1
55
+
56
+
57
+ class FluxModel(nn.Module, ModelProtocol):
58
+ """
59
+ Transformer model for flow matching on sequences.
60
+
61
+ Agrs:
62
+ model_args: FluxModelArgs.
63
+
64
+ Attributes:
65
+ model_args (TransformerModelArgs): Model configuration arguments.
66
+ """
67
+
68
+ def __init__(self, model_args: FluxModelArgs):
69
+ super().__init__()
70
+
71
+ self.model_args = model_args
72
+ self.in_channels = model_args.in_channels
73
+ self.out_channels = model_args.out_channels
74
+ if model_args.hidden_size % model_args.num_heads != 0:
75
+ raise ValueError(
76
+ f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
77
+ )
78
+ pe_dim = model_args.hidden_size // model_args.num_heads
79
+ if sum(model_args.axes_dim) != pe_dim:
80
+ raise ValueError(
81
+ f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
82
+ )
83
+ self.hidden_size = model_args.hidden_size
84
+ self.num_heads = model_args.num_heads
85
+ self.pe_embedder = EmbedND(
86
+ dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
87
+ )
88
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
89
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
90
+ self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
91
+ self.guidance_in = (
92
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
93
+ if model_args.guidance_embed
94
+ else nn.Identity()
95
+ )
96
+ self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
97
+
98
+ self.double_blocks = nn.ModuleList(
99
+ [
100
+ DoubleStreamBlock(
101
+ self.hidden_size,
102
+ self.num_heads,
103
+ mlp_ratio=model_args.mlp_ratio,
104
+ qkv_bias=model_args.qkv_bias,
105
+ )
106
+ for _ in range(model_args.depth)
107
+ ]
108
+ )
109
+
110
+ self.single_blocks = nn.ModuleList(
111
+ [
112
+ SingleStreamBlock(
113
+ self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio
114
+ )
115
+ for _ in range(model_args.depth_single_blocks)
116
+ ]
117
+ )
118
+
119
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
120
+
121
+ def init_weights(self, buffer_device=None):
122
+ # TODO(jianiw): replace placeholder with real weight init
123
+ for param in self.parameters():
124
+ param.data.uniform_(0, 0.1)
125
+
126
+ def forward(
127
+ self,
128
+ img: Tensor,
129
+ img_ids: Tensor,
130
+ txt: Tensor,
131
+ txt_ids: Tensor,
132
+ timesteps: Tensor,
133
+ y: Tensor,
134
+ guidance: Tensor | None = None,
135
+ ) -> Tensor:
136
+ if img.ndim != 3 or txt.ndim != 3:
137
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
138
+
139
+ # running on sequences img
140
+ img = self.img_in(img)
141
+ vec = self.time_in(timestep_embedding(timesteps, 256))
142
+ if self.model_args.guidance_embed:
143
+ if guidance is None:
144
+ raise ValueError(
145
+ "Didn't get guidance strength for guidance distilled model."
146
+ )
147
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
148
+ vec = vec + self.vector_in(y)
149
+ txt = self.txt_in(txt)
150
+
151
+ ids = torch.cat((txt_ids, img_ids), dim=1)
152
+ pe = self.pe_embedder(ids)
153
+
154
+ for block in self.double_blocks:
155
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
156
+
157
+ img = torch.cat((txt, img), 1)
158
+ for block in self.single_blocks:
159
+ img = block(img, vec=vec, pe=pe)
160
+ img = img[:, txt.shape[1] :, ...]
161
+
162
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
163
+ return img
164
+
165
+ @classmethod
166
+ def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
167
+ """
168
+ Initialize a Flux model from a FluxModelArgs object.
169
+
170
+ Args:
171
+ model_args (FluxModelArgs): Model configuration arguments.
172
+
173
+ Returns:
174
+ FluxModel: FluxModel model.
175
+
176
+ """
177
+ return cls(model_args)
torchtitan/experiments/flux/parallelize_flux.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+
11
+ import torch.nn as nn
12
+
13
+ from torch.distributed.device_mesh import DeviceMesh
14
+
15
+ from torchtitan.config_manager import JobConfig
16
+ from torchtitan.distributed import ParallelDims
17
+
18
+
19
+ def parallelize_flux(
20
+ model: nn.Module,
21
+ world_mesh: DeviceMesh,
22
+ parallel_dims: ParallelDims,
23
+ job_config: JobConfig,
24
+ ):
25
+ # TODO: Add model parallel strategy here
26
+ return model
torchtitan/experiments/flux/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ einops
torchtitan/experiments/flux/scripts/download_autoencoder.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ from requests.exceptions import HTTPError
10
+
11
+
12
+ def hf_download(
13
+ repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None
14
+ ) -> None:
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ try:
18
+ hf_hub_download(
19
+ repo_id=repo_id,
20
+ filename=file_path,
21
+ local_dir=local_dir,
22
+ local_dir_use_symlinks=False,
23
+ token=hf_token,
24
+ )
25
+ except HTTPError as e:
26
+ if e.response.status_code == 401:
27
+ print(
28
+ "You need to pass a valid `--hf_token=...` to download private checkpoints."
29
+ )
30
+ else:
31
+ raise e
32
+
33
+
34
+ if __name__ == "__main__":
35
+ import argparse
36
+
37
+ parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
38
+ parser.add_argument(
39
+ "--repo_id",
40
+ type=str,
41
+ default="black-forest-labs/FLUX.1-dev",
42
+ help="Repository ID to download from. default to Flux-dev model",
43
+ )
44
+ parser.add_argument(
45
+ "--ae_path",
46
+ type=str,
47
+ default="ae.safetensors",
48
+ help="the autoencoder path relative to repo_id",
49
+ )
50
+ parser.add_argument(
51
+ "--hf_token", type=str, default=None, help="HuggingFace API token"
52
+ )
53
+ parser.add_argument(
54
+ "--local_dir",
55
+ type=str,
56
+ default="torchtitan/experiments/flux/assets/autoencoder/",
57
+ help="local directory to save the autoencoder",
58
+ )
59
+
60
+ args = parser.parse_args()
61
+ hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token)
torchtitan/experiments/flux/tests/test_flux_dataloader.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+
9
+ from torchtitan.config_manager import JobConfig
10
+ from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
11
+ from torchtitan.tools.profiling import (
12
+ maybe_enable_memory_snapshot,
13
+ maybe_enable_profiling,
14
+ )
15
+
16
+
17
+ class TestFluxDataLoader:
18
+ def test_flux_dataloader(self):
19
+ dataset_name = "cc12m"
20
+ batch_size = 32
21
+ world_size = 4
22
+ rank = 0
23
+
24
+ num_steps = 10
25
+
26
+ path = "torchtitan.experiments.flux.flux_argparser"
27
+ sys.argv.append(f"--experimental.custom_args_module={path}")
28
+ config = JobConfig()
29
+ config.maybe_add_custom_args()
30
+ config.parse_args(
31
+ [
32
+ # Profiling options
33
+ # "--profiling.enable_profiling",
34
+ # "--profiling.profile_freq",
35
+ # "5",
36
+ # "--profiling.enable_memory_snapshot",
37
+ # "--profiling.save_memory_snapshot_folder",
38
+ # "memory_snapshot_flux",
39
+ "--training.dataset",
40
+ dataset_name,
41
+ "--training.batch_size",
42
+ str(batch_size),
43
+ "--encoder.t5_encoder",
44
+ "google/t5-v1_1-small",
45
+ "--encoder.clip_encoder",
46
+ "openai/clip-vit-large-patch14",
47
+ "--encoder.max_t5_encoding_len",
48
+ "512",
49
+ ]
50
+ )
51
+
52
+ with maybe_enable_profiling(
53
+ config, global_step=0
54
+ ) as torch_profiler, maybe_enable_memory_snapshot(
55
+ config, global_step=0
56
+ ) as memory_profiler:
57
+ dl = self._build_dataloader(
58
+ config,
59
+ world_size,
60
+ rank,
61
+ )
62
+ dl = iter(dl)
63
+
64
+ for i in range(0, num_steps):
65
+ input_data, labels = next(dl)
66
+ print(f"Step {i} image size: {labels.shape}")
67
+ if torch_profiler:
68
+ torch_profiler.step()
69
+ if memory_profiler:
70
+ memory_profiler.step()
71
+
72
+ print(len(input_data["clip_tokens"]))
73
+ for k, v in input_data.items():
74
+ print(f"Step {i} {k} value: {type(v), v.shape}")
75
+
76
+ assert len(input_data) == 2 # (clip_encodings, t5_encodings)
77
+ assert labels.shape == (batch_size, 3, 256, 256)
78
+ # assert input_data["clip_tokens"].shape[0] == batch_size
79
+ # assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
80
+
81
+ if torch_profiler:
82
+ torch_profiler.step()
83
+ if memory_profiler:
84
+ memory_profiler.step(exit_ctx=True)
85
+
86
+ def test_preprocess(self):
87
+ # TODO
88
+ pass
89
+
90
+ def _build_dataloader(
91
+ self,
92
+ job_config,
93
+ world_size,
94
+ rank,
95
+ ):
96
+
97
+ return build_flux_dataloader(
98
+ dp_world_size=world_size,
99
+ dp_rank=rank,
100
+ job_config=job_config,
101
+ tokenizer=None,
102
+ infinite=False,
103
+ )
torchtitan/experiments/flux/train.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from typing import Optional
9
+
10
+ import torch
11
+
12
+ from torchtitan.config_manager import JobConfig
13
+ from torchtitan.distributed import utils as dist_utils
14
+ from torchtitan.experiments.flux.model.autoencoder import load_ae
15
+ from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
16
+ from torchtitan.experiments.flux.model.model import FluxModel
17
+ from torchtitan.experiments.flux.utils import (
18
+ create_position_encoding_for_latents,
19
+ pack_latents,
20
+ preprocess_flux_data,
21
+ unpack_latents,
22
+ )
23
+ from torchtitan.tools.logging import init_logger, logger
24
+ from torchtitan.train import Trainer
25
+
26
+
27
+ class FluxTrainer(Trainer):
28
+ def __init__(self, job_config: JobConfig):
29
+ super().__init__(job_config)
30
+
31
+ self.preprocess_fn = preprocess_flux_data
32
+ # self.dtype = job_config.encoder.dtype
33
+ self._dtype = torch.bfloat16
34
+ self._seed = job_config.training.seed
35
+ self._guidance = job_config.training.guidance
36
+
37
+ # load components
38
+ model_config = self.train_spec.config[job_config.model.flavor]
39
+ self.autoencoder = load_ae(
40
+ job_config.encoder.auto_encoder_path,
41
+ model_config.autoencoder_params,
42
+ device="cpu",
43
+ dtype=self._dtype,
44
+ )
45
+ self.clip_encoder = FluxEmbedder(version=job_config.encoder.clip_encoder).to(
46
+ dtype=self._dtype
47
+ )
48
+ self.t5_encoder = FluxEmbedder(version=job_config.encoder.t5_encoder).to(
49
+ dtype=self._dtype
50
+ )
51
+
52
+ def _predict_noise(
53
+ self,
54
+ model: FluxModel,
55
+ latents: torch.Tensor,
56
+ clip_encodings: torch.Tensor,
57
+ t5_encodings: torch.Tensor,
58
+ timesteps: torch.Tensor,
59
+ guidance: Optional[torch.Tensor] = None,
60
+ ) -> torch.Tensor:
61
+ """
62
+ Use Flux's flow-matching model to predict the noise in image latents.
63
+ Args:
64
+ model (FluxFlowModel): The Flux flow model.
65
+ latents (Tensor): Image encodings from the Flux autoencoder.
66
+ Shape: [bsz, 16, latent height, latent width]
67
+ clip_encodings (Tensor): CLIP text encodings.
68
+ Shape: [bsz, 768]
69
+ t5_encodings (Tensor): T5 text encodings.
70
+ Shape: [bsz, sequence length, 256 or 512]
71
+ timesteps (Tensor): The amount of noise (0 to 1).
72
+ Shape: [bsz]
73
+ guidance (Optional[Tensor]): The guidance value (1.5 to 4) if guidance-enabled model.
74
+ Shape: [bsz]
75
+ Default: None
76
+ model_ctx (ContextManager): Optional context to wrap the model call (e.g. for activation offloading)
77
+ Default: nullcontext
78
+ Returns:
79
+ Tensor: The noise prediction.
80
+ Shape: [bsz, 16, latent height, latent width]
81
+ """
82
+ bsz, _, latent_height, latent_width = latents.shape
83
+
84
+ POSITION_DIM = 3 # constant for Flux flow model
85
+ with torch.no_grad():
86
+ # Create positional encodings
87
+ latent_pos_enc = create_position_encoding_for_latents(
88
+ bsz, latent_height, latent_width, POSITION_DIM
89
+ )
90
+ text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM)
91
+
92
+ # Convert latent into a sequence of patches
93
+ latents = pack_latents(latents)
94
+
95
+ # Predict noise
96
+ latent_noise_pred = model(
97
+ img=latents,
98
+ img_ids=latent_pos_enc.to(latents),
99
+ txt=t5_encodings.to(latents),
100
+ txt_ids=text_pos_enc.to(latents),
101
+ y=clip_encodings.to(latents),
102
+ timesteps=timesteps.to(latents),
103
+ guidance=guidance.to(latents) if guidance is not None else None,
104
+ )
105
+
106
+ # Convert sequence of patches to latent shape
107
+ latent_noise_pred = unpack_latents(
108
+ latent_noise_pred, latent_height, latent_width
109
+ )
110
+
111
+ return latent_noise_pred
112
+
113
+ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
114
+ # generate t5 and clip
115
+ input_dict["image"] = labels
116
+ input_dict = self.preprocess_fn(
117
+ device=self.device,
118
+ dtype=self._dtype,
119
+ autoencoder=self.autoencoder,
120
+ clip_encoder=self.clip_encoder,
121
+ t5_encoder=self.t5_encoder,
122
+ batch=input_dict,
123
+ offload=True,
124
+ )
125
+ labels = input_dict["img_encodings"]
126
+
127
+ self.optimizers.zero_grad()
128
+
129
+ # Keep these variables local to shorten the code as these are
130
+ # the major variables that are used in the training loop.
131
+ model_parts = self.model_parts
132
+ world_mesh = self.world_mesh
133
+ parallel_dims = self.parallel_dims
134
+
135
+ # image in latent space transformed by self.auto_encoder
136
+ clip_encodings = input_dict["clip_encodings"]
137
+ t5_encodings = input_dict["t5_encodings"]
138
+
139
+ bsz = labels.shape[0]
140
+
141
+ with torch.no_grad():
142
+ noise = torch.randn_like(labels)
143
+ timesteps = torch.rand((bsz,)).to(labels)
144
+ sigmas = timesteps.view(-1, 1, 1, 1)
145
+ noisy_latents = (1 - sigmas) * labels + sigmas * noise
146
+ guidance = torch.full((bsz,), self._guidance).to(labels)
147
+
148
+ target = noise - labels
149
+
150
+ assert len(model_parts) == 1
151
+ # TODO(jianiw): model_parts will be wrapped by FSDP, which will cacluate
152
+ model_parts[0] = model_parts[0].to(dtype=self._dtype)
153
+
154
+ pred = self._predict_noise(
155
+ model_parts[0],
156
+ noisy_latents,
157
+ clip_encodings,
158
+ t5_encodings,
159
+ timesteps,
160
+ guidance,
161
+ )
162
+ loss = self.loss_fn(pred, target)
163
+ # pred.shape=(bs, seq_len, vocab_size)
164
+ # need to free to before bwd to avoid peaking memory
165
+ del (pred, noise, target)
166
+ loss.backward()
167
+
168
+ dist_utils.clip_grad_norm_(
169
+ [p for m in model_parts for p in m.parameters()],
170
+ self.job_config.training.max_norm,
171
+ foreach=True,
172
+ pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None,
173
+ )
174
+ self.checkpointer.maybe_wait_for_staging()
175
+ self.optimizers.step()
176
+ self.lr_schedulers.step()
177
+
178
+ # log metrics
179
+ if not self.metrics_processor.should_log(self.step):
180
+ return
181
+
182
+ if (
183
+ parallel_dims.dp_replicate_enabled
184
+ or parallel_dims.dp_shard_enabled
185
+ or parallel_dims.cp_enabled
186
+ ):
187
+ loss = loss.detach()
188
+ global_avg_loss, global_max_loss = (
189
+ dist_utils.dist_mean(loss, world_mesh["dp_cp"]),
190
+ dist_utils.dist_max(loss, world_mesh["dp_cp"]),
191
+ )
192
+ else:
193
+ global_avg_loss = global_max_loss = loss.item()
194
+
195
+ self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)
196
+
197
+
198
+ if __name__ == "__main__":
199
+ init_logger()
200
+ config = JobConfig()
201
+ config.maybe_add_custom_args()
202
+ config.parse_args()
203
+ trainer: Optional[FluxTrainer] = None
204
+
205
+ try:
206
+ trainer = FluxTrainer(config)
207
+ if config.checkpoint.create_seed_checkpoint:
208
+ assert int(
209
+ os.environ["WORLD_SIZE"]
210
+ ), "Must create seed checkpoint using a single device, to disable sharding."
211
+ assert (
212
+ config.checkpoint.enable_checkpoint
213
+ ), "Must enable checkpointing when creating a seed checkpoint."
214
+ trainer.checkpointer.save(curr_step=0, force=True)
215
+ logger.info("Created seed checkpoint")
216
+ else:
217
+ trainer.train()
218
+ finally:
219
+ if trainer:
220
+ trainer.close()
221
+
222
+ if torch.distributed.is_initialized():
223
+ torch.distributed.destroy_process_group()
224
+ logger.info("Process group destroyed.")
torchtitan/experiments/flux/train_configs/debug_model.toml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [job]
3
+ dump_folder = "./outputs"
4
+ description = "Flux debug model"
5
+ print_args = false
6
+ use_for_integration_test = true
7
+
8
+ [profiling]
9
+ enable_profiling = false
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 10
12
+ enable_memory_snapshot = false
13
+ save_memory_snapshot_folder = "memory_snapshot"
14
+
15
+ [metrics]
16
+ log_freq = 1
17
+ disable_color_printing = false
18
+ enable_tensorboard = false
19
+ save_tb_folder = "tb"
20
+ enable_wandb = false
21
+
22
+ [model]
23
+ name = "flux"
24
+ flavor = "flux-debug"
25
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
26
+ # test tokenizer.model, for debug purpose only
27
+ # tokenizer_path = "./tests/assets/test_tiktoken.model"
28
+ # converters = "float8"
29
+
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 8e-4
34
+ eps = 1e-8
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.0
41
+
42
+ [training]
43
+ batch_size = 32
44
+ seq_len = 512
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "cc12m"
49
+ guidance = 3.5
50
+ seed = 0
51
+
52
+ [encoder]
53
+ t5_encoder="google/t5-v1_1-small"
54
+ clip_encoder="openai/clip-vit-large-patch14"
55
+ max_t5_encoding_len=512
56
+ auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
57
+
58
+ [parallelism]
59
+ data_parallel_replicate_degree = 1
60
+ data_parallel_shard_degree = 1
61
+ fsdp_reshard_after_forward = "default" # default / never / always
62
+ tensor_parallel_degree = 1
63
+ enable_async_tensor_parallel = false
64
+ pipeline_parallel_degree = 1
65
+ context_parallel_degree = 1
66
+
67
+ [experimental]
68
+ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
torchtitan/experiments/flux/utils.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+
11
+ from torch import Tensor
12
+
13
+ from torchtitan.experiments.flux.model.autoencoder import AutoEncoder
14
+ from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
15
+
16
+
17
+ def preprocess_flux_data(
18
+ # arguments from the recipe
19
+ device: torch.device,
20
+ dtype: torch.dtype,
21
+ *,
22
+ # arguments from the config
23
+ autoencoder: Optional[AutoEncoder],
24
+ clip_encoder: FluxEmbedder,
25
+ t5_encoder: FluxEmbedder,
26
+ batch: dict[str, Tensor],
27
+ offload: bool = False,
28
+ ) -> dict[str, Tensor]:
29
+ """
30
+ Take a batch of inputs and encoder as input and return a batch of preprocessed data.
31
+
32
+ Args:
33
+ device (torch.device): device to do preprocessing on
34
+ dtype (torch.dtype): data type to do preprocessing in
35
+ autoencoer(AutoEncoder): autoencoder to use for preprocessing
36
+ clip_encoder
37
+ t5_encoder
38
+ batch (dict[str, Tensor]): batch of data to preprocess
39
+
40
+ Returns:
41
+ dict[str, Tensor]: batch of preprocessed data
42
+ """
43
+
44
+ # The input of encoder should be torch.int type
45
+ if offload:
46
+ clip_encoder.to(device)
47
+ t5_encoder.to(device)
48
+ if autoencoder is not None:
49
+ autoencoder.to(device)
50
+
51
+ clip_tokens = batch["clip_tokens"].squeeze().to(device=device, dtype=torch.int)
52
+ t5_tokens = batch["t5_tokens"].squeeze().to(device=device, dtype=torch.int)
53
+
54
+ clip_text_encodings = clip_encoder(clip_tokens)
55
+ t5_text_encodings = t5_encoder(t5_tokens)
56
+
57
+ if autoencoder is not None:
58
+ images = batch["image"].to(device=device, dtype=dtype)
59
+ img_encodings = autoencoder.encode(images)
60
+ batch["img_encodings"] = img_encodings.to(device=device, dtype=dtype)
61
+
62
+ batch["clip_encodings"] = clip_text_encodings.to(dtype)
63
+ batch["t5_encodings"] = t5_text_encodings.to(dtype)
64
+
65
+ # offload encoders to cpu after preprocessing
66
+ if offload:
67
+ clip_encoder.to("cpu")
68
+ t5_encoder.to("cpu")
69
+ if autoencoder is not None:
70
+ autoencoder.to("cpu")
71
+
72
+ return batch
73
+
74
+
75
+ def generate_noise_latent(
76
+ bsz: int,
77
+ height: int,
78
+ width: int,
79
+ device: str | torch.device,
80
+ dtype: torch.dtype,
81
+ seed: int,
82
+ ) -> Tensor:
83
+ """Generate noise latents for the Flux flow model.
84
+
85
+ Args:
86
+ bsz (int): batch_size.
87
+ height (int): The height of the image.
88
+ width (int): The width of the image.
89
+ device (str | torch.device): The device to use.
90
+ dtype (torch.dtype): The dtype to use.
91
+ seed (int): The seed to use for randomize.
92
+
93
+ Returns:
94
+ Tensor: The noise latents.
95
+ Shape: [num_samples, LATENT_CHANNELS, height // IMG_LATENT_SIZE_RATIO, width // IMG_LATENT_SIZE_RATIO]
96
+
97
+ """
98
+ LATENT_CHANNELS, IMAGE_LATENT_SIZE_RATIO = 16, 8
99
+ return torch.randn(
100
+ bsz,
101
+ LATENT_CHANNELS,
102
+ height // IMAGE_LATENT_SIZE_RATIO,
103
+ width // IMAGE_LATENT_SIZE_RATIO,
104
+ dtype=dtype,
105
+ generator=torch.Generator().manual_seed(seed),
106
+ ).to(device)
107
+
108
+
109
+ def create_position_encoding_for_latents(
110
+ bsz: int, latent_height: int, latent_width: int, position_dim: int = 3
111
+ ) -> Tensor:
112
+ """
113
+ Create the packed latents' position encodings for the Flux flow model.
114
+
115
+ Args:
116
+ bsz (int): The batch size.
117
+ latent_height (int): The height of the latent.
118
+ latent_width (int): The width of the latent.
119
+
120
+ Returns:
121
+ Tensor: The position encodings.
122
+ Shape: [bsz, (latent_height // PATCH_HEIGHT) * (latent_width // PATCH_WIDTH), POSITION_DIM)
123
+ """
124
+ PATCH_HEIGHT, PATCH_WIDTH = 2, 2
125
+
126
+ height = latent_height // PATCH_HEIGHT
127
+ width = latent_width // PATCH_WIDTH
128
+
129
+ position_encoding = torch.zeros(height, width, position_dim)
130
+
131
+ row_indices = torch.arange(height)
132
+ position_encoding[:, :, 1] = row_indices.unsqueeze(1)
133
+
134
+ col_indices = torch.arange(width)
135
+ position_encoding[:, :, 2] = col_indices.unsqueeze(0)
136
+
137
+ # Flatten and repeat for the full batch
138
+ # [height, width, 3] -> [bsz, height * width, 3]
139
+ position_encoding = position_encoding.view(1, height * width, position_dim)
140
+ position_encoding = position_encoding.repeat(bsz, 1, 1)
141
+
142
+ return position_encoding
143
+
144
+
145
+ def pack_latents(x: Tensor) -> Tensor:
146
+ """
147
+ Rearrange latents from an image-like format into a sequence of patches.
148
+ Equivalent to `einops.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)")`.
149
+
150
+ Args:
151
+ x (Tensor): The unpacked latents.
152
+ Shape: [bsz, ch, latent height, latent width]
153
+
154
+ Returns:
155
+ Tensor: The packed latents.
156
+ Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
157
+ """
158
+ PATCH_HEIGHT, PATCH_WIDTH = 2, 2
159
+
160
+ b, c, latent_height, latent_width = x.shape
161
+ h = latent_height // PATCH_HEIGHT
162
+ w = latent_width // PATCH_WIDTH
163
+
164
+ # [b, c, h*ph, w*ph] -> [b, c, h, w, ph, pw]
165
+ x = x.unfold(2, PATCH_HEIGHT, PATCH_HEIGHT).unfold(3, PATCH_WIDTH, PATCH_WIDTH)
166
+
167
+ # [b, c, h, w, ph, PW] -> [b, h, w, c, ph, PW]
168
+ x = x.permute(0, 2, 3, 1, 4, 5)
169
+
170
+ # [b, h, w, c, ph, PW] -> [b, h*w, c*ph*PW]
171
+ return x.reshape(b, h * w, c * PATCH_HEIGHT * PATCH_WIDTH)
172
+
173
+
174
+ def unpack_latents(x: Tensor, latent_height: int, latent_width: int) -> Tensor:
175
+ """
176
+ Rearrange latents from a sequence of patches into an image-like format.
177
+ Equivalent to `einops.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)")`.
178
+
179
+ Args:
180
+ x (Tensor): The packed latents.
181
+ Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
182
+ latent_height (int): The height of the unpacked latents.
183
+ latent_width (int): The width of the unpacked latents.
184
+
185
+ Returns:
186
+ Tensor: The unpacked latents.
187
+ Shape: [bsz, ch, latent height, latent width]
188
+ """
189
+ PATCH_HEIGHT, PATCH_WIDTH = 2, 2
190
+
191
+ b, _, c_ph_pw = x.shape
192
+ h = latent_height // PATCH_HEIGHT
193
+ w = latent_width // PATCH_WIDTH
194
+ c = c_ph_pw // (PATCH_HEIGHT * PATCH_WIDTH)
195
+
196
+ # [b, h*w, c*ph*pw] -> [b, h, w, c, ph, pw]
197
+ x = x.reshape(b, h, w, c, PATCH_HEIGHT, PATCH_WIDTH)
198
+
199
+ # [b, h, w, c, ph, pw] -> [b, c, h, ph, w, pw]
200
+ x = x.permute(0, 3, 1, 4, 2, 5)
201
+
202
+ # [b, c, h, ph, w, pw] -> [b, c, h*ph, w*pw]
203
+ return x.reshape(b, c, h * PATCH_HEIGHT, w * PATCH_WIDTH)
torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ import math
10
+ import time
11
+
12
+ from typing import Dict, List, Tuple
13
+
14
+ # import numpy as np
15
+ import torch #
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+
20
+ # from torchao_pr.mg_grouped_gemm import mg_grouped_gemm
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
25
+ )
26
+
27
+ # Try to import the optimized MG GEMM implementation
28
+ try:
29
+ from torchao_pr.mg_grouped_gemm import ( # grouped_gemm_backward,
30
+ grouped_gemm_forward,
31
+ )
32
+
33
+ has_mg_gemm = True
34
+ except ImportError:
35
+ logging.warning("MG GEMM implementation not found. Will use manual looping only.")
36
+ has_mg_gemm = False
37
+
38
+
39
+ class Router(nn.Module):
40
+ """
41
+ Router module that assigns tokens to experts.
42
+ """
43
+
44
+ def __init__(self, input_dim: int, num_experts: int, top_k: int = 2):
45
+ super().__init__()
46
+ self.input_dim = input_dim
47
+ self.num_experts = num_experts
48
+ self.top_k = top_k
49
+
50
+ # Routing layer
51
+ self.router = nn.Linear(input_dim, num_experts)
52
+
53
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
54
+ """
55
+ Route input tokens to experts.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim)
59
+
60
+ Returns:
61
+ Tuple containing:
62
+ - router_logits: Raw routing probabilities
63
+ - dispatch_tensor: One-hot tensor indicating expert assignment
64
+ - expert_indices: List of indices for each expert's tokens
65
+ """
66
+ batch_size, seq_len, _ = x.shape
67
+
68
+ # Flatten batch and sequence dimensions
69
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
70
+
71
+ # Compute routing probabilities
72
+ router_logits = self.router(x_flat) # (batch_size * seq_len, num_experts)
73
+
74
+ # Apply softmax to get probabilities
75
+ router_probs = F.softmax(router_logits, dim=-1)
76
+
77
+ # Get top-k experts for each token
78
+ top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
79
+
80
+ # Normalize top-k probabilities
81
+ top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
82
+
83
+ # Create dispatch tensor (one-hot representation of assignments)
84
+ dispatch_tensor = torch.zeros_like(router_probs)
85
+ token_indices = (
86
+ torch.arange(router_probs.size(0), device=router_probs.device)
87
+ .unsqueeze(1)
88
+ .expand(-1, self.top_k)
89
+ )
90
+ dispatch_tensor.scatter_(1, top_k_indices, top_k_probs) # .unsqueeze(-1))
91
+
92
+ # For each expert, get the indices of tokens routed to it
93
+ expert_indices = []
94
+ for expert_idx in range(self.num_experts):
95
+ # Get indices of tokens that have non-zero probability for this expert
96
+ indices = torch.nonzero(dispatch_tensor[:, expert_idx] > 0, as_tuple=True)[
97
+ 0
98
+ ]
99
+ expert_indices.append(indices)
100
+
101
+ return router_logits, dispatch_tensor, expert_indices
102
+
103
+
104
+ class Expert(nn.Module):
105
+ """
106
+ Individual expert module.
107
+ """
108
+
109
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
110
+ super().__init__()
111
+ self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
112
+ self.activation = nn.GELU()
113
+ self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False)
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ x = self.fc1(x)
117
+ x = self.activation(x)
118
+ x = self.fc2(x)
119
+ return x
120
+
121
+
122
+ class MixtureOfExperts(nn.Module):
123
+ """
124
+ Mixture of Experts layer with support for both manual looping and grouped GEMM.
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ input_dim: int,
130
+ hidden_dim: int,
131
+ output_dim: int,
132
+ num_experts: int,
133
+ top_k: int = 2,
134
+ use_mg_gemm: bool = False,
135
+ ):
136
+ super().__init__()
137
+ self.input_dim = input_dim
138
+ self.hidden_dim = hidden_dim
139
+ self.output_dim = output_dim
140
+ self.num_experts = num_experts
141
+ self.top_k = top_k
142
+ self.use_mg_gemm = use_mg_gemm and has_mg_gemm
143
+
144
+ # Router
145
+ self.router = Router(input_dim, num_experts, top_k)
146
+
147
+ # Create expert modules
148
+ if self.use_mg_gemm:
149
+ # For MG GEMM, we need a single weight tensor for all experts
150
+ # First layer (input -> hidden)
151
+ self.expert_fc1_weight = nn.Parameter(
152
+ torch.randn(num_experts * hidden_dim, input_dim) / math.sqrt(input_dim)
153
+ )
154
+ # self.expert_fc1_bias = nn.Parameter(torch.zeros(num_experts * hidden_dim))
155
+
156
+ # Second layer (hidden -> output)
157
+ self.expert_fc2_weight = nn.Parameter(
158
+ torch.randn(num_experts * output_dim, hidden_dim)
159
+ / math.sqrt(hidden_dim)
160
+ )
161
+ # self.expert_fc2_bias = nn.Parameter(torch.zeros(num_experts * output_dim))
162
+ else:
163
+ # For manual looping, create separate experts
164
+ self.experts = nn.ModuleList(
165
+ [Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
166
+ )
167
+
168
+ def forward_manual_loop(self, x: torch.Tensor) -> torch.Tensor:
169
+ """
170
+ Forward pass using manual looping over experts.
171
+ """
172
+ batch_size, seq_len, _ = x.shape
173
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
174
+
175
+ # Get routing information
176
+ router_logits, dispatch_tensor, expert_indices = self.router(x)
177
+
178
+ # Initialize output tensor
179
+ final_output = torch.zeros(
180
+ batch_size * seq_len, self.output_dim, device=x.device
181
+ )
182
+
183
+ # Process each expert
184
+ for expert_idx, indices in enumerate(expert_indices):
185
+ if indices.numel() > 0:
186
+ # Get tokens routed to this expert
187
+ expert_inputs = x_flat[indices] # (num_tokens_for_expert, input_dim)
188
+
189
+ # Process tokens through expert
190
+ expert_outputs = self.experts[expert_idx](
191
+ expert_inputs
192
+ ) # (num_tokens_for_expert, output_dim)
193
+
194
+ # Scale outputs by router probabilities
195
+ scaled_outputs = expert_outputs * dispatch_tensor[
196
+ indices, expert_idx
197
+ ].unsqueeze(1)
198
+
199
+ # Add to final output
200
+ final_output.index_add_(0, indices, scaled_outputs)
201
+
202
+ # Reshape back to original dimensions
203
+ output = final_output.reshape(batch_size, seq_len, self.output_dim)
204
+
205
+ return output, router_logits
206
+
207
+ def forward_mg_gemm(self, x: torch.Tensor) -> torch.Tensor:
208
+ batch_size, seq_len, _ = x.shape
209
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
210
+ total_tokens = batch_size * seq_len
211
+
212
+ # Get routing information
213
+ router_logits, dispatch_tensor, expert_indices = self.router(x)
214
+
215
+ # Get token counts for each expert
216
+ token_counts = [indices.numel() for indices in expert_indices]
217
+ m_sizes = torch.tensor(token_counts, dtype=torch.int32, device=x.device)
218
+
219
+ print(f"Token counts per expert: {token_counts}")
220
+ print(f"m_sizes: {m_sizes}")
221
+
222
+ # Create the combined input tensor
223
+ combined_input = torch.zeros(sum(token_counts), self.input_dim, device=x.device)
224
+
225
+ start_idx = 0
226
+ for expert_idx, indices in enumerate(expert_indices):
227
+ if indices.numel() > 0:
228
+ end_idx = start_idx + indices.numel()
229
+ combined_input[start_idx:end_idx] = x_flat[indices]
230
+ start_idx = end_idx
231
+
232
+ print(f"combined_input shape: {combined_input.shape}")
233
+
234
+ # First layer: input -> hidden
235
+ fc1_weight_reshaped = self.expert_fc1_weight.reshape(
236
+ self.num_experts, self.hidden_dim, self.input_dim
237
+ )
238
+ fc1_weight_combined = fc1_weight_reshaped.reshape(-1, self.input_dim)
239
+
240
+ print(f"fc1_weight_combined shape: {fc1_weight_combined.shape}")
241
+
242
+ # Run the grouped GEMM
243
+ hidden_outputs = grouped_gemm_forward(
244
+ combined_input, fc1_weight_combined, m_sizes
245
+ )
246
+
247
+ print(f"hidden_outputs shape after first GEMM: {hidden_outputs.shape}")
248
+
249
+ # Apply activation
250
+ hidden_outputs = F.gelu(hidden_outputs)
251
+
252
+ print(f"hidden_outputs shape after activation: {hidden_outputs.shape}")
253
+
254
+ # Second layer: hidden -> output
255
+ # Reshape hidden_outputs to match expected dimensions
256
+ reshaped_hidden_outputs = []
257
+ start_idx = 0
258
+
259
+ for expert_idx, count in enumerate(token_counts):
260
+ if count > 0:
261
+ end_idx = start_idx + count
262
+ # Take this expert's outputs and reshape to [count, hidden_dim]
263
+ expert_output = hidden_outputs[
264
+ start_idx:end_idx,
265
+ expert_idx * self.hidden_dim : (expert_idx + 1) * self.hidden_dim,
266
+ ]
267
+ reshaped_hidden_outputs.append(expert_output)
268
+ start_idx = end_idx
269
+
270
+ # Concatenate all reshaped outputs
271
+ hidden_outputs = torch.cat(reshaped_hidden_outputs, dim=0)
272
+
273
+ # Reshape expert weights for second layer
274
+ fc2_weight_reshaped = self.expert_fc2_weight.reshape(
275
+ self.num_experts, self.output_dim, self.hidden_dim
276
+ )
277
+ fc2_weight_combined = fc2_weight_reshaped.reshape(-1, self.hidden_dim)
278
+
279
+ print(f"fc2_weight_combined shape: {fc2_weight_combined.shape}")
280
+
281
+ # Run the second grouped GEMM
282
+ expert_outputs_combined = grouped_gemm_forward(
283
+ hidden_outputs, fc2_weight_combined, m_sizes
284
+ )
285
+
286
+ # Initialize final output tensor with correct shape
287
+ final_output = torch.zeros(total_tokens, self.output_dim, device=x.device)
288
+
289
+ # Distribute the outputs back to the original token positions
290
+ start_idx = 0
291
+ for expert_idx, indices in enumerate(expert_indices):
292
+ if indices.numel() > 0:
293
+ end_idx = start_idx + indices.numel()
294
+ # Get this expert's outputs
295
+ expert_outputs = expert_outputs_combined[start_idx:end_idx]
296
+
297
+ print(
298
+ f"Expert {expert_idx} - indices shape: {indices.shape}, expert_outputs shape: {expert_outputs.shape}"
299
+ )
300
+
301
+ # Scale outputs by router probabilities
302
+ scaled_outputs = expert_outputs * dispatch_tensor[
303
+ indices, expert_idx
304
+ ].unsqueeze(1)
305
+
306
+ # Ensure dimensions match before using index_add_
307
+ if scaled_outputs.shape[1] != final_output.shape[1]:
308
+ # print(
309
+ # f"Reshaping: Dimension mismatch: scaled_outputs {scaled_outputs.shape}, final_output {final_output.shape}"
310
+ # )
311
+ # Reshape if needed - make sure output_dim is correct
312
+ scaled_outputs = scaled_outputs[:, : self.output_dim]
313
+
314
+ # Add to final output
315
+ final_output.index_add_(0, indices, scaled_outputs)
316
+
317
+ start_idx = end_idx
318
+
319
+ # Reshape back to original dimensions
320
+ output = final_output.reshape(batch_size, seq_len, self.output_dim)
321
+
322
+ return output, router_logits
323
+
324
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
325
+ if self.use_mg_gemm and has_mg_gemm:
326
+ return self.forward_mg_gemm(x)
327
+ else:
328
+ return self.forward_manual_loop(x)
329
+
330
+
331
+ class MoEModel(nn.Module):
332
+ """
333
+ Simple model using MoE layers.
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ vocab_size: int,
339
+ embed_dim: int,
340
+ hidden_dim: int,
341
+ num_experts: int,
342
+ top_k: int = 2,
343
+ use_mg_gemm: bool = False,
344
+ ):
345
+ super().__init__()
346
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
347
+ self.moe_layer = MixtureOfExperts(
348
+ input_dim=embed_dim,
349
+ hidden_dim=hidden_dim,
350
+ output_dim=embed_dim,
351
+ num_experts=num_experts,
352
+ top_k=top_k,
353
+ use_mg_gemm=use_mg_gemm,
354
+ )
355
+ self.output_layer = nn.Linear(embed_dim, vocab_size)
356
+
357
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
358
+ # x shape: (batch_size, seq_len)
359
+ embedded = self.embedding(x) # (batch_size, seq_len, embed_dim)
360
+ moe_output, router_logits = self.moe_layer(
361
+ embedded
362
+ ) # (batch_size, seq_len, embed_dim)
363
+ logits = self.output_layer(moe_output) # (batch_size, seq_len, vocab_size)
364
+ return logits, router_logits
365
+
366
+
367
+ def compute_load_balancing_loss(
368
+ router_logits: torch.Tensor, num_experts: int
369
+ ) -> torch.Tensor:
370
+ """
371
+ Compute the load balancing loss for MoE training.
372
+
373
+ Args:
374
+ router_logits (torch.Tensor): Router logits of shape (batch_size * seq_len, num_experts)
375
+ num_experts (int): Number of experts
376
+
377
+ Returns:
378
+ torch.Tensor: Load balancing loss
379
+ """
380
+ # Get router probabilities
381
+ router_probs = F.softmax(
382
+ router_logits, dim=-1
383
+ ) # (batch_size * seq_len, num_experts)
384
+
385
+ # Compute fraction of tokens routed to each expert
386
+ # Sum across the batch dimension and normalize
387
+ router_probs_sum = router_probs.sum(dim=0) # (num_experts,)
388
+ router_probs_sum = router_probs_sum / router_probs_sum.sum()
389
+
390
+ # Compute the mean probability per expert
391
+ mean_prob = 1.0 / num_experts
392
+
393
+ # Compute the fraction of tokens routed to each expert
394
+ # The goal is to have uniform routing across experts
395
+ load_balancing_loss = num_experts * torch.sum(router_probs_sum * router_probs_sum)
396
+
397
+ return load_balancing_loss
398
+
399
+
400
+ def generate_sample_data(
401
+ batch_size: int, seq_len: int, vocab_size: int, device: str = "cuda"
402
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
403
+ """
404
+ Generate sample data for training.
405
+
406
+ Args:
407
+ batch_size (int): Batch size
408
+ seq_len (int): Sequence length
409
+ vocab_size (int): Vocabulary size
410
+ device (str): Device to use
411
+
412
+ Returns:
413
+ Tuple of input tokens and target tokens
414
+ """
415
+ # Generate random input tokens
416
+ inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
417
+
418
+ # Generate random target tokens
419
+ targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
420
+
421
+ return inputs, targets
422
+
423
+
424
+ def train_epoch(
425
+ model: nn.Module,
426
+ optimizer: torch.optim.Optimizer,
427
+ batch_size: int,
428
+ seq_len: int,
429
+ vocab_size: int,
430
+ num_batches: int,
431
+ device: str,
432
+ load_balance_coef: float = 0.01,
433
+ ) -> Dict[str, float]:
434
+ """
435
+ Train the model for one epoch.
436
+
437
+ Args:
438
+ model (nn.Module): Model to train
439
+ optimizer (torch.optim.Optimizer): Optimizer
440
+ batch_size (int): Batch size
441
+ seq_len (int): Sequence length
442
+ vocab_size (int): Vocabulary size
443
+ num_batches (int): Number of batches per epoch
444
+ device (str): Device to use
445
+ load_balance_coef (float): Coefficient for load balancing loss
446
+
447
+ Returns:
448
+ Dict containing training metrics
449
+ """
450
+ model.train()
451
+ total_loss = 0.0
452
+ total_acc = 0.0
453
+ start_time = time.time()
454
+
455
+ for i in range(num_batches):
456
+ # Generate sample data
457
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
458
+
459
+ # Forward pass
460
+ optimizer.zero_grad()
461
+ logits, router_logits = model(inputs)
462
+
463
+ # Compute loss
464
+ # Reshape for cross entropy loss
465
+ logits_flat = logits.reshape(-1, vocab_size)
466
+ targets_flat = targets.reshape(-1)
467
+
468
+ # Cross entropy loss
469
+ ce_loss = F.cross_entropy(logits_flat, targets_flat)
470
+
471
+ # Load balancing loss
472
+ lb_loss = compute_load_balancing_loss(
473
+ router_logits, model.moe_layer.num_experts
474
+ )
475
+
476
+ # Combined loss
477
+ loss = ce_loss + load_balance_coef * lb_loss
478
+
479
+ # Backward pass
480
+ loss.backward()
481
+ optimizer.step()
482
+
483
+ # Compute accuracy
484
+ preds = logits_flat.argmax(dim=-1)
485
+ correct = (preds == targets_flat).float().sum()
486
+ acc = correct / (batch_size * seq_len)
487
+
488
+ # Accumulate metrics
489
+ total_loss += loss.item()
490
+ total_acc += acc.item()
491
+
492
+ # Log progress
493
+ if (i + 1) % 10 == 0:
494
+ logging.info(
495
+ f"Batch {i + 1}/{num_batches} | "
496
+ f"Loss: {loss.item():.4f} | "
497
+ f"CE Loss: {ce_loss.item():.4f} | "
498
+ f"LB Loss: {lb_loss.item():.4f} | "
499
+ f"Acc: {acc.item():.4f}"
500
+ )
501
+
502
+ # Compute average metrics
503
+ avg_loss = total_loss / num_batches
504
+ avg_acc = total_acc / num_batches
505
+ epoch_time = time.time() - start_time
506
+
507
+ return {"loss": avg_loss, "acc": avg_acc, "time": epoch_time}
508
+
509
+
510
+ def evaluate(
511
+ model: nn.Module,
512
+ batch_size: int,
513
+ seq_len: int,
514
+ vocab_size: int,
515
+ num_batches: int,
516
+ device: str,
517
+ ) -> Dict[str, float]:
518
+ """
519
+ Evaluate the model.
520
+
521
+ Args:
522
+ model (nn.Module): Model to evaluate
523
+ batch_size (int): Batch size
524
+ seq_len (int): Sequence length
525
+ vocab_size (int): Vocabulary size
526
+ num_batches (int): Number of batches for evaluation
527
+ device (str): Device to use
528
+
529
+ Returns:
530
+ Dict containing evaluation metrics
531
+ """
532
+ model.eval()
533
+ total_loss = 0.0
534
+ total_acc = 0.0
535
+
536
+ with torch.no_grad():
537
+ for i in range(num_batches):
538
+ # Generate sample data
539
+ inputs, targets = generate_sample_data(
540
+ batch_size, seq_len, vocab_size, device
541
+ )
542
+
543
+ # Forward pass
544
+ logits, router_logits = model(inputs)
545
+
546
+ # Compute loss
547
+ logits_flat = logits.reshape(-1, vocab_size)
548
+ targets_flat = targets.reshape(-1)
549
+
550
+ # Cross entropy loss
551
+ loss = F.cross_entropy(logits_flat, targets_flat)
552
+
553
+ # Compute accuracy
554
+ preds = logits_flat.argmax(dim=-1)
555
+ correct = (preds == targets_flat).float().sum()
556
+ acc = correct / (batch_size * seq_len)
557
+
558
+ # Accumulate metrics
559
+ total_loss += loss.item()
560
+ total_acc += acc.item()
561
+
562
+ # Compute average metrics
563
+ avg_loss = total_loss / num_batches
564
+ avg_acc = total_acc / num_batches
565
+
566
+ return {"loss": avg_loss, "acc": avg_acc}
567
+
568
+
569
+ def measure_performance(
570
+ model: nn.Module,
571
+ batch_size: int,
572
+ seq_len: int,
573
+ vocab_size: int,
574
+ num_batches: int,
575
+ device: str,
576
+ ) -> Dict[str, float]:
577
+ """
578
+ Measure forward and backward pass performance.
579
+
580
+ Args:
581
+ model (nn.Module): Model to evaluate
582
+ batch_size (int): Batch size
583
+ seq_len (int): Sequence length
584
+ vocab_size (int): Vocabulary size
585
+ num_batches (int): Number of batches for measurement
586
+ device (str): Device to use
587
+
588
+ Returns:
589
+ Dict containing performance metrics
590
+ """
591
+ model.train()
592
+
593
+ # Create dummy optimizer
594
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
595
+
596
+ # Warmup
597
+ for _ in range(5):
598
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
599
+ logits, router_logits = model(inputs)
600
+ loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
601
+ loss.backward()
602
+ optimizer.zero_grad()
603
+
604
+ # Measure forward pass time
605
+ torch.cuda.synchronize()
606
+ forward_start = time.time()
607
+
608
+ for _ in range(num_batches):
609
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
610
+ with torch.no_grad():
611
+ logits, router_logits = model(inputs)
612
+
613
+ torch.cuda.synchronize()
614
+ forward_end = time.time()
615
+ forward_time = (forward_end - forward_start) / num_batches
616
+
617
+ # Measure backward pass time
618
+ torch.cuda.synchronize()
619
+ backward_start = time.time()
620
+
621
+ for _ in range(num_batches):
622
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
623
+ logits, router_logits = model(inputs)
624
+ loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
625
+ loss.backward()
626
+ optimizer.zero_grad()
627
+
628
+ torch.cuda.synchronize()
629
+ backward_end = time.time()
630
+ backward_time = (backward_end - backward_start) / num_batches
631
+
632
+ return {
633
+ "forward_time": forward_time * 1000, # Convert to ms
634
+ "backward_time": backward_time * 1000, # Convert to ms
635
+ "total_time": (forward_time + backward_time) * 1000, # Convert to ms
636
+ }
637
+
638
+
639
+ def compare_methods(args):
640
+ """
641
+ Compare manual looping and MG GEMM implementations.
642
+ """
643
+ device = torch.device(args.device)
644
+
645
+ # Create models
646
+ manual_model = MoEModel(
647
+ vocab_size=args.vocab_size,
648
+ embed_dim=args.embed_dim,
649
+ hidden_dim=args.hidden_dim,
650
+ num_experts=args.num_experts,
651
+ top_k=args.top_k,
652
+ use_mg_gemm=False,
653
+ ).to(device)
654
+
655
+ if has_mg_gemm:
656
+ mg_model = MoEModel(
657
+ vocab_size=args.vocab_size,
658
+ embed_dim=args.embed_dim,
659
+ hidden_dim=args.hidden_dim,
660
+ num_experts=args.num_experts,
661
+ top_k=args.top_k,
662
+ use_mg_gemm=True,
663
+ ).to(device)
664
+ else:
665
+ mg_model = None
666
+
667
+ # Measure performance
668
+ logging.info("Measuring performance of manual looping method...")
669
+ manual_perf = measure_performance(
670
+ manual_model,
671
+ args.batch_size,
672
+ args.seq_len,
673
+ args.vocab_size,
674
+ args.perf_batches,
675
+ device,
676
+ )
677
+
678
+ if mg_model is not None:
679
+ logging.info("Measuring performance of MG GEMM method...")
680
+ mg_perf = measure_performance(
681
+ mg_model,
682
+ args.batch_size,
683
+ args.seq_len,
684
+ args.vocab_size,
685
+ args.perf_batches,
686
+ device,
687
+ )
688
+ else:
689
+ mg_perf = {"forward_time": 0, "backward_time": 0, "total_time": 0}
690
+
691
+ # Log results
692
+ logging.info("\n===== Performance Comparison =====")
693
+ logging.info("Model Configuration:")
694
+ logging.info(f" - Batch Size: {args.batch_size}")
695
+ logging.info(f" - Sequence Length: {args.seq_len}")
696
+ logging.info(f" - Embed Dimension: {args.embed_dim}")
697
+ logging.info(f" - Hidden Dimension: {args.hidden_dim}")
698
+ logging.info(f" - Number of Experts: {args.num_experts}")
699
+ logging.info(f" - Top-K: {args.top_k}")
700
+ logging.info("")
701
+
702
+ logging.info("Manual Looping Method:")
703
+ logging.info(f" - Forward Time: {manual_perf['forward_time']:.2f} ms")
704
+ logging.info(f" - Backward Time: {manual_perf['backward_time']:.2f} ms")
705
+ logging.info(f" - Total Time: {manual_perf['total_time']:.2f} ms")
706
+ logging.info("")
707
+
708
+ if mg_model is not None:
709
+ logging.info("MG GEMM Method:")
710
+ logging.info(f" - Forward Time: {mg_perf['forward_time']:.2f} ms")
711
+ logging.info(f" - Backward Time: {mg_perf['backward_time']:.2f} ms")
712
+ logging.info(f" - Total Time: {mg_perf['total_time']:.2f} ms")
713
+ logging.info("")
714
+
715
+ # Calculate speedup
716
+ forward_speedup = (
717
+ manual_perf["forward_time"] / mg_perf["forward_time"]
718
+ if mg_perf["forward_time"] > 0
719
+ else 0
720
+ )
721
+ backward_speedup = (
722
+ manual_perf["backward_time"] / mg_perf["backward_time"]
723
+ if mg_perf["backward_time"] > 0
724
+ else 0
725
+ )
726
+ total_speedup = (
727
+ manual_perf["total_time"] / mg_perf["total_time"]
728
+ if mg_perf["total_time"] > 0
729
+ else 0
730
+ )
731
+
732
+ logging.info("Speedup (MG GEMM vs Manual):")
733
+ logging.info(f" - Forward Speedup: {forward_speedup:.2f}x")
734
+ logging.info(f" - Backward Speedup: {backward_speedup:.2f}x")
735
+ logging.info(f" - Total Speedup: {total_speedup:.2f}x")
736
+ else:
737
+ logging.info("MG GEMM method not available.")
738
+
739
+
740
+ def train_model(args):
741
+ """
742
+ Train an MoE model.
743
+ """
744
+ device = torch.device(args.device)
745
+
746
+ # Create model
747
+ model = MoEModel(
748
+ vocab_size=args.vocab_size,
749
+ embed_dim=args.embed_dim,
750
+ hidden_dim=args.hidden_dim,
751
+ num_experts=args.num_experts,
752
+ top_k=args.top_k,
753
+ use_mg_gemm=args.use_mg_gemm and has_mg_gemm,
754
+ ).to(device)
755
+
756
+ # Create optimizer
757
+ optimizer = optim.Adam(model.parameters(), lr=args.lr)
758
+
759
+ # Log model information
760
+ logging.info("Model configuration:")
761
+ logging.info(f" - Vocabulary Size: {args.vocab_size}")
762
+ logging.info(f" - Embedding Dimension: {args.embed_dim}")
763
+ logging.info(f" - Hidden Dimension: {args.hidden_dim}")
764
+ logging.info(f" - Number of Experts: {args.num_experts}")
765
+ logging.info(f" - Top-K: {args.top_k}")
766
+ logging.info(f" - Using MG GEMM: {args.use_mg_gemm and has_mg_gemm}")
767
+
768
+ # Training loop
769
+ for epoch in range(args.epochs):
770
+ logging.info(f"\nEpoch {epoch + 1}/{args.epochs}")
771
+
772
+ # Train
773
+ train_metrics = train_epoch(
774
+ model=model,
775
+ optimizer=optimizer,
776
+ batch_size=args.batch_size,
777
+ seq_len=args.seq_len,
778
+ vocab_size=args.vocab_size,
779
+ num_batches=args.train_batches,
780
+ device=device,
781
+ load_balance_coef=args.load_balance_coef,
782
+ )
783
+
784
+ # Evaluate
785
+ eval_metrics = evaluate(
786
+ model=model,
787
+ batch_size=args.batch_size,
788
+ seq_len=args.seq_len,
789
+ vocab_size=args.vocab_size,
790
+ num_batches=args.eval_batches,
791
+ device=device,
792
+ )
793
+
794
+ # Log metrics
795
+ logging.info(
796
+ f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['acc']:.4f}"
797
+ )
798
+ logging.info(
799
+ f"Eval Loss: {eval_metrics['loss']:.4f} | Eval Acc: {eval_metrics['acc']:.4f}"
800
+ )
801
+ logging.info(f"Epoch Time: {train_metrics['time']:.2f} seconds")
802
+
803
+
804
+ if __name__ == "__main__":
805
+ parser = argparse.ArgumentParser(description="Train MoE model")
806
+
807
+ # Model parameters
808
+ parser.add_argument("--vocab_size", type=int, default=10000, help="Vocabulary size")
809
+ parser.add_argument(
810
+ "--embed_dim", type=int, default=512, help="Embedding dimension"
811
+ )
812
+ parser.add_argument(
813
+ "--hidden_dim", type=int, default=1024, help="Hidden dimension in experts"
814
+ )
815
+ parser.add_argument("--num_experts", type=int, default=8, help="Number of experts")
816
+ parser.add_argument(
817
+ "--top_k", type=int, default=2, help="Top-k experts to route to"
818
+ )
819
+
820
+ # Training parameters
821
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
822
+ parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
823
+ parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
824
+ parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
825
+ parser.add_argument(
826
+ "--train_batches",
827
+ type=int,
828
+ default=100,
829
+ help="Number of training batches per epoch",
830
+ )
831
+ parser.add_argument(
832
+ "--eval_batches", type=int, default=20, help="Number of evaluation batches"
833
+ )
834
+ parser.add_argument(
835
+ "--perf_batches",
836
+ type=int,
837
+ default=50,
838
+ help="Number of batches for performance testing",
839
+ )
840
+ parser.add_argument(
841
+ "--load_balance_coef",
842
+ type=float,
843
+ default=0.01,
844
+ help="Load balancing loss coefficient",
845
+ )
846
+
847
+ # Runtime parameters
848
+ parser.add_argument(
849
+ "--device",
850
+ type=str,
851
+ default="cuda" if torch.cuda.is_available() else "cpu",
852
+ help="Device to use (cuda or cpu)",
853
+ )
854
+ parser.add_argument(
855
+ "--use_mg_gemm",
856
+ action="store_true",
857
+ help="Use MG GEMM implementation if available",
858
+ )
859
+ parser.add_argument(
860
+ "--compare",
861
+ action="store_true",
862
+ help="Compare manual and MG GEMM implementations",
863
+ )
864
+ parser.add_argument("--train", action="store_true", help="Train the model")
865
+
866
+ args = parser.parse_args()
867
+
868
+ # Check for CUDA
869
+ if args.device == "cuda" and not torch.cuda.is_available():
870
+ logging.warning("CUDA not available, using CPU instead.")
871
+ args.device = "cpu"
872
+
873
+ # Log basic information
874
+ logging.info(f"PyTorch version: {torch.__version__}")
875
+ logging.info(f"Device: {args.device}")
876
+ logging.info(f"MG GEMM available: {has_mg_gemm}")
877
+
878
+ # Run the requested action
879
+ if args.compare:
880
+ compare_methods(args)
881
+ elif args.train:
882
+ train_model(args)
883
+ else:
884
+ # Default to comparison if no action specified
885
+ compare_methods(args)
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # credit - flat index forward kernel is derived from FBGemm:
8
+ # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
9
+
10
+ # pyre-unsafe
11
+ import functools
12
+ import logging
13
+
14
+ import os
15
+ import sys
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import torch
19
+
20
+ import triton
21
+ import triton.language as tl
22
+ from triton import Config as TConfig
23
+
24
+ from triton.runtime import driver # @manual
25
+
26
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
27
+
28
+ from tma_autotuning import (
29
+ ALIGN_SIZE_M,
30
+ _NV_CONFIGS,
31
+ CudaUtils,
32
+ early_config_prune,
33
+ TmaDescriptorHelper,
34
+ )
35
+
36
+
37
+ # Configure logging
38
+ logging.basicConfig(
39
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
40
+ )
41
+
42
+ # ============== Start Triton Kernels ===============
43
+
44
+
45
+ @triton.autotune(
46
+ configs=_NV_CONFIGS,
47
+ key=["G", "M_BUCKET", "N", "K"],
48
+ prune_configs_by={"early_config_prune": early_config_prune},
49
+ )
50
+ @triton.jit
51
+ def _kernel_mg_forward_hopper(
52
+ a_desc_ptr,
53
+ b_desc_ptr,
54
+ c_ptr,
55
+ workspace,
56
+ m_sizes,
57
+ # problem sizes
58
+ G: tl.constexpr,
59
+ M_BUCKET: tl.constexpr,
60
+ N: tl.constexpr,
61
+ K: tl.constexpr,
62
+ # config
63
+ NUM_SMS: tl.constexpr,
64
+ TMA_SIZE: tl.constexpr,
65
+ USE_EPILOGUE_SUBTILING: tl.constexpr,
66
+ # tiles
67
+ BLOCK_SIZE_M: tl.constexpr,
68
+ BLOCK_SIZE_N: tl.constexpr,
69
+ BLOCK_SIZE_K: tl.constexpr,
70
+ ) -> None:
71
+ """
72
+ Flat index style forward kernel for Hopper.
73
+ For simplicity, we always use TMA Load and TMA Store
74
+ """
75
+ tbidx = tl.program_id(0) # thread block index
76
+
77
+ c_dtype = c_ptr.dtype.element_ty # output dtype
78
+
79
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store
80
+
81
+ M_end = 0
82
+ M_start = 0
83
+ processed_tiles = 0
84
+ # Size of individual weight matrix
85
+ n_size = N // G
86
+ n_start = 0
87
+
88
+ for g in range(G):
89
+ # Move down along groups
90
+ # reset to new M offset
91
+ M_start = M_end
92
+ m_size = tl.load(m_sizes + g)
93
+ M_end = M_start + m_size
94
+ n_start = n_size * g
95
+
96
+ if m_size > 0:
97
+ # Process this group
98
+
99
+ # Acquire hold on c_desc_ptr for TMA Store
100
+ tl.extra.cuda.experimental_device_tensormap_create2d(
101
+ desc_ptr=c_desc_ptr,
102
+ global_address=c_ptr + M_start * n_size,
103
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
104
+ global_size=[m_size, n_size],
105
+ element_ty=c_dtype,
106
+ )
107
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
108
+
109
+ # tiles for this group
110
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
111
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
112
+ group_num_tiles = num_m_tiles * num_n_tiles
113
+
114
+ while tbidx >= processed_tiles and tbidx < (
115
+ processed_tiles + group_num_tiles
116
+ ):
117
+ group_index = tbidx - processed_tiles
118
+
119
+ # columnwise
120
+ tile_m_index = group_index % num_m_tiles
121
+ tile_n_index = group_index // num_m_tiles
122
+
123
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
124
+
125
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
126
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
127
+ global_n_offset = (n_start + n_offset).to(tl.int32)
128
+
129
+ for k_offset in range(0, K, BLOCK_SIZE_K):
130
+ # input block [M,K]
131
+ a = tl._experimental_descriptor_load(
132
+ a_desc_ptr,
133
+ [m_offset, k_offset],
134
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
135
+ c_dtype,
136
+ )
137
+ # weight block [N, K]
138
+ b = tl._experimental_descriptor_load(
139
+ b_desc_ptr,
140
+ [global_n_offset, k_offset],
141
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
142
+ c_dtype,
143
+ )
144
+
145
+ accumulator += tl.dot(a, b.T)
146
+
147
+ # Store using TMA
148
+
149
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
150
+
151
+ if USE_EPILOGUE_SUBTILING:
152
+ acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
153
+ acc = tl.permute(acc, (0, 2, 1))
154
+ acc0, acc1 = tl.split(acc)
155
+ c0 = acc0.to(c_dtype)
156
+ tl._experimental_descriptor_store(
157
+ c_desc_ptr, c0, [m_offset, n_offset]
158
+ )
159
+ c1 = acc1.to(c_dtype)
160
+ tl._experimental_descriptor_store(
161
+ c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
162
+ )
163
+ else:
164
+ tl._experimental_descriptor_store(
165
+ c_desc_ptr,
166
+ accumulator.to(c_dtype),
167
+ [m_offset, n_offset],
168
+ )
169
+ # move to next tile in group
170
+ tbidx += NUM_SMS
171
+ # Update the total tiles count for the next group
172
+ processed_tiles += group_num_tiles
173
+
174
+
175
+ @triton.autotune(
176
+ configs=_NV_CONFIGS,
177
+ key=["G", "M_BUCKET", "N", "K"],
178
+ prune_configs_by={"early_config_prune": early_config_prune},
179
+ )
180
+ @triton.jit
181
+ def _kernel_mg_forward_tma(
182
+ a_desc_ptr,
183
+ b_desc_ptr,
184
+ c_ptr,
185
+ workspace,
186
+ m_sizes,
187
+ a_scale_ptr,
188
+ b_scale_ptr,
189
+ # problem sizes
190
+ G: tl.constexpr,
191
+ M_BUCKET: tl.constexpr,
192
+ N: tl.constexpr,
193
+ K: tl.constexpr,
194
+ # config
195
+ NUM_SMS: tl.constexpr,
196
+ USE_TMA_LOAD: tl.constexpr,
197
+ USE_TMA_STORE: tl.constexpr,
198
+ TMA_SIZE: tl.constexpr,
199
+ USE_FP8: tl.constexpr,
200
+ # tiles
201
+ BLOCK_SIZE_M: tl.constexpr,
202
+ BLOCK_SIZE_N: tl.constexpr,
203
+ BLOCK_SIZE_K: tl.constexpr,
204
+ ) -> None:
205
+ """
206
+ Flat index style forward kernel.
207
+ For simplicity, we always use TMA Load and TMA Store
208
+ """
209
+ tbidx = tl.program_id(0) # thread block index
210
+
211
+ c_dtype = c_ptr.dtype.element_ty
212
+
213
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
214
+
215
+ M_end = 0
216
+ processed_tiles = 0
217
+
218
+ for g in range(G):
219
+ # Move down along groups
220
+ # reset to new M offset
221
+ M_start = M_end
222
+ m_size = tl.load(m_sizes + g)
223
+ M_end = M_start + m_size
224
+
225
+ if m_size > 0:
226
+ # Process this group
227
+ n_size = N
228
+
229
+ # TMA Store prep
230
+ tl.extra.cuda.experimental_device_tensormap_create2d(
231
+ desc_ptr=c_desc_ptr,
232
+ global_address=c_ptr + M_start * N,
233
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
234
+ global_size=[m_size, n_size],
235
+ element_ty=c_dtype,
236
+ )
237
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
238
+
239
+ # tiles for this group
240
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
241
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
242
+ group_num_tiles = num_m_tiles * num_n_tiles
243
+
244
+ while tbidx >= processed_tiles and tbidx < (
245
+ processed_tiles + group_num_tiles
246
+ ):
247
+ group_index = tbidx - processed_tiles
248
+
249
+ tile_m_index = group_index % num_m_tiles
250
+ tile_n_index = group_index // num_m_tiles
251
+
252
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
253
+
254
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
255
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
256
+
257
+ for k_offset in range(0, K, BLOCK_SIZE_K):
258
+ # input block [M,K]
259
+ a = tl._experimental_descriptor_load(
260
+ a_desc_ptr,
261
+ [m_offset, k_offset],
262
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
263
+ c_dtype,
264
+ )
265
+ # weight block [N, K]
266
+ b = tl._experimental_descriptor_load(
267
+ b_desc_ptr,
268
+ [n_offset, k_offset],
269
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
270
+ c_dtype,
271
+ )
272
+
273
+ accumulator += tl.dot(a, b.T)
274
+
275
+ # Store using TMA
276
+
277
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
278
+ # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
279
+
280
+ tl._experimental_descriptor_store(
281
+ c_desc_ptr,
282
+ accumulator.to(c_dtype),
283
+ [m_offset, n_offset],
284
+ )
285
+
286
+ # Move to the next tile
287
+ tbidx += NUM_SMS
288
+ # Update the total tiles count for the next group
289
+ processed_tiles += group_num_tiles
290
+
291
+
292
+ @triton.autotune(
293
+ configs=_NV_CONFIGS,
294
+ key=["G", "M_BUCKET", "N", "K"],
295
+ prune_configs_by={"early_config_prune": early_config_prune},
296
+ )
297
+ @triton.jit
298
+ def _kernel_mg_forward_no_tma(
299
+ a_ptr,
300
+ b_ptr,
301
+ c_ptr,
302
+ workspace,
303
+ m_sizes,
304
+ # problem sizes
305
+ G: tl.constexpr,
306
+ M_BUCKET: tl.constexpr,
307
+ N: tl.constexpr,
308
+ K: tl.constexpr,
309
+ # config
310
+ NUM_SMS: tl.constexpr,
311
+ USE_TMA_LOAD: tl.constexpr,
312
+ USE_TMA_STORE: tl.constexpr,
313
+ TMA_SIZE: tl.constexpr,
314
+ # tiles
315
+ BLOCK_SIZE_M: tl.constexpr,
316
+ BLOCK_SIZE_N: tl.constexpr,
317
+ BLOCK_SIZE_K: tl.constexpr,
318
+ ) -> None:
319
+ """
320
+ Flat index style forward kernel.
321
+ For bc and Ampere, we never use TMA Load and TMA Store
322
+ """
323
+ tbidx = tl.program_id(0) # thread block index
324
+
325
+ c_dtype = c_ptr.dtype.element_ty
326
+ c_desc_ptr = None
327
+
328
+ M_end = 0
329
+ processed_tiles = 0
330
+
331
+ for g in range(G):
332
+ # Move down along groups
333
+ # reset to new M offset
334
+ M_start = M_end
335
+ m_size = tl.load(m_sizes + g)
336
+ M_end = M_start + m_size
337
+
338
+ if m_size > 0:
339
+ # Process this group
340
+ n_size = N
341
+
342
+ # tiles for this group
343
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
344
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
345
+ group_num_tiles = num_m_tiles * num_n_tiles
346
+
347
+ while tbidx >= processed_tiles and tbidx < (
348
+ processed_tiles + group_num_tiles
349
+ ):
350
+ group_index = tbidx - processed_tiles
351
+
352
+ tile_m_index = group_index % num_m_tiles
353
+ tile_n_index = group_index // num_m_tiles
354
+
355
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
356
+
357
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
358
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
359
+
360
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
361
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
362
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
363
+
364
+ a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
365
+ b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]
366
+
367
+ for k_offset in range(0, K, BLOCK_SIZE_K):
368
+ # Load with bounds checking
369
+ a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
370
+ b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
371
+
372
+ # Main matmul
373
+ accumulator += tl.dot(a, b.T)
374
+
375
+ # Update pointers for next block
376
+ a_ptrs += BLOCK_SIZE_K
377
+ b_ptrs += BLOCK_SIZE_K
378
+
379
+ # Store without TMA
380
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
381
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
382
+
383
+ c = accumulator.to(c_dtype)
384
+
385
+ tl.store(
386
+ c_ptr
387
+ + (M_start + offs_am[:, None]) * N # Row stride is N
388
+ + offs_bn[None, :], # Column offset
389
+ c,
390
+ mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
391
+ )
392
+ # Move to the next tile
393
+ tbidx += NUM_SMS
394
+ # Update the total tiles count for the next group
395
+ processed_tiles += group_num_tiles
396
+
397
+
398
+ """
399
+ Backward pass for grouped GEMM with Triton, where grouping is M*G
400
+ We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
401
+ """
402
+
403
+
404
+ # ---- dx flat linear indexed ----
405
+ @triton.autotune(
406
+ configs=_NV_CONFIGS,
407
+ key=["G", "M_BUCKET", "N", "K"],
408
+ prune_configs_by={"early_config_prune": early_config_prune},
409
+ )
410
+ @triton.jit
411
+ def _kernel_mg_dx_tma(
412
+ grad_output_desc_ptr, # [MG, N]
413
+ w_desc_ptr, # [N, K]
414
+ grad_input_ptr, # output grad_x [MG, K]
415
+ workspace, # for TMA store
416
+ m_sizes, # group sizes [G]
417
+ # problem sizes
418
+ G: tl.constexpr,
419
+ M_BUCKET: tl.constexpr,
420
+ N: tl.constexpr,
421
+ K: tl.constexpr,
422
+ # config
423
+ NUM_SMS: tl.constexpr,
424
+ USE_TMA_LOAD: tl.constexpr,
425
+ USE_TMA_STORE: tl.constexpr,
426
+ TMA_SIZE: tl.constexpr,
427
+ # tiles
428
+ BLOCK_SIZE_M: tl.constexpr,
429
+ BLOCK_SIZE_N: tl.constexpr,
430
+ BLOCK_SIZE_K: tl.constexpr,
431
+ ) -> None:
432
+ """
433
+ TMA-optimized kernel for computing gradients with respect to input (dx).
434
+ For the forward pass Y = X @ W.T, the backward for input is:
435
+ grad_X = grad_Y @ W
436
+
437
+ This maps to [MG, N] @ [N, K] -> [MG, K]
438
+
439
+ Key differences from forward:
440
+ 1. W is used directly and not transposed
441
+ 2. The reduction dimension is now N (not K)
442
+ 3. Output is [M, K] instead of [M, N]
443
+ """
444
+ tbidx = tl.program_id(0) # thread block index
445
+
446
+ c_dtype = grad_input_ptr.dtype.element_ty
447
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
448
+
449
+ M_end = 0
450
+ processed_tiles = 0
451
+
452
+ for g in range(G):
453
+ # Move down along groups - same as forward
454
+ M_start = M_end
455
+ m_size = tl.load(m_sizes + g)
456
+ M_end = M_start + m_size
457
+
458
+ if m_size > 0:
459
+ # Process this group
460
+ # tiles for this group - now producing [M, K] output
461
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
462
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
463
+ group_num_tiles = num_m_tiles * num_k_tiles
464
+
465
+ # TMA Store prep for [M, K] output
466
+ tl.extra.cuda.experimental_device_tensormap_create2d(
467
+ desc_ptr=c_desc_ptr,
468
+ global_address=grad_input_ptr + M_start * K,
469
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
470
+ global_size=[m_size, K],
471
+ element_ty=c_dtype,
472
+ )
473
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
474
+
475
+ while tbidx >= processed_tiles and tbidx < (
476
+ processed_tiles + group_num_tiles
477
+ ):
478
+ group_index = tbidx - processed_tiles
479
+
480
+ # Different tiling scheme for [M, K] output
481
+ tile_m_index = group_index % num_m_tiles
482
+ tile_k_index = group_index // num_m_tiles
483
+
484
+ # for grad_input block [M, K]
485
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
486
+
487
+ # Position in full matrix
488
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
489
+ k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
490
+
491
+ # reduce along N dimension (instead of K in forward)
492
+ for n_offset in range(0, N, BLOCK_SIZE_N):
493
+ # grad_output block [M, N]
494
+ grad_output = tl._experimental_descriptor_load(
495
+ grad_output_desc_ptr,
496
+ [m_offset, n_offset],
497
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
498
+ c_dtype,
499
+ )
500
+
501
+ # weight block [N, K] - no transpose needed
502
+ w = tl._experimental_descriptor_load(
503
+ w_desc_ptr,
504
+ [n_offset, k_offset],
505
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
506
+ c_dtype,
507
+ )
508
+
509
+ # grad_x = grad_output @ w
510
+ # reducing along N dimension
511
+ accumulator += tl.dot(grad_output, w)
512
+
513
+ # Store using TMA
514
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
515
+ # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
516
+
517
+ tl._experimental_descriptor_store(
518
+ c_desc_ptr,
519
+ accumulator.to(c_dtype),
520
+ [m_offset, k_offset],
521
+ )
522
+
523
+ # Move to the next tile
524
+ tbidx += NUM_SMS
525
+
526
+ # Update the total tiles count for the next group
527
+ processed_tiles += group_num_tiles
528
+
529
+
530
+ # ---- dw flat linear indexed ----
531
+
532
+
533
+ @triton.autotune(
534
+ configs=_NV_CONFIGS,
535
+ key=["G", "M_BUCKET", "N", "K"],
536
+ prune_configs_by={"early_config_prune": early_config_prune},
537
+ )
538
+ @triton.jit
539
+ def _kernel_mg_dw_tma(
540
+ x_desc_ptr, # input descriptor [M_total, K]
541
+ grad_output_desc_ptr, # grad_output descriptor [M_total, N]
542
+ grad_weight_ptr, # output grad_w [N, K]
543
+ workspace, # workspace for TMA store
544
+ m_sizes, # group sizes [G]
545
+ # problem sizes
546
+ G: tl.constexpr,
547
+ M_BUCKET: tl.constexpr,
548
+ N: tl.constexpr,
549
+ K: tl.constexpr,
550
+ # config
551
+ NUM_SMS: tl.constexpr,
552
+ USE_TMA_LOAD: tl.constexpr,
553
+ USE_TMA_STORE: tl.constexpr,
554
+ TMA_SIZE: tl.constexpr,
555
+ # tiles
556
+ BLOCK_SIZE_N: tl.constexpr,
557
+ BLOCK_SIZE_K: tl.constexpr,
558
+ BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension
559
+ ) -> None:
560
+ """
561
+ Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
562
+ Uses flat index structure similar to forward.
563
+
564
+ For the forward pass Y = X @ W.T,
565
+ the backward for weights is:
566
+ grad_W = grad_Y.T @ X
567
+
568
+ Where:
569
+ - grad_Y is [MG, N]
570
+ - X is [MG, K]
571
+ - grad_W is [N, K]
572
+ - we return [N,K]
573
+ """
574
+ # Get thread block index l
575
+ tbidx = tl.program_id(0)
576
+
577
+ # Get output data type
578
+ c_dtype = grad_weight_ptr.dtype.element_ty
579
+
580
+ # Calculate number of output tiles
581
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
582
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
583
+ total_output_tiles = num_n_tiles * num_k_tiles
584
+
585
+ # Process tiles in strided manner across SMs
586
+ for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
587
+ # Calculate tile indices
588
+ tile_n_idx = tile_idx % num_n_tiles
589
+ tile_k_idx = tile_idx // num_n_tiles
590
+
591
+ # Calculate global offsets
592
+ n_offset = tile_n_idx * BLOCK_SIZE_N
593
+ k_offset = tile_k_idx * BLOCK_SIZE_K
594
+
595
+ # Initialize accumulator for this output tile [N, K]
596
+ accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
597
+
598
+ # Process each group
599
+ M_end = 0
600
+ for g in range(G):
601
+ # Get group boundaries
602
+ M_start = M_end
603
+ m_size = tl.load(m_sizes + g)
604
+ M_end = M_start + m_size
605
+
606
+ # Only process if group is non-empty
607
+ if m_size > 0:
608
+ # Process this group in chunks along the M dimension
609
+ for m_offset in range(0, m_size, BLOCK_SIZE_M):
610
+ # Calculate actual block size (handling boundary)
611
+ m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)
612
+
613
+ # Only process if we have actual work to do
614
+ if m_block_size > 0:
615
+ # Global offset for this chunk
616
+ m_global_offset = M_start + m_offset
617
+
618
+ if USE_TMA_LOAD:
619
+ # Load input chunk [M_chunk, K] using TMA
620
+ x_block = tl._experimental_descriptor_load(
621
+ x_desc_ptr,
622
+ [m_global_offset, k_offset],
623
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
624
+ c_dtype,
625
+ )
626
+
627
+ # Load grad_output chunk [M_chunk, N] using TMA
628
+ grad_output_block = tl._experimental_descriptor_load(
629
+ grad_output_desc_ptr,
630
+ [m_global_offset, n_offset],
631
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
632
+ c_dtype,
633
+ )
634
+
635
+ # Apply masks for valid regions
636
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
637
+ m_mask = offs_m < m_block_size
638
+
639
+ # Zero out invalid elements
640
+ x_block = tl.where(m_mask[:, None], x_block, 0.0)
641
+ grad_output_block = tl.where(
642
+ m_mask[:, None], grad_output_block, 0.0
643
+ )
644
+ else:
645
+ # Manual load with bounds checking
646
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
647
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
648
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
649
+
650
+ # Create masks
651
+ m_mask = offs_m < m_block_size
652
+ n_mask = offs_n < N - n_offset
653
+ k_mask = offs_k < K - k_offset
654
+
655
+ # Combined masks
656
+ mk_mask = m_mask[:, None] & k_mask[None, :]
657
+ mn_mask = m_mask[:, None] & n_mask[None, :]
658
+
659
+ # Global offsets for loading
660
+ m_global_offs = m_global_offset + offs_m
661
+
662
+ # Load x block [M_chunk, K]
663
+ x_block = tl.load(
664
+ x_desc_ptr
665
+ + m_global_offs[:, None] * K
666
+ + (k_offset + offs_k)[None, :],
667
+ mask=mk_mask,
668
+ other=0.0,
669
+ )
670
+
671
+ # Load grad_output block [M_chunk, N]
672
+ grad_output_block = tl.load(
673
+ grad_output_desc_ptr
674
+ + m_global_offs[:, None] * N
675
+ + (n_offset + offs_n)[None, :],
676
+ mask=mn_mask,
677
+ other=0.0,
678
+ )
679
+
680
+ # Compute partial contribution: grad_W += grad_Y.T @ X
681
+ # transpose grad_output for the matmul
682
+ contribution = tl.dot(
683
+ grad_output_block.to(tl.float32).T, # [N, M_chunk]
684
+ x_block.to(tl.float32), # [M_chunk, K]
685
+ )
686
+
687
+ # Accumulate
688
+ accumulator += contribution
689
+
690
+ # Store the result
691
+ if USE_TMA_STORE:
692
+ # Store using TMA
693
+ tl._experimental_descriptor_store(
694
+ workspace, # TMA store descriptor
695
+ accumulator.to(c_dtype),
696
+ [n_offset, k_offset],
697
+ )
698
+ else:
699
+ # Manual store with bounds checking
700
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
701
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
702
+
703
+ # Create masks for bounds checking
704
+ n_mask = offs_n < N - n_offset
705
+ k_mask = offs_k < K - k_offset
706
+ output_mask = n_mask[:, None] & k_mask[None, :]
707
+
708
+ # Store the result
709
+ tl.store(
710
+ grad_weight_ptr
711
+ + (n_offset + offs_n)[:, None] * K
712
+ + (k_offset + offs_k)[None, :],
713
+ accumulator.to(c_dtype),
714
+ mask=output_mask,
715
+ )
716
+
717
+
718
+ # ======== End Triton kernels ========
719
+
720
+ # ======== Triton wrapper functions ========
721
+
722
+ # ----- main forward pass wrapper -----
723
+
724
+
725
+ def grouped_gemm_forward(
726
+ x: torch.Tensor,
727
+ w: torch.Tensor,
728
+ m_sizes: torch.Tensor,
729
+ tma_size: int = 128,
730
+ ) -> torch.Tensor:
731
+ """
732
+ M*G style grouped GEMM with TMA and Float8 support.
733
+ # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.
734
+
735
+ """
736
+ if not CudaUtils.verify_tma():
737
+ raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
738
+
739
+ G = m_sizes.shape[0]
740
+
741
+ assert x.is_contiguous()
742
+ assert w.is_contiguous()
743
+ assert m_sizes.is_contiguous()
744
+
745
+ # Total input size is now [M_total, K] where M_total is the sum of all group sizes
746
+ M_total, K = x.shape
747
+ N = w.shape[0] # N is now the same for all groups
748
+
749
+ assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
750
+
751
+ # Verify that all group sizes are multiples of ALIGN_SIZE_M
752
+ # This check is commented out because it will involve a GPU-CPU sync
753
+ # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"
754
+
755
+ # Create output tensor with correct shape [M_total, N]
756
+ y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
757
+
758
+ if M_total == 0:
759
+ return y
760
+
761
+ NUM_SMS = CudaUtils.get_num_sms()
762
+ USE_TMA_LOAD = True
763
+ USE_TMA_STORE = True
764
+ USE_EPILOGUE_SUBTILING = False
765
+
766
+ # TMA descriptor helper
767
+ desc_helper = None
768
+ desc_x = x
769
+ desc_w = w
770
+ workspace = None
771
+
772
+ if USE_TMA_LOAD:
773
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
774
+ desc_helper.init_tma_descriptor("x")
775
+ desc_helper.init_tma_descriptor("w")
776
+ desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
777
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
778
+
779
+ if USE_TMA_STORE:
780
+ workspace = torch.empty(
781
+ NUM_SMS * desc_helper.tma_size,
782
+ device=x.device,
783
+ dtype=torch.uint8,
784
+ )
785
+
786
+ def grid(META):
787
+ if USE_TMA_LOAD:
788
+ nonlocal desc_helper
789
+ desc_helper.fill_2d_tma_descriptor(
790
+ "x",
791
+ x.data_ptr(),
792
+ M_total,
793
+ K,
794
+ META["BLOCK_SIZE_M"],
795
+ META["BLOCK_SIZE_K"],
796
+ x.element_size(),
797
+ )
798
+
799
+ desc_helper.fill_2d_tma_descriptor(
800
+ "w",
801
+ w.data_ptr(),
802
+ N,
803
+ K,
804
+ META["BLOCK_SIZE_N"],
805
+ META["BLOCK_SIZE_K"],
806
+ w.element_size(),
807
+ )
808
+ return (NUM_SMS,)
809
+
810
+ M_BUCKET = triton.next_power_of_2(M_total)
811
+
812
+ _kernel_mg_forward_hopper[grid](
813
+ desc_x,
814
+ desc_w,
815
+ y,
816
+ workspace,
817
+ m_sizes,
818
+ G,
819
+ M_BUCKET,
820
+ N,
821
+ K,
822
+ NUM_SMS,
823
+ TMA_SIZE=tma_size,
824
+ USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
825
+ )
826
+
827
+ return y
828
+
829
+
830
+ # ======== Improved Backward =============
831
+ def grouped_gemm_backward(
832
+ grad_output: torch.Tensor,
833
+ x: torch.Tensor,
834
+ w: torch.Tensor,
835
+ m_sizes: torch.Tensor,
836
+ use_tma: bool = True,
837
+ tma_size: int = 128,
838
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
839
+ """
840
+ Unified backward pass for grouped GeMM with M*G grouping.
841
+ Uses optimized TMA-based implementations for both dx and dw when available.
842
+
843
+ Args:
844
+ grad_output: Gradient of output, shape [M_total, N]
845
+ x: Input tensor from forward pass, shape [M_total, K]
846
+ w: Weight tensor from forward pass, shape [N, K]
847
+ m_sizes: Group sizes tensor, shape [G]
848
+ use_tma: Whether to try using TMA acceleration (if available)
849
+ tma_size: Size of TMA descriptor in bytes
850
+
851
+
852
+ Returns:
853
+ Tuple of gradients with respect to x and w: (grad_x, grad_w)
854
+ """
855
+ logging.info("Starting unified grouped_gemm_backward")
856
+
857
+ # do this once, seems expensive
858
+ NUM_SMS = CudaUtils.get_num_sms()
859
+
860
+ # Basic validation
861
+ G = m_sizes.shape[0]
862
+ M_total, K_x = x.shape
863
+ M_grad, N = grad_output.shape
864
+ N_w, K_w = w.shape
865
+
866
+ # Check dimensions
867
+ if K_x != K_w:
868
+ raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
869
+ if M_total != M_grad:
870
+ raise ValueError(
871
+ f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
872
+ )
873
+
874
+ # Check total M matches sum of group sizes
875
+ sum_m_sizes = m_sizes.sum().item()
876
+ if M_total != sum_m_sizes:
877
+ raise ValueError(
878
+ f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
879
+ )
880
+
881
+ # Make sure inputs are contiguous
882
+ grad_output = grad_output.contiguous()
883
+ x = x.contiguous()
884
+ w = w.contiguous()
885
+ m_sizes = m_sizes.contiguous()
886
+
887
+ # Check TMA support
888
+ can_use_tma = use_tma and CudaUtils.verify_tma()
889
+ if use_tma and not can_use_tma:
890
+ logging.info("TMA requested but not supported on this device")
891
+ use_tma = False
892
+
893
+ # Compute grad_x using flat linear implementation
894
+ try:
895
+ logging.info(f"Computing grad_x with flat linear kernel")
896
+
897
+ # Use TMA-optimized implementation
898
+ grad_x = grouped_gemm_dx_tma(
899
+ grad_output=grad_output,
900
+ w=w,
901
+ m_sizes=m_sizes,
902
+ num_sms=NUM_SMS,
903
+ tma_size=tma_size,
904
+ )
905
+
906
+ except Exception as e:
907
+ logging.error(f"Error in grad_x computation: {e}")
908
+ raise
909
+
910
+ # Compute grad_w using flat linear style implementation
911
+ try:
912
+ logging.info(f"Computing grad_w with flat linear kernel")
913
+
914
+ grad_w = grouped_gemm_dw_tma(
915
+ x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
916
+ )
917
+ except Exception as e:
918
+ logging.error(f"Error in grad_w computation: {e}")
919
+ raise
920
+
921
+ return grad_x, grad_w
922
+
923
+
924
+ # ----- dx backward pass wrapper -----
925
+
926
+
927
+ def grouped_gemm_dx_tma(
928
+ grad_output: torch.Tensor,
929
+ w: torch.Tensor,
930
+ m_sizes: torch.Tensor,
931
+ num_sms: int = 132,
932
+ tma_size: int = 128,
933
+ ) -> torch.Tensor:
934
+ """
935
+ Optimized backward pass wrapper for computing gradient with respect to input (dx)
936
+ using TMA patterns similar to the forward pass.
937
+
938
+ Args:
939
+ grad_output: Gradient of output, shape [M_total, N]
940
+ w: Weight tensor, shape [N, K]
941
+ m_sizes: Group sizes tensor, shape [G]
942
+ tma_size: Size of TMA descriptor
943
+ # using_fp8: Whether to use FP8 quantization
944
+ # grad_output_scale: Scale for grad_output in FP8 mode
945
+ # w_scale: Scale for w in FP8 mode
946
+
947
+ Returns:
948
+ grad_x: Gradient with respect to x, shape [M_total, K]
949
+ """
950
+ """
951
+ Optimized backward pass for computing gradient with respect to input (dx)
952
+ using TMA patterns similar to the forward pass.
953
+
954
+ Args:
955
+ grad_output: Gradient of output, shape [M_total, N]
956
+ w: Weight tensor, shape [N, K]
957
+ m_sizes: Group sizes tensor, shape [G]
958
+ tma_size: Size of TMA descriptor
959
+ using_fp8: Whether to use FP8 quantization
960
+ # grad_output_scale: Scale for grad_output in FP8 mode
961
+ # w_scale: Scale for w in FP8 mode
962
+
963
+ Returns:
964
+ grad_x: Gradient with respect to x, shape [M_total, K]
965
+ """
966
+ if not CudaUtils.verify_tma():
967
+ raise NotImplementedError("Optimized dx computation requires TMA support")
968
+
969
+ G = m_sizes.shape[0]
970
+
971
+ assert grad_output.is_contiguous()
972
+ assert w.is_contiguous()
973
+ assert m_sizes.is_contiguous()
974
+
975
+ M_total, N_grad = grad_output.shape
976
+ N_w, K = w.shape
977
+
978
+ # Check dimensions
979
+ assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"
980
+
981
+ # Verify that the sum of m_sizes matches M_total
982
+ sum_m_sizes = m_sizes.sum().item()
983
+ assert (
984
+ M_total == sum_m_sizes
985
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
986
+
987
+ # Create output tensor (grad_x) with shape [M_total, K]
988
+ grad_x = torch.empty(
989
+ (M_total, K), device=grad_output.device, dtype=grad_output.dtype
990
+ )
991
+
992
+ NUM_SMS = num_sms # CudaUtils.get_num_sms()
993
+ USE_TMA_LOAD = True
994
+ USE_TMA_STORE = True
995
+
996
+ # Set up TMA descriptors
997
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
998
+ desc_helper.init_tma_descriptor("grad_output")
999
+ desc_helper.init_tma_descriptor("w")
1000
+ desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
1001
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
1002
+
1003
+ # Allocate workspace for TMA store
1004
+ workspace = torch.empty(
1005
+ NUM_SMS * desc_helper.tma_size,
1006
+ device=grad_output.device,
1007
+ dtype=torch.uint8,
1008
+ )
1009
+
1010
+ def grid(META):
1011
+ # Fill TMA descriptors with appropriate dimensions
1012
+ desc_helper.fill_2d_tma_descriptor(
1013
+ "grad_output",
1014
+ grad_output.data_ptr(),
1015
+ M_total,
1016
+ N_grad,
1017
+ META["BLOCK_SIZE_M"],
1018
+ META["BLOCK_SIZE_N"],
1019
+ grad_output.element_size(),
1020
+ )
1021
+
1022
+ desc_helper.fill_2d_tma_descriptor(
1023
+ "w",
1024
+ w.data_ptr(),
1025
+ N_w,
1026
+ K,
1027
+ META["BLOCK_SIZE_N"],
1028
+ META["BLOCK_SIZE_K"],
1029
+ w.element_size(),
1030
+ )
1031
+ return (NUM_SMS,)
1032
+
1033
+ M_BUCKET = triton.next_power_of_2(M_total)
1034
+
1035
+ # Launch the flat linear kernel for computing grad_x
1036
+ _kernel_mg_dx_tma[grid](
1037
+ desc_grad_output,
1038
+ desc_w,
1039
+ grad_x,
1040
+ workspace,
1041
+ m_sizes,
1042
+ G,
1043
+ M_BUCKET,
1044
+ N_grad, # N dimension is now the reduction dimension
1045
+ K,
1046
+ NUM_SMS,
1047
+ USE_TMA_LOAD,
1048
+ USE_TMA_STORE,
1049
+ TMA_SIZE=tma_size,
1050
+ )
1051
+
1052
+ return grad_x
1053
+
1054
+
1055
+ # ======== dw wrapper function ==========
1056
+
1057
+
1058
+ def grouped_gemm_dw_tma(
1059
+ x: torch.Tensor,
1060
+ grad_output: torch.Tensor,
1061
+ m_sizes: torch.Tensor,
1062
+ num_sms: int = 132,
1063
+ tma_size: int = 128,
1064
+ ) -> torch.Tensor:
1065
+ """
1066
+ Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
1067
+ For the forward pass Y = X @ W.T, the backward for weights is:
1068
+ grad_W = grad_Y.T @ X
1069
+
1070
+ Args:
1071
+ x: Input tensor, shape [M_total, K]
1072
+ grad_output: Gradient of output, shape [M_total, N]
1073
+ m_sizes: Group sizes tensor, shape [G]
1074
+ tma_size: Size of TMA descriptor in bytes
1075
+
1076
+
1077
+ Returns:
1078
+ grad_w: Gradient with respect to weights, shape [N, K]
1079
+ """
1080
+ # Check TMA support
1081
+ has_tma_support = CudaUtils.verify_tma()
1082
+
1083
+ # Get group count
1084
+ G = m_sizes.shape[0]
1085
+
1086
+ # Ensure contiguous tensors
1087
+ x = x.contiguous()
1088
+ grad_output = grad_output.contiguous()
1089
+ m_sizes = m_sizes.contiguous()
1090
+
1091
+ # Get dimensions
1092
+ M_total, K_x = x.shape
1093
+ M_grad, N = grad_output.shape
1094
+
1095
+ # Check dimensions
1096
+ assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"
1097
+
1098
+ # Verify that the sum of m_sizes matches M_total
1099
+ sum_m_sizes = m_sizes.sum().item()
1100
+ assert (
1101
+ sum_m_sizes == M_total
1102
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
1103
+
1104
+ # Create output tensor (grad_w) with shape [N, K]
1105
+ grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)
1106
+
1107
+ NUM_SMS = num_sms
1108
+
1109
+ # TODO - hardcoded for now...but should set TMA flags based on hardware support
1110
+ USE_TMA_LOAD = True # has_tma_support
1111
+ USE_TMA_STORE = True # has_tma_support
1112
+
1113
+ # Set up TMA descriptors or direct pointers
1114
+ if USE_TMA_LOAD or USE_TMA_STORE:
1115
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
1116
+
1117
+ if USE_TMA_LOAD:
1118
+ desc_helper.init_tma_descriptor("x")
1119
+ desc_helper.init_tma_descriptor("grad_output")
1120
+ x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
1121
+ grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
1122
+ "grad_output"
1123
+ )
1124
+ else:
1125
+ x_desc = x
1126
+ grad_output_desc = grad_output
1127
+
1128
+ if USE_TMA_STORE:
1129
+ desc_helper.init_tma_descriptor("grad_w")
1130
+ workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
1131
+ else:
1132
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1133
+ else:
1134
+ # If not using TMA, just use the tensors directly
1135
+ x_desc = x
1136
+ grad_output_desc = grad_output
1137
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1138
+
1139
+ # M_BUCKET for grid size
1140
+ M_BUCKET = triton.next_power_of_2(M_total)
1141
+
1142
+ # Define grid for kernel launch
1143
+ def grid(META):
1144
+ if USE_TMA_LOAD or USE_TMA_STORE:
1145
+
1146
+ if USE_TMA_LOAD:
1147
+ desc_helper.fill_2d_tma_descriptor(
1148
+ "x",
1149
+ x.data_ptr(),
1150
+ M_total,
1151
+ K_x,
1152
+ META["BLOCK_SIZE_M"],
1153
+ META["BLOCK_SIZE_K"],
1154
+ x.element_size(),
1155
+ )
1156
+
1157
+ desc_helper.fill_2d_tma_descriptor(
1158
+ "grad_output",
1159
+ grad_output.data_ptr(),
1160
+ M_total,
1161
+ N,
1162
+ META["BLOCK_SIZE_M"],
1163
+ META["BLOCK_SIZE_N"],
1164
+ grad_output.element_size(),
1165
+ )
1166
+
1167
+ if USE_TMA_STORE:
1168
+ desc_helper.fill_2d_tma_descriptor(
1169
+ "grad_w",
1170
+ grad_w.data_ptr(),
1171
+ N,
1172
+ K_x,
1173
+ META["BLOCK_SIZE_N"],
1174
+ META["BLOCK_SIZE_K"],
1175
+ grad_w.element_size(),
1176
+ )
1177
+
1178
+ # Return grid size - one block per SM for balanced work distribution
1179
+ return (NUM_SMS,)
1180
+
1181
+ # Launch the optimized kernel
1182
+ _kernel_mg_dw_tma[grid](
1183
+ x_desc,
1184
+ grad_output_desc,
1185
+ grad_w,
1186
+ workspace,
1187
+ m_sizes,
1188
+ G,
1189
+ M_BUCKET,
1190
+ N,
1191
+ K_x,
1192
+ NUM_SMS,
1193
+ USE_TMA_LOAD,
1194
+ USE_TMA_STORE,
1195
+ TMA_SIZE=tma_size,
1196
+ )
1197
+
1198
+ return grad_w
1199
+
1200
+
1201
+ # ======== End Backwards Wrapper Functions =============
1202
+
1203
+ # ======== PyTorch wrapper functions ========
1204
+
1205
+
1206
+ class GroupedGEMM_mg(torch.autograd.Function):
1207
+ """
1208
+ Autograd function for GroupedGEMM with M*G grouping.
1209
+ Supports both standard and FP8 quantized operations.
1210
+ """
1211
+
1212
+ @staticmethod
1213
+ def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
1214
+ """
1215
+ Forward pass of GroupedGEMM.
1216
+
1217
+ Args:
1218
+ x: Input tensor, shape [M_total, K]
1219
+ w: Weight tensor, shape [N, K]
1220
+ m_sizes: Tensor of shape [G] containing the size of each group
1221
+ use_tma: Whether to try using TMA acceleration (if available)
1222
+ tma_size: Size of TMA descriptor in bytes
1223
+ using_fp8: Whether to use FP8 quantization
1224
+
1225
+ Returns:
1226
+ Output tensor, shape [M_total, N]
1227
+ """
1228
+
1229
+ # Use regular forward without quantization
1230
+ output = grouped_gemm_forward(
1231
+ x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
1232
+ )
1233
+
1234
+ # Save inputs and parameters for backward pass
1235
+ ctx.save_for_backward(x, w, m_sizes)
1236
+ ctx.use_tma = use_tma
1237
+ ctx.tma_size = tma_size
1238
+
1239
+ ctx.save_for_backward(x, w, m_sizes)
1240
+
1241
+ return output
1242
+
1243
+ @staticmethod
1244
+ def backward(ctx, grad_output):
1245
+ """
1246
+ Backward pass of M*G GroupedGEMM.
1247
+
1248
+ Args:
1249
+ grad_output: Gradient of output, shape [M_total, N]
1250
+
1251
+ Returns:
1252
+ Tuple of gradients:
1253
+ - grad_x: Gradient with respect to x, shape [M_total, K]
1254
+ - grad_w: Gradient with respect to w, shape [N, K]
1255
+ - None: Gradient with respect to m_sizes (not differentiable)
1256
+ - None: Gradient with respect to use_tma (not differentiable)
1257
+ - None: Gradient with respect to tma_size (not differentiable)
1258
+
1259
+ """
1260
+ # Retrieve saved tensors and parameters
1261
+
1262
+ x, w, m_sizes = ctx.saved_tensors
1263
+
1264
+ use_tma = ctx.use_tma
1265
+ tma_size = ctx.tma_size
1266
+
1267
+ # Compute gradients using the unified implementation
1268
+ grad_x, grad_w = grouped_gemm_backward(
1269
+ grad_output=grad_output,
1270
+ x=x,
1271
+ w=w,
1272
+ m_sizes=m_sizes,
1273
+ use_tma=use_tma,
1274
+ tma_size=tma_size,
1275
+ )
1276
+
1277
+ # Return gradients for all inputs (None for non-differentiable parameters)
1278
+ return grad_x, grad_w, None, None
1279
+
1280
+
1281
+ def mg_grouped_gemm(
1282
+ x: torch.Tensor,
1283
+ w: torch.Tensor,
1284
+ m_sizes: torch.Tensor,
1285
+ use_tma: bool = True,
1286
+ tma_size: int = 128,
1287
+ using_fp8: bool = False,
1288
+ ) -> torch.Tensor:
1289
+ """
1290
+ Unified differentiable grouped GEMM operation for M*G grouped GEMM.
1291
+ Supports both standard precision and FP8 quantized operations.
1292
+
1293
+ Args:
1294
+ x: Input tensor, shape [M_total, K]
1295
+ w: Weight tensor, shape [N, K]
1296
+ m_sizes: Tensor of shape [G] containing the size of each group
1297
+ use_tma: Whether to try using TMA acceleration (if available)
1298
+ tma_size: Size of TMA descriptor in bytes
1299
+ using_fp8: Whether to use FP8 quantization
1300
+
1301
+ Returns:
1302
+ Output tensor, shape [M_total, N]
1303
+ """
1304
+ return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # credit - TMAHelper class, AutoTuning are derived from FBGemm:
8
+ # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
9
+
10
+ # pyre-unsafe
11
+ import functools
12
+
13
+ import os
14
+ import sys
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ import torch
18
+
19
+ import triton
20
+ import triton.language as tl
21
+ from triton import Config as TConfig
22
+
23
+ from triton.runtime import driver # @manual
24
+
25
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
26
+
27
+
28
+ # ===== Supporting utils, CUDA and TMA =====
29
+
30
+
31
+ class CudaUtils:
32
+ @staticmethod
33
+ def is_cuda() -> bool:
34
+ """Check if Triton is running on CUDA backend."""
35
+ return driver.active.get_current_target().backend == "cuda"
36
+
37
+ @staticmethod
38
+ def verify_tma() -> bool:
39
+ """Check if TMA is supported on the current device."""
40
+ return (
41
+ CudaUtils.is_cuda()
42
+ and torch.cuda.is_available()
43
+ and torch.cuda.get_device_capability()[0] >= 9
44
+ )
45
+
46
+ @staticmethod
47
+ def get_num_sms() -> int:
48
+ """Get the number of streaming multiprocessors on the current device."""
49
+ if not CudaUtils.is_cuda():
50
+ raise RuntimeError("Triton is not running on CUDA backend")
51
+ if not torch.cuda.is_available():
52
+ raise RuntimeError("CUDA is not available")
53
+ return torch.cuda.get_device_properties("cuda").multi_processor_count
54
+
55
+
56
+ class TmaDescriptorHelper:
57
+ """Helper class for managing TMA descriptors in Triton kernels."""
58
+
59
+ class KernelParamWrapper:
60
+ """Wrapper to implement the TmaDescKernelParam interface."""
61
+
62
+ def __init__(self, desc: torch.Tensor):
63
+ self.desc = desc
64
+
65
+ def tma_desc_cpu_ptr(self) -> int:
66
+ """Return the CPU pointer to the TMA descriptor."""
67
+ return self.desc.data_ptr()
68
+
69
+ def __init__(self, tma_size: int = 128):
70
+ """Initialize the TMA descriptor helper.
71
+
72
+ Args:
73
+ tma_size: Size of the TMA descriptor in bytes
74
+ """
75
+ if not CudaUtils.verify_tma():
76
+ raise RuntimeError(
77
+ "TMA not supported on this device (requires Hopper or newer)"
78
+ )
79
+ if "nv_tma_desc_type" not in dir(tl):
80
+ raise RuntimeError(
81
+ "TMA grid constant descriptors not supported in your Triton version"
82
+ )
83
+
84
+ self.tma_size = tma_size
85
+ self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
86
+ self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
87
+ self.descriptors: Dict[str, torch.Tensor] = {}
88
+
89
+ def init_tma_descriptor(self, name: str) -> None:
90
+ """Initialize a TMA descriptor with the given name.
91
+
92
+ Call this method outside of the lambda function for grid size.
93
+ """
94
+ self.descriptors[name] = torch.empty(
95
+ self.tma_size, device="cpu", dtype=torch.int8
96
+ )
97
+
98
+ def fill_1d_tma_descriptor(
99
+ self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
100
+ ) -> None:
101
+ """Fill a 1D TMA descriptor.
102
+
103
+ Call this method inside the lambda function for grid size.
104
+ """
105
+ if name not in self.descriptors:
106
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
107
+
108
+ desc_x = self.descriptors[name]
109
+ if desc_x.data_ptr() % 64 != 0:
110
+ raise ValueError("TMA descriptor must be 64-byte aligned")
111
+ self.fill_1d_tma_descriptor_inner(
112
+ ptr, dim, block_dim, element_size, desc_x.data_ptr()
113
+ )
114
+
115
+ def fill_2d_tma_descriptor(
116
+ self,
117
+ name: str,
118
+ ptr: int,
119
+ dim1: int,
120
+ dim0: int,
121
+ block_dim1: int,
122
+ block_dim0: int,
123
+ element_size: int,
124
+ ) -> None:
125
+ """Fill a 2D TMA descriptor.
126
+
127
+ Call this method inside the lambda function for grid size.
128
+ """
129
+ if name not in self.descriptors:
130
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
131
+
132
+ desc_x = self.descriptors[name]
133
+ if desc_x.data_ptr() % 64 != 0:
134
+ raise ValueError("TMA descriptor must be 64-byte aligned")
135
+ self.fill_2d_tma_descriptor_inner(
136
+ ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
137
+ )
138
+
139
+ def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
140
+ """Get the TMA descriptor kernel parameter for the given name."""
141
+ if name not in self.descriptors or self.descriptors[name] is None:
142
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
143
+ return self.KernelParamWrapper(self.descriptors[name])
144
+
145
+
146
+ # ====== Autotuning utilities ======
147
+ ALIGN_SIZE_M = 128
148
+
149
+ _NV_CONFIGS = [
150
+ triton.Config(
151
+ {
152
+ "BLOCK_SIZE_M": block_size_m,
153
+ "BLOCK_SIZE_N": block_size_n,
154
+ "BLOCK_SIZE_K": block_size_k,
155
+ },
156
+ num_stages=num_stages,
157
+ num_warps=num_warps,
158
+ num_ctas=num_ctas,
159
+ )
160
+ for block_size_m in [ALIGN_SIZE_M, ]
161
+ for block_size_n in [64, 128, 256]
162
+ for block_size_k in [64, 128, 256]
163
+ for num_stages in [3, 4]
164
+ for num_warps in [4, 8]
165
+ for num_ctas in [1]
166
+ ]
167
+
168
+
169
+ def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
170
+ device = torch.cuda.current_device()
171
+ # Check for all possible pointer parameter names
172
+ if "grad_input_ptr" in named_args:
173
+ ptr_name = "grad_input_ptr"
174
+ elif "c_ptr" in named_args:
175
+ ptr_name = "c_ptr"
176
+ elif "grad_weight_ptr" in named_args:
177
+ ptr_name = "grad_weight_ptr"
178
+ else:
179
+ raise KeyError("No recognized pointer parameter found in kernel arguments")
180
+
181
+ if dtsize is None:
182
+ dtsize = named_args[ptr_name].element_size()
183
+ if dtype is None:
184
+ dtype = named_args[ptr_name].dtype
185
+
186
+ pruned_configs = []
187
+ for config in configs:
188
+ kw = config.kwargs
189
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
190
+ kw["BLOCK_SIZE_M"],
191
+ kw["BLOCK_SIZE_N"],
192
+ kw["BLOCK_SIZE_K"],
193
+ config.num_stages,
194
+ )
195
+ G, M, N, K = (
196
+ named_args["G"],
197
+ named_args["M_BUCKET"],
198
+ named_args["N"],
199
+ named_args["K"],
200
+ )
201
+
202
+ # 1. make sure we have enough smem
203
+ max_shared_memory = driver.active.utils.get_device_properties(device)[
204
+ "max_shared_mem"
205
+ ]
206
+
207
+ required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
208
+ if required_shared_memory > max_shared_memory:
209
+ continue
210
+
211
+ M_PER_GROUP = M // G
212
+ MIN_M_TILES = 64
213
+ # 2. make sure we don't load M tiles that are too big
214
+ if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
215
+ continue
216
+ # 3. make sure we don't load N tiles that are too small
217
+ if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
218
+ continue
219
+
220
+ num_sm = driver.active.utils.get_device_properties(device)[
221
+ "multiprocessor_count"
222
+ ]
223
+ N_TILES = N // BLOCK_N
224
+ MIN_N_TILES = 64
225
+ # 4. make sure we don't load N tiles that are too big
226
+ if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
227
+ continue
228
+ # 5. make sure we don't load N tiles that are too small
229
+ if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
230
+ continue
231
+ # 6. make sure K can be evenly divided
232
+ if K % BLOCK_K != 0:
233
+ continue
234
+
235
+ pruned_configs.append(config)
236
+
237
+ return pruned_configs
238
+
239
+
240
+ # ======== End Autotuning utilities ========
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import unittest
10
+ from typing import Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from mg_grouped_gemm import grouped_gemm_forward
16
+
17
+
18
+ class TestMG_GroupedGEMM(unittest.TestCase):
19
+ def setUp(self) -> None:
20
+ torch.manual_seed(2020)
21
+
22
+ def _run_grouped_gemm_test(
23
+ self,
24
+ shape: Tuple[int, int, int, int],
25
+ device: torch.device,
26
+ dtype: torch.dtype = torch.bfloat16,
27
+ atol: float = 1e-5,
28
+ rtol: float = 1.6e-2,
29
+ ) -> None:
30
+ G, M, N, K = shape
31
+ # In M*G grouping, input is [M*G, K] and weights are [N*G, K]
32
+ a = torch.randn(M * G, K, dtype=dtype, device=device)
33
+ b = torch.randn(N * G, K, dtype=dtype, device=device)
34
+
35
+ # Create equal-sized groups for simplicity
36
+ m_size = M
37
+ m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
38
+
39
+ result = grouped_gemm_forward(a, b, m_sizes)
40
+ self.assertTrue(result.shape == (M * G, N))
41
+
42
+ expected_result = torch.zeros(M * G, N, dtype=dtype, device=device)
43
+ m_start = 0
44
+ for g in range(G):
45
+ m_end = m_start + m_sizes[g]
46
+ b_slice = b[N * g : N * (g+1), :]
47
+ expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T
48
+ m_start = m_end
49
+
50
+ # Convert result to match input dtype if needed
51
+ result = result.to(dtype)
52
+ torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol)
53
+
54
+ def test_MG_grouped_gemm_bf16(self) -> None:
55
+ for G in (1, 4, 16):
56
+ for M in (128, 512, 1024):
57
+ print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}")
58
+ self._run_grouped_gemm_test(
59
+ (G, M, 1024, 1024),
60
+ torch.device("cuda"),
61
+ dtype=torch.bfloat16,
62
+ atol=1e-5,
63
+ rtol=1.6e-2,
64
+ )
65
+
66
+ def test_MG_grouped_gemm_deepseek_shapes(self) -> None:
67
+ """Test with shapes from Deepseek model."""
68
+ deepseek_shapes = [
69
+ (4, 2048, 4096, 7168), # G, M, N, K
70
+ (4, 2048, 7168, 2048),
71
+ (8, 512, 4096, 7168),
72
+ (8, 512, 7168, 2048),
73
+ ]
74
+
75
+ device = torch.device("cuda")
76
+
77
+ for shape in deepseek_shapes:
78
+ G, M, N, K = shape
79
+ print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}")
80
+ self._run_grouped_gemm_test(
81
+ shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2
82
+ )
torchtitan/experiments/llama4/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torchtitan.components.loss import build_cross_entropy_loss
8
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
9
+ from torchtitan.components.optimizer import build_optimizers
10
+ from torchtitan.datasets.hf_datasets import build_hf_dataloader
11
+ from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
12
+ from torchtitan.models.llama3 import pipeline_llama
13
+ from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
14
+
15
+ from .infra.parallelize_llama import parallelize_llama
16
+ from .model.args import TransformerModelArgs
17
+ from .model.model import Transformer
18
+
19
+ __all__ = [
20
+ "TransformerModelArgs",
21
+ "Transformer",
22
+ "llama4_configs",
23
+ ]
24
+
25
+
26
+ llama4_configs = {
27
+ "debugmodel": TransformerModelArgs(
28
+ dim=256,
29
+ n_layers=8,
30
+ n_heads=16,
31
+ rope_theta=500000,
32
+ ),
33
+ "17bx16e": TransformerModelArgs(
34
+ dim=5120,
35
+ n_layers=48,
36
+ n_heads=40,
37
+ n_kv_heads=8,
38
+ ffn_dim_multiplier=1.2,
39
+ multiple_of=2048,
40
+ rope_theta=500000,
41
+ num_experts=16,
42
+ interleave_moe_layer_step=1,
43
+ ),
44
+ "17bx128e": TransformerModelArgs(
45
+ dim=5120,
46
+ n_layers=48,
47
+ n_heads=40,
48
+ n_kv_heads=8,
49
+ ffn_dim_multiplier=1.2,
50
+ multiple_of=2048,
51
+ rope_theta=500000,
52
+ num_experts=128,
53
+ ),
54
+ }
55
+
56
+
57
+ register_train_spec(
58
+ TrainSpec(
59
+ name="llama4",
60
+ cls=Transformer,
61
+ config=llama4_configs,
62
+ parallelize_fn=parallelize_llama,
63
+ pipelining_fn=pipeline_llama,
64
+ build_optimizers_fn=build_optimizers,
65
+ build_lr_schedulers_fn=build_lr_schedulers,
66
+ build_dataloader_fn=build_hf_dataloader,
67
+ build_tokenizer_fn=build_tiktoken_tokenizer,
68
+ build_loss_fn=build_cross_entropy_loss,
69
+ )
70
+ )
torchtitan/experiments/llama4/infra/expert_parallel.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ from functools import partial
9
+ from typing import Optional, Tuple
10
+
11
+ import torch.nn as nn
12
+ from torch.distributed.tensor import (
13
+ DeviceMesh,
14
+ distribute_module,
15
+ distribute_tensor,
16
+ DTensor,
17
+ Partial,
18
+ Replicate,
19
+ Shard,
20
+ )
21
+ from torch.distributed.tensor.parallel import ParallelStyle
22
+ from torch.distributed.tensor.placement_types import Placement
23
+
24
+
25
+ # implementation of Tensor Parallel on the non-shared experts in MoE
26
+ class TensorParallel(ParallelStyle):
27
+ def __init__(
28
+ self,
29
+ *,
30
+ input_layouts: Optional[Tuple[Optional[Placement]]] = None,
31
+ output_layout: Optional[Placement] = None,
32
+ use_local_output: bool = True,
33
+ ):
34
+ super().__init__()
35
+ self.input_layouts = input_layouts or (Replicate(), None)
36
+ self.output_layout = output_layout or Partial()
37
+ self.desired_input_layouts = (Replicate(), None)
38
+ self.use_local_output = use_local_output
39
+
40
+ @staticmethod
41
+ def _prepare_input_fn(
42
+ input_layouts, desired_input_layouts, mod, inputs, device_mesh
43
+ ):
44
+ # TODO: figure out dynamo support for instance method and switch this to instance method
45
+
46
+ # annotate module input placements/sharding with input_layouts
47
+ input_tensor, input_layout, desired_input_layout = (
48
+ inputs[0],
49
+ input_layouts[0],
50
+ desired_input_layouts[0],
51
+ )
52
+ if not isinstance(input_tensor, DTensor):
53
+ input_tensor = DTensor.from_local(
54
+ input_tensor, device_mesh, (input_layout,), run_check=False
55
+ )
56
+
57
+ if input_layouts != desired_input_layouts:
58
+ input_tensor = input_tensor.redistribute(
59
+ placements=(desired_input_layout,), async_op=True
60
+ )
61
+ return (input_tensor, *inputs[1:])
62
+
63
+ def _partition_fn(self, name, module, device_mesh):
64
+ module.register_parameter(
65
+ "w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)]))
66
+ ) # Column-wise sharding
67
+ module.register_parameter(
68
+ "w2",
69
+ nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])),
70
+ ) # Row-wise sharding
71
+ module.register_parameter(
72
+ "w3",
73
+ nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])),
74
+ ) # Column-wise sharding
75
+
76
+ @staticmethod
77
+ def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
78
+ if outputs.placements != (output_layout,):
79
+ outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
80
+ # back to local tensor
81
+ return outputs.to_local() if use_local_output else outputs
82
+
83
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
84
+ return distribute_module(
85
+ module,
86
+ device_mesh,
87
+ self._partition_fn,
88
+ partial(
89
+ self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
90
+ ),
91
+ partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
92
+ )
93
+
94
+
95
+ # NOTE: This is to achieve replicate computation on the gate module in the MoE router.
96
+ # It does nothing other than (1) setting the module parameters as DTensors on the given mesh
97
+ # and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
98
+ # TODO: The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
99
+ # which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
100
+ class NoParallel(ParallelStyle):
101
+ def __init__(
102
+ self,
103
+ *,
104
+ input_layout: Optional[Placement] = None,
105
+ output_layout: Optional[Placement] = None,
106
+ use_local_output: bool = True,
107
+ ):
108
+ super().__init__()
109
+ self.input_layout = input_layout or Replicate()
110
+ self.output_layout = output_layout or Replicate()
111
+ self.desired_input_layout = Replicate()
112
+ self.use_local_output = use_local_output
113
+
114
+ @staticmethod
115
+ def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
116
+ # annotate module input placements/sharding with input_layouts
117
+ input_tensor = inputs[0]
118
+ if not isinstance(input_tensor, DTensor):
119
+ input_tensor = DTensor.from_local(
120
+ input_tensor, device_mesh, (input_layout,), run_check=False
121
+ )
122
+
123
+ if input_layout != desired_input_layout:
124
+ input_tensor = input_tensor.redistribute(
125
+ placements=(desired_input_layout,), async_op=True
126
+ )
127
+ return (input_tensor, *inputs[1:])
128
+
129
+ @staticmethod
130
+ def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
131
+ if outputs.placements != (output_layout,):
132
+ outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
133
+ # back to local tensor
134
+ return outputs.to_local() if use_local_output else outputs
135
+
136
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
137
+ return distribute_module(
138
+ module,
139
+ device_mesh,
140
+ None,
141
+ partial(
142
+ self._prepare_input_fn, self.input_layout, self.desired_input_layout
143
+ ),
144
+ partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
145
+ )
torchtitan/experiments/llama4/infra/parallelize_llama.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.distributed.device_mesh import DeviceMesh
11
+
12
+ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
13
+ from torchtitan.distributed import ParallelDims
14
+
15
+ from torchtitan.models.llama3.parallelize_llama import (
16
+ apply_ac,
17
+ apply_compile,
18
+ apply_ddp,
19
+ apply_fsdp,
20
+ apply_tp,
21
+ )
22
+ from torchtitan.tools.logging import logger
23
+
24
+
25
+ def parallelize_llama(
26
+ model: nn.Module,
27
+ world_mesh: DeviceMesh,
28
+ parallel_dims: ParallelDims,
29
+ job_config: JobConfig,
30
+ ):
31
+ """
32
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
33
+ parallelism to the model.
34
+
35
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
36
+ the model must fit on GPU or CPU memory.
37
+ """
38
+
39
+ if parallel_dims.tp_enabled:
40
+ if (
41
+ job_config.parallelism.enable_async_tensor_parallel
42
+ and not job_config.training.compile
43
+ ):
44
+ raise RuntimeError("Async TP requires --training.compile")
45
+
46
+ enable_float8_linear = "float8" in job_config.model.converters
47
+ float8_is_rowwise = job_config.float8.recipe_name in (
48
+ "rowwise",
49
+ "rowwise_with_gw_hp",
50
+ )
51
+
52
+ # For now, float8 all-gather with TP is only supported for tensorwise
53
+ # float8 scaling recipes. For rowwise recipes, we use regular TP and
54
+ # all-gather happens in high precision.
55
+ enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
56
+
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
62
+ enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
63
+ )
64
+
65
+ apply_moe_tp(model, world_mesh["tp"])
66
+
67
+ if job_config.activation_checkpoint.mode != "none":
68
+ if (
69
+ job_config.activation_checkpoint.mode == "selective"
70
+ and job_config.model.use_flex_attn
71
+ ):
72
+ raise ValueError(
73
+ "FlexAttention is not compatible with selective AC yet. "
74
+ "See https://github.com/pytorch/pytorch/issues/147879"
75
+ )
76
+ apply_ac(model, job_config.activation_checkpoint)
77
+
78
+ # turn on per-TransformerBlock compile after AC wrapping and before FSDP
79
+ if job_config.training.compile:
80
+ apply_compile(model)
81
+
82
+ # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
83
+ torch._dynamo.config.capture_scalar_outputs = True
84
+
85
+ if (
86
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
87
+ ): # apply FSDP or HSDP, potentially with Context Parallel
88
+ if parallel_dims.dp_replicate_enabled:
89
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
90
+ else:
91
+ dp_mesh_dim_names = ("dp_shard_cp",)
92
+
93
+ apply_fsdp(
94
+ model,
95
+ world_mesh[tuple(dp_mesh_dim_names)],
96
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
97
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
98
+ pp_enabled=parallel_dims.pp_enabled,
99
+ cpu_offload=job_config.training.enable_cpu_offload,
100
+ reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
101
+ )
102
+
103
+ if parallel_dims.dp_replicate_enabled:
104
+ logger.info("Applied HSDP to the model")
105
+ else:
106
+ logger.info("Applied FSDP to the model")
107
+
108
+ if parallel_dims.cp_enabled:
109
+ logger.info("Applied Context Parallel to the model")
110
+
111
+ if job_config.training.enable_cpu_offload:
112
+ logger.info("Applied CPU Offloading to the model")
113
+ elif parallel_dims.dp_replicate_enabled:
114
+ if world_mesh.ndim > 1:
115
+ raise RuntimeError("DDP has not supported > 1D parallelism")
116
+ apply_ddp(
117
+ model,
118
+ world_mesh,
119
+ enable_compile=job_config.training.compile,
120
+ enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
121
+ )
122
+
123
+ return model
124
+
125
+
126
+ def apply_moe_tp(
127
+ model: nn.Module,
128
+ tp_mesh: DeviceMesh,
129
+ ):
130
+ from torch.distributed.tensor import Partial, Replicate, Shard
131
+ from torch.distributed.tensor.parallel import (
132
+ parallelize_module,
133
+ PrepareModuleInputOutput,
134
+ )
135
+
136
+ from .expert_parallel import NoParallel, TensorParallel
137
+
138
+ for _, transformer_block in model.layers.items():
139
+ moe_layer_plan = {
140
+ # input / output sharding on the seqlen dim
141
+ # all-gather for input, reduce-scatter for output
142
+ "moe": PrepareModuleInputOutput(
143
+ input_layouts=(Shard(1),),
144
+ desired_input_layouts=(Replicate(),),
145
+ use_local_input=True,
146
+ output_layouts=(Partial(),),
147
+ desired_output_layouts=(Shard(1),),
148
+ ),
149
+ # replicate computation for the router
150
+ "moe.router.gate": NoParallel(),
151
+ # input Replicate, output Partial
152
+ "moe.experts": TensorParallel(),
153
+ "moe.shared_expert": TensorParallel(),
154
+ }
155
+ parallelize_module(
156
+ module=transformer_block,
157
+ device_mesh=tp_mesh,
158
+ parallelize_plan=moe_layer_plan,
159
+ )
torchtitan/experiments/llama4/model/__pycache__/args.cpython-311.pyc ADDED
Binary file (4.43 kB). View file
 
torchtitan/experiments/llama4/model/args.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ from torch import nn
12
+ from torchtitan.components.tokenizer import Tokenizer
13
+ from torchtitan.config_manager import JobConfig
14
+
15
+ from torchtitan.protocols.train_spec import BaseModelArgs
16
+ from torchtitan.tools.logging import logger
17
+
18
+
19
+ @dataclass
20
+ class TransformerModelArgs(BaseModelArgs):
21
+ dim: int = 4096
22
+ n_layers: int = 32
23
+ n_heads: int = 32
24
+ n_kv_heads: Optional[int] = None
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ ffn_dim_multiplier: Optional[float] = None
28
+ norm_eps: float = 1e-5
29
+ rope_theta: float = 10000
30
+
31
+ max_seq_len: int = 2048
32
+ # If `True`, then each transformer block init uses its layer ID, and if
33
+ # `False`, each uses the total number of transformer blocks
34
+ depth_init: bool = True
35
+ norm_type: str = "rmsnorm"
36
+
37
+ use_flex_attn: bool = False
38
+ attn_mask_type: str = "causal"
39
+ eos_id: int = 0
40
+
41
+ # MoE args
42
+ moe_enabled: bool = True
43
+ num_experts: int = 8
44
+ use_shared_expert: bool = True
45
+ auto_scale_hidden_dim: bool = True
46
+ # frequency of using MoE layer instead of feedforward layer in a transformer block
47
+ interleave_moe_layer_step: int = 2
48
+ # token-choice
49
+ top_k: int = 1
50
+
51
+ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
52
+ self.norm_type = job_config.model.norm_type
53
+ self.vocab_size = tokenizer.n_words
54
+ self.max_seq_len = job_config.training.seq_len
55
+ self.use_flex_attn = job_config.model.use_flex_attn
56
+
57
+ def get_nparams_and_flops(
58
+ self, model: nn.Module, seq_len: int
59
+ ) -> tuple[int, float]:
60
+ nparams_embedding = 0
61
+ nparams_moe_router = 0
62
+ nparams_shared_expert = 0
63
+ nparams_experts = 0
64
+ nparams_dense = 0
65
+
66
+ for name, p in model.named_parameters():
67
+ if "embedding" in name:
68
+ nparams_embedding += p.numel()
69
+ nparams_dense += p.numel()
70
+ elif "moe.shared_expert" in name:
71
+ nparams_shared_expert += p.numel()
72
+ elif "moe.router" in name:
73
+ nparams_moe_router += p.numel()
74
+ elif "moe.experts" in name:
75
+ nparams_experts += p.numel()
76
+ else:
77
+ nparams_dense += p.numel()
78
+
79
+ nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
80
+ nparams = nparams_dense + nparams_sparse
81
+ nparams_sparse_active = (
82
+ nparams_moe_router
83
+ + nparams_shared_expert
84
+ + nparams_experts * self.top_k // self.num_experts
85
+ )
86
+
87
+ logger.info(
88
+ f"Total parameter count: dense {nparams_dense:,}, "
89
+ f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
90
+ )
91
+
92
+ l, h, q, t = (
93
+ self.n_layers,
94
+ self.n_heads,
95
+ self.dim // self.n_heads,
96
+ seq_len,
97
+ )
98
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
99
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
100
+ # 2. the flash attention does 1 more matmul recomputation in the backward
101
+ # but recomputation should not be counted in calculating MFU (+0)
102
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
103
+ # 4. we follow the convention and do not account for sparsity in causal attention
104
+ num_flops_per_token = (
105
+ 6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
106
+ + 12 * l * h * q * t
107
+ )
108
+
109
+ return nparams, num_flops_per_token
torchtitan/experiments/llama4/model/moe.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from .args import TransformerModelArgs
12
+
13
+
14
+ class GroupedExperts(nn.Module):
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ hidden_dim: int,
19
+ num_experts: int,
20
+ ):
21
+ super().__init__()
22
+ self.num_experts = num_experts
23
+ self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
24
+ self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
25
+ self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
26
+
27
+ def forward(
28
+ self,
29
+ x: torch.Tensor,
30
+ num_local_tokens_per_expert: torch.Tensor | None = None,
31
+ ) -> torch.Tensor:
32
+ if num_local_tokens_per_expert is not None:
33
+ # a tuple of tensors indexed by experts
34
+ # each with shape (tokens_per_expert(varying), dim)
35
+ x = torch.split(
36
+ x,
37
+ split_size_or_sections=num_local_tokens_per_expert.tolist(),
38
+ dim=0,
39
+ )
40
+ out_experts_splits = []
41
+ for expert_idx, x_expert in enumerate(x):
42
+ w1, w2, w3 = (
43
+ self.w1[expert_idx],
44
+ self.w2[expert_idx],
45
+ self.w3[expert_idx],
46
+ )
47
+ h = F.silu(torch.matmul(x_expert, w1))
48
+ h = h * torch.matmul(x_expert, w3)
49
+ h = torch.matmul(h, w2)
50
+ # h shape (tokens_per_expert(varying), dim)
51
+ out_experts_splits.append(h)
52
+ out = torch.cat(out_experts_splits, dim=0)
53
+
54
+ # TODO:optimize with GroupedGEMM
55
+ # https://github.com/pytorch/pytorch/pull/150374
56
+ # _gouped_mm requires shapes to be multiple of 8
57
+ # offsets = torch.cumsum(num_local_tokens_per_expert, dim=0, dtype=torch.int32)
58
+ # h = F.silu(torch._grouped_mm(x, self.w1.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16))
59
+ # h = h * torch._grouped_mm(x, self.w3.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
60
+ # out = torch._grouped_mm(h, self.w2.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
61
+ else:
62
+ # x shape (num_experts, tokens_per_expert, dim)
63
+ h = F.silu(torch.bmm(x, self.w1))
64
+ h = h * torch.bmm(x, self.w3)
65
+ # out shape (num_experts, tokens_per_expert, dim)
66
+ out = torch.bmm(h, self.w2)
67
+ return out
68
+
69
+ def init_weights(self, init_std: float):
70
+ nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
71
+ nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
72
+ nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
73
+
74
+
75
+ class TokenChoiceTopKRouter(nn.Module):
76
+ """This class implements token-choice routing. In token-choice top-K routing, each token is
77
+ routed to top K experts based on the router scores.
78
+
79
+ Args:
80
+ gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
81
+ dim (int): Dimension of input tokens.
82
+ num_experts (int): Number of experts in each moe layer.
83
+ top_k (int): Number of experts each token will be routed to in token-choice routing.
84
+ use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ dim: int,
90
+ num_experts: int,
91
+ top_k: int,
92
+ use_sigmoid: bool = False,
93
+ ):
94
+ super().__init__()
95
+ self.gate = nn.Linear(dim, num_experts, bias=False)
96
+ self.num_experts = num_experts
97
+ self.top_k = top_k
98
+ self.use_sigmoid = use_sigmoid
99
+
100
+ def forward(
101
+ self, x: torch.Tensor
102
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
103
+ """
104
+ Args:
105
+ x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
106
+
107
+ Returns:
108
+ routed_input (torch.Tensor):
109
+ Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``.
110
+ token_indices (torch.Tensor):
111
+ Token indices for routed_input with shape ``(bs*slen*top_k,)``.
112
+ num_local_tokens_per_expert (torch.Tensor):
113
+ Number of tokens assigned to each expert with shape ``(num_experts,)``.
114
+ """
115
+ # scores shape (bs*slen, num_experts)
116
+ scores = self.gate(x)
117
+
118
+ # By default, sigmoid or softmax is performed in float32 to avoid loss explosion
119
+ if self.use_sigmoid:
120
+ scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
121
+ else:
122
+ scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype)
123
+
124
+ # top scores shape (bs*slen, top_k)
125
+ top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1)
126
+ # top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype)
127
+
128
+ # group tokens together by expert indices from 0 to num_experts and pass that to experts forward
129
+ num_local_tokens_per_expert = torch.histc(
130
+ selected_experts_indices.view(-1),
131
+ bins=self.num_experts,
132
+ min=0,
133
+ max=self.num_experts,
134
+ )
135
+ # token_indices_experts_sorted shape (bs*slen*top_k,)
136
+ token_indices_experts_sorted = torch.argsort(
137
+ selected_experts_indices.view(-1), stable=True
138
+ )
139
+ top_scores = top_scores.view(-1)[token_indices_experts_sorted]
140
+ token_indices_experts_sorted = token_indices_experts_sorted // self.top_k
141
+
142
+ return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert
143
+
144
+ def init_weights(self, init_std: float):
145
+ nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
146
+
147
+
148
+ # TODO: implement load balancing auxiliary loss for token-choice routing
149
+ class MoE(nn.Module):
150
+ def __init__(self, model_args: TransformerModelArgs):
151
+ super().__init__()
152
+ dim = model_args.dim
153
+ hidden_dim = 4 * model_args.dim
154
+ ffn_dim_multiplier = model_args.ffn_dim_multiplier
155
+ hidden_dim = int(2 * hidden_dim / 3)
156
+ if ffn_dim_multiplier is not None:
157
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
158
+
159
+ num_experts = model_args.num_experts
160
+
161
+ hidden_dim_denom = 1
162
+ if model_args.auto_scale_hidden_dim:
163
+ hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert)
164
+
165
+ if model_args.auto_scale_hidden_dim:
166
+ hidden_dim = int(hidden_dim / hidden_dim_denom)
167
+ hidden_dim += -hidden_dim % model_args.multiple_of
168
+
169
+ self.experts = GroupedExperts(
170
+ dim=dim, hidden_dim=hidden_dim, num_experts=num_experts
171
+ )
172
+ self.router = TokenChoiceTopKRouter(
173
+ dim=dim, num_experts=num_experts, top_k=model_args.top_k
174
+ )
175
+ self.shared_expert = (
176
+ GroupedExperts(dim=dim, hidden_dim=hidden_dim, num_experts=1)
177
+ if model_args.use_shared_expert
178
+ else None
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ """
183
+ Args:
184
+ x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``.
185
+
186
+ Returns:
187
+ out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
188
+ """
189
+ bs, slen, dim = x.shape
190
+ # top_scores and selected_indices shape (bs*slen*top_k,)
191
+ # num_local_tokens_per_expert shape (num_experts,)
192
+ (
193
+ top_scores,
194
+ token_indices,
195
+ num_local_tokens_per_expert,
196
+ ) = self.router(x.reshape(bs * slen, dim))
197
+
198
+ # shape (bs*slen*top_k, dim)
199
+ token_indices = token_indices.reshape(-1, 1).expand(-1, dim)
200
+
201
+ # shape (bs*slen*top_k, dim)
202
+ routed_input = torch.gather(
203
+ x.view(-1, dim),
204
+ dim=0,
205
+ index=token_indices,
206
+ )
207
+ routed_input = routed_input * top_scores.reshape(-1, 1)
208
+
209
+ # shape (bs*slen*top_k, dim)
210
+ routed_output = self.experts(routed_input, num_local_tokens_per_expert)
211
+
212
+ # shared expert
213
+ if self.shared_expert is not None:
214
+ out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
215
+ bs * slen, dim
216
+ )
217
+ else:
218
+ out = torch.zeros_like(x.reshape(bs * slen, dim))
219
+
220
+ out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
221
+ out = out.reshape(bs, slen, dim)
222
+ return out
223
+
224
+ def init_weights(self, init_std: float):
225
+ self.experts.init_weights(init_std)
226
+ self.router.init_weights(init_std)
227
+ if self.shared_expert is not None:
228
+ self.shared_expert.init_weights(init_std)
torchtitan/experiments/llama4/scripts/REAME.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## How to convert a Llama 4 checkpoint for use in torchtitan
2
+
3
+ To continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager.
4
+ This folder contains the scripts for converting officially released Llama 4 checkpoints into the expected DCP format, from original Meta format, or from HuggingFace format, using GPUs.
5
+
6
+ #### Example usage
7
+
8
+ From Meta format:
9
+ ```bash
10
+ CONFIG_FILE=../train_configs/llama4_16.toml ./convert_meta_to_dcp.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8
11
+ ```
12
+
13
+
14
+ From HuggingFace format:
15
+ ```bash
16
+ CONFIG_FILE=../train_configs/llama4_16.toml ./convert_hf_to_dcp_with_gpus.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8
17
+ ```
torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ set -ex
9
+
10
+ # use envs as local overrides for convenience
11
+ # e.g.
12
+ # LOG_RANK=0,1 NGPU=4 ./convert_meta_to_dcp_with_gpus.sh
13
+ NGPU=${NGPU:-"8"}
14
+ LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7}
15
+ CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"}
16
+
17
+ overrides=""
18
+ if [ $# -ne 0 ]; then
19
+ overrides="$*"
20
+ fi
21
+
22
+ PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
23
+ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
24
+ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
25
+ convert_meta_to_dcp_with_gpus_meta.py --job.config_file ${CONFIG_FILE} $overrides
torchtitan/experiments/llama4/train_configs/debug_model.toml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [job]
2
+ dump_folder = "./outputs"
3
+ description = "Llama 4 debug training"
4
+ print_args = false
5
+ use_for_integration_test = true
6
+
7
+ [profiling]
8
+ enable_profiling = false
9
+ save_traces_folder = "profile_trace"
10
+ profile_freq = 10
11
+ enable_memory_snapshot = false
12
+ save_memory_snapshot_folder = "memory_snapshot"
13
+
14
+ [metrics]
15
+ log_freq = 1
16
+ disable_color_printing = false
17
+ enable_tensorboard = false
18
+ save_tb_folder = "tb"
19
+ enable_wandb = false
20
+
21
+ [model]
22
+ name = "llama4"
23
+ flavor = "debugmodel"
24
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
25
+ # test tokenizer.model, for debug purpose only
26
+ tokenizer_path = "./tests/assets/test_tiktoken.model"
27
+ # converters = "float8"
28
+ use_flex_attn = false
29
+ attn_mask_type = "causal" # causal / block_causal
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 4e-3
34
+ eps = 1e-15
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.1
41
+
42
+ [training]
43
+ batch_size = 8
44
+ seq_len = 2048
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
49
+
50
+ [parallelism]
51
+ data_parallel_replicate_degree = 1
52
+ data_parallel_shard_degree = -1
53
+ fsdp_reshard_after_forward = "default" # default / never / always
54
+ tensor_parallel_degree = 1
55
+ enable_async_tensor_parallel = false
56
+ pipeline_parallel_degree = 1
57
+ context_parallel_degree = 1
58
+
59
+ [checkpoint]
60
+ enable_checkpoint = false
61
+ folder = "checkpoint"
62
+ interval = 10
63
+ model_weights_only = false
64
+ export_dtype = "float32"
65
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
66
+
67
+ [activation_checkpoint]
68
+ mode = 'none' # ['none', 'selective', 'full']
69
+ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy
70
+
71
+ [float8]
72
+ enable_fsdp_float8_all_gather = false
73
+ precompute_float8_dynamic_scale_for_fsdp = false
74
+ filter_fqns = "output,router.gate"
torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: this toml config is still under development
2
+
3
+ [job]
4
+ dump_folder = "./outputs"
5
+ description = "Llama 4 Maverick 17Bx128E training"
6
+
7
+ [profiling]
8
+ enable_profiling = false
9
+ save_traces_folder = "profile_trace"
10
+ profile_freq = 100
11
+
12
+ [metrics]
13
+ log_freq = 10
14
+ enable_tensorboard = false
15
+ save_tb_folder = "tb"
16
+
17
+ [model]
18
+ name = "llama4"
19
+ flavor = "17bx128e"
20
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
21
+ tokenizer_path = "./assets/tokenizer/tokenizer.model"
22
+ # converters = "float8"
23
+
24
+ [optimizer]
25
+ name = "AdamW"
26
+ lr = 4e-3
27
+ eps = 1e-15
28
+
29
+ [lr_scheduler]
30
+ warmup_steps = 600
31
+ lr_min = 0.1
32
+
33
+ [training]
34
+ batch_size = 1
35
+ seq_len = 8192
36
+ max_norm = 1.0 # grad norm clipping
37
+ steps = 3000
38
+ compile = false
39
+ dataset = "c4"
40
+
41
+ [parallelism]
42
+ data_parallel_replicate_degree = 1
43
+ data_parallel_shard_degree = -1
44
+ tensor_parallel_degree = 8
45
+ enable_async_tensor_parallel = false
46
+ pipeline_parallel_degree = 4
47
+ # pipeline_parallel_schedule = "interleaved1f1b"
48
+ # pipeline_parallel_microbatches = 2
49
+ context_parallel_degree = 1
50
+
51
+ [checkpoint]
52
+ enable_checkpoint = false
53
+ folder = "checkpoint"
54
+ interval = 500
55
+ model_weights_only = false
56
+ export_dtype = "float32"
57
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
58
+
59
+ [activation_checkpoint]
60
+ mode = 'full' # ['none', 'selective', 'full']
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+ filter_fqns = "output,router.gate"
torchtitan/experiments/multimodal/requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torchvision
torchtitan/experiments/multimodal/tests/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
torchtitan/experiments/multimodal/tests/test_utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ from typing import Optional, Union
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+
15
+ def fixed_init_tensor(
16
+ shape: torch.Size,
17
+ min_val: Union[float, int] = 0.0,
18
+ max_val: Union[float, int] = 1.0,
19
+ nonlinear: bool = False,
20
+ dtype: torch.dtype = torch.float,
21
+ ):
22
+ """
23
+ Utility for generating deterministic tensors of a given shape. In general stuff
24
+ like torch.ones, torch.eye, etc can result in trivial outputs. This utility
25
+ generates a range tensor [min_val, max_val) of a specified dtype, applies
26
+ a sine function if nonlinear=True, then reshapes to the appropriate shape.
27
+ """
28
+ n_elements = math.prod(shape)
29
+ step_size = (max_val - min_val) / n_elements
30
+ x = torch.arange(min_val, max_val, step_size, dtype=dtype)
31
+ x = x.reshape(shape)
32
+ if nonlinear:
33
+ return torch.sin(x)
34
+ return x
35
+
36
+
37
+ @torch.no_grad
38
+ def fixed_init_model(
39
+ model: nn.Module,
40
+ min_val: Union[float, int] = 0.0,
41
+ max_val: Union[float, int] = 1.0,
42
+ nonlinear: bool = False,
43
+ dtype: Optional[torch.dtype] = None,
44
+ ):
45
+ """
46
+ This utility initializes all parameters of a model deterministically using the
47
+ function fixed_init_tensor above. See that docstring for details of each parameter.
48
+ """
49
+ for _, param in model.named_parameters():
50
+ param.copy_(
51
+ fixed_init_tensor(
52
+ param.shape,
53
+ min_val=min_val,
54
+ max_val=max_val,
55
+ nonlinear=nonlinear,
56
+ dtype=param.dtype if dtype is None else dtype,
57
+ )
58
+ )
torchtitan/experiments/multimodal/tokenizer/tiktoken.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+ import os
11
+ from pathlib import Path
12
+ from typing import (
13
+ AbstractSet,
14
+ Any,
15
+ cast,
16
+ Collection,
17
+ Dict,
18
+ Iterator,
19
+ List,
20
+ Literal,
21
+ Mapping,
22
+ Optional,
23
+ Sequence,
24
+ Union,
25
+ )
26
+
27
+ import tiktoken
28
+ import torch
29
+ from tiktoken.load import load_tiktoken_bpe
30
+
31
+ from torchtitan.components.tokenizer import Tokenizer
32
+ from torchtitan.config_manager import JobConfig
33
+ from torchtitan.tools.logging import logger
34
+
35
+ IMAGE_TOKEN_ID = 128256
36
+ IGNORE_INDEX = -100
37
+
38
+
39
+ class TikTokenizer(Tokenizer):
40
+ """
41
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
42
+
43
+ Args:
44
+ model_path (str): The path to the Tiktoken model file.
45
+ """
46
+
47
+ special_tokens: Dict[str, int]
48
+
49
+ num_reserved_special_tokens = 256
50
+
51
+ pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950
52
+
53
+ def __init__(self, model_path: str):
54
+ super().__init__(model_path)
55
+ assert os.path.isfile(model_path), model_path
56
+
57
+ mergeable_ranks = load_tiktoken_bpe(model_path)
58
+ num_base_tokens = len(mergeable_ranks)
59
+ special_tokens = [
60
+ "<|begin_of_text|>",
61
+ "<|end_of_text|>",
62
+ "<|reserved_special_token_0|>",
63
+ "<|reserved_special_token_1|>",
64
+ "<|reserved_special_token_2|>",
65
+ "<|reserved_special_token_3|>",
66
+ "<|start_header_id|>",
67
+ "<|end_header_id|>",
68
+ "<|reserved_special_token_4|>",
69
+ "<|eot_id|>", # end of turn
70
+ ] + [
71
+ f"<|reserved_special_token_{i}|>"
72
+ for i in range(5, self.num_reserved_special_tokens - 5)
73
+ ]
74
+ self.special_tokens = {
75
+ token: num_base_tokens + i for i, token in enumerate(special_tokens)
76
+ }
77
+ self.special_tokens["<|image|>"] = IMAGE_TOKEN_ID
78
+ self.model = tiktoken.Encoding(
79
+ name=Path(model_path).name,
80
+ pat_str=self.pat_str,
81
+ mergeable_ranks=mergeable_ranks,
82
+ special_tokens=self.special_tokens,
83
+ )
84
+
85
+ self._n_words: int = self.model.n_vocab
86
+ # BOS / EOS token IDs
87
+ self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
88
+ self.eos_id: int = self.special_tokens["<|end_of_text|>"]
89
+ self.pad_id: int = -1
90
+ self.image_id = IMAGE_TOKEN_ID
91
+ self.stop_tokens = {
92
+ self.special_tokens["<|end_of_text|>"],
93
+ self.special_tokens["<|eot_id|>"],
94
+ }
95
+ logger.info(
96
+ f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}, IMAGE ID {self.image_id}"
97
+ )
98
+
99
+ def encode(
100
+ self,
101
+ s: str,
102
+ *,
103
+ bos: bool,
104
+ eos: bool,
105
+ allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
106
+ disallowed_special: Optional[Union[Literal["all"], Collection[str]]] = None,
107
+ ) -> List[int]:
108
+ """
109
+ Encodes a string into a list of token IDs.
110
+
111
+ Args:
112
+ s (str): The input string to be encoded.
113
+ bos (bool): Whether to prepend the beginning-of-sequence token.
114
+ eos (bool): Whether to append the end-of-sequence token.
115
+ allowed_tokens ("all"|set[str]): allowed special tokens in string
116
+ disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
117
+
118
+ Returns:
119
+ list[int]: A list of token IDs.
120
+
121
+ By default, setting disallowed_special=() encodes a string by ignoring
122
+ special tokens. Specifically:
123
+ - Setting `disallowed_special` to () will cause all text corresponding
124
+ to special tokens to be encoded as natural text (insteading of raising
125
+ an error).
126
+ - Setting `allowed_special` to "all" will treat all text corresponding
127
+ to special tokens to be encoded as special tokens.
128
+ """
129
+ assert type(s) is str
130
+ allowed_special = allowed_special or set()
131
+ disallowed_special = disallowed_special or ()
132
+
133
+ # The tiktoken tokenizer can handle <=400k chars without
134
+ # pyo3_runtime.PanicException.
135
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
136
+
137
+ # https://github.com/openai/tiktoken/issues/195
138
+ # Here we iterate over subsequences and split if we exceed the limit
139
+ # of max consecutive non-whitespace or whitespace characters.
140
+ MAX_NO_WHITESPACES_CHARS = 25_000
141
+
142
+ substrs = (
143
+ substr
144
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
145
+ for substr in self._split_whitespaces_or_nonwhitespaces(
146
+ s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
147
+ )
148
+ )
149
+ t: List[int] = []
150
+ for substr in substrs:
151
+ t.extend(
152
+ self.model.encode(
153
+ substr,
154
+ allowed_special=allowed_special,
155
+ disallowed_special=disallowed_special,
156
+ )
157
+ )
158
+ if bos:
159
+ t.insert(0, self.bos_id)
160
+ if eos:
161
+ t.append(self.eos_id)
162
+ return t
163
+
164
+ def decode(self, t: Sequence[int]) -> str:
165
+ """
166
+ Decodes a list of token IDs into a string.
167
+
168
+ Args:
169
+ t (List[int]): The list of token IDs to be decoded.
170
+
171
+ Returns:
172
+ str: The decoded string.
173
+ """
174
+ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
175
+ return self.model.decode(cast(List[int], t))
176
+
177
+ @staticmethod
178
+ def _split_whitespaces_or_nonwhitespaces(
179
+ s: str, max_consecutive_slice_len: int
180
+ ) -> Iterator[str]:
181
+ """
182
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
183
+ consecutive whitespaces or consecutive non-whitespaces.
184
+ """
185
+ current_slice_len = 0
186
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
187
+ slice_start = 0
188
+
189
+ for i in range(len(s)):
190
+ is_now_space = s[i].isspace()
191
+
192
+ if current_slice_is_space ^ is_now_space:
193
+ current_slice_len = 1
194
+ current_slice_is_space = is_now_space
195
+ else:
196
+ current_slice_len += 1
197
+ if current_slice_len > max_consecutive_slice_len:
198
+ yield s[slice_start:i]
199
+ slice_start = i
200
+ current_slice_len = 1
201
+ yield s[slice_start:]
202
+
203
+ def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]:
204
+ """
205
+ Tokenizes a `str` of text and creates `labels` masking BOS, EOS and `image_id` tokens.
206
+ """
207
+ # TODO(tj.solergibert) Should we keep `input_ids` OR `tokens` across this class, VisionCrossAttentionMask & the collator?
208
+ # For me it makes more sense to split `tokens` between `input_ids` & `labels` as in train.py BUT the `MultimodalDecoder`
209
+ # & everything else expects `tokens`
210
+ text = sample["text"]
211
+ tokens = self.encode(
212
+ text, bos=True, eos=True, allowed_special=set(["<|image|>"])
213
+ )
214
+ input_ids = torch.LongTensor(tokens[:-1])
215
+ labels = torch.LongTensor(tokens[1:])
216
+ labels = torch.where(
217
+ torch.isin(
218
+ labels, torch.LongTensor([self.bos_id, self.eos_id, self.image_id])
219
+ ),
220
+ IGNORE_INDEX,
221
+ labels,
222
+ )
223
+
224
+ assert len(input_ids) == len(labels) # TODO(tj.solergibert) Delete
225
+
226
+ sample.update({"tokens": input_ids, "labels": labels})
227
+
228
+ return sample
229
+
230
+
231
+ def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer:
232
+ return TikTokenizer(job_config.model.tokenizer_path)
torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.33 kB). View file
 
torchtitan/experiments/simple_fsdp/tests/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
torchtitan/models/llama3/__pycache__/pipeline_llama.cpython-311.pyc ADDED
Binary file (5.96 kB). View file
 
torchtitan/models/llama3/train_configs/debug_model.toml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torchtitan Config.toml
2
+
3
+ [job]
4
+ dump_folder = "./outputs"
5
+ description = "Llama 3 debug training"
6
+ print_args = false
7
+ use_for_integration_test = true
8
+
9
+ [profiling]
10
+ enable_profiling = false
11
+ save_traces_folder = "profile_trace"
12
+ profile_freq = 10
13
+ enable_memory_snapshot = false
14
+ save_memory_snapshot_folder = "memory_snapshot"
15
+
16
+ [metrics]
17
+ log_freq = 1
18
+ disable_color_printing = false
19
+ enable_tensorboard = false
20
+ save_tb_folder = "tb"
21
+ enable_wandb = false
22
+
23
+ [model]
24
+ name = "llama3"
25
+ flavor = "debugmodel"
26
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
27
+ # test tokenizer.model, for debug purpose only
28
+ tokenizer_path = "./tests/assets/test_tiktoken.model"
29
+ # converters = "float8"
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 8e-4
34
+ eps = 1e-8
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.0
41
+
42
+ [training]
43
+ batch_size = 8
44
+ seq_len = 2048
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
49
+
50
+ [parallelism]
51
+ data_parallel_replicate_degree = 1
52
+ data_parallel_shard_degree = -1
53
+ fsdp_reshard_after_forward = "default" # default / never / always
54
+ tensor_parallel_degree = 1
55
+ enable_async_tensor_parallel = false
56
+ pipeline_parallel_degree = 1
57
+ context_parallel_degree = 1
58
+
59
+ [checkpoint]
60
+ enable_checkpoint = false
61
+ folder = "checkpoint"
62
+ interval = 10
63
+ model_weights_only = false
64
+ export_dtype = "float32"
65
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
66
+
67
+ [activation_checkpoint]
68
+ mode = 'selective' # ['none', 'selective', 'full']
69
+ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy
70
+
71
+ [float8]
72
+ enable_fsdp_float8_all_gather = false
73
+ precompute_float8_dynamic_scale_for_fsdp = false
74
+ filter_fqns = "output"