zaydzuhri commited on
Commit
2b2620a
·
verified ·
1 Parent(s): ae2fe24

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. flame/__pycache__/__init__.cpython-312.pyc +0 -0
  2. flame/__pycache__/data.cpython-312.pyc +0 -0
  3. flame/__pycache__/train.cpython-312.pyc +0 -0
  4. flame/components/__init__.py +0 -0
  5. flame/components/__pycache__/__init__.cpython-312.pyc +0 -0
  6. flame/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
  7. flame/components/checkpoint.py +59 -0
  8. flame/models/__init__.py +0 -0
  9. flame/models/__pycache__/__init__.cpython-312.pyc +0 -0
  10. flame/models/__pycache__/parallelize_fla.cpython-312.pyc +0 -0
  11. flame/models/__pycache__/pipeline_fla.cpython-312.pyc +0 -0
  12. flame/models/fla.toml +67 -0
  13. flame/models/parallelize_fla.py +550 -0
  14. flame/models/pipeline_fla.py +162 -0
  15. flame/tools/__init__.py +0 -0
  16. flame/tools/__pycache__/__init__.cpython-312.pyc +0 -0
  17. flame/tools/__pycache__/utils.cpython-312.pyc +0 -0
  18. flame/tools/utils.py +41 -0
  19. flame/utils/__init__.py +0 -0
  20. flame/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  21. flame/utils/__pycache__/checkpoint.cpython-312.pyc +0 -0
  22. flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc +0 -0
  23. flame/utils/__pycache__/convert_hf_to_dcp.cpython-312.pyc +0 -0
  24. flame/utils/__pycache__/hf_utils.cpython-312.pyc +0 -0
  25. flame/utils/checkpoint.py +50 -0
  26. flame/utils/convert_dcp_to_hf.py +66 -0
  27. flame/utils/convert_hf_to_dcp.py +34 -0
  28. flame/utils/hf_utils.py +77 -0
  29. logs/none_enyj3lod/attempt_0/5/stderr.log +0 -0
  30. logs/none_enyj3lod/attempt_0/7/stderr.log +0 -0
  31. profile_trace/iteration_15872/rank6_trace.json +0 -0
  32. profile_trace/iteration_16384/rank2_trace.json +0 -0
  33. profile_trace/iteration_16384/rank3_trace.json +0 -0
  34. profile_trace/iteration_16384/rank5_trace.json +0 -0
  35. profile_trace/iteration_19968/rank3_trace.json +0 -0
  36. profile_trace/iteration_19968/rank7_trace.json +0 -0
  37. profile_trace/iteration_3072/rank2_trace.json +0 -0
  38. profile_trace/iteration_32256/rank1_trace.json +0 -0
  39. profile_trace/iteration_37376/rank0_trace.json +0 -0
  40. profile_trace/iteration_37376/rank6_trace.json +0 -0
  41. profile_trace/iteration_9216/rank2_trace.json +0 -0
  42. profile_trace/iteration_9216/rank3_trace.json +0 -0
  43. profile_trace/iteration_9216/rank4_trace.json +0 -0
  44. profile_trace/iteration_9216/rank6_trace.json +0 -0
  45. profile_trace/iteration_9216/rank7_trace.json +0 -0
  46. profile_trace/iteration_9728/rank1_trace.json +0 -0
  47. tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/files/output.log +0 -0
  48. tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/files/wandb-metadata.json +146 -0
  49. tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/logs/debug-core.log +16 -0
  50. torchtitan/__pycache__/config_manager.cpython-312.pyc +0 -0
flame/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (156 Bytes). View file
 
flame/__pycache__/data.cpython-312.pyc ADDED
Binary file (31.3 kB). View file
 
flame/__pycache__/train.cpython-312.pyc ADDED
Binary file (38.1 kB). View file
 
flame/components/__init__.py ADDED
File without changes
flame/components/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (141 Bytes). View file
 
flame/components/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (3.21 kB). View file
 
flame/components/checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+ from datetime import timedelta
9
+ from io import BytesIO
10
+ from typing import Any, Dict, List
11
+
12
+ import torch
13
+ from torch.distributed.checkpoint.stateful import Stateful
14
+
15
+
16
+ @dataclass
17
+ class TrainState(Stateful):
18
+ step: int = 0
19
+ skipped_step: int = 0
20
+ token: int = 0
21
+ elapsed: timedelta = timedelta(0)
22
+ global_avg_losses: List[float] = field(default_factory=list)
23
+ global_max_losses: List[float] = field(default_factory=list)
24
+ log_steps: List[int] = field(default_factory=list)
25
+
26
+ def state_dict(self) -> Dict[str, Any]:
27
+ # Only checkpoint global_avg_losses and global_max_losses per log frequency
28
+ # to avoid sync overhead in every iteration.
29
+ global_avg_losses_bytes = BytesIO()
30
+ torch.save(self.global_avg_losses, global_avg_losses_bytes)
31
+ global_max_losses_bytes = BytesIO()
32
+ torch.save(self.global_max_losses, global_max_losses_bytes)
33
+ log_steps_bytes = BytesIO()
34
+ torch.save(self.log_steps, log_steps_bytes)
35
+ return {
36
+ "step": torch.tensor(self.step, dtype=torch.int32),
37
+ "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
38
+ "token": torch.tensor(self.token, dtype=torch.int64),
39
+ "elapsed": self.elapsed,
40
+ "global_avg_losses": global_avg_losses_bytes,
41
+ "global_max_losses": global_max_losses_bytes,
42
+ "log_steps": log_steps_bytes,
43
+ }
44
+
45
+ def load_state_dict(self, state_dict) -> None:
46
+ self.step = state_dict["step"].item()
47
+ self.skipped_step = state_dict.get("skipped_step", 0).item()
48
+ self.token = state_dict["token"].item()
49
+ self.elapsed = state_dict["elapsed"]
50
+ state_dict["global_avg_losses"].seek(0)
51
+ self.global_avg_losses = torch.load(
52
+ state_dict["global_avg_losses"], weights_only=False
53
+ )
54
+ state_dict["global_max_losses"].seek(0)
55
+ self.global_max_losses = torch.load(
56
+ state_dict["global_max_losses"], weights_only=False
57
+ )
58
+ state_dict["log_steps"].seek(0)
59
+ self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
flame/models/__init__.py ADDED
File without changes
flame/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (137 Bytes). View file
 
flame/models/__pycache__/parallelize_fla.cpython-312.pyc ADDED
Binary file (22.1 kB). View file
 
flame/models/__pycache__/pipeline_fla.cpython-312.pyc ADDED
Binary file (5.75 kB). View file
 
flame/models/fla.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [model]
2
+ config = "fla-hub/transformer-1.3B-100B"
3
+ tokenizer_path = "fla-hub/transformer-1.3B-100B"
4
+
5
+ [job]
6
+ dump_folder = "exp"
7
+ print_args = true
8
+
9
+ [training]
10
+ batch_size = 32
11
+ seq_len = 2048
12
+ context_len = 2048
13
+ gradient_accumulation_steps = 1
14
+ steps = 20480
15
+ max_norm = 1.0
16
+ skip_nan_inf = true
17
+ data_parallel_replicate_degree = 1
18
+ data_parallel_shard_degree = -1
19
+ tensor_parallel_degree = 1
20
+ compile = false
21
+ dataset = "HuggingFaceFW/fineweb-edu"
22
+ dataset_name = "default"
23
+ num_workers = 32
24
+ pin_memory = false
25
+ persistent_workers = false
26
+ prefetch_factor = 2
27
+ seed = 42
28
+ varlen = false
29
+
30
+ [optimizer]
31
+ name = "AdamW"
32
+ eps = 1e-15
33
+ lr = 3e-4
34
+
35
+ [lr_scheduler]
36
+ warmup_steps = 1024
37
+ decay_type = "cosine"
38
+ lr_min = 0.1
39
+
40
+ [checkpoint]
41
+ enable_checkpoint = true
42
+ folder = "checkpoint"
43
+ interval_type = "steps"
44
+ interval = 2048
45
+ model_weights_only = false
46
+ export_dtype = "float32"
47
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
48
+
49
+ [profiling]
50
+ enable_profiling = true
51
+ save_traces_folder = "profile_trace"
52
+ profile_freq = 512
53
+
54
+ [metrics]
55
+ log_freq = 32
56
+ enable_wandb = true
57
+
58
+ [experimental]
59
+ context_parallel_degree = 1
60
+ pipeline_parallel_degree = 1
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+
66
+ [activation_checkpoint]
67
+ mode = "none"
flame/models/parallelize_fla.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
16
+ from torch.distributed._composable.replicate import replicate
17
+ from torch.distributed._tensor import Replicate, Shard
18
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
19
+ from torch.distributed.tensor.parallel import (
20
+ ColwiseParallel,
21
+ PrepareModuleInput,
22
+ PrepareModuleOutput,
23
+ RowwiseParallel,
24
+ SequenceParallel,
25
+ parallelize_module
26
+ )
27
+
28
+ from fla.modules.fused_linear_cross_entropy import LinearLossParallel
29
+ from fla.modules.mlp import SwiGLULinearParallel
30
+ from fla.modules.parallel import PrepareModuleWeight
31
+ from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
32
+ from torchtitan.distributed.parallel_dims import ParallelDims
33
+ from torchtitan.tools.logging import logger
34
+
35
+
36
+ def parallelize_fla(
37
+ model: nn.Module,
38
+ world_mesh: DeviceMesh,
39
+ parallel_dims: ParallelDims,
40
+ job_config: JobConfig,
41
+ ):
42
+ """
43
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
44
+ parallelism to the model.
45
+
46
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
47
+ the model must fit on GPU or CPU memory.
48
+ """
49
+
50
+ if parallel_dims.tp_enabled:
51
+ if (
52
+ job_config.experimental.enable_async_tensor_parallel
53
+ and not job_config.training.compile
54
+ ):
55
+ raise RuntimeError("Async TP requires --training.compile")
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8=enable_float8_linear,
62
+ enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
63
+ )
64
+
65
+ if job_config.activation_checkpoint.mode != "none":
66
+ apply_ac(model, job_config.activation_checkpoint)
67
+
68
+ # turn on per-block compile after AC wrapping and before FSDP
69
+ if job_config.training.compile:
70
+ apply_compile(model)
71
+
72
+ if (
73
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
74
+ ): # apply FSDP or HSDP, potentially with Context Parallel
75
+ if parallel_dims.dp_replicate_enabled:
76
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
77
+ else:
78
+ dp_mesh_dim_names = ("dp_shard_cp",)
79
+
80
+ apply_fsdp(
81
+ model,
82
+ world_mesh[tuple(dp_mesh_dim_names)],
83
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
84
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
85
+ pp_enabled=parallel_dims.pp_enabled,
86
+ cpu_offload=job_config.training.enable_cpu_offload,
87
+ reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
88
+ )
89
+
90
+ if parallel_dims.dp_replicate_enabled:
91
+ logger.info("Applied HSDP to the model")
92
+ else:
93
+ logger.info("Applied FSDP to the model")
94
+
95
+ if parallel_dims.cp_enabled:
96
+ logger.info("Applied Context Parallel to the model")
97
+
98
+ if job_config.training.enable_cpu_offload:
99
+ logger.info("Applied CPU Offloading to the model")
100
+ elif parallel_dims.dp_replicate_enabled:
101
+ if world_mesh.ndim > 1:
102
+ raise RuntimeError("DDP has not supported > 1D parallelism")
103
+ apply_ddp(
104
+ model,
105
+ world_mesh,
106
+ enable_compile=job_config.training.compile,
107
+ enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
108
+ )
109
+
110
+
111
+ class TPPlan:
112
+ def __init__(
113
+ self,
114
+ model=None,
115
+ loss_parallel=False,
116
+ enable_float8=False,
117
+ ):
118
+ self.model = model
119
+ self.loss_parallel = loss_parallel
120
+ self.enable_float8 = enable_float8
121
+ self.base_model_prefix = getattr(model, "base_model_prefix", "model")
122
+
123
+ # TODO(vkuzo): once float8 configuration supports delayed scaling,
124
+ # add a check here to enforce supported float8 all-gather configurations
125
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
126
+ try:
127
+ from torchao.float8.float8_tensor_parallel import (
128
+ Float8ColwiseParallel,
129
+ Float8RowwiseParallel,
130
+ PrepareFloat8ModuleInput
131
+ )
132
+ except ImportError:
133
+ Float8ColwiseParallel = None
134
+ Float8RowwiseParallel = None
135
+ PrepareFloat8ModuleInput = None
136
+ if self.enable_float8 and Float8ColwiseParallel is not None:
137
+ self.rowwise_parallel = Float8RowwiseParallel
138
+ self.colwise_parallel = Float8ColwiseParallel
139
+ self.prepare_module_input = PrepareFloat8ModuleInput
140
+ self.prepare_module_output = PrepareModuleOutput
141
+ else:
142
+ self.rowwise_parallel = RowwiseParallel
143
+ self.colwise_parallel = ColwiseParallel
144
+ self.prepare_module_input = PrepareModuleInput
145
+ self.prepare_module_output = PrepareModuleOutput
146
+
147
+ @property
148
+ def model_plan(self):
149
+ plans = {
150
+ f"{self.base_model_prefix}.embeddings": RowwiseParallel(
151
+ input_layouts=Replicate(),
152
+ output_layouts=Shard(1),
153
+ ),
154
+ f"{self.base_model_prefix}.norm": SequenceParallel(),
155
+ }
156
+ if self.loss_parallel:
157
+ plans.update(
158
+ {
159
+ "lm_head": ColwiseParallel(
160
+ input_layouts=Shard(1),
161
+ output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
162
+ use_local_output=not self.loss_parallel,
163
+ ),
164
+ }
165
+ )
166
+ else:
167
+ plans.update(
168
+ {
169
+ "lm_head": PrepareModuleWeight(layouts=Replicate()),
170
+ "criterion": LinearLossParallel(),
171
+ }
172
+ )
173
+ return plans
174
+
175
+ @property
176
+ def layer_plan(self):
177
+ return {
178
+ "attn_norm": SequenceParallel(),
179
+ **self.attn_plan,
180
+ "mlp_norm": SequenceParallel(),
181
+ **self.mlp_plan,
182
+ }
183
+
184
+ @property
185
+ def attn_plan(self):
186
+ raise NotImplementedError(
187
+ f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
188
+ )
189
+
190
+ @property
191
+ def mlp_plan(self):
192
+ return {
193
+ "mlp": self.prepare_module_input(
194
+ input_layouts=(Shard(1),),
195
+ desired_input_layouts=(Replicate(),),
196
+ ),
197
+ "mlp.gate_proj": self.colwise_parallel(),
198
+ "mlp.up_proj": self.colwise_parallel(),
199
+ "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
200
+ "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
201
+ }
202
+
203
+
204
+ class TransformerTPPlan(TPPlan):
205
+
206
+ @property
207
+ def attn_plan(self):
208
+ return {
209
+ "attn": self.prepare_module_input(
210
+ input_kwarg_layouts={"hidden_states": Shard(1)},
211
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
212
+ ),
213
+ "attn.q_proj": self.colwise_parallel(),
214
+ "attn.k_proj": self.colwise_parallel(),
215
+ "attn.v_proj": self.colwise_parallel(),
216
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
217
+ }
218
+
219
+
220
+ class GLATPPlan(TPPlan):
221
+
222
+ @property
223
+ def attn_plan(self):
224
+ return {
225
+ "attn": self.prepare_module_input(
226
+ input_kwarg_layouts={"hidden_states": Shard(1)},
227
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
228
+ ),
229
+ "attn.q_proj": self.colwise_parallel(),
230
+ "attn.k_proj": self.colwise_parallel(),
231
+ "attn.v_proj": self.colwise_parallel(),
232
+ "attn.g_proj": self.colwise_parallel(),
233
+ "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
234
+ "attn.gk_proj.1": self.colwise_parallel(),
235
+ "attn.g_norm": SequenceParallel(sequence_dim=-1),
236
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
237
+ }
238
+
239
+
240
+ TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
241
+
242
+
243
+ def apply_tp(
244
+ model: nn.Module,
245
+ tp_mesh: DeviceMesh,
246
+ loss_parallel: bool,
247
+ enable_float8: bool,
248
+ enable_async_tp: bool,
249
+ ):
250
+ """Apply tensor parallelism."""
251
+ # 1. Parallelize the embedding and shard its outputs (which are the first
252
+ # transformer block's inputs)
253
+ # 2. Parallelize the root norm layer over the sequence dim
254
+ # 3. Parallelize the final linear output layer
255
+ tp_plan = TP_PLAN_MAP[model.config.model_type](
256
+ model, loss_parallel=loss_parallel, enable_float8=enable_float8
257
+ )
258
+ parallelize_module(model, tp_mesh, tp_plan.model_plan)
259
+
260
+ blocks = get_blocks(model)
261
+ if blocks is None:
262
+ logger.warning("No block found for tensor parallelism")
263
+ else:
264
+ for _, block in enumerate(blocks):
265
+ parallelize_module(
266
+ module=block,
267
+ device_mesh=tp_mesh,
268
+ parallelize_plan=tp_plan.layer_plan,
269
+ )
270
+
271
+ if enable_async_tp:
272
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
273
+
274
+ torch._inductor.config._micro_pipeline_tp = True
275
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
276
+
277
+ logger.info(
278
+ f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
279
+ "Tensor Parallelism to the model"
280
+ )
281
+
282
+
283
+ # for selective op activation checkpointing
284
+ _save_list = {
285
+ torch.ops.aten.mm.default,
286
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
287
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
288
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
289
+ # for low precision training, it's useful to always save
290
+ # the result of max, since the absolute maximum is
291
+ # used to compute the scaling factor for quantization.
292
+ torch.ops.aten.max.default,
293
+ }
294
+
295
+
296
+ def _apply_ac_to_block(module: nn.Module, ac_config):
297
+ valid_ac_modes = ("full", "selective")
298
+ if ac_config.mode not in valid_ac_modes:
299
+ raise ValueError(
300
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
301
+ )
302
+
303
+ if ac_config.mode == "full":
304
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
305
+
306
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
307
+ use_op_sac = ac_config.selective_ac_option == "op"
308
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
309
+ if not use_op_sac and not use_layer_sac:
310
+ raise ValueError(
311
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
312
+ f"Valid options: 'op' or a positive int representing layer frequency"
313
+ )
314
+ if use_op_sac:
315
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
316
+
317
+ def _get_custom_policy(meta):
318
+ def _custom_policy(ctx, func, *args, **kwargs):
319
+ mode = "recompute" if ctx.is_recompute else "forward"
320
+ mm_count_key = f"{mode}_mm_count"
321
+ if func == torch.ops.aten.mm.default:
322
+ meta[mm_count_key] += 1
323
+ # Saves output of all compute ops, except every second mm
324
+ to_save = func in _save_list and not (
325
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
326
+ )
327
+ return (
328
+ CheckpointPolicy.MUST_SAVE
329
+ if to_save
330
+ else CheckpointPolicy.PREFER_RECOMPUTE
331
+ )
332
+
333
+ return _custom_policy
334
+
335
+ def selective_checkpointing_context_fn():
336
+ meta = defaultdict(int)
337
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
338
+
339
+ return ptd_checkpoint_wrapper(
340
+ module,
341
+ context_fn=selective_checkpointing_context_fn,
342
+ preserve_rng_state=False,
343
+ )
344
+ elif use_layer_sac:
345
+ # Checkpoint every `ac_freq` of the modules passed to this function
346
+ ac_freq = int(ac_config.selective_ac_option)
347
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
348
+ ptd_checkpoint_wrapper._count += 1
349
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
350
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
351
+ else:
352
+ return module
353
+
354
+
355
+ def apply_ac(model: nn.Module, ac_config):
356
+ """Apply activation checkpointing to the model."""
357
+ blocks = get_blocks(model)
358
+ if blocks is None:
359
+ logger.warning("No block found for activation checkpointing")
360
+ return
361
+
362
+ for layer_id, block in blocks.named_children():
363
+ block = _apply_ac_to_block(block, ac_config)
364
+ blocks.register_module(layer_id, block)
365
+
366
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
367
+
368
+
369
+ def apply_compile(model: nn.Module):
370
+ """
371
+ Apply torch.compile to each block, which makes compilation efficient due to
372
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
373
+ """
374
+
375
+ blocks = get_blocks(model)
376
+ if blocks is None:
377
+ logger.warning("No block found for torch.compile")
378
+ else:
379
+ for layer_id, block in blocks.named_children():
380
+ block = torch.compile(block)
381
+ blocks.register_module(layer_id, block)
382
+ logger.info("Compiling each block with torch.compile")
383
+
384
+ real_model = get_model(model)
385
+
386
+ logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
387
+ embeddings_key = get_components_name(real_model, "tok_embeddings")
388
+ if embeddings_key is not None:
389
+ embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
390
+ real_model.register_module(embeddings_key, embeddings)
391
+
392
+ norm_key = get_components_name(real_model, "norm")
393
+ if norm_key is not None:
394
+ norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
395
+ real_model.register_module(norm_key, norm)
396
+
397
+ lm_head_key = get_components_name(model, "lm_head")
398
+ if lm_head_key is not None:
399
+ lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
400
+ model.register_module(lm_head_key, lm_head)
401
+
402
+ logger.info("Compiling the entire model with torch.compile")
403
+ model = torch.compile(model)
404
+
405
+
406
+ def apply_fsdp(
407
+ model: nn.Module,
408
+ dp_mesh: DeviceMesh,
409
+ param_dtype: torch.dtype,
410
+ reduce_dtype: torch.dtype,
411
+ pp_enabled: bool,
412
+ cpu_offload: bool = False,
413
+ reshard_after_forward_policy: str = "default",
414
+ ):
415
+ """
416
+ Apply data parallelism (via FSDP2) to the model.
417
+
418
+ Args:
419
+ model (nn.Module): The model to apply data parallelism to.
420
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
421
+ param_dtype (torch.dtype): The data type to use for model parameters.
422
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
423
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
424
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
425
+ reshard_after_forward_policy (str, optional):
426
+ The policy to use for resharding after forward pass. Defaults to "default".
427
+ Other options: "never", "always".
428
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
429
+ - "always" will enable `reshard_after_forward` for all forward passes.
430
+ - "never" will disable `reshard_after_forward` for all forward passes.
431
+
432
+ """
433
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
434
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
435
+ if cpu_offload:
436
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
437
+
438
+ blocks = get_blocks(model)
439
+ if blocks is None:
440
+ logger.warning("No block found for FSDP")
441
+ else:
442
+ total_blocks = len(blocks)
443
+ for layer_id, block in enumerate(blocks):
444
+ if reshard_after_forward_policy == "always":
445
+ reshard_after_forward = True
446
+ elif reshard_after_forward_policy == "never":
447
+ reshard_after_forward = False
448
+ elif reshard_after_forward_policy == "default":
449
+ if pp_enabled:
450
+ # For PP, do not reshard after forward to avoid per-microbatch
451
+ # all-gathers, which can be expensive and non-overlapped
452
+ reshard_after_forward = False
453
+ else:
454
+ # As an optimization, do not reshard after forward for the last
455
+ # transformer block since FSDP would prefetch it immediately
456
+ reshard_after_forward = int(layer_id) < total_blocks - 1
457
+ else:
458
+ raise ValueError(
459
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
460
+ )
461
+ fully_shard(
462
+ block,
463
+ **fsdp_config,
464
+ reshard_after_forward=reshard_after_forward,
465
+ )
466
+
467
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
468
+
469
+
470
+ def apply_ddp(
471
+ model: nn.Module,
472
+ dp_mesh: DeviceMesh,
473
+ enable_compile: bool,
474
+ enable_compiled_autograd: bool,
475
+ ):
476
+ if enable_compile:
477
+ if enable_compiled_autograd:
478
+ torch._dynamo.config.optimize_ddp = (
479
+ "python_reducer_without_compiled_forward"
480
+ )
481
+ else:
482
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
483
+
484
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
485
+
486
+ logger.info("Applied DDP to the model")
487
+
488
+
489
+ def get_model(model):
490
+ base_model_prefix = getattr(model, "base_model_prefix", "model")
491
+ if not hasattr(model, base_model_prefix):
492
+ return None
493
+ model = getattr(model, base_model_prefix)
494
+ return model
495
+
496
+
497
+ def get_blocks(model):
498
+ # TODO[flame]: adapt for network not using 'layers' attribute
499
+ model = get_model(model)
500
+ if not hasattr(model, "layers"):
501
+ logger.warning('no "layers" in model can be found')
502
+ return None
503
+ return model.layers
504
+
505
+
506
+ def get_components_name(model, component_name):
507
+ """
508
+ We try to catch tok_embeddings, norm layers and lm_head layers
509
+ We do not catch the layer names in the blocks, for blocks see `get_blocks`
510
+ We assume the model has the following structure:
511
+ LlamaForCausalLM:
512
+ Model:
513
+ embed_tokens,
514
+ layers,
515
+ norm,
516
+ lm_head
517
+ ***
518
+ so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
519
+ and for 'lm_head' we need to pass `model`
520
+ ***
521
+ """
522
+
523
+ if component_name == "tok_embeddings":
524
+ if hasattr(model, "tok_embeddings"):
525
+ return "tok_embeddings"
526
+ elif hasattr(model, "embed_tokens"):
527
+ return "embed_tokens"
528
+ elif hasattr(model, "embeddings"):
529
+ return "embeddings"
530
+ else:
531
+ logger.warning("No tok_embeddings found in model")
532
+ return None
533
+
534
+ elif component_name == "norm":
535
+ if hasattr(model, "norm"):
536
+ return "norm"
537
+ elif hasattr(model, "norms"):
538
+ return "norms"
539
+ elif hasattr(model, "layernorm"):
540
+ return "layernorm"
541
+ else:
542
+ logger.warning("No norm found in model")
543
+ return None
544
+
545
+ elif component_name == "lm_head":
546
+ if hasattr(model, "lm_head"):
547
+ return "lm_head"
548
+ else:
549
+ logger.warning("No lm_head found in model")
550
+ return None
flame/models/pipeline_fla.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D pipeline parallelism to the Llama model.
8
+
9
+ import copy
10
+ from typing import Callable, Optional, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.pipelining import PipelineStage
16
+ from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
17
+ from transformers import PretrainedConfig
18
+
19
+ from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
20
+ from torchtitan.config_manager import JobConfig
21
+ from torchtitan.distributed.parallel_dims import ParallelDims
22
+ from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
23
+ from torchtitan.tools.logging import logger
24
+
25
+ DeviceType = Union[int, str, torch.device]
26
+
27
+
28
+ def pipeline_fla(
29
+ model: nn.Module,
30
+ pp_mesh: DeviceMesh,
31
+ parallel_dims: ParallelDims,
32
+ job_config: JobConfig,
33
+ device: DeviceType,
34
+ model_config: PretrainedConfig,
35
+ loss_fn: Callable[..., torch.Tensor],
36
+ ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
37
+ stages, models = pipeline_fla_manual_split(
38
+ model, pp_mesh, parallel_dims, job_config, device, model_config
39
+ )
40
+
41
+ pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
42
+
43
+ # This is used in the train loop to determine whether to pass in the input_ids and labels
44
+ has_first_stage = False
45
+ has_last_stage = False
46
+ for stage in stages:
47
+ if stage.is_first:
48
+ has_first_stage = True
49
+ if stage.is_last:
50
+ has_last_stage = True
51
+
52
+ return pp_schedule, models, has_first_stage, has_last_stage
53
+
54
+
55
+ def pipeline_fla_manual_split(
56
+ whole_model: nn.Module,
57
+ pp_mesh: DeviceMesh,
58
+ parallel_dims: ParallelDims,
59
+ job_config: JobConfig,
60
+ device: DeviceType,
61
+ model_config: PretrainedConfig,
62
+ ) -> tuple[list[PipelineStage], list[nn.Module]]:
63
+ """
64
+ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
65
+
66
+ It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
67
+
68
+ The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
69
+ parallelism.
70
+ """
71
+ pp_rank = pp_mesh.get_local_rank()
72
+ pp_size = pp_mesh.size()
73
+
74
+ splits = (
75
+ job_config.experimental.pipeline_parallel_split_points
76
+ or generate_split_points(
77
+ job_config, parallel_dims.pp, model_config.num_hidden_layers
78
+ )
79
+ )
80
+
81
+ def _build_stage(
82
+ stage_idx: int,
83
+ start_layer: Optional[str],
84
+ stop_layer: Optional[str],
85
+ is_first: bool = False,
86
+ is_last: bool = False,
87
+ ) -> tuple[PipelineStage, nn.Module]:
88
+ model = copy.deepcopy(whole_model)
89
+ if not is_first:
90
+ # we do `model.tok_embeddings = None` here
91
+ real_model = get_model(model)
92
+ tok_embeddings_name = get_components_name(real_model, "tok_embeddings")
93
+ setattr(real_model, tok_embeddings_name, None)
94
+
95
+ drop_layers = start_layer is not None
96
+ # Get module dictionary from get_blocks(model)
97
+ # and Create a list of keys before modifying dictionary
98
+ module_dict = get_blocks(model)._modules # Store reference
99
+ layer_names = list(module_dict.keys())
100
+
101
+ # Iterate over the list of keys instead of `_modules.items()`
102
+ for name in layer_names:
103
+ # Dynamically determine prefix (blocks.* or layers.*)
104
+ prefix = start_layer.split(".")[0] if start_layer else "layers"
105
+ layer_name = f"{prefix}.{name}" # Construct the correct name format
106
+
107
+ # Ensure `drop_layers` activation is based on actual naming
108
+ if layer_name == start_layer:
109
+ drop_layers = False
110
+ if layer_name == stop_layer:
111
+ drop_layers = True
112
+
113
+ # Delete layer if drop_layers is active
114
+ if drop_layers:
115
+ del module_dict[name] # Safe deletion from stored dictionary
116
+
117
+ if not is_last:
118
+ # we do `model.norm = None` and `model.output = None`
119
+ real_model = get_model(model)
120
+ norm_name = get_components_name(real_model, "norm")
121
+ setattr(real_model, norm_name, None)
122
+
123
+ head_name = get_components_name(model, "lm_head")
124
+ setattr(model, head_name, None)
125
+
126
+ stage = PipelineStage(
127
+ model,
128
+ stage_idx,
129
+ num_stages,
130
+ device,
131
+ group=pp_mesh.get_group("pp"),
132
+ )
133
+ return stage, model
134
+
135
+ num_stages = len(splits) + 1
136
+ stage_idx = pp_rank
137
+
138
+ stages = []
139
+ models = []
140
+
141
+ schedule_class = get_schedule_class(
142
+ job_config.experimental.pipeline_parallel_schedule
143
+ )
144
+ style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
145
+
146
+ for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
147
+ start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
148
+ stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
149
+ stage, model_chunk = _build_stage(
150
+ stage_idx,
151
+ start_layer,
152
+ stop_layer,
153
+ is_first=stage_idx == 0,
154
+ is_last=stage_idx == num_stages - 1,
155
+ )
156
+ logger.info(
157
+ f"PP rank {pp_rank} is building stage_idx {stage_idx}"
158
+ f" with start_layer {start_layer}, stop_layer {stop_layer}"
159
+ )
160
+ stages.append(stage)
161
+ models.append(model_chunk)
162
+ return stages, models
flame/tools/__init__.py ADDED
File without changes
flame/tools/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (136 Bytes). View file
 
flame/tools/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.14 kB). View file
 
flame/tools/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+ from torchtitan.tools.logging import logger
9
+
10
+
11
+ def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
12
+ nparams = sum(p.numel() for p in model.parameters())
13
+ nparams_embedding = sum(
14
+ sum(p.numel() for p in m.parameters())
15
+ for m in model.children()
16
+ if isinstance(m, nn.Embedding)
17
+ )
18
+
19
+ if hasattr(model_config, "num_heads"):
20
+ num_heads = model_config.num_heads
21
+ elif hasattr(model_config, "num_attention_heads"):
22
+ num_heads = model_config.num_attention_heads
23
+ else:
24
+ num_heads = 1
25
+ logger.warning("num_heads not found in model_config, defaulting to 1. ")
26
+
27
+ l, h, q, t = (
28
+ model_config.num_hidden_layers,
29
+ num_heads,
30
+ model_config.hidden_size // num_heads,
31
+ seq_len,
32
+ )
33
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
34
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
35
+ # 2. the flash attention does 1 more matmul recomputation in the backward
36
+ # but recomputation should not be counted in calculating MFU (+0)
37
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
38
+ # 4. we follow the convention and do not account for sparsity in causal attention
39
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
40
+
41
+ return nparams, num_flops_per_token
flame/utils/__init__.py ADDED
File without changes
flame/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (136 Bytes). View file
 
flame/utils/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (4.07 kB). View file
 
flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc ADDED
Binary file (3.73 kB). View file
 
flame/utils/__pycache__/convert_hf_to_dcp.cpython-312.pyc ADDED
Binary file (1.92 kB). View file
 
flame/utils/__pycache__/hf_utils.cpython-312.pyc ADDED
Binary file (4.46 kB). View file
 
flame/utils/checkpoint.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import re
4
+ import shutil
5
+ from torchtitan.tools.logging import logger
6
+
7
+
8
+ def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int):
9
+ """Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats."""
10
+ if keep_latest_k <= 0:
11
+ return # Keep all checkpoints
12
+
13
+ logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}")
14
+
15
+ # Cleanup DCP checkpoints (step-*)
16
+ dcp_checkpoints = sorted(
17
+ glob.glob(os.path.join(checkpoint_dir, "step-*")),
18
+ key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1,
19
+ reverse=True
20
+ )
21
+ # Filter out HF format directories
22
+ dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")]
23
+
24
+ if len(dcp_checkpoints) > keep_latest_k:
25
+ checkpoints_to_delete = dcp_checkpoints[keep_latest_k:]
26
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
27
+ for ckpt_path in checkpoints_to_delete:
28
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
29
+ try:
30
+ shutil.rmtree(ckpt_path)
31
+ except OSError as e:
32
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
33
+
34
+
35
+ # Cleanup HF checkpoints (step-*-hf)
36
+ hf_checkpoints = sorted(
37
+ glob.glob(os.path.join(checkpoint_dir, "step-*-hf")),
38
+ key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1,
39
+ reverse=True
40
+ )
41
+
42
+ if len(hf_checkpoints) > keep_latest_k:
43
+ checkpoints_to_delete = hf_checkpoints[keep_latest_k:]
44
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
45
+ for ckpt_path in checkpoints_to_delete:
46
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
47
+ try:
48
+ shutil.rmtree(ckpt_path)
49
+ except OSError as e:
50
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
flame/utils/convert_dcp_to_hf.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ import io
6
+ import os
7
+ import tempfile
8
+ from datetime import timedelta
9
+
10
+ import torch
11
+ import torch.serialization
12
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
13
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
14
+
15
+ import fla # noqa
16
+ from torchtitan.tools.logging import init_logger, logger
17
+
18
+
19
+ @torch.inference_mode()
20
+ def save_pretrained(
21
+ path: str,
22
+ step: int,
23
+ config: str,
24
+ tokenizer: str
25
+ ):
26
+ logger.info(f"Loading the config from {config}")
27
+ config = AutoConfig.from_pretrained(config, trust_remote_code=True)
28
+
29
+ logger.info(f"Saving the config to {path}")
30
+ config.save_pretrained(path)
31
+ logger.info(f"Loading the tokenizer from {tokenizer}")
32
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
33
+ logger.info(f"Saving the tokenizer to {path}")
34
+ tokenizer.save_pretrained(path)
35
+
36
+ with tempfile.TemporaryDirectory() as tmpdir:
37
+ # base_checkpoint_dir = os.path.dirname(path)
38
+ base_checkpoint_dir = path
39
+ checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}')
40
+ checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
41
+ logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
42
+ dcp_to_torch_save(checkpoint, checkpoint_path)
43
+
44
+ logger.info(f"Initializing the model from config\n{config}")
45
+ model = AutoModelForCausalLM.from_config(config)
46
+ logger.info(model)
47
+ logger.info("Loading state dict from the checkpoint")
48
+
49
+ # Add datetime.timedelta and io.BytesIO to safe globals
50
+ torch.serialization.add_safe_globals([timedelta, io.BytesIO])
51
+ # torch.load now with default weights_only=True will work
52
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
53
+
54
+ logger.info(f"Saving the model to {path}")
55
+ model.save_pretrained(path)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ init_logger()
60
+ parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
61
+ parser.add_argument("--path", type=str, required=True)
62
+ parser.add_argument("--step", type=int, required=True)
63
+ parser.add_argument("--config", type=str, required=True)
64
+ parser.add_argument("--tokenizer", type=str, required=True)
65
+ args = parser.parse_args()
66
+ save_pretrained(args.path, args.step, args.config, args.tokenizer)
flame/utils/convert_hf_to_dcp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.distributed.checkpoint as DCP
9
+ from transformers import AutoModelForCausalLM
10
+
11
+ import fla # noqa
12
+ from torchtitan.tools.logging import init_logger, logger
13
+
14
+
15
+ @torch.inference_mode()
16
+ def convert_hf_weights(model: str, checkpoint: str):
17
+ logger.info(f"Loading model from {model}")
18
+ model = AutoModelForCausalLM.from_pretrained(model)
19
+ state_dict = model.state_dict()
20
+
21
+ logger.info(f"Writing to DCP at '{checkpoint}'")
22
+ checkpoint.mkdir(parents=True, exist_ok=True)
23
+ storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
24
+ DCP.save({"model": state_dict}, storage_writer=storage_writer)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ init_logger()
29
+ parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
30
+ parser.add_argument("--model", type=str, required=True)
31
+ parser.add_argument("--checkpoint", type=Path, required=True)
32
+ args = parser.parse_args()
33
+
34
+ convert_hf_weights(args.model, args.checkpoint)
flame/utils/hf_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo
4
+ from torchtitan.tools.logging import logger
5
+
6
+ def upload_checkpoint_to_hf(
7
+ local_path: str,
8
+ step: int,
9
+ hf_repo_id_for_run: str,
10
+ hf_keep_latest_k: int,
11
+ upload_format: str
12
+ ):
13
+ """Uploads a checkpoint directory to HF Hub and manages retention."""
14
+ if not os.path.isdir(local_path):
15
+ logger.error(f"Local path for upload does not exist or is not a directory: {local_path}")
16
+ return
17
+
18
+ api = HfApi()
19
+ token = HfFolder.get_token()
20
+ if not token:
21
+ logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.")
22
+ return
23
+
24
+ # --- Ensure the specific repository for this run exists ---
25
+ try:
26
+ logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...")
27
+ # Use create_repo which handles creation only if it doesn't exist
28
+ create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True)
29
+ logger.info(f"Repository {hf_repo_id_for_run} ensured.")
30
+ except Exception as e:
31
+ logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True)
32
+ return # Stop if repo interaction fails
33
+
34
+ commit_message = f"Upload {upload_format.upper()} checkpoint step {step}"
35
+ path_in_repo = f"step-{step}"
36
+
37
+ logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...")
38
+ try:
39
+ api.upload_folder(
40
+ folder_path=local_path,
41
+ path_in_repo=path_in_repo,
42
+ repo_id=hf_repo_id_for_run,
43
+ repo_type="model",
44
+ commit_message=commit_message,
45
+ token=token,
46
+ )
47
+ logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.")
48
+ except Exception as e:
49
+ logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True)
50
+ if hf_keep_latest_k > 0:
51
+ logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}")
52
+ try:
53
+ repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False)
54
+ step_folders = [
55
+ item.path for item in repo_files
56
+ if item.path.startswith("step-") and item.path[5:].isdigit()
57
+ ]
58
+
59
+ step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True)
60
+
61
+ if len(step_folders) > hf_keep_latest_k:
62
+ folders_to_delete = step_folders[hf_keep_latest_k:]
63
+ logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}")
64
+ for folder in folders_to_delete:
65
+ # Deleting requires repo_id, path_in_repo, and token
66
+ api.delete_folder(
67
+ repo_id=hf_repo_id_for_run,
68
+ path_in_repo=folder,
69
+ repo_type="model",
70
+ commit_message=f"Delete old checkpoint {folder}",
71
+ token=token
72
+ )
73
+ logger.info("Hub cleanup complete.")
74
+ else:
75
+ logger.info("No old checkpoints found on Hub to delete.")
76
+ except Exception as e:
77
+ logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True)
logs/none_enyj3lod/attempt_0/5/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_enyj3lod/attempt_0/7/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_15872/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_16384/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_16384/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_16384/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_19968/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_19968/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_3072/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_32256/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_37376/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_37376/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_9216/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_9216/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_9216/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_9216/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_9216/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_9728/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/files/output.log ADDED
The diff for this file is too large to render. See raw diff
 
tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/files/wandb-metadata.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-6.8.0-62-generic-x86_64-with-glibc2.39",
3
+ "python": "CPython 3.12.11",
4
+ "startedAt": "2025-09-11T14:15:51.409164Z",
5
+ "args": [
6
+ "--job.config_file",
7
+ "flame/models/fla.toml",
8
+ "--job.dump_folder",
9
+ "exp/mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine",
10
+ "--model.config",
11
+ "configs/mtp_transformer_7B.json",
12
+ "--model.tokenizer_path",
13
+ "fla-hub/transformer-1.3B-100B",
14
+ "--optimizer.name",
15
+ "AdamW",
16
+ "--optimizer.eps",
17
+ "1e-15",
18
+ "--optimizer.lr",
19
+ "2e-5",
20
+ "--lr_scheduler.warmup_steps",
21
+ "400",
22
+ "--lr_scheduler.lr_min",
23
+ "0.1",
24
+ "--lr_scheduler.decay_type",
25
+ "cosine",
26
+ "--training.batch_size",
27
+ "8",
28
+ "--training.seq_len",
29
+ "4096",
30
+ "--training.context_len",
31
+ "4096",
32
+ "--training.gradient_accumulation_steps",
33
+ "2",
34
+ "--training.steps",
35
+ "40000",
36
+ "--training.max_norm",
37
+ "1.0",
38
+ "--training.skip_nan_inf",
39
+ "--training.dataset",
40
+ "/home/cvm/.cache/zaydzuhri___stack-edu-python/default",
41
+ "--training.dataset_split",
42
+ "train",
43
+ "--training.num_workers",
44
+ "32",
45
+ "--training.prefetch_factor",
46
+ "2",
47
+ "--training.seed",
48
+ "79",
49
+ "--training.compile",
50
+ "--checkpoint.interval",
51
+ "5000",
52
+ "--checkpoint.load_step",
53
+ "-1",
54
+ "--metrics.log_freq",
55
+ "5",
56
+ "--checkpoint.hf_upload_enabled",
57
+ "--checkpoint.hf_repo_base_name",
58
+ "zaydzuhri/mtp-code-7B-4096-batch8x2-steps40000",
59
+ "--comm.init_timeout_seconds",
60
+ "6000",
61
+ "--comm.train_timeout_seconds",
62
+ "6000"
63
+ ],
64
+ "program": "-m flame.train",
65
+ "git": {
66
+ "remote": "https://github.com/zaydzuhri/flame.git",
67
+ "commit": "aa4d5932e54fad8a568e10aa6895e69e0664fcf1"
68
+ },
69
+ "email": "zaydzuhri@gmail.com",
70
+ "root": "exp/mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20250911-1415",
71
+ "host": "cvm-hnlakfcy",
72
+ "executable": "/home/cvm/miniconda3/envs/flame-env/bin/python3.12",
73
+ "cpu_count": 64,
74
+ "cpu_count_logical": 128,
75
+ "gpu": "NVIDIA H200",
76
+ "gpu_count": 8,
77
+ "disk": {
78
+ "/": {
79
+ "total": "3242363822080",
80
+ "used": "1142457630720"
81
+ }
82
+ },
83
+ "memory": {
84
+ "total": "1913832992768"
85
+ },
86
+ "gpu_nvidia": [
87
+ {
88
+ "name": "NVIDIA H200",
89
+ "memoryTotal": "150754820096",
90
+ "cudaCores": 16896,
91
+ "architecture": "Hopper",
92
+ "uuid": "GPU-248746a8-c843-17da-da73-f7e913ce8534"
93
+ },
94
+ {
95
+ "name": "NVIDIA H200",
96
+ "memoryTotal": "150754820096",
97
+ "cudaCores": 16896,
98
+ "architecture": "Hopper",
99
+ "uuid": "GPU-dd71b7fe-465a-c7fc-695c-92022644d1e4"
100
+ },
101
+ {
102
+ "name": "NVIDIA H200",
103
+ "memoryTotal": "150754820096",
104
+ "cudaCores": 16896,
105
+ "architecture": "Hopper",
106
+ "uuid": "GPU-fa231ade-f7f2-7b4b-7038-1b6ba3478565"
107
+ },
108
+ {
109
+ "name": "NVIDIA H200",
110
+ "memoryTotal": "150754820096",
111
+ "cudaCores": 16896,
112
+ "architecture": "Hopper",
113
+ "uuid": "GPU-6c677375-a50c-d5ca-a517-8bab66e768e5"
114
+ },
115
+ {
116
+ "name": "NVIDIA H200",
117
+ "memoryTotal": "150754820096",
118
+ "cudaCores": 16896,
119
+ "architecture": "Hopper",
120
+ "uuid": "GPU-e98c9a5d-ed96-14c6-fcd7-e32e074c4ec6"
121
+ },
122
+ {
123
+ "name": "NVIDIA H200",
124
+ "memoryTotal": "150754820096",
125
+ "cudaCores": 16896,
126
+ "architecture": "Hopper",
127
+ "uuid": "GPU-0325ab0c-c935-f0f5-8488-1504586966c0"
128
+ },
129
+ {
130
+ "name": "NVIDIA H200",
131
+ "memoryTotal": "150754820096",
132
+ "cudaCores": 16896,
133
+ "architecture": "Hopper",
134
+ "uuid": "GPU-bd82a8ad-cbaf-446f-5012-c8841749fe95"
135
+ },
136
+ {
137
+ "name": "NVIDIA H200",
138
+ "memoryTotal": "150754820096",
139
+ "cudaCores": 16896,
140
+ "architecture": "Hopper",
141
+ "uuid": "GPU-2aeed443-dce7-f05e-6010-086fa04a9413"
142
+ }
143
+ ],
144
+ "cudaVersion": "12.8",
145
+ "writerId": "173cboaedkpy0bqi2wup4pr1w4lb2n87"
146
+ }
tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/logs/debug-core.log ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2025-09-11T14:15:51.448889635Z","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpme1m9m73/port-2338706.txt","pid":2338706,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
2
+ {"time":"2025-09-11T14:15:51.449765208Z","level":"INFO","msg":"server: will exit if parent process dies","ppid":2338706}
3
+ {"time":"2025-09-11T14:15:51.449708027Z","level":"INFO","msg":"server: accepting connections","addr":{"Name":"/tmp/wandb-2338706-2345351-2853185180/socket","Net":"unix"}}
4
+ {"time":"2025-09-11T14:15:51.617113051Z","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"1(@)"}
5
+ {"time":"2025-09-11T14:15:51.621427641Z","level":"INFO","msg":"handleInformInit: received","streamId":"mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411","id":"1(@)"}
6
+ {"time":"2025-09-11T14:15:51.915371225Z","level":"INFO","msg":"handleInformInit: stream started","streamId":"mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411","id":"1(@)"}
7
+ {"time":"2025-09-14T18:19:06.56412841Z","level":"INFO","msg":"handleInformFinish: finish message received","streamId":"mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411","id":"1(@)"}
8
+ {"time":"2025-09-14T18:19:06.565499094Z","level":"INFO","msg":"handleInformFinish: stream closed","streamId":"mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411","id":"1(@)"}
9
+ {"time":"2025-09-14T18:19:14.63018152Z","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"1(@)"}
10
+ {"time":"2025-09-14T18:19:14.630252543Z","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"1(@)"}
11
+ {"time":"2025-09-14T18:19:14.630261701Z","level":"INFO","msg":"server is shutting down"}
12
+ {"time":"2025-09-14T18:19:14.630333303Z","level":"INFO","msg":"server: listener closed","addr":{"Name":"/tmp/wandb-2338706-2345351-2853185180/socket","Net":"unix"}}
13
+ {"time":"2025-09-14T18:19:14.630331706Z","level":"INFO","msg":"connection: closing","id":"1(@)"}
14
+ {"time":"2025-09-14T18:19:14.630454034Z","level":"INFO","msg":"connection: closed successfully","id":"1(@)"}
15
+ {"time":"2025-09-14T18:19:14.630460521Z","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"1(@)"}
16
+ {"time":"2025-09-14T18:19:14.630484673Z","level":"INFO","msg":"server is closed"}
torchtitan/__pycache__/config_manager.cpython-312.pyc ADDED
Binary file (38.5 kB). View file