Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc +0 -0
- fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc +0 -0
- flame/__pycache__/config_manager.cpython-312.pyc +0 -0
- flame/__pycache__/train.cpython-312.pyc +0 -0
- flame/components/__init__.py +0 -0
- flame/components/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
- flame/components/checkpoint.py +59 -0
- flame/models/__init__.py +0 -0
- flame/models/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/models/__pycache__/parallelize_fla.cpython-312.pyc +0 -0
- flame/models/__pycache__/pipeline_fla.cpython-312.pyc +0 -0
- flame/models/activation_offloading.py +447 -0
- flame/models/fla.toml +67 -0
- flame/models/parallelize_fla.py +550 -0
- flame/models/pipeline_fla.py +162 -0
- flame/tools/__init__.py +0 -0
- flame/tools/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/tools/__pycache__/utils.cpython-312.pyc +0 -0
- flame/tools/utils.py +41 -0
- flame/utils/__init__.py +0 -0
- flame/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/utils/__pycache__/checkpoint.cpython-312.pyc +0 -0
- flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc +0 -0
- flame/utils/__pycache__/convert_hf_to_dcp.cpython-312.pyc +0 -0
- flame/utils/__pycache__/hf_utils.cpython-312.pyc +0 -0
- flame/utils/checkpoint.py +50 -0
- flame/utils/convert_dcp_to_hf.py +66 -0
- flame/utils/convert_hf_to_dcp.py +34 -0
- flame/utils/hf_utils.py +77 -0
- logs/none_ewbp5xc1/attempt_0/1/stderr.log +0 -0
- profile_trace/iteration_1024/rank0_trace.json +0 -0
- profile_trace/iteration_1024/rank1_trace.json +0 -0
- profile_trace/iteration_1024/rank5_trace.json +0 -0
- profile_trace/iteration_1024/rank6_trace.json +0 -0
- profile_trace/iteration_1024/rank7_trace.json +0 -0
- profile_trace/iteration_1536/rank4_trace.json +0 -0
- profile_trace/iteration_20992/rank5_trace.json +0 -0
- profile_trace/iteration_23552/rank6_trace.json +0 -0
- profile_trace/iteration_2560/rank5_trace.json +0 -0
- profile_trace/iteration_2560/rank7_trace.json +0 -0
- profile_trace/iteration_29696/rank2_trace.json +0 -0
- profile_trace/iteration_29696/rank6_trace.json +0 -0
- profile_trace/iteration_30720/rank6_trace.json +0 -0
- profile_trace/iteration_3584/rank0_trace.json +0 -0
- profile_trace/iteration_3584/rank4_trace.json +0 -0
- profile_trace/iteration_3584/rank5_trace.json +0 -0
- profile_trace/iteration_3584/rank7_trace.json +0 -0
- tb/20250901-0749/wandb/run-20250901_074914-top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/files/wandb-metadata.json +146 -0
- tb/20250901-0749/wandb/run-20250901_074914-top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/logs/debug-internal.log +17 -0
fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc
ADDED
|
Binary file (3.76 kB). View file
|
|
|
flame/__pycache__/config_manager.cpython-312.pyc
ADDED
|
Binary file (36.9 kB). View file
|
|
|
flame/__pycache__/train.cpython-312.pyc
ADDED
|
Binary file (38.1 kB). View file
|
|
|
flame/components/__init__.py
ADDED
|
File without changes
|
flame/components/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (141 Bytes). View file
|
|
|
flame/components/__pycache__/checkpoint.cpython-312.pyc
ADDED
|
Binary file (3.21 kB). View file
|
|
|
flame/components/checkpoint.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from datetime import timedelta
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
from typing import Any, Dict, List
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TrainState(Stateful):
|
| 18 |
+
step: int = 0
|
| 19 |
+
skipped_step: int = 0
|
| 20 |
+
token: int = 0
|
| 21 |
+
elapsed: timedelta = timedelta(0)
|
| 22 |
+
global_avg_losses: List[float] = field(default_factory=list)
|
| 23 |
+
global_max_losses: List[float] = field(default_factory=list)
|
| 24 |
+
log_steps: List[int] = field(default_factory=list)
|
| 25 |
+
|
| 26 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 27 |
+
# Only checkpoint global_avg_losses and global_max_losses per log frequency
|
| 28 |
+
# to avoid sync overhead in every iteration.
|
| 29 |
+
global_avg_losses_bytes = BytesIO()
|
| 30 |
+
torch.save(self.global_avg_losses, global_avg_losses_bytes)
|
| 31 |
+
global_max_losses_bytes = BytesIO()
|
| 32 |
+
torch.save(self.global_max_losses, global_max_losses_bytes)
|
| 33 |
+
log_steps_bytes = BytesIO()
|
| 34 |
+
torch.save(self.log_steps, log_steps_bytes)
|
| 35 |
+
return {
|
| 36 |
+
"step": torch.tensor(self.step, dtype=torch.int32),
|
| 37 |
+
"skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
|
| 38 |
+
"token": torch.tensor(self.token, dtype=torch.int64),
|
| 39 |
+
"elapsed": self.elapsed,
|
| 40 |
+
"global_avg_losses": global_avg_losses_bytes,
|
| 41 |
+
"global_max_losses": global_max_losses_bytes,
|
| 42 |
+
"log_steps": log_steps_bytes,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def load_state_dict(self, state_dict) -> None:
|
| 46 |
+
self.step = state_dict["step"].item()
|
| 47 |
+
self.skipped_step = state_dict.get("skipped_step", 0).item()
|
| 48 |
+
self.token = state_dict["token"].item()
|
| 49 |
+
self.elapsed = state_dict["elapsed"]
|
| 50 |
+
state_dict["global_avg_losses"].seek(0)
|
| 51 |
+
self.global_avg_losses = torch.load(
|
| 52 |
+
state_dict["global_avg_losses"], weights_only=False
|
| 53 |
+
)
|
| 54 |
+
state_dict["global_max_losses"].seek(0)
|
| 55 |
+
self.global_max_losses = torch.load(
|
| 56 |
+
state_dict["global_max_losses"], weights_only=False
|
| 57 |
+
)
|
| 58 |
+
state_dict["log_steps"].seek(0)
|
| 59 |
+
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
|
flame/models/__init__.py
ADDED
|
File without changes
|
flame/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (137 Bytes). View file
|
|
|
flame/models/__pycache__/parallelize_fla.cpython-312.pyc
ADDED
|
Binary file (22.1 kB). View file
|
|
|
flame/models/__pycache__/pipeline_fla.cpython-312.pyc
ADDED
|
Binary file (5.75 kB). View file
|
|
|
flame/models/activation_offloading.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/training/_activation_offloading.py
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the BSD-style license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
import contextlib
|
| 9 |
+
from typing import Union
|
| 10 |
+
from warnings import warn
|
| 11 |
+
|
| 12 |
+
import psutil
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torch.autograd.graph import saved_tensors_hooks
|
| 16 |
+
|
| 17 |
+
from torchtitan.tools.logging import logger
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import torchao
|
| 21 |
+
from torchao.dtypes.nf4tensor import NF4Tensor
|
| 22 |
+
except ImportError:
|
| 23 |
+
torchao = None
|
| 24 |
+
NF4Tensor = None
|
| 25 |
+
logger.warning("torchao not found. ")
|
| 26 |
+
|
| 27 |
+
# from torchtune.modules import TiedLinear
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class OffloadActivations(saved_tensors_hooks):
|
| 31 |
+
"""Context manager under which activation tensors created in the forward pass will be offloaded.
|
| 32 |
+
|
| 33 |
+
Enable the memory efficiency technique of activation offloading, where activations bigger than
|
| 34 |
+
min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward.
|
| 35 |
+
This is in contrast to maintaining the activation on GPU VRAM throughout the program.
|
| 36 |
+
|
| 37 |
+
This manager contains the option of using one additional CUDA stream to handle the communication
|
| 38 |
+
between CUDA and CPU, which is intended to overlap with the default computation stream to improve
|
| 39 |
+
runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between
|
| 40 |
+
runtime vs memory usage.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned
|
| 44 |
+
memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly
|
| 45 |
+
but is a limited resource. Default: True.
|
| 46 |
+
|
| 47 |
+
use_streams (bool): Whether or not to use streams for performance optimization where
|
| 48 |
+
the communications get overlapped with the computation. Requires a torch build
|
| 49 |
+
after torch-2.5.0.]. Default: True.
|
| 50 |
+
|
| 51 |
+
max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of
|
| 52 |
+
consecutive activations to keep alive during the forward pass. This number must be at
|
| 53 |
+
least 1. Keeping alive more activations will potentially allow more overlap between the
|
| 54 |
+
communication and compute streams at the cost of increasing memory usage. Keeping alive
|
| 55 |
+
fewer activations will conserve memory, but may cause poor overlap between the streams,
|
| 56 |
+
increasing runtime. Default: 5.
|
| 57 |
+
|
| 58 |
+
min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify
|
| 59 |
+
for offloading. If the tensor is too small, we do not want to waste bandwidth and resources
|
| 60 |
+
moving it to CPU and back. Default: 1024 bytes.
|
| 61 |
+
|
| 62 |
+
Raises:
|
| 63 |
+
ValueError: if max_fwd_stash_size is not at least 1.
|
| 64 |
+
|
| 65 |
+
Example:
|
| 66 |
+
>>> with OffloadActivations():
|
| 67 |
+
>>> logits = model(inputs)
|
| 68 |
+
>>> loss = ...
|
| 69 |
+
>>> loss.backward()
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
use_pin_memory: bool = True,
|
| 75 |
+
use_streams: bool = True,
|
| 76 |
+
max_fwd_stash_size: int = 5,
|
| 77 |
+
min_offload_size: int = 1024,
|
| 78 |
+
) -> None:
|
| 79 |
+
|
| 80 |
+
self.use_streams: bool = use_streams
|
| 81 |
+
|
| 82 |
+
self.min_tensor_size_bytes = (
|
| 83 |
+
min_offload_size # we don't want to bother with small tensors
|
| 84 |
+
)
|
| 85 |
+
self.tracker = (
|
| 86 |
+
{}
|
| 87 |
+
) # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where
|
| 88 |
+
self.tensor_id: int = 0
|
| 89 |
+
self.is_first_forward_call = True
|
| 90 |
+
self.is_first_backward_call = True
|
| 91 |
+
self.is_first_forward_pass = True
|
| 92 |
+
|
| 93 |
+
# managing cpu memory
|
| 94 |
+
self.use_pin_memory: bool = use_pin_memory
|
| 95 |
+
self.virtual_memory_safe_pct = (
|
| 96 |
+
60 # we should not exceed this percentage of memory
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.s0 = torch.cuda.default_stream() # comp stream
|
| 100 |
+
|
| 101 |
+
# for streaming
|
| 102 |
+
if self.use_streams:
|
| 103 |
+
self.s1 = torch.cuda.Stream() # comms stream
|
| 104 |
+
self.fwd_stash = {} # tensor_id => (activation, ev1)
|
| 105 |
+
if max_fwd_stash_size < 1:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}"
|
| 108 |
+
)
|
| 109 |
+
self.max_fwd_stash_size = max_fwd_stash_size
|
| 110 |
+
self.bwd_tensor_stash = {} # tensor_id => activation
|
| 111 |
+
self.bwd_ev_stash = {} # tensor_id => ev0
|
| 112 |
+
self.curr_graph_id = None
|
| 113 |
+
self.curr_autograd_node = None
|
| 114 |
+
|
| 115 |
+
# -------- platform util functions -------- #
|
| 116 |
+
def verify_sufficient_virtual_memory():
|
| 117 |
+
curr_pct = get_cpu_ram_pct()
|
| 118 |
+
if curr_pct > self.virtual_memory_safe_pct:
|
| 119 |
+
warn(
|
| 120 |
+
f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def get_cpu_ram_pct() -> float:
|
| 124 |
+
# get the percentage of memory used by the system
|
| 125 |
+
return psutil.virtual_memory().percent
|
| 126 |
+
|
| 127 |
+
def get_tensor_id() -> int:
|
| 128 |
+
# create a unique id for each tensor we are managing
|
| 129 |
+
self.tensor_id += 1
|
| 130 |
+
return self.tensor_id
|
| 131 |
+
|
| 132 |
+
def get_num_bytes_tensor(x: torch.Tensor) -> int:
|
| 133 |
+
# get the number of bytes in a tensor, for memory management purposes
|
| 134 |
+
return (
|
| 135 |
+
x.element_size() * x.nelement()
|
| 136 |
+
) # x.element_size() * x._base_storage().nbytes()
|
| 137 |
+
|
| 138 |
+
# -------- core pack / unpack work -------- #
|
| 139 |
+
def pack_tensor(activation: torch.Tensor) -> int:
|
| 140 |
+
# activations are passed in during forward pass - from here we take over and return a unique id
|
| 141 |
+
if self.is_first_forward_call:
|
| 142 |
+
assert (
|
| 143 |
+
len(self.tracker) == 0
|
| 144 |
+
), "backward pass should have cleared tracker of all tensors"
|
| 145 |
+
|
| 146 |
+
# set training phase trackers
|
| 147 |
+
self.is_first_forward_call = False
|
| 148 |
+
self.is_first_backward_call = True
|
| 149 |
+
|
| 150 |
+
# query for basic tensor info
|
| 151 |
+
num_bytes = get_num_bytes_tensor(activation)
|
| 152 |
+
tensor_id = get_tensor_id()
|
| 153 |
+
|
| 154 |
+
# only offload hefty bois if they're activations on CUDA (our heuristic
|
| 155 |
+
# for that is to check if they're not params or buffers)!
|
| 156 |
+
if (
|
| 157 |
+
activation.is_cuda
|
| 158 |
+
and num_bytes >= self.min_tensor_size_bytes
|
| 159 |
+
and (
|
| 160 |
+
not isinstance(activation, torch.nn.Parameter)
|
| 161 |
+
and not isinstance(activation, torch.nn.Buffer)
|
| 162 |
+
)
|
| 163 |
+
):
|
| 164 |
+
if self.use_streams:
|
| 165 |
+
# First, sync back and dereference previously offloaded tensors
|
| 166 |
+
# as the offloading should be done sufficiently long ago.
|
| 167 |
+
for id in [k for k in self.fwd_stash.keys()]:
|
| 168 |
+
if id <= tensor_id - self.max_fwd_stash_size:
|
| 169 |
+
_, ev = self.fwd_stash[id]
|
| 170 |
+
self.s0.wait_event(ev)
|
| 171 |
+
del self.fwd_stash[id]
|
| 172 |
+
else:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
# Sync in, offload, and add an event to sync back later
|
| 176 |
+
self.s1.wait_stream(self.s0)
|
| 177 |
+
|
| 178 |
+
stream = self.s1 if self.use_streams else self.s0
|
| 179 |
+
with torch.cuda.stream(stream):
|
| 180 |
+
try:
|
| 181 |
+
cpu_tensor = torch.empty_like(
|
| 182 |
+
activation, pin_memory=self.use_pin_memory, device="cpu"
|
| 183 |
+
)
|
| 184 |
+
except NotImplementedError as e:
|
| 185 |
+
if (
|
| 186 |
+
isinstance(activation, NF4Tensor)
|
| 187 |
+
and torchao.__version__ < "0.6.0.dev20240917"
|
| 188 |
+
):
|
| 189 |
+
raise RuntimeError(
|
| 190 |
+
"Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later"
|
| 191 |
+
) from e
|
| 192 |
+
raise e
|
| 193 |
+
cpu_tensor.copy_(activation, non_blocking=True)
|
| 194 |
+
self.tracker[tensor_id] = (
|
| 195 |
+
cpu_tensor,
|
| 196 |
+
True,
|
| 197 |
+
) # True = (in future) modified
|
| 198 |
+
|
| 199 |
+
if self.use_streams:
|
| 200 |
+
event = self.s1.record_event()
|
| 201 |
+
|
| 202 |
+
# Stash to keep activation alive til s1 is done
|
| 203 |
+
self.fwd_stash[tensor_id] = (activation, event)
|
| 204 |
+
else:
|
| 205 |
+
self.tracker[tensor_id] = (
|
| 206 |
+
activation,
|
| 207 |
+
False,
|
| 208 |
+
) # False = not modified, tensor is as is
|
| 209 |
+
|
| 210 |
+
return tensor_id
|
| 211 |
+
|
| 212 |
+
def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
|
| 213 |
+
# backward pass - we are called with the tensor_id, which
|
| 214 |
+
# we will use to retrieve the saved/offloaded tensor
|
| 215 |
+
if self.is_first_backward_call:
|
| 216 |
+
if self.is_first_forward_pass:
|
| 217 |
+
self.is_first_forward_pass = False
|
| 218 |
+
if self.use_pin_memory:
|
| 219 |
+
verify_sufficient_virtual_memory()
|
| 220 |
+
|
| 221 |
+
self.is_first_backward_call = False
|
| 222 |
+
self.is_first_forward_call = True
|
| 223 |
+
|
| 224 |
+
assert (
|
| 225 |
+
unpack_tensor_id in self.tracker
|
| 226 |
+
), f"untracked tensor with id {unpack_tensor_id}"
|
| 227 |
+
|
| 228 |
+
maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
|
| 229 |
+
if modified:
|
| 230 |
+
gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
|
| 231 |
+
maybe_gpu_tensor = gpu_tensor
|
| 232 |
+
|
| 233 |
+
# clear tensor from tracking
|
| 234 |
+
del self.tracker[unpack_tensor_id]
|
| 235 |
+
return maybe_gpu_tensor
|
| 236 |
+
|
| 237 |
+
def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
|
| 238 |
+
# backward pass - we are called with the tensor_id, which
|
| 239 |
+
# we will use to retrieve the saved/offloaded tensor
|
| 240 |
+
if self.is_first_backward_call:
|
| 241 |
+
self.curr_graph_id = torch._C._current_graph_task_id()
|
| 242 |
+
|
| 243 |
+
def wait_and_del_remaining_references() -> None:
|
| 244 |
+
for id in [k for k in self.bwd_tensor_stash.keys()]:
|
| 245 |
+
event = self.bwd_ev_stash[id]
|
| 246 |
+
self.s1.wait_event(event)
|
| 247 |
+
del self.bwd_tensor_stash[id]
|
| 248 |
+
|
| 249 |
+
# Register a callback to the end of autograd to clean everything up
|
| 250 |
+
torch.autograd.variable.Variable._execution_engine.queue_callback(
|
| 251 |
+
wait_and_del_remaining_references
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if self.is_first_forward_pass:
|
| 255 |
+
self.is_first_forward_pass = False
|
| 256 |
+
if self.use_pin_memory:
|
| 257 |
+
verify_sufficient_virtual_memory()
|
| 258 |
+
|
| 259 |
+
self.is_first_backward_call = False
|
| 260 |
+
self.is_first_forward_call = True
|
| 261 |
+
|
| 262 |
+
assert (
|
| 263 |
+
unpack_tensor_id in self.tracker
|
| 264 |
+
), f"untracked tensor with id {unpack_tensor_id}"
|
| 265 |
+
|
| 266 |
+
maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
|
| 267 |
+
if modified:
|
| 268 |
+
# Get data on the current autograd node
|
| 269 |
+
graph_id = torch._C._current_graph_task_id()
|
| 270 |
+
node = torch._C._current_autograd_node()
|
| 271 |
+
prev_node_ids = []
|
| 272 |
+
|
| 273 |
+
# If we're on a new node, mark prev node's tensors to be freed later
|
| 274 |
+
if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
|
| 275 |
+
self.curr_autograd_node = node
|
| 276 |
+
prev_node_ids = [id for id in self.bwd_tensor_stash.keys()]
|
| 277 |
+
|
| 278 |
+
brought_back_from_cpu = True
|
| 279 |
+
if unpack_tensor_id in self.fwd_stash:
|
| 280 |
+
maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0]
|
| 281 |
+
brought_back_from_cpu = False
|
| 282 |
+
else:
|
| 283 |
+
# Kick off the process to bring tensors back
|
| 284 |
+
with torch.cuda.stream(self.s1):
|
| 285 |
+
gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
|
| 286 |
+
maybe_gpu_tensor = gpu_tensor
|
| 287 |
+
|
| 288 |
+
# Tell comp stream to wait for the info to be loaded before executing
|
| 289 |
+
self.s0.wait_stream(self.s1)
|
| 290 |
+
|
| 291 |
+
# Stash the tensor to keep memory alive until compute stream is complete
|
| 292 |
+
self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor
|
| 293 |
+
|
| 294 |
+
# Note: [Track views of the unpacked]
|
| 295 |
+
# Why do we get the use count of the unpacked tensor here? We want an
|
| 296 |
+
# initial count to compare to later, during the post-hook of the
|
| 297 |
+
# backward node, when we need to decide whether we're allowed to free
|
| 298 |
+
# the tensor yet. In what obscure cases must we delay freeing the
|
| 299 |
+
# tensor (and thus call record_stream)?
|
| 300 |
+
# 1. Any of the outputs of the backward node is a view of the unpacked
|
| 301 |
+
# tensor.
|
| 302 |
+
# 2. In the case that this unpacked tensor will be used in a
|
| 303 |
+
# checkpointed region, if one of the recomputed saved tensors ends
|
| 304 |
+
# up as a view of the unpacked tensor.
|
| 305 |
+
# 3. The user abuses the system somehow and manually relies on the
|
| 306 |
+
# unpacked tensor to exist after the backward node has executed.
|
| 307 |
+
storage_refcount = torch._C._storage_Use_Count(
|
| 308 |
+
maybe_gpu_tensor.untyped_storage()._cdata
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def hook(outputs, inputs):
|
| 312 |
+
# create events for the current node inputs/outputs if they were streamed in
|
| 313 |
+
if brought_back_from_cpu:
|
| 314 |
+
# See Note: [Track views of the unpacked]
|
| 315 |
+
# IF any of the outputs is a view of the tensor, OR if a view of
|
| 316 |
+
# the tensor has been saved as a part of checkpoint's recompute
|
| 317 |
+
# process, OR the user has abusedly incurred a reference on the
|
| 318 |
+
# unpacked tensor, THEN the tensor might be used later and we
|
| 319 |
+
# cannot presume to delete it after only the current node is
|
| 320 |
+
# done! So we use our frenemy, record_stream, to ensure the
|
| 321 |
+
# Tensor stays unmessed with until it's done getting used in the
|
| 322 |
+
# compute stream (s0 here). Note that the con here is we introduce
|
| 323 |
+
# non-deterministic (thus higher) memory usage, but this case
|
| 324 |
+
# should not happen often.
|
| 325 |
+
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
|
| 326 |
+
if (
|
| 327 |
+
torch._C._storage_Use_Count(
|
| 328 |
+
unpacked_tensor.untyped_storage()._cdata
|
| 329 |
+
)
|
| 330 |
+
> storage_refcount
|
| 331 |
+
):
|
| 332 |
+
unpacked_tensor.record_stream(self.s0)
|
| 333 |
+
del self.bwd_tensor_stash[unpack_tensor_id]
|
| 334 |
+
else:
|
| 335 |
+
event = self.s0.record_event()
|
| 336 |
+
self.bwd_ev_stash[unpack_tensor_id] = event
|
| 337 |
+
|
| 338 |
+
# if there are still things in the fwd_stash, get rid of them as we're in bwd now
|
| 339 |
+
for id in [k for k in self.fwd_stash.keys()]:
|
| 340 |
+
_, ev = self.fwd_stash[id]
|
| 341 |
+
self.s0.wait_event(ev)
|
| 342 |
+
del self.fwd_stash[id]
|
| 343 |
+
|
| 344 |
+
# wait on prev node's events and del those
|
| 345 |
+
for id in prev_node_ids:
|
| 346 |
+
event = self.bwd_ev_stash[id]
|
| 347 |
+
self.s1.wait_event(event)
|
| 348 |
+
del self.bwd_tensor_stash[id]
|
| 349 |
+
|
| 350 |
+
return outputs
|
| 351 |
+
|
| 352 |
+
node.register_hook(hook)
|
| 353 |
+
|
| 354 |
+
# clear tensor from tracking
|
| 355 |
+
del self.tracker[unpack_tensor_id]
|
| 356 |
+
return maybe_gpu_tensor
|
| 357 |
+
|
| 358 |
+
unpack_tensor = (
|
| 359 |
+
unpack_tensor_with_streams
|
| 360 |
+
if self.use_streams
|
| 361 |
+
else unpack_tensor_single_stream
|
| 362 |
+
)
|
| 363 |
+
super().__init__(pack_tensor, unpack_tensor)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class NoOpManager(saved_tensors_hooks):
|
| 367 |
+
"""
|
| 368 |
+
A saved_tensors_hook manager used to disable any other saved_tensors_hook manager
|
| 369 |
+
applied before. This relies on the behavior that only the most recently registered
|
| 370 |
+
saved_tensors_hook will run.
|
| 371 |
+
|
| 372 |
+
One example usage is to opt a local region of code out of activations offloading,
|
| 373 |
+
which is usually applied globally to best track state.
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
def __init__(self) -> None:
|
| 377 |
+
def noop(tensor):
|
| 378 |
+
return tensor
|
| 379 |
+
|
| 380 |
+
super().__init__(noop, noop)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def get_act_offloading_ctx_manager(
|
| 384 |
+
model: nn.Module, enable_activation_offloading: bool
|
| 385 |
+
) -> Union[OffloadActivations, contextlib.nullcontext]:
|
| 386 |
+
"""Returns the activation offloading context manager for the model, which will be
|
| 387 |
+
a null context if enable_activation_offloading is False.
|
| 388 |
+
|
| 389 |
+
If activation offloading is enabled, we return the OffloadActivations context manager.
|
| 390 |
+
If activation offloading is disabled, we return a NoOpManager context manager.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
model (nn.Module): the model to wrap with the activation offloading context manager.
|
| 394 |
+
enable_activation_offloading (bool): whether or not to enable activation offloading
|
| 395 |
+
for the model.
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
contextlib.ContextDecorator: the activation offloading context manager for the model.
|
| 399 |
+
|
| 400 |
+
Raises:
|
| 401 |
+
NotImplementedError: If the model is a multimodal model and activation offloading is enabled.
|
| 402 |
+
"""
|
| 403 |
+
if enable_activation_offloading:
|
| 404 |
+
activations_handling_ctx = OffloadActivations()
|
| 405 |
+
|
| 406 |
+
# Below is our hack to disable offloading the last output Linear in every
|
| 407 |
+
# step, as the cost for offloading the activation and then soon after bringing
|
| 408 |
+
# it back is expensive. Moreover, due to heuristics in our streaming API,
|
| 409 |
+
# we actually use more memory if we offload it as it interferes with chunkedCE.
|
| 410 |
+
output_head_detected = False
|
| 411 |
+
noop_ctx = NoOpManager()
|
| 412 |
+
|
| 413 |
+
if hasattr(model, "output"):
|
| 414 |
+
if isinstance(model.output, nn.Module):
|
| 415 |
+
model.output.register_forward_pre_hook(
|
| 416 |
+
lambda *args: noop_ctx.__enter__()
|
| 417 |
+
)
|
| 418 |
+
model.output.register_forward_hook(
|
| 419 |
+
lambda *args: noop_ctx.__exit__(), always_call=True
|
| 420 |
+
)
|
| 421 |
+
print("registering hooks for model.output ============ ")
|
| 422 |
+
output_head_detected = True
|
| 423 |
+
# ================================
|
| 424 |
+
# ! TODO[flame] check if we need to detal with TiedLinear
|
| 425 |
+
# The following code appears in `torchtune`
|
| 426 |
+
# elif isinstance(model.output, TiedLinear):
|
| 427 |
+
# model.output.linear.register_forward_pre_hook(
|
| 428 |
+
# lambda *args: noop_ctx.__enter__()
|
| 429 |
+
# )
|
| 430 |
+
# model.output.linear.register_forward_hook(
|
| 431 |
+
# lambda *args: noop_ctx.__exit__(), always_call=True
|
| 432 |
+
# )
|
| 433 |
+
# output_head_detected = True
|
| 434 |
+
|
| 435 |
+
if not output_head_detected:
|
| 436 |
+
logger.warning(
|
| 437 |
+
"During activation offloading, no output head was detected. "
|
| 438 |
+
"If your model has an output head, it will be offloaded. "
|
| 439 |
+
"This usually greatly slows training, given the large vocabulary size. "
|
| 440 |
+
"To change this behavior, set your output head as model.output and make it "
|
| 441 |
+
"an nn.Module."
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
else:
|
| 445 |
+
activations_handling_ctx = contextlib.nullcontext()
|
| 446 |
+
|
| 447 |
+
return activations_handling_ctx
|
flame/models/fla.toml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[model]
|
| 2 |
+
config = "fla-hub/transformer-1.3B-100B"
|
| 3 |
+
tokenizer_path = "fla-hub/transformer-1.3B-100B"
|
| 4 |
+
|
| 5 |
+
[job]
|
| 6 |
+
dump_folder = "exp"
|
| 7 |
+
print_args = true
|
| 8 |
+
|
| 9 |
+
[training]
|
| 10 |
+
batch_size = 32
|
| 11 |
+
seq_len = 2048
|
| 12 |
+
context_len = 2048
|
| 13 |
+
gradient_accumulation_steps = 1
|
| 14 |
+
steps = 20480
|
| 15 |
+
max_norm = 1.0
|
| 16 |
+
skip_nan_inf = true
|
| 17 |
+
data_parallel_replicate_degree = 1
|
| 18 |
+
data_parallel_shard_degree = -1
|
| 19 |
+
tensor_parallel_degree = 1
|
| 20 |
+
compile = false
|
| 21 |
+
dataset = "HuggingFaceFW/fineweb-edu"
|
| 22 |
+
dataset_name = "default"
|
| 23 |
+
num_workers = 32
|
| 24 |
+
pin_memory = false
|
| 25 |
+
persistent_workers = false
|
| 26 |
+
prefetch_factor = 2
|
| 27 |
+
seed = 42
|
| 28 |
+
varlen = false
|
| 29 |
+
|
| 30 |
+
[optimizer]
|
| 31 |
+
name = "AdamW"
|
| 32 |
+
eps = 1e-15
|
| 33 |
+
lr = 3e-4
|
| 34 |
+
|
| 35 |
+
[lr_scheduler]
|
| 36 |
+
warmup_steps = 1024
|
| 37 |
+
decay_type = "cosine"
|
| 38 |
+
lr_min = 0.1
|
| 39 |
+
|
| 40 |
+
[checkpoint]
|
| 41 |
+
enable_checkpoint = true
|
| 42 |
+
folder = "checkpoint"
|
| 43 |
+
interval_type = "steps"
|
| 44 |
+
interval = 2048
|
| 45 |
+
model_weights_only = false
|
| 46 |
+
export_dtype = "float32"
|
| 47 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
| 48 |
+
|
| 49 |
+
[profiling]
|
| 50 |
+
enable_profiling = true
|
| 51 |
+
save_traces_folder = "profile_trace"
|
| 52 |
+
profile_freq = 512
|
| 53 |
+
|
| 54 |
+
[metrics]
|
| 55 |
+
log_freq = 32
|
| 56 |
+
enable_wandb = true
|
| 57 |
+
|
| 58 |
+
[experimental]
|
| 59 |
+
context_parallel_degree = 1
|
| 60 |
+
pipeline_parallel_degree = 1
|
| 61 |
+
|
| 62 |
+
[float8]
|
| 63 |
+
enable_fsdp_float8_all_gather = false
|
| 64 |
+
precompute_float8_dynamic_scale_for_fsdp = false
|
| 65 |
+
|
| 66 |
+
[activation_checkpoint]
|
| 67 |
+
mode = "none"
|
flame/models/parallelize_fla.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
|
| 8 |
+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
|
| 9 |
+
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.distributed import DeviceMesh
|
| 15 |
+
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
|
| 16 |
+
from torch.distributed._composable.replicate import replicate
|
| 17 |
+
from torch.distributed._tensor import Replicate, Shard
|
| 18 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
|
| 19 |
+
from torch.distributed.tensor.parallel import (
|
| 20 |
+
ColwiseParallel,
|
| 21 |
+
PrepareModuleInput,
|
| 22 |
+
PrepareModuleOutput,
|
| 23 |
+
RowwiseParallel,
|
| 24 |
+
SequenceParallel,
|
| 25 |
+
parallelize_module
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from fla.modules.fused_linear_cross_entropy import LinearLossParallel
|
| 29 |
+
from fla.modules.mlp import SwiGLULinearParallel
|
| 30 |
+
from fla.modules.parallel import PrepareModuleWeight
|
| 31 |
+
from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
|
| 32 |
+
from torchtitan.distributed.parallel_dims import ParallelDims
|
| 33 |
+
from torchtitan.tools.logging import logger
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def parallelize_fla(
|
| 37 |
+
model: nn.Module,
|
| 38 |
+
world_mesh: DeviceMesh,
|
| 39 |
+
parallel_dims: ParallelDims,
|
| 40 |
+
job_config: JobConfig,
|
| 41 |
+
):
|
| 42 |
+
"""
|
| 43 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
| 44 |
+
parallelism to the model.
|
| 45 |
+
|
| 46 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
| 47 |
+
the model must fit on GPU or CPU memory.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
if parallel_dims.tp_enabled:
|
| 51 |
+
if (
|
| 52 |
+
job_config.experimental.enable_async_tensor_parallel
|
| 53 |
+
and not job_config.training.compile
|
| 54 |
+
):
|
| 55 |
+
raise RuntimeError("Async TP requires --training.compile")
|
| 56 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
| 57 |
+
apply_tp(
|
| 58 |
+
model,
|
| 59 |
+
world_mesh["tp"],
|
| 60 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
| 61 |
+
enable_float8=enable_float8_linear,
|
| 62 |
+
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if job_config.activation_checkpoint.mode != "none":
|
| 66 |
+
apply_ac(model, job_config.activation_checkpoint)
|
| 67 |
+
|
| 68 |
+
# turn on per-block compile after AC wrapping and before FSDP
|
| 69 |
+
if job_config.training.compile:
|
| 70 |
+
apply_compile(model)
|
| 71 |
+
|
| 72 |
+
if (
|
| 73 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
| 74 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
| 75 |
+
if parallel_dims.dp_replicate_enabled:
|
| 76 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
| 77 |
+
else:
|
| 78 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
| 79 |
+
|
| 80 |
+
apply_fsdp(
|
| 81 |
+
model,
|
| 82 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
| 83 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
| 84 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
| 85 |
+
pp_enabled=parallel_dims.pp_enabled,
|
| 86 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
| 87 |
+
reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if parallel_dims.dp_replicate_enabled:
|
| 91 |
+
logger.info("Applied HSDP to the model")
|
| 92 |
+
else:
|
| 93 |
+
logger.info("Applied FSDP to the model")
|
| 94 |
+
|
| 95 |
+
if parallel_dims.cp_enabled:
|
| 96 |
+
logger.info("Applied Context Parallel to the model")
|
| 97 |
+
|
| 98 |
+
if job_config.training.enable_cpu_offload:
|
| 99 |
+
logger.info("Applied CPU Offloading to the model")
|
| 100 |
+
elif parallel_dims.dp_replicate_enabled:
|
| 101 |
+
if world_mesh.ndim > 1:
|
| 102 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
| 103 |
+
apply_ddp(
|
| 104 |
+
model,
|
| 105 |
+
world_mesh,
|
| 106 |
+
enable_compile=job_config.training.compile,
|
| 107 |
+
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class TPPlan:
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
model=None,
|
| 115 |
+
loss_parallel=False,
|
| 116 |
+
enable_float8=False,
|
| 117 |
+
):
|
| 118 |
+
self.model = model
|
| 119 |
+
self.loss_parallel = loss_parallel
|
| 120 |
+
self.enable_float8 = enable_float8
|
| 121 |
+
self.base_model_prefix = getattr(model, "base_model_prefix", "model")
|
| 122 |
+
|
| 123 |
+
# TODO(vkuzo): once float8 configuration supports delayed scaling,
|
| 124 |
+
# add a check here to enforce supported float8 all-gather configurations
|
| 125 |
+
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
|
| 126 |
+
try:
|
| 127 |
+
from torchao.float8.float8_tensor_parallel import (
|
| 128 |
+
Float8ColwiseParallel,
|
| 129 |
+
Float8RowwiseParallel,
|
| 130 |
+
PrepareFloat8ModuleInput
|
| 131 |
+
)
|
| 132 |
+
except ImportError:
|
| 133 |
+
Float8ColwiseParallel = None
|
| 134 |
+
Float8RowwiseParallel = None
|
| 135 |
+
PrepareFloat8ModuleInput = None
|
| 136 |
+
if self.enable_float8 and Float8ColwiseParallel is not None:
|
| 137 |
+
self.rowwise_parallel = Float8RowwiseParallel
|
| 138 |
+
self.colwise_parallel = Float8ColwiseParallel
|
| 139 |
+
self.prepare_module_input = PrepareFloat8ModuleInput
|
| 140 |
+
self.prepare_module_output = PrepareModuleOutput
|
| 141 |
+
else:
|
| 142 |
+
self.rowwise_parallel = RowwiseParallel
|
| 143 |
+
self.colwise_parallel = ColwiseParallel
|
| 144 |
+
self.prepare_module_input = PrepareModuleInput
|
| 145 |
+
self.prepare_module_output = PrepareModuleOutput
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def model_plan(self):
|
| 149 |
+
plans = {
|
| 150 |
+
f"{self.base_model_prefix}.embeddings": RowwiseParallel(
|
| 151 |
+
input_layouts=Replicate(),
|
| 152 |
+
output_layouts=Shard(1),
|
| 153 |
+
),
|
| 154 |
+
f"{self.base_model_prefix}.norm": SequenceParallel(),
|
| 155 |
+
}
|
| 156 |
+
if self.loss_parallel:
|
| 157 |
+
plans.update(
|
| 158 |
+
{
|
| 159 |
+
"lm_head": ColwiseParallel(
|
| 160 |
+
input_layouts=Shard(1),
|
| 161 |
+
output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
|
| 162 |
+
use_local_output=not self.loss_parallel,
|
| 163 |
+
),
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
plans.update(
|
| 168 |
+
{
|
| 169 |
+
"lm_head": PrepareModuleWeight(layouts=Replicate()),
|
| 170 |
+
"criterion": LinearLossParallel(),
|
| 171 |
+
}
|
| 172 |
+
)
|
| 173 |
+
return plans
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def layer_plan(self):
|
| 177 |
+
return {
|
| 178 |
+
"attn_norm": SequenceParallel(),
|
| 179 |
+
**self.attn_plan,
|
| 180 |
+
"mlp_norm": SequenceParallel(),
|
| 181 |
+
**self.mlp_plan,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def attn_plan(self):
|
| 186 |
+
raise NotImplementedError(
|
| 187 |
+
f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def mlp_plan(self):
|
| 192 |
+
return {
|
| 193 |
+
"mlp": self.prepare_module_input(
|
| 194 |
+
input_layouts=(Shard(1),),
|
| 195 |
+
desired_input_layouts=(Replicate(),),
|
| 196 |
+
),
|
| 197 |
+
"mlp.gate_proj": self.colwise_parallel(),
|
| 198 |
+
"mlp.up_proj": self.colwise_parallel(),
|
| 199 |
+
"mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
| 200 |
+
"mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class TransformerTPPlan(TPPlan):
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def attn_plan(self):
|
| 208 |
+
return {
|
| 209 |
+
"attn": self.prepare_module_input(
|
| 210 |
+
input_kwarg_layouts={"hidden_states": Shard(1)},
|
| 211 |
+
desired_input_kwarg_layouts={"hidden_states": Replicate()},
|
| 212 |
+
),
|
| 213 |
+
"attn.q_proj": self.colwise_parallel(),
|
| 214 |
+
"attn.k_proj": self.colwise_parallel(),
|
| 215 |
+
"attn.v_proj": self.colwise_parallel(),
|
| 216 |
+
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class GLATPPlan(TPPlan):
|
| 221 |
+
|
| 222 |
+
@property
|
| 223 |
+
def attn_plan(self):
|
| 224 |
+
return {
|
| 225 |
+
"attn": self.prepare_module_input(
|
| 226 |
+
input_kwarg_layouts={"hidden_states": Shard(1)},
|
| 227 |
+
desired_input_kwarg_layouts={"hidden_states": Replicate()},
|
| 228 |
+
),
|
| 229 |
+
"attn.q_proj": self.colwise_parallel(),
|
| 230 |
+
"attn.k_proj": self.colwise_parallel(),
|
| 231 |
+
"attn.v_proj": self.colwise_parallel(),
|
| 232 |
+
"attn.g_proj": self.colwise_parallel(),
|
| 233 |
+
"attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
|
| 234 |
+
"attn.gk_proj.1": self.colwise_parallel(),
|
| 235 |
+
"attn.g_norm": SequenceParallel(sequence_dim=-1),
|
| 236 |
+
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def apply_tp(
|
| 244 |
+
model: nn.Module,
|
| 245 |
+
tp_mesh: DeviceMesh,
|
| 246 |
+
loss_parallel: bool,
|
| 247 |
+
enable_float8: bool,
|
| 248 |
+
enable_async_tp: bool,
|
| 249 |
+
):
|
| 250 |
+
"""Apply tensor parallelism."""
|
| 251 |
+
# 1. Parallelize the embedding and shard its outputs (which are the first
|
| 252 |
+
# transformer block's inputs)
|
| 253 |
+
# 2. Parallelize the root norm layer over the sequence dim
|
| 254 |
+
# 3. Parallelize the final linear output layer
|
| 255 |
+
tp_plan = TP_PLAN_MAP[model.config.model_type](
|
| 256 |
+
model, loss_parallel=loss_parallel, enable_float8=enable_float8
|
| 257 |
+
)
|
| 258 |
+
parallelize_module(model, tp_mesh, tp_plan.model_plan)
|
| 259 |
+
|
| 260 |
+
blocks = get_blocks(model)
|
| 261 |
+
if blocks is None:
|
| 262 |
+
logger.warning("No block found for tensor parallelism")
|
| 263 |
+
else:
|
| 264 |
+
for _, block in enumerate(blocks):
|
| 265 |
+
parallelize_module(
|
| 266 |
+
module=block,
|
| 267 |
+
device_mesh=tp_mesh,
|
| 268 |
+
parallelize_plan=tp_plan.layer_plan,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if enable_async_tp:
|
| 272 |
+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
| 273 |
+
|
| 274 |
+
torch._inductor.config._micro_pipeline_tp = True
|
| 275 |
+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
|
| 276 |
+
|
| 277 |
+
logger.info(
|
| 278 |
+
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
|
| 279 |
+
"Tensor Parallelism to the model"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# for selective op activation checkpointing
|
| 284 |
+
_save_list = {
|
| 285 |
+
torch.ops.aten.mm.default,
|
| 286 |
+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
| 287 |
+
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
| 288 |
+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
|
| 289 |
+
# for low precision training, it's useful to always save
|
| 290 |
+
# the result of max, since the absolute maximum is
|
| 291 |
+
# used to compute the scaling factor for quantization.
|
| 292 |
+
torch.ops.aten.max.default,
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _apply_ac_to_block(module: nn.Module, ac_config):
|
| 297 |
+
valid_ac_modes = ("full", "selective")
|
| 298 |
+
if ac_config.mode not in valid_ac_modes:
|
| 299 |
+
raise ValueError(
|
| 300 |
+
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if ac_config.mode == "full":
|
| 304 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
| 305 |
+
|
| 306 |
+
assert ac_config.mode == "selective", f"{ac_config.mode}"
|
| 307 |
+
use_op_sac = ac_config.selective_ac_option == "op"
|
| 308 |
+
use_layer_sac = ac_config.selective_ac_option.isdigit()
|
| 309 |
+
if not use_op_sac and not use_layer_sac:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
|
| 312 |
+
f"Valid options: 'op' or a positive int representing layer frequency"
|
| 313 |
+
)
|
| 314 |
+
if use_op_sac:
|
| 315 |
+
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
|
| 316 |
+
|
| 317 |
+
def _get_custom_policy(meta):
|
| 318 |
+
def _custom_policy(ctx, func, *args, **kwargs):
|
| 319 |
+
mode = "recompute" if ctx.is_recompute else "forward"
|
| 320 |
+
mm_count_key = f"{mode}_mm_count"
|
| 321 |
+
if func == torch.ops.aten.mm.default:
|
| 322 |
+
meta[mm_count_key] += 1
|
| 323 |
+
# Saves output of all compute ops, except every second mm
|
| 324 |
+
to_save = func in _save_list and not (
|
| 325 |
+
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
|
| 326 |
+
)
|
| 327 |
+
return (
|
| 328 |
+
CheckpointPolicy.MUST_SAVE
|
| 329 |
+
if to_save
|
| 330 |
+
else CheckpointPolicy.PREFER_RECOMPUTE
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return _custom_policy
|
| 334 |
+
|
| 335 |
+
def selective_checkpointing_context_fn():
|
| 336 |
+
meta = defaultdict(int)
|
| 337 |
+
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
|
| 338 |
+
|
| 339 |
+
return ptd_checkpoint_wrapper(
|
| 340 |
+
module,
|
| 341 |
+
context_fn=selective_checkpointing_context_fn,
|
| 342 |
+
preserve_rng_state=False,
|
| 343 |
+
)
|
| 344 |
+
elif use_layer_sac:
|
| 345 |
+
# Checkpoint every `ac_freq` of the modules passed to this function
|
| 346 |
+
ac_freq = int(ac_config.selective_ac_option)
|
| 347 |
+
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
|
| 348 |
+
ptd_checkpoint_wrapper._count += 1
|
| 349 |
+
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
|
| 350 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
| 351 |
+
else:
|
| 352 |
+
return module
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def apply_ac(model: nn.Module, ac_config):
|
| 356 |
+
"""Apply activation checkpointing to the model."""
|
| 357 |
+
blocks = get_blocks(model)
|
| 358 |
+
if blocks is None:
|
| 359 |
+
logger.warning("No block found for activation checkpointing")
|
| 360 |
+
return
|
| 361 |
+
|
| 362 |
+
for layer_id, block in blocks.named_children():
|
| 363 |
+
block = _apply_ac_to_block(block, ac_config)
|
| 364 |
+
blocks.register_module(layer_id, block)
|
| 365 |
+
|
| 366 |
+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def apply_compile(model: nn.Module):
|
| 370 |
+
"""
|
| 371 |
+
Apply torch.compile to each block, which makes compilation efficient due to
|
| 372 |
+
repeated structure. Alternatively one can compile the whole model (after applying DP).
|
| 373 |
+
"""
|
| 374 |
+
|
| 375 |
+
blocks = get_blocks(model)
|
| 376 |
+
if blocks is None:
|
| 377 |
+
logger.warning("No block found for torch.compile")
|
| 378 |
+
else:
|
| 379 |
+
for layer_id, block in blocks.named_children():
|
| 380 |
+
block = torch.compile(block)
|
| 381 |
+
blocks.register_module(layer_id, block)
|
| 382 |
+
logger.info("Compiling each block with torch.compile")
|
| 383 |
+
|
| 384 |
+
real_model = get_model(model)
|
| 385 |
+
|
| 386 |
+
logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
|
| 387 |
+
embeddings_key = get_components_name(real_model, "tok_embeddings")
|
| 388 |
+
if embeddings_key is not None:
|
| 389 |
+
embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
|
| 390 |
+
real_model.register_module(embeddings_key, embeddings)
|
| 391 |
+
|
| 392 |
+
norm_key = get_components_name(real_model, "norm")
|
| 393 |
+
if norm_key is not None:
|
| 394 |
+
norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
|
| 395 |
+
real_model.register_module(norm_key, norm)
|
| 396 |
+
|
| 397 |
+
lm_head_key = get_components_name(model, "lm_head")
|
| 398 |
+
if lm_head_key is not None:
|
| 399 |
+
lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
|
| 400 |
+
model.register_module(lm_head_key, lm_head)
|
| 401 |
+
|
| 402 |
+
logger.info("Compiling the entire model with torch.compile")
|
| 403 |
+
model = torch.compile(model)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def apply_fsdp(
|
| 407 |
+
model: nn.Module,
|
| 408 |
+
dp_mesh: DeviceMesh,
|
| 409 |
+
param_dtype: torch.dtype,
|
| 410 |
+
reduce_dtype: torch.dtype,
|
| 411 |
+
pp_enabled: bool,
|
| 412 |
+
cpu_offload: bool = False,
|
| 413 |
+
reshard_after_forward_policy: str = "default",
|
| 414 |
+
):
|
| 415 |
+
"""
|
| 416 |
+
Apply data parallelism (via FSDP2) to the model.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
model (nn.Module): The model to apply data parallelism to.
|
| 420 |
+
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
|
| 421 |
+
param_dtype (torch.dtype): The data type to use for model parameters.
|
| 422 |
+
reduce_dtype (torch.dtype): The data type to use for reduction operations.
|
| 423 |
+
pp_enabled (bool): Whether pipeline parallelism is enabled.
|
| 424 |
+
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
|
| 425 |
+
reshard_after_forward_policy (str, optional):
|
| 426 |
+
The policy to use for resharding after forward pass. Defaults to "default".
|
| 427 |
+
Other options: "never", "always".
|
| 428 |
+
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
|
| 429 |
+
- "always" will enable `reshard_after_forward` for all forward passes.
|
| 430 |
+
- "never" will disable `reshard_after_forward` for all forward passes.
|
| 431 |
+
|
| 432 |
+
"""
|
| 433 |
+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
|
| 434 |
+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
| 435 |
+
if cpu_offload:
|
| 436 |
+
fsdp_config["offload_policy"] = CPUOffloadPolicy()
|
| 437 |
+
|
| 438 |
+
blocks = get_blocks(model)
|
| 439 |
+
if blocks is None:
|
| 440 |
+
logger.warning("No block found for FSDP")
|
| 441 |
+
else:
|
| 442 |
+
total_blocks = len(blocks)
|
| 443 |
+
for layer_id, block in enumerate(blocks):
|
| 444 |
+
if reshard_after_forward_policy == "always":
|
| 445 |
+
reshard_after_forward = True
|
| 446 |
+
elif reshard_after_forward_policy == "never":
|
| 447 |
+
reshard_after_forward = False
|
| 448 |
+
elif reshard_after_forward_policy == "default":
|
| 449 |
+
if pp_enabled:
|
| 450 |
+
# For PP, do not reshard after forward to avoid per-microbatch
|
| 451 |
+
# all-gathers, which can be expensive and non-overlapped
|
| 452 |
+
reshard_after_forward = False
|
| 453 |
+
else:
|
| 454 |
+
# As an optimization, do not reshard after forward for the last
|
| 455 |
+
# transformer block since FSDP would prefetch it immediately
|
| 456 |
+
reshard_after_forward = int(layer_id) < total_blocks - 1
|
| 457 |
+
else:
|
| 458 |
+
raise ValueError(
|
| 459 |
+
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
|
| 460 |
+
)
|
| 461 |
+
fully_shard(
|
| 462 |
+
block,
|
| 463 |
+
**fsdp_config,
|
| 464 |
+
reshard_after_forward=reshard_after_forward,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def apply_ddp(
|
| 471 |
+
model: nn.Module,
|
| 472 |
+
dp_mesh: DeviceMesh,
|
| 473 |
+
enable_compile: bool,
|
| 474 |
+
enable_compiled_autograd: bool,
|
| 475 |
+
):
|
| 476 |
+
if enable_compile:
|
| 477 |
+
if enable_compiled_autograd:
|
| 478 |
+
torch._dynamo.config.optimize_ddp = (
|
| 479 |
+
"python_reducer_without_compiled_forward"
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
|
| 483 |
+
|
| 484 |
+
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
|
| 485 |
+
|
| 486 |
+
logger.info("Applied DDP to the model")
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def get_model(model):
|
| 490 |
+
base_model_prefix = getattr(model, "base_model_prefix", "model")
|
| 491 |
+
if not hasattr(model, base_model_prefix):
|
| 492 |
+
return None
|
| 493 |
+
model = getattr(model, base_model_prefix)
|
| 494 |
+
return model
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def get_blocks(model):
|
| 498 |
+
# TODO[flame]: adapt for network not using 'layers' attribute
|
| 499 |
+
model = get_model(model)
|
| 500 |
+
if not hasattr(model, "layers"):
|
| 501 |
+
logger.warning('no "layers" in model can be found')
|
| 502 |
+
return None
|
| 503 |
+
return model.layers
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def get_components_name(model, component_name):
|
| 507 |
+
"""
|
| 508 |
+
We try to catch tok_embeddings, norm layers and lm_head layers
|
| 509 |
+
We do not catch the layer names in the blocks, for blocks see `get_blocks`
|
| 510 |
+
We assume the model has the following structure:
|
| 511 |
+
LlamaForCausalLM:
|
| 512 |
+
Model:
|
| 513 |
+
embed_tokens,
|
| 514 |
+
layers,
|
| 515 |
+
norm,
|
| 516 |
+
lm_head
|
| 517 |
+
***
|
| 518 |
+
so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
|
| 519 |
+
and for 'lm_head' we need to pass `model`
|
| 520 |
+
***
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
if component_name == "tok_embeddings":
|
| 524 |
+
if hasattr(model, "tok_embeddings"):
|
| 525 |
+
return "tok_embeddings"
|
| 526 |
+
elif hasattr(model, "embed_tokens"):
|
| 527 |
+
return "embed_tokens"
|
| 528 |
+
elif hasattr(model, "embeddings"):
|
| 529 |
+
return "embeddings"
|
| 530 |
+
else:
|
| 531 |
+
logger.warning("No tok_embeddings found in model")
|
| 532 |
+
return None
|
| 533 |
+
|
| 534 |
+
elif component_name == "norm":
|
| 535 |
+
if hasattr(model, "norm"):
|
| 536 |
+
return "norm"
|
| 537 |
+
elif hasattr(model, "norms"):
|
| 538 |
+
return "norms"
|
| 539 |
+
elif hasattr(model, "layernorm"):
|
| 540 |
+
return "layernorm"
|
| 541 |
+
else:
|
| 542 |
+
logger.warning("No norm found in model")
|
| 543 |
+
return None
|
| 544 |
+
|
| 545 |
+
elif component_name == "lm_head":
|
| 546 |
+
if hasattr(model, "lm_head"):
|
| 547 |
+
return "lm_head"
|
| 548 |
+
else:
|
| 549 |
+
logger.warning("No lm_head found in model")
|
| 550 |
+
return None
|
flame/models/pipeline_fla.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file applies the PT-D pipeline parallelism to the Llama model.
|
| 8 |
+
|
| 9 |
+
import copy
|
| 10 |
+
from typing import Callable, Optional, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.distributed import DeviceMesh
|
| 15 |
+
from torch.distributed.pipelining import PipelineStage
|
| 16 |
+
from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
|
| 17 |
+
from transformers import PretrainedConfig
|
| 18 |
+
|
| 19 |
+
from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
|
| 20 |
+
from torchtitan.config_manager import JobConfig
|
| 21 |
+
from torchtitan.distributed.parallel_dims import ParallelDims
|
| 22 |
+
from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
|
| 23 |
+
from torchtitan.tools.logging import logger
|
| 24 |
+
|
| 25 |
+
DeviceType = Union[int, str, torch.device]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def pipeline_fla(
|
| 29 |
+
model: nn.Module,
|
| 30 |
+
pp_mesh: DeviceMesh,
|
| 31 |
+
parallel_dims: ParallelDims,
|
| 32 |
+
job_config: JobConfig,
|
| 33 |
+
device: DeviceType,
|
| 34 |
+
model_config: PretrainedConfig,
|
| 35 |
+
loss_fn: Callable[..., torch.Tensor],
|
| 36 |
+
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
|
| 37 |
+
stages, models = pipeline_fla_manual_split(
|
| 38 |
+
model, pp_mesh, parallel_dims, job_config, device, model_config
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
|
| 42 |
+
|
| 43 |
+
# This is used in the train loop to determine whether to pass in the input_ids and labels
|
| 44 |
+
has_first_stage = False
|
| 45 |
+
has_last_stage = False
|
| 46 |
+
for stage in stages:
|
| 47 |
+
if stage.is_first:
|
| 48 |
+
has_first_stage = True
|
| 49 |
+
if stage.is_last:
|
| 50 |
+
has_last_stage = True
|
| 51 |
+
|
| 52 |
+
return pp_schedule, models, has_first_stage, has_last_stage
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def pipeline_fla_manual_split(
|
| 56 |
+
whole_model: nn.Module,
|
| 57 |
+
pp_mesh: DeviceMesh,
|
| 58 |
+
parallel_dims: ParallelDims,
|
| 59 |
+
job_config: JobConfig,
|
| 60 |
+
device: DeviceType,
|
| 61 |
+
model_config: PretrainedConfig,
|
| 62 |
+
) -> tuple[list[PipelineStage], list[nn.Module]]:
|
| 63 |
+
"""
|
| 64 |
+
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
|
| 65 |
+
|
| 66 |
+
It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
|
| 67 |
+
|
| 68 |
+
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
|
| 69 |
+
parallelism.
|
| 70 |
+
"""
|
| 71 |
+
pp_rank = pp_mesh.get_local_rank()
|
| 72 |
+
pp_size = pp_mesh.size()
|
| 73 |
+
|
| 74 |
+
splits = (
|
| 75 |
+
job_config.experimental.pipeline_parallel_split_points
|
| 76 |
+
or generate_split_points(
|
| 77 |
+
job_config, parallel_dims.pp, model_config.num_hidden_layers
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def _build_stage(
|
| 82 |
+
stage_idx: int,
|
| 83 |
+
start_layer: Optional[str],
|
| 84 |
+
stop_layer: Optional[str],
|
| 85 |
+
is_first: bool = False,
|
| 86 |
+
is_last: bool = False,
|
| 87 |
+
) -> tuple[PipelineStage, nn.Module]:
|
| 88 |
+
model = copy.deepcopy(whole_model)
|
| 89 |
+
if not is_first:
|
| 90 |
+
# we do `model.tok_embeddings = None` here
|
| 91 |
+
real_model = get_model(model)
|
| 92 |
+
tok_embeddings_name = get_components_name(real_model, "tok_embeddings")
|
| 93 |
+
setattr(real_model, tok_embeddings_name, None)
|
| 94 |
+
|
| 95 |
+
drop_layers = start_layer is not None
|
| 96 |
+
# Get module dictionary from get_blocks(model)
|
| 97 |
+
# and Create a list of keys before modifying dictionary
|
| 98 |
+
module_dict = get_blocks(model)._modules # Store reference
|
| 99 |
+
layer_names = list(module_dict.keys())
|
| 100 |
+
|
| 101 |
+
# Iterate over the list of keys instead of `_modules.items()`
|
| 102 |
+
for name in layer_names:
|
| 103 |
+
# Dynamically determine prefix (blocks.* or layers.*)
|
| 104 |
+
prefix = start_layer.split(".")[0] if start_layer else "layers"
|
| 105 |
+
layer_name = f"{prefix}.{name}" # Construct the correct name format
|
| 106 |
+
|
| 107 |
+
# Ensure `drop_layers` activation is based on actual naming
|
| 108 |
+
if layer_name == start_layer:
|
| 109 |
+
drop_layers = False
|
| 110 |
+
if layer_name == stop_layer:
|
| 111 |
+
drop_layers = True
|
| 112 |
+
|
| 113 |
+
# Delete layer if drop_layers is active
|
| 114 |
+
if drop_layers:
|
| 115 |
+
del module_dict[name] # Safe deletion from stored dictionary
|
| 116 |
+
|
| 117 |
+
if not is_last:
|
| 118 |
+
# we do `model.norm = None` and `model.output = None`
|
| 119 |
+
real_model = get_model(model)
|
| 120 |
+
norm_name = get_components_name(real_model, "norm")
|
| 121 |
+
setattr(real_model, norm_name, None)
|
| 122 |
+
|
| 123 |
+
head_name = get_components_name(model, "lm_head")
|
| 124 |
+
setattr(model, head_name, None)
|
| 125 |
+
|
| 126 |
+
stage = PipelineStage(
|
| 127 |
+
model,
|
| 128 |
+
stage_idx,
|
| 129 |
+
num_stages,
|
| 130 |
+
device,
|
| 131 |
+
group=pp_mesh.get_group("pp"),
|
| 132 |
+
)
|
| 133 |
+
return stage, model
|
| 134 |
+
|
| 135 |
+
num_stages = len(splits) + 1
|
| 136 |
+
stage_idx = pp_rank
|
| 137 |
+
|
| 138 |
+
stages = []
|
| 139 |
+
models = []
|
| 140 |
+
|
| 141 |
+
schedule_class = get_schedule_class(
|
| 142 |
+
job_config.experimental.pipeline_parallel_schedule
|
| 143 |
+
)
|
| 144 |
+
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
|
| 145 |
+
|
| 146 |
+
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
|
| 147 |
+
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
|
| 148 |
+
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
|
| 149 |
+
stage, model_chunk = _build_stage(
|
| 150 |
+
stage_idx,
|
| 151 |
+
start_layer,
|
| 152 |
+
stop_layer,
|
| 153 |
+
is_first=stage_idx == 0,
|
| 154 |
+
is_last=stage_idx == num_stages - 1,
|
| 155 |
+
)
|
| 156 |
+
logger.info(
|
| 157 |
+
f"PP rank {pp_rank} is building stage_idx {stage_idx}"
|
| 158 |
+
f" with start_layer {start_layer}, stop_layer {stop_layer}"
|
| 159 |
+
)
|
| 160 |
+
stages.append(stage)
|
| 161 |
+
models.append(model_chunk)
|
| 162 |
+
return stages, models
|
flame/tools/__init__.py
ADDED
|
File without changes
|
flame/tools/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (136 Bytes). View file
|
|
|
flame/tools/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
flame/tools/utils.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torchtitan.tools.logging import logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
|
| 12 |
+
nparams = sum(p.numel() for p in model.parameters())
|
| 13 |
+
nparams_embedding = sum(
|
| 14 |
+
sum(p.numel() for p in m.parameters())
|
| 15 |
+
for m in model.children()
|
| 16 |
+
if isinstance(m, nn.Embedding)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
if hasattr(model_config, "num_heads"):
|
| 20 |
+
num_heads = model_config.num_heads
|
| 21 |
+
elif hasattr(model_config, "num_attention_heads"):
|
| 22 |
+
num_heads = model_config.num_attention_heads
|
| 23 |
+
else:
|
| 24 |
+
num_heads = 1
|
| 25 |
+
logger.warning("num_heads not found in model_config, defaulting to 1. ")
|
| 26 |
+
|
| 27 |
+
l, h, q, t = (
|
| 28 |
+
model_config.num_hidden_layers,
|
| 29 |
+
num_heads,
|
| 30 |
+
model_config.hidden_size // num_heads,
|
| 31 |
+
seq_len,
|
| 32 |
+
)
|
| 33 |
+
# Reasoning behind the factor of 12 for the self-attention part of the formula:
|
| 34 |
+
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
|
| 35 |
+
# 2. the flash attention does 1 more matmul recomputation in the backward
|
| 36 |
+
# but recomputation should not be counted in calculating MFU (+0)
|
| 37 |
+
# 3. each matmul performs 1 multiplication and 1 addition (*2)
|
| 38 |
+
# 4. we follow the convention and do not account for sparsity in causal attention
|
| 39 |
+
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
| 40 |
+
|
| 41 |
+
return nparams, num_flops_per_token
|
flame/utils/__init__.py
ADDED
|
File without changes
|
flame/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (136 Bytes). View file
|
|
|
flame/utils/__pycache__/checkpoint.cpython-312.pyc
ADDED
|
Binary file (4.07 kB). View file
|
|
|
flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc
ADDED
|
Binary file (3.73 kB). View file
|
|
|
flame/utils/__pycache__/convert_hf_to_dcp.cpython-312.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
flame/utils/__pycache__/hf_utils.cpython-312.pyc
ADDED
|
Binary file (4.46 kB). View file
|
|
|
flame/utils/checkpoint.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import re
|
| 4 |
+
import shutil
|
| 5 |
+
from torchtitan.tools.logging import logger
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int):
|
| 9 |
+
"""Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats."""
|
| 10 |
+
if keep_latest_k <= 0:
|
| 11 |
+
return # Keep all checkpoints
|
| 12 |
+
|
| 13 |
+
logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}")
|
| 14 |
+
|
| 15 |
+
# Cleanup DCP checkpoints (step-*)
|
| 16 |
+
dcp_checkpoints = sorted(
|
| 17 |
+
glob.glob(os.path.join(checkpoint_dir, "step-*")),
|
| 18 |
+
key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1,
|
| 19 |
+
reverse=True
|
| 20 |
+
)
|
| 21 |
+
# Filter out HF format directories
|
| 22 |
+
dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")]
|
| 23 |
+
|
| 24 |
+
if len(dcp_checkpoints) > keep_latest_k:
|
| 25 |
+
checkpoints_to_delete = dcp_checkpoints[keep_latest_k:]
|
| 26 |
+
logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
|
| 27 |
+
for ckpt_path in checkpoints_to_delete:
|
| 28 |
+
if os.path.isdir(ckpt_path): # Ensure it's a directory
|
| 29 |
+
try:
|
| 30 |
+
shutil.rmtree(ckpt_path)
|
| 31 |
+
except OSError as e:
|
| 32 |
+
logger.error(f"Error removing directory {ckpt_path}: {e}")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Cleanup HF checkpoints (step-*-hf)
|
| 36 |
+
hf_checkpoints = sorted(
|
| 37 |
+
glob.glob(os.path.join(checkpoint_dir, "step-*-hf")),
|
| 38 |
+
key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1,
|
| 39 |
+
reverse=True
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if len(hf_checkpoints) > keep_latest_k:
|
| 43 |
+
checkpoints_to_delete = hf_checkpoints[keep_latest_k:]
|
| 44 |
+
logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
|
| 45 |
+
for ckpt_path in checkpoints_to_delete:
|
| 46 |
+
if os.path.isdir(ckpt_path): # Ensure it's a directory
|
| 47 |
+
try:
|
| 48 |
+
shutil.rmtree(ckpt_path)
|
| 49 |
+
except OSError as e:
|
| 50 |
+
logger.error(f"Error removing directory {ckpt_path}: {e}")
|
flame/utils/convert_dcp_to_hf.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from datetime import timedelta
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.serialization
|
| 12 |
+
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
| 13 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
import fla # noqa
|
| 16 |
+
from torchtitan.tools.logging import init_logger, logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@torch.inference_mode()
|
| 20 |
+
def save_pretrained(
|
| 21 |
+
path: str,
|
| 22 |
+
step: int,
|
| 23 |
+
config: str,
|
| 24 |
+
tokenizer: str
|
| 25 |
+
):
|
| 26 |
+
logger.info(f"Loading the config from {config}")
|
| 27 |
+
config = AutoConfig.from_pretrained(config, trust_remote_code=True)
|
| 28 |
+
|
| 29 |
+
logger.info(f"Saving the config to {path}")
|
| 30 |
+
config.save_pretrained(path)
|
| 31 |
+
logger.info(f"Loading the tokenizer from {tokenizer}")
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
|
| 33 |
+
logger.info(f"Saving the tokenizer to {path}")
|
| 34 |
+
tokenizer.save_pretrained(path)
|
| 35 |
+
|
| 36 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 37 |
+
# base_checkpoint_dir = os.path.dirname(path)
|
| 38 |
+
base_checkpoint_dir = path
|
| 39 |
+
checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}')
|
| 40 |
+
checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
|
| 41 |
+
logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
|
| 42 |
+
dcp_to_torch_save(checkpoint, checkpoint_path)
|
| 43 |
+
|
| 44 |
+
logger.info(f"Initializing the model from config\n{config}")
|
| 45 |
+
model = AutoModelForCausalLM.from_config(config)
|
| 46 |
+
logger.info(model)
|
| 47 |
+
logger.info("Loading state dict from the checkpoint")
|
| 48 |
+
|
| 49 |
+
# Add datetime.timedelta and io.BytesIO to safe globals
|
| 50 |
+
torch.serialization.add_safe_globals([timedelta, io.BytesIO])
|
| 51 |
+
# torch.load now with default weights_only=True will work
|
| 52 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
|
| 53 |
+
|
| 54 |
+
logger.info(f"Saving the model to {path}")
|
| 55 |
+
model.save_pretrained(path)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
init_logger()
|
| 60 |
+
parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
|
| 61 |
+
parser.add_argument("--path", type=str, required=True)
|
| 62 |
+
parser.add_argument("--step", type=int, required=True)
|
| 63 |
+
parser.add_argument("--config", type=str, required=True)
|
| 64 |
+
parser.add_argument("--tokenizer", type=str, required=True)
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
save_pretrained(args.path, args.step, args.config, args.tokenizer)
|
flame/utils/convert_hf_to_dcp.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed.checkpoint as DCP
|
| 9 |
+
from transformers import AutoModelForCausalLM
|
| 10 |
+
|
| 11 |
+
import fla # noqa
|
| 12 |
+
from torchtitan.tools.logging import init_logger, logger
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@torch.inference_mode()
|
| 16 |
+
def convert_hf_weights(model: str, checkpoint: str):
|
| 17 |
+
logger.info(f"Loading model from {model}")
|
| 18 |
+
model = AutoModelForCausalLM.from_pretrained(model)
|
| 19 |
+
state_dict = model.state_dict()
|
| 20 |
+
|
| 21 |
+
logger.info(f"Writing to DCP at '{checkpoint}'")
|
| 22 |
+
checkpoint.mkdir(parents=True, exist_ok=True)
|
| 23 |
+
storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
|
| 24 |
+
DCP.save({"model": state_dict}, storage_writer=storage_writer)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
init_logger()
|
| 29 |
+
parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
|
| 30 |
+
parser.add_argument("--model", type=str, required=True)
|
| 31 |
+
parser.add_argument("--checkpoint", type=Path, required=True)
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
|
| 34 |
+
convert_hf_weights(args.model, args.checkpoint)
|
flame/utils/hf_utils.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo
|
| 4 |
+
from torchtitan.tools.logging import logger
|
| 5 |
+
|
| 6 |
+
def upload_checkpoint_to_hf(
|
| 7 |
+
local_path: str,
|
| 8 |
+
step: int,
|
| 9 |
+
hf_repo_id_for_run: str,
|
| 10 |
+
hf_keep_latest_k: int,
|
| 11 |
+
upload_format: str
|
| 12 |
+
):
|
| 13 |
+
"""Uploads a checkpoint directory to HF Hub and manages retention."""
|
| 14 |
+
if not os.path.isdir(local_path):
|
| 15 |
+
logger.error(f"Local path for upload does not exist or is not a directory: {local_path}")
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
api = HfApi()
|
| 19 |
+
token = HfFolder.get_token()
|
| 20 |
+
if not token:
|
| 21 |
+
logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.")
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
# --- Ensure the specific repository for this run exists ---
|
| 25 |
+
try:
|
| 26 |
+
logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...")
|
| 27 |
+
# Use create_repo which handles creation only if it doesn't exist
|
| 28 |
+
create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True)
|
| 29 |
+
logger.info(f"Repository {hf_repo_id_for_run} ensured.")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True)
|
| 32 |
+
return # Stop if repo interaction fails
|
| 33 |
+
|
| 34 |
+
commit_message = f"Upload {upload_format.upper()} checkpoint step {step}"
|
| 35 |
+
path_in_repo = f"step-{step}"
|
| 36 |
+
|
| 37 |
+
logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...")
|
| 38 |
+
try:
|
| 39 |
+
api.upload_folder(
|
| 40 |
+
folder_path=local_path,
|
| 41 |
+
path_in_repo=path_in_repo,
|
| 42 |
+
repo_id=hf_repo_id_for_run,
|
| 43 |
+
repo_type="model",
|
| 44 |
+
commit_message=commit_message,
|
| 45 |
+
token=token,
|
| 46 |
+
)
|
| 47 |
+
logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.")
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True)
|
| 50 |
+
if hf_keep_latest_k > 0:
|
| 51 |
+
logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}")
|
| 52 |
+
try:
|
| 53 |
+
repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False)
|
| 54 |
+
step_folders = [
|
| 55 |
+
item.path for item in repo_files
|
| 56 |
+
if item.path.startswith("step-") and item.path[5:].isdigit()
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True)
|
| 60 |
+
|
| 61 |
+
if len(step_folders) > hf_keep_latest_k:
|
| 62 |
+
folders_to_delete = step_folders[hf_keep_latest_k:]
|
| 63 |
+
logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}")
|
| 64 |
+
for folder in folders_to_delete:
|
| 65 |
+
# Deleting requires repo_id, path_in_repo, and token
|
| 66 |
+
api.delete_folder(
|
| 67 |
+
repo_id=hf_repo_id_for_run,
|
| 68 |
+
path_in_repo=folder,
|
| 69 |
+
repo_type="model",
|
| 70 |
+
commit_message=f"Delete old checkpoint {folder}",
|
| 71 |
+
token=token
|
| 72 |
+
)
|
| 73 |
+
logger.info("Hub cleanup complete.")
|
| 74 |
+
else:
|
| 75 |
+
logger.info("No old checkpoints found on Hub to delete.")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True)
|
logs/none_ewbp5xc1/attempt_0/1/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_1024/rank0_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_1024/rank1_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_1024/rank5_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_1024/rank6_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_1024/rank7_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_1536/rank4_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_20992/rank5_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_23552/rank6_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_2560/rank5_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_2560/rank7_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_29696/rank2_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_29696/rank6_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_30720/rank6_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_3584/rank0_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_3584/rank4_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_3584/rank5_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_3584/rank7_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tb/20250901-0749/wandb/run-20250901_074914-top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-6.8.0-62-generic-x86_64-with-glibc2.39",
|
| 3 |
+
"python": "CPython 3.12.11",
|
| 4 |
+
"startedAt": "2025-09-01T07:49:14.031224Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--job.config_file",
|
| 7 |
+
"flame/models/fla.toml",
|
| 8 |
+
"--job.dump_folder",
|
| 9 |
+
"exp/top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine",
|
| 10 |
+
"--model.config",
|
| 11 |
+
"configs/top_transformer_1B.json",
|
| 12 |
+
"--model.tokenizer_path",
|
| 13 |
+
"fla-hub/transformer-1.3B-100B",
|
| 14 |
+
"--optimizer.name",
|
| 15 |
+
"AdamW",
|
| 16 |
+
"--optimizer.eps",
|
| 17 |
+
"1e-15",
|
| 18 |
+
"--optimizer.lr",
|
| 19 |
+
"5e-5",
|
| 20 |
+
"--lr_scheduler.warmup_steps",
|
| 21 |
+
"400",
|
| 22 |
+
"--lr_scheduler.lr_min",
|
| 23 |
+
"0.1",
|
| 24 |
+
"--lr_scheduler.decay_type",
|
| 25 |
+
"cosine",
|
| 26 |
+
"--training.batch_size",
|
| 27 |
+
"16",
|
| 28 |
+
"--training.seq_len",
|
| 29 |
+
"4096",
|
| 30 |
+
"--training.context_len",
|
| 31 |
+
"4096",
|
| 32 |
+
"--training.gradient_accumulation_steps",
|
| 33 |
+
"1",
|
| 34 |
+
"--training.steps",
|
| 35 |
+
"40000",
|
| 36 |
+
"--training.max_norm",
|
| 37 |
+
"1.0",
|
| 38 |
+
"--training.skip_nan_inf",
|
| 39 |
+
"--training.dataset",
|
| 40 |
+
"/home/cvm/.cache/zaydzuhri___stack-edu-python/default",
|
| 41 |
+
"--training.dataset_split",
|
| 42 |
+
"train",
|
| 43 |
+
"--training.num_workers",
|
| 44 |
+
"32",
|
| 45 |
+
"--training.prefetch_factor",
|
| 46 |
+
"2",
|
| 47 |
+
"--training.seed",
|
| 48 |
+
"79",
|
| 49 |
+
"--training.compile",
|
| 50 |
+
"--checkpoint.interval",
|
| 51 |
+
"5000",
|
| 52 |
+
"--checkpoint.load_step",
|
| 53 |
+
"-1",
|
| 54 |
+
"--metrics.log_freq",
|
| 55 |
+
"5",
|
| 56 |
+
"--checkpoint.hf_upload_enabled",
|
| 57 |
+
"--checkpoint.hf_repo_base_name",
|
| 58 |
+
"zaydzuhri/top-code-1B-4096-batch16x1-steps40000",
|
| 59 |
+
"--comm.init_timeout_seconds",
|
| 60 |
+
"1600",
|
| 61 |
+
"--comm.train_timeout_seconds",
|
| 62 |
+
"1600"
|
| 63 |
+
],
|
| 64 |
+
"program": "-m flame.train",
|
| 65 |
+
"git": {
|
| 66 |
+
"remote": "https://github.com/zaydzuhri/flame.git",
|
| 67 |
+
"commit": "aa4d5932e54fad8a568e10aa6895e69e0664fcf1"
|
| 68 |
+
},
|
| 69 |
+
"email": "zaydzuhri@gmail.com",
|
| 70 |
+
"root": "exp/top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine/tb/20250901-0749",
|
| 71 |
+
"host": "cvm-gncv9hlh",
|
| 72 |
+
"executable": "/home/cvm/miniconda3/envs/flame-env/bin/python3.12",
|
| 73 |
+
"cpu_count": 64,
|
| 74 |
+
"cpu_count_logical": 128,
|
| 75 |
+
"gpu": "NVIDIA H200",
|
| 76 |
+
"gpu_count": 8,
|
| 77 |
+
"disk": {
|
| 78 |
+
"/": {
|
| 79 |
+
"total": "3242363822080",
|
| 80 |
+
"used": "1307996758016"
|
| 81 |
+
}
|
| 82 |
+
},
|
| 83 |
+
"memory": {
|
| 84 |
+
"total": "1913833021440"
|
| 85 |
+
},
|
| 86 |
+
"gpu_nvidia": [
|
| 87 |
+
{
|
| 88 |
+
"name": "NVIDIA H200",
|
| 89 |
+
"memoryTotal": "150754820096",
|
| 90 |
+
"cudaCores": 16896,
|
| 91 |
+
"architecture": "Hopper",
|
| 92 |
+
"uuid": "GPU-eddf9f4c-ffde-5f10-3c76-12ebce1f042b"
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"name": "NVIDIA H200",
|
| 96 |
+
"memoryTotal": "150754820096",
|
| 97 |
+
"cudaCores": 16896,
|
| 98 |
+
"architecture": "Hopper",
|
| 99 |
+
"uuid": "GPU-b532c850-7343-8f67-7eb1-a69024695a99"
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"name": "NVIDIA H200",
|
| 103 |
+
"memoryTotal": "150754820096",
|
| 104 |
+
"cudaCores": 16896,
|
| 105 |
+
"architecture": "Hopper",
|
| 106 |
+
"uuid": "GPU-751a6bdf-72f3-4f5a-fefd-d2b98c338579"
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"name": "NVIDIA H200",
|
| 110 |
+
"memoryTotal": "150754820096",
|
| 111 |
+
"cudaCores": 16896,
|
| 112 |
+
"architecture": "Hopper",
|
| 113 |
+
"uuid": "GPU-0cd9d3c7-1d2e-1925-91eb-8ec99a4ed277"
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"name": "NVIDIA H200",
|
| 117 |
+
"memoryTotal": "150754820096",
|
| 118 |
+
"cudaCores": 16896,
|
| 119 |
+
"architecture": "Hopper",
|
| 120 |
+
"uuid": "GPU-fba7e7ab-8340-13b0-b893-c3686cfec728"
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"name": "NVIDIA H200",
|
| 124 |
+
"memoryTotal": "150754820096",
|
| 125 |
+
"cudaCores": 16896,
|
| 126 |
+
"architecture": "Hopper",
|
| 127 |
+
"uuid": "GPU-12ca11c0-9080-3877-2bd5-3775573a4134"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"name": "NVIDIA H200",
|
| 131 |
+
"memoryTotal": "150754820096",
|
| 132 |
+
"cudaCores": 16896,
|
| 133 |
+
"architecture": "Hopper",
|
| 134 |
+
"uuid": "GPU-32b3ec8b-9dc8-c6f6-5c19-74fa2ce10ffd"
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"name": "NVIDIA H200",
|
| 138 |
+
"memoryTotal": "150754820096",
|
| 139 |
+
"cudaCores": 16896,
|
| 140 |
+
"architecture": "Hopper",
|
| 141 |
+
"uuid": "GPU-d0021141-e4f4-14ab-c2ab-0ef3e30d6dd5"
|
| 142 |
+
}
|
| 143 |
+
],
|
| 144 |
+
"cudaVersion": "12.8",
|
| 145 |
+
"writerId": "da7dvih583ith342zcw0cwucsgured2u"
|
| 146 |
+
}
|
tb/20250901-0749/wandb/run-20250901_074914-top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-09-01T07:49:14.247972294Z","level":"INFO","msg":"stream: starting","core version":"0.21.1"}
|
| 2 |
+
{"time":"2025-09-01T07:49:14.545288881Z","level":"INFO","msg":"stream: created new stream","id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
|
| 3 |
+
{"time":"2025-09-01T07:49:14.545362953Z","level":"INFO","msg":"stream: started","id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
|
| 4 |
+
{"time":"2025-09-01T07:49:14.54541562Z","level":"INFO","msg":"writer: started","stream_id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
|
| 5 |
+
{"time":"2025-09-01T07:49:14.545435817Z","level":"INFO","msg":"sender: started","stream_id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
|
| 6 |
+
{"time":"2025-09-01T07:49:14.545490133Z","level":"INFO","msg":"handler: started","stream_id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
|
| 7 |
+
{"time":"2025-09-01T12:39:44.49607374Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
| 8 |
+
{"time":"2025-09-01T12:57:09.402167829Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
| 9 |
+
{"time":"2025-09-01T20:38:44.471380019Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
| 10 |
+
{"time":"2025-09-01T22:25:18.669785309Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
| 11 |
+
{"time":"2025-09-01T22:55:35.532603708Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
| 12 |
+
{"time":"2025-09-02T07:07:34.089412209Z","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
|
| 13 |
+
{"time":"2025-09-02T07:07:34.291787824Z","level":"INFO","msg":"handler: operation stats","stats":{}}
|
| 14 |
+
{"time":"2025-09-02T07:07:34.295689194Z","level":"INFO","msg":"stream: closing","id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
|
| 15 |
+
{"time":"2025-09-02T07:07:34.295726455Z","level":"INFO","msg":"handler: closed","stream_id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
|
| 16 |
+
{"time":"2025-09-02T07:07:34.295770415Z","level":"INFO","msg":"sender: closed","stream_id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
|
| 17 |
+
{"time":"2025-09-02T07:07:34.29578361Z","level":"INFO","msg":"stream: closed","id":"top_transformer-top.code.1B.batch16.seqlen4096.context4096.warmup400.update1.steps40000.lr5e-5.cosine-202509010747"}
|