Add files using upload-large-folder tool
Browse files- profile_trace/iteration_10752/rank5_trace.json +0 -0
- profile_trace/iteration_11776/rank3_trace.json +0 -0
- profile_trace/iteration_11776/rank4_trace.json +0 -0
- profile_trace/iteration_11776/rank7_trace.json +0 -0
- profile_trace/iteration_12288/rank5_trace.json +0 -0
- profile_trace/iteration_13824/rank2_trace.json +0 -0
- profile_trace/iteration_13824/rank4_trace.json +0 -0
- profile_trace/iteration_13824/rank6_trace.json +0 -0
- profile_trace/iteration_14848/rank2_trace.json +0 -0
- profile_trace/iteration_14848/rank4_trace.json +0 -0
- profile_trace/iteration_14848/rank7_trace.json +0 -0
- profile_trace/iteration_21504/rank0_trace.json +0 -0
- profile_trace/iteration_21504/rank3_trace.json +0 -0
- profile_trace/iteration_28160/rank2_trace.json +0 -0
- profile_trace/iteration_28160/rank4_trace.json +0 -0
- profile_trace/iteration_28160/rank6_trace.json +0 -0
- profile_trace/iteration_31744/rank3_trace.json +0 -0
- profile_trace/iteration_33792/rank1_trace.json +0 -0
- profile_trace/iteration_33792/rank2_trace.json +0 -0
- profile_trace/iteration_33792/rank4_trace.json +0 -0
- profile_trace/iteration_33792/rank5_trace.json +0 -0
- profile_trace/iteration_33792/rank6_trace.json +0 -0
- profile_trace/iteration_33792/rank7_trace.json +0 -0
- profile_trace/iteration_512/rank0_trace.json +0 -0
- profile_trace/iteration_512/rank1_trace.json +0 -0
- profile_trace/iteration_512/rank3_trace.json +0 -0
- profile_trace/iteration_8192/rank0_trace.json +0 -0
- profile_trace/iteration_8192/rank1_trace.json +0 -0
- profile_trace/iteration_8192/rank5_trace.json +0 -0
- torchtitan/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/dataloader.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/float8.cpython-312.pyc +0 -0
- torchtitan/components/loss.py +29 -0
- torchtitan/experiments/deepseek_v3/indices.py +195 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
- torchtitan/experiments/flux/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/experiments/flux/__pycache__/utils.cpython-312.pyc +0 -0
- torchtitan/experiments/flux/dataset/__pycache__/tokenizer.cpython-312.pyc +0 -0
- torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
- torchtitan/experiments/flux/model/__pycache__/hf_embedder.cpython-312.pyc +0 -0
- torchtitan/experiments/flux/model/__pycache__/layers.cpython-312.pyc +0 -0
- torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_backwards.py +174 -0
- torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc +0 -0
- torchtitan/experiments/llama4/model/args.py +109 -0
- torchtitan/experiments/multimodal/__init__.py +37 -0
- torchtitan/experiments/multimodal/tests/test_multimodal_model.py +128 -0
- torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/models/__pycache__/norms.cpython-312.pyc +0 -0
profile_trace/iteration_10752/rank5_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_11776/rank3_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_11776/rank4_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_11776/rank7_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_12288/rank5_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_13824/rank2_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_13824/rank4_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_13824/rank6_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_14848/rank2_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_14848/rank4_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_14848/rank7_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_21504/rank0_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_21504/rank3_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_28160/rank2_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_28160/rank4_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_28160/rank6_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_31744/rank3_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_33792/rank1_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_33792/rank2_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_33792/rank4_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_33792/rank5_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_33792/rank6_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_33792/rank7_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_512/rank0_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_512/rank1_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_512/rank3_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_8192/rank0_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_8192/rank1_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_8192/rank5_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
torchtitan/components/__pycache__/checkpoint.cpython-312.pyc
ADDED
|
Binary file (33.1 kB). View file
|
|
|
torchtitan/components/__pycache__/dataloader.cpython-312.pyc
ADDED
|
Binary file (3.79 kB). View file
|
|
|
torchtitan/components/__pycache__/float8.cpython-312.pyc
ADDED
|
Binary file (6.2 kB). View file
|
|
|
torchtitan/components/loss.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable, TypeAlias
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torchtitan.config_manager import JobConfig
|
| 12 |
+
from torchtitan.tools.logging import logger
|
| 13 |
+
|
| 14 |
+
LossFunction: TypeAlias = Callable[..., torch.Tensor]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
"""Common cross-entropy loss function for Transformer models training."""
|
| 19 |
+
return torch.nn.functional.cross_entropy(
|
| 20 |
+
pred.flatten(0, 1).float(), labels.flatten(0, 1)
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def build_cross_entropy_loss(job_config: JobConfig):
|
| 25 |
+
loss_fn = cross_entropy_loss
|
| 26 |
+
if job_config.training.compile:
|
| 27 |
+
logger.info("Compiling the loss function with torch.compile")
|
| 28 |
+
loss_fn = torch.compile(loss_fn)
|
| 29 |
+
return loss_fn
|
torchtitan/experiments/deepseek_v3/indices.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__all__ = ["generate_permute_indices"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@triton.jit
|
| 16 |
+
def fill_indices_kernel(
|
| 17 |
+
tokens_per_expert_group_ptr, # *Pointer* to first input vector.
|
| 18 |
+
start_index_values_ptr, # *Pointer* to second input vector.
|
| 19 |
+
write_offsets_ptr, # *Pointer* to third input vector.
|
| 20 |
+
output_ptr, # *Pointer* to output vector.
|
| 21 |
+
experts_per_rank, # Number of experts per rank.
|
| 22 |
+
num_ranks, # Number of expert ranks.
|
| 23 |
+
):
|
| 24 |
+
# There are multiple 'programs' processing different data. We identify which program
|
| 25 |
+
# we are here:
|
| 26 |
+
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
|
| 27 |
+
# The total number of programs in the launch grid.
|
| 28 |
+
num_programs = tl.num_programs(axis=0)
|
| 29 |
+
# We map the programs (blocks) to the experts.
|
| 30 |
+
for expert_id in tl.range(pid, experts_per_rank, step=num_programs):
|
| 31 |
+
# Read this expert's write offset.
|
| 32 |
+
write_offset = tl.load(write_offsets_ptr + expert_id)
|
| 33 |
+
# Loop over the ranks.
|
| 34 |
+
for r in tl.range(num_ranks):
|
| 35 |
+
# Slot in the tokens_per_expert_group array.
|
| 36 |
+
i = r * experts_per_rank + expert_id
|
| 37 |
+
start_index = tl.load(start_index_values_ptr + i)
|
| 38 |
+
length = tl.load(tokens_per_expert_group_ptr + i)
|
| 39 |
+
# Write the indices.
|
| 40 |
+
for l in tl.range(length):
|
| 41 |
+
val = start_index + l
|
| 42 |
+
tl.store(output_ptr + write_offset + l, val)
|
| 43 |
+
write_offset += length
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def fill_indices(
|
| 47 |
+
tokens_per_expert_group: torch.Tensor,
|
| 48 |
+
start_index_values: torch.Tensor,
|
| 49 |
+
write_offsets: torch.Tensor,
|
| 50 |
+
experts_per_rank: int,
|
| 51 |
+
num_ranks: int,
|
| 52 |
+
max_len: int,
|
| 53 |
+
):
|
| 54 |
+
# We need to preallocate the output.
|
| 55 |
+
permuted_indices = torch.full(
|
| 56 |
+
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
|
| 57 |
+
)
|
| 58 |
+
# Analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
|
| 59 |
+
# In this case, we use a 1D grid where the size is the number of blocks (TODO: bump this value).
|
| 60 |
+
grid = lambda meta: (1,)
|
| 61 |
+
# Each torch.tensor object is implicitly converted into a pointer to its first element.
|
| 62 |
+
fill_indices_kernel[grid](
|
| 63 |
+
tokens_per_expert_group,
|
| 64 |
+
start_index_values,
|
| 65 |
+
write_offsets,
|
| 66 |
+
permuted_indices,
|
| 67 |
+
experts_per_rank,
|
| 68 |
+
num_ranks,
|
| 69 |
+
)
|
| 70 |
+
return permuted_indices
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def fill_indices_cpu(
|
| 74 |
+
tokens_per_expert_group: torch.Tensor,
|
| 75 |
+
start_index_values: torch.Tensor,
|
| 76 |
+
write_offsets: torch.Tensor,
|
| 77 |
+
experts_per_rank: int,
|
| 78 |
+
num_ranks: int,
|
| 79 |
+
max_len: int,
|
| 80 |
+
):
|
| 81 |
+
# We need to preallocate the output.
|
| 82 |
+
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
|
| 83 |
+
# Fill the permuted indices
|
| 84 |
+
# For each local expert
|
| 85 |
+
for e in range(experts_per_rank):
|
| 86 |
+
write_start = write_offsets[e]
|
| 87 |
+
# For each remote rank
|
| 88 |
+
for r in range(num_ranks):
|
| 89 |
+
i = r * experts_per_rank + e
|
| 90 |
+
start_index = start_index_values[i]
|
| 91 |
+
length = tokens_per_expert_group[i]
|
| 92 |
+
# Fill in the indices
|
| 93 |
+
permuted_indices[write_start : write_start + length] = torch.arange(
|
| 94 |
+
start_index, start_index + length
|
| 95 |
+
)
|
| 96 |
+
write_start += length
|
| 97 |
+
return permuted_indices
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def generate_permute_indices(
|
| 101 |
+
tokens_per_expert_group: torch.Tensor,
|
| 102 |
+
experts_per_rank: int,
|
| 103 |
+
num_ranks: int,
|
| 104 |
+
max_len: int,
|
| 105 |
+
alignment: int,
|
| 106 |
+
use_cpu: bool = False,
|
| 107 |
+
):
|
| 108 |
+
# Prepare permutation indices and the number of tokens for each expert. The
|
| 109 |
+
# permutation indices are the indices of the tokens for each expert. The
|
| 110 |
+
# number of tokens for each expert is the sum of the number of tokens for
|
| 111 |
+
# such experts from all ranks. This number is aligned to the provided
|
| 112 |
+
# alignment requirement (usually comes from group gemm).
|
| 113 |
+
|
| 114 |
+
# Args:
|
| 115 |
+
# tokens_per_expert_group: number of tokens for each expert from all ranks.
|
| 116 |
+
# experts_per_rank: number of experts per rank.
|
| 117 |
+
# num_ranks: number of ranks.
|
| 118 |
+
# max_len: maximum length of the output index vector. If greater than
|
| 119 |
+
# total number of tokens, the remaining indices are set to -1.
|
| 120 |
+
# alignment: alignment for each returned element in `m_sizes`.
|
| 121 |
+
# use_cpu: whether to use cpu or gpu.
|
| 122 |
+
# Returns:
|
| 123 |
+
# permuted_indices: permutation indices.
|
| 124 |
+
# m_sizes: number of tokens for each expert.
|
| 125 |
+
|
| 126 |
+
# `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example:
|
| 127 |
+
# From: | rank 0 | rank 1 |
|
| 128 |
+
# To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 |
|
| 129 |
+
# | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
|
| 130 |
+
|
| 131 |
+
# Prefix sum to get the start index value of each expert
|
| 132 |
+
start_index_values = (
|
| 133 |
+
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
|
| 134 |
+
)
|
| 135 |
+
# Chunk sizes for each expert
|
| 136 |
+
chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
|
| 137 |
+
# Align the chunk sizes to the given alignment
|
| 138 |
+
m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to(
|
| 139 |
+
torch.int32
|
| 140 |
+
)
|
| 141 |
+
# Perform another prefix sum to get the write offset of each expert in `permuted_indices`
|
| 142 |
+
write_offsets = torch.cumsum(m_sizes, 0) - m_sizes
|
| 143 |
+
# Select the method to fill the permuted indices
|
| 144 |
+
fill_fn = fill_indices_cpu if use_cpu else fill_indices
|
| 145 |
+
# Fill the permuted indices
|
| 146 |
+
permuted_indices = fill_fn(
|
| 147 |
+
tokens_per_expert_group,
|
| 148 |
+
start_index_values,
|
| 149 |
+
write_offsets,
|
| 150 |
+
experts_per_rank,
|
| 151 |
+
num_ranks,
|
| 152 |
+
max_len,
|
| 153 |
+
)
|
| 154 |
+
return permuted_indices, m_sizes
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Below is for testing only
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def test():
|
| 161 |
+
device = torch.device("cuda", 0)
|
| 162 |
+
experts_per_rank = 4
|
| 163 |
+
num_ranks = 4
|
| 164 |
+
tokens_per_expert_group = torch.full(
|
| 165 |
+
(num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device
|
| 166 |
+
)
|
| 167 |
+
max_len = 128
|
| 168 |
+
alignment = 32
|
| 169 |
+
# Use the GPU kernel
|
| 170 |
+
permuted_indices_gpu, m_sizes = generate_permute_indices(
|
| 171 |
+
tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
|
| 172 |
+
)
|
| 173 |
+
# Use the CPU method
|
| 174 |
+
permuted_indices_cpu, _ = generate_permute_indices(
|
| 175 |
+
tokens_per_expert_group,
|
| 176 |
+
experts_per_rank,
|
| 177 |
+
num_ranks,
|
| 178 |
+
max_len,
|
| 179 |
+
alignment,
|
| 180 |
+
use_cpu=True,
|
| 181 |
+
)
|
| 182 |
+
# Check that the results are the same
|
| 183 |
+
assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu)
|
| 184 |
+
assert torch.equal(
|
| 185 |
+
torch.remainder(m_sizes, alignment),
|
| 186 |
+
torch.zeros(experts_per_rank, device=device),
|
| 187 |
+
)
|
| 188 |
+
# Print the results
|
| 189 |
+
print(permuted_indices_gpu)
|
| 190 |
+
print(m_sizes)
|
| 191 |
+
print("Success")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
test()
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from .triton_utils import get_flat_bid, get_flat_tid
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@triton.jit
|
| 14 |
+
def send_signal(addrs, sem: tl.constexpr):
|
| 15 |
+
if sem == "relaxed":
|
| 16 |
+
tl.inline_asm_elementwise(
|
| 17 |
+
"""
|
| 18 |
+
{
|
| 19 |
+
.reg .u32 %tmp32_<1>;
|
| 20 |
+
.reg .pred %p<1>;
|
| 21 |
+
|
| 22 |
+
send_signal:
|
| 23 |
+
atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
| 24 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
| 25 |
+
@!%p0 bra send_signal;
|
| 26 |
+
}
|
| 27 |
+
""",
|
| 28 |
+
"=r, l",
|
| 29 |
+
[addrs],
|
| 30 |
+
dtype=tl.int32,
|
| 31 |
+
is_pure=False,
|
| 32 |
+
pack=1,
|
| 33 |
+
)
|
| 34 |
+
elif sem == "acq_rel":
|
| 35 |
+
tl.inline_asm_elementwise(
|
| 36 |
+
"""
|
| 37 |
+
{
|
| 38 |
+
.reg .u32 %tmp32_<1>;
|
| 39 |
+
.reg .pred %p<1>;
|
| 40 |
+
|
| 41 |
+
send_signal:
|
| 42 |
+
atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
| 43 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
| 44 |
+
@!%p0 bra send_signal;
|
| 45 |
+
}
|
| 46 |
+
""",
|
| 47 |
+
"=r, l",
|
| 48 |
+
[addrs],
|
| 49 |
+
dtype=tl.int32,
|
| 50 |
+
is_pure=False,
|
| 51 |
+
pack=1,
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@triton.jit
|
| 58 |
+
def wait_signal(addrs, sem: tl.constexpr):
|
| 59 |
+
if sem == "relaxed":
|
| 60 |
+
tl.inline_asm_elementwise(
|
| 61 |
+
"""
|
| 62 |
+
{
|
| 63 |
+
.reg .u32 %tmp32_<1>;
|
| 64 |
+
.reg .pred %p<1>;
|
| 65 |
+
|
| 66 |
+
wait_signal:
|
| 67 |
+
atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
|
| 68 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
| 69 |
+
@!%p0 bra wait_signal;
|
| 70 |
+
}
|
| 71 |
+
""",
|
| 72 |
+
"=r, l",
|
| 73 |
+
[addrs],
|
| 74 |
+
dtype=tl.int32,
|
| 75 |
+
is_pure=False,
|
| 76 |
+
pack=1,
|
| 77 |
+
)
|
| 78 |
+
elif sem == "acq_rel":
|
| 79 |
+
tl.inline_asm_elementwise(
|
| 80 |
+
"""
|
| 81 |
+
{
|
| 82 |
+
.reg .u32 %tmp32_<1>;
|
| 83 |
+
.reg .pred %p<1>;
|
| 84 |
+
|
| 85 |
+
wait_signal:
|
| 86 |
+
atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
|
| 87 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
| 88 |
+
@!%p0 bra wait_signal;
|
| 89 |
+
}
|
| 90 |
+
""",
|
| 91 |
+
"=r, l",
|
| 92 |
+
[addrs],
|
| 93 |
+
dtype=tl.int32,
|
| 94 |
+
is_pure=False,
|
| 95 |
+
pack=1,
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@triton.jit
|
| 102 |
+
def blockwise_barrier(
|
| 103 |
+
signal_pad_ptrs,
|
| 104 |
+
block_id,
|
| 105 |
+
rank: tl.constexpr,
|
| 106 |
+
world_size: tl.constexpr,
|
| 107 |
+
sem: tl.constexpr,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Synchronizes blocks with matching block_id across participating devices.
|
| 111 |
+
|
| 112 |
+
Note: the function itself is not a system level barrier/fence. It is a
|
| 113 |
+
building block for expressing different synchronization patterns.
|
| 114 |
+
|
| 115 |
+
Pattern 0: Ensures that all writes to symm_mem buffers from previous
|
| 116 |
+
kernels across all devices are visible to the current kernel:
|
| 117 |
+
|
| 118 |
+
blockwise_barrier(..., sem="relaxed")
|
| 119 |
+
sync_threads()
|
| 120 |
+
|
| 121 |
+
Pattern 1: Ensures that all writes to symm_mem buffers from the current
|
| 122 |
+
block are visible to all remote blocks with matching blockIdx:
|
| 123 |
+
|
| 124 |
+
sync_threads()
|
| 125 |
+
blockwise_barrier(..., sem="acq_rel")
|
| 126 |
+
sync_threads()
|
| 127 |
+
|
| 128 |
+
Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
|
| 129 |
+
for writing by subsequent kernels across all devices.
|
| 130 |
+
|
| 131 |
+
sync_threads()
|
| 132 |
+
blockwise_barrier(..., sem="relaxed")
|
| 133 |
+
|
| 134 |
+
CUDA graph friendliness:
|
| 135 |
+
|
| 136 |
+
This barrier operates through atomic operations on a zero-filled signal
|
| 137 |
+
pad, which resets to a zero-filled state after each successful
|
| 138 |
+
synchronization. This design eliminates the need for incrementing a
|
| 139 |
+
flag from host.
|
| 140 |
+
"""
|
| 141 |
+
if block_id is None:
|
| 142 |
+
block_id = get_flat_bid()
|
| 143 |
+
flat_tid = get_flat_tid()
|
| 144 |
+
|
| 145 |
+
remote_ranks = tl.arange(0, world_size)
|
| 146 |
+
signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
|
| 147 |
+
remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
|
| 148 |
+
tl.pointer_type(tl.uint32)
|
| 149 |
+
)
|
| 150 |
+
send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
|
| 151 |
+
|
| 152 |
+
local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
|
| 153 |
+
tl.pointer_type(tl.uint32)
|
| 154 |
+
)
|
| 155 |
+
wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
|
| 156 |
+
|
| 157 |
+
if flat_tid < world_size:
|
| 158 |
+
send_signal(send_addrs, sem)
|
| 159 |
+
wait_signal(wait_addrs, sem)
|
torchtitan/experiments/flux/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
torchtitan/experiments/flux/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (7.31 kB). View file
|
|
|
torchtitan/experiments/flux/dataset/__pycache__/tokenizer.cpython-312.pyc
ADDED
|
Binary file (2.21 kB). View file
|
|
|
torchtitan/experiments/flux/dataset/tokenizer.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 8 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import List
|
| 12 |
+
|
| 13 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 14 |
+
from transformers import CLIPTokenizer, T5Tokenizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FluxTokenizer(Tokenizer):
|
| 18 |
+
"""
|
| 19 |
+
Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model_path (str): Path to the tokenzier from hugging face.
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, model_path: str = "t5-small", max_length: int = 77):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self._n_words = 8 # TODO(jianiw): check
|
| 29 |
+
self._max_length = max_length
|
| 30 |
+
|
| 31 |
+
self.is_clip = model_path.startswith("openai")
|
| 32 |
+
|
| 33 |
+
if self.is_clip:
|
| 34 |
+
self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
|
| 35 |
+
model_path, max_length=max_length
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
|
| 39 |
+
model_path, max_length=max_length
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def encode(
|
| 43 |
+
self,
|
| 44 |
+
s: str,
|
| 45 |
+
) -> List[int]:
|
| 46 |
+
"""
|
| 47 |
+
Encode the prompt text into tokens.
|
| 48 |
+
"""
|
| 49 |
+
tokens = self._tokenizer(
|
| 50 |
+
s,
|
| 51 |
+
truncation=True,
|
| 52 |
+
max_length=self._max_length,
|
| 53 |
+
return_length=False,
|
| 54 |
+
return_overflowing_tokens=False,
|
| 55 |
+
padding="max_length",
|
| 56 |
+
return_tensors="pt", # return pytorch tensors, default return List[int]
|
| 57 |
+
)["input_ids"]
|
| 58 |
+
return tokens
|
| 59 |
+
|
| 60 |
+
def decode(self, t: List[int]) -> str:
|
| 61 |
+
"""
|
| 62 |
+
Decode function. This function will not be called.
|
| 63 |
+
"""
|
| 64 |
+
return self._tokenizer.decode(t)
|
torchtitan/experiments/flux/model/__pycache__/hf_embedder.cpython-312.pyc
ADDED
|
Binary file (1.95 kB). View file
|
|
|
torchtitan/experiments/flux/model/__pycache__/layers.cpython-312.pyc
ADDED
|
Binary file (17.7 kB). View file
|
|
|
torchtitan/experiments/flux/train_configs/debug_model.toml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
[job]
|
| 3 |
+
dump_folder = "./outputs"
|
| 4 |
+
description = "Flux debug model"
|
| 5 |
+
print_args = false
|
| 6 |
+
use_for_integration_test = true
|
| 7 |
+
|
| 8 |
+
[profiling]
|
| 9 |
+
enable_profiling = false
|
| 10 |
+
save_traces_folder = "profile_trace"
|
| 11 |
+
profile_freq = 10
|
| 12 |
+
enable_memory_snapshot = false
|
| 13 |
+
save_memory_snapshot_folder = "memory_snapshot"
|
| 14 |
+
|
| 15 |
+
[metrics]
|
| 16 |
+
log_freq = 1
|
| 17 |
+
disable_color_printing = false
|
| 18 |
+
enable_tensorboard = false
|
| 19 |
+
save_tb_folder = "tb"
|
| 20 |
+
enable_wandb = false
|
| 21 |
+
|
| 22 |
+
[model]
|
| 23 |
+
name = "flux"
|
| 24 |
+
flavor = "flux-debug"
|
| 25 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 26 |
+
# test tokenizer.model, for debug purpose only
|
| 27 |
+
# tokenizer_path = "./tests/assets/test_tiktoken.model"
|
| 28 |
+
# converters = "float8"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
[optimizer]
|
| 32 |
+
name = "AdamW"
|
| 33 |
+
lr = 8e-4
|
| 34 |
+
eps = 1e-8
|
| 35 |
+
|
| 36 |
+
[lr_scheduler]
|
| 37 |
+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
|
| 38 |
+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
|
| 39 |
+
decay_type = "linear"
|
| 40 |
+
lr_min = 0.0
|
| 41 |
+
|
| 42 |
+
[training]
|
| 43 |
+
batch_size = 32
|
| 44 |
+
seq_len = 512
|
| 45 |
+
max_norm = 1.0 # grad norm clipping
|
| 46 |
+
steps = 10
|
| 47 |
+
compile = false
|
| 48 |
+
dataset = "cc12m"
|
| 49 |
+
guidance = 3.5
|
| 50 |
+
seed = 0
|
| 51 |
+
|
| 52 |
+
[encoder]
|
| 53 |
+
t5_encoder="google/t5-v1_1-small"
|
| 54 |
+
clip_encoder="openai/clip-vit-large-patch14"
|
| 55 |
+
max_t5_encoding_len=512
|
| 56 |
+
auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
|
| 57 |
+
|
| 58 |
+
[parallelism]
|
| 59 |
+
data_parallel_replicate_degree = 1
|
| 60 |
+
data_parallel_shard_degree = 1
|
| 61 |
+
fsdp_reshard_after_forward = "default" # default / never / always
|
| 62 |
+
tensor_parallel_degree = 1
|
| 63 |
+
enable_async_tensor_parallel = false
|
| 64 |
+
pipeline_parallel_degree = 1
|
| 65 |
+
context_parallel_degree = 1
|
| 66 |
+
|
| 67 |
+
[experimental]
|
| 68 |
+
custom_args_module = "torchtitan.experiments.flux.flux_argparser"
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_backwards.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# pyre-unsafe
|
| 8 |
+
import logging
|
| 9 |
+
import unittest
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
from mg_grouped_gemm import (
|
| 16 |
+
grouped_gemm_backward,
|
| 17 |
+
grouped_gemm_dw_tma,
|
| 18 |
+
grouped_gemm_dx_tma,
|
| 19 |
+
grouped_gemm_forward,
|
| 20 |
+
mg_grouped_gemm,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from reference_utils import (
|
| 24 |
+
analyze_tensor_differences,
|
| 25 |
+
compute_reference_backward,
|
| 26 |
+
compute_reference_forward,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TestMG_GroupedGEMM_Backward(unittest.TestCase):
|
| 31 |
+
def setUp(self) -> None:
|
| 32 |
+
torch.manual_seed(2020) # Set seed for reproducibility
|
| 33 |
+
|
| 34 |
+
def _run_grouped_gemm_backward_test(
|
| 35 |
+
self,
|
| 36 |
+
shape: Tuple[int, int, int, int],
|
| 37 |
+
device: torch.device,
|
| 38 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 39 |
+
atol: float = 1e-5,
|
| 40 |
+
rtol: float = 1.6e-2,
|
| 41 |
+
) -> None:
|
| 42 |
+
G, M, N, K = shape
|
| 43 |
+
# Set up inputs for forward pass
|
| 44 |
+
# In M*G grouping, input is [M*G, K] and weights are [N, K]
|
| 45 |
+
a = torch.randn(M * G, K, dtype=dtype, device=device, requires_grad=True)
|
| 46 |
+
b = torch.randn(N, K, dtype=dtype, device=device, requires_grad=True)
|
| 47 |
+
|
| 48 |
+
# Create equal-sized groups for simplicity
|
| 49 |
+
m_size = M
|
| 50 |
+
m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
|
| 51 |
+
|
| 52 |
+
# Run forward pass with our implementation
|
| 53 |
+
result = grouped_gemm_forward(a, b, m_sizes)
|
| 54 |
+
# Ensure result has correct shape
|
| 55 |
+
self.assertTrue(result.shape == (M * G, N))
|
| 56 |
+
|
| 57 |
+
# Compute expected result using reference implementation
|
| 58 |
+
expected_result = compute_reference_forward(a, b, m_sizes)
|
| 59 |
+
|
| 60 |
+
# Verify forward pass correctness
|
| 61 |
+
forward_close = analyze_tensor_differences(
|
| 62 |
+
result, expected_result, "Forward output"
|
| 63 |
+
)
|
| 64 |
+
self.assertTrue(forward_close)
|
| 65 |
+
|
| 66 |
+
# Create a gradient for backpropagation
|
| 67 |
+
grad_output = torch.randn_like(result)
|
| 68 |
+
|
| 69 |
+
# Compute gradients using our custom backward implementation
|
| 70 |
+
grad_a, grad_b = grouped_gemm_backward(grad_output, a, b, m_sizes)
|
| 71 |
+
|
| 72 |
+
# Compute expected gradients using reference implementation
|
| 73 |
+
expected_grad_a, expected_grad_b = compute_reference_backward(
|
| 74 |
+
a, b, m_sizes, grad_output
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Verify gradient correctness
|
| 78 |
+
grad_a_close = analyze_tensor_differences(grad_a, expected_grad_a, "grad_x")
|
| 79 |
+
grad_b_close = analyze_tensor_differences(grad_b, expected_grad_b, "grad_w")
|
| 80 |
+
|
| 81 |
+
self.assertTrue(grad_a_close)
|
| 82 |
+
self.assertTrue(grad_b_close)
|
| 83 |
+
|
| 84 |
+
def test_MG_grouped_gemm_backward_bf16(self) -> None:
|
| 85 |
+
for G in (1, 8, 16):
|
| 86 |
+
for M in (512, 1024):
|
| 87 |
+
print(f"Testing BF16 M*G GroupGeMM Backward with G={G}, M={M}")
|
| 88 |
+
self._run_grouped_gemm_backward_test(
|
| 89 |
+
(G, M, 1024, 1024),
|
| 90 |
+
torch.device("cuda"),
|
| 91 |
+
dtype=torch.float16,
|
| 92 |
+
atol=1e-2,
|
| 93 |
+
rtol=1e-2,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def test_MG_grouped_gemm_backward_deepseek_shapes(self) -> None:
|
| 97 |
+
"""Test backward pass with shapes from Deepseek model."""
|
| 98 |
+
deepseek_shapes = [
|
| 99 |
+
(4, 2048, 4096, 7168), # G, M, N, K
|
| 100 |
+
(4, 2048, 7168, 2048),
|
| 101 |
+
(8, 512, 4096, 7168),
|
| 102 |
+
(8, 512, 7168, 2048),
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
device = torch.device("cuda")
|
| 106 |
+
|
| 107 |
+
for shape in deepseek_shapes:
|
| 108 |
+
G, M, N, K = shape
|
| 109 |
+
print(
|
| 110 |
+
f"Testing BF16 M*G Deepseek Backward shape: G={G}, M={M}, N={N}, K={K}"
|
| 111 |
+
)
|
| 112 |
+
self._run_grouped_gemm_backward_test(
|
| 113 |
+
shape, device, dtype=torch.float16, atol=1e-2, rtol=1e-2
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def test_MG_dx(self) -> None:
|
| 117 |
+
"""Test specifically the dx (gradient w.r.t. input) computation."""
|
| 118 |
+
G, M, N, K = 4, 512, 1024, 2048
|
| 119 |
+
device = torch.device("cuda")
|
| 120 |
+
dtype = torch.bfloat16
|
| 121 |
+
|
| 122 |
+
# Set up inputs
|
| 123 |
+
a = torch.randn(M * G, K, dtype=dtype, device=device, requires_grad=True)
|
| 124 |
+
b = torch.randn(N, K, dtype=dtype, device=device, requires_grad=True)
|
| 125 |
+
|
| 126 |
+
# Create equal-sized groups
|
| 127 |
+
m_size = M
|
| 128 |
+
m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
|
| 129 |
+
|
| 130 |
+
# Forward pass
|
| 131 |
+
result = grouped_gemm_forward(a, b, m_sizes)
|
| 132 |
+
|
| 133 |
+
# Create gradient for backward
|
| 134 |
+
grad_output = torch.randn_like(result)
|
| 135 |
+
|
| 136 |
+
# Compute gradient using our optimized function
|
| 137 |
+
grad_a, _ = grouped_gemm_backward(grad_output, a, b, m_sizes)
|
| 138 |
+
|
| 139 |
+
# Compute expected gradient using reference implementation
|
| 140 |
+
expected_grad_a, _ = compute_reference_backward(a, b, m_sizes, grad_output)
|
| 141 |
+
|
| 142 |
+
# Verify gradient
|
| 143 |
+
dx_close = analyze_tensor_differences(grad_a, expected_grad_a, "grad_a (dx)")
|
| 144 |
+
self.assertTrue(dx_close)
|
| 145 |
+
|
| 146 |
+
def test_MG_dw(self) -> None:
|
| 147 |
+
"""Test specifically the dw (gradient w.r.t. weights) computation."""
|
| 148 |
+
G, M, N, K = 4, 512, 1024, 2048
|
| 149 |
+
device = torch.device("cuda")
|
| 150 |
+
dtype = torch.bfloat16
|
| 151 |
+
|
| 152 |
+
# Set up inputs
|
| 153 |
+
a = torch.randn(M * G, K, dtype=dtype, device=device, requires_grad=True)
|
| 154 |
+
b = torch.randn(N, K, dtype=dtype, device=device, requires_grad=True)
|
| 155 |
+
|
| 156 |
+
# Create equal-sized groups
|
| 157 |
+
m_size = M
|
| 158 |
+
m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
|
| 159 |
+
|
| 160 |
+
# Forward pass
|
| 161 |
+
result = grouped_gemm_forward(a, b, m_sizes)
|
| 162 |
+
|
| 163 |
+
# Create gradient for backward
|
| 164 |
+
grad_output = torch.randn_like(result)
|
| 165 |
+
|
| 166 |
+
# Compute gradient using our optimized function
|
| 167 |
+
_, grad_b = grouped_gemm_backward(grad_output, a, b, m_sizes)
|
| 168 |
+
|
| 169 |
+
# Compute expected gradient using reference implementation
|
| 170 |
+
_, expected_grad_b = compute_reference_backward(a, b, m_sizes, grad_output)
|
| 171 |
+
|
| 172 |
+
# Verify gradient
|
| 173 |
+
dw_close = analyze_tensor_differences(grad_b, expected_grad_b, "grad_b (dw)")
|
| 174 |
+
self.assertTrue(dw_close)
|
torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
torchtitan/experiments/llama4/model/args.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 13 |
+
from torchtitan.config_manager import JobConfig
|
| 14 |
+
|
| 15 |
+
from torchtitan.protocols.train_spec import BaseModelArgs
|
| 16 |
+
from torchtitan.tools.logging import logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class TransformerModelArgs(BaseModelArgs):
|
| 21 |
+
dim: int = 4096
|
| 22 |
+
n_layers: int = 32
|
| 23 |
+
n_heads: int = 32
|
| 24 |
+
n_kv_heads: Optional[int] = None
|
| 25 |
+
vocab_size: int = -1 # defined later by tokenizer
|
| 26 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
| 27 |
+
ffn_dim_multiplier: Optional[float] = None
|
| 28 |
+
norm_eps: float = 1e-5
|
| 29 |
+
rope_theta: float = 10000
|
| 30 |
+
|
| 31 |
+
max_seq_len: int = 2048
|
| 32 |
+
# If `True`, then each transformer block init uses its layer ID, and if
|
| 33 |
+
# `False`, each uses the total number of transformer blocks
|
| 34 |
+
depth_init: bool = True
|
| 35 |
+
norm_type: str = "rmsnorm"
|
| 36 |
+
|
| 37 |
+
use_flex_attn: bool = False
|
| 38 |
+
attn_mask_type: str = "causal"
|
| 39 |
+
eos_id: int = 0
|
| 40 |
+
|
| 41 |
+
# MoE args
|
| 42 |
+
moe_enabled: bool = True
|
| 43 |
+
num_experts: int = 8
|
| 44 |
+
use_shared_expert: bool = True
|
| 45 |
+
auto_scale_hidden_dim: bool = True
|
| 46 |
+
# frequency of using MoE layer instead of feedforward layer in a transformer block
|
| 47 |
+
interleave_moe_layer_step: int = 2
|
| 48 |
+
# token-choice
|
| 49 |
+
top_k: int = 1
|
| 50 |
+
|
| 51 |
+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
|
| 52 |
+
self.norm_type = job_config.model.norm_type
|
| 53 |
+
self.vocab_size = tokenizer.n_words
|
| 54 |
+
self.max_seq_len = job_config.training.seq_len
|
| 55 |
+
self.use_flex_attn = job_config.model.use_flex_attn
|
| 56 |
+
|
| 57 |
+
def get_nparams_and_flops(
|
| 58 |
+
self, model: nn.Module, seq_len: int
|
| 59 |
+
) -> tuple[int, float]:
|
| 60 |
+
nparams_embedding = 0
|
| 61 |
+
nparams_moe_router = 0
|
| 62 |
+
nparams_shared_expert = 0
|
| 63 |
+
nparams_experts = 0
|
| 64 |
+
nparams_dense = 0
|
| 65 |
+
|
| 66 |
+
for name, p in model.named_parameters():
|
| 67 |
+
if "embedding" in name:
|
| 68 |
+
nparams_embedding += p.numel()
|
| 69 |
+
nparams_dense += p.numel()
|
| 70 |
+
elif "moe.shared_expert" in name:
|
| 71 |
+
nparams_shared_expert += p.numel()
|
| 72 |
+
elif "moe.router" in name:
|
| 73 |
+
nparams_moe_router += p.numel()
|
| 74 |
+
elif "moe.experts" in name:
|
| 75 |
+
nparams_experts += p.numel()
|
| 76 |
+
else:
|
| 77 |
+
nparams_dense += p.numel()
|
| 78 |
+
|
| 79 |
+
nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
|
| 80 |
+
nparams = nparams_dense + nparams_sparse
|
| 81 |
+
nparams_sparse_active = (
|
| 82 |
+
nparams_moe_router
|
| 83 |
+
+ nparams_shared_expert
|
| 84 |
+
+ nparams_experts * self.top_k // self.num_experts
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
logger.info(
|
| 88 |
+
f"Total parameter count: dense {nparams_dense:,}, "
|
| 89 |
+
f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
l, h, q, t = (
|
| 93 |
+
self.n_layers,
|
| 94 |
+
self.n_heads,
|
| 95 |
+
self.dim // self.n_heads,
|
| 96 |
+
seq_len,
|
| 97 |
+
)
|
| 98 |
+
# Reasoning behind the factor of 12 for the self-attention part of the formula:
|
| 99 |
+
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
|
| 100 |
+
# 2. the flash attention does 1 more matmul recomputation in the backward
|
| 101 |
+
# but recomputation should not be counted in calculating MFU (+0)
|
| 102 |
+
# 3. each matmul performs 1 multiplication and 1 addition (*2)
|
| 103 |
+
# 4. we follow the convention and do not account for sparsity in causal attention
|
| 104 |
+
num_flops_per_token = (
|
| 105 |
+
6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
|
| 106 |
+
+ 12 * l * h * q * t
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return nparams, num_flops_per_token
|
torchtitan/experiments/multimodal/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from mm_dataset import build_mm_dataloader
|
| 8 |
+
|
| 9 |
+
from torchtitan.components.loss import build_cross_entropy_loss
|
| 10 |
+
from torchtitan.components.lr_scheduler import build_lr_schedulers
|
| 11 |
+
from torchtitan.components.optimizer import build_optimizers
|
| 12 |
+
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
|
| 13 |
+
from torchtitan.models.llama3 import parallelize_llama, pipeline_llama
|
| 14 |
+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
|
| 15 |
+
|
| 16 |
+
from .model import ModelArgs, MultimodalDecoder, VisionEncoder
|
| 17 |
+
|
| 18 |
+
__all__ = ["VisionEncoder", "ModelArgs", "MultimodalDecoder"]
|
| 19 |
+
|
| 20 |
+
llama4_mm_configs = {
|
| 21 |
+
# TODO: add configs for llama4 multimodal
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
register_train_spec(
|
| 25 |
+
TrainSpec(
|
| 26 |
+
name="llama4_multimodal",
|
| 27 |
+
cls=MultimodalDecoder,
|
| 28 |
+
config=llama4_mm_configs,
|
| 29 |
+
parallelize_fn=parallelize_llama,
|
| 30 |
+
pipelining_fn=pipeline_llama,
|
| 31 |
+
build_optimizers_fn=build_optimizers,
|
| 32 |
+
build_lr_schedulers_fn=build_lr_schedulers,
|
| 33 |
+
build_dataloader_fn=build_mm_dataloader,
|
| 34 |
+
build_tokenizer_fn=build_tiktoken_tokenizer,
|
| 35 |
+
build_loss_fn=build_cross_entropy_loss,
|
| 36 |
+
)
|
| 37 |
+
)
|
torchtitan/experiments/multimodal/tests/test_multimodal_model.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from torchtitan.experiments.llama_multimodal import (
|
| 11 |
+
ModelArgs,
|
| 12 |
+
MultimodalDecoder,
|
| 13 |
+
VisionEncoder,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from .test_utils import fixed_init_model, fixed_init_tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@pytest.fixture
|
| 20 |
+
def encoder_config():
|
| 21 |
+
return ModelArgs(
|
| 22 |
+
encoder_embed_dim=32,
|
| 23 |
+
encoder_num_layers=2,
|
| 24 |
+
encoder_num_heads=4,
|
| 25 |
+
tile_size=49,
|
| 26 |
+
patch_size=9,
|
| 27 |
+
max_num_tiles=4,
|
| 28 |
+
in_channels=3,
|
| 29 |
+
return_intermediates=[0, 1],
|
| 30 |
+
num_layers_projection=2,
|
| 31 |
+
decoder_embed_dim=128,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@pytest.fixture
|
| 36 |
+
def decoder_config():
|
| 37 |
+
return ModelArgs(
|
| 38 |
+
decoder_embed_dim=512,
|
| 39 |
+
vocab_size=10000,
|
| 40 |
+
fusion_interval=2,
|
| 41 |
+
num_special_tokens=3,
|
| 42 |
+
decoder_num_layers=6,
|
| 43 |
+
decoder_num_heads=8,
|
| 44 |
+
decoder_num_kv_heads=4,
|
| 45 |
+
max_seq_len=512,
|
| 46 |
+
rope_theta=50000.0,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TestMultimodalModelVisionEncoder:
|
| 51 |
+
@pytest.fixture(autouse=True)
|
| 52 |
+
def setup_class(self, encoder_config):
|
| 53 |
+
self.model_args = encoder_config
|
| 54 |
+
self.batch_size = 1
|
| 55 |
+
self.num_imgs = 2
|
| 56 |
+
self.num_tiles = 4
|
| 57 |
+
self.aspect_ratio = torch.tensor([[1, 3], [2, 2]]).reshape(
|
| 58 |
+
self.batch_size, self.num_imgs, 2
|
| 59 |
+
)
|
| 60 |
+
image = torch.rand(
|
| 61 |
+
(
|
| 62 |
+
self.batch_size,
|
| 63 |
+
self.num_imgs,
|
| 64 |
+
self.num_tiles,
|
| 65 |
+
self.model_args.in_channels,
|
| 66 |
+
self.model_args.tile_size,
|
| 67 |
+
self.model_args.tile_size,
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
self.image = fixed_init_tensor(image.shape, min_val=-1, max_val=1)
|
| 71 |
+
|
| 72 |
+
def test_llama_mm_vision_encoder(self):
|
| 73 |
+
model = VisionEncoder(self.model_args)
|
| 74 |
+
fixed_init_model(model, min_val=-1, max_val=1)
|
| 75 |
+
output = model(self.image, self.aspect_ratio)
|
| 76 |
+
expected_shape = (
|
| 77 |
+
self.batch_size,
|
| 78 |
+
self.num_imgs * self.num_tiles * (model.vit.patches_per_tile + 1),
|
| 79 |
+
self.model_args.decoder_embed_dim,
|
| 80 |
+
)
|
| 81 |
+
assert (
|
| 82 |
+
output.shape == expected_shape
|
| 83 |
+
), f"Expected shape {expected_shape}, but got {output.shape}"
|
| 84 |
+
|
| 85 |
+
# TODO: Need to ensure numerical stability before doing convergence test.
|
| 86 |
+
# output.mean() = 3.994, we need to debug why it is not close to 5.28800, which is
|
| 87 |
+
# the test value from the original torch tune test
|
| 88 |
+
# assert torch.allclose(
|
| 89 |
+
# output.mean(), torch.tensor(5.28800), atol=1e-3, rtol=1e-3
|
| 90 |
+
# )
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TestMultimodalModelDecoder:
|
| 94 |
+
@pytest.fixture(autouse=True)
|
| 95 |
+
def setup_class(self, decoder_config):
|
| 96 |
+
self.model_args = decoder_config
|
| 97 |
+
self.batch_size = 1
|
| 98 |
+
self.decoder_embed_dim = self.model_args.decoder_embed_dim
|
| 99 |
+
self.vocab_size = self.model_args.vocab_size
|
| 100 |
+
self.seq_len = 128
|
| 101 |
+
self.input = {
|
| 102 |
+
"tokens": torch.arange(self.batch_size * self.seq_len).reshape(
|
| 103 |
+
self.batch_size, self.seq_len
|
| 104 |
+
),
|
| 105 |
+
"encoder_input": fixed_init_tensor(
|
| 106 |
+
(self.batch_size, self.seq_len, self.decoder_embed_dim),
|
| 107 |
+
min_val=-1,
|
| 108 |
+
max_val=1,
|
| 109 |
+
),
|
| 110 |
+
"encoder_mask": None,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
@torch.no_grad()
|
| 114 |
+
def test_llama_mm_decoder(self):
|
| 115 |
+
model = MultimodalDecoder(self.model_args)
|
| 116 |
+
fixed_init_model(model, min_val=-1, max_val=1)
|
| 117 |
+
output = model(**self.input)
|
| 118 |
+
expected_shape = (self.batch_size, self.seq_len, self.vocab_size)
|
| 119 |
+
assert (
|
| 120 |
+
output.shape == expected_shape
|
| 121 |
+
), f"Expected shape {expected_shape}, but got {output.shape}"
|
| 122 |
+
|
| 123 |
+
# TODO: Need to ensure numerical stability before doing convergence test.
|
| 124 |
+
# output.mean() = -0.0134, we need to debug why it is not close to -9.47548e-5, which is
|
| 125 |
+
# the test value from the original torch tune test
|
| 126 |
+
# assert torch.allclose(
|
| 127 |
+
# output.mean(), torch.tensor(-9.47548e-5), atol=1e-3, rtol=1e-3
|
| 128 |
+
# )
|
torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
torchtitan/models/__pycache__/norms.cpython-312.pyc
ADDED
|
Binary file (1.39 kB). View file
|
|
|