zaydzuhri commited on
Commit
86c6113
·
verified ·
1 Parent(s): 0fa019d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. flame/__pycache__/__init__.cpython-312.pyc +0 -0
  2. flame/__pycache__/config_manager.cpython-312.pyc +0 -0
  3. flame/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
  4. flame/components/checkpoint.py +59 -0
  5. flame/models/__init__.py +0 -0
  6. flame/models/__pycache__/__init__.cpython-312.pyc +0 -0
  7. flame/models/fla.toml +67 -0
  8. flame/models/parallelize_fla.py +550 -0
  9. flame/models/pipeline_fla.py +162 -0
  10. flame/tools/__pycache__/utils.cpython-312.pyc +0 -0
  11. flame/tools/utils.py +41 -0
  12. flame/utils/__init__.py +0 -0
  13. flame/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  14. flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc +0 -0
  15. flame/utils/convert_dcp_to_hf.py +66 -0
  16. flame/utils/convert_hf_to_dcp.py +34 -0
  17. flame/utils/hf_utils.py +77 -0
  18. logs/none_g37i6vbo/attempt_0/6/stderr.log +0 -0
  19. logs/none_lyv0rec_/attempt_0/0/stdout.log +33 -0
  20. logs/none_lyv0rec_/attempt_0/7/stderr.log +0 -0
  21. logs/none_lyv0rec_/attempt_0/7/stdout.log +0 -0
  22. tb/20250909-0619/wandb/debug.log +21 -0
  23. tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/files/output.log +0 -0
  24. tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/files/requirements.txt +207 -0
  25. tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug-internal.log +10 -0
  26. tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug.log +21 -0
  27. torchtitan/components/__pycache__/dataloader.cpython-312.pyc +0 -0
  28. torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
  29. torchtitan/components/__pycache__/metrics.cpython-312.pyc +0 -0
  30. torchtitan/components/__pycache__/tokenizer.cpython-312.pyc +0 -0
  31. torchtitan/components/metrics.py +435 -0
  32. torchtitan/experiments/deepseek_v3/LICENSE-CODE +21 -0
  33. torchtitan/experiments/deepseek_v3/README.md +40 -0
  34. torchtitan/experiments/deepseek_v3/checkpoint.py +154 -0
  35. torchtitan/experiments/deepseek_v3/download.py +70 -0
  36. torchtitan/experiments/deepseek_v3/model.py +1325 -0
  37. torchtitan/experiments/deepseek_v3/requirements.txt +5 -0
  38. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py +63 -0
  39. torchtitan/experiments/flux/README.md +23 -0
  40. torchtitan/experiments/flux/__pycache__/__init__.cpython-312.pyc +0 -0
  41. torchtitan/experiments/flux/dataset/flux_dataset.py +267 -0
  42. torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
  43. torchtitan/experiments/flux/model/__pycache__/layers.cpython-312.pyc +0 -0
  44. torchtitan/experiments/flux/model/hf_embedder.py +40 -0
  45. torchtitan/experiments/flux/model/model.py +177 -0
  46. torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
  47. torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
  48. torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
  49. torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py +630 -0
  50. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py +82 -0
flame/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (156 Bytes). View file
 
flame/__pycache__/config_manager.cpython-312.pyc ADDED
Binary file (36.9 kB). View file
 
flame/components/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (3.21 kB). View file
 
flame/components/checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+ from datetime import timedelta
9
+ from io import BytesIO
10
+ from typing import Any, Dict, List
11
+
12
+ import torch
13
+ from torch.distributed.checkpoint.stateful import Stateful
14
+
15
+
16
+ @dataclass
17
+ class TrainState(Stateful):
18
+ step: int = 0
19
+ skipped_step: int = 0
20
+ token: int = 0
21
+ elapsed: timedelta = timedelta(0)
22
+ global_avg_losses: List[float] = field(default_factory=list)
23
+ global_max_losses: List[float] = field(default_factory=list)
24
+ log_steps: List[int] = field(default_factory=list)
25
+
26
+ def state_dict(self) -> Dict[str, Any]:
27
+ # Only checkpoint global_avg_losses and global_max_losses per log frequency
28
+ # to avoid sync overhead in every iteration.
29
+ global_avg_losses_bytes = BytesIO()
30
+ torch.save(self.global_avg_losses, global_avg_losses_bytes)
31
+ global_max_losses_bytes = BytesIO()
32
+ torch.save(self.global_max_losses, global_max_losses_bytes)
33
+ log_steps_bytes = BytesIO()
34
+ torch.save(self.log_steps, log_steps_bytes)
35
+ return {
36
+ "step": torch.tensor(self.step, dtype=torch.int32),
37
+ "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
38
+ "token": torch.tensor(self.token, dtype=torch.int64),
39
+ "elapsed": self.elapsed,
40
+ "global_avg_losses": global_avg_losses_bytes,
41
+ "global_max_losses": global_max_losses_bytes,
42
+ "log_steps": log_steps_bytes,
43
+ }
44
+
45
+ def load_state_dict(self, state_dict) -> None:
46
+ self.step = state_dict["step"].item()
47
+ self.skipped_step = state_dict.get("skipped_step", 0).item()
48
+ self.token = state_dict["token"].item()
49
+ self.elapsed = state_dict["elapsed"]
50
+ state_dict["global_avg_losses"].seek(0)
51
+ self.global_avg_losses = torch.load(
52
+ state_dict["global_avg_losses"], weights_only=False
53
+ )
54
+ state_dict["global_max_losses"].seek(0)
55
+ self.global_max_losses = torch.load(
56
+ state_dict["global_max_losses"], weights_only=False
57
+ )
58
+ state_dict["log_steps"].seek(0)
59
+ self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
flame/models/__init__.py ADDED
File without changes
flame/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (137 Bytes). View file
 
flame/models/fla.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [model]
2
+ config = "fla-hub/transformer-1.3B-100B"
3
+ tokenizer_path = "fla-hub/transformer-1.3B-100B"
4
+
5
+ [job]
6
+ dump_folder = "exp"
7
+ print_args = true
8
+
9
+ [training]
10
+ batch_size = 32
11
+ seq_len = 2048
12
+ context_len = 2048
13
+ gradient_accumulation_steps = 1
14
+ steps = 20480
15
+ max_norm = 1.0
16
+ skip_nan_inf = true
17
+ data_parallel_replicate_degree = 1
18
+ data_parallel_shard_degree = -1
19
+ tensor_parallel_degree = 1
20
+ compile = false
21
+ dataset = "HuggingFaceFW/fineweb-edu"
22
+ dataset_name = "default"
23
+ num_workers = 32
24
+ pin_memory = false
25
+ persistent_workers = false
26
+ prefetch_factor = 2
27
+ seed = 42
28
+ varlen = false
29
+
30
+ [optimizer]
31
+ name = "AdamW"
32
+ eps = 1e-15
33
+ lr = 3e-4
34
+
35
+ [lr_scheduler]
36
+ warmup_steps = 1024
37
+ decay_type = "cosine"
38
+ lr_min = 0.1
39
+
40
+ [checkpoint]
41
+ enable_checkpoint = true
42
+ folder = "checkpoint"
43
+ interval_type = "steps"
44
+ interval = 2048
45
+ model_weights_only = false
46
+ export_dtype = "float32"
47
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
48
+
49
+ [profiling]
50
+ enable_profiling = true
51
+ save_traces_folder = "profile_trace"
52
+ profile_freq = 512
53
+
54
+ [metrics]
55
+ log_freq = 32
56
+ enable_wandb = true
57
+
58
+ [experimental]
59
+ context_parallel_degree = 1
60
+ pipeline_parallel_degree = 1
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+
66
+ [activation_checkpoint]
67
+ mode = "none"
flame/models/parallelize_fla.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
16
+ from torch.distributed._composable.replicate import replicate
17
+ from torch.distributed._tensor import Replicate, Shard
18
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
19
+ from torch.distributed.tensor.parallel import (
20
+ ColwiseParallel,
21
+ PrepareModuleInput,
22
+ PrepareModuleOutput,
23
+ RowwiseParallel,
24
+ SequenceParallel,
25
+ parallelize_module
26
+ )
27
+
28
+ from fla.modules.fused_linear_cross_entropy import LinearLossParallel
29
+ from fla.modules.mlp import SwiGLULinearParallel
30
+ from fla.modules.parallel import PrepareModuleWeight
31
+ from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
32
+ from torchtitan.distributed.parallel_dims import ParallelDims
33
+ from torchtitan.tools.logging import logger
34
+
35
+
36
+ def parallelize_fla(
37
+ model: nn.Module,
38
+ world_mesh: DeviceMesh,
39
+ parallel_dims: ParallelDims,
40
+ job_config: JobConfig,
41
+ ):
42
+ """
43
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
44
+ parallelism to the model.
45
+
46
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
47
+ the model must fit on GPU or CPU memory.
48
+ """
49
+
50
+ if parallel_dims.tp_enabled:
51
+ if (
52
+ job_config.experimental.enable_async_tensor_parallel
53
+ and not job_config.training.compile
54
+ ):
55
+ raise RuntimeError("Async TP requires --training.compile")
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8=enable_float8_linear,
62
+ enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
63
+ )
64
+
65
+ if job_config.activation_checkpoint.mode != "none":
66
+ apply_ac(model, job_config.activation_checkpoint)
67
+
68
+ # turn on per-block compile after AC wrapping and before FSDP
69
+ if job_config.training.compile:
70
+ apply_compile(model)
71
+
72
+ if (
73
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
74
+ ): # apply FSDP or HSDP, potentially with Context Parallel
75
+ if parallel_dims.dp_replicate_enabled:
76
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
77
+ else:
78
+ dp_mesh_dim_names = ("dp_shard_cp",)
79
+
80
+ apply_fsdp(
81
+ model,
82
+ world_mesh[tuple(dp_mesh_dim_names)],
83
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
84
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
85
+ pp_enabled=parallel_dims.pp_enabled,
86
+ cpu_offload=job_config.training.enable_cpu_offload,
87
+ reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
88
+ )
89
+
90
+ if parallel_dims.dp_replicate_enabled:
91
+ logger.info("Applied HSDP to the model")
92
+ else:
93
+ logger.info("Applied FSDP to the model")
94
+
95
+ if parallel_dims.cp_enabled:
96
+ logger.info("Applied Context Parallel to the model")
97
+
98
+ if job_config.training.enable_cpu_offload:
99
+ logger.info("Applied CPU Offloading to the model")
100
+ elif parallel_dims.dp_replicate_enabled:
101
+ if world_mesh.ndim > 1:
102
+ raise RuntimeError("DDP has not supported > 1D parallelism")
103
+ apply_ddp(
104
+ model,
105
+ world_mesh,
106
+ enable_compile=job_config.training.compile,
107
+ enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
108
+ )
109
+
110
+
111
+ class TPPlan:
112
+ def __init__(
113
+ self,
114
+ model=None,
115
+ loss_parallel=False,
116
+ enable_float8=False,
117
+ ):
118
+ self.model = model
119
+ self.loss_parallel = loss_parallel
120
+ self.enable_float8 = enable_float8
121
+ self.base_model_prefix = getattr(model, "base_model_prefix", "model")
122
+
123
+ # TODO(vkuzo): once float8 configuration supports delayed scaling,
124
+ # add a check here to enforce supported float8 all-gather configurations
125
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
126
+ try:
127
+ from torchao.float8.float8_tensor_parallel import (
128
+ Float8ColwiseParallel,
129
+ Float8RowwiseParallel,
130
+ PrepareFloat8ModuleInput
131
+ )
132
+ except ImportError:
133
+ Float8ColwiseParallel = None
134
+ Float8RowwiseParallel = None
135
+ PrepareFloat8ModuleInput = None
136
+ if self.enable_float8 and Float8ColwiseParallel is not None:
137
+ self.rowwise_parallel = Float8RowwiseParallel
138
+ self.colwise_parallel = Float8ColwiseParallel
139
+ self.prepare_module_input = PrepareFloat8ModuleInput
140
+ self.prepare_module_output = PrepareModuleOutput
141
+ else:
142
+ self.rowwise_parallel = RowwiseParallel
143
+ self.colwise_parallel = ColwiseParallel
144
+ self.prepare_module_input = PrepareModuleInput
145
+ self.prepare_module_output = PrepareModuleOutput
146
+
147
+ @property
148
+ def model_plan(self):
149
+ plans = {
150
+ f"{self.base_model_prefix}.embeddings": RowwiseParallel(
151
+ input_layouts=Replicate(),
152
+ output_layouts=Shard(1),
153
+ ),
154
+ f"{self.base_model_prefix}.norm": SequenceParallel(),
155
+ }
156
+ if self.loss_parallel:
157
+ plans.update(
158
+ {
159
+ "lm_head": ColwiseParallel(
160
+ input_layouts=Shard(1),
161
+ output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
162
+ use_local_output=not self.loss_parallel,
163
+ ),
164
+ }
165
+ )
166
+ else:
167
+ plans.update(
168
+ {
169
+ "lm_head": PrepareModuleWeight(layouts=Replicate()),
170
+ "criterion": LinearLossParallel(),
171
+ }
172
+ )
173
+ return plans
174
+
175
+ @property
176
+ def layer_plan(self):
177
+ return {
178
+ "attn_norm": SequenceParallel(),
179
+ **self.attn_plan,
180
+ "mlp_norm": SequenceParallel(),
181
+ **self.mlp_plan,
182
+ }
183
+
184
+ @property
185
+ def attn_plan(self):
186
+ raise NotImplementedError(
187
+ f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
188
+ )
189
+
190
+ @property
191
+ def mlp_plan(self):
192
+ return {
193
+ "mlp": self.prepare_module_input(
194
+ input_layouts=(Shard(1),),
195
+ desired_input_layouts=(Replicate(),),
196
+ ),
197
+ "mlp.gate_proj": self.colwise_parallel(),
198
+ "mlp.up_proj": self.colwise_parallel(),
199
+ "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
200
+ "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
201
+ }
202
+
203
+
204
+ class TransformerTPPlan(TPPlan):
205
+
206
+ @property
207
+ def attn_plan(self):
208
+ return {
209
+ "attn": self.prepare_module_input(
210
+ input_kwarg_layouts={"hidden_states": Shard(1)},
211
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
212
+ ),
213
+ "attn.q_proj": self.colwise_parallel(),
214
+ "attn.k_proj": self.colwise_parallel(),
215
+ "attn.v_proj": self.colwise_parallel(),
216
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
217
+ }
218
+
219
+
220
+ class GLATPPlan(TPPlan):
221
+
222
+ @property
223
+ def attn_plan(self):
224
+ return {
225
+ "attn": self.prepare_module_input(
226
+ input_kwarg_layouts={"hidden_states": Shard(1)},
227
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
228
+ ),
229
+ "attn.q_proj": self.colwise_parallel(),
230
+ "attn.k_proj": self.colwise_parallel(),
231
+ "attn.v_proj": self.colwise_parallel(),
232
+ "attn.g_proj": self.colwise_parallel(),
233
+ "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
234
+ "attn.gk_proj.1": self.colwise_parallel(),
235
+ "attn.g_norm": SequenceParallel(sequence_dim=-1),
236
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
237
+ }
238
+
239
+
240
+ TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
241
+
242
+
243
+ def apply_tp(
244
+ model: nn.Module,
245
+ tp_mesh: DeviceMesh,
246
+ loss_parallel: bool,
247
+ enable_float8: bool,
248
+ enable_async_tp: bool,
249
+ ):
250
+ """Apply tensor parallelism."""
251
+ # 1. Parallelize the embedding and shard its outputs (which are the first
252
+ # transformer block's inputs)
253
+ # 2. Parallelize the root norm layer over the sequence dim
254
+ # 3. Parallelize the final linear output layer
255
+ tp_plan = TP_PLAN_MAP[model.config.model_type](
256
+ model, loss_parallel=loss_parallel, enable_float8=enable_float8
257
+ )
258
+ parallelize_module(model, tp_mesh, tp_plan.model_plan)
259
+
260
+ blocks = get_blocks(model)
261
+ if blocks is None:
262
+ logger.warning("No block found for tensor parallelism")
263
+ else:
264
+ for _, block in enumerate(blocks):
265
+ parallelize_module(
266
+ module=block,
267
+ device_mesh=tp_mesh,
268
+ parallelize_plan=tp_plan.layer_plan,
269
+ )
270
+
271
+ if enable_async_tp:
272
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
273
+
274
+ torch._inductor.config._micro_pipeline_tp = True
275
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
276
+
277
+ logger.info(
278
+ f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
279
+ "Tensor Parallelism to the model"
280
+ )
281
+
282
+
283
+ # for selective op activation checkpointing
284
+ _save_list = {
285
+ torch.ops.aten.mm.default,
286
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
287
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
288
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
289
+ # for low precision training, it's useful to always save
290
+ # the result of max, since the absolute maximum is
291
+ # used to compute the scaling factor for quantization.
292
+ torch.ops.aten.max.default,
293
+ }
294
+
295
+
296
+ def _apply_ac_to_block(module: nn.Module, ac_config):
297
+ valid_ac_modes = ("full", "selective")
298
+ if ac_config.mode not in valid_ac_modes:
299
+ raise ValueError(
300
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
301
+ )
302
+
303
+ if ac_config.mode == "full":
304
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
305
+
306
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
307
+ use_op_sac = ac_config.selective_ac_option == "op"
308
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
309
+ if not use_op_sac and not use_layer_sac:
310
+ raise ValueError(
311
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
312
+ f"Valid options: 'op' or a positive int representing layer frequency"
313
+ )
314
+ if use_op_sac:
315
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
316
+
317
+ def _get_custom_policy(meta):
318
+ def _custom_policy(ctx, func, *args, **kwargs):
319
+ mode = "recompute" if ctx.is_recompute else "forward"
320
+ mm_count_key = f"{mode}_mm_count"
321
+ if func == torch.ops.aten.mm.default:
322
+ meta[mm_count_key] += 1
323
+ # Saves output of all compute ops, except every second mm
324
+ to_save = func in _save_list and not (
325
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
326
+ )
327
+ return (
328
+ CheckpointPolicy.MUST_SAVE
329
+ if to_save
330
+ else CheckpointPolicy.PREFER_RECOMPUTE
331
+ )
332
+
333
+ return _custom_policy
334
+
335
+ def selective_checkpointing_context_fn():
336
+ meta = defaultdict(int)
337
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
338
+
339
+ return ptd_checkpoint_wrapper(
340
+ module,
341
+ context_fn=selective_checkpointing_context_fn,
342
+ preserve_rng_state=False,
343
+ )
344
+ elif use_layer_sac:
345
+ # Checkpoint every `ac_freq` of the modules passed to this function
346
+ ac_freq = int(ac_config.selective_ac_option)
347
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
348
+ ptd_checkpoint_wrapper._count += 1
349
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
350
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
351
+ else:
352
+ return module
353
+
354
+
355
+ def apply_ac(model: nn.Module, ac_config):
356
+ """Apply activation checkpointing to the model."""
357
+ blocks = get_blocks(model)
358
+ if blocks is None:
359
+ logger.warning("No block found for activation checkpointing")
360
+ return
361
+
362
+ for layer_id, block in blocks.named_children():
363
+ block = _apply_ac_to_block(block, ac_config)
364
+ blocks.register_module(layer_id, block)
365
+
366
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
367
+
368
+
369
+ def apply_compile(model: nn.Module):
370
+ """
371
+ Apply torch.compile to each block, which makes compilation efficient due to
372
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
373
+ """
374
+
375
+ blocks = get_blocks(model)
376
+ if blocks is None:
377
+ logger.warning("No block found for torch.compile")
378
+ else:
379
+ for layer_id, block in blocks.named_children():
380
+ block = torch.compile(block)
381
+ blocks.register_module(layer_id, block)
382
+ logger.info("Compiling each block with torch.compile")
383
+
384
+ real_model = get_model(model)
385
+
386
+ logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
387
+ embeddings_key = get_components_name(real_model, "tok_embeddings")
388
+ if embeddings_key is not None:
389
+ embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
390
+ real_model.register_module(embeddings_key, embeddings)
391
+
392
+ norm_key = get_components_name(real_model, "norm")
393
+ if norm_key is not None:
394
+ norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
395
+ real_model.register_module(norm_key, norm)
396
+
397
+ lm_head_key = get_components_name(model, "lm_head")
398
+ if lm_head_key is not None:
399
+ lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
400
+ model.register_module(lm_head_key, lm_head)
401
+
402
+ logger.info("Compiling the entire model with torch.compile")
403
+ model = torch.compile(model)
404
+
405
+
406
+ def apply_fsdp(
407
+ model: nn.Module,
408
+ dp_mesh: DeviceMesh,
409
+ param_dtype: torch.dtype,
410
+ reduce_dtype: torch.dtype,
411
+ pp_enabled: bool,
412
+ cpu_offload: bool = False,
413
+ reshard_after_forward_policy: str = "default",
414
+ ):
415
+ """
416
+ Apply data parallelism (via FSDP2) to the model.
417
+
418
+ Args:
419
+ model (nn.Module): The model to apply data parallelism to.
420
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
421
+ param_dtype (torch.dtype): The data type to use for model parameters.
422
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
423
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
424
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
425
+ reshard_after_forward_policy (str, optional):
426
+ The policy to use for resharding after forward pass. Defaults to "default".
427
+ Other options: "never", "always".
428
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
429
+ - "always" will enable `reshard_after_forward` for all forward passes.
430
+ - "never" will disable `reshard_after_forward` for all forward passes.
431
+
432
+ """
433
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
434
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
435
+ if cpu_offload:
436
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
437
+
438
+ blocks = get_blocks(model)
439
+ if blocks is None:
440
+ logger.warning("No block found for FSDP")
441
+ else:
442
+ total_blocks = len(blocks)
443
+ for layer_id, block in enumerate(blocks):
444
+ if reshard_after_forward_policy == "always":
445
+ reshard_after_forward = True
446
+ elif reshard_after_forward_policy == "never":
447
+ reshard_after_forward = False
448
+ elif reshard_after_forward_policy == "default":
449
+ if pp_enabled:
450
+ # For PP, do not reshard after forward to avoid per-microbatch
451
+ # all-gathers, which can be expensive and non-overlapped
452
+ reshard_after_forward = False
453
+ else:
454
+ # As an optimization, do not reshard after forward for the last
455
+ # transformer block since FSDP would prefetch it immediately
456
+ reshard_after_forward = int(layer_id) < total_blocks - 1
457
+ else:
458
+ raise ValueError(
459
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
460
+ )
461
+ fully_shard(
462
+ block,
463
+ **fsdp_config,
464
+ reshard_after_forward=reshard_after_forward,
465
+ )
466
+
467
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
468
+
469
+
470
+ def apply_ddp(
471
+ model: nn.Module,
472
+ dp_mesh: DeviceMesh,
473
+ enable_compile: bool,
474
+ enable_compiled_autograd: bool,
475
+ ):
476
+ if enable_compile:
477
+ if enable_compiled_autograd:
478
+ torch._dynamo.config.optimize_ddp = (
479
+ "python_reducer_without_compiled_forward"
480
+ )
481
+ else:
482
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
483
+
484
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
485
+
486
+ logger.info("Applied DDP to the model")
487
+
488
+
489
+ def get_model(model):
490
+ base_model_prefix = getattr(model, "base_model_prefix", "model")
491
+ if not hasattr(model, base_model_prefix):
492
+ return None
493
+ model = getattr(model, base_model_prefix)
494
+ return model
495
+
496
+
497
+ def get_blocks(model):
498
+ # TODO[flame]: adapt for network not using 'layers' attribute
499
+ model = get_model(model)
500
+ if not hasattr(model, "layers"):
501
+ logger.warning('no "layers" in model can be found')
502
+ return None
503
+ return model.layers
504
+
505
+
506
+ def get_components_name(model, component_name):
507
+ """
508
+ We try to catch tok_embeddings, norm layers and lm_head layers
509
+ We do not catch the layer names in the blocks, for blocks see `get_blocks`
510
+ We assume the model has the following structure:
511
+ LlamaForCausalLM:
512
+ Model:
513
+ embed_tokens,
514
+ layers,
515
+ norm,
516
+ lm_head
517
+ ***
518
+ so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
519
+ and for 'lm_head' we need to pass `model`
520
+ ***
521
+ """
522
+
523
+ if component_name == "tok_embeddings":
524
+ if hasattr(model, "tok_embeddings"):
525
+ return "tok_embeddings"
526
+ elif hasattr(model, "embed_tokens"):
527
+ return "embed_tokens"
528
+ elif hasattr(model, "embeddings"):
529
+ return "embeddings"
530
+ else:
531
+ logger.warning("No tok_embeddings found in model")
532
+ return None
533
+
534
+ elif component_name == "norm":
535
+ if hasattr(model, "norm"):
536
+ return "norm"
537
+ elif hasattr(model, "norms"):
538
+ return "norms"
539
+ elif hasattr(model, "layernorm"):
540
+ return "layernorm"
541
+ else:
542
+ logger.warning("No norm found in model")
543
+ return None
544
+
545
+ elif component_name == "lm_head":
546
+ if hasattr(model, "lm_head"):
547
+ return "lm_head"
548
+ else:
549
+ logger.warning("No lm_head found in model")
550
+ return None
flame/models/pipeline_fla.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D pipeline parallelism to the Llama model.
8
+
9
+ import copy
10
+ from typing import Callable, Optional, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.pipelining import PipelineStage
16
+ from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
17
+ from transformers import PretrainedConfig
18
+
19
+ from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
20
+ from torchtitan.config_manager import JobConfig
21
+ from torchtitan.distributed.parallel_dims import ParallelDims
22
+ from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
23
+ from torchtitan.tools.logging import logger
24
+
25
+ DeviceType = Union[int, str, torch.device]
26
+
27
+
28
+ def pipeline_fla(
29
+ model: nn.Module,
30
+ pp_mesh: DeviceMesh,
31
+ parallel_dims: ParallelDims,
32
+ job_config: JobConfig,
33
+ device: DeviceType,
34
+ model_config: PretrainedConfig,
35
+ loss_fn: Callable[..., torch.Tensor],
36
+ ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
37
+ stages, models = pipeline_fla_manual_split(
38
+ model, pp_mesh, parallel_dims, job_config, device, model_config
39
+ )
40
+
41
+ pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
42
+
43
+ # This is used in the train loop to determine whether to pass in the input_ids and labels
44
+ has_first_stage = False
45
+ has_last_stage = False
46
+ for stage in stages:
47
+ if stage.is_first:
48
+ has_first_stage = True
49
+ if stage.is_last:
50
+ has_last_stage = True
51
+
52
+ return pp_schedule, models, has_first_stage, has_last_stage
53
+
54
+
55
+ def pipeline_fla_manual_split(
56
+ whole_model: nn.Module,
57
+ pp_mesh: DeviceMesh,
58
+ parallel_dims: ParallelDims,
59
+ job_config: JobConfig,
60
+ device: DeviceType,
61
+ model_config: PretrainedConfig,
62
+ ) -> tuple[list[PipelineStage], list[nn.Module]]:
63
+ """
64
+ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
65
+
66
+ It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
67
+
68
+ The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
69
+ parallelism.
70
+ """
71
+ pp_rank = pp_mesh.get_local_rank()
72
+ pp_size = pp_mesh.size()
73
+
74
+ splits = (
75
+ job_config.experimental.pipeline_parallel_split_points
76
+ or generate_split_points(
77
+ job_config, parallel_dims.pp, model_config.num_hidden_layers
78
+ )
79
+ )
80
+
81
+ def _build_stage(
82
+ stage_idx: int,
83
+ start_layer: Optional[str],
84
+ stop_layer: Optional[str],
85
+ is_first: bool = False,
86
+ is_last: bool = False,
87
+ ) -> tuple[PipelineStage, nn.Module]:
88
+ model = copy.deepcopy(whole_model)
89
+ if not is_first:
90
+ # we do `model.tok_embeddings = None` here
91
+ real_model = get_model(model)
92
+ tok_embeddings_name = get_components_name(real_model, "tok_embeddings")
93
+ setattr(real_model, tok_embeddings_name, None)
94
+
95
+ drop_layers = start_layer is not None
96
+ # Get module dictionary from get_blocks(model)
97
+ # and Create a list of keys before modifying dictionary
98
+ module_dict = get_blocks(model)._modules # Store reference
99
+ layer_names = list(module_dict.keys())
100
+
101
+ # Iterate over the list of keys instead of `_modules.items()`
102
+ for name in layer_names:
103
+ # Dynamically determine prefix (blocks.* or layers.*)
104
+ prefix = start_layer.split(".")[0] if start_layer else "layers"
105
+ layer_name = f"{prefix}.{name}" # Construct the correct name format
106
+
107
+ # Ensure `drop_layers` activation is based on actual naming
108
+ if layer_name == start_layer:
109
+ drop_layers = False
110
+ if layer_name == stop_layer:
111
+ drop_layers = True
112
+
113
+ # Delete layer if drop_layers is active
114
+ if drop_layers:
115
+ del module_dict[name] # Safe deletion from stored dictionary
116
+
117
+ if not is_last:
118
+ # we do `model.norm = None` and `model.output = None`
119
+ real_model = get_model(model)
120
+ norm_name = get_components_name(real_model, "norm")
121
+ setattr(real_model, norm_name, None)
122
+
123
+ head_name = get_components_name(model, "lm_head")
124
+ setattr(model, head_name, None)
125
+
126
+ stage = PipelineStage(
127
+ model,
128
+ stage_idx,
129
+ num_stages,
130
+ device,
131
+ group=pp_mesh.get_group("pp"),
132
+ )
133
+ return stage, model
134
+
135
+ num_stages = len(splits) + 1
136
+ stage_idx = pp_rank
137
+
138
+ stages = []
139
+ models = []
140
+
141
+ schedule_class = get_schedule_class(
142
+ job_config.experimental.pipeline_parallel_schedule
143
+ )
144
+ style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
145
+
146
+ for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
147
+ start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
148
+ stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
149
+ stage, model_chunk = _build_stage(
150
+ stage_idx,
151
+ start_layer,
152
+ stop_layer,
153
+ is_first=stage_idx == 0,
154
+ is_last=stage_idx == num_stages - 1,
155
+ )
156
+ logger.info(
157
+ f"PP rank {pp_rank} is building stage_idx {stage_idx}"
158
+ f" with start_layer {start_layer}, stop_layer {stop_layer}"
159
+ )
160
+ stages.append(stage)
161
+ models.append(model_chunk)
162
+ return stages, models
flame/tools/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.14 kB). View file
 
flame/tools/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+ from torchtitan.tools.logging import logger
9
+
10
+
11
+ def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
12
+ nparams = sum(p.numel() for p in model.parameters())
13
+ nparams_embedding = sum(
14
+ sum(p.numel() for p in m.parameters())
15
+ for m in model.children()
16
+ if isinstance(m, nn.Embedding)
17
+ )
18
+
19
+ if hasattr(model_config, "num_heads"):
20
+ num_heads = model_config.num_heads
21
+ elif hasattr(model_config, "num_attention_heads"):
22
+ num_heads = model_config.num_attention_heads
23
+ else:
24
+ num_heads = 1
25
+ logger.warning("num_heads not found in model_config, defaulting to 1. ")
26
+
27
+ l, h, q, t = (
28
+ model_config.num_hidden_layers,
29
+ num_heads,
30
+ model_config.hidden_size // num_heads,
31
+ seq_len,
32
+ )
33
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
34
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
35
+ # 2. the flash attention does 1 more matmul recomputation in the backward
36
+ # but recomputation should not be counted in calculating MFU (+0)
37
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
38
+ # 4. we follow the convention and do not account for sparsity in causal attention
39
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
40
+
41
+ return nparams, num_flops_per_token
flame/utils/__init__.py ADDED
File without changes
flame/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (136 Bytes). View file
 
flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc ADDED
Binary file (3.73 kB). View file
 
flame/utils/convert_dcp_to_hf.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ import io
6
+ import os
7
+ import tempfile
8
+ from datetime import timedelta
9
+
10
+ import torch
11
+ import torch.serialization
12
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
13
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
14
+
15
+ import fla # noqa
16
+ from torchtitan.tools.logging import init_logger, logger
17
+
18
+
19
+ @torch.inference_mode()
20
+ def save_pretrained(
21
+ path: str,
22
+ step: int,
23
+ config: str,
24
+ tokenizer: str
25
+ ):
26
+ logger.info(f"Loading the config from {config}")
27
+ config = AutoConfig.from_pretrained(config, trust_remote_code=True)
28
+
29
+ logger.info(f"Saving the config to {path}")
30
+ config.save_pretrained(path)
31
+ logger.info(f"Loading the tokenizer from {tokenizer}")
32
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
33
+ logger.info(f"Saving the tokenizer to {path}")
34
+ tokenizer.save_pretrained(path)
35
+
36
+ with tempfile.TemporaryDirectory() as tmpdir:
37
+ # base_checkpoint_dir = os.path.dirname(path)
38
+ base_checkpoint_dir = path
39
+ checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}')
40
+ checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
41
+ logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
42
+ dcp_to_torch_save(checkpoint, checkpoint_path)
43
+
44
+ logger.info(f"Initializing the model from config\n{config}")
45
+ model = AutoModelForCausalLM.from_config(config)
46
+ logger.info(model)
47
+ logger.info("Loading state dict from the checkpoint")
48
+
49
+ # Add datetime.timedelta and io.BytesIO to safe globals
50
+ torch.serialization.add_safe_globals([timedelta, io.BytesIO])
51
+ # torch.load now with default weights_only=True will work
52
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
53
+
54
+ logger.info(f"Saving the model to {path}")
55
+ model.save_pretrained(path)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ init_logger()
60
+ parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
61
+ parser.add_argument("--path", type=str, required=True)
62
+ parser.add_argument("--step", type=int, required=True)
63
+ parser.add_argument("--config", type=str, required=True)
64
+ parser.add_argument("--tokenizer", type=str, required=True)
65
+ args = parser.parse_args()
66
+ save_pretrained(args.path, args.step, args.config, args.tokenizer)
flame/utils/convert_hf_to_dcp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.distributed.checkpoint as DCP
9
+ from transformers import AutoModelForCausalLM
10
+
11
+ import fla # noqa
12
+ from torchtitan.tools.logging import init_logger, logger
13
+
14
+
15
+ @torch.inference_mode()
16
+ def convert_hf_weights(model: str, checkpoint: str):
17
+ logger.info(f"Loading model from {model}")
18
+ model = AutoModelForCausalLM.from_pretrained(model)
19
+ state_dict = model.state_dict()
20
+
21
+ logger.info(f"Writing to DCP at '{checkpoint}'")
22
+ checkpoint.mkdir(parents=True, exist_ok=True)
23
+ storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
24
+ DCP.save({"model": state_dict}, storage_writer=storage_writer)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ init_logger()
29
+ parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
30
+ parser.add_argument("--model", type=str, required=True)
31
+ parser.add_argument("--checkpoint", type=Path, required=True)
32
+ args = parser.parse_args()
33
+
34
+ convert_hf_weights(args.model, args.checkpoint)
flame/utils/hf_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo
4
+ from torchtitan.tools.logging import logger
5
+
6
+ def upload_checkpoint_to_hf(
7
+ local_path: str,
8
+ step: int,
9
+ hf_repo_id_for_run: str,
10
+ hf_keep_latest_k: int,
11
+ upload_format: str
12
+ ):
13
+ """Uploads a checkpoint directory to HF Hub and manages retention."""
14
+ if not os.path.isdir(local_path):
15
+ logger.error(f"Local path for upload does not exist or is not a directory: {local_path}")
16
+ return
17
+
18
+ api = HfApi()
19
+ token = HfFolder.get_token()
20
+ if not token:
21
+ logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.")
22
+ return
23
+
24
+ # --- Ensure the specific repository for this run exists ---
25
+ try:
26
+ logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...")
27
+ # Use create_repo which handles creation only if it doesn't exist
28
+ create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True)
29
+ logger.info(f"Repository {hf_repo_id_for_run} ensured.")
30
+ except Exception as e:
31
+ logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True)
32
+ return # Stop if repo interaction fails
33
+
34
+ commit_message = f"Upload {upload_format.upper()} checkpoint step {step}"
35
+ path_in_repo = f"step-{step}"
36
+
37
+ logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...")
38
+ try:
39
+ api.upload_folder(
40
+ folder_path=local_path,
41
+ path_in_repo=path_in_repo,
42
+ repo_id=hf_repo_id_for_run,
43
+ repo_type="model",
44
+ commit_message=commit_message,
45
+ token=token,
46
+ )
47
+ logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.")
48
+ except Exception as e:
49
+ logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True)
50
+ if hf_keep_latest_k > 0:
51
+ logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}")
52
+ try:
53
+ repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False)
54
+ step_folders = [
55
+ item.path for item in repo_files
56
+ if item.path.startswith("step-") and item.path[5:].isdigit()
57
+ ]
58
+
59
+ step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True)
60
+
61
+ if len(step_folders) > hf_keep_latest_k:
62
+ folders_to_delete = step_folders[hf_keep_latest_k:]
63
+ logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}")
64
+ for folder in folders_to_delete:
65
+ # Deleting requires repo_id, path_in_repo, and token
66
+ api.delete_folder(
67
+ repo_id=hf_repo_id_for_run,
68
+ path_in_repo=folder,
69
+ repo_type="model",
70
+ commit_message=f"Delete old checkpoint {folder}",
71
+ token=token
72
+ )
73
+ logger.info("Hub cleanup complete.")
74
+ else:
75
+ logger.info("No old checkpoints found on Hub to delete.")
76
+ except Exception as e:
77
+ logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True)
logs/none_g37i6vbo/attempt_0/6/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_lyv0rec_/attempt_0/0/stdout.log ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-09-10T00:25:50.402942Z  WARN Status Code: 502. Retrying..., request_id: ""
2
+ at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
3
+
4
+ 2025-09-10T00:25:50.448322Z  WARN Status Code: 502. Retrying..., request_id: ""
5
+ at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
6
+
7
+ 2025-09-10T00:26:01.892901Z  WARN Status Code: 504. Retrying..., request_id: ""
8
+ at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
9
+
10
+ 2025-09-10T00:26:01.894451Z  WARN Status Code: 504. Retrying..., request_id: ""
11
+ at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
12
+
13
+ 2025-09-10T00:26:46.358405Z  WARN Status Code: 504. Retrying..., request_id: ""
14
+ at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
15
+
16
+ 2025-09-10T00:26:50.304225Z  WARN Status Code: 502. Retrying..., request_id: ""
17
+ at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
18
+
19
+ 2025-09-10T00:27:00.830860Z  WARN Status Code: 504. Retrying..., request_id: ""
20
+ at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
21
+
22
+ 2025-09-10T00:28:33.662622Z  WARN Status Code: 502. Retrying..., request_id: ""
23
+ at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
24
+
25
+ 2025-09-10T00:37:21.678500Z  WARN Status Code: 502. Retrying..., request_id: ""
26
+ at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
27
+
28
+ 2025-09-10T00:37:33.396089Z  WARN Status Code: 504. Retrying..., request_id: ""
29
+ at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
30
+
31
+ 2025-09-10T00:38:21.672469Z  WARN Status Code: 502. Retrying..., request_id: ""
32
+ at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
33
+
logs/none_lyv0rec_/attempt_0/7/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_lyv0rec_/attempt_0/7/stdout.log ADDED
File without changes
tb/20250909-0619/wandb/debug.log ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Current SDK version is 0.21.0
2
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Configure stats pid to 795439
3
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/.config/wandb/settings
4
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/flame/wandb/settings
5
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from environment variables
6
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:setup_run_log_directory():703] Logging user logs to exp/top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug.log
7
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to exp/top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug-internal.log
8
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():830] calling init triggers
9
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():835] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {}}
11
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():871] starting backend
12
+ 2025-09-09 06:19:20,025 INFO MainThread:795439 [wandb_init.py:init():874] sending inform_init request
13
+ 2025-09-09 06:19:20,027 INFO MainThread:795439 [wandb_init.py:init():882] backend started and connected
14
+ 2025-09-09 06:19:20,033 INFO MainThread:795439 [wandb_init.py:init():953] updated telemetry
15
+ 2025-09-09 06:19:20,039 INFO MainThread:795439 [wandb_init.py:init():977] communicating run to backend with 90.0 second timeout
16
+ 2025-09-09 06:19:20,682 INFO MainThread:795439 [wandb_init.py:init():1029] starting run threads in backend
17
+ 2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_console_start():2458] atexit reg
18
+ 2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2306] redirect: wrap_raw
19
+ 2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2375] Wrapping output streams.
20
+ 2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2398] Redirects installed.
21
+ 2025-09-09 06:19:20,817 INFO MainThread:795439 [wandb_init.py:init():1075] run started, returning control to user process
tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/files/output.log ADDED
The diff for this file is too large to render. See raw diff
 
tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/files/requirements.txt ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flame==0.1.0
2
+ pluggy==1.6.0
3
+ triton==3.2.0
4
+ sympy==1.13.1
5
+ wcwidth==0.2.13
6
+ nvidia-cusolver-cu12==11.6.1.9
7
+ peft==0.17.0
8
+ smart_open==7.3.0.post1
9
+ cymem==2.0.11
10
+ spacy-legacy==3.0.12
11
+ h11==0.16.0
12
+ pytablewriter==1.2.1
13
+ idna==3.10
14
+ regex==2025.7.34
15
+ antlr4-python3-runtime==4.13.2
16
+ wandb==0.21.0
17
+ nvidia-cuda-cupti-cu12==12.4.127
18
+ sentencepiece==0.2.1
19
+ zstandard==0.23.0
20
+ pybind11==3.0.0
21
+ inquirerpy==0.3.4
22
+ contourpy==1.3.3
23
+ Pygments==2.19.2
24
+ sniffio==1.3.1
25
+ Jinja2==3.1.6
26
+ packaging==25.0
27
+ Markdown==3.8.2
28
+ astunparse==1.6.3
29
+ spacy==3.8.7
30
+ pyparsing==3.2.3
31
+ networkx==3.5
32
+ ninja==1.11.1.4
33
+ tf-slim==1.1.0
34
+ PyYAML==6.0.2
35
+ smmap==5.0.2
36
+ tiktoken==0.9.0
37
+ flatbuffers==25.2.10
38
+ tensorflow==2.20.0
39
+ langcodes==3.5.0
40
+ nvidia-cuda-nvrtc-cu12==12.4.127
41
+ numexpr==2.11.0
42
+ charset-normalizer==3.4.3
43
+ frozenlist==1.7.0
44
+ setuptools==80.9.0
45
+ cycler==0.12.1
46
+ weasel==0.4.1
47
+ tzdata==2025.2
48
+ sacrebleu==2.5.1
49
+ rouge_score==0.1.2
50
+ requests==2.32.5
51
+ nvidia-nvjitlink-cu12==12.4.127
52
+ grpcio==1.74.0
53
+ nvidia-cusparse-cu12==12.3.1.170
54
+ mdurl==0.1.2
55
+ pandas==2.3.1
56
+ preshed==3.0.10
57
+ attrs==25.3.0
58
+ tensorboard-data-server==0.7.2
59
+ aiohappyeyeballs==2.6.1
60
+ keras==3.11.2
61
+ wrapt==1.17.3
62
+ aiosignal==1.4.0
63
+ tcolorpy==0.1.7
64
+ platformdirs==4.3.8
65
+ tqdm-multiprocess==0.0.11
66
+ python-dotenv==1.1.1
67
+ wasabi==1.1.3
68
+ google-pasta==0.2.0
69
+ optree==0.17.0
70
+ MarkupSafe==3.0.2
71
+ colorlog==6.9.0
72
+ nvidia-cufft-cu12==11.2.1.3
73
+ lm_eval==0.4.9.1
74
+ lxml==6.0.0
75
+ protobuf==6.32.0
76
+ radgraph==0.1.18
77
+ scipy==1.16.1
78
+ click==8.2.1
79
+ wheel==0.45.1
80
+ marisa-trie==1.3.0
81
+ pathvalidate==3.3.1
82
+ nvidia-nccl-cu12==2.21.5
83
+ evaluate==0.4.5
84
+ nvidia-cuda-runtime-cu12==12.4.127
85
+ transformers==4.51.3
86
+ aenum==3.1.15
87
+ typing-inspection==0.4.1
88
+ gitdb==4.0.12
89
+ iniconfig==2.1.0
90
+ multidict==6.6.3
91
+ huggingface-hub==0.34.4
92
+ tokenizers==0.21.4
93
+ tabledata==1.3.4
94
+ mbstrdecoder==1.1.4
95
+ Werkzeug==3.1.3
96
+ accelerate==1.10.0
97
+ hf-xet==1.1.8
98
+ tensorboard==2.20.0
99
+ ml_dtypes==0.5.3
100
+ pytest==8.4.1
101
+ namex==0.1.0
102
+ pillow==11.3.0
103
+ datasets==3.6.0
104
+ tqdm==4.67.1
105
+ murmurhash==1.0.13
106
+ fonttools==4.59.1
107
+ absl-py==2.3.1
108
+ multiprocess==0.70.16
109
+ fsspec==2025.3.0
110
+ transformers==4.51.3
111
+ dill==0.3.8
112
+ propcache==0.3.2
113
+ jsonpickle==4.1.1
114
+ BLEURT==0.0.2
115
+ yarl==1.20.1
116
+ portalocker==3.2.0
117
+ httpx==0.27.2
118
+ numpy==2.3.2
119
+ mpmath==1.3.0
120
+ pyarrow==21.0.0
121
+ matplotlib==3.10.5
122
+ typepy==1.3.4
123
+ pycountry==24.6.1
124
+ word2number==1.1
125
+ psutil==7.0.0
126
+ catalogue==2.0.10
127
+ latex2sympy2_extended==1.0.6
128
+ pydantic_core==2.33.2
129
+ threadpoolctl==3.6.0
130
+ spacy-loggers==1.0.5
131
+ certifi==2025.8.3
132
+ confection==0.1.5
133
+ flame==0.1.0
134
+ pfzy==0.3.4
135
+ safetensors==0.6.2
136
+ pip==25.1
137
+ DataProperty==1.1.0
138
+ lighteval==0.10.1.dev0
139
+ jsonlines==4.0.0
140
+ scikit-learn==1.7.1
141
+ torch==2.6.0
142
+ pytz==2025.2
143
+ python-dateutil==2.9.0.post0
144
+ nltk==3.9.1
145
+ sqlitedict==2.1.0
146
+ gast==0.6.0
147
+ nvidia-curand-cu12==10.3.5.147
148
+ rich==14.1.0
149
+ sentry-sdk==2.33.2
150
+ nvidia-cusparselt-cu12==0.6.2
151
+ kiwisolver==1.4.9
152
+ appdirs==1.4.4
153
+ bert-score==0.3.13
154
+ blis==1.3.0
155
+ GitPython==3.1.45
156
+ chardet==5.2.0
157
+ more-itertools==10.7.0
158
+ filelock==3.19.1
159
+ transformers==4.51.3
160
+ httpcore==1.0.9
161
+ termcolor==3.1.0
162
+ typer==0.16.1
163
+ einops==0.8.1
164
+ torchdata==0.11.0
165
+ six==1.17.0
166
+ colorama==0.4.6
167
+ aiohttp==3.12.14
168
+ srsly==2.5.1
169
+ urllib3==2.5.0
170
+ nvidia-cublas-cu12==12.4.5.8
171
+ cloudpathlib==0.21.1
172
+ h5py==3.14.0
173
+ thinc==8.3.6
174
+ markdown-it-py==4.0.0
175
+ flash-attn==2.7.3
176
+ prompt_toolkit==3.0.52
177
+ nvidia-nvtx-cu12==12.4.127
178
+ en_core_web_sm==3.8.0
179
+ xxhash==3.5.0
180
+ anyio==4.10.0
181
+ joblib==1.5.1
182
+ pydantic==2.11.7
183
+ opt_einsum==3.4.0
184
+ dotmap==1.3.30
185
+ language_data==1.3.0
186
+ shellingham==1.5.4
187
+ nvidia-cudnn-cu12==9.1.0.70
188
+ typing_extensions==4.14.1
189
+ libclang==18.1.1
190
+ tabulate==0.9.0
191
+ annotated-types==0.7.0
192
+ jaraco.context==5.3.0
193
+ autocommand==2.2.2
194
+ more-itertools==10.3.0
195
+ tomli==2.0.1
196
+ jaraco.functools==4.0.1
197
+ zipp==3.19.2
198
+ backports.tarfile==1.2.0
199
+ wheel==0.45.1
200
+ platformdirs==4.2.2
201
+ inflect==7.3.1
202
+ typing_extensions==4.12.2
203
+ jaraco.text==3.12.1
204
+ typeguard==4.3.0
205
+ importlib_metadata==8.0.0
206
+ packaging==24.2
207
+ jaraco.collections==5.1.0
tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug-internal.log ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2025-09-09T06:19:20.029854482Z","level":"INFO","msg":"stream: starting","core version":"0.21.0"}
2
+ {"time":"2025-09-09T06:19:20.338868384Z","level":"INFO","msg":"stream: created new stream","id":"top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614"}
3
+ {"time":"2025-09-09T06:19:20.338942945Z","level":"INFO","msg":"stream: started","id":"top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614"}
4
+ {"time":"2025-09-09T06:19:20.338955936Z","level":"INFO","msg":"handler: started","stream_id":"top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614"}
5
+ {"time":"2025-09-09T06:19:20.33900181Z","level":"INFO","msg":"writer: Do: started","stream_id":"top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614"}
6
+ {"time":"2025-09-09T06:19:20.339014387Z","level":"INFO","msg":"sender: started","stream_id":"top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614"}
7
+ {"time":"2025-09-09T16:55:51.461783187Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
8
+ {"time":"2025-09-09T17:52:23.968650788Z","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
9
+ {"time":"2025-09-09T22:51:18.011409168Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
10
+ {"time":"2025-09-09T22:58:20.165767227Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug.log ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Current SDK version is 0.21.0
2
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Configure stats pid to 795439
3
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/.config/wandb/settings
4
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/flame/wandb/settings
5
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from environment variables
6
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:setup_run_log_directory():703] Logging user logs to exp/top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug.log
7
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to exp/top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug-internal.log
8
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():830] calling init triggers
9
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():835] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {}}
11
+ 2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():871] starting backend
12
+ 2025-09-09 06:19:20,025 INFO MainThread:795439 [wandb_init.py:init():874] sending inform_init request
13
+ 2025-09-09 06:19:20,027 INFO MainThread:795439 [wandb_init.py:init():882] backend started and connected
14
+ 2025-09-09 06:19:20,033 INFO MainThread:795439 [wandb_init.py:init():953] updated telemetry
15
+ 2025-09-09 06:19:20,039 INFO MainThread:795439 [wandb_init.py:init():977] communicating run to backend with 90.0 second timeout
16
+ 2025-09-09 06:19:20,682 INFO MainThread:795439 [wandb_init.py:init():1029] starting run threads in backend
17
+ 2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_console_start():2458] atexit reg
18
+ 2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2306] redirect: wrap_raw
19
+ 2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2375] Wrapping output streams.
20
+ 2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2398] Redirects installed.
21
+ 2025-09-09 06:19:20,817 INFO MainThread:795439 [wandb_init.py:init():1075] run started, returning control to user process
torchtitan/components/__pycache__/dataloader.cpython-312.pyc ADDED
Binary file (3.79 kB). View file
 
torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc ADDED
Binary file (7.71 kB). View file
 
torchtitan/components/__pycache__/metrics.cpython-312.pyc ADDED
Binary file (19.6 kB). View file
 
torchtitan/components/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (1.09 kB). View file
 
torchtitan/components/metrics.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import time
9
+ from collections import namedtuple
10
+ from datetime import datetime
11
+ from typing import Any
12
+
13
+ import torch
14
+ from torch.utils.tensorboard import SummaryWriter
15
+ from torchtitan.components.lr_scheduler import LRSchedulersContainer
16
+ from torchtitan.components.optimizer import OptimizersContainer
17
+ from torchtitan.config_manager import JobConfig
18
+ from torchtitan.distributed import ParallelDims
19
+ from torchtitan.tools import utils
20
+ from torchtitan.tools.logging import logger
21
+ from torchtitan.tools.utils import Color, device_module, device_type
22
+
23
+ # named tuple for passing device memory stats for logging
24
+ DeviceMemStats = namedtuple(
25
+ "DeviceMemStats",
26
+ [
27
+ "max_active_gib",
28
+ "max_active_pct",
29
+ "max_reserved_gib",
30
+ "max_reserved_pct",
31
+ "num_alloc_retries",
32
+ "num_ooms",
33
+ ],
34
+ )
35
+
36
+
37
+ class DeviceMemoryMonitor:
38
+ def __init__(self, device: str = f"{device_type}:0"):
39
+ self.device = torch.device(device) # device object
40
+ self.device_name = device_module.get_device_name(self.device)
41
+ self.device_index = device_module.current_device()
42
+ self.device_capacity = device_module.get_device_properties(
43
+ self.device
44
+ ).total_memory
45
+ self.device_capacity_gib = self._to_gib(self.device_capacity)
46
+
47
+ device_module.reset_peak_memory_stats()
48
+ device_module.empty_cache()
49
+
50
+ def _to_gib(self, memory_in_bytes):
51
+ # NOTE: GiB (gibibyte) is 1024, vs GB is 1000
52
+ _gib_in_bytes = 1024 * 1024 * 1024
53
+ memory_in_gib = memory_in_bytes / _gib_in_bytes
54
+ return memory_in_gib
55
+
56
+ def _to_pct(self, memory):
57
+ return 100 * memory / self.device_capacity
58
+
59
+ def get_peak_stats(self):
60
+ device_info = device_module.memory_stats(self.device)
61
+
62
+ max_active = device_info.get("active_bytes.all.peak", -1)
63
+ max_active_gib = self._to_gib(max_active)
64
+ max_active_pct = self._to_pct(max_active)
65
+
66
+ max_reserved = device_info.get("reserved_bytes.all.peak", -1)
67
+ max_reserved_gib = self._to_gib(max_reserved)
68
+ max_reserved_pct = self._to_pct(max_reserved)
69
+
70
+ num_retries = device_info.get("num_alloc_retries", -1)
71
+ num_ooms = device_info.get("num_ooms", -1)
72
+
73
+ if num_retries > 0:
74
+ logger.warning(
75
+ f"{num_retries} {device_type.upper()} memory allocation retries."
76
+ )
77
+ if num_ooms > 0:
78
+ logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.")
79
+
80
+ return DeviceMemStats(
81
+ max_active_gib,
82
+ max_active_pct,
83
+ max_reserved_gib,
84
+ max_reserved_pct,
85
+ num_retries,
86
+ num_ooms,
87
+ )
88
+
89
+ def reset_peak_stats(self):
90
+ device_module.reset_peak_memory_stats()
91
+
92
+
93
+ def build_device_memory_monitor():
94
+ device_memory_monitor = DeviceMemoryMonitor(device_type)
95
+ logger.info(
96
+ f"{device_type.upper()} capacity: {device_memory_monitor.device_name} "
97
+ f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory"
98
+ )
99
+ return device_memory_monitor
100
+
101
+
102
+ class BaseLogger:
103
+ """Logger that does nothing, used when logging is disabled."""
104
+
105
+ def log(self, metrics: dict[str, Any], step: int) -> None:
106
+ pass
107
+
108
+ def close(self) -> None:
109
+ pass
110
+
111
+
112
+ class TensorBoardLogger(BaseLogger):
113
+ """Logger implementation for TensorBoard."""
114
+
115
+ def __init__(self, log_dir: str, tag: str | None = None):
116
+ self.tag = tag
117
+ self.writer = SummaryWriter(log_dir, max_queue=1000)
118
+ logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}")
119
+
120
+ def log(self, metrics: dict[str, Any], step: int) -> None:
121
+ for k, v in metrics.items():
122
+ tag = k if self.tag is None else f"{self.tag}/{k}"
123
+ self.writer.add_scalar(tag, v, step)
124
+
125
+ def close(self) -> None:
126
+ self.writer.close()
127
+
128
+
129
+ class WandBLogger(BaseLogger):
130
+ """Logger implementation for Weights & Biases."""
131
+
132
+ def __init__(self, log_dir: str, tag: str | None = None):
133
+ # Import wandb here to avoid startup import
134
+ import wandb
135
+
136
+ self.wandb = wandb
137
+ self.tag = tag
138
+
139
+ # Create logging directory
140
+ os.makedirs(log_dir, exist_ok=True)
141
+
142
+ self.wandb.init(
143
+ project=os.getenv("WANDB_PROJECT", "torchtitan"),
144
+ dir=log_dir,
145
+ )
146
+ logger.info("WandB logging enabled")
147
+
148
+ def log(self, metrics: dict[str, Any], step: int) -> None:
149
+ wandb_metrics = {
150
+ (k if self.tag is None else f"{self.tag}/{k}"): v
151
+ for k, v in metrics.items()
152
+ }
153
+ self.wandb.log(wandb_metrics, step=step)
154
+
155
+ def close(self) -> None:
156
+ if self.wandb.run is not None:
157
+ self.wandb.finish()
158
+
159
+
160
+ def ensure_pp_loss_visible(
161
+ parallel_dims: ParallelDims, job_config: JobConfig, color: Color
162
+ ) -> None:
163
+ """
164
+ Ensures that the loss is visible on the console for pipeline-parallel training.
165
+
166
+ For pipeline-parallel training, the loss is only visible on the last pipeline stage.
167
+ This function checks if the appropriate rank is included in the LOG_RANK environment
168
+ variable and warns if it's not.
169
+ """
170
+
171
+ # V Block Schedules return loss on rank 0
172
+ if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble":
173
+ return
174
+
175
+ # Calculate the rank where loss is visible (first rank of the last pipeline stage)
176
+ world_size = parallel_dims.world_size
177
+ pp_size = parallel_dims.pp
178
+ loss_visible_rank = (world_size // pp_size) * (pp_size - 1)
179
+
180
+ # Check if the loss-visible rank is included in LOG_RANK environment variable
181
+ env_logged_ranks = os.environ.get("LOG_RANK", "").split(",")
182
+ if env_logged_ranks == [""]:
183
+ env_logged_ranks = []
184
+
185
+ if str(loss_visible_rank) not in env_logged_ranks:
186
+ logger.warning(
187
+ f"{color.red}Pipeline Parallel loss is not visible. "
188
+ f"Please add {color.yellow}rank {loss_visible_rank}{color.red} "
189
+ f"to LOG_RANK environment variable in run_train.sh.{color.reset}"
190
+ )
191
+
192
+
193
+ def _get_metrics_rank(
194
+ parallel_dims: ParallelDims,
195
+ job_config: JobConfig,
196
+ ) -> int:
197
+ """
198
+ Determines which rank should log metrics.
199
+
200
+ Returns:
201
+ int: The rank responsible for logging metrics:
202
+ - Rank 0 for non-pipeline-parallel configs
203
+ - Rank 0 for pipeline-parallel 'ZBVZeroBubble' schedule
204
+ - The first rank of the last pipeline stage for other pipeline-parallel schedules
205
+ """
206
+ # Early return for non-pipeline-parallel configurations
207
+ if not parallel_dims.pp_enabled:
208
+ return 0
209
+
210
+ # V Block Schedules return loss on rank 0
211
+ if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble":
212
+ return 0
213
+
214
+ # Calculate first rank of the last pipeline stage
215
+ world_size = parallel_dims.world_size
216
+ pp_size = parallel_dims.pp
217
+ return (world_size // pp_size) * (pp_size - 1)
218
+
219
+
220
+ def _build_metric_logger(
221
+ job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None
222
+ ) -> BaseLogger:
223
+ """
224
+ Build an appropriate metric logger based on configuration.
225
+ """
226
+ metrics_config = job_config.metrics
227
+
228
+ # Log initial config state
229
+ logger.debug(
230
+ f"Building logger with config: wandb={metrics_config.enable_wandb}, "
231
+ f"tensorboard={metrics_config.enable_tensorboard}"
232
+ )
233
+
234
+ # Check if any logging backend is enabled
235
+ has_logging_enabled = (
236
+ metrics_config.enable_tensorboard or metrics_config.enable_wandb
237
+ )
238
+
239
+ # Determine if this rank should log
240
+ should_log = has_logging_enabled
241
+ if (not metrics_config.save_for_all_ranks) and should_log:
242
+ metrics_rank = _get_metrics_rank(parallel_dims, job_config)
243
+ should_log = torch.distributed.get_rank() == metrics_rank
244
+
245
+ logger.debug(
246
+ f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}"
247
+ )
248
+
249
+ if not should_log:
250
+ logger.debug("Returning BaseLogger due to should_log=False")
251
+ return BaseLogger()
252
+
253
+ # Setup logging directory
254
+ dump_dir = job_config.job.dump_folder
255
+ base_log_dir = os.path.join(
256
+ dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M")
257
+ )
258
+
259
+ if metrics_config.save_for_all_ranks:
260
+ base_log_dir = os.path.join(
261
+ base_log_dir, f"rank_{torch.distributed.get_rank()}"
262
+ )
263
+
264
+ # Create loggers in priority order
265
+ if metrics_config.enable_wandb:
266
+ logger.debug("Attempting to create WandB logger")
267
+ try:
268
+ return WandBLogger(base_log_dir, tag)
269
+ except Exception as e:
270
+ if "No module named 'wandb'" in str(e):
271
+ logger.error(
272
+ "Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'."
273
+ )
274
+ else:
275
+ logger.error(f"Failed to create WandB logger: {e}")
276
+
277
+ if metrics_config.enable_tensorboard:
278
+ logger.debug("Creating TensorBoard logger")
279
+ return TensorBoardLogger(base_log_dir, tag)
280
+
281
+ logger.debug("No loggers enabled, returning BaseLogger")
282
+ return BaseLogger()
283
+
284
+
285
+ class MetricsProcessor:
286
+ """Metrics processor to processes the metrics and log metrics.
287
+
288
+ The current MetricsProcessor log some metrics to STDOUT and some metrics to
289
+ TensorBoard or WandB.
290
+
291
+ Args:
292
+ job_config (JobConfig): Job configuration.
293
+ parallel_dims (ParallelDims): Parallel dimensions.
294
+ tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None.
295
+ """
296
+
297
+ logger: BaseLogger
298
+ parallel_dims: ParallelDims
299
+ job_config: JobConfig
300
+ device_memory_monitor: DeviceMemoryMonitor
301
+ color: utils.NoColor | utils.Color
302
+
303
+ gpu_peak_flops: int
304
+ ntokens_since_last_log: int
305
+ data_loading_times: list[float]
306
+ time_last_log: float
307
+
308
+ num_flops_per_token: int
309
+ optimizers: OptimizersContainer | None
310
+ lr_schedulers: LRSchedulersContainer | None
311
+
312
+ def __init__(
313
+ self,
314
+ job_config: JobConfig,
315
+ parallel_dims: ParallelDims,
316
+ tag: str | None = None,
317
+ ):
318
+ self.logger = _build_metric_logger(job_config, parallel_dims, tag)
319
+ self.parallel_dims = parallel_dims
320
+ self.job_config = job_config
321
+ self.device_memory_monitor = build_device_memory_monitor()
322
+ # used for colorful printing
323
+ self.color = (
324
+ utils.NoColor()
325
+ if job_config.metrics.disable_color_printing
326
+ else utils.Color()
327
+ )
328
+
329
+ self.gpu_peak_flops = utils.get_peak_flops(
330
+ self.device_memory_monitor.device_name
331
+ )
332
+ self.ntokens_since_last_log = 0
333
+ self.data_loading_times = []
334
+ self.time_last_log = time.perf_counter()
335
+ self.device_memory_monitor.reset_peak_stats()
336
+
337
+ # These variables have to be set later as they depend on other components or model.
338
+ self.num_flops_per_token = -1
339
+ self.optimizers = None
340
+ self.lr_schedulers = None
341
+
342
+ def should_log(self, step: int) -> bool:
343
+ return step == 1 or step % self.job_config.metrics.log_freq == 0
344
+
345
+ def log(
346
+ self,
347
+ step: int,
348
+ global_avg_loss: float,
349
+ global_max_loss: float,
350
+ extra_metrics: dict[str, Any] | None = None,
351
+ ):
352
+ assert self.num_flops_per_token > 0, "num_flops_per_token must be set"
353
+
354
+ time_delta = time.perf_counter() - self.time_last_log
355
+
356
+ # tokens per second per device, abbreviated as tps
357
+ tps = self.ntokens_since_last_log / (
358
+ time_delta * self.parallel_dims.non_data_parallel_size
359
+ )
360
+ # model FLOPS utilization
361
+ # For its definition and calculation, please refer to the PaLM paper:
362
+ # https://arxiv.org/abs/2204.02311
363
+ mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops
364
+ tflops = self.num_flops_per_token * tps / 1e12
365
+
366
+ time_end_to_end = time_delta / self.job_config.metrics.log_freq
367
+ time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times)
368
+ time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta
369
+
370
+ device_mem_stats = self.device_memory_monitor.get_peak_stats()
371
+
372
+ metrics = {
373
+ "loss_metrics/global_avg_loss": global_avg_loss,
374
+ "loss_metrics/global_max_loss": global_max_loss,
375
+ "throughput(tps)": tps,
376
+ "tflops": tflops,
377
+ "mfu(%)": mfu,
378
+ "time_metrics/end_to_end(s)": time_end_to_end,
379
+ "time_metrics/data_loading(s)": time_data_loading,
380
+ "time_metrics/data_loading(%)": time_data_loading_pct,
381
+ "memory/max_active(GiB)": device_mem_stats.max_active_gib,
382
+ "memory/max_active(%)": device_mem_stats.max_active_pct,
383
+ "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib,
384
+ "memory/max_reserved(%)": device_mem_stats.max_reserved_pct,
385
+ "memory/num_alloc_retries": device_mem_stats.num_alloc_retries,
386
+ "memory/num_ooms": device_mem_stats.num_ooms,
387
+ }
388
+
389
+ if extra_metrics:
390
+ metrics.update(extra_metrics)
391
+
392
+ self.logger.log(metrics, step)
393
+
394
+ color = self.color
395
+ construct_string = str(
396
+ f"{color.red}step: {step:2} "
397
+ f"{color.green}loss: {global_avg_loss:7.4f} "
398
+ f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
399
+ f"({device_mem_stats.max_reserved_pct:.2f}%) "
400
+ f"{color.blue}tps: {round(tps):,} "
401
+ f"{color.cyan}tflops: {tflops:,.2f} "
402
+ f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
403
+ )
404
+
405
+ if extra_metrics:
406
+ for k, v in extra_metrics.items():
407
+ if "loss" in k:
408
+ construct_string += f" {color.white}{k.lstrip('loss_metrics/')}: {v:7.4f}"
409
+ logger.info(
410
+ construct_string
411
+ )
412
+
413
+ self.ntokens_since_last_log = 0
414
+ self.data_loading_times.clear()
415
+ self.time_last_log = time.perf_counter()
416
+ self.device_memory_monitor.reset_peak_stats()
417
+
418
+ def close(self):
419
+ self.logger.close()
420
+
421
+
422
+ def build_metrics_processor(
423
+ job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None
424
+ ) -> MetricsProcessor:
425
+ """Create a metrics processor.
426
+
427
+ Args:
428
+ job_config (JobConfig): Job configuration.
429
+ parallel_dims (ParallelDims): Parallel dimensions.
430
+ tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None.
431
+
432
+ Returns:
433
+ MetricsProcessor: A metrics processor.
434
+ """
435
+ return MetricsProcessor(job_config, parallel_dims, tag)
torchtitan/experiments/deepseek_v3/LICENSE-CODE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 DeepSeek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
torchtitan/experiments/deepseek_v3/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running DeepSeek in Titan (experimental)
2
+
3
+ This folder contains a DeepSeek model supporting v2 and v3 as well as kernels
4
+ and scripts needed to run it.
5
+
6
+ ## Inference
7
+
8
+ ### Prerequisites:
9
+
10
+ You will need to download a DeepSeek model's weights if you want to run a
11
+ pre-trained checkpoint. We provided a script to download the weights from
12
+ HuggingFace Model Hub:
13
+ ```bash
14
+ python download.py [vX]
15
+ ```
16
+ where `vX` can be v2 or v3, both are supported. You may be required to create a
17
+ HuggingFace account and log in first.
18
+
19
+ ### Running inference:
20
+
21
+ The inference script is in `generate.py`. You can run it with the following
22
+ command:
23
+ ```bash
24
+ torchrun --standalone --nproc-per-node 4 generate.py
25
+ ```
26
+ This will run inference on the `DeepSeek-V2-Lite-Chat` model using 4 GPUs by
27
+ default.
28
+
29
+ Alternatively, you can run inference by using `bash inference.sh`, optionally
30
+ followed by your prompt.
31
+
32
+ ## Training
33
+
34
+ The training script is in `train.py`. You can run it by the following command:
35
+ ```bash
36
+ torchrun --standalone --nproc-per-node 8 train.py
37
+ ```
38
+
39
+ This will run training on the `DeepSeek-V2-Lite-Chat` model using 8 GPUs by
40
+ default, with pipeline parallel, expert parallel, and data parallel enabled.
torchtitan/experiments/deepseek_v3/checkpoint.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ from typing import Dict, Optional, Set, Tuple
11
+
12
+ import torch
13
+ from safetensors import safe_open
14
+
15
+ from transformers.utils import cached_file
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
21
+
22
+
23
+ def read_weights_from_json(file_path: str) -> Optional[Dict[str, str]]:
24
+ try:
25
+ with open(file_path, "r") as file:
26
+ data = json.load(file)
27
+
28
+ if "weight_map" in data and isinstance(data["weight_map"], dict):
29
+ return data["weight_map"]
30
+ else:
31
+ logger.info("No 'weight_map' dictionary found in the JSON file.")
32
+ return None
33
+ except (json.JSONDecodeError, Exception) as e:
34
+ logger.info(f"An error occurred while reading the JSON file: {str(e)}")
35
+ return None
36
+
37
+
38
+ def get_hf_weight_map_and_path(
39
+ model_id: str,
40
+ ) -> Tuple[Dict[str, str], str]:
41
+ """Get the weight map for a given HF model id and also the cache path for loading the weights"""
42
+ try:
43
+ index_file = cached_file(model_id, _DEFAULT_SAFETENSOR_FILE_NAME)
44
+ except Exception as e:
45
+ logger.error(
46
+ f"Model `{model_id}` not found in HF cache. "
47
+ f"You can download the model using `python download.py {model_id}"
48
+ )
49
+ raise e
50
+
51
+ weight_map = read_weights_from_json(index_file)
52
+ weight_path = os.path.dirname(index_file)
53
+ logger.info(f"Loading weights from: {weight_path}")
54
+ return weight_map, weight_path
55
+
56
+
57
+ def get_needed_files(
58
+ state_dict: Dict[str, torch.Tensor], weight_map: Dict[str, str]
59
+ ) -> Set[str]:
60
+ needed_files = set()
61
+ for param in state_dict.keys():
62
+ file = weight_map.get(param)
63
+ if file:
64
+ needed_files.add(file)
65
+ elif param.endswith("weight"):
66
+ raise ValueError(
67
+ f"Parameter {param} not found in weight map, please check..."
68
+ )
69
+ logger.info(f"Needed files: {needed_files}")
70
+ return needed_files
71
+
72
+
73
+ def load_safetensor_file(
74
+ full_path: str, device: torch.device
75
+ ) -> Dict[str, torch.Tensor]:
76
+ tensors = {}
77
+ with safe_open(full_path, framework="pt", device=device) as f:
78
+ for k in f.keys():
79
+ tensors[k] = f.get_tensor(k)
80
+ logger.info(f"Loaded {len(tensors)} tensors from {full_path}")
81
+ return tensors
82
+
83
+
84
+ def load_safetensor_weights(
85
+ model: torch.nn.Module,
86
+ weight_map: Dict[str, str],
87
+ file_location: str,
88
+ device: torch.device,
89
+ ):
90
+ """
91
+ Load safetensor weights into a `nn.Module`.
92
+
93
+ Args:
94
+ model (Module): The PyTorch module to load weights into. It may be a
95
+ model chunk or a full model.
96
+ weight_map (Dict[str, str]): Mapping of model parameters to file names.
97
+ file_location (str): Directory containing the weight files.
98
+ device (torch.device): The device to load tensors onto.
99
+ """
100
+ model_state_dict = model.state_dict()
101
+ needed_files = get_needed_files(model_state_dict, weight_map)
102
+ updated_states: Set[str] = set()
103
+
104
+ for file in needed_files:
105
+ full_path = os.path.join(file_location, file)
106
+ try:
107
+ checkpoint = load_safetensor_file(full_path, "cpu")
108
+ except FileNotFoundError:
109
+ logger.error(f"File not found: {full_path}")
110
+ except Exception as e:
111
+ logger.error(f"Error during checkpoint processing of {full_path}: {str(e)}")
112
+
113
+ matched_keys = set(checkpoint.keys()) & set(model_state_dict.keys())
114
+ for key in matched_keys:
115
+ # Check shape
116
+ if model_state_dict[key].shape != checkpoint[key].shape:
117
+ raise ValueError(
118
+ f"Shape mismatch for {key}: "
119
+ f"model needs {model_state_dict[key].shape}, but "
120
+ f"checkpoint has {checkpoint[key].shape}"
121
+ )
122
+ model_state_dict[key] = checkpoint[key].to(device)
123
+
124
+ updated_states.update(matched_keys)
125
+
126
+ missing_keys = set(model_state_dict.keys()) - updated_states
127
+ if missing_keys:
128
+ raise RuntimeError(
129
+ f"Partially updated state dict. Missing parameters: {missing_keys}"
130
+ )
131
+
132
+ model.load_state_dict(model_state_dict, strict=False, assign=True)
133
+ logger.info(f"Successfully loaded {len(updated_states)} weights into model")
134
+
135
+
136
+ def load_weights_from_hf(
137
+ model: torch.nn.Module,
138
+ distribution: str,
139
+ device: torch.device,
140
+ ):
141
+ """
142
+ Load the weights from Hugging Face format (index file + multiple safetensor
143
+ files), and fill into `model`. Model config is needed b/c we permute
144
+ wq and wk weights based on attn heads.
145
+ """
146
+
147
+ weight_map, weight_path = get_hf_weight_map_and_path(distribution)
148
+
149
+ load_safetensor_weights(
150
+ model,
151
+ weight_map,
152
+ weight_path,
153
+ device,
154
+ )
torchtitan/experiments/deepseek_v3/download.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Usage:
8
+ # Downloads a given model to the HF Cache. Pass in a listed option ala "v3" or your own custom model path.
9
+ # python download.py {model_id} [custom_model_path]
10
+ # Examples:
11
+ # python download.py v2 # Use predefined model: deepseek-ai/DeepSeek-V2
12
+ # python download.py custom "deepseek-ai/new-model" # Download a custom model path
13
+
14
+ # Available models:
15
+ # "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
16
+ # "v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
17
+ # "v2": "deepseek-ai/DeepSeek-V2",
18
+ # "v3": "deepseek-ai/deepseek-v3",
19
+ # "v3-0324": "deepseek-ai/DeepSeek-V3-0324",
20
+ # "custom": None, # Placeholder for custom models
21
+
22
+
23
+ import sys
24
+
25
+ from transformers import AutoModelForCausalLM
26
+
27
+
28
+ MODELS = {
29
+ "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
30
+ "v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
31
+ "v2": "deepseek-ai/DeepSeek-V2",
32
+ "v3": "deepseek-ai/deepseek-v3",
33
+ "v3-0324": "deepseek-ai/DeepSeek-V3-0324",
34
+ "custom": None, # For custom (any) models
35
+ }
36
+
37
+
38
+ def print_usage():
39
+ print("Usage:")
40
+ print(" python download.py [model_version]")
41
+ print(" python download.py custom [custom_model_path]")
42
+ print("\nAvailable predefined models:")
43
+ for key, model in MODELS.items():
44
+ if key != "custom": # Skip the custom placeholder
45
+ print(f" {key}: {model}")
46
+ print("\nFor custom models:")
47
+ print(" custom: Specify your own model path")
48
+ print(' Example: python download.py custom "organization/model-name"')
49
+ sys.exit(1)
50
+
51
+
52
+ # Process command line arguments
53
+ if len(sys.argv) < 2 or sys.argv[1] not in MODELS:
54
+ print_usage()
55
+
56
+ if sys.argv[1] == "custom":
57
+ if len(sys.argv) != 3:
58
+ print("Error: Custom model requires a model path")
59
+ print_usage()
60
+ model_id = sys.argv[2]
61
+ print(f"Using custom model: {model_id}")
62
+ else:
63
+ model_id = MODELS[sys.argv[1]]
64
+ print(f"Downloading model: {model_id}")
65
+
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ model_id,
68
+ device_map="auto",
69
+ trust_remote_code=True,
70
+ )
torchtitan/experiments/deepseek_v3/model.py ADDED
@@ -0,0 +1,1325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This code is based on model definition of `deepseek-ai/DeepSeek-V3-Base` on
8
+ # Hugging Face Model Hub. Url:
9
+ # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py
10
+ # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/resolve/main/configuration_deepseek.py
11
+ #
12
+ # It has been modified from its original forms to accommodate naming convention
13
+ # and usage patterns of the TorchTitan project.
14
+
15
+ # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+ """ PyTorch DeepSeek model."""
29
+ import math
30
+ from typing import Optional, Tuple
31
+
32
+ import torch
33
+ import torch.distributed as dist
34
+
35
+ import torch.distributed._symmetric_memory as symm_mem
36
+ import torch.nn.functional as F
37
+ import torch.utils.checkpoint
38
+
39
+ from attn_mask_utils import _prepare_4d_causal_attention_mask
40
+ from indices import generate_permute_indices
41
+ from model_config import ModelArgs
42
+ from symm_mem_recipes import OnDeviceAllToAllV
43
+ from torch import nn
44
+ from torch.distributed._functional_collectives import all_to_all_single_autograd
45
+
46
+ from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import (
47
+ ALIGN_SIZE_M,
48
+ grouped_gemm_forward,
49
+ )
50
+
51
+ # Get model parallel subgroup by name:
52
+ # e.g. "pp", "ep", None
53
+ def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup:
54
+ glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh()
55
+ return glob.get_group(dim_name)
56
+
57
+
58
+ class RMSNorm(nn.Module):
59
+ def __init__(self, hidden_size, eps=1e-6):
60
+ super().__init__()
61
+ self.weight = nn.Parameter(torch.ones(hidden_size))
62
+ self.variance_epsilon = eps
63
+
64
+ def forward(self, hidden_states):
65
+ input_dtype = hidden_states.dtype
66
+ hidden_states = hidden_states.to(torch.float32)
67
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
68
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
69
+ return self.weight * hidden_states.to(input_dtype)
70
+
71
+
72
+ class RotaryEmbedding(nn.Module):
73
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
74
+ super().__init__()
75
+
76
+ self.dim = dim
77
+ self.max_position_embeddings = max_position_embeddings
78
+ self.base = base
79
+ inv_freq = 1.0 / (
80
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
81
+ )
82
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
83
+
84
+ # Build here to make `torch.jit.trace` work.
85
+ self._set_cos_sin_cache(
86
+ seq_len=max_position_embeddings,
87
+ device=self.inv_freq.device,
88
+ dtype=torch.get_default_dtype(),
89
+ )
90
+
91
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
92
+ self.max_seq_len_cached = seq_len
93
+ t = torch.arange(
94
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
95
+ )
96
+
97
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
98
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
99
+ emb = torch.cat((freqs, freqs), dim=-1)
100
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
101
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
102
+
103
+ def forward(self, x, seq_len=None):
104
+ # x: [bs, num_attention_heads, seq_len, head_size]
105
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
106
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
107
+
108
+ return (
109
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
110
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
111
+ )
112
+
113
+
114
+ class LinearScalingRotaryEmbedding(RotaryEmbedding):
115
+ """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
116
+
117
+ def __init__(
118
+ self,
119
+ dim,
120
+ max_position_embeddings=2048,
121
+ base=10000,
122
+ device=None,
123
+ scaling_factor=1.0,
124
+ ):
125
+ self.scaling_factor = scaling_factor
126
+ super().__init__(dim, max_position_embeddings, base, device)
127
+
128
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
129
+ self.max_seq_len_cached = seq_len
130
+ t = torch.arange(
131
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
132
+ )
133
+ t = t / self.scaling_factor
134
+
135
+ freqs = torch.outer(t, self.inv_freq)
136
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
137
+ emb = torch.cat((freqs, freqs), dim=-1)
138
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
139
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
140
+
141
+
142
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek
143
+ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
144
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
145
+
146
+ def __init__(
147
+ self,
148
+ dim,
149
+ max_position_embeddings=2048,
150
+ base=10000,
151
+ device=None,
152
+ scaling_factor=1.0,
153
+ ):
154
+ self.scaling_factor = scaling_factor
155
+ super().__init__(dim, max_position_embeddings, base, device)
156
+
157
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
158
+ self.max_seq_len_cached = seq_len
159
+
160
+ if seq_len > self.max_position_embeddings:
161
+ base = self.base * (
162
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
163
+ - (self.scaling_factor - 1)
164
+ ) ** (self.dim / (self.dim - 2))
165
+ inv_freq = 1.0 / (
166
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
167
+ )
168
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
169
+
170
+ t = torch.arange(
171
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
172
+ )
173
+
174
+ freqs = torch.outer(t, self.inv_freq)
175
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
176
+ emb = torch.cat((freqs, freqs), dim=-1)
177
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
178
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
179
+
180
+
181
+ # Inverse dim formula to find dim based on number of rotations
182
+ def yarn_find_correction_dim(
183
+ num_rotations, dim, base=10000, max_position_embeddings=2048
184
+ ):
185
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
186
+ 2 * math.log(base)
187
+ )
188
+
189
+
190
+ # Find dim range bounds based on rotations
191
+ def yarn_find_correction_range(
192
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
193
+ ):
194
+ low = math.floor(
195
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
196
+ )
197
+ high = math.ceil(
198
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
199
+ )
200
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
201
+
202
+
203
+ def yarn_get_mscale(scale=1, mscale=1):
204
+ if scale <= 1:
205
+ return 1.0
206
+ return 0.1 * mscale * math.log(scale) + 1.0
207
+
208
+
209
+ def yarn_linear_ramp_mask(min, max, dim):
210
+ if min == max:
211
+ max += 0.001 # Prevent singularity
212
+
213
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
214
+ ramp_func = torch.clamp(linear_func, 0, 1)
215
+ return ramp_func
216
+
217
+
218
+ class YarnRotaryEmbedding(RotaryEmbedding):
219
+ def __init__(
220
+ self,
221
+ dim,
222
+ max_position_embeddings=2048,
223
+ base=10000,
224
+ device=None,
225
+ scaling_factor=1.0,
226
+ original_max_position_embeddings=4096,
227
+ beta_fast=32,
228
+ beta_slow=1,
229
+ mscale=1,
230
+ mscale_all_dim=0,
231
+ ):
232
+ self.scaling_factor = scaling_factor
233
+ self.original_max_position_embeddings = original_max_position_embeddings
234
+ self.beta_fast = beta_fast
235
+ self.beta_slow = beta_slow
236
+ self.mscale = mscale
237
+ self.mscale_all_dim = mscale_all_dim
238
+ super().__init__(dim, max_position_embeddings, base, device)
239
+
240
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
241
+ self.max_seq_len_cached = seq_len
242
+ dim = self.dim
243
+
244
+ freq_extra = 1.0 / (
245
+ self.base
246
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
247
+ )
248
+ freq_inter = 1.0 / (
249
+ self.scaling_factor
250
+ * self.base
251
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
252
+ )
253
+
254
+ low, high = yarn_find_correction_range(
255
+ self.beta_fast,
256
+ self.beta_slow,
257
+ dim,
258
+ self.base,
259
+ self.original_max_position_embeddings,
260
+ )
261
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
262
+ device=device, dtype=torch.float32
263
+ )
264
+ inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
265
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
266
+
267
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
268
+
269
+ freqs = torch.outer(t, inv_freq)
270
+
271
+ _mscale = float(
272
+ yarn_get_mscale(self.scaling_factor, self.mscale)
273
+ / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
274
+ )
275
+
276
+ emb = torch.cat((freqs, freqs), dim=-1)
277
+ self.register_buffer(
278
+ "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
279
+ )
280
+ self.register_buffer(
281
+ "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
282
+ )
283
+
284
+
285
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
286
+ def rotate_half(x):
287
+ """Rotates half the hidden dims of the input."""
288
+ x1 = x[..., : x.shape[-1] // 2]
289
+ x2 = x[..., x.shape[-1] // 2 :]
290
+ return torch.cat((-x2, x1), dim=-1)
291
+
292
+
293
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
294
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
295
+ """Applies Rotary Position Embedding to the query and key tensors.
296
+
297
+ Args:
298
+ q (`torch.Tensor`): The query tensor.
299
+ k (`torch.Tensor`): The key tensor.
300
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
301
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
302
+ position_ids (`torch.Tensor`):
303
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
304
+ used to pass offsetted position ids when working with a KV-cache.
305
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
306
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
307
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
308
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
309
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
310
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
311
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
312
+ Returns:
313
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
314
+ """
315
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
316
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
317
+
318
+ b, h, s, d = q.shape
319
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
320
+
321
+ b, h, s, d = k.shape
322
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
323
+
324
+ q_embed = (q * cos) + (rotate_half(q) * sin)
325
+ k_embed = (k * cos) + (rotate_half(k) * sin)
326
+ return q_embed, k_embed
327
+
328
+
329
+ class MLP(nn.Module):
330
+ act_fn = nn.SiLU()
331
+
332
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
333
+ super().__init__()
334
+ self.config = config
335
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
336
+ self.intermediate_size = (
337
+ config.intermediate_size if intermediate_size is None else intermediate_size
338
+ )
339
+
340
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
341
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
342
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
343
+
344
+ def forward(self, x):
345
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
346
+ return down_proj
347
+
348
+
349
+ class MoEGate(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.config = config
353
+ self.top_k = config.num_experts_per_tok
354
+ self.n_routed_experts = config.n_routed_experts
355
+ self.routed_scaling_factor = config.routed_scaling_factor
356
+ self.scoring_func = config.scoring_func
357
+ self.seq_aux = config.seq_aux
358
+ self.topk_method = config.topk_method
359
+ self.n_group = config.n_group
360
+ self.topk_group = config.topk_group
361
+
362
+ # topk selection algorithm
363
+ self.norm_topk_prob = config.norm_topk_prob
364
+ self.gating_dim = config.hidden_size
365
+ self.weight = nn.Parameter(
366
+ torch.empty((self.n_routed_experts, self.gating_dim))
367
+ )
368
+ if self.topk_method == "noaux_tc":
369
+ self.e_score_correction_bias = nn.Parameter(
370
+ # Changed from torch.empty to torch.rand to avoid non-even
371
+ # distribution for runs without actual weigths
372
+ torch.rand((self.n_routed_experts))
373
+ )
374
+ self.reset_parameters()
375
+
376
+ def reset_parameters(self) -> None:
377
+ import torch.nn.init as init
378
+
379
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
380
+
381
+ def forward(self, hidden_states):
382
+ bsz, seq_len, h = hidden_states.shape
383
+ # compute gating score
384
+ hidden_states = hidden_states.view(-1, h)
385
+ logits = F.linear(
386
+ hidden_states.type(torch.float32), self.weight.type(torch.float32), None
387
+ )
388
+ if self.scoring_func == "sigmoid":
389
+ scores = logits.sigmoid()
390
+ elif self.scoring_func == "softmax":
391
+ scores = logits.softmax(dim=-1, dtype=torch.float32)
392
+ else:
393
+ raise NotImplementedError(
394
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
395
+ )
396
+
397
+ # select top-k experts
398
+ if self.topk_method == "noaux_tc":
399
+ scores_for_choice = scores.view(
400
+ bsz * seq_len, -1
401
+ ) + self.e_score_correction_bias.unsqueeze(0)
402
+ group_scores = (
403
+ scores_for_choice.view(bsz * seq_len, self.n_group, -1)
404
+ .topk(2, dim=-1)[0]
405
+ .sum(dim=-1)
406
+ ) # [n, n_group]
407
+ group_idx = torch.topk(
408
+ group_scores, k=self.topk_group, dim=-1, sorted=False
409
+ )[
410
+ 1
411
+ ] # [n, top_k_group]
412
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
413
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
414
+ score_mask = (
415
+ group_mask.unsqueeze(-1)
416
+ .expand(
417
+ bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
418
+ )
419
+ .reshape(bsz * seq_len, -1)
420
+ ) # [n, e]
421
+ tmp_scores = scores_for_choice.masked_fill(
422
+ ~score_mask.bool(), 0.0
423
+ ) # [n, e]
424
+ _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
425
+ topk_weight = scores.gather(1, topk_idx)
426
+ elif self.topk_method == "greedy":
427
+ topk_weight, topk_idx = torch.topk(
428
+ scores, k=self.top_k, dim=-1, sorted=False
429
+ )
430
+ else:
431
+ raise NotImplementedError(
432
+ f"insupportable TopK function for MoE gating: {self.topk_method}"
433
+ )
434
+
435
+ # norm gate to sum 1
436
+ if self.top_k > 1 and self.norm_topk_prob:
437
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
438
+ topk_weight = topk_weight / denominator
439
+ topk_weight = (
440
+ topk_weight * self.routed_scaling_factor
441
+ ) # must multiply the scaling factor
442
+
443
+ return topk_idx, topk_weight
444
+
445
+
446
+ class MoE(nn.Module):
447
+ """
448
+ A mixed expert module containing shared experts.
449
+ """
450
+
451
+ # Class attributes:
452
+ # Two shuffle method supported:
453
+ # 1. "torch_all_to_all"
454
+ # 2. "symm_mem" (see `setup_symm_mem` below)
455
+ shuffle_method = "torch_all_to_all"
456
+
457
+ # Symmetric memory buffers shared by all MoE instances across layers
458
+ token_send_buf: Optional[torch.Tensor] = None
459
+ token_gather_buf: Optional[torch.Tensor] = None
460
+
461
+ def __init__(self, config):
462
+ super().__init__()
463
+ self.config = config
464
+ self.num_experts_per_tok = config.num_experts_per_tok
465
+
466
+ # ep_size is the number of ranks in expert dimension
467
+ if config.ep_size <= 1:
468
+ raise ValueError(
469
+ "For code simplicity, this model only supports distributed experts, "
470
+ "thus EP size must be > 1, please modify your model config"
471
+ )
472
+ self.ep_group = get_group("ep")
473
+ assert config.ep_size == self.ep_group.size()
474
+ self.ep_size = config.ep_size
475
+ self.ep_rank = self.ep_group.rank()
476
+ self.experts_per_rank = config.n_routed_experts // config.ep_size
477
+ # Use ModuleDict instead of ModuleList to preserve absoulte expert
478
+ # IDs while avoiding `None` experts. The absolute expert IDs match
479
+ # with checkpoint FQNs.
480
+ self.experts = nn.ModuleDict()
481
+ for i in range(self.experts_per_rank):
482
+ abs_expert_id = self.ep_rank * self.experts_per_rank + i
483
+ self.experts[str(abs_expert_id)] = MLP(
484
+ config, intermediate_size=config.moe_intermediate_size
485
+ )
486
+ self.gate = MoEGate(config)
487
+ if config.n_shared_experts is not None:
488
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
489
+ self.shared_experts = MLP(
490
+ config=config, intermediate_size=intermediate_size
491
+ )
492
+
493
+ def combine_experts(self, submod_name):
494
+ all_weights = []
495
+ for expert in self.experts.values():
496
+ lin = expert.get_submodule(submod_name)
497
+ all_weights.append(lin.weight)
498
+ lin.weight = None
499
+
500
+ concat_weight = torch.cat(all_weights)
501
+ self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight))
502
+
503
+ # This function is used to create a symm mem buffer for MoE's. It is for
504
+ # shuffling tokens fully "on-device", as compared to traditional torch
505
+ # all_to_all APIs which requrie a GPU-to-CPU sync of the splits. If a user
506
+ # calls this function, the `shuffle_method` would switch from
507
+ # `torch_all_to_all` to `symm_mem`.
508
+ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
509
+ # Switch shuffle method
510
+ self.shuffle_method = "symm_mem"
511
+
512
+ # Combine expert weights
513
+ print("Combining expert weights for Group GEMM")
514
+ self.combine_experts("gate_proj")
515
+ self.combine_experts("up_proj")
516
+ self.combine_experts("down_proj")
517
+
518
+ # Assuming worst case, 2x tokens are routed to one EP rank
519
+ overflow = 2
520
+ OnDeviceAllToAllV.max_output_len = (
521
+ self.config.max_seq_len * self.num_experts_per_tok * overflow
522
+ )
523
+
524
+ # Symmetric memory buffers are shared by all MoE instances across
525
+ # layers, we only need to initialize them once
526
+ if MoE.token_send_buf is not None:
527
+ return
528
+
529
+ # Input buffer for DP-to-EP shuffle
530
+ MoE.token_send_buf = symm_mem.empty(
531
+ self.config.max_seq_len
532
+ * self.num_experts_per_tok, # seq len * top k (flattened)
533
+ self.config.hidden_size, # hidden dim
534
+ dtype=dtype,
535
+ device=device,
536
+ )
537
+ # Input buffer for EP-to-DP shuffle
538
+ MoE.token_gather_buf = symm_mem.empty(
539
+ self.config.max_seq_len
540
+ * self.num_experts_per_tok # seq len * top k (flattened)
541
+ * overflow,
542
+ self.config.hidden_size, # hidden dim
543
+ dtype=dtype,
544
+ device=device,
545
+ )
546
+ print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE")
547
+
548
+ def get_send_buf(self):
549
+ # [Why detach?] During a first forward-backward step, the buffer would
550
+ # be included in a computational graph. In a second step, autograd will
551
+ # return an error saying "Trying to backward through the graph a second
552
+ # time (or directly access saved tensors more than once)". This is
553
+ # because the buffer is still in the graph, and autograd is trying to
554
+ # backward through the graph a second time. To avoid this, we detach the
555
+ # buffer from the graph. `detach()` returns a new tensor, which shares
556
+ # the same storage with the original one.
557
+ self.token_send_buf.grad = None
558
+ return self.token_send_buf.detach()
559
+
560
+ def get_gather_buf(self):
561
+ # See [Why detach?] in `get_send_buf`
562
+ self.token_gather_buf.grad = None
563
+ return self.token_gather_buf.detach()
564
+
565
+ def forward(self, hidden_states):
566
+ identity = hidden_states
567
+ orig_shape = hidden_states.shape
568
+ # for each token, select top-k experts, and compute the weight for each expert
569
+ topk_idx, topk_weight = self.gate(hidden_states)
570
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
571
+ if self.shuffle_method == "symm_mem":
572
+ y = self.moe_on_device(hidden_states, topk_idx, topk_weight)
573
+ else: # "torch_all_to_all"
574
+ y = self.moe_forward(hidden_states, topk_idx, topk_weight)
575
+
576
+ y = y.view(*orig_shape)
577
+ if self.config.n_shared_experts is not None:
578
+ y = y + self.shared_experts(identity)
579
+ return y
580
+
581
+ def moe_forward(self, x, topk_ids, topk_weight):
582
+ # This part sorts the token indices so that tokens routed to the same expert reside consecutively.
583
+ # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
584
+ # Since this is an "aritificial" index creation (final outcome being
585
+ # `idxs`), we don't need gradients here.
586
+ with torch.no_grad():
587
+ # [seq_len, n_routed_experts]
588
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
589
+ # Fill 1 to the selected experts
590
+ cnts.scatter_(1, topk_ids, 1)
591
+ tokens_per_expert = cnts.sum(dim=0)
592
+ # Token indices for each expert
593
+ idxs = topk_ids.view(-1).argsort()
594
+ sorted_tokens_shape = idxs.shape + x.shape[1:]
595
+
596
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
597
+ assert sorted_tokens.shape == sorted_tokens_shape
598
+
599
+ # This part exchange the information about the number of tokens send and
600
+ # received by each expert. We can understand this information as "side
601
+ # band", which is not part of the actual data. Thus no gradient is
602
+ # needed.
603
+ with torch.no_grad():
604
+ # Sum the tokens over local experts, then we get tokens per EP rank,
605
+ # which is the input splits
606
+ tokens_per_expert_group = tokens_per_expert.new_empty(
607
+ tokens_per_expert.shape[0]
608
+ )
609
+ dist.all_to_all_single(
610
+ tokens_per_expert_group, tokens_per_expert, group=self.ep_group
611
+ )
612
+ input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
613
+
614
+ # DP to EP token shuffle. This part needs gradient.
615
+ if self.shuffle_method == "symm_mem":
616
+ # Move input to the `token_send_buf` symm mem
617
+ token_send_buf = self.get_send_buf()
618
+ token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
619
+ # Note: `out=` avoids copy, but it is not differentiable
620
+ # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
621
+ token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
622
+ token_send_buf,
623
+ input_splits,
624
+ self.ep_group,
625
+ )
626
+ with torch.no_grad():
627
+ # Received tokens from all other ranks. TODO: use mask instead
628
+ received = output_splits.sum()
629
+ # TODO: don't use `received`
630
+ gathered_tokens = token_gather_buf[:received]
631
+ else: # "torch_all_to_all"
632
+ # Prepare input ans output splits
633
+ with torch.no_grad():
634
+ output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
635
+ dim=1
636
+ )
637
+ gathered_tokens = all_to_all_single_autograd(
638
+ sorted_tokens,
639
+ output_splits.tolist(),
640
+ input_splits.tolist(),
641
+ self.ep_group,
642
+ )
643
+
644
+ # This part prepares a 1D tensor with the same length as
645
+ # `gathered_tokens`. The 1D tensor is filled with local expert IDs which
646
+ # the tokens in `gathered_tokens` are headed for. This part doesn't need
647
+ # gradient.
648
+ with torch.no_grad():
649
+ gatherd_idxs = (
650
+ torch.arange(
651
+ tokens_per_expert_group.numel(),
652
+ device=tokens_per_expert_group.device,
653
+ )
654
+ % self.experts_per_rank
655
+ )
656
+ gatherd_idxs = gatherd_idxs.repeat_interleave(tokens_per_expert_group)
657
+
658
+ # Prepare buffer for tokens processed by experts
659
+ if self.shuffle_method == "symm_mem":
660
+ # Take necessary space from `token_gather_buf` symm mem because we are
661
+ # going to send them out after expert processing
662
+ processed_tokens = self.get_gather_buf()[: gathered_tokens.shape[0]]
663
+ else: # "torch_all_to_all"
664
+ processed_tokens = torch.empty_like(gathered_tokens)
665
+
666
+ # This part processes the tokens routed to the local experts.
667
+ # TODO: can we use group GEMM here?
668
+ for i, expert in enumerate(self.experts.values()):
669
+ processed_tokens[gatherd_idxs == i] = expert(
670
+ gathered_tokens[gatherd_idxs == i]
671
+ )
672
+
673
+ # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
674
+ # The input/output splits are just a reverse of the previous shuffle.
675
+ if self.shuffle_method == "symm_mem":
676
+ token_return_buf, _ = OnDeviceAllToAllV.apply(
677
+ processed_tokens,
678
+ output_splits,
679
+ self.ep_group,
680
+ )
681
+ returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
682
+ else: # "torch_all_to_all"
683
+ returned_tokens = all_to_all_single_autograd(
684
+ processed_tokens,
685
+ input_splits.tolist(),
686
+ output_splits.tolist(),
687
+ self.ep_group,
688
+ )
689
+
690
+ output_tokens = torch.empty_like(returned_tokens)
691
+ output_tokens[idxs] = returned_tokens
692
+ final_out = (
693
+ output_tokens.view(*topk_ids.shape, -1)
694
+ .type(topk_weight.dtype)
695
+ .mul_(topk_weight.unsqueeze(dim=-1))
696
+ .sum(dim=1)
697
+ .type(returned_tokens.dtype)
698
+ )
699
+ return final_out
700
+
701
+ def moe_on_device(self, x, topk_ids, topk_weight):
702
+ # This part sorts the token indices so that tokens routed to the same expert reside consecutively.
703
+ # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
704
+ # Since this is an "aritificial" index creation (final outcome being
705
+ # `idxs`), we don't need gradients here.
706
+ with torch.no_grad():
707
+ # [seq_len, n_routed_experts]
708
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
709
+ # Fill 1 to the selected experts
710
+ cnts.scatter_(1, topk_ids, 1)
711
+ tokens_per_expert = cnts.sum(dim=0)
712
+ # Token indices for each expert
713
+ idxs = topk_ids.view(-1).argsort()
714
+ sorted_tokens_shape = idxs.shape + x.shape[1:]
715
+
716
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
717
+ assert sorted_tokens.shape == sorted_tokens_shape
718
+
719
+ # This part exchange the information about the number of tokens send and
720
+ # received by each expert. We can understand this information as "side
721
+ # band", which is not part of the actual data. Thus no gradient is
722
+ # needed.
723
+ with torch.no_grad():
724
+ # Sum the tokens over local experts, then we get tokens per EP rank,
725
+ # which is the input splits
726
+ tokens_per_expert_group = tokens_per_expert.new_empty(
727
+ tokens_per_expert.shape[0]
728
+ )
729
+ dist.all_to_all_single(
730
+ tokens_per_expert_group, tokens_per_expert, group=self.ep_group
731
+ )
732
+ input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
733
+
734
+ # Move input to the `token_send_buf` symm mem
735
+ token_send_buf = self.get_send_buf()
736
+ token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
737
+ # Note: `out=` avoids copy, but it is not differentiable
738
+ # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
739
+ token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
740
+ token_send_buf,
741
+ input_splits,
742
+ self.ep_group,
743
+ )
744
+
745
+ # We need to permute the received tokens so that tokens for the same expert are contiguous.
746
+ # This part prepares a 1D tensor `permuted_indices` for such permutation.
747
+ # This part doesn't need gradient.
748
+ with torch.no_grad():
749
+ permuted_indices, m_sizes = generate_permute_indices(
750
+ tokens_per_expert_group,
751
+ self.experts_per_rank,
752
+ self.ep_size,
753
+ token_gather_buf.shape[0],
754
+ ALIGN_SIZE_M,
755
+ )
756
+
757
+ # Permute the received tokens so that tokens for the same expert are contiguous.
758
+ contig_tokens = token_gather_buf[permuted_indices]
759
+
760
+ # Run the first grouped GEMM
761
+ w1 = self.get_parameter("gate_proj_weight")
762
+ gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes)
763
+
764
+ # Run the second grouped GEMM
765
+ w3 = self.get_parameter("up_proj_weight")
766
+ up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes)
767
+
768
+ # Apply activation
769
+ hidden_outputs = MLP.act_fn(gate_proj) * up_proj
770
+
771
+ # Run the third grouped GEMM
772
+ w2 = self.get_parameter("down_proj_weight")
773
+ hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes)
774
+
775
+ # Prepare buffer for tokens processed by experts
776
+ # Take necessary space from `token_gather_buf` symm mem because we are
777
+ # going to send them out after expert processing
778
+ processed_tokens = self.get_gather_buf()
779
+
780
+ # Move into Symmetric Memory for the return shuffle
781
+ processed_tokens[permuted_indices] = hidden_outputs
782
+
783
+ # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
784
+ # The input/output splits are just a reverse of the previous shuffle.
785
+ token_return_buf, _ = OnDeviceAllToAllV.apply(
786
+ processed_tokens,
787
+ output_splits,
788
+ self.ep_group,
789
+ )
790
+ returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
791
+
792
+ output_tokens = torch.empty_like(returned_tokens)
793
+ output_tokens[idxs] = returned_tokens
794
+ final_out = (
795
+ output_tokens.view(*topk_ids.shape, -1)
796
+ .type(topk_weight.dtype)
797
+ .mul_(topk_weight.unsqueeze(dim=-1))
798
+ .sum(dim=1)
799
+ .type(returned_tokens.dtype)
800
+ )
801
+ return final_out
802
+
803
+
804
+ class Attention(nn.Module):
805
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
806
+
807
+ def __init__(self, config: ModelArgs, layer_idx: Optional[int] = None):
808
+ super().__init__()
809
+ self.config = config
810
+ self.layer_idx = layer_idx
811
+ self.attention_dropout = config.attention_dropout
812
+ self.hidden_size = config.hidden_size
813
+ self.num_heads = config.num_attention_heads
814
+
815
+ self.max_position_embeddings = config.max_position_embeddings
816
+ self.rope_theta = config.rope_theta
817
+ self.q_lora_rank = config.q_lora_rank
818
+ self.qk_rope_head_dim = config.qk_rope_head_dim
819
+ self.kv_lora_rank = config.kv_lora_rank
820
+ self.v_head_dim = config.v_head_dim
821
+ self.qk_nope_head_dim = config.qk_nope_head_dim
822
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
823
+
824
+ self.is_causal = True
825
+
826
+ if self.q_lora_rank is None:
827
+ self.q_proj = nn.Linear(
828
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
829
+ )
830
+ else:
831
+ self.q_a_proj = nn.Linear(
832
+ self.hidden_size, config.q_lora_rank, bias=config.attention_bias
833
+ )
834
+ self.q_a_layernorm = RMSNorm(config.q_lora_rank)
835
+ self.q_b_proj = nn.Linear(
836
+ config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
837
+ )
838
+
839
+ self.kv_a_proj_with_mqa = nn.Linear(
840
+ self.hidden_size,
841
+ config.kv_lora_rank + config.qk_rope_head_dim,
842
+ bias=config.attention_bias,
843
+ )
844
+ self.kv_a_layernorm = RMSNorm(config.kv_lora_rank)
845
+ self.kv_b_proj = nn.Linear(
846
+ config.kv_lora_rank,
847
+ self.num_heads
848
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
849
+ bias=False,
850
+ )
851
+
852
+ self.o_proj = nn.Linear(
853
+ self.num_heads * self.v_head_dim,
854
+ self.hidden_size,
855
+ bias=config.attention_bias,
856
+ )
857
+ self._init_rope()
858
+
859
+ self.softmax_scale = self.q_head_dim ** (-0.5)
860
+ if self.config.rope_scaling is not None:
861
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
862
+ scaling_factor = self.config.rope_scaling["factor"]
863
+ if mscale_all_dim:
864
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
865
+ self.softmax_scale = self.softmax_scale * mscale * mscale
866
+
867
+ def _init_rope(self):
868
+ if self.config.rope_scaling is None:
869
+ self.rotary_emb = RotaryEmbedding(
870
+ self.qk_rope_head_dim,
871
+ max_position_embeddings=self.max_position_embeddings,
872
+ base=self.rope_theta,
873
+ )
874
+ else:
875
+ scaling_type = self.config.rope_scaling["type"]
876
+ scaling_factor = self.config.rope_scaling["factor"]
877
+ if scaling_type == "linear":
878
+ self.rotary_emb = LinearScalingRotaryEmbedding(
879
+ self.qk_rope_head_dim,
880
+ max_position_embeddings=self.max_position_embeddings,
881
+ scaling_factor=scaling_factor,
882
+ base=self.rope_theta,
883
+ )
884
+ elif scaling_type == "dynamic":
885
+ self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
886
+ self.qk_rope_head_dim,
887
+ max_position_embeddings=self.max_position_embeddings,
888
+ scaling_factor=scaling_factor,
889
+ base=self.rope_theta,
890
+ )
891
+ elif scaling_type == "yarn":
892
+ kwargs = {
893
+ key: self.config.rope_scaling[key]
894
+ for key in [
895
+ "original_max_position_embeddings",
896
+ "beta_fast",
897
+ "beta_slow",
898
+ "mscale",
899
+ "mscale_all_dim",
900
+ ]
901
+ if key in self.config.rope_scaling
902
+ }
903
+ self.rotary_emb = YarnRotaryEmbedding(
904
+ self.qk_rope_head_dim,
905
+ max_position_embeddings=self.max_position_embeddings,
906
+ scaling_factor=scaling_factor,
907
+ base=self.rope_theta,
908
+ **kwargs,
909
+ )
910
+ else:
911
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
912
+
913
+ def forward(
914
+ self,
915
+ hidden_states: torch.Tensor,
916
+ attention_mask: Optional[torch.Tensor] = None,
917
+ position_ids: Optional[torch.LongTensor] = None,
918
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
919
+ bsz, q_len, _ = hidden_states.size()
920
+
921
+ if self.q_lora_rank is None:
922
+ q = self.q_proj(hidden_states)
923
+ else:
924
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
925
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
926
+ q_nope, q_pe = torch.split(
927
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
928
+ )
929
+
930
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
931
+ compressed_kv, k_pe = torch.split(
932
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
933
+ )
934
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
935
+ kv = (
936
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
937
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
938
+ .transpose(1, 2)
939
+ )
940
+
941
+ k_nope, value_states = torch.split(
942
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
943
+ )
944
+ kv_seq_len = value_states.shape[-2]
945
+
946
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
947
+
948
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
949
+
950
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
951
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
952
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
953
+
954
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
955
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
956
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
957
+
958
+ if attention_mask is not None:
959
+ # Attention mask was made 4D because the `attn_weights` above is 4D.
960
+ # We probably can make this mask smarter if we want to pack sequences
961
+ # together, instead of using padding. This optimization can be used in
962
+ # inference. For training, if we want to pack sequences, data loader
963
+ # will pass in a mask containing such info.
964
+ attention_mask = _prepare_4d_causal_attention_mask(
965
+ attention_mask, # None, or user provided mask in 2D
966
+ (bsz, q_len),
967
+ hidden_states,
968
+ 0, # past_key_values_length, 0 when training
969
+ )
970
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
971
+ raise ValueError(
972
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
973
+ )
974
+
975
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
976
+ query=query_states,
977
+ key=key_states,
978
+ value=value_states,
979
+ attn_mask=attention_mask,
980
+ dropout_p=self.attention_dropout,
981
+ is_causal=attention_mask is None,
982
+ scale=self.softmax_scale,
983
+ )
984
+
985
+ attn_output = attn_output.transpose(1, 2).contiguous()
986
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
987
+ attn_output = self.o_proj(attn_output)
988
+
989
+ return attn_output
990
+
991
+
992
+ class DecoderLayer(nn.Module):
993
+ def __init__(self, config: ModelArgs, layer_idx: int):
994
+ super().__init__()
995
+ self.hidden_size = config.hidden_size
996
+
997
+ self.self_attn = Attention(config=config, layer_idx=layer_idx)
998
+
999
+ self.mlp = (
1000
+ MoE(config)
1001
+ if (
1002
+ config.n_routed_experts is not None
1003
+ and layer_idx >= config.first_k_dense_replace
1004
+ and layer_idx % config.moe_layer_freq == 0
1005
+ )
1006
+ else MLP(config)
1007
+ )
1008
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1009
+ self.post_attention_layernorm = RMSNorm(
1010
+ config.hidden_size, eps=config.rms_norm_eps
1011
+ )
1012
+
1013
+ def forward(
1014
+ self,
1015
+ hidden_states: torch.Tensor,
1016
+ attention_mask: Optional[torch.Tensor] = None,
1017
+ position_ids: Optional[torch.LongTensor] = None,
1018
+ ) -> torch.Tensor:
1019
+ """
1020
+ Args:
1021
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1022
+ attention_mask (`torch.FloatTensor`, *optional*):
1023
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1024
+ query_sequence_length, key_sequence_length)` if default attention is used.
1025
+ """
1026
+ residual = hidden_states
1027
+
1028
+ hidden_states = self.input_layernorm(hidden_states)
1029
+
1030
+ # Self Attention
1031
+ hidden_states = self.self_attn(
1032
+ hidden_states=hidden_states,
1033
+ attention_mask=attention_mask,
1034
+ position_ids=position_ids,
1035
+ )
1036
+ hidden_states = residual + hidden_states
1037
+
1038
+ # Fully Connected
1039
+ residual = hidden_states
1040
+ hidden_states = self.post_attention_layernorm(hidden_states)
1041
+ hidden_states = self.mlp(hidden_states)
1042
+ hidden_states = residual + hidden_states
1043
+
1044
+ return hidden_states
1045
+
1046
+
1047
+ Deepseek_INPUTS_DOCSTRING = r"""
1048
+ Args:
1049
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1050
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1051
+ it.
1052
+
1053
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1054
+ [`PreTrainedTokenizer.__call__`] for details.
1055
+
1056
+ [What are input IDs?](../glossary#input-ids)
1057
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1058
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1059
+
1060
+ - 1 for tokens that are **not masked**,
1061
+ - 0 for tokens that are **masked**.
1062
+
1063
+ [What are attention masks?](../glossary#attention-mask)
1064
+
1065
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1066
+ [`PreTrainedTokenizer.__call__`] for details.
1067
+
1068
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1069
+ `past_key_values`).
1070
+
1071
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1072
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1073
+ information on the default strategy.
1074
+
1075
+ - 1 indicates the head is **not masked**,
1076
+ - 0 indicates the head is **masked**.
1077
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1078
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1079
+ config.n_positions - 1]`.
1080
+
1081
+ [What are position IDs?](../glossary#position-ids)
1082
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1083
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1084
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1085
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1086
+
1087
+ Two formats are allowed:
1088
+ - a [`~cache_utils.Cache`] instance;
1089
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1090
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1091
+ cache format.
1092
+
1093
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1094
+ legacy cache format will be returned.
1095
+
1096
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1097
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1098
+ of shape `(batch_size, sequence_length)`.
1099
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1100
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1101
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1102
+ model's internal embedding lookup matrix.
1103
+ use_cache (`bool`, *optional*):
1104
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1105
+ `past_key_values`).
1106
+ output_attentions (`bool`, *optional*):
1107
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1108
+ tensors for more detail.
1109
+ output_hidden_states (`bool`, *optional*):
1110
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1111
+ more detail.
1112
+ return_dict (`bool`, *optional*):
1113
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1114
+ """
1115
+
1116
+
1117
+ class DeepseekModel(torch.nn.Module):
1118
+ """
1119
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`]
1120
+
1121
+ Args:
1122
+ config: ModelArgs
1123
+ """
1124
+
1125
+ def __init__(self, config: ModelArgs):
1126
+ super().__init__()
1127
+ self.config = config
1128
+ self.padding_idx = config.pad_token_id
1129
+ self.vocab_size = config.vocab_size
1130
+
1131
+ # Creating model parts related to my stage
1132
+ assert (
1133
+ config.stage_idx < config.num_stages
1134
+ ), f"Stage {config.stage_idx} is not in the model"
1135
+ print(f"Creating model stage {config.stage_idx} of {config.num_stages}")
1136
+
1137
+ self.embed_tokens = (
1138
+ nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1139
+ if config.stage_idx == 0
1140
+ else None
1141
+ )
1142
+
1143
+ self.layers = torch.nn.ModuleDict()
1144
+ division = config.num_hidden_layers // config.num_stages
1145
+ residual = config.num_hidden_layers % config.num_stages
1146
+ # Some earlier stages may have 1 more layer than latter stages because
1147
+ # the division may have residual; this is more even than giving the
1148
+ # entire residual to the last stage.
1149
+ layers_per_stage = [
1150
+ division + 1 if stage < residual else division
1151
+ for stage in range(config.num_stages)
1152
+ ]
1153
+ assert sum(layers_per_stage) == config.num_hidden_layers
1154
+ layer_id_start = sum(layers_per_stage[: config.stage_idx])
1155
+ layer_id_end = layer_id_start + layers_per_stage[config.stage_idx]
1156
+ for layer_id in range(layer_id_start, layer_id_end):
1157
+ self.layers[str(layer_id)] = DecoderLayer(config, layer_id)
1158
+
1159
+ self.norm = (
1160
+ RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1161
+ if config.stage_idx == config.num_stages - 1
1162
+ else None
1163
+ )
1164
+
1165
+ # Initialize weights and apply final processing
1166
+ self.apply(self._init_weights)
1167
+
1168
+ def _init_weights(self, module):
1169
+ std = self.config.initializer_range
1170
+ if isinstance(module, nn.Linear):
1171
+ module.weight.data.normal_(mean=0.0, std=std)
1172
+ if module.bias is not None:
1173
+ module.bias.data.zero_()
1174
+ elif isinstance(module, nn.Embedding):
1175
+ module.weight.data.normal_(mean=0.0, std=std)
1176
+ if module.padding_idx is not None:
1177
+ module.weight.data[module.padding_idx].zero_()
1178
+
1179
+ def forward(
1180
+ self,
1181
+ tokens: torch.Tensor,
1182
+ attention_mask: Optional[torch.Tensor] = None,
1183
+ position_ids: Optional[torch.LongTensor] = None,
1184
+ ) -> torch.Tensor:
1185
+ # Embedding
1186
+ hidden_states = (
1187
+ self.embed_tokens(tokens) if self.embed_tokens is not None else tokens
1188
+ )
1189
+
1190
+ # decoder layers
1191
+ for decoder_layer in self.layers.values():
1192
+ hidden_states = decoder_layer(
1193
+ hidden_states,
1194
+ attention_mask=attention_mask,
1195
+ position_ids=position_ids,
1196
+ )
1197
+
1198
+ hidden_states = (
1199
+ self.norm(hidden_states) if self.norm is not None else hidden_states
1200
+ )
1201
+ return hidden_states
1202
+
1203
+
1204
+ class DeepseekForCausalLM(torch.nn.Module):
1205
+ def __init__(self, config):
1206
+ super().__init__()
1207
+ self.model = DeepseekModel(config)
1208
+ self.lm_head = (
1209
+ nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1210
+ if config.stage_idx == config.num_stages - 1
1211
+ else None
1212
+ )
1213
+
1214
+ # Initialize weights and apply final processing
1215
+ # self.post_init()
1216
+
1217
+ def forward(
1218
+ self,
1219
+ tokens: torch.Tensor,
1220
+ attention_mask: Optional[torch.Tensor] = None,
1221
+ position_ids: Optional[torch.LongTensor] = None,
1222
+ ) -> Tuple:
1223
+ r"""
1224
+ Example:
1225
+
1226
+ ```python
1227
+ >>> from transformers import AutoTokenizer, DeepseekForCausalLM
1228
+
1229
+ >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1230
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1231
+
1232
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1233
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1234
+
1235
+ >>> # Generate
1236
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1237
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1238
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1239
+ ```"""
1240
+ hidden_states = self.model(
1241
+ tokens,
1242
+ attention_mask=attention_mask,
1243
+ position_ids=position_ids,
1244
+ )
1245
+
1246
+ logits = (
1247
+ self.lm_head(hidden_states) if self.lm_head is not None else hidden_states
1248
+ )
1249
+ return logits
1250
+
1251
+ def prepare_inputs_for_generation(
1252
+ self,
1253
+ input_ids,
1254
+ past_key_values=None,
1255
+ attention_mask=None,
1256
+ **kwargs,
1257
+ ):
1258
+ if past_key_values is not None:
1259
+ # Assuming isinstance(past_key_values, Cache):
1260
+ cache_length = past_key_values.get_seq_length()
1261
+ past_length = past_key_values.seen_tokens
1262
+ max_cache_length = past_key_values.get_max_length()
1263
+
1264
+ # Keep only the unprocessed tokens:
1265
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1266
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1267
+ # input)
1268
+ if (
1269
+ attention_mask is not None
1270
+ and attention_mask.shape[1] > input_ids.shape[1]
1271
+ ):
1272
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1273
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1274
+ # input_ids based on the past_length.
1275
+ elif past_length < input_ids.shape[1]:
1276
+ input_ids = input_ids[:, past_length:]
1277
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1278
+
1279
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1280
+ if (
1281
+ max_cache_length is not None
1282
+ and attention_mask is not None
1283
+ and cache_length + input_ids.shape[1] > max_cache_length
1284
+ ):
1285
+ attention_mask = attention_mask[:, -max_cache_length:]
1286
+
1287
+ position_ids = kwargs.get("position_ids", None)
1288
+ if attention_mask is not None and position_ids is None:
1289
+ # create position_ids on the fly for batch generation
1290
+ position_ids = attention_mask.long().cumsum(-1) - 1
1291
+ position_ids.masked_fill_(attention_mask == 0, 1)
1292
+ if past_key_values:
1293
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1294
+
1295
+ model_inputs = {"input_ids": input_ids}
1296
+
1297
+ model_inputs.update(
1298
+ {
1299
+ "position_ids": position_ids,
1300
+ "past_key_values": past_key_values,
1301
+ "use_cache": kwargs.get("use_cache"),
1302
+ "attention_mask": attention_mask,
1303
+ }
1304
+ )
1305
+ return model_inputs
1306
+
1307
+ @staticmethod
1308
+ def _reorder_cache(past_key_values, beam_idx):
1309
+ reordered_past = ()
1310
+ for layer_past in past_key_values:
1311
+ reordered_past += (
1312
+ tuple(
1313
+ past_state.index_select(0, beam_idx.to(past_state.device))
1314
+ for past_state in layer_past
1315
+ ),
1316
+ )
1317
+ return reordered_past
1318
+
1319
+ # Setup Symmetric Memory for MoE token shuffle.
1320
+ # Supports inference currently.
1321
+ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
1322
+ for layer in self.model.layers.values():
1323
+ if not isinstance(layer.mlp, MoE):
1324
+ continue
1325
+ layer.mlp.setup_symm_mem(dtype, device)
torchtitan/experiments/deepseek_v3/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ accelerate
3
+ torchdata >= 0.8.0
4
+ datasets >= 2.21.0
5
+ tomli >= 1.1.0 ; python_version < "3.11"
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ @triton.jit
12
+ def get_tid():
13
+ return tl.inline_asm_elementwise(
14
+ """
15
+ mov.u32 $0, %tid.x;
16
+ mov.u32 $1, %tid.y;
17
+ mov.u32 $2, %tid.z;
18
+ """,
19
+ "=r,=r,=r",
20
+ [],
21
+ dtype=(tl.uint32, tl.uint32, tl.uint32),
22
+ is_pure=True,
23
+ pack=1,
24
+ )
25
+
26
+
27
+ @triton.jit
28
+ def get_ntid():
29
+ return tl.inline_asm_elementwise(
30
+ """
31
+ mov.u32 $0, %ntid.x;
32
+ mov.u32 $1, %ntid.y;
33
+ mov.u32 $2, %ntid.z;
34
+ """,
35
+ "=r,=r,=r",
36
+ [],
37
+ dtype=(tl.uint32, tl.uint32, tl.uint32),
38
+ is_pure=True,
39
+ pack=1,
40
+ )
41
+
42
+
43
+ @triton.jit
44
+ def get_flat_tid():
45
+ tid_x, tid_y, tid_z = get_tid()
46
+ ntid_x, ntid_y, _ = get_ntid()
47
+ return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
48
+
49
+
50
+ @triton.jit
51
+ def get_flat_bid():
52
+ return (
53
+ tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
54
+ + tl.program_id(1) * tl.num_programs(0)
55
+ + tl.program_id(0)
56
+ )
57
+
58
+
59
+ @triton.jit
60
+ def sync_threads():
61
+ tl.inline_asm_elementwise(
62
+ "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
63
+ )
torchtitan/experiments/flux/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX model in torchtitan
2
+
3
+ ## Overview
4
+
5
+ ## Usage
6
+ First, download the autoencoder model from HuggingFace with your own access token:
7
+ ```bash
8
+ python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
9
+ ```
10
+ This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
11
+
12
+ Run the following command to train the model on a single GPU:
13
+ ```bash
14
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
15
+ ```
16
+
17
+ ## TODO
18
+ - [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
19
+ - [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
20
+ - [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
21
+ - [ ] Support for distributed checkpointing and loading
22
+ - [ ] Implement init_weights() function to initialize the model weights
23
+ - [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
torchtitan/experiments/flux/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (2.08 kB). View file
 
torchtitan/experiments/flux/dataset/flux_dataset.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import random
9
+ from dataclasses import dataclass
10
+ from typing import Any, Callable, Optional
11
+
12
+ import numpy as np
13
+
14
+ import torch
15
+
16
+ from datasets import Dataset, load_dataset
17
+ from datasets.distributed import split_dataset_by_node
18
+ from PIL import Image
19
+
20
+ from torch.distributed.checkpoint.stateful import Stateful
21
+
22
+ from torch.utils.data import IterableDataset
23
+ from torchtitan.components.dataloader import ParallelAwareDataloader
24
+
25
+ from torchtitan.config_manager import JobConfig
26
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
27
+ from torchtitan.tools.logging import logger
28
+
29
+
30
+ def _process_cc12m_image(
31
+ img: Image.Image,
32
+ output_size: int = 256,
33
+ ) -> Optional[torch.Tensor]:
34
+ """Process CC12M image to the desired size."""
35
+
36
+ width, height = img.size
37
+ # Skip low resolution images
38
+ if width < output_size or height < output_size:
39
+ return None
40
+
41
+ if width >= height:
42
+ # resize height to be equal to output_size, then crop
43
+ new_width, new_height = math.ceil(output_size / height * width), output_size
44
+ img = img.resize((new_width, new_height))
45
+ left = random.randint(0, new_width - output_size)
46
+ resized_img = img.crop((left, 0, left + output_size, output_size))
47
+ else:
48
+ # resize width to be equal to output_size, the crop
49
+ new_width, new_height = (
50
+ output_size,
51
+ math.ceil(output_size / width * height),
52
+ )
53
+ img = img.resize((new_width, new_height))
54
+ lower = random.randint(0, new_width - output_size)
55
+ resized_img = img.crop((0, lower, output_size, lower + output_size))
56
+
57
+ assert resized_img.size[0] == resized_img.size[1] == output_size
58
+
59
+ # Skip grayscale images
60
+ if resized_img.mode == "L":
61
+ return None
62
+
63
+ np_img = np.array(resized_img).transpose((2, 0, 1))
64
+ tensor_img = torch.tensor(np_img).float() / 255.0
65
+
66
+ # NOTE: The following commented code is an alternative way
67
+ # img_transform = transforms.Compose(
68
+ # [
69
+ # transforms.Resize(max(output_size, output_size)),
70
+ # transforms.CenterCrop((output_size, output_size)),
71
+ # transforms.ToTensor(),
72
+ # ]
73
+ # )
74
+ # tensor_img = img_transform(img)
75
+
76
+ return tensor_img
77
+
78
+
79
+ def _flux_data_processor(
80
+ sample: dict[str, Any],
81
+ t5_tokenizer: FluxTokenizer,
82
+ clip_tokenizer: FluxTokenizer,
83
+ output_size: int = 256,
84
+ ) -> dict[str, Any]:
85
+ """
86
+ Preprocess CC12M dataset sample image and text for Flux model.
87
+
88
+ Args:
89
+ sample: A sample from dataset
90
+ t5_encoder: T5 encoder
91
+ clip_encoder: CLIP encoder
92
+ output_size: The output image size
93
+
94
+ """
95
+ img = _process_cc12m_image(sample["jpg"], output_size=output_size)
96
+ t5_tokens = t5_tokenizer.encode(sample["txt"])
97
+ clip_tokens = clip_tokenizer.encode(sample["txt"])
98
+
99
+ return {
100
+ "image": img,
101
+ "clip_tokens": clip_tokens, # type: List[int]
102
+ "t5_tokens": t5_tokens, # type: List[int]
103
+ }
104
+
105
+
106
+ @dataclass
107
+ class TextToImageDatasetConfig:
108
+ path: str
109
+ loader: Callable
110
+ data_processor: Callable
111
+
112
+
113
+ DATASETS = {
114
+ "cc12m": TextToImageDatasetConfig(
115
+ path="pixparse/cc12m-wds",
116
+ loader=lambda path: load_dataset(path, split="train", streaming=True),
117
+ data_processor=_flux_data_processor,
118
+ ),
119
+ }
120
+
121
+
122
+ def _validate_dataset(
123
+ dataset_name: str, dataset_path: Optional[str] = None
124
+ ) -> tuple[str, Callable, Callable]:
125
+ """Validate dataset name and path."""
126
+ if dataset_name not in DATASETS:
127
+ raise ValueError(
128
+ f"Dataset {dataset_name} is not supported. "
129
+ f"Supported datasets are: {list(DATASETS.keys())}"
130
+ )
131
+
132
+ config = DATASETS[dataset_name]
133
+ path = dataset_path or config.path
134
+ logger.info(f"Preparing {dataset_name} dataset from {path}")
135
+ return path, config.loader, config.data_processor
136
+
137
+
138
+ class FluxDataset(IterableDataset, Stateful):
139
+ """Dataset for FLUX text-to-image model.
140
+
141
+ Args:
142
+ dataset_name (str): Name of the dataset.
143
+ dataset_path (str): Path to the dataset.
144
+ model_transform (Transform): Callable that applies model-specific preprocessing to the sample.
145
+ dp_rank (int): Data parallel rank.
146
+ dp_world_size (int): Data parallel world size.
147
+ infinite (bool): Whether to loop over the dataset infinitely.
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ dataset_name: str,
153
+ dataset_path: Optional[str],
154
+ t5_tokenizer: FluxTokenizer,
155
+ clip_tokenizer: FluxTokenizer,
156
+ job_config: Optional[JobConfig] = None,
157
+ dp_rank: int = 0,
158
+ dp_world_size: int = 1,
159
+ infinite: bool = False,
160
+ ) -> None:
161
+
162
+ # Force lowercase for consistent comparison
163
+ dataset_name = dataset_name.lower()
164
+
165
+ path, dataset_loader, data_processor = _validate_dataset(
166
+ dataset_name, dataset_path
167
+ )
168
+ ds = dataset_loader(path)
169
+
170
+ self.dataset_name = dataset_name
171
+ self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
172
+
173
+ self._t5_tokenizer = t5_tokenizer
174
+ self._clip_tokenizer = clip_tokenizer
175
+ self._data_processor = data_processor
176
+ self.job_config = job_config
177
+
178
+ self.infinite = infinite
179
+
180
+ # Variables for checkpointing
181
+ self._sample_idx = 0
182
+ self._all_samples: list[dict[str, Any]] = []
183
+
184
+ def _get_data_iter(self):
185
+ if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
186
+ return iter([])
187
+
188
+ it = iter(self._data)
189
+ for _ in range(self._sample_idx):
190
+ next(it)
191
+ return it
192
+
193
+ def __iter__(self):
194
+ while True:
195
+ for sample in self._get_data_iter():
196
+ # Use the dataset-specific preprocessor
197
+ sample_dict = self._data_processor(
198
+ sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256
199
+ )
200
+
201
+ # skip low quality image or image with color channel = 1
202
+ if sample_dict["image"] is None:
203
+ logger.warning(
204
+ f"Low quality image {sample['__key__']} is skipped in Flux Dataloader"
205
+ )
206
+ continue
207
+
208
+ self._all_samples.extend(sample_dict)
209
+ self._sample_idx += 1
210
+
211
+ labels = sample_dict.pop("image")
212
+ yield sample_dict, labels
213
+
214
+ if not self.infinite:
215
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
216
+ break
217
+ else:
218
+ # Reset offset for the next iteration
219
+ self._sample_idx = 0
220
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
221
+
222
+ def load_state_dict(self, state_dict):
223
+ self._sample_idx = state_dict["sample_idx"]
224
+ self._all_samples = state_dict["all_samples"]
225
+
226
+ def state_dict(self):
227
+ return {
228
+ "all_samples": self._all_samples,
229
+ "sample_idx": self._sample_idx,
230
+ }
231
+
232
+
233
+ def build_flux_dataloader(
234
+ dp_world_size: int,
235
+ dp_rank: int,
236
+ job_config: JobConfig,
237
+ # This parameter is not used, keep it for compatibility
238
+ tokenizer: FluxTokenizer | None,
239
+ infinite: bool = True,
240
+ ) -> ParallelAwareDataloader:
241
+ """Build a data loader for HuggingFace datasets."""
242
+ dataset_name = job_config.training.dataset
243
+ dataset_path = job_config.training.dataset_path
244
+ batch_size = job_config.training.batch_size
245
+
246
+ t5_encoder_name = job_config.encoder.t5_encoder
247
+ clip_encoder_name = job_config.encoder.clip_encoder
248
+ max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
249
+
250
+ ds = FluxDataset(
251
+ dataset_name=dataset_name,
252
+ dataset_path=dataset_path,
253
+ t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len),
254
+ clip_tokenizer=FluxTokenizer(
255
+ clip_encoder_name, max_length=77
256
+ ), # fix max_length for CLIP
257
+ dp_rank=dp_rank,
258
+ dp_world_size=dp_world_size,
259
+ infinite=infinite,
260
+ )
261
+
262
+ return ParallelAwareDataloader(
263
+ dataset=ds,
264
+ dp_rank=dp_rank,
265
+ dp_world_size=dp_world_size,
266
+ batch_size=batch_size,
267
+ )
torchtitan/experiments/flux/dataset/tokenizer.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+
11
+ from typing import List
12
+
13
+ from torchtitan.components.tokenizer import Tokenizer
14
+ from transformers import CLIPTokenizer, T5Tokenizer
15
+
16
+
17
+ class FluxTokenizer(Tokenizer):
18
+ """
19
+ Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
20
+
21
+ Args:
22
+ model_path (str): Path to the tokenzier from hugging face.
23
+
24
+ """
25
+
26
+ def __init__(self, model_path: str = "t5-small", max_length: int = 77):
27
+ super().__init__()
28
+ self._n_words = 8 # TODO(jianiw): check
29
+ self._max_length = max_length
30
+
31
+ self.is_clip = model_path.startswith("openai")
32
+
33
+ if self.is_clip:
34
+ self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
35
+ model_path, max_length=max_length
36
+ )
37
+ else:
38
+ self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
39
+ model_path, max_length=max_length
40
+ )
41
+
42
+ def encode(
43
+ self,
44
+ s: str,
45
+ ) -> List[int]:
46
+ """
47
+ Encode the prompt text into tokens.
48
+ """
49
+ tokens = self._tokenizer(
50
+ s,
51
+ truncation=True,
52
+ max_length=self._max_length,
53
+ return_length=False,
54
+ return_overflowing_tokens=False,
55
+ padding="max_length",
56
+ return_tensors="pt", # return pytorch tensors, default return List[int]
57
+ )["input_ids"]
58
+ return tokens
59
+
60
+ def decode(self, t: List[int]) -> str:
61
+ """
62
+ Decode function. This function will not be called.
63
+ """
64
+ return self._tokenizer.decode(t)
torchtitan/experiments/flux/model/__pycache__/layers.cpython-312.pyc ADDED
Binary file (17.7 kB). View file
 
torchtitan/experiments/flux/model/hf_embedder.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn, Tensor
8
+ from transformers import CLIPTextModel, T5EncoderModel
9
+
10
+
11
+ class FluxEmbedder(nn.Module):
12
+ def __init__(self, version: str, **hf_kwargs):
13
+ super().__init__()
14
+ self.is_clip = version.startswith("openai")
15
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
16
+
17
+ if self.is_clip:
18
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
19
+ version, **hf_kwargs
20
+ )
21
+ else:
22
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
23
+ version, **hf_kwargs
24
+ )
25
+
26
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
27
+
28
+ def forward(self, batch_tokens: Tensor) -> Tensor:
29
+ """
30
+ batch_tokens: [bsz, embedding_length]
31
+
32
+ For T5 Encoder, embeding_length is 768
33
+ For CLIP, embedding_length is 256
34
+ """
35
+ outputs = self.hf_module(
36
+ input_ids=batch_tokens.to(self.hf_module.device),
37
+ attention_mask=None,
38
+ output_hidden_states=False,
39
+ )
40
+ return outputs[self.output_key]
torchtitan/experiments/flux/model/model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+ import torch
10
+
11
+ from torch import nn, Tensor
12
+ from torchtitan.components.tokenizer import Tokenizer
13
+ from torchtitan.config_manager import JobConfig
14
+
15
+ from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
16
+ from torchtitan.experiments.flux.model.layers import (
17
+ DoubleStreamBlock,
18
+ EmbedND,
19
+ LastLayer,
20
+ MLPEmbedder,
21
+ SingleStreamBlock,
22
+ timestep_embedding,
23
+ )
24
+
25
+ from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
26
+ from torchtitan.tools.logging import logger
27
+
28
+
29
+ @dataclass
30
+ class FluxModelArgs(BaseModelArgs):
31
+ in_channels: int = 64
32
+ out_channels: int = 64
33
+ vec_in_dim: int = 768
34
+ context_in_dim: int = 512
35
+ hidden_size: int = 3072
36
+ mlp_ratio: float = 4.0
37
+ num_heads: int = 24
38
+ depth: int = 19
39
+ depth_single_blocks: int = 38
40
+ axes_dim: tuple = (16, 56, 56)
41
+ theta: int = 10_000
42
+ qkv_bias: bool = True
43
+ guidance_embed: bool = True
44
+ autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
45
+
46
+ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
47
+ # context_in_dim is the same as the T5 embedding dimension
48
+ self.context_in_dim = job_config.encoder.max_t5_encoding_len
49
+
50
+ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
51
+ # TODO(jianiw): Add the number of flops for the autoencoder
52
+ nparams = sum(p.numel() for p in model.parameters())
53
+ logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
54
+ return nparams, 1
55
+
56
+
57
+ class FluxModel(nn.Module, ModelProtocol):
58
+ """
59
+ Transformer model for flow matching on sequences.
60
+
61
+ Agrs:
62
+ model_args: FluxModelArgs.
63
+
64
+ Attributes:
65
+ model_args (TransformerModelArgs): Model configuration arguments.
66
+ """
67
+
68
+ def __init__(self, model_args: FluxModelArgs):
69
+ super().__init__()
70
+
71
+ self.model_args = model_args
72
+ self.in_channels = model_args.in_channels
73
+ self.out_channels = model_args.out_channels
74
+ if model_args.hidden_size % model_args.num_heads != 0:
75
+ raise ValueError(
76
+ f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
77
+ )
78
+ pe_dim = model_args.hidden_size // model_args.num_heads
79
+ if sum(model_args.axes_dim) != pe_dim:
80
+ raise ValueError(
81
+ f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
82
+ )
83
+ self.hidden_size = model_args.hidden_size
84
+ self.num_heads = model_args.num_heads
85
+ self.pe_embedder = EmbedND(
86
+ dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
87
+ )
88
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
89
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
90
+ self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
91
+ self.guidance_in = (
92
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
93
+ if model_args.guidance_embed
94
+ else nn.Identity()
95
+ )
96
+ self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
97
+
98
+ self.double_blocks = nn.ModuleList(
99
+ [
100
+ DoubleStreamBlock(
101
+ self.hidden_size,
102
+ self.num_heads,
103
+ mlp_ratio=model_args.mlp_ratio,
104
+ qkv_bias=model_args.qkv_bias,
105
+ )
106
+ for _ in range(model_args.depth)
107
+ ]
108
+ )
109
+
110
+ self.single_blocks = nn.ModuleList(
111
+ [
112
+ SingleStreamBlock(
113
+ self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio
114
+ )
115
+ for _ in range(model_args.depth_single_blocks)
116
+ ]
117
+ )
118
+
119
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
120
+
121
+ def init_weights(self, buffer_device=None):
122
+ # TODO(jianiw): replace placeholder with real weight init
123
+ for param in self.parameters():
124
+ param.data.uniform_(0, 0.1)
125
+
126
+ def forward(
127
+ self,
128
+ img: Tensor,
129
+ img_ids: Tensor,
130
+ txt: Tensor,
131
+ txt_ids: Tensor,
132
+ timesteps: Tensor,
133
+ y: Tensor,
134
+ guidance: Tensor | None = None,
135
+ ) -> Tensor:
136
+ if img.ndim != 3 or txt.ndim != 3:
137
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
138
+
139
+ # running on sequences img
140
+ img = self.img_in(img)
141
+ vec = self.time_in(timestep_embedding(timesteps, 256))
142
+ if self.model_args.guidance_embed:
143
+ if guidance is None:
144
+ raise ValueError(
145
+ "Didn't get guidance strength for guidance distilled model."
146
+ )
147
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
148
+ vec = vec + self.vector_in(y)
149
+ txt = self.txt_in(txt)
150
+
151
+ ids = torch.cat((txt_ids, img_ids), dim=1)
152
+ pe = self.pe_embedder(ids)
153
+
154
+ for block in self.double_blocks:
155
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
156
+
157
+ img = torch.cat((txt, img), 1)
158
+ for block in self.single_blocks:
159
+ img = block(img, vec=vec, pe=pe)
160
+ img = img[:, txt.shape[1] :, ...]
161
+
162
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
163
+ return img
164
+
165
+ @classmethod
166
+ def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
167
+ """
168
+ Initialize a Flux model from a FluxModelArgs object.
169
+
170
+ Args:
171
+ model_args (FluxModelArgs): Model configuration arguments.
172
+
173
+ Returns:
174
+ FluxModel: FluxModel model.
175
+
176
+ """
177
+ return cls(model_args)
torchtitan/experiments/flux/tests/test_flux_dataloader.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+
9
+ from torchtitan.config_manager import JobConfig
10
+ from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
11
+ from torchtitan.tools.profiling import (
12
+ maybe_enable_memory_snapshot,
13
+ maybe_enable_profiling,
14
+ )
15
+
16
+
17
+ class TestFluxDataLoader:
18
+ def test_flux_dataloader(self):
19
+ dataset_name = "cc12m"
20
+ batch_size = 32
21
+ world_size = 4
22
+ rank = 0
23
+
24
+ num_steps = 10
25
+
26
+ path = "torchtitan.experiments.flux.flux_argparser"
27
+ sys.argv.append(f"--experimental.custom_args_module={path}")
28
+ config = JobConfig()
29
+ config.maybe_add_custom_args()
30
+ config.parse_args(
31
+ [
32
+ # Profiling options
33
+ # "--profiling.enable_profiling",
34
+ # "--profiling.profile_freq",
35
+ # "5",
36
+ # "--profiling.enable_memory_snapshot",
37
+ # "--profiling.save_memory_snapshot_folder",
38
+ # "memory_snapshot_flux",
39
+ "--training.dataset",
40
+ dataset_name,
41
+ "--training.batch_size",
42
+ str(batch_size),
43
+ "--encoder.t5_encoder",
44
+ "google/t5-v1_1-small",
45
+ "--encoder.clip_encoder",
46
+ "openai/clip-vit-large-patch14",
47
+ "--encoder.max_t5_encoding_len",
48
+ "512",
49
+ ]
50
+ )
51
+
52
+ with maybe_enable_profiling(
53
+ config, global_step=0
54
+ ) as torch_profiler, maybe_enable_memory_snapshot(
55
+ config, global_step=0
56
+ ) as memory_profiler:
57
+ dl = self._build_dataloader(
58
+ config,
59
+ world_size,
60
+ rank,
61
+ )
62
+ dl = iter(dl)
63
+
64
+ for i in range(0, num_steps):
65
+ input_data, labels = next(dl)
66
+ print(f"Step {i} image size: {labels.shape}")
67
+ if torch_profiler:
68
+ torch_profiler.step()
69
+ if memory_profiler:
70
+ memory_profiler.step()
71
+
72
+ print(len(input_data["clip_tokens"]))
73
+ for k, v in input_data.items():
74
+ print(f"Step {i} {k} value: {type(v), v.shape}")
75
+
76
+ assert len(input_data) == 2 # (clip_encodings, t5_encodings)
77
+ assert labels.shape == (batch_size, 3, 256, 256)
78
+ # assert input_data["clip_tokens"].shape[0] == batch_size
79
+ # assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
80
+
81
+ if torch_profiler:
82
+ torch_profiler.step()
83
+ if memory_profiler:
84
+ memory_profiler.step(exit_ctx=True)
85
+
86
+ def test_preprocess(self):
87
+ # TODO
88
+ pass
89
+
90
+ def _build_dataloader(
91
+ self,
92
+ job_config,
93
+ world_size,
94
+ rank,
95
+ ):
96
+
97
+ return build_flux_dataloader(
98
+ dp_world_size=world_size,
99
+ dp_rank=rank,
100
+ job_config=job_config,
101
+ tokenizer=None,
102
+ infinite=False,
103
+ )
torchtitan/experiments/flux/tests/test_generate_image.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import os
9
+ import time
10
+ from typing import Callable
11
+
12
+ import torch
13
+ from einops import rearrange
14
+
15
+ from PIL import ExifTags, Image
16
+
17
+ from torch import Tensor
18
+
19
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
20
+
21
+ from torchtitan.experiments.flux.model.autoencoder import (
22
+ AutoEncoder,
23
+ AutoEncoderParams,
24
+ load_ae,
25
+ )
26
+ from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
27
+
28
+ from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
29
+ from torchtitan.experiments.flux.utils import (
30
+ create_position_encoding_for_latents,
31
+ generate_noise_latent,
32
+ pack_latents,
33
+ preprocess_flux_data,
34
+ unpack_latents,
35
+ )
36
+
37
+
38
+ def time_shift(mu: float, sigma: float, t: Tensor):
39
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
40
+
41
+
42
+ def get_lin_function(
43
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
44
+ ) -> Callable[[float], float]:
45
+ m = (y2 - y1) / (x2 - x1)
46
+ b = y1 - m * x1
47
+ return lambda x: m * x + b
48
+
49
+
50
+ def get_schedule(
51
+ num_steps: int,
52
+ image_seq_len: int,
53
+ base_shift: float = 0.5,
54
+ max_shift: float = 1.15,
55
+ shift: bool = True,
56
+ ) -> list[float]:
57
+ # extra step for zero
58
+ timesteps = torch.linspace(1, 0, num_steps + 1)
59
+
60
+ # shifting the schedule to favor high timesteps for higher signal images
61
+ if shift:
62
+ # estimate mu based on linear estimation between two points
63
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
64
+ timesteps = time_shift(mu, 1.0, timesteps)
65
+
66
+ return timesteps.tolist()
67
+
68
+
69
+ class TestGenerateImage:
70
+ def test_generate_image(self):
71
+ """
72
+ Run a forward pass of flux model to generate an image.
73
+ """
74
+ name = "flux-dev"
75
+ img_width = 512
76
+ img_height = 512
77
+ seed = None
78
+ prompt = (
79
+ "a photo of a forest with mist swirling around the tree trunks. The word "
80
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
81
+ )
82
+ device = "cuda"
83
+ num_steps = None
84
+ loop = False
85
+ guidance = 3.5
86
+ output_dir = "output"
87
+ add_sampling_metadata = True
88
+
89
+ prompt = prompt.split("|")
90
+ if len(prompt) == 1:
91
+ prompt = prompt[0]
92
+ additional_prompts = None
93
+ else:
94
+ additional_prompts = prompt[1:]
95
+ prompt = prompt[0]
96
+
97
+ assert not (
98
+ (additional_prompts is not None) and loop
99
+ ), "Do not provide additional prompts and set loop to True"
100
+
101
+ torch_device = torch.device(device)
102
+ if num_steps is None:
103
+ num_steps = 30
104
+
105
+ # allow for packing and conversion to latent space
106
+ img_height = 16 * (img_height // 16)
107
+ img_width = 16 * (img_width // 16)
108
+
109
+ # init all components
110
+ model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
111
+
112
+ ae = load_ae(
113
+ ckpt_path="assets/autoencoder/ae.safetensors",
114
+ autoencoder_params=AutoEncoderParams(),
115
+ device=torch_device,
116
+ dtype=torch.bfloat16,
117
+ )
118
+ clip_tokenizer = FluxTokenizer(
119
+ model_path="openai/clip-vit-large-patch14", max_length=77
120
+ )
121
+ t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
122
+ clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
123
+ torch_device, dtype=torch.bfloat16
124
+ )
125
+ t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
126
+ torch_device, dtype=torch.bfloat16
127
+ )
128
+
129
+ rng = torch.Generator(device="cpu")
130
+
131
+ if seed is None:
132
+ seed = rng.seed()
133
+ print(f"Generating with seed {seed}:\n{prompt}")
134
+ t0 = time.perf_counter()
135
+ output_name = os.path.join(output_dir, f"img_{seed}.jpg")
136
+
137
+ # Tokenize the prompt, on CPU
138
+ clip_tokens = clip_tokenizer.encode(prompt)
139
+ t5_tokens = t5_tokenizer.encode(prompt)
140
+
141
+ batch = preprocess_flux_data(
142
+ device=torch_device,
143
+ dtype=torch.bfloat16,
144
+ autoencoder=None,
145
+ clip_encoder=clip_encoder,
146
+ t5_encoder=t5_encoder,
147
+ batch={
148
+ "clip_tokens": clip_tokens,
149
+ "t5_tokens": t5_tokens,
150
+ },
151
+ )
152
+
153
+ img = self._generate_images(
154
+ device=torch_device,
155
+ dtype=torch.bfloat16,
156
+ model=model,
157
+ decoder=ae,
158
+ img_width=img_width,
159
+ img_height=img_height,
160
+ denoising_steps=num_steps,
161
+ seed=seed,
162
+ clip_encodings=batch["clip_encodings"],
163
+ t5_encodings=batch["t5_encodings"],
164
+ guidance=guidance,
165
+ )
166
+
167
+ if torch.cuda.is_available():
168
+ torch.cuda.synchronize()
169
+ t1 = time.perf_counter()
170
+
171
+ print(f"Done in {t1 - t0:.1f}s.")
172
+
173
+ self._save_image(name, output_name, img, add_sampling_metadata, prompt)
174
+
175
+ def _generate_images(
176
+ self,
177
+ device: torch.device,
178
+ dtype: torch.dtype,
179
+ model: FluxModel,
180
+ decoder: AutoEncoder,
181
+ # image params:
182
+ img_width: int,
183
+ img_height: int,
184
+ # sampling params:
185
+ denoising_steps: int,
186
+ seed: int,
187
+ clip_encodings: torch.Tensor,
188
+ t5_encodings: torch.Tensor,
189
+ guidance: float = 4.0,
190
+ ):
191
+
192
+ bsz = clip_encodings.shape[0]
193
+ latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
194
+ _, latent_channels, latent_height, latent_width = latents.shape
195
+
196
+ # create denoising schedule
197
+ timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
198
+
199
+ # create positional encodings
200
+ POSITION_DIM = 3 # constant for Flux flow model
201
+ latent_pos_enc = create_position_encoding_for_latents(
202
+ bsz, latent_height, latent_width, POSITION_DIM
203
+ ).to(latents)
204
+ text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
205
+
206
+ # convert img-like latents into sequences of patches
207
+ latents = pack_latents(latents)
208
+
209
+ # this is ignored for schnell
210
+ guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
211
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
212
+ t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
213
+ pred = model(
214
+ img=latents,
215
+ img_ids=latent_pos_enc,
216
+ txt=t5_encodings,
217
+ txt_ids=text_pos_enc,
218
+ y=clip_encodings,
219
+ timesteps=t_vec,
220
+ guidance=guidance_vec,
221
+ )
222
+
223
+ latents = latents + (t_prev - t_curr) * pred
224
+
225
+ # convert sequences of patches into img-like latents
226
+ latents = unpack_latents(latents, latent_height, latent_width)
227
+
228
+ img = decoder.decode(latents)
229
+ return img
230
+
231
+ def _save_image(
232
+ self,
233
+ name: str,
234
+ output_name: str,
235
+ x: torch.Tensor,
236
+ add_sampling_metadata: bool,
237
+ prompt: str,
238
+ ):
239
+ print(f"Saving {output_name}")
240
+ # bring into PIL format and save
241
+ x = x.clamp(-1, 1)
242
+ x = rearrange(x[0], "c h w -> h w c")
243
+
244
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
245
+
246
+ exif_data = Image.Exif()
247
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
248
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
249
+ exif_data[ExifTags.Base.Model] = name
250
+ if add_sampling_metadata:
251
+ exif_data[ExifTags.Base.ImageDescription] = prompt
252
+ img.save(output_name, exif=exif_data, quality=95, subsampling=0)
torchtitan/experiments/flux/train_configs/debug_model.toml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [job]
3
+ dump_folder = "./outputs"
4
+ description = "Flux debug model"
5
+ print_args = false
6
+ use_for_integration_test = true
7
+
8
+ [profiling]
9
+ enable_profiling = false
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 10
12
+ enable_memory_snapshot = false
13
+ save_memory_snapshot_folder = "memory_snapshot"
14
+
15
+ [metrics]
16
+ log_freq = 1
17
+ disable_color_printing = false
18
+ enable_tensorboard = false
19
+ save_tb_folder = "tb"
20
+ enable_wandb = false
21
+
22
+ [model]
23
+ name = "flux"
24
+ flavor = "flux-debug"
25
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
26
+ # test tokenizer.model, for debug purpose only
27
+ # tokenizer_path = "./tests/assets/test_tiktoken.model"
28
+ # converters = "float8"
29
+
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 8e-4
34
+ eps = 1e-8
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.0
41
+
42
+ [training]
43
+ batch_size = 32
44
+ seq_len = 512
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "cc12m"
49
+ guidance = 3.5
50
+ seed = 0
51
+
52
+ [encoder]
53
+ t5_encoder="google/t5-v1_1-small"
54
+ clip_encoder="openai/clip-vit-large-patch14"
55
+ max_t5_encoding_len=512
56
+ auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
57
+
58
+ [parallelism]
59
+ data_parallel_replicate_degree = 1
60
+ data_parallel_shard_degree = 1
61
+ fsdp_reshard_after_forward = "default" # default / never / always
62
+ tensor_parallel_degree = 1
63
+ enable_async_tensor_parallel = false
64
+ pipeline_parallel_degree = 1
65
+ context_parallel_degree = 1
66
+
67
+ [experimental]
68
+ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # Benchmark comparing reference PyTorch vs optimized M*G group GEMM implementation
11
+
12
+ import argparse
13
+ import logging
14
+ import time
15
+
16
+ # from typing import Dict, List, Optional, Tuple
17
+
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import torch
21
+ import triton
22
+
23
+ # import triton.language as tl
24
+
25
+ # Configure logging
26
+ logging.basicConfig(
27
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
28
+ )
29
+
30
+ # Try to import the optimized implementations
31
+ try:
32
+ from torchao_pr.mg_grouped_gemm import grouped_gemm_forward
33
+
34
+ except ImportError:
35
+ logging.error(
36
+ "Error importing MG grouped GEMM modules. Make sure the implementation files are in the correct path."
37
+ )
38
+ raise
39
+
40
+
41
+ def compute_reference_forward(x, w, m_sizes):
42
+ """
43
+ Reference PyTorch implementation of M*G grouped GEMM forward pass.
44
+
45
+ Args:
46
+ x (torch.Tensor): Input tensor of shape (M, K)
47
+ w (torch.Tensor): Weight tensor of shape (N, K)
48
+ m_sizes (torch.Tensor): Group sizes tensor of shape (G)
49
+
50
+ Returns:
51
+ torch.Tensor: Output tensor of shape (M, N)
52
+ """
53
+ result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
54
+
55
+ m_start = 0
56
+ for g in range(len(m_sizes)):
57
+ m_size = m_sizes[g].item()
58
+ if m_size > 0:
59
+ m_end = m_start + m_size
60
+
61
+ # Extract group input
62
+ x_g = x[m_start:m_end]
63
+
64
+ # Compute group output
65
+ y_g = torch.matmul(x_g, w.T)
66
+
67
+ # Store result
68
+ result[m_start:m_end] = y_g
69
+
70
+ # Update start index
71
+ m_start = m_end
72
+
73
+ return result
74
+
75
+
76
+ @triton.testing.perf_report(
77
+ triton.testing.Benchmark(
78
+ x_names=["N"], # We'll vary the output dimension
79
+ x_vals=[1024, 2048, 4096, 8192, 16384], # Different output dimensions to test
80
+ # x_vals=[8192, 16384],
81
+ line_arg="provider", # We'll compare different providers
82
+ line_vals=["pytorch_reference", "M*G grouped GEMM"],
83
+ line_names=["PyTorch Reference", "M*G grouped Kernel"],
84
+ styles=[("blue", "-"), ("red", "-")],
85
+ ylabel="TFLOPS", # We'll measure TFLOPS
86
+ plot_name="mg_grouped_gemm_comparison",
87
+ args={
88
+ "M": 8192, # Batch dimension, fixed for all tests
89
+ "K": 7168, # Hidden dimension, fixed for all tests
90
+ "G": 8, # Number of groups
91
+ "dtype": torch.float16,
92
+ "device": "cuda",
93
+ },
94
+ )
95
+ )
96
+ def benchmark_forward(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
97
+ """
98
+ Benchmark the forward pass of the grouped GEMM implementation.
99
+
100
+ Args:
101
+ M (int): Total batch size dimension
102
+ K (int): Hidden dimension
103
+ N (int): Output dimension
104
+ G (int): Number of groups
105
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
106
+ dtype (torch.dtype): Data type to use
107
+ device (str): Device to use
108
+
109
+ Returns:
110
+ float: Performance in TFLOPS
111
+ """
112
+ # Create group sizes for M dimension (balanced across groups)
113
+ base_size = M // G
114
+ remainder = M % G
115
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
116
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
117
+
118
+ print(f"N: {N}, M: {M}, K: {K}, G: {G}, dtype: {dtype}, device: {device}")
119
+
120
+ # Create input and weight tensors
121
+ x = torch.randn(M, K, dtype=dtype, device=device)
122
+ w = torch.randn(N, K, dtype=dtype, device=device)
123
+
124
+ # Pre-compute for PyTorch reference to ensure fair comparison
125
+ if provider == "pytorch_reference":
126
+ # Warmup
127
+ torch.cuda.synchronize()
128
+ compute_reference_forward(x, w, m_sizes)
129
+ torch.cuda.synchronize()
130
+
131
+ # Benchmark
132
+ start_time = time.time()
133
+ for _ in range(10): # Average over 10 runs
134
+ compute_reference_forward(x, w, m_sizes)
135
+ torch.cuda.synchronize()
136
+ end_time = time.time()
137
+ else: # Optimized kernel
138
+ # Warmup
139
+ torch.cuda.synchronize()
140
+ grouped_gemm_forward(x, w, m_sizes)
141
+ torch.cuda.synchronize()
142
+
143
+ # Benchmark
144
+ start_time = time.time()
145
+ for _ in range(10): # Average over 10 runs
146
+ grouped_gemm_forward(x, w, m_sizes)
147
+ torch.cuda.synchronize()
148
+ end_time = time.time()
149
+
150
+ # Calculate FLOPs
151
+ # For GEMM: 2 * M * N * K FLOPs (multiply-add counts as 2 FLOPs)
152
+ flops = 2 * M * N * K
153
+
154
+ # Convert to TFLOPS (tera-FLOPS)
155
+ avg_time = (end_time - start_time) / 10 # Average time per run
156
+ tflops = flops / avg_time / 1e12
157
+
158
+ return tflops
159
+
160
+
161
+ @triton.testing.perf_report(
162
+ triton.testing.Benchmark(
163
+ x_names=["G"], # We'll vary the number of groups
164
+ x_vals=[1, 2, 4, 8, 16], # Different numbers of groups to test
165
+ line_arg="provider", # We'll compare different providers
166
+ line_vals=["pytorch_reference", "optimized_kernel"],
167
+ line_names=["PyTorch Reference", "Optimized Kernel"],
168
+ styles=[("blue", "-"), ("red", "-")],
169
+ ylabel="TFLOPS", # We'll measure TFLOPS
170
+ plot_name="mg_grouped_gemm_group_scaling",
171
+ args={
172
+ "M": 8192, # Batch dimension, fixed for all tests
173
+ "K": 4096, # Hidden dimension, fixed for all tests
174
+ "N": 8192, # Output dimension, fixed for all tests
175
+ "dtype": torch.float16,
176
+ "device": "cuda",
177
+ },
178
+ )
179
+ )
180
+ def benchmark_forward_groups(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
181
+ """
182
+ Benchmark how performance scales with number of groups.
183
+
184
+ Args:
185
+ M (int): Total batch size dimension
186
+ K (int): Hidden dimension
187
+ N (int): Output dimension
188
+ G (int): Number of groups
189
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
190
+ dtype (torch.dtype): Data type to use
191
+ device (str): Device to use
192
+
193
+ Returns:
194
+ float: Performance in TFLOPS
195
+ """
196
+ # Create group sizes for M dimension (balanced across groups)
197
+ base_size = M // G
198
+ remainder = M % G
199
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
200
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
201
+
202
+ # Create input and weight tensors
203
+ x = torch.randn(M, K, dtype=dtype, device=device)
204
+ w = torch.randn(N, K, dtype=dtype, device=device)
205
+
206
+ # Benchmark logic - same as previous function
207
+ if provider == "pytorch_reference":
208
+ torch.cuda.synchronize()
209
+ compute_reference_forward(x, w, m_sizes)
210
+ torch.cuda.synchronize()
211
+
212
+ start_time = time.time()
213
+ for _ in range(10):
214
+ compute_reference_forward(x, w, m_sizes)
215
+ torch.cuda.synchronize()
216
+ end_time = time.time()
217
+ else:
218
+ torch.cuda.synchronize()
219
+ grouped_gemm_forward(x, w, m_sizes)
220
+ torch.cuda.synchronize()
221
+
222
+ start_time = time.time()
223
+ for _ in range(10):
224
+ grouped_gemm_forward(x, w, m_sizes)
225
+ torch.cuda.synchronize()
226
+ end_time = time.time()
227
+
228
+ # Calculate FLOPs and TFLOPS
229
+ flops = 2 * M * N * K
230
+ avg_time = (end_time - start_time) / 10
231
+ tflops = flops / avg_time / 1e12
232
+
233
+ return tflops
234
+
235
+
236
+ @triton.testing.perf_report(
237
+ triton.testing.Benchmark(
238
+ x_names=["group_balance"], # We'll vary the group balance factor
239
+ x_vals=[
240
+ 0.0,
241
+ 0.25,
242
+ 0.5,
243
+ 0.75,
244
+ 0.9,
245
+ ], # Different imbalance factors (0 = balanced, 1 = max imbalance)
246
+ line_arg="provider", # We'll compare different providers
247
+ line_vals=["pytorch_reference", "optimized_kernel"],
248
+ line_names=["PyTorch Reference", "Optimized Kernel"],
249
+ styles=[("blue", "-"), ("red", "-")],
250
+ ylabel="TFLOPS", # We'll measure TFLOPS
251
+ plot_name="mg_grouped_gemm_imbalance",
252
+ args={
253
+ "M": 8192, # Batch dimension, fixed for all tests
254
+ "K": 4096, # Hidden dimension, fixed for all tests
255
+ "N": 8192, # Output dimension, fixed for all tests
256
+ "G": 4, # Number of groups
257
+ "dtype": torch.float16,
258
+ "device": "cuda",
259
+ },
260
+ )
261
+ )
262
+ def benchmark_imbalance(
263
+ M, K, N, G, group_balance, provider, dtype=torch.float16, device="cuda"
264
+ ):
265
+ """
266
+ Benchmark how performance is affected by imbalanced group sizes.
267
+
268
+ Args:
269
+ M (int): Total batch size dimension
270
+ K (int): Hidden dimension
271
+ N (int): Output dimension
272
+ G (int): Number of groups
273
+ group_balance (float): Balance factor from 0 to 1 (0 = balanced, 1 = max imbalance)
274
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
275
+ dtype (torch.dtype): Data type to use
276
+ device (str): Device to use
277
+
278
+ Returns:
279
+ float: Performance in TFLOPS
280
+ """
281
+ # Create imbalanced group sizes for M dimension
282
+ if group_balance == 0:
283
+ # Balanced case
284
+ base_size = M // G
285
+ remainder = M % G
286
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
287
+ else:
288
+ # Imbalanced case
289
+ # First group gets more elements, last group gets fewer
290
+ # The imbalance is controlled by the group_balance factor
291
+ remaining = M
292
+ M_sizes = []
293
+ for g in range(G):
294
+ # Interpolate from balanced to imbalanced based on group_balance
295
+ # For balanced (group_balance=0), each group gets M/G
296
+ # For imbalanced (group_balance=1), first group gets much more than last group
297
+ balanced_size = remaining // (G - g)
298
+
299
+ # Adjusting size based on position and imbalance factor
300
+ # First groups get more, last groups get less
301
+ if g < G // 2:
302
+ # First half of groups get more
303
+ adjustment = int(balanced_size * group_balance * (1 - g / (G - 1)))
304
+ size = balanced_size + adjustment
305
+ else:
306
+ # Second half of groups get less
307
+ adjustment = int(balanced_size * group_balance * ((g / (G - 1)) - 0.5))
308
+ size = balanced_size - adjustment
309
+
310
+ # Ensure we don't go below 1 or take more than remaining
311
+ size = max(1, min(size, remaining))
312
+ M_sizes.append(size)
313
+ remaining -= size
314
+
315
+ # Handle any remaining elements
316
+ if remaining > 0:
317
+ M_sizes[-1] += remaining
318
+
319
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
320
+
321
+ # Create input and weight tensors
322
+ x = torch.randn(M, K, dtype=dtype, device=device)
323
+ w = torch.randn(N, K, dtype=dtype, device=device)
324
+
325
+ # Benchmark logic
326
+ if provider == "pytorch_reference":
327
+ torch.cuda.synchronize()
328
+ compute_reference_forward(x, w, m_sizes)
329
+ torch.cuda.synchronize()
330
+
331
+ start_time = time.time()
332
+ for _ in range(10):
333
+ compute_reference_forward(x, w, m_sizes)
334
+ torch.cuda.synchronize()
335
+ end_time = time.time()
336
+ else:
337
+ torch.cuda.synchronize()
338
+ grouped_gemm_forward(x, w, m_sizes)
339
+ torch.cuda.synchronize()
340
+
341
+ start_time = time.time()
342
+ for _ in range(10):
343
+ grouped_gemm_forward(x, w, m_sizes)
344
+ torch.cuda.synchronize()
345
+ end_time = time.time()
346
+
347
+ # Calculate FLOPs and TFLOPS
348
+ flops = 2 * M * N * K
349
+ avg_time = (end_time - start_time) / 10
350
+ tflops = flops / avg_time / 1e12
351
+
352
+ return tflops
353
+
354
+
355
+ def benchmark_model_configs():
356
+ """
357
+ Benchmark common model configurations used in DeepSeek-like models.
358
+ """
359
+ # Model configurations: (M, K, N, G)
360
+ configs = [
361
+ (8192, 7168, 4096, 4), # Config 1
362
+ (8192, 2048, 7168, 4), # Config 2
363
+ (4096, 7168, 4096, 8), # Config 3
364
+ (4096, 2048, 7168, 8), # Config 4
365
+ ]
366
+
367
+ results = []
368
+
369
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
370
+ dtype = torch.float16
371
+
372
+ for config_idx, (M, K, N, G) in enumerate(configs):
373
+ logging.info(f"\n===== Benchmarking DeepSeek Config {config_idx + 1} =====")
374
+ logging.info(f"M={M}, K={K}, N={N}, G={G}")
375
+
376
+ # Create group sizes for M dimension
377
+ base_size = M // G
378
+ remainder = M % G
379
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
380
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
381
+
382
+ # Create tensors
383
+ x = torch.randn(M, K, dtype=dtype, device=device)
384
+ w = torch.randn(N, K, dtype=dtype, device=device)
385
+
386
+ # Benchmark PyTorch reference
387
+ torch.cuda.synchronize()
388
+ compute_reference_forward(x, w, m_sizes) # Warmup
389
+ torch.cuda.synchronize()
390
+
391
+ logging.info("Benchmarking PyTorch reference...")
392
+ torch.cuda.reset_peak_memory_stats()
393
+ start_time = time.time()
394
+ for _ in range(10):
395
+ compute_reference_forward(x, w, m_sizes)
396
+ torch.cuda.synchronize()
397
+ end_time = time.time()
398
+ pt_time = (end_time - start_time) / 10
399
+ pt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
400
+
401
+ # Benchmark optimized kernel
402
+ torch.cuda.synchronize()
403
+ grouped_gemm_forward(x, w, m_sizes) # Warmup
404
+ torch.cuda.synchronize()
405
+
406
+ logging.info("Benchmarking optimized kernel...")
407
+ torch.cuda.reset_peak_memory_stats()
408
+ start_time = time.time()
409
+ for _ in range(10):
410
+ grouped_gemm_forward(x, w, m_sizes)
411
+ torch.cuda.synchronize()
412
+ end_time = time.time()
413
+ opt_time = (end_time - start_time) / 10
414
+ opt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
415
+
416
+ # Calculate FLOPs and speedup
417
+ flops = 2 * M * N * K
418
+ pt_tflops = flops / pt_time / 1e12
419
+ opt_tflops = flops / opt_time / 1e12
420
+ speedup = pt_time / opt_time
421
+
422
+ # Store results
423
+ results.append(
424
+ {
425
+ "config": f"Config {config_idx + 1}",
426
+ "dimensions": f"M={M}, K={K}, N={N}, G={G}",
427
+ "pt_time_ms": pt_time * 1000,
428
+ "opt_time_ms": opt_time * 1000,
429
+ "pt_tflops": pt_tflops,
430
+ "opt_tflops": opt_tflops,
431
+ "speedup": speedup,
432
+ "pt_memory_mb": pt_memory,
433
+ "opt_memory_mb": opt_memory,
434
+ "memory_savings": (
435
+ (pt_memory - opt_memory) / pt_memory * 100 if pt_memory > 0 else 0
436
+ ),
437
+ }
438
+ )
439
+
440
+ logging.info(
441
+ f"PyTorch Reference: {pt_time * 1000:.2f} ms, {pt_tflops:.2f} TFLOPS, {pt_memory:.2f} MB"
442
+ )
443
+ logging.info(
444
+ f"Optimized Kernel: {opt_time * 1000:.2f} ms, {opt_tflops:.2f} TFLOPS, {opt_memory:.2f} MB"
445
+ )
446
+ logging.info(
447
+ f"Speedup: {speedup:.2f}x, Memory savings: {results[-1]['memory_savings']:.2f}%"
448
+ )
449
+
450
+ # Print summary table
451
+ logging.info("\n===== Benchmark Results Summary =====")
452
+ logging.info(
453
+ f"{'Config':<10} | {'Time (ms)':<20} | {'TFLOPS':<20} | {'Speedup':<10} | {'Memory (MB)':<20} | {'Memory Saved':<12}"
454
+ )
455
+ logging.info(
456
+ f"{'':<10} | {'PyTorch':<9} {'Kernel':<9} | {'PyTorch':<9} {'Kernel':<9} | {'':<10} | "
457
+ f"{'PyTorch':<9} {'Kernel':<9} | {'':<12}"
458
+ )
459
+ logging.info("-" * 100)
460
+
461
+ for result in results:
462
+ logging.info(
463
+ f"{result['config']:<10} | "
464
+ f"{result['pt_time_ms']:<9.2f} {result['opt_time_ms']:<9.2f} | "
465
+ f"{result['pt_tflops']:<9.2f} {result['opt_tflops']:<9.2f} | "
466
+ f"{result['speedup']:<10.2f} | "
467
+ f"{result['pt_memory_mb']:<9.2f} {result['opt_memory_mb']:<9.2f} | "
468
+ f"{result['memory_savings']:<12.2f}%"
469
+ )
470
+
471
+ return results
472
+
473
+
474
+ def plot_benchmark_results(results):
475
+ """
476
+ Plot benchmark results as bar charts.
477
+ """
478
+ # Extract data
479
+ configs = [r["config"] for r in results]
480
+ pt_tflops = [r["pt_tflops"] for r in results]
481
+ opt_tflops = [r["opt_tflops"] for r in results]
482
+ speedups = [r["speedup"] for r in results]
483
+
484
+ # Create figure with subplots
485
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
486
+
487
+ # Plot TFLOPS comparison
488
+ x = np.arange(len(configs))
489
+ width = 0.35
490
+ ax1.bar(x - width / 2, pt_tflops, width, label="PyTorch Reference")
491
+ ax1.bar(x + width / 2, opt_tflops, width, label="Optimized Kernel")
492
+ ax1.set_xlabel("Model Configuration")
493
+ ax1.set_ylabel("TFLOPS")
494
+ ax1.set_title("Performance Comparison (Higher is Better)")
495
+ ax1.set_xticks(x)
496
+ ax1.set_xticklabels(configs)
497
+ ax1.legend()
498
+ ax1.grid(axis="y", linestyle="--", alpha=0.7)
499
+
500
+ # Plot speedup
501
+ ax2.bar(x, speedups, width=0.6, color="green")
502
+ ax2.set_xlabel("Model Configuration")
503
+ ax2.set_ylabel("Speedup (x)")
504
+ ax2.set_title("Speedup Factor (Higher is Better)")
505
+ ax2.set_xticks(x)
506
+ ax2.set_xticklabels(configs)
507
+ ax2.grid(axis="y", linestyle="--", alpha=0.7)
508
+
509
+ # Add speedup values on top of bars
510
+ for i, v in enumerate(speedups):
511
+ ax2.text(i, v + 0.1, f"{v:.2f}x", ha="center")
512
+
513
+ plt.tight_layout()
514
+ plt.savefig("mg_grouped_gemm_benchmark_results.png")
515
+ logging.info(
516
+ "Benchmark results plot saved to 'mg_grouped_gemm_benchmark_results.png'"
517
+ )
518
+
519
+
520
+ def compare_mg_implementations():
521
+ """
522
+ Combine the M*G and N*G benchmark results for comparison.
523
+ """
524
+ # Only run this if both NG and MG benchmarks have been run
525
+ try:
526
+ import pandas as pd
527
+
528
+ # Try to load previous benchmark results
529
+ mg_results = pd.read_csv("mg_grouped_gemm_benchmark_results.csv")
530
+ ng_results = pd.read_csv("ng_grouped_gemm_benchmark_results.csv")
531
+
532
+ # Create comparison plot
533
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
534
+
535
+ # Plot speedup comparison
536
+ configs = mg_results["config"].unique()
537
+ mg_speedups = mg_results.groupby("config")["speedup"].mean()
538
+ ng_speedups = ng_results.groupby("config")["speedup"].mean()
539
+
540
+ x = np.arange(len(configs))
541
+ width = 0.35
542
+
543
+ axes[0].bar(x - width / 2, mg_speedups, width, label="M*G Grouping")
544
+ axes[0].bar(x + width / 2, ng_speedups, width, label="N*G Grouping")
545
+ axes[0].set_xlabel("Model Configuration")
546
+ axes[0].set_ylabel("Speedup (x)")
547
+ axes[0].set_title("Speedup Comparison: M*G vs N*G")
548
+ axes[0].set_xticks(x)
549
+ axes[0].set_xticklabels(configs)
550
+ axes[0].legend()
551
+ axes[0].grid(axis="y", linestyle="--", alpha=0.7)
552
+
553
+ # Plot TFLOPS comparison for optimized kernels
554
+ mg_tflops = (
555
+ mg_results[mg_results["implementation"] == "optimized"]
556
+ .groupby("config")["tflops"]
557
+ .mean()
558
+ )
559
+ ng_tflops = (
560
+ ng_results[ng_results["implementation"] == "optimized"]
561
+ .groupby("config")["tflops"]
562
+ .mean()
563
+ )
564
+
565
+ axes[1].bar(x - width / 2, mg_tflops, width, label="M*G Grouping")
566
+ axes[1].bar(x + width / 2, ng_tflops, width, label="N*G Grouping")
567
+ axes[1].set_xlabel("Model Configuration")
568
+ axes[1].set_ylabel("TFLOPS")
569
+ axes[1].set_title("Performance Comparison: M*G vs N*G")
570
+ axes[1].set_xticks(x)
571
+ axes[1].set_xticklabels(configs)
572
+ axes[1].legend()
573
+ axes[1].grid(axis="y", linestyle="--", alpha=0.7)
574
+
575
+ plt.tight_layout()
576
+ plt.savefig("mg_vs_ng_comparison.png")
577
+ logging.info("Comparison plot saved to 'mg_vs_ng_comparison.png'")
578
+
579
+ except Exception as e:
580
+ logging.error(f"Could not create comparison plot: {e}")
581
+ logging.info(
582
+ "Run both M*G and N*G benchmarks first to generate comparison plots"
583
+ )
584
+
585
+
586
+ if __name__ == "__main__":
587
+ parser = argparse.ArgumentParser(
588
+ description="Benchmark M*G Grouped GEMM implementations"
589
+ )
590
+ parser.add_argument("--run-all", action="store_true", help="Run all benchmarks")
591
+ parser.add_argument(
592
+ "--triton-bench", action="store_true", help="Run Triton performance reports"
593
+ )
594
+ parser.add_argument(
595
+ "--model-configs", action="store_true", help="Benchmark model configurations"
596
+ )
597
+ parser.add_argument(
598
+ "--compare-mg-ng",
599
+ action="store_true",
600
+ help="Compare M*G and N*G implementations",
601
+ )
602
+ args = parser.parse_args()
603
+
604
+ # Check if CUDA is available
605
+ if not torch.cuda.is_available():
606
+ logging.error(
607
+ "CUDA is not available. This benchmark requires a CUDA-capable GPU."
608
+ )
609
+ exit(1)
610
+
611
+ if args.run_all or args.model_configs:
612
+ # Benchmark model configurations
613
+ logging.info("Running benchmark for model configurations...")
614
+ results = benchmark_model_configs()
615
+ plot_benchmark_results(results)
616
+
617
+ if args.run_all or args.triton_bench:
618
+ # Run Triton performance reports
619
+ logging.info("Running Triton performance reports...")
620
+ benchmark_forward.run(save_path="mg_grouped_gemm_benchmark_results")
621
+ benchmark_forward_groups.run(save_path="mg_grouped_gemm_benchmark_results")
622
+ benchmark_imbalance.run(save_path="mg_grouped_gemm_benchmark_results")
623
+ logging.info(
624
+ "Triton performance reports saved to 'mg_grouped_gemm_benchmark_results' directory"
625
+ )
626
+
627
+ if args.run_all or args.compare_mg_ng:
628
+ # Compare M*G and N*G implementations
629
+ logging.info("Comparing M*G and N*G implementations...")
630
+ compare_mg_implementations()
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import unittest
10
+ from typing import Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from mg_grouped_gemm import grouped_gemm_forward
16
+
17
+
18
+ class TestMG_GroupedGEMM(unittest.TestCase):
19
+ def setUp(self) -> None:
20
+ torch.manual_seed(2020)
21
+
22
+ def _run_grouped_gemm_test(
23
+ self,
24
+ shape: Tuple[int, int, int, int],
25
+ device: torch.device,
26
+ dtype: torch.dtype = torch.bfloat16,
27
+ atol: float = 1e-5,
28
+ rtol: float = 1.6e-2,
29
+ ) -> None:
30
+ G, M, N, K = shape
31
+ # In M*G grouping, input is [M*G, K] and weights are [N*G, K]
32
+ a = torch.randn(M * G, K, dtype=dtype, device=device)
33
+ b = torch.randn(N * G, K, dtype=dtype, device=device)
34
+
35
+ # Create equal-sized groups for simplicity
36
+ m_size = M
37
+ m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
38
+
39
+ result = grouped_gemm_forward(a, b, m_sizes)
40
+ self.assertTrue(result.shape == (M * G, N))
41
+
42
+ expected_result = torch.zeros(M * G, N, dtype=dtype, device=device)
43
+ m_start = 0
44
+ for g in range(G):
45
+ m_end = m_start + m_sizes[g]
46
+ b_slice = b[N * g : N * (g+1), :]
47
+ expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T
48
+ m_start = m_end
49
+
50
+ # Convert result to match input dtype if needed
51
+ result = result.to(dtype)
52
+ torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol)
53
+
54
+ def test_MG_grouped_gemm_bf16(self) -> None:
55
+ for G in (1, 4, 16):
56
+ for M in (128, 512, 1024):
57
+ print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}")
58
+ self._run_grouped_gemm_test(
59
+ (G, M, 1024, 1024),
60
+ torch.device("cuda"),
61
+ dtype=torch.bfloat16,
62
+ atol=1e-5,
63
+ rtol=1.6e-2,
64
+ )
65
+
66
+ def test_MG_grouped_gemm_deepseek_shapes(self) -> None:
67
+ """Test with shapes from Deepseek model."""
68
+ deepseek_shapes = [
69
+ (4, 2048, 4096, 7168), # G, M, N, K
70
+ (4, 2048, 7168, 2048),
71
+ (8, 512, 4096, 7168),
72
+ (8, 512, 7168, 2048),
73
+ ]
74
+
75
+ device = torch.device("cuda")
76
+
77
+ for shape in deepseek_shapes:
78
+ G, M, N, K = shape
79
+ print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}")
80
+ self._run_grouped_gemm_test(
81
+ shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2
82
+ )