Add files using upload-large-folder tool
Browse files- LICENSE +21 -0
- torchtitan/components/__pycache__/checkpoint.cpython-311.pyc +0 -0
- torchtitan/components/__pycache__/float8.cpython-311.pyc +0 -0
- torchtitan/components/__pycache__/metrics.cpython-311.pyc +0 -0
- torchtitan/distributed/__pycache__/utils.cpython-311.pyc +0 -0
- torchtitan/distributed/pipeline.py +201 -0
- torchtitan/experiments/deepseek_v3/LICENSE-CODE +21 -0
- torchtitan/experiments/deepseek_v3/README.md +40 -0
- torchtitan/experiments/deepseek_v3/attn_mask_utils.py +397 -0
- torchtitan/experiments/deepseek_v3/download.py +70 -0
- torchtitan/experiments/deepseek_v3/generate.py +308 -0
- torchtitan/experiments/deepseek_v3/indices.py +195 -0
- torchtitan/experiments/deepseek_v3/inference.sh +15 -0
- torchtitan/experiments/deepseek_v3/model_config.py +204 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py +11 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
- torchtitan/experiments/deepseek_v3/train.py +142 -0
- torchtitan/experiments/flux/dataset/flux_dataset.py +267 -0
- torchtitan/experiments/flux/model/layers.py +286 -0
- torchtitan/experiments/flux/model/math.py +38 -0
- torchtitan/experiments/flux/model/model.py +177 -0
- torchtitan/experiments/flux/parallelize_flux.py +26 -0
- torchtitan/experiments/flux/requirements.txt +2 -0
- torchtitan/experiments/flux/scripts/download_autoencoder.py +61 -0
- torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
- torchtitan/experiments/flux/train.py +224 -0
- torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
- torchtitan/experiments/flux/utils.py +203 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py +885 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py +1304 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py +240 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py +82 -0
- torchtitan/experiments/llama4/__init__.py +70 -0
- torchtitan/experiments/llama4/infra/expert_parallel.py +145 -0
- torchtitan/experiments/llama4/infra/parallelize_llama.py +159 -0
- torchtitan/experiments/llama4/model/__pycache__/args.cpython-311.pyc +0 -0
- torchtitan/experiments/llama4/model/args.py +109 -0
- torchtitan/experiments/llama4/model/moe.py +228 -0
- torchtitan/experiments/llama4/scripts/REAME.md +17 -0
- torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh +25 -0
- torchtitan/experiments/llama4/train_configs/debug_model.toml +74 -0
- torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +65 -0
- torchtitan/experiments/multimodal/requirements.txt +1 -0
- torchtitan/experiments/multimodal/tests/__init__.py +5 -0
- torchtitan/experiments/multimodal/tests/test_utils.py +58 -0
- torchtitan/experiments/multimodal/tokenizer/tiktoken.py +232 -0
- torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-311.pyc +0 -0
- torchtitan/experiments/simple_fsdp/tests/__init__.py +5 -0
- torchtitan/models/llama3/__pycache__/pipeline_llama.cpython-311.pyc +0 -0
- torchtitan/models/llama3/train_configs/debug_model.toml +74 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
torchtitan/components/__pycache__/checkpoint.cpython-311.pyc
ADDED
|
Binary file (35.4 kB). View file
|
|
|
torchtitan/components/__pycache__/float8.cpython-311.pyc
ADDED
|
Binary file (6.54 kB). View file
|
|
|
torchtitan/components/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (20.2 kB). View file
|
|
|
torchtitan/distributed/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
torchtitan/distributed/pipeline.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
from typing import Callable, Optional
|
| 9 |
+
|
| 10 |
+
from torch.distributed.pipelining.schedules import (
|
| 11 |
+
_PipelineSchedule,
|
| 12 |
+
_PipelineScheduleRuntime,
|
| 13 |
+
get_schedule_class,
|
| 14 |
+
PipelineScheduleMulti,
|
| 15 |
+
PipelineScheduleSingle,
|
| 16 |
+
)
|
| 17 |
+
from torch.distributed.pipelining.stage import PipelineStage
|
| 18 |
+
|
| 19 |
+
from torchtitan.config_manager import JobConfig
|
| 20 |
+
from torchtitan.tools.logging import logger
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# TODO: It's unclear if this API is general enough to be used by other models.
|
| 27 |
+
# If not, we should move it to a Transformer-specific directory.
|
| 28 |
+
def generate_split_points(
|
| 29 |
+
schedule_str: str,
|
| 30 |
+
layers_per_stage: Optional[int],
|
| 31 |
+
pp_dim: int,
|
| 32 |
+
num_layers: int,
|
| 33 |
+
input_weight: int = 1,
|
| 34 |
+
output_weight: int = 1,
|
| 35 |
+
) -> list[str]:
|
| 36 |
+
"""
|
| 37 |
+
Generate a list of split points based on the number of layers and
|
| 38 |
+
pipeline parallel dimension, ensuring the first and last stages have the least layers.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
schedule_str (str): The string of the schedule name.
|
| 42 |
+
layers_per_stage (int): The number of layers per stage.
|
| 43 |
+
pp_dim (int): The pipeline parallel dimension.
|
| 44 |
+
num_layers (int): The number of layers in the model.
|
| 45 |
+
input_output_weight (int): The number of layers to consider the input/output modules in the layer calculation.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
list[str]: A list of split point FQNs.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
schedule_class = get_schedule_class(schedule_str)
|
| 52 |
+
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
|
| 53 |
+
num_stages_per_rank = 1 if is_single_stage_schedule else 2
|
| 54 |
+
|
| 55 |
+
if layers_per_stage is not None:
|
| 56 |
+
total_stages = math.ceil(num_layers / layers_per_stage)
|
| 57 |
+
if total_stages % pp_dim != 0:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"Number of stages ({total_stages}) must be divisible by the pipeline parallel dimension ({pp_dim})."
|
| 60 |
+
f"Each rank should have the same number of stages. "
|
| 61 |
+
)
|
| 62 |
+
num_stages_per_rank = total_stages // pp_dim
|
| 63 |
+
|
| 64 |
+
if is_single_stage_schedule and num_stages_per_rank != 1:
|
| 65 |
+
raise ValueError(
|
| 66 |
+
f"Number of stages per rank ({num_stages_per_rank}) must be 1 for single stage schedules."
|
| 67 |
+
)
|
| 68 |
+
elif not is_single_stage_schedule and num_stages_per_rank < 2:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"Number of stages per rank ({num_stages_per_rank}) must be >= 2 for multi stage schedules."
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
total_stages = pp_dim * num_stages_per_rank
|
| 74 |
+
if total_stages > num_layers:
|
| 75 |
+
raise ValueError("Total stages cannot be greater than the number of layers")
|
| 76 |
+
|
| 77 |
+
# Calculate effective number of layers including input and output weights
|
| 78 |
+
effective_num_layers = num_layers + input_weight + output_weight
|
| 79 |
+
base_layers_per_stage = effective_num_layers // total_stages
|
| 80 |
+
|
| 81 |
+
splits = [""] * (total_stages - 1)
|
| 82 |
+
current_layer_index = 0
|
| 83 |
+
|
| 84 |
+
# First stage
|
| 85 |
+
layers_on_first_stage = max(0, base_layers_per_stage - input_weight)
|
| 86 |
+
current_layer_index += layers_on_first_stage
|
| 87 |
+
splits[0] = "layers." + str(current_layer_index)
|
| 88 |
+
|
| 89 |
+
# Last stage
|
| 90 |
+
layers_on_last_stage = max(0, base_layers_per_stage - output_weight)
|
| 91 |
+
splits[-1] = "layers." + str(num_layers - layers_on_last_stage)
|
| 92 |
+
|
| 93 |
+
# Middle stages
|
| 94 |
+
remaining_layers = num_layers - layers_on_first_stage - layers_on_last_stage - 1
|
| 95 |
+
middle_stages = len(splits) - 2
|
| 96 |
+
layers_per_middle_stage = remaining_layers // middle_stages
|
| 97 |
+
# split remainder evenly across middle stages
|
| 98 |
+
remainder = remaining_layers % middle_stages
|
| 99 |
+
|
| 100 |
+
for i in range(1, middle_stages + 1):
|
| 101 |
+
current_layer_index += layers_per_middle_stage
|
| 102 |
+
if remainder > 0:
|
| 103 |
+
current_layer_index += 1
|
| 104 |
+
remainder -= 1
|
| 105 |
+
splits[i] = "layers." + str(current_layer_index)
|
| 106 |
+
|
| 107 |
+
logger.info(
|
| 108 |
+
f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} "
|
| 109 |
+
"This may be sub-optimal as the number of layers per stage may be unbalanced."
|
| 110 |
+
)
|
| 111 |
+
return splits
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def build_pipeline_schedule(
|
| 115 |
+
job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable
|
| 116 |
+
) -> _PipelineSchedule:
|
| 117 |
+
"""Builds a pipeline schedule for the given job configuration and stages.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
job_config (JobConfig): The job configuration.
|
| 121 |
+
stages (list[PipelineStage]): The stages to be scheduled.
|
| 122 |
+
loss_fn (Callable): The loss function.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
_PipelineSchedule: The pipeline schedule for the given stages.
|
| 126 |
+
"""
|
| 127 |
+
pp_schedule_csv = job_config.parallelism.pipeline_parallel_schedule_csv
|
| 128 |
+
|
| 129 |
+
# Validate that pp_schedule_csv is a valid path
|
| 130 |
+
if pp_schedule_csv:
|
| 131 |
+
if not os.path.isfile(pp_schedule_csv):
|
| 132 |
+
raise FileNotFoundError(
|
| 133 |
+
f"The specified path {pp_schedule_csv} does not exist or is not a file."
|
| 134 |
+
)
|
| 135 |
+
schedule_class = _PipelineScheduleRuntime
|
| 136 |
+
else:
|
| 137 |
+
schedule_class = get_schedule_class(
|
| 138 |
+
job_config.parallelism.pipeline_parallel_schedule
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
looped_schedule = issubclass(schedule_class, PipelineScheduleMulti)
|
| 142 |
+
microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size
|
| 143 |
+
batch_size = job_config.training.batch_size
|
| 144 |
+
# validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training
|
| 145 |
+
if batch_size % microbatch_size != 0:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. "
|
| 148 |
+
"Update the config arguments for either batch_size or pipeline_parallel_microbatch_size."
|
| 149 |
+
)
|
| 150 |
+
n_microbatches = batch_size // microbatch_size
|
| 151 |
+
# We expect that the number of local stages (`len(stages)`) is the same across all ranks
|
| 152 |
+
num_total_stages = job_config.parallelism.pipeline_parallel_degree * len(stages)
|
| 153 |
+
if n_microbatches < num_total_stages:
|
| 154 |
+
logger.warning(
|
| 155 |
+
f"Number of microbatches ({n_microbatches}) is less than the total number "
|
| 156 |
+
f"of stages ({num_total_stages}) which may result in a bubble in the pipeline."
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
schedule = schedule_class(
|
| 160 |
+
stages if looped_schedule else stages[0],
|
| 161 |
+
n_microbatches=n_microbatches,
|
| 162 |
+
loss_fn=loss_fn,
|
| 163 |
+
)
|
| 164 |
+
logger.info(
|
| 165 |
+
f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} "
|
| 166 |
+
f"with {n_microbatches} microbatches and {num_total_stages} stages."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
if pp_schedule_csv:
|
| 170 |
+
assert schedule_class in [
|
| 171 |
+
PipelineScheduleSingle,
|
| 172 |
+
PipelineScheduleMulti,
|
| 173 |
+
_PipelineScheduleRuntime,
|
| 174 |
+
], (
|
| 175 |
+
"Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), "
|
| 176 |
+
"and _PipelineScheduleRuntime support csv schedules"
|
| 177 |
+
)
|
| 178 |
+
schedule._load_csv(pp_schedule_csv)
|
| 179 |
+
|
| 180 |
+
return schedule
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# TODO(whc) should this be a utility inside torch.pipelining?
|
| 184 |
+
def stage_ids_this_rank(
|
| 185 |
+
pp_rank: int, pp_size: int, num_stages: int, style: str = "loop"
|
| 186 |
+
) -> tuple[int]:
|
| 187 |
+
"""Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule"""
|
| 188 |
+
assert (
|
| 189 |
+
num_stages % pp_size == 0
|
| 190 |
+
), f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}"
|
| 191 |
+
stages_per_rank = num_stages // pp_size
|
| 192 |
+
if style == "loop":
|
| 193 |
+
return tuple(pp_rank + s * pp_size for s in range(stages_per_rank))
|
| 194 |
+
elif style == "v":
|
| 195 |
+
assert (
|
| 196 |
+
stages_per_rank == 2
|
| 197 |
+
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
|
| 198 |
+
stage_v_pairs = list(
|
| 199 |
+
zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))
|
| 200 |
+
)
|
| 201 |
+
return stage_v_pairs[pp_rank]
|
torchtitan/experiments/deepseek_v3/LICENSE-CODE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 DeepSeek
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
torchtitan/experiments/deepseek_v3/README.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Running DeepSeek in Titan (experimental)
|
| 2 |
+
|
| 3 |
+
This folder contains a DeepSeek model supporting v2 and v3 as well as kernels
|
| 4 |
+
and scripts needed to run it.
|
| 5 |
+
|
| 6 |
+
## Inference
|
| 7 |
+
|
| 8 |
+
### Prerequisites:
|
| 9 |
+
|
| 10 |
+
You will need to download a DeepSeek model's weights if you want to run a
|
| 11 |
+
pre-trained checkpoint. We provided a script to download the weights from
|
| 12 |
+
HuggingFace Model Hub:
|
| 13 |
+
```bash
|
| 14 |
+
python download.py [vX]
|
| 15 |
+
```
|
| 16 |
+
where `vX` can be v2 or v3, both are supported. You may be required to create a
|
| 17 |
+
HuggingFace account and log in first.
|
| 18 |
+
|
| 19 |
+
### Running inference:
|
| 20 |
+
|
| 21 |
+
The inference script is in `generate.py`. You can run it with the following
|
| 22 |
+
command:
|
| 23 |
+
```bash
|
| 24 |
+
torchrun --standalone --nproc-per-node 4 generate.py
|
| 25 |
+
```
|
| 26 |
+
This will run inference on the `DeepSeek-V2-Lite-Chat` model using 4 GPUs by
|
| 27 |
+
default.
|
| 28 |
+
|
| 29 |
+
Alternatively, you can run inference by using `bash inference.sh`, optionally
|
| 30 |
+
followed by your prompt.
|
| 31 |
+
|
| 32 |
+
## Training
|
| 33 |
+
|
| 34 |
+
The training script is in `train.py`. You can run it by the following command:
|
| 35 |
+
```bash
|
| 36 |
+
torchrun --standalone --nproc-per-node 8 train.py
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
This will run training on the `DeepSeek-V2-Lite-Chat` model using 8 GPUs by
|
| 40 |
+
default, with pipeline parallel, expert parallel, and data parallel enabled.
|
torchtitan/experiments/deepseek_v3/attn_mask_utils.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This code is based on src/transformers/modeling_attn_mask_utils.py of
|
| 8 |
+
# huggingface/transformers. It has been modified from its original forms to
|
| 9 |
+
# contain only the necessary utilities.
|
| 10 |
+
|
| 11 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 12 |
+
#
|
| 13 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 14 |
+
# you may not use this file except in compliance with the License.
|
| 15 |
+
# You may obtain a copy of the License at
|
| 16 |
+
#
|
| 17 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 18 |
+
#
|
| 19 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 20 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 21 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 22 |
+
# See the License for the specific language governing permissions and
|
| 23 |
+
# limitations under the License.
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from typing import List, Optional, Tuple, Union
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class AttentionMaskConverter:
|
| 32 |
+
"""
|
| 33 |
+
A utility attention mask class that allows one to:
|
| 34 |
+
- Create a causal 4d mask
|
| 35 |
+
- Create a causal 4d mask with slided window
|
| 36 |
+
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
|
| 37 |
+
key_value_length) that can be multiplied with attention scores
|
| 38 |
+
|
| 39 |
+
Examples:
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
>>> import torch
|
| 43 |
+
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 44 |
+
|
| 45 |
+
>>> converter = AttentionMaskConverter(True)
|
| 46 |
+
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
|
| 47 |
+
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 48 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 49 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 50 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
|
| 51 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Parameters:
|
| 55 |
+
is_causal (`bool`):
|
| 56 |
+
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
|
| 57 |
+
|
| 58 |
+
sliding_window (`int`, *optional*):
|
| 59 |
+
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
is_causal: bool
|
| 63 |
+
sliding_window: int
|
| 64 |
+
|
| 65 |
+
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
|
| 66 |
+
self.is_causal = is_causal
|
| 67 |
+
self.sliding_window = sliding_window
|
| 68 |
+
|
| 69 |
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
| 70 |
+
raise ValueError(
|
| 71 |
+
"Make sure that when passing `sliding_window` that its value is a strictly positive integer, "
|
| 72 |
+
f"not `{self.sliding_window}`"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def to_causal_4d(
|
| 76 |
+
self,
|
| 77 |
+
batch_size: int,
|
| 78 |
+
query_length: int,
|
| 79 |
+
key_value_length: int,
|
| 80 |
+
dtype: torch.dtype,
|
| 81 |
+
device: Union[torch.device, "str"] = "cpu",
|
| 82 |
+
) -> Optional[torch.Tensor]:
|
| 83 |
+
"""
|
| 84 |
+
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
|
| 85 |
+
bias to upper right hand triangular matrix (causal mask).
|
| 86 |
+
"""
|
| 87 |
+
if not self.is_causal:
|
| 88 |
+
raise ValueError(
|
| 89 |
+
f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# If shape is not cached, create a new causal mask and cache it
|
| 93 |
+
input_shape = (batch_size, query_length)
|
| 94 |
+
past_key_values_length = key_value_length - query_length
|
| 95 |
+
|
| 96 |
+
# create causal mask
|
| 97 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 98 |
+
causal_4d_mask = None
|
| 99 |
+
if input_shape[-1] > 1 or self.sliding_window is not None:
|
| 100 |
+
causal_4d_mask = self._make_causal_mask(
|
| 101 |
+
input_shape,
|
| 102 |
+
dtype,
|
| 103 |
+
device=device,
|
| 104 |
+
past_key_values_length=past_key_values_length,
|
| 105 |
+
sliding_window=self.sliding_window,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
return causal_4d_mask
|
| 109 |
+
|
| 110 |
+
def to_4d(
|
| 111 |
+
self,
|
| 112 |
+
attention_mask_2d: torch.Tensor,
|
| 113 |
+
query_length: int,
|
| 114 |
+
dtype: torch.dtype,
|
| 115 |
+
key_value_length: Optional[int] = None,
|
| 116 |
+
) -> torch.Tensor:
|
| 117 |
+
"""
|
| 118 |
+
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
| 119 |
+
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
|
| 120 |
+
causal, a causal mask will be added.
|
| 121 |
+
"""
|
| 122 |
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
| 123 |
+
|
| 124 |
+
# create causal mask
|
| 125 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 126 |
+
causal_4d_mask = None
|
| 127 |
+
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
| 128 |
+
if key_value_length is None:
|
| 129 |
+
raise ValueError(
|
| 130 |
+
"This attention mask converter is causal. Make sure to pass "
|
| 131 |
+
"`key_value_length` to correctly create a causal mask."
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
past_key_values_length = key_value_length - query_length
|
| 135 |
+
causal_4d_mask = self._make_causal_mask(
|
| 136 |
+
input_shape,
|
| 137 |
+
dtype,
|
| 138 |
+
device=attention_mask_2d.device,
|
| 139 |
+
past_key_values_length=past_key_values_length,
|
| 140 |
+
sliding_window=self.sliding_window,
|
| 141 |
+
)
|
| 142 |
+
elif self.sliding_window is not None:
|
| 143 |
+
raise NotImplementedError(
|
| 144 |
+
"Sliding window is currently only implemented for causal masking"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 148 |
+
expanded_attn_mask = self._expand_mask(
|
| 149 |
+
attention_mask_2d, dtype, tgt_len=input_shape[-1]
|
| 150 |
+
).to(attention_mask_2d.device)
|
| 151 |
+
|
| 152 |
+
if causal_4d_mask is not None:
|
| 153 |
+
expanded_attn_mask = causal_4d_mask.masked_fill(
|
| 154 |
+
expanded_attn_mask.bool(), torch.finfo(dtype).min
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# expanded_attn_mask + causal_4d_mask can cause some overflow
|
| 158 |
+
expanded_4d_mask = expanded_attn_mask
|
| 159 |
+
|
| 160 |
+
return expanded_4d_mask
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _make_causal_mask(
|
| 164 |
+
input_ids_shape: torch.Size,
|
| 165 |
+
dtype: torch.dtype,
|
| 166 |
+
device: torch.device,
|
| 167 |
+
past_key_values_length: int = 0,
|
| 168 |
+
sliding_window: Optional[int] = None,
|
| 169 |
+
):
|
| 170 |
+
"""
|
| 171 |
+
Make causal mask used for bi-directional self-attention.
|
| 172 |
+
"""
|
| 173 |
+
bsz, tgt_len = input_ids_shape
|
| 174 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
| 175 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
| 176 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 177 |
+
|
| 178 |
+
mask = mask.to(dtype)
|
| 179 |
+
|
| 180 |
+
if past_key_values_length > 0:
|
| 181 |
+
mask = torch.cat(
|
| 182 |
+
[
|
| 183 |
+
torch.zeros(
|
| 184 |
+
tgt_len, past_key_values_length, dtype=dtype, device=device
|
| 185 |
+
),
|
| 186 |
+
mask,
|
| 187 |
+
],
|
| 188 |
+
dim=-1,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# add lower triangular sliding window mask if necessary
|
| 192 |
+
if sliding_window is not None:
|
| 193 |
+
diagonal = past_key_values_length - sliding_window - 1
|
| 194 |
+
|
| 195 |
+
context_mask = torch.tril(
|
| 196 |
+
torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
|
| 197 |
+
)
|
| 198 |
+
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
|
| 199 |
+
|
| 200 |
+
return mask[None, None, :, :].expand(
|
| 201 |
+
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
def _expand_mask(
|
| 206 |
+
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
| 207 |
+
):
|
| 208 |
+
"""
|
| 209 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 210 |
+
"""
|
| 211 |
+
bsz, src_len = mask.size()
|
| 212 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 213 |
+
|
| 214 |
+
expanded_mask = (
|
| 215 |
+
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
inverted_mask = 1.0 - expanded_mask
|
| 219 |
+
|
| 220 |
+
return inverted_mask.masked_fill(
|
| 221 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
@staticmethod
|
| 225 |
+
def _unmask_unattended(
|
| 226 |
+
expanded_mask: torch.FloatTensor,
|
| 227 |
+
min_dtype: float,
|
| 228 |
+
):
|
| 229 |
+
# fmt: off
|
| 230 |
+
"""
|
| 231 |
+
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
|
| 232 |
+
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 233 |
+
Details: https://github.com/pytorch/pytorch/issues/110213
|
| 234 |
+
|
| 235 |
+
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
|
| 236 |
+
`attention_mask` is [bsz, src_seq_len].
|
| 237 |
+
|
| 238 |
+
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case
|
| 239 |
+
of alibi attention bias.
|
| 240 |
+
|
| 241 |
+
For example, if `expanded_mask` is (e.g. here left-padding case)
|
| 242 |
+
```
|
| 243 |
+
[[[[0, 0, 0],
|
| 244 |
+
[0, 0, 0],
|
| 245 |
+
[0, 0, 1]]],
|
| 246 |
+
[[[1, 0, 0],
|
| 247 |
+
[1, 1, 0],
|
| 248 |
+
[1, 1, 1]]],
|
| 249 |
+
[[[0, 0, 0],
|
| 250 |
+
[0, 1, 0],
|
| 251 |
+
[0, 1, 1]]]]
|
| 252 |
+
```
|
| 253 |
+
then the modified `expanded_mask` will be
|
| 254 |
+
```
|
| 255 |
+
[[[[1, 1, 1], <-- modified
|
| 256 |
+
[1, 1, 1], <-- modified
|
| 257 |
+
[0, 0, 1]]],
|
| 258 |
+
[[[1, 0, 0],
|
| 259 |
+
[1, 1, 0],
|
| 260 |
+
[1, 1, 1]]],
|
| 261 |
+
[[[1, 1, 1], <-- modified
|
| 262 |
+
[0, 1, 0],
|
| 263 |
+
[0, 1, 1]]]]
|
| 264 |
+
```
|
| 265 |
+
"""
|
| 266 |
+
# fmt: on
|
| 267 |
+
if expanded_mask.dtype == torch.bool:
|
| 268 |
+
raise ValueError(
|
| 269 |
+
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
return expanded_mask.mul(
|
| 273 |
+
~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
def _ignore_causal_mask_sdpa(
|
| 278 |
+
attention_mask: Optional[torch.Tensor],
|
| 279 |
+
inputs_embeds: torch.Tensor,
|
| 280 |
+
past_key_values_length: int,
|
| 281 |
+
sliding_window: Optional[int] = None,
|
| 282 |
+
is_training: bool = False,
|
| 283 |
+
) -> bool:
|
| 284 |
+
"""
|
| 285 |
+
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
|
| 286 |
+
ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
|
| 287 |
+
|
| 288 |
+
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
|
| 289 |
+
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
|
| 290 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
|
| 291 |
+
passed).
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
| 295 |
+
key_value_length = query_length + past_key_values_length
|
| 296 |
+
|
| 297 |
+
is_tracing = (
|
| 298 |
+
torch.jit.is_tracing()
|
| 299 |
+
or isinstance(inputs_embeds, torch.fx.Proxy)
|
| 300 |
+
or is_torchdynamo_compiling()
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
ignore_causal_mask = False
|
| 304 |
+
|
| 305 |
+
if attention_mask is None:
|
| 306 |
+
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
|
| 307 |
+
# shape, thus SDPA's `is_causal` argument is rightfully updated
|
| 308 |
+
# (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
|
| 309 |
+
# `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
|
| 310 |
+
# hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
|
| 311 |
+
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
|
| 312 |
+
# Thus, we only set `ignore_causal_mask = True` if the model is set to training.
|
| 313 |
+
#
|
| 314 |
+
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
|
| 315 |
+
# ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
|
| 316 |
+
if (
|
| 317 |
+
(is_training or not is_tracing)
|
| 318 |
+
and (query_length == 1 or key_value_length == query_length)
|
| 319 |
+
and (sliding_window is None or key_value_length < sliding_window)
|
| 320 |
+
):
|
| 321 |
+
ignore_causal_mask = True
|
| 322 |
+
elif sliding_window is None or key_value_length < sliding_window:
|
| 323 |
+
if len(attention_mask.shape) == 4:
|
| 324 |
+
return False
|
| 325 |
+
elif not is_tracing and torch.all(attention_mask == 1):
|
| 326 |
+
if query_length == 1 or key_value_length == query_length:
|
| 327 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
| 328 |
+
ignore_causal_mask = True
|
| 329 |
+
|
| 330 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
|
| 331 |
+
# the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
|
| 332 |
+
# SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
| 333 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
| 334 |
+
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
|
| 335 |
+
|
| 336 |
+
return ignore_causal_mask
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def _prepare_4d_causal_attention_mask(
|
| 340 |
+
attention_mask: Optional[torch.Tensor],
|
| 341 |
+
input_shape: Union[torch.Size, Tuple, List],
|
| 342 |
+
inputs_embeds: torch.Tensor,
|
| 343 |
+
past_key_values_length: int,
|
| 344 |
+
sliding_window: Optional[int] = None,
|
| 345 |
+
):
|
| 346 |
+
"""
|
| 347 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 348 |
+
`(batch_size, key_value_length)`
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
attention_mask (`torch.Tensor` or `None`):
|
| 352 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
| 353 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
| 354 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
| 355 |
+
inputs_embeds (`torch.Tensor`):
|
| 356 |
+
The embedded inputs as a torch Tensor.
|
| 357 |
+
past_key_values_length (`int`):
|
| 358 |
+
The length of the key value cache.
|
| 359 |
+
sliding_window (`int`, *optional*):
|
| 360 |
+
If the model uses windowed attention, a sliding window should be passed.
|
| 361 |
+
"""
|
| 362 |
+
attn_mask_converter = AttentionMaskConverter(
|
| 363 |
+
is_causal=True, sliding_window=sliding_window
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
| 367 |
+
|
| 368 |
+
# 4d mask is passed through the layers
|
| 369 |
+
if attention_mask is not None and len(attention_mask.shape) == 2:
|
| 370 |
+
attention_mask = attn_mask_converter.to_4d(
|
| 371 |
+
attention_mask,
|
| 372 |
+
input_shape[-1],
|
| 373 |
+
key_value_length=key_value_length,
|
| 374 |
+
dtype=inputs_embeds.dtype,
|
| 375 |
+
)
|
| 376 |
+
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
| 377 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
| 378 |
+
if tuple(attention_mask.shape) != expected_shape:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
| 381 |
+
)
|
| 382 |
+
else:
|
| 383 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
| 384 |
+
inverted_mask = 1.0 - attention_mask
|
| 385 |
+
attention_mask = inverted_mask.masked_fill(
|
| 386 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
| 387 |
+
)
|
| 388 |
+
else:
|
| 389 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
| 390 |
+
input_shape[0],
|
| 391 |
+
input_shape[-1],
|
| 392 |
+
key_value_length,
|
| 393 |
+
dtype=inputs_embeds.dtype,
|
| 394 |
+
device=inputs_embeds.device,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
return attention_mask
|
torchtitan/experiments/deepseek_v3/download.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Usage:
|
| 8 |
+
# Downloads a given model to the HF Cache. Pass in a listed option ala "v3" or your own custom model path.
|
| 9 |
+
# python download.py {model_id} [custom_model_path]
|
| 10 |
+
# Examples:
|
| 11 |
+
# python download.py v2 # Use predefined model: deepseek-ai/DeepSeek-V2
|
| 12 |
+
# python download.py custom "deepseek-ai/new-model" # Download a custom model path
|
| 13 |
+
|
| 14 |
+
# Available models:
|
| 15 |
+
# "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
| 16 |
+
# "v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
|
| 17 |
+
# "v2": "deepseek-ai/DeepSeek-V2",
|
| 18 |
+
# "v3": "deepseek-ai/deepseek-v3",
|
| 19 |
+
# "v3-0324": "deepseek-ai/DeepSeek-V3-0324",
|
| 20 |
+
# "custom": None, # Placeholder for custom models
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
from transformers import AutoModelForCausalLM
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
MODELS = {
|
| 29 |
+
"v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
| 30 |
+
"v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
|
| 31 |
+
"v2": "deepseek-ai/DeepSeek-V2",
|
| 32 |
+
"v3": "deepseek-ai/deepseek-v3",
|
| 33 |
+
"v3-0324": "deepseek-ai/DeepSeek-V3-0324",
|
| 34 |
+
"custom": None, # For custom (any) models
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def print_usage():
|
| 39 |
+
print("Usage:")
|
| 40 |
+
print(" python download.py [model_version]")
|
| 41 |
+
print(" python download.py custom [custom_model_path]")
|
| 42 |
+
print("\nAvailable predefined models:")
|
| 43 |
+
for key, model in MODELS.items():
|
| 44 |
+
if key != "custom": # Skip the custom placeholder
|
| 45 |
+
print(f" {key}: {model}")
|
| 46 |
+
print("\nFor custom models:")
|
| 47 |
+
print(" custom: Specify your own model path")
|
| 48 |
+
print(' Example: python download.py custom "organization/model-name"')
|
| 49 |
+
sys.exit(1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Process command line arguments
|
| 53 |
+
if len(sys.argv) < 2 or sys.argv[1] not in MODELS:
|
| 54 |
+
print_usage()
|
| 55 |
+
|
| 56 |
+
if sys.argv[1] == "custom":
|
| 57 |
+
if len(sys.argv) != 3:
|
| 58 |
+
print("Error: Custom model requires a model path")
|
| 59 |
+
print_usage()
|
| 60 |
+
model_id = sys.argv[2]
|
| 61 |
+
print(f"Using custom model: {model_id}")
|
| 62 |
+
else:
|
| 63 |
+
model_id = MODELS[sys.argv[1]]
|
| 64 |
+
print(f"Downloading model: {model_id}")
|
| 65 |
+
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
model_id,
|
| 68 |
+
device_map="auto",
|
| 69 |
+
trust_remote_code=True,
|
| 70 |
+
)
|
torchtitan/experiments/deepseek_v3/generate.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# torchrun --standalone --nproc-per-node 4 generate.py
|
| 8 |
+
|
| 9 |
+
# use inference.sh "Your Question Here?" to run inference with a single prompt.
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
|
| 17 |
+
from checkpoint import load_weights_from_hf
|
| 18 |
+
from model import DeepseekForCausalLM
|
| 19 |
+
from model_config import deepseek_config_registry
|
| 20 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 21 |
+
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
|
| 22 |
+
from torchtitan.tools.utils import Color
|
| 23 |
+
from transformers import AutoTokenizer
|
| 24 |
+
|
| 25 |
+
# Uncomment the model you want to run.
|
| 26 |
+
model_id, mesh_shape = "deepseek-ai/DeepSeek-V2-Lite-Chat", (1, 4)
|
| 27 |
+
# model_id, mesh_shape = "deepseek-ai/deepseek-v3", (8, 4)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def colorize_chat(text, user_color=None, assistant_color=None, output_color=None):
|
| 31 |
+
"""Parse and colorize chat output with optional colors for each role."""
|
| 32 |
+
lines = text.split("\n")
|
| 33 |
+
result = []
|
| 34 |
+
|
| 35 |
+
current_role = None
|
| 36 |
+
current_content = []
|
| 37 |
+
|
| 38 |
+
def _process_current_content():
|
| 39 |
+
if not current_role or not current_content:
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
content = "\n".join(current_content)
|
| 43 |
+
if current_role == "output":
|
| 44 |
+
return (
|
| 45 |
+
f"Output: {output_color}{content}{color.reset}"
|
| 46 |
+
if output_color
|
| 47 |
+
else f"Output: {content}"
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
try:
|
| 51 |
+
prefix, rest = current_content[0].split(":", 1)
|
| 52 |
+
role_color = user_color if current_role == "user" else assistant_color
|
| 53 |
+
if role_color:
|
| 54 |
+
formatted = f"{prefix}:{role_color}{rest}{color.reset}"
|
| 55 |
+
if len(current_content) > 1:
|
| 56 |
+
formatted += (
|
| 57 |
+
f"{role_color}\n"
|
| 58 |
+
+ "\n".join(current_content[1:])
|
| 59 |
+
+ f"{color.reset}"
|
| 60 |
+
)
|
| 61 |
+
return formatted
|
| 62 |
+
except ValueError:
|
| 63 |
+
pass
|
| 64 |
+
return content
|
| 65 |
+
|
| 66 |
+
for line in lines:
|
| 67 |
+
if line.startswith("Output:"):
|
| 68 |
+
if processed := _process_current_content():
|
| 69 |
+
result.append(processed)
|
| 70 |
+
current_role = "output"
|
| 71 |
+
content = line[len("Output:") :].strip()
|
| 72 |
+
if output_color:
|
| 73 |
+
content = f"Output: {output_color}{content}{color.reset}"
|
| 74 |
+
else:
|
| 75 |
+
content = f"Output: {content}"
|
| 76 |
+
result.append(content)
|
| 77 |
+
current_content = []
|
| 78 |
+
|
| 79 |
+
elif line.startswith("User:"):
|
| 80 |
+
if processed := _process_current_content():
|
| 81 |
+
result.append(processed)
|
| 82 |
+
current_role = "user"
|
| 83 |
+
current_content = [line]
|
| 84 |
+
|
| 85 |
+
elif line.startswith("Assistant:"):
|
| 86 |
+
if processed := _process_current_content():
|
| 87 |
+
result.append(processed)
|
| 88 |
+
current_role = "assistant"
|
| 89 |
+
current_content = [line]
|
| 90 |
+
|
| 91 |
+
else:
|
| 92 |
+
if current_content:
|
| 93 |
+
current_content.append(line)
|
| 94 |
+
elif line.strip() and current_role is None:
|
| 95 |
+
# Handle system message at the beginning
|
| 96 |
+
current_role = "output"
|
| 97 |
+
if output_color:
|
| 98 |
+
result.append(f"Output: {output_color}{line.strip()}{color.reset}")
|
| 99 |
+
else:
|
| 100 |
+
result.append(f"Output: {line.strip()}")
|
| 101 |
+
|
| 102 |
+
# Process the last segment
|
| 103 |
+
if processed := _process_current_content():
|
| 104 |
+
result.append(processed)
|
| 105 |
+
|
| 106 |
+
return "\n".join(result)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
color = Color()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dataclass
|
| 113 |
+
class DistConfig:
|
| 114 |
+
mesh: DeviceMesh
|
| 115 |
+
pp_mesh: DeviceMesh
|
| 116 |
+
ep_mesh: DeviceMesh
|
| 117 |
+
pp_size: int
|
| 118 |
+
ep_size: int
|
| 119 |
+
ep_rank: int
|
| 120 |
+
pp_rank: int
|
| 121 |
+
device: torch.device
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def create_model(dist_config: DistConfig):
|
| 125 |
+
model_args = deepseek_config_registry[model_id]
|
| 126 |
+
model_args.ep_size = dist_config.ep_size
|
| 127 |
+
model_args.num_stages = dist_config.pp_size
|
| 128 |
+
model_args.stage_idx = dist_config.pp_rank
|
| 129 |
+
model_args.max_seq_len = 16384
|
| 130 |
+
|
| 131 |
+
with dist_config.device, dist_config.mesh:
|
| 132 |
+
model = DeepseekForCausalLM(model_args)
|
| 133 |
+
load_weights_from_hf(model, model_id, dist_config.device)
|
| 134 |
+
model.eval()
|
| 135 |
+
model.setup_symm_mem(torch.bfloat16, dist_config.device)
|
| 136 |
+
|
| 137 |
+
stage = PipelineStage(
|
| 138 |
+
model,
|
| 139 |
+
dist_config.pp_rank,
|
| 140 |
+
dist_config.pp_size,
|
| 141 |
+
dist_config.device,
|
| 142 |
+
group=dist_config.pp_mesh.get_group(),
|
| 143 |
+
)
|
| 144 |
+
pp_schedule = ScheduleGPipe(stage, dist_config.pp_size)
|
| 145 |
+
return model, pp_schedule
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def create_dist_config(mesh: DeviceMesh):
|
| 149 |
+
rank = dist.get_rank()
|
| 150 |
+
device_count = torch.cuda.device_count()
|
| 151 |
+
device = torch.device("cuda", rank % device_count)
|
| 152 |
+
|
| 153 |
+
dist_config = DistConfig(
|
| 154 |
+
mesh=mesh,
|
| 155 |
+
pp_mesh=mesh["pp"],
|
| 156 |
+
ep_mesh=mesh["ep"],
|
| 157 |
+
pp_rank=mesh["pp"].get_local_rank(),
|
| 158 |
+
pp_size=mesh["pp"].size(),
|
| 159 |
+
ep_size=mesh["ep"].size(),
|
| 160 |
+
ep_rank=mesh["ep"].get_local_rank(),
|
| 161 |
+
device=device,
|
| 162 |
+
)
|
| 163 |
+
return dist_config
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def decode(tokenizer, x):
|
| 167 |
+
output = tokenizer.decode(x[0])
|
| 168 |
+
# Clean up the output by removing special tokens
|
| 169 |
+
bos = tokenizer.bos_token
|
| 170 |
+
output = output.replace(bos, "")
|
| 171 |
+
# Truncate at end of sentence token
|
| 172 |
+
eos_token = tokenizer.eos_token
|
| 173 |
+
if eos_token and eos_token in output:
|
| 174 |
+
output = output.split(eos_token)[0]
|
| 175 |
+
colored_output = colorize_chat(
|
| 176 |
+
output,
|
| 177 |
+
user_color=color.green,
|
| 178 |
+
assistant_color=color.cyan,
|
| 179 |
+
output_color=color.blue,
|
| 180 |
+
)
|
| 181 |
+
return colored_output
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@torch.inference_mode()
|
| 185 |
+
def generate(
|
| 186 |
+
model,
|
| 187 |
+
pp_schedule,
|
| 188 |
+
tokenizer,
|
| 189 |
+
dist_config,
|
| 190 |
+
messages: list[dict],
|
| 191 |
+
n_tokens: int = 50,
|
| 192 |
+
):
|
| 193 |
+
rank = dist.get_rank()
|
| 194 |
+
device = dist_config.device
|
| 195 |
+
x = tokenizer.apply_chat_template(
|
| 196 |
+
[messages] * dist_config.pp_size,
|
| 197 |
+
add_generation_prompt=True,
|
| 198 |
+
return_tensors="pt",
|
| 199 |
+
)
|
| 200 |
+
next_idx = x.shape[-1]
|
| 201 |
+
x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
|
| 202 |
+
x = x.to(device)
|
| 203 |
+
|
| 204 |
+
for _ in range(n_tokens):
|
| 205 |
+
if dist_config.pp_size > 1:
|
| 206 |
+
if dist_config.pp_rank == 0:
|
| 207 |
+
pp_schedule.step(x)
|
| 208 |
+
torch.distributed.broadcast(
|
| 209 |
+
x,
|
| 210 |
+
group=dist_config.pp_mesh.get_group(),
|
| 211 |
+
group_src=dist_config.pp_size - 1,
|
| 212 |
+
)
|
| 213 |
+
elif dist_config.pp_rank == dist_config.pp_size - 1:
|
| 214 |
+
preds = pp_schedule.step()
|
| 215 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
| 216 |
+
x[:, next_idx] = next_token
|
| 217 |
+
torch.distributed.broadcast(
|
| 218 |
+
x,
|
| 219 |
+
group=dist_config.pp_mesh.get_group(),
|
| 220 |
+
group_src=dist_config.pp_size - 1,
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
pp_schedule.step()
|
| 224 |
+
torch.distributed.broadcast(
|
| 225 |
+
x,
|
| 226 |
+
group=dist_config.pp_mesh.get_group(),
|
| 227 |
+
group_src=dist_config.pp_size - 1,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
next_idx += 1
|
| 231 |
+
else:
|
| 232 |
+
preds = model(x)
|
| 233 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
| 234 |
+
x[:, next_idx] = next_token
|
| 235 |
+
next_idx += 1
|
| 236 |
+
|
| 237 |
+
if rank == 0:
|
| 238 |
+
colored_output = decode(tokenizer, x)
|
| 239 |
+
print(f"Without CUDA Graph:\n{colored_output}")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@torch.inference_mode()
|
| 243 |
+
def generate_with_cuda_graph(
|
| 244 |
+
model,
|
| 245 |
+
tokenizer,
|
| 246 |
+
dist_config,
|
| 247 |
+
messages: list[dict],
|
| 248 |
+
n_tokens: int = 10,
|
| 249 |
+
):
|
| 250 |
+
rank = dist.get_rank()
|
| 251 |
+
device = dist_config.device
|
| 252 |
+
x = tokenizer.apply_chat_template(
|
| 253 |
+
[messages] * dist_config.pp_size,
|
| 254 |
+
add_generation_prompt=True,
|
| 255 |
+
return_tensors="pt",
|
| 256 |
+
)
|
| 257 |
+
next_idx = x.shape[-1]
|
| 258 |
+
x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
|
| 259 |
+
x = x.to(device)
|
| 260 |
+
|
| 261 |
+
torch.cuda.synchronize()
|
| 262 |
+
|
| 263 |
+
# Create CUDA graph
|
| 264 |
+
g = torch.cuda.CUDAGraph()
|
| 265 |
+
with torch.cuda.graph(g):
|
| 266 |
+
preds = model(x)
|
| 267 |
+
|
| 268 |
+
# Run CUDA graph
|
| 269 |
+
for _ in range(n_tokens):
|
| 270 |
+
g.replay()
|
| 271 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
| 272 |
+
x[:, next_idx] = next_token
|
| 273 |
+
next_idx += 1
|
| 274 |
+
|
| 275 |
+
if rank == 0:
|
| 276 |
+
colored_output = decode(tokenizer, x)
|
| 277 |
+
print(f"With CUDA Graph:\n{colored_output}")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
# Get user prompt from command line arguments
|
| 282 |
+
user_prompt = "What is 2+2?" # Default prompt
|
| 283 |
+
if len(sys.argv) > 1:
|
| 284 |
+
user_prompt = sys.argv[1]
|
| 285 |
+
|
| 286 |
+
mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=("pp", "ep"))
|
| 287 |
+
rank = dist.get_rank()
|
| 288 |
+
if rank == 0:
|
| 289 |
+
print(
|
| 290 |
+
f"{color.yellow}Running inference with {model_id} on {mesh_shape} mesh{color.reset}"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
dist_config = create_dist_config(mesh)
|
| 294 |
+
model, pp_schedule = create_model(dist_config)
|
| 295 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 296 |
+
|
| 297 |
+
messages = [
|
| 298 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 299 |
+
{"role": "user", "content": user_prompt},
|
| 300 |
+
]
|
| 301 |
+
|
| 302 |
+
generate(model, pp_schedule, tokenizer, dist_config, messages)
|
| 303 |
+
generate_with_cuda_graph(model, tokenizer, dist_config, messages)
|
| 304 |
+
|
| 305 |
+
if rank == 0:
|
| 306 |
+
print(f"\n{color.yellow}Closing inference mesh...{color.reset}")
|
| 307 |
+
|
| 308 |
+
dist.destroy_process_group()
|
torchtitan/experiments/deepseek_v3/indices.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__all__ = ["generate_permute_indices"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@triton.jit
|
| 16 |
+
def fill_indices_kernel(
|
| 17 |
+
tokens_per_expert_group_ptr, # *Pointer* to first input vector.
|
| 18 |
+
start_index_values_ptr, # *Pointer* to second input vector.
|
| 19 |
+
write_offsets_ptr, # *Pointer* to third input vector.
|
| 20 |
+
output_ptr, # *Pointer* to output vector.
|
| 21 |
+
experts_per_rank, # Number of experts per rank.
|
| 22 |
+
num_ranks, # Number of expert ranks.
|
| 23 |
+
):
|
| 24 |
+
# There are multiple 'programs' processing different data. We identify which program
|
| 25 |
+
# we are here:
|
| 26 |
+
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
|
| 27 |
+
# The total number of programs in the launch grid.
|
| 28 |
+
num_programs = tl.num_programs(axis=0)
|
| 29 |
+
# We map the programs (blocks) to the experts.
|
| 30 |
+
for expert_id in tl.range(pid, experts_per_rank, step=num_programs):
|
| 31 |
+
# Read this expert's write offset.
|
| 32 |
+
write_offset = tl.load(write_offsets_ptr + expert_id)
|
| 33 |
+
# Loop over the ranks.
|
| 34 |
+
for r in tl.range(num_ranks):
|
| 35 |
+
# Slot in the tokens_per_expert_group array.
|
| 36 |
+
i = r * experts_per_rank + expert_id
|
| 37 |
+
start_index = tl.load(start_index_values_ptr + i)
|
| 38 |
+
length = tl.load(tokens_per_expert_group_ptr + i)
|
| 39 |
+
# Write the indices.
|
| 40 |
+
for l in tl.range(length):
|
| 41 |
+
val = start_index + l
|
| 42 |
+
tl.store(output_ptr + write_offset + l, val)
|
| 43 |
+
write_offset += length
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def fill_indices(
|
| 47 |
+
tokens_per_expert_group: torch.Tensor,
|
| 48 |
+
start_index_values: torch.Tensor,
|
| 49 |
+
write_offsets: torch.Tensor,
|
| 50 |
+
experts_per_rank: int,
|
| 51 |
+
num_ranks: int,
|
| 52 |
+
max_len: int,
|
| 53 |
+
):
|
| 54 |
+
# We need to preallocate the output.
|
| 55 |
+
permuted_indices = torch.full(
|
| 56 |
+
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
|
| 57 |
+
)
|
| 58 |
+
# Analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
|
| 59 |
+
# In this case, we use a 1D grid where the size is the number of blocks (TODO: bump this value).
|
| 60 |
+
grid = lambda meta: (1,)
|
| 61 |
+
# Each torch.tensor object is implicitly converted into a pointer to its first element.
|
| 62 |
+
fill_indices_kernel[grid](
|
| 63 |
+
tokens_per_expert_group,
|
| 64 |
+
start_index_values,
|
| 65 |
+
write_offsets,
|
| 66 |
+
permuted_indices,
|
| 67 |
+
experts_per_rank,
|
| 68 |
+
num_ranks,
|
| 69 |
+
)
|
| 70 |
+
return permuted_indices
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def fill_indices_cpu(
|
| 74 |
+
tokens_per_expert_group: torch.Tensor,
|
| 75 |
+
start_index_values: torch.Tensor,
|
| 76 |
+
write_offsets: torch.Tensor,
|
| 77 |
+
experts_per_rank: int,
|
| 78 |
+
num_ranks: int,
|
| 79 |
+
max_len: int,
|
| 80 |
+
):
|
| 81 |
+
# We need to preallocate the output.
|
| 82 |
+
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
|
| 83 |
+
# Fill the permuted indices
|
| 84 |
+
# For each local expert
|
| 85 |
+
for e in range(experts_per_rank):
|
| 86 |
+
write_start = write_offsets[e]
|
| 87 |
+
# For each remote rank
|
| 88 |
+
for r in range(num_ranks):
|
| 89 |
+
i = r * experts_per_rank + e
|
| 90 |
+
start_index = start_index_values[i]
|
| 91 |
+
length = tokens_per_expert_group[i]
|
| 92 |
+
# Fill in the indices
|
| 93 |
+
permuted_indices[write_start : write_start + length] = torch.arange(
|
| 94 |
+
start_index, start_index + length
|
| 95 |
+
)
|
| 96 |
+
write_start += length
|
| 97 |
+
return permuted_indices
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def generate_permute_indices(
|
| 101 |
+
tokens_per_expert_group: torch.Tensor,
|
| 102 |
+
experts_per_rank: int,
|
| 103 |
+
num_ranks: int,
|
| 104 |
+
max_len: int,
|
| 105 |
+
alignment: int,
|
| 106 |
+
use_cpu: bool = False,
|
| 107 |
+
):
|
| 108 |
+
# Prepare permutation indices and the number of tokens for each expert. The
|
| 109 |
+
# permutation indices are the indices of the tokens for each expert. The
|
| 110 |
+
# number of tokens for each expert is the sum of the number of tokens for
|
| 111 |
+
# such experts from all ranks. This number is aligned to the provided
|
| 112 |
+
# alignment requirement (usually comes from group gemm).
|
| 113 |
+
|
| 114 |
+
# Args:
|
| 115 |
+
# tokens_per_expert_group: number of tokens for each expert from all ranks.
|
| 116 |
+
# experts_per_rank: number of experts per rank.
|
| 117 |
+
# num_ranks: number of ranks.
|
| 118 |
+
# max_len: maximum length of the output index vector. If greater than
|
| 119 |
+
# total number of tokens, the remaining indices are set to -1.
|
| 120 |
+
# alignment: alignment for each returned element in `m_sizes`.
|
| 121 |
+
# use_cpu: whether to use cpu or gpu.
|
| 122 |
+
# Returns:
|
| 123 |
+
# permuted_indices: permutation indices.
|
| 124 |
+
# m_sizes: number of tokens for each expert.
|
| 125 |
+
|
| 126 |
+
# `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example:
|
| 127 |
+
# From: | rank 0 | rank 1 |
|
| 128 |
+
# To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 |
|
| 129 |
+
# | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
|
| 130 |
+
|
| 131 |
+
# Prefix sum to get the start index value of each expert
|
| 132 |
+
start_index_values = (
|
| 133 |
+
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
|
| 134 |
+
)
|
| 135 |
+
# Chunk sizes for each expert
|
| 136 |
+
chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
|
| 137 |
+
# Align the chunk sizes to the given alignment
|
| 138 |
+
m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to(
|
| 139 |
+
torch.int32
|
| 140 |
+
)
|
| 141 |
+
# Perform another prefix sum to get the write offset of each expert in `permuted_indices`
|
| 142 |
+
write_offsets = torch.cumsum(m_sizes, 0) - m_sizes
|
| 143 |
+
# Select the method to fill the permuted indices
|
| 144 |
+
fill_fn = fill_indices_cpu if use_cpu else fill_indices
|
| 145 |
+
# Fill the permuted indices
|
| 146 |
+
permuted_indices = fill_fn(
|
| 147 |
+
tokens_per_expert_group,
|
| 148 |
+
start_index_values,
|
| 149 |
+
write_offsets,
|
| 150 |
+
experts_per_rank,
|
| 151 |
+
num_ranks,
|
| 152 |
+
max_len,
|
| 153 |
+
)
|
| 154 |
+
return permuted_indices, m_sizes
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Below is for testing only
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def test():
|
| 161 |
+
device = torch.device("cuda", 0)
|
| 162 |
+
experts_per_rank = 4
|
| 163 |
+
num_ranks = 4
|
| 164 |
+
tokens_per_expert_group = torch.full(
|
| 165 |
+
(num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device
|
| 166 |
+
)
|
| 167 |
+
max_len = 128
|
| 168 |
+
alignment = 32
|
| 169 |
+
# Use the GPU kernel
|
| 170 |
+
permuted_indices_gpu, m_sizes = generate_permute_indices(
|
| 171 |
+
tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
|
| 172 |
+
)
|
| 173 |
+
# Use the CPU method
|
| 174 |
+
permuted_indices_cpu, _ = generate_permute_indices(
|
| 175 |
+
tokens_per_expert_group,
|
| 176 |
+
experts_per_rank,
|
| 177 |
+
num_ranks,
|
| 178 |
+
max_len,
|
| 179 |
+
alignment,
|
| 180 |
+
use_cpu=True,
|
| 181 |
+
)
|
| 182 |
+
# Check that the results are the same
|
| 183 |
+
assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu)
|
| 184 |
+
assert torch.equal(
|
| 185 |
+
torch.remainder(m_sizes, alignment),
|
| 186 |
+
torch.zeros(experts_per_rank, device=device),
|
| 187 |
+
)
|
| 188 |
+
# Print the results
|
| 189 |
+
print(permuted_indices_gpu)
|
| 190 |
+
print(m_sizes)
|
| 191 |
+
print("Success")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
test()
|
torchtitan/experiments/deepseek_v3/inference.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#!/usr/bin/bash
|
| 3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
|
| 6 |
+
# This source code is licensed under the BSD-style license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
NGPU=${NGPU:-"4"}
|
| 10 |
+
|
| 11 |
+
# Get the prompt from command line argument or use a default
|
| 12 |
+
prompt="${1:-What is 2+2?}"
|
| 13 |
+
|
| 14 |
+
# Run the model with the prompt
|
| 15 |
+
torchrun --standalone --nproc-per-node ${NGPU} generate.py "$prompt"
|
torchtitan/experiments/deepseek_v3/model_config.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class ModelArgs:
|
| 12 |
+
r"""
|
| 13 |
+
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
| 14 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 15 |
+
defaults will yield a similar configuration to that of the DeepSeek-V3.
|
| 16 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 17 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 18 |
+
Args:
|
| 19 |
+
vocab_size (`int`, *optional*, defaults to 129280):
|
| 20 |
+
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
| 21 |
+
`inputs_ids` passed when calling [`DeepseekV3Model`]
|
| 22 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 23 |
+
Dimension of the hidden representations.
|
| 24 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 25 |
+
Dimension of the MLP representations.
|
| 26 |
+
moe_intermediate_size (`int`, *optional*, defaults to 1407):
|
| 27 |
+
Dimension of the MoE representations.
|
| 28 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 29 |
+
Number of hidden layers in the Transformer decoder.
|
| 30 |
+
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
|
| 31 |
+
Number of nextn predict layers in the DeepSeekV3 Model.
|
| 32 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 33 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 34 |
+
n_shared_experts (`int`, *optional*, defaults to None):
|
| 35 |
+
Number of shared experts, None means dense model.
|
| 36 |
+
n_routed_experts (`int`, *optional*, defaults to None):
|
| 37 |
+
Number of routed experts, None means dense model.
|
| 38 |
+
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
| 39 |
+
Scaling factor or routed experts.
|
| 40 |
+
topk_method (`str`, *optional*, defaults to `gready`):
|
| 41 |
+
Topk method used in routed gate.
|
| 42 |
+
n_group (`int`, *optional*, defaults to None):
|
| 43 |
+
Number of groups for routed experts.
|
| 44 |
+
topk_group (`int`, *optional*, defaults to None):
|
| 45 |
+
Number of selected groups for each token(for each token, ensuring the selected experts is only within
|
| 46 |
+
`topk_group` groups).
|
| 47 |
+
num_experts_per_tok (`int`, *optional*, defaults to None):
|
| 48 |
+
Number of selected experts, None means dense model.
|
| 49 |
+
moe_layer_freq (`int`, *optional*, defaults to 1):
|
| 50 |
+
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
| 51 |
+
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
| 52 |
+
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
| 53 |
+
\--k dense layers--/
|
| 54 |
+
norm_topk_prob (`bool`, *optional*, defaults to False):
|
| 55 |
+
Whether to normalize the weights of the routed experts.
|
| 56 |
+
scoring_func (`str`, *optional*, defaults to 'softmax'):
|
| 57 |
+
Method of computing expert weights.
|
| 58 |
+
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
| 59 |
+
Auxiliary loss weight coefficient.
|
| 60 |
+
seq_aux = (`bool`, *optional*, defaults to True):
|
| 61 |
+
Whether to compute the auxiliary loss for each individual sample.
|
| 62 |
+
num_key_value_heads (`int`, *optional*):
|
| 63 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 64 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 65 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 66 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 67 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 68 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 69 |
+
`num_attention_heads`.
|
| 70 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 71 |
+
The non-linear activation function (function or string) in the decoder.
|
| 72 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 73 |
+
The maximum sequence length that this model might ever be used with.
|
| 74 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 75 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 76 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 77 |
+
The epsilon used by the rms normalization layers.
|
| 78 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 79 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 80 |
+
relevant if `config.is_decoder=True`.
|
| 81 |
+
pad_token_id (`int`, *optional*):
|
| 82 |
+
Padding token id.
|
| 83 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 84 |
+
Beginning of stream token id.
|
| 85 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 86 |
+
End of stream token id.
|
| 87 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
| 88 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
| 89 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
| 90 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
| 91 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
| 92 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 93 |
+
Whether to tie weight embeddings
|
| 94 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 95 |
+
The base period of the RoPE embeddings.
|
| 96 |
+
rope_scaling (`Dict`, *optional*):
|
| 97 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 98 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 99 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 100 |
+
`max_position_embeddings` to the expected new maximum.
|
| 101 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 102 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 103 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 104 |
+
The dropout ratio for the attention probabilities.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
vocab_size: int = 129280
|
| 108 |
+
hidden_size: int = 7168
|
| 109 |
+
intermediate_size: int = 18432
|
| 110 |
+
moe_intermediate_size: int = 2048
|
| 111 |
+
num_hidden_layers: int = 61
|
| 112 |
+
num_nextn_predict_layers: int = 1
|
| 113 |
+
num_attention_heads: int = 128
|
| 114 |
+
num_key_value_heads: int = 128
|
| 115 |
+
n_shared_experts: int = 1
|
| 116 |
+
n_routed_experts: int = 256
|
| 117 |
+
ep_size: int = 1
|
| 118 |
+
routed_scaling_factor: float = 2.5
|
| 119 |
+
kv_lora_rank: int = 512
|
| 120 |
+
q_lora_rank: int = 1536
|
| 121 |
+
qk_rope_head_dim: int = 64
|
| 122 |
+
v_head_dim: int = 128
|
| 123 |
+
qk_nope_head_dim: int = 128
|
| 124 |
+
topk_method: str = "noaux_tc"
|
| 125 |
+
n_group: int = 8
|
| 126 |
+
topk_group: int = 4
|
| 127 |
+
num_experts_per_tok: int = 8
|
| 128 |
+
moe_layer_freq: int = 1
|
| 129 |
+
first_k_dense_replace: int = 3
|
| 130 |
+
norm_topk_prob: bool = True
|
| 131 |
+
scoring_func: str = "sigmoid"
|
| 132 |
+
aux_loss_alpha: float = 0.001
|
| 133 |
+
seq_aux: bool = True
|
| 134 |
+
hidden_act: str = "silu"
|
| 135 |
+
max_position_embeddings: int = 163840
|
| 136 |
+
initializer_range: float = 0.02
|
| 137 |
+
rms_norm_eps: float = 1e-6
|
| 138 |
+
rope_theta: float = 10000.0
|
| 139 |
+
rope_scaling: dict = field(
|
| 140 |
+
default_factory=lambda: {
|
| 141 |
+
"beta_fast": 32,
|
| 142 |
+
"beta_slow": 1,
|
| 143 |
+
"factor": 40,
|
| 144 |
+
"mscale": 1.0,
|
| 145 |
+
"mscale_all_dim": 1.0,
|
| 146 |
+
"original_max_position_embeddings": 4096,
|
| 147 |
+
"type": "yarn",
|
| 148 |
+
}
|
| 149 |
+
)
|
| 150 |
+
attention_bias: bool = False
|
| 151 |
+
attention_dropout: float = 0.0
|
| 152 |
+
pad_token_id = None
|
| 153 |
+
# Added for symmetric memory
|
| 154 |
+
max_seq_len: int = 4096
|
| 155 |
+
dtype: str = "bfloat16"
|
| 156 |
+
# Added for pipeline parallel
|
| 157 |
+
num_stages: int = 1
|
| 158 |
+
stage_idx: int = 0
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# This is the configuration for deepseek-ai/DeepSeek-V2-Lite.
|
| 162 |
+
deepseek_v2_lite_config = ModelArgs(
|
| 163 |
+
vocab_size=102400,
|
| 164 |
+
hidden_size=2048,
|
| 165 |
+
intermediate_size=10944,
|
| 166 |
+
moe_intermediate_size=1408,
|
| 167 |
+
num_hidden_layers=27,
|
| 168 |
+
num_attention_heads=16,
|
| 169 |
+
num_key_value_heads=16,
|
| 170 |
+
n_shared_experts=2,
|
| 171 |
+
n_routed_experts=64,
|
| 172 |
+
routed_scaling_factor=1.0,
|
| 173 |
+
kv_lora_rank=512,
|
| 174 |
+
q_lora_rank=None,
|
| 175 |
+
qk_rope_head_dim=64,
|
| 176 |
+
v_head_dim=128,
|
| 177 |
+
qk_nope_head_dim=128,
|
| 178 |
+
topk_method="greedy",
|
| 179 |
+
n_group=1,
|
| 180 |
+
topk_group=1,
|
| 181 |
+
num_experts_per_tok=6,
|
| 182 |
+
first_k_dense_replace=1,
|
| 183 |
+
norm_topk_prob=False,
|
| 184 |
+
scoring_func="softmax",
|
| 185 |
+
max_position_embeddings=4096,
|
| 186 |
+
rope_scaling={
|
| 187 |
+
"beta_fast": 32,
|
| 188 |
+
"beta_slow": 1,
|
| 189 |
+
"factor": 40,
|
| 190 |
+
"mscale": 0.707,
|
| 191 |
+
"mscale_all_dim": 0.707,
|
| 192 |
+
"original_max_position_embeddings": 4096,
|
| 193 |
+
"type": "yarn",
|
| 194 |
+
},
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# Model configuration registry
|
| 199 |
+
# Key is the model distribution ID on HuggingFace Hub
|
| 200 |
+
deepseek_config_registry = {
|
| 201 |
+
"deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config,
|
| 202 |
+
"deepseek-ai/DeepSeek-V2-Lite-Chat": deepseek_v2_lite_config,
|
| 203 |
+
"deepseek-ai/deepseek-v3": ModelArgs(),
|
| 204 |
+
}
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .triton_on_device_all_to_all_v import OnDeviceAllToAllV
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"OnDeviceAllToAllV",
|
| 11 |
+
]
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from .triton_utils import get_flat_bid, get_flat_tid
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@triton.jit
|
| 14 |
+
def send_signal(addrs, sem: tl.constexpr):
|
| 15 |
+
if sem == "relaxed":
|
| 16 |
+
tl.inline_asm_elementwise(
|
| 17 |
+
"""
|
| 18 |
+
{
|
| 19 |
+
.reg .u32 %tmp32_<1>;
|
| 20 |
+
.reg .pred %p<1>;
|
| 21 |
+
|
| 22 |
+
send_signal:
|
| 23 |
+
atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
| 24 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
| 25 |
+
@!%p0 bra send_signal;
|
| 26 |
+
}
|
| 27 |
+
""",
|
| 28 |
+
"=r, l",
|
| 29 |
+
[addrs],
|
| 30 |
+
dtype=tl.int32,
|
| 31 |
+
is_pure=False,
|
| 32 |
+
pack=1,
|
| 33 |
+
)
|
| 34 |
+
elif sem == "acq_rel":
|
| 35 |
+
tl.inline_asm_elementwise(
|
| 36 |
+
"""
|
| 37 |
+
{
|
| 38 |
+
.reg .u32 %tmp32_<1>;
|
| 39 |
+
.reg .pred %p<1>;
|
| 40 |
+
|
| 41 |
+
send_signal:
|
| 42 |
+
atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
| 43 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
| 44 |
+
@!%p0 bra send_signal;
|
| 45 |
+
}
|
| 46 |
+
""",
|
| 47 |
+
"=r, l",
|
| 48 |
+
[addrs],
|
| 49 |
+
dtype=tl.int32,
|
| 50 |
+
is_pure=False,
|
| 51 |
+
pack=1,
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@triton.jit
|
| 58 |
+
def wait_signal(addrs, sem: tl.constexpr):
|
| 59 |
+
if sem == "relaxed":
|
| 60 |
+
tl.inline_asm_elementwise(
|
| 61 |
+
"""
|
| 62 |
+
{
|
| 63 |
+
.reg .u32 %tmp32_<1>;
|
| 64 |
+
.reg .pred %p<1>;
|
| 65 |
+
|
| 66 |
+
wait_signal:
|
| 67 |
+
atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
|
| 68 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
| 69 |
+
@!%p0 bra wait_signal;
|
| 70 |
+
}
|
| 71 |
+
""",
|
| 72 |
+
"=r, l",
|
| 73 |
+
[addrs],
|
| 74 |
+
dtype=tl.int32,
|
| 75 |
+
is_pure=False,
|
| 76 |
+
pack=1,
|
| 77 |
+
)
|
| 78 |
+
elif sem == "acq_rel":
|
| 79 |
+
tl.inline_asm_elementwise(
|
| 80 |
+
"""
|
| 81 |
+
{
|
| 82 |
+
.reg .u32 %tmp32_<1>;
|
| 83 |
+
.reg .pred %p<1>;
|
| 84 |
+
|
| 85 |
+
wait_signal:
|
| 86 |
+
atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
|
| 87 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
| 88 |
+
@!%p0 bra wait_signal;
|
| 89 |
+
}
|
| 90 |
+
""",
|
| 91 |
+
"=r, l",
|
| 92 |
+
[addrs],
|
| 93 |
+
dtype=tl.int32,
|
| 94 |
+
is_pure=False,
|
| 95 |
+
pack=1,
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@triton.jit
|
| 102 |
+
def blockwise_barrier(
|
| 103 |
+
signal_pad_ptrs,
|
| 104 |
+
block_id,
|
| 105 |
+
rank: tl.constexpr,
|
| 106 |
+
world_size: tl.constexpr,
|
| 107 |
+
sem: tl.constexpr,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Synchronizes blocks with matching block_id across participating devices.
|
| 111 |
+
|
| 112 |
+
Note: the function itself is not a system level barrier/fence. It is a
|
| 113 |
+
building block for expressing different synchronization patterns.
|
| 114 |
+
|
| 115 |
+
Pattern 0: Ensures that all writes to symm_mem buffers from previous
|
| 116 |
+
kernels across all devices are visible to the current kernel:
|
| 117 |
+
|
| 118 |
+
blockwise_barrier(..., sem="relaxed")
|
| 119 |
+
sync_threads()
|
| 120 |
+
|
| 121 |
+
Pattern 1: Ensures that all writes to symm_mem buffers from the current
|
| 122 |
+
block are visible to all remote blocks with matching blockIdx:
|
| 123 |
+
|
| 124 |
+
sync_threads()
|
| 125 |
+
blockwise_barrier(..., sem="acq_rel")
|
| 126 |
+
sync_threads()
|
| 127 |
+
|
| 128 |
+
Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
|
| 129 |
+
for writing by subsequent kernels across all devices.
|
| 130 |
+
|
| 131 |
+
sync_threads()
|
| 132 |
+
blockwise_barrier(..., sem="relaxed")
|
| 133 |
+
|
| 134 |
+
CUDA graph friendliness:
|
| 135 |
+
|
| 136 |
+
This barrier operates through atomic operations on a zero-filled signal
|
| 137 |
+
pad, which resets to a zero-filled state after each successful
|
| 138 |
+
synchronization. This design eliminates the need for incrementing a
|
| 139 |
+
flag from host.
|
| 140 |
+
"""
|
| 141 |
+
if block_id is None:
|
| 142 |
+
block_id = get_flat_bid()
|
| 143 |
+
flat_tid = get_flat_tid()
|
| 144 |
+
|
| 145 |
+
remote_ranks = tl.arange(0, world_size)
|
| 146 |
+
signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
|
| 147 |
+
remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
|
| 148 |
+
tl.pointer_type(tl.uint32)
|
| 149 |
+
)
|
| 150 |
+
send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
|
| 151 |
+
|
| 152 |
+
local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
|
| 153 |
+
tl.pointer_type(tl.uint32)
|
| 154 |
+
)
|
| 155 |
+
wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
|
| 156 |
+
|
| 157 |
+
if flat_tid < world_size:
|
| 158 |
+
send_signal(send_addrs, sem)
|
| 159 |
+
wait_signal(wait_addrs, sem)
|
torchtitan/experiments/deepseek_v3/train.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# torchrun --standalone --nproc-per-node 8 run.py
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from checkpoint import load_weights_from_hf
|
| 11 |
+
from model import DeepseekForCausalLM
|
| 12 |
+
from model_config import deepseek_config_registry
|
| 13 |
+
|
| 14 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 15 |
+
from torch.distributed.fsdp import fully_shard
|
| 16 |
+
from torch.distributed.pipelining import PipelineStage, Schedule1F1B
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Use DeepSeek-V2-Lite as a proxy
|
| 20 |
+
model_id = "deepseek-ai/DeepSeek-V2-Lite"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Run full model
|
| 24 |
+
def run_full_model(
|
| 25 |
+
mesh: DeviceMesh,
|
| 26 |
+
):
|
| 27 |
+
rank = dist.get_rank()
|
| 28 |
+
device_count = torch.cuda.device_count()
|
| 29 |
+
device = torch.device("cuda", rank % device_count)
|
| 30 |
+
|
| 31 |
+
pp_mesh = mesh["pp"]
|
| 32 |
+
ep_mesh = mesh["ep"]
|
| 33 |
+
pp_rank = pp_mesh.get_local_rank()
|
| 34 |
+
ep_rank = ep_mesh.get_local_rank()
|
| 35 |
+
pp_size = pp_mesh.size()
|
| 36 |
+
ep_size = ep_mesh.size()
|
| 37 |
+
|
| 38 |
+
# Get model configs
|
| 39 |
+
model_args = deepseek_config_registry[model_id]
|
| 40 |
+
# [Note]: I am making the model smaller for testing / avoiding OOM. If you
|
| 41 |
+
# have sufficient GPUs for model parallelism, you can remove this line.
|
| 42 |
+
model_args.num_hidden_layers = 16
|
| 43 |
+
|
| 44 |
+
# Apply model parallelism
|
| 45 |
+
model_args.ep_size = ep_size
|
| 46 |
+
model_args.num_stages = pp_size
|
| 47 |
+
model_args.stage_idx = pp_rank
|
| 48 |
+
print(model_args)
|
| 49 |
+
|
| 50 |
+
# Instantiate model
|
| 51 |
+
with device, mesh:
|
| 52 |
+
model = DeepseekForCausalLM(model_args)
|
| 53 |
+
|
| 54 |
+
# Load weights
|
| 55 |
+
load_weights_from_hf(model, model_id, device)
|
| 56 |
+
model.train()
|
| 57 |
+
|
| 58 |
+
# Apply data parallelism
|
| 59 |
+
fsdp_mesh = mesh["fsdp"]
|
| 60 |
+
hsdp_mesh = mesh["ep", "fsdp"]
|
| 61 |
+
# Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the
|
| 62 |
+
# optimizer (Zero-1) and gradients (Zero-2), but not the model weights.
|
| 63 |
+
# Reason: the MoE is "sparsely activated" compared to the dense model, thus
|
| 64 |
+
# it will be ineconomical re-gather the weights.
|
| 65 |
+
for layer in model.model.layers.values():
|
| 66 |
+
# Apply FSDP to experts
|
| 67 |
+
if hasattr(layer.mlp, "experts"):
|
| 68 |
+
for expert in layer.mlp.experts.values():
|
| 69 |
+
fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False)
|
| 70 |
+
# Apply HSDP to other parts such as attention, layernorm, because they
|
| 71 |
+
# are doing DDP on EP dimension
|
| 72 |
+
fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False)
|
| 73 |
+
|
| 74 |
+
# Apply HSDP on root model (lm_head, embeddings, etc)
|
| 75 |
+
fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)
|
| 76 |
+
|
| 77 |
+
# Synthetic setting
|
| 78 |
+
microbatches = pp_size * 2
|
| 79 |
+
|
| 80 |
+
# Use Symmetric Memory for MoE token shuffle.
|
| 81 |
+
# TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is
|
| 82 |
+
# currently supported for forward only. See `generate.py`.
|
| 83 |
+
# model.setup_symm_mem(torch.bfloat16, device)
|
| 84 |
+
|
| 85 |
+
# Example inputs
|
| 86 |
+
torch.manual_seed(ep_rank)
|
| 87 |
+
bs = 4
|
| 88 |
+
seqlen = 128
|
| 89 |
+
x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device)
|
| 90 |
+
label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device)
|
| 91 |
+
|
| 92 |
+
# Create loss function
|
| 93 |
+
loss_fn = torch.nn.functional.cross_entropy
|
| 94 |
+
|
| 95 |
+
# Run forward and backward
|
| 96 |
+
steps = 2
|
| 97 |
+
for _ in range(steps):
|
| 98 |
+
if pp_size > 1:
|
| 99 |
+
# Create pipeline stage
|
| 100 |
+
stage = PipelineStage(
|
| 101 |
+
model,
|
| 102 |
+
pp_rank,
|
| 103 |
+
pp_size,
|
| 104 |
+
device,
|
| 105 |
+
group=pp_mesh.get_group(),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Create pipeline schedule
|
| 109 |
+
losses = []
|
| 110 |
+
pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn)
|
| 111 |
+
|
| 112 |
+
if pp_rank == 0:
|
| 113 |
+
y = pp_schedule.step(x)
|
| 114 |
+
elif pp_rank == pp_size - 1:
|
| 115 |
+
y = pp_schedule.step(target=label, losses=losses)
|
| 116 |
+
loss = torch.mean(torch.stack(losses))
|
| 117 |
+
else:
|
| 118 |
+
pp_schedule.step()
|
| 119 |
+
else:
|
| 120 |
+
y = model(x)
|
| 121 |
+
loss = loss_fn(y, label)
|
| 122 |
+
loss.backward()
|
| 123 |
+
|
| 124 |
+
if pp_rank == pp_size - 1:
|
| 125 |
+
print(f"logits: {y.shape}")
|
| 126 |
+
print(f"{loss=}")
|
| 127 |
+
|
| 128 |
+
if pp_rank == 0:
|
| 129 |
+
param = model.get_parameter("model.layers.0.self_attn.q_proj.weight")
|
| 130 |
+
print(f"{torch.linalg.norm(param.grad)=}")
|
| 131 |
+
|
| 132 |
+
model.zero_grad()
|
| 133 |
+
|
| 134 |
+
print("Backward done")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp"))
|
| 139 |
+
|
| 140 |
+
run_full_model(mesh)
|
| 141 |
+
|
| 142 |
+
dist.destroy_process_group()
|
torchtitan/experiments/flux/dataset/flux_dataset.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import random
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Callable, Optional
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from datasets import Dataset, load_dataset
|
| 17 |
+
from datasets.distributed import split_dataset_by_node
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 21 |
+
|
| 22 |
+
from torch.utils.data import IterableDataset
|
| 23 |
+
from torchtitan.components.dataloader import ParallelAwareDataloader
|
| 24 |
+
|
| 25 |
+
from torchtitan.config_manager import JobConfig
|
| 26 |
+
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
|
| 27 |
+
from torchtitan.tools.logging import logger
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _process_cc12m_image(
|
| 31 |
+
img: Image.Image,
|
| 32 |
+
output_size: int = 256,
|
| 33 |
+
) -> Optional[torch.Tensor]:
|
| 34 |
+
"""Process CC12M image to the desired size."""
|
| 35 |
+
|
| 36 |
+
width, height = img.size
|
| 37 |
+
# Skip low resolution images
|
| 38 |
+
if width < output_size or height < output_size:
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
if width >= height:
|
| 42 |
+
# resize height to be equal to output_size, then crop
|
| 43 |
+
new_width, new_height = math.ceil(output_size / height * width), output_size
|
| 44 |
+
img = img.resize((new_width, new_height))
|
| 45 |
+
left = random.randint(0, new_width - output_size)
|
| 46 |
+
resized_img = img.crop((left, 0, left + output_size, output_size))
|
| 47 |
+
else:
|
| 48 |
+
# resize width to be equal to output_size, the crop
|
| 49 |
+
new_width, new_height = (
|
| 50 |
+
output_size,
|
| 51 |
+
math.ceil(output_size / width * height),
|
| 52 |
+
)
|
| 53 |
+
img = img.resize((new_width, new_height))
|
| 54 |
+
lower = random.randint(0, new_width - output_size)
|
| 55 |
+
resized_img = img.crop((0, lower, output_size, lower + output_size))
|
| 56 |
+
|
| 57 |
+
assert resized_img.size[0] == resized_img.size[1] == output_size
|
| 58 |
+
|
| 59 |
+
# Skip grayscale images
|
| 60 |
+
if resized_img.mode == "L":
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
np_img = np.array(resized_img).transpose((2, 0, 1))
|
| 64 |
+
tensor_img = torch.tensor(np_img).float() / 255.0
|
| 65 |
+
|
| 66 |
+
# NOTE: The following commented code is an alternative way
|
| 67 |
+
# img_transform = transforms.Compose(
|
| 68 |
+
# [
|
| 69 |
+
# transforms.Resize(max(output_size, output_size)),
|
| 70 |
+
# transforms.CenterCrop((output_size, output_size)),
|
| 71 |
+
# transforms.ToTensor(),
|
| 72 |
+
# ]
|
| 73 |
+
# )
|
| 74 |
+
# tensor_img = img_transform(img)
|
| 75 |
+
|
| 76 |
+
return tensor_img
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _flux_data_processor(
|
| 80 |
+
sample: dict[str, Any],
|
| 81 |
+
t5_tokenizer: FluxTokenizer,
|
| 82 |
+
clip_tokenizer: FluxTokenizer,
|
| 83 |
+
output_size: int = 256,
|
| 84 |
+
) -> dict[str, Any]:
|
| 85 |
+
"""
|
| 86 |
+
Preprocess CC12M dataset sample image and text for Flux model.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
sample: A sample from dataset
|
| 90 |
+
t5_encoder: T5 encoder
|
| 91 |
+
clip_encoder: CLIP encoder
|
| 92 |
+
output_size: The output image size
|
| 93 |
+
|
| 94 |
+
"""
|
| 95 |
+
img = _process_cc12m_image(sample["jpg"], output_size=output_size)
|
| 96 |
+
t5_tokens = t5_tokenizer.encode(sample["txt"])
|
| 97 |
+
clip_tokens = clip_tokenizer.encode(sample["txt"])
|
| 98 |
+
|
| 99 |
+
return {
|
| 100 |
+
"image": img,
|
| 101 |
+
"clip_tokens": clip_tokens, # type: List[int]
|
| 102 |
+
"t5_tokens": t5_tokens, # type: List[int]
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class TextToImageDatasetConfig:
|
| 108 |
+
path: str
|
| 109 |
+
loader: Callable
|
| 110 |
+
data_processor: Callable
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
DATASETS = {
|
| 114 |
+
"cc12m": TextToImageDatasetConfig(
|
| 115 |
+
path="pixparse/cc12m-wds",
|
| 116 |
+
loader=lambda path: load_dataset(path, split="train", streaming=True),
|
| 117 |
+
data_processor=_flux_data_processor,
|
| 118 |
+
),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _validate_dataset(
|
| 123 |
+
dataset_name: str, dataset_path: Optional[str] = None
|
| 124 |
+
) -> tuple[str, Callable, Callable]:
|
| 125 |
+
"""Validate dataset name and path."""
|
| 126 |
+
if dataset_name not in DATASETS:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Dataset {dataset_name} is not supported. "
|
| 129 |
+
f"Supported datasets are: {list(DATASETS.keys())}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
config = DATASETS[dataset_name]
|
| 133 |
+
path = dataset_path or config.path
|
| 134 |
+
logger.info(f"Preparing {dataset_name} dataset from {path}")
|
| 135 |
+
return path, config.loader, config.data_processor
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class FluxDataset(IterableDataset, Stateful):
|
| 139 |
+
"""Dataset for FLUX text-to-image model.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
dataset_name (str): Name of the dataset.
|
| 143 |
+
dataset_path (str): Path to the dataset.
|
| 144 |
+
model_transform (Transform): Callable that applies model-specific preprocessing to the sample.
|
| 145 |
+
dp_rank (int): Data parallel rank.
|
| 146 |
+
dp_world_size (int): Data parallel world size.
|
| 147 |
+
infinite (bool): Whether to loop over the dataset infinitely.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
dataset_name: str,
|
| 153 |
+
dataset_path: Optional[str],
|
| 154 |
+
t5_tokenizer: FluxTokenizer,
|
| 155 |
+
clip_tokenizer: FluxTokenizer,
|
| 156 |
+
job_config: Optional[JobConfig] = None,
|
| 157 |
+
dp_rank: int = 0,
|
| 158 |
+
dp_world_size: int = 1,
|
| 159 |
+
infinite: bool = False,
|
| 160 |
+
) -> None:
|
| 161 |
+
|
| 162 |
+
# Force lowercase for consistent comparison
|
| 163 |
+
dataset_name = dataset_name.lower()
|
| 164 |
+
|
| 165 |
+
path, dataset_loader, data_processor = _validate_dataset(
|
| 166 |
+
dataset_name, dataset_path
|
| 167 |
+
)
|
| 168 |
+
ds = dataset_loader(path)
|
| 169 |
+
|
| 170 |
+
self.dataset_name = dataset_name
|
| 171 |
+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
|
| 172 |
+
|
| 173 |
+
self._t5_tokenizer = t5_tokenizer
|
| 174 |
+
self._clip_tokenizer = clip_tokenizer
|
| 175 |
+
self._data_processor = data_processor
|
| 176 |
+
self.job_config = job_config
|
| 177 |
+
|
| 178 |
+
self.infinite = infinite
|
| 179 |
+
|
| 180 |
+
# Variables for checkpointing
|
| 181 |
+
self._sample_idx = 0
|
| 182 |
+
self._all_samples: list[dict[str, Any]] = []
|
| 183 |
+
|
| 184 |
+
def _get_data_iter(self):
|
| 185 |
+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
|
| 186 |
+
return iter([])
|
| 187 |
+
|
| 188 |
+
it = iter(self._data)
|
| 189 |
+
for _ in range(self._sample_idx):
|
| 190 |
+
next(it)
|
| 191 |
+
return it
|
| 192 |
+
|
| 193 |
+
def __iter__(self):
|
| 194 |
+
while True:
|
| 195 |
+
for sample in self._get_data_iter():
|
| 196 |
+
# Use the dataset-specific preprocessor
|
| 197 |
+
sample_dict = self._data_processor(
|
| 198 |
+
sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# skip low quality image or image with color channel = 1
|
| 202 |
+
if sample_dict["image"] is None:
|
| 203 |
+
logger.warning(
|
| 204 |
+
f"Low quality image {sample['__key__']} is skipped in Flux Dataloader"
|
| 205 |
+
)
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
self._all_samples.extend(sample_dict)
|
| 209 |
+
self._sample_idx += 1
|
| 210 |
+
|
| 211 |
+
labels = sample_dict.pop("image")
|
| 212 |
+
yield sample_dict, labels
|
| 213 |
+
|
| 214 |
+
if not self.infinite:
|
| 215 |
+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
|
| 216 |
+
break
|
| 217 |
+
else:
|
| 218 |
+
# Reset offset for the next iteration
|
| 219 |
+
self._sample_idx = 0
|
| 220 |
+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
|
| 221 |
+
|
| 222 |
+
def load_state_dict(self, state_dict):
|
| 223 |
+
self._sample_idx = state_dict["sample_idx"]
|
| 224 |
+
self._all_samples = state_dict["all_samples"]
|
| 225 |
+
|
| 226 |
+
def state_dict(self):
|
| 227 |
+
return {
|
| 228 |
+
"all_samples": self._all_samples,
|
| 229 |
+
"sample_idx": self._sample_idx,
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def build_flux_dataloader(
|
| 234 |
+
dp_world_size: int,
|
| 235 |
+
dp_rank: int,
|
| 236 |
+
job_config: JobConfig,
|
| 237 |
+
# This parameter is not used, keep it for compatibility
|
| 238 |
+
tokenizer: FluxTokenizer | None,
|
| 239 |
+
infinite: bool = True,
|
| 240 |
+
) -> ParallelAwareDataloader:
|
| 241 |
+
"""Build a data loader for HuggingFace datasets."""
|
| 242 |
+
dataset_name = job_config.training.dataset
|
| 243 |
+
dataset_path = job_config.training.dataset_path
|
| 244 |
+
batch_size = job_config.training.batch_size
|
| 245 |
+
|
| 246 |
+
t5_encoder_name = job_config.encoder.t5_encoder
|
| 247 |
+
clip_encoder_name = job_config.encoder.clip_encoder
|
| 248 |
+
max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
|
| 249 |
+
|
| 250 |
+
ds = FluxDataset(
|
| 251 |
+
dataset_name=dataset_name,
|
| 252 |
+
dataset_path=dataset_path,
|
| 253 |
+
t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len),
|
| 254 |
+
clip_tokenizer=FluxTokenizer(
|
| 255 |
+
clip_encoder_name, max_length=77
|
| 256 |
+
), # fix max_length for CLIP
|
| 257 |
+
dp_rank=dp_rank,
|
| 258 |
+
dp_world_size=dp_world_size,
|
| 259 |
+
infinite=infinite,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return ParallelAwareDataloader(
|
| 263 |
+
dataset=ds,
|
| 264 |
+
dp_rank=dp_rank,
|
| 265 |
+
dp_world_size=dp_world_size,
|
| 266 |
+
batch_size=batch_size,
|
| 267 |
+
)
|
torchtitan/experiments/flux/model/layers.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# imported from black-forest-labs/FLUX
|
| 8 |
+
import math
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from torch import nn, Tensor
|
| 14 |
+
|
| 15 |
+
from torchtitan.experiments.flux.model.math import attention, rope
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EmbedND(nn.Module):
|
| 19 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.dim = dim
|
| 22 |
+
self.theta = theta
|
| 23 |
+
self.axes_dim = axes_dim
|
| 24 |
+
|
| 25 |
+
def forward(self, ids: Tensor) -> Tensor:
|
| 26 |
+
n_axes = ids.shape[-1]
|
| 27 |
+
emb = torch.cat(
|
| 28 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
| 29 |
+
dim=-3,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return emb.unsqueeze(1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
| 36 |
+
"""
|
| 37 |
+
Create sinusoidal timestep embeddings.
|
| 38 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 39 |
+
These may be fractional.
|
| 40 |
+
:param dim: the dimension of the output.
|
| 41 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 42 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 43 |
+
"""
|
| 44 |
+
t = time_factor * t
|
| 45 |
+
half = dim // 2
|
| 46 |
+
freqs = torch.exp(
|
| 47 |
+
-math.log(max_period)
|
| 48 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 49 |
+
/ half
|
| 50 |
+
).to(t.device)
|
| 51 |
+
|
| 52 |
+
args = t[:, None].float() * freqs[None]
|
| 53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 54 |
+
if dim % 2:
|
| 55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 56 |
+
if torch.is_floating_point(t):
|
| 57 |
+
embedding = embedding.to(t)
|
| 58 |
+
return embedding
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MLPEmbedder(nn.Module):
|
| 62 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
| 65 |
+
self.silu = nn.SiLU()
|
| 66 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RMSNorm(torch.nn.Module):
|
| 73 |
+
def __init__(self, dim: int):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
| 76 |
+
|
| 77 |
+
def forward(self, x: Tensor):
|
| 78 |
+
x_dtype = x.dtype
|
| 79 |
+
x = x.float()
|
| 80 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
| 81 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class QKNorm(torch.nn.Module):
|
| 85 |
+
def __init__(self, dim: int):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.query_norm = RMSNorm(dim) # TODO(jianiw): switch to pytorch nn.RMSNorm
|
| 88 |
+
self.key_norm = RMSNorm(dim)
|
| 89 |
+
|
| 90 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
| 91 |
+
q = self.query_norm(q)
|
| 92 |
+
k = self.key_norm(k)
|
| 93 |
+
return q.to(v), k.to(v)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SelfAttention(nn.Module):
|
| 97 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.num_heads = num_heads
|
| 100 |
+
head_dim = dim // num_heads
|
| 101 |
+
|
| 102 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 103 |
+
self.norm = QKNorm(head_dim)
|
| 104 |
+
self.proj = nn.Linear(dim, dim)
|
| 105 |
+
|
| 106 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
| 107 |
+
qkv = self.qkv(x)
|
| 108 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 109 |
+
q, k = self.norm(q, k, v)
|
| 110 |
+
x = attention(q, k, v, pe=pe)
|
| 111 |
+
x = self.proj(x)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dataclass
|
| 116 |
+
class ModulationOut:
|
| 117 |
+
shift: Tensor
|
| 118 |
+
scale: Tensor
|
| 119 |
+
gate: Tensor
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Modulation(nn.Module):
|
| 123 |
+
def __init__(self, dim: int, double: bool):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.is_double = double
|
| 126 |
+
self.multiplier = 6 if double else 3
|
| 127 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 128 |
+
|
| 129 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
| 130 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
|
| 131 |
+
self.multiplier, dim=-1
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return (
|
| 135 |
+
ModulationOut(*out[:3]),
|
| 136 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class DoubleStreamBlock(nn.Module):
|
| 141 |
+
def __init__(
|
| 142 |
+
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
| 143 |
+
):
|
| 144 |
+
super().__init__()
|
| 145 |
+
|
| 146 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 147 |
+
self.num_heads = num_heads
|
| 148 |
+
self.hidden_size = hidden_size
|
| 149 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
| 150 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 151 |
+
self.img_attn = SelfAttention(
|
| 152 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 156 |
+
self.img_mlp = nn.Sequential(
|
| 157 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 158 |
+
nn.GELU(approximate="tanh"),
|
| 159 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
| 163 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 164 |
+
self.txt_attn = SelfAttention(
|
| 165 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 169 |
+
self.txt_mlp = nn.Sequential(
|
| 170 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 171 |
+
nn.GELU(approximate="tanh"),
|
| 172 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def forward(
|
| 176 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
|
| 177 |
+
) -> tuple[Tensor, Tensor]:
|
| 178 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
| 179 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
| 180 |
+
|
| 181 |
+
# prepare image for attention
|
| 182 |
+
img_modulated = self.img_norm1(img)
|
| 183 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 184 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
| 185 |
+
img_q, img_k, img_v = rearrange(
|
| 186 |
+
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
| 187 |
+
)
|
| 188 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
| 189 |
+
|
| 190 |
+
# prepare txt for attention
|
| 191 |
+
txt_modulated = self.txt_norm1(txt)
|
| 192 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 193 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
| 194 |
+
txt_q, txt_k, txt_v = rearrange(
|
| 195 |
+
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
| 196 |
+
)
|
| 197 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 198 |
+
|
| 199 |
+
# run actual attention
|
| 200 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
| 201 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
| 202 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
| 203 |
+
|
| 204 |
+
attn = attention(q, k, v, pe=pe)
|
| 205 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
| 206 |
+
|
| 207 |
+
# calculate the img bloks
|
| 208 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
| 209 |
+
img = img + img_mod2.gate * self.img_mlp(
|
| 210 |
+
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# calculate the txt bloks
|
| 214 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
| 215 |
+
txt = txt + txt_mod2.gate * self.txt_mlp(
|
| 216 |
+
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
| 217 |
+
)
|
| 218 |
+
return img, txt
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class SingleStreamBlock(nn.Module):
|
| 222 |
+
"""
|
| 223 |
+
A DiT block with parallel linear layers as described in
|
| 224 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
def __init__(
|
| 228 |
+
self,
|
| 229 |
+
hidden_size: int,
|
| 230 |
+
num_heads: int,
|
| 231 |
+
mlp_ratio: float = 4.0,
|
| 232 |
+
qk_scale: float | None = None,
|
| 233 |
+
):
|
| 234 |
+
super().__init__()
|
| 235 |
+
self.hidden_dim = hidden_size
|
| 236 |
+
self.num_heads = num_heads
|
| 237 |
+
head_dim = hidden_size // num_heads
|
| 238 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 239 |
+
|
| 240 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 241 |
+
# qkv and mlp_in
|
| 242 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
| 243 |
+
# proj and mlp_out
|
| 244 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
| 245 |
+
|
| 246 |
+
self.norm = QKNorm(head_dim)
|
| 247 |
+
|
| 248 |
+
self.hidden_size = hidden_size
|
| 249 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 250 |
+
|
| 251 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
| 252 |
+
self.modulation = Modulation(hidden_size, double=False)
|
| 253 |
+
|
| 254 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
| 255 |
+
mod, _ = self.modulation(vec)
|
| 256 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
| 257 |
+
qkv, mlp = torch.split(
|
| 258 |
+
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 262 |
+
q, k = self.norm(q, k, v)
|
| 263 |
+
|
| 264 |
+
# compute attention
|
| 265 |
+
attn = attention(q, k, v, pe=pe)
|
| 266 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 267 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
| 268 |
+
return x + mod.gate * output
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class LastLayer(nn.Module):
|
| 272 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 275 |
+
self.linear = nn.Linear(
|
| 276 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True
|
| 277 |
+
)
|
| 278 |
+
self.adaLN_modulation = nn.Sequential(
|
| 279 |
+
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
| 283 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
| 284 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
| 285 |
+
x = self.linear(x)
|
| 286 |
+
return x
|
torchtitan/experiments/flux/model/math.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 13 |
+
q, k = apply_rope(q, k, pe)
|
| 14 |
+
|
| 15 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 16 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
| 17 |
+
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
| 22 |
+
assert dim % 2 == 0
|
| 23 |
+
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
|
| 24 |
+
omega = 1.0 / (theta**scale)
|
| 25 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 26 |
+
out = torch.stack(
|
| 27 |
+
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
|
| 28 |
+
)
|
| 29 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
| 30 |
+
return out.float()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
| 34 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 35 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 36 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 37 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 38 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
torchtitan/experiments/flux/model/model.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torch import nn, Tensor
|
| 12 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 13 |
+
from torchtitan.config_manager import JobConfig
|
| 14 |
+
|
| 15 |
+
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
|
| 16 |
+
from torchtitan.experiments.flux.model.layers import (
|
| 17 |
+
DoubleStreamBlock,
|
| 18 |
+
EmbedND,
|
| 19 |
+
LastLayer,
|
| 20 |
+
MLPEmbedder,
|
| 21 |
+
SingleStreamBlock,
|
| 22 |
+
timestep_embedding,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
|
| 26 |
+
from torchtitan.tools.logging import logger
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class FluxModelArgs(BaseModelArgs):
|
| 31 |
+
in_channels: int = 64
|
| 32 |
+
out_channels: int = 64
|
| 33 |
+
vec_in_dim: int = 768
|
| 34 |
+
context_in_dim: int = 512
|
| 35 |
+
hidden_size: int = 3072
|
| 36 |
+
mlp_ratio: float = 4.0
|
| 37 |
+
num_heads: int = 24
|
| 38 |
+
depth: int = 19
|
| 39 |
+
depth_single_blocks: int = 38
|
| 40 |
+
axes_dim: tuple = (16, 56, 56)
|
| 41 |
+
theta: int = 10_000
|
| 42 |
+
qkv_bias: bool = True
|
| 43 |
+
guidance_embed: bool = True
|
| 44 |
+
autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
|
| 45 |
+
|
| 46 |
+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
|
| 47 |
+
# context_in_dim is the same as the T5 embedding dimension
|
| 48 |
+
self.context_in_dim = job_config.encoder.max_t5_encoding_len
|
| 49 |
+
|
| 50 |
+
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
|
| 51 |
+
# TODO(jianiw): Add the number of flops for the autoencoder
|
| 52 |
+
nparams = sum(p.numel() for p in model.parameters())
|
| 53 |
+
logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
|
| 54 |
+
return nparams, 1
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class FluxModel(nn.Module, ModelProtocol):
|
| 58 |
+
"""
|
| 59 |
+
Transformer model for flow matching on sequences.
|
| 60 |
+
|
| 61 |
+
Agrs:
|
| 62 |
+
model_args: FluxModelArgs.
|
| 63 |
+
|
| 64 |
+
Attributes:
|
| 65 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, model_args: FluxModelArgs):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.model_args = model_args
|
| 72 |
+
self.in_channels = model_args.in_channels
|
| 73 |
+
self.out_channels = model_args.out_channels
|
| 74 |
+
if model_args.hidden_size % model_args.num_heads != 0:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
|
| 77 |
+
)
|
| 78 |
+
pe_dim = model_args.hidden_size // model_args.num_heads
|
| 79 |
+
if sum(model_args.axes_dim) != pe_dim:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
|
| 82 |
+
)
|
| 83 |
+
self.hidden_size = model_args.hidden_size
|
| 84 |
+
self.num_heads = model_args.num_heads
|
| 85 |
+
self.pe_embedder = EmbedND(
|
| 86 |
+
dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
|
| 87 |
+
)
|
| 88 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
| 89 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 90 |
+
self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
|
| 91 |
+
self.guidance_in = (
|
| 92 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 93 |
+
if model_args.guidance_embed
|
| 94 |
+
else nn.Identity()
|
| 95 |
+
)
|
| 96 |
+
self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
|
| 97 |
+
|
| 98 |
+
self.double_blocks = nn.ModuleList(
|
| 99 |
+
[
|
| 100 |
+
DoubleStreamBlock(
|
| 101 |
+
self.hidden_size,
|
| 102 |
+
self.num_heads,
|
| 103 |
+
mlp_ratio=model_args.mlp_ratio,
|
| 104 |
+
qkv_bias=model_args.qkv_bias,
|
| 105 |
+
)
|
| 106 |
+
for _ in range(model_args.depth)
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self.single_blocks = nn.ModuleList(
|
| 111 |
+
[
|
| 112 |
+
SingleStreamBlock(
|
| 113 |
+
self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio
|
| 114 |
+
)
|
| 115 |
+
for _ in range(model_args.depth_single_blocks)
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
| 120 |
+
|
| 121 |
+
def init_weights(self, buffer_device=None):
|
| 122 |
+
# TODO(jianiw): replace placeholder with real weight init
|
| 123 |
+
for param in self.parameters():
|
| 124 |
+
param.data.uniform_(0, 0.1)
|
| 125 |
+
|
| 126 |
+
def forward(
|
| 127 |
+
self,
|
| 128 |
+
img: Tensor,
|
| 129 |
+
img_ids: Tensor,
|
| 130 |
+
txt: Tensor,
|
| 131 |
+
txt_ids: Tensor,
|
| 132 |
+
timesteps: Tensor,
|
| 133 |
+
y: Tensor,
|
| 134 |
+
guidance: Tensor | None = None,
|
| 135 |
+
) -> Tensor:
|
| 136 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 137 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 138 |
+
|
| 139 |
+
# running on sequences img
|
| 140 |
+
img = self.img_in(img)
|
| 141 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
| 142 |
+
if self.model_args.guidance_embed:
|
| 143 |
+
if guidance is None:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
"Didn't get guidance strength for guidance distilled model."
|
| 146 |
+
)
|
| 147 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
| 148 |
+
vec = vec + self.vector_in(y)
|
| 149 |
+
txt = self.txt_in(txt)
|
| 150 |
+
|
| 151 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 152 |
+
pe = self.pe_embedder(ids)
|
| 153 |
+
|
| 154 |
+
for block in self.double_blocks:
|
| 155 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
| 156 |
+
|
| 157 |
+
img = torch.cat((txt, img), 1)
|
| 158 |
+
for block in self.single_blocks:
|
| 159 |
+
img = block(img, vec=vec, pe=pe)
|
| 160 |
+
img = img[:, txt.shape[1] :, ...]
|
| 161 |
+
|
| 162 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
| 163 |
+
return img
|
| 164 |
+
|
| 165 |
+
@classmethod
|
| 166 |
+
def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
|
| 167 |
+
"""
|
| 168 |
+
Initialize a Flux model from a FluxModelArgs object.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
model_args (FluxModelArgs): Model configuration arguments.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
FluxModel: FluxModel model.
|
| 175 |
+
|
| 176 |
+
"""
|
| 177 |
+
return cls(model_args)
|
torchtitan/experiments/flux/parallelize_flux.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
|
| 8 |
+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 14 |
+
|
| 15 |
+
from torchtitan.config_manager import JobConfig
|
| 16 |
+
from torchtitan.distributed import ParallelDims
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parallelize_flux(
|
| 20 |
+
model: nn.Module,
|
| 21 |
+
world_mesh: DeviceMesh,
|
| 22 |
+
parallel_dims: ParallelDims,
|
| 23 |
+
job_config: JobConfig,
|
| 24 |
+
):
|
| 25 |
+
# TODO: Add model parallel strategy here
|
| 26 |
+
return model
|
torchtitan/experiments/flux/requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers
|
| 2 |
+
einops
|
torchtitan/experiments/flux/scripts/download_autoencoder.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from requests.exceptions import HTTPError
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def hf_download(
|
| 13 |
+
repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None
|
| 14 |
+
) -> None:
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
hf_hub_download(
|
| 19 |
+
repo_id=repo_id,
|
| 20 |
+
filename=file_path,
|
| 21 |
+
local_dir=local_dir,
|
| 22 |
+
local_dir_use_symlinks=False,
|
| 23 |
+
token=hf_token,
|
| 24 |
+
)
|
| 25 |
+
except HTTPError as e:
|
| 26 |
+
if e.response.status_code == 401:
|
| 27 |
+
print(
|
| 28 |
+
"You need to pass a valid `--hf_token=...` to download private checkpoints."
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
raise e
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
import argparse
|
| 36 |
+
|
| 37 |
+
parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--repo_id",
|
| 40 |
+
type=str,
|
| 41 |
+
default="black-forest-labs/FLUX.1-dev",
|
| 42 |
+
help="Repository ID to download from. default to Flux-dev model",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--ae_path",
|
| 46 |
+
type=str,
|
| 47 |
+
default="ae.safetensors",
|
| 48 |
+
help="the autoencoder path relative to repo_id",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--hf_token", type=str, default=None, help="HuggingFace API token"
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--local_dir",
|
| 55 |
+
type=str,
|
| 56 |
+
default="torchtitan/experiments/flux/assets/autoencoder/",
|
| 57 |
+
help="local directory to save the autoencoder",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token)
|
torchtitan/experiments/flux/tests/test_flux_dataloader.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from torchtitan.config_manager import JobConfig
|
| 10 |
+
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
|
| 11 |
+
from torchtitan.tools.profiling import (
|
| 12 |
+
maybe_enable_memory_snapshot,
|
| 13 |
+
maybe_enable_profiling,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestFluxDataLoader:
|
| 18 |
+
def test_flux_dataloader(self):
|
| 19 |
+
dataset_name = "cc12m"
|
| 20 |
+
batch_size = 32
|
| 21 |
+
world_size = 4
|
| 22 |
+
rank = 0
|
| 23 |
+
|
| 24 |
+
num_steps = 10
|
| 25 |
+
|
| 26 |
+
path = "torchtitan.experiments.flux.flux_argparser"
|
| 27 |
+
sys.argv.append(f"--experimental.custom_args_module={path}")
|
| 28 |
+
config = JobConfig()
|
| 29 |
+
config.maybe_add_custom_args()
|
| 30 |
+
config.parse_args(
|
| 31 |
+
[
|
| 32 |
+
# Profiling options
|
| 33 |
+
# "--profiling.enable_profiling",
|
| 34 |
+
# "--profiling.profile_freq",
|
| 35 |
+
# "5",
|
| 36 |
+
# "--profiling.enable_memory_snapshot",
|
| 37 |
+
# "--profiling.save_memory_snapshot_folder",
|
| 38 |
+
# "memory_snapshot_flux",
|
| 39 |
+
"--training.dataset",
|
| 40 |
+
dataset_name,
|
| 41 |
+
"--training.batch_size",
|
| 42 |
+
str(batch_size),
|
| 43 |
+
"--encoder.t5_encoder",
|
| 44 |
+
"google/t5-v1_1-small",
|
| 45 |
+
"--encoder.clip_encoder",
|
| 46 |
+
"openai/clip-vit-large-patch14",
|
| 47 |
+
"--encoder.max_t5_encoding_len",
|
| 48 |
+
"512",
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
with maybe_enable_profiling(
|
| 53 |
+
config, global_step=0
|
| 54 |
+
) as torch_profiler, maybe_enable_memory_snapshot(
|
| 55 |
+
config, global_step=0
|
| 56 |
+
) as memory_profiler:
|
| 57 |
+
dl = self._build_dataloader(
|
| 58 |
+
config,
|
| 59 |
+
world_size,
|
| 60 |
+
rank,
|
| 61 |
+
)
|
| 62 |
+
dl = iter(dl)
|
| 63 |
+
|
| 64 |
+
for i in range(0, num_steps):
|
| 65 |
+
input_data, labels = next(dl)
|
| 66 |
+
print(f"Step {i} image size: {labels.shape}")
|
| 67 |
+
if torch_profiler:
|
| 68 |
+
torch_profiler.step()
|
| 69 |
+
if memory_profiler:
|
| 70 |
+
memory_profiler.step()
|
| 71 |
+
|
| 72 |
+
print(len(input_data["clip_tokens"]))
|
| 73 |
+
for k, v in input_data.items():
|
| 74 |
+
print(f"Step {i} {k} value: {type(v), v.shape}")
|
| 75 |
+
|
| 76 |
+
assert len(input_data) == 2 # (clip_encodings, t5_encodings)
|
| 77 |
+
assert labels.shape == (batch_size, 3, 256, 256)
|
| 78 |
+
# assert input_data["clip_tokens"].shape[0] == batch_size
|
| 79 |
+
# assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
|
| 80 |
+
|
| 81 |
+
if torch_profiler:
|
| 82 |
+
torch_profiler.step()
|
| 83 |
+
if memory_profiler:
|
| 84 |
+
memory_profiler.step(exit_ctx=True)
|
| 85 |
+
|
| 86 |
+
def test_preprocess(self):
|
| 87 |
+
# TODO
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
def _build_dataloader(
|
| 91 |
+
self,
|
| 92 |
+
job_config,
|
| 93 |
+
world_size,
|
| 94 |
+
rank,
|
| 95 |
+
):
|
| 96 |
+
|
| 97 |
+
return build_flux_dataloader(
|
| 98 |
+
dp_world_size=world_size,
|
| 99 |
+
dp_rank=rank,
|
| 100 |
+
job_config=job_config,
|
| 101 |
+
tokenizer=None,
|
| 102 |
+
infinite=False,
|
| 103 |
+
)
|
torchtitan/experiments/flux/train.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from torchtitan.config_manager import JobConfig
|
| 13 |
+
from torchtitan.distributed import utils as dist_utils
|
| 14 |
+
from torchtitan.experiments.flux.model.autoencoder import load_ae
|
| 15 |
+
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
|
| 16 |
+
from torchtitan.experiments.flux.model.model import FluxModel
|
| 17 |
+
from torchtitan.experiments.flux.utils import (
|
| 18 |
+
create_position_encoding_for_latents,
|
| 19 |
+
pack_latents,
|
| 20 |
+
preprocess_flux_data,
|
| 21 |
+
unpack_latents,
|
| 22 |
+
)
|
| 23 |
+
from torchtitan.tools.logging import init_logger, logger
|
| 24 |
+
from torchtitan.train import Trainer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class FluxTrainer(Trainer):
|
| 28 |
+
def __init__(self, job_config: JobConfig):
|
| 29 |
+
super().__init__(job_config)
|
| 30 |
+
|
| 31 |
+
self.preprocess_fn = preprocess_flux_data
|
| 32 |
+
# self.dtype = job_config.encoder.dtype
|
| 33 |
+
self._dtype = torch.bfloat16
|
| 34 |
+
self._seed = job_config.training.seed
|
| 35 |
+
self._guidance = job_config.training.guidance
|
| 36 |
+
|
| 37 |
+
# load components
|
| 38 |
+
model_config = self.train_spec.config[job_config.model.flavor]
|
| 39 |
+
self.autoencoder = load_ae(
|
| 40 |
+
job_config.encoder.auto_encoder_path,
|
| 41 |
+
model_config.autoencoder_params,
|
| 42 |
+
device="cpu",
|
| 43 |
+
dtype=self._dtype,
|
| 44 |
+
)
|
| 45 |
+
self.clip_encoder = FluxEmbedder(version=job_config.encoder.clip_encoder).to(
|
| 46 |
+
dtype=self._dtype
|
| 47 |
+
)
|
| 48 |
+
self.t5_encoder = FluxEmbedder(version=job_config.encoder.t5_encoder).to(
|
| 49 |
+
dtype=self._dtype
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def _predict_noise(
|
| 53 |
+
self,
|
| 54 |
+
model: FluxModel,
|
| 55 |
+
latents: torch.Tensor,
|
| 56 |
+
clip_encodings: torch.Tensor,
|
| 57 |
+
t5_encodings: torch.Tensor,
|
| 58 |
+
timesteps: torch.Tensor,
|
| 59 |
+
guidance: Optional[torch.Tensor] = None,
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Use Flux's flow-matching model to predict the noise in image latents.
|
| 63 |
+
Args:
|
| 64 |
+
model (FluxFlowModel): The Flux flow model.
|
| 65 |
+
latents (Tensor): Image encodings from the Flux autoencoder.
|
| 66 |
+
Shape: [bsz, 16, latent height, latent width]
|
| 67 |
+
clip_encodings (Tensor): CLIP text encodings.
|
| 68 |
+
Shape: [bsz, 768]
|
| 69 |
+
t5_encodings (Tensor): T5 text encodings.
|
| 70 |
+
Shape: [bsz, sequence length, 256 or 512]
|
| 71 |
+
timesteps (Tensor): The amount of noise (0 to 1).
|
| 72 |
+
Shape: [bsz]
|
| 73 |
+
guidance (Optional[Tensor]): The guidance value (1.5 to 4) if guidance-enabled model.
|
| 74 |
+
Shape: [bsz]
|
| 75 |
+
Default: None
|
| 76 |
+
model_ctx (ContextManager): Optional context to wrap the model call (e.g. for activation offloading)
|
| 77 |
+
Default: nullcontext
|
| 78 |
+
Returns:
|
| 79 |
+
Tensor: The noise prediction.
|
| 80 |
+
Shape: [bsz, 16, latent height, latent width]
|
| 81 |
+
"""
|
| 82 |
+
bsz, _, latent_height, latent_width = latents.shape
|
| 83 |
+
|
| 84 |
+
POSITION_DIM = 3 # constant for Flux flow model
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
# Create positional encodings
|
| 87 |
+
latent_pos_enc = create_position_encoding_for_latents(
|
| 88 |
+
bsz, latent_height, latent_width, POSITION_DIM
|
| 89 |
+
)
|
| 90 |
+
text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM)
|
| 91 |
+
|
| 92 |
+
# Convert latent into a sequence of patches
|
| 93 |
+
latents = pack_latents(latents)
|
| 94 |
+
|
| 95 |
+
# Predict noise
|
| 96 |
+
latent_noise_pred = model(
|
| 97 |
+
img=latents,
|
| 98 |
+
img_ids=latent_pos_enc.to(latents),
|
| 99 |
+
txt=t5_encodings.to(latents),
|
| 100 |
+
txt_ids=text_pos_enc.to(latents),
|
| 101 |
+
y=clip_encodings.to(latents),
|
| 102 |
+
timesteps=timesteps.to(latents),
|
| 103 |
+
guidance=guidance.to(latents) if guidance is not None else None,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Convert sequence of patches to latent shape
|
| 107 |
+
latent_noise_pred = unpack_latents(
|
| 108 |
+
latent_noise_pred, latent_height, latent_width
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return latent_noise_pred
|
| 112 |
+
|
| 113 |
+
def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
|
| 114 |
+
# generate t5 and clip
|
| 115 |
+
input_dict["image"] = labels
|
| 116 |
+
input_dict = self.preprocess_fn(
|
| 117 |
+
device=self.device,
|
| 118 |
+
dtype=self._dtype,
|
| 119 |
+
autoencoder=self.autoencoder,
|
| 120 |
+
clip_encoder=self.clip_encoder,
|
| 121 |
+
t5_encoder=self.t5_encoder,
|
| 122 |
+
batch=input_dict,
|
| 123 |
+
offload=True,
|
| 124 |
+
)
|
| 125 |
+
labels = input_dict["img_encodings"]
|
| 126 |
+
|
| 127 |
+
self.optimizers.zero_grad()
|
| 128 |
+
|
| 129 |
+
# Keep these variables local to shorten the code as these are
|
| 130 |
+
# the major variables that are used in the training loop.
|
| 131 |
+
model_parts = self.model_parts
|
| 132 |
+
world_mesh = self.world_mesh
|
| 133 |
+
parallel_dims = self.parallel_dims
|
| 134 |
+
|
| 135 |
+
# image in latent space transformed by self.auto_encoder
|
| 136 |
+
clip_encodings = input_dict["clip_encodings"]
|
| 137 |
+
t5_encodings = input_dict["t5_encodings"]
|
| 138 |
+
|
| 139 |
+
bsz = labels.shape[0]
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
noise = torch.randn_like(labels)
|
| 143 |
+
timesteps = torch.rand((bsz,)).to(labels)
|
| 144 |
+
sigmas = timesteps.view(-1, 1, 1, 1)
|
| 145 |
+
noisy_latents = (1 - sigmas) * labels + sigmas * noise
|
| 146 |
+
guidance = torch.full((bsz,), self._guidance).to(labels)
|
| 147 |
+
|
| 148 |
+
target = noise - labels
|
| 149 |
+
|
| 150 |
+
assert len(model_parts) == 1
|
| 151 |
+
# TODO(jianiw): model_parts will be wrapped by FSDP, which will cacluate
|
| 152 |
+
model_parts[0] = model_parts[0].to(dtype=self._dtype)
|
| 153 |
+
|
| 154 |
+
pred = self._predict_noise(
|
| 155 |
+
model_parts[0],
|
| 156 |
+
noisy_latents,
|
| 157 |
+
clip_encodings,
|
| 158 |
+
t5_encodings,
|
| 159 |
+
timesteps,
|
| 160 |
+
guidance,
|
| 161 |
+
)
|
| 162 |
+
loss = self.loss_fn(pred, target)
|
| 163 |
+
# pred.shape=(bs, seq_len, vocab_size)
|
| 164 |
+
# need to free to before bwd to avoid peaking memory
|
| 165 |
+
del (pred, noise, target)
|
| 166 |
+
loss.backward()
|
| 167 |
+
|
| 168 |
+
dist_utils.clip_grad_norm_(
|
| 169 |
+
[p for m in model_parts for p in m.parameters()],
|
| 170 |
+
self.job_config.training.max_norm,
|
| 171 |
+
foreach=True,
|
| 172 |
+
pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None,
|
| 173 |
+
)
|
| 174 |
+
self.checkpointer.maybe_wait_for_staging()
|
| 175 |
+
self.optimizers.step()
|
| 176 |
+
self.lr_schedulers.step()
|
| 177 |
+
|
| 178 |
+
# log metrics
|
| 179 |
+
if not self.metrics_processor.should_log(self.step):
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
if (
|
| 183 |
+
parallel_dims.dp_replicate_enabled
|
| 184 |
+
or parallel_dims.dp_shard_enabled
|
| 185 |
+
or parallel_dims.cp_enabled
|
| 186 |
+
):
|
| 187 |
+
loss = loss.detach()
|
| 188 |
+
global_avg_loss, global_max_loss = (
|
| 189 |
+
dist_utils.dist_mean(loss, world_mesh["dp_cp"]),
|
| 190 |
+
dist_utils.dist_max(loss, world_mesh["dp_cp"]),
|
| 191 |
+
)
|
| 192 |
+
else:
|
| 193 |
+
global_avg_loss = global_max_loss = loss.item()
|
| 194 |
+
|
| 195 |
+
self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
init_logger()
|
| 200 |
+
config = JobConfig()
|
| 201 |
+
config.maybe_add_custom_args()
|
| 202 |
+
config.parse_args()
|
| 203 |
+
trainer: Optional[FluxTrainer] = None
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
trainer = FluxTrainer(config)
|
| 207 |
+
if config.checkpoint.create_seed_checkpoint:
|
| 208 |
+
assert int(
|
| 209 |
+
os.environ["WORLD_SIZE"]
|
| 210 |
+
), "Must create seed checkpoint using a single device, to disable sharding."
|
| 211 |
+
assert (
|
| 212 |
+
config.checkpoint.enable_checkpoint
|
| 213 |
+
), "Must enable checkpointing when creating a seed checkpoint."
|
| 214 |
+
trainer.checkpointer.save(curr_step=0, force=True)
|
| 215 |
+
logger.info("Created seed checkpoint")
|
| 216 |
+
else:
|
| 217 |
+
trainer.train()
|
| 218 |
+
finally:
|
| 219 |
+
if trainer:
|
| 220 |
+
trainer.close()
|
| 221 |
+
|
| 222 |
+
if torch.distributed.is_initialized():
|
| 223 |
+
torch.distributed.destroy_process_group()
|
| 224 |
+
logger.info("Process group destroyed.")
|
torchtitan/experiments/flux/train_configs/debug_model.toml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
[job]
|
| 3 |
+
dump_folder = "./outputs"
|
| 4 |
+
description = "Flux debug model"
|
| 5 |
+
print_args = false
|
| 6 |
+
use_for_integration_test = true
|
| 7 |
+
|
| 8 |
+
[profiling]
|
| 9 |
+
enable_profiling = false
|
| 10 |
+
save_traces_folder = "profile_trace"
|
| 11 |
+
profile_freq = 10
|
| 12 |
+
enable_memory_snapshot = false
|
| 13 |
+
save_memory_snapshot_folder = "memory_snapshot"
|
| 14 |
+
|
| 15 |
+
[metrics]
|
| 16 |
+
log_freq = 1
|
| 17 |
+
disable_color_printing = false
|
| 18 |
+
enable_tensorboard = false
|
| 19 |
+
save_tb_folder = "tb"
|
| 20 |
+
enable_wandb = false
|
| 21 |
+
|
| 22 |
+
[model]
|
| 23 |
+
name = "flux"
|
| 24 |
+
flavor = "flux-debug"
|
| 25 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 26 |
+
# test tokenizer.model, for debug purpose only
|
| 27 |
+
# tokenizer_path = "./tests/assets/test_tiktoken.model"
|
| 28 |
+
# converters = "float8"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
[optimizer]
|
| 32 |
+
name = "AdamW"
|
| 33 |
+
lr = 8e-4
|
| 34 |
+
eps = 1e-8
|
| 35 |
+
|
| 36 |
+
[lr_scheduler]
|
| 37 |
+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
|
| 38 |
+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
|
| 39 |
+
decay_type = "linear"
|
| 40 |
+
lr_min = 0.0
|
| 41 |
+
|
| 42 |
+
[training]
|
| 43 |
+
batch_size = 32
|
| 44 |
+
seq_len = 512
|
| 45 |
+
max_norm = 1.0 # grad norm clipping
|
| 46 |
+
steps = 10
|
| 47 |
+
compile = false
|
| 48 |
+
dataset = "cc12m"
|
| 49 |
+
guidance = 3.5
|
| 50 |
+
seed = 0
|
| 51 |
+
|
| 52 |
+
[encoder]
|
| 53 |
+
t5_encoder="google/t5-v1_1-small"
|
| 54 |
+
clip_encoder="openai/clip-vit-large-patch14"
|
| 55 |
+
max_t5_encoding_len=512
|
| 56 |
+
auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
|
| 57 |
+
|
| 58 |
+
[parallelism]
|
| 59 |
+
data_parallel_replicate_degree = 1
|
| 60 |
+
data_parallel_shard_degree = 1
|
| 61 |
+
fsdp_reshard_after_forward = "default" # default / never / always
|
| 62 |
+
tensor_parallel_degree = 1
|
| 63 |
+
enable_async_tensor_parallel = false
|
| 64 |
+
pipeline_parallel_degree = 1
|
| 65 |
+
context_parallel_degree = 1
|
| 66 |
+
|
| 67 |
+
[experimental]
|
| 68 |
+
custom_args_module = "torchtitan.experiments.flux.flux_argparser"
|
torchtitan/experiments/flux/utils.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
|
| 13 |
+
from torchtitan.experiments.flux.model.autoencoder import AutoEncoder
|
| 14 |
+
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def preprocess_flux_data(
|
| 18 |
+
# arguments from the recipe
|
| 19 |
+
device: torch.device,
|
| 20 |
+
dtype: torch.dtype,
|
| 21 |
+
*,
|
| 22 |
+
# arguments from the config
|
| 23 |
+
autoencoder: Optional[AutoEncoder],
|
| 24 |
+
clip_encoder: FluxEmbedder,
|
| 25 |
+
t5_encoder: FluxEmbedder,
|
| 26 |
+
batch: dict[str, Tensor],
|
| 27 |
+
offload: bool = False,
|
| 28 |
+
) -> dict[str, Tensor]:
|
| 29 |
+
"""
|
| 30 |
+
Take a batch of inputs and encoder as input and return a batch of preprocessed data.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
device (torch.device): device to do preprocessing on
|
| 34 |
+
dtype (torch.dtype): data type to do preprocessing in
|
| 35 |
+
autoencoer(AutoEncoder): autoencoder to use for preprocessing
|
| 36 |
+
clip_encoder
|
| 37 |
+
t5_encoder
|
| 38 |
+
batch (dict[str, Tensor]): batch of data to preprocess
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
dict[str, Tensor]: batch of preprocessed data
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
# The input of encoder should be torch.int type
|
| 45 |
+
if offload:
|
| 46 |
+
clip_encoder.to(device)
|
| 47 |
+
t5_encoder.to(device)
|
| 48 |
+
if autoencoder is not None:
|
| 49 |
+
autoencoder.to(device)
|
| 50 |
+
|
| 51 |
+
clip_tokens = batch["clip_tokens"].squeeze().to(device=device, dtype=torch.int)
|
| 52 |
+
t5_tokens = batch["t5_tokens"].squeeze().to(device=device, dtype=torch.int)
|
| 53 |
+
|
| 54 |
+
clip_text_encodings = clip_encoder(clip_tokens)
|
| 55 |
+
t5_text_encodings = t5_encoder(t5_tokens)
|
| 56 |
+
|
| 57 |
+
if autoencoder is not None:
|
| 58 |
+
images = batch["image"].to(device=device, dtype=dtype)
|
| 59 |
+
img_encodings = autoencoder.encode(images)
|
| 60 |
+
batch["img_encodings"] = img_encodings.to(device=device, dtype=dtype)
|
| 61 |
+
|
| 62 |
+
batch["clip_encodings"] = clip_text_encodings.to(dtype)
|
| 63 |
+
batch["t5_encodings"] = t5_text_encodings.to(dtype)
|
| 64 |
+
|
| 65 |
+
# offload encoders to cpu after preprocessing
|
| 66 |
+
if offload:
|
| 67 |
+
clip_encoder.to("cpu")
|
| 68 |
+
t5_encoder.to("cpu")
|
| 69 |
+
if autoencoder is not None:
|
| 70 |
+
autoencoder.to("cpu")
|
| 71 |
+
|
| 72 |
+
return batch
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def generate_noise_latent(
|
| 76 |
+
bsz: int,
|
| 77 |
+
height: int,
|
| 78 |
+
width: int,
|
| 79 |
+
device: str | torch.device,
|
| 80 |
+
dtype: torch.dtype,
|
| 81 |
+
seed: int,
|
| 82 |
+
) -> Tensor:
|
| 83 |
+
"""Generate noise latents for the Flux flow model.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
bsz (int): batch_size.
|
| 87 |
+
height (int): The height of the image.
|
| 88 |
+
width (int): The width of the image.
|
| 89 |
+
device (str | torch.device): The device to use.
|
| 90 |
+
dtype (torch.dtype): The dtype to use.
|
| 91 |
+
seed (int): The seed to use for randomize.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Tensor: The noise latents.
|
| 95 |
+
Shape: [num_samples, LATENT_CHANNELS, height // IMG_LATENT_SIZE_RATIO, width // IMG_LATENT_SIZE_RATIO]
|
| 96 |
+
|
| 97 |
+
"""
|
| 98 |
+
LATENT_CHANNELS, IMAGE_LATENT_SIZE_RATIO = 16, 8
|
| 99 |
+
return torch.randn(
|
| 100 |
+
bsz,
|
| 101 |
+
LATENT_CHANNELS,
|
| 102 |
+
height // IMAGE_LATENT_SIZE_RATIO,
|
| 103 |
+
width // IMAGE_LATENT_SIZE_RATIO,
|
| 104 |
+
dtype=dtype,
|
| 105 |
+
generator=torch.Generator().manual_seed(seed),
|
| 106 |
+
).to(device)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def create_position_encoding_for_latents(
|
| 110 |
+
bsz: int, latent_height: int, latent_width: int, position_dim: int = 3
|
| 111 |
+
) -> Tensor:
|
| 112 |
+
"""
|
| 113 |
+
Create the packed latents' position encodings for the Flux flow model.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
bsz (int): The batch size.
|
| 117 |
+
latent_height (int): The height of the latent.
|
| 118 |
+
latent_width (int): The width of the latent.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Tensor: The position encodings.
|
| 122 |
+
Shape: [bsz, (latent_height // PATCH_HEIGHT) * (latent_width // PATCH_WIDTH), POSITION_DIM)
|
| 123 |
+
"""
|
| 124 |
+
PATCH_HEIGHT, PATCH_WIDTH = 2, 2
|
| 125 |
+
|
| 126 |
+
height = latent_height // PATCH_HEIGHT
|
| 127 |
+
width = latent_width // PATCH_WIDTH
|
| 128 |
+
|
| 129 |
+
position_encoding = torch.zeros(height, width, position_dim)
|
| 130 |
+
|
| 131 |
+
row_indices = torch.arange(height)
|
| 132 |
+
position_encoding[:, :, 1] = row_indices.unsqueeze(1)
|
| 133 |
+
|
| 134 |
+
col_indices = torch.arange(width)
|
| 135 |
+
position_encoding[:, :, 2] = col_indices.unsqueeze(0)
|
| 136 |
+
|
| 137 |
+
# Flatten and repeat for the full batch
|
| 138 |
+
# [height, width, 3] -> [bsz, height * width, 3]
|
| 139 |
+
position_encoding = position_encoding.view(1, height * width, position_dim)
|
| 140 |
+
position_encoding = position_encoding.repeat(bsz, 1, 1)
|
| 141 |
+
|
| 142 |
+
return position_encoding
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def pack_latents(x: Tensor) -> Tensor:
|
| 146 |
+
"""
|
| 147 |
+
Rearrange latents from an image-like format into a sequence of patches.
|
| 148 |
+
Equivalent to `einops.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)")`.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
x (Tensor): The unpacked latents.
|
| 152 |
+
Shape: [bsz, ch, latent height, latent width]
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Tensor: The packed latents.
|
| 156 |
+
Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
|
| 157 |
+
"""
|
| 158 |
+
PATCH_HEIGHT, PATCH_WIDTH = 2, 2
|
| 159 |
+
|
| 160 |
+
b, c, latent_height, latent_width = x.shape
|
| 161 |
+
h = latent_height // PATCH_HEIGHT
|
| 162 |
+
w = latent_width // PATCH_WIDTH
|
| 163 |
+
|
| 164 |
+
# [b, c, h*ph, w*ph] -> [b, c, h, w, ph, pw]
|
| 165 |
+
x = x.unfold(2, PATCH_HEIGHT, PATCH_HEIGHT).unfold(3, PATCH_WIDTH, PATCH_WIDTH)
|
| 166 |
+
|
| 167 |
+
# [b, c, h, w, ph, PW] -> [b, h, w, c, ph, PW]
|
| 168 |
+
x = x.permute(0, 2, 3, 1, 4, 5)
|
| 169 |
+
|
| 170 |
+
# [b, h, w, c, ph, PW] -> [b, h*w, c*ph*PW]
|
| 171 |
+
return x.reshape(b, h * w, c * PATCH_HEIGHT * PATCH_WIDTH)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def unpack_latents(x: Tensor, latent_height: int, latent_width: int) -> Tensor:
|
| 175 |
+
"""
|
| 176 |
+
Rearrange latents from a sequence of patches into an image-like format.
|
| 177 |
+
Equivalent to `einops.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)")`.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
x (Tensor): The packed latents.
|
| 181 |
+
Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
|
| 182 |
+
latent_height (int): The height of the unpacked latents.
|
| 183 |
+
latent_width (int): The width of the unpacked latents.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Tensor: The unpacked latents.
|
| 187 |
+
Shape: [bsz, ch, latent height, latent width]
|
| 188 |
+
"""
|
| 189 |
+
PATCH_HEIGHT, PATCH_WIDTH = 2, 2
|
| 190 |
+
|
| 191 |
+
b, _, c_ph_pw = x.shape
|
| 192 |
+
h = latent_height // PATCH_HEIGHT
|
| 193 |
+
w = latent_width // PATCH_WIDTH
|
| 194 |
+
c = c_ph_pw // (PATCH_HEIGHT * PATCH_WIDTH)
|
| 195 |
+
|
| 196 |
+
# [b, h*w, c*ph*pw] -> [b, h, w, c, ph, pw]
|
| 197 |
+
x = x.reshape(b, h, w, c, PATCH_HEIGHT, PATCH_WIDTH)
|
| 198 |
+
|
| 199 |
+
# [b, h, w, c, ph, pw] -> [b, c, h, ph, w, pw]
|
| 200 |
+
x = x.permute(0, 3, 1, 4, 2, 5)
|
| 201 |
+
|
| 202 |
+
# [b, c, h, ph, w, pw] -> [b, c, h*ph, w*pw]
|
| 203 |
+
return x.reshape(b, c, h * PATCH_HEIGHT, w * PATCH_WIDTH)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py
ADDED
|
@@ -0,0 +1,885 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import logging
|
| 9 |
+
import math
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
from typing import Dict, List, Tuple
|
| 13 |
+
|
| 14 |
+
# import numpy as np
|
| 15 |
+
import torch #
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torch.optim as optim
|
| 19 |
+
|
| 20 |
+
# from torchao_pr.mg_grouped_gemm import mg_grouped_gemm
|
| 21 |
+
|
| 22 |
+
# Configure logging
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Try to import the optimized MG GEMM implementation
|
| 28 |
+
try:
|
| 29 |
+
from torchao_pr.mg_grouped_gemm import ( # grouped_gemm_backward,
|
| 30 |
+
grouped_gemm_forward,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
has_mg_gemm = True
|
| 34 |
+
except ImportError:
|
| 35 |
+
logging.warning("MG GEMM implementation not found. Will use manual looping only.")
|
| 36 |
+
has_mg_gemm = False
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Router(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
Router module that assigns tokens to experts.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, input_dim: int, num_experts: int, top_k: int = 2):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.input_dim = input_dim
|
| 47 |
+
self.num_experts = num_experts
|
| 48 |
+
self.top_k = top_k
|
| 49 |
+
|
| 50 |
+
# Routing layer
|
| 51 |
+
self.router = nn.Linear(input_dim, num_experts)
|
| 52 |
+
|
| 53 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
| 54 |
+
"""
|
| 55 |
+
Route input tokens to experts.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim)
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Tuple containing:
|
| 62 |
+
- router_logits: Raw routing probabilities
|
| 63 |
+
- dispatch_tensor: One-hot tensor indicating expert assignment
|
| 64 |
+
- expert_indices: List of indices for each expert's tokens
|
| 65 |
+
"""
|
| 66 |
+
batch_size, seq_len, _ = x.shape
|
| 67 |
+
|
| 68 |
+
# Flatten batch and sequence dimensions
|
| 69 |
+
x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
|
| 70 |
+
|
| 71 |
+
# Compute routing probabilities
|
| 72 |
+
router_logits = self.router(x_flat) # (batch_size * seq_len, num_experts)
|
| 73 |
+
|
| 74 |
+
# Apply softmax to get probabilities
|
| 75 |
+
router_probs = F.softmax(router_logits, dim=-1)
|
| 76 |
+
|
| 77 |
+
# Get top-k experts for each token
|
| 78 |
+
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
| 79 |
+
|
| 80 |
+
# Normalize top-k probabilities
|
| 81 |
+
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
|
| 82 |
+
|
| 83 |
+
# Create dispatch tensor (one-hot representation of assignments)
|
| 84 |
+
dispatch_tensor = torch.zeros_like(router_probs)
|
| 85 |
+
token_indices = (
|
| 86 |
+
torch.arange(router_probs.size(0), device=router_probs.device)
|
| 87 |
+
.unsqueeze(1)
|
| 88 |
+
.expand(-1, self.top_k)
|
| 89 |
+
)
|
| 90 |
+
dispatch_tensor.scatter_(1, top_k_indices, top_k_probs) # .unsqueeze(-1))
|
| 91 |
+
|
| 92 |
+
# For each expert, get the indices of tokens routed to it
|
| 93 |
+
expert_indices = []
|
| 94 |
+
for expert_idx in range(self.num_experts):
|
| 95 |
+
# Get indices of tokens that have non-zero probability for this expert
|
| 96 |
+
indices = torch.nonzero(dispatch_tensor[:, expert_idx] > 0, as_tuple=True)[
|
| 97 |
+
0
|
| 98 |
+
]
|
| 99 |
+
expert_indices.append(indices)
|
| 100 |
+
|
| 101 |
+
return router_logits, dispatch_tensor, expert_indices
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Expert(nn.Module):
|
| 105 |
+
"""
|
| 106 |
+
Individual expert module.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
|
| 112 |
+
self.activation = nn.GELU()
|
| 113 |
+
self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False)
|
| 114 |
+
|
| 115 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 116 |
+
x = self.fc1(x)
|
| 117 |
+
x = self.activation(x)
|
| 118 |
+
x = self.fc2(x)
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class MixtureOfExperts(nn.Module):
|
| 123 |
+
"""
|
| 124 |
+
Mixture of Experts layer with support for both manual looping and grouped GEMM.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
input_dim: int,
|
| 130 |
+
hidden_dim: int,
|
| 131 |
+
output_dim: int,
|
| 132 |
+
num_experts: int,
|
| 133 |
+
top_k: int = 2,
|
| 134 |
+
use_mg_gemm: bool = False,
|
| 135 |
+
):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.input_dim = input_dim
|
| 138 |
+
self.hidden_dim = hidden_dim
|
| 139 |
+
self.output_dim = output_dim
|
| 140 |
+
self.num_experts = num_experts
|
| 141 |
+
self.top_k = top_k
|
| 142 |
+
self.use_mg_gemm = use_mg_gemm and has_mg_gemm
|
| 143 |
+
|
| 144 |
+
# Router
|
| 145 |
+
self.router = Router(input_dim, num_experts, top_k)
|
| 146 |
+
|
| 147 |
+
# Create expert modules
|
| 148 |
+
if self.use_mg_gemm:
|
| 149 |
+
# For MG GEMM, we need a single weight tensor for all experts
|
| 150 |
+
# First layer (input -> hidden)
|
| 151 |
+
self.expert_fc1_weight = nn.Parameter(
|
| 152 |
+
torch.randn(num_experts * hidden_dim, input_dim) / math.sqrt(input_dim)
|
| 153 |
+
)
|
| 154 |
+
# self.expert_fc1_bias = nn.Parameter(torch.zeros(num_experts * hidden_dim))
|
| 155 |
+
|
| 156 |
+
# Second layer (hidden -> output)
|
| 157 |
+
self.expert_fc2_weight = nn.Parameter(
|
| 158 |
+
torch.randn(num_experts * output_dim, hidden_dim)
|
| 159 |
+
/ math.sqrt(hidden_dim)
|
| 160 |
+
)
|
| 161 |
+
# self.expert_fc2_bias = nn.Parameter(torch.zeros(num_experts * output_dim))
|
| 162 |
+
else:
|
| 163 |
+
# For manual looping, create separate experts
|
| 164 |
+
self.experts = nn.ModuleList(
|
| 165 |
+
[Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def forward_manual_loop(self, x: torch.Tensor) -> torch.Tensor:
|
| 169 |
+
"""
|
| 170 |
+
Forward pass using manual looping over experts.
|
| 171 |
+
"""
|
| 172 |
+
batch_size, seq_len, _ = x.shape
|
| 173 |
+
x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
|
| 174 |
+
|
| 175 |
+
# Get routing information
|
| 176 |
+
router_logits, dispatch_tensor, expert_indices = self.router(x)
|
| 177 |
+
|
| 178 |
+
# Initialize output tensor
|
| 179 |
+
final_output = torch.zeros(
|
| 180 |
+
batch_size * seq_len, self.output_dim, device=x.device
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Process each expert
|
| 184 |
+
for expert_idx, indices in enumerate(expert_indices):
|
| 185 |
+
if indices.numel() > 0:
|
| 186 |
+
# Get tokens routed to this expert
|
| 187 |
+
expert_inputs = x_flat[indices] # (num_tokens_for_expert, input_dim)
|
| 188 |
+
|
| 189 |
+
# Process tokens through expert
|
| 190 |
+
expert_outputs = self.experts[expert_idx](
|
| 191 |
+
expert_inputs
|
| 192 |
+
) # (num_tokens_for_expert, output_dim)
|
| 193 |
+
|
| 194 |
+
# Scale outputs by router probabilities
|
| 195 |
+
scaled_outputs = expert_outputs * dispatch_tensor[
|
| 196 |
+
indices, expert_idx
|
| 197 |
+
].unsqueeze(1)
|
| 198 |
+
|
| 199 |
+
# Add to final output
|
| 200 |
+
final_output.index_add_(0, indices, scaled_outputs)
|
| 201 |
+
|
| 202 |
+
# Reshape back to original dimensions
|
| 203 |
+
output = final_output.reshape(batch_size, seq_len, self.output_dim)
|
| 204 |
+
|
| 205 |
+
return output, router_logits
|
| 206 |
+
|
| 207 |
+
def forward_mg_gemm(self, x: torch.Tensor) -> torch.Tensor:
|
| 208 |
+
batch_size, seq_len, _ = x.shape
|
| 209 |
+
x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
|
| 210 |
+
total_tokens = batch_size * seq_len
|
| 211 |
+
|
| 212 |
+
# Get routing information
|
| 213 |
+
router_logits, dispatch_tensor, expert_indices = self.router(x)
|
| 214 |
+
|
| 215 |
+
# Get token counts for each expert
|
| 216 |
+
token_counts = [indices.numel() for indices in expert_indices]
|
| 217 |
+
m_sizes = torch.tensor(token_counts, dtype=torch.int32, device=x.device)
|
| 218 |
+
|
| 219 |
+
print(f"Token counts per expert: {token_counts}")
|
| 220 |
+
print(f"m_sizes: {m_sizes}")
|
| 221 |
+
|
| 222 |
+
# Create the combined input tensor
|
| 223 |
+
combined_input = torch.zeros(sum(token_counts), self.input_dim, device=x.device)
|
| 224 |
+
|
| 225 |
+
start_idx = 0
|
| 226 |
+
for expert_idx, indices in enumerate(expert_indices):
|
| 227 |
+
if indices.numel() > 0:
|
| 228 |
+
end_idx = start_idx + indices.numel()
|
| 229 |
+
combined_input[start_idx:end_idx] = x_flat[indices]
|
| 230 |
+
start_idx = end_idx
|
| 231 |
+
|
| 232 |
+
print(f"combined_input shape: {combined_input.shape}")
|
| 233 |
+
|
| 234 |
+
# First layer: input -> hidden
|
| 235 |
+
fc1_weight_reshaped = self.expert_fc1_weight.reshape(
|
| 236 |
+
self.num_experts, self.hidden_dim, self.input_dim
|
| 237 |
+
)
|
| 238 |
+
fc1_weight_combined = fc1_weight_reshaped.reshape(-1, self.input_dim)
|
| 239 |
+
|
| 240 |
+
print(f"fc1_weight_combined shape: {fc1_weight_combined.shape}")
|
| 241 |
+
|
| 242 |
+
# Run the grouped GEMM
|
| 243 |
+
hidden_outputs = grouped_gemm_forward(
|
| 244 |
+
combined_input, fc1_weight_combined, m_sizes
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
print(f"hidden_outputs shape after first GEMM: {hidden_outputs.shape}")
|
| 248 |
+
|
| 249 |
+
# Apply activation
|
| 250 |
+
hidden_outputs = F.gelu(hidden_outputs)
|
| 251 |
+
|
| 252 |
+
print(f"hidden_outputs shape after activation: {hidden_outputs.shape}")
|
| 253 |
+
|
| 254 |
+
# Second layer: hidden -> output
|
| 255 |
+
# Reshape hidden_outputs to match expected dimensions
|
| 256 |
+
reshaped_hidden_outputs = []
|
| 257 |
+
start_idx = 0
|
| 258 |
+
|
| 259 |
+
for expert_idx, count in enumerate(token_counts):
|
| 260 |
+
if count > 0:
|
| 261 |
+
end_idx = start_idx + count
|
| 262 |
+
# Take this expert's outputs and reshape to [count, hidden_dim]
|
| 263 |
+
expert_output = hidden_outputs[
|
| 264 |
+
start_idx:end_idx,
|
| 265 |
+
expert_idx * self.hidden_dim : (expert_idx + 1) * self.hidden_dim,
|
| 266 |
+
]
|
| 267 |
+
reshaped_hidden_outputs.append(expert_output)
|
| 268 |
+
start_idx = end_idx
|
| 269 |
+
|
| 270 |
+
# Concatenate all reshaped outputs
|
| 271 |
+
hidden_outputs = torch.cat(reshaped_hidden_outputs, dim=0)
|
| 272 |
+
|
| 273 |
+
# Reshape expert weights for second layer
|
| 274 |
+
fc2_weight_reshaped = self.expert_fc2_weight.reshape(
|
| 275 |
+
self.num_experts, self.output_dim, self.hidden_dim
|
| 276 |
+
)
|
| 277 |
+
fc2_weight_combined = fc2_weight_reshaped.reshape(-1, self.hidden_dim)
|
| 278 |
+
|
| 279 |
+
print(f"fc2_weight_combined shape: {fc2_weight_combined.shape}")
|
| 280 |
+
|
| 281 |
+
# Run the second grouped GEMM
|
| 282 |
+
expert_outputs_combined = grouped_gemm_forward(
|
| 283 |
+
hidden_outputs, fc2_weight_combined, m_sizes
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Initialize final output tensor with correct shape
|
| 287 |
+
final_output = torch.zeros(total_tokens, self.output_dim, device=x.device)
|
| 288 |
+
|
| 289 |
+
# Distribute the outputs back to the original token positions
|
| 290 |
+
start_idx = 0
|
| 291 |
+
for expert_idx, indices in enumerate(expert_indices):
|
| 292 |
+
if indices.numel() > 0:
|
| 293 |
+
end_idx = start_idx + indices.numel()
|
| 294 |
+
# Get this expert's outputs
|
| 295 |
+
expert_outputs = expert_outputs_combined[start_idx:end_idx]
|
| 296 |
+
|
| 297 |
+
print(
|
| 298 |
+
f"Expert {expert_idx} - indices shape: {indices.shape}, expert_outputs shape: {expert_outputs.shape}"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Scale outputs by router probabilities
|
| 302 |
+
scaled_outputs = expert_outputs * dispatch_tensor[
|
| 303 |
+
indices, expert_idx
|
| 304 |
+
].unsqueeze(1)
|
| 305 |
+
|
| 306 |
+
# Ensure dimensions match before using index_add_
|
| 307 |
+
if scaled_outputs.shape[1] != final_output.shape[1]:
|
| 308 |
+
# print(
|
| 309 |
+
# f"Reshaping: Dimension mismatch: scaled_outputs {scaled_outputs.shape}, final_output {final_output.shape}"
|
| 310 |
+
# )
|
| 311 |
+
# Reshape if needed - make sure output_dim is correct
|
| 312 |
+
scaled_outputs = scaled_outputs[:, : self.output_dim]
|
| 313 |
+
|
| 314 |
+
# Add to final output
|
| 315 |
+
final_output.index_add_(0, indices, scaled_outputs)
|
| 316 |
+
|
| 317 |
+
start_idx = end_idx
|
| 318 |
+
|
| 319 |
+
# Reshape back to original dimensions
|
| 320 |
+
output = final_output.reshape(batch_size, seq_len, self.output_dim)
|
| 321 |
+
|
| 322 |
+
return output, router_logits
|
| 323 |
+
|
| 324 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 325 |
+
if self.use_mg_gemm and has_mg_gemm:
|
| 326 |
+
return self.forward_mg_gemm(x)
|
| 327 |
+
else:
|
| 328 |
+
return self.forward_manual_loop(x)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class MoEModel(nn.Module):
|
| 332 |
+
"""
|
| 333 |
+
Simple model using MoE layers.
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
def __init__(
|
| 337 |
+
self,
|
| 338 |
+
vocab_size: int,
|
| 339 |
+
embed_dim: int,
|
| 340 |
+
hidden_dim: int,
|
| 341 |
+
num_experts: int,
|
| 342 |
+
top_k: int = 2,
|
| 343 |
+
use_mg_gemm: bool = False,
|
| 344 |
+
):
|
| 345 |
+
super().__init__()
|
| 346 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| 347 |
+
self.moe_layer = MixtureOfExperts(
|
| 348 |
+
input_dim=embed_dim,
|
| 349 |
+
hidden_dim=hidden_dim,
|
| 350 |
+
output_dim=embed_dim,
|
| 351 |
+
num_experts=num_experts,
|
| 352 |
+
top_k=top_k,
|
| 353 |
+
use_mg_gemm=use_mg_gemm,
|
| 354 |
+
)
|
| 355 |
+
self.output_layer = nn.Linear(embed_dim, vocab_size)
|
| 356 |
+
|
| 357 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 358 |
+
# x shape: (batch_size, seq_len)
|
| 359 |
+
embedded = self.embedding(x) # (batch_size, seq_len, embed_dim)
|
| 360 |
+
moe_output, router_logits = self.moe_layer(
|
| 361 |
+
embedded
|
| 362 |
+
) # (batch_size, seq_len, embed_dim)
|
| 363 |
+
logits = self.output_layer(moe_output) # (batch_size, seq_len, vocab_size)
|
| 364 |
+
return logits, router_logits
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def compute_load_balancing_loss(
|
| 368 |
+
router_logits: torch.Tensor, num_experts: int
|
| 369 |
+
) -> torch.Tensor:
|
| 370 |
+
"""
|
| 371 |
+
Compute the load balancing loss for MoE training.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
router_logits (torch.Tensor): Router logits of shape (batch_size * seq_len, num_experts)
|
| 375 |
+
num_experts (int): Number of experts
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
torch.Tensor: Load balancing loss
|
| 379 |
+
"""
|
| 380 |
+
# Get router probabilities
|
| 381 |
+
router_probs = F.softmax(
|
| 382 |
+
router_logits, dim=-1
|
| 383 |
+
) # (batch_size * seq_len, num_experts)
|
| 384 |
+
|
| 385 |
+
# Compute fraction of tokens routed to each expert
|
| 386 |
+
# Sum across the batch dimension and normalize
|
| 387 |
+
router_probs_sum = router_probs.sum(dim=0) # (num_experts,)
|
| 388 |
+
router_probs_sum = router_probs_sum / router_probs_sum.sum()
|
| 389 |
+
|
| 390 |
+
# Compute the mean probability per expert
|
| 391 |
+
mean_prob = 1.0 / num_experts
|
| 392 |
+
|
| 393 |
+
# Compute the fraction of tokens routed to each expert
|
| 394 |
+
# The goal is to have uniform routing across experts
|
| 395 |
+
load_balancing_loss = num_experts * torch.sum(router_probs_sum * router_probs_sum)
|
| 396 |
+
|
| 397 |
+
return load_balancing_loss
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def generate_sample_data(
|
| 401 |
+
batch_size: int, seq_len: int, vocab_size: int, device: str = "cuda"
|
| 402 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 403 |
+
"""
|
| 404 |
+
Generate sample data for training.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
batch_size (int): Batch size
|
| 408 |
+
seq_len (int): Sequence length
|
| 409 |
+
vocab_size (int): Vocabulary size
|
| 410 |
+
device (str): Device to use
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
Tuple of input tokens and target tokens
|
| 414 |
+
"""
|
| 415 |
+
# Generate random input tokens
|
| 416 |
+
inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
| 417 |
+
|
| 418 |
+
# Generate random target tokens
|
| 419 |
+
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
| 420 |
+
|
| 421 |
+
return inputs, targets
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def train_epoch(
|
| 425 |
+
model: nn.Module,
|
| 426 |
+
optimizer: torch.optim.Optimizer,
|
| 427 |
+
batch_size: int,
|
| 428 |
+
seq_len: int,
|
| 429 |
+
vocab_size: int,
|
| 430 |
+
num_batches: int,
|
| 431 |
+
device: str,
|
| 432 |
+
load_balance_coef: float = 0.01,
|
| 433 |
+
) -> Dict[str, float]:
|
| 434 |
+
"""
|
| 435 |
+
Train the model for one epoch.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
model (nn.Module): Model to train
|
| 439 |
+
optimizer (torch.optim.Optimizer): Optimizer
|
| 440 |
+
batch_size (int): Batch size
|
| 441 |
+
seq_len (int): Sequence length
|
| 442 |
+
vocab_size (int): Vocabulary size
|
| 443 |
+
num_batches (int): Number of batches per epoch
|
| 444 |
+
device (str): Device to use
|
| 445 |
+
load_balance_coef (float): Coefficient for load balancing loss
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
Dict containing training metrics
|
| 449 |
+
"""
|
| 450 |
+
model.train()
|
| 451 |
+
total_loss = 0.0
|
| 452 |
+
total_acc = 0.0
|
| 453 |
+
start_time = time.time()
|
| 454 |
+
|
| 455 |
+
for i in range(num_batches):
|
| 456 |
+
# Generate sample data
|
| 457 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
| 458 |
+
|
| 459 |
+
# Forward pass
|
| 460 |
+
optimizer.zero_grad()
|
| 461 |
+
logits, router_logits = model(inputs)
|
| 462 |
+
|
| 463 |
+
# Compute loss
|
| 464 |
+
# Reshape for cross entropy loss
|
| 465 |
+
logits_flat = logits.reshape(-1, vocab_size)
|
| 466 |
+
targets_flat = targets.reshape(-1)
|
| 467 |
+
|
| 468 |
+
# Cross entropy loss
|
| 469 |
+
ce_loss = F.cross_entropy(logits_flat, targets_flat)
|
| 470 |
+
|
| 471 |
+
# Load balancing loss
|
| 472 |
+
lb_loss = compute_load_balancing_loss(
|
| 473 |
+
router_logits, model.moe_layer.num_experts
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Combined loss
|
| 477 |
+
loss = ce_loss + load_balance_coef * lb_loss
|
| 478 |
+
|
| 479 |
+
# Backward pass
|
| 480 |
+
loss.backward()
|
| 481 |
+
optimizer.step()
|
| 482 |
+
|
| 483 |
+
# Compute accuracy
|
| 484 |
+
preds = logits_flat.argmax(dim=-1)
|
| 485 |
+
correct = (preds == targets_flat).float().sum()
|
| 486 |
+
acc = correct / (batch_size * seq_len)
|
| 487 |
+
|
| 488 |
+
# Accumulate metrics
|
| 489 |
+
total_loss += loss.item()
|
| 490 |
+
total_acc += acc.item()
|
| 491 |
+
|
| 492 |
+
# Log progress
|
| 493 |
+
if (i + 1) % 10 == 0:
|
| 494 |
+
logging.info(
|
| 495 |
+
f"Batch {i + 1}/{num_batches} | "
|
| 496 |
+
f"Loss: {loss.item():.4f} | "
|
| 497 |
+
f"CE Loss: {ce_loss.item():.4f} | "
|
| 498 |
+
f"LB Loss: {lb_loss.item():.4f} | "
|
| 499 |
+
f"Acc: {acc.item():.4f}"
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# Compute average metrics
|
| 503 |
+
avg_loss = total_loss / num_batches
|
| 504 |
+
avg_acc = total_acc / num_batches
|
| 505 |
+
epoch_time = time.time() - start_time
|
| 506 |
+
|
| 507 |
+
return {"loss": avg_loss, "acc": avg_acc, "time": epoch_time}
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def evaluate(
|
| 511 |
+
model: nn.Module,
|
| 512 |
+
batch_size: int,
|
| 513 |
+
seq_len: int,
|
| 514 |
+
vocab_size: int,
|
| 515 |
+
num_batches: int,
|
| 516 |
+
device: str,
|
| 517 |
+
) -> Dict[str, float]:
|
| 518 |
+
"""
|
| 519 |
+
Evaluate the model.
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
model (nn.Module): Model to evaluate
|
| 523 |
+
batch_size (int): Batch size
|
| 524 |
+
seq_len (int): Sequence length
|
| 525 |
+
vocab_size (int): Vocabulary size
|
| 526 |
+
num_batches (int): Number of batches for evaluation
|
| 527 |
+
device (str): Device to use
|
| 528 |
+
|
| 529 |
+
Returns:
|
| 530 |
+
Dict containing evaluation metrics
|
| 531 |
+
"""
|
| 532 |
+
model.eval()
|
| 533 |
+
total_loss = 0.0
|
| 534 |
+
total_acc = 0.0
|
| 535 |
+
|
| 536 |
+
with torch.no_grad():
|
| 537 |
+
for i in range(num_batches):
|
| 538 |
+
# Generate sample data
|
| 539 |
+
inputs, targets = generate_sample_data(
|
| 540 |
+
batch_size, seq_len, vocab_size, device
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# Forward pass
|
| 544 |
+
logits, router_logits = model(inputs)
|
| 545 |
+
|
| 546 |
+
# Compute loss
|
| 547 |
+
logits_flat = logits.reshape(-1, vocab_size)
|
| 548 |
+
targets_flat = targets.reshape(-1)
|
| 549 |
+
|
| 550 |
+
# Cross entropy loss
|
| 551 |
+
loss = F.cross_entropy(logits_flat, targets_flat)
|
| 552 |
+
|
| 553 |
+
# Compute accuracy
|
| 554 |
+
preds = logits_flat.argmax(dim=-1)
|
| 555 |
+
correct = (preds == targets_flat).float().sum()
|
| 556 |
+
acc = correct / (batch_size * seq_len)
|
| 557 |
+
|
| 558 |
+
# Accumulate metrics
|
| 559 |
+
total_loss += loss.item()
|
| 560 |
+
total_acc += acc.item()
|
| 561 |
+
|
| 562 |
+
# Compute average metrics
|
| 563 |
+
avg_loss = total_loss / num_batches
|
| 564 |
+
avg_acc = total_acc / num_batches
|
| 565 |
+
|
| 566 |
+
return {"loss": avg_loss, "acc": avg_acc}
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def measure_performance(
|
| 570 |
+
model: nn.Module,
|
| 571 |
+
batch_size: int,
|
| 572 |
+
seq_len: int,
|
| 573 |
+
vocab_size: int,
|
| 574 |
+
num_batches: int,
|
| 575 |
+
device: str,
|
| 576 |
+
) -> Dict[str, float]:
|
| 577 |
+
"""
|
| 578 |
+
Measure forward and backward pass performance.
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
model (nn.Module): Model to evaluate
|
| 582 |
+
batch_size (int): Batch size
|
| 583 |
+
seq_len (int): Sequence length
|
| 584 |
+
vocab_size (int): Vocabulary size
|
| 585 |
+
num_batches (int): Number of batches for measurement
|
| 586 |
+
device (str): Device to use
|
| 587 |
+
|
| 588 |
+
Returns:
|
| 589 |
+
Dict containing performance metrics
|
| 590 |
+
"""
|
| 591 |
+
model.train()
|
| 592 |
+
|
| 593 |
+
# Create dummy optimizer
|
| 594 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
| 595 |
+
|
| 596 |
+
# Warmup
|
| 597 |
+
for _ in range(5):
|
| 598 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
| 599 |
+
logits, router_logits = model(inputs)
|
| 600 |
+
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
|
| 601 |
+
loss.backward()
|
| 602 |
+
optimizer.zero_grad()
|
| 603 |
+
|
| 604 |
+
# Measure forward pass time
|
| 605 |
+
torch.cuda.synchronize()
|
| 606 |
+
forward_start = time.time()
|
| 607 |
+
|
| 608 |
+
for _ in range(num_batches):
|
| 609 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
| 610 |
+
with torch.no_grad():
|
| 611 |
+
logits, router_logits = model(inputs)
|
| 612 |
+
|
| 613 |
+
torch.cuda.synchronize()
|
| 614 |
+
forward_end = time.time()
|
| 615 |
+
forward_time = (forward_end - forward_start) / num_batches
|
| 616 |
+
|
| 617 |
+
# Measure backward pass time
|
| 618 |
+
torch.cuda.synchronize()
|
| 619 |
+
backward_start = time.time()
|
| 620 |
+
|
| 621 |
+
for _ in range(num_batches):
|
| 622 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
| 623 |
+
logits, router_logits = model(inputs)
|
| 624 |
+
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
|
| 625 |
+
loss.backward()
|
| 626 |
+
optimizer.zero_grad()
|
| 627 |
+
|
| 628 |
+
torch.cuda.synchronize()
|
| 629 |
+
backward_end = time.time()
|
| 630 |
+
backward_time = (backward_end - backward_start) / num_batches
|
| 631 |
+
|
| 632 |
+
return {
|
| 633 |
+
"forward_time": forward_time * 1000, # Convert to ms
|
| 634 |
+
"backward_time": backward_time * 1000, # Convert to ms
|
| 635 |
+
"total_time": (forward_time + backward_time) * 1000, # Convert to ms
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def compare_methods(args):
|
| 640 |
+
"""
|
| 641 |
+
Compare manual looping and MG GEMM implementations.
|
| 642 |
+
"""
|
| 643 |
+
device = torch.device(args.device)
|
| 644 |
+
|
| 645 |
+
# Create models
|
| 646 |
+
manual_model = MoEModel(
|
| 647 |
+
vocab_size=args.vocab_size,
|
| 648 |
+
embed_dim=args.embed_dim,
|
| 649 |
+
hidden_dim=args.hidden_dim,
|
| 650 |
+
num_experts=args.num_experts,
|
| 651 |
+
top_k=args.top_k,
|
| 652 |
+
use_mg_gemm=False,
|
| 653 |
+
).to(device)
|
| 654 |
+
|
| 655 |
+
if has_mg_gemm:
|
| 656 |
+
mg_model = MoEModel(
|
| 657 |
+
vocab_size=args.vocab_size,
|
| 658 |
+
embed_dim=args.embed_dim,
|
| 659 |
+
hidden_dim=args.hidden_dim,
|
| 660 |
+
num_experts=args.num_experts,
|
| 661 |
+
top_k=args.top_k,
|
| 662 |
+
use_mg_gemm=True,
|
| 663 |
+
).to(device)
|
| 664 |
+
else:
|
| 665 |
+
mg_model = None
|
| 666 |
+
|
| 667 |
+
# Measure performance
|
| 668 |
+
logging.info("Measuring performance of manual looping method...")
|
| 669 |
+
manual_perf = measure_performance(
|
| 670 |
+
manual_model,
|
| 671 |
+
args.batch_size,
|
| 672 |
+
args.seq_len,
|
| 673 |
+
args.vocab_size,
|
| 674 |
+
args.perf_batches,
|
| 675 |
+
device,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
if mg_model is not None:
|
| 679 |
+
logging.info("Measuring performance of MG GEMM method...")
|
| 680 |
+
mg_perf = measure_performance(
|
| 681 |
+
mg_model,
|
| 682 |
+
args.batch_size,
|
| 683 |
+
args.seq_len,
|
| 684 |
+
args.vocab_size,
|
| 685 |
+
args.perf_batches,
|
| 686 |
+
device,
|
| 687 |
+
)
|
| 688 |
+
else:
|
| 689 |
+
mg_perf = {"forward_time": 0, "backward_time": 0, "total_time": 0}
|
| 690 |
+
|
| 691 |
+
# Log results
|
| 692 |
+
logging.info("\n===== Performance Comparison =====")
|
| 693 |
+
logging.info("Model Configuration:")
|
| 694 |
+
logging.info(f" - Batch Size: {args.batch_size}")
|
| 695 |
+
logging.info(f" - Sequence Length: {args.seq_len}")
|
| 696 |
+
logging.info(f" - Embed Dimension: {args.embed_dim}")
|
| 697 |
+
logging.info(f" - Hidden Dimension: {args.hidden_dim}")
|
| 698 |
+
logging.info(f" - Number of Experts: {args.num_experts}")
|
| 699 |
+
logging.info(f" - Top-K: {args.top_k}")
|
| 700 |
+
logging.info("")
|
| 701 |
+
|
| 702 |
+
logging.info("Manual Looping Method:")
|
| 703 |
+
logging.info(f" - Forward Time: {manual_perf['forward_time']:.2f} ms")
|
| 704 |
+
logging.info(f" - Backward Time: {manual_perf['backward_time']:.2f} ms")
|
| 705 |
+
logging.info(f" - Total Time: {manual_perf['total_time']:.2f} ms")
|
| 706 |
+
logging.info("")
|
| 707 |
+
|
| 708 |
+
if mg_model is not None:
|
| 709 |
+
logging.info("MG GEMM Method:")
|
| 710 |
+
logging.info(f" - Forward Time: {mg_perf['forward_time']:.2f} ms")
|
| 711 |
+
logging.info(f" - Backward Time: {mg_perf['backward_time']:.2f} ms")
|
| 712 |
+
logging.info(f" - Total Time: {mg_perf['total_time']:.2f} ms")
|
| 713 |
+
logging.info("")
|
| 714 |
+
|
| 715 |
+
# Calculate speedup
|
| 716 |
+
forward_speedup = (
|
| 717 |
+
manual_perf["forward_time"] / mg_perf["forward_time"]
|
| 718 |
+
if mg_perf["forward_time"] > 0
|
| 719 |
+
else 0
|
| 720 |
+
)
|
| 721 |
+
backward_speedup = (
|
| 722 |
+
manual_perf["backward_time"] / mg_perf["backward_time"]
|
| 723 |
+
if mg_perf["backward_time"] > 0
|
| 724 |
+
else 0
|
| 725 |
+
)
|
| 726 |
+
total_speedup = (
|
| 727 |
+
manual_perf["total_time"] / mg_perf["total_time"]
|
| 728 |
+
if mg_perf["total_time"] > 0
|
| 729 |
+
else 0
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
logging.info("Speedup (MG GEMM vs Manual):")
|
| 733 |
+
logging.info(f" - Forward Speedup: {forward_speedup:.2f}x")
|
| 734 |
+
logging.info(f" - Backward Speedup: {backward_speedup:.2f}x")
|
| 735 |
+
logging.info(f" - Total Speedup: {total_speedup:.2f}x")
|
| 736 |
+
else:
|
| 737 |
+
logging.info("MG GEMM method not available.")
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
def train_model(args):
|
| 741 |
+
"""
|
| 742 |
+
Train an MoE model.
|
| 743 |
+
"""
|
| 744 |
+
device = torch.device(args.device)
|
| 745 |
+
|
| 746 |
+
# Create model
|
| 747 |
+
model = MoEModel(
|
| 748 |
+
vocab_size=args.vocab_size,
|
| 749 |
+
embed_dim=args.embed_dim,
|
| 750 |
+
hidden_dim=args.hidden_dim,
|
| 751 |
+
num_experts=args.num_experts,
|
| 752 |
+
top_k=args.top_k,
|
| 753 |
+
use_mg_gemm=args.use_mg_gemm and has_mg_gemm,
|
| 754 |
+
).to(device)
|
| 755 |
+
|
| 756 |
+
# Create optimizer
|
| 757 |
+
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
| 758 |
+
|
| 759 |
+
# Log model information
|
| 760 |
+
logging.info("Model configuration:")
|
| 761 |
+
logging.info(f" - Vocabulary Size: {args.vocab_size}")
|
| 762 |
+
logging.info(f" - Embedding Dimension: {args.embed_dim}")
|
| 763 |
+
logging.info(f" - Hidden Dimension: {args.hidden_dim}")
|
| 764 |
+
logging.info(f" - Number of Experts: {args.num_experts}")
|
| 765 |
+
logging.info(f" - Top-K: {args.top_k}")
|
| 766 |
+
logging.info(f" - Using MG GEMM: {args.use_mg_gemm and has_mg_gemm}")
|
| 767 |
+
|
| 768 |
+
# Training loop
|
| 769 |
+
for epoch in range(args.epochs):
|
| 770 |
+
logging.info(f"\nEpoch {epoch + 1}/{args.epochs}")
|
| 771 |
+
|
| 772 |
+
# Train
|
| 773 |
+
train_metrics = train_epoch(
|
| 774 |
+
model=model,
|
| 775 |
+
optimizer=optimizer,
|
| 776 |
+
batch_size=args.batch_size,
|
| 777 |
+
seq_len=args.seq_len,
|
| 778 |
+
vocab_size=args.vocab_size,
|
| 779 |
+
num_batches=args.train_batches,
|
| 780 |
+
device=device,
|
| 781 |
+
load_balance_coef=args.load_balance_coef,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# Evaluate
|
| 785 |
+
eval_metrics = evaluate(
|
| 786 |
+
model=model,
|
| 787 |
+
batch_size=args.batch_size,
|
| 788 |
+
seq_len=args.seq_len,
|
| 789 |
+
vocab_size=args.vocab_size,
|
| 790 |
+
num_batches=args.eval_batches,
|
| 791 |
+
device=device,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
# Log metrics
|
| 795 |
+
logging.info(
|
| 796 |
+
f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['acc']:.4f}"
|
| 797 |
+
)
|
| 798 |
+
logging.info(
|
| 799 |
+
f"Eval Loss: {eval_metrics['loss']:.4f} | Eval Acc: {eval_metrics['acc']:.4f}"
|
| 800 |
+
)
|
| 801 |
+
logging.info(f"Epoch Time: {train_metrics['time']:.2f} seconds")
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
if __name__ == "__main__":
|
| 805 |
+
parser = argparse.ArgumentParser(description="Train MoE model")
|
| 806 |
+
|
| 807 |
+
# Model parameters
|
| 808 |
+
parser.add_argument("--vocab_size", type=int, default=10000, help="Vocabulary size")
|
| 809 |
+
parser.add_argument(
|
| 810 |
+
"--embed_dim", type=int, default=512, help="Embedding dimension"
|
| 811 |
+
)
|
| 812 |
+
parser.add_argument(
|
| 813 |
+
"--hidden_dim", type=int, default=1024, help="Hidden dimension in experts"
|
| 814 |
+
)
|
| 815 |
+
parser.add_argument("--num_experts", type=int, default=8, help="Number of experts")
|
| 816 |
+
parser.add_argument(
|
| 817 |
+
"--top_k", type=int, default=2, help="Top-k experts to route to"
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
# Training parameters
|
| 821 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
| 822 |
+
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
| 823 |
+
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
|
| 824 |
+
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
|
| 825 |
+
parser.add_argument(
|
| 826 |
+
"--train_batches",
|
| 827 |
+
type=int,
|
| 828 |
+
default=100,
|
| 829 |
+
help="Number of training batches per epoch",
|
| 830 |
+
)
|
| 831 |
+
parser.add_argument(
|
| 832 |
+
"--eval_batches", type=int, default=20, help="Number of evaluation batches"
|
| 833 |
+
)
|
| 834 |
+
parser.add_argument(
|
| 835 |
+
"--perf_batches",
|
| 836 |
+
type=int,
|
| 837 |
+
default=50,
|
| 838 |
+
help="Number of batches for performance testing",
|
| 839 |
+
)
|
| 840 |
+
parser.add_argument(
|
| 841 |
+
"--load_balance_coef",
|
| 842 |
+
type=float,
|
| 843 |
+
default=0.01,
|
| 844 |
+
help="Load balancing loss coefficient",
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
# Runtime parameters
|
| 848 |
+
parser.add_argument(
|
| 849 |
+
"--device",
|
| 850 |
+
type=str,
|
| 851 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 852 |
+
help="Device to use (cuda or cpu)",
|
| 853 |
+
)
|
| 854 |
+
parser.add_argument(
|
| 855 |
+
"--use_mg_gemm",
|
| 856 |
+
action="store_true",
|
| 857 |
+
help="Use MG GEMM implementation if available",
|
| 858 |
+
)
|
| 859 |
+
parser.add_argument(
|
| 860 |
+
"--compare",
|
| 861 |
+
action="store_true",
|
| 862 |
+
help="Compare manual and MG GEMM implementations",
|
| 863 |
+
)
|
| 864 |
+
parser.add_argument("--train", action="store_true", help="Train the model")
|
| 865 |
+
|
| 866 |
+
args = parser.parse_args()
|
| 867 |
+
|
| 868 |
+
# Check for CUDA
|
| 869 |
+
if args.device == "cuda" and not torch.cuda.is_available():
|
| 870 |
+
logging.warning("CUDA not available, using CPU instead.")
|
| 871 |
+
args.device = "cpu"
|
| 872 |
+
|
| 873 |
+
# Log basic information
|
| 874 |
+
logging.info(f"PyTorch version: {torch.__version__}")
|
| 875 |
+
logging.info(f"Device: {args.device}")
|
| 876 |
+
logging.info(f"MG GEMM available: {has_mg_gemm}")
|
| 877 |
+
|
| 878 |
+
# Run the requested action
|
| 879 |
+
if args.compare:
|
| 880 |
+
compare_methods(args)
|
| 881 |
+
elif args.train:
|
| 882 |
+
train_model(args)
|
| 883 |
+
else:
|
| 884 |
+
# Default to comparison if no action specified
|
| 885 |
+
compare_methods(args)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py
ADDED
|
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# credit - flat index forward kernel is derived from FBGemm:
|
| 8 |
+
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
| 9 |
+
|
| 10 |
+
# pyre-unsafe
|
| 11 |
+
import functools
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
from typing import Any, Dict, Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
import triton
|
| 21 |
+
import triton.language as tl
|
| 22 |
+
from triton import Config as TConfig
|
| 23 |
+
|
| 24 |
+
from triton.runtime import driver # @manual
|
| 25 |
+
|
| 26 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 27 |
+
|
| 28 |
+
from tma_autotuning import (
|
| 29 |
+
ALIGN_SIZE_M,
|
| 30 |
+
_NV_CONFIGS,
|
| 31 |
+
CudaUtils,
|
| 32 |
+
early_config_prune,
|
| 33 |
+
TmaDescriptorHelper,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Configure logging
|
| 38 |
+
logging.basicConfig(
|
| 39 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# ============== Start Triton Kernels ===============
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@triton.autotune(
|
| 46 |
+
configs=_NV_CONFIGS,
|
| 47 |
+
key=["G", "M_BUCKET", "N", "K"],
|
| 48 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
| 49 |
+
)
|
| 50 |
+
@triton.jit
|
| 51 |
+
def _kernel_mg_forward_hopper(
|
| 52 |
+
a_desc_ptr,
|
| 53 |
+
b_desc_ptr,
|
| 54 |
+
c_ptr,
|
| 55 |
+
workspace,
|
| 56 |
+
m_sizes,
|
| 57 |
+
# problem sizes
|
| 58 |
+
G: tl.constexpr,
|
| 59 |
+
M_BUCKET: tl.constexpr,
|
| 60 |
+
N: tl.constexpr,
|
| 61 |
+
K: tl.constexpr,
|
| 62 |
+
# config
|
| 63 |
+
NUM_SMS: tl.constexpr,
|
| 64 |
+
TMA_SIZE: tl.constexpr,
|
| 65 |
+
USE_EPILOGUE_SUBTILING: tl.constexpr,
|
| 66 |
+
# tiles
|
| 67 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 68 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 69 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 70 |
+
) -> None:
|
| 71 |
+
"""
|
| 72 |
+
Flat index style forward kernel for Hopper.
|
| 73 |
+
For simplicity, we always use TMA Load and TMA Store
|
| 74 |
+
"""
|
| 75 |
+
tbidx = tl.program_id(0) # thread block index
|
| 76 |
+
|
| 77 |
+
c_dtype = c_ptr.dtype.element_ty # output dtype
|
| 78 |
+
|
| 79 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store
|
| 80 |
+
|
| 81 |
+
M_end = 0
|
| 82 |
+
M_start = 0
|
| 83 |
+
processed_tiles = 0
|
| 84 |
+
# Size of individual weight matrix
|
| 85 |
+
n_size = N // G
|
| 86 |
+
n_start = 0
|
| 87 |
+
|
| 88 |
+
for g in range(G):
|
| 89 |
+
# Move down along groups
|
| 90 |
+
# reset to new M offset
|
| 91 |
+
M_start = M_end
|
| 92 |
+
m_size = tl.load(m_sizes + g)
|
| 93 |
+
M_end = M_start + m_size
|
| 94 |
+
n_start = n_size * g
|
| 95 |
+
|
| 96 |
+
if m_size > 0:
|
| 97 |
+
# Process this group
|
| 98 |
+
|
| 99 |
+
# Acquire hold on c_desc_ptr for TMA Store
|
| 100 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
| 101 |
+
desc_ptr=c_desc_ptr,
|
| 102 |
+
global_address=c_ptr + M_start * n_size,
|
| 103 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
| 104 |
+
global_size=[m_size, n_size],
|
| 105 |
+
element_ty=c_dtype,
|
| 106 |
+
)
|
| 107 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
| 108 |
+
|
| 109 |
+
# tiles for this group
|
| 110 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
| 111 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
| 112 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
| 113 |
+
|
| 114 |
+
while tbidx >= processed_tiles and tbidx < (
|
| 115 |
+
processed_tiles + group_num_tiles
|
| 116 |
+
):
|
| 117 |
+
group_index = tbidx - processed_tiles
|
| 118 |
+
|
| 119 |
+
# columnwise
|
| 120 |
+
tile_m_index = group_index % num_m_tiles
|
| 121 |
+
tile_n_index = group_index // num_m_tiles
|
| 122 |
+
|
| 123 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 124 |
+
|
| 125 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
| 126 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
| 127 |
+
global_n_offset = (n_start + n_offset).to(tl.int32)
|
| 128 |
+
|
| 129 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
| 130 |
+
# input block [M,K]
|
| 131 |
+
a = tl._experimental_descriptor_load(
|
| 132 |
+
a_desc_ptr,
|
| 133 |
+
[m_offset, k_offset],
|
| 134 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
| 135 |
+
c_dtype,
|
| 136 |
+
)
|
| 137 |
+
# weight block [N, K]
|
| 138 |
+
b = tl._experimental_descriptor_load(
|
| 139 |
+
b_desc_ptr,
|
| 140 |
+
[global_n_offset, k_offset],
|
| 141 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
| 142 |
+
c_dtype,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
accumulator += tl.dot(a, b.T)
|
| 146 |
+
|
| 147 |
+
# Store using TMA
|
| 148 |
+
|
| 149 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
| 150 |
+
|
| 151 |
+
if USE_EPILOGUE_SUBTILING:
|
| 152 |
+
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
| 153 |
+
acc = tl.permute(acc, (0, 2, 1))
|
| 154 |
+
acc0, acc1 = tl.split(acc)
|
| 155 |
+
c0 = acc0.to(c_dtype)
|
| 156 |
+
tl._experimental_descriptor_store(
|
| 157 |
+
c_desc_ptr, c0, [m_offset, n_offset]
|
| 158 |
+
)
|
| 159 |
+
c1 = acc1.to(c_dtype)
|
| 160 |
+
tl._experimental_descriptor_store(
|
| 161 |
+
c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
tl._experimental_descriptor_store(
|
| 165 |
+
c_desc_ptr,
|
| 166 |
+
accumulator.to(c_dtype),
|
| 167 |
+
[m_offset, n_offset],
|
| 168 |
+
)
|
| 169 |
+
# move to next tile in group
|
| 170 |
+
tbidx += NUM_SMS
|
| 171 |
+
# Update the total tiles count for the next group
|
| 172 |
+
processed_tiles += group_num_tiles
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@triton.autotune(
|
| 176 |
+
configs=_NV_CONFIGS,
|
| 177 |
+
key=["G", "M_BUCKET", "N", "K"],
|
| 178 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
| 179 |
+
)
|
| 180 |
+
@triton.jit
|
| 181 |
+
def _kernel_mg_forward_tma(
|
| 182 |
+
a_desc_ptr,
|
| 183 |
+
b_desc_ptr,
|
| 184 |
+
c_ptr,
|
| 185 |
+
workspace,
|
| 186 |
+
m_sizes,
|
| 187 |
+
a_scale_ptr,
|
| 188 |
+
b_scale_ptr,
|
| 189 |
+
# problem sizes
|
| 190 |
+
G: tl.constexpr,
|
| 191 |
+
M_BUCKET: tl.constexpr,
|
| 192 |
+
N: tl.constexpr,
|
| 193 |
+
K: tl.constexpr,
|
| 194 |
+
# config
|
| 195 |
+
NUM_SMS: tl.constexpr,
|
| 196 |
+
USE_TMA_LOAD: tl.constexpr,
|
| 197 |
+
USE_TMA_STORE: tl.constexpr,
|
| 198 |
+
TMA_SIZE: tl.constexpr,
|
| 199 |
+
USE_FP8: tl.constexpr,
|
| 200 |
+
# tiles
|
| 201 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 202 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 203 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 204 |
+
) -> None:
|
| 205 |
+
"""
|
| 206 |
+
Flat index style forward kernel.
|
| 207 |
+
For simplicity, we always use TMA Load and TMA Store
|
| 208 |
+
"""
|
| 209 |
+
tbidx = tl.program_id(0) # thread block index
|
| 210 |
+
|
| 211 |
+
c_dtype = c_ptr.dtype.element_ty
|
| 212 |
+
|
| 213 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE)
|
| 214 |
+
|
| 215 |
+
M_end = 0
|
| 216 |
+
processed_tiles = 0
|
| 217 |
+
|
| 218 |
+
for g in range(G):
|
| 219 |
+
# Move down along groups
|
| 220 |
+
# reset to new M offset
|
| 221 |
+
M_start = M_end
|
| 222 |
+
m_size = tl.load(m_sizes + g)
|
| 223 |
+
M_end = M_start + m_size
|
| 224 |
+
|
| 225 |
+
if m_size > 0:
|
| 226 |
+
# Process this group
|
| 227 |
+
n_size = N
|
| 228 |
+
|
| 229 |
+
# TMA Store prep
|
| 230 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
| 231 |
+
desc_ptr=c_desc_ptr,
|
| 232 |
+
global_address=c_ptr + M_start * N,
|
| 233 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
| 234 |
+
global_size=[m_size, n_size],
|
| 235 |
+
element_ty=c_dtype,
|
| 236 |
+
)
|
| 237 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
| 238 |
+
|
| 239 |
+
# tiles for this group
|
| 240 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
| 241 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
| 242 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
| 243 |
+
|
| 244 |
+
while tbidx >= processed_tiles and tbidx < (
|
| 245 |
+
processed_tiles + group_num_tiles
|
| 246 |
+
):
|
| 247 |
+
group_index = tbidx - processed_tiles
|
| 248 |
+
|
| 249 |
+
tile_m_index = group_index % num_m_tiles
|
| 250 |
+
tile_n_index = group_index // num_m_tiles
|
| 251 |
+
|
| 252 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 253 |
+
|
| 254 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
| 255 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
| 256 |
+
|
| 257 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
| 258 |
+
# input block [M,K]
|
| 259 |
+
a = tl._experimental_descriptor_load(
|
| 260 |
+
a_desc_ptr,
|
| 261 |
+
[m_offset, k_offset],
|
| 262 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
| 263 |
+
c_dtype,
|
| 264 |
+
)
|
| 265 |
+
# weight block [N, K]
|
| 266 |
+
b = tl._experimental_descriptor_load(
|
| 267 |
+
b_desc_ptr,
|
| 268 |
+
[n_offset, k_offset],
|
| 269 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
| 270 |
+
c_dtype,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
accumulator += tl.dot(a, b.T)
|
| 274 |
+
|
| 275 |
+
# Store using TMA
|
| 276 |
+
|
| 277 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
| 278 |
+
# n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
| 279 |
+
|
| 280 |
+
tl._experimental_descriptor_store(
|
| 281 |
+
c_desc_ptr,
|
| 282 |
+
accumulator.to(c_dtype),
|
| 283 |
+
[m_offset, n_offset],
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Move to the next tile
|
| 287 |
+
tbidx += NUM_SMS
|
| 288 |
+
# Update the total tiles count for the next group
|
| 289 |
+
processed_tiles += group_num_tiles
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
@triton.autotune(
|
| 293 |
+
configs=_NV_CONFIGS,
|
| 294 |
+
key=["G", "M_BUCKET", "N", "K"],
|
| 295 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
| 296 |
+
)
|
| 297 |
+
@triton.jit
|
| 298 |
+
def _kernel_mg_forward_no_tma(
|
| 299 |
+
a_ptr,
|
| 300 |
+
b_ptr,
|
| 301 |
+
c_ptr,
|
| 302 |
+
workspace,
|
| 303 |
+
m_sizes,
|
| 304 |
+
# problem sizes
|
| 305 |
+
G: tl.constexpr,
|
| 306 |
+
M_BUCKET: tl.constexpr,
|
| 307 |
+
N: tl.constexpr,
|
| 308 |
+
K: tl.constexpr,
|
| 309 |
+
# config
|
| 310 |
+
NUM_SMS: tl.constexpr,
|
| 311 |
+
USE_TMA_LOAD: tl.constexpr,
|
| 312 |
+
USE_TMA_STORE: tl.constexpr,
|
| 313 |
+
TMA_SIZE: tl.constexpr,
|
| 314 |
+
# tiles
|
| 315 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 316 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 317 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 318 |
+
) -> None:
|
| 319 |
+
"""
|
| 320 |
+
Flat index style forward kernel.
|
| 321 |
+
For bc and Ampere, we never use TMA Load and TMA Store
|
| 322 |
+
"""
|
| 323 |
+
tbidx = tl.program_id(0) # thread block index
|
| 324 |
+
|
| 325 |
+
c_dtype = c_ptr.dtype.element_ty
|
| 326 |
+
c_desc_ptr = None
|
| 327 |
+
|
| 328 |
+
M_end = 0
|
| 329 |
+
processed_tiles = 0
|
| 330 |
+
|
| 331 |
+
for g in range(G):
|
| 332 |
+
# Move down along groups
|
| 333 |
+
# reset to new M offset
|
| 334 |
+
M_start = M_end
|
| 335 |
+
m_size = tl.load(m_sizes + g)
|
| 336 |
+
M_end = M_start + m_size
|
| 337 |
+
|
| 338 |
+
if m_size > 0:
|
| 339 |
+
# Process this group
|
| 340 |
+
n_size = N
|
| 341 |
+
|
| 342 |
+
# tiles for this group
|
| 343 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
| 344 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
| 345 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
| 346 |
+
|
| 347 |
+
while tbidx >= processed_tiles and tbidx < (
|
| 348 |
+
processed_tiles + group_num_tiles
|
| 349 |
+
):
|
| 350 |
+
group_index = tbidx - processed_tiles
|
| 351 |
+
|
| 352 |
+
tile_m_index = group_index % num_m_tiles
|
| 353 |
+
tile_n_index = group_index // num_m_tiles
|
| 354 |
+
|
| 355 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 356 |
+
|
| 357 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
| 358 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
| 359 |
+
|
| 360 |
+
offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 361 |
+
offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 362 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 363 |
+
|
| 364 |
+
a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
|
| 365 |
+
b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]
|
| 366 |
+
|
| 367 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
| 368 |
+
# Load with bounds checking
|
| 369 |
+
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
|
| 370 |
+
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
|
| 371 |
+
|
| 372 |
+
# Main matmul
|
| 373 |
+
accumulator += tl.dot(a, b.T)
|
| 374 |
+
|
| 375 |
+
# Update pointers for next block
|
| 376 |
+
a_ptrs += BLOCK_SIZE_K
|
| 377 |
+
b_ptrs += BLOCK_SIZE_K
|
| 378 |
+
|
| 379 |
+
# Store without TMA
|
| 380 |
+
offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 381 |
+
offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 382 |
+
|
| 383 |
+
c = accumulator.to(c_dtype)
|
| 384 |
+
|
| 385 |
+
tl.store(
|
| 386 |
+
c_ptr
|
| 387 |
+
+ (M_start + offs_am[:, None]) * N # Row stride is N
|
| 388 |
+
+ offs_bn[None, :], # Column offset
|
| 389 |
+
c,
|
| 390 |
+
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
|
| 391 |
+
)
|
| 392 |
+
# Move to the next tile
|
| 393 |
+
tbidx += NUM_SMS
|
| 394 |
+
# Update the total tiles count for the next group
|
| 395 |
+
processed_tiles += group_num_tiles
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
"""
|
| 399 |
+
Backward pass for grouped GEMM with Triton, where grouping is M*G
|
| 400 |
+
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# ---- dx flat linear indexed ----
|
| 405 |
+
@triton.autotune(
|
| 406 |
+
configs=_NV_CONFIGS,
|
| 407 |
+
key=["G", "M_BUCKET", "N", "K"],
|
| 408 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
| 409 |
+
)
|
| 410 |
+
@triton.jit
|
| 411 |
+
def _kernel_mg_dx_tma(
|
| 412 |
+
grad_output_desc_ptr, # [MG, N]
|
| 413 |
+
w_desc_ptr, # [N, K]
|
| 414 |
+
grad_input_ptr, # output grad_x [MG, K]
|
| 415 |
+
workspace, # for TMA store
|
| 416 |
+
m_sizes, # group sizes [G]
|
| 417 |
+
# problem sizes
|
| 418 |
+
G: tl.constexpr,
|
| 419 |
+
M_BUCKET: tl.constexpr,
|
| 420 |
+
N: tl.constexpr,
|
| 421 |
+
K: tl.constexpr,
|
| 422 |
+
# config
|
| 423 |
+
NUM_SMS: tl.constexpr,
|
| 424 |
+
USE_TMA_LOAD: tl.constexpr,
|
| 425 |
+
USE_TMA_STORE: tl.constexpr,
|
| 426 |
+
TMA_SIZE: tl.constexpr,
|
| 427 |
+
# tiles
|
| 428 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 429 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 430 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 431 |
+
) -> None:
|
| 432 |
+
"""
|
| 433 |
+
TMA-optimized kernel for computing gradients with respect to input (dx).
|
| 434 |
+
For the forward pass Y = X @ W.T, the backward for input is:
|
| 435 |
+
grad_X = grad_Y @ W
|
| 436 |
+
|
| 437 |
+
This maps to [MG, N] @ [N, K] -> [MG, K]
|
| 438 |
+
|
| 439 |
+
Key differences from forward:
|
| 440 |
+
1. W is used directly and not transposed
|
| 441 |
+
2. The reduction dimension is now N (not K)
|
| 442 |
+
3. Output is [M, K] instead of [M, N]
|
| 443 |
+
"""
|
| 444 |
+
tbidx = tl.program_id(0) # thread block index
|
| 445 |
+
|
| 446 |
+
c_dtype = grad_input_ptr.dtype.element_ty
|
| 447 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE)
|
| 448 |
+
|
| 449 |
+
M_end = 0
|
| 450 |
+
processed_tiles = 0
|
| 451 |
+
|
| 452 |
+
for g in range(G):
|
| 453 |
+
# Move down along groups - same as forward
|
| 454 |
+
M_start = M_end
|
| 455 |
+
m_size = tl.load(m_sizes + g)
|
| 456 |
+
M_end = M_start + m_size
|
| 457 |
+
|
| 458 |
+
if m_size > 0:
|
| 459 |
+
# Process this group
|
| 460 |
+
# tiles for this group - now producing [M, K] output
|
| 461 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
| 462 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
| 463 |
+
group_num_tiles = num_m_tiles * num_k_tiles
|
| 464 |
+
|
| 465 |
+
# TMA Store prep for [M, K] output
|
| 466 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
| 467 |
+
desc_ptr=c_desc_ptr,
|
| 468 |
+
global_address=grad_input_ptr + M_start * K,
|
| 469 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
| 470 |
+
global_size=[m_size, K],
|
| 471 |
+
element_ty=c_dtype,
|
| 472 |
+
)
|
| 473 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
| 474 |
+
|
| 475 |
+
while tbidx >= processed_tiles and tbidx < (
|
| 476 |
+
processed_tiles + group_num_tiles
|
| 477 |
+
):
|
| 478 |
+
group_index = tbidx - processed_tiles
|
| 479 |
+
|
| 480 |
+
# Different tiling scheme for [M, K] output
|
| 481 |
+
tile_m_index = group_index % num_m_tiles
|
| 482 |
+
tile_k_index = group_index // num_m_tiles
|
| 483 |
+
|
| 484 |
+
# for grad_input block [M, K]
|
| 485 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
| 486 |
+
|
| 487 |
+
# Position in full matrix
|
| 488 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
| 489 |
+
k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
|
| 490 |
+
|
| 491 |
+
# reduce along N dimension (instead of K in forward)
|
| 492 |
+
for n_offset in range(0, N, BLOCK_SIZE_N):
|
| 493 |
+
# grad_output block [M, N]
|
| 494 |
+
grad_output = tl._experimental_descriptor_load(
|
| 495 |
+
grad_output_desc_ptr,
|
| 496 |
+
[m_offset, n_offset],
|
| 497 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
| 498 |
+
c_dtype,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
# weight block [N, K] - no transpose needed
|
| 502 |
+
w = tl._experimental_descriptor_load(
|
| 503 |
+
w_desc_ptr,
|
| 504 |
+
[n_offset, k_offset],
|
| 505 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
| 506 |
+
c_dtype,
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
# grad_x = grad_output @ w
|
| 510 |
+
# reducing along N dimension
|
| 511 |
+
accumulator += tl.dot(grad_output, w)
|
| 512 |
+
|
| 513 |
+
# Store using TMA
|
| 514 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
| 515 |
+
# k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
|
| 516 |
+
|
| 517 |
+
tl._experimental_descriptor_store(
|
| 518 |
+
c_desc_ptr,
|
| 519 |
+
accumulator.to(c_dtype),
|
| 520 |
+
[m_offset, k_offset],
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# Move to the next tile
|
| 524 |
+
tbidx += NUM_SMS
|
| 525 |
+
|
| 526 |
+
# Update the total tiles count for the next group
|
| 527 |
+
processed_tiles += group_num_tiles
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
# ---- dw flat linear indexed ----
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
@triton.autotune(
|
| 534 |
+
configs=_NV_CONFIGS,
|
| 535 |
+
key=["G", "M_BUCKET", "N", "K"],
|
| 536 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
| 537 |
+
)
|
| 538 |
+
@triton.jit
|
| 539 |
+
def _kernel_mg_dw_tma(
|
| 540 |
+
x_desc_ptr, # input descriptor [M_total, K]
|
| 541 |
+
grad_output_desc_ptr, # grad_output descriptor [M_total, N]
|
| 542 |
+
grad_weight_ptr, # output grad_w [N, K]
|
| 543 |
+
workspace, # workspace for TMA store
|
| 544 |
+
m_sizes, # group sizes [G]
|
| 545 |
+
# problem sizes
|
| 546 |
+
G: tl.constexpr,
|
| 547 |
+
M_BUCKET: tl.constexpr,
|
| 548 |
+
N: tl.constexpr,
|
| 549 |
+
K: tl.constexpr,
|
| 550 |
+
# config
|
| 551 |
+
NUM_SMS: tl.constexpr,
|
| 552 |
+
USE_TMA_LOAD: tl.constexpr,
|
| 553 |
+
USE_TMA_STORE: tl.constexpr,
|
| 554 |
+
TMA_SIZE: tl.constexpr,
|
| 555 |
+
# tiles
|
| 556 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 557 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 558 |
+
BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension
|
| 559 |
+
) -> None:
|
| 560 |
+
"""
|
| 561 |
+
Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
|
| 562 |
+
Uses flat index structure similar to forward.
|
| 563 |
+
|
| 564 |
+
For the forward pass Y = X @ W.T,
|
| 565 |
+
the backward for weights is:
|
| 566 |
+
grad_W = grad_Y.T @ X
|
| 567 |
+
|
| 568 |
+
Where:
|
| 569 |
+
- grad_Y is [MG, N]
|
| 570 |
+
- X is [MG, K]
|
| 571 |
+
- grad_W is [N, K]
|
| 572 |
+
- we return [N,K]
|
| 573 |
+
"""
|
| 574 |
+
# Get thread block index l
|
| 575 |
+
tbidx = tl.program_id(0)
|
| 576 |
+
|
| 577 |
+
# Get output data type
|
| 578 |
+
c_dtype = grad_weight_ptr.dtype.element_ty
|
| 579 |
+
|
| 580 |
+
# Calculate number of output tiles
|
| 581 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
| 582 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
| 583 |
+
total_output_tiles = num_n_tiles * num_k_tiles
|
| 584 |
+
|
| 585 |
+
# Process tiles in strided manner across SMs
|
| 586 |
+
for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
|
| 587 |
+
# Calculate tile indices
|
| 588 |
+
tile_n_idx = tile_idx % num_n_tiles
|
| 589 |
+
tile_k_idx = tile_idx // num_n_tiles
|
| 590 |
+
|
| 591 |
+
# Calculate global offsets
|
| 592 |
+
n_offset = tile_n_idx * BLOCK_SIZE_N
|
| 593 |
+
k_offset = tile_k_idx * BLOCK_SIZE_K
|
| 594 |
+
|
| 595 |
+
# Initialize accumulator for this output tile [N, K]
|
| 596 |
+
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
|
| 597 |
+
|
| 598 |
+
# Process each group
|
| 599 |
+
M_end = 0
|
| 600 |
+
for g in range(G):
|
| 601 |
+
# Get group boundaries
|
| 602 |
+
M_start = M_end
|
| 603 |
+
m_size = tl.load(m_sizes + g)
|
| 604 |
+
M_end = M_start + m_size
|
| 605 |
+
|
| 606 |
+
# Only process if group is non-empty
|
| 607 |
+
if m_size > 0:
|
| 608 |
+
# Process this group in chunks along the M dimension
|
| 609 |
+
for m_offset in range(0, m_size, BLOCK_SIZE_M):
|
| 610 |
+
# Calculate actual block size (handling boundary)
|
| 611 |
+
m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)
|
| 612 |
+
|
| 613 |
+
# Only process if we have actual work to do
|
| 614 |
+
if m_block_size > 0:
|
| 615 |
+
# Global offset for this chunk
|
| 616 |
+
m_global_offset = M_start + m_offset
|
| 617 |
+
|
| 618 |
+
if USE_TMA_LOAD:
|
| 619 |
+
# Load input chunk [M_chunk, K] using TMA
|
| 620 |
+
x_block = tl._experimental_descriptor_load(
|
| 621 |
+
x_desc_ptr,
|
| 622 |
+
[m_global_offset, k_offset],
|
| 623 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
| 624 |
+
c_dtype,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# Load grad_output chunk [M_chunk, N] using TMA
|
| 628 |
+
grad_output_block = tl._experimental_descriptor_load(
|
| 629 |
+
grad_output_desc_ptr,
|
| 630 |
+
[m_global_offset, n_offset],
|
| 631 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
| 632 |
+
c_dtype,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
# Apply masks for valid regions
|
| 636 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
| 637 |
+
m_mask = offs_m < m_block_size
|
| 638 |
+
|
| 639 |
+
# Zero out invalid elements
|
| 640 |
+
x_block = tl.where(m_mask[:, None], x_block, 0.0)
|
| 641 |
+
grad_output_block = tl.where(
|
| 642 |
+
m_mask[:, None], grad_output_block, 0.0
|
| 643 |
+
)
|
| 644 |
+
else:
|
| 645 |
+
# Manual load with bounds checking
|
| 646 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
| 647 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
| 648 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 649 |
+
|
| 650 |
+
# Create masks
|
| 651 |
+
m_mask = offs_m < m_block_size
|
| 652 |
+
n_mask = offs_n < N - n_offset
|
| 653 |
+
k_mask = offs_k < K - k_offset
|
| 654 |
+
|
| 655 |
+
# Combined masks
|
| 656 |
+
mk_mask = m_mask[:, None] & k_mask[None, :]
|
| 657 |
+
mn_mask = m_mask[:, None] & n_mask[None, :]
|
| 658 |
+
|
| 659 |
+
# Global offsets for loading
|
| 660 |
+
m_global_offs = m_global_offset + offs_m
|
| 661 |
+
|
| 662 |
+
# Load x block [M_chunk, K]
|
| 663 |
+
x_block = tl.load(
|
| 664 |
+
x_desc_ptr
|
| 665 |
+
+ m_global_offs[:, None] * K
|
| 666 |
+
+ (k_offset + offs_k)[None, :],
|
| 667 |
+
mask=mk_mask,
|
| 668 |
+
other=0.0,
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# Load grad_output block [M_chunk, N]
|
| 672 |
+
grad_output_block = tl.load(
|
| 673 |
+
grad_output_desc_ptr
|
| 674 |
+
+ m_global_offs[:, None] * N
|
| 675 |
+
+ (n_offset + offs_n)[None, :],
|
| 676 |
+
mask=mn_mask,
|
| 677 |
+
other=0.0,
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# Compute partial contribution: grad_W += grad_Y.T @ X
|
| 681 |
+
# transpose grad_output for the matmul
|
| 682 |
+
contribution = tl.dot(
|
| 683 |
+
grad_output_block.to(tl.float32).T, # [N, M_chunk]
|
| 684 |
+
x_block.to(tl.float32), # [M_chunk, K]
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# Accumulate
|
| 688 |
+
accumulator += contribution
|
| 689 |
+
|
| 690 |
+
# Store the result
|
| 691 |
+
if USE_TMA_STORE:
|
| 692 |
+
# Store using TMA
|
| 693 |
+
tl._experimental_descriptor_store(
|
| 694 |
+
workspace, # TMA store descriptor
|
| 695 |
+
accumulator.to(c_dtype),
|
| 696 |
+
[n_offset, k_offset],
|
| 697 |
+
)
|
| 698 |
+
else:
|
| 699 |
+
# Manual store with bounds checking
|
| 700 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
| 701 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 702 |
+
|
| 703 |
+
# Create masks for bounds checking
|
| 704 |
+
n_mask = offs_n < N - n_offset
|
| 705 |
+
k_mask = offs_k < K - k_offset
|
| 706 |
+
output_mask = n_mask[:, None] & k_mask[None, :]
|
| 707 |
+
|
| 708 |
+
# Store the result
|
| 709 |
+
tl.store(
|
| 710 |
+
grad_weight_ptr
|
| 711 |
+
+ (n_offset + offs_n)[:, None] * K
|
| 712 |
+
+ (k_offset + offs_k)[None, :],
|
| 713 |
+
accumulator.to(c_dtype),
|
| 714 |
+
mask=output_mask,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
# ======== End Triton kernels ========
|
| 719 |
+
|
| 720 |
+
# ======== Triton wrapper functions ========
|
| 721 |
+
|
| 722 |
+
# ----- main forward pass wrapper -----
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def grouped_gemm_forward(
|
| 726 |
+
x: torch.Tensor,
|
| 727 |
+
w: torch.Tensor,
|
| 728 |
+
m_sizes: torch.Tensor,
|
| 729 |
+
tma_size: int = 128,
|
| 730 |
+
) -> torch.Tensor:
|
| 731 |
+
"""
|
| 732 |
+
M*G style grouped GEMM with TMA and Float8 support.
|
| 733 |
+
# Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.
|
| 734 |
+
|
| 735 |
+
"""
|
| 736 |
+
if not CudaUtils.verify_tma():
|
| 737 |
+
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
|
| 738 |
+
|
| 739 |
+
G = m_sizes.shape[0]
|
| 740 |
+
|
| 741 |
+
assert x.is_contiguous()
|
| 742 |
+
assert w.is_contiguous()
|
| 743 |
+
assert m_sizes.is_contiguous()
|
| 744 |
+
|
| 745 |
+
# Total input size is now [M_total, K] where M_total is the sum of all group sizes
|
| 746 |
+
M_total, K = x.shape
|
| 747 |
+
N = w.shape[0] # N is now the same for all groups
|
| 748 |
+
|
| 749 |
+
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
|
| 750 |
+
|
| 751 |
+
# Verify that all group sizes are multiples of ALIGN_SIZE_M
|
| 752 |
+
# This check is commented out because it will involve a GPU-CPU sync
|
| 753 |
+
# assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"
|
| 754 |
+
|
| 755 |
+
# Create output tensor with correct shape [M_total, N]
|
| 756 |
+
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
|
| 757 |
+
|
| 758 |
+
if M_total == 0:
|
| 759 |
+
return y
|
| 760 |
+
|
| 761 |
+
NUM_SMS = CudaUtils.get_num_sms()
|
| 762 |
+
USE_TMA_LOAD = True
|
| 763 |
+
USE_TMA_STORE = True
|
| 764 |
+
USE_EPILOGUE_SUBTILING = False
|
| 765 |
+
|
| 766 |
+
# TMA descriptor helper
|
| 767 |
+
desc_helper = None
|
| 768 |
+
desc_x = x
|
| 769 |
+
desc_w = w
|
| 770 |
+
workspace = None
|
| 771 |
+
|
| 772 |
+
if USE_TMA_LOAD:
|
| 773 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
| 774 |
+
desc_helper.init_tma_descriptor("x")
|
| 775 |
+
desc_helper.init_tma_descriptor("w")
|
| 776 |
+
desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
|
| 777 |
+
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
| 778 |
+
|
| 779 |
+
if USE_TMA_STORE:
|
| 780 |
+
workspace = torch.empty(
|
| 781 |
+
NUM_SMS * desc_helper.tma_size,
|
| 782 |
+
device=x.device,
|
| 783 |
+
dtype=torch.uint8,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
def grid(META):
|
| 787 |
+
if USE_TMA_LOAD:
|
| 788 |
+
nonlocal desc_helper
|
| 789 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 790 |
+
"x",
|
| 791 |
+
x.data_ptr(),
|
| 792 |
+
M_total,
|
| 793 |
+
K,
|
| 794 |
+
META["BLOCK_SIZE_M"],
|
| 795 |
+
META["BLOCK_SIZE_K"],
|
| 796 |
+
x.element_size(),
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 800 |
+
"w",
|
| 801 |
+
w.data_ptr(),
|
| 802 |
+
N,
|
| 803 |
+
K,
|
| 804 |
+
META["BLOCK_SIZE_N"],
|
| 805 |
+
META["BLOCK_SIZE_K"],
|
| 806 |
+
w.element_size(),
|
| 807 |
+
)
|
| 808 |
+
return (NUM_SMS,)
|
| 809 |
+
|
| 810 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
| 811 |
+
|
| 812 |
+
_kernel_mg_forward_hopper[grid](
|
| 813 |
+
desc_x,
|
| 814 |
+
desc_w,
|
| 815 |
+
y,
|
| 816 |
+
workspace,
|
| 817 |
+
m_sizes,
|
| 818 |
+
G,
|
| 819 |
+
M_BUCKET,
|
| 820 |
+
N,
|
| 821 |
+
K,
|
| 822 |
+
NUM_SMS,
|
| 823 |
+
TMA_SIZE=tma_size,
|
| 824 |
+
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
return y
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
# ======== Improved Backward =============
|
| 831 |
+
def grouped_gemm_backward(
|
| 832 |
+
grad_output: torch.Tensor,
|
| 833 |
+
x: torch.Tensor,
|
| 834 |
+
w: torch.Tensor,
|
| 835 |
+
m_sizes: torch.Tensor,
|
| 836 |
+
use_tma: bool = True,
|
| 837 |
+
tma_size: int = 128,
|
| 838 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 839 |
+
"""
|
| 840 |
+
Unified backward pass for grouped GeMM with M*G grouping.
|
| 841 |
+
Uses optimized TMA-based implementations for both dx and dw when available.
|
| 842 |
+
|
| 843 |
+
Args:
|
| 844 |
+
grad_output: Gradient of output, shape [M_total, N]
|
| 845 |
+
x: Input tensor from forward pass, shape [M_total, K]
|
| 846 |
+
w: Weight tensor from forward pass, shape [N, K]
|
| 847 |
+
m_sizes: Group sizes tensor, shape [G]
|
| 848 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
| 849 |
+
tma_size: Size of TMA descriptor in bytes
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
Returns:
|
| 853 |
+
Tuple of gradients with respect to x and w: (grad_x, grad_w)
|
| 854 |
+
"""
|
| 855 |
+
logging.info("Starting unified grouped_gemm_backward")
|
| 856 |
+
|
| 857 |
+
# do this once, seems expensive
|
| 858 |
+
NUM_SMS = CudaUtils.get_num_sms()
|
| 859 |
+
|
| 860 |
+
# Basic validation
|
| 861 |
+
G = m_sizes.shape[0]
|
| 862 |
+
M_total, K_x = x.shape
|
| 863 |
+
M_grad, N = grad_output.shape
|
| 864 |
+
N_w, K_w = w.shape
|
| 865 |
+
|
| 866 |
+
# Check dimensions
|
| 867 |
+
if K_x != K_w:
|
| 868 |
+
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
|
| 869 |
+
if M_total != M_grad:
|
| 870 |
+
raise ValueError(
|
| 871 |
+
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
# Check total M matches sum of group sizes
|
| 875 |
+
sum_m_sizes = m_sizes.sum().item()
|
| 876 |
+
if M_total != sum_m_sizes:
|
| 877 |
+
raise ValueError(
|
| 878 |
+
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
# Make sure inputs are contiguous
|
| 882 |
+
grad_output = grad_output.contiguous()
|
| 883 |
+
x = x.contiguous()
|
| 884 |
+
w = w.contiguous()
|
| 885 |
+
m_sizes = m_sizes.contiguous()
|
| 886 |
+
|
| 887 |
+
# Check TMA support
|
| 888 |
+
can_use_tma = use_tma and CudaUtils.verify_tma()
|
| 889 |
+
if use_tma and not can_use_tma:
|
| 890 |
+
logging.info("TMA requested but not supported on this device")
|
| 891 |
+
use_tma = False
|
| 892 |
+
|
| 893 |
+
# Compute grad_x using flat linear implementation
|
| 894 |
+
try:
|
| 895 |
+
logging.info(f"Computing grad_x with flat linear kernel")
|
| 896 |
+
|
| 897 |
+
# Use TMA-optimized implementation
|
| 898 |
+
grad_x = grouped_gemm_dx_tma(
|
| 899 |
+
grad_output=grad_output,
|
| 900 |
+
w=w,
|
| 901 |
+
m_sizes=m_sizes,
|
| 902 |
+
num_sms=NUM_SMS,
|
| 903 |
+
tma_size=tma_size,
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
except Exception as e:
|
| 907 |
+
logging.error(f"Error in grad_x computation: {e}")
|
| 908 |
+
raise
|
| 909 |
+
|
| 910 |
+
# Compute grad_w using flat linear style implementation
|
| 911 |
+
try:
|
| 912 |
+
logging.info(f"Computing grad_w with flat linear kernel")
|
| 913 |
+
|
| 914 |
+
grad_w = grouped_gemm_dw_tma(
|
| 915 |
+
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
|
| 916 |
+
)
|
| 917 |
+
except Exception as e:
|
| 918 |
+
logging.error(f"Error in grad_w computation: {e}")
|
| 919 |
+
raise
|
| 920 |
+
|
| 921 |
+
return grad_x, grad_w
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
# ----- dx backward pass wrapper -----
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
def grouped_gemm_dx_tma(
|
| 928 |
+
grad_output: torch.Tensor,
|
| 929 |
+
w: torch.Tensor,
|
| 930 |
+
m_sizes: torch.Tensor,
|
| 931 |
+
num_sms: int = 132,
|
| 932 |
+
tma_size: int = 128,
|
| 933 |
+
) -> torch.Tensor:
|
| 934 |
+
"""
|
| 935 |
+
Optimized backward pass wrapper for computing gradient with respect to input (dx)
|
| 936 |
+
using TMA patterns similar to the forward pass.
|
| 937 |
+
|
| 938 |
+
Args:
|
| 939 |
+
grad_output: Gradient of output, shape [M_total, N]
|
| 940 |
+
w: Weight tensor, shape [N, K]
|
| 941 |
+
m_sizes: Group sizes tensor, shape [G]
|
| 942 |
+
tma_size: Size of TMA descriptor
|
| 943 |
+
# using_fp8: Whether to use FP8 quantization
|
| 944 |
+
# grad_output_scale: Scale for grad_output in FP8 mode
|
| 945 |
+
# w_scale: Scale for w in FP8 mode
|
| 946 |
+
|
| 947 |
+
Returns:
|
| 948 |
+
grad_x: Gradient with respect to x, shape [M_total, K]
|
| 949 |
+
"""
|
| 950 |
+
"""
|
| 951 |
+
Optimized backward pass for computing gradient with respect to input (dx)
|
| 952 |
+
using TMA patterns similar to the forward pass.
|
| 953 |
+
|
| 954 |
+
Args:
|
| 955 |
+
grad_output: Gradient of output, shape [M_total, N]
|
| 956 |
+
w: Weight tensor, shape [N, K]
|
| 957 |
+
m_sizes: Group sizes tensor, shape [G]
|
| 958 |
+
tma_size: Size of TMA descriptor
|
| 959 |
+
using_fp8: Whether to use FP8 quantization
|
| 960 |
+
# grad_output_scale: Scale for grad_output in FP8 mode
|
| 961 |
+
# w_scale: Scale for w in FP8 mode
|
| 962 |
+
|
| 963 |
+
Returns:
|
| 964 |
+
grad_x: Gradient with respect to x, shape [M_total, K]
|
| 965 |
+
"""
|
| 966 |
+
if not CudaUtils.verify_tma():
|
| 967 |
+
raise NotImplementedError("Optimized dx computation requires TMA support")
|
| 968 |
+
|
| 969 |
+
G = m_sizes.shape[0]
|
| 970 |
+
|
| 971 |
+
assert grad_output.is_contiguous()
|
| 972 |
+
assert w.is_contiguous()
|
| 973 |
+
assert m_sizes.is_contiguous()
|
| 974 |
+
|
| 975 |
+
M_total, N_grad = grad_output.shape
|
| 976 |
+
N_w, K = w.shape
|
| 977 |
+
|
| 978 |
+
# Check dimensions
|
| 979 |
+
assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"
|
| 980 |
+
|
| 981 |
+
# Verify that the sum of m_sizes matches M_total
|
| 982 |
+
sum_m_sizes = m_sizes.sum().item()
|
| 983 |
+
assert (
|
| 984 |
+
M_total == sum_m_sizes
|
| 985 |
+
), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
| 986 |
+
|
| 987 |
+
# Create output tensor (grad_x) with shape [M_total, K]
|
| 988 |
+
grad_x = torch.empty(
|
| 989 |
+
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
NUM_SMS = num_sms # CudaUtils.get_num_sms()
|
| 993 |
+
USE_TMA_LOAD = True
|
| 994 |
+
USE_TMA_STORE = True
|
| 995 |
+
|
| 996 |
+
# Set up TMA descriptors
|
| 997 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
| 998 |
+
desc_helper.init_tma_descriptor("grad_output")
|
| 999 |
+
desc_helper.init_tma_descriptor("w")
|
| 1000 |
+
desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
|
| 1001 |
+
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
| 1002 |
+
|
| 1003 |
+
# Allocate workspace for TMA store
|
| 1004 |
+
workspace = torch.empty(
|
| 1005 |
+
NUM_SMS * desc_helper.tma_size,
|
| 1006 |
+
device=grad_output.device,
|
| 1007 |
+
dtype=torch.uint8,
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
def grid(META):
|
| 1011 |
+
# Fill TMA descriptors with appropriate dimensions
|
| 1012 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 1013 |
+
"grad_output",
|
| 1014 |
+
grad_output.data_ptr(),
|
| 1015 |
+
M_total,
|
| 1016 |
+
N_grad,
|
| 1017 |
+
META["BLOCK_SIZE_M"],
|
| 1018 |
+
META["BLOCK_SIZE_N"],
|
| 1019 |
+
grad_output.element_size(),
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 1023 |
+
"w",
|
| 1024 |
+
w.data_ptr(),
|
| 1025 |
+
N_w,
|
| 1026 |
+
K,
|
| 1027 |
+
META["BLOCK_SIZE_N"],
|
| 1028 |
+
META["BLOCK_SIZE_K"],
|
| 1029 |
+
w.element_size(),
|
| 1030 |
+
)
|
| 1031 |
+
return (NUM_SMS,)
|
| 1032 |
+
|
| 1033 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
| 1034 |
+
|
| 1035 |
+
# Launch the flat linear kernel for computing grad_x
|
| 1036 |
+
_kernel_mg_dx_tma[grid](
|
| 1037 |
+
desc_grad_output,
|
| 1038 |
+
desc_w,
|
| 1039 |
+
grad_x,
|
| 1040 |
+
workspace,
|
| 1041 |
+
m_sizes,
|
| 1042 |
+
G,
|
| 1043 |
+
M_BUCKET,
|
| 1044 |
+
N_grad, # N dimension is now the reduction dimension
|
| 1045 |
+
K,
|
| 1046 |
+
NUM_SMS,
|
| 1047 |
+
USE_TMA_LOAD,
|
| 1048 |
+
USE_TMA_STORE,
|
| 1049 |
+
TMA_SIZE=tma_size,
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
return grad_x
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
# ======== dw wrapper function ==========
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
def grouped_gemm_dw_tma(
|
| 1059 |
+
x: torch.Tensor,
|
| 1060 |
+
grad_output: torch.Tensor,
|
| 1061 |
+
m_sizes: torch.Tensor,
|
| 1062 |
+
num_sms: int = 132,
|
| 1063 |
+
tma_size: int = 128,
|
| 1064 |
+
) -> torch.Tensor:
|
| 1065 |
+
"""
|
| 1066 |
+
Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
|
| 1067 |
+
For the forward pass Y = X @ W.T, the backward for weights is:
|
| 1068 |
+
grad_W = grad_Y.T @ X
|
| 1069 |
+
|
| 1070 |
+
Args:
|
| 1071 |
+
x: Input tensor, shape [M_total, K]
|
| 1072 |
+
grad_output: Gradient of output, shape [M_total, N]
|
| 1073 |
+
m_sizes: Group sizes tensor, shape [G]
|
| 1074 |
+
tma_size: Size of TMA descriptor in bytes
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
Returns:
|
| 1078 |
+
grad_w: Gradient with respect to weights, shape [N, K]
|
| 1079 |
+
"""
|
| 1080 |
+
# Check TMA support
|
| 1081 |
+
has_tma_support = CudaUtils.verify_tma()
|
| 1082 |
+
|
| 1083 |
+
# Get group count
|
| 1084 |
+
G = m_sizes.shape[0]
|
| 1085 |
+
|
| 1086 |
+
# Ensure contiguous tensors
|
| 1087 |
+
x = x.contiguous()
|
| 1088 |
+
grad_output = grad_output.contiguous()
|
| 1089 |
+
m_sizes = m_sizes.contiguous()
|
| 1090 |
+
|
| 1091 |
+
# Get dimensions
|
| 1092 |
+
M_total, K_x = x.shape
|
| 1093 |
+
M_grad, N = grad_output.shape
|
| 1094 |
+
|
| 1095 |
+
# Check dimensions
|
| 1096 |
+
assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"
|
| 1097 |
+
|
| 1098 |
+
# Verify that the sum of m_sizes matches M_total
|
| 1099 |
+
sum_m_sizes = m_sizes.sum().item()
|
| 1100 |
+
assert (
|
| 1101 |
+
sum_m_sizes == M_total
|
| 1102 |
+
), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
| 1103 |
+
|
| 1104 |
+
# Create output tensor (grad_w) with shape [N, K]
|
| 1105 |
+
grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)
|
| 1106 |
+
|
| 1107 |
+
NUM_SMS = num_sms
|
| 1108 |
+
|
| 1109 |
+
# TODO - hardcoded for now...but should set TMA flags based on hardware support
|
| 1110 |
+
USE_TMA_LOAD = True # has_tma_support
|
| 1111 |
+
USE_TMA_STORE = True # has_tma_support
|
| 1112 |
+
|
| 1113 |
+
# Set up TMA descriptors or direct pointers
|
| 1114 |
+
if USE_TMA_LOAD or USE_TMA_STORE:
|
| 1115 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
| 1116 |
+
|
| 1117 |
+
if USE_TMA_LOAD:
|
| 1118 |
+
desc_helper.init_tma_descriptor("x")
|
| 1119 |
+
desc_helper.init_tma_descriptor("grad_output")
|
| 1120 |
+
x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
|
| 1121 |
+
grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
|
| 1122 |
+
"grad_output"
|
| 1123 |
+
)
|
| 1124 |
+
else:
|
| 1125 |
+
x_desc = x
|
| 1126 |
+
grad_output_desc = grad_output
|
| 1127 |
+
|
| 1128 |
+
if USE_TMA_STORE:
|
| 1129 |
+
desc_helper.init_tma_descriptor("grad_w")
|
| 1130 |
+
workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
|
| 1131 |
+
else:
|
| 1132 |
+
workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
|
| 1133 |
+
else:
|
| 1134 |
+
# If not using TMA, just use the tensors directly
|
| 1135 |
+
x_desc = x
|
| 1136 |
+
grad_output_desc = grad_output
|
| 1137 |
+
workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
|
| 1138 |
+
|
| 1139 |
+
# M_BUCKET for grid size
|
| 1140 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
| 1141 |
+
|
| 1142 |
+
# Define grid for kernel launch
|
| 1143 |
+
def grid(META):
|
| 1144 |
+
if USE_TMA_LOAD or USE_TMA_STORE:
|
| 1145 |
+
|
| 1146 |
+
if USE_TMA_LOAD:
|
| 1147 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 1148 |
+
"x",
|
| 1149 |
+
x.data_ptr(),
|
| 1150 |
+
M_total,
|
| 1151 |
+
K_x,
|
| 1152 |
+
META["BLOCK_SIZE_M"],
|
| 1153 |
+
META["BLOCK_SIZE_K"],
|
| 1154 |
+
x.element_size(),
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 1158 |
+
"grad_output",
|
| 1159 |
+
grad_output.data_ptr(),
|
| 1160 |
+
M_total,
|
| 1161 |
+
N,
|
| 1162 |
+
META["BLOCK_SIZE_M"],
|
| 1163 |
+
META["BLOCK_SIZE_N"],
|
| 1164 |
+
grad_output.element_size(),
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
if USE_TMA_STORE:
|
| 1168 |
+
desc_helper.fill_2d_tma_descriptor(
|
| 1169 |
+
"grad_w",
|
| 1170 |
+
grad_w.data_ptr(),
|
| 1171 |
+
N,
|
| 1172 |
+
K_x,
|
| 1173 |
+
META["BLOCK_SIZE_N"],
|
| 1174 |
+
META["BLOCK_SIZE_K"],
|
| 1175 |
+
grad_w.element_size(),
|
| 1176 |
+
)
|
| 1177 |
+
|
| 1178 |
+
# Return grid size - one block per SM for balanced work distribution
|
| 1179 |
+
return (NUM_SMS,)
|
| 1180 |
+
|
| 1181 |
+
# Launch the optimized kernel
|
| 1182 |
+
_kernel_mg_dw_tma[grid](
|
| 1183 |
+
x_desc,
|
| 1184 |
+
grad_output_desc,
|
| 1185 |
+
grad_w,
|
| 1186 |
+
workspace,
|
| 1187 |
+
m_sizes,
|
| 1188 |
+
G,
|
| 1189 |
+
M_BUCKET,
|
| 1190 |
+
N,
|
| 1191 |
+
K_x,
|
| 1192 |
+
NUM_SMS,
|
| 1193 |
+
USE_TMA_LOAD,
|
| 1194 |
+
USE_TMA_STORE,
|
| 1195 |
+
TMA_SIZE=tma_size,
|
| 1196 |
+
)
|
| 1197 |
+
|
| 1198 |
+
return grad_w
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
# ======== End Backwards Wrapper Functions =============
|
| 1202 |
+
|
| 1203 |
+
# ======== PyTorch wrapper functions ========
|
| 1204 |
+
|
| 1205 |
+
|
| 1206 |
+
class GroupedGEMM_mg(torch.autograd.Function):
|
| 1207 |
+
"""
|
| 1208 |
+
Autograd function for GroupedGEMM with M*G grouping.
|
| 1209 |
+
Supports both standard and FP8 quantized operations.
|
| 1210 |
+
"""
|
| 1211 |
+
|
| 1212 |
+
@staticmethod
|
| 1213 |
+
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
|
| 1214 |
+
"""
|
| 1215 |
+
Forward pass of GroupedGEMM.
|
| 1216 |
+
|
| 1217 |
+
Args:
|
| 1218 |
+
x: Input tensor, shape [M_total, K]
|
| 1219 |
+
w: Weight tensor, shape [N, K]
|
| 1220 |
+
m_sizes: Tensor of shape [G] containing the size of each group
|
| 1221 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
| 1222 |
+
tma_size: Size of TMA descriptor in bytes
|
| 1223 |
+
using_fp8: Whether to use FP8 quantization
|
| 1224 |
+
|
| 1225 |
+
Returns:
|
| 1226 |
+
Output tensor, shape [M_total, N]
|
| 1227 |
+
"""
|
| 1228 |
+
|
| 1229 |
+
# Use regular forward without quantization
|
| 1230 |
+
output = grouped_gemm_forward(
|
| 1231 |
+
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
+
# Save inputs and parameters for backward pass
|
| 1235 |
+
ctx.save_for_backward(x, w, m_sizes)
|
| 1236 |
+
ctx.use_tma = use_tma
|
| 1237 |
+
ctx.tma_size = tma_size
|
| 1238 |
+
|
| 1239 |
+
ctx.save_for_backward(x, w, m_sizes)
|
| 1240 |
+
|
| 1241 |
+
return output
|
| 1242 |
+
|
| 1243 |
+
@staticmethod
|
| 1244 |
+
def backward(ctx, grad_output):
|
| 1245 |
+
"""
|
| 1246 |
+
Backward pass of M*G GroupedGEMM.
|
| 1247 |
+
|
| 1248 |
+
Args:
|
| 1249 |
+
grad_output: Gradient of output, shape [M_total, N]
|
| 1250 |
+
|
| 1251 |
+
Returns:
|
| 1252 |
+
Tuple of gradients:
|
| 1253 |
+
- grad_x: Gradient with respect to x, shape [M_total, K]
|
| 1254 |
+
- grad_w: Gradient with respect to w, shape [N, K]
|
| 1255 |
+
- None: Gradient with respect to m_sizes (not differentiable)
|
| 1256 |
+
- None: Gradient with respect to use_tma (not differentiable)
|
| 1257 |
+
- None: Gradient with respect to tma_size (not differentiable)
|
| 1258 |
+
|
| 1259 |
+
"""
|
| 1260 |
+
# Retrieve saved tensors and parameters
|
| 1261 |
+
|
| 1262 |
+
x, w, m_sizes = ctx.saved_tensors
|
| 1263 |
+
|
| 1264 |
+
use_tma = ctx.use_tma
|
| 1265 |
+
tma_size = ctx.tma_size
|
| 1266 |
+
|
| 1267 |
+
# Compute gradients using the unified implementation
|
| 1268 |
+
grad_x, grad_w = grouped_gemm_backward(
|
| 1269 |
+
grad_output=grad_output,
|
| 1270 |
+
x=x,
|
| 1271 |
+
w=w,
|
| 1272 |
+
m_sizes=m_sizes,
|
| 1273 |
+
use_tma=use_tma,
|
| 1274 |
+
tma_size=tma_size,
|
| 1275 |
+
)
|
| 1276 |
+
|
| 1277 |
+
# Return gradients for all inputs (None for non-differentiable parameters)
|
| 1278 |
+
return grad_x, grad_w, None, None
|
| 1279 |
+
|
| 1280 |
+
|
| 1281 |
+
def mg_grouped_gemm(
|
| 1282 |
+
x: torch.Tensor,
|
| 1283 |
+
w: torch.Tensor,
|
| 1284 |
+
m_sizes: torch.Tensor,
|
| 1285 |
+
use_tma: bool = True,
|
| 1286 |
+
tma_size: int = 128,
|
| 1287 |
+
using_fp8: bool = False,
|
| 1288 |
+
) -> torch.Tensor:
|
| 1289 |
+
"""
|
| 1290 |
+
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
|
| 1291 |
+
Supports both standard precision and FP8 quantized operations.
|
| 1292 |
+
|
| 1293 |
+
Args:
|
| 1294 |
+
x: Input tensor, shape [M_total, K]
|
| 1295 |
+
w: Weight tensor, shape [N, K]
|
| 1296 |
+
m_sizes: Tensor of shape [G] containing the size of each group
|
| 1297 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
| 1298 |
+
tma_size: Size of TMA descriptor in bytes
|
| 1299 |
+
using_fp8: Whether to use FP8 quantization
|
| 1300 |
+
|
| 1301 |
+
Returns:
|
| 1302 |
+
Output tensor, shape [M_total, N]
|
| 1303 |
+
"""
|
| 1304 |
+
return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
|
| 8 |
+
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
| 9 |
+
|
| 10 |
+
# pyre-unsafe
|
| 11 |
+
import functools
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from typing import Any, Dict, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
import triton
|
| 20 |
+
import triton.language as tl
|
| 21 |
+
from triton import Config as TConfig
|
| 22 |
+
|
| 23 |
+
from triton.runtime import driver # @manual
|
| 24 |
+
|
| 25 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ===== Supporting utils, CUDA and TMA =====
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CudaUtils:
|
| 32 |
+
@staticmethod
|
| 33 |
+
def is_cuda() -> bool:
|
| 34 |
+
"""Check if Triton is running on CUDA backend."""
|
| 35 |
+
return driver.active.get_current_target().backend == "cuda"
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def verify_tma() -> bool:
|
| 39 |
+
"""Check if TMA is supported on the current device."""
|
| 40 |
+
return (
|
| 41 |
+
CudaUtils.is_cuda()
|
| 42 |
+
and torch.cuda.is_available()
|
| 43 |
+
and torch.cuda.get_device_capability()[0] >= 9
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_num_sms() -> int:
|
| 48 |
+
"""Get the number of streaming multiprocessors on the current device."""
|
| 49 |
+
if not CudaUtils.is_cuda():
|
| 50 |
+
raise RuntimeError("Triton is not running on CUDA backend")
|
| 51 |
+
if not torch.cuda.is_available():
|
| 52 |
+
raise RuntimeError("CUDA is not available")
|
| 53 |
+
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TmaDescriptorHelper:
|
| 57 |
+
"""Helper class for managing TMA descriptors in Triton kernels."""
|
| 58 |
+
|
| 59 |
+
class KernelParamWrapper:
|
| 60 |
+
"""Wrapper to implement the TmaDescKernelParam interface."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, desc: torch.Tensor):
|
| 63 |
+
self.desc = desc
|
| 64 |
+
|
| 65 |
+
def tma_desc_cpu_ptr(self) -> int:
|
| 66 |
+
"""Return the CPU pointer to the TMA descriptor."""
|
| 67 |
+
return self.desc.data_ptr()
|
| 68 |
+
|
| 69 |
+
def __init__(self, tma_size: int = 128):
|
| 70 |
+
"""Initialize the TMA descriptor helper.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
tma_size: Size of the TMA descriptor in bytes
|
| 74 |
+
"""
|
| 75 |
+
if not CudaUtils.verify_tma():
|
| 76 |
+
raise RuntimeError(
|
| 77 |
+
"TMA not supported on this device (requires Hopper or newer)"
|
| 78 |
+
)
|
| 79 |
+
if "nv_tma_desc_type" not in dir(tl):
|
| 80 |
+
raise RuntimeError(
|
| 81 |
+
"TMA grid constant descriptors not supported in your Triton version"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.tma_size = tma_size
|
| 85 |
+
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
| 86 |
+
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
| 87 |
+
self.descriptors: Dict[str, torch.Tensor] = {}
|
| 88 |
+
|
| 89 |
+
def init_tma_descriptor(self, name: str) -> None:
|
| 90 |
+
"""Initialize a TMA descriptor with the given name.
|
| 91 |
+
|
| 92 |
+
Call this method outside of the lambda function for grid size.
|
| 93 |
+
"""
|
| 94 |
+
self.descriptors[name] = torch.empty(
|
| 95 |
+
self.tma_size, device="cpu", dtype=torch.int8
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def fill_1d_tma_descriptor(
|
| 99 |
+
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
| 100 |
+
) -> None:
|
| 101 |
+
"""Fill a 1D TMA descriptor.
|
| 102 |
+
|
| 103 |
+
Call this method inside the lambda function for grid size.
|
| 104 |
+
"""
|
| 105 |
+
if name not in self.descriptors:
|
| 106 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 107 |
+
|
| 108 |
+
desc_x = self.descriptors[name]
|
| 109 |
+
if desc_x.data_ptr() % 64 != 0:
|
| 110 |
+
raise ValueError("TMA descriptor must be 64-byte aligned")
|
| 111 |
+
self.fill_1d_tma_descriptor_inner(
|
| 112 |
+
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def fill_2d_tma_descriptor(
|
| 116 |
+
self,
|
| 117 |
+
name: str,
|
| 118 |
+
ptr: int,
|
| 119 |
+
dim1: int,
|
| 120 |
+
dim0: int,
|
| 121 |
+
block_dim1: int,
|
| 122 |
+
block_dim0: int,
|
| 123 |
+
element_size: int,
|
| 124 |
+
) -> None:
|
| 125 |
+
"""Fill a 2D TMA descriptor.
|
| 126 |
+
|
| 127 |
+
Call this method inside the lambda function for grid size.
|
| 128 |
+
"""
|
| 129 |
+
if name not in self.descriptors:
|
| 130 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 131 |
+
|
| 132 |
+
desc_x = self.descriptors[name]
|
| 133 |
+
if desc_x.data_ptr() % 64 != 0:
|
| 134 |
+
raise ValueError("TMA descriptor must be 64-byte aligned")
|
| 135 |
+
self.fill_2d_tma_descriptor_inner(
|
| 136 |
+
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
|
| 140 |
+
"""Get the TMA descriptor kernel parameter for the given name."""
|
| 141 |
+
if name not in self.descriptors or self.descriptors[name] is None:
|
| 142 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 143 |
+
return self.KernelParamWrapper(self.descriptors[name])
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ====== Autotuning utilities ======
|
| 147 |
+
ALIGN_SIZE_M = 128
|
| 148 |
+
|
| 149 |
+
_NV_CONFIGS = [
|
| 150 |
+
triton.Config(
|
| 151 |
+
{
|
| 152 |
+
"BLOCK_SIZE_M": block_size_m,
|
| 153 |
+
"BLOCK_SIZE_N": block_size_n,
|
| 154 |
+
"BLOCK_SIZE_K": block_size_k,
|
| 155 |
+
},
|
| 156 |
+
num_stages=num_stages,
|
| 157 |
+
num_warps=num_warps,
|
| 158 |
+
num_ctas=num_ctas,
|
| 159 |
+
)
|
| 160 |
+
for block_size_m in [ALIGN_SIZE_M, ]
|
| 161 |
+
for block_size_n in [64, 128, 256]
|
| 162 |
+
for block_size_k in [64, 128, 256]
|
| 163 |
+
for num_stages in [3, 4]
|
| 164 |
+
for num_warps in [4, 8]
|
| 165 |
+
for num_ctas in [1]
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
| 170 |
+
device = torch.cuda.current_device()
|
| 171 |
+
# Check for all possible pointer parameter names
|
| 172 |
+
if "grad_input_ptr" in named_args:
|
| 173 |
+
ptr_name = "grad_input_ptr"
|
| 174 |
+
elif "c_ptr" in named_args:
|
| 175 |
+
ptr_name = "c_ptr"
|
| 176 |
+
elif "grad_weight_ptr" in named_args:
|
| 177 |
+
ptr_name = "grad_weight_ptr"
|
| 178 |
+
else:
|
| 179 |
+
raise KeyError("No recognized pointer parameter found in kernel arguments")
|
| 180 |
+
|
| 181 |
+
if dtsize is None:
|
| 182 |
+
dtsize = named_args[ptr_name].element_size()
|
| 183 |
+
if dtype is None:
|
| 184 |
+
dtype = named_args[ptr_name].dtype
|
| 185 |
+
|
| 186 |
+
pruned_configs = []
|
| 187 |
+
for config in configs:
|
| 188 |
+
kw = config.kwargs
|
| 189 |
+
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
|
| 190 |
+
kw["BLOCK_SIZE_M"],
|
| 191 |
+
kw["BLOCK_SIZE_N"],
|
| 192 |
+
kw["BLOCK_SIZE_K"],
|
| 193 |
+
config.num_stages,
|
| 194 |
+
)
|
| 195 |
+
G, M, N, K = (
|
| 196 |
+
named_args["G"],
|
| 197 |
+
named_args["M_BUCKET"],
|
| 198 |
+
named_args["N"],
|
| 199 |
+
named_args["K"],
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# 1. make sure we have enough smem
|
| 203 |
+
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
| 204 |
+
"max_shared_mem"
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
| 208 |
+
if required_shared_memory > max_shared_memory:
|
| 209 |
+
continue
|
| 210 |
+
|
| 211 |
+
M_PER_GROUP = M // G
|
| 212 |
+
MIN_M_TILES = 64
|
| 213 |
+
# 2. make sure we don't load M tiles that are too big
|
| 214 |
+
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
| 215 |
+
continue
|
| 216 |
+
# 3. make sure we don't load N tiles that are too small
|
| 217 |
+
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
+
num_sm = driver.active.utils.get_device_properties(device)[
|
| 221 |
+
"multiprocessor_count"
|
| 222 |
+
]
|
| 223 |
+
N_TILES = N // BLOCK_N
|
| 224 |
+
MIN_N_TILES = 64
|
| 225 |
+
# 4. make sure we don't load N tiles that are too big
|
| 226 |
+
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
| 227 |
+
continue
|
| 228 |
+
# 5. make sure we don't load N tiles that are too small
|
| 229 |
+
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
| 230 |
+
continue
|
| 231 |
+
# 6. make sure K can be evenly divided
|
| 232 |
+
if K % BLOCK_K != 0:
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
pruned_configs.append(config)
|
| 236 |
+
|
| 237 |
+
return pruned_configs
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ======== End Autotuning utilities ========
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# pyre-unsafe
|
| 8 |
+
import logging
|
| 9 |
+
import unittest
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
from mg_grouped_gemm import grouped_gemm_forward
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestMG_GroupedGEMM(unittest.TestCase):
|
| 19 |
+
def setUp(self) -> None:
|
| 20 |
+
torch.manual_seed(2020)
|
| 21 |
+
|
| 22 |
+
def _run_grouped_gemm_test(
|
| 23 |
+
self,
|
| 24 |
+
shape: Tuple[int, int, int, int],
|
| 25 |
+
device: torch.device,
|
| 26 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 27 |
+
atol: float = 1e-5,
|
| 28 |
+
rtol: float = 1.6e-2,
|
| 29 |
+
) -> None:
|
| 30 |
+
G, M, N, K = shape
|
| 31 |
+
# In M*G grouping, input is [M*G, K] and weights are [N*G, K]
|
| 32 |
+
a = torch.randn(M * G, K, dtype=dtype, device=device)
|
| 33 |
+
b = torch.randn(N * G, K, dtype=dtype, device=device)
|
| 34 |
+
|
| 35 |
+
# Create equal-sized groups for simplicity
|
| 36 |
+
m_size = M
|
| 37 |
+
m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
|
| 38 |
+
|
| 39 |
+
result = grouped_gemm_forward(a, b, m_sizes)
|
| 40 |
+
self.assertTrue(result.shape == (M * G, N))
|
| 41 |
+
|
| 42 |
+
expected_result = torch.zeros(M * G, N, dtype=dtype, device=device)
|
| 43 |
+
m_start = 0
|
| 44 |
+
for g in range(G):
|
| 45 |
+
m_end = m_start + m_sizes[g]
|
| 46 |
+
b_slice = b[N * g : N * (g+1), :]
|
| 47 |
+
expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T
|
| 48 |
+
m_start = m_end
|
| 49 |
+
|
| 50 |
+
# Convert result to match input dtype if needed
|
| 51 |
+
result = result.to(dtype)
|
| 52 |
+
torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol)
|
| 53 |
+
|
| 54 |
+
def test_MG_grouped_gemm_bf16(self) -> None:
|
| 55 |
+
for G in (1, 4, 16):
|
| 56 |
+
for M in (128, 512, 1024):
|
| 57 |
+
print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}")
|
| 58 |
+
self._run_grouped_gemm_test(
|
| 59 |
+
(G, M, 1024, 1024),
|
| 60 |
+
torch.device("cuda"),
|
| 61 |
+
dtype=torch.bfloat16,
|
| 62 |
+
atol=1e-5,
|
| 63 |
+
rtol=1.6e-2,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def test_MG_grouped_gemm_deepseek_shapes(self) -> None:
|
| 67 |
+
"""Test with shapes from Deepseek model."""
|
| 68 |
+
deepseek_shapes = [
|
| 69 |
+
(4, 2048, 4096, 7168), # G, M, N, K
|
| 70 |
+
(4, 2048, 7168, 2048),
|
| 71 |
+
(8, 512, 4096, 7168),
|
| 72 |
+
(8, 512, 7168, 2048),
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
device = torch.device("cuda")
|
| 76 |
+
|
| 77 |
+
for shape in deepseek_shapes:
|
| 78 |
+
G, M, N, K = shape
|
| 79 |
+
print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}")
|
| 80 |
+
self._run_grouped_gemm_test(
|
| 81 |
+
shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2
|
| 82 |
+
)
|
torchtitan/experiments/llama4/__init__.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torchtitan.components.loss import build_cross_entropy_loss
|
| 8 |
+
from torchtitan.components.lr_scheduler import build_lr_schedulers
|
| 9 |
+
from torchtitan.components.optimizer import build_optimizers
|
| 10 |
+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
|
| 11 |
+
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
|
| 12 |
+
from torchtitan.models.llama3 import pipeline_llama
|
| 13 |
+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
|
| 14 |
+
|
| 15 |
+
from .infra.parallelize_llama import parallelize_llama
|
| 16 |
+
from .model.args import TransformerModelArgs
|
| 17 |
+
from .model.model import Transformer
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"TransformerModelArgs",
|
| 21 |
+
"Transformer",
|
| 22 |
+
"llama4_configs",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
llama4_configs = {
|
| 27 |
+
"debugmodel": TransformerModelArgs(
|
| 28 |
+
dim=256,
|
| 29 |
+
n_layers=8,
|
| 30 |
+
n_heads=16,
|
| 31 |
+
rope_theta=500000,
|
| 32 |
+
),
|
| 33 |
+
"17bx16e": TransformerModelArgs(
|
| 34 |
+
dim=5120,
|
| 35 |
+
n_layers=48,
|
| 36 |
+
n_heads=40,
|
| 37 |
+
n_kv_heads=8,
|
| 38 |
+
ffn_dim_multiplier=1.2,
|
| 39 |
+
multiple_of=2048,
|
| 40 |
+
rope_theta=500000,
|
| 41 |
+
num_experts=16,
|
| 42 |
+
interleave_moe_layer_step=1,
|
| 43 |
+
),
|
| 44 |
+
"17bx128e": TransformerModelArgs(
|
| 45 |
+
dim=5120,
|
| 46 |
+
n_layers=48,
|
| 47 |
+
n_heads=40,
|
| 48 |
+
n_kv_heads=8,
|
| 49 |
+
ffn_dim_multiplier=1.2,
|
| 50 |
+
multiple_of=2048,
|
| 51 |
+
rope_theta=500000,
|
| 52 |
+
num_experts=128,
|
| 53 |
+
),
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
register_train_spec(
|
| 58 |
+
TrainSpec(
|
| 59 |
+
name="llama4",
|
| 60 |
+
cls=Transformer,
|
| 61 |
+
config=llama4_configs,
|
| 62 |
+
parallelize_fn=parallelize_llama,
|
| 63 |
+
pipelining_fn=pipeline_llama,
|
| 64 |
+
build_optimizers_fn=build_optimizers,
|
| 65 |
+
build_lr_schedulers_fn=build_lr_schedulers,
|
| 66 |
+
build_dataloader_fn=build_hf_dataloader,
|
| 67 |
+
build_tokenizer_fn=build_tiktoken_tokenizer,
|
| 68 |
+
build_loss_fn=build_cross_entropy_loss,
|
| 69 |
+
)
|
| 70 |
+
)
|
torchtitan/experiments/llama4/infra/expert_parallel.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.distributed.tensor import (
|
| 13 |
+
DeviceMesh,
|
| 14 |
+
distribute_module,
|
| 15 |
+
distribute_tensor,
|
| 16 |
+
DTensor,
|
| 17 |
+
Partial,
|
| 18 |
+
Replicate,
|
| 19 |
+
Shard,
|
| 20 |
+
)
|
| 21 |
+
from torch.distributed.tensor.parallel import ParallelStyle
|
| 22 |
+
from torch.distributed.tensor.placement_types import Placement
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# implementation of Tensor Parallel on the non-shared experts in MoE
|
| 26 |
+
class TensorParallel(ParallelStyle):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
*,
|
| 30 |
+
input_layouts: Optional[Tuple[Optional[Placement]]] = None,
|
| 31 |
+
output_layout: Optional[Placement] = None,
|
| 32 |
+
use_local_output: bool = True,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.input_layouts = input_layouts or (Replicate(), None)
|
| 36 |
+
self.output_layout = output_layout or Partial()
|
| 37 |
+
self.desired_input_layouts = (Replicate(), None)
|
| 38 |
+
self.use_local_output = use_local_output
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def _prepare_input_fn(
|
| 42 |
+
input_layouts, desired_input_layouts, mod, inputs, device_mesh
|
| 43 |
+
):
|
| 44 |
+
# TODO: figure out dynamo support for instance method and switch this to instance method
|
| 45 |
+
|
| 46 |
+
# annotate module input placements/sharding with input_layouts
|
| 47 |
+
input_tensor, input_layout, desired_input_layout = (
|
| 48 |
+
inputs[0],
|
| 49 |
+
input_layouts[0],
|
| 50 |
+
desired_input_layouts[0],
|
| 51 |
+
)
|
| 52 |
+
if not isinstance(input_tensor, DTensor):
|
| 53 |
+
input_tensor = DTensor.from_local(
|
| 54 |
+
input_tensor, device_mesh, (input_layout,), run_check=False
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if input_layouts != desired_input_layouts:
|
| 58 |
+
input_tensor = input_tensor.redistribute(
|
| 59 |
+
placements=(desired_input_layout,), async_op=True
|
| 60 |
+
)
|
| 61 |
+
return (input_tensor, *inputs[1:])
|
| 62 |
+
|
| 63 |
+
def _partition_fn(self, name, module, device_mesh):
|
| 64 |
+
module.register_parameter(
|
| 65 |
+
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)]))
|
| 66 |
+
) # Column-wise sharding
|
| 67 |
+
module.register_parameter(
|
| 68 |
+
"w2",
|
| 69 |
+
nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])),
|
| 70 |
+
) # Row-wise sharding
|
| 71 |
+
module.register_parameter(
|
| 72 |
+
"w3",
|
| 73 |
+
nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])),
|
| 74 |
+
) # Column-wise sharding
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
|
| 78 |
+
if outputs.placements != (output_layout,):
|
| 79 |
+
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
|
| 80 |
+
# back to local tensor
|
| 81 |
+
return outputs.to_local() if use_local_output else outputs
|
| 82 |
+
|
| 83 |
+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
| 84 |
+
return distribute_module(
|
| 85 |
+
module,
|
| 86 |
+
device_mesh,
|
| 87 |
+
self._partition_fn,
|
| 88 |
+
partial(
|
| 89 |
+
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
|
| 90 |
+
),
|
| 91 |
+
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# NOTE: This is to achieve replicate computation on the gate module in the MoE router.
|
| 96 |
+
# It does nothing other than (1) setting the module parameters as DTensors on the given mesh
|
| 97 |
+
# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
|
| 98 |
+
# TODO: The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
|
| 99 |
+
# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
|
| 100 |
+
class NoParallel(ParallelStyle):
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
*,
|
| 104 |
+
input_layout: Optional[Placement] = None,
|
| 105 |
+
output_layout: Optional[Placement] = None,
|
| 106 |
+
use_local_output: bool = True,
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.input_layout = input_layout or Replicate()
|
| 110 |
+
self.output_layout = output_layout or Replicate()
|
| 111 |
+
self.desired_input_layout = Replicate()
|
| 112 |
+
self.use_local_output = use_local_output
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
|
| 116 |
+
# annotate module input placements/sharding with input_layouts
|
| 117 |
+
input_tensor = inputs[0]
|
| 118 |
+
if not isinstance(input_tensor, DTensor):
|
| 119 |
+
input_tensor = DTensor.from_local(
|
| 120 |
+
input_tensor, device_mesh, (input_layout,), run_check=False
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if input_layout != desired_input_layout:
|
| 124 |
+
input_tensor = input_tensor.redistribute(
|
| 125 |
+
placements=(desired_input_layout,), async_op=True
|
| 126 |
+
)
|
| 127 |
+
return (input_tensor, *inputs[1:])
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
|
| 131 |
+
if outputs.placements != (output_layout,):
|
| 132 |
+
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
|
| 133 |
+
# back to local tensor
|
| 134 |
+
return outputs.to_local() if use_local_output else outputs
|
| 135 |
+
|
| 136 |
+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
| 137 |
+
return distribute_module(
|
| 138 |
+
module,
|
| 139 |
+
device_mesh,
|
| 140 |
+
None,
|
| 141 |
+
partial(
|
| 142 |
+
self._prepare_input_fn, self.input_layout, self.desired_input_layout
|
| 143 |
+
),
|
| 144 |
+
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
|
| 145 |
+
)
|
torchtitan/experiments/llama4/infra/parallelize_llama.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 11 |
+
|
| 12 |
+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
|
| 13 |
+
from torchtitan.distributed import ParallelDims
|
| 14 |
+
|
| 15 |
+
from torchtitan.models.llama3.parallelize_llama import (
|
| 16 |
+
apply_ac,
|
| 17 |
+
apply_compile,
|
| 18 |
+
apply_ddp,
|
| 19 |
+
apply_fsdp,
|
| 20 |
+
apply_tp,
|
| 21 |
+
)
|
| 22 |
+
from torchtitan.tools.logging import logger
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parallelize_llama(
|
| 26 |
+
model: nn.Module,
|
| 27 |
+
world_mesh: DeviceMesh,
|
| 28 |
+
parallel_dims: ParallelDims,
|
| 29 |
+
job_config: JobConfig,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
| 33 |
+
parallelism to the model.
|
| 34 |
+
|
| 35 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
| 36 |
+
the model must fit on GPU or CPU memory.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
if parallel_dims.tp_enabled:
|
| 40 |
+
if (
|
| 41 |
+
job_config.parallelism.enable_async_tensor_parallel
|
| 42 |
+
and not job_config.training.compile
|
| 43 |
+
):
|
| 44 |
+
raise RuntimeError("Async TP requires --training.compile")
|
| 45 |
+
|
| 46 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
| 47 |
+
float8_is_rowwise = job_config.float8.recipe_name in (
|
| 48 |
+
"rowwise",
|
| 49 |
+
"rowwise_with_gw_hp",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# For now, float8 all-gather with TP is only supported for tensorwise
|
| 53 |
+
# float8 scaling recipes. For rowwise recipes, we use regular TP and
|
| 54 |
+
# all-gather happens in high precision.
|
| 55 |
+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
|
| 56 |
+
|
| 57 |
+
apply_tp(
|
| 58 |
+
model,
|
| 59 |
+
world_mesh["tp"],
|
| 60 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
| 61 |
+
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
|
| 62 |
+
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
apply_moe_tp(model, world_mesh["tp"])
|
| 66 |
+
|
| 67 |
+
if job_config.activation_checkpoint.mode != "none":
|
| 68 |
+
if (
|
| 69 |
+
job_config.activation_checkpoint.mode == "selective"
|
| 70 |
+
and job_config.model.use_flex_attn
|
| 71 |
+
):
|
| 72 |
+
raise ValueError(
|
| 73 |
+
"FlexAttention is not compatible with selective AC yet. "
|
| 74 |
+
"See https://github.com/pytorch/pytorch/issues/147879"
|
| 75 |
+
)
|
| 76 |
+
apply_ac(model, job_config.activation_checkpoint)
|
| 77 |
+
|
| 78 |
+
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
|
| 79 |
+
if job_config.training.compile:
|
| 80 |
+
apply_compile(model)
|
| 81 |
+
|
| 82 |
+
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
|
| 83 |
+
torch._dynamo.config.capture_scalar_outputs = True
|
| 84 |
+
|
| 85 |
+
if (
|
| 86 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
| 87 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
| 88 |
+
if parallel_dims.dp_replicate_enabled:
|
| 89 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
| 90 |
+
else:
|
| 91 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
| 92 |
+
|
| 93 |
+
apply_fsdp(
|
| 94 |
+
model,
|
| 95 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
| 96 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
| 97 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
| 98 |
+
pp_enabled=parallel_dims.pp_enabled,
|
| 99 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
| 100 |
+
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if parallel_dims.dp_replicate_enabled:
|
| 104 |
+
logger.info("Applied HSDP to the model")
|
| 105 |
+
else:
|
| 106 |
+
logger.info("Applied FSDP to the model")
|
| 107 |
+
|
| 108 |
+
if parallel_dims.cp_enabled:
|
| 109 |
+
logger.info("Applied Context Parallel to the model")
|
| 110 |
+
|
| 111 |
+
if job_config.training.enable_cpu_offload:
|
| 112 |
+
logger.info("Applied CPU Offloading to the model")
|
| 113 |
+
elif parallel_dims.dp_replicate_enabled:
|
| 114 |
+
if world_mesh.ndim > 1:
|
| 115 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
| 116 |
+
apply_ddp(
|
| 117 |
+
model,
|
| 118 |
+
world_mesh,
|
| 119 |
+
enable_compile=job_config.training.compile,
|
| 120 |
+
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return model
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def apply_moe_tp(
|
| 127 |
+
model: nn.Module,
|
| 128 |
+
tp_mesh: DeviceMesh,
|
| 129 |
+
):
|
| 130 |
+
from torch.distributed.tensor import Partial, Replicate, Shard
|
| 131 |
+
from torch.distributed.tensor.parallel import (
|
| 132 |
+
parallelize_module,
|
| 133 |
+
PrepareModuleInputOutput,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
from .expert_parallel import NoParallel, TensorParallel
|
| 137 |
+
|
| 138 |
+
for _, transformer_block in model.layers.items():
|
| 139 |
+
moe_layer_plan = {
|
| 140 |
+
# input / output sharding on the seqlen dim
|
| 141 |
+
# all-gather for input, reduce-scatter for output
|
| 142 |
+
"moe": PrepareModuleInputOutput(
|
| 143 |
+
input_layouts=(Shard(1),),
|
| 144 |
+
desired_input_layouts=(Replicate(),),
|
| 145 |
+
use_local_input=True,
|
| 146 |
+
output_layouts=(Partial(),),
|
| 147 |
+
desired_output_layouts=(Shard(1),),
|
| 148 |
+
),
|
| 149 |
+
# replicate computation for the router
|
| 150 |
+
"moe.router.gate": NoParallel(),
|
| 151 |
+
# input Replicate, output Partial
|
| 152 |
+
"moe.experts": TensorParallel(),
|
| 153 |
+
"moe.shared_expert": TensorParallel(),
|
| 154 |
+
}
|
| 155 |
+
parallelize_module(
|
| 156 |
+
module=transformer_block,
|
| 157 |
+
device_mesh=tp_mesh,
|
| 158 |
+
parallelize_plan=moe_layer_plan,
|
| 159 |
+
)
|
torchtitan/experiments/llama4/model/__pycache__/args.cpython-311.pyc
ADDED
|
Binary file (4.43 kB). View file
|
|
|
torchtitan/experiments/llama4/model/args.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 13 |
+
from torchtitan.config_manager import JobConfig
|
| 14 |
+
|
| 15 |
+
from torchtitan.protocols.train_spec import BaseModelArgs
|
| 16 |
+
from torchtitan.tools.logging import logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class TransformerModelArgs(BaseModelArgs):
|
| 21 |
+
dim: int = 4096
|
| 22 |
+
n_layers: int = 32
|
| 23 |
+
n_heads: int = 32
|
| 24 |
+
n_kv_heads: Optional[int] = None
|
| 25 |
+
vocab_size: int = -1 # defined later by tokenizer
|
| 26 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
| 27 |
+
ffn_dim_multiplier: Optional[float] = None
|
| 28 |
+
norm_eps: float = 1e-5
|
| 29 |
+
rope_theta: float = 10000
|
| 30 |
+
|
| 31 |
+
max_seq_len: int = 2048
|
| 32 |
+
# If `True`, then each transformer block init uses its layer ID, and if
|
| 33 |
+
# `False`, each uses the total number of transformer blocks
|
| 34 |
+
depth_init: bool = True
|
| 35 |
+
norm_type: str = "rmsnorm"
|
| 36 |
+
|
| 37 |
+
use_flex_attn: bool = False
|
| 38 |
+
attn_mask_type: str = "causal"
|
| 39 |
+
eos_id: int = 0
|
| 40 |
+
|
| 41 |
+
# MoE args
|
| 42 |
+
moe_enabled: bool = True
|
| 43 |
+
num_experts: int = 8
|
| 44 |
+
use_shared_expert: bool = True
|
| 45 |
+
auto_scale_hidden_dim: bool = True
|
| 46 |
+
# frequency of using MoE layer instead of feedforward layer in a transformer block
|
| 47 |
+
interleave_moe_layer_step: int = 2
|
| 48 |
+
# token-choice
|
| 49 |
+
top_k: int = 1
|
| 50 |
+
|
| 51 |
+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
|
| 52 |
+
self.norm_type = job_config.model.norm_type
|
| 53 |
+
self.vocab_size = tokenizer.n_words
|
| 54 |
+
self.max_seq_len = job_config.training.seq_len
|
| 55 |
+
self.use_flex_attn = job_config.model.use_flex_attn
|
| 56 |
+
|
| 57 |
+
def get_nparams_and_flops(
|
| 58 |
+
self, model: nn.Module, seq_len: int
|
| 59 |
+
) -> tuple[int, float]:
|
| 60 |
+
nparams_embedding = 0
|
| 61 |
+
nparams_moe_router = 0
|
| 62 |
+
nparams_shared_expert = 0
|
| 63 |
+
nparams_experts = 0
|
| 64 |
+
nparams_dense = 0
|
| 65 |
+
|
| 66 |
+
for name, p in model.named_parameters():
|
| 67 |
+
if "embedding" in name:
|
| 68 |
+
nparams_embedding += p.numel()
|
| 69 |
+
nparams_dense += p.numel()
|
| 70 |
+
elif "moe.shared_expert" in name:
|
| 71 |
+
nparams_shared_expert += p.numel()
|
| 72 |
+
elif "moe.router" in name:
|
| 73 |
+
nparams_moe_router += p.numel()
|
| 74 |
+
elif "moe.experts" in name:
|
| 75 |
+
nparams_experts += p.numel()
|
| 76 |
+
else:
|
| 77 |
+
nparams_dense += p.numel()
|
| 78 |
+
|
| 79 |
+
nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
|
| 80 |
+
nparams = nparams_dense + nparams_sparse
|
| 81 |
+
nparams_sparse_active = (
|
| 82 |
+
nparams_moe_router
|
| 83 |
+
+ nparams_shared_expert
|
| 84 |
+
+ nparams_experts * self.top_k // self.num_experts
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
logger.info(
|
| 88 |
+
f"Total parameter count: dense {nparams_dense:,}, "
|
| 89 |
+
f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
l, h, q, t = (
|
| 93 |
+
self.n_layers,
|
| 94 |
+
self.n_heads,
|
| 95 |
+
self.dim // self.n_heads,
|
| 96 |
+
seq_len,
|
| 97 |
+
)
|
| 98 |
+
# Reasoning behind the factor of 12 for the self-attention part of the formula:
|
| 99 |
+
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
|
| 100 |
+
# 2. the flash attention does 1 more matmul recomputation in the backward
|
| 101 |
+
# but recomputation should not be counted in calculating MFU (+0)
|
| 102 |
+
# 3. each matmul performs 1 multiplication and 1 addition (*2)
|
| 103 |
+
# 4. we follow the convention and do not account for sparsity in causal attention
|
| 104 |
+
num_flops_per_token = (
|
| 105 |
+
6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
|
| 106 |
+
+ 12 * l * h * q * t
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return nparams, num_flops_per_token
|
torchtitan/experiments/llama4/model/moe.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from .args import TransformerModelArgs
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GroupedExperts(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
dim: int,
|
| 18 |
+
hidden_dim: int,
|
| 19 |
+
num_experts: int,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.num_experts = num_experts
|
| 23 |
+
self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
|
| 24 |
+
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
|
| 25 |
+
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
|
| 26 |
+
|
| 27 |
+
def forward(
|
| 28 |
+
self,
|
| 29 |
+
x: torch.Tensor,
|
| 30 |
+
num_local_tokens_per_expert: torch.Tensor | None = None,
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
if num_local_tokens_per_expert is not None:
|
| 33 |
+
# a tuple of tensors indexed by experts
|
| 34 |
+
# each with shape (tokens_per_expert(varying), dim)
|
| 35 |
+
x = torch.split(
|
| 36 |
+
x,
|
| 37 |
+
split_size_or_sections=num_local_tokens_per_expert.tolist(),
|
| 38 |
+
dim=0,
|
| 39 |
+
)
|
| 40 |
+
out_experts_splits = []
|
| 41 |
+
for expert_idx, x_expert in enumerate(x):
|
| 42 |
+
w1, w2, w3 = (
|
| 43 |
+
self.w1[expert_idx],
|
| 44 |
+
self.w2[expert_idx],
|
| 45 |
+
self.w3[expert_idx],
|
| 46 |
+
)
|
| 47 |
+
h = F.silu(torch.matmul(x_expert, w1))
|
| 48 |
+
h = h * torch.matmul(x_expert, w3)
|
| 49 |
+
h = torch.matmul(h, w2)
|
| 50 |
+
# h shape (tokens_per_expert(varying), dim)
|
| 51 |
+
out_experts_splits.append(h)
|
| 52 |
+
out = torch.cat(out_experts_splits, dim=0)
|
| 53 |
+
|
| 54 |
+
# TODO:optimize with GroupedGEMM
|
| 55 |
+
# https://github.com/pytorch/pytorch/pull/150374
|
| 56 |
+
# _gouped_mm requires shapes to be multiple of 8
|
| 57 |
+
# offsets = torch.cumsum(num_local_tokens_per_expert, dim=0, dtype=torch.int32)
|
| 58 |
+
# h = F.silu(torch._grouped_mm(x, self.w1.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16))
|
| 59 |
+
# h = h * torch._grouped_mm(x, self.w3.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
|
| 60 |
+
# out = torch._grouped_mm(h, self.w2.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
|
| 61 |
+
else:
|
| 62 |
+
# x shape (num_experts, tokens_per_expert, dim)
|
| 63 |
+
h = F.silu(torch.bmm(x, self.w1))
|
| 64 |
+
h = h * torch.bmm(x, self.w3)
|
| 65 |
+
# out shape (num_experts, tokens_per_expert, dim)
|
| 66 |
+
out = torch.bmm(h, self.w2)
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
def init_weights(self, init_std: float):
|
| 70 |
+
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
|
| 71 |
+
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
|
| 72 |
+
nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TokenChoiceTopKRouter(nn.Module):
|
| 76 |
+
"""This class implements token-choice routing. In token-choice top-K routing, each token is
|
| 77 |
+
routed to top K experts based on the router scores.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
|
| 81 |
+
dim (int): Dimension of input tokens.
|
| 82 |
+
num_experts (int): Number of experts in each moe layer.
|
| 83 |
+
top_k (int): Number of experts each token will be routed to in token-choice routing.
|
| 84 |
+
use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
dim: int,
|
| 90 |
+
num_experts: int,
|
| 91 |
+
top_k: int,
|
| 92 |
+
use_sigmoid: bool = False,
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.gate = nn.Linear(dim, num_experts, bias=False)
|
| 96 |
+
self.num_experts = num_experts
|
| 97 |
+
self.top_k = top_k
|
| 98 |
+
self.use_sigmoid = use_sigmoid
|
| 99 |
+
|
| 100 |
+
def forward(
|
| 101 |
+
self, x: torch.Tensor
|
| 102 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 103 |
+
"""
|
| 104 |
+
Args:
|
| 105 |
+
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
routed_input (torch.Tensor):
|
| 109 |
+
Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``.
|
| 110 |
+
token_indices (torch.Tensor):
|
| 111 |
+
Token indices for routed_input with shape ``(bs*slen*top_k,)``.
|
| 112 |
+
num_local_tokens_per_expert (torch.Tensor):
|
| 113 |
+
Number of tokens assigned to each expert with shape ``(num_experts,)``.
|
| 114 |
+
"""
|
| 115 |
+
# scores shape (bs*slen, num_experts)
|
| 116 |
+
scores = self.gate(x)
|
| 117 |
+
|
| 118 |
+
# By default, sigmoid or softmax is performed in float32 to avoid loss explosion
|
| 119 |
+
if self.use_sigmoid:
|
| 120 |
+
scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
|
| 121 |
+
else:
|
| 122 |
+
scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype)
|
| 123 |
+
|
| 124 |
+
# top scores shape (bs*slen, top_k)
|
| 125 |
+
top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1)
|
| 126 |
+
# top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype)
|
| 127 |
+
|
| 128 |
+
# group tokens together by expert indices from 0 to num_experts and pass that to experts forward
|
| 129 |
+
num_local_tokens_per_expert = torch.histc(
|
| 130 |
+
selected_experts_indices.view(-1),
|
| 131 |
+
bins=self.num_experts,
|
| 132 |
+
min=0,
|
| 133 |
+
max=self.num_experts,
|
| 134 |
+
)
|
| 135 |
+
# token_indices_experts_sorted shape (bs*slen*top_k,)
|
| 136 |
+
token_indices_experts_sorted = torch.argsort(
|
| 137 |
+
selected_experts_indices.view(-1), stable=True
|
| 138 |
+
)
|
| 139 |
+
top_scores = top_scores.view(-1)[token_indices_experts_sorted]
|
| 140 |
+
token_indices_experts_sorted = token_indices_experts_sorted // self.top_k
|
| 141 |
+
|
| 142 |
+
return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert
|
| 143 |
+
|
| 144 |
+
def init_weights(self, init_std: float):
|
| 145 |
+
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# TODO: implement load balancing auxiliary loss for token-choice routing
|
| 149 |
+
class MoE(nn.Module):
|
| 150 |
+
def __init__(self, model_args: TransformerModelArgs):
|
| 151 |
+
super().__init__()
|
| 152 |
+
dim = model_args.dim
|
| 153 |
+
hidden_dim = 4 * model_args.dim
|
| 154 |
+
ffn_dim_multiplier = model_args.ffn_dim_multiplier
|
| 155 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 156 |
+
if ffn_dim_multiplier is not None:
|
| 157 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 158 |
+
|
| 159 |
+
num_experts = model_args.num_experts
|
| 160 |
+
|
| 161 |
+
hidden_dim_denom = 1
|
| 162 |
+
if model_args.auto_scale_hidden_dim:
|
| 163 |
+
hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert)
|
| 164 |
+
|
| 165 |
+
if model_args.auto_scale_hidden_dim:
|
| 166 |
+
hidden_dim = int(hidden_dim / hidden_dim_denom)
|
| 167 |
+
hidden_dim += -hidden_dim % model_args.multiple_of
|
| 168 |
+
|
| 169 |
+
self.experts = GroupedExperts(
|
| 170 |
+
dim=dim, hidden_dim=hidden_dim, num_experts=num_experts
|
| 171 |
+
)
|
| 172 |
+
self.router = TokenChoiceTopKRouter(
|
| 173 |
+
dim=dim, num_experts=num_experts, top_k=model_args.top_k
|
| 174 |
+
)
|
| 175 |
+
self.shared_expert = (
|
| 176 |
+
GroupedExperts(dim=dim, hidden_dim=hidden_dim, num_experts=1)
|
| 177 |
+
if model_args.use_shared_expert
|
| 178 |
+
else None
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
"""
|
| 183 |
+
Args:
|
| 184 |
+
x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
|
| 188 |
+
"""
|
| 189 |
+
bs, slen, dim = x.shape
|
| 190 |
+
# top_scores and selected_indices shape (bs*slen*top_k,)
|
| 191 |
+
# num_local_tokens_per_expert shape (num_experts,)
|
| 192 |
+
(
|
| 193 |
+
top_scores,
|
| 194 |
+
token_indices,
|
| 195 |
+
num_local_tokens_per_expert,
|
| 196 |
+
) = self.router(x.reshape(bs * slen, dim))
|
| 197 |
+
|
| 198 |
+
# shape (bs*slen*top_k, dim)
|
| 199 |
+
token_indices = token_indices.reshape(-1, 1).expand(-1, dim)
|
| 200 |
+
|
| 201 |
+
# shape (bs*slen*top_k, dim)
|
| 202 |
+
routed_input = torch.gather(
|
| 203 |
+
x.view(-1, dim),
|
| 204 |
+
dim=0,
|
| 205 |
+
index=token_indices,
|
| 206 |
+
)
|
| 207 |
+
routed_input = routed_input * top_scores.reshape(-1, 1)
|
| 208 |
+
|
| 209 |
+
# shape (bs*slen*top_k, dim)
|
| 210 |
+
routed_output = self.experts(routed_input, num_local_tokens_per_expert)
|
| 211 |
+
|
| 212 |
+
# shared expert
|
| 213 |
+
if self.shared_expert is not None:
|
| 214 |
+
out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
|
| 215 |
+
bs * slen, dim
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
out = torch.zeros_like(x.reshape(bs * slen, dim))
|
| 219 |
+
|
| 220 |
+
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
|
| 221 |
+
out = out.reshape(bs, slen, dim)
|
| 222 |
+
return out
|
| 223 |
+
|
| 224 |
+
def init_weights(self, init_std: float):
|
| 225 |
+
self.experts.init_weights(init_std)
|
| 226 |
+
self.router.init_weights(init_std)
|
| 227 |
+
if self.shared_expert is not None:
|
| 228 |
+
self.shared_expert.init_weights(init_std)
|
torchtitan/experiments/llama4/scripts/REAME.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## How to convert a Llama 4 checkpoint for use in torchtitan
|
| 2 |
+
|
| 3 |
+
To continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager.
|
| 4 |
+
This folder contains the scripts for converting officially released Llama 4 checkpoints into the expected DCP format, from original Meta format, or from HuggingFace format, using GPUs.
|
| 5 |
+
|
| 6 |
+
#### Example usage
|
| 7 |
+
|
| 8 |
+
From Meta format:
|
| 9 |
+
```bash
|
| 10 |
+
CONFIG_FILE=../train_configs/llama4_16.toml ./convert_meta_to_dcp.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
From HuggingFace format:
|
| 15 |
+
```bash
|
| 16 |
+
CONFIG_FILE=../train_configs/llama4_16.toml ./convert_hf_to_dcp_with_gpus.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8
|
| 17 |
+
```
|
torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/bash
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the BSD-style license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
set -ex
|
| 9 |
+
|
| 10 |
+
# use envs as local overrides for convenience
|
| 11 |
+
# e.g.
|
| 12 |
+
# LOG_RANK=0,1 NGPU=4 ./convert_meta_to_dcp_with_gpus.sh
|
| 13 |
+
NGPU=${NGPU:-"8"}
|
| 14 |
+
LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7}
|
| 15 |
+
CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"}
|
| 16 |
+
|
| 17 |
+
overrides=""
|
| 18 |
+
if [ $# -ne 0 ]; then
|
| 19 |
+
overrides="$*"
|
| 20 |
+
fi
|
| 21 |
+
|
| 22 |
+
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
|
| 23 |
+
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
|
| 24 |
+
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
|
| 25 |
+
convert_meta_to_dcp_with_gpus_meta.py --job.config_file ${CONFIG_FILE} $overrides
|
torchtitan/experiments/llama4/train_configs/debug_model.toml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[job]
|
| 2 |
+
dump_folder = "./outputs"
|
| 3 |
+
description = "Llama 4 debug training"
|
| 4 |
+
print_args = false
|
| 5 |
+
use_for_integration_test = true
|
| 6 |
+
|
| 7 |
+
[profiling]
|
| 8 |
+
enable_profiling = false
|
| 9 |
+
save_traces_folder = "profile_trace"
|
| 10 |
+
profile_freq = 10
|
| 11 |
+
enable_memory_snapshot = false
|
| 12 |
+
save_memory_snapshot_folder = "memory_snapshot"
|
| 13 |
+
|
| 14 |
+
[metrics]
|
| 15 |
+
log_freq = 1
|
| 16 |
+
disable_color_printing = false
|
| 17 |
+
enable_tensorboard = false
|
| 18 |
+
save_tb_folder = "tb"
|
| 19 |
+
enable_wandb = false
|
| 20 |
+
|
| 21 |
+
[model]
|
| 22 |
+
name = "llama4"
|
| 23 |
+
flavor = "debugmodel"
|
| 24 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 25 |
+
# test tokenizer.model, for debug purpose only
|
| 26 |
+
tokenizer_path = "./tests/assets/test_tiktoken.model"
|
| 27 |
+
# converters = "float8"
|
| 28 |
+
use_flex_attn = false
|
| 29 |
+
attn_mask_type = "causal" # causal / block_causal
|
| 30 |
+
|
| 31 |
+
[optimizer]
|
| 32 |
+
name = "AdamW"
|
| 33 |
+
lr = 4e-3
|
| 34 |
+
eps = 1e-15
|
| 35 |
+
|
| 36 |
+
[lr_scheduler]
|
| 37 |
+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
|
| 38 |
+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
|
| 39 |
+
decay_type = "linear"
|
| 40 |
+
lr_min = 0.1
|
| 41 |
+
|
| 42 |
+
[training]
|
| 43 |
+
batch_size = 8
|
| 44 |
+
seq_len = 2048
|
| 45 |
+
max_norm = 1.0 # grad norm clipping
|
| 46 |
+
steps = 10
|
| 47 |
+
compile = false
|
| 48 |
+
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
|
| 49 |
+
|
| 50 |
+
[parallelism]
|
| 51 |
+
data_parallel_replicate_degree = 1
|
| 52 |
+
data_parallel_shard_degree = -1
|
| 53 |
+
fsdp_reshard_after_forward = "default" # default / never / always
|
| 54 |
+
tensor_parallel_degree = 1
|
| 55 |
+
enable_async_tensor_parallel = false
|
| 56 |
+
pipeline_parallel_degree = 1
|
| 57 |
+
context_parallel_degree = 1
|
| 58 |
+
|
| 59 |
+
[checkpoint]
|
| 60 |
+
enable_checkpoint = false
|
| 61 |
+
folder = "checkpoint"
|
| 62 |
+
interval = 10
|
| 63 |
+
model_weights_only = false
|
| 64 |
+
export_dtype = "float32"
|
| 65 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
| 66 |
+
|
| 67 |
+
[activation_checkpoint]
|
| 68 |
+
mode = 'none' # ['none', 'selective', 'full']
|
| 69 |
+
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy
|
| 70 |
+
|
| 71 |
+
[float8]
|
| 72 |
+
enable_fsdp_float8_all_gather = false
|
| 73 |
+
precompute_float8_dynamic_scale_for_fsdp = false
|
| 74 |
+
filter_fqns = "output,router.gate"
|
torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TODO: this toml config is still under development
|
| 2 |
+
|
| 3 |
+
[job]
|
| 4 |
+
dump_folder = "./outputs"
|
| 5 |
+
description = "Llama 4 Maverick 17Bx128E training"
|
| 6 |
+
|
| 7 |
+
[profiling]
|
| 8 |
+
enable_profiling = false
|
| 9 |
+
save_traces_folder = "profile_trace"
|
| 10 |
+
profile_freq = 100
|
| 11 |
+
|
| 12 |
+
[metrics]
|
| 13 |
+
log_freq = 10
|
| 14 |
+
enable_tensorboard = false
|
| 15 |
+
save_tb_folder = "tb"
|
| 16 |
+
|
| 17 |
+
[model]
|
| 18 |
+
name = "llama4"
|
| 19 |
+
flavor = "17bx128e"
|
| 20 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 21 |
+
tokenizer_path = "./assets/tokenizer/tokenizer.model"
|
| 22 |
+
# converters = "float8"
|
| 23 |
+
|
| 24 |
+
[optimizer]
|
| 25 |
+
name = "AdamW"
|
| 26 |
+
lr = 4e-3
|
| 27 |
+
eps = 1e-15
|
| 28 |
+
|
| 29 |
+
[lr_scheduler]
|
| 30 |
+
warmup_steps = 600
|
| 31 |
+
lr_min = 0.1
|
| 32 |
+
|
| 33 |
+
[training]
|
| 34 |
+
batch_size = 1
|
| 35 |
+
seq_len = 8192
|
| 36 |
+
max_norm = 1.0 # grad norm clipping
|
| 37 |
+
steps = 3000
|
| 38 |
+
compile = false
|
| 39 |
+
dataset = "c4"
|
| 40 |
+
|
| 41 |
+
[parallelism]
|
| 42 |
+
data_parallel_replicate_degree = 1
|
| 43 |
+
data_parallel_shard_degree = -1
|
| 44 |
+
tensor_parallel_degree = 8
|
| 45 |
+
enable_async_tensor_parallel = false
|
| 46 |
+
pipeline_parallel_degree = 4
|
| 47 |
+
# pipeline_parallel_schedule = "interleaved1f1b"
|
| 48 |
+
# pipeline_parallel_microbatches = 2
|
| 49 |
+
context_parallel_degree = 1
|
| 50 |
+
|
| 51 |
+
[checkpoint]
|
| 52 |
+
enable_checkpoint = false
|
| 53 |
+
folder = "checkpoint"
|
| 54 |
+
interval = 500
|
| 55 |
+
model_weights_only = false
|
| 56 |
+
export_dtype = "float32"
|
| 57 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
| 58 |
+
|
| 59 |
+
[activation_checkpoint]
|
| 60 |
+
mode = 'full' # ['none', 'selective', 'full']
|
| 61 |
+
|
| 62 |
+
[float8]
|
| 63 |
+
enable_fsdp_float8_all_gather = false
|
| 64 |
+
precompute_float8_dynamic_scale_for_fsdp = false
|
| 65 |
+
filter_fqns = "output,router.gate"
|
torchtitan/experiments/multimodal/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
torchvision
|
torchtitan/experiments/multimodal/tests/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
torchtitan/experiments/multimodal/tests/test_utils.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
from typing import Optional, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def fixed_init_tensor(
|
| 16 |
+
shape: torch.Size,
|
| 17 |
+
min_val: Union[float, int] = 0.0,
|
| 18 |
+
max_val: Union[float, int] = 1.0,
|
| 19 |
+
nonlinear: bool = False,
|
| 20 |
+
dtype: torch.dtype = torch.float,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Utility for generating deterministic tensors of a given shape. In general stuff
|
| 24 |
+
like torch.ones, torch.eye, etc can result in trivial outputs. This utility
|
| 25 |
+
generates a range tensor [min_val, max_val) of a specified dtype, applies
|
| 26 |
+
a sine function if nonlinear=True, then reshapes to the appropriate shape.
|
| 27 |
+
"""
|
| 28 |
+
n_elements = math.prod(shape)
|
| 29 |
+
step_size = (max_val - min_val) / n_elements
|
| 30 |
+
x = torch.arange(min_val, max_val, step_size, dtype=dtype)
|
| 31 |
+
x = x.reshape(shape)
|
| 32 |
+
if nonlinear:
|
| 33 |
+
return torch.sin(x)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@torch.no_grad
|
| 38 |
+
def fixed_init_model(
|
| 39 |
+
model: nn.Module,
|
| 40 |
+
min_val: Union[float, int] = 0.0,
|
| 41 |
+
max_val: Union[float, int] = 1.0,
|
| 42 |
+
nonlinear: bool = False,
|
| 43 |
+
dtype: Optional[torch.dtype] = None,
|
| 44 |
+
):
|
| 45 |
+
"""
|
| 46 |
+
This utility initializes all parameters of a model deterministically using the
|
| 47 |
+
function fixed_init_tensor above. See that docstring for details of each parameter.
|
| 48 |
+
"""
|
| 49 |
+
for _, param in model.named_parameters():
|
| 50 |
+
param.copy_(
|
| 51 |
+
fixed_init_tensor(
|
| 52 |
+
param.shape,
|
| 53 |
+
min_val=min_val,
|
| 54 |
+
max_val=max_val,
|
| 55 |
+
nonlinear=nonlinear,
|
| 56 |
+
dtype=param.dtype if dtype is None else dtype,
|
| 57 |
+
)
|
| 58 |
+
)
|
torchtitan/experiments/multimodal/tokenizer/tiktoken.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 8 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import (
|
| 13 |
+
AbstractSet,
|
| 14 |
+
Any,
|
| 15 |
+
cast,
|
| 16 |
+
Collection,
|
| 17 |
+
Dict,
|
| 18 |
+
Iterator,
|
| 19 |
+
List,
|
| 20 |
+
Literal,
|
| 21 |
+
Mapping,
|
| 22 |
+
Optional,
|
| 23 |
+
Sequence,
|
| 24 |
+
Union,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
import tiktoken
|
| 28 |
+
import torch
|
| 29 |
+
from tiktoken.load import load_tiktoken_bpe
|
| 30 |
+
|
| 31 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 32 |
+
from torchtitan.config_manager import JobConfig
|
| 33 |
+
from torchtitan.tools.logging import logger
|
| 34 |
+
|
| 35 |
+
IMAGE_TOKEN_ID = 128256
|
| 36 |
+
IGNORE_INDEX = -100
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class TikTokenizer(Tokenizer):
|
| 40 |
+
"""
|
| 41 |
+
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
model_path (str): The path to the Tiktoken model file.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
special_tokens: Dict[str, int]
|
| 48 |
+
|
| 49 |
+
num_reserved_special_tokens = 256
|
| 50 |
+
|
| 51 |
+
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950
|
| 52 |
+
|
| 53 |
+
def __init__(self, model_path: str):
|
| 54 |
+
super().__init__(model_path)
|
| 55 |
+
assert os.path.isfile(model_path), model_path
|
| 56 |
+
|
| 57 |
+
mergeable_ranks = load_tiktoken_bpe(model_path)
|
| 58 |
+
num_base_tokens = len(mergeable_ranks)
|
| 59 |
+
special_tokens = [
|
| 60 |
+
"<|begin_of_text|>",
|
| 61 |
+
"<|end_of_text|>",
|
| 62 |
+
"<|reserved_special_token_0|>",
|
| 63 |
+
"<|reserved_special_token_1|>",
|
| 64 |
+
"<|reserved_special_token_2|>",
|
| 65 |
+
"<|reserved_special_token_3|>",
|
| 66 |
+
"<|start_header_id|>",
|
| 67 |
+
"<|end_header_id|>",
|
| 68 |
+
"<|reserved_special_token_4|>",
|
| 69 |
+
"<|eot_id|>", # end of turn
|
| 70 |
+
] + [
|
| 71 |
+
f"<|reserved_special_token_{i}|>"
|
| 72 |
+
for i in range(5, self.num_reserved_special_tokens - 5)
|
| 73 |
+
]
|
| 74 |
+
self.special_tokens = {
|
| 75 |
+
token: num_base_tokens + i for i, token in enumerate(special_tokens)
|
| 76 |
+
}
|
| 77 |
+
self.special_tokens["<|image|>"] = IMAGE_TOKEN_ID
|
| 78 |
+
self.model = tiktoken.Encoding(
|
| 79 |
+
name=Path(model_path).name,
|
| 80 |
+
pat_str=self.pat_str,
|
| 81 |
+
mergeable_ranks=mergeable_ranks,
|
| 82 |
+
special_tokens=self.special_tokens,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self._n_words: int = self.model.n_vocab
|
| 86 |
+
# BOS / EOS token IDs
|
| 87 |
+
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
| 88 |
+
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
| 89 |
+
self.pad_id: int = -1
|
| 90 |
+
self.image_id = IMAGE_TOKEN_ID
|
| 91 |
+
self.stop_tokens = {
|
| 92 |
+
self.special_tokens["<|end_of_text|>"],
|
| 93 |
+
self.special_tokens["<|eot_id|>"],
|
| 94 |
+
}
|
| 95 |
+
logger.info(
|
| 96 |
+
f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}, IMAGE ID {self.image_id}"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def encode(
|
| 100 |
+
self,
|
| 101 |
+
s: str,
|
| 102 |
+
*,
|
| 103 |
+
bos: bool,
|
| 104 |
+
eos: bool,
|
| 105 |
+
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
| 106 |
+
disallowed_special: Optional[Union[Literal["all"], Collection[str]]] = None,
|
| 107 |
+
) -> List[int]:
|
| 108 |
+
"""
|
| 109 |
+
Encodes a string into a list of token IDs.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
s (str): The input string to be encoded.
|
| 113 |
+
bos (bool): Whether to prepend the beginning-of-sequence token.
|
| 114 |
+
eos (bool): Whether to append the end-of-sequence token.
|
| 115 |
+
allowed_tokens ("all"|set[str]): allowed special tokens in string
|
| 116 |
+
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
list[int]: A list of token IDs.
|
| 120 |
+
|
| 121 |
+
By default, setting disallowed_special=() encodes a string by ignoring
|
| 122 |
+
special tokens. Specifically:
|
| 123 |
+
- Setting `disallowed_special` to () will cause all text corresponding
|
| 124 |
+
to special tokens to be encoded as natural text (insteading of raising
|
| 125 |
+
an error).
|
| 126 |
+
- Setting `allowed_special` to "all" will treat all text corresponding
|
| 127 |
+
to special tokens to be encoded as special tokens.
|
| 128 |
+
"""
|
| 129 |
+
assert type(s) is str
|
| 130 |
+
allowed_special = allowed_special or set()
|
| 131 |
+
disallowed_special = disallowed_special or ()
|
| 132 |
+
|
| 133 |
+
# The tiktoken tokenizer can handle <=400k chars without
|
| 134 |
+
# pyo3_runtime.PanicException.
|
| 135 |
+
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
| 136 |
+
|
| 137 |
+
# https://github.com/openai/tiktoken/issues/195
|
| 138 |
+
# Here we iterate over subsequences and split if we exceed the limit
|
| 139 |
+
# of max consecutive non-whitespace or whitespace characters.
|
| 140 |
+
MAX_NO_WHITESPACES_CHARS = 25_000
|
| 141 |
+
|
| 142 |
+
substrs = (
|
| 143 |
+
substr
|
| 144 |
+
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
| 145 |
+
for substr in self._split_whitespaces_or_nonwhitespaces(
|
| 146 |
+
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
| 147 |
+
)
|
| 148 |
+
)
|
| 149 |
+
t: List[int] = []
|
| 150 |
+
for substr in substrs:
|
| 151 |
+
t.extend(
|
| 152 |
+
self.model.encode(
|
| 153 |
+
substr,
|
| 154 |
+
allowed_special=allowed_special,
|
| 155 |
+
disallowed_special=disallowed_special,
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
if bos:
|
| 159 |
+
t.insert(0, self.bos_id)
|
| 160 |
+
if eos:
|
| 161 |
+
t.append(self.eos_id)
|
| 162 |
+
return t
|
| 163 |
+
|
| 164 |
+
def decode(self, t: Sequence[int]) -> str:
|
| 165 |
+
"""
|
| 166 |
+
Decodes a list of token IDs into a string.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
t (List[int]): The list of token IDs to be decoded.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
str: The decoded string.
|
| 173 |
+
"""
|
| 174 |
+
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
| 175 |
+
return self.model.decode(cast(List[int], t))
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def _split_whitespaces_or_nonwhitespaces(
|
| 179 |
+
s: str, max_consecutive_slice_len: int
|
| 180 |
+
) -> Iterator[str]:
|
| 181 |
+
"""
|
| 182 |
+
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
| 183 |
+
consecutive whitespaces or consecutive non-whitespaces.
|
| 184 |
+
"""
|
| 185 |
+
current_slice_len = 0
|
| 186 |
+
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
| 187 |
+
slice_start = 0
|
| 188 |
+
|
| 189 |
+
for i in range(len(s)):
|
| 190 |
+
is_now_space = s[i].isspace()
|
| 191 |
+
|
| 192 |
+
if current_slice_is_space ^ is_now_space:
|
| 193 |
+
current_slice_len = 1
|
| 194 |
+
current_slice_is_space = is_now_space
|
| 195 |
+
else:
|
| 196 |
+
current_slice_len += 1
|
| 197 |
+
if current_slice_len > max_consecutive_slice_len:
|
| 198 |
+
yield s[slice_start:i]
|
| 199 |
+
slice_start = i
|
| 200 |
+
current_slice_len = 1
|
| 201 |
+
yield s[slice_start:]
|
| 202 |
+
|
| 203 |
+
def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]:
|
| 204 |
+
"""
|
| 205 |
+
Tokenizes a `str` of text and creates `labels` masking BOS, EOS and `image_id` tokens.
|
| 206 |
+
"""
|
| 207 |
+
# TODO(tj.solergibert) Should we keep `input_ids` OR `tokens` across this class, VisionCrossAttentionMask & the collator?
|
| 208 |
+
# For me it makes more sense to split `tokens` between `input_ids` & `labels` as in train.py BUT the `MultimodalDecoder`
|
| 209 |
+
# & everything else expects `tokens`
|
| 210 |
+
text = sample["text"]
|
| 211 |
+
tokens = self.encode(
|
| 212 |
+
text, bos=True, eos=True, allowed_special=set(["<|image|>"])
|
| 213 |
+
)
|
| 214 |
+
input_ids = torch.LongTensor(tokens[:-1])
|
| 215 |
+
labels = torch.LongTensor(tokens[1:])
|
| 216 |
+
labels = torch.where(
|
| 217 |
+
torch.isin(
|
| 218 |
+
labels, torch.LongTensor([self.bos_id, self.eos_id, self.image_id])
|
| 219 |
+
),
|
| 220 |
+
IGNORE_INDEX,
|
| 221 |
+
labels,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
assert len(input_ids) == len(labels) # TODO(tj.solergibert) Delete
|
| 225 |
+
|
| 226 |
+
sample.update({"tokens": input_ids, "labels": labels})
|
| 227 |
+
|
| 228 |
+
return sample
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer:
|
| 232 |
+
return TikTokenizer(job_config.model.tokenizer_path)
|
torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.33 kB). View file
|
|
|
torchtitan/experiments/simple_fsdp/tests/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
torchtitan/models/llama3/__pycache__/pipeline_llama.cpython-311.pyc
ADDED
|
Binary file (5.96 kB). View file
|
|
|
torchtitan/models/llama3/train_configs/debug_model.toml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# torchtitan Config.toml
|
| 2 |
+
|
| 3 |
+
[job]
|
| 4 |
+
dump_folder = "./outputs"
|
| 5 |
+
description = "Llama 3 debug training"
|
| 6 |
+
print_args = false
|
| 7 |
+
use_for_integration_test = true
|
| 8 |
+
|
| 9 |
+
[profiling]
|
| 10 |
+
enable_profiling = false
|
| 11 |
+
save_traces_folder = "profile_trace"
|
| 12 |
+
profile_freq = 10
|
| 13 |
+
enable_memory_snapshot = false
|
| 14 |
+
save_memory_snapshot_folder = "memory_snapshot"
|
| 15 |
+
|
| 16 |
+
[metrics]
|
| 17 |
+
log_freq = 1
|
| 18 |
+
disable_color_printing = false
|
| 19 |
+
enable_tensorboard = false
|
| 20 |
+
save_tb_folder = "tb"
|
| 21 |
+
enable_wandb = false
|
| 22 |
+
|
| 23 |
+
[model]
|
| 24 |
+
name = "llama3"
|
| 25 |
+
flavor = "debugmodel"
|
| 26 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 27 |
+
# test tokenizer.model, for debug purpose only
|
| 28 |
+
tokenizer_path = "./tests/assets/test_tiktoken.model"
|
| 29 |
+
# converters = "float8"
|
| 30 |
+
|
| 31 |
+
[optimizer]
|
| 32 |
+
name = "AdamW"
|
| 33 |
+
lr = 8e-4
|
| 34 |
+
eps = 1e-8
|
| 35 |
+
|
| 36 |
+
[lr_scheduler]
|
| 37 |
+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
|
| 38 |
+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
|
| 39 |
+
decay_type = "linear"
|
| 40 |
+
lr_min = 0.0
|
| 41 |
+
|
| 42 |
+
[training]
|
| 43 |
+
batch_size = 8
|
| 44 |
+
seq_len = 2048
|
| 45 |
+
max_norm = 1.0 # grad norm clipping
|
| 46 |
+
steps = 10
|
| 47 |
+
compile = false
|
| 48 |
+
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
|
| 49 |
+
|
| 50 |
+
[parallelism]
|
| 51 |
+
data_parallel_replicate_degree = 1
|
| 52 |
+
data_parallel_shard_degree = -1
|
| 53 |
+
fsdp_reshard_after_forward = "default" # default / never / always
|
| 54 |
+
tensor_parallel_degree = 1
|
| 55 |
+
enable_async_tensor_parallel = false
|
| 56 |
+
pipeline_parallel_degree = 1
|
| 57 |
+
context_parallel_degree = 1
|
| 58 |
+
|
| 59 |
+
[checkpoint]
|
| 60 |
+
enable_checkpoint = false
|
| 61 |
+
folder = "checkpoint"
|
| 62 |
+
interval = 10
|
| 63 |
+
model_weights_only = false
|
| 64 |
+
export_dtype = "float32"
|
| 65 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
| 66 |
+
|
| 67 |
+
[activation_checkpoint]
|
| 68 |
+
mode = 'selective' # ['none', 'selective', 'full']
|
| 69 |
+
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy
|
| 70 |
+
|
| 71 |
+
[float8]
|
| 72 |
+
enable_fsdp_float8_all_gather = false
|
| 73 |
+
precompute_float8_dynamic_scale_for_fsdp = false
|
| 74 |
+
filter_fqns = "output"
|