zaydzuhri commited on
Commit
8cb4047
·
verified ·
1 Parent(s): 683df89

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc +0 -0
  2. fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc +0 -0
  3. flame/__pycache__/config_manager.cpython-312.pyc +0 -0
  4. flame/__pycache__/train.cpython-312.pyc +0 -0
  5. flame/components/__init__.py +0 -0
  6. flame/components/__pycache__/__init__.cpython-312.pyc +0 -0
  7. flame/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
  8. flame/components/checkpoint.py +59 -0
  9. flame/models/__init__.py +0 -0
  10. flame/models/__pycache__/__init__.cpython-312.pyc +0 -0
  11. flame/models/__pycache__/parallelize_fla.cpython-312.pyc +0 -0
  12. flame/models/__pycache__/pipeline_fla.cpython-312.pyc +0 -0
  13. flame/models/activation_offloading.py +447 -0
  14. flame/models/fla.toml +67 -0
  15. flame/models/parallelize_fla.py +550 -0
  16. flame/models/pipeline_fla.py +162 -0
  17. flame/tools/__init__.py +0 -0
  18. flame/tools/__pycache__/__init__.cpython-312.pyc +0 -0
  19. flame/tools/__pycache__/utils.cpython-312.pyc +0 -0
  20. flame/tools/utils.py +41 -0
  21. flame/utils/__init__.py +0 -0
  22. flame/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  23. flame/utils/__pycache__/checkpoint.cpython-312.pyc +0 -0
  24. flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc +0 -0
  25. flame/utils/__pycache__/convert_hf_to_dcp.cpython-312.pyc +0 -0
  26. flame/utils/__pycache__/hf_utils.cpython-312.pyc +0 -0
  27. flame/utils/checkpoint.py +50 -0
  28. flame/utils/convert_dcp_to_hf.py +66 -0
  29. flame/utils/convert_hf_to_dcp.py +34 -0
  30. flame/utils/hf_utils.py +77 -0
  31. logs/none_ewbp5xc1/attempt_0/1/stderr.log +0 -0
  32. profile_trace/iteration_1024/rank0_trace.json +0 -0
  33. profile_trace/iteration_1024/rank1_trace.json +0 -0
  34. profile_trace/iteration_1024/rank5_trace.json +0 -0
  35. profile_trace/iteration_1024/rank6_trace.json +0 -0
  36. profile_trace/iteration_1024/rank7_trace.json +0 -0
  37. profile_trace/iteration_1536/rank4_trace.json +0 -0
  38. profile_trace/iteration_20992/rank5_trace.json +0 -0
  39. profile_trace/iteration_23552/rank6_trace.json +0 -0
  40. profile_trace/iteration_2560/rank5_trace.json +0 -0
  41. profile_trace/iteration_2560/rank7_trace.json +0 -0
  42. profile_trace/iteration_29696/rank2_trace.json +0 -0
  43. profile_trace/iteration_29696/rank6_trace.json +0 -0
  44. profile_trace/iteration_30720/rank6_trace.json +0 -0
  45. profile_trace/iteration_3584/rank0_trace.json +0 -0
  46. profile_trace/iteration_3584/rank4_trace.json +0 -0
  47. profile_trace/iteration_3584/rank5_trace.json +0 -0
  48. profile_trace/iteration_3584/rank7_trace.json +0 -0
  49. tb/20250901-0749/wandb/run-20250901_074914-top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/files/wandb-metadata.json +146 -0
  50. tb/20250901-0749/wandb/run-20250901_074914-top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/logs/debug-internal.log +17 -0
fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc ADDED
Binary file (3.76 kB). View file
 
flame/__pycache__/config_manager.cpython-312.pyc ADDED
Binary file (36.9 kB). View file
 
flame/__pycache__/train.cpython-312.pyc ADDED
Binary file (38.1 kB). View file
 
flame/components/__init__.py ADDED
File without changes
flame/components/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (141 Bytes). View file
 
flame/components/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (3.21 kB). View file
 
flame/components/checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+ from datetime import timedelta
9
+ from io import BytesIO
10
+ from typing import Any, Dict, List
11
+
12
+ import torch
13
+ from torch.distributed.checkpoint.stateful import Stateful
14
+
15
+
16
+ @dataclass
17
+ class TrainState(Stateful):
18
+ step: int = 0
19
+ skipped_step: int = 0
20
+ token: int = 0
21
+ elapsed: timedelta = timedelta(0)
22
+ global_avg_losses: List[float] = field(default_factory=list)
23
+ global_max_losses: List[float] = field(default_factory=list)
24
+ log_steps: List[int] = field(default_factory=list)
25
+
26
+ def state_dict(self) -> Dict[str, Any]:
27
+ # Only checkpoint global_avg_losses and global_max_losses per log frequency
28
+ # to avoid sync overhead in every iteration.
29
+ global_avg_losses_bytes = BytesIO()
30
+ torch.save(self.global_avg_losses, global_avg_losses_bytes)
31
+ global_max_losses_bytes = BytesIO()
32
+ torch.save(self.global_max_losses, global_max_losses_bytes)
33
+ log_steps_bytes = BytesIO()
34
+ torch.save(self.log_steps, log_steps_bytes)
35
+ return {
36
+ "step": torch.tensor(self.step, dtype=torch.int32),
37
+ "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
38
+ "token": torch.tensor(self.token, dtype=torch.int64),
39
+ "elapsed": self.elapsed,
40
+ "global_avg_losses": global_avg_losses_bytes,
41
+ "global_max_losses": global_max_losses_bytes,
42
+ "log_steps": log_steps_bytes,
43
+ }
44
+
45
+ def load_state_dict(self, state_dict) -> None:
46
+ self.step = state_dict["step"].item()
47
+ self.skipped_step = state_dict.get("skipped_step", 0).item()
48
+ self.token = state_dict["token"].item()
49
+ self.elapsed = state_dict["elapsed"]
50
+ state_dict["global_avg_losses"].seek(0)
51
+ self.global_avg_losses = torch.load(
52
+ state_dict["global_avg_losses"], weights_only=False
53
+ )
54
+ state_dict["global_max_losses"].seek(0)
55
+ self.global_max_losses = torch.load(
56
+ state_dict["global_max_losses"], weights_only=False
57
+ )
58
+ state_dict["log_steps"].seek(0)
59
+ self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
flame/models/__init__.py ADDED
File without changes
flame/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (137 Bytes). View file
 
flame/models/__pycache__/parallelize_fla.cpython-312.pyc ADDED
Binary file (22.1 kB). View file
 
flame/models/__pycache__/pipeline_fla.cpython-312.pyc ADDED
Binary file (5.75 kB). View file
 
flame/models/activation_offloading.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/training/_activation_offloading.py
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import contextlib
9
+ from typing import Union
10
+ from warnings import warn
11
+
12
+ import psutil
13
+ import torch
14
+ from torch import nn
15
+ from torch.autograd.graph import saved_tensors_hooks
16
+
17
+ from torchtitan.tools.logging import logger
18
+
19
+ try:
20
+ import torchao
21
+ from torchao.dtypes.nf4tensor import NF4Tensor
22
+ except ImportError:
23
+ torchao = None
24
+ NF4Tensor = None
25
+ logger.warning("torchao not found. ")
26
+
27
+ # from torchtune.modules import TiedLinear
28
+
29
+
30
+ class OffloadActivations(saved_tensors_hooks):
31
+ """Context manager under which activation tensors created in the forward pass will be offloaded.
32
+
33
+ Enable the memory efficiency technique of activation offloading, where activations bigger than
34
+ min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward.
35
+ This is in contrast to maintaining the activation on GPU VRAM throughout the program.
36
+
37
+ This manager contains the option of using one additional CUDA stream to handle the communication
38
+ between CUDA and CPU, which is intended to overlap with the default computation stream to improve
39
+ runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between
40
+ runtime vs memory usage.
41
+
42
+ Args:
43
+ use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned
44
+ memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly
45
+ but is a limited resource. Default: True.
46
+
47
+ use_streams (bool): Whether or not to use streams for performance optimization where
48
+ the communications get overlapped with the computation. Requires a torch build
49
+ after torch-2.5.0.]. Default: True.
50
+
51
+ max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of
52
+ consecutive activations to keep alive during the forward pass. This number must be at
53
+ least 1. Keeping alive more activations will potentially allow more overlap between the
54
+ communication and compute streams at the cost of increasing memory usage. Keeping alive
55
+ fewer activations will conserve memory, but may cause poor overlap between the streams,
56
+ increasing runtime. Default: 5.
57
+
58
+ min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify
59
+ for offloading. If the tensor is too small, we do not want to waste bandwidth and resources
60
+ moving it to CPU and back. Default: 1024 bytes.
61
+
62
+ Raises:
63
+ ValueError: if max_fwd_stash_size is not at least 1.
64
+
65
+ Example:
66
+ >>> with OffloadActivations():
67
+ >>> logits = model(inputs)
68
+ >>> loss = ...
69
+ >>> loss.backward()
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ use_pin_memory: bool = True,
75
+ use_streams: bool = True,
76
+ max_fwd_stash_size: int = 5,
77
+ min_offload_size: int = 1024,
78
+ ) -> None:
79
+
80
+ self.use_streams: bool = use_streams
81
+
82
+ self.min_tensor_size_bytes = (
83
+ min_offload_size # we don't want to bother with small tensors
84
+ )
85
+ self.tracker = (
86
+ {}
87
+ ) # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where
88
+ self.tensor_id: int = 0
89
+ self.is_first_forward_call = True
90
+ self.is_first_backward_call = True
91
+ self.is_first_forward_pass = True
92
+
93
+ # managing cpu memory
94
+ self.use_pin_memory: bool = use_pin_memory
95
+ self.virtual_memory_safe_pct = (
96
+ 60 # we should not exceed this percentage of memory
97
+ )
98
+
99
+ self.s0 = torch.cuda.default_stream() # comp stream
100
+
101
+ # for streaming
102
+ if self.use_streams:
103
+ self.s1 = torch.cuda.Stream() # comms stream
104
+ self.fwd_stash = {} # tensor_id => (activation, ev1)
105
+ if max_fwd_stash_size < 1:
106
+ raise ValueError(
107
+ f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}"
108
+ )
109
+ self.max_fwd_stash_size = max_fwd_stash_size
110
+ self.bwd_tensor_stash = {} # tensor_id => activation
111
+ self.bwd_ev_stash = {} # tensor_id => ev0
112
+ self.curr_graph_id = None
113
+ self.curr_autograd_node = None
114
+
115
+ # -------- platform util functions -------- #
116
+ def verify_sufficient_virtual_memory():
117
+ curr_pct = get_cpu_ram_pct()
118
+ if curr_pct > self.virtual_memory_safe_pct:
119
+ warn(
120
+ f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used"
121
+ )
122
+
123
+ def get_cpu_ram_pct() -> float:
124
+ # get the percentage of memory used by the system
125
+ return psutil.virtual_memory().percent
126
+
127
+ def get_tensor_id() -> int:
128
+ # create a unique id for each tensor we are managing
129
+ self.tensor_id += 1
130
+ return self.tensor_id
131
+
132
+ def get_num_bytes_tensor(x: torch.Tensor) -> int:
133
+ # get the number of bytes in a tensor, for memory management purposes
134
+ return (
135
+ x.element_size() * x.nelement()
136
+ ) # x.element_size() * x._base_storage().nbytes()
137
+
138
+ # -------- core pack / unpack work -------- #
139
+ def pack_tensor(activation: torch.Tensor) -> int:
140
+ # activations are passed in during forward pass - from here we take over and return a unique id
141
+ if self.is_first_forward_call:
142
+ assert (
143
+ len(self.tracker) == 0
144
+ ), "backward pass should have cleared tracker of all tensors"
145
+
146
+ # set training phase trackers
147
+ self.is_first_forward_call = False
148
+ self.is_first_backward_call = True
149
+
150
+ # query for basic tensor info
151
+ num_bytes = get_num_bytes_tensor(activation)
152
+ tensor_id = get_tensor_id()
153
+
154
+ # only offload hefty bois if they're activations on CUDA (our heuristic
155
+ # for that is to check if they're not params or buffers)!
156
+ if (
157
+ activation.is_cuda
158
+ and num_bytes >= self.min_tensor_size_bytes
159
+ and (
160
+ not isinstance(activation, torch.nn.Parameter)
161
+ and not isinstance(activation, torch.nn.Buffer)
162
+ )
163
+ ):
164
+ if self.use_streams:
165
+ # First, sync back and dereference previously offloaded tensors
166
+ # as the offloading should be done sufficiently long ago.
167
+ for id in [k for k in self.fwd_stash.keys()]:
168
+ if id <= tensor_id - self.max_fwd_stash_size:
169
+ _, ev = self.fwd_stash[id]
170
+ self.s0.wait_event(ev)
171
+ del self.fwd_stash[id]
172
+ else:
173
+ break
174
+
175
+ # Sync in, offload, and add an event to sync back later
176
+ self.s1.wait_stream(self.s0)
177
+
178
+ stream = self.s1 if self.use_streams else self.s0
179
+ with torch.cuda.stream(stream):
180
+ try:
181
+ cpu_tensor = torch.empty_like(
182
+ activation, pin_memory=self.use_pin_memory, device="cpu"
183
+ )
184
+ except NotImplementedError as e:
185
+ if (
186
+ isinstance(activation, NF4Tensor)
187
+ and torchao.__version__ < "0.6.0.dev20240917"
188
+ ):
189
+ raise RuntimeError(
190
+ "Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later"
191
+ ) from e
192
+ raise e
193
+ cpu_tensor.copy_(activation, non_blocking=True)
194
+ self.tracker[tensor_id] = (
195
+ cpu_tensor,
196
+ True,
197
+ ) # True = (in future) modified
198
+
199
+ if self.use_streams:
200
+ event = self.s1.record_event()
201
+
202
+ # Stash to keep activation alive til s1 is done
203
+ self.fwd_stash[tensor_id] = (activation, event)
204
+ else:
205
+ self.tracker[tensor_id] = (
206
+ activation,
207
+ False,
208
+ ) # False = not modified, tensor is as is
209
+
210
+ return tensor_id
211
+
212
+ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
213
+ # backward pass - we are called with the tensor_id, which
214
+ # we will use to retrieve the saved/offloaded tensor
215
+ if self.is_first_backward_call:
216
+ if self.is_first_forward_pass:
217
+ self.is_first_forward_pass = False
218
+ if self.use_pin_memory:
219
+ verify_sufficient_virtual_memory()
220
+
221
+ self.is_first_backward_call = False
222
+ self.is_first_forward_call = True
223
+
224
+ assert (
225
+ unpack_tensor_id in self.tracker
226
+ ), f"untracked tensor with id {unpack_tensor_id}"
227
+
228
+ maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
229
+ if modified:
230
+ gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
231
+ maybe_gpu_tensor = gpu_tensor
232
+
233
+ # clear tensor from tracking
234
+ del self.tracker[unpack_tensor_id]
235
+ return maybe_gpu_tensor
236
+
237
+ def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
238
+ # backward pass - we are called with the tensor_id, which
239
+ # we will use to retrieve the saved/offloaded tensor
240
+ if self.is_first_backward_call:
241
+ self.curr_graph_id = torch._C._current_graph_task_id()
242
+
243
+ def wait_and_del_remaining_references() -> None:
244
+ for id in [k for k in self.bwd_tensor_stash.keys()]:
245
+ event = self.bwd_ev_stash[id]
246
+ self.s1.wait_event(event)
247
+ del self.bwd_tensor_stash[id]
248
+
249
+ # Register a callback to the end of autograd to clean everything up
250
+ torch.autograd.variable.Variable._execution_engine.queue_callback(
251
+ wait_and_del_remaining_references
252
+ )
253
+
254
+ if self.is_first_forward_pass:
255
+ self.is_first_forward_pass = False
256
+ if self.use_pin_memory:
257
+ verify_sufficient_virtual_memory()
258
+
259
+ self.is_first_backward_call = False
260
+ self.is_first_forward_call = True
261
+
262
+ assert (
263
+ unpack_tensor_id in self.tracker
264
+ ), f"untracked tensor with id {unpack_tensor_id}"
265
+
266
+ maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
267
+ if modified:
268
+ # Get data on the current autograd node
269
+ graph_id = torch._C._current_graph_task_id()
270
+ node = torch._C._current_autograd_node()
271
+ prev_node_ids = []
272
+
273
+ # If we're on a new node, mark prev node's tensors to be freed later
274
+ if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
275
+ self.curr_autograd_node = node
276
+ prev_node_ids = [id for id in self.bwd_tensor_stash.keys()]
277
+
278
+ brought_back_from_cpu = True
279
+ if unpack_tensor_id in self.fwd_stash:
280
+ maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0]
281
+ brought_back_from_cpu = False
282
+ else:
283
+ # Kick off the process to bring tensors back
284
+ with torch.cuda.stream(self.s1):
285
+ gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
286
+ maybe_gpu_tensor = gpu_tensor
287
+
288
+ # Tell comp stream to wait for the info to be loaded before executing
289
+ self.s0.wait_stream(self.s1)
290
+
291
+ # Stash the tensor to keep memory alive until compute stream is complete
292
+ self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor
293
+
294
+ # Note: [Track views of the unpacked]
295
+ # Why do we get the use count of the unpacked tensor here? We want an
296
+ # initial count to compare to later, during the post-hook of the
297
+ # backward node, when we need to decide whether we're allowed to free
298
+ # the tensor yet. In what obscure cases must we delay freeing the
299
+ # tensor (and thus call record_stream)?
300
+ # 1. Any of the outputs of the backward node is a view of the unpacked
301
+ # tensor.
302
+ # 2. In the case that this unpacked tensor will be used in a
303
+ # checkpointed region, if one of the recomputed saved tensors ends
304
+ # up as a view of the unpacked tensor.
305
+ # 3. The user abuses the system somehow and manually relies on the
306
+ # unpacked tensor to exist after the backward node has executed.
307
+ storage_refcount = torch._C._storage_Use_Count(
308
+ maybe_gpu_tensor.untyped_storage()._cdata
309
+ )
310
+
311
+ def hook(outputs, inputs):
312
+ # create events for the current node inputs/outputs if they were streamed in
313
+ if brought_back_from_cpu:
314
+ # See Note: [Track views of the unpacked]
315
+ # IF any of the outputs is a view of the tensor, OR if a view of
316
+ # the tensor has been saved as a part of checkpoint's recompute
317
+ # process, OR the user has abusedly incurred a reference on the
318
+ # unpacked tensor, THEN the tensor might be used later and we
319
+ # cannot presume to delete it after only the current node is
320
+ # done! So we use our frenemy, record_stream, to ensure the
321
+ # Tensor stays unmessed with until it's done getting used in the
322
+ # compute stream (s0 here). Note that the con here is we introduce
323
+ # non-deterministic (thus higher) memory usage, but this case
324
+ # should not happen often.
325
+ unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
326
+ if (
327
+ torch._C._storage_Use_Count(
328
+ unpacked_tensor.untyped_storage()._cdata
329
+ )
330
+ > storage_refcount
331
+ ):
332
+ unpacked_tensor.record_stream(self.s0)
333
+ del self.bwd_tensor_stash[unpack_tensor_id]
334
+ else:
335
+ event = self.s0.record_event()
336
+ self.bwd_ev_stash[unpack_tensor_id] = event
337
+
338
+ # if there are still things in the fwd_stash, get rid of them as we're in bwd now
339
+ for id in [k for k in self.fwd_stash.keys()]:
340
+ _, ev = self.fwd_stash[id]
341
+ self.s0.wait_event(ev)
342
+ del self.fwd_stash[id]
343
+
344
+ # wait on prev node's events and del those
345
+ for id in prev_node_ids:
346
+ event = self.bwd_ev_stash[id]
347
+ self.s1.wait_event(event)
348
+ del self.bwd_tensor_stash[id]
349
+
350
+ return outputs
351
+
352
+ node.register_hook(hook)
353
+
354
+ # clear tensor from tracking
355
+ del self.tracker[unpack_tensor_id]
356
+ return maybe_gpu_tensor
357
+
358
+ unpack_tensor = (
359
+ unpack_tensor_with_streams
360
+ if self.use_streams
361
+ else unpack_tensor_single_stream
362
+ )
363
+ super().__init__(pack_tensor, unpack_tensor)
364
+
365
+
366
+ class NoOpManager(saved_tensors_hooks):
367
+ """
368
+ A saved_tensors_hook manager used to disable any other saved_tensors_hook manager
369
+ applied before. This relies on the behavior that only the most recently registered
370
+ saved_tensors_hook will run.
371
+
372
+ One example usage is to opt a local region of code out of activations offloading,
373
+ which is usually applied globally to best track state.
374
+ """
375
+
376
+ def __init__(self) -> None:
377
+ def noop(tensor):
378
+ return tensor
379
+
380
+ super().__init__(noop, noop)
381
+
382
+
383
+ def get_act_offloading_ctx_manager(
384
+ model: nn.Module, enable_activation_offloading: bool
385
+ ) -> Union[OffloadActivations, contextlib.nullcontext]:
386
+ """Returns the activation offloading context manager for the model, which will be
387
+ a null context if enable_activation_offloading is False.
388
+
389
+ If activation offloading is enabled, we return the OffloadActivations context manager.
390
+ If activation offloading is disabled, we return a NoOpManager context manager.
391
+
392
+ Args:
393
+ model (nn.Module): the model to wrap with the activation offloading context manager.
394
+ enable_activation_offloading (bool): whether or not to enable activation offloading
395
+ for the model.
396
+
397
+ Returns:
398
+ contextlib.ContextDecorator: the activation offloading context manager for the model.
399
+
400
+ Raises:
401
+ NotImplementedError: If the model is a multimodal model and activation offloading is enabled.
402
+ """
403
+ if enable_activation_offloading:
404
+ activations_handling_ctx = OffloadActivations()
405
+
406
+ # Below is our hack to disable offloading the last output Linear in every
407
+ # step, as the cost for offloading the activation and then soon after bringing
408
+ # it back is expensive. Moreover, due to heuristics in our streaming API,
409
+ # we actually use more memory if we offload it as it interferes with chunkedCE.
410
+ output_head_detected = False
411
+ noop_ctx = NoOpManager()
412
+
413
+ if hasattr(model, "output"):
414
+ if isinstance(model.output, nn.Module):
415
+ model.output.register_forward_pre_hook(
416
+ lambda *args: noop_ctx.__enter__()
417
+ )
418
+ model.output.register_forward_hook(
419
+ lambda *args: noop_ctx.__exit__(), always_call=True
420
+ )
421
+ print("registering hooks for model.output ============ ")
422
+ output_head_detected = True
423
+ # ================================
424
+ # ! TODO[flame] check if we need to detal with TiedLinear
425
+ # The following code appears in `torchtune`
426
+ # elif isinstance(model.output, TiedLinear):
427
+ # model.output.linear.register_forward_pre_hook(
428
+ # lambda *args: noop_ctx.__enter__()
429
+ # )
430
+ # model.output.linear.register_forward_hook(
431
+ # lambda *args: noop_ctx.__exit__(), always_call=True
432
+ # )
433
+ # output_head_detected = True
434
+
435
+ if not output_head_detected:
436
+ logger.warning(
437
+ "During activation offloading, no output head was detected. "
438
+ "If your model has an output head, it will be offloaded. "
439
+ "This usually greatly slows training, given the large vocabulary size. "
440
+ "To change this behavior, set your output head as model.output and make it "
441
+ "an nn.Module."
442
+ )
443
+
444
+ else:
445
+ activations_handling_ctx = contextlib.nullcontext()
446
+
447
+ return activations_handling_ctx
flame/models/fla.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [model]
2
+ config = "fla-hub/transformer-1.3B-100B"
3
+ tokenizer_path = "fla-hub/transformer-1.3B-100B"
4
+
5
+ [job]
6
+ dump_folder = "exp"
7
+ print_args = true
8
+
9
+ [training]
10
+ batch_size = 32
11
+ seq_len = 2048
12
+ context_len = 2048
13
+ gradient_accumulation_steps = 1
14
+ steps = 20480
15
+ max_norm = 1.0
16
+ skip_nan_inf = true
17
+ data_parallel_replicate_degree = 1
18
+ data_parallel_shard_degree = -1
19
+ tensor_parallel_degree = 1
20
+ compile = false
21
+ dataset = "HuggingFaceFW/fineweb-edu"
22
+ dataset_name = "default"
23
+ num_workers = 32
24
+ pin_memory = false
25
+ persistent_workers = false
26
+ prefetch_factor = 2
27
+ seed = 42
28
+ varlen = false
29
+
30
+ [optimizer]
31
+ name = "AdamW"
32
+ eps = 1e-15
33
+ lr = 3e-4
34
+
35
+ [lr_scheduler]
36
+ warmup_steps = 1024
37
+ decay_type = "cosine"
38
+ lr_min = 0.1
39
+
40
+ [checkpoint]
41
+ enable_checkpoint = true
42
+ folder = "checkpoint"
43
+ interval_type = "steps"
44
+ interval = 2048
45
+ model_weights_only = false
46
+ export_dtype = "float32"
47
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
48
+
49
+ [profiling]
50
+ enable_profiling = true
51
+ save_traces_folder = "profile_trace"
52
+ profile_freq = 512
53
+
54
+ [metrics]
55
+ log_freq = 32
56
+ enable_wandb = true
57
+
58
+ [experimental]
59
+ context_parallel_degree = 1
60
+ pipeline_parallel_degree = 1
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+
66
+ [activation_checkpoint]
67
+ mode = "none"
flame/models/parallelize_fla.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
16
+ from torch.distributed._composable.replicate import replicate
17
+ from torch.distributed._tensor import Replicate, Shard
18
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
19
+ from torch.distributed.tensor.parallel import (
20
+ ColwiseParallel,
21
+ PrepareModuleInput,
22
+ PrepareModuleOutput,
23
+ RowwiseParallel,
24
+ SequenceParallel,
25
+ parallelize_module
26
+ )
27
+
28
+ from fla.modules.fused_linear_cross_entropy import LinearLossParallel
29
+ from fla.modules.mlp import SwiGLULinearParallel
30
+ from fla.modules.parallel import PrepareModuleWeight
31
+ from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
32
+ from torchtitan.distributed.parallel_dims import ParallelDims
33
+ from torchtitan.tools.logging import logger
34
+
35
+
36
+ def parallelize_fla(
37
+ model: nn.Module,
38
+ world_mesh: DeviceMesh,
39
+ parallel_dims: ParallelDims,
40
+ job_config: JobConfig,
41
+ ):
42
+ """
43
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
44
+ parallelism to the model.
45
+
46
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
47
+ the model must fit on GPU or CPU memory.
48
+ """
49
+
50
+ if parallel_dims.tp_enabled:
51
+ if (
52
+ job_config.experimental.enable_async_tensor_parallel
53
+ and not job_config.training.compile
54
+ ):
55
+ raise RuntimeError("Async TP requires --training.compile")
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8=enable_float8_linear,
62
+ enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
63
+ )
64
+
65
+ if job_config.activation_checkpoint.mode != "none":
66
+ apply_ac(model, job_config.activation_checkpoint)
67
+
68
+ # turn on per-block compile after AC wrapping and before FSDP
69
+ if job_config.training.compile:
70
+ apply_compile(model)
71
+
72
+ if (
73
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
74
+ ): # apply FSDP or HSDP, potentially with Context Parallel
75
+ if parallel_dims.dp_replicate_enabled:
76
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
77
+ else:
78
+ dp_mesh_dim_names = ("dp_shard_cp",)
79
+
80
+ apply_fsdp(
81
+ model,
82
+ world_mesh[tuple(dp_mesh_dim_names)],
83
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
84
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
85
+ pp_enabled=parallel_dims.pp_enabled,
86
+ cpu_offload=job_config.training.enable_cpu_offload,
87
+ reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
88
+ )
89
+
90
+ if parallel_dims.dp_replicate_enabled:
91
+ logger.info("Applied HSDP to the model")
92
+ else:
93
+ logger.info("Applied FSDP to the model")
94
+
95
+ if parallel_dims.cp_enabled:
96
+ logger.info("Applied Context Parallel to the model")
97
+
98
+ if job_config.training.enable_cpu_offload:
99
+ logger.info("Applied CPU Offloading to the model")
100
+ elif parallel_dims.dp_replicate_enabled:
101
+ if world_mesh.ndim > 1:
102
+ raise RuntimeError("DDP has not supported > 1D parallelism")
103
+ apply_ddp(
104
+ model,
105
+ world_mesh,
106
+ enable_compile=job_config.training.compile,
107
+ enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
108
+ )
109
+
110
+
111
+ class TPPlan:
112
+ def __init__(
113
+ self,
114
+ model=None,
115
+ loss_parallel=False,
116
+ enable_float8=False,
117
+ ):
118
+ self.model = model
119
+ self.loss_parallel = loss_parallel
120
+ self.enable_float8 = enable_float8
121
+ self.base_model_prefix = getattr(model, "base_model_prefix", "model")
122
+
123
+ # TODO(vkuzo): once float8 configuration supports delayed scaling,
124
+ # add a check here to enforce supported float8 all-gather configurations
125
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
126
+ try:
127
+ from torchao.float8.float8_tensor_parallel import (
128
+ Float8ColwiseParallel,
129
+ Float8RowwiseParallel,
130
+ PrepareFloat8ModuleInput
131
+ )
132
+ except ImportError:
133
+ Float8ColwiseParallel = None
134
+ Float8RowwiseParallel = None
135
+ PrepareFloat8ModuleInput = None
136
+ if self.enable_float8 and Float8ColwiseParallel is not None:
137
+ self.rowwise_parallel = Float8RowwiseParallel
138
+ self.colwise_parallel = Float8ColwiseParallel
139
+ self.prepare_module_input = PrepareFloat8ModuleInput
140
+ self.prepare_module_output = PrepareModuleOutput
141
+ else:
142
+ self.rowwise_parallel = RowwiseParallel
143
+ self.colwise_parallel = ColwiseParallel
144
+ self.prepare_module_input = PrepareModuleInput
145
+ self.prepare_module_output = PrepareModuleOutput
146
+
147
+ @property
148
+ def model_plan(self):
149
+ plans = {
150
+ f"{self.base_model_prefix}.embeddings": RowwiseParallel(
151
+ input_layouts=Replicate(),
152
+ output_layouts=Shard(1),
153
+ ),
154
+ f"{self.base_model_prefix}.norm": SequenceParallel(),
155
+ }
156
+ if self.loss_parallel:
157
+ plans.update(
158
+ {
159
+ "lm_head": ColwiseParallel(
160
+ input_layouts=Shard(1),
161
+ output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
162
+ use_local_output=not self.loss_parallel,
163
+ ),
164
+ }
165
+ )
166
+ else:
167
+ plans.update(
168
+ {
169
+ "lm_head": PrepareModuleWeight(layouts=Replicate()),
170
+ "criterion": LinearLossParallel(),
171
+ }
172
+ )
173
+ return plans
174
+
175
+ @property
176
+ def layer_plan(self):
177
+ return {
178
+ "attn_norm": SequenceParallel(),
179
+ **self.attn_plan,
180
+ "mlp_norm": SequenceParallel(),
181
+ **self.mlp_plan,
182
+ }
183
+
184
+ @property
185
+ def attn_plan(self):
186
+ raise NotImplementedError(
187
+ f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
188
+ )
189
+
190
+ @property
191
+ def mlp_plan(self):
192
+ return {
193
+ "mlp": self.prepare_module_input(
194
+ input_layouts=(Shard(1),),
195
+ desired_input_layouts=(Replicate(),),
196
+ ),
197
+ "mlp.gate_proj": self.colwise_parallel(),
198
+ "mlp.up_proj": self.colwise_parallel(),
199
+ "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
200
+ "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
201
+ }
202
+
203
+
204
+ class TransformerTPPlan(TPPlan):
205
+
206
+ @property
207
+ def attn_plan(self):
208
+ return {
209
+ "attn": self.prepare_module_input(
210
+ input_kwarg_layouts={"hidden_states": Shard(1)},
211
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
212
+ ),
213
+ "attn.q_proj": self.colwise_parallel(),
214
+ "attn.k_proj": self.colwise_parallel(),
215
+ "attn.v_proj": self.colwise_parallel(),
216
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
217
+ }
218
+
219
+
220
+ class GLATPPlan(TPPlan):
221
+
222
+ @property
223
+ def attn_plan(self):
224
+ return {
225
+ "attn": self.prepare_module_input(
226
+ input_kwarg_layouts={"hidden_states": Shard(1)},
227
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
228
+ ),
229
+ "attn.q_proj": self.colwise_parallel(),
230
+ "attn.k_proj": self.colwise_parallel(),
231
+ "attn.v_proj": self.colwise_parallel(),
232
+ "attn.g_proj": self.colwise_parallel(),
233
+ "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
234
+ "attn.gk_proj.1": self.colwise_parallel(),
235
+ "attn.g_norm": SequenceParallel(sequence_dim=-1),
236
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
237
+ }
238
+
239
+
240
+ TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
241
+
242
+
243
+ def apply_tp(
244
+ model: nn.Module,
245
+ tp_mesh: DeviceMesh,
246
+ loss_parallel: bool,
247
+ enable_float8: bool,
248
+ enable_async_tp: bool,
249
+ ):
250
+ """Apply tensor parallelism."""
251
+ # 1. Parallelize the embedding and shard its outputs (which are the first
252
+ # transformer block's inputs)
253
+ # 2. Parallelize the root norm layer over the sequence dim
254
+ # 3. Parallelize the final linear output layer
255
+ tp_plan = TP_PLAN_MAP[model.config.model_type](
256
+ model, loss_parallel=loss_parallel, enable_float8=enable_float8
257
+ )
258
+ parallelize_module(model, tp_mesh, tp_plan.model_plan)
259
+
260
+ blocks = get_blocks(model)
261
+ if blocks is None:
262
+ logger.warning("No block found for tensor parallelism")
263
+ else:
264
+ for _, block in enumerate(blocks):
265
+ parallelize_module(
266
+ module=block,
267
+ device_mesh=tp_mesh,
268
+ parallelize_plan=tp_plan.layer_plan,
269
+ )
270
+
271
+ if enable_async_tp:
272
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
273
+
274
+ torch._inductor.config._micro_pipeline_tp = True
275
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
276
+
277
+ logger.info(
278
+ f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
279
+ "Tensor Parallelism to the model"
280
+ )
281
+
282
+
283
+ # for selective op activation checkpointing
284
+ _save_list = {
285
+ torch.ops.aten.mm.default,
286
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
287
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
288
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
289
+ # for low precision training, it's useful to always save
290
+ # the result of max, since the absolute maximum is
291
+ # used to compute the scaling factor for quantization.
292
+ torch.ops.aten.max.default,
293
+ }
294
+
295
+
296
+ def _apply_ac_to_block(module: nn.Module, ac_config):
297
+ valid_ac_modes = ("full", "selective")
298
+ if ac_config.mode not in valid_ac_modes:
299
+ raise ValueError(
300
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
301
+ )
302
+
303
+ if ac_config.mode == "full":
304
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
305
+
306
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
307
+ use_op_sac = ac_config.selective_ac_option == "op"
308
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
309
+ if not use_op_sac and not use_layer_sac:
310
+ raise ValueError(
311
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
312
+ f"Valid options: 'op' or a positive int representing layer frequency"
313
+ )
314
+ if use_op_sac:
315
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
316
+
317
+ def _get_custom_policy(meta):
318
+ def _custom_policy(ctx, func, *args, **kwargs):
319
+ mode = "recompute" if ctx.is_recompute else "forward"
320
+ mm_count_key = f"{mode}_mm_count"
321
+ if func == torch.ops.aten.mm.default:
322
+ meta[mm_count_key] += 1
323
+ # Saves output of all compute ops, except every second mm
324
+ to_save = func in _save_list and not (
325
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
326
+ )
327
+ return (
328
+ CheckpointPolicy.MUST_SAVE
329
+ if to_save
330
+ else CheckpointPolicy.PREFER_RECOMPUTE
331
+ )
332
+
333
+ return _custom_policy
334
+
335
+ def selective_checkpointing_context_fn():
336
+ meta = defaultdict(int)
337
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
338
+
339
+ return ptd_checkpoint_wrapper(
340
+ module,
341
+ context_fn=selective_checkpointing_context_fn,
342
+ preserve_rng_state=False,
343
+ )
344
+ elif use_layer_sac:
345
+ # Checkpoint every `ac_freq` of the modules passed to this function
346
+ ac_freq = int(ac_config.selective_ac_option)
347
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
348
+ ptd_checkpoint_wrapper._count += 1
349
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
350
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
351
+ else:
352
+ return module
353
+
354
+
355
+ def apply_ac(model: nn.Module, ac_config):
356
+ """Apply activation checkpointing to the model."""
357
+ blocks = get_blocks(model)
358
+ if blocks is None:
359
+ logger.warning("No block found for activation checkpointing")
360
+ return
361
+
362
+ for layer_id, block in blocks.named_children():
363
+ block = _apply_ac_to_block(block, ac_config)
364
+ blocks.register_module(layer_id, block)
365
+
366
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
367
+
368
+
369
+ def apply_compile(model: nn.Module):
370
+ """
371
+ Apply torch.compile to each block, which makes compilation efficient due to
372
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
373
+ """
374
+
375
+ blocks = get_blocks(model)
376
+ if blocks is None:
377
+ logger.warning("No block found for torch.compile")
378
+ else:
379
+ for layer_id, block in blocks.named_children():
380
+ block = torch.compile(block)
381
+ blocks.register_module(layer_id, block)
382
+ logger.info("Compiling each block with torch.compile")
383
+
384
+ real_model = get_model(model)
385
+
386
+ logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
387
+ embeddings_key = get_components_name(real_model, "tok_embeddings")
388
+ if embeddings_key is not None:
389
+ embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
390
+ real_model.register_module(embeddings_key, embeddings)
391
+
392
+ norm_key = get_components_name(real_model, "norm")
393
+ if norm_key is not None:
394
+ norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
395
+ real_model.register_module(norm_key, norm)
396
+
397
+ lm_head_key = get_components_name(model, "lm_head")
398
+ if lm_head_key is not None:
399
+ lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
400
+ model.register_module(lm_head_key, lm_head)
401
+
402
+ logger.info("Compiling the entire model with torch.compile")
403
+ model = torch.compile(model)
404
+
405
+
406
+ def apply_fsdp(
407
+ model: nn.Module,
408
+ dp_mesh: DeviceMesh,
409
+ param_dtype: torch.dtype,
410
+ reduce_dtype: torch.dtype,
411
+ pp_enabled: bool,
412
+ cpu_offload: bool = False,
413
+ reshard_after_forward_policy: str = "default",
414
+ ):
415
+ """
416
+ Apply data parallelism (via FSDP2) to the model.
417
+
418
+ Args:
419
+ model (nn.Module): The model to apply data parallelism to.
420
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
421
+ param_dtype (torch.dtype): The data type to use for model parameters.
422
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
423
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
424
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
425
+ reshard_after_forward_policy (str, optional):
426
+ The policy to use for resharding after forward pass. Defaults to "default".
427
+ Other options: "never", "always".
428
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
429
+ - "always" will enable `reshard_after_forward` for all forward passes.
430
+ - "never" will disable `reshard_after_forward` for all forward passes.
431
+
432
+ """
433
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
434
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
435
+ if cpu_offload:
436
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
437
+
438
+ blocks = get_blocks(model)
439
+ if blocks is None:
440
+ logger.warning("No block found for FSDP")
441
+ else:
442
+ total_blocks = len(blocks)
443
+ for layer_id, block in enumerate(blocks):
444
+ if reshard_after_forward_policy == "always":
445
+ reshard_after_forward = True
446
+ elif reshard_after_forward_policy == "never":
447
+ reshard_after_forward = False
448
+ elif reshard_after_forward_policy == "default":
449
+ if pp_enabled:
450
+ # For PP, do not reshard after forward to avoid per-microbatch
451
+ # all-gathers, which can be expensive and non-overlapped
452
+ reshard_after_forward = False
453
+ else:
454
+ # As an optimization, do not reshard after forward for the last
455
+ # transformer block since FSDP would prefetch it immediately
456
+ reshard_after_forward = int(layer_id) < total_blocks - 1
457
+ else:
458
+ raise ValueError(
459
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
460
+ )
461
+ fully_shard(
462
+ block,
463
+ **fsdp_config,
464
+ reshard_after_forward=reshard_after_forward,
465
+ )
466
+
467
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
468
+
469
+
470
+ def apply_ddp(
471
+ model: nn.Module,
472
+ dp_mesh: DeviceMesh,
473
+ enable_compile: bool,
474
+ enable_compiled_autograd: bool,
475
+ ):
476
+ if enable_compile:
477
+ if enable_compiled_autograd:
478
+ torch._dynamo.config.optimize_ddp = (
479
+ "python_reducer_without_compiled_forward"
480
+ )
481
+ else:
482
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
483
+
484
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
485
+
486
+ logger.info("Applied DDP to the model")
487
+
488
+
489
+ def get_model(model):
490
+ base_model_prefix = getattr(model, "base_model_prefix", "model")
491
+ if not hasattr(model, base_model_prefix):
492
+ return None
493
+ model = getattr(model, base_model_prefix)
494
+ return model
495
+
496
+
497
+ def get_blocks(model):
498
+ # TODO[flame]: adapt for network not using 'layers' attribute
499
+ model = get_model(model)
500
+ if not hasattr(model, "layers"):
501
+ logger.warning('no "layers" in model can be found')
502
+ return None
503
+ return model.layers
504
+
505
+
506
+ def get_components_name(model, component_name):
507
+ """
508
+ We try to catch tok_embeddings, norm layers and lm_head layers
509
+ We do not catch the layer names in the blocks, for blocks see `get_blocks`
510
+ We assume the model has the following structure:
511
+ LlamaForCausalLM:
512
+ Model:
513
+ embed_tokens,
514
+ layers,
515
+ norm,
516
+ lm_head
517
+ ***
518
+ so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
519
+ and for 'lm_head' we need to pass `model`
520
+ ***
521
+ """
522
+
523
+ if component_name == "tok_embeddings":
524
+ if hasattr(model, "tok_embeddings"):
525
+ return "tok_embeddings"
526
+ elif hasattr(model, "embed_tokens"):
527
+ return "embed_tokens"
528
+ elif hasattr(model, "embeddings"):
529
+ return "embeddings"
530
+ else:
531
+ logger.warning("No tok_embeddings found in model")
532
+ return None
533
+
534
+ elif component_name == "norm":
535
+ if hasattr(model, "norm"):
536
+ return "norm"
537
+ elif hasattr(model, "norms"):
538
+ return "norms"
539
+ elif hasattr(model, "layernorm"):
540
+ return "layernorm"
541
+ else:
542
+ logger.warning("No norm found in model")
543
+ return None
544
+
545
+ elif component_name == "lm_head":
546
+ if hasattr(model, "lm_head"):
547
+ return "lm_head"
548
+ else:
549
+ logger.warning("No lm_head found in model")
550
+ return None
flame/models/pipeline_fla.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D pipeline parallelism to the Llama model.
8
+
9
+ import copy
10
+ from typing import Callable, Optional, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.pipelining import PipelineStage
16
+ from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
17
+ from transformers import PretrainedConfig
18
+
19
+ from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
20
+ from torchtitan.config_manager import JobConfig
21
+ from torchtitan.distributed.parallel_dims import ParallelDims
22
+ from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
23
+ from torchtitan.tools.logging import logger
24
+
25
+ DeviceType = Union[int, str, torch.device]
26
+
27
+
28
+ def pipeline_fla(
29
+ model: nn.Module,
30
+ pp_mesh: DeviceMesh,
31
+ parallel_dims: ParallelDims,
32
+ job_config: JobConfig,
33
+ device: DeviceType,
34
+ model_config: PretrainedConfig,
35
+ loss_fn: Callable[..., torch.Tensor],
36
+ ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
37
+ stages, models = pipeline_fla_manual_split(
38
+ model, pp_mesh, parallel_dims, job_config, device, model_config
39
+ )
40
+
41
+ pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
42
+
43
+ # This is used in the train loop to determine whether to pass in the input_ids and labels
44
+ has_first_stage = False
45
+ has_last_stage = False
46
+ for stage in stages:
47
+ if stage.is_first:
48
+ has_first_stage = True
49
+ if stage.is_last:
50
+ has_last_stage = True
51
+
52
+ return pp_schedule, models, has_first_stage, has_last_stage
53
+
54
+
55
+ def pipeline_fla_manual_split(
56
+ whole_model: nn.Module,
57
+ pp_mesh: DeviceMesh,
58
+ parallel_dims: ParallelDims,
59
+ job_config: JobConfig,
60
+ device: DeviceType,
61
+ model_config: PretrainedConfig,
62
+ ) -> tuple[list[PipelineStage], list[nn.Module]]:
63
+ """
64
+ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
65
+
66
+ It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
67
+
68
+ The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
69
+ parallelism.
70
+ """
71
+ pp_rank = pp_mesh.get_local_rank()
72
+ pp_size = pp_mesh.size()
73
+
74
+ splits = (
75
+ job_config.experimental.pipeline_parallel_split_points
76
+ or generate_split_points(
77
+ job_config, parallel_dims.pp, model_config.num_hidden_layers
78
+ )
79
+ )
80
+
81
+ def _build_stage(
82
+ stage_idx: int,
83
+ start_layer: Optional[str],
84
+ stop_layer: Optional[str],
85
+ is_first: bool = False,
86
+ is_last: bool = False,
87
+ ) -> tuple[PipelineStage, nn.Module]:
88
+ model = copy.deepcopy(whole_model)
89
+ if not is_first:
90
+ # we do `model.tok_embeddings = None` here
91
+ real_model = get_model(model)
92
+ tok_embeddings_name = get_components_name(real_model, "tok_embeddings")
93
+ setattr(real_model, tok_embeddings_name, None)
94
+
95
+ drop_layers = start_layer is not None
96
+ # Get module dictionary from get_blocks(model)
97
+ # and Create a list of keys before modifying dictionary
98
+ module_dict = get_blocks(model)._modules # Store reference
99
+ layer_names = list(module_dict.keys())
100
+
101
+ # Iterate over the list of keys instead of `_modules.items()`
102
+ for name in layer_names:
103
+ # Dynamically determine prefix (blocks.* or layers.*)
104
+ prefix = start_layer.split(".")[0] if start_layer else "layers"
105
+ layer_name = f"{prefix}.{name}" # Construct the correct name format
106
+
107
+ # Ensure `drop_layers` activation is based on actual naming
108
+ if layer_name == start_layer:
109
+ drop_layers = False
110
+ if layer_name == stop_layer:
111
+ drop_layers = True
112
+
113
+ # Delete layer if drop_layers is active
114
+ if drop_layers:
115
+ del module_dict[name] # Safe deletion from stored dictionary
116
+
117
+ if not is_last:
118
+ # we do `model.norm = None` and `model.output = None`
119
+ real_model = get_model(model)
120
+ norm_name = get_components_name(real_model, "norm")
121
+ setattr(real_model, norm_name, None)
122
+
123
+ head_name = get_components_name(model, "lm_head")
124
+ setattr(model, head_name, None)
125
+
126
+ stage = PipelineStage(
127
+ model,
128
+ stage_idx,
129
+ num_stages,
130
+ device,
131
+ group=pp_mesh.get_group("pp"),
132
+ )
133
+ return stage, model
134
+
135
+ num_stages = len(splits) + 1
136
+ stage_idx = pp_rank
137
+
138
+ stages = []
139
+ models = []
140
+
141
+ schedule_class = get_schedule_class(
142
+ job_config.experimental.pipeline_parallel_schedule
143
+ )
144
+ style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
145
+
146
+ for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
147
+ start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
148
+ stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
149
+ stage, model_chunk = _build_stage(
150
+ stage_idx,
151
+ start_layer,
152
+ stop_layer,
153
+ is_first=stage_idx == 0,
154
+ is_last=stage_idx == num_stages - 1,
155
+ )
156
+ logger.info(
157
+ f"PP rank {pp_rank} is building stage_idx {stage_idx}"
158
+ f" with start_layer {start_layer}, stop_layer {stop_layer}"
159
+ )
160
+ stages.append(stage)
161
+ models.append(model_chunk)
162
+ return stages, models
flame/tools/__init__.py ADDED
File without changes
flame/tools/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (136 Bytes). View file
 
flame/tools/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.14 kB). View file
 
flame/tools/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+ from torchtitan.tools.logging import logger
9
+
10
+
11
+ def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
12
+ nparams = sum(p.numel() for p in model.parameters())
13
+ nparams_embedding = sum(
14
+ sum(p.numel() for p in m.parameters())
15
+ for m in model.children()
16
+ if isinstance(m, nn.Embedding)
17
+ )
18
+
19
+ if hasattr(model_config, "num_heads"):
20
+ num_heads = model_config.num_heads
21
+ elif hasattr(model_config, "num_attention_heads"):
22
+ num_heads = model_config.num_attention_heads
23
+ else:
24
+ num_heads = 1
25
+ logger.warning("num_heads not found in model_config, defaulting to 1. ")
26
+
27
+ l, h, q, t = (
28
+ model_config.num_hidden_layers,
29
+ num_heads,
30
+ model_config.hidden_size // num_heads,
31
+ seq_len,
32
+ )
33
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
34
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
35
+ # 2. the flash attention does 1 more matmul recomputation in the backward
36
+ # but recomputation should not be counted in calculating MFU (+0)
37
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
38
+ # 4. we follow the convention and do not account for sparsity in causal attention
39
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
40
+
41
+ return nparams, num_flops_per_token
flame/utils/__init__.py ADDED
File without changes
flame/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (136 Bytes). View file
 
flame/utils/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (4.07 kB). View file
 
flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc ADDED
Binary file (3.73 kB). View file
 
flame/utils/__pycache__/convert_hf_to_dcp.cpython-312.pyc ADDED
Binary file (1.92 kB). View file
 
flame/utils/__pycache__/hf_utils.cpython-312.pyc ADDED
Binary file (4.46 kB). View file
 
flame/utils/checkpoint.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import re
4
+ import shutil
5
+ from torchtitan.tools.logging import logger
6
+
7
+
8
+ def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int):
9
+ """Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats."""
10
+ if keep_latest_k <= 0:
11
+ return # Keep all checkpoints
12
+
13
+ logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}")
14
+
15
+ # Cleanup DCP checkpoints (step-*)
16
+ dcp_checkpoints = sorted(
17
+ glob.glob(os.path.join(checkpoint_dir, "step-*")),
18
+ key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1,
19
+ reverse=True
20
+ )
21
+ # Filter out HF format directories
22
+ dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")]
23
+
24
+ if len(dcp_checkpoints) > keep_latest_k:
25
+ checkpoints_to_delete = dcp_checkpoints[keep_latest_k:]
26
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
27
+ for ckpt_path in checkpoints_to_delete:
28
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
29
+ try:
30
+ shutil.rmtree(ckpt_path)
31
+ except OSError as e:
32
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
33
+
34
+
35
+ # Cleanup HF checkpoints (step-*-hf)
36
+ hf_checkpoints = sorted(
37
+ glob.glob(os.path.join(checkpoint_dir, "step-*-hf")),
38
+ key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1,
39
+ reverse=True
40
+ )
41
+
42
+ if len(hf_checkpoints) > keep_latest_k:
43
+ checkpoints_to_delete = hf_checkpoints[keep_latest_k:]
44
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
45
+ for ckpt_path in checkpoints_to_delete:
46
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
47
+ try:
48
+ shutil.rmtree(ckpt_path)
49
+ except OSError as e:
50
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
flame/utils/convert_dcp_to_hf.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ import io
6
+ import os
7
+ import tempfile
8
+ from datetime import timedelta
9
+
10
+ import torch
11
+ import torch.serialization
12
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
13
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
14
+
15
+ import fla # noqa
16
+ from torchtitan.tools.logging import init_logger, logger
17
+
18
+
19
+ @torch.inference_mode()
20
+ def save_pretrained(
21
+ path: str,
22
+ step: int,
23
+ config: str,
24
+ tokenizer: str
25
+ ):
26
+ logger.info(f"Loading the config from {config}")
27
+ config = AutoConfig.from_pretrained(config, trust_remote_code=True)
28
+
29
+ logger.info(f"Saving the config to {path}")
30
+ config.save_pretrained(path)
31
+ logger.info(f"Loading the tokenizer from {tokenizer}")
32
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
33
+ logger.info(f"Saving the tokenizer to {path}")
34
+ tokenizer.save_pretrained(path)
35
+
36
+ with tempfile.TemporaryDirectory() as tmpdir:
37
+ # base_checkpoint_dir = os.path.dirname(path)
38
+ base_checkpoint_dir = path
39
+ checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}')
40
+ checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
41
+ logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
42
+ dcp_to_torch_save(checkpoint, checkpoint_path)
43
+
44
+ logger.info(f"Initializing the model from config\n{config}")
45
+ model = AutoModelForCausalLM.from_config(config)
46
+ logger.info(model)
47
+ logger.info("Loading state dict from the checkpoint")
48
+
49
+ # Add datetime.timedelta and io.BytesIO to safe globals
50
+ torch.serialization.add_safe_globals([timedelta, io.BytesIO])
51
+ # torch.load now with default weights_only=True will work
52
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
53
+
54
+ logger.info(f"Saving the model to {path}")
55
+ model.save_pretrained(path)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ init_logger()
60
+ parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
61
+ parser.add_argument("--path", type=str, required=True)
62
+ parser.add_argument("--step", type=int, required=True)
63
+ parser.add_argument("--config", type=str, required=True)
64
+ parser.add_argument("--tokenizer", type=str, required=True)
65
+ args = parser.parse_args()
66
+ save_pretrained(args.path, args.step, args.config, args.tokenizer)
flame/utils/convert_hf_to_dcp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.distributed.checkpoint as DCP
9
+ from transformers import AutoModelForCausalLM
10
+
11
+ import fla # noqa
12
+ from torchtitan.tools.logging import init_logger, logger
13
+
14
+
15
+ @torch.inference_mode()
16
+ def convert_hf_weights(model: str, checkpoint: str):
17
+ logger.info(f"Loading model from {model}")
18
+ model = AutoModelForCausalLM.from_pretrained(model)
19
+ state_dict = model.state_dict()
20
+
21
+ logger.info(f"Writing to DCP at '{checkpoint}'")
22
+ checkpoint.mkdir(parents=True, exist_ok=True)
23
+ storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
24
+ DCP.save({"model": state_dict}, storage_writer=storage_writer)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ init_logger()
29
+ parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
30
+ parser.add_argument("--model", type=str, required=True)
31
+ parser.add_argument("--checkpoint", type=Path, required=True)
32
+ args = parser.parse_args()
33
+
34
+ convert_hf_weights(args.model, args.checkpoint)
flame/utils/hf_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo
4
+ from torchtitan.tools.logging import logger
5
+
6
+ def upload_checkpoint_to_hf(
7
+ local_path: str,
8
+ step: int,
9
+ hf_repo_id_for_run: str,
10
+ hf_keep_latest_k: int,
11
+ upload_format: str
12
+ ):
13
+ """Uploads a checkpoint directory to HF Hub and manages retention."""
14
+ if not os.path.isdir(local_path):
15
+ logger.error(f"Local path for upload does not exist or is not a directory: {local_path}")
16
+ return
17
+
18
+ api = HfApi()
19
+ token = HfFolder.get_token()
20
+ if not token:
21
+ logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.")
22
+ return
23
+
24
+ # --- Ensure the specific repository for this run exists ---
25
+ try:
26
+ logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...")
27
+ # Use create_repo which handles creation only if it doesn't exist
28
+ create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True)
29
+ logger.info(f"Repository {hf_repo_id_for_run} ensured.")
30
+ except Exception as e:
31
+ logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True)
32
+ return # Stop if repo interaction fails
33
+
34
+ commit_message = f"Upload {upload_format.upper()} checkpoint step {step}"
35
+ path_in_repo = f"step-{step}"
36
+
37
+ logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...")
38
+ try:
39
+ api.upload_folder(
40
+ folder_path=local_path,
41
+ path_in_repo=path_in_repo,
42
+ repo_id=hf_repo_id_for_run,
43
+ repo_type="model",
44
+ commit_message=commit_message,
45
+ token=token,
46
+ )
47
+ logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.")
48
+ except Exception as e:
49
+ logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True)
50
+ if hf_keep_latest_k > 0:
51
+ logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}")
52
+ try:
53
+ repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False)
54
+ step_folders = [
55
+ item.path for item in repo_files
56
+ if item.path.startswith("step-") and item.path[5:].isdigit()
57
+ ]
58
+
59
+ step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True)
60
+
61
+ if len(step_folders) > hf_keep_latest_k:
62
+ folders_to_delete = step_folders[hf_keep_latest_k:]
63
+ logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}")
64
+ for folder in folders_to_delete:
65
+ # Deleting requires repo_id, path_in_repo, and token
66
+ api.delete_folder(
67
+ repo_id=hf_repo_id_for_run,
68
+ path_in_repo=folder,
69
+ repo_type="model",
70
+ commit_message=f"Delete old checkpoint {folder}",
71
+ token=token
72
+ )
73
+ logger.info("Hub cleanup complete.")
74
+ else:
75
+ logger.info("No old checkpoints found on Hub to delete.")
76
+ except Exception as e:
77
+ logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True)
logs/none_ewbp5xc1/attempt_0/1/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1536/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_20992/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_23552/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_2560/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_2560/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_29696/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_29696/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_30720/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_3584/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_3584/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_3584/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_3584/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
tb/20250901-0749/wandb/run-20250901_074914-top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/files/wandb-metadata.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-6.8.0-62-generic-x86_64-with-glibc2.39",
3
+ "python": "CPython 3.12.11",
4
+ "startedAt": "2025-09-01T07:49:14.031224Z",
5
+ "args": [
6
+ "--job.config_file",
7
+ "flame/models/fla.toml",
8
+ "--job.dump_folder",
9
+ "exp/top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine",
10
+ "--model.config",
11
+ "configs/top_transformer_1B.json",
12
+ "--model.tokenizer_path",
13
+ "fla-hub/transformer-1.3B-100B",
14
+ "--optimizer.name",
15
+ "AdamW",
16
+ "--optimizer.eps",
17
+ "1e-15",
18
+ "--optimizer.lr",
19
+ "5e-5",
20
+ "--lr_scheduler.warmup_steps",
21
+ "400",
22
+ "--lr_scheduler.lr_min",
23
+ "0.1",
24
+ "--lr_scheduler.decay_type",
25
+ "cosine",
26
+ "--training.batch_size",
27
+ "16",
28
+ "--training.seq_len",
29
+ "4096",
30
+ "--training.context_len",
31
+ "4096",
32
+ "--training.gradient_accumulation_steps",
33
+ "1",
34
+ "--training.steps",
35
+ "40000",
36
+ "--training.max_norm",
37
+ "1.0",
38
+ "--training.skip_nan_inf",
39
+ "--training.dataset",
40
+ "/home/cvm/.cache/zaydzuhri___stack-edu-python/default",
41
+ "--training.dataset_split",
42
+ "train",
43
+ "--training.num_workers",
44
+ "32",
45
+ "--training.prefetch_factor",
46
+ "2",
47
+ "--training.seed",
48
+ "79",
49
+ "--training.compile",
50
+ "--checkpoint.interval",
51
+ "5000",
52
+ "--checkpoint.load_step",
53
+ "-1",
54
+ "--metrics.log_freq",
55
+ "5",
56
+ "--checkpoint.hf_upload_enabled",
57
+ "--checkpoint.hf_repo_base_name",
58
+ "zaydzuhri/top-code-1B-4096-batch16x1-steps40000",
59
+ "--comm.init_timeout_seconds",
60
+ "1600",
61
+ "--comm.train_timeout_seconds",
62
+ "1600"
63
+ ],
64
+ "program": "-m flame.train",
65
+ "git": {
66
+ "remote": "https://github.com/zaydzuhri/flame.git",
67
+ "commit": "aa4d5932e54fad8a568e10aa6895e69e0664fcf1"
68
+ },
69
+ "email": "zaydzuhri@gmail.com",
70
+ "root": "exp/top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine/tb/20250901-0749",
71
+ "host": "cvm-gncv9hlh",
72
+ "executable": "/home/cvm/miniconda3/envs/flame-env/bin/python3.12",
73
+ "cpu_count": 64,
74
+ "cpu_count_logical": 128,
75
+ "gpu": "NVIDIA H200",
76
+ "gpu_count": 8,
77
+ "disk": {
78
+ "/": {
79
+ "total": "3242363822080",
80
+ "used": "1307996758016"
81
+ }
82
+ },
83
+ "memory": {
84
+ "total": "1913833021440"
85
+ },
86
+ "gpu_nvidia": [
87
+ {
88
+ "name": "NVIDIA H200",
89
+ "memoryTotal": "150754820096",
90
+ "cudaCores": 16896,
91
+ "architecture": "Hopper",
92
+ "uuid": "GPU-eddf9f4c-ffde-5f10-3c76-12ebce1f042b"
93
+ },
94
+ {
95
+ "name": "NVIDIA H200",
96
+ "memoryTotal": "150754820096",
97
+ "cudaCores": 16896,
98
+ "architecture": "Hopper",
99
+ "uuid": "GPU-b532c850-7343-8f67-7eb1-a69024695a99"
100
+ },
101
+ {
102
+ "name": "NVIDIA H200",
103
+ "memoryTotal": "150754820096",
104
+ "cudaCores": 16896,
105
+ "architecture": "Hopper",
106
+ "uuid": "GPU-751a6bdf-72f3-4f5a-fefd-d2b98c338579"
107
+ },
108
+ {
109
+ "name": "NVIDIA H200",
110
+ "memoryTotal": "150754820096",
111
+ "cudaCores": 16896,
112
+ "architecture": "Hopper",
113
+ "uuid": "GPU-0cd9d3c7-1d2e-1925-91eb-8ec99a4ed277"
114
+ },
115
+ {
116
+ "name": "NVIDIA H200",
117
+ "memoryTotal": "150754820096",
118
+ "cudaCores": 16896,
119
+ "architecture": "Hopper",
120
+ "uuid": "GPU-fba7e7ab-8340-13b0-b893-c3686cfec728"
121
+ },
122
+ {
123
+ "name": "NVIDIA H200",
124
+ "memoryTotal": "150754820096",
125
+ "cudaCores": 16896,
126
+ "architecture": "Hopper",
127
+ "uuid": "GPU-12ca11c0-9080-3877-2bd5-3775573a4134"
128
+ },
129
+ {
130
+ "name": "NVIDIA H200",
131
+ "memoryTotal": "150754820096",
132
+ "cudaCores": 16896,
133
+ "architecture": "Hopper",
134
+ "uuid": "GPU-32b3ec8b-9dc8-c6f6-5c19-74fa2ce10ffd"
135
+ },
136
+ {
137
+ "name": "NVIDIA H200",
138
+ "memoryTotal": "150754820096",
139
+ "cudaCores": 16896,
140
+ "architecture": "Hopper",
141
+ "uuid": "GPU-d0021141-e4f4-14ab-c2ab-0ef3e30d6dd5"
142
+ }
143
+ ],
144
+ "cudaVersion": "12.8",
145
+ "writerId": "da7dvih583ith342zcw0cwucsgured2u"
146
+ }
tb/20250901-0749/wandb/run-20250901_074914-top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/logs/debug-internal.log ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2025-09-01T07:49:14.247972294Z","level":"INFO","msg":"stream: starting","core version":"0.21.1"}
2
+ {"time":"2025-09-01T07:49:14.545288881Z","level":"INFO","msg":"stream: created new stream","id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
3
+ {"time":"2025-09-01T07:49:14.545362953Z","level":"INFO","msg":"stream: started","id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
4
+ {"time":"2025-09-01T07:49:14.54541562Z","level":"INFO","msg":"writer: started","stream_id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
5
+ {"time":"2025-09-01T07:49:14.545435817Z","level":"INFO","msg":"sender: started","stream_id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
6
+ {"time":"2025-09-01T07:49:14.545490133Z","level":"INFO","msg":"handler: started","stream_id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
7
+ {"time":"2025-09-01T12:39:44.49607374Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
8
+ {"time":"2025-09-01T12:57:09.402167829Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
9
+ {"time":"2025-09-01T20:38:44.471380019Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
10
+ {"time":"2025-09-01T22:25:18.669785309Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
11
+ {"time":"2025-09-01T22:55:35.532603708Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
12
+ {"time":"2025-09-02T07:07:34.089412209Z","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
13
+ {"time":"2025-09-02T07:07:34.291787824Z","level":"INFO","msg":"handler: operation stats","stats":{}}
14
+ {"time":"2025-09-02T07:07:34.295689194Z","level":"INFO","msg":"stream: closing","id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
15
+ {"time":"2025-09-02T07:07:34.295726455Z","level":"INFO","msg":"handler: closed","stream_id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
16
+ {"time":"2025-09-02T07:07:34.295770415Z","level":"INFO","msg":"sender: closed","stream_id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
17
+ {"time":"2025-09-02T07:07:34.29578361Z","level":"INFO","msg":"stream: closed","id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}