zaydzuhri commited on
Commit
8a45d34
·
verified ·
1 Parent(s): d092603

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. profile_trace/iteration_10752/rank5_trace.json +0 -0
  2. profile_trace/iteration_11776/rank3_trace.json +0 -0
  3. profile_trace/iteration_11776/rank4_trace.json +0 -0
  4. profile_trace/iteration_11776/rank7_trace.json +0 -0
  5. profile_trace/iteration_12288/rank5_trace.json +0 -0
  6. profile_trace/iteration_13824/rank2_trace.json +0 -0
  7. profile_trace/iteration_13824/rank4_trace.json +0 -0
  8. profile_trace/iteration_13824/rank6_trace.json +0 -0
  9. profile_trace/iteration_14848/rank2_trace.json +0 -0
  10. profile_trace/iteration_14848/rank4_trace.json +0 -0
  11. profile_trace/iteration_14848/rank7_trace.json +0 -0
  12. profile_trace/iteration_21504/rank0_trace.json +0 -0
  13. profile_trace/iteration_21504/rank3_trace.json +0 -0
  14. profile_trace/iteration_28160/rank2_trace.json +0 -0
  15. profile_trace/iteration_28160/rank4_trace.json +0 -0
  16. profile_trace/iteration_28160/rank6_trace.json +0 -0
  17. profile_trace/iteration_31744/rank3_trace.json +0 -0
  18. profile_trace/iteration_33792/rank1_trace.json +0 -0
  19. profile_trace/iteration_33792/rank2_trace.json +0 -0
  20. profile_trace/iteration_33792/rank4_trace.json +0 -0
  21. profile_trace/iteration_33792/rank5_trace.json +0 -0
  22. profile_trace/iteration_33792/rank6_trace.json +0 -0
  23. profile_trace/iteration_33792/rank7_trace.json +0 -0
  24. profile_trace/iteration_512/rank0_trace.json +0 -0
  25. profile_trace/iteration_512/rank1_trace.json +0 -0
  26. profile_trace/iteration_512/rank3_trace.json +0 -0
  27. profile_trace/iteration_8192/rank0_trace.json +0 -0
  28. profile_trace/iteration_8192/rank1_trace.json +0 -0
  29. profile_trace/iteration_8192/rank5_trace.json +0 -0
  30. torchtitan/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
  31. torchtitan/components/__pycache__/dataloader.cpython-312.pyc +0 -0
  32. torchtitan/components/__pycache__/float8.cpython-312.pyc +0 -0
  33. torchtitan/components/loss.py +29 -0
  34. torchtitan/experiments/deepseek_v3/indices.py +195 -0
  35. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
  36. torchtitan/experiments/flux/__pycache__/__init__.cpython-312.pyc +0 -0
  37. torchtitan/experiments/flux/__pycache__/utils.cpython-312.pyc +0 -0
  38. torchtitan/experiments/flux/dataset/__pycache__/tokenizer.cpython-312.pyc +0 -0
  39. torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
  40. torchtitan/experiments/flux/model/__pycache__/hf_embedder.cpython-312.pyc +0 -0
  41. torchtitan/experiments/flux/model/__pycache__/layers.cpython-312.pyc +0 -0
  42. torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
  43. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_backwards.py +174 -0
  44. torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc +0 -0
  45. torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc +0 -0
  46. torchtitan/experiments/llama4/model/args.py +109 -0
  47. torchtitan/experiments/multimodal/__init__.py +37 -0
  48. torchtitan/experiments/multimodal/tests/test_multimodal_model.py +128 -0
  49. torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc +0 -0
  50. torchtitan/models/__pycache__/norms.cpython-312.pyc +0 -0
profile_trace/iteration_10752/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_11776/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_11776/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_11776/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_12288/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_13824/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_13824/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_13824/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_14848/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_14848/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_14848/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_21504/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_21504/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_28160/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_28160/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_28160/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_31744/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_33792/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_33792/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_33792/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_33792/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_33792/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_33792/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_512/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_512/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_512/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_8192/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_8192/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_8192/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
torchtitan/components/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (33.1 kB). View file
 
torchtitan/components/__pycache__/dataloader.cpython-312.pyc ADDED
Binary file (3.79 kB). View file
 
torchtitan/components/__pycache__/float8.cpython-312.pyc ADDED
Binary file (6.2 kB). View file
 
torchtitan/components/loss.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, TypeAlias
8
+
9
+ import torch
10
+
11
+ from torchtitan.config_manager import JobConfig
12
+ from torchtitan.tools.logging import logger
13
+
14
+ LossFunction: TypeAlias = Callable[..., torch.Tensor]
15
+
16
+
17
+ def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
18
+ """Common cross-entropy loss function for Transformer models training."""
19
+ return torch.nn.functional.cross_entropy(
20
+ pred.flatten(0, 1).float(), labels.flatten(0, 1)
21
+ )
22
+
23
+
24
+ def build_cross_entropy_loss(job_config: JobConfig):
25
+ loss_fn = cross_entropy_loss
26
+ if job_config.training.compile:
27
+ logger.info("Compiling the loss function with torch.compile")
28
+ loss_fn = torch.compile(loss_fn)
29
+ return loss_fn
torchtitan/experiments/deepseek_v3/indices.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+
12
+ __all__ = ["generate_permute_indices"]
13
+
14
+
15
+ @triton.jit
16
+ def fill_indices_kernel(
17
+ tokens_per_expert_group_ptr, # *Pointer* to first input vector.
18
+ start_index_values_ptr, # *Pointer* to second input vector.
19
+ write_offsets_ptr, # *Pointer* to third input vector.
20
+ output_ptr, # *Pointer* to output vector.
21
+ experts_per_rank, # Number of experts per rank.
22
+ num_ranks, # Number of expert ranks.
23
+ ):
24
+ # There are multiple 'programs' processing different data. We identify which program
25
+ # we are here:
26
+ pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
27
+ # The total number of programs in the launch grid.
28
+ num_programs = tl.num_programs(axis=0)
29
+ # We map the programs (blocks) to the experts.
30
+ for expert_id in tl.range(pid, experts_per_rank, step=num_programs):
31
+ # Read this expert's write offset.
32
+ write_offset = tl.load(write_offsets_ptr + expert_id)
33
+ # Loop over the ranks.
34
+ for r in tl.range(num_ranks):
35
+ # Slot in the tokens_per_expert_group array.
36
+ i = r * experts_per_rank + expert_id
37
+ start_index = tl.load(start_index_values_ptr + i)
38
+ length = tl.load(tokens_per_expert_group_ptr + i)
39
+ # Write the indices.
40
+ for l in tl.range(length):
41
+ val = start_index + l
42
+ tl.store(output_ptr + write_offset + l, val)
43
+ write_offset += length
44
+
45
+
46
+ def fill_indices(
47
+ tokens_per_expert_group: torch.Tensor,
48
+ start_index_values: torch.Tensor,
49
+ write_offsets: torch.Tensor,
50
+ experts_per_rank: int,
51
+ num_ranks: int,
52
+ max_len: int,
53
+ ):
54
+ # We need to preallocate the output.
55
+ permuted_indices = torch.full(
56
+ (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
57
+ )
58
+ # Analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
59
+ # In this case, we use a 1D grid where the size is the number of blocks (TODO: bump this value).
60
+ grid = lambda meta: (1,)
61
+ # Each torch.tensor object is implicitly converted into a pointer to its first element.
62
+ fill_indices_kernel[grid](
63
+ tokens_per_expert_group,
64
+ start_index_values,
65
+ write_offsets,
66
+ permuted_indices,
67
+ experts_per_rank,
68
+ num_ranks,
69
+ )
70
+ return permuted_indices
71
+
72
+
73
+ def fill_indices_cpu(
74
+ tokens_per_expert_group: torch.Tensor,
75
+ start_index_values: torch.Tensor,
76
+ write_offsets: torch.Tensor,
77
+ experts_per_rank: int,
78
+ num_ranks: int,
79
+ max_len: int,
80
+ ):
81
+ # We need to preallocate the output.
82
+ permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
83
+ # Fill the permuted indices
84
+ # For each local expert
85
+ for e in range(experts_per_rank):
86
+ write_start = write_offsets[e]
87
+ # For each remote rank
88
+ for r in range(num_ranks):
89
+ i = r * experts_per_rank + e
90
+ start_index = start_index_values[i]
91
+ length = tokens_per_expert_group[i]
92
+ # Fill in the indices
93
+ permuted_indices[write_start : write_start + length] = torch.arange(
94
+ start_index, start_index + length
95
+ )
96
+ write_start += length
97
+ return permuted_indices
98
+
99
+
100
+ def generate_permute_indices(
101
+ tokens_per_expert_group: torch.Tensor,
102
+ experts_per_rank: int,
103
+ num_ranks: int,
104
+ max_len: int,
105
+ alignment: int,
106
+ use_cpu: bool = False,
107
+ ):
108
+ # Prepare permutation indices and the number of tokens for each expert. The
109
+ # permutation indices are the indices of the tokens for each expert. The
110
+ # number of tokens for each expert is the sum of the number of tokens for
111
+ # such experts from all ranks. This number is aligned to the provided
112
+ # alignment requirement (usually comes from group gemm).
113
+
114
+ # Args:
115
+ # tokens_per_expert_group: number of tokens for each expert from all ranks.
116
+ # experts_per_rank: number of experts per rank.
117
+ # num_ranks: number of ranks.
118
+ # max_len: maximum length of the output index vector. If greater than
119
+ # total number of tokens, the remaining indices are set to -1.
120
+ # alignment: alignment for each returned element in `m_sizes`.
121
+ # use_cpu: whether to use cpu or gpu.
122
+ # Returns:
123
+ # permuted_indices: permutation indices.
124
+ # m_sizes: number of tokens for each expert.
125
+
126
+ # `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example:
127
+ # From: | rank 0 | rank 1 |
128
+ # To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 |
129
+ # | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
130
+
131
+ # Prefix sum to get the start index value of each expert
132
+ start_index_values = (
133
+ torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
134
+ )
135
+ # Chunk sizes for each expert
136
+ chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
137
+ # Align the chunk sizes to the given alignment
138
+ m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to(
139
+ torch.int32
140
+ )
141
+ # Perform another prefix sum to get the write offset of each expert in `permuted_indices`
142
+ write_offsets = torch.cumsum(m_sizes, 0) - m_sizes
143
+ # Select the method to fill the permuted indices
144
+ fill_fn = fill_indices_cpu if use_cpu else fill_indices
145
+ # Fill the permuted indices
146
+ permuted_indices = fill_fn(
147
+ tokens_per_expert_group,
148
+ start_index_values,
149
+ write_offsets,
150
+ experts_per_rank,
151
+ num_ranks,
152
+ max_len,
153
+ )
154
+ return permuted_indices, m_sizes
155
+
156
+
157
+ # Below is for testing only
158
+
159
+
160
+ def test():
161
+ device = torch.device("cuda", 0)
162
+ experts_per_rank = 4
163
+ num_ranks = 4
164
+ tokens_per_expert_group = torch.full(
165
+ (num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device
166
+ )
167
+ max_len = 128
168
+ alignment = 32
169
+ # Use the GPU kernel
170
+ permuted_indices_gpu, m_sizes = generate_permute_indices(
171
+ tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
172
+ )
173
+ # Use the CPU method
174
+ permuted_indices_cpu, _ = generate_permute_indices(
175
+ tokens_per_expert_group,
176
+ experts_per_rank,
177
+ num_ranks,
178
+ max_len,
179
+ alignment,
180
+ use_cpu=True,
181
+ )
182
+ # Check that the results are the same
183
+ assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu)
184
+ assert torch.equal(
185
+ torch.remainder(m_sizes, alignment),
186
+ torch.zeros(experts_per_rank, device=device),
187
+ )
188
+ # Print the results
189
+ print(permuted_indices_gpu)
190
+ print(m_sizes)
191
+ print("Success")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ test()
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from .triton_utils import get_flat_bid, get_flat_tid
11
+
12
+
13
+ @triton.jit
14
+ def send_signal(addrs, sem: tl.constexpr):
15
+ if sem == "relaxed":
16
+ tl.inline_asm_elementwise(
17
+ """
18
+ {
19
+ .reg .u32 %tmp32_<1>;
20
+ .reg .pred %p<1>;
21
+
22
+ send_signal:
23
+ atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
24
+ setp.eq.u32 %p0, %tmp32_0, 0;
25
+ @!%p0 bra send_signal;
26
+ }
27
+ """,
28
+ "=r, l",
29
+ [addrs],
30
+ dtype=tl.int32,
31
+ is_pure=False,
32
+ pack=1,
33
+ )
34
+ elif sem == "acq_rel":
35
+ tl.inline_asm_elementwise(
36
+ """
37
+ {
38
+ .reg .u32 %tmp32_<1>;
39
+ .reg .pred %p<1>;
40
+
41
+ send_signal:
42
+ atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
43
+ setp.eq.u32 %p0, %tmp32_0, 0;
44
+ @!%p0 bra send_signal;
45
+ }
46
+ """,
47
+ "=r, l",
48
+ [addrs],
49
+ dtype=tl.int32,
50
+ is_pure=False,
51
+ pack=1,
52
+ )
53
+ else:
54
+ raise RuntimeError(f"Unrecognized sem: {sem}")
55
+
56
+
57
+ @triton.jit
58
+ def wait_signal(addrs, sem: tl.constexpr):
59
+ if sem == "relaxed":
60
+ tl.inline_asm_elementwise(
61
+ """
62
+ {
63
+ .reg .u32 %tmp32_<1>;
64
+ .reg .pred %p<1>;
65
+
66
+ wait_signal:
67
+ atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
68
+ setp.eq.u32 %p0, %tmp32_0, 1;
69
+ @!%p0 bra wait_signal;
70
+ }
71
+ """,
72
+ "=r, l",
73
+ [addrs],
74
+ dtype=tl.int32,
75
+ is_pure=False,
76
+ pack=1,
77
+ )
78
+ elif sem == "acq_rel":
79
+ tl.inline_asm_elementwise(
80
+ """
81
+ {
82
+ .reg .u32 %tmp32_<1>;
83
+ .reg .pred %p<1>;
84
+
85
+ wait_signal:
86
+ atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
87
+ setp.eq.u32 %p0, %tmp32_0, 1;
88
+ @!%p0 bra wait_signal;
89
+ }
90
+ """,
91
+ "=r, l",
92
+ [addrs],
93
+ dtype=tl.int32,
94
+ is_pure=False,
95
+ pack=1,
96
+ )
97
+ else:
98
+ raise RuntimeError(f"Unrecognized sem: {sem}")
99
+
100
+
101
+ @triton.jit
102
+ def blockwise_barrier(
103
+ signal_pad_ptrs,
104
+ block_id,
105
+ rank: tl.constexpr,
106
+ world_size: tl.constexpr,
107
+ sem: tl.constexpr,
108
+ ):
109
+ """
110
+ Synchronizes blocks with matching block_id across participating devices.
111
+
112
+ Note: the function itself is not a system level barrier/fence. It is a
113
+ building block for expressing different synchronization patterns.
114
+
115
+ Pattern 0: Ensures that all writes to symm_mem buffers from previous
116
+ kernels across all devices are visible to the current kernel:
117
+
118
+ blockwise_barrier(..., sem="relaxed")
119
+ sync_threads()
120
+
121
+ Pattern 1: Ensures that all writes to symm_mem buffers from the current
122
+ block are visible to all remote blocks with matching blockIdx:
123
+
124
+ sync_threads()
125
+ blockwise_barrier(..., sem="acq_rel")
126
+ sync_threads()
127
+
128
+ Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
129
+ for writing by subsequent kernels across all devices.
130
+
131
+ sync_threads()
132
+ blockwise_barrier(..., sem="relaxed")
133
+
134
+ CUDA graph friendliness:
135
+
136
+ This barrier operates through atomic operations on a zero-filled signal
137
+ pad, which resets to a zero-filled state after each successful
138
+ synchronization. This design eliminates the need for incrementing a
139
+ flag from host.
140
+ """
141
+ if block_id is None:
142
+ block_id = get_flat_bid()
143
+ flat_tid = get_flat_tid()
144
+
145
+ remote_ranks = tl.arange(0, world_size)
146
+ signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
147
+ remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
148
+ tl.pointer_type(tl.uint32)
149
+ )
150
+ send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
151
+
152
+ local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
153
+ tl.pointer_type(tl.uint32)
154
+ )
155
+ wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
156
+
157
+ if flat_tid < world_size:
158
+ send_signal(send_addrs, sem)
159
+ wait_signal(wait_addrs, sem)
torchtitan/experiments/flux/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (2.08 kB). View file
 
torchtitan/experiments/flux/__pycache__/utils.cpython-312.pyc ADDED
Binary file (7.31 kB). View file
 
torchtitan/experiments/flux/dataset/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (2.21 kB). View file
 
torchtitan/experiments/flux/dataset/tokenizer.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+
11
+ from typing import List
12
+
13
+ from torchtitan.components.tokenizer import Tokenizer
14
+ from transformers import CLIPTokenizer, T5Tokenizer
15
+
16
+
17
+ class FluxTokenizer(Tokenizer):
18
+ """
19
+ Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
20
+
21
+ Args:
22
+ model_path (str): Path to the tokenzier from hugging face.
23
+
24
+ """
25
+
26
+ def __init__(self, model_path: str = "t5-small", max_length: int = 77):
27
+ super().__init__()
28
+ self._n_words = 8 # TODO(jianiw): check
29
+ self._max_length = max_length
30
+
31
+ self.is_clip = model_path.startswith("openai")
32
+
33
+ if self.is_clip:
34
+ self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
35
+ model_path, max_length=max_length
36
+ )
37
+ else:
38
+ self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
39
+ model_path, max_length=max_length
40
+ )
41
+
42
+ def encode(
43
+ self,
44
+ s: str,
45
+ ) -> List[int]:
46
+ """
47
+ Encode the prompt text into tokens.
48
+ """
49
+ tokens = self._tokenizer(
50
+ s,
51
+ truncation=True,
52
+ max_length=self._max_length,
53
+ return_length=False,
54
+ return_overflowing_tokens=False,
55
+ padding="max_length",
56
+ return_tensors="pt", # return pytorch tensors, default return List[int]
57
+ )["input_ids"]
58
+ return tokens
59
+
60
+ def decode(self, t: List[int]) -> str:
61
+ """
62
+ Decode function. This function will not be called.
63
+ """
64
+ return self._tokenizer.decode(t)
torchtitan/experiments/flux/model/__pycache__/hf_embedder.cpython-312.pyc ADDED
Binary file (1.95 kB). View file
 
torchtitan/experiments/flux/model/__pycache__/layers.cpython-312.pyc ADDED
Binary file (17.7 kB). View file
 
torchtitan/experiments/flux/train_configs/debug_model.toml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [job]
3
+ dump_folder = "./outputs"
4
+ description = "Flux debug model"
5
+ print_args = false
6
+ use_for_integration_test = true
7
+
8
+ [profiling]
9
+ enable_profiling = false
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 10
12
+ enable_memory_snapshot = false
13
+ save_memory_snapshot_folder = "memory_snapshot"
14
+
15
+ [metrics]
16
+ log_freq = 1
17
+ disable_color_printing = false
18
+ enable_tensorboard = false
19
+ save_tb_folder = "tb"
20
+ enable_wandb = false
21
+
22
+ [model]
23
+ name = "flux"
24
+ flavor = "flux-debug"
25
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
26
+ # test tokenizer.model, for debug purpose only
27
+ # tokenizer_path = "./tests/assets/test_tiktoken.model"
28
+ # converters = "float8"
29
+
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 8e-4
34
+ eps = 1e-8
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.0
41
+
42
+ [training]
43
+ batch_size = 32
44
+ seq_len = 512
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "cc12m"
49
+ guidance = 3.5
50
+ seed = 0
51
+
52
+ [encoder]
53
+ t5_encoder="google/t5-v1_1-small"
54
+ clip_encoder="openai/clip-vit-large-patch14"
55
+ max_t5_encoding_len=512
56
+ auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
57
+
58
+ [parallelism]
59
+ data_parallel_replicate_degree = 1
60
+ data_parallel_shard_degree = 1
61
+ fsdp_reshard_after_forward = "default" # default / never / always
62
+ tensor_parallel_degree = 1
63
+ enable_async_tensor_parallel = false
64
+ pipeline_parallel_degree = 1
65
+ context_parallel_degree = 1
66
+
67
+ [experimental]
68
+ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_backwards.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import unittest
10
+ from typing import Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from mg_grouped_gemm import (
16
+ grouped_gemm_backward,
17
+ grouped_gemm_dw_tma,
18
+ grouped_gemm_dx_tma,
19
+ grouped_gemm_forward,
20
+ mg_grouped_gemm,
21
+ )
22
+
23
+ from reference_utils import (
24
+ analyze_tensor_differences,
25
+ compute_reference_backward,
26
+ compute_reference_forward,
27
+ )
28
+
29
+
30
+ class TestMG_GroupedGEMM_Backward(unittest.TestCase):
31
+ def setUp(self) -> None:
32
+ torch.manual_seed(2020) # Set seed for reproducibility
33
+
34
+ def _run_grouped_gemm_backward_test(
35
+ self,
36
+ shape: Tuple[int, int, int, int],
37
+ device: torch.device,
38
+ dtype: torch.dtype = torch.bfloat16,
39
+ atol: float = 1e-5,
40
+ rtol: float = 1.6e-2,
41
+ ) -> None:
42
+ G, M, N, K = shape
43
+ # Set up inputs for forward pass
44
+ # In M*G grouping, input is [M*G, K] and weights are [N, K]
45
+ a = torch.randn(M * G, K, dtype=dtype, device=device, requires_grad=True)
46
+ b = torch.randn(N, K, dtype=dtype, device=device, requires_grad=True)
47
+
48
+ # Create equal-sized groups for simplicity
49
+ m_size = M
50
+ m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
51
+
52
+ # Run forward pass with our implementation
53
+ result = grouped_gemm_forward(a, b, m_sizes)
54
+ # Ensure result has correct shape
55
+ self.assertTrue(result.shape == (M * G, N))
56
+
57
+ # Compute expected result using reference implementation
58
+ expected_result = compute_reference_forward(a, b, m_sizes)
59
+
60
+ # Verify forward pass correctness
61
+ forward_close = analyze_tensor_differences(
62
+ result, expected_result, "Forward output"
63
+ )
64
+ self.assertTrue(forward_close)
65
+
66
+ # Create a gradient for backpropagation
67
+ grad_output = torch.randn_like(result)
68
+
69
+ # Compute gradients using our custom backward implementation
70
+ grad_a, grad_b = grouped_gemm_backward(grad_output, a, b, m_sizes)
71
+
72
+ # Compute expected gradients using reference implementation
73
+ expected_grad_a, expected_grad_b = compute_reference_backward(
74
+ a, b, m_sizes, grad_output
75
+ )
76
+
77
+ # Verify gradient correctness
78
+ grad_a_close = analyze_tensor_differences(grad_a, expected_grad_a, "grad_x")
79
+ grad_b_close = analyze_tensor_differences(grad_b, expected_grad_b, "grad_w")
80
+
81
+ self.assertTrue(grad_a_close)
82
+ self.assertTrue(grad_b_close)
83
+
84
+ def test_MG_grouped_gemm_backward_bf16(self) -> None:
85
+ for G in (1, 8, 16):
86
+ for M in (512, 1024):
87
+ print(f"Testing BF16 M*G GroupGeMM Backward with G={G}, M={M}")
88
+ self._run_grouped_gemm_backward_test(
89
+ (G, M, 1024, 1024),
90
+ torch.device("cuda"),
91
+ dtype=torch.float16,
92
+ atol=1e-2,
93
+ rtol=1e-2,
94
+ )
95
+
96
+ def test_MG_grouped_gemm_backward_deepseek_shapes(self) -> None:
97
+ """Test backward pass with shapes from Deepseek model."""
98
+ deepseek_shapes = [
99
+ (4, 2048, 4096, 7168), # G, M, N, K
100
+ (4, 2048, 7168, 2048),
101
+ (8, 512, 4096, 7168),
102
+ (8, 512, 7168, 2048),
103
+ ]
104
+
105
+ device = torch.device("cuda")
106
+
107
+ for shape in deepseek_shapes:
108
+ G, M, N, K = shape
109
+ print(
110
+ f"Testing BF16 M*G Deepseek Backward shape: G={G}, M={M}, N={N}, K={K}"
111
+ )
112
+ self._run_grouped_gemm_backward_test(
113
+ shape, device, dtype=torch.float16, atol=1e-2, rtol=1e-2
114
+ )
115
+
116
+ def test_MG_dx(self) -> None:
117
+ """Test specifically the dx (gradient w.r.t. input) computation."""
118
+ G, M, N, K = 4, 512, 1024, 2048
119
+ device = torch.device("cuda")
120
+ dtype = torch.bfloat16
121
+
122
+ # Set up inputs
123
+ a = torch.randn(M * G, K, dtype=dtype, device=device, requires_grad=True)
124
+ b = torch.randn(N, K, dtype=dtype, device=device, requires_grad=True)
125
+
126
+ # Create equal-sized groups
127
+ m_size = M
128
+ m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
129
+
130
+ # Forward pass
131
+ result = grouped_gemm_forward(a, b, m_sizes)
132
+
133
+ # Create gradient for backward
134
+ grad_output = torch.randn_like(result)
135
+
136
+ # Compute gradient using our optimized function
137
+ grad_a, _ = grouped_gemm_backward(grad_output, a, b, m_sizes)
138
+
139
+ # Compute expected gradient using reference implementation
140
+ expected_grad_a, _ = compute_reference_backward(a, b, m_sizes, grad_output)
141
+
142
+ # Verify gradient
143
+ dx_close = analyze_tensor_differences(grad_a, expected_grad_a, "grad_a (dx)")
144
+ self.assertTrue(dx_close)
145
+
146
+ def test_MG_dw(self) -> None:
147
+ """Test specifically the dw (gradient w.r.t. weights) computation."""
148
+ G, M, N, K = 4, 512, 1024, 2048
149
+ device = torch.device("cuda")
150
+ dtype = torch.bfloat16
151
+
152
+ # Set up inputs
153
+ a = torch.randn(M * G, K, dtype=dtype, device=device, requires_grad=True)
154
+ b = torch.randn(N, K, dtype=dtype, device=device, requires_grad=True)
155
+
156
+ # Create equal-sized groups
157
+ m_size = M
158
+ m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
159
+
160
+ # Forward pass
161
+ result = grouped_gemm_forward(a, b, m_sizes)
162
+
163
+ # Create gradient for backward
164
+ grad_output = torch.randn_like(result)
165
+
166
+ # Compute gradient using our optimized function
167
+ _, grad_b = grouped_gemm_backward(grad_output, a, b, m_sizes)
168
+
169
+ # Compute expected gradient using reference implementation
170
+ _, expected_grad_b = compute_reference_backward(a, b, m_sizes, grad_output)
171
+
172
+ # Verify gradient
173
+ dw_close = analyze_tensor_differences(grad_b, expected_grad_b, "grad_b (dw)")
174
+ self.assertTrue(dw_close)
torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.66 kB). View file
 
torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
torchtitan/experiments/llama4/model/args.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ from torch import nn
12
+ from torchtitan.components.tokenizer import Tokenizer
13
+ from torchtitan.config_manager import JobConfig
14
+
15
+ from torchtitan.protocols.train_spec import BaseModelArgs
16
+ from torchtitan.tools.logging import logger
17
+
18
+
19
+ @dataclass
20
+ class TransformerModelArgs(BaseModelArgs):
21
+ dim: int = 4096
22
+ n_layers: int = 32
23
+ n_heads: int = 32
24
+ n_kv_heads: Optional[int] = None
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ ffn_dim_multiplier: Optional[float] = None
28
+ norm_eps: float = 1e-5
29
+ rope_theta: float = 10000
30
+
31
+ max_seq_len: int = 2048
32
+ # If `True`, then each transformer block init uses its layer ID, and if
33
+ # `False`, each uses the total number of transformer blocks
34
+ depth_init: bool = True
35
+ norm_type: str = "rmsnorm"
36
+
37
+ use_flex_attn: bool = False
38
+ attn_mask_type: str = "causal"
39
+ eos_id: int = 0
40
+
41
+ # MoE args
42
+ moe_enabled: bool = True
43
+ num_experts: int = 8
44
+ use_shared_expert: bool = True
45
+ auto_scale_hidden_dim: bool = True
46
+ # frequency of using MoE layer instead of feedforward layer in a transformer block
47
+ interleave_moe_layer_step: int = 2
48
+ # token-choice
49
+ top_k: int = 1
50
+
51
+ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
52
+ self.norm_type = job_config.model.norm_type
53
+ self.vocab_size = tokenizer.n_words
54
+ self.max_seq_len = job_config.training.seq_len
55
+ self.use_flex_attn = job_config.model.use_flex_attn
56
+
57
+ def get_nparams_and_flops(
58
+ self, model: nn.Module, seq_len: int
59
+ ) -> tuple[int, float]:
60
+ nparams_embedding = 0
61
+ nparams_moe_router = 0
62
+ nparams_shared_expert = 0
63
+ nparams_experts = 0
64
+ nparams_dense = 0
65
+
66
+ for name, p in model.named_parameters():
67
+ if "embedding" in name:
68
+ nparams_embedding += p.numel()
69
+ nparams_dense += p.numel()
70
+ elif "moe.shared_expert" in name:
71
+ nparams_shared_expert += p.numel()
72
+ elif "moe.router" in name:
73
+ nparams_moe_router += p.numel()
74
+ elif "moe.experts" in name:
75
+ nparams_experts += p.numel()
76
+ else:
77
+ nparams_dense += p.numel()
78
+
79
+ nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
80
+ nparams = nparams_dense + nparams_sparse
81
+ nparams_sparse_active = (
82
+ nparams_moe_router
83
+ + nparams_shared_expert
84
+ + nparams_experts * self.top_k // self.num_experts
85
+ )
86
+
87
+ logger.info(
88
+ f"Total parameter count: dense {nparams_dense:,}, "
89
+ f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
90
+ )
91
+
92
+ l, h, q, t = (
93
+ self.n_layers,
94
+ self.n_heads,
95
+ self.dim // self.n_heads,
96
+ seq_len,
97
+ )
98
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
99
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
100
+ # 2. the flash attention does 1 more matmul recomputation in the backward
101
+ # but recomputation should not be counted in calculating MFU (+0)
102
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
103
+ # 4. we follow the convention and do not account for sparsity in causal attention
104
+ num_flops_per_token = (
105
+ 6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
106
+ + 12 * l * h * q * t
107
+ )
108
+
109
+ return nparams, num_flops_per_token
torchtitan/experiments/multimodal/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from mm_dataset import build_mm_dataloader
8
+
9
+ from torchtitan.components.loss import build_cross_entropy_loss
10
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
11
+ from torchtitan.components.optimizer import build_optimizers
12
+ from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
13
+ from torchtitan.models.llama3 import parallelize_llama, pipeline_llama
14
+ from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
15
+
16
+ from .model import ModelArgs, MultimodalDecoder, VisionEncoder
17
+
18
+ __all__ = ["VisionEncoder", "ModelArgs", "MultimodalDecoder"]
19
+
20
+ llama4_mm_configs = {
21
+ # TODO: add configs for llama4 multimodal
22
+ }
23
+
24
+ register_train_spec(
25
+ TrainSpec(
26
+ name="llama4_multimodal",
27
+ cls=MultimodalDecoder,
28
+ config=llama4_mm_configs,
29
+ parallelize_fn=parallelize_llama,
30
+ pipelining_fn=pipeline_llama,
31
+ build_optimizers_fn=build_optimizers,
32
+ build_lr_schedulers_fn=build_lr_schedulers,
33
+ build_dataloader_fn=build_mm_dataloader,
34
+ build_tokenizer_fn=build_tiktoken_tokenizer,
35
+ build_loss_fn=build_cross_entropy_loss,
36
+ )
37
+ )
torchtitan/experiments/multimodal/tests/test_multimodal_model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import pytest
8
+ import torch
9
+
10
+ from torchtitan.experiments.llama_multimodal import (
11
+ ModelArgs,
12
+ MultimodalDecoder,
13
+ VisionEncoder,
14
+ )
15
+
16
+ from .test_utils import fixed_init_model, fixed_init_tensor
17
+
18
+
19
+ @pytest.fixture
20
+ def encoder_config():
21
+ return ModelArgs(
22
+ encoder_embed_dim=32,
23
+ encoder_num_layers=2,
24
+ encoder_num_heads=4,
25
+ tile_size=49,
26
+ patch_size=9,
27
+ max_num_tiles=4,
28
+ in_channels=3,
29
+ return_intermediates=[0, 1],
30
+ num_layers_projection=2,
31
+ decoder_embed_dim=128,
32
+ )
33
+
34
+
35
+ @pytest.fixture
36
+ def decoder_config():
37
+ return ModelArgs(
38
+ decoder_embed_dim=512,
39
+ vocab_size=10000,
40
+ fusion_interval=2,
41
+ num_special_tokens=3,
42
+ decoder_num_layers=6,
43
+ decoder_num_heads=8,
44
+ decoder_num_kv_heads=4,
45
+ max_seq_len=512,
46
+ rope_theta=50000.0,
47
+ )
48
+
49
+
50
+ class TestMultimodalModelVisionEncoder:
51
+ @pytest.fixture(autouse=True)
52
+ def setup_class(self, encoder_config):
53
+ self.model_args = encoder_config
54
+ self.batch_size = 1
55
+ self.num_imgs = 2
56
+ self.num_tiles = 4
57
+ self.aspect_ratio = torch.tensor([[1, 3], [2, 2]]).reshape(
58
+ self.batch_size, self.num_imgs, 2
59
+ )
60
+ image = torch.rand(
61
+ (
62
+ self.batch_size,
63
+ self.num_imgs,
64
+ self.num_tiles,
65
+ self.model_args.in_channels,
66
+ self.model_args.tile_size,
67
+ self.model_args.tile_size,
68
+ )
69
+ )
70
+ self.image = fixed_init_tensor(image.shape, min_val=-1, max_val=1)
71
+
72
+ def test_llama_mm_vision_encoder(self):
73
+ model = VisionEncoder(self.model_args)
74
+ fixed_init_model(model, min_val=-1, max_val=1)
75
+ output = model(self.image, self.aspect_ratio)
76
+ expected_shape = (
77
+ self.batch_size,
78
+ self.num_imgs * self.num_tiles * (model.vit.patches_per_tile + 1),
79
+ self.model_args.decoder_embed_dim,
80
+ )
81
+ assert (
82
+ output.shape == expected_shape
83
+ ), f"Expected shape {expected_shape}, but got {output.shape}"
84
+
85
+ # TODO: Need to ensure numerical stability before doing convergence test.
86
+ # output.mean() = 3.994, we need to debug why it is not close to 5.28800, which is
87
+ # the test value from the original torch tune test
88
+ # assert torch.allclose(
89
+ # output.mean(), torch.tensor(5.28800), atol=1e-3, rtol=1e-3
90
+ # )
91
+
92
+
93
+ class TestMultimodalModelDecoder:
94
+ @pytest.fixture(autouse=True)
95
+ def setup_class(self, decoder_config):
96
+ self.model_args = decoder_config
97
+ self.batch_size = 1
98
+ self.decoder_embed_dim = self.model_args.decoder_embed_dim
99
+ self.vocab_size = self.model_args.vocab_size
100
+ self.seq_len = 128
101
+ self.input = {
102
+ "tokens": torch.arange(self.batch_size * self.seq_len).reshape(
103
+ self.batch_size, self.seq_len
104
+ ),
105
+ "encoder_input": fixed_init_tensor(
106
+ (self.batch_size, self.seq_len, self.decoder_embed_dim),
107
+ min_val=-1,
108
+ max_val=1,
109
+ ),
110
+ "encoder_mask": None,
111
+ }
112
+
113
+ @torch.no_grad()
114
+ def test_llama_mm_decoder(self):
115
+ model = MultimodalDecoder(self.model_args)
116
+ fixed_init_model(model, min_val=-1, max_val=1)
117
+ output = model(**self.input)
118
+ expected_shape = (self.batch_size, self.seq_len, self.vocab_size)
119
+ assert (
120
+ output.shape == expected_shape
121
+ ), f"Expected shape {expected_shape}, but got {output.shape}"
122
+
123
+ # TODO: Need to ensure numerical stability before doing convergence test.
124
+ # output.mean() = -0.0134, we need to debug why it is not close to -9.47548e-5, which is
125
+ # the test value from the original torch tune test
126
+ # assert torch.allclose(
127
+ # output.mean(), torch.tensor(-9.47548e-5), atol=1e-3, rtol=1e-3
128
+ # )
torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.11 kB). View file
 
torchtitan/models/__pycache__/norms.cpython-312.pyc ADDED
Binary file (1.39 kB). View file