Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- flame/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/__pycache__/data.cpython-312.pyc +0 -0
- flame/__pycache__/train.cpython-312.pyc +0 -0
- flame/components/__init__.py +0 -0
- flame/components/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
- flame/components/checkpoint.py +59 -0
- flame/models/__init__.py +0 -0
- flame/models/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/models/__pycache__/parallelize_fla.cpython-312.pyc +0 -0
- flame/models/__pycache__/pipeline_fla.cpython-312.pyc +0 -0
- flame/models/fla.toml +67 -0
- flame/models/parallelize_fla.py +550 -0
- flame/models/pipeline_fla.py +162 -0
- flame/tools/__init__.py +0 -0
- flame/tools/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/tools/__pycache__/utils.cpython-312.pyc +0 -0
- flame/tools/utils.py +41 -0
- flame/utils/__init__.py +0 -0
- flame/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/utils/__pycache__/checkpoint.cpython-312.pyc +0 -0
- flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc +0 -0
- flame/utils/__pycache__/convert_hf_to_dcp.cpython-312.pyc +0 -0
- flame/utils/__pycache__/hf_utils.cpython-312.pyc +0 -0
- flame/utils/checkpoint.py +50 -0
- flame/utils/convert_dcp_to_hf.py +66 -0
- flame/utils/convert_hf_to_dcp.py +34 -0
- flame/utils/hf_utils.py +77 -0
- logs/none_enyj3lod/attempt_0/5/stderr.log +0 -0
- logs/none_enyj3lod/attempt_0/7/stderr.log +0 -0
- profile_trace/iteration_15872/rank6_trace.json +0 -0
- profile_trace/iteration_16384/rank2_trace.json +0 -0
- profile_trace/iteration_16384/rank3_trace.json +0 -0
- profile_trace/iteration_16384/rank5_trace.json +0 -0
- profile_trace/iteration_19968/rank3_trace.json +0 -0
- profile_trace/iteration_19968/rank7_trace.json +0 -0
- profile_trace/iteration_3072/rank2_trace.json +0 -0
- profile_trace/iteration_32256/rank1_trace.json +0 -0
- profile_trace/iteration_37376/rank0_trace.json +0 -0
- profile_trace/iteration_37376/rank6_trace.json +0 -0
- profile_trace/iteration_9216/rank2_trace.json +0 -0
- profile_trace/iteration_9216/rank3_trace.json +0 -0
- profile_trace/iteration_9216/rank4_trace.json +0 -0
- profile_trace/iteration_9216/rank6_trace.json +0 -0
- profile_trace/iteration_9216/rank7_trace.json +0 -0
- profile_trace/iteration_9728/rank1_trace.json +0 -0
- tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/files/output.log +0 -0
- tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/files/wandb-metadata.json +146 -0
- tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/logs/debug-core.log +16 -0
- torchtitan/__pycache__/config_manager.cpython-312.pyc +0 -0
flame/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
flame/__pycache__/data.cpython-312.pyc
ADDED
|
Binary file (31.3 kB). View file
|
|
|
flame/__pycache__/train.cpython-312.pyc
ADDED
|
Binary file (38.1 kB). View file
|
|
|
flame/components/__init__.py
ADDED
|
File without changes
|
flame/components/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (141 Bytes). View file
|
|
|
flame/components/__pycache__/checkpoint.cpython-312.pyc
ADDED
|
Binary file (3.21 kB). View file
|
|
|
flame/components/checkpoint.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from datetime import timedelta
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
from typing import Any, Dict, List
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TrainState(Stateful):
|
| 18 |
+
step: int = 0
|
| 19 |
+
skipped_step: int = 0
|
| 20 |
+
token: int = 0
|
| 21 |
+
elapsed: timedelta = timedelta(0)
|
| 22 |
+
global_avg_losses: List[float] = field(default_factory=list)
|
| 23 |
+
global_max_losses: List[float] = field(default_factory=list)
|
| 24 |
+
log_steps: List[int] = field(default_factory=list)
|
| 25 |
+
|
| 26 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 27 |
+
# Only checkpoint global_avg_losses and global_max_losses per log frequency
|
| 28 |
+
# to avoid sync overhead in every iteration.
|
| 29 |
+
global_avg_losses_bytes = BytesIO()
|
| 30 |
+
torch.save(self.global_avg_losses, global_avg_losses_bytes)
|
| 31 |
+
global_max_losses_bytes = BytesIO()
|
| 32 |
+
torch.save(self.global_max_losses, global_max_losses_bytes)
|
| 33 |
+
log_steps_bytes = BytesIO()
|
| 34 |
+
torch.save(self.log_steps, log_steps_bytes)
|
| 35 |
+
return {
|
| 36 |
+
"step": torch.tensor(self.step, dtype=torch.int32),
|
| 37 |
+
"skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
|
| 38 |
+
"token": torch.tensor(self.token, dtype=torch.int64),
|
| 39 |
+
"elapsed": self.elapsed,
|
| 40 |
+
"global_avg_losses": global_avg_losses_bytes,
|
| 41 |
+
"global_max_losses": global_max_losses_bytes,
|
| 42 |
+
"log_steps": log_steps_bytes,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def load_state_dict(self, state_dict) -> None:
|
| 46 |
+
self.step = state_dict["step"].item()
|
| 47 |
+
self.skipped_step = state_dict.get("skipped_step", 0).item()
|
| 48 |
+
self.token = state_dict["token"].item()
|
| 49 |
+
self.elapsed = state_dict["elapsed"]
|
| 50 |
+
state_dict["global_avg_losses"].seek(0)
|
| 51 |
+
self.global_avg_losses = torch.load(
|
| 52 |
+
state_dict["global_avg_losses"], weights_only=False
|
| 53 |
+
)
|
| 54 |
+
state_dict["global_max_losses"].seek(0)
|
| 55 |
+
self.global_max_losses = torch.load(
|
| 56 |
+
state_dict["global_max_losses"], weights_only=False
|
| 57 |
+
)
|
| 58 |
+
state_dict["log_steps"].seek(0)
|
| 59 |
+
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
|
flame/models/__init__.py
ADDED
|
File without changes
|
flame/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (137 Bytes). View file
|
|
|
flame/models/__pycache__/parallelize_fla.cpython-312.pyc
ADDED
|
Binary file (22.1 kB). View file
|
|
|
flame/models/__pycache__/pipeline_fla.cpython-312.pyc
ADDED
|
Binary file (5.75 kB). View file
|
|
|
flame/models/fla.toml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[model]
|
| 2 |
+
config = "fla-hub/transformer-1.3B-100B"
|
| 3 |
+
tokenizer_path = "fla-hub/transformer-1.3B-100B"
|
| 4 |
+
|
| 5 |
+
[job]
|
| 6 |
+
dump_folder = "exp"
|
| 7 |
+
print_args = true
|
| 8 |
+
|
| 9 |
+
[training]
|
| 10 |
+
batch_size = 32
|
| 11 |
+
seq_len = 2048
|
| 12 |
+
context_len = 2048
|
| 13 |
+
gradient_accumulation_steps = 1
|
| 14 |
+
steps = 20480
|
| 15 |
+
max_norm = 1.0
|
| 16 |
+
skip_nan_inf = true
|
| 17 |
+
data_parallel_replicate_degree = 1
|
| 18 |
+
data_parallel_shard_degree = -1
|
| 19 |
+
tensor_parallel_degree = 1
|
| 20 |
+
compile = false
|
| 21 |
+
dataset = "HuggingFaceFW/fineweb-edu"
|
| 22 |
+
dataset_name = "default"
|
| 23 |
+
num_workers = 32
|
| 24 |
+
pin_memory = false
|
| 25 |
+
persistent_workers = false
|
| 26 |
+
prefetch_factor = 2
|
| 27 |
+
seed = 42
|
| 28 |
+
varlen = false
|
| 29 |
+
|
| 30 |
+
[optimizer]
|
| 31 |
+
name = "AdamW"
|
| 32 |
+
eps = 1e-15
|
| 33 |
+
lr = 3e-4
|
| 34 |
+
|
| 35 |
+
[lr_scheduler]
|
| 36 |
+
warmup_steps = 1024
|
| 37 |
+
decay_type = "cosine"
|
| 38 |
+
lr_min = 0.1
|
| 39 |
+
|
| 40 |
+
[checkpoint]
|
| 41 |
+
enable_checkpoint = true
|
| 42 |
+
folder = "checkpoint"
|
| 43 |
+
interval_type = "steps"
|
| 44 |
+
interval = 2048
|
| 45 |
+
model_weights_only = false
|
| 46 |
+
export_dtype = "float32"
|
| 47 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
| 48 |
+
|
| 49 |
+
[profiling]
|
| 50 |
+
enable_profiling = true
|
| 51 |
+
save_traces_folder = "profile_trace"
|
| 52 |
+
profile_freq = 512
|
| 53 |
+
|
| 54 |
+
[metrics]
|
| 55 |
+
log_freq = 32
|
| 56 |
+
enable_wandb = true
|
| 57 |
+
|
| 58 |
+
[experimental]
|
| 59 |
+
context_parallel_degree = 1
|
| 60 |
+
pipeline_parallel_degree = 1
|
| 61 |
+
|
| 62 |
+
[float8]
|
| 63 |
+
enable_fsdp_float8_all_gather = false
|
| 64 |
+
precompute_float8_dynamic_scale_for_fsdp = false
|
| 65 |
+
|
| 66 |
+
[activation_checkpoint]
|
| 67 |
+
mode = "none"
|
flame/models/parallelize_fla.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
|
| 8 |
+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
|
| 9 |
+
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.distributed import DeviceMesh
|
| 15 |
+
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
|
| 16 |
+
from torch.distributed._composable.replicate import replicate
|
| 17 |
+
from torch.distributed._tensor import Replicate, Shard
|
| 18 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
|
| 19 |
+
from torch.distributed.tensor.parallel import (
|
| 20 |
+
ColwiseParallel,
|
| 21 |
+
PrepareModuleInput,
|
| 22 |
+
PrepareModuleOutput,
|
| 23 |
+
RowwiseParallel,
|
| 24 |
+
SequenceParallel,
|
| 25 |
+
parallelize_module
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from fla.modules.fused_linear_cross_entropy import LinearLossParallel
|
| 29 |
+
from fla.modules.mlp import SwiGLULinearParallel
|
| 30 |
+
from fla.modules.parallel import PrepareModuleWeight
|
| 31 |
+
from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
|
| 32 |
+
from torchtitan.distributed.parallel_dims import ParallelDims
|
| 33 |
+
from torchtitan.tools.logging import logger
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def parallelize_fla(
|
| 37 |
+
model: nn.Module,
|
| 38 |
+
world_mesh: DeviceMesh,
|
| 39 |
+
parallel_dims: ParallelDims,
|
| 40 |
+
job_config: JobConfig,
|
| 41 |
+
):
|
| 42 |
+
"""
|
| 43 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
| 44 |
+
parallelism to the model.
|
| 45 |
+
|
| 46 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
| 47 |
+
the model must fit on GPU or CPU memory.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
if parallel_dims.tp_enabled:
|
| 51 |
+
if (
|
| 52 |
+
job_config.experimental.enable_async_tensor_parallel
|
| 53 |
+
and not job_config.training.compile
|
| 54 |
+
):
|
| 55 |
+
raise RuntimeError("Async TP requires --training.compile")
|
| 56 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
| 57 |
+
apply_tp(
|
| 58 |
+
model,
|
| 59 |
+
world_mesh["tp"],
|
| 60 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
| 61 |
+
enable_float8=enable_float8_linear,
|
| 62 |
+
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if job_config.activation_checkpoint.mode != "none":
|
| 66 |
+
apply_ac(model, job_config.activation_checkpoint)
|
| 67 |
+
|
| 68 |
+
# turn on per-block compile after AC wrapping and before FSDP
|
| 69 |
+
if job_config.training.compile:
|
| 70 |
+
apply_compile(model)
|
| 71 |
+
|
| 72 |
+
if (
|
| 73 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
| 74 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
| 75 |
+
if parallel_dims.dp_replicate_enabled:
|
| 76 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
| 77 |
+
else:
|
| 78 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
| 79 |
+
|
| 80 |
+
apply_fsdp(
|
| 81 |
+
model,
|
| 82 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
| 83 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
| 84 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
| 85 |
+
pp_enabled=parallel_dims.pp_enabled,
|
| 86 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
| 87 |
+
reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if parallel_dims.dp_replicate_enabled:
|
| 91 |
+
logger.info("Applied HSDP to the model")
|
| 92 |
+
else:
|
| 93 |
+
logger.info("Applied FSDP to the model")
|
| 94 |
+
|
| 95 |
+
if parallel_dims.cp_enabled:
|
| 96 |
+
logger.info("Applied Context Parallel to the model")
|
| 97 |
+
|
| 98 |
+
if job_config.training.enable_cpu_offload:
|
| 99 |
+
logger.info("Applied CPU Offloading to the model")
|
| 100 |
+
elif parallel_dims.dp_replicate_enabled:
|
| 101 |
+
if world_mesh.ndim > 1:
|
| 102 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
| 103 |
+
apply_ddp(
|
| 104 |
+
model,
|
| 105 |
+
world_mesh,
|
| 106 |
+
enable_compile=job_config.training.compile,
|
| 107 |
+
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class TPPlan:
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
model=None,
|
| 115 |
+
loss_parallel=False,
|
| 116 |
+
enable_float8=False,
|
| 117 |
+
):
|
| 118 |
+
self.model = model
|
| 119 |
+
self.loss_parallel = loss_parallel
|
| 120 |
+
self.enable_float8 = enable_float8
|
| 121 |
+
self.base_model_prefix = getattr(model, "base_model_prefix", "model")
|
| 122 |
+
|
| 123 |
+
# TODO(vkuzo): once float8 configuration supports delayed scaling,
|
| 124 |
+
# add a check here to enforce supported float8 all-gather configurations
|
| 125 |
+
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
|
| 126 |
+
try:
|
| 127 |
+
from torchao.float8.float8_tensor_parallel import (
|
| 128 |
+
Float8ColwiseParallel,
|
| 129 |
+
Float8RowwiseParallel,
|
| 130 |
+
PrepareFloat8ModuleInput
|
| 131 |
+
)
|
| 132 |
+
except ImportError:
|
| 133 |
+
Float8ColwiseParallel = None
|
| 134 |
+
Float8RowwiseParallel = None
|
| 135 |
+
PrepareFloat8ModuleInput = None
|
| 136 |
+
if self.enable_float8 and Float8ColwiseParallel is not None:
|
| 137 |
+
self.rowwise_parallel = Float8RowwiseParallel
|
| 138 |
+
self.colwise_parallel = Float8ColwiseParallel
|
| 139 |
+
self.prepare_module_input = PrepareFloat8ModuleInput
|
| 140 |
+
self.prepare_module_output = PrepareModuleOutput
|
| 141 |
+
else:
|
| 142 |
+
self.rowwise_parallel = RowwiseParallel
|
| 143 |
+
self.colwise_parallel = ColwiseParallel
|
| 144 |
+
self.prepare_module_input = PrepareModuleInput
|
| 145 |
+
self.prepare_module_output = PrepareModuleOutput
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def model_plan(self):
|
| 149 |
+
plans = {
|
| 150 |
+
f"{self.base_model_prefix}.embeddings": RowwiseParallel(
|
| 151 |
+
input_layouts=Replicate(),
|
| 152 |
+
output_layouts=Shard(1),
|
| 153 |
+
),
|
| 154 |
+
f"{self.base_model_prefix}.norm": SequenceParallel(),
|
| 155 |
+
}
|
| 156 |
+
if self.loss_parallel:
|
| 157 |
+
plans.update(
|
| 158 |
+
{
|
| 159 |
+
"lm_head": ColwiseParallel(
|
| 160 |
+
input_layouts=Shard(1),
|
| 161 |
+
output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
|
| 162 |
+
use_local_output=not self.loss_parallel,
|
| 163 |
+
),
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
plans.update(
|
| 168 |
+
{
|
| 169 |
+
"lm_head": PrepareModuleWeight(layouts=Replicate()),
|
| 170 |
+
"criterion": LinearLossParallel(),
|
| 171 |
+
}
|
| 172 |
+
)
|
| 173 |
+
return plans
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def layer_plan(self):
|
| 177 |
+
return {
|
| 178 |
+
"attn_norm": SequenceParallel(),
|
| 179 |
+
**self.attn_plan,
|
| 180 |
+
"mlp_norm": SequenceParallel(),
|
| 181 |
+
**self.mlp_plan,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def attn_plan(self):
|
| 186 |
+
raise NotImplementedError(
|
| 187 |
+
f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def mlp_plan(self):
|
| 192 |
+
return {
|
| 193 |
+
"mlp": self.prepare_module_input(
|
| 194 |
+
input_layouts=(Shard(1),),
|
| 195 |
+
desired_input_layouts=(Replicate(),),
|
| 196 |
+
),
|
| 197 |
+
"mlp.gate_proj": self.colwise_parallel(),
|
| 198 |
+
"mlp.up_proj": self.colwise_parallel(),
|
| 199 |
+
"mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
| 200 |
+
"mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class TransformerTPPlan(TPPlan):
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def attn_plan(self):
|
| 208 |
+
return {
|
| 209 |
+
"attn": self.prepare_module_input(
|
| 210 |
+
input_kwarg_layouts={"hidden_states": Shard(1)},
|
| 211 |
+
desired_input_kwarg_layouts={"hidden_states": Replicate()},
|
| 212 |
+
),
|
| 213 |
+
"attn.q_proj": self.colwise_parallel(),
|
| 214 |
+
"attn.k_proj": self.colwise_parallel(),
|
| 215 |
+
"attn.v_proj": self.colwise_parallel(),
|
| 216 |
+
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class GLATPPlan(TPPlan):
|
| 221 |
+
|
| 222 |
+
@property
|
| 223 |
+
def attn_plan(self):
|
| 224 |
+
return {
|
| 225 |
+
"attn": self.prepare_module_input(
|
| 226 |
+
input_kwarg_layouts={"hidden_states": Shard(1)},
|
| 227 |
+
desired_input_kwarg_layouts={"hidden_states": Replicate()},
|
| 228 |
+
),
|
| 229 |
+
"attn.q_proj": self.colwise_parallel(),
|
| 230 |
+
"attn.k_proj": self.colwise_parallel(),
|
| 231 |
+
"attn.v_proj": self.colwise_parallel(),
|
| 232 |
+
"attn.g_proj": self.colwise_parallel(),
|
| 233 |
+
"attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
|
| 234 |
+
"attn.gk_proj.1": self.colwise_parallel(),
|
| 235 |
+
"attn.g_norm": SequenceParallel(sequence_dim=-1),
|
| 236 |
+
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def apply_tp(
|
| 244 |
+
model: nn.Module,
|
| 245 |
+
tp_mesh: DeviceMesh,
|
| 246 |
+
loss_parallel: bool,
|
| 247 |
+
enable_float8: bool,
|
| 248 |
+
enable_async_tp: bool,
|
| 249 |
+
):
|
| 250 |
+
"""Apply tensor parallelism."""
|
| 251 |
+
# 1. Parallelize the embedding and shard its outputs (which are the first
|
| 252 |
+
# transformer block's inputs)
|
| 253 |
+
# 2. Parallelize the root norm layer over the sequence dim
|
| 254 |
+
# 3. Parallelize the final linear output layer
|
| 255 |
+
tp_plan = TP_PLAN_MAP[model.config.model_type](
|
| 256 |
+
model, loss_parallel=loss_parallel, enable_float8=enable_float8
|
| 257 |
+
)
|
| 258 |
+
parallelize_module(model, tp_mesh, tp_plan.model_plan)
|
| 259 |
+
|
| 260 |
+
blocks = get_blocks(model)
|
| 261 |
+
if blocks is None:
|
| 262 |
+
logger.warning("No block found for tensor parallelism")
|
| 263 |
+
else:
|
| 264 |
+
for _, block in enumerate(blocks):
|
| 265 |
+
parallelize_module(
|
| 266 |
+
module=block,
|
| 267 |
+
device_mesh=tp_mesh,
|
| 268 |
+
parallelize_plan=tp_plan.layer_plan,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if enable_async_tp:
|
| 272 |
+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
| 273 |
+
|
| 274 |
+
torch._inductor.config._micro_pipeline_tp = True
|
| 275 |
+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
|
| 276 |
+
|
| 277 |
+
logger.info(
|
| 278 |
+
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
|
| 279 |
+
"Tensor Parallelism to the model"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# for selective op activation checkpointing
|
| 284 |
+
_save_list = {
|
| 285 |
+
torch.ops.aten.mm.default,
|
| 286 |
+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
| 287 |
+
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
| 288 |
+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
|
| 289 |
+
# for low precision training, it's useful to always save
|
| 290 |
+
# the result of max, since the absolute maximum is
|
| 291 |
+
# used to compute the scaling factor for quantization.
|
| 292 |
+
torch.ops.aten.max.default,
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _apply_ac_to_block(module: nn.Module, ac_config):
|
| 297 |
+
valid_ac_modes = ("full", "selective")
|
| 298 |
+
if ac_config.mode not in valid_ac_modes:
|
| 299 |
+
raise ValueError(
|
| 300 |
+
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if ac_config.mode == "full":
|
| 304 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
| 305 |
+
|
| 306 |
+
assert ac_config.mode == "selective", f"{ac_config.mode}"
|
| 307 |
+
use_op_sac = ac_config.selective_ac_option == "op"
|
| 308 |
+
use_layer_sac = ac_config.selective_ac_option.isdigit()
|
| 309 |
+
if not use_op_sac and not use_layer_sac:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
|
| 312 |
+
f"Valid options: 'op' or a positive int representing layer frequency"
|
| 313 |
+
)
|
| 314 |
+
if use_op_sac:
|
| 315 |
+
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
|
| 316 |
+
|
| 317 |
+
def _get_custom_policy(meta):
|
| 318 |
+
def _custom_policy(ctx, func, *args, **kwargs):
|
| 319 |
+
mode = "recompute" if ctx.is_recompute else "forward"
|
| 320 |
+
mm_count_key = f"{mode}_mm_count"
|
| 321 |
+
if func == torch.ops.aten.mm.default:
|
| 322 |
+
meta[mm_count_key] += 1
|
| 323 |
+
# Saves output of all compute ops, except every second mm
|
| 324 |
+
to_save = func in _save_list and not (
|
| 325 |
+
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
|
| 326 |
+
)
|
| 327 |
+
return (
|
| 328 |
+
CheckpointPolicy.MUST_SAVE
|
| 329 |
+
if to_save
|
| 330 |
+
else CheckpointPolicy.PREFER_RECOMPUTE
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return _custom_policy
|
| 334 |
+
|
| 335 |
+
def selective_checkpointing_context_fn():
|
| 336 |
+
meta = defaultdict(int)
|
| 337 |
+
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
|
| 338 |
+
|
| 339 |
+
return ptd_checkpoint_wrapper(
|
| 340 |
+
module,
|
| 341 |
+
context_fn=selective_checkpointing_context_fn,
|
| 342 |
+
preserve_rng_state=False,
|
| 343 |
+
)
|
| 344 |
+
elif use_layer_sac:
|
| 345 |
+
# Checkpoint every `ac_freq` of the modules passed to this function
|
| 346 |
+
ac_freq = int(ac_config.selective_ac_option)
|
| 347 |
+
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
|
| 348 |
+
ptd_checkpoint_wrapper._count += 1
|
| 349 |
+
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
|
| 350 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
| 351 |
+
else:
|
| 352 |
+
return module
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def apply_ac(model: nn.Module, ac_config):
|
| 356 |
+
"""Apply activation checkpointing to the model."""
|
| 357 |
+
blocks = get_blocks(model)
|
| 358 |
+
if blocks is None:
|
| 359 |
+
logger.warning("No block found for activation checkpointing")
|
| 360 |
+
return
|
| 361 |
+
|
| 362 |
+
for layer_id, block in blocks.named_children():
|
| 363 |
+
block = _apply_ac_to_block(block, ac_config)
|
| 364 |
+
blocks.register_module(layer_id, block)
|
| 365 |
+
|
| 366 |
+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def apply_compile(model: nn.Module):
|
| 370 |
+
"""
|
| 371 |
+
Apply torch.compile to each block, which makes compilation efficient due to
|
| 372 |
+
repeated structure. Alternatively one can compile the whole model (after applying DP).
|
| 373 |
+
"""
|
| 374 |
+
|
| 375 |
+
blocks = get_blocks(model)
|
| 376 |
+
if blocks is None:
|
| 377 |
+
logger.warning("No block found for torch.compile")
|
| 378 |
+
else:
|
| 379 |
+
for layer_id, block in blocks.named_children():
|
| 380 |
+
block = torch.compile(block)
|
| 381 |
+
blocks.register_module(layer_id, block)
|
| 382 |
+
logger.info("Compiling each block with torch.compile")
|
| 383 |
+
|
| 384 |
+
real_model = get_model(model)
|
| 385 |
+
|
| 386 |
+
logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
|
| 387 |
+
embeddings_key = get_components_name(real_model, "tok_embeddings")
|
| 388 |
+
if embeddings_key is not None:
|
| 389 |
+
embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
|
| 390 |
+
real_model.register_module(embeddings_key, embeddings)
|
| 391 |
+
|
| 392 |
+
norm_key = get_components_name(real_model, "norm")
|
| 393 |
+
if norm_key is not None:
|
| 394 |
+
norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
|
| 395 |
+
real_model.register_module(norm_key, norm)
|
| 396 |
+
|
| 397 |
+
lm_head_key = get_components_name(model, "lm_head")
|
| 398 |
+
if lm_head_key is not None:
|
| 399 |
+
lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
|
| 400 |
+
model.register_module(lm_head_key, lm_head)
|
| 401 |
+
|
| 402 |
+
logger.info("Compiling the entire model with torch.compile")
|
| 403 |
+
model = torch.compile(model)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def apply_fsdp(
|
| 407 |
+
model: nn.Module,
|
| 408 |
+
dp_mesh: DeviceMesh,
|
| 409 |
+
param_dtype: torch.dtype,
|
| 410 |
+
reduce_dtype: torch.dtype,
|
| 411 |
+
pp_enabled: bool,
|
| 412 |
+
cpu_offload: bool = False,
|
| 413 |
+
reshard_after_forward_policy: str = "default",
|
| 414 |
+
):
|
| 415 |
+
"""
|
| 416 |
+
Apply data parallelism (via FSDP2) to the model.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
model (nn.Module): The model to apply data parallelism to.
|
| 420 |
+
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
|
| 421 |
+
param_dtype (torch.dtype): The data type to use for model parameters.
|
| 422 |
+
reduce_dtype (torch.dtype): The data type to use for reduction operations.
|
| 423 |
+
pp_enabled (bool): Whether pipeline parallelism is enabled.
|
| 424 |
+
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
|
| 425 |
+
reshard_after_forward_policy (str, optional):
|
| 426 |
+
The policy to use for resharding after forward pass. Defaults to "default".
|
| 427 |
+
Other options: "never", "always".
|
| 428 |
+
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
|
| 429 |
+
- "always" will enable `reshard_after_forward` for all forward passes.
|
| 430 |
+
- "never" will disable `reshard_after_forward` for all forward passes.
|
| 431 |
+
|
| 432 |
+
"""
|
| 433 |
+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
|
| 434 |
+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
| 435 |
+
if cpu_offload:
|
| 436 |
+
fsdp_config["offload_policy"] = CPUOffloadPolicy()
|
| 437 |
+
|
| 438 |
+
blocks = get_blocks(model)
|
| 439 |
+
if blocks is None:
|
| 440 |
+
logger.warning("No block found for FSDP")
|
| 441 |
+
else:
|
| 442 |
+
total_blocks = len(blocks)
|
| 443 |
+
for layer_id, block in enumerate(blocks):
|
| 444 |
+
if reshard_after_forward_policy == "always":
|
| 445 |
+
reshard_after_forward = True
|
| 446 |
+
elif reshard_after_forward_policy == "never":
|
| 447 |
+
reshard_after_forward = False
|
| 448 |
+
elif reshard_after_forward_policy == "default":
|
| 449 |
+
if pp_enabled:
|
| 450 |
+
# For PP, do not reshard after forward to avoid per-microbatch
|
| 451 |
+
# all-gathers, which can be expensive and non-overlapped
|
| 452 |
+
reshard_after_forward = False
|
| 453 |
+
else:
|
| 454 |
+
# As an optimization, do not reshard after forward for the last
|
| 455 |
+
# transformer block since FSDP would prefetch it immediately
|
| 456 |
+
reshard_after_forward = int(layer_id) < total_blocks - 1
|
| 457 |
+
else:
|
| 458 |
+
raise ValueError(
|
| 459 |
+
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
|
| 460 |
+
)
|
| 461 |
+
fully_shard(
|
| 462 |
+
block,
|
| 463 |
+
**fsdp_config,
|
| 464 |
+
reshard_after_forward=reshard_after_forward,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def apply_ddp(
|
| 471 |
+
model: nn.Module,
|
| 472 |
+
dp_mesh: DeviceMesh,
|
| 473 |
+
enable_compile: bool,
|
| 474 |
+
enable_compiled_autograd: bool,
|
| 475 |
+
):
|
| 476 |
+
if enable_compile:
|
| 477 |
+
if enable_compiled_autograd:
|
| 478 |
+
torch._dynamo.config.optimize_ddp = (
|
| 479 |
+
"python_reducer_without_compiled_forward"
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
|
| 483 |
+
|
| 484 |
+
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
|
| 485 |
+
|
| 486 |
+
logger.info("Applied DDP to the model")
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def get_model(model):
|
| 490 |
+
base_model_prefix = getattr(model, "base_model_prefix", "model")
|
| 491 |
+
if not hasattr(model, base_model_prefix):
|
| 492 |
+
return None
|
| 493 |
+
model = getattr(model, base_model_prefix)
|
| 494 |
+
return model
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def get_blocks(model):
|
| 498 |
+
# TODO[flame]: adapt for network not using 'layers' attribute
|
| 499 |
+
model = get_model(model)
|
| 500 |
+
if not hasattr(model, "layers"):
|
| 501 |
+
logger.warning('no "layers" in model can be found')
|
| 502 |
+
return None
|
| 503 |
+
return model.layers
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def get_components_name(model, component_name):
|
| 507 |
+
"""
|
| 508 |
+
We try to catch tok_embeddings, norm layers and lm_head layers
|
| 509 |
+
We do not catch the layer names in the blocks, for blocks see `get_blocks`
|
| 510 |
+
We assume the model has the following structure:
|
| 511 |
+
LlamaForCausalLM:
|
| 512 |
+
Model:
|
| 513 |
+
embed_tokens,
|
| 514 |
+
layers,
|
| 515 |
+
norm,
|
| 516 |
+
lm_head
|
| 517 |
+
***
|
| 518 |
+
so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
|
| 519 |
+
and for 'lm_head' we need to pass `model`
|
| 520 |
+
***
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
if component_name == "tok_embeddings":
|
| 524 |
+
if hasattr(model, "tok_embeddings"):
|
| 525 |
+
return "tok_embeddings"
|
| 526 |
+
elif hasattr(model, "embed_tokens"):
|
| 527 |
+
return "embed_tokens"
|
| 528 |
+
elif hasattr(model, "embeddings"):
|
| 529 |
+
return "embeddings"
|
| 530 |
+
else:
|
| 531 |
+
logger.warning("No tok_embeddings found in model")
|
| 532 |
+
return None
|
| 533 |
+
|
| 534 |
+
elif component_name == "norm":
|
| 535 |
+
if hasattr(model, "norm"):
|
| 536 |
+
return "norm"
|
| 537 |
+
elif hasattr(model, "norms"):
|
| 538 |
+
return "norms"
|
| 539 |
+
elif hasattr(model, "layernorm"):
|
| 540 |
+
return "layernorm"
|
| 541 |
+
else:
|
| 542 |
+
logger.warning("No norm found in model")
|
| 543 |
+
return None
|
| 544 |
+
|
| 545 |
+
elif component_name == "lm_head":
|
| 546 |
+
if hasattr(model, "lm_head"):
|
| 547 |
+
return "lm_head"
|
| 548 |
+
else:
|
| 549 |
+
logger.warning("No lm_head found in model")
|
| 550 |
+
return None
|
flame/models/pipeline_fla.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file applies the PT-D pipeline parallelism to the Llama model.
|
| 8 |
+
|
| 9 |
+
import copy
|
| 10 |
+
from typing import Callable, Optional, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.distributed import DeviceMesh
|
| 15 |
+
from torch.distributed.pipelining import PipelineStage
|
| 16 |
+
from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
|
| 17 |
+
from transformers import PretrainedConfig
|
| 18 |
+
|
| 19 |
+
from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
|
| 20 |
+
from torchtitan.config_manager import JobConfig
|
| 21 |
+
from torchtitan.distributed.parallel_dims import ParallelDims
|
| 22 |
+
from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
|
| 23 |
+
from torchtitan.tools.logging import logger
|
| 24 |
+
|
| 25 |
+
DeviceType = Union[int, str, torch.device]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def pipeline_fla(
|
| 29 |
+
model: nn.Module,
|
| 30 |
+
pp_mesh: DeviceMesh,
|
| 31 |
+
parallel_dims: ParallelDims,
|
| 32 |
+
job_config: JobConfig,
|
| 33 |
+
device: DeviceType,
|
| 34 |
+
model_config: PretrainedConfig,
|
| 35 |
+
loss_fn: Callable[..., torch.Tensor],
|
| 36 |
+
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
|
| 37 |
+
stages, models = pipeline_fla_manual_split(
|
| 38 |
+
model, pp_mesh, parallel_dims, job_config, device, model_config
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
|
| 42 |
+
|
| 43 |
+
# This is used in the train loop to determine whether to pass in the input_ids and labels
|
| 44 |
+
has_first_stage = False
|
| 45 |
+
has_last_stage = False
|
| 46 |
+
for stage in stages:
|
| 47 |
+
if stage.is_first:
|
| 48 |
+
has_first_stage = True
|
| 49 |
+
if stage.is_last:
|
| 50 |
+
has_last_stage = True
|
| 51 |
+
|
| 52 |
+
return pp_schedule, models, has_first_stage, has_last_stage
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def pipeline_fla_manual_split(
|
| 56 |
+
whole_model: nn.Module,
|
| 57 |
+
pp_mesh: DeviceMesh,
|
| 58 |
+
parallel_dims: ParallelDims,
|
| 59 |
+
job_config: JobConfig,
|
| 60 |
+
device: DeviceType,
|
| 61 |
+
model_config: PretrainedConfig,
|
| 62 |
+
) -> tuple[list[PipelineStage], list[nn.Module]]:
|
| 63 |
+
"""
|
| 64 |
+
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
|
| 65 |
+
|
| 66 |
+
It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
|
| 67 |
+
|
| 68 |
+
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
|
| 69 |
+
parallelism.
|
| 70 |
+
"""
|
| 71 |
+
pp_rank = pp_mesh.get_local_rank()
|
| 72 |
+
pp_size = pp_mesh.size()
|
| 73 |
+
|
| 74 |
+
splits = (
|
| 75 |
+
job_config.experimental.pipeline_parallel_split_points
|
| 76 |
+
or generate_split_points(
|
| 77 |
+
job_config, parallel_dims.pp, model_config.num_hidden_layers
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def _build_stage(
|
| 82 |
+
stage_idx: int,
|
| 83 |
+
start_layer: Optional[str],
|
| 84 |
+
stop_layer: Optional[str],
|
| 85 |
+
is_first: bool = False,
|
| 86 |
+
is_last: bool = False,
|
| 87 |
+
) -> tuple[PipelineStage, nn.Module]:
|
| 88 |
+
model = copy.deepcopy(whole_model)
|
| 89 |
+
if not is_first:
|
| 90 |
+
# we do `model.tok_embeddings = None` here
|
| 91 |
+
real_model = get_model(model)
|
| 92 |
+
tok_embeddings_name = get_components_name(real_model, "tok_embeddings")
|
| 93 |
+
setattr(real_model, tok_embeddings_name, None)
|
| 94 |
+
|
| 95 |
+
drop_layers = start_layer is not None
|
| 96 |
+
# Get module dictionary from get_blocks(model)
|
| 97 |
+
# and Create a list of keys before modifying dictionary
|
| 98 |
+
module_dict = get_blocks(model)._modules # Store reference
|
| 99 |
+
layer_names = list(module_dict.keys())
|
| 100 |
+
|
| 101 |
+
# Iterate over the list of keys instead of `_modules.items()`
|
| 102 |
+
for name in layer_names:
|
| 103 |
+
# Dynamically determine prefix (blocks.* or layers.*)
|
| 104 |
+
prefix = start_layer.split(".")[0] if start_layer else "layers"
|
| 105 |
+
layer_name = f"{prefix}.{name}" # Construct the correct name format
|
| 106 |
+
|
| 107 |
+
# Ensure `drop_layers` activation is based on actual naming
|
| 108 |
+
if layer_name == start_layer:
|
| 109 |
+
drop_layers = False
|
| 110 |
+
if layer_name == stop_layer:
|
| 111 |
+
drop_layers = True
|
| 112 |
+
|
| 113 |
+
# Delete layer if drop_layers is active
|
| 114 |
+
if drop_layers:
|
| 115 |
+
del module_dict[name] # Safe deletion from stored dictionary
|
| 116 |
+
|
| 117 |
+
if not is_last:
|
| 118 |
+
# we do `model.norm = None` and `model.output = None`
|
| 119 |
+
real_model = get_model(model)
|
| 120 |
+
norm_name = get_components_name(real_model, "norm")
|
| 121 |
+
setattr(real_model, norm_name, None)
|
| 122 |
+
|
| 123 |
+
head_name = get_components_name(model, "lm_head")
|
| 124 |
+
setattr(model, head_name, None)
|
| 125 |
+
|
| 126 |
+
stage = PipelineStage(
|
| 127 |
+
model,
|
| 128 |
+
stage_idx,
|
| 129 |
+
num_stages,
|
| 130 |
+
device,
|
| 131 |
+
group=pp_mesh.get_group("pp"),
|
| 132 |
+
)
|
| 133 |
+
return stage, model
|
| 134 |
+
|
| 135 |
+
num_stages = len(splits) + 1
|
| 136 |
+
stage_idx = pp_rank
|
| 137 |
+
|
| 138 |
+
stages = []
|
| 139 |
+
models = []
|
| 140 |
+
|
| 141 |
+
schedule_class = get_schedule_class(
|
| 142 |
+
job_config.experimental.pipeline_parallel_schedule
|
| 143 |
+
)
|
| 144 |
+
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
|
| 145 |
+
|
| 146 |
+
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
|
| 147 |
+
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
|
| 148 |
+
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
|
| 149 |
+
stage, model_chunk = _build_stage(
|
| 150 |
+
stage_idx,
|
| 151 |
+
start_layer,
|
| 152 |
+
stop_layer,
|
| 153 |
+
is_first=stage_idx == 0,
|
| 154 |
+
is_last=stage_idx == num_stages - 1,
|
| 155 |
+
)
|
| 156 |
+
logger.info(
|
| 157 |
+
f"PP rank {pp_rank} is building stage_idx {stage_idx}"
|
| 158 |
+
f" with start_layer {start_layer}, stop_layer {stop_layer}"
|
| 159 |
+
)
|
| 160 |
+
stages.append(stage)
|
| 161 |
+
models.append(model_chunk)
|
| 162 |
+
return stages, models
|
flame/tools/__init__.py
ADDED
|
File without changes
|
flame/tools/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (136 Bytes). View file
|
|
|
flame/tools/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
flame/tools/utils.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torchtitan.tools.logging import logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
|
| 12 |
+
nparams = sum(p.numel() for p in model.parameters())
|
| 13 |
+
nparams_embedding = sum(
|
| 14 |
+
sum(p.numel() for p in m.parameters())
|
| 15 |
+
for m in model.children()
|
| 16 |
+
if isinstance(m, nn.Embedding)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
if hasattr(model_config, "num_heads"):
|
| 20 |
+
num_heads = model_config.num_heads
|
| 21 |
+
elif hasattr(model_config, "num_attention_heads"):
|
| 22 |
+
num_heads = model_config.num_attention_heads
|
| 23 |
+
else:
|
| 24 |
+
num_heads = 1
|
| 25 |
+
logger.warning("num_heads not found in model_config, defaulting to 1. ")
|
| 26 |
+
|
| 27 |
+
l, h, q, t = (
|
| 28 |
+
model_config.num_hidden_layers,
|
| 29 |
+
num_heads,
|
| 30 |
+
model_config.hidden_size // num_heads,
|
| 31 |
+
seq_len,
|
| 32 |
+
)
|
| 33 |
+
# Reasoning behind the factor of 12 for the self-attention part of the formula:
|
| 34 |
+
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
|
| 35 |
+
# 2. the flash attention does 1 more matmul recomputation in the backward
|
| 36 |
+
# but recomputation should not be counted in calculating MFU (+0)
|
| 37 |
+
# 3. each matmul performs 1 multiplication and 1 addition (*2)
|
| 38 |
+
# 4. we follow the convention and do not account for sparsity in causal attention
|
| 39 |
+
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
| 40 |
+
|
| 41 |
+
return nparams, num_flops_per_token
|
flame/utils/__init__.py
ADDED
|
File without changes
|
flame/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (136 Bytes). View file
|
|
|
flame/utils/__pycache__/checkpoint.cpython-312.pyc
ADDED
|
Binary file (4.07 kB). View file
|
|
|
flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc
ADDED
|
Binary file (3.73 kB). View file
|
|
|
flame/utils/__pycache__/convert_hf_to_dcp.cpython-312.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
flame/utils/__pycache__/hf_utils.cpython-312.pyc
ADDED
|
Binary file (4.46 kB). View file
|
|
|
flame/utils/checkpoint.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import re
|
| 4 |
+
import shutil
|
| 5 |
+
from torchtitan.tools.logging import logger
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int):
|
| 9 |
+
"""Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats."""
|
| 10 |
+
if keep_latest_k <= 0:
|
| 11 |
+
return # Keep all checkpoints
|
| 12 |
+
|
| 13 |
+
logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}")
|
| 14 |
+
|
| 15 |
+
# Cleanup DCP checkpoints (step-*)
|
| 16 |
+
dcp_checkpoints = sorted(
|
| 17 |
+
glob.glob(os.path.join(checkpoint_dir, "step-*")),
|
| 18 |
+
key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1,
|
| 19 |
+
reverse=True
|
| 20 |
+
)
|
| 21 |
+
# Filter out HF format directories
|
| 22 |
+
dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")]
|
| 23 |
+
|
| 24 |
+
if len(dcp_checkpoints) > keep_latest_k:
|
| 25 |
+
checkpoints_to_delete = dcp_checkpoints[keep_latest_k:]
|
| 26 |
+
logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
|
| 27 |
+
for ckpt_path in checkpoints_to_delete:
|
| 28 |
+
if os.path.isdir(ckpt_path): # Ensure it's a directory
|
| 29 |
+
try:
|
| 30 |
+
shutil.rmtree(ckpt_path)
|
| 31 |
+
except OSError as e:
|
| 32 |
+
logger.error(f"Error removing directory {ckpt_path}: {e}")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Cleanup HF checkpoints (step-*-hf)
|
| 36 |
+
hf_checkpoints = sorted(
|
| 37 |
+
glob.glob(os.path.join(checkpoint_dir, "step-*-hf")),
|
| 38 |
+
key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1,
|
| 39 |
+
reverse=True
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if len(hf_checkpoints) > keep_latest_k:
|
| 43 |
+
checkpoints_to_delete = hf_checkpoints[keep_latest_k:]
|
| 44 |
+
logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
|
| 45 |
+
for ckpt_path in checkpoints_to_delete:
|
| 46 |
+
if os.path.isdir(ckpt_path): # Ensure it's a directory
|
| 47 |
+
try:
|
| 48 |
+
shutil.rmtree(ckpt_path)
|
| 49 |
+
except OSError as e:
|
| 50 |
+
logger.error(f"Error removing directory {ckpt_path}: {e}")
|
flame/utils/convert_dcp_to_hf.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from datetime import timedelta
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.serialization
|
| 12 |
+
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
| 13 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
import fla # noqa
|
| 16 |
+
from torchtitan.tools.logging import init_logger, logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@torch.inference_mode()
|
| 20 |
+
def save_pretrained(
|
| 21 |
+
path: str,
|
| 22 |
+
step: int,
|
| 23 |
+
config: str,
|
| 24 |
+
tokenizer: str
|
| 25 |
+
):
|
| 26 |
+
logger.info(f"Loading the config from {config}")
|
| 27 |
+
config = AutoConfig.from_pretrained(config, trust_remote_code=True)
|
| 28 |
+
|
| 29 |
+
logger.info(f"Saving the config to {path}")
|
| 30 |
+
config.save_pretrained(path)
|
| 31 |
+
logger.info(f"Loading the tokenizer from {tokenizer}")
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
|
| 33 |
+
logger.info(f"Saving the tokenizer to {path}")
|
| 34 |
+
tokenizer.save_pretrained(path)
|
| 35 |
+
|
| 36 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 37 |
+
# base_checkpoint_dir = os.path.dirname(path)
|
| 38 |
+
base_checkpoint_dir = path
|
| 39 |
+
checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}')
|
| 40 |
+
checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
|
| 41 |
+
logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
|
| 42 |
+
dcp_to_torch_save(checkpoint, checkpoint_path)
|
| 43 |
+
|
| 44 |
+
logger.info(f"Initializing the model from config\n{config}")
|
| 45 |
+
model = AutoModelForCausalLM.from_config(config)
|
| 46 |
+
logger.info(model)
|
| 47 |
+
logger.info("Loading state dict from the checkpoint")
|
| 48 |
+
|
| 49 |
+
# Add datetime.timedelta and io.BytesIO to safe globals
|
| 50 |
+
torch.serialization.add_safe_globals([timedelta, io.BytesIO])
|
| 51 |
+
# torch.load now with default weights_only=True will work
|
| 52 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
|
| 53 |
+
|
| 54 |
+
logger.info(f"Saving the model to {path}")
|
| 55 |
+
model.save_pretrained(path)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
init_logger()
|
| 60 |
+
parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
|
| 61 |
+
parser.add_argument("--path", type=str, required=True)
|
| 62 |
+
parser.add_argument("--step", type=int, required=True)
|
| 63 |
+
parser.add_argument("--config", type=str, required=True)
|
| 64 |
+
parser.add_argument("--tokenizer", type=str, required=True)
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
save_pretrained(args.path, args.step, args.config, args.tokenizer)
|
flame/utils/convert_hf_to_dcp.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed.checkpoint as DCP
|
| 9 |
+
from transformers import AutoModelForCausalLM
|
| 10 |
+
|
| 11 |
+
import fla # noqa
|
| 12 |
+
from torchtitan.tools.logging import init_logger, logger
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@torch.inference_mode()
|
| 16 |
+
def convert_hf_weights(model: str, checkpoint: str):
|
| 17 |
+
logger.info(f"Loading model from {model}")
|
| 18 |
+
model = AutoModelForCausalLM.from_pretrained(model)
|
| 19 |
+
state_dict = model.state_dict()
|
| 20 |
+
|
| 21 |
+
logger.info(f"Writing to DCP at '{checkpoint}'")
|
| 22 |
+
checkpoint.mkdir(parents=True, exist_ok=True)
|
| 23 |
+
storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
|
| 24 |
+
DCP.save({"model": state_dict}, storage_writer=storage_writer)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
init_logger()
|
| 29 |
+
parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
|
| 30 |
+
parser.add_argument("--model", type=str, required=True)
|
| 31 |
+
parser.add_argument("--checkpoint", type=Path, required=True)
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
|
| 34 |
+
convert_hf_weights(args.model, args.checkpoint)
|
flame/utils/hf_utils.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo
|
| 4 |
+
from torchtitan.tools.logging import logger
|
| 5 |
+
|
| 6 |
+
def upload_checkpoint_to_hf(
|
| 7 |
+
local_path: str,
|
| 8 |
+
step: int,
|
| 9 |
+
hf_repo_id_for_run: str,
|
| 10 |
+
hf_keep_latest_k: int,
|
| 11 |
+
upload_format: str
|
| 12 |
+
):
|
| 13 |
+
"""Uploads a checkpoint directory to HF Hub and manages retention."""
|
| 14 |
+
if not os.path.isdir(local_path):
|
| 15 |
+
logger.error(f"Local path for upload does not exist or is not a directory: {local_path}")
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
api = HfApi()
|
| 19 |
+
token = HfFolder.get_token()
|
| 20 |
+
if not token:
|
| 21 |
+
logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.")
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
# --- Ensure the specific repository for this run exists ---
|
| 25 |
+
try:
|
| 26 |
+
logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...")
|
| 27 |
+
# Use create_repo which handles creation only if it doesn't exist
|
| 28 |
+
create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True)
|
| 29 |
+
logger.info(f"Repository {hf_repo_id_for_run} ensured.")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True)
|
| 32 |
+
return # Stop if repo interaction fails
|
| 33 |
+
|
| 34 |
+
commit_message = f"Upload {upload_format.upper()} checkpoint step {step}"
|
| 35 |
+
path_in_repo = f"step-{step}"
|
| 36 |
+
|
| 37 |
+
logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...")
|
| 38 |
+
try:
|
| 39 |
+
api.upload_folder(
|
| 40 |
+
folder_path=local_path,
|
| 41 |
+
path_in_repo=path_in_repo,
|
| 42 |
+
repo_id=hf_repo_id_for_run,
|
| 43 |
+
repo_type="model",
|
| 44 |
+
commit_message=commit_message,
|
| 45 |
+
token=token,
|
| 46 |
+
)
|
| 47 |
+
logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.")
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True)
|
| 50 |
+
if hf_keep_latest_k > 0:
|
| 51 |
+
logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}")
|
| 52 |
+
try:
|
| 53 |
+
repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False)
|
| 54 |
+
step_folders = [
|
| 55 |
+
item.path for item in repo_files
|
| 56 |
+
if item.path.startswith("step-") and item.path[5:].isdigit()
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True)
|
| 60 |
+
|
| 61 |
+
if len(step_folders) > hf_keep_latest_k:
|
| 62 |
+
folders_to_delete = step_folders[hf_keep_latest_k:]
|
| 63 |
+
logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}")
|
| 64 |
+
for folder in folders_to_delete:
|
| 65 |
+
# Deleting requires repo_id, path_in_repo, and token
|
| 66 |
+
api.delete_folder(
|
| 67 |
+
repo_id=hf_repo_id_for_run,
|
| 68 |
+
path_in_repo=folder,
|
| 69 |
+
repo_type="model",
|
| 70 |
+
commit_message=f"Delete old checkpoint {folder}",
|
| 71 |
+
token=token
|
| 72 |
+
)
|
| 73 |
+
logger.info("Hub cleanup complete.")
|
| 74 |
+
else:
|
| 75 |
+
logger.info("No old checkpoints found on Hub to delete.")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True)
|
logs/none_enyj3lod/attempt_0/5/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/none_enyj3lod/attempt_0/7/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_15872/rank6_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_16384/rank2_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_16384/rank3_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_16384/rank5_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_19968/rank3_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_19968/rank7_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_3072/rank2_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_32256/rank1_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_37376/rank0_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_37376/rank6_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_9216/rank2_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_9216/rank3_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_9216/rank4_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_9216/rank6_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_9216/rank7_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_9728/rank1_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/files/output.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-6.8.0-62-generic-x86_64-with-glibc2.39",
|
| 3 |
+
"python": "CPython 3.12.11",
|
| 4 |
+
"startedAt": "2025-09-11T14:15:51.409164Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--job.config_file",
|
| 7 |
+
"flame/models/fla.toml",
|
| 8 |
+
"--job.dump_folder",
|
| 9 |
+
"exp/mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine",
|
| 10 |
+
"--model.config",
|
| 11 |
+
"configs/mtp_transformer_7B.json",
|
| 12 |
+
"--model.tokenizer_path",
|
| 13 |
+
"fla-hub/transformer-1.3B-100B",
|
| 14 |
+
"--optimizer.name",
|
| 15 |
+
"AdamW",
|
| 16 |
+
"--optimizer.eps",
|
| 17 |
+
"1e-15",
|
| 18 |
+
"--optimizer.lr",
|
| 19 |
+
"2e-5",
|
| 20 |
+
"--lr_scheduler.warmup_steps",
|
| 21 |
+
"400",
|
| 22 |
+
"--lr_scheduler.lr_min",
|
| 23 |
+
"0.1",
|
| 24 |
+
"--lr_scheduler.decay_type",
|
| 25 |
+
"cosine",
|
| 26 |
+
"--training.batch_size",
|
| 27 |
+
"8",
|
| 28 |
+
"--training.seq_len",
|
| 29 |
+
"4096",
|
| 30 |
+
"--training.context_len",
|
| 31 |
+
"4096",
|
| 32 |
+
"--training.gradient_accumulation_steps",
|
| 33 |
+
"2",
|
| 34 |
+
"--training.steps",
|
| 35 |
+
"40000",
|
| 36 |
+
"--training.max_norm",
|
| 37 |
+
"1.0",
|
| 38 |
+
"--training.skip_nan_inf",
|
| 39 |
+
"--training.dataset",
|
| 40 |
+
"/home/cvm/.cache/zaydzuhri___stack-edu-python/default",
|
| 41 |
+
"--training.dataset_split",
|
| 42 |
+
"train",
|
| 43 |
+
"--training.num_workers",
|
| 44 |
+
"32",
|
| 45 |
+
"--training.prefetch_factor",
|
| 46 |
+
"2",
|
| 47 |
+
"--training.seed",
|
| 48 |
+
"79",
|
| 49 |
+
"--training.compile",
|
| 50 |
+
"--checkpoint.interval",
|
| 51 |
+
"5000",
|
| 52 |
+
"--checkpoint.load_step",
|
| 53 |
+
"-1",
|
| 54 |
+
"--metrics.log_freq",
|
| 55 |
+
"5",
|
| 56 |
+
"--checkpoint.hf_upload_enabled",
|
| 57 |
+
"--checkpoint.hf_repo_base_name",
|
| 58 |
+
"zaydzuhri/mtp-code-7B-4096-batch8x2-steps40000",
|
| 59 |
+
"--comm.init_timeout_seconds",
|
| 60 |
+
"6000",
|
| 61 |
+
"--comm.train_timeout_seconds",
|
| 62 |
+
"6000"
|
| 63 |
+
],
|
| 64 |
+
"program": "-m flame.train",
|
| 65 |
+
"git": {
|
| 66 |
+
"remote": "https://github.com/zaydzuhri/flame.git",
|
| 67 |
+
"commit": "aa4d5932e54fad8a568e10aa6895e69e0664fcf1"
|
| 68 |
+
},
|
| 69 |
+
"email": "zaydzuhri@gmail.com",
|
| 70 |
+
"root": "exp/mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20250911-1415",
|
| 71 |
+
"host": "cvm-hnlakfcy",
|
| 72 |
+
"executable": "/home/cvm/miniconda3/envs/flame-env/bin/python3.12",
|
| 73 |
+
"cpu_count": 64,
|
| 74 |
+
"cpu_count_logical": 128,
|
| 75 |
+
"gpu": "NVIDIA H200",
|
| 76 |
+
"gpu_count": 8,
|
| 77 |
+
"disk": {
|
| 78 |
+
"/": {
|
| 79 |
+
"total": "3242363822080",
|
| 80 |
+
"used": "1142457630720"
|
| 81 |
+
}
|
| 82 |
+
},
|
| 83 |
+
"memory": {
|
| 84 |
+
"total": "1913832992768"
|
| 85 |
+
},
|
| 86 |
+
"gpu_nvidia": [
|
| 87 |
+
{
|
| 88 |
+
"name": "NVIDIA H200",
|
| 89 |
+
"memoryTotal": "150754820096",
|
| 90 |
+
"cudaCores": 16896,
|
| 91 |
+
"architecture": "Hopper",
|
| 92 |
+
"uuid": "GPU-248746a8-c843-17da-da73-f7e913ce8534"
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"name": "NVIDIA H200",
|
| 96 |
+
"memoryTotal": "150754820096",
|
| 97 |
+
"cudaCores": 16896,
|
| 98 |
+
"architecture": "Hopper",
|
| 99 |
+
"uuid": "GPU-dd71b7fe-465a-c7fc-695c-92022644d1e4"
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"name": "NVIDIA H200",
|
| 103 |
+
"memoryTotal": "150754820096",
|
| 104 |
+
"cudaCores": 16896,
|
| 105 |
+
"architecture": "Hopper",
|
| 106 |
+
"uuid": "GPU-fa231ade-f7f2-7b4b-7038-1b6ba3478565"
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"name": "NVIDIA H200",
|
| 110 |
+
"memoryTotal": "150754820096",
|
| 111 |
+
"cudaCores": 16896,
|
| 112 |
+
"architecture": "Hopper",
|
| 113 |
+
"uuid": "GPU-6c677375-a50c-d5ca-a517-8bab66e768e5"
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"name": "NVIDIA H200",
|
| 117 |
+
"memoryTotal": "150754820096",
|
| 118 |
+
"cudaCores": 16896,
|
| 119 |
+
"architecture": "Hopper",
|
| 120 |
+
"uuid": "GPU-e98c9a5d-ed96-14c6-fcd7-e32e074c4ec6"
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"name": "NVIDIA H200",
|
| 124 |
+
"memoryTotal": "150754820096",
|
| 125 |
+
"cudaCores": 16896,
|
| 126 |
+
"architecture": "Hopper",
|
| 127 |
+
"uuid": "GPU-0325ab0c-c935-f0f5-8488-1504586966c0"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"name": "NVIDIA H200",
|
| 131 |
+
"memoryTotal": "150754820096",
|
| 132 |
+
"cudaCores": 16896,
|
| 133 |
+
"architecture": "Hopper",
|
| 134 |
+
"uuid": "GPU-bd82a8ad-cbaf-446f-5012-c8841749fe95"
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"name": "NVIDIA H200",
|
| 138 |
+
"memoryTotal": "150754820096",
|
| 139 |
+
"cudaCores": 16896,
|
| 140 |
+
"architecture": "Hopper",
|
| 141 |
+
"uuid": "GPU-2aeed443-dce7-f05e-6010-086fa04a9413"
|
| 142 |
+
}
|
| 143 |
+
],
|
| 144 |
+
"cudaVersion": "12.8",
|
| 145 |
+
"writerId": "173cboaedkpy0bqi2wup4pr1w4lb2n87"
|
| 146 |
+
}
|
tb/20250911-1415/wandb/run-20250911_141551-mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411/logs/debug-core.log
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-09-11T14:15:51.448889635Z","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpme1m9m73/port-2338706.txt","pid":2338706,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
|
| 2 |
+
{"time":"2025-09-11T14:15:51.449765208Z","level":"INFO","msg":"server: will exit if parent process dies","ppid":2338706}
|
| 3 |
+
{"time":"2025-09-11T14:15:51.449708027Z","level":"INFO","msg":"server: accepting connections","addr":{"Name":"/tmp/wandb-2338706-2345351-2853185180/socket","Net":"unix"}}
|
| 4 |
+
{"time":"2025-09-11T14:15:51.617113051Z","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"1(@)"}
|
| 5 |
+
{"time":"2025-09-11T14:15:51.621427641Z","level":"INFO","msg":"handleInformInit: received","streamId":"mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411","id":"1(@)"}
|
| 6 |
+
{"time":"2025-09-11T14:15:51.915371225Z","level":"INFO","msg":"handleInformInit: stream started","streamId":"mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411","id":"1(@)"}
|
| 7 |
+
{"time":"2025-09-14T18:19:06.56412841Z","level":"INFO","msg":"handleInformFinish: finish message received","streamId":"mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411","id":"1(@)"}
|
| 8 |
+
{"time":"2025-09-14T18:19:06.565499094Z","level":"INFO","msg":"handleInformFinish: stream closed","streamId":"mtp_transformer-mtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509111411","id":"1(@)"}
|
| 9 |
+
{"time":"2025-09-14T18:19:14.63018152Z","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"1(@)"}
|
| 10 |
+
{"time":"2025-09-14T18:19:14.630252543Z","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"1(@)"}
|
| 11 |
+
{"time":"2025-09-14T18:19:14.630261701Z","level":"INFO","msg":"server is shutting down"}
|
| 12 |
+
{"time":"2025-09-14T18:19:14.630333303Z","level":"INFO","msg":"server: listener closed","addr":{"Name":"/tmp/wandb-2338706-2345351-2853185180/socket","Net":"unix"}}
|
| 13 |
+
{"time":"2025-09-14T18:19:14.630331706Z","level":"INFO","msg":"connection: closing","id":"1(@)"}
|
| 14 |
+
{"time":"2025-09-14T18:19:14.630454034Z","level":"INFO","msg":"connection: closed successfully","id":"1(@)"}
|
| 15 |
+
{"time":"2025-09-14T18:19:14.630460521Z","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"1(@)"}
|
| 16 |
+
{"time":"2025-09-14T18:19:14.630484673Z","level":"INFO","msg":"server is closed"}
|
torchtitan/__pycache__/config_manager.cpython-312.pyc
ADDED
|
Binary file (38.5 kB). View file
|
|
|