Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- flame/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/__pycache__/config_manager.cpython-312.pyc +0 -0
- flame/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
- flame/components/checkpoint.py +59 -0
- flame/models/__init__.py +0 -0
- flame/models/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/models/fla.toml +67 -0
- flame/models/parallelize_fla.py +550 -0
- flame/models/pipeline_fla.py +162 -0
- flame/tools/__pycache__/utils.cpython-312.pyc +0 -0
- flame/tools/utils.py +41 -0
- flame/utils/__init__.py +0 -0
- flame/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc +0 -0
- flame/utils/convert_dcp_to_hf.py +66 -0
- flame/utils/convert_hf_to_dcp.py +34 -0
- flame/utils/hf_utils.py +77 -0
- logs/none_g37i6vbo/attempt_0/6/stderr.log +0 -0
- logs/none_lyv0rec_/attempt_0/0/stdout.log +33 -0
- logs/none_lyv0rec_/attempt_0/7/stderr.log +0 -0
- logs/none_lyv0rec_/attempt_0/7/stdout.log +0 -0
- tb/20250909-0619/wandb/debug.log +21 -0
- tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/files/output.log +0 -0
- tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/files/requirements.txt +207 -0
- tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug-internal.log +10 -0
- tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug.log +21 -0
- torchtitan/components/__pycache__/dataloader.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/metrics.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/tokenizer.cpython-312.pyc +0 -0
- torchtitan/components/metrics.py +435 -0
- torchtitan/experiments/deepseek_v3/LICENSE-CODE +21 -0
- torchtitan/experiments/deepseek_v3/README.md +40 -0
- torchtitan/experiments/deepseek_v3/checkpoint.py +154 -0
- torchtitan/experiments/deepseek_v3/download.py +70 -0
- torchtitan/experiments/deepseek_v3/model.py +1325 -0
- torchtitan/experiments/deepseek_v3/requirements.txt +5 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py +63 -0
- torchtitan/experiments/flux/README.md +23 -0
- torchtitan/experiments/flux/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/experiments/flux/dataset/flux_dataset.py +267 -0
- torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
- torchtitan/experiments/flux/model/__pycache__/layers.cpython-312.pyc +0 -0
- torchtitan/experiments/flux/model/hf_embedder.py +40 -0
- torchtitan/experiments/flux/model/model.py +177 -0
- torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
- torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
- torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py +630 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py +82 -0
flame/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
flame/__pycache__/config_manager.cpython-312.pyc
ADDED
|
Binary file (36.9 kB). View file
|
|
|
flame/components/__pycache__/checkpoint.cpython-312.pyc
ADDED
|
Binary file (3.21 kB). View file
|
|
|
flame/components/checkpoint.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from datetime import timedelta
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
from typing import Any, Dict, List
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TrainState(Stateful):
|
| 18 |
+
step: int = 0
|
| 19 |
+
skipped_step: int = 0
|
| 20 |
+
token: int = 0
|
| 21 |
+
elapsed: timedelta = timedelta(0)
|
| 22 |
+
global_avg_losses: List[float] = field(default_factory=list)
|
| 23 |
+
global_max_losses: List[float] = field(default_factory=list)
|
| 24 |
+
log_steps: List[int] = field(default_factory=list)
|
| 25 |
+
|
| 26 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 27 |
+
# Only checkpoint global_avg_losses and global_max_losses per log frequency
|
| 28 |
+
# to avoid sync overhead in every iteration.
|
| 29 |
+
global_avg_losses_bytes = BytesIO()
|
| 30 |
+
torch.save(self.global_avg_losses, global_avg_losses_bytes)
|
| 31 |
+
global_max_losses_bytes = BytesIO()
|
| 32 |
+
torch.save(self.global_max_losses, global_max_losses_bytes)
|
| 33 |
+
log_steps_bytes = BytesIO()
|
| 34 |
+
torch.save(self.log_steps, log_steps_bytes)
|
| 35 |
+
return {
|
| 36 |
+
"step": torch.tensor(self.step, dtype=torch.int32),
|
| 37 |
+
"skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
|
| 38 |
+
"token": torch.tensor(self.token, dtype=torch.int64),
|
| 39 |
+
"elapsed": self.elapsed,
|
| 40 |
+
"global_avg_losses": global_avg_losses_bytes,
|
| 41 |
+
"global_max_losses": global_max_losses_bytes,
|
| 42 |
+
"log_steps": log_steps_bytes,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def load_state_dict(self, state_dict) -> None:
|
| 46 |
+
self.step = state_dict["step"].item()
|
| 47 |
+
self.skipped_step = state_dict.get("skipped_step", 0).item()
|
| 48 |
+
self.token = state_dict["token"].item()
|
| 49 |
+
self.elapsed = state_dict["elapsed"]
|
| 50 |
+
state_dict["global_avg_losses"].seek(0)
|
| 51 |
+
self.global_avg_losses = torch.load(
|
| 52 |
+
state_dict["global_avg_losses"], weights_only=False
|
| 53 |
+
)
|
| 54 |
+
state_dict["global_max_losses"].seek(0)
|
| 55 |
+
self.global_max_losses = torch.load(
|
| 56 |
+
state_dict["global_max_losses"], weights_only=False
|
| 57 |
+
)
|
| 58 |
+
state_dict["log_steps"].seek(0)
|
| 59 |
+
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
|
flame/models/__init__.py
ADDED
|
File without changes
|
flame/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (137 Bytes). View file
|
|
|
flame/models/fla.toml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[model]
|
| 2 |
+
config = "fla-hub/transformer-1.3B-100B"
|
| 3 |
+
tokenizer_path = "fla-hub/transformer-1.3B-100B"
|
| 4 |
+
|
| 5 |
+
[job]
|
| 6 |
+
dump_folder = "exp"
|
| 7 |
+
print_args = true
|
| 8 |
+
|
| 9 |
+
[training]
|
| 10 |
+
batch_size = 32
|
| 11 |
+
seq_len = 2048
|
| 12 |
+
context_len = 2048
|
| 13 |
+
gradient_accumulation_steps = 1
|
| 14 |
+
steps = 20480
|
| 15 |
+
max_norm = 1.0
|
| 16 |
+
skip_nan_inf = true
|
| 17 |
+
data_parallel_replicate_degree = 1
|
| 18 |
+
data_parallel_shard_degree = -1
|
| 19 |
+
tensor_parallel_degree = 1
|
| 20 |
+
compile = false
|
| 21 |
+
dataset = "HuggingFaceFW/fineweb-edu"
|
| 22 |
+
dataset_name = "default"
|
| 23 |
+
num_workers = 32
|
| 24 |
+
pin_memory = false
|
| 25 |
+
persistent_workers = false
|
| 26 |
+
prefetch_factor = 2
|
| 27 |
+
seed = 42
|
| 28 |
+
varlen = false
|
| 29 |
+
|
| 30 |
+
[optimizer]
|
| 31 |
+
name = "AdamW"
|
| 32 |
+
eps = 1e-15
|
| 33 |
+
lr = 3e-4
|
| 34 |
+
|
| 35 |
+
[lr_scheduler]
|
| 36 |
+
warmup_steps = 1024
|
| 37 |
+
decay_type = "cosine"
|
| 38 |
+
lr_min = 0.1
|
| 39 |
+
|
| 40 |
+
[checkpoint]
|
| 41 |
+
enable_checkpoint = true
|
| 42 |
+
folder = "checkpoint"
|
| 43 |
+
interval_type = "steps"
|
| 44 |
+
interval = 2048
|
| 45 |
+
model_weights_only = false
|
| 46 |
+
export_dtype = "float32"
|
| 47 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
| 48 |
+
|
| 49 |
+
[profiling]
|
| 50 |
+
enable_profiling = true
|
| 51 |
+
save_traces_folder = "profile_trace"
|
| 52 |
+
profile_freq = 512
|
| 53 |
+
|
| 54 |
+
[metrics]
|
| 55 |
+
log_freq = 32
|
| 56 |
+
enable_wandb = true
|
| 57 |
+
|
| 58 |
+
[experimental]
|
| 59 |
+
context_parallel_degree = 1
|
| 60 |
+
pipeline_parallel_degree = 1
|
| 61 |
+
|
| 62 |
+
[float8]
|
| 63 |
+
enable_fsdp_float8_all_gather = false
|
| 64 |
+
precompute_float8_dynamic_scale_for_fsdp = false
|
| 65 |
+
|
| 66 |
+
[activation_checkpoint]
|
| 67 |
+
mode = "none"
|
flame/models/parallelize_fla.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
|
| 8 |
+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
|
| 9 |
+
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.distributed import DeviceMesh
|
| 15 |
+
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
|
| 16 |
+
from torch.distributed._composable.replicate import replicate
|
| 17 |
+
from torch.distributed._tensor import Replicate, Shard
|
| 18 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
|
| 19 |
+
from torch.distributed.tensor.parallel import (
|
| 20 |
+
ColwiseParallel,
|
| 21 |
+
PrepareModuleInput,
|
| 22 |
+
PrepareModuleOutput,
|
| 23 |
+
RowwiseParallel,
|
| 24 |
+
SequenceParallel,
|
| 25 |
+
parallelize_module
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from fla.modules.fused_linear_cross_entropy import LinearLossParallel
|
| 29 |
+
from fla.modules.mlp import SwiGLULinearParallel
|
| 30 |
+
from fla.modules.parallel import PrepareModuleWeight
|
| 31 |
+
from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
|
| 32 |
+
from torchtitan.distributed.parallel_dims import ParallelDims
|
| 33 |
+
from torchtitan.tools.logging import logger
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def parallelize_fla(
|
| 37 |
+
model: nn.Module,
|
| 38 |
+
world_mesh: DeviceMesh,
|
| 39 |
+
parallel_dims: ParallelDims,
|
| 40 |
+
job_config: JobConfig,
|
| 41 |
+
):
|
| 42 |
+
"""
|
| 43 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
| 44 |
+
parallelism to the model.
|
| 45 |
+
|
| 46 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
| 47 |
+
the model must fit on GPU or CPU memory.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
if parallel_dims.tp_enabled:
|
| 51 |
+
if (
|
| 52 |
+
job_config.experimental.enable_async_tensor_parallel
|
| 53 |
+
and not job_config.training.compile
|
| 54 |
+
):
|
| 55 |
+
raise RuntimeError("Async TP requires --training.compile")
|
| 56 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
| 57 |
+
apply_tp(
|
| 58 |
+
model,
|
| 59 |
+
world_mesh["tp"],
|
| 60 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
| 61 |
+
enable_float8=enable_float8_linear,
|
| 62 |
+
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if job_config.activation_checkpoint.mode != "none":
|
| 66 |
+
apply_ac(model, job_config.activation_checkpoint)
|
| 67 |
+
|
| 68 |
+
# turn on per-block compile after AC wrapping and before FSDP
|
| 69 |
+
if job_config.training.compile:
|
| 70 |
+
apply_compile(model)
|
| 71 |
+
|
| 72 |
+
if (
|
| 73 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
| 74 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
| 75 |
+
if parallel_dims.dp_replicate_enabled:
|
| 76 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
| 77 |
+
else:
|
| 78 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
| 79 |
+
|
| 80 |
+
apply_fsdp(
|
| 81 |
+
model,
|
| 82 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
| 83 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
| 84 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
| 85 |
+
pp_enabled=parallel_dims.pp_enabled,
|
| 86 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
| 87 |
+
reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if parallel_dims.dp_replicate_enabled:
|
| 91 |
+
logger.info("Applied HSDP to the model")
|
| 92 |
+
else:
|
| 93 |
+
logger.info("Applied FSDP to the model")
|
| 94 |
+
|
| 95 |
+
if parallel_dims.cp_enabled:
|
| 96 |
+
logger.info("Applied Context Parallel to the model")
|
| 97 |
+
|
| 98 |
+
if job_config.training.enable_cpu_offload:
|
| 99 |
+
logger.info("Applied CPU Offloading to the model")
|
| 100 |
+
elif parallel_dims.dp_replicate_enabled:
|
| 101 |
+
if world_mesh.ndim > 1:
|
| 102 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
| 103 |
+
apply_ddp(
|
| 104 |
+
model,
|
| 105 |
+
world_mesh,
|
| 106 |
+
enable_compile=job_config.training.compile,
|
| 107 |
+
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class TPPlan:
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
model=None,
|
| 115 |
+
loss_parallel=False,
|
| 116 |
+
enable_float8=False,
|
| 117 |
+
):
|
| 118 |
+
self.model = model
|
| 119 |
+
self.loss_parallel = loss_parallel
|
| 120 |
+
self.enable_float8 = enable_float8
|
| 121 |
+
self.base_model_prefix = getattr(model, "base_model_prefix", "model")
|
| 122 |
+
|
| 123 |
+
# TODO(vkuzo): once float8 configuration supports delayed scaling,
|
| 124 |
+
# add a check here to enforce supported float8 all-gather configurations
|
| 125 |
+
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
|
| 126 |
+
try:
|
| 127 |
+
from torchao.float8.float8_tensor_parallel import (
|
| 128 |
+
Float8ColwiseParallel,
|
| 129 |
+
Float8RowwiseParallel,
|
| 130 |
+
PrepareFloat8ModuleInput
|
| 131 |
+
)
|
| 132 |
+
except ImportError:
|
| 133 |
+
Float8ColwiseParallel = None
|
| 134 |
+
Float8RowwiseParallel = None
|
| 135 |
+
PrepareFloat8ModuleInput = None
|
| 136 |
+
if self.enable_float8 and Float8ColwiseParallel is not None:
|
| 137 |
+
self.rowwise_parallel = Float8RowwiseParallel
|
| 138 |
+
self.colwise_parallel = Float8ColwiseParallel
|
| 139 |
+
self.prepare_module_input = PrepareFloat8ModuleInput
|
| 140 |
+
self.prepare_module_output = PrepareModuleOutput
|
| 141 |
+
else:
|
| 142 |
+
self.rowwise_parallel = RowwiseParallel
|
| 143 |
+
self.colwise_parallel = ColwiseParallel
|
| 144 |
+
self.prepare_module_input = PrepareModuleInput
|
| 145 |
+
self.prepare_module_output = PrepareModuleOutput
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def model_plan(self):
|
| 149 |
+
plans = {
|
| 150 |
+
f"{self.base_model_prefix}.embeddings": RowwiseParallel(
|
| 151 |
+
input_layouts=Replicate(),
|
| 152 |
+
output_layouts=Shard(1),
|
| 153 |
+
),
|
| 154 |
+
f"{self.base_model_prefix}.norm": SequenceParallel(),
|
| 155 |
+
}
|
| 156 |
+
if self.loss_parallel:
|
| 157 |
+
plans.update(
|
| 158 |
+
{
|
| 159 |
+
"lm_head": ColwiseParallel(
|
| 160 |
+
input_layouts=Shard(1),
|
| 161 |
+
output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
|
| 162 |
+
use_local_output=not self.loss_parallel,
|
| 163 |
+
),
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
plans.update(
|
| 168 |
+
{
|
| 169 |
+
"lm_head": PrepareModuleWeight(layouts=Replicate()),
|
| 170 |
+
"criterion": LinearLossParallel(),
|
| 171 |
+
}
|
| 172 |
+
)
|
| 173 |
+
return plans
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def layer_plan(self):
|
| 177 |
+
return {
|
| 178 |
+
"attn_norm": SequenceParallel(),
|
| 179 |
+
**self.attn_plan,
|
| 180 |
+
"mlp_norm": SequenceParallel(),
|
| 181 |
+
**self.mlp_plan,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def attn_plan(self):
|
| 186 |
+
raise NotImplementedError(
|
| 187 |
+
f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def mlp_plan(self):
|
| 192 |
+
return {
|
| 193 |
+
"mlp": self.prepare_module_input(
|
| 194 |
+
input_layouts=(Shard(1),),
|
| 195 |
+
desired_input_layouts=(Replicate(),),
|
| 196 |
+
),
|
| 197 |
+
"mlp.gate_proj": self.colwise_parallel(),
|
| 198 |
+
"mlp.up_proj": self.colwise_parallel(),
|
| 199 |
+
"mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
| 200 |
+
"mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class TransformerTPPlan(TPPlan):
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def attn_plan(self):
|
| 208 |
+
return {
|
| 209 |
+
"attn": self.prepare_module_input(
|
| 210 |
+
input_kwarg_layouts={"hidden_states": Shard(1)},
|
| 211 |
+
desired_input_kwarg_layouts={"hidden_states": Replicate()},
|
| 212 |
+
),
|
| 213 |
+
"attn.q_proj": self.colwise_parallel(),
|
| 214 |
+
"attn.k_proj": self.colwise_parallel(),
|
| 215 |
+
"attn.v_proj": self.colwise_parallel(),
|
| 216 |
+
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class GLATPPlan(TPPlan):
|
| 221 |
+
|
| 222 |
+
@property
|
| 223 |
+
def attn_plan(self):
|
| 224 |
+
return {
|
| 225 |
+
"attn": self.prepare_module_input(
|
| 226 |
+
input_kwarg_layouts={"hidden_states": Shard(1)},
|
| 227 |
+
desired_input_kwarg_layouts={"hidden_states": Replicate()},
|
| 228 |
+
),
|
| 229 |
+
"attn.q_proj": self.colwise_parallel(),
|
| 230 |
+
"attn.k_proj": self.colwise_parallel(),
|
| 231 |
+
"attn.v_proj": self.colwise_parallel(),
|
| 232 |
+
"attn.g_proj": self.colwise_parallel(),
|
| 233 |
+
"attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
|
| 234 |
+
"attn.gk_proj.1": self.colwise_parallel(),
|
| 235 |
+
"attn.g_norm": SequenceParallel(sequence_dim=-1),
|
| 236 |
+
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def apply_tp(
|
| 244 |
+
model: nn.Module,
|
| 245 |
+
tp_mesh: DeviceMesh,
|
| 246 |
+
loss_parallel: bool,
|
| 247 |
+
enable_float8: bool,
|
| 248 |
+
enable_async_tp: bool,
|
| 249 |
+
):
|
| 250 |
+
"""Apply tensor parallelism."""
|
| 251 |
+
# 1. Parallelize the embedding and shard its outputs (which are the first
|
| 252 |
+
# transformer block's inputs)
|
| 253 |
+
# 2. Parallelize the root norm layer over the sequence dim
|
| 254 |
+
# 3. Parallelize the final linear output layer
|
| 255 |
+
tp_plan = TP_PLAN_MAP[model.config.model_type](
|
| 256 |
+
model, loss_parallel=loss_parallel, enable_float8=enable_float8
|
| 257 |
+
)
|
| 258 |
+
parallelize_module(model, tp_mesh, tp_plan.model_plan)
|
| 259 |
+
|
| 260 |
+
blocks = get_blocks(model)
|
| 261 |
+
if blocks is None:
|
| 262 |
+
logger.warning("No block found for tensor parallelism")
|
| 263 |
+
else:
|
| 264 |
+
for _, block in enumerate(blocks):
|
| 265 |
+
parallelize_module(
|
| 266 |
+
module=block,
|
| 267 |
+
device_mesh=tp_mesh,
|
| 268 |
+
parallelize_plan=tp_plan.layer_plan,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if enable_async_tp:
|
| 272 |
+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
| 273 |
+
|
| 274 |
+
torch._inductor.config._micro_pipeline_tp = True
|
| 275 |
+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
|
| 276 |
+
|
| 277 |
+
logger.info(
|
| 278 |
+
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
|
| 279 |
+
"Tensor Parallelism to the model"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# for selective op activation checkpointing
|
| 284 |
+
_save_list = {
|
| 285 |
+
torch.ops.aten.mm.default,
|
| 286 |
+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
| 287 |
+
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
| 288 |
+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
|
| 289 |
+
# for low precision training, it's useful to always save
|
| 290 |
+
# the result of max, since the absolute maximum is
|
| 291 |
+
# used to compute the scaling factor for quantization.
|
| 292 |
+
torch.ops.aten.max.default,
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _apply_ac_to_block(module: nn.Module, ac_config):
|
| 297 |
+
valid_ac_modes = ("full", "selective")
|
| 298 |
+
if ac_config.mode not in valid_ac_modes:
|
| 299 |
+
raise ValueError(
|
| 300 |
+
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if ac_config.mode == "full":
|
| 304 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
| 305 |
+
|
| 306 |
+
assert ac_config.mode == "selective", f"{ac_config.mode}"
|
| 307 |
+
use_op_sac = ac_config.selective_ac_option == "op"
|
| 308 |
+
use_layer_sac = ac_config.selective_ac_option.isdigit()
|
| 309 |
+
if not use_op_sac and not use_layer_sac:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
|
| 312 |
+
f"Valid options: 'op' or a positive int representing layer frequency"
|
| 313 |
+
)
|
| 314 |
+
if use_op_sac:
|
| 315 |
+
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
|
| 316 |
+
|
| 317 |
+
def _get_custom_policy(meta):
|
| 318 |
+
def _custom_policy(ctx, func, *args, **kwargs):
|
| 319 |
+
mode = "recompute" if ctx.is_recompute else "forward"
|
| 320 |
+
mm_count_key = f"{mode}_mm_count"
|
| 321 |
+
if func == torch.ops.aten.mm.default:
|
| 322 |
+
meta[mm_count_key] += 1
|
| 323 |
+
# Saves output of all compute ops, except every second mm
|
| 324 |
+
to_save = func in _save_list and not (
|
| 325 |
+
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
|
| 326 |
+
)
|
| 327 |
+
return (
|
| 328 |
+
CheckpointPolicy.MUST_SAVE
|
| 329 |
+
if to_save
|
| 330 |
+
else CheckpointPolicy.PREFER_RECOMPUTE
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return _custom_policy
|
| 334 |
+
|
| 335 |
+
def selective_checkpointing_context_fn():
|
| 336 |
+
meta = defaultdict(int)
|
| 337 |
+
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
|
| 338 |
+
|
| 339 |
+
return ptd_checkpoint_wrapper(
|
| 340 |
+
module,
|
| 341 |
+
context_fn=selective_checkpointing_context_fn,
|
| 342 |
+
preserve_rng_state=False,
|
| 343 |
+
)
|
| 344 |
+
elif use_layer_sac:
|
| 345 |
+
# Checkpoint every `ac_freq` of the modules passed to this function
|
| 346 |
+
ac_freq = int(ac_config.selective_ac_option)
|
| 347 |
+
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
|
| 348 |
+
ptd_checkpoint_wrapper._count += 1
|
| 349 |
+
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
|
| 350 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
| 351 |
+
else:
|
| 352 |
+
return module
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def apply_ac(model: nn.Module, ac_config):
|
| 356 |
+
"""Apply activation checkpointing to the model."""
|
| 357 |
+
blocks = get_blocks(model)
|
| 358 |
+
if blocks is None:
|
| 359 |
+
logger.warning("No block found for activation checkpointing")
|
| 360 |
+
return
|
| 361 |
+
|
| 362 |
+
for layer_id, block in blocks.named_children():
|
| 363 |
+
block = _apply_ac_to_block(block, ac_config)
|
| 364 |
+
blocks.register_module(layer_id, block)
|
| 365 |
+
|
| 366 |
+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def apply_compile(model: nn.Module):
|
| 370 |
+
"""
|
| 371 |
+
Apply torch.compile to each block, which makes compilation efficient due to
|
| 372 |
+
repeated structure. Alternatively one can compile the whole model (after applying DP).
|
| 373 |
+
"""
|
| 374 |
+
|
| 375 |
+
blocks = get_blocks(model)
|
| 376 |
+
if blocks is None:
|
| 377 |
+
logger.warning("No block found for torch.compile")
|
| 378 |
+
else:
|
| 379 |
+
for layer_id, block in blocks.named_children():
|
| 380 |
+
block = torch.compile(block)
|
| 381 |
+
blocks.register_module(layer_id, block)
|
| 382 |
+
logger.info("Compiling each block with torch.compile")
|
| 383 |
+
|
| 384 |
+
real_model = get_model(model)
|
| 385 |
+
|
| 386 |
+
logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
|
| 387 |
+
embeddings_key = get_components_name(real_model, "tok_embeddings")
|
| 388 |
+
if embeddings_key is not None:
|
| 389 |
+
embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
|
| 390 |
+
real_model.register_module(embeddings_key, embeddings)
|
| 391 |
+
|
| 392 |
+
norm_key = get_components_name(real_model, "norm")
|
| 393 |
+
if norm_key is not None:
|
| 394 |
+
norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
|
| 395 |
+
real_model.register_module(norm_key, norm)
|
| 396 |
+
|
| 397 |
+
lm_head_key = get_components_name(model, "lm_head")
|
| 398 |
+
if lm_head_key is not None:
|
| 399 |
+
lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
|
| 400 |
+
model.register_module(lm_head_key, lm_head)
|
| 401 |
+
|
| 402 |
+
logger.info("Compiling the entire model with torch.compile")
|
| 403 |
+
model = torch.compile(model)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def apply_fsdp(
|
| 407 |
+
model: nn.Module,
|
| 408 |
+
dp_mesh: DeviceMesh,
|
| 409 |
+
param_dtype: torch.dtype,
|
| 410 |
+
reduce_dtype: torch.dtype,
|
| 411 |
+
pp_enabled: bool,
|
| 412 |
+
cpu_offload: bool = False,
|
| 413 |
+
reshard_after_forward_policy: str = "default",
|
| 414 |
+
):
|
| 415 |
+
"""
|
| 416 |
+
Apply data parallelism (via FSDP2) to the model.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
model (nn.Module): The model to apply data parallelism to.
|
| 420 |
+
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
|
| 421 |
+
param_dtype (torch.dtype): The data type to use for model parameters.
|
| 422 |
+
reduce_dtype (torch.dtype): The data type to use for reduction operations.
|
| 423 |
+
pp_enabled (bool): Whether pipeline parallelism is enabled.
|
| 424 |
+
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
|
| 425 |
+
reshard_after_forward_policy (str, optional):
|
| 426 |
+
The policy to use for resharding after forward pass. Defaults to "default".
|
| 427 |
+
Other options: "never", "always".
|
| 428 |
+
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
|
| 429 |
+
- "always" will enable `reshard_after_forward` for all forward passes.
|
| 430 |
+
- "never" will disable `reshard_after_forward` for all forward passes.
|
| 431 |
+
|
| 432 |
+
"""
|
| 433 |
+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
|
| 434 |
+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
| 435 |
+
if cpu_offload:
|
| 436 |
+
fsdp_config["offload_policy"] = CPUOffloadPolicy()
|
| 437 |
+
|
| 438 |
+
blocks = get_blocks(model)
|
| 439 |
+
if blocks is None:
|
| 440 |
+
logger.warning("No block found for FSDP")
|
| 441 |
+
else:
|
| 442 |
+
total_blocks = len(blocks)
|
| 443 |
+
for layer_id, block in enumerate(blocks):
|
| 444 |
+
if reshard_after_forward_policy == "always":
|
| 445 |
+
reshard_after_forward = True
|
| 446 |
+
elif reshard_after_forward_policy == "never":
|
| 447 |
+
reshard_after_forward = False
|
| 448 |
+
elif reshard_after_forward_policy == "default":
|
| 449 |
+
if pp_enabled:
|
| 450 |
+
# For PP, do not reshard after forward to avoid per-microbatch
|
| 451 |
+
# all-gathers, which can be expensive and non-overlapped
|
| 452 |
+
reshard_after_forward = False
|
| 453 |
+
else:
|
| 454 |
+
# As an optimization, do not reshard after forward for the last
|
| 455 |
+
# transformer block since FSDP would prefetch it immediately
|
| 456 |
+
reshard_after_forward = int(layer_id) < total_blocks - 1
|
| 457 |
+
else:
|
| 458 |
+
raise ValueError(
|
| 459 |
+
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
|
| 460 |
+
)
|
| 461 |
+
fully_shard(
|
| 462 |
+
block,
|
| 463 |
+
**fsdp_config,
|
| 464 |
+
reshard_after_forward=reshard_after_forward,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def apply_ddp(
|
| 471 |
+
model: nn.Module,
|
| 472 |
+
dp_mesh: DeviceMesh,
|
| 473 |
+
enable_compile: bool,
|
| 474 |
+
enable_compiled_autograd: bool,
|
| 475 |
+
):
|
| 476 |
+
if enable_compile:
|
| 477 |
+
if enable_compiled_autograd:
|
| 478 |
+
torch._dynamo.config.optimize_ddp = (
|
| 479 |
+
"python_reducer_without_compiled_forward"
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
|
| 483 |
+
|
| 484 |
+
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
|
| 485 |
+
|
| 486 |
+
logger.info("Applied DDP to the model")
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def get_model(model):
|
| 490 |
+
base_model_prefix = getattr(model, "base_model_prefix", "model")
|
| 491 |
+
if not hasattr(model, base_model_prefix):
|
| 492 |
+
return None
|
| 493 |
+
model = getattr(model, base_model_prefix)
|
| 494 |
+
return model
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def get_blocks(model):
|
| 498 |
+
# TODO[flame]: adapt for network not using 'layers' attribute
|
| 499 |
+
model = get_model(model)
|
| 500 |
+
if not hasattr(model, "layers"):
|
| 501 |
+
logger.warning('no "layers" in model can be found')
|
| 502 |
+
return None
|
| 503 |
+
return model.layers
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def get_components_name(model, component_name):
|
| 507 |
+
"""
|
| 508 |
+
We try to catch tok_embeddings, norm layers and lm_head layers
|
| 509 |
+
We do not catch the layer names in the blocks, for blocks see `get_blocks`
|
| 510 |
+
We assume the model has the following structure:
|
| 511 |
+
LlamaForCausalLM:
|
| 512 |
+
Model:
|
| 513 |
+
embed_tokens,
|
| 514 |
+
layers,
|
| 515 |
+
norm,
|
| 516 |
+
lm_head
|
| 517 |
+
***
|
| 518 |
+
so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
|
| 519 |
+
and for 'lm_head' we need to pass `model`
|
| 520 |
+
***
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
if component_name == "tok_embeddings":
|
| 524 |
+
if hasattr(model, "tok_embeddings"):
|
| 525 |
+
return "tok_embeddings"
|
| 526 |
+
elif hasattr(model, "embed_tokens"):
|
| 527 |
+
return "embed_tokens"
|
| 528 |
+
elif hasattr(model, "embeddings"):
|
| 529 |
+
return "embeddings"
|
| 530 |
+
else:
|
| 531 |
+
logger.warning("No tok_embeddings found in model")
|
| 532 |
+
return None
|
| 533 |
+
|
| 534 |
+
elif component_name == "norm":
|
| 535 |
+
if hasattr(model, "norm"):
|
| 536 |
+
return "norm"
|
| 537 |
+
elif hasattr(model, "norms"):
|
| 538 |
+
return "norms"
|
| 539 |
+
elif hasattr(model, "layernorm"):
|
| 540 |
+
return "layernorm"
|
| 541 |
+
else:
|
| 542 |
+
logger.warning("No norm found in model")
|
| 543 |
+
return None
|
| 544 |
+
|
| 545 |
+
elif component_name == "lm_head":
|
| 546 |
+
if hasattr(model, "lm_head"):
|
| 547 |
+
return "lm_head"
|
| 548 |
+
else:
|
| 549 |
+
logger.warning("No lm_head found in model")
|
| 550 |
+
return None
|
flame/models/pipeline_fla.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file applies the PT-D pipeline parallelism to the Llama model.
|
| 8 |
+
|
| 9 |
+
import copy
|
| 10 |
+
from typing import Callable, Optional, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.distributed import DeviceMesh
|
| 15 |
+
from torch.distributed.pipelining import PipelineStage
|
| 16 |
+
from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
|
| 17 |
+
from transformers import PretrainedConfig
|
| 18 |
+
|
| 19 |
+
from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
|
| 20 |
+
from torchtitan.config_manager import JobConfig
|
| 21 |
+
from torchtitan.distributed.parallel_dims import ParallelDims
|
| 22 |
+
from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
|
| 23 |
+
from torchtitan.tools.logging import logger
|
| 24 |
+
|
| 25 |
+
DeviceType = Union[int, str, torch.device]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def pipeline_fla(
|
| 29 |
+
model: nn.Module,
|
| 30 |
+
pp_mesh: DeviceMesh,
|
| 31 |
+
parallel_dims: ParallelDims,
|
| 32 |
+
job_config: JobConfig,
|
| 33 |
+
device: DeviceType,
|
| 34 |
+
model_config: PretrainedConfig,
|
| 35 |
+
loss_fn: Callable[..., torch.Tensor],
|
| 36 |
+
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
|
| 37 |
+
stages, models = pipeline_fla_manual_split(
|
| 38 |
+
model, pp_mesh, parallel_dims, job_config, device, model_config
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
|
| 42 |
+
|
| 43 |
+
# This is used in the train loop to determine whether to pass in the input_ids and labels
|
| 44 |
+
has_first_stage = False
|
| 45 |
+
has_last_stage = False
|
| 46 |
+
for stage in stages:
|
| 47 |
+
if stage.is_first:
|
| 48 |
+
has_first_stage = True
|
| 49 |
+
if stage.is_last:
|
| 50 |
+
has_last_stage = True
|
| 51 |
+
|
| 52 |
+
return pp_schedule, models, has_first_stage, has_last_stage
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def pipeline_fla_manual_split(
|
| 56 |
+
whole_model: nn.Module,
|
| 57 |
+
pp_mesh: DeviceMesh,
|
| 58 |
+
parallel_dims: ParallelDims,
|
| 59 |
+
job_config: JobConfig,
|
| 60 |
+
device: DeviceType,
|
| 61 |
+
model_config: PretrainedConfig,
|
| 62 |
+
) -> tuple[list[PipelineStage], list[nn.Module]]:
|
| 63 |
+
"""
|
| 64 |
+
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
|
| 65 |
+
|
| 66 |
+
It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
|
| 67 |
+
|
| 68 |
+
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
|
| 69 |
+
parallelism.
|
| 70 |
+
"""
|
| 71 |
+
pp_rank = pp_mesh.get_local_rank()
|
| 72 |
+
pp_size = pp_mesh.size()
|
| 73 |
+
|
| 74 |
+
splits = (
|
| 75 |
+
job_config.experimental.pipeline_parallel_split_points
|
| 76 |
+
or generate_split_points(
|
| 77 |
+
job_config, parallel_dims.pp, model_config.num_hidden_layers
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def _build_stage(
|
| 82 |
+
stage_idx: int,
|
| 83 |
+
start_layer: Optional[str],
|
| 84 |
+
stop_layer: Optional[str],
|
| 85 |
+
is_first: bool = False,
|
| 86 |
+
is_last: bool = False,
|
| 87 |
+
) -> tuple[PipelineStage, nn.Module]:
|
| 88 |
+
model = copy.deepcopy(whole_model)
|
| 89 |
+
if not is_first:
|
| 90 |
+
# we do `model.tok_embeddings = None` here
|
| 91 |
+
real_model = get_model(model)
|
| 92 |
+
tok_embeddings_name = get_components_name(real_model, "tok_embeddings")
|
| 93 |
+
setattr(real_model, tok_embeddings_name, None)
|
| 94 |
+
|
| 95 |
+
drop_layers = start_layer is not None
|
| 96 |
+
# Get module dictionary from get_blocks(model)
|
| 97 |
+
# and Create a list of keys before modifying dictionary
|
| 98 |
+
module_dict = get_blocks(model)._modules # Store reference
|
| 99 |
+
layer_names = list(module_dict.keys())
|
| 100 |
+
|
| 101 |
+
# Iterate over the list of keys instead of `_modules.items()`
|
| 102 |
+
for name in layer_names:
|
| 103 |
+
# Dynamically determine prefix (blocks.* or layers.*)
|
| 104 |
+
prefix = start_layer.split(".")[0] if start_layer else "layers"
|
| 105 |
+
layer_name = f"{prefix}.{name}" # Construct the correct name format
|
| 106 |
+
|
| 107 |
+
# Ensure `drop_layers` activation is based on actual naming
|
| 108 |
+
if layer_name == start_layer:
|
| 109 |
+
drop_layers = False
|
| 110 |
+
if layer_name == stop_layer:
|
| 111 |
+
drop_layers = True
|
| 112 |
+
|
| 113 |
+
# Delete layer if drop_layers is active
|
| 114 |
+
if drop_layers:
|
| 115 |
+
del module_dict[name] # Safe deletion from stored dictionary
|
| 116 |
+
|
| 117 |
+
if not is_last:
|
| 118 |
+
# we do `model.norm = None` and `model.output = None`
|
| 119 |
+
real_model = get_model(model)
|
| 120 |
+
norm_name = get_components_name(real_model, "norm")
|
| 121 |
+
setattr(real_model, norm_name, None)
|
| 122 |
+
|
| 123 |
+
head_name = get_components_name(model, "lm_head")
|
| 124 |
+
setattr(model, head_name, None)
|
| 125 |
+
|
| 126 |
+
stage = PipelineStage(
|
| 127 |
+
model,
|
| 128 |
+
stage_idx,
|
| 129 |
+
num_stages,
|
| 130 |
+
device,
|
| 131 |
+
group=pp_mesh.get_group("pp"),
|
| 132 |
+
)
|
| 133 |
+
return stage, model
|
| 134 |
+
|
| 135 |
+
num_stages = len(splits) + 1
|
| 136 |
+
stage_idx = pp_rank
|
| 137 |
+
|
| 138 |
+
stages = []
|
| 139 |
+
models = []
|
| 140 |
+
|
| 141 |
+
schedule_class = get_schedule_class(
|
| 142 |
+
job_config.experimental.pipeline_parallel_schedule
|
| 143 |
+
)
|
| 144 |
+
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
|
| 145 |
+
|
| 146 |
+
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
|
| 147 |
+
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
|
| 148 |
+
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
|
| 149 |
+
stage, model_chunk = _build_stage(
|
| 150 |
+
stage_idx,
|
| 151 |
+
start_layer,
|
| 152 |
+
stop_layer,
|
| 153 |
+
is_first=stage_idx == 0,
|
| 154 |
+
is_last=stage_idx == num_stages - 1,
|
| 155 |
+
)
|
| 156 |
+
logger.info(
|
| 157 |
+
f"PP rank {pp_rank} is building stage_idx {stage_idx}"
|
| 158 |
+
f" with start_layer {start_layer}, stop_layer {stop_layer}"
|
| 159 |
+
)
|
| 160 |
+
stages.append(stage)
|
| 161 |
+
models.append(model_chunk)
|
| 162 |
+
return stages, models
|
flame/tools/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
flame/tools/utils.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torchtitan.tools.logging import logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
|
| 12 |
+
nparams = sum(p.numel() for p in model.parameters())
|
| 13 |
+
nparams_embedding = sum(
|
| 14 |
+
sum(p.numel() for p in m.parameters())
|
| 15 |
+
for m in model.children()
|
| 16 |
+
if isinstance(m, nn.Embedding)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
if hasattr(model_config, "num_heads"):
|
| 20 |
+
num_heads = model_config.num_heads
|
| 21 |
+
elif hasattr(model_config, "num_attention_heads"):
|
| 22 |
+
num_heads = model_config.num_attention_heads
|
| 23 |
+
else:
|
| 24 |
+
num_heads = 1
|
| 25 |
+
logger.warning("num_heads not found in model_config, defaulting to 1. ")
|
| 26 |
+
|
| 27 |
+
l, h, q, t = (
|
| 28 |
+
model_config.num_hidden_layers,
|
| 29 |
+
num_heads,
|
| 30 |
+
model_config.hidden_size // num_heads,
|
| 31 |
+
seq_len,
|
| 32 |
+
)
|
| 33 |
+
# Reasoning behind the factor of 12 for the self-attention part of the formula:
|
| 34 |
+
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
|
| 35 |
+
# 2. the flash attention does 1 more matmul recomputation in the backward
|
| 36 |
+
# but recomputation should not be counted in calculating MFU (+0)
|
| 37 |
+
# 3. each matmul performs 1 multiplication and 1 addition (*2)
|
| 38 |
+
# 4. we follow the convention and do not account for sparsity in causal attention
|
| 39 |
+
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
| 40 |
+
|
| 41 |
+
return nparams, num_flops_per_token
|
flame/utils/__init__.py
ADDED
|
File without changes
|
flame/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (136 Bytes). View file
|
|
|
flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc
ADDED
|
Binary file (3.73 kB). View file
|
|
|
flame/utils/convert_dcp_to_hf.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from datetime import timedelta
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.serialization
|
| 12 |
+
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
| 13 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
import fla # noqa
|
| 16 |
+
from torchtitan.tools.logging import init_logger, logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@torch.inference_mode()
|
| 20 |
+
def save_pretrained(
|
| 21 |
+
path: str,
|
| 22 |
+
step: int,
|
| 23 |
+
config: str,
|
| 24 |
+
tokenizer: str
|
| 25 |
+
):
|
| 26 |
+
logger.info(f"Loading the config from {config}")
|
| 27 |
+
config = AutoConfig.from_pretrained(config, trust_remote_code=True)
|
| 28 |
+
|
| 29 |
+
logger.info(f"Saving the config to {path}")
|
| 30 |
+
config.save_pretrained(path)
|
| 31 |
+
logger.info(f"Loading the tokenizer from {tokenizer}")
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
|
| 33 |
+
logger.info(f"Saving the tokenizer to {path}")
|
| 34 |
+
tokenizer.save_pretrained(path)
|
| 35 |
+
|
| 36 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 37 |
+
# base_checkpoint_dir = os.path.dirname(path)
|
| 38 |
+
base_checkpoint_dir = path
|
| 39 |
+
checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}')
|
| 40 |
+
checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
|
| 41 |
+
logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
|
| 42 |
+
dcp_to_torch_save(checkpoint, checkpoint_path)
|
| 43 |
+
|
| 44 |
+
logger.info(f"Initializing the model from config\n{config}")
|
| 45 |
+
model = AutoModelForCausalLM.from_config(config)
|
| 46 |
+
logger.info(model)
|
| 47 |
+
logger.info("Loading state dict from the checkpoint")
|
| 48 |
+
|
| 49 |
+
# Add datetime.timedelta and io.BytesIO to safe globals
|
| 50 |
+
torch.serialization.add_safe_globals([timedelta, io.BytesIO])
|
| 51 |
+
# torch.load now with default weights_only=True will work
|
| 52 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
|
| 53 |
+
|
| 54 |
+
logger.info(f"Saving the model to {path}")
|
| 55 |
+
model.save_pretrained(path)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
init_logger()
|
| 60 |
+
parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
|
| 61 |
+
parser.add_argument("--path", type=str, required=True)
|
| 62 |
+
parser.add_argument("--step", type=int, required=True)
|
| 63 |
+
parser.add_argument("--config", type=str, required=True)
|
| 64 |
+
parser.add_argument("--tokenizer", type=str, required=True)
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
save_pretrained(args.path, args.step, args.config, args.tokenizer)
|
flame/utils/convert_hf_to_dcp.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed.checkpoint as DCP
|
| 9 |
+
from transformers import AutoModelForCausalLM
|
| 10 |
+
|
| 11 |
+
import fla # noqa
|
| 12 |
+
from torchtitan.tools.logging import init_logger, logger
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@torch.inference_mode()
|
| 16 |
+
def convert_hf_weights(model: str, checkpoint: str):
|
| 17 |
+
logger.info(f"Loading model from {model}")
|
| 18 |
+
model = AutoModelForCausalLM.from_pretrained(model)
|
| 19 |
+
state_dict = model.state_dict()
|
| 20 |
+
|
| 21 |
+
logger.info(f"Writing to DCP at '{checkpoint}'")
|
| 22 |
+
checkpoint.mkdir(parents=True, exist_ok=True)
|
| 23 |
+
storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
|
| 24 |
+
DCP.save({"model": state_dict}, storage_writer=storage_writer)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
init_logger()
|
| 29 |
+
parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
|
| 30 |
+
parser.add_argument("--model", type=str, required=True)
|
| 31 |
+
parser.add_argument("--checkpoint", type=Path, required=True)
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
|
| 34 |
+
convert_hf_weights(args.model, args.checkpoint)
|
flame/utils/hf_utils.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo
|
| 4 |
+
from torchtitan.tools.logging import logger
|
| 5 |
+
|
| 6 |
+
def upload_checkpoint_to_hf(
|
| 7 |
+
local_path: str,
|
| 8 |
+
step: int,
|
| 9 |
+
hf_repo_id_for_run: str,
|
| 10 |
+
hf_keep_latest_k: int,
|
| 11 |
+
upload_format: str
|
| 12 |
+
):
|
| 13 |
+
"""Uploads a checkpoint directory to HF Hub and manages retention."""
|
| 14 |
+
if not os.path.isdir(local_path):
|
| 15 |
+
logger.error(f"Local path for upload does not exist or is not a directory: {local_path}")
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
api = HfApi()
|
| 19 |
+
token = HfFolder.get_token()
|
| 20 |
+
if not token:
|
| 21 |
+
logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.")
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
# --- Ensure the specific repository for this run exists ---
|
| 25 |
+
try:
|
| 26 |
+
logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...")
|
| 27 |
+
# Use create_repo which handles creation only if it doesn't exist
|
| 28 |
+
create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True)
|
| 29 |
+
logger.info(f"Repository {hf_repo_id_for_run} ensured.")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True)
|
| 32 |
+
return # Stop if repo interaction fails
|
| 33 |
+
|
| 34 |
+
commit_message = f"Upload {upload_format.upper()} checkpoint step {step}"
|
| 35 |
+
path_in_repo = f"step-{step}"
|
| 36 |
+
|
| 37 |
+
logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...")
|
| 38 |
+
try:
|
| 39 |
+
api.upload_folder(
|
| 40 |
+
folder_path=local_path,
|
| 41 |
+
path_in_repo=path_in_repo,
|
| 42 |
+
repo_id=hf_repo_id_for_run,
|
| 43 |
+
repo_type="model",
|
| 44 |
+
commit_message=commit_message,
|
| 45 |
+
token=token,
|
| 46 |
+
)
|
| 47 |
+
logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.")
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True)
|
| 50 |
+
if hf_keep_latest_k > 0:
|
| 51 |
+
logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}")
|
| 52 |
+
try:
|
| 53 |
+
repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False)
|
| 54 |
+
step_folders = [
|
| 55 |
+
item.path for item in repo_files
|
| 56 |
+
if item.path.startswith("step-") and item.path[5:].isdigit()
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True)
|
| 60 |
+
|
| 61 |
+
if len(step_folders) > hf_keep_latest_k:
|
| 62 |
+
folders_to_delete = step_folders[hf_keep_latest_k:]
|
| 63 |
+
logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}")
|
| 64 |
+
for folder in folders_to_delete:
|
| 65 |
+
# Deleting requires repo_id, path_in_repo, and token
|
| 66 |
+
api.delete_folder(
|
| 67 |
+
repo_id=hf_repo_id_for_run,
|
| 68 |
+
path_in_repo=folder,
|
| 69 |
+
repo_type="model",
|
| 70 |
+
commit_message=f"Delete old checkpoint {folder}",
|
| 71 |
+
token=token
|
| 72 |
+
)
|
| 73 |
+
logger.info("Hub cleanup complete.")
|
| 74 |
+
else:
|
| 75 |
+
logger.info("No old checkpoints found on Hub to delete.")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True)
|
logs/none_g37i6vbo/attempt_0/6/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/none_lyv0rec_/attempt_0/0/stdout.log
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[2m2025-09-10T00:25:50.402942Z[0m [33m WARN[0m [33mStatus Code: 502. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
|
| 2 |
+
[2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
|
| 3 |
+
|
| 4 |
+
[2m2025-09-10T00:25:50.448322Z[0m [33m WARN[0m [33mStatus Code: 502. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
|
| 5 |
+
[2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
|
| 6 |
+
|
| 7 |
+
[2m2025-09-10T00:26:01.892901Z[0m [33m WARN[0m [33mStatus Code: 504. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
|
| 8 |
+
[2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
|
| 9 |
+
|
| 10 |
+
[2m2025-09-10T00:26:01.894451Z[0m [33m WARN[0m [33mStatus Code: 504. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
|
| 11 |
+
[2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
|
| 12 |
+
|
| 13 |
+
[2m2025-09-10T00:26:46.358405Z[0m [33m WARN[0m [33mStatus Code: 504. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
|
| 14 |
+
[2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
|
| 15 |
+
|
| 16 |
+
[2m2025-09-10T00:26:50.304225Z[0m [33m WARN[0m [33mStatus Code: 502. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
|
| 17 |
+
[2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
|
| 18 |
+
|
| 19 |
+
[2m2025-09-10T00:27:00.830860Z[0m [33m WARN[0m [33mStatus Code: 504. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
|
| 20 |
+
[2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
|
| 21 |
+
|
| 22 |
+
[2m2025-09-10T00:28:33.662622Z[0m [33m WARN[0m [33mStatus Code: 502. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
|
| 23 |
+
[2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
|
| 24 |
+
|
| 25 |
+
[2m2025-09-10T00:37:21.678500Z[0m [33m WARN[0m [33mStatus Code: 502. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
|
| 26 |
+
[2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
|
| 27 |
+
|
| 28 |
+
[2m2025-09-10T00:37:33.396089Z[0m [33m WARN[0m [33mStatus Code: 504. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
|
| 29 |
+
[2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
|
| 30 |
+
|
| 31 |
+
[2m2025-09-10T00:38:21.672469Z[0m [33m WARN[0m [33mStatus Code: 502. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
|
| 32 |
+
[2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:220
|
| 33 |
+
|
logs/none_lyv0rec_/attempt_0/7/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/none_lyv0rec_/attempt_0/7/stdout.log
ADDED
|
File without changes
|
tb/20250909-0619/wandb/debug.log
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Current SDK version is 0.21.0
|
| 2 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Configure stats pid to 795439
|
| 3 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/.config/wandb/settings
|
| 4 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/flame/wandb/settings
|
| 5 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from environment variables
|
| 6 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:setup_run_log_directory():703] Logging user logs to exp/top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug.log
|
| 7 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to exp/top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug-internal.log
|
| 8 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():830] calling init triggers
|
| 9 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():835] wandb.init called with sweep_config: {}
|
| 10 |
+
config: {'_wandb': {}}
|
| 11 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():871] starting backend
|
| 12 |
+
2025-09-09 06:19:20,025 INFO MainThread:795439 [wandb_init.py:init():874] sending inform_init request
|
| 13 |
+
2025-09-09 06:19:20,027 INFO MainThread:795439 [wandb_init.py:init():882] backend started and connected
|
| 14 |
+
2025-09-09 06:19:20,033 INFO MainThread:795439 [wandb_init.py:init():953] updated telemetry
|
| 15 |
+
2025-09-09 06:19:20,039 INFO MainThread:795439 [wandb_init.py:init():977] communicating run to backend with 90.0 second timeout
|
| 16 |
+
2025-09-09 06:19:20,682 INFO MainThread:795439 [wandb_init.py:init():1029] starting run threads in backend
|
| 17 |
+
2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_console_start():2458] atexit reg
|
| 18 |
+
2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2306] redirect: wrap_raw
|
| 19 |
+
2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2375] Wrapping output streams.
|
| 20 |
+
2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2398] Redirects installed.
|
| 21 |
+
2025-09-09 06:19:20,817 INFO MainThread:795439 [wandb_init.py:init():1075] run started, returning control to user process
|
tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/files/output.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/files/requirements.txt
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flame==0.1.0
|
| 2 |
+
pluggy==1.6.0
|
| 3 |
+
triton==3.2.0
|
| 4 |
+
sympy==1.13.1
|
| 5 |
+
wcwidth==0.2.13
|
| 6 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 7 |
+
peft==0.17.0
|
| 8 |
+
smart_open==7.3.0.post1
|
| 9 |
+
cymem==2.0.11
|
| 10 |
+
spacy-legacy==3.0.12
|
| 11 |
+
h11==0.16.0
|
| 12 |
+
pytablewriter==1.2.1
|
| 13 |
+
idna==3.10
|
| 14 |
+
regex==2025.7.34
|
| 15 |
+
antlr4-python3-runtime==4.13.2
|
| 16 |
+
wandb==0.21.0
|
| 17 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 18 |
+
sentencepiece==0.2.1
|
| 19 |
+
zstandard==0.23.0
|
| 20 |
+
pybind11==3.0.0
|
| 21 |
+
inquirerpy==0.3.4
|
| 22 |
+
contourpy==1.3.3
|
| 23 |
+
Pygments==2.19.2
|
| 24 |
+
sniffio==1.3.1
|
| 25 |
+
Jinja2==3.1.6
|
| 26 |
+
packaging==25.0
|
| 27 |
+
Markdown==3.8.2
|
| 28 |
+
astunparse==1.6.3
|
| 29 |
+
spacy==3.8.7
|
| 30 |
+
pyparsing==3.2.3
|
| 31 |
+
networkx==3.5
|
| 32 |
+
ninja==1.11.1.4
|
| 33 |
+
tf-slim==1.1.0
|
| 34 |
+
PyYAML==6.0.2
|
| 35 |
+
smmap==5.0.2
|
| 36 |
+
tiktoken==0.9.0
|
| 37 |
+
flatbuffers==25.2.10
|
| 38 |
+
tensorflow==2.20.0
|
| 39 |
+
langcodes==3.5.0
|
| 40 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 41 |
+
numexpr==2.11.0
|
| 42 |
+
charset-normalizer==3.4.3
|
| 43 |
+
frozenlist==1.7.0
|
| 44 |
+
setuptools==80.9.0
|
| 45 |
+
cycler==0.12.1
|
| 46 |
+
weasel==0.4.1
|
| 47 |
+
tzdata==2025.2
|
| 48 |
+
sacrebleu==2.5.1
|
| 49 |
+
rouge_score==0.1.2
|
| 50 |
+
requests==2.32.5
|
| 51 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 52 |
+
grpcio==1.74.0
|
| 53 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 54 |
+
mdurl==0.1.2
|
| 55 |
+
pandas==2.3.1
|
| 56 |
+
preshed==3.0.10
|
| 57 |
+
attrs==25.3.0
|
| 58 |
+
tensorboard-data-server==0.7.2
|
| 59 |
+
aiohappyeyeballs==2.6.1
|
| 60 |
+
keras==3.11.2
|
| 61 |
+
wrapt==1.17.3
|
| 62 |
+
aiosignal==1.4.0
|
| 63 |
+
tcolorpy==0.1.7
|
| 64 |
+
platformdirs==4.3.8
|
| 65 |
+
tqdm-multiprocess==0.0.11
|
| 66 |
+
python-dotenv==1.1.1
|
| 67 |
+
wasabi==1.1.3
|
| 68 |
+
google-pasta==0.2.0
|
| 69 |
+
optree==0.17.0
|
| 70 |
+
MarkupSafe==3.0.2
|
| 71 |
+
colorlog==6.9.0
|
| 72 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 73 |
+
lm_eval==0.4.9.1
|
| 74 |
+
lxml==6.0.0
|
| 75 |
+
protobuf==6.32.0
|
| 76 |
+
radgraph==0.1.18
|
| 77 |
+
scipy==1.16.1
|
| 78 |
+
click==8.2.1
|
| 79 |
+
wheel==0.45.1
|
| 80 |
+
marisa-trie==1.3.0
|
| 81 |
+
pathvalidate==3.3.1
|
| 82 |
+
nvidia-nccl-cu12==2.21.5
|
| 83 |
+
evaluate==0.4.5
|
| 84 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 85 |
+
transformers==4.51.3
|
| 86 |
+
aenum==3.1.15
|
| 87 |
+
typing-inspection==0.4.1
|
| 88 |
+
gitdb==4.0.12
|
| 89 |
+
iniconfig==2.1.0
|
| 90 |
+
multidict==6.6.3
|
| 91 |
+
huggingface-hub==0.34.4
|
| 92 |
+
tokenizers==0.21.4
|
| 93 |
+
tabledata==1.3.4
|
| 94 |
+
mbstrdecoder==1.1.4
|
| 95 |
+
Werkzeug==3.1.3
|
| 96 |
+
accelerate==1.10.0
|
| 97 |
+
hf-xet==1.1.8
|
| 98 |
+
tensorboard==2.20.0
|
| 99 |
+
ml_dtypes==0.5.3
|
| 100 |
+
pytest==8.4.1
|
| 101 |
+
namex==0.1.0
|
| 102 |
+
pillow==11.3.0
|
| 103 |
+
datasets==3.6.0
|
| 104 |
+
tqdm==4.67.1
|
| 105 |
+
murmurhash==1.0.13
|
| 106 |
+
fonttools==4.59.1
|
| 107 |
+
absl-py==2.3.1
|
| 108 |
+
multiprocess==0.70.16
|
| 109 |
+
fsspec==2025.3.0
|
| 110 |
+
transformers==4.51.3
|
| 111 |
+
dill==0.3.8
|
| 112 |
+
propcache==0.3.2
|
| 113 |
+
jsonpickle==4.1.1
|
| 114 |
+
BLEURT==0.0.2
|
| 115 |
+
yarl==1.20.1
|
| 116 |
+
portalocker==3.2.0
|
| 117 |
+
httpx==0.27.2
|
| 118 |
+
numpy==2.3.2
|
| 119 |
+
mpmath==1.3.0
|
| 120 |
+
pyarrow==21.0.0
|
| 121 |
+
matplotlib==3.10.5
|
| 122 |
+
typepy==1.3.4
|
| 123 |
+
pycountry==24.6.1
|
| 124 |
+
word2number==1.1
|
| 125 |
+
psutil==7.0.0
|
| 126 |
+
catalogue==2.0.10
|
| 127 |
+
latex2sympy2_extended==1.0.6
|
| 128 |
+
pydantic_core==2.33.2
|
| 129 |
+
threadpoolctl==3.6.0
|
| 130 |
+
spacy-loggers==1.0.5
|
| 131 |
+
certifi==2025.8.3
|
| 132 |
+
confection==0.1.5
|
| 133 |
+
flame==0.1.0
|
| 134 |
+
pfzy==0.3.4
|
| 135 |
+
safetensors==0.6.2
|
| 136 |
+
pip==25.1
|
| 137 |
+
DataProperty==1.1.0
|
| 138 |
+
lighteval==0.10.1.dev0
|
| 139 |
+
jsonlines==4.0.0
|
| 140 |
+
scikit-learn==1.7.1
|
| 141 |
+
torch==2.6.0
|
| 142 |
+
pytz==2025.2
|
| 143 |
+
python-dateutil==2.9.0.post0
|
| 144 |
+
nltk==3.9.1
|
| 145 |
+
sqlitedict==2.1.0
|
| 146 |
+
gast==0.6.0
|
| 147 |
+
nvidia-curand-cu12==10.3.5.147
|
| 148 |
+
rich==14.1.0
|
| 149 |
+
sentry-sdk==2.33.2
|
| 150 |
+
nvidia-cusparselt-cu12==0.6.2
|
| 151 |
+
kiwisolver==1.4.9
|
| 152 |
+
appdirs==1.4.4
|
| 153 |
+
bert-score==0.3.13
|
| 154 |
+
blis==1.3.0
|
| 155 |
+
GitPython==3.1.45
|
| 156 |
+
chardet==5.2.0
|
| 157 |
+
more-itertools==10.7.0
|
| 158 |
+
filelock==3.19.1
|
| 159 |
+
transformers==4.51.3
|
| 160 |
+
httpcore==1.0.9
|
| 161 |
+
termcolor==3.1.0
|
| 162 |
+
typer==0.16.1
|
| 163 |
+
einops==0.8.1
|
| 164 |
+
torchdata==0.11.0
|
| 165 |
+
six==1.17.0
|
| 166 |
+
colorama==0.4.6
|
| 167 |
+
aiohttp==3.12.14
|
| 168 |
+
srsly==2.5.1
|
| 169 |
+
urllib3==2.5.0
|
| 170 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 171 |
+
cloudpathlib==0.21.1
|
| 172 |
+
h5py==3.14.0
|
| 173 |
+
thinc==8.3.6
|
| 174 |
+
markdown-it-py==4.0.0
|
| 175 |
+
flash-attn==2.7.3
|
| 176 |
+
prompt_toolkit==3.0.52
|
| 177 |
+
nvidia-nvtx-cu12==12.4.127
|
| 178 |
+
en_core_web_sm==3.8.0
|
| 179 |
+
xxhash==3.5.0
|
| 180 |
+
anyio==4.10.0
|
| 181 |
+
joblib==1.5.1
|
| 182 |
+
pydantic==2.11.7
|
| 183 |
+
opt_einsum==3.4.0
|
| 184 |
+
dotmap==1.3.30
|
| 185 |
+
language_data==1.3.0
|
| 186 |
+
shellingham==1.5.4
|
| 187 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 188 |
+
typing_extensions==4.14.1
|
| 189 |
+
libclang==18.1.1
|
| 190 |
+
tabulate==0.9.0
|
| 191 |
+
annotated-types==0.7.0
|
| 192 |
+
jaraco.context==5.3.0
|
| 193 |
+
autocommand==2.2.2
|
| 194 |
+
more-itertools==10.3.0
|
| 195 |
+
tomli==2.0.1
|
| 196 |
+
jaraco.functools==4.0.1
|
| 197 |
+
zipp==3.19.2
|
| 198 |
+
backports.tarfile==1.2.0
|
| 199 |
+
wheel==0.45.1
|
| 200 |
+
platformdirs==4.2.2
|
| 201 |
+
inflect==7.3.1
|
| 202 |
+
typing_extensions==4.12.2
|
| 203 |
+
jaraco.text==3.12.1
|
| 204 |
+
typeguard==4.3.0
|
| 205 |
+
importlib_metadata==8.0.0
|
| 206 |
+
packaging==24.2
|
| 207 |
+
jaraco.collections==5.1.0
|
tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-09-09T06:19:20.029854482Z","level":"INFO","msg":"stream: starting","core version":"0.21.0"}
|
| 2 |
+
{"time":"2025-09-09T06:19:20.338868384Z","level":"INFO","msg":"stream: created new stream","id":"top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614"}
|
| 3 |
+
{"time":"2025-09-09T06:19:20.338942945Z","level":"INFO","msg":"stream: started","id":"top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614"}
|
| 4 |
+
{"time":"2025-09-09T06:19:20.338955936Z","level":"INFO","msg":"handler: started","stream_id":"top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614"}
|
| 5 |
+
{"time":"2025-09-09T06:19:20.33900181Z","level":"INFO","msg":"writer: Do: started","stream_id":"top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614"}
|
| 6 |
+
{"time":"2025-09-09T06:19:20.339014387Z","level":"INFO","msg":"sender: started","stream_id":"top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614"}
|
| 7 |
+
{"time":"2025-09-09T16:55:51.461783187Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 8 |
+
{"time":"2025-09-09T17:52:23.968650788Z","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 9 |
+
{"time":"2025-09-09T22:51:18.011409168Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
| 10 |
+
{"time":"2025-09-09T22:58:20.165767227Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug.log
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Current SDK version is 0.21.0
|
| 2 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Configure stats pid to 795439
|
| 3 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/.config/wandb/settings
|
| 4 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/flame/wandb/settings
|
| 5 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_setup.py:_flush():80] Loading settings from environment variables
|
| 6 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:setup_run_log_directory():703] Logging user logs to exp/top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug.log
|
| 7 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to exp/top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20250909-0619/wandb/run-20250909_061919-top_transformer-top.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202509090614/logs/debug-internal.log
|
| 8 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():830] calling init triggers
|
| 9 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():835] wandb.init called with sweep_config: {}
|
| 10 |
+
config: {'_wandb': {}}
|
| 11 |
+
2025-09-09 06:19:19,818 INFO MainThread:795439 [wandb_init.py:init():871] starting backend
|
| 12 |
+
2025-09-09 06:19:20,025 INFO MainThread:795439 [wandb_init.py:init():874] sending inform_init request
|
| 13 |
+
2025-09-09 06:19:20,027 INFO MainThread:795439 [wandb_init.py:init():882] backend started and connected
|
| 14 |
+
2025-09-09 06:19:20,033 INFO MainThread:795439 [wandb_init.py:init():953] updated telemetry
|
| 15 |
+
2025-09-09 06:19:20,039 INFO MainThread:795439 [wandb_init.py:init():977] communicating run to backend with 90.0 second timeout
|
| 16 |
+
2025-09-09 06:19:20,682 INFO MainThread:795439 [wandb_init.py:init():1029] starting run threads in backend
|
| 17 |
+
2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_console_start():2458] atexit reg
|
| 18 |
+
2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2306] redirect: wrap_raw
|
| 19 |
+
2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2375] Wrapping output streams.
|
| 20 |
+
2025-09-09 06:19:20,815 INFO MainThread:795439 [wandb_run.py:_redirect():2398] Redirects installed.
|
| 21 |
+
2025-09-09 06:19:20,817 INFO MainThread:795439 [wandb_init.py:init():1075] run started, returning control to user process
|
torchtitan/components/__pycache__/dataloader.cpython-312.pyc
ADDED
|
Binary file (3.79 kB). View file
|
|
|
torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc
ADDED
|
Binary file (7.71 kB). View file
|
|
|
torchtitan/components/__pycache__/metrics.cpython-312.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
torchtitan/components/__pycache__/tokenizer.cpython-312.pyc
ADDED
|
Binary file (1.09 kB). View file
|
|
|
torchtitan/components/metrics.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
from collections import namedtuple
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 15 |
+
from torchtitan.components.lr_scheduler import LRSchedulersContainer
|
| 16 |
+
from torchtitan.components.optimizer import OptimizersContainer
|
| 17 |
+
from torchtitan.config_manager import JobConfig
|
| 18 |
+
from torchtitan.distributed import ParallelDims
|
| 19 |
+
from torchtitan.tools import utils
|
| 20 |
+
from torchtitan.tools.logging import logger
|
| 21 |
+
from torchtitan.tools.utils import Color, device_module, device_type
|
| 22 |
+
|
| 23 |
+
# named tuple for passing device memory stats for logging
|
| 24 |
+
DeviceMemStats = namedtuple(
|
| 25 |
+
"DeviceMemStats",
|
| 26 |
+
[
|
| 27 |
+
"max_active_gib",
|
| 28 |
+
"max_active_pct",
|
| 29 |
+
"max_reserved_gib",
|
| 30 |
+
"max_reserved_pct",
|
| 31 |
+
"num_alloc_retries",
|
| 32 |
+
"num_ooms",
|
| 33 |
+
],
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DeviceMemoryMonitor:
|
| 38 |
+
def __init__(self, device: str = f"{device_type}:0"):
|
| 39 |
+
self.device = torch.device(device) # device object
|
| 40 |
+
self.device_name = device_module.get_device_name(self.device)
|
| 41 |
+
self.device_index = device_module.current_device()
|
| 42 |
+
self.device_capacity = device_module.get_device_properties(
|
| 43 |
+
self.device
|
| 44 |
+
).total_memory
|
| 45 |
+
self.device_capacity_gib = self._to_gib(self.device_capacity)
|
| 46 |
+
|
| 47 |
+
device_module.reset_peak_memory_stats()
|
| 48 |
+
device_module.empty_cache()
|
| 49 |
+
|
| 50 |
+
def _to_gib(self, memory_in_bytes):
|
| 51 |
+
# NOTE: GiB (gibibyte) is 1024, vs GB is 1000
|
| 52 |
+
_gib_in_bytes = 1024 * 1024 * 1024
|
| 53 |
+
memory_in_gib = memory_in_bytes / _gib_in_bytes
|
| 54 |
+
return memory_in_gib
|
| 55 |
+
|
| 56 |
+
def _to_pct(self, memory):
|
| 57 |
+
return 100 * memory / self.device_capacity
|
| 58 |
+
|
| 59 |
+
def get_peak_stats(self):
|
| 60 |
+
device_info = device_module.memory_stats(self.device)
|
| 61 |
+
|
| 62 |
+
max_active = device_info.get("active_bytes.all.peak", -1)
|
| 63 |
+
max_active_gib = self._to_gib(max_active)
|
| 64 |
+
max_active_pct = self._to_pct(max_active)
|
| 65 |
+
|
| 66 |
+
max_reserved = device_info.get("reserved_bytes.all.peak", -1)
|
| 67 |
+
max_reserved_gib = self._to_gib(max_reserved)
|
| 68 |
+
max_reserved_pct = self._to_pct(max_reserved)
|
| 69 |
+
|
| 70 |
+
num_retries = device_info.get("num_alloc_retries", -1)
|
| 71 |
+
num_ooms = device_info.get("num_ooms", -1)
|
| 72 |
+
|
| 73 |
+
if num_retries > 0:
|
| 74 |
+
logger.warning(
|
| 75 |
+
f"{num_retries} {device_type.upper()} memory allocation retries."
|
| 76 |
+
)
|
| 77 |
+
if num_ooms > 0:
|
| 78 |
+
logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.")
|
| 79 |
+
|
| 80 |
+
return DeviceMemStats(
|
| 81 |
+
max_active_gib,
|
| 82 |
+
max_active_pct,
|
| 83 |
+
max_reserved_gib,
|
| 84 |
+
max_reserved_pct,
|
| 85 |
+
num_retries,
|
| 86 |
+
num_ooms,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def reset_peak_stats(self):
|
| 90 |
+
device_module.reset_peak_memory_stats()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def build_device_memory_monitor():
|
| 94 |
+
device_memory_monitor = DeviceMemoryMonitor(device_type)
|
| 95 |
+
logger.info(
|
| 96 |
+
f"{device_type.upper()} capacity: {device_memory_monitor.device_name} "
|
| 97 |
+
f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory"
|
| 98 |
+
)
|
| 99 |
+
return device_memory_monitor
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class BaseLogger:
|
| 103 |
+
"""Logger that does nothing, used when logging is disabled."""
|
| 104 |
+
|
| 105 |
+
def log(self, metrics: dict[str, Any], step: int) -> None:
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
def close(self) -> None:
|
| 109 |
+
pass
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class TensorBoardLogger(BaseLogger):
|
| 113 |
+
"""Logger implementation for TensorBoard."""
|
| 114 |
+
|
| 115 |
+
def __init__(self, log_dir: str, tag: str | None = None):
|
| 116 |
+
self.tag = tag
|
| 117 |
+
self.writer = SummaryWriter(log_dir, max_queue=1000)
|
| 118 |
+
logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}")
|
| 119 |
+
|
| 120 |
+
def log(self, metrics: dict[str, Any], step: int) -> None:
|
| 121 |
+
for k, v in metrics.items():
|
| 122 |
+
tag = k if self.tag is None else f"{self.tag}/{k}"
|
| 123 |
+
self.writer.add_scalar(tag, v, step)
|
| 124 |
+
|
| 125 |
+
def close(self) -> None:
|
| 126 |
+
self.writer.close()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class WandBLogger(BaseLogger):
|
| 130 |
+
"""Logger implementation for Weights & Biases."""
|
| 131 |
+
|
| 132 |
+
def __init__(self, log_dir: str, tag: str | None = None):
|
| 133 |
+
# Import wandb here to avoid startup import
|
| 134 |
+
import wandb
|
| 135 |
+
|
| 136 |
+
self.wandb = wandb
|
| 137 |
+
self.tag = tag
|
| 138 |
+
|
| 139 |
+
# Create logging directory
|
| 140 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 141 |
+
|
| 142 |
+
self.wandb.init(
|
| 143 |
+
project=os.getenv("WANDB_PROJECT", "torchtitan"),
|
| 144 |
+
dir=log_dir,
|
| 145 |
+
)
|
| 146 |
+
logger.info("WandB logging enabled")
|
| 147 |
+
|
| 148 |
+
def log(self, metrics: dict[str, Any], step: int) -> None:
|
| 149 |
+
wandb_metrics = {
|
| 150 |
+
(k if self.tag is None else f"{self.tag}/{k}"): v
|
| 151 |
+
for k, v in metrics.items()
|
| 152 |
+
}
|
| 153 |
+
self.wandb.log(wandb_metrics, step=step)
|
| 154 |
+
|
| 155 |
+
def close(self) -> None:
|
| 156 |
+
if self.wandb.run is not None:
|
| 157 |
+
self.wandb.finish()
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def ensure_pp_loss_visible(
|
| 161 |
+
parallel_dims: ParallelDims, job_config: JobConfig, color: Color
|
| 162 |
+
) -> None:
|
| 163 |
+
"""
|
| 164 |
+
Ensures that the loss is visible on the console for pipeline-parallel training.
|
| 165 |
+
|
| 166 |
+
For pipeline-parallel training, the loss is only visible on the last pipeline stage.
|
| 167 |
+
This function checks if the appropriate rank is included in the LOG_RANK environment
|
| 168 |
+
variable and warns if it's not.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
# V Block Schedules return loss on rank 0
|
| 172 |
+
if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble":
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
+
# Calculate the rank where loss is visible (first rank of the last pipeline stage)
|
| 176 |
+
world_size = parallel_dims.world_size
|
| 177 |
+
pp_size = parallel_dims.pp
|
| 178 |
+
loss_visible_rank = (world_size // pp_size) * (pp_size - 1)
|
| 179 |
+
|
| 180 |
+
# Check if the loss-visible rank is included in LOG_RANK environment variable
|
| 181 |
+
env_logged_ranks = os.environ.get("LOG_RANK", "").split(",")
|
| 182 |
+
if env_logged_ranks == [""]:
|
| 183 |
+
env_logged_ranks = []
|
| 184 |
+
|
| 185 |
+
if str(loss_visible_rank) not in env_logged_ranks:
|
| 186 |
+
logger.warning(
|
| 187 |
+
f"{color.red}Pipeline Parallel loss is not visible. "
|
| 188 |
+
f"Please add {color.yellow}rank {loss_visible_rank}{color.red} "
|
| 189 |
+
f"to LOG_RANK environment variable in run_train.sh.{color.reset}"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _get_metrics_rank(
|
| 194 |
+
parallel_dims: ParallelDims,
|
| 195 |
+
job_config: JobConfig,
|
| 196 |
+
) -> int:
|
| 197 |
+
"""
|
| 198 |
+
Determines which rank should log metrics.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
int: The rank responsible for logging metrics:
|
| 202 |
+
- Rank 0 for non-pipeline-parallel configs
|
| 203 |
+
- Rank 0 for pipeline-parallel 'ZBVZeroBubble' schedule
|
| 204 |
+
- The first rank of the last pipeline stage for other pipeline-parallel schedules
|
| 205 |
+
"""
|
| 206 |
+
# Early return for non-pipeline-parallel configurations
|
| 207 |
+
if not parallel_dims.pp_enabled:
|
| 208 |
+
return 0
|
| 209 |
+
|
| 210 |
+
# V Block Schedules return loss on rank 0
|
| 211 |
+
if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble":
|
| 212 |
+
return 0
|
| 213 |
+
|
| 214 |
+
# Calculate first rank of the last pipeline stage
|
| 215 |
+
world_size = parallel_dims.world_size
|
| 216 |
+
pp_size = parallel_dims.pp
|
| 217 |
+
return (world_size // pp_size) * (pp_size - 1)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _build_metric_logger(
|
| 221 |
+
job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None
|
| 222 |
+
) -> BaseLogger:
|
| 223 |
+
"""
|
| 224 |
+
Build an appropriate metric logger based on configuration.
|
| 225 |
+
"""
|
| 226 |
+
metrics_config = job_config.metrics
|
| 227 |
+
|
| 228 |
+
# Log initial config state
|
| 229 |
+
logger.debug(
|
| 230 |
+
f"Building logger with config: wandb={metrics_config.enable_wandb}, "
|
| 231 |
+
f"tensorboard={metrics_config.enable_tensorboard}"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Check if any logging backend is enabled
|
| 235 |
+
has_logging_enabled = (
|
| 236 |
+
metrics_config.enable_tensorboard or metrics_config.enable_wandb
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Determine if this rank should log
|
| 240 |
+
should_log = has_logging_enabled
|
| 241 |
+
if (not metrics_config.save_for_all_ranks) and should_log:
|
| 242 |
+
metrics_rank = _get_metrics_rank(parallel_dims, job_config)
|
| 243 |
+
should_log = torch.distributed.get_rank() == metrics_rank
|
| 244 |
+
|
| 245 |
+
logger.debug(
|
| 246 |
+
f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if not should_log:
|
| 250 |
+
logger.debug("Returning BaseLogger due to should_log=False")
|
| 251 |
+
return BaseLogger()
|
| 252 |
+
|
| 253 |
+
# Setup logging directory
|
| 254 |
+
dump_dir = job_config.job.dump_folder
|
| 255 |
+
base_log_dir = os.path.join(
|
| 256 |
+
dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M")
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if metrics_config.save_for_all_ranks:
|
| 260 |
+
base_log_dir = os.path.join(
|
| 261 |
+
base_log_dir, f"rank_{torch.distributed.get_rank()}"
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Create loggers in priority order
|
| 265 |
+
if metrics_config.enable_wandb:
|
| 266 |
+
logger.debug("Attempting to create WandB logger")
|
| 267 |
+
try:
|
| 268 |
+
return WandBLogger(base_log_dir, tag)
|
| 269 |
+
except Exception as e:
|
| 270 |
+
if "No module named 'wandb'" in str(e):
|
| 271 |
+
logger.error(
|
| 272 |
+
"Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'."
|
| 273 |
+
)
|
| 274 |
+
else:
|
| 275 |
+
logger.error(f"Failed to create WandB logger: {e}")
|
| 276 |
+
|
| 277 |
+
if metrics_config.enable_tensorboard:
|
| 278 |
+
logger.debug("Creating TensorBoard logger")
|
| 279 |
+
return TensorBoardLogger(base_log_dir, tag)
|
| 280 |
+
|
| 281 |
+
logger.debug("No loggers enabled, returning BaseLogger")
|
| 282 |
+
return BaseLogger()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class MetricsProcessor:
|
| 286 |
+
"""Metrics processor to processes the metrics and log metrics.
|
| 287 |
+
|
| 288 |
+
The current MetricsProcessor log some metrics to STDOUT and some metrics to
|
| 289 |
+
TensorBoard or WandB.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
job_config (JobConfig): Job configuration.
|
| 293 |
+
parallel_dims (ParallelDims): Parallel dimensions.
|
| 294 |
+
tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
logger: BaseLogger
|
| 298 |
+
parallel_dims: ParallelDims
|
| 299 |
+
job_config: JobConfig
|
| 300 |
+
device_memory_monitor: DeviceMemoryMonitor
|
| 301 |
+
color: utils.NoColor | utils.Color
|
| 302 |
+
|
| 303 |
+
gpu_peak_flops: int
|
| 304 |
+
ntokens_since_last_log: int
|
| 305 |
+
data_loading_times: list[float]
|
| 306 |
+
time_last_log: float
|
| 307 |
+
|
| 308 |
+
num_flops_per_token: int
|
| 309 |
+
optimizers: OptimizersContainer | None
|
| 310 |
+
lr_schedulers: LRSchedulersContainer | None
|
| 311 |
+
|
| 312 |
+
def __init__(
|
| 313 |
+
self,
|
| 314 |
+
job_config: JobConfig,
|
| 315 |
+
parallel_dims: ParallelDims,
|
| 316 |
+
tag: str | None = None,
|
| 317 |
+
):
|
| 318 |
+
self.logger = _build_metric_logger(job_config, parallel_dims, tag)
|
| 319 |
+
self.parallel_dims = parallel_dims
|
| 320 |
+
self.job_config = job_config
|
| 321 |
+
self.device_memory_monitor = build_device_memory_monitor()
|
| 322 |
+
# used for colorful printing
|
| 323 |
+
self.color = (
|
| 324 |
+
utils.NoColor()
|
| 325 |
+
if job_config.metrics.disable_color_printing
|
| 326 |
+
else utils.Color()
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
self.gpu_peak_flops = utils.get_peak_flops(
|
| 330 |
+
self.device_memory_monitor.device_name
|
| 331 |
+
)
|
| 332 |
+
self.ntokens_since_last_log = 0
|
| 333 |
+
self.data_loading_times = []
|
| 334 |
+
self.time_last_log = time.perf_counter()
|
| 335 |
+
self.device_memory_monitor.reset_peak_stats()
|
| 336 |
+
|
| 337 |
+
# These variables have to be set later as they depend on other components or model.
|
| 338 |
+
self.num_flops_per_token = -1
|
| 339 |
+
self.optimizers = None
|
| 340 |
+
self.lr_schedulers = None
|
| 341 |
+
|
| 342 |
+
def should_log(self, step: int) -> bool:
|
| 343 |
+
return step == 1 or step % self.job_config.metrics.log_freq == 0
|
| 344 |
+
|
| 345 |
+
def log(
|
| 346 |
+
self,
|
| 347 |
+
step: int,
|
| 348 |
+
global_avg_loss: float,
|
| 349 |
+
global_max_loss: float,
|
| 350 |
+
extra_metrics: dict[str, Any] | None = None,
|
| 351 |
+
):
|
| 352 |
+
assert self.num_flops_per_token > 0, "num_flops_per_token must be set"
|
| 353 |
+
|
| 354 |
+
time_delta = time.perf_counter() - self.time_last_log
|
| 355 |
+
|
| 356 |
+
# tokens per second per device, abbreviated as tps
|
| 357 |
+
tps = self.ntokens_since_last_log / (
|
| 358 |
+
time_delta * self.parallel_dims.non_data_parallel_size
|
| 359 |
+
)
|
| 360 |
+
# model FLOPS utilization
|
| 361 |
+
# For its definition and calculation, please refer to the PaLM paper:
|
| 362 |
+
# https://arxiv.org/abs/2204.02311
|
| 363 |
+
mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops
|
| 364 |
+
tflops = self.num_flops_per_token * tps / 1e12
|
| 365 |
+
|
| 366 |
+
time_end_to_end = time_delta / self.job_config.metrics.log_freq
|
| 367 |
+
time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times)
|
| 368 |
+
time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta
|
| 369 |
+
|
| 370 |
+
device_mem_stats = self.device_memory_monitor.get_peak_stats()
|
| 371 |
+
|
| 372 |
+
metrics = {
|
| 373 |
+
"loss_metrics/global_avg_loss": global_avg_loss,
|
| 374 |
+
"loss_metrics/global_max_loss": global_max_loss,
|
| 375 |
+
"throughput(tps)": tps,
|
| 376 |
+
"tflops": tflops,
|
| 377 |
+
"mfu(%)": mfu,
|
| 378 |
+
"time_metrics/end_to_end(s)": time_end_to_end,
|
| 379 |
+
"time_metrics/data_loading(s)": time_data_loading,
|
| 380 |
+
"time_metrics/data_loading(%)": time_data_loading_pct,
|
| 381 |
+
"memory/max_active(GiB)": device_mem_stats.max_active_gib,
|
| 382 |
+
"memory/max_active(%)": device_mem_stats.max_active_pct,
|
| 383 |
+
"memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib,
|
| 384 |
+
"memory/max_reserved(%)": device_mem_stats.max_reserved_pct,
|
| 385 |
+
"memory/num_alloc_retries": device_mem_stats.num_alloc_retries,
|
| 386 |
+
"memory/num_ooms": device_mem_stats.num_ooms,
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
if extra_metrics:
|
| 390 |
+
metrics.update(extra_metrics)
|
| 391 |
+
|
| 392 |
+
self.logger.log(metrics, step)
|
| 393 |
+
|
| 394 |
+
color = self.color
|
| 395 |
+
construct_string = str(
|
| 396 |
+
f"{color.red}step: {step:2} "
|
| 397 |
+
f"{color.green}loss: {global_avg_loss:7.4f} "
|
| 398 |
+
f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
|
| 399 |
+
f"({device_mem_stats.max_reserved_pct:.2f}%) "
|
| 400 |
+
f"{color.blue}tps: {round(tps):,} "
|
| 401 |
+
f"{color.cyan}tflops: {tflops:,.2f} "
|
| 402 |
+
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
if extra_metrics:
|
| 406 |
+
for k, v in extra_metrics.items():
|
| 407 |
+
if "loss" in k:
|
| 408 |
+
construct_string += f" {color.white}{k.lstrip('loss_metrics/')}: {v:7.4f}"
|
| 409 |
+
logger.info(
|
| 410 |
+
construct_string
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
self.ntokens_since_last_log = 0
|
| 414 |
+
self.data_loading_times.clear()
|
| 415 |
+
self.time_last_log = time.perf_counter()
|
| 416 |
+
self.device_memory_monitor.reset_peak_stats()
|
| 417 |
+
|
| 418 |
+
def close(self):
|
| 419 |
+
self.logger.close()
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def build_metrics_processor(
|
| 423 |
+
job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None
|
| 424 |
+
) -> MetricsProcessor:
|
| 425 |
+
"""Create a metrics processor.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
job_config (JobConfig): Job configuration.
|
| 429 |
+
parallel_dims (ParallelDims): Parallel dimensions.
|
| 430 |
+
tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None.
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
MetricsProcessor: A metrics processor.
|
| 434 |
+
"""
|
| 435 |
+
return MetricsProcessor(job_config, parallel_dims, tag)
|
torchtitan/experiments/deepseek_v3/LICENSE-CODE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 DeepSeek
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
torchtitan/experiments/deepseek_v3/README.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Running DeepSeek in Titan (experimental)
|
| 2 |
+
|
| 3 |
+
This folder contains a DeepSeek model supporting v2 and v3 as well as kernels
|
| 4 |
+
and scripts needed to run it.
|
| 5 |
+
|
| 6 |
+
## Inference
|
| 7 |
+
|
| 8 |
+
### Prerequisites:
|
| 9 |
+
|
| 10 |
+
You will need to download a DeepSeek model's weights if you want to run a
|
| 11 |
+
pre-trained checkpoint. We provided a script to download the weights from
|
| 12 |
+
HuggingFace Model Hub:
|
| 13 |
+
```bash
|
| 14 |
+
python download.py [vX]
|
| 15 |
+
```
|
| 16 |
+
where `vX` can be v2 or v3, both are supported. You may be required to create a
|
| 17 |
+
HuggingFace account and log in first.
|
| 18 |
+
|
| 19 |
+
### Running inference:
|
| 20 |
+
|
| 21 |
+
The inference script is in `generate.py`. You can run it with the following
|
| 22 |
+
command:
|
| 23 |
+
```bash
|
| 24 |
+
torchrun --standalone --nproc-per-node 4 generate.py
|
| 25 |
+
```
|
| 26 |
+
This will run inference on the `DeepSeek-V2-Lite-Chat` model using 4 GPUs by
|
| 27 |
+
default.
|
| 28 |
+
|
| 29 |
+
Alternatively, you can run inference by using `bash inference.sh`, optionally
|
| 30 |
+
followed by your prompt.
|
| 31 |
+
|
| 32 |
+
## Training
|
| 33 |
+
|
| 34 |
+
The training script is in `train.py`. You can run it by the following command:
|
| 35 |
+
```bash
|
| 36 |
+
torchrun --standalone --nproc-per-node 8 train.py
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
This will run training on the `DeepSeek-V2-Lite-Chat` model using 8 GPUs by
|
| 40 |
+
default, with pipeline parallel, expert parallel, and data parallel enabled.
|
torchtitan/experiments/deepseek_v3/checkpoint.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
from typing import Dict, Optional, Set, Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from safetensors import safe_open
|
| 14 |
+
|
| 15 |
+
from transformers.utils import cached_file
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def read_weights_from_json(file_path: str) -> Optional[Dict[str, str]]:
|
| 24 |
+
try:
|
| 25 |
+
with open(file_path, "r") as file:
|
| 26 |
+
data = json.load(file)
|
| 27 |
+
|
| 28 |
+
if "weight_map" in data and isinstance(data["weight_map"], dict):
|
| 29 |
+
return data["weight_map"]
|
| 30 |
+
else:
|
| 31 |
+
logger.info("No 'weight_map' dictionary found in the JSON file.")
|
| 32 |
+
return None
|
| 33 |
+
except (json.JSONDecodeError, Exception) as e:
|
| 34 |
+
logger.info(f"An error occurred while reading the JSON file: {str(e)}")
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_hf_weight_map_and_path(
|
| 39 |
+
model_id: str,
|
| 40 |
+
) -> Tuple[Dict[str, str], str]:
|
| 41 |
+
"""Get the weight map for a given HF model id and also the cache path for loading the weights"""
|
| 42 |
+
try:
|
| 43 |
+
index_file = cached_file(model_id, _DEFAULT_SAFETENSOR_FILE_NAME)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(
|
| 46 |
+
f"Model `{model_id}` not found in HF cache. "
|
| 47 |
+
f"You can download the model using `python download.py {model_id}"
|
| 48 |
+
)
|
| 49 |
+
raise e
|
| 50 |
+
|
| 51 |
+
weight_map = read_weights_from_json(index_file)
|
| 52 |
+
weight_path = os.path.dirname(index_file)
|
| 53 |
+
logger.info(f"Loading weights from: {weight_path}")
|
| 54 |
+
return weight_map, weight_path
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_needed_files(
|
| 58 |
+
state_dict: Dict[str, torch.Tensor], weight_map: Dict[str, str]
|
| 59 |
+
) -> Set[str]:
|
| 60 |
+
needed_files = set()
|
| 61 |
+
for param in state_dict.keys():
|
| 62 |
+
file = weight_map.get(param)
|
| 63 |
+
if file:
|
| 64 |
+
needed_files.add(file)
|
| 65 |
+
elif param.endswith("weight"):
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f"Parameter {param} not found in weight map, please check..."
|
| 68 |
+
)
|
| 69 |
+
logger.info(f"Needed files: {needed_files}")
|
| 70 |
+
return needed_files
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_safetensor_file(
|
| 74 |
+
full_path: str, device: torch.device
|
| 75 |
+
) -> Dict[str, torch.Tensor]:
|
| 76 |
+
tensors = {}
|
| 77 |
+
with safe_open(full_path, framework="pt", device=device) as f:
|
| 78 |
+
for k in f.keys():
|
| 79 |
+
tensors[k] = f.get_tensor(k)
|
| 80 |
+
logger.info(f"Loaded {len(tensors)} tensors from {full_path}")
|
| 81 |
+
return tensors
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def load_safetensor_weights(
|
| 85 |
+
model: torch.nn.Module,
|
| 86 |
+
weight_map: Dict[str, str],
|
| 87 |
+
file_location: str,
|
| 88 |
+
device: torch.device,
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Load safetensor weights into a `nn.Module`.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
model (Module): The PyTorch module to load weights into. It may be a
|
| 95 |
+
model chunk or a full model.
|
| 96 |
+
weight_map (Dict[str, str]): Mapping of model parameters to file names.
|
| 97 |
+
file_location (str): Directory containing the weight files.
|
| 98 |
+
device (torch.device): The device to load tensors onto.
|
| 99 |
+
"""
|
| 100 |
+
model_state_dict = model.state_dict()
|
| 101 |
+
needed_files = get_needed_files(model_state_dict, weight_map)
|
| 102 |
+
updated_states: Set[str] = set()
|
| 103 |
+
|
| 104 |
+
for file in needed_files:
|
| 105 |
+
full_path = os.path.join(file_location, file)
|
| 106 |
+
try:
|
| 107 |
+
checkpoint = load_safetensor_file(full_path, "cpu")
|
| 108 |
+
except FileNotFoundError:
|
| 109 |
+
logger.error(f"File not found: {full_path}")
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.error(f"Error during checkpoint processing of {full_path}: {str(e)}")
|
| 112 |
+
|
| 113 |
+
matched_keys = set(checkpoint.keys()) & set(model_state_dict.keys())
|
| 114 |
+
for key in matched_keys:
|
| 115 |
+
# Check shape
|
| 116 |
+
if model_state_dict[key].shape != checkpoint[key].shape:
|
| 117 |
+
raise ValueError(
|
| 118 |
+
f"Shape mismatch for {key}: "
|
| 119 |
+
f"model needs {model_state_dict[key].shape}, but "
|
| 120 |
+
f"checkpoint has {checkpoint[key].shape}"
|
| 121 |
+
)
|
| 122 |
+
model_state_dict[key] = checkpoint[key].to(device)
|
| 123 |
+
|
| 124 |
+
updated_states.update(matched_keys)
|
| 125 |
+
|
| 126 |
+
missing_keys = set(model_state_dict.keys()) - updated_states
|
| 127 |
+
if missing_keys:
|
| 128 |
+
raise RuntimeError(
|
| 129 |
+
f"Partially updated state dict. Missing parameters: {missing_keys}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
model.load_state_dict(model_state_dict, strict=False, assign=True)
|
| 133 |
+
logger.info(f"Successfully loaded {len(updated_states)} weights into model")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_weights_from_hf(
|
| 137 |
+
model: torch.nn.Module,
|
| 138 |
+
distribution: str,
|
| 139 |
+
device: torch.device,
|
| 140 |
+
):
|
| 141 |
+
"""
|
| 142 |
+
Load the weights from Hugging Face format (index file + multiple safetensor
|
| 143 |
+
files), and fill into `model`. Model config is needed b/c we permute
|
| 144 |
+
wq and wk weights based on attn heads.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
weight_map, weight_path = get_hf_weight_map_and_path(distribution)
|
| 148 |
+
|
| 149 |
+
load_safetensor_weights(
|
| 150 |
+
model,
|
| 151 |
+
weight_map,
|
| 152 |
+
weight_path,
|
| 153 |
+
device,
|
| 154 |
+
)
|
torchtitan/experiments/deepseek_v3/download.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Usage:
|
| 8 |
+
# Downloads a given model to the HF Cache. Pass in a listed option ala "v3" or your own custom model path.
|
| 9 |
+
# python download.py {model_id} [custom_model_path]
|
| 10 |
+
# Examples:
|
| 11 |
+
# python download.py v2 # Use predefined model: deepseek-ai/DeepSeek-V2
|
| 12 |
+
# python download.py custom "deepseek-ai/new-model" # Download a custom model path
|
| 13 |
+
|
| 14 |
+
# Available models:
|
| 15 |
+
# "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
| 16 |
+
# "v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
|
| 17 |
+
# "v2": "deepseek-ai/DeepSeek-V2",
|
| 18 |
+
# "v3": "deepseek-ai/deepseek-v3",
|
| 19 |
+
# "v3-0324": "deepseek-ai/DeepSeek-V3-0324",
|
| 20 |
+
# "custom": None, # Placeholder for custom models
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
from transformers import AutoModelForCausalLM
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
MODELS = {
|
| 29 |
+
"v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
| 30 |
+
"v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
|
| 31 |
+
"v2": "deepseek-ai/DeepSeek-V2",
|
| 32 |
+
"v3": "deepseek-ai/deepseek-v3",
|
| 33 |
+
"v3-0324": "deepseek-ai/DeepSeek-V3-0324",
|
| 34 |
+
"custom": None, # For custom (any) models
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def print_usage():
|
| 39 |
+
print("Usage:")
|
| 40 |
+
print(" python download.py [model_version]")
|
| 41 |
+
print(" python download.py custom [custom_model_path]")
|
| 42 |
+
print("\nAvailable predefined models:")
|
| 43 |
+
for key, model in MODELS.items():
|
| 44 |
+
if key != "custom": # Skip the custom placeholder
|
| 45 |
+
print(f" {key}: {model}")
|
| 46 |
+
print("\nFor custom models:")
|
| 47 |
+
print(" custom: Specify your own model path")
|
| 48 |
+
print(' Example: python download.py custom "organization/model-name"')
|
| 49 |
+
sys.exit(1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Process command line arguments
|
| 53 |
+
if len(sys.argv) < 2 or sys.argv[1] not in MODELS:
|
| 54 |
+
print_usage()
|
| 55 |
+
|
| 56 |
+
if sys.argv[1] == "custom":
|
| 57 |
+
if len(sys.argv) != 3:
|
| 58 |
+
print("Error: Custom model requires a model path")
|
| 59 |
+
print_usage()
|
| 60 |
+
model_id = sys.argv[2]
|
| 61 |
+
print(f"Using custom model: {model_id}")
|
| 62 |
+
else:
|
| 63 |
+
model_id = MODELS[sys.argv[1]]
|
| 64 |
+
print(f"Downloading model: {model_id}")
|
| 65 |
+
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
model_id,
|
| 68 |
+
device_map="auto",
|
| 69 |
+
trust_remote_code=True,
|
| 70 |
+
)
|
torchtitan/experiments/deepseek_v3/model.py
ADDED
|
@@ -0,0 +1,1325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This code is based on model definition of `deepseek-ai/DeepSeek-V3-Base` on
|
| 8 |
+
# Hugging Face Model Hub. Url:
|
| 9 |
+
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py
|
| 10 |
+
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/resolve/main/configuration_deepseek.py
|
| 11 |
+
#
|
| 12 |
+
# It has been modified from its original forms to accommodate naming convention
|
| 13 |
+
# and usage patterns of the TorchTitan project.
|
| 14 |
+
|
| 15 |
+
# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
| 16 |
+
#
|
| 17 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 18 |
+
# you may not use this file except in compliance with the License.
|
| 19 |
+
# You may obtain a copy of the License at
|
| 20 |
+
#
|
| 21 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 22 |
+
#
|
| 23 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 24 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 25 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 26 |
+
# See the License for the specific language governing permissions and
|
| 27 |
+
# limitations under the License.
|
| 28 |
+
""" PyTorch DeepSeek model."""
|
| 29 |
+
import math
|
| 30 |
+
from typing import Optional, Tuple
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.distributed as dist
|
| 34 |
+
|
| 35 |
+
import torch.distributed._symmetric_memory as symm_mem
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
import torch.utils.checkpoint
|
| 38 |
+
|
| 39 |
+
from attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 40 |
+
from indices import generate_permute_indices
|
| 41 |
+
from model_config import ModelArgs
|
| 42 |
+
from symm_mem_recipes import OnDeviceAllToAllV
|
| 43 |
+
from torch import nn
|
| 44 |
+
from torch.distributed._functional_collectives import all_to_all_single_autograd
|
| 45 |
+
|
| 46 |
+
from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import (
|
| 47 |
+
ALIGN_SIZE_M,
|
| 48 |
+
grouped_gemm_forward,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Get model parallel subgroup by name:
|
| 52 |
+
# e.g. "pp", "ep", None
|
| 53 |
+
def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup:
|
| 54 |
+
glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh()
|
| 55 |
+
return glob.get_group(dim_name)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class RMSNorm(nn.Module):
|
| 59 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 62 |
+
self.variance_epsilon = eps
|
| 63 |
+
|
| 64 |
+
def forward(self, hidden_states):
|
| 65 |
+
input_dtype = hidden_states.dtype
|
| 66 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 67 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 68 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 69 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RotaryEmbedding(nn.Module):
|
| 73 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
self.dim = dim
|
| 77 |
+
self.max_position_embeddings = max_position_embeddings
|
| 78 |
+
self.base = base
|
| 79 |
+
inv_freq = 1.0 / (
|
| 80 |
+
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
| 81 |
+
)
|
| 82 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 83 |
+
|
| 84 |
+
# Build here to make `torch.jit.trace` work.
|
| 85 |
+
self._set_cos_sin_cache(
|
| 86 |
+
seq_len=max_position_embeddings,
|
| 87 |
+
device=self.inv_freq.device,
|
| 88 |
+
dtype=torch.get_default_dtype(),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 92 |
+
self.max_seq_len_cached = seq_len
|
| 93 |
+
t = torch.arange(
|
| 94 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
freqs = torch.outer(t, self.inv_freq.to(t.device))
|
| 98 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 99 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 100 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 101 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 102 |
+
|
| 103 |
+
def forward(self, x, seq_len=None):
|
| 104 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 105 |
+
if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
|
| 106 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 107 |
+
|
| 108 |
+
return (
|
| 109 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 110 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
| 115 |
+
"""RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
dim,
|
| 120 |
+
max_position_embeddings=2048,
|
| 121 |
+
base=10000,
|
| 122 |
+
device=None,
|
| 123 |
+
scaling_factor=1.0,
|
| 124 |
+
):
|
| 125 |
+
self.scaling_factor = scaling_factor
|
| 126 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 127 |
+
|
| 128 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 129 |
+
self.max_seq_len_cached = seq_len
|
| 130 |
+
t = torch.arange(
|
| 131 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
| 132 |
+
)
|
| 133 |
+
t = t / self.scaling_factor
|
| 134 |
+
|
| 135 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 136 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 137 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 138 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 139 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek
|
| 143 |
+
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
| 144 |
+
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
dim,
|
| 149 |
+
max_position_embeddings=2048,
|
| 150 |
+
base=10000,
|
| 151 |
+
device=None,
|
| 152 |
+
scaling_factor=1.0,
|
| 153 |
+
):
|
| 154 |
+
self.scaling_factor = scaling_factor
|
| 155 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 156 |
+
|
| 157 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 158 |
+
self.max_seq_len_cached = seq_len
|
| 159 |
+
|
| 160 |
+
if seq_len > self.max_position_embeddings:
|
| 161 |
+
base = self.base * (
|
| 162 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings)
|
| 163 |
+
- (self.scaling_factor - 1)
|
| 164 |
+
) ** (self.dim / (self.dim - 2))
|
| 165 |
+
inv_freq = 1.0 / (
|
| 166 |
+
base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
| 167 |
+
)
|
| 168 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 169 |
+
|
| 170 |
+
t = torch.arange(
|
| 171 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 175 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 176 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 177 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 178 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# Inverse dim formula to find dim based on number of rotations
|
| 182 |
+
def yarn_find_correction_dim(
|
| 183 |
+
num_rotations, dim, base=10000, max_position_embeddings=2048
|
| 184 |
+
):
|
| 185 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
| 186 |
+
2 * math.log(base)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# Find dim range bounds based on rotations
|
| 191 |
+
def yarn_find_correction_range(
|
| 192 |
+
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
| 193 |
+
):
|
| 194 |
+
low = math.floor(
|
| 195 |
+
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
| 196 |
+
)
|
| 197 |
+
high = math.ceil(
|
| 198 |
+
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
| 199 |
+
)
|
| 200 |
+
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def yarn_get_mscale(scale=1, mscale=1):
|
| 204 |
+
if scale <= 1:
|
| 205 |
+
return 1.0
|
| 206 |
+
return 0.1 * mscale * math.log(scale) + 1.0
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def yarn_linear_ramp_mask(min, max, dim):
|
| 210 |
+
if min == max:
|
| 211 |
+
max += 0.001 # Prevent singularity
|
| 212 |
+
|
| 213 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 214 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 215 |
+
return ramp_func
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class YarnRotaryEmbedding(RotaryEmbedding):
|
| 219 |
+
def __init__(
|
| 220 |
+
self,
|
| 221 |
+
dim,
|
| 222 |
+
max_position_embeddings=2048,
|
| 223 |
+
base=10000,
|
| 224 |
+
device=None,
|
| 225 |
+
scaling_factor=1.0,
|
| 226 |
+
original_max_position_embeddings=4096,
|
| 227 |
+
beta_fast=32,
|
| 228 |
+
beta_slow=1,
|
| 229 |
+
mscale=1,
|
| 230 |
+
mscale_all_dim=0,
|
| 231 |
+
):
|
| 232 |
+
self.scaling_factor = scaling_factor
|
| 233 |
+
self.original_max_position_embeddings = original_max_position_embeddings
|
| 234 |
+
self.beta_fast = beta_fast
|
| 235 |
+
self.beta_slow = beta_slow
|
| 236 |
+
self.mscale = mscale
|
| 237 |
+
self.mscale_all_dim = mscale_all_dim
|
| 238 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 239 |
+
|
| 240 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 241 |
+
self.max_seq_len_cached = seq_len
|
| 242 |
+
dim = self.dim
|
| 243 |
+
|
| 244 |
+
freq_extra = 1.0 / (
|
| 245 |
+
self.base
|
| 246 |
+
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
| 247 |
+
)
|
| 248 |
+
freq_inter = 1.0 / (
|
| 249 |
+
self.scaling_factor
|
| 250 |
+
* self.base
|
| 251 |
+
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
low, high = yarn_find_correction_range(
|
| 255 |
+
self.beta_fast,
|
| 256 |
+
self.beta_slow,
|
| 257 |
+
dim,
|
| 258 |
+
self.base,
|
| 259 |
+
self.original_max_position_embeddings,
|
| 260 |
+
)
|
| 261 |
+
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
| 262 |
+
device=device, dtype=torch.float32
|
| 263 |
+
)
|
| 264 |
+
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
| 265 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 266 |
+
|
| 267 |
+
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
| 268 |
+
|
| 269 |
+
freqs = torch.outer(t, inv_freq)
|
| 270 |
+
|
| 271 |
+
_mscale = float(
|
| 272 |
+
yarn_get_mscale(self.scaling_factor, self.mscale)
|
| 273 |
+
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 277 |
+
self.register_buffer(
|
| 278 |
+
"cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
|
| 279 |
+
)
|
| 280 |
+
self.register_buffer(
|
| 281 |
+
"sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 286 |
+
def rotate_half(x):
|
| 287 |
+
"""Rotates half the hidden dims of the input."""
|
| 288 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 289 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 290 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 294 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 295 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
q (`torch.Tensor`): The query tensor.
|
| 299 |
+
k (`torch.Tensor`): The key tensor.
|
| 300 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 301 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 302 |
+
position_ids (`torch.Tensor`):
|
| 303 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 304 |
+
used to pass offsetted position ids when working with a KV-cache.
|
| 305 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 306 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 307 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 308 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 309 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 310 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 311 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 312 |
+
Returns:
|
| 313 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 314 |
+
"""
|
| 315 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 316 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 317 |
+
|
| 318 |
+
b, h, s, d = q.shape
|
| 319 |
+
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 320 |
+
|
| 321 |
+
b, h, s, d = k.shape
|
| 322 |
+
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 323 |
+
|
| 324 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 325 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 326 |
+
return q_embed, k_embed
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class MLP(nn.Module):
|
| 330 |
+
act_fn = nn.SiLU()
|
| 331 |
+
|
| 332 |
+
def __init__(self, config, hidden_size=None, intermediate_size=None):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.config = config
|
| 335 |
+
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
| 336 |
+
self.intermediate_size = (
|
| 337 |
+
config.intermediate_size if intermediate_size is None else intermediate_size
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 341 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 342 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 343 |
+
|
| 344 |
+
def forward(self, x):
|
| 345 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 346 |
+
return down_proj
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class MoEGate(nn.Module):
|
| 350 |
+
def __init__(self, config):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.config = config
|
| 353 |
+
self.top_k = config.num_experts_per_tok
|
| 354 |
+
self.n_routed_experts = config.n_routed_experts
|
| 355 |
+
self.routed_scaling_factor = config.routed_scaling_factor
|
| 356 |
+
self.scoring_func = config.scoring_func
|
| 357 |
+
self.seq_aux = config.seq_aux
|
| 358 |
+
self.topk_method = config.topk_method
|
| 359 |
+
self.n_group = config.n_group
|
| 360 |
+
self.topk_group = config.topk_group
|
| 361 |
+
|
| 362 |
+
# topk selection algorithm
|
| 363 |
+
self.norm_topk_prob = config.norm_topk_prob
|
| 364 |
+
self.gating_dim = config.hidden_size
|
| 365 |
+
self.weight = nn.Parameter(
|
| 366 |
+
torch.empty((self.n_routed_experts, self.gating_dim))
|
| 367 |
+
)
|
| 368 |
+
if self.topk_method == "noaux_tc":
|
| 369 |
+
self.e_score_correction_bias = nn.Parameter(
|
| 370 |
+
# Changed from torch.empty to torch.rand to avoid non-even
|
| 371 |
+
# distribution for runs without actual weigths
|
| 372 |
+
torch.rand((self.n_routed_experts))
|
| 373 |
+
)
|
| 374 |
+
self.reset_parameters()
|
| 375 |
+
|
| 376 |
+
def reset_parameters(self) -> None:
|
| 377 |
+
import torch.nn.init as init
|
| 378 |
+
|
| 379 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 380 |
+
|
| 381 |
+
def forward(self, hidden_states):
|
| 382 |
+
bsz, seq_len, h = hidden_states.shape
|
| 383 |
+
# compute gating score
|
| 384 |
+
hidden_states = hidden_states.view(-1, h)
|
| 385 |
+
logits = F.linear(
|
| 386 |
+
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
|
| 387 |
+
)
|
| 388 |
+
if self.scoring_func == "sigmoid":
|
| 389 |
+
scores = logits.sigmoid()
|
| 390 |
+
elif self.scoring_func == "softmax":
|
| 391 |
+
scores = logits.softmax(dim=-1, dtype=torch.float32)
|
| 392 |
+
else:
|
| 393 |
+
raise NotImplementedError(
|
| 394 |
+
f"insupportable scoring function for MoE gating: {self.scoring_func}"
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# select top-k experts
|
| 398 |
+
if self.topk_method == "noaux_tc":
|
| 399 |
+
scores_for_choice = scores.view(
|
| 400 |
+
bsz * seq_len, -1
|
| 401 |
+
) + self.e_score_correction_bias.unsqueeze(0)
|
| 402 |
+
group_scores = (
|
| 403 |
+
scores_for_choice.view(bsz * seq_len, self.n_group, -1)
|
| 404 |
+
.topk(2, dim=-1)[0]
|
| 405 |
+
.sum(dim=-1)
|
| 406 |
+
) # [n, n_group]
|
| 407 |
+
group_idx = torch.topk(
|
| 408 |
+
group_scores, k=self.topk_group, dim=-1, sorted=False
|
| 409 |
+
)[
|
| 410 |
+
1
|
| 411 |
+
] # [n, top_k_group]
|
| 412 |
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
| 413 |
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
| 414 |
+
score_mask = (
|
| 415 |
+
group_mask.unsqueeze(-1)
|
| 416 |
+
.expand(
|
| 417 |
+
bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
|
| 418 |
+
)
|
| 419 |
+
.reshape(bsz * seq_len, -1)
|
| 420 |
+
) # [n, e]
|
| 421 |
+
tmp_scores = scores_for_choice.masked_fill(
|
| 422 |
+
~score_mask.bool(), 0.0
|
| 423 |
+
) # [n, e]
|
| 424 |
+
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
| 425 |
+
topk_weight = scores.gather(1, topk_idx)
|
| 426 |
+
elif self.topk_method == "greedy":
|
| 427 |
+
topk_weight, topk_idx = torch.topk(
|
| 428 |
+
scores, k=self.top_k, dim=-1, sorted=False
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
raise NotImplementedError(
|
| 432 |
+
f"insupportable TopK function for MoE gating: {self.topk_method}"
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# norm gate to sum 1
|
| 436 |
+
if self.top_k > 1 and self.norm_topk_prob:
|
| 437 |
+
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 438 |
+
topk_weight = topk_weight / denominator
|
| 439 |
+
topk_weight = (
|
| 440 |
+
topk_weight * self.routed_scaling_factor
|
| 441 |
+
) # must multiply the scaling factor
|
| 442 |
+
|
| 443 |
+
return topk_idx, topk_weight
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class MoE(nn.Module):
|
| 447 |
+
"""
|
| 448 |
+
A mixed expert module containing shared experts.
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
# Class attributes:
|
| 452 |
+
# Two shuffle method supported:
|
| 453 |
+
# 1. "torch_all_to_all"
|
| 454 |
+
# 2. "symm_mem" (see `setup_symm_mem` below)
|
| 455 |
+
shuffle_method = "torch_all_to_all"
|
| 456 |
+
|
| 457 |
+
# Symmetric memory buffers shared by all MoE instances across layers
|
| 458 |
+
token_send_buf: Optional[torch.Tensor] = None
|
| 459 |
+
token_gather_buf: Optional[torch.Tensor] = None
|
| 460 |
+
|
| 461 |
+
def __init__(self, config):
|
| 462 |
+
super().__init__()
|
| 463 |
+
self.config = config
|
| 464 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 465 |
+
|
| 466 |
+
# ep_size is the number of ranks in expert dimension
|
| 467 |
+
if config.ep_size <= 1:
|
| 468 |
+
raise ValueError(
|
| 469 |
+
"For code simplicity, this model only supports distributed experts, "
|
| 470 |
+
"thus EP size must be > 1, please modify your model config"
|
| 471 |
+
)
|
| 472 |
+
self.ep_group = get_group("ep")
|
| 473 |
+
assert config.ep_size == self.ep_group.size()
|
| 474 |
+
self.ep_size = config.ep_size
|
| 475 |
+
self.ep_rank = self.ep_group.rank()
|
| 476 |
+
self.experts_per_rank = config.n_routed_experts // config.ep_size
|
| 477 |
+
# Use ModuleDict instead of ModuleList to preserve absoulte expert
|
| 478 |
+
# IDs while avoiding `None` experts. The absolute expert IDs match
|
| 479 |
+
# with checkpoint FQNs.
|
| 480 |
+
self.experts = nn.ModuleDict()
|
| 481 |
+
for i in range(self.experts_per_rank):
|
| 482 |
+
abs_expert_id = self.ep_rank * self.experts_per_rank + i
|
| 483 |
+
self.experts[str(abs_expert_id)] = MLP(
|
| 484 |
+
config, intermediate_size=config.moe_intermediate_size
|
| 485 |
+
)
|
| 486 |
+
self.gate = MoEGate(config)
|
| 487 |
+
if config.n_shared_experts is not None:
|
| 488 |
+
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
| 489 |
+
self.shared_experts = MLP(
|
| 490 |
+
config=config, intermediate_size=intermediate_size
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def combine_experts(self, submod_name):
|
| 494 |
+
all_weights = []
|
| 495 |
+
for expert in self.experts.values():
|
| 496 |
+
lin = expert.get_submodule(submod_name)
|
| 497 |
+
all_weights.append(lin.weight)
|
| 498 |
+
lin.weight = None
|
| 499 |
+
|
| 500 |
+
concat_weight = torch.cat(all_weights)
|
| 501 |
+
self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight))
|
| 502 |
+
|
| 503 |
+
# This function is used to create a symm mem buffer for MoE's. It is for
|
| 504 |
+
# shuffling tokens fully "on-device", as compared to traditional torch
|
| 505 |
+
# all_to_all APIs which requrie a GPU-to-CPU sync of the splits. If a user
|
| 506 |
+
# calls this function, the `shuffle_method` would switch from
|
| 507 |
+
# `torch_all_to_all` to `symm_mem`.
|
| 508 |
+
def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
|
| 509 |
+
# Switch shuffle method
|
| 510 |
+
self.shuffle_method = "symm_mem"
|
| 511 |
+
|
| 512 |
+
# Combine expert weights
|
| 513 |
+
print("Combining expert weights for Group GEMM")
|
| 514 |
+
self.combine_experts("gate_proj")
|
| 515 |
+
self.combine_experts("up_proj")
|
| 516 |
+
self.combine_experts("down_proj")
|
| 517 |
+
|
| 518 |
+
# Assuming worst case, 2x tokens are routed to one EP rank
|
| 519 |
+
overflow = 2
|
| 520 |
+
OnDeviceAllToAllV.max_output_len = (
|
| 521 |
+
self.config.max_seq_len * self.num_experts_per_tok * overflow
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# Symmetric memory buffers are shared by all MoE instances across
|
| 525 |
+
# layers, we only need to initialize them once
|
| 526 |
+
if MoE.token_send_buf is not None:
|
| 527 |
+
return
|
| 528 |
+
|
| 529 |
+
# Input buffer for DP-to-EP shuffle
|
| 530 |
+
MoE.token_send_buf = symm_mem.empty(
|
| 531 |
+
self.config.max_seq_len
|
| 532 |
+
* self.num_experts_per_tok, # seq len * top k (flattened)
|
| 533 |
+
self.config.hidden_size, # hidden dim
|
| 534 |
+
dtype=dtype,
|
| 535 |
+
device=device,
|
| 536 |
+
)
|
| 537 |
+
# Input buffer for EP-to-DP shuffle
|
| 538 |
+
MoE.token_gather_buf = symm_mem.empty(
|
| 539 |
+
self.config.max_seq_len
|
| 540 |
+
* self.num_experts_per_tok # seq len * top k (flattened)
|
| 541 |
+
* overflow,
|
| 542 |
+
self.config.hidden_size, # hidden dim
|
| 543 |
+
dtype=dtype,
|
| 544 |
+
device=device,
|
| 545 |
+
)
|
| 546 |
+
print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE")
|
| 547 |
+
|
| 548 |
+
def get_send_buf(self):
|
| 549 |
+
# [Why detach?] During a first forward-backward step, the buffer would
|
| 550 |
+
# be included in a computational graph. In a second step, autograd will
|
| 551 |
+
# return an error saying "Trying to backward through the graph a second
|
| 552 |
+
# time (or directly access saved tensors more than once)". This is
|
| 553 |
+
# because the buffer is still in the graph, and autograd is trying to
|
| 554 |
+
# backward through the graph a second time. To avoid this, we detach the
|
| 555 |
+
# buffer from the graph. `detach()` returns a new tensor, which shares
|
| 556 |
+
# the same storage with the original one.
|
| 557 |
+
self.token_send_buf.grad = None
|
| 558 |
+
return self.token_send_buf.detach()
|
| 559 |
+
|
| 560 |
+
def get_gather_buf(self):
|
| 561 |
+
# See [Why detach?] in `get_send_buf`
|
| 562 |
+
self.token_gather_buf.grad = None
|
| 563 |
+
return self.token_gather_buf.detach()
|
| 564 |
+
|
| 565 |
+
def forward(self, hidden_states):
|
| 566 |
+
identity = hidden_states
|
| 567 |
+
orig_shape = hidden_states.shape
|
| 568 |
+
# for each token, select top-k experts, and compute the weight for each expert
|
| 569 |
+
topk_idx, topk_weight = self.gate(hidden_states)
|
| 570 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 571 |
+
if self.shuffle_method == "symm_mem":
|
| 572 |
+
y = self.moe_on_device(hidden_states, topk_idx, topk_weight)
|
| 573 |
+
else: # "torch_all_to_all"
|
| 574 |
+
y = self.moe_forward(hidden_states, topk_idx, topk_weight)
|
| 575 |
+
|
| 576 |
+
y = y.view(*orig_shape)
|
| 577 |
+
if self.config.n_shared_experts is not None:
|
| 578 |
+
y = y + self.shared_experts(identity)
|
| 579 |
+
return y
|
| 580 |
+
|
| 581 |
+
def moe_forward(self, x, topk_ids, topk_weight):
|
| 582 |
+
# This part sorts the token indices so that tokens routed to the same expert reside consecutively.
|
| 583 |
+
# An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
|
| 584 |
+
# Since this is an "aritificial" index creation (final outcome being
|
| 585 |
+
# `idxs`), we don't need gradients here.
|
| 586 |
+
with torch.no_grad():
|
| 587 |
+
# [seq_len, n_routed_experts]
|
| 588 |
+
cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
|
| 589 |
+
# Fill 1 to the selected experts
|
| 590 |
+
cnts.scatter_(1, topk_ids, 1)
|
| 591 |
+
tokens_per_expert = cnts.sum(dim=0)
|
| 592 |
+
# Token indices for each expert
|
| 593 |
+
idxs = topk_ids.view(-1).argsort()
|
| 594 |
+
sorted_tokens_shape = idxs.shape + x.shape[1:]
|
| 595 |
+
|
| 596 |
+
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
| 597 |
+
assert sorted_tokens.shape == sorted_tokens_shape
|
| 598 |
+
|
| 599 |
+
# This part exchange the information about the number of tokens send and
|
| 600 |
+
# received by each expert. We can understand this information as "side
|
| 601 |
+
# band", which is not part of the actual data. Thus no gradient is
|
| 602 |
+
# needed.
|
| 603 |
+
with torch.no_grad():
|
| 604 |
+
# Sum the tokens over local experts, then we get tokens per EP rank,
|
| 605 |
+
# which is the input splits
|
| 606 |
+
tokens_per_expert_group = tokens_per_expert.new_empty(
|
| 607 |
+
tokens_per_expert.shape[0]
|
| 608 |
+
)
|
| 609 |
+
dist.all_to_all_single(
|
| 610 |
+
tokens_per_expert_group, tokens_per_expert, group=self.ep_group
|
| 611 |
+
)
|
| 612 |
+
input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
| 613 |
+
|
| 614 |
+
# DP to EP token shuffle. This part needs gradient.
|
| 615 |
+
if self.shuffle_method == "symm_mem":
|
| 616 |
+
# Move input to the `token_send_buf` symm mem
|
| 617 |
+
token_send_buf = self.get_send_buf()
|
| 618 |
+
token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
|
| 619 |
+
# Note: `out=` avoids copy, but it is not differentiable
|
| 620 |
+
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
|
| 621 |
+
token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
|
| 622 |
+
token_send_buf,
|
| 623 |
+
input_splits,
|
| 624 |
+
self.ep_group,
|
| 625 |
+
)
|
| 626 |
+
with torch.no_grad():
|
| 627 |
+
# Received tokens from all other ranks. TODO: use mask instead
|
| 628 |
+
received = output_splits.sum()
|
| 629 |
+
# TODO: don't use `received`
|
| 630 |
+
gathered_tokens = token_gather_buf[:received]
|
| 631 |
+
else: # "torch_all_to_all"
|
| 632 |
+
# Prepare input ans output splits
|
| 633 |
+
with torch.no_grad():
|
| 634 |
+
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
|
| 635 |
+
dim=1
|
| 636 |
+
)
|
| 637 |
+
gathered_tokens = all_to_all_single_autograd(
|
| 638 |
+
sorted_tokens,
|
| 639 |
+
output_splits.tolist(),
|
| 640 |
+
input_splits.tolist(),
|
| 641 |
+
self.ep_group,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# This part prepares a 1D tensor with the same length as
|
| 645 |
+
# `gathered_tokens`. The 1D tensor is filled with local expert IDs which
|
| 646 |
+
# the tokens in `gathered_tokens` are headed for. This part doesn't need
|
| 647 |
+
# gradient.
|
| 648 |
+
with torch.no_grad():
|
| 649 |
+
gatherd_idxs = (
|
| 650 |
+
torch.arange(
|
| 651 |
+
tokens_per_expert_group.numel(),
|
| 652 |
+
device=tokens_per_expert_group.device,
|
| 653 |
+
)
|
| 654 |
+
% self.experts_per_rank
|
| 655 |
+
)
|
| 656 |
+
gatherd_idxs = gatherd_idxs.repeat_interleave(tokens_per_expert_group)
|
| 657 |
+
|
| 658 |
+
# Prepare buffer for tokens processed by experts
|
| 659 |
+
if self.shuffle_method == "symm_mem":
|
| 660 |
+
# Take necessary space from `token_gather_buf` symm mem because we are
|
| 661 |
+
# going to send them out after expert processing
|
| 662 |
+
processed_tokens = self.get_gather_buf()[: gathered_tokens.shape[0]]
|
| 663 |
+
else: # "torch_all_to_all"
|
| 664 |
+
processed_tokens = torch.empty_like(gathered_tokens)
|
| 665 |
+
|
| 666 |
+
# This part processes the tokens routed to the local experts.
|
| 667 |
+
# TODO: can we use group GEMM here?
|
| 668 |
+
for i, expert in enumerate(self.experts.values()):
|
| 669 |
+
processed_tokens[gatherd_idxs == i] = expert(
|
| 670 |
+
gathered_tokens[gatherd_idxs == i]
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
# Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
|
| 674 |
+
# The input/output splits are just a reverse of the previous shuffle.
|
| 675 |
+
if self.shuffle_method == "symm_mem":
|
| 676 |
+
token_return_buf, _ = OnDeviceAllToAllV.apply(
|
| 677 |
+
processed_tokens,
|
| 678 |
+
output_splits,
|
| 679 |
+
self.ep_group,
|
| 680 |
+
)
|
| 681 |
+
returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
|
| 682 |
+
else: # "torch_all_to_all"
|
| 683 |
+
returned_tokens = all_to_all_single_autograd(
|
| 684 |
+
processed_tokens,
|
| 685 |
+
input_splits.tolist(),
|
| 686 |
+
output_splits.tolist(),
|
| 687 |
+
self.ep_group,
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
output_tokens = torch.empty_like(returned_tokens)
|
| 691 |
+
output_tokens[idxs] = returned_tokens
|
| 692 |
+
final_out = (
|
| 693 |
+
output_tokens.view(*topk_ids.shape, -1)
|
| 694 |
+
.type(topk_weight.dtype)
|
| 695 |
+
.mul_(topk_weight.unsqueeze(dim=-1))
|
| 696 |
+
.sum(dim=1)
|
| 697 |
+
.type(returned_tokens.dtype)
|
| 698 |
+
)
|
| 699 |
+
return final_out
|
| 700 |
+
|
| 701 |
+
def moe_on_device(self, x, topk_ids, topk_weight):
|
| 702 |
+
# This part sorts the token indices so that tokens routed to the same expert reside consecutively.
|
| 703 |
+
# An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
|
| 704 |
+
# Since this is an "aritificial" index creation (final outcome being
|
| 705 |
+
# `idxs`), we don't need gradients here.
|
| 706 |
+
with torch.no_grad():
|
| 707 |
+
# [seq_len, n_routed_experts]
|
| 708 |
+
cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
|
| 709 |
+
# Fill 1 to the selected experts
|
| 710 |
+
cnts.scatter_(1, topk_ids, 1)
|
| 711 |
+
tokens_per_expert = cnts.sum(dim=0)
|
| 712 |
+
# Token indices for each expert
|
| 713 |
+
idxs = topk_ids.view(-1).argsort()
|
| 714 |
+
sorted_tokens_shape = idxs.shape + x.shape[1:]
|
| 715 |
+
|
| 716 |
+
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
| 717 |
+
assert sorted_tokens.shape == sorted_tokens_shape
|
| 718 |
+
|
| 719 |
+
# This part exchange the information about the number of tokens send and
|
| 720 |
+
# received by each expert. We can understand this information as "side
|
| 721 |
+
# band", which is not part of the actual data. Thus no gradient is
|
| 722 |
+
# needed.
|
| 723 |
+
with torch.no_grad():
|
| 724 |
+
# Sum the tokens over local experts, then we get tokens per EP rank,
|
| 725 |
+
# which is the input splits
|
| 726 |
+
tokens_per_expert_group = tokens_per_expert.new_empty(
|
| 727 |
+
tokens_per_expert.shape[0]
|
| 728 |
+
)
|
| 729 |
+
dist.all_to_all_single(
|
| 730 |
+
tokens_per_expert_group, tokens_per_expert, group=self.ep_group
|
| 731 |
+
)
|
| 732 |
+
input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
| 733 |
+
|
| 734 |
+
# Move input to the `token_send_buf` symm mem
|
| 735 |
+
token_send_buf = self.get_send_buf()
|
| 736 |
+
token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
|
| 737 |
+
# Note: `out=` avoids copy, but it is not differentiable
|
| 738 |
+
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
|
| 739 |
+
token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
|
| 740 |
+
token_send_buf,
|
| 741 |
+
input_splits,
|
| 742 |
+
self.ep_group,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
# We need to permute the received tokens so that tokens for the same expert are contiguous.
|
| 746 |
+
# This part prepares a 1D tensor `permuted_indices` for such permutation.
|
| 747 |
+
# This part doesn't need gradient.
|
| 748 |
+
with torch.no_grad():
|
| 749 |
+
permuted_indices, m_sizes = generate_permute_indices(
|
| 750 |
+
tokens_per_expert_group,
|
| 751 |
+
self.experts_per_rank,
|
| 752 |
+
self.ep_size,
|
| 753 |
+
token_gather_buf.shape[0],
|
| 754 |
+
ALIGN_SIZE_M,
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
# Permute the received tokens so that tokens for the same expert are contiguous.
|
| 758 |
+
contig_tokens = token_gather_buf[permuted_indices]
|
| 759 |
+
|
| 760 |
+
# Run the first grouped GEMM
|
| 761 |
+
w1 = self.get_parameter("gate_proj_weight")
|
| 762 |
+
gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes)
|
| 763 |
+
|
| 764 |
+
# Run the second grouped GEMM
|
| 765 |
+
w3 = self.get_parameter("up_proj_weight")
|
| 766 |
+
up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes)
|
| 767 |
+
|
| 768 |
+
# Apply activation
|
| 769 |
+
hidden_outputs = MLP.act_fn(gate_proj) * up_proj
|
| 770 |
+
|
| 771 |
+
# Run the third grouped GEMM
|
| 772 |
+
w2 = self.get_parameter("down_proj_weight")
|
| 773 |
+
hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes)
|
| 774 |
+
|
| 775 |
+
# Prepare buffer for tokens processed by experts
|
| 776 |
+
# Take necessary space from `token_gather_buf` symm mem because we are
|
| 777 |
+
# going to send them out after expert processing
|
| 778 |
+
processed_tokens = self.get_gather_buf()
|
| 779 |
+
|
| 780 |
+
# Move into Symmetric Memory for the return shuffle
|
| 781 |
+
processed_tokens[permuted_indices] = hidden_outputs
|
| 782 |
+
|
| 783 |
+
# Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
|
| 784 |
+
# The input/output splits are just a reverse of the previous shuffle.
|
| 785 |
+
token_return_buf, _ = OnDeviceAllToAllV.apply(
|
| 786 |
+
processed_tokens,
|
| 787 |
+
output_splits,
|
| 788 |
+
self.ep_group,
|
| 789 |
+
)
|
| 790 |
+
returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
|
| 791 |
+
|
| 792 |
+
output_tokens = torch.empty_like(returned_tokens)
|
| 793 |
+
output_tokens[idxs] = returned_tokens
|
| 794 |
+
final_out = (
|
| 795 |
+
output_tokens.view(*topk_ids.shape, -1)
|
| 796 |
+
.type(topk_weight.dtype)
|
| 797 |
+
.mul_(topk_weight.unsqueeze(dim=-1))
|
| 798 |
+
.sum(dim=1)
|
| 799 |
+
.type(returned_tokens.dtype)
|
| 800 |
+
)
|
| 801 |
+
return final_out
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
class Attention(nn.Module):
|
| 805 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 806 |
+
|
| 807 |
+
def __init__(self, config: ModelArgs, layer_idx: Optional[int] = None):
|
| 808 |
+
super().__init__()
|
| 809 |
+
self.config = config
|
| 810 |
+
self.layer_idx = layer_idx
|
| 811 |
+
self.attention_dropout = config.attention_dropout
|
| 812 |
+
self.hidden_size = config.hidden_size
|
| 813 |
+
self.num_heads = config.num_attention_heads
|
| 814 |
+
|
| 815 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 816 |
+
self.rope_theta = config.rope_theta
|
| 817 |
+
self.q_lora_rank = config.q_lora_rank
|
| 818 |
+
self.qk_rope_head_dim = config.qk_rope_head_dim
|
| 819 |
+
self.kv_lora_rank = config.kv_lora_rank
|
| 820 |
+
self.v_head_dim = config.v_head_dim
|
| 821 |
+
self.qk_nope_head_dim = config.qk_nope_head_dim
|
| 822 |
+
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
| 823 |
+
|
| 824 |
+
self.is_causal = True
|
| 825 |
+
|
| 826 |
+
if self.q_lora_rank is None:
|
| 827 |
+
self.q_proj = nn.Linear(
|
| 828 |
+
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
| 829 |
+
)
|
| 830 |
+
else:
|
| 831 |
+
self.q_a_proj = nn.Linear(
|
| 832 |
+
self.hidden_size, config.q_lora_rank, bias=config.attention_bias
|
| 833 |
+
)
|
| 834 |
+
self.q_a_layernorm = RMSNorm(config.q_lora_rank)
|
| 835 |
+
self.q_b_proj = nn.Linear(
|
| 836 |
+
config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
self.kv_a_proj_with_mqa = nn.Linear(
|
| 840 |
+
self.hidden_size,
|
| 841 |
+
config.kv_lora_rank + config.qk_rope_head_dim,
|
| 842 |
+
bias=config.attention_bias,
|
| 843 |
+
)
|
| 844 |
+
self.kv_a_layernorm = RMSNorm(config.kv_lora_rank)
|
| 845 |
+
self.kv_b_proj = nn.Linear(
|
| 846 |
+
config.kv_lora_rank,
|
| 847 |
+
self.num_heads
|
| 848 |
+
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
| 849 |
+
bias=False,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
self.o_proj = nn.Linear(
|
| 853 |
+
self.num_heads * self.v_head_dim,
|
| 854 |
+
self.hidden_size,
|
| 855 |
+
bias=config.attention_bias,
|
| 856 |
+
)
|
| 857 |
+
self._init_rope()
|
| 858 |
+
|
| 859 |
+
self.softmax_scale = self.q_head_dim ** (-0.5)
|
| 860 |
+
if self.config.rope_scaling is not None:
|
| 861 |
+
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
| 862 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 863 |
+
if mscale_all_dim:
|
| 864 |
+
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
| 865 |
+
self.softmax_scale = self.softmax_scale * mscale * mscale
|
| 866 |
+
|
| 867 |
+
def _init_rope(self):
|
| 868 |
+
if self.config.rope_scaling is None:
|
| 869 |
+
self.rotary_emb = RotaryEmbedding(
|
| 870 |
+
self.qk_rope_head_dim,
|
| 871 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 872 |
+
base=self.rope_theta,
|
| 873 |
+
)
|
| 874 |
+
else:
|
| 875 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 876 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 877 |
+
if scaling_type == "linear":
|
| 878 |
+
self.rotary_emb = LinearScalingRotaryEmbedding(
|
| 879 |
+
self.qk_rope_head_dim,
|
| 880 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 881 |
+
scaling_factor=scaling_factor,
|
| 882 |
+
base=self.rope_theta,
|
| 883 |
+
)
|
| 884 |
+
elif scaling_type == "dynamic":
|
| 885 |
+
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
| 886 |
+
self.qk_rope_head_dim,
|
| 887 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 888 |
+
scaling_factor=scaling_factor,
|
| 889 |
+
base=self.rope_theta,
|
| 890 |
+
)
|
| 891 |
+
elif scaling_type == "yarn":
|
| 892 |
+
kwargs = {
|
| 893 |
+
key: self.config.rope_scaling[key]
|
| 894 |
+
for key in [
|
| 895 |
+
"original_max_position_embeddings",
|
| 896 |
+
"beta_fast",
|
| 897 |
+
"beta_slow",
|
| 898 |
+
"mscale",
|
| 899 |
+
"mscale_all_dim",
|
| 900 |
+
]
|
| 901 |
+
if key in self.config.rope_scaling
|
| 902 |
+
}
|
| 903 |
+
self.rotary_emb = YarnRotaryEmbedding(
|
| 904 |
+
self.qk_rope_head_dim,
|
| 905 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 906 |
+
scaling_factor=scaling_factor,
|
| 907 |
+
base=self.rope_theta,
|
| 908 |
+
**kwargs,
|
| 909 |
+
)
|
| 910 |
+
else:
|
| 911 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 912 |
+
|
| 913 |
+
def forward(
|
| 914 |
+
self,
|
| 915 |
+
hidden_states: torch.Tensor,
|
| 916 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 917 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 918 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 919 |
+
bsz, q_len, _ = hidden_states.size()
|
| 920 |
+
|
| 921 |
+
if self.q_lora_rank is None:
|
| 922 |
+
q = self.q_proj(hidden_states)
|
| 923 |
+
else:
|
| 924 |
+
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
| 925 |
+
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
| 926 |
+
q_nope, q_pe = torch.split(
|
| 927 |
+
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
| 931 |
+
compressed_kv, k_pe = torch.split(
|
| 932 |
+
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
| 933 |
+
)
|
| 934 |
+
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
| 935 |
+
kv = (
|
| 936 |
+
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
| 937 |
+
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 938 |
+
.transpose(1, 2)
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
k_nope, value_states = torch.split(
|
| 942 |
+
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
| 943 |
+
)
|
| 944 |
+
kv_seq_len = value_states.shape[-2]
|
| 945 |
+
|
| 946 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 947 |
+
|
| 948 |
+
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
| 949 |
+
|
| 950 |
+
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
| 951 |
+
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
| 952 |
+
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
| 953 |
+
|
| 954 |
+
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
| 955 |
+
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
| 956 |
+
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
| 957 |
+
|
| 958 |
+
if attention_mask is not None:
|
| 959 |
+
# Attention mask was made 4D because the `attn_weights` above is 4D.
|
| 960 |
+
# We probably can make this mask smarter if we want to pack sequences
|
| 961 |
+
# together, instead of using padding. This optimization can be used in
|
| 962 |
+
# inference. For training, if we want to pack sequences, data loader
|
| 963 |
+
# will pass in a mask containing such info.
|
| 964 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 965 |
+
attention_mask, # None, or user provided mask in 2D
|
| 966 |
+
(bsz, q_len),
|
| 967 |
+
hidden_states,
|
| 968 |
+
0, # past_key_values_length, 0 when training
|
| 969 |
+
)
|
| 970 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 971 |
+
raise ValueError(
|
| 972 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 976 |
+
query=query_states,
|
| 977 |
+
key=key_states,
|
| 978 |
+
value=value_states,
|
| 979 |
+
attn_mask=attention_mask,
|
| 980 |
+
dropout_p=self.attention_dropout,
|
| 981 |
+
is_causal=attention_mask is None,
|
| 982 |
+
scale=self.softmax_scale,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 986 |
+
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
| 987 |
+
attn_output = self.o_proj(attn_output)
|
| 988 |
+
|
| 989 |
+
return attn_output
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
class DecoderLayer(nn.Module):
|
| 993 |
+
def __init__(self, config: ModelArgs, layer_idx: int):
|
| 994 |
+
super().__init__()
|
| 995 |
+
self.hidden_size = config.hidden_size
|
| 996 |
+
|
| 997 |
+
self.self_attn = Attention(config=config, layer_idx=layer_idx)
|
| 998 |
+
|
| 999 |
+
self.mlp = (
|
| 1000 |
+
MoE(config)
|
| 1001 |
+
if (
|
| 1002 |
+
config.n_routed_experts is not None
|
| 1003 |
+
and layer_idx >= config.first_k_dense_replace
|
| 1004 |
+
and layer_idx % config.moe_layer_freq == 0
|
| 1005 |
+
)
|
| 1006 |
+
else MLP(config)
|
| 1007 |
+
)
|
| 1008 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1009 |
+
self.post_attention_layernorm = RMSNorm(
|
| 1010 |
+
config.hidden_size, eps=config.rms_norm_eps
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
def forward(
|
| 1014 |
+
self,
|
| 1015 |
+
hidden_states: torch.Tensor,
|
| 1016 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1017 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1018 |
+
) -> torch.Tensor:
|
| 1019 |
+
"""
|
| 1020 |
+
Args:
|
| 1021 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 1022 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
| 1023 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
| 1024 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
| 1025 |
+
"""
|
| 1026 |
+
residual = hidden_states
|
| 1027 |
+
|
| 1028 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 1029 |
+
|
| 1030 |
+
# Self Attention
|
| 1031 |
+
hidden_states = self.self_attn(
|
| 1032 |
+
hidden_states=hidden_states,
|
| 1033 |
+
attention_mask=attention_mask,
|
| 1034 |
+
position_ids=position_ids,
|
| 1035 |
+
)
|
| 1036 |
+
hidden_states = residual + hidden_states
|
| 1037 |
+
|
| 1038 |
+
# Fully Connected
|
| 1039 |
+
residual = hidden_states
|
| 1040 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 1041 |
+
hidden_states = self.mlp(hidden_states)
|
| 1042 |
+
hidden_states = residual + hidden_states
|
| 1043 |
+
|
| 1044 |
+
return hidden_states
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
Deepseek_INPUTS_DOCSTRING = r"""
|
| 1048 |
+
Args:
|
| 1049 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1050 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 1051 |
+
it.
|
| 1052 |
+
|
| 1053 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1054 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1055 |
+
|
| 1056 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1057 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1058 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1059 |
+
|
| 1060 |
+
- 1 for tokens that are **not masked**,
|
| 1061 |
+
- 0 for tokens that are **masked**.
|
| 1062 |
+
|
| 1063 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1064 |
+
|
| 1065 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1066 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1067 |
+
|
| 1068 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 1069 |
+
`past_key_values`).
|
| 1070 |
+
|
| 1071 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 1072 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 1073 |
+
information on the default strategy.
|
| 1074 |
+
|
| 1075 |
+
- 1 indicates the head is **not masked**,
|
| 1076 |
+
- 0 indicates the head is **masked**.
|
| 1077 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1078 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1079 |
+
config.n_positions - 1]`.
|
| 1080 |
+
|
| 1081 |
+
[What are position IDs?](../glossary#position-ids)
|
| 1082 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 1083 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 1084 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 1085 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 1086 |
+
|
| 1087 |
+
Two formats are allowed:
|
| 1088 |
+
- a [`~cache_utils.Cache`] instance;
|
| 1089 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 1090 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 1091 |
+
cache format.
|
| 1092 |
+
|
| 1093 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 1094 |
+
legacy cache format will be returned.
|
| 1095 |
+
|
| 1096 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 1097 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 1098 |
+
of shape `(batch_size, sequence_length)`.
|
| 1099 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1100 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 1101 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 1102 |
+
model's internal embedding lookup matrix.
|
| 1103 |
+
use_cache (`bool`, *optional*):
|
| 1104 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 1105 |
+
`past_key_values`).
|
| 1106 |
+
output_attentions (`bool`, *optional*):
|
| 1107 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1108 |
+
tensors for more detail.
|
| 1109 |
+
output_hidden_states (`bool`, *optional*):
|
| 1110 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1111 |
+
more detail.
|
| 1112 |
+
return_dict (`bool`, *optional*):
|
| 1113 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1114 |
+
"""
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
class DeepseekModel(torch.nn.Module):
|
| 1118 |
+
"""
|
| 1119 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`]
|
| 1120 |
+
|
| 1121 |
+
Args:
|
| 1122 |
+
config: ModelArgs
|
| 1123 |
+
"""
|
| 1124 |
+
|
| 1125 |
+
def __init__(self, config: ModelArgs):
|
| 1126 |
+
super().__init__()
|
| 1127 |
+
self.config = config
|
| 1128 |
+
self.padding_idx = config.pad_token_id
|
| 1129 |
+
self.vocab_size = config.vocab_size
|
| 1130 |
+
|
| 1131 |
+
# Creating model parts related to my stage
|
| 1132 |
+
assert (
|
| 1133 |
+
config.stage_idx < config.num_stages
|
| 1134 |
+
), f"Stage {config.stage_idx} is not in the model"
|
| 1135 |
+
print(f"Creating model stage {config.stage_idx} of {config.num_stages}")
|
| 1136 |
+
|
| 1137 |
+
self.embed_tokens = (
|
| 1138 |
+
nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 1139 |
+
if config.stage_idx == 0
|
| 1140 |
+
else None
|
| 1141 |
+
)
|
| 1142 |
+
|
| 1143 |
+
self.layers = torch.nn.ModuleDict()
|
| 1144 |
+
division = config.num_hidden_layers // config.num_stages
|
| 1145 |
+
residual = config.num_hidden_layers % config.num_stages
|
| 1146 |
+
# Some earlier stages may have 1 more layer than latter stages because
|
| 1147 |
+
# the division may have residual; this is more even than giving the
|
| 1148 |
+
# entire residual to the last stage.
|
| 1149 |
+
layers_per_stage = [
|
| 1150 |
+
division + 1 if stage < residual else division
|
| 1151 |
+
for stage in range(config.num_stages)
|
| 1152 |
+
]
|
| 1153 |
+
assert sum(layers_per_stage) == config.num_hidden_layers
|
| 1154 |
+
layer_id_start = sum(layers_per_stage[: config.stage_idx])
|
| 1155 |
+
layer_id_end = layer_id_start + layers_per_stage[config.stage_idx]
|
| 1156 |
+
for layer_id in range(layer_id_start, layer_id_end):
|
| 1157 |
+
self.layers[str(layer_id)] = DecoderLayer(config, layer_id)
|
| 1158 |
+
|
| 1159 |
+
self.norm = (
|
| 1160 |
+
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1161 |
+
if config.stage_idx == config.num_stages - 1
|
| 1162 |
+
else None
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
# Initialize weights and apply final processing
|
| 1166 |
+
self.apply(self._init_weights)
|
| 1167 |
+
|
| 1168 |
+
def _init_weights(self, module):
|
| 1169 |
+
std = self.config.initializer_range
|
| 1170 |
+
if isinstance(module, nn.Linear):
|
| 1171 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1172 |
+
if module.bias is not None:
|
| 1173 |
+
module.bias.data.zero_()
|
| 1174 |
+
elif isinstance(module, nn.Embedding):
|
| 1175 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1176 |
+
if module.padding_idx is not None:
|
| 1177 |
+
module.weight.data[module.padding_idx].zero_()
|
| 1178 |
+
|
| 1179 |
+
def forward(
|
| 1180 |
+
self,
|
| 1181 |
+
tokens: torch.Tensor,
|
| 1182 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1183 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1184 |
+
) -> torch.Tensor:
|
| 1185 |
+
# Embedding
|
| 1186 |
+
hidden_states = (
|
| 1187 |
+
self.embed_tokens(tokens) if self.embed_tokens is not None else tokens
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
# decoder layers
|
| 1191 |
+
for decoder_layer in self.layers.values():
|
| 1192 |
+
hidden_states = decoder_layer(
|
| 1193 |
+
hidden_states,
|
| 1194 |
+
attention_mask=attention_mask,
|
| 1195 |
+
position_ids=position_ids,
|
| 1196 |
+
)
|
| 1197 |
+
|
| 1198 |
+
hidden_states = (
|
| 1199 |
+
self.norm(hidden_states) if self.norm is not None else hidden_states
|
| 1200 |
+
)
|
| 1201 |
+
return hidden_states
|
| 1202 |
+
|
| 1203 |
+
|
| 1204 |
+
class DeepseekForCausalLM(torch.nn.Module):
|
| 1205 |
+
def __init__(self, config):
|
| 1206 |
+
super().__init__()
|
| 1207 |
+
self.model = DeepseekModel(config)
|
| 1208 |
+
self.lm_head = (
|
| 1209 |
+
nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1210 |
+
if config.stage_idx == config.num_stages - 1
|
| 1211 |
+
else None
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
# Initialize weights and apply final processing
|
| 1215 |
+
# self.post_init()
|
| 1216 |
+
|
| 1217 |
+
def forward(
|
| 1218 |
+
self,
|
| 1219 |
+
tokens: torch.Tensor,
|
| 1220 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1221 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1222 |
+
) -> Tuple:
|
| 1223 |
+
r"""
|
| 1224 |
+
Example:
|
| 1225 |
+
|
| 1226 |
+
```python
|
| 1227 |
+
>>> from transformers import AutoTokenizer, DeepseekForCausalLM
|
| 1228 |
+
|
| 1229 |
+
>>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
| 1230 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
| 1231 |
+
|
| 1232 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1233 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1234 |
+
|
| 1235 |
+
>>> # Generate
|
| 1236 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1237 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1238 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1239 |
+
```"""
|
| 1240 |
+
hidden_states = self.model(
|
| 1241 |
+
tokens,
|
| 1242 |
+
attention_mask=attention_mask,
|
| 1243 |
+
position_ids=position_ids,
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
logits = (
|
| 1247 |
+
self.lm_head(hidden_states) if self.lm_head is not None else hidden_states
|
| 1248 |
+
)
|
| 1249 |
+
return logits
|
| 1250 |
+
|
| 1251 |
+
def prepare_inputs_for_generation(
|
| 1252 |
+
self,
|
| 1253 |
+
input_ids,
|
| 1254 |
+
past_key_values=None,
|
| 1255 |
+
attention_mask=None,
|
| 1256 |
+
**kwargs,
|
| 1257 |
+
):
|
| 1258 |
+
if past_key_values is not None:
|
| 1259 |
+
# Assuming isinstance(past_key_values, Cache):
|
| 1260 |
+
cache_length = past_key_values.get_seq_length()
|
| 1261 |
+
past_length = past_key_values.seen_tokens
|
| 1262 |
+
max_cache_length = past_key_values.get_max_length()
|
| 1263 |
+
|
| 1264 |
+
# Keep only the unprocessed tokens:
|
| 1265 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 1266 |
+
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
| 1267 |
+
# input)
|
| 1268 |
+
if (
|
| 1269 |
+
attention_mask is not None
|
| 1270 |
+
and attention_mask.shape[1] > input_ids.shape[1]
|
| 1271 |
+
):
|
| 1272 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1273 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 1274 |
+
# input_ids based on the past_length.
|
| 1275 |
+
elif past_length < input_ids.shape[1]:
|
| 1276 |
+
input_ids = input_ids[:, past_length:]
|
| 1277 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 1278 |
+
|
| 1279 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
| 1280 |
+
if (
|
| 1281 |
+
max_cache_length is not None
|
| 1282 |
+
and attention_mask is not None
|
| 1283 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
| 1284 |
+
):
|
| 1285 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
| 1286 |
+
|
| 1287 |
+
position_ids = kwargs.get("position_ids", None)
|
| 1288 |
+
if attention_mask is not None and position_ids is None:
|
| 1289 |
+
# create position_ids on the fly for batch generation
|
| 1290 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1291 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1292 |
+
if past_key_values:
|
| 1293 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1294 |
+
|
| 1295 |
+
model_inputs = {"input_ids": input_ids}
|
| 1296 |
+
|
| 1297 |
+
model_inputs.update(
|
| 1298 |
+
{
|
| 1299 |
+
"position_ids": position_ids,
|
| 1300 |
+
"past_key_values": past_key_values,
|
| 1301 |
+
"use_cache": kwargs.get("use_cache"),
|
| 1302 |
+
"attention_mask": attention_mask,
|
| 1303 |
+
}
|
| 1304 |
+
)
|
| 1305 |
+
return model_inputs
|
| 1306 |
+
|
| 1307 |
+
@staticmethod
|
| 1308 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 1309 |
+
reordered_past = ()
|
| 1310 |
+
for layer_past in past_key_values:
|
| 1311 |
+
reordered_past += (
|
| 1312 |
+
tuple(
|
| 1313 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
| 1314 |
+
for past_state in layer_past
|
| 1315 |
+
),
|
| 1316 |
+
)
|
| 1317 |
+
return reordered_past
|
| 1318 |
+
|
| 1319 |
+
# Setup Symmetric Memory for MoE token shuffle.
|
| 1320 |
+
# Supports inference currently.
|
| 1321 |
+
def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
|
| 1322 |
+
for layer in self.model.layers.values():
|
| 1323 |
+
if not isinstance(layer.mlp, MoE):
|
| 1324 |
+
continue
|
| 1325 |
+
layer.mlp.setup_symm_mem(dtype, device)
|
torchtitan/experiments/deepseek_v3/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers
|
| 2 |
+
accelerate
|
| 3 |
+
torchdata >= 0.8.0
|
| 4 |
+
datasets >= 2.21.0
|
| 5 |
+
tomli >= 1.1.0 ; python_version < "3.11"
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@triton.jit
|
| 12 |
+
def get_tid():
|
| 13 |
+
return tl.inline_asm_elementwise(
|
| 14 |
+
"""
|
| 15 |
+
mov.u32 $0, %tid.x;
|
| 16 |
+
mov.u32 $1, %tid.y;
|
| 17 |
+
mov.u32 $2, %tid.z;
|
| 18 |
+
""",
|
| 19 |
+
"=r,=r,=r",
|
| 20 |
+
[],
|
| 21 |
+
dtype=(tl.uint32, tl.uint32, tl.uint32),
|
| 22 |
+
is_pure=True,
|
| 23 |
+
pack=1,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@triton.jit
|
| 28 |
+
def get_ntid():
|
| 29 |
+
return tl.inline_asm_elementwise(
|
| 30 |
+
"""
|
| 31 |
+
mov.u32 $0, %ntid.x;
|
| 32 |
+
mov.u32 $1, %ntid.y;
|
| 33 |
+
mov.u32 $2, %ntid.z;
|
| 34 |
+
""",
|
| 35 |
+
"=r,=r,=r",
|
| 36 |
+
[],
|
| 37 |
+
dtype=(tl.uint32, tl.uint32, tl.uint32),
|
| 38 |
+
is_pure=True,
|
| 39 |
+
pack=1,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@triton.jit
|
| 44 |
+
def get_flat_tid():
|
| 45 |
+
tid_x, tid_y, tid_z = get_tid()
|
| 46 |
+
ntid_x, ntid_y, _ = get_ntid()
|
| 47 |
+
return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@triton.jit
|
| 51 |
+
def get_flat_bid():
|
| 52 |
+
return (
|
| 53 |
+
tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
|
| 54 |
+
+ tl.program_id(1) * tl.num_programs(0)
|
| 55 |
+
+ tl.program_id(0)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@triton.jit
|
| 60 |
+
def sync_threads():
|
| 61 |
+
tl.inline_asm_elementwise(
|
| 62 |
+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
|
| 63 |
+
)
|
torchtitan/experiments/flux/README.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FLUX model in torchtitan
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
## Usage
|
| 6 |
+
First, download the autoencoder model from HuggingFace with your own access token:
|
| 7 |
+
```bash
|
| 8 |
+
python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
|
| 9 |
+
```
|
| 10 |
+
This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
|
| 11 |
+
|
| 12 |
+
Run the following command to train the model on a single GPU:
|
| 13 |
+
```bash
|
| 14 |
+
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## TODO
|
| 18 |
+
- [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
|
| 19 |
+
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
|
| 20 |
+
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
|
| 21 |
+
- [ ] Support for distributed checkpointing and loading
|
| 22 |
+
- [ ] Implement init_weights() function to initialize the model weights
|
| 23 |
+
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
|
torchtitan/experiments/flux/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
torchtitan/experiments/flux/dataset/flux_dataset.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import random
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Callable, Optional
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from datasets import Dataset, load_dataset
|
| 17 |
+
from datasets.distributed import split_dataset_by_node
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 21 |
+
|
| 22 |
+
from torch.utils.data import IterableDataset
|
| 23 |
+
from torchtitan.components.dataloader import ParallelAwareDataloader
|
| 24 |
+
|
| 25 |
+
from torchtitan.config_manager import JobConfig
|
| 26 |
+
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
|
| 27 |
+
from torchtitan.tools.logging import logger
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _process_cc12m_image(
|
| 31 |
+
img: Image.Image,
|
| 32 |
+
output_size: int = 256,
|
| 33 |
+
) -> Optional[torch.Tensor]:
|
| 34 |
+
"""Process CC12M image to the desired size."""
|
| 35 |
+
|
| 36 |
+
width, height = img.size
|
| 37 |
+
# Skip low resolution images
|
| 38 |
+
if width < output_size or height < output_size:
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
if width >= height:
|
| 42 |
+
# resize height to be equal to output_size, then crop
|
| 43 |
+
new_width, new_height = math.ceil(output_size / height * width), output_size
|
| 44 |
+
img = img.resize((new_width, new_height))
|
| 45 |
+
left = random.randint(0, new_width - output_size)
|
| 46 |
+
resized_img = img.crop((left, 0, left + output_size, output_size))
|
| 47 |
+
else:
|
| 48 |
+
# resize width to be equal to output_size, the crop
|
| 49 |
+
new_width, new_height = (
|
| 50 |
+
output_size,
|
| 51 |
+
math.ceil(output_size / width * height),
|
| 52 |
+
)
|
| 53 |
+
img = img.resize((new_width, new_height))
|
| 54 |
+
lower = random.randint(0, new_width - output_size)
|
| 55 |
+
resized_img = img.crop((0, lower, output_size, lower + output_size))
|
| 56 |
+
|
| 57 |
+
assert resized_img.size[0] == resized_img.size[1] == output_size
|
| 58 |
+
|
| 59 |
+
# Skip grayscale images
|
| 60 |
+
if resized_img.mode == "L":
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
np_img = np.array(resized_img).transpose((2, 0, 1))
|
| 64 |
+
tensor_img = torch.tensor(np_img).float() / 255.0
|
| 65 |
+
|
| 66 |
+
# NOTE: The following commented code is an alternative way
|
| 67 |
+
# img_transform = transforms.Compose(
|
| 68 |
+
# [
|
| 69 |
+
# transforms.Resize(max(output_size, output_size)),
|
| 70 |
+
# transforms.CenterCrop((output_size, output_size)),
|
| 71 |
+
# transforms.ToTensor(),
|
| 72 |
+
# ]
|
| 73 |
+
# )
|
| 74 |
+
# tensor_img = img_transform(img)
|
| 75 |
+
|
| 76 |
+
return tensor_img
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _flux_data_processor(
|
| 80 |
+
sample: dict[str, Any],
|
| 81 |
+
t5_tokenizer: FluxTokenizer,
|
| 82 |
+
clip_tokenizer: FluxTokenizer,
|
| 83 |
+
output_size: int = 256,
|
| 84 |
+
) -> dict[str, Any]:
|
| 85 |
+
"""
|
| 86 |
+
Preprocess CC12M dataset sample image and text for Flux model.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
sample: A sample from dataset
|
| 90 |
+
t5_encoder: T5 encoder
|
| 91 |
+
clip_encoder: CLIP encoder
|
| 92 |
+
output_size: The output image size
|
| 93 |
+
|
| 94 |
+
"""
|
| 95 |
+
img = _process_cc12m_image(sample["jpg"], output_size=output_size)
|
| 96 |
+
t5_tokens = t5_tokenizer.encode(sample["txt"])
|
| 97 |
+
clip_tokens = clip_tokenizer.encode(sample["txt"])
|
| 98 |
+
|
| 99 |
+
return {
|
| 100 |
+
"image": img,
|
| 101 |
+
"clip_tokens": clip_tokens, # type: List[int]
|
| 102 |
+
"t5_tokens": t5_tokens, # type: List[int]
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class TextToImageDatasetConfig:
|
| 108 |
+
path: str
|
| 109 |
+
loader: Callable
|
| 110 |
+
data_processor: Callable
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
DATASETS = {
|
| 114 |
+
"cc12m": TextToImageDatasetConfig(
|
| 115 |
+
path="pixparse/cc12m-wds",
|
| 116 |
+
loader=lambda path: load_dataset(path, split="train", streaming=True),
|
| 117 |
+
data_processor=_flux_data_processor,
|
| 118 |
+
),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _validate_dataset(
|
| 123 |
+
dataset_name: str, dataset_path: Optional[str] = None
|
| 124 |
+
) -> tuple[str, Callable, Callable]:
|
| 125 |
+
"""Validate dataset name and path."""
|
| 126 |
+
if dataset_name not in DATASETS:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Dataset {dataset_name} is not supported. "
|
| 129 |
+
f"Supported datasets are: {list(DATASETS.keys())}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
config = DATASETS[dataset_name]
|
| 133 |
+
path = dataset_path or config.path
|
| 134 |
+
logger.info(f"Preparing {dataset_name} dataset from {path}")
|
| 135 |
+
return path, config.loader, config.data_processor
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class FluxDataset(IterableDataset, Stateful):
|
| 139 |
+
"""Dataset for FLUX text-to-image model.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
dataset_name (str): Name of the dataset.
|
| 143 |
+
dataset_path (str): Path to the dataset.
|
| 144 |
+
model_transform (Transform): Callable that applies model-specific preprocessing to the sample.
|
| 145 |
+
dp_rank (int): Data parallel rank.
|
| 146 |
+
dp_world_size (int): Data parallel world size.
|
| 147 |
+
infinite (bool): Whether to loop over the dataset infinitely.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
dataset_name: str,
|
| 153 |
+
dataset_path: Optional[str],
|
| 154 |
+
t5_tokenizer: FluxTokenizer,
|
| 155 |
+
clip_tokenizer: FluxTokenizer,
|
| 156 |
+
job_config: Optional[JobConfig] = None,
|
| 157 |
+
dp_rank: int = 0,
|
| 158 |
+
dp_world_size: int = 1,
|
| 159 |
+
infinite: bool = False,
|
| 160 |
+
) -> None:
|
| 161 |
+
|
| 162 |
+
# Force lowercase for consistent comparison
|
| 163 |
+
dataset_name = dataset_name.lower()
|
| 164 |
+
|
| 165 |
+
path, dataset_loader, data_processor = _validate_dataset(
|
| 166 |
+
dataset_name, dataset_path
|
| 167 |
+
)
|
| 168 |
+
ds = dataset_loader(path)
|
| 169 |
+
|
| 170 |
+
self.dataset_name = dataset_name
|
| 171 |
+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
|
| 172 |
+
|
| 173 |
+
self._t5_tokenizer = t5_tokenizer
|
| 174 |
+
self._clip_tokenizer = clip_tokenizer
|
| 175 |
+
self._data_processor = data_processor
|
| 176 |
+
self.job_config = job_config
|
| 177 |
+
|
| 178 |
+
self.infinite = infinite
|
| 179 |
+
|
| 180 |
+
# Variables for checkpointing
|
| 181 |
+
self._sample_idx = 0
|
| 182 |
+
self._all_samples: list[dict[str, Any]] = []
|
| 183 |
+
|
| 184 |
+
def _get_data_iter(self):
|
| 185 |
+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
|
| 186 |
+
return iter([])
|
| 187 |
+
|
| 188 |
+
it = iter(self._data)
|
| 189 |
+
for _ in range(self._sample_idx):
|
| 190 |
+
next(it)
|
| 191 |
+
return it
|
| 192 |
+
|
| 193 |
+
def __iter__(self):
|
| 194 |
+
while True:
|
| 195 |
+
for sample in self._get_data_iter():
|
| 196 |
+
# Use the dataset-specific preprocessor
|
| 197 |
+
sample_dict = self._data_processor(
|
| 198 |
+
sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# skip low quality image or image with color channel = 1
|
| 202 |
+
if sample_dict["image"] is None:
|
| 203 |
+
logger.warning(
|
| 204 |
+
f"Low quality image {sample['__key__']} is skipped in Flux Dataloader"
|
| 205 |
+
)
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
self._all_samples.extend(sample_dict)
|
| 209 |
+
self._sample_idx += 1
|
| 210 |
+
|
| 211 |
+
labels = sample_dict.pop("image")
|
| 212 |
+
yield sample_dict, labels
|
| 213 |
+
|
| 214 |
+
if not self.infinite:
|
| 215 |
+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
|
| 216 |
+
break
|
| 217 |
+
else:
|
| 218 |
+
# Reset offset for the next iteration
|
| 219 |
+
self._sample_idx = 0
|
| 220 |
+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
|
| 221 |
+
|
| 222 |
+
def load_state_dict(self, state_dict):
|
| 223 |
+
self._sample_idx = state_dict["sample_idx"]
|
| 224 |
+
self._all_samples = state_dict["all_samples"]
|
| 225 |
+
|
| 226 |
+
def state_dict(self):
|
| 227 |
+
return {
|
| 228 |
+
"all_samples": self._all_samples,
|
| 229 |
+
"sample_idx": self._sample_idx,
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def build_flux_dataloader(
|
| 234 |
+
dp_world_size: int,
|
| 235 |
+
dp_rank: int,
|
| 236 |
+
job_config: JobConfig,
|
| 237 |
+
# This parameter is not used, keep it for compatibility
|
| 238 |
+
tokenizer: FluxTokenizer | None,
|
| 239 |
+
infinite: bool = True,
|
| 240 |
+
) -> ParallelAwareDataloader:
|
| 241 |
+
"""Build a data loader for HuggingFace datasets."""
|
| 242 |
+
dataset_name = job_config.training.dataset
|
| 243 |
+
dataset_path = job_config.training.dataset_path
|
| 244 |
+
batch_size = job_config.training.batch_size
|
| 245 |
+
|
| 246 |
+
t5_encoder_name = job_config.encoder.t5_encoder
|
| 247 |
+
clip_encoder_name = job_config.encoder.clip_encoder
|
| 248 |
+
max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
|
| 249 |
+
|
| 250 |
+
ds = FluxDataset(
|
| 251 |
+
dataset_name=dataset_name,
|
| 252 |
+
dataset_path=dataset_path,
|
| 253 |
+
t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len),
|
| 254 |
+
clip_tokenizer=FluxTokenizer(
|
| 255 |
+
clip_encoder_name, max_length=77
|
| 256 |
+
), # fix max_length for CLIP
|
| 257 |
+
dp_rank=dp_rank,
|
| 258 |
+
dp_world_size=dp_world_size,
|
| 259 |
+
infinite=infinite,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return ParallelAwareDataloader(
|
| 263 |
+
dataset=ds,
|
| 264 |
+
dp_rank=dp_rank,
|
| 265 |
+
dp_world_size=dp_world_size,
|
| 266 |
+
batch_size=batch_size,
|
| 267 |
+
)
|
torchtitan/experiments/flux/dataset/tokenizer.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 8 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import List
|
| 12 |
+
|
| 13 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 14 |
+
from transformers import CLIPTokenizer, T5Tokenizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FluxTokenizer(Tokenizer):
|
| 18 |
+
"""
|
| 19 |
+
Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model_path (str): Path to the tokenzier from hugging face.
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, model_path: str = "t5-small", max_length: int = 77):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self._n_words = 8 # TODO(jianiw): check
|
| 29 |
+
self._max_length = max_length
|
| 30 |
+
|
| 31 |
+
self.is_clip = model_path.startswith("openai")
|
| 32 |
+
|
| 33 |
+
if self.is_clip:
|
| 34 |
+
self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
|
| 35 |
+
model_path, max_length=max_length
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
|
| 39 |
+
model_path, max_length=max_length
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def encode(
|
| 43 |
+
self,
|
| 44 |
+
s: str,
|
| 45 |
+
) -> List[int]:
|
| 46 |
+
"""
|
| 47 |
+
Encode the prompt text into tokens.
|
| 48 |
+
"""
|
| 49 |
+
tokens = self._tokenizer(
|
| 50 |
+
s,
|
| 51 |
+
truncation=True,
|
| 52 |
+
max_length=self._max_length,
|
| 53 |
+
return_length=False,
|
| 54 |
+
return_overflowing_tokens=False,
|
| 55 |
+
padding="max_length",
|
| 56 |
+
return_tensors="pt", # return pytorch tensors, default return List[int]
|
| 57 |
+
)["input_ids"]
|
| 58 |
+
return tokens
|
| 59 |
+
|
| 60 |
+
def decode(self, t: List[int]) -> str:
|
| 61 |
+
"""
|
| 62 |
+
Decode function. This function will not be called.
|
| 63 |
+
"""
|
| 64 |
+
return self._tokenizer.decode(t)
|
torchtitan/experiments/flux/model/__pycache__/layers.cpython-312.pyc
ADDED
|
Binary file (17.7 kB). View file
|
|
|
torchtitan/experiments/flux/model/hf_embedder.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torch import nn, Tensor
|
| 8 |
+
from transformers import CLIPTextModel, T5EncoderModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FluxEmbedder(nn.Module):
|
| 12 |
+
def __init__(self, version: str, **hf_kwargs):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.is_clip = version.startswith("openai")
|
| 15 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
| 16 |
+
|
| 17 |
+
if self.is_clip:
|
| 18 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
|
| 19 |
+
version, **hf_kwargs
|
| 20 |
+
)
|
| 21 |
+
else:
|
| 22 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
|
| 23 |
+
version, **hf_kwargs
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
| 27 |
+
|
| 28 |
+
def forward(self, batch_tokens: Tensor) -> Tensor:
|
| 29 |
+
"""
|
| 30 |
+
batch_tokens: [bsz, embedding_length]
|
| 31 |
+
|
| 32 |
+
For T5 Encoder, embeding_length is 768
|
| 33 |
+
For CLIP, embedding_length is 256
|
| 34 |
+
"""
|
| 35 |
+
outputs = self.hf_module(
|
| 36 |
+
input_ids=batch_tokens.to(self.hf_module.device),
|
| 37 |
+
attention_mask=None,
|
| 38 |
+
output_hidden_states=False,
|
| 39 |
+
)
|
| 40 |
+
return outputs[self.output_key]
|
torchtitan/experiments/flux/model/model.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torch import nn, Tensor
|
| 12 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 13 |
+
from torchtitan.config_manager import JobConfig
|
| 14 |
+
|
| 15 |
+
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
|
| 16 |
+
from torchtitan.experiments.flux.model.layers import (
|
| 17 |
+
DoubleStreamBlock,
|
| 18 |
+
EmbedND,
|
| 19 |
+
LastLayer,
|
| 20 |
+
MLPEmbedder,
|
| 21 |
+
SingleStreamBlock,
|
| 22 |
+
timestep_embedding,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
|
| 26 |
+
from torchtitan.tools.logging import logger
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class FluxModelArgs(BaseModelArgs):
|
| 31 |
+
in_channels: int = 64
|
| 32 |
+
out_channels: int = 64
|
| 33 |
+
vec_in_dim: int = 768
|
| 34 |
+
context_in_dim: int = 512
|
| 35 |
+
hidden_size: int = 3072
|
| 36 |
+
mlp_ratio: float = 4.0
|
| 37 |
+
num_heads: int = 24
|
| 38 |
+
depth: int = 19
|
| 39 |
+
depth_single_blocks: int = 38
|
| 40 |
+
axes_dim: tuple = (16, 56, 56)
|
| 41 |
+
theta: int = 10_000
|
| 42 |
+
qkv_bias: bool = True
|
| 43 |
+
guidance_embed: bool = True
|
| 44 |
+
autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
|
| 45 |
+
|
| 46 |
+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
|
| 47 |
+
# context_in_dim is the same as the T5 embedding dimension
|
| 48 |
+
self.context_in_dim = job_config.encoder.max_t5_encoding_len
|
| 49 |
+
|
| 50 |
+
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
|
| 51 |
+
# TODO(jianiw): Add the number of flops for the autoencoder
|
| 52 |
+
nparams = sum(p.numel() for p in model.parameters())
|
| 53 |
+
logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
|
| 54 |
+
return nparams, 1
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class FluxModel(nn.Module, ModelProtocol):
|
| 58 |
+
"""
|
| 59 |
+
Transformer model for flow matching on sequences.
|
| 60 |
+
|
| 61 |
+
Agrs:
|
| 62 |
+
model_args: FluxModelArgs.
|
| 63 |
+
|
| 64 |
+
Attributes:
|
| 65 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, model_args: FluxModelArgs):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.model_args = model_args
|
| 72 |
+
self.in_channels = model_args.in_channels
|
| 73 |
+
self.out_channels = model_args.out_channels
|
| 74 |
+
if model_args.hidden_size % model_args.num_heads != 0:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
|
| 77 |
+
)
|
| 78 |
+
pe_dim = model_args.hidden_size // model_args.num_heads
|
| 79 |
+
if sum(model_args.axes_dim) != pe_dim:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
|
| 82 |
+
)
|
| 83 |
+
self.hidden_size = model_args.hidden_size
|
| 84 |
+
self.num_heads = model_args.num_heads
|
| 85 |
+
self.pe_embedder = EmbedND(
|
| 86 |
+
dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
|
| 87 |
+
)
|
| 88 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
| 89 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 90 |
+
self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
|
| 91 |
+
self.guidance_in = (
|
| 92 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 93 |
+
if model_args.guidance_embed
|
| 94 |
+
else nn.Identity()
|
| 95 |
+
)
|
| 96 |
+
self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
|
| 97 |
+
|
| 98 |
+
self.double_blocks = nn.ModuleList(
|
| 99 |
+
[
|
| 100 |
+
DoubleStreamBlock(
|
| 101 |
+
self.hidden_size,
|
| 102 |
+
self.num_heads,
|
| 103 |
+
mlp_ratio=model_args.mlp_ratio,
|
| 104 |
+
qkv_bias=model_args.qkv_bias,
|
| 105 |
+
)
|
| 106 |
+
for _ in range(model_args.depth)
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self.single_blocks = nn.ModuleList(
|
| 111 |
+
[
|
| 112 |
+
SingleStreamBlock(
|
| 113 |
+
self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio
|
| 114 |
+
)
|
| 115 |
+
for _ in range(model_args.depth_single_blocks)
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
| 120 |
+
|
| 121 |
+
def init_weights(self, buffer_device=None):
|
| 122 |
+
# TODO(jianiw): replace placeholder with real weight init
|
| 123 |
+
for param in self.parameters():
|
| 124 |
+
param.data.uniform_(0, 0.1)
|
| 125 |
+
|
| 126 |
+
def forward(
|
| 127 |
+
self,
|
| 128 |
+
img: Tensor,
|
| 129 |
+
img_ids: Tensor,
|
| 130 |
+
txt: Tensor,
|
| 131 |
+
txt_ids: Tensor,
|
| 132 |
+
timesteps: Tensor,
|
| 133 |
+
y: Tensor,
|
| 134 |
+
guidance: Tensor | None = None,
|
| 135 |
+
) -> Tensor:
|
| 136 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 137 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 138 |
+
|
| 139 |
+
# running on sequences img
|
| 140 |
+
img = self.img_in(img)
|
| 141 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
| 142 |
+
if self.model_args.guidance_embed:
|
| 143 |
+
if guidance is None:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
"Didn't get guidance strength for guidance distilled model."
|
| 146 |
+
)
|
| 147 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
| 148 |
+
vec = vec + self.vector_in(y)
|
| 149 |
+
txt = self.txt_in(txt)
|
| 150 |
+
|
| 151 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 152 |
+
pe = self.pe_embedder(ids)
|
| 153 |
+
|
| 154 |
+
for block in self.double_blocks:
|
| 155 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
| 156 |
+
|
| 157 |
+
img = torch.cat((txt, img), 1)
|
| 158 |
+
for block in self.single_blocks:
|
| 159 |
+
img = block(img, vec=vec, pe=pe)
|
| 160 |
+
img = img[:, txt.shape[1] :, ...]
|
| 161 |
+
|
| 162 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
| 163 |
+
return img
|
| 164 |
+
|
| 165 |
+
@classmethod
|
| 166 |
+
def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
|
| 167 |
+
"""
|
| 168 |
+
Initialize a Flux model from a FluxModelArgs object.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
model_args (FluxModelArgs): Model configuration arguments.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
FluxModel: FluxModel model.
|
| 175 |
+
|
| 176 |
+
"""
|
| 177 |
+
return cls(model_args)
|
torchtitan/experiments/flux/tests/test_flux_dataloader.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from torchtitan.config_manager import JobConfig
|
| 10 |
+
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
|
| 11 |
+
from torchtitan.tools.profiling import (
|
| 12 |
+
maybe_enable_memory_snapshot,
|
| 13 |
+
maybe_enable_profiling,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestFluxDataLoader:
|
| 18 |
+
def test_flux_dataloader(self):
|
| 19 |
+
dataset_name = "cc12m"
|
| 20 |
+
batch_size = 32
|
| 21 |
+
world_size = 4
|
| 22 |
+
rank = 0
|
| 23 |
+
|
| 24 |
+
num_steps = 10
|
| 25 |
+
|
| 26 |
+
path = "torchtitan.experiments.flux.flux_argparser"
|
| 27 |
+
sys.argv.append(f"--experimental.custom_args_module={path}")
|
| 28 |
+
config = JobConfig()
|
| 29 |
+
config.maybe_add_custom_args()
|
| 30 |
+
config.parse_args(
|
| 31 |
+
[
|
| 32 |
+
# Profiling options
|
| 33 |
+
# "--profiling.enable_profiling",
|
| 34 |
+
# "--profiling.profile_freq",
|
| 35 |
+
# "5",
|
| 36 |
+
# "--profiling.enable_memory_snapshot",
|
| 37 |
+
# "--profiling.save_memory_snapshot_folder",
|
| 38 |
+
# "memory_snapshot_flux",
|
| 39 |
+
"--training.dataset",
|
| 40 |
+
dataset_name,
|
| 41 |
+
"--training.batch_size",
|
| 42 |
+
str(batch_size),
|
| 43 |
+
"--encoder.t5_encoder",
|
| 44 |
+
"google/t5-v1_1-small",
|
| 45 |
+
"--encoder.clip_encoder",
|
| 46 |
+
"openai/clip-vit-large-patch14",
|
| 47 |
+
"--encoder.max_t5_encoding_len",
|
| 48 |
+
"512",
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
with maybe_enable_profiling(
|
| 53 |
+
config, global_step=0
|
| 54 |
+
) as torch_profiler, maybe_enable_memory_snapshot(
|
| 55 |
+
config, global_step=0
|
| 56 |
+
) as memory_profiler:
|
| 57 |
+
dl = self._build_dataloader(
|
| 58 |
+
config,
|
| 59 |
+
world_size,
|
| 60 |
+
rank,
|
| 61 |
+
)
|
| 62 |
+
dl = iter(dl)
|
| 63 |
+
|
| 64 |
+
for i in range(0, num_steps):
|
| 65 |
+
input_data, labels = next(dl)
|
| 66 |
+
print(f"Step {i} image size: {labels.shape}")
|
| 67 |
+
if torch_profiler:
|
| 68 |
+
torch_profiler.step()
|
| 69 |
+
if memory_profiler:
|
| 70 |
+
memory_profiler.step()
|
| 71 |
+
|
| 72 |
+
print(len(input_data["clip_tokens"]))
|
| 73 |
+
for k, v in input_data.items():
|
| 74 |
+
print(f"Step {i} {k} value: {type(v), v.shape}")
|
| 75 |
+
|
| 76 |
+
assert len(input_data) == 2 # (clip_encodings, t5_encodings)
|
| 77 |
+
assert labels.shape == (batch_size, 3, 256, 256)
|
| 78 |
+
# assert input_data["clip_tokens"].shape[0] == batch_size
|
| 79 |
+
# assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
|
| 80 |
+
|
| 81 |
+
if torch_profiler:
|
| 82 |
+
torch_profiler.step()
|
| 83 |
+
if memory_profiler:
|
| 84 |
+
memory_profiler.step(exit_ctx=True)
|
| 85 |
+
|
| 86 |
+
def test_preprocess(self):
|
| 87 |
+
# TODO
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
def _build_dataloader(
|
| 91 |
+
self,
|
| 92 |
+
job_config,
|
| 93 |
+
world_size,
|
| 94 |
+
rank,
|
| 95 |
+
):
|
| 96 |
+
|
| 97 |
+
return build_flux_dataloader(
|
| 98 |
+
dp_world_size=world_size,
|
| 99 |
+
dp_rank=rank,
|
| 100 |
+
job_config=job_config,
|
| 101 |
+
tokenizer=None,
|
| 102 |
+
infinite=False,
|
| 103 |
+
)
|
torchtitan/experiments/flux/tests/test_generate_image.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
from typing import Callable
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
|
| 15 |
+
from PIL import ExifTags, Image
|
| 16 |
+
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
|
| 19 |
+
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
|
| 20 |
+
|
| 21 |
+
from torchtitan.experiments.flux.model.autoencoder import (
|
| 22 |
+
AutoEncoder,
|
| 23 |
+
AutoEncoderParams,
|
| 24 |
+
load_ae,
|
| 25 |
+
)
|
| 26 |
+
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
|
| 27 |
+
|
| 28 |
+
from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
|
| 29 |
+
from torchtitan.experiments.flux.utils import (
|
| 30 |
+
create_position_encoding_for_latents,
|
| 31 |
+
generate_noise_latent,
|
| 32 |
+
pack_latents,
|
| 33 |
+
preprocess_flux_data,
|
| 34 |
+
unpack_latents,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
| 39 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_lin_function(
|
| 43 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
| 44 |
+
) -> Callable[[float], float]:
|
| 45 |
+
m = (y2 - y1) / (x2 - x1)
|
| 46 |
+
b = y1 - m * x1
|
| 47 |
+
return lambda x: m * x + b
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_schedule(
|
| 51 |
+
num_steps: int,
|
| 52 |
+
image_seq_len: int,
|
| 53 |
+
base_shift: float = 0.5,
|
| 54 |
+
max_shift: float = 1.15,
|
| 55 |
+
shift: bool = True,
|
| 56 |
+
) -> list[float]:
|
| 57 |
+
# extra step for zero
|
| 58 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
| 59 |
+
|
| 60 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
| 61 |
+
if shift:
|
| 62 |
+
# estimate mu based on linear estimation between two points
|
| 63 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
| 64 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
| 65 |
+
|
| 66 |
+
return timesteps.tolist()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class TestGenerateImage:
|
| 70 |
+
def test_generate_image(self):
|
| 71 |
+
"""
|
| 72 |
+
Run a forward pass of flux model to generate an image.
|
| 73 |
+
"""
|
| 74 |
+
name = "flux-dev"
|
| 75 |
+
img_width = 512
|
| 76 |
+
img_height = 512
|
| 77 |
+
seed = None
|
| 78 |
+
prompt = (
|
| 79 |
+
"a photo of a forest with mist swirling around the tree trunks. The word "
|
| 80 |
+
'"FLUX" is painted over it in big, red brush strokes with visible texture'
|
| 81 |
+
)
|
| 82 |
+
device = "cuda"
|
| 83 |
+
num_steps = None
|
| 84 |
+
loop = False
|
| 85 |
+
guidance = 3.5
|
| 86 |
+
output_dir = "output"
|
| 87 |
+
add_sampling_metadata = True
|
| 88 |
+
|
| 89 |
+
prompt = prompt.split("|")
|
| 90 |
+
if len(prompt) == 1:
|
| 91 |
+
prompt = prompt[0]
|
| 92 |
+
additional_prompts = None
|
| 93 |
+
else:
|
| 94 |
+
additional_prompts = prompt[1:]
|
| 95 |
+
prompt = prompt[0]
|
| 96 |
+
|
| 97 |
+
assert not (
|
| 98 |
+
(additional_prompts is not None) and loop
|
| 99 |
+
), "Do not provide additional prompts and set loop to True"
|
| 100 |
+
|
| 101 |
+
torch_device = torch.device(device)
|
| 102 |
+
if num_steps is None:
|
| 103 |
+
num_steps = 30
|
| 104 |
+
|
| 105 |
+
# allow for packing and conversion to latent space
|
| 106 |
+
img_height = 16 * (img_height // 16)
|
| 107 |
+
img_width = 16 * (img_width // 16)
|
| 108 |
+
|
| 109 |
+
# init all components
|
| 110 |
+
model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
|
| 111 |
+
|
| 112 |
+
ae = load_ae(
|
| 113 |
+
ckpt_path="assets/autoencoder/ae.safetensors",
|
| 114 |
+
autoencoder_params=AutoEncoderParams(),
|
| 115 |
+
device=torch_device,
|
| 116 |
+
dtype=torch.bfloat16,
|
| 117 |
+
)
|
| 118 |
+
clip_tokenizer = FluxTokenizer(
|
| 119 |
+
model_path="openai/clip-vit-large-patch14", max_length=77
|
| 120 |
+
)
|
| 121 |
+
t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
|
| 122 |
+
clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
|
| 123 |
+
torch_device, dtype=torch.bfloat16
|
| 124 |
+
)
|
| 125 |
+
t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
|
| 126 |
+
torch_device, dtype=torch.bfloat16
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
rng = torch.Generator(device="cpu")
|
| 130 |
+
|
| 131 |
+
if seed is None:
|
| 132 |
+
seed = rng.seed()
|
| 133 |
+
print(f"Generating with seed {seed}:\n{prompt}")
|
| 134 |
+
t0 = time.perf_counter()
|
| 135 |
+
output_name = os.path.join(output_dir, f"img_{seed}.jpg")
|
| 136 |
+
|
| 137 |
+
# Tokenize the prompt, on CPU
|
| 138 |
+
clip_tokens = clip_tokenizer.encode(prompt)
|
| 139 |
+
t5_tokens = t5_tokenizer.encode(prompt)
|
| 140 |
+
|
| 141 |
+
batch = preprocess_flux_data(
|
| 142 |
+
device=torch_device,
|
| 143 |
+
dtype=torch.bfloat16,
|
| 144 |
+
autoencoder=None,
|
| 145 |
+
clip_encoder=clip_encoder,
|
| 146 |
+
t5_encoder=t5_encoder,
|
| 147 |
+
batch={
|
| 148 |
+
"clip_tokens": clip_tokens,
|
| 149 |
+
"t5_tokens": t5_tokens,
|
| 150 |
+
},
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
img = self._generate_images(
|
| 154 |
+
device=torch_device,
|
| 155 |
+
dtype=torch.bfloat16,
|
| 156 |
+
model=model,
|
| 157 |
+
decoder=ae,
|
| 158 |
+
img_width=img_width,
|
| 159 |
+
img_height=img_height,
|
| 160 |
+
denoising_steps=num_steps,
|
| 161 |
+
seed=seed,
|
| 162 |
+
clip_encodings=batch["clip_encodings"],
|
| 163 |
+
t5_encodings=batch["t5_encodings"],
|
| 164 |
+
guidance=guidance,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if torch.cuda.is_available():
|
| 168 |
+
torch.cuda.synchronize()
|
| 169 |
+
t1 = time.perf_counter()
|
| 170 |
+
|
| 171 |
+
print(f"Done in {t1 - t0:.1f}s.")
|
| 172 |
+
|
| 173 |
+
self._save_image(name, output_name, img, add_sampling_metadata, prompt)
|
| 174 |
+
|
| 175 |
+
def _generate_images(
|
| 176 |
+
self,
|
| 177 |
+
device: torch.device,
|
| 178 |
+
dtype: torch.dtype,
|
| 179 |
+
model: FluxModel,
|
| 180 |
+
decoder: AutoEncoder,
|
| 181 |
+
# image params:
|
| 182 |
+
img_width: int,
|
| 183 |
+
img_height: int,
|
| 184 |
+
# sampling params:
|
| 185 |
+
denoising_steps: int,
|
| 186 |
+
seed: int,
|
| 187 |
+
clip_encodings: torch.Tensor,
|
| 188 |
+
t5_encodings: torch.Tensor,
|
| 189 |
+
guidance: float = 4.0,
|
| 190 |
+
):
|
| 191 |
+
|
| 192 |
+
bsz = clip_encodings.shape[0]
|
| 193 |
+
latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
|
| 194 |
+
_, latent_channels, latent_height, latent_width = latents.shape
|
| 195 |
+
|
| 196 |
+
# create denoising schedule
|
| 197 |
+
timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
|
| 198 |
+
|
| 199 |
+
# create positional encodings
|
| 200 |
+
POSITION_DIM = 3 # constant for Flux flow model
|
| 201 |
+
latent_pos_enc = create_position_encoding_for_latents(
|
| 202 |
+
bsz, latent_height, latent_width, POSITION_DIM
|
| 203 |
+
).to(latents)
|
| 204 |
+
text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
|
| 205 |
+
|
| 206 |
+
# convert img-like latents into sequences of patches
|
| 207 |
+
latents = pack_latents(latents)
|
| 208 |
+
|
| 209 |
+
# this is ignored for schnell
|
| 210 |
+
guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
|
| 211 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
| 212 |
+
t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
|
| 213 |
+
pred = model(
|
| 214 |
+
img=latents,
|
| 215 |
+
img_ids=latent_pos_enc,
|
| 216 |
+
txt=t5_encodings,
|
| 217 |
+
txt_ids=text_pos_enc,
|
| 218 |
+
y=clip_encodings,
|
| 219 |
+
timesteps=t_vec,
|
| 220 |
+
guidance=guidance_vec,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
latents = latents + (t_prev - t_curr) * pred
|
| 224 |
+
|
| 225 |
+
# convert sequences of patches into img-like latents
|
| 226 |
+
latents = unpack_latents(latents, latent_height, latent_width)
|
| 227 |
+
|
| 228 |
+
img = decoder.decode(latents)
|
| 229 |
+
return img
|
| 230 |
+
|
| 231 |
+
def _save_image(
|
| 232 |
+
self,
|
| 233 |
+
name: str,
|
| 234 |
+
output_name: str,
|
| 235 |
+
x: torch.Tensor,
|
| 236 |
+
add_sampling_metadata: bool,
|
| 237 |
+
prompt: str,
|
| 238 |
+
):
|
| 239 |
+
print(f"Saving {output_name}")
|
| 240 |
+
# bring into PIL format and save
|
| 241 |
+
x = x.clamp(-1, 1)
|
| 242 |
+
x = rearrange(x[0], "c h w -> h w c")
|
| 243 |
+
|
| 244 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
| 245 |
+
|
| 246 |
+
exif_data = Image.Exif()
|
| 247 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
| 248 |
+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
| 249 |
+
exif_data[ExifTags.Base.Model] = name
|
| 250 |
+
if add_sampling_metadata:
|
| 251 |
+
exif_data[ExifTags.Base.ImageDescription] = prompt
|
| 252 |
+
img.save(output_name, exif=exif_data, quality=95, subsampling=0)
|
torchtitan/experiments/flux/train_configs/debug_model.toml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
[job]
|
| 3 |
+
dump_folder = "./outputs"
|
| 4 |
+
description = "Flux debug model"
|
| 5 |
+
print_args = false
|
| 6 |
+
use_for_integration_test = true
|
| 7 |
+
|
| 8 |
+
[profiling]
|
| 9 |
+
enable_profiling = false
|
| 10 |
+
save_traces_folder = "profile_trace"
|
| 11 |
+
profile_freq = 10
|
| 12 |
+
enable_memory_snapshot = false
|
| 13 |
+
save_memory_snapshot_folder = "memory_snapshot"
|
| 14 |
+
|
| 15 |
+
[metrics]
|
| 16 |
+
log_freq = 1
|
| 17 |
+
disable_color_printing = false
|
| 18 |
+
enable_tensorboard = false
|
| 19 |
+
save_tb_folder = "tb"
|
| 20 |
+
enable_wandb = false
|
| 21 |
+
|
| 22 |
+
[model]
|
| 23 |
+
name = "flux"
|
| 24 |
+
flavor = "flux-debug"
|
| 25 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 26 |
+
# test tokenizer.model, for debug purpose only
|
| 27 |
+
# tokenizer_path = "./tests/assets/test_tiktoken.model"
|
| 28 |
+
# converters = "float8"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
[optimizer]
|
| 32 |
+
name = "AdamW"
|
| 33 |
+
lr = 8e-4
|
| 34 |
+
eps = 1e-8
|
| 35 |
+
|
| 36 |
+
[lr_scheduler]
|
| 37 |
+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
|
| 38 |
+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
|
| 39 |
+
decay_type = "linear"
|
| 40 |
+
lr_min = 0.0
|
| 41 |
+
|
| 42 |
+
[training]
|
| 43 |
+
batch_size = 32
|
| 44 |
+
seq_len = 512
|
| 45 |
+
max_norm = 1.0 # grad norm clipping
|
| 46 |
+
steps = 10
|
| 47 |
+
compile = false
|
| 48 |
+
dataset = "cc12m"
|
| 49 |
+
guidance = 3.5
|
| 50 |
+
seed = 0
|
| 51 |
+
|
| 52 |
+
[encoder]
|
| 53 |
+
t5_encoder="google/t5-v1_1-small"
|
| 54 |
+
clip_encoder="openai/clip-vit-large-patch14"
|
| 55 |
+
max_t5_encoding_len=512
|
| 56 |
+
auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
|
| 57 |
+
|
| 58 |
+
[parallelism]
|
| 59 |
+
data_parallel_replicate_degree = 1
|
| 60 |
+
data_parallel_shard_degree = 1
|
| 61 |
+
fsdp_reshard_after_forward = "default" # default / never / always
|
| 62 |
+
tensor_parallel_degree = 1
|
| 63 |
+
enable_async_tensor_parallel = false
|
| 64 |
+
pipeline_parallel_degree = 1
|
| 65 |
+
context_parallel_degree = 1
|
| 66 |
+
|
| 67 |
+
[experimental]
|
| 68 |
+
custom_args_module = "torchtitan.experiments.flux.flux_argparser"
|
torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 8 |
+
# All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Benchmark comparing reference PyTorch vs optimized M*G group GEMM implementation
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import logging
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
# from typing import Dict, List, Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import triton
|
| 22 |
+
|
| 23 |
+
# import triton.language as tl
|
| 24 |
+
|
| 25 |
+
# Configure logging
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Try to import the optimized implementations
|
| 31 |
+
try:
|
| 32 |
+
from torchao_pr.mg_grouped_gemm import grouped_gemm_forward
|
| 33 |
+
|
| 34 |
+
except ImportError:
|
| 35 |
+
logging.error(
|
| 36 |
+
"Error importing MG grouped GEMM modules. Make sure the implementation files are in the correct path."
|
| 37 |
+
)
|
| 38 |
+
raise
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def compute_reference_forward(x, w, m_sizes):
|
| 42 |
+
"""
|
| 43 |
+
Reference PyTorch implementation of M*G grouped GEMM forward pass.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
x (torch.Tensor): Input tensor of shape (M, K)
|
| 47 |
+
w (torch.Tensor): Weight tensor of shape (N, K)
|
| 48 |
+
m_sizes (torch.Tensor): Group sizes tensor of shape (G)
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
torch.Tensor: Output tensor of shape (M, N)
|
| 52 |
+
"""
|
| 53 |
+
result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
|
| 54 |
+
|
| 55 |
+
m_start = 0
|
| 56 |
+
for g in range(len(m_sizes)):
|
| 57 |
+
m_size = m_sizes[g].item()
|
| 58 |
+
if m_size > 0:
|
| 59 |
+
m_end = m_start + m_size
|
| 60 |
+
|
| 61 |
+
# Extract group input
|
| 62 |
+
x_g = x[m_start:m_end]
|
| 63 |
+
|
| 64 |
+
# Compute group output
|
| 65 |
+
y_g = torch.matmul(x_g, w.T)
|
| 66 |
+
|
| 67 |
+
# Store result
|
| 68 |
+
result[m_start:m_end] = y_g
|
| 69 |
+
|
| 70 |
+
# Update start index
|
| 71 |
+
m_start = m_end
|
| 72 |
+
|
| 73 |
+
return result
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@triton.testing.perf_report(
|
| 77 |
+
triton.testing.Benchmark(
|
| 78 |
+
x_names=["N"], # We'll vary the output dimension
|
| 79 |
+
x_vals=[1024, 2048, 4096, 8192, 16384], # Different output dimensions to test
|
| 80 |
+
# x_vals=[8192, 16384],
|
| 81 |
+
line_arg="provider", # We'll compare different providers
|
| 82 |
+
line_vals=["pytorch_reference", "M*G grouped GEMM"],
|
| 83 |
+
line_names=["PyTorch Reference", "M*G grouped Kernel"],
|
| 84 |
+
styles=[("blue", "-"), ("red", "-")],
|
| 85 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
| 86 |
+
plot_name="mg_grouped_gemm_comparison",
|
| 87 |
+
args={
|
| 88 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
| 89 |
+
"K": 7168, # Hidden dimension, fixed for all tests
|
| 90 |
+
"G": 8, # Number of groups
|
| 91 |
+
"dtype": torch.float16,
|
| 92 |
+
"device": "cuda",
|
| 93 |
+
},
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
def benchmark_forward(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
|
| 97 |
+
"""
|
| 98 |
+
Benchmark the forward pass of the grouped GEMM implementation.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
M (int): Total batch size dimension
|
| 102 |
+
K (int): Hidden dimension
|
| 103 |
+
N (int): Output dimension
|
| 104 |
+
G (int): Number of groups
|
| 105 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
| 106 |
+
dtype (torch.dtype): Data type to use
|
| 107 |
+
device (str): Device to use
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
float: Performance in TFLOPS
|
| 111 |
+
"""
|
| 112 |
+
# Create group sizes for M dimension (balanced across groups)
|
| 113 |
+
base_size = M // G
|
| 114 |
+
remainder = M % G
|
| 115 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
| 116 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 117 |
+
|
| 118 |
+
print(f"N: {N}, M: {M}, K: {K}, G: {G}, dtype: {dtype}, device: {device}")
|
| 119 |
+
|
| 120 |
+
# Create input and weight tensors
|
| 121 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
| 122 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
| 123 |
+
|
| 124 |
+
# Pre-compute for PyTorch reference to ensure fair comparison
|
| 125 |
+
if provider == "pytorch_reference":
|
| 126 |
+
# Warmup
|
| 127 |
+
torch.cuda.synchronize()
|
| 128 |
+
compute_reference_forward(x, w, m_sizes)
|
| 129 |
+
torch.cuda.synchronize()
|
| 130 |
+
|
| 131 |
+
# Benchmark
|
| 132 |
+
start_time = time.time()
|
| 133 |
+
for _ in range(10): # Average over 10 runs
|
| 134 |
+
compute_reference_forward(x, w, m_sizes)
|
| 135 |
+
torch.cuda.synchronize()
|
| 136 |
+
end_time = time.time()
|
| 137 |
+
else: # Optimized kernel
|
| 138 |
+
# Warmup
|
| 139 |
+
torch.cuda.synchronize()
|
| 140 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 141 |
+
torch.cuda.synchronize()
|
| 142 |
+
|
| 143 |
+
# Benchmark
|
| 144 |
+
start_time = time.time()
|
| 145 |
+
for _ in range(10): # Average over 10 runs
|
| 146 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 147 |
+
torch.cuda.synchronize()
|
| 148 |
+
end_time = time.time()
|
| 149 |
+
|
| 150 |
+
# Calculate FLOPs
|
| 151 |
+
# For GEMM: 2 * M * N * K FLOPs (multiply-add counts as 2 FLOPs)
|
| 152 |
+
flops = 2 * M * N * K
|
| 153 |
+
|
| 154 |
+
# Convert to TFLOPS (tera-FLOPS)
|
| 155 |
+
avg_time = (end_time - start_time) / 10 # Average time per run
|
| 156 |
+
tflops = flops / avg_time / 1e12
|
| 157 |
+
|
| 158 |
+
return tflops
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@triton.testing.perf_report(
|
| 162 |
+
triton.testing.Benchmark(
|
| 163 |
+
x_names=["G"], # We'll vary the number of groups
|
| 164 |
+
x_vals=[1, 2, 4, 8, 16], # Different numbers of groups to test
|
| 165 |
+
line_arg="provider", # We'll compare different providers
|
| 166 |
+
line_vals=["pytorch_reference", "optimized_kernel"],
|
| 167 |
+
line_names=["PyTorch Reference", "Optimized Kernel"],
|
| 168 |
+
styles=[("blue", "-"), ("red", "-")],
|
| 169 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
| 170 |
+
plot_name="mg_grouped_gemm_group_scaling",
|
| 171 |
+
args={
|
| 172 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
| 173 |
+
"K": 4096, # Hidden dimension, fixed for all tests
|
| 174 |
+
"N": 8192, # Output dimension, fixed for all tests
|
| 175 |
+
"dtype": torch.float16,
|
| 176 |
+
"device": "cuda",
|
| 177 |
+
},
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
def benchmark_forward_groups(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
|
| 181 |
+
"""
|
| 182 |
+
Benchmark how performance scales with number of groups.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
M (int): Total batch size dimension
|
| 186 |
+
K (int): Hidden dimension
|
| 187 |
+
N (int): Output dimension
|
| 188 |
+
G (int): Number of groups
|
| 189 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
| 190 |
+
dtype (torch.dtype): Data type to use
|
| 191 |
+
device (str): Device to use
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
float: Performance in TFLOPS
|
| 195 |
+
"""
|
| 196 |
+
# Create group sizes for M dimension (balanced across groups)
|
| 197 |
+
base_size = M // G
|
| 198 |
+
remainder = M % G
|
| 199 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
| 200 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 201 |
+
|
| 202 |
+
# Create input and weight tensors
|
| 203 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
| 204 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
| 205 |
+
|
| 206 |
+
# Benchmark logic - same as previous function
|
| 207 |
+
if provider == "pytorch_reference":
|
| 208 |
+
torch.cuda.synchronize()
|
| 209 |
+
compute_reference_forward(x, w, m_sizes)
|
| 210 |
+
torch.cuda.synchronize()
|
| 211 |
+
|
| 212 |
+
start_time = time.time()
|
| 213 |
+
for _ in range(10):
|
| 214 |
+
compute_reference_forward(x, w, m_sizes)
|
| 215 |
+
torch.cuda.synchronize()
|
| 216 |
+
end_time = time.time()
|
| 217 |
+
else:
|
| 218 |
+
torch.cuda.synchronize()
|
| 219 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 220 |
+
torch.cuda.synchronize()
|
| 221 |
+
|
| 222 |
+
start_time = time.time()
|
| 223 |
+
for _ in range(10):
|
| 224 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 225 |
+
torch.cuda.synchronize()
|
| 226 |
+
end_time = time.time()
|
| 227 |
+
|
| 228 |
+
# Calculate FLOPs and TFLOPS
|
| 229 |
+
flops = 2 * M * N * K
|
| 230 |
+
avg_time = (end_time - start_time) / 10
|
| 231 |
+
tflops = flops / avg_time / 1e12
|
| 232 |
+
|
| 233 |
+
return tflops
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@triton.testing.perf_report(
|
| 237 |
+
triton.testing.Benchmark(
|
| 238 |
+
x_names=["group_balance"], # We'll vary the group balance factor
|
| 239 |
+
x_vals=[
|
| 240 |
+
0.0,
|
| 241 |
+
0.25,
|
| 242 |
+
0.5,
|
| 243 |
+
0.75,
|
| 244 |
+
0.9,
|
| 245 |
+
], # Different imbalance factors (0 = balanced, 1 = max imbalance)
|
| 246 |
+
line_arg="provider", # We'll compare different providers
|
| 247 |
+
line_vals=["pytorch_reference", "optimized_kernel"],
|
| 248 |
+
line_names=["PyTorch Reference", "Optimized Kernel"],
|
| 249 |
+
styles=[("blue", "-"), ("red", "-")],
|
| 250 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
| 251 |
+
plot_name="mg_grouped_gemm_imbalance",
|
| 252 |
+
args={
|
| 253 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
| 254 |
+
"K": 4096, # Hidden dimension, fixed for all tests
|
| 255 |
+
"N": 8192, # Output dimension, fixed for all tests
|
| 256 |
+
"G": 4, # Number of groups
|
| 257 |
+
"dtype": torch.float16,
|
| 258 |
+
"device": "cuda",
|
| 259 |
+
},
|
| 260 |
+
)
|
| 261 |
+
)
|
| 262 |
+
def benchmark_imbalance(
|
| 263 |
+
M, K, N, G, group_balance, provider, dtype=torch.float16, device="cuda"
|
| 264 |
+
):
|
| 265 |
+
"""
|
| 266 |
+
Benchmark how performance is affected by imbalanced group sizes.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
M (int): Total batch size dimension
|
| 270 |
+
K (int): Hidden dimension
|
| 271 |
+
N (int): Output dimension
|
| 272 |
+
G (int): Number of groups
|
| 273 |
+
group_balance (float): Balance factor from 0 to 1 (0 = balanced, 1 = max imbalance)
|
| 274 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
| 275 |
+
dtype (torch.dtype): Data type to use
|
| 276 |
+
device (str): Device to use
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
float: Performance in TFLOPS
|
| 280 |
+
"""
|
| 281 |
+
# Create imbalanced group sizes for M dimension
|
| 282 |
+
if group_balance == 0:
|
| 283 |
+
# Balanced case
|
| 284 |
+
base_size = M // G
|
| 285 |
+
remainder = M % G
|
| 286 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
| 287 |
+
else:
|
| 288 |
+
# Imbalanced case
|
| 289 |
+
# First group gets more elements, last group gets fewer
|
| 290 |
+
# The imbalance is controlled by the group_balance factor
|
| 291 |
+
remaining = M
|
| 292 |
+
M_sizes = []
|
| 293 |
+
for g in range(G):
|
| 294 |
+
# Interpolate from balanced to imbalanced based on group_balance
|
| 295 |
+
# For balanced (group_balance=0), each group gets M/G
|
| 296 |
+
# For imbalanced (group_balance=1), first group gets much more than last group
|
| 297 |
+
balanced_size = remaining // (G - g)
|
| 298 |
+
|
| 299 |
+
# Adjusting size based on position and imbalance factor
|
| 300 |
+
# First groups get more, last groups get less
|
| 301 |
+
if g < G // 2:
|
| 302 |
+
# First half of groups get more
|
| 303 |
+
adjustment = int(balanced_size * group_balance * (1 - g / (G - 1)))
|
| 304 |
+
size = balanced_size + adjustment
|
| 305 |
+
else:
|
| 306 |
+
# Second half of groups get less
|
| 307 |
+
adjustment = int(balanced_size * group_balance * ((g / (G - 1)) - 0.5))
|
| 308 |
+
size = balanced_size - adjustment
|
| 309 |
+
|
| 310 |
+
# Ensure we don't go below 1 or take more than remaining
|
| 311 |
+
size = max(1, min(size, remaining))
|
| 312 |
+
M_sizes.append(size)
|
| 313 |
+
remaining -= size
|
| 314 |
+
|
| 315 |
+
# Handle any remaining elements
|
| 316 |
+
if remaining > 0:
|
| 317 |
+
M_sizes[-1] += remaining
|
| 318 |
+
|
| 319 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 320 |
+
|
| 321 |
+
# Create input and weight tensors
|
| 322 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
| 323 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
| 324 |
+
|
| 325 |
+
# Benchmark logic
|
| 326 |
+
if provider == "pytorch_reference":
|
| 327 |
+
torch.cuda.synchronize()
|
| 328 |
+
compute_reference_forward(x, w, m_sizes)
|
| 329 |
+
torch.cuda.synchronize()
|
| 330 |
+
|
| 331 |
+
start_time = time.time()
|
| 332 |
+
for _ in range(10):
|
| 333 |
+
compute_reference_forward(x, w, m_sizes)
|
| 334 |
+
torch.cuda.synchronize()
|
| 335 |
+
end_time = time.time()
|
| 336 |
+
else:
|
| 337 |
+
torch.cuda.synchronize()
|
| 338 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 339 |
+
torch.cuda.synchronize()
|
| 340 |
+
|
| 341 |
+
start_time = time.time()
|
| 342 |
+
for _ in range(10):
|
| 343 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 344 |
+
torch.cuda.synchronize()
|
| 345 |
+
end_time = time.time()
|
| 346 |
+
|
| 347 |
+
# Calculate FLOPs and TFLOPS
|
| 348 |
+
flops = 2 * M * N * K
|
| 349 |
+
avg_time = (end_time - start_time) / 10
|
| 350 |
+
tflops = flops / avg_time / 1e12
|
| 351 |
+
|
| 352 |
+
return tflops
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def benchmark_model_configs():
|
| 356 |
+
"""
|
| 357 |
+
Benchmark common model configurations used in DeepSeek-like models.
|
| 358 |
+
"""
|
| 359 |
+
# Model configurations: (M, K, N, G)
|
| 360 |
+
configs = [
|
| 361 |
+
(8192, 7168, 4096, 4), # Config 1
|
| 362 |
+
(8192, 2048, 7168, 4), # Config 2
|
| 363 |
+
(4096, 7168, 4096, 8), # Config 3
|
| 364 |
+
(4096, 2048, 7168, 8), # Config 4
|
| 365 |
+
]
|
| 366 |
+
|
| 367 |
+
results = []
|
| 368 |
+
|
| 369 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 370 |
+
dtype = torch.float16
|
| 371 |
+
|
| 372 |
+
for config_idx, (M, K, N, G) in enumerate(configs):
|
| 373 |
+
logging.info(f"\n===== Benchmarking DeepSeek Config {config_idx + 1} =====")
|
| 374 |
+
logging.info(f"M={M}, K={K}, N={N}, G={G}")
|
| 375 |
+
|
| 376 |
+
# Create group sizes for M dimension
|
| 377 |
+
base_size = M // G
|
| 378 |
+
remainder = M % G
|
| 379 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
| 380 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 381 |
+
|
| 382 |
+
# Create tensors
|
| 383 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
| 384 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
| 385 |
+
|
| 386 |
+
# Benchmark PyTorch reference
|
| 387 |
+
torch.cuda.synchronize()
|
| 388 |
+
compute_reference_forward(x, w, m_sizes) # Warmup
|
| 389 |
+
torch.cuda.synchronize()
|
| 390 |
+
|
| 391 |
+
logging.info("Benchmarking PyTorch reference...")
|
| 392 |
+
torch.cuda.reset_peak_memory_stats()
|
| 393 |
+
start_time = time.time()
|
| 394 |
+
for _ in range(10):
|
| 395 |
+
compute_reference_forward(x, w, m_sizes)
|
| 396 |
+
torch.cuda.synchronize()
|
| 397 |
+
end_time = time.time()
|
| 398 |
+
pt_time = (end_time - start_time) / 10
|
| 399 |
+
pt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
|
| 400 |
+
|
| 401 |
+
# Benchmark optimized kernel
|
| 402 |
+
torch.cuda.synchronize()
|
| 403 |
+
grouped_gemm_forward(x, w, m_sizes) # Warmup
|
| 404 |
+
torch.cuda.synchronize()
|
| 405 |
+
|
| 406 |
+
logging.info("Benchmarking optimized kernel...")
|
| 407 |
+
torch.cuda.reset_peak_memory_stats()
|
| 408 |
+
start_time = time.time()
|
| 409 |
+
for _ in range(10):
|
| 410 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 411 |
+
torch.cuda.synchronize()
|
| 412 |
+
end_time = time.time()
|
| 413 |
+
opt_time = (end_time - start_time) / 10
|
| 414 |
+
opt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
|
| 415 |
+
|
| 416 |
+
# Calculate FLOPs and speedup
|
| 417 |
+
flops = 2 * M * N * K
|
| 418 |
+
pt_tflops = flops / pt_time / 1e12
|
| 419 |
+
opt_tflops = flops / opt_time / 1e12
|
| 420 |
+
speedup = pt_time / opt_time
|
| 421 |
+
|
| 422 |
+
# Store results
|
| 423 |
+
results.append(
|
| 424 |
+
{
|
| 425 |
+
"config": f"Config {config_idx + 1}",
|
| 426 |
+
"dimensions": f"M={M}, K={K}, N={N}, G={G}",
|
| 427 |
+
"pt_time_ms": pt_time * 1000,
|
| 428 |
+
"opt_time_ms": opt_time * 1000,
|
| 429 |
+
"pt_tflops": pt_tflops,
|
| 430 |
+
"opt_tflops": opt_tflops,
|
| 431 |
+
"speedup": speedup,
|
| 432 |
+
"pt_memory_mb": pt_memory,
|
| 433 |
+
"opt_memory_mb": opt_memory,
|
| 434 |
+
"memory_savings": (
|
| 435 |
+
(pt_memory - opt_memory) / pt_memory * 100 if pt_memory > 0 else 0
|
| 436 |
+
),
|
| 437 |
+
}
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
logging.info(
|
| 441 |
+
f"PyTorch Reference: {pt_time * 1000:.2f} ms, {pt_tflops:.2f} TFLOPS, {pt_memory:.2f} MB"
|
| 442 |
+
)
|
| 443 |
+
logging.info(
|
| 444 |
+
f"Optimized Kernel: {opt_time * 1000:.2f} ms, {opt_tflops:.2f} TFLOPS, {opt_memory:.2f} MB"
|
| 445 |
+
)
|
| 446 |
+
logging.info(
|
| 447 |
+
f"Speedup: {speedup:.2f}x, Memory savings: {results[-1]['memory_savings']:.2f}%"
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# Print summary table
|
| 451 |
+
logging.info("\n===== Benchmark Results Summary =====")
|
| 452 |
+
logging.info(
|
| 453 |
+
f"{'Config':<10} | {'Time (ms)':<20} | {'TFLOPS':<20} | {'Speedup':<10} | {'Memory (MB)':<20} | {'Memory Saved':<12}"
|
| 454 |
+
)
|
| 455 |
+
logging.info(
|
| 456 |
+
f"{'':<10} | {'PyTorch':<9} {'Kernel':<9} | {'PyTorch':<9} {'Kernel':<9} | {'':<10} | "
|
| 457 |
+
f"{'PyTorch':<9} {'Kernel':<9} | {'':<12}"
|
| 458 |
+
)
|
| 459 |
+
logging.info("-" * 100)
|
| 460 |
+
|
| 461 |
+
for result in results:
|
| 462 |
+
logging.info(
|
| 463 |
+
f"{result['config']:<10} | "
|
| 464 |
+
f"{result['pt_time_ms']:<9.2f} {result['opt_time_ms']:<9.2f} | "
|
| 465 |
+
f"{result['pt_tflops']:<9.2f} {result['opt_tflops']:<9.2f} | "
|
| 466 |
+
f"{result['speedup']:<10.2f} | "
|
| 467 |
+
f"{result['pt_memory_mb']:<9.2f} {result['opt_memory_mb']:<9.2f} | "
|
| 468 |
+
f"{result['memory_savings']:<12.2f}%"
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
return results
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def plot_benchmark_results(results):
|
| 475 |
+
"""
|
| 476 |
+
Plot benchmark results as bar charts.
|
| 477 |
+
"""
|
| 478 |
+
# Extract data
|
| 479 |
+
configs = [r["config"] for r in results]
|
| 480 |
+
pt_tflops = [r["pt_tflops"] for r in results]
|
| 481 |
+
opt_tflops = [r["opt_tflops"] for r in results]
|
| 482 |
+
speedups = [r["speedup"] for r in results]
|
| 483 |
+
|
| 484 |
+
# Create figure with subplots
|
| 485 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
| 486 |
+
|
| 487 |
+
# Plot TFLOPS comparison
|
| 488 |
+
x = np.arange(len(configs))
|
| 489 |
+
width = 0.35
|
| 490 |
+
ax1.bar(x - width / 2, pt_tflops, width, label="PyTorch Reference")
|
| 491 |
+
ax1.bar(x + width / 2, opt_tflops, width, label="Optimized Kernel")
|
| 492 |
+
ax1.set_xlabel("Model Configuration")
|
| 493 |
+
ax1.set_ylabel("TFLOPS")
|
| 494 |
+
ax1.set_title("Performance Comparison (Higher is Better)")
|
| 495 |
+
ax1.set_xticks(x)
|
| 496 |
+
ax1.set_xticklabels(configs)
|
| 497 |
+
ax1.legend()
|
| 498 |
+
ax1.grid(axis="y", linestyle="--", alpha=0.7)
|
| 499 |
+
|
| 500 |
+
# Plot speedup
|
| 501 |
+
ax2.bar(x, speedups, width=0.6, color="green")
|
| 502 |
+
ax2.set_xlabel("Model Configuration")
|
| 503 |
+
ax2.set_ylabel("Speedup (x)")
|
| 504 |
+
ax2.set_title("Speedup Factor (Higher is Better)")
|
| 505 |
+
ax2.set_xticks(x)
|
| 506 |
+
ax2.set_xticklabels(configs)
|
| 507 |
+
ax2.grid(axis="y", linestyle="--", alpha=0.7)
|
| 508 |
+
|
| 509 |
+
# Add speedup values on top of bars
|
| 510 |
+
for i, v in enumerate(speedups):
|
| 511 |
+
ax2.text(i, v + 0.1, f"{v:.2f}x", ha="center")
|
| 512 |
+
|
| 513 |
+
plt.tight_layout()
|
| 514 |
+
plt.savefig("mg_grouped_gemm_benchmark_results.png")
|
| 515 |
+
logging.info(
|
| 516 |
+
"Benchmark results plot saved to 'mg_grouped_gemm_benchmark_results.png'"
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def compare_mg_implementations():
|
| 521 |
+
"""
|
| 522 |
+
Combine the M*G and N*G benchmark results for comparison.
|
| 523 |
+
"""
|
| 524 |
+
# Only run this if both NG and MG benchmarks have been run
|
| 525 |
+
try:
|
| 526 |
+
import pandas as pd
|
| 527 |
+
|
| 528 |
+
# Try to load previous benchmark results
|
| 529 |
+
mg_results = pd.read_csv("mg_grouped_gemm_benchmark_results.csv")
|
| 530 |
+
ng_results = pd.read_csv("ng_grouped_gemm_benchmark_results.csv")
|
| 531 |
+
|
| 532 |
+
# Create comparison plot
|
| 533 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
| 534 |
+
|
| 535 |
+
# Plot speedup comparison
|
| 536 |
+
configs = mg_results["config"].unique()
|
| 537 |
+
mg_speedups = mg_results.groupby("config")["speedup"].mean()
|
| 538 |
+
ng_speedups = ng_results.groupby("config")["speedup"].mean()
|
| 539 |
+
|
| 540 |
+
x = np.arange(len(configs))
|
| 541 |
+
width = 0.35
|
| 542 |
+
|
| 543 |
+
axes[0].bar(x - width / 2, mg_speedups, width, label="M*G Grouping")
|
| 544 |
+
axes[0].bar(x + width / 2, ng_speedups, width, label="N*G Grouping")
|
| 545 |
+
axes[0].set_xlabel("Model Configuration")
|
| 546 |
+
axes[0].set_ylabel("Speedup (x)")
|
| 547 |
+
axes[0].set_title("Speedup Comparison: M*G vs N*G")
|
| 548 |
+
axes[0].set_xticks(x)
|
| 549 |
+
axes[0].set_xticklabels(configs)
|
| 550 |
+
axes[0].legend()
|
| 551 |
+
axes[0].grid(axis="y", linestyle="--", alpha=0.7)
|
| 552 |
+
|
| 553 |
+
# Plot TFLOPS comparison for optimized kernels
|
| 554 |
+
mg_tflops = (
|
| 555 |
+
mg_results[mg_results["implementation"] == "optimized"]
|
| 556 |
+
.groupby("config")["tflops"]
|
| 557 |
+
.mean()
|
| 558 |
+
)
|
| 559 |
+
ng_tflops = (
|
| 560 |
+
ng_results[ng_results["implementation"] == "optimized"]
|
| 561 |
+
.groupby("config")["tflops"]
|
| 562 |
+
.mean()
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
axes[1].bar(x - width / 2, mg_tflops, width, label="M*G Grouping")
|
| 566 |
+
axes[1].bar(x + width / 2, ng_tflops, width, label="N*G Grouping")
|
| 567 |
+
axes[1].set_xlabel("Model Configuration")
|
| 568 |
+
axes[1].set_ylabel("TFLOPS")
|
| 569 |
+
axes[1].set_title("Performance Comparison: M*G vs N*G")
|
| 570 |
+
axes[1].set_xticks(x)
|
| 571 |
+
axes[1].set_xticklabels(configs)
|
| 572 |
+
axes[1].legend()
|
| 573 |
+
axes[1].grid(axis="y", linestyle="--", alpha=0.7)
|
| 574 |
+
|
| 575 |
+
plt.tight_layout()
|
| 576 |
+
plt.savefig("mg_vs_ng_comparison.png")
|
| 577 |
+
logging.info("Comparison plot saved to 'mg_vs_ng_comparison.png'")
|
| 578 |
+
|
| 579 |
+
except Exception as e:
|
| 580 |
+
logging.error(f"Could not create comparison plot: {e}")
|
| 581 |
+
logging.info(
|
| 582 |
+
"Run both M*G and N*G benchmarks first to generate comparison plots"
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
if __name__ == "__main__":
|
| 587 |
+
parser = argparse.ArgumentParser(
|
| 588 |
+
description="Benchmark M*G Grouped GEMM implementations"
|
| 589 |
+
)
|
| 590 |
+
parser.add_argument("--run-all", action="store_true", help="Run all benchmarks")
|
| 591 |
+
parser.add_argument(
|
| 592 |
+
"--triton-bench", action="store_true", help="Run Triton performance reports"
|
| 593 |
+
)
|
| 594 |
+
parser.add_argument(
|
| 595 |
+
"--model-configs", action="store_true", help="Benchmark model configurations"
|
| 596 |
+
)
|
| 597 |
+
parser.add_argument(
|
| 598 |
+
"--compare-mg-ng",
|
| 599 |
+
action="store_true",
|
| 600 |
+
help="Compare M*G and N*G implementations",
|
| 601 |
+
)
|
| 602 |
+
args = parser.parse_args()
|
| 603 |
+
|
| 604 |
+
# Check if CUDA is available
|
| 605 |
+
if not torch.cuda.is_available():
|
| 606 |
+
logging.error(
|
| 607 |
+
"CUDA is not available. This benchmark requires a CUDA-capable GPU."
|
| 608 |
+
)
|
| 609 |
+
exit(1)
|
| 610 |
+
|
| 611 |
+
if args.run_all or args.model_configs:
|
| 612 |
+
# Benchmark model configurations
|
| 613 |
+
logging.info("Running benchmark for model configurations...")
|
| 614 |
+
results = benchmark_model_configs()
|
| 615 |
+
plot_benchmark_results(results)
|
| 616 |
+
|
| 617 |
+
if args.run_all or args.triton_bench:
|
| 618 |
+
# Run Triton performance reports
|
| 619 |
+
logging.info("Running Triton performance reports...")
|
| 620 |
+
benchmark_forward.run(save_path="mg_grouped_gemm_benchmark_results")
|
| 621 |
+
benchmark_forward_groups.run(save_path="mg_grouped_gemm_benchmark_results")
|
| 622 |
+
benchmark_imbalance.run(save_path="mg_grouped_gemm_benchmark_results")
|
| 623 |
+
logging.info(
|
| 624 |
+
"Triton performance reports saved to 'mg_grouped_gemm_benchmark_results' directory"
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
if args.run_all or args.compare_mg_ng:
|
| 628 |
+
# Compare M*G and N*G implementations
|
| 629 |
+
logging.info("Comparing M*G and N*G implementations...")
|
| 630 |
+
compare_mg_implementations()
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# pyre-unsafe
|
| 8 |
+
import logging
|
| 9 |
+
import unittest
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
from mg_grouped_gemm import grouped_gemm_forward
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestMG_GroupedGEMM(unittest.TestCase):
|
| 19 |
+
def setUp(self) -> None:
|
| 20 |
+
torch.manual_seed(2020)
|
| 21 |
+
|
| 22 |
+
def _run_grouped_gemm_test(
|
| 23 |
+
self,
|
| 24 |
+
shape: Tuple[int, int, int, int],
|
| 25 |
+
device: torch.device,
|
| 26 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 27 |
+
atol: float = 1e-5,
|
| 28 |
+
rtol: float = 1.6e-2,
|
| 29 |
+
) -> None:
|
| 30 |
+
G, M, N, K = shape
|
| 31 |
+
# In M*G grouping, input is [M*G, K] and weights are [N*G, K]
|
| 32 |
+
a = torch.randn(M * G, K, dtype=dtype, device=device)
|
| 33 |
+
b = torch.randn(N * G, K, dtype=dtype, device=device)
|
| 34 |
+
|
| 35 |
+
# Create equal-sized groups for simplicity
|
| 36 |
+
m_size = M
|
| 37 |
+
m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
|
| 38 |
+
|
| 39 |
+
result = grouped_gemm_forward(a, b, m_sizes)
|
| 40 |
+
self.assertTrue(result.shape == (M * G, N))
|
| 41 |
+
|
| 42 |
+
expected_result = torch.zeros(M * G, N, dtype=dtype, device=device)
|
| 43 |
+
m_start = 0
|
| 44 |
+
for g in range(G):
|
| 45 |
+
m_end = m_start + m_sizes[g]
|
| 46 |
+
b_slice = b[N * g : N * (g+1), :]
|
| 47 |
+
expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T
|
| 48 |
+
m_start = m_end
|
| 49 |
+
|
| 50 |
+
# Convert result to match input dtype if needed
|
| 51 |
+
result = result.to(dtype)
|
| 52 |
+
torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol)
|
| 53 |
+
|
| 54 |
+
def test_MG_grouped_gemm_bf16(self) -> None:
|
| 55 |
+
for G in (1, 4, 16):
|
| 56 |
+
for M in (128, 512, 1024):
|
| 57 |
+
print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}")
|
| 58 |
+
self._run_grouped_gemm_test(
|
| 59 |
+
(G, M, 1024, 1024),
|
| 60 |
+
torch.device("cuda"),
|
| 61 |
+
dtype=torch.bfloat16,
|
| 62 |
+
atol=1e-5,
|
| 63 |
+
rtol=1.6e-2,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def test_MG_grouped_gemm_deepseek_shapes(self) -> None:
|
| 67 |
+
"""Test with shapes from Deepseek model."""
|
| 68 |
+
deepseek_shapes = [
|
| 69 |
+
(4, 2048, 4096, 7168), # G, M, N, K
|
| 70 |
+
(4, 2048, 7168, 2048),
|
| 71 |
+
(8, 512, 4096, 7168),
|
| 72 |
+
(8, 512, 7168, 2048),
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
device = torch.device("cuda")
|
| 76 |
+
|
| 77 |
+
for shape in deepseek_shapes:
|
| 78 |
+
G, M, N, K = shape
|
| 79 |
+
print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}")
|
| 80 |
+
self._run_grouped_gemm_test(
|
| 81 |
+
shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2
|
| 82 |
+
)
|