Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- fla/models/delta_net/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc +0 -0
- fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc +0 -0
- fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc +0 -0
- fla/models/gla/configuration_gla.py +95 -0
- fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc +0 -0
- fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc +0 -0
- fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc +0 -0
- fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc +0 -0
- fla/models/transformer/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/transformer_top/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/transformer_top/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
- fla/modules/__pycache__/activations.cpython-312.pyc +0 -0
- fla/modules/__pycache__/convolution.cpython-312.pyc +0 -0
- fla/modules/__pycache__/fused_bitlinear.cpython-312.pyc +0 -0
- fla/modules/__pycache__/mlp.cpython-312.pyc +0 -0
- fla/modules/__pycache__/rotary.cpython-312.pyc +0 -0
- tb/20260101-0922/wandb/run-20260101_092219--dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202601010919/files/requirements.txt +169 -0
- tb/20260101-0922/wandb/run-20260101_092219--dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202601010919/files/wandb-metadata.json +146 -0
- torchtitan/experiments/deepseek_v3/attn_mask_utils.py +397 -0
- torchtitan/experiments/deepseek_v3/inference.sh +15 -0
- torchtitan/experiments/deepseek_v3/model.py +1325 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py +11 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
- torchtitan/experiments/deepseek_v3/train.py +142 -0
- torchtitan/experiments/flux/README.md +23 -0
- torchtitan/experiments/flux/__init__.py +122 -0
- torchtitan/experiments/flux/flux_argparser.py +42 -0
- torchtitan/experiments/flux/loss.py +27 -0
- torchtitan/experiments/flux/parallelize_flux.py +26 -0
- torchtitan/experiments/flux/requirements.txt +2 -0
- torchtitan/experiments/flux/scripts/download_autoencoder.py +61 -0
- torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
- torchtitan/experiments/flux/train.py +224 -0
- torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
- torchtitan/experiments/flux/utils.py +203 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py +13 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py +299 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py +240 -0
- torchtitan/experiments/llama4/README.md +29 -0
- torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc +0 -0
- torchtitan/experiments/llama4/infra/expert_parallel.py +145 -0
- torchtitan/experiments/llama4/infra/parallelize_llama.py +159 -0
- torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc +0 -0
- torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc +0 -0
- torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc +0 -0
fla/models/delta_net/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (729 Bytes). View file
|
|
|
fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc
ADDED
|
Binary file (3.62 kB). View file
|
|
|
fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
fla/models/gla/configuration_gla.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GLAConfig(PretrainedConfig):
|
| 9 |
+
|
| 10 |
+
model_type = 'gla'
|
| 11 |
+
keys_to_ignore_at_inference = ['past_key_values']
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
hidden_size: int = 2048,
|
| 16 |
+
expand_k: int = 0.5,
|
| 17 |
+
expand_v: int = 1,
|
| 18 |
+
hidden_ratio: Optional[int] = 4,
|
| 19 |
+
intermediate_size: Optional[int] = None,
|
| 20 |
+
num_hidden_layers: int = 24,
|
| 21 |
+
num_heads: int = 4,
|
| 22 |
+
num_kv_heads: Optional[int] = None,
|
| 23 |
+
feature_map: Optional[str] = None,
|
| 24 |
+
attn_mode: str = "chunk",
|
| 25 |
+
use_short_conv: bool = False,
|
| 26 |
+
conv_size: int = 4,
|
| 27 |
+
use_output_gate: bool = True,
|
| 28 |
+
clamp_min: Optional[float] = None,
|
| 29 |
+
hidden_act: str = "swish",
|
| 30 |
+
max_position_embeddings: int = 2048,
|
| 31 |
+
elementwise_affine: Optional[bool] = True,
|
| 32 |
+
norm_eps: float = 1e-6,
|
| 33 |
+
use_gk: bool = True,
|
| 34 |
+
use_gv: bool = False,
|
| 35 |
+
attn: Optional[Dict] = None,
|
| 36 |
+
use_cache: bool = True,
|
| 37 |
+
pad_token_id: int = None,
|
| 38 |
+
bos_token_id: int = 1,
|
| 39 |
+
eos_token_id: int = 2,
|
| 40 |
+
tie_word_embeddings: bool = False,
|
| 41 |
+
initializer_range: float = 0.006,
|
| 42 |
+
fuse_norm: bool = True,
|
| 43 |
+
fuse_swiglu: bool = True,
|
| 44 |
+
fuse_cross_entropy: bool = True,
|
| 45 |
+
vocab_size: int = 32000,
|
| 46 |
+
**kwargs
|
| 47 |
+
):
|
| 48 |
+
self.hidden_size = hidden_size
|
| 49 |
+
self.expand_k = expand_k
|
| 50 |
+
self.expand_v = expand_v
|
| 51 |
+
self.hidden_ratio = hidden_ratio
|
| 52 |
+
self.intermediate_size = intermediate_size
|
| 53 |
+
self.num_hidden_layers = num_hidden_layers
|
| 54 |
+
self.num_heads = num_heads
|
| 55 |
+
self.num_kv_heads = num_kv_heads
|
| 56 |
+
self.feature_map = feature_map
|
| 57 |
+
self.attn_mode = attn_mode
|
| 58 |
+
self.use_short_conv = use_short_conv
|
| 59 |
+
self.conv_size = conv_size
|
| 60 |
+
self.use_output_gate = use_output_gate
|
| 61 |
+
self.clamp_min = clamp_min
|
| 62 |
+
self.hidden_act = hidden_act
|
| 63 |
+
self.max_position_embeddings = max_position_embeddings
|
| 64 |
+
self.elementwise_affine = elementwise_affine
|
| 65 |
+
self.norm_eps = norm_eps
|
| 66 |
+
self.use_gk = use_gk
|
| 67 |
+
self.use_gv = use_gv
|
| 68 |
+
self.attn = attn
|
| 69 |
+
self.use_cache = use_cache
|
| 70 |
+
self.initializer_range = initializer_range
|
| 71 |
+
|
| 72 |
+
self.fuse_norm = fuse_norm
|
| 73 |
+
self.fuse_swiglu = fuse_swiglu
|
| 74 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
| 75 |
+
self.vocab_size = vocab_size
|
| 76 |
+
|
| 77 |
+
if attn is not None:
|
| 78 |
+
if not isinstance(attn, Dict):
|
| 79 |
+
raise ValueError("attn must be a dictionary")
|
| 80 |
+
if 'layers' not in attn:
|
| 81 |
+
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
|
| 82 |
+
if 'num_heads' not in attn:
|
| 83 |
+
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
|
| 84 |
+
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
|
| 85 |
+
attn['qkv_bias'] = attn.get('qkv_bias', False)
|
| 86 |
+
attn['window_size'] = attn.get('window_size', None)
|
| 87 |
+
attn['rope_theta'] = attn.get('rope_theta', 10000.)
|
| 88 |
+
|
| 89 |
+
super().__init__(
|
| 90 |
+
pad_token_id=pad_token_id,
|
| 91 |
+
bos_token_id=bos_token_id,
|
| 92 |
+
eos_token_id=eos_token_id,
|
| 93 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 94 |
+
**kwargs,
|
| 95 |
+
)
|
fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc
ADDED
|
Binary file (18.5 kB). View file
|
|
|
fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc
ADDED
|
Binary file (3.76 kB). View file
|
|
|
fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (715 Bytes). View file
|
|
|
fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (715 Bytes). View file
|
|
|
fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc
ADDED
|
Binary file (3.42 kB). View file
|
|
|
fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
fla/models/transformer/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (756 Bytes). View file
|
|
|
fla/models/transformer_top/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (777 Bytes). View file
|
|
|
fla/models/transformer_top/__pycache__/configuration_transformer.cpython-312.pyc
ADDED
|
Binary file (2.83 kB). View file
|
|
|
fla/modules/__pycache__/activations.cpython-312.pyc
ADDED
|
Binary file (23 kB). View file
|
|
|
fla/modules/__pycache__/convolution.cpython-312.pyc
ADDED
|
Binary file (21.1 kB). View file
|
|
|
fla/modules/__pycache__/fused_bitlinear.cpython-312.pyc
ADDED
|
Binary file (23.7 kB). View file
|
|
|
fla/modules/__pycache__/mlp.cpython-312.pyc
ADDED
|
Binary file (6.26 kB). View file
|
|
|
fla/modules/__pycache__/rotary.cpython-312.pyc
ADDED
|
Binary file (23.2 kB). View file
|
|
|
tb/20260101-0922/wandb/run-20260101_092219--dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202601010919/files/requirements.txt
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flame==0.1.0
|
| 2 |
+
fsspec==2025.10.0
|
| 3 |
+
aiohappyeyeballs==2.6.1
|
| 4 |
+
ipykernel==7.1.0
|
| 5 |
+
smmap==5.0.2
|
| 6 |
+
pybind11==3.0.1
|
| 7 |
+
tabulate==0.9.0
|
| 8 |
+
parso==0.8.5
|
| 9 |
+
yarl==1.22.0
|
| 10 |
+
asttokens==3.0.1
|
| 11 |
+
pandas==2.3.3
|
| 12 |
+
xxhash==3.6.0
|
| 13 |
+
pathvalidate==3.3.1
|
| 14 |
+
Werkzeug==3.1.4
|
| 15 |
+
regex==2025.11.3
|
| 16 |
+
inquirerpy==0.3.4
|
| 17 |
+
click==8.3.1
|
| 18 |
+
idna==3.11
|
| 19 |
+
pydantic==2.12.5
|
| 20 |
+
pexpect==4.9.0
|
| 21 |
+
typepy==1.3.4
|
| 22 |
+
certifi==2025.11.12
|
| 23 |
+
wcwidth==0.2.14
|
| 24 |
+
triton==3.2.0
|
| 25 |
+
hf-xet==1.2.0
|
| 26 |
+
joblib==1.5.3
|
| 27 |
+
tqdm==4.67.1
|
| 28 |
+
nvidia-nvtx-cu12==12.4.127
|
| 29 |
+
setuptools==80.9.0
|
| 30 |
+
lxml==6.0.2
|
| 31 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 32 |
+
evaluate==0.4.6
|
| 33 |
+
Markdown==3.10
|
| 34 |
+
chardet==5.2.0
|
| 35 |
+
multiprocess==0.70.18
|
| 36 |
+
tensorboard==2.20.0
|
| 37 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 38 |
+
flame==0.1.0
|
| 39 |
+
matplotlib-inline==0.2.1
|
| 40 |
+
Cython==3.2.3
|
| 41 |
+
tensorboard-data-server==0.7.2
|
| 42 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 43 |
+
lm_eval==0.4.9.1
|
| 44 |
+
pure_eval==0.2.3
|
| 45 |
+
protobuf==6.33.2
|
| 46 |
+
DataProperty==1.1.0
|
| 47 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 48 |
+
accelerate==1.12.0
|
| 49 |
+
psutil==7.1.3
|
| 50 |
+
Jinja2==3.1.6
|
| 51 |
+
scikit-learn==1.8.0
|
| 52 |
+
nvidia-nccl-cu12==2.21.5
|
| 53 |
+
typing_extensions==4.15.0
|
| 54 |
+
pyzmq==27.1.0
|
| 55 |
+
mpmath==1.3.0
|
| 56 |
+
annotated-types==0.7.0
|
| 57 |
+
propcache==0.4.1
|
| 58 |
+
wandb==0.23.1
|
| 59 |
+
requests==2.32.5
|
| 60 |
+
ipython==9.8.0
|
| 61 |
+
more-itertools==10.8.0
|
| 62 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 63 |
+
sacrebleu==2.5.1
|
| 64 |
+
httpx==0.28.1
|
| 65 |
+
huggingface-hub==0.36.0
|
| 66 |
+
MarkupSafe==3.0.3
|
| 67 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 68 |
+
gitdb==4.0.12
|
| 69 |
+
torchdata==0.11.0
|
| 70 |
+
sentry-sdk==2.48.0
|
| 71 |
+
sympy==1.13.1
|
| 72 |
+
safetensors==0.7.0
|
| 73 |
+
httpcore==1.0.9
|
| 74 |
+
portalocker==3.2.0
|
| 75 |
+
attrs==25.4.0
|
| 76 |
+
typing-inspection==0.4.2
|
| 77 |
+
ptyprocess==0.7.0
|
| 78 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 79 |
+
numexpr==2.14.1
|
| 80 |
+
executing==2.2.1
|
| 81 |
+
networkx==3.6.1
|
| 82 |
+
threadpoolctl==3.6.0
|
| 83 |
+
nvidia-cusparselt-cu12==0.6.2
|
| 84 |
+
einops==0.8.1
|
| 85 |
+
zstandard==0.25.0
|
| 86 |
+
comm==0.2.3
|
| 87 |
+
six==1.17.0
|
| 88 |
+
packaging==25.0
|
| 89 |
+
tqdm-multiprocess==0.0.11
|
| 90 |
+
numpy==2.3.5
|
| 91 |
+
colorama==0.4.6
|
| 92 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 93 |
+
jupyter_client==8.7.0
|
| 94 |
+
scipy==1.16.3
|
| 95 |
+
tornado==6.5.4
|
| 96 |
+
nltk==3.9.2
|
| 97 |
+
antlr4-python3-runtime==4.11.0
|
| 98 |
+
jupyter_core==5.9.1
|
| 99 |
+
sqlitedict==2.1.0
|
| 100 |
+
tzdata==2025.3
|
| 101 |
+
pytz==2025.2
|
| 102 |
+
Pygments==2.19.2
|
| 103 |
+
python-dotenv==1.2.1
|
| 104 |
+
cmake==4.2.0
|
| 105 |
+
tiktoken==0.12.0
|
| 106 |
+
PyYAML==6.0.3
|
| 107 |
+
datasets==4.4.1
|
| 108 |
+
pillow==12.0.0
|
| 109 |
+
math-verify==0.8.0
|
| 110 |
+
dill==0.4.0
|
| 111 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 112 |
+
anyio==4.12.0
|
| 113 |
+
prompt_toolkit==3.0.52
|
| 114 |
+
filelock==3.20.1
|
| 115 |
+
jedi==0.19.2
|
| 116 |
+
frozenlist==1.8.0
|
| 117 |
+
tokenizers==0.21.4
|
| 118 |
+
grpcio==1.76.0
|
| 119 |
+
ninja==1.13.0
|
| 120 |
+
mbstrdecoder==1.1.4
|
| 121 |
+
flash-attn==2.7.3
|
| 122 |
+
aiosignal==1.4.0
|
| 123 |
+
tabledata==1.3.4
|
| 124 |
+
h11==0.16.0
|
| 125 |
+
absl-py==2.3.1
|
| 126 |
+
latex2sympy2_extended==1.10.2
|
| 127 |
+
torch==2.6.0
|
| 128 |
+
nest_asyncio==1.6.0
|
| 129 |
+
pip==25.3
|
| 130 |
+
aiohttp==3.13.2
|
| 131 |
+
pfzy==0.3.4
|
| 132 |
+
platformdirs==4.5.1
|
| 133 |
+
wheel==0.45.1
|
| 134 |
+
peft==0.17.0
|
| 135 |
+
debugpy==1.8.19
|
| 136 |
+
ipython_pygments_lexers==1.1.1
|
| 137 |
+
rouge_score==0.1.2
|
| 138 |
+
multidict==6.7.0
|
| 139 |
+
tcolorpy==0.1.7
|
| 140 |
+
nvidia-curand-cu12==10.3.5.147
|
| 141 |
+
pydantic_core==2.41.5
|
| 142 |
+
pytablewriter==1.2.1
|
| 143 |
+
charset-normalizer==3.4.4
|
| 144 |
+
transformers==4.51.3
|
| 145 |
+
word2number==1.1
|
| 146 |
+
jsonlines==4.0.0
|
| 147 |
+
stack_data==0.6.3
|
| 148 |
+
urllib3==2.6.2
|
| 149 |
+
decorator==5.2.1
|
| 150 |
+
python-dateutil==2.9.0.post0
|
| 151 |
+
pyarrow==22.0.0
|
| 152 |
+
traitlets==5.14.3
|
| 153 |
+
GitPython==3.1.45
|
| 154 |
+
tomli==2.0.1
|
| 155 |
+
more-itertools==10.3.0
|
| 156 |
+
inflect==7.3.1
|
| 157 |
+
zipp==3.19.2
|
| 158 |
+
jaraco.functools==4.0.1
|
| 159 |
+
autocommand==2.2.2
|
| 160 |
+
jaraco.collections==5.1.0
|
| 161 |
+
platformdirs==4.2.2
|
| 162 |
+
backports.tarfile==1.2.0
|
| 163 |
+
importlib_metadata==8.0.0
|
| 164 |
+
jaraco.text==3.12.1
|
| 165 |
+
typing_extensions==4.12.2
|
| 166 |
+
jaraco.context==5.3.0
|
| 167 |
+
typeguard==4.3.0
|
| 168 |
+
packaging==24.2
|
| 169 |
+
wheel==0.45.1
|
tb/20260101-0922/wandb/run-20260101_092219--dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202601010919/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-6.8.0-85-generic-x86_64-with-glibc2.39",
|
| 3 |
+
"python": "CPython 3.12.12",
|
| 4 |
+
"startedAt": "2026-01-01T09:22:19.321743Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--job.config_file",
|
| 7 |
+
"flame/models/fla.toml",
|
| 8 |
+
"--job.dump_folder",
|
| 9 |
+
"exp/dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine",
|
| 10 |
+
"--model.config",
|
| 11 |
+
"configs/dsmtp_transformer_7B.json",
|
| 12 |
+
"--model.tokenizer_path",
|
| 13 |
+
"fla-hub/transformer-1.3B-100B",
|
| 14 |
+
"--optimizer.name",
|
| 15 |
+
"AdamW",
|
| 16 |
+
"--optimizer.eps",
|
| 17 |
+
"1e-15",
|
| 18 |
+
"--optimizer.lr",
|
| 19 |
+
"2e-5",
|
| 20 |
+
"--lr_scheduler.warmup_steps",
|
| 21 |
+
"400",
|
| 22 |
+
"--lr_scheduler.lr_min",
|
| 23 |
+
"0.1",
|
| 24 |
+
"--lr_scheduler.decay_type",
|
| 25 |
+
"cosine",
|
| 26 |
+
"--training.batch_size",
|
| 27 |
+
"8",
|
| 28 |
+
"--training.seq_len",
|
| 29 |
+
"4096",
|
| 30 |
+
"--training.context_len",
|
| 31 |
+
"4096",
|
| 32 |
+
"--training.gradient_accumulation_steps",
|
| 33 |
+
"2",
|
| 34 |
+
"--training.steps",
|
| 35 |
+
"40000",
|
| 36 |
+
"--training.max_norm",
|
| 37 |
+
"1.0",
|
| 38 |
+
"--training.skip_nan_inf",
|
| 39 |
+
"--training.dataset",
|
| 40 |
+
"/root/.cache/zaydzuhri___stack-edu-python/default",
|
| 41 |
+
"--training.dataset_split",
|
| 42 |
+
"train",
|
| 43 |
+
"--training.num_workers",
|
| 44 |
+
"32",
|
| 45 |
+
"--training.prefetch_factor",
|
| 46 |
+
"2",
|
| 47 |
+
"--training.seed",
|
| 48 |
+
"79",
|
| 49 |
+
"--training.compile",
|
| 50 |
+
"--checkpoint.interval",
|
| 51 |
+
"8000",
|
| 52 |
+
"--checkpoint.load_step",
|
| 53 |
+
"-1",
|
| 54 |
+
"--metrics.log_freq",
|
| 55 |
+
"5",
|
| 56 |
+
"--checkpoint.hf_upload_enabled",
|
| 57 |
+
"--checkpoint.hf_repo_base_name",
|
| 58 |
+
"zaydzuhri/dsmtp-code-7B-4096-batch8x2-steps40000",
|
| 59 |
+
"--comm.init_timeout_seconds",
|
| 60 |
+
"6000",
|
| 61 |
+
"--comm.train_timeout_seconds",
|
| 62 |
+
"6000"
|
| 63 |
+
],
|
| 64 |
+
"program": "-m flame.train",
|
| 65 |
+
"git": {
|
| 66 |
+
"remote": "https://github.com/zaydzuhri/flame.git",
|
| 67 |
+
"commit": "5bcd6b6423606e07b92dd2644ecc24d908d2c7a4"
|
| 68 |
+
},
|
| 69 |
+
"email": "zaydzuhri@gmail.com",
|
| 70 |
+
"root": "exp/dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20260101-0922",
|
| 71 |
+
"host": "rentals-6z3zwezo0sfapf3y-697b4fc787-gh86h",
|
| 72 |
+
"executable": "/root/miniconda3/envs/flame-env/bin/python3.12",
|
| 73 |
+
"cpu_count": 64,
|
| 74 |
+
"cpu_count_logical": 128,
|
| 75 |
+
"gpu": "NVIDIA H200",
|
| 76 |
+
"gpu_count": 8,
|
| 77 |
+
"disk": {
|
| 78 |
+
"/": {
|
| 79 |
+
"total": "3246163542016",
|
| 80 |
+
"used": "1652645769216"
|
| 81 |
+
}
|
| 82 |
+
},
|
| 83 |
+
"memory": {
|
| 84 |
+
"total": "1913835118592"
|
| 85 |
+
},
|
| 86 |
+
"gpu_nvidia": [
|
| 87 |
+
{
|
| 88 |
+
"name": "NVIDIA H200",
|
| 89 |
+
"memoryTotal": "150754820096",
|
| 90 |
+
"cudaCores": 16896,
|
| 91 |
+
"architecture": "Hopper",
|
| 92 |
+
"uuid": "GPU-bf7aa6f4-2ee0-0ff7-3851-1b40dafcde1f"
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"name": "NVIDIA H200",
|
| 96 |
+
"memoryTotal": "150754820096",
|
| 97 |
+
"cudaCores": 16896,
|
| 98 |
+
"architecture": "Hopper",
|
| 99 |
+
"uuid": "GPU-24e3a14c-3196-7560-5e54-cd031aa25f76"
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"name": "NVIDIA H200",
|
| 103 |
+
"memoryTotal": "150754820096",
|
| 104 |
+
"cudaCores": 16896,
|
| 105 |
+
"architecture": "Hopper",
|
| 106 |
+
"uuid": "GPU-3e484efe-97e7-7b5b-e6b7-d1dc17ed2765"
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"name": "NVIDIA H200",
|
| 110 |
+
"memoryTotal": "150754820096",
|
| 111 |
+
"cudaCores": 16896,
|
| 112 |
+
"architecture": "Hopper",
|
| 113 |
+
"uuid": "GPU-7b9f4a41-11cd-03b7-0065-5e4dab09ddd4"
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"name": "NVIDIA H200",
|
| 117 |
+
"memoryTotal": "150754820096",
|
| 118 |
+
"cudaCores": 16896,
|
| 119 |
+
"architecture": "Hopper",
|
| 120 |
+
"uuid": "GPU-3f34c938-ca85-e68c-f501-59e20f64e14c"
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"name": "NVIDIA H200",
|
| 124 |
+
"memoryTotal": "150754820096",
|
| 125 |
+
"cudaCores": 16896,
|
| 126 |
+
"architecture": "Hopper",
|
| 127 |
+
"uuid": "GPU-9d38c94f-6fc0-6735-7f0d-3e359e2562cb"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"name": "NVIDIA H200",
|
| 131 |
+
"memoryTotal": "150754820096",
|
| 132 |
+
"cudaCores": 16896,
|
| 133 |
+
"architecture": "Hopper",
|
| 134 |
+
"uuid": "GPU-a7fc49a5-17e0-6e8f-feee-4ffa6e526637"
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"name": "NVIDIA H200",
|
| 138 |
+
"memoryTotal": "150754820096",
|
| 139 |
+
"cudaCores": 16896,
|
| 140 |
+
"architecture": "Hopper",
|
| 141 |
+
"uuid": "GPU-731d225c-6930-54d4-db0f-f53e16eaeb2e"
|
| 142 |
+
}
|
| 143 |
+
],
|
| 144 |
+
"cudaVersion": "12.8",
|
| 145 |
+
"writerId": "allj4tgslt7j35odul2j6cl35jpfxh7k"
|
| 146 |
+
}
|
torchtitan/experiments/deepseek_v3/attn_mask_utils.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This code is based on src/transformers/modeling_attn_mask_utils.py of
|
| 8 |
+
# huggingface/transformers. It has been modified from its original forms to
|
| 9 |
+
# contain only the necessary utilities.
|
| 10 |
+
|
| 11 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 12 |
+
#
|
| 13 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 14 |
+
# you may not use this file except in compliance with the License.
|
| 15 |
+
# You may obtain a copy of the License at
|
| 16 |
+
#
|
| 17 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 18 |
+
#
|
| 19 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 20 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 21 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 22 |
+
# See the License for the specific language governing permissions and
|
| 23 |
+
# limitations under the License.
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from typing import List, Optional, Tuple, Union
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class AttentionMaskConverter:
|
| 32 |
+
"""
|
| 33 |
+
A utility attention mask class that allows one to:
|
| 34 |
+
- Create a causal 4d mask
|
| 35 |
+
- Create a causal 4d mask with slided window
|
| 36 |
+
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
|
| 37 |
+
key_value_length) that can be multiplied with attention scores
|
| 38 |
+
|
| 39 |
+
Examples:
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
>>> import torch
|
| 43 |
+
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 44 |
+
|
| 45 |
+
>>> converter = AttentionMaskConverter(True)
|
| 46 |
+
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
|
| 47 |
+
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 48 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 49 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 50 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
|
| 51 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Parameters:
|
| 55 |
+
is_causal (`bool`):
|
| 56 |
+
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
|
| 57 |
+
|
| 58 |
+
sliding_window (`int`, *optional*):
|
| 59 |
+
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
is_causal: bool
|
| 63 |
+
sliding_window: int
|
| 64 |
+
|
| 65 |
+
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
|
| 66 |
+
self.is_causal = is_causal
|
| 67 |
+
self.sliding_window = sliding_window
|
| 68 |
+
|
| 69 |
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
| 70 |
+
raise ValueError(
|
| 71 |
+
"Make sure that when passing `sliding_window` that its value is a strictly positive integer, "
|
| 72 |
+
f"not `{self.sliding_window}`"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def to_causal_4d(
|
| 76 |
+
self,
|
| 77 |
+
batch_size: int,
|
| 78 |
+
query_length: int,
|
| 79 |
+
key_value_length: int,
|
| 80 |
+
dtype: torch.dtype,
|
| 81 |
+
device: Union[torch.device, "str"] = "cpu",
|
| 82 |
+
) -> Optional[torch.Tensor]:
|
| 83 |
+
"""
|
| 84 |
+
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
|
| 85 |
+
bias to upper right hand triangular matrix (causal mask).
|
| 86 |
+
"""
|
| 87 |
+
if not self.is_causal:
|
| 88 |
+
raise ValueError(
|
| 89 |
+
f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# If shape is not cached, create a new causal mask and cache it
|
| 93 |
+
input_shape = (batch_size, query_length)
|
| 94 |
+
past_key_values_length = key_value_length - query_length
|
| 95 |
+
|
| 96 |
+
# create causal mask
|
| 97 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 98 |
+
causal_4d_mask = None
|
| 99 |
+
if input_shape[-1] > 1 or self.sliding_window is not None:
|
| 100 |
+
causal_4d_mask = self._make_causal_mask(
|
| 101 |
+
input_shape,
|
| 102 |
+
dtype,
|
| 103 |
+
device=device,
|
| 104 |
+
past_key_values_length=past_key_values_length,
|
| 105 |
+
sliding_window=self.sliding_window,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
return causal_4d_mask
|
| 109 |
+
|
| 110 |
+
def to_4d(
|
| 111 |
+
self,
|
| 112 |
+
attention_mask_2d: torch.Tensor,
|
| 113 |
+
query_length: int,
|
| 114 |
+
dtype: torch.dtype,
|
| 115 |
+
key_value_length: Optional[int] = None,
|
| 116 |
+
) -> torch.Tensor:
|
| 117 |
+
"""
|
| 118 |
+
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
| 119 |
+
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
|
| 120 |
+
causal, a causal mask will be added.
|
| 121 |
+
"""
|
| 122 |
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
| 123 |
+
|
| 124 |
+
# create causal mask
|
| 125 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 126 |
+
causal_4d_mask = None
|
| 127 |
+
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
| 128 |
+
if key_value_length is None:
|
| 129 |
+
raise ValueError(
|
| 130 |
+
"This attention mask converter is causal. Make sure to pass "
|
| 131 |
+
"`key_value_length` to correctly create a causal mask."
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
past_key_values_length = key_value_length - query_length
|
| 135 |
+
causal_4d_mask = self._make_causal_mask(
|
| 136 |
+
input_shape,
|
| 137 |
+
dtype,
|
| 138 |
+
device=attention_mask_2d.device,
|
| 139 |
+
past_key_values_length=past_key_values_length,
|
| 140 |
+
sliding_window=self.sliding_window,
|
| 141 |
+
)
|
| 142 |
+
elif self.sliding_window is not None:
|
| 143 |
+
raise NotImplementedError(
|
| 144 |
+
"Sliding window is currently only implemented for causal masking"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 148 |
+
expanded_attn_mask = self._expand_mask(
|
| 149 |
+
attention_mask_2d, dtype, tgt_len=input_shape[-1]
|
| 150 |
+
).to(attention_mask_2d.device)
|
| 151 |
+
|
| 152 |
+
if causal_4d_mask is not None:
|
| 153 |
+
expanded_attn_mask = causal_4d_mask.masked_fill(
|
| 154 |
+
expanded_attn_mask.bool(), torch.finfo(dtype).min
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# expanded_attn_mask + causal_4d_mask can cause some overflow
|
| 158 |
+
expanded_4d_mask = expanded_attn_mask
|
| 159 |
+
|
| 160 |
+
return expanded_4d_mask
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _make_causal_mask(
|
| 164 |
+
input_ids_shape: torch.Size,
|
| 165 |
+
dtype: torch.dtype,
|
| 166 |
+
device: torch.device,
|
| 167 |
+
past_key_values_length: int = 0,
|
| 168 |
+
sliding_window: Optional[int] = None,
|
| 169 |
+
):
|
| 170 |
+
"""
|
| 171 |
+
Make causal mask used for bi-directional self-attention.
|
| 172 |
+
"""
|
| 173 |
+
bsz, tgt_len = input_ids_shape
|
| 174 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
| 175 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
| 176 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 177 |
+
|
| 178 |
+
mask = mask.to(dtype)
|
| 179 |
+
|
| 180 |
+
if past_key_values_length > 0:
|
| 181 |
+
mask = torch.cat(
|
| 182 |
+
[
|
| 183 |
+
torch.zeros(
|
| 184 |
+
tgt_len, past_key_values_length, dtype=dtype, device=device
|
| 185 |
+
),
|
| 186 |
+
mask,
|
| 187 |
+
],
|
| 188 |
+
dim=-1,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# add lower triangular sliding window mask if necessary
|
| 192 |
+
if sliding_window is not None:
|
| 193 |
+
diagonal = past_key_values_length - sliding_window - 1
|
| 194 |
+
|
| 195 |
+
context_mask = torch.tril(
|
| 196 |
+
torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
|
| 197 |
+
)
|
| 198 |
+
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
|
| 199 |
+
|
| 200 |
+
return mask[None, None, :, :].expand(
|
| 201 |
+
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
def _expand_mask(
|
| 206 |
+
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
| 207 |
+
):
|
| 208 |
+
"""
|
| 209 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 210 |
+
"""
|
| 211 |
+
bsz, src_len = mask.size()
|
| 212 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 213 |
+
|
| 214 |
+
expanded_mask = (
|
| 215 |
+
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
inverted_mask = 1.0 - expanded_mask
|
| 219 |
+
|
| 220 |
+
return inverted_mask.masked_fill(
|
| 221 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
@staticmethod
|
| 225 |
+
def _unmask_unattended(
|
| 226 |
+
expanded_mask: torch.FloatTensor,
|
| 227 |
+
min_dtype: float,
|
| 228 |
+
):
|
| 229 |
+
# fmt: off
|
| 230 |
+
"""
|
| 231 |
+
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
|
| 232 |
+
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 233 |
+
Details: https://github.com/pytorch/pytorch/issues/110213
|
| 234 |
+
|
| 235 |
+
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
|
| 236 |
+
`attention_mask` is [bsz, src_seq_len].
|
| 237 |
+
|
| 238 |
+
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case
|
| 239 |
+
of alibi attention bias.
|
| 240 |
+
|
| 241 |
+
For example, if `expanded_mask` is (e.g. here left-padding case)
|
| 242 |
+
```
|
| 243 |
+
[[[[0, 0, 0],
|
| 244 |
+
[0, 0, 0],
|
| 245 |
+
[0, 0, 1]]],
|
| 246 |
+
[[[1, 0, 0],
|
| 247 |
+
[1, 1, 0],
|
| 248 |
+
[1, 1, 1]]],
|
| 249 |
+
[[[0, 0, 0],
|
| 250 |
+
[0, 1, 0],
|
| 251 |
+
[0, 1, 1]]]]
|
| 252 |
+
```
|
| 253 |
+
then the modified `expanded_mask` will be
|
| 254 |
+
```
|
| 255 |
+
[[[[1, 1, 1], <-- modified
|
| 256 |
+
[1, 1, 1], <-- modified
|
| 257 |
+
[0, 0, 1]]],
|
| 258 |
+
[[[1, 0, 0],
|
| 259 |
+
[1, 1, 0],
|
| 260 |
+
[1, 1, 1]]],
|
| 261 |
+
[[[1, 1, 1], <-- modified
|
| 262 |
+
[0, 1, 0],
|
| 263 |
+
[0, 1, 1]]]]
|
| 264 |
+
```
|
| 265 |
+
"""
|
| 266 |
+
# fmt: on
|
| 267 |
+
if expanded_mask.dtype == torch.bool:
|
| 268 |
+
raise ValueError(
|
| 269 |
+
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
return expanded_mask.mul(
|
| 273 |
+
~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
def _ignore_causal_mask_sdpa(
|
| 278 |
+
attention_mask: Optional[torch.Tensor],
|
| 279 |
+
inputs_embeds: torch.Tensor,
|
| 280 |
+
past_key_values_length: int,
|
| 281 |
+
sliding_window: Optional[int] = None,
|
| 282 |
+
is_training: bool = False,
|
| 283 |
+
) -> bool:
|
| 284 |
+
"""
|
| 285 |
+
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
|
| 286 |
+
ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
|
| 287 |
+
|
| 288 |
+
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
|
| 289 |
+
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
|
| 290 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
|
| 291 |
+
passed).
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
| 295 |
+
key_value_length = query_length + past_key_values_length
|
| 296 |
+
|
| 297 |
+
is_tracing = (
|
| 298 |
+
torch.jit.is_tracing()
|
| 299 |
+
or isinstance(inputs_embeds, torch.fx.Proxy)
|
| 300 |
+
or is_torchdynamo_compiling()
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
ignore_causal_mask = False
|
| 304 |
+
|
| 305 |
+
if attention_mask is None:
|
| 306 |
+
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
|
| 307 |
+
# shape, thus SDPA's `is_causal` argument is rightfully updated
|
| 308 |
+
# (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
|
| 309 |
+
# `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
|
| 310 |
+
# hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
|
| 311 |
+
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
|
| 312 |
+
# Thus, we only set `ignore_causal_mask = True` if the model is set to training.
|
| 313 |
+
#
|
| 314 |
+
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
|
| 315 |
+
# ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
|
| 316 |
+
if (
|
| 317 |
+
(is_training or not is_tracing)
|
| 318 |
+
and (query_length == 1 or key_value_length == query_length)
|
| 319 |
+
and (sliding_window is None or key_value_length < sliding_window)
|
| 320 |
+
):
|
| 321 |
+
ignore_causal_mask = True
|
| 322 |
+
elif sliding_window is None or key_value_length < sliding_window:
|
| 323 |
+
if len(attention_mask.shape) == 4:
|
| 324 |
+
return False
|
| 325 |
+
elif not is_tracing and torch.all(attention_mask == 1):
|
| 326 |
+
if query_length == 1 or key_value_length == query_length:
|
| 327 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
| 328 |
+
ignore_causal_mask = True
|
| 329 |
+
|
| 330 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
|
| 331 |
+
# the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
|
| 332 |
+
# SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
| 333 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
| 334 |
+
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
|
| 335 |
+
|
| 336 |
+
return ignore_causal_mask
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def _prepare_4d_causal_attention_mask(
|
| 340 |
+
attention_mask: Optional[torch.Tensor],
|
| 341 |
+
input_shape: Union[torch.Size, Tuple, List],
|
| 342 |
+
inputs_embeds: torch.Tensor,
|
| 343 |
+
past_key_values_length: int,
|
| 344 |
+
sliding_window: Optional[int] = None,
|
| 345 |
+
):
|
| 346 |
+
"""
|
| 347 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 348 |
+
`(batch_size, key_value_length)`
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
attention_mask (`torch.Tensor` or `None`):
|
| 352 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
| 353 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
| 354 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
| 355 |
+
inputs_embeds (`torch.Tensor`):
|
| 356 |
+
The embedded inputs as a torch Tensor.
|
| 357 |
+
past_key_values_length (`int`):
|
| 358 |
+
The length of the key value cache.
|
| 359 |
+
sliding_window (`int`, *optional*):
|
| 360 |
+
If the model uses windowed attention, a sliding window should be passed.
|
| 361 |
+
"""
|
| 362 |
+
attn_mask_converter = AttentionMaskConverter(
|
| 363 |
+
is_causal=True, sliding_window=sliding_window
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
| 367 |
+
|
| 368 |
+
# 4d mask is passed through the layers
|
| 369 |
+
if attention_mask is not None and len(attention_mask.shape) == 2:
|
| 370 |
+
attention_mask = attn_mask_converter.to_4d(
|
| 371 |
+
attention_mask,
|
| 372 |
+
input_shape[-1],
|
| 373 |
+
key_value_length=key_value_length,
|
| 374 |
+
dtype=inputs_embeds.dtype,
|
| 375 |
+
)
|
| 376 |
+
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
| 377 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
| 378 |
+
if tuple(attention_mask.shape) != expected_shape:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
| 381 |
+
)
|
| 382 |
+
else:
|
| 383 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
| 384 |
+
inverted_mask = 1.0 - attention_mask
|
| 385 |
+
attention_mask = inverted_mask.masked_fill(
|
| 386 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
| 387 |
+
)
|
| 388 |
+
else:
|
| 389 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
| 390 |
+
input_shape[0],
|
| 391 |
+
input_shape[-1],
|
| 392 |
+
key_value_length,
|
| 393 |
+
dtype=inputs_embeds.dtype,
|
| 394 |
+
device=inputs_embeds.device,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
return attention_mask
|
torchtitan/experiments/deepseek_v3/inference.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#!/usr/bin/bash
|
| 3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
|
| 6 |
+
# This source code is licensed under the BSD-style license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
NGPU=${NGPU:-"4"}
|
| 10 |
+
|
| 11 |
+
# Get the prompt from command line argument or use a default
|
| 12 |
+
prompt="${1:-What is 2+2?}"
|
| 13 |
+
|
| 14 |
+
# Run the model with the prompt
|
| 15 |
+
torchrun --standalone --nproc-per-node ${NGPU} generate.py "$prompt"
|
torchtitan/experiments/deepseek_v3/model.py
ADDED
|
@@ -0,0 +1,1325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This code is based on model definition of `deepseek-ai/DeepSeek-V3-Base` on
|
| 8 |
+
# Hugging Face Model Hub. Url:
|
| 9 |
+
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py
|
| 10 |
+
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/resolve/main/configuration_deepseek.py
|
| 11 |
+
#
|
| 12 |
+
# It has been modified from its original forms to accommodate naming convention
|
| 13 |
+
# and usage patterns of the TorchTitan project.
|
| 14 |
+
|
| 15 |
+
# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
| 16 |
+
#
|
| 17 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 18 |
+
# you may not use this file except in compliance with the License.
|
| 19 |
+
# You may obtain a copy of the License at
|
| 20 |
+
#
|
| 21 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 22 |
+
#
|
| 23 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 24 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 25 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 26 |
+
# See the License for the specific language governing permissions and
|
| 27 |
+
# limitations under the License.
|
| 28 |
+
""" PyTorch DeepSeek model."""
|
| 29 |
+
import math
|
| 30 |
+
from typing import Optional, Tuple
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.distributed as dist
|
| 34 |
+
|
| 35 |
+
import torch.distributed._symmetric_memory as symm_mem
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
import torch.utils.checkpoint
|
| 38 |
+
|
| 39 |
+
from attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 40 |
+
from indices import generate_permute_indices
|
| 41 |
+
from model_config import ModelArgs
|
| 42 |
+
from symm_mem_recipes import OnDeviceAllToAllV
|
| 43 |
+
from torch import nn
|
| 44 |
+
from torch.distributed._functional_collectives import all_to_all_single_autograd
|
| 45 |
+
|
| 46 |
+
from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import (
|
| 47 |
+
ALIGN_SIZE_M,
|
| 48 |
+
grouped_gemm_forward,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Get model parallel subgroup by name:
|
| 52 |
+
# e.g. "pp", "ep", None
|
| 53 |
+
def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup:
|
| 54 |
+
glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh()
|
| 55 |
+
return glob.get_group(dim_name)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class RMSNorm(nn.Module):
|
| 59 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 62 |
+
self.variance_epsilon = eps
|
| 63 |
+
|
| 64 |
+
def forward(self, hidden_states):
|
| 65 |
+
input_dtype = hidden_states.dtype
|
| 66 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 67 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 68 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 69 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RotaryEmbedding(nn.Module):
|
| 73 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
self.dim = dim
|
| 77 |
+
self.max_position_embeddings = max_position_embeddings
|
| 78 |
+
self.base = base
|
| 79 |
+
inv_freq = 1.0 / (
|
| 80 |
+
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
| 81 |
+
)
|
| 82 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 83 |
+
|
| 84 |
+
# Build here to make `torch.jit.trace` work.
|
| 85 |
+
self._set_cos_sin_cache(
|
| 86 |
+
seq_len=max_position_embeddings,
|
| 87 |
+
device=self.inv_freq.device,
|
| 88 |
+
dtype=torch.get_default_dtype(),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 92 |
+
self.max_seq_len_cached = seq_len
|
| 93 |
+
t = torch.arange(
|
| 94 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
freqs = torch.outer(t, self.inv_freq.to(t.device))
|
| 98 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 99 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 100 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 101 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 102 |
+
|
| 103 |
+
def forward(self, x, seq_len=None):
|
| 104 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 105 |
+
if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
|
| 106 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 107 |
+
|
| 108 |
+
return (
|
| 109 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 110 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
| 115 |
+
"""RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
dim,
|
| 120 |
+
max_position_embeddings=2048,
|
| 121 |
+
base=10000,
|
| 122 |
+
device=None,
|
| 123 |
+
scaling_factor=1.0,
|
| 124 |
+
):
|
| 125 |
+
self.scaling_factor = scaling_factor
|
| 126 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 127 |
+
|
| 128 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 129 |
+
self.max_seq_len_cached = seq_len
|
| 130 |
+
t = torch.arange(
|
| 131 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
| 132 |
+
)
|
| 133 |
+
t = t / self.scaling_factor
|
| 134 |
+
|
| 135 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 136 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 137 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 138 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 139 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek
|
| 143 |
+
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
| 144 |
+
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
dim,
|
| 149 |
+
max_position_embeddings=2048,
|
| 150 |
+
base=10000,
|
| 151 |
+
device=None,
|
| 152 |
+
scaling_factor=1.0,
|
| 153 |
+
):
|
| 154 |
+
self.scaling_factor = scaling_factor
|
| 155 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 156 |
+
|
| 157 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 158 |
+
self.max_seq_len_cached = seq_len
|
| 159 |
+
|
| 160 |
+
if seq_len > self.max_position_embeddings:
|
| 161 |
+
base = self.base * (
|
| 162 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings)
|
| 163 |
+
- (self.scaling_factor - 1)
|
| 164 |
+
) ** (self.dim / (self.dim - 2))
|
| 165 |
+
inv_freq = 1.0 / (
|
| 166 |
+
base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
| 167 |
+
)
|
| 168 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 169 |
+
|
| 170 |
+
t = torch.arange(
|
| 171 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 175 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 176 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 177 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 178 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# Inverse dim formula to find dim based on number of rotations
|
| 182 |
+
def yarn_find_correction_dim(
|
| 183 |
+
num_rotations, dim, base=10000, max_position_embeddings=2048
|
| 184 |
+
):
|
| 185 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
| 186 |
+
2 * math.log(base)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# Find dim range bounds based on rotations
|
| 191 |
+
def yarn_find_correction_range(
|
| 192 |
+
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
| 193 |
+
):
|
| 194 |
+
low = math.floor(
|
| 195 |
+
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
| 196 |
+
)
|
| 197 |
+
high = math.ceil(
|
| 198 |
+
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
| 199 |
+
)
|
| 200 |
+
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def yarn_get_mscale(scale=1, mscale=1):
|
| 204 |
+
if scale <= 1:
|
| 205 |
+
return 1.0
|
| 206 |
+
return 0.1 * mscale * math.log(scale) + 1.0
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def yarn_linear_ramp_mask(min, max, dim):
|
| 210 |
+
if min == max:
|
| 211 |
+
max += 0.001 # Prevent singularity
|
| 212 |
+
|
| 213 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 214 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 215 |
+
return ramp_func
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class YarnRotaryEmbedding(RotaryEmbedding):
|
| 219 |
+
def __init__(
|
| 220 |
+
self,
|
| 221 |
+
dim,
|
| 222 |
+
max_position_embeddings=2048,
|
| 223 |
+
base=10000,
|
| 224 |
+
device=None,
|
| 225 |
+
scaling_factor=1.0,
|
| 226 |
+
original_max_position_embeddings=4096,
|
| 227 |
+
beta_fast=32,
|
| 228 |
+
beta_slow=1,
|
| 229 |
+
mscale=1,
|
| 230 |
+
mscale_all_dim=0,
|
| 231 |
+
):
|
| 232 |
+
self.scaling_factor = scaling_factor
|
| 233 |
+
self.original_max_position_embeddings = original_max_position_embeddings
|
| 234 |
+
self.beta_fast = beta_fast
|
| 235 |
+
self.beta_slow = beta_slow
|
| 236 |
+
self.mscale = mscale
|
| 237 |
+
self.mscale_all_dim = mscale_all_dim
|
| 238 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 239 |
+
|
| 240 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 241 |
+
self.max_seq_len_cached = seq_len
|
| 242 |
+
dim = self.dim
|
| 243 |
+
|
| 244 |
+
freq_extra = 1.0 / (
|
| 245 |
+
self.base
|
| 246 |
+
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
| 247 |
+
)
|
| 248 |
+
freq_inter = 1.0 / (
|
| 249 |
+
self.scaling_factor
|
| 250 |
+
* self.base
|
| 251 |
+
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
low, high = yarn_find_correction_range(
|
| 255 |
+
self.beta_fast,
|
| 256 |
+
self.beta_slow,
|
| 257 |
+
dim,
|
| 258 |
+
self.base,
|
| 259 |
+
self.original_max_position_embeddings,
|
| 260 |
+
)
|
| 261 |
+
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
| 262 |
+
device=device, dtype=torch.float32
|
| 263 |
+
)
|
| 264 |
+
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
| 265 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 266 |
+
|
| 267 |
+
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
| 268 |
+
|
| 269 |
+
freqs = torch.outer(t, inv_freq)
|
| 270 |
+
|
| 271 |
+
_mscale = float(
|
| 272 |
+
yarn_get_mscale(self.scaling_factor, self.mscale)
|
| 273 |
+
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 277 |
+
self.register_buffer(
|
| 278 |
+
"cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
|
| 279 |
+
)
|
| 280 |
+
self.register_buffer(
|
| 281 |
+
"sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 286 |
+
def rotate_half(x):
|
| 287 |
+
"""Rotates half the hidden dims of the input."""
|
| 288 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 289 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 290 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 294 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 295 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
q (`torch.Tensor`): The query tensor.
|
| 299 |
+
k (`torch.Tensor`): The key tensor.
|
| 300 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 301 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 302 |
+
position_ids (`torch.Tensor`):
|
| 303 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 304 |
+
used to pass offsetted position ids when working with a KV-cache.
|
| 305 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 306 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 307 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 308 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 309 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 310 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 311 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 312 |
+
Returns:
|
| 313 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 314 |
+
"""
|
| 315 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 316 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 317 |
+
|
| 318 |
+
b, h, s, d = q.shape
|
| 319 |
+
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 320 |
+
|
| 321 |
+
b, h, s, d = k.shape
|
| 322 |
+
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 323 |
+
|
| 324 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 325 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 326 |
+
return q_embed, k_embed
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class MLP(nn.Module):
|
| 330 |
+
act_fn = nn.SiLU()
|
| 331 |
+
|
| 332 |
+
def __init__(self, config, hidden_size=None, intermediate_size=None):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.config = config
|
| 335 |
+
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
| 336 |
+
self.intermediate_size = (
|
| 337 |
+
config.intermediate_size if intermediate_size is None else intermediate_size
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 341 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 342 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 343 |
+
|
| 344 |
+
def forward(self, x):
|
| 345 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 346 |
+
return down_proj
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class MoEGate(nn.Module):
|
| 350 |
+
def __init__(self, config):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.config = config
|
| 353 |
+
self.top_k = config.num_experts_per_tok
|
| 354 |
+
self.n_routed_experts = config.n_routed_experts
|
| 355 |
+
self.routed_scaling_factor = config.routed_scaling_factor
|
| 356 |
+
self.scoring_func = config.scoring_func
|
| 357 |
+
self.seq_aux = config.seq_aux
|
| 358 |
+
self.topk_method = config.topk_method
|
| 359 |
+
self.n_group = config.n_group
|
| 360 |
+
self.topk_group = config.topk_group
|
| 361 |
+
|
| 362 |
+
# topk selection algorithm
|
| 363 |
+
self.norm_topk_prob = config.norm_topk_prob
|
| 364 |
+
self.gating_dim = config.hidden_size
|
| 365 |
+
self.weight = nn.Parameter(
|
| 366 |
+
torch.empty((self.n_routed_experts, self.gating_dim))
|
| 367 |
+
)
|
| 368 |
+
if self.topk_method == "noaux_tc":
|
| 369 |
+
self.e_score_correction_bias = nn.Parameter(
|
| 370 |
+
# Changed from torch.empty to torch.rand to avoid non-even
|
| 371 |
+
# distribution for runs without actual weigths
|
| 372 |
+
torch.rand((self.n_routed_experts))
|
| 373 |
+
)
|
| 374 |
+
self.reset_parameters()
|
| 375 |
+
|
| 376 |
+
def reset_parameters(self) -> None:
|
| 377 |
+
import torch.nn.init as init
|
| 378 |
+
|
| 379 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 380 |
+
|
| 381 |
+
def forward(self, hidden_states):
|
| 382 |
+
bsz, seq_len, h = hidden_states.shape
|
| 383 |
+
# compute gating score
|
| 384 |
+
hidden_states = hidden_states.view(-1, h)
|
| 385 |
+
logits = F.linear(
|
| 386 |
+
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
|
| 387 |
+
)
|
| 388 |
+
if self.scoring_func == "sigmoid":
|
| 389 |
+
scores = logits.sigmoid()
|
| 390 |
+
elif self.scoring_func == "softmax":
|
| 391 |
+
scores = logits.softmax(dim=-1, dtype=torch.float32)
|
| 392 |
+
else:
|
| 393 |
+
raise NotImplementedError(
|
| 394 |
+
f"insupportable scoring function for MoE gating: {self.scoring_func}"
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# select top-k experts
|
| 398 |
+
if self.topk_method == "noaux_tc":
|
| 399 |
+
scores_for_choice = scores.view(
|
| 400 |
+
bsz * seq_len, -1
|
| 401 |
+
) + self.e_score_correction_bias.unsqueeze(0)
|
| 402 |
+
group_scores = (
|
| 403 |
+
scores_for_choice.view(bsz * seq_len, self.n_group, -1)
|
| 404 |
+
.topk(2, dim=-1)[0]
|
| 405 |
+
.sum(dim=-1)
|
| 406 |
+
) # [n, n_group]
|
| 407 |
+
group_idx = torch.topk(
|
| 408 |
+
group_scores, k=self.topk_group, dim=-1, sorted=False
|
| 409 |
+
)[
|
| 410 |
+
1
|
| 411 |
+
] # [n, top_k_group]
|
| 412 |
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
| 413 |
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
| 414 |
+
score_mask = (
|
| 415 |
+
group_mask.unsqueeze(-1)
|
| 416 |
+
.expand(
|
| 417 |
+
bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
|
| 418 |
+
)
|
| 419 |
+
.reshape(bsz * seq_len, -1)
|
| 420 |
+
) # [n, e]
|
| 421 |
+
tmp_scores = scores_for_choice.masked_fill(
|
| 422 |
+
~score_mask.bool(), 0.0
|
| 423 |
+
) # [n, e]
|
| 424 |
+
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
| 425 |
+
topk_weight = scores.gather(1, topk_idx)
|
| 426 |
+
elif self.topk_method == "greedy":
|
| 427 |
+
topk_weight, topk_idx = torch.topk(
|
| 428 |
+
scores, k=self.top_k, dim=-1, sorted=False
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
raise NotImplementedError(
|
| 432 |
+
f"insupportable TopK function for MoE gating: {self.topk_method}"
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# norm gate to sum 1
|
| 436 |
+
if self.top_k > 1 and self.norm_topk_prob:
|
| 437 |
+
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 438 |
+
topk_weight = topk_weight / denominator
|
| 439 |
+
topk_weight = (
|
| 440 |
+
topk_weight * self.routed_scaling_factor
|
| 441 |
+
) # must multiply the scaling factor
|
| 442 |
+
|
| 443 |
+
return topk_idx, topk_weight
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class MoE(nn.Module):
|
| 447 |
+
"""
|
| 448 |
+
A mixed expert module containing shared experts.
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
# Class attributes:
|
| 452 |
+
# Two shuffle method supported:
|
| 453 |
+
# 1. "torch_all_to_all"
|
| 454 |
+
# 2. "symm_mem" (see `setup_symm_mem` below)
|
| 455 |
+
shuffle_method = "torch_all_to_all"
|
| 456 |
+
|
| 457 |
+
# Symmetric memory buffers shared by all MoE instances across layers
|
| 458 |
+
token_send_buf: Optional[torch.Tensor] = None
|
| 459 |
+
token_gather_buf: Optional[torch.Tensor] = None
|
| 460 |
+
|
| 461 |
+
def __init__(self, config):
|
| 462 |
+
super().__init__()
|
| 463 |
+
self.config = config
|
| 464 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 465 |
+
|
| 466 |
+
# ep_size is the number of ranks in expert dimension
|
| 467 |
+
if config.ep_size <= 1:
|
| 468 |
+
raise ValueError(
|
| 469 |
+
"For code simplicity, this model only supports distributed experts, "
|
| 470 |
+
"thus EP size must be > 1, please modify your model config"
|
| 471 |
+
)
|
| 472 |
+
self.ep_group = get_group("ep")
|
| 473 |
+
assert config.ep_size == self.ep_group.size()
|
| 474 |
+
self.ep_size = config.ep_size
|
| 475 |
+
self.ep_rank = self.ep_group.rank()
|
| 476 |
+
self.experts_per_rank = config.n_routed_experts // config.ep_size
|
| 477 |
+
# Use ModuleDict instead of ModuleList to preserve absoulte expert
|
| 478 |
+
# IDs while avoiding `None` experts. The absolute expert IDs match
|
| 479 |
+
# with checkpoint FQNs.
|
| 480 |
+
self.experts = nn.ModuleDict()
|
| 481 |
+
for i in range(self.experts_per_rank):
|
| 482 |
+
abs_expert_id = self.ep_rank * self.experts_per_rank + i
|
| 483 |
+
self.experts[str(abs_expert_id)] = MLP(
|
| 484 |
+
config, intermediate_size=config.moe_intermediate_size
|
| 485 |
+
)
|
| 486 |
+
self.gate = MoEGate(config)
|
| 487 |
+
if config.n_shared_experts is not None:
|
| 488 |
+
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
| 489 |
+
self.shared_experts = MLP(
|
| 490 |
+
config=config, intermediate_size=intermediate_size
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def combine_experts(self, submod_name):
|
| 494 |
+
all_weights = []
|
| 495 |
+
for expert in self.experts.values():
|
| 496 |
+
lin = expert.get_submodule(submod_name)
|
| 497 |
+
all_weights.append(lin.weight)
|
| 498 |
+
lin.weight = None
|
| 499 |
+
|
| 500 |
+
concat_weight = torch.cat(all_weights)
|
| 501 |
+
self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight))
|
| 502 |
+
|
| 503 |
+
# This function is used to create a symm mem buffer for MoE's. It is for
|
| 504 |
+
# shuffling tokens fully "on-device", as compared to traditional torch
|
| 505 |
+
# all_to_all APIs which requrie a GPU-to-CPU sync of the splits. If a user
|
| 506 |
+
# calls this function, the `shuffle_method` would switch from
|
| 507 |
+
# `torch_all_to_all` to `symm_mem`.
|
| 508 |
+
def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
|
| 509 |
+
# Switch shuffle method
|
| 510 |
+
self.shuffle_method = "symm_mem"
|
| 511 |
+
|
| 512 |
+
# Combine expert weights
|
| 513 |
+
print("Combining expert weights for Group GEMM")
|
| 514 |
+
self.combine_experts("gate_proj")
|
| 515 |
+
self.combine_experts("up_proj")
|
| 516 |
+
self.combine_experts("down_proj")
|
| 517 |
+
|
| 518 |
+
# Assuming worst case, 2x tokens are routed to one EP rank
|
| 519 |
+
overflow = 2
|
| 520 |
+
OnDeviceAllToAllV.max_output_len = (
|
| 521 |
+
self.config.max_seq_len * self.num_experts_per_tok * overflow
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# Symmetric memory buffers are shared by all MoE instances across
|
| 525 |
+
# layers, we only need to initialize them once
|
| 526 |
+
if MoE.token_send_buf is not None:
|
| 527 |
+
return
|
| 528 |
+
|
| 529 |
+
# Input buffer for DP-to-EP shuffle
|
| 530 |
+
MoE.token_send_buf = symm_mem.empty(
|
| 531 |
+
self.config.max_seq_len
|
| 532 |
+
* self.num_experts_per_tok, # seq len * top k (flattened)
|
| 533 |
+
self.config.hidden_size, # hidden dim
|
| 534 |
+
dtype=dtype,
|
| 535 |
+
device=device,
|
| 536 |
+
)
|
| 537 |
+
# Input buffer for EP-to-DP shuffle
|
| 538 |
+
MoE.token_gather_buf = symm_mem.empty(
|
| 539 |
+
self.config.max_seq_len
|
| 540 |
+
* self.num_experts_per_tok # seq len * top k (flattened)
|
| 541 |
+
* overflow,
|
| 542 |
+
self.config.hidden_size, # hidden dim
|
| 543 |
+
dtype=dtype,
|
| 544 |
+
device=device,
|
| 545 |
+
)
|
| 546 |
+
print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE")
|
| 547 |
+
|
| 548 |
+
def get_send_buf(self):
|
| 549 |
+
# [Why detach?] During a first forward-backward step, the buffer would
|
| 550 |
+
# be included in a computational graph. In a second step, autograd will
|
| 551 |
+
# return an error saying "Trying to backward through the graph a second
|
| 552 |
+
# time (or directly access saved tensors more than once)". This is
|
| 553 |
+
# because the buffer is still in the graph, and autograd is trying to
|
| 554 |
+
# backward through the graph a second time. To avoid this, we detach the
|
| 555 |
+
# buffer from the graph. `detach()` returns a new tensor, which shares
|
| 556 |
+
# the same storage with the original one.
|
| 557 |
+
self.token_send_buf.grad = None
|
| 558 |
+
return self.token_send_buf.detach()
|
| 559 |
+
|
| 560 |
+
def get_gather_buf(self):
|
| 561 |
+
# See [Why detach?] in `get_send_buf`
|
| 562 |
+
self.token_gather_buf.grad = None
|
| 563 |
+
return self.token_gather_buf.detach()
|
| 564 |
+
|
| 565 |
+
def forward(self, hidden_states):
|
| 566 |
+
identity = hidden_states
|
| 567 |
+
orig_shape = hidden_states.shape
|
| 568 |
+
# for each token, select top-k experts, and compute the weight for each expert
|
| 569 |
+
topk_idx, topk_weight = self.gate(hidden_states)
|
| 570 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 571 |
+
if self.shuffle_method == "symm_mem":
|
| 572 |
+
y = self.moe_on_device(hidden_states, topk_idx, topk_weight)
|
| 573 |
+
else: # "torch_all_to_all"
|
| 574 |
+
y = self.moe_forward(hidden_states, topk_idx, topk_weight)
|
| 575 |
+
|
| 576 |
+
y = y.view(*orig_shape)
|
| 577 |
+
if self.config.n_shared_experts is not None:
|
| 578 |
+
y = y + self.shared_experts(identity)
|
| 579 |
+
return y
|
| 580 |
+
|
| 581 |
+
def moe_forward(self, x, topk_ids, topk_weight):
|
| 582 |
+
# This part sorts the token indices so that tokens routed to the same expert reside consecutively.
|
| 583 |
+
# An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
|
| 584 |
+
# Since this is an "aritificial" index creation (final outcome being
|
| 585 |
+
# `idxs`), we don't need gradients here.
|
| 586 |
+
with torch.no_grad():
|
| 587 |
+
# [seq_len, n_routed_experts]
|
| 588 |
+
cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
|
| 589 |
+
# Fill 1 to the selected experts
|
| 590 |
+
cnts.scatter_(1, topk_ids, 1)
|
| 591 |
+
tokens_per_expert = cnts.sum(dim=0)
|
| 592 |
+
# Token indices for each expert
|
| 593 |
+
idxs = topk_ids.view(-1).argsort()
|
| 594 |
+
sorted_tokens_shape = idxs.shape + x.shape[1:]
|
| 595 |
+
|
| 596 |
+
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
| 597 |
+
assert sorted_tokens.shape == sorted_tokens_shape
|
| 598 |
+
|
| 599 |
+
# This part exchange the information about the number of tokens send and
|
| 600 |
+
# received by each expert. We can understand this information as "side
|
| 601 |
+
# band", which is not part of the actual data. Thus no gradient is
|
| 602 |
+
# needed.
|
| 603 |
+
with torch.no_grad():
|
| 604 |
+
# Sum the tokens over local experts, then we get tokens per EP rank,
|
| 605 |
+
# which is the input splits
|
| 606 |
+
tokens_per_expert_group = tokens_per_expert.new_empty(
|
| 607 |
+
tokens_per_expert.shape[0]
|
| 608 |
+
)
|
| 609 |
+
dist.all_to_all_single(
|
| 610 |
+
tokens_per_expert_group, tokens_per_expert, group=self.ep_group
|
| 611 |
+
)
|
| 612 |
+
input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
| 613 |
+
|
| 614 |
+
# DP to EP token shuffle. This part needs gradient.
|
| 615 |
+
if self.shuffle_method == "symm_mem":
|
| 616 |
+
# Move input to the `token_send_buf` symm mem
|
| 617 |
+
token_send_buf = self.get_send_buf()
|
| 618 |
+
token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
|
| 619 |
+
# Note: `out=` avoids copy, but it is not differentiable
|
| 620 |
+
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
|
| 621 |
+
token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
|
| 622 |
+
token_send_buf,
|
| 623 |
+
input_splits,
|
| 624 |
+
self.ep_group,
|
| 625 |
+
)
|
| 626 |
+
with torch.no_grad():
|
| 627 |
+
# Received tokens from all other ranks. TODO: use mask instead
|
| 628 |
+
received = output_splits.sum()
|
| 629 |
+
# TODO: don't use `received`
|
| 630 |
+
gathered_tokens = token_gather_buf[:received]
|
| 631 |
+
else: # "torch_all_to_all"
|
| 632 |
+
# Prepare input ans output splits
|
| 633 |
+
with torch.no_grad():
|
| 634 |
+
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
|
| 635 |
+
dim=1
|
| 636 |
+
)
|
| 637 |
+
gathered_tokens = all_to_all_single_autograd(
|
| 638 |
+
sorted_tokens,
|
| 639 |
+
output_splits.tolist(),
|
| 640 |
+
input_splits.tolist(),
|
| 641 |
+
self.ep_group,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# This part prepares a 1D tensor with the same length as
|
| 645 |
+
# `gathered_tokens`. The 1D tensor is filled with local expert IDs which
|
| 646 |
+
# the tokens in `gathered_tokens` are headed for. This part doesn't need
|
| 647 |
+
# gradient.
|
| 648 |
+
with torch.no_grad():
|
| 649 |
+
gatherd_idxs = (
|
| 650 |
+
torch.arange(
|
| 651 |
+
tokens_per_expert_group.numel(),
|
| 652 |
+
device=tokens_per_expert_group.device,
|
| 653 |
+
)
|
| 654 |
+
% self.experts_per_rank
|
| 655 |
+
)
|
| 656 |
+
gatherd_idxs = gatherd_idxs.repeat_interleave(tokens_per_expert_group)
|
| 657 |
+
|
| 658 |
+
# Prepare buffer for tokens processed by experts
|
| 659 |
+
if self.shuffle_method == "symm_mem":
|
| 660 |
+
# Take necessary space from `token_gather_buf` symm mem because we are
|
| 661 |
+
# going to send them out after expert processing
|
| 662 |
+
processed_tokens = self.get_gather_buf()[: gathered_tokens.shape[0]]
|
| 663 |
+
else: # "torch_all_to_all"
|
| 664 |
+
processed_tokens = torch.empty_like(gathered_tokens)
|
| 665 |
+
|
| 666 |
+
# This part processes the tokens routed to the local experts.
|
| 667 |
+
# TODO: can we use group GEMM here?
|
| 668 |
+
for i, expert in enumerate(self.experts.values()):
|
| 669 |
+
processed_tokens[gatherd_idxs == i] = expert(
|
| 670 |
+
gathered_tokens[gatherd_idxs == i]
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
# Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
|
| 674 |
+
# The input/output splits are just a reverse of the previous shuffle.
|
| 675 |
+
if self.shuffle_method == "symm_mem":
|
| 676 |
+
token_return_buf, _ = OnDeviceAllToAllV.apply(
|
| 677 |
+
processed_tokens,
|
| 678 |
+
output_splits,
|
| 679 |
+
self.ep_group,
|
| 680 |
+
)
|
| 681 |
+
returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
|
| 682 |
+
else: # "torch_all_to_all"
|
| 683 |
+
returned_tokens = all_to_all_single_autograd(
|
| 684 |
+
processed_tokens,
|
| 685 |
+
input_splits.tolist(),
|
| 686 |
+
output_splits.tolist(),
|
| 687 |
+
self.ep_group,
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
output_tokens = torch.empty_like(returned_tokens)
|
| 691 |
+
output_tokens[idxs] = returned_tokens
|
| 692 |
+
final_out = (
|
| 693 |
+
output_tokens.view(*topk_ids.shape, -1)
|
| 694 |
+
.type(topk_weight.dtype)
|
| 695 |
+
.mul_(topk_weight.unsqueeze(dim=-1))
|
| 696 |
+
.sum(dim=1)
|
| 697 |
+
.type(returned_tokens.dtype)
|
| 698 |
+
)
|
| 699 |
+
return final_out
|
| 700 |
+
|
| 701 |
+
def moe_on_device(self, x, topk_ids, topk_weight):
|
| 702 |
+
# This part sorts the token indices so that tokens routed to the same expert reside consecutively.
|
| 703 |
+
# An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
|
| 704 |
+
# Since this is an "aritificial" index creation (final outcome being
|
| 705 |
+
# `idxs`), we don't need gradients here.
|
| 706 |
+
with torch.no_grad():
|
| 707 |
+
# [seq_len, n_routed_experts]
|
| 708 |
+
cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
|
| 709 |
+
# Fill 1 to the selected experts
|
| 710 |
+
cnts.scatter_(1, topk_ids, 1)
|
| 711 |
+
tokens_per_expert = cnts.sum(dim=0)
|
| 712 |
+
# Token indices for each expert
|
| 713 |
+
idxs = topk_ids.view(-1).argsort()
|
| 714 |
+
sorted_tokens_shape = idxs.shape + x.shape[1:]
|
| 715 |
+
|
| 716 |
+
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
| 717 |
+
assert sorted_tokens.shape == sorted_tokens_shape
|
| 718 |
+
|
| 719 |
+
# This part exchange the information about the number of tokens send and
|
| 720 |
+
# received by each expert. We can understand this information as "side
|
| 721 |
+
# band", which is not part of the actual data. Thus no gradient is
|
| 722 |
+
# needed.
|
| 723 |
+
with torch.no_grad():
|
| 724 |
+
# Sum the tokens over local experts, then we get tokens per EP rank,
|
| 725 |
+
# which is the input splits
|
| 726 |
+
tokens_per_expert_group = tokens_per_expert.new_empty(
|
| 727 |
+
tokens_per_expert.shape[0]
|
| 728 |
+
)
|
| 729 |
+
dist.all_to_all_single(
|
| 730 |
+
tokens_per_expert_group, tokens_per_expert, group=self.ep_group
|
| 731 |
+
)
|
| 732 |
+
input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
| 733 |
+
|
| 734 |
+
# Move input to the `token_send_buf` symm mem
|
| 735 |
+
token_send_buf = self.get_send_buf()
|
| 736 |
+
token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
|
| 737 |
+
# Note: `out=` avoids copy, but it is not differentiable
|
| 738 |
+
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
|
| 739 |
+
token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
|
| 740 |
+
token_send_buf,
|
| 741 |
+
input_splits,
|
| 742 |
+
self.ep_group,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
# We need to permute the received tokens so that tokens for the same expert are contiguous.
|
| 746 |
+
# This part prepares a 1D tensor `permuted_indices` for such permutation.
|
| 747 |
+
# This part doesn't need gradient.
|
| 748 |
+
with torch.no_grad():
|
| 749 |
+
permuted_indices, m_sizes = generate_permute_indices(
|
| 750 |
+
tokens_per_expert_group,
|
| 751 |
+
self.experts_per_rank,
|
| 752 |
+
self.ep_size,
|
| 753 |
+
token_gather_buf.shape[0],
|
| 754 |
+
ALIGN_SIZE_M,
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
# Permute the received tokens so that tokens for the same expert are contiguous.
|
| 758 |
+
contig_tokens = token_gather_buf[permuted_indices]
|
| 759 |
+
|
| 760 |
+
# Run the first grouped GEMM
|
| 761 |
+
w1 = self.get_parameter("gate_proj_weight")
|
| 762 |
+
gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes)
|
| 763 |
+
|
| 764 |
+
# Run the second grouped GEMM
|
| 765 |
+
w3 = self.get_parameter("up_proj_weight")
|
| 766 |
+
up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes)
|
| 767 |
+
|
| 768 |
+
# Apply activation
|
| 769 |
+
hidden_outputs = MLP.act_fn(gate_proj) * up_proj
|
| 770 |
+
|
| 771 |
+
# Run the third grouped GEMM
|
| 772 |
+
w2 = self.get_parameter("down_proj_weight")
|
| 773 |
+
hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes)
|
| 774 |
+
|
| 775 |
+
# Prepare buffer for tokens processed by experts
|
| 776 |
+
# Take necessary space from `token_gather_buf` symm mem because we are
|
| 777 |
+
# going to send them out after expert processing
|
| 778 |
+
processed_tokens = self.get_gather_buf()
|
| 779 |
+
|
| 780 |
+
# Move into Symmetric Memory for the return shuffle
|
| 781 |
+
processed_tokens[permuted_indices] = hidden_outputs
|
| 782 |
+
|
| 783 |
+
# Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
|
| 784 |
+
# The input/output splits are just a reverse of the previous shuffle.
|
| 785 |
+
token_return_buf, _ = OnDeviceAllToAllV.apply(
|
| 786 |
+
processed_tokens,
|
| 787 |
+
output_splits,
|
| 788 |
+
self.ep_group,
|
| 789 |
+
)
|
| 790 |
+
returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
|
| 791 |
+
|
| 792 |
+
output_tokens = torch.empty_like(returned_tokens)
|
| 793 |
+
output_tokens[idxs] = returned_tokens
|
| 794 |
+
final_out = (
|
| 795 |
+
output_tokens.view(*topk_ids.shape, -1)
|
| 796 |
+
.type(topk_weight.dtype)
|
| 797 |
+
.mul_(topk_weight.unsqueeze(dim=-1))
|
| 798 |
+
.sum(dim=1)
|
| 799 |
+
.type(returned_tokens.dtype)
|
| 800 |
+
)
|
| 801 |
+
return final_out
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
class Attention(nn.Module):
|
| 805 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 806 |
+
|
| 807 |
+
def __init__(self, config: ModelArgs, layer_idx: Optional[int] = None):
|
| 808 |
+
super().__init__()
|
| 809 |
+
self.config = config
|
| 810 |
+
self.layer_idx = layer_idx
|
| 811 |
+
self.attention_dropout = config.attention_dropout
|
| 812 |
+
self.hidden_size = config.hidden_size
|
| 813 |
+
self.num_heads = config.num_attention_heads
|
| 814 |
+
|
| 815 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 816 |
+
self.rope_theta = config.rope_theta
|
| 817 |
+
self.q_lora_rank = config.q_lora_rank
|
| 818 |
+
self.qk_rope_head_dim = config.qk_rope_head_dim
|
| 819 |
+
self.kv_lora_rank = config.kv_lora_rank
|
| 820 |
+
self.v_head_dim = config.v_head_dim
|
| 821 |
+
self.qk_nope_head_dim = config.qk_nope_head_dim
|
| 822 |
+
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
| 823 |
+
|
| 824 |
+
self.is_causal = True
|
| 825 |
+
|
| 826 |
+
if self.q_lora_rank is None:
|
| 827 |
+
self.q_proj = nn.Linear(
|
| 828 |
+
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
| 829 |
+
)
|
| 830 |
+
else:
|
| 831 |
+
self.q_a_proj = nn.Linear(
|
| 832 |
+
self.hidden_size, config.q_lora_rank, bias=config.attention_bias
|
| 833 |
+
)
|
| 834 |
+
self.q_a_layernorm = RMSNorm(config.q_lora_rank)
|
| 835 |
+
self.q_b_proj = nn.Linear(
|
| 836 |
+
config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
self.kv_a_proj_with_mqa = nn.Linear(
|
| 840 |
+
self.hidden_size,
|
| 841 |
+
config.kv_lora_rank + config.qk_rope_head_dim,
|
| 842 |
+
bias=config.attention_bias,
|
| 843 |
+
)
|
| 844 |
+
self.kv_a_layernorm = RMSNorm(config.kv_lora_rank)
|
| 845 |
+
self.kv_b_proj = nn.Linear(
|
| 846 |
+
config.kv_lora_rank,
|
| 847 |
+
self.num_heads
|
| 848 |
+
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
| 849 |
+
bias=False,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
self.o_proj = nn.Linear(
|
| 853 |
+
self.num_heads * self.v_head_dim,
|
| 854 |
+
self.hidden_size,
|
| 855 |
+
bias=config.attention_bias,
|
| 856 |
+
)
|
| 857 |
+
self._init_rope()
|
| 858 |
+
|
| 859 |
+
self.softmax_scale = self.q_head_dim ** (-0.5)
|
| 860 |
+
if self.config.rope_scaling is not None:
|
| 861 |
+
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
| 862 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 863 |
+
if mscale_all_dim:
|
| 864 |
+
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
| 865 |
+
self.softmax_scale = self.softmax_scale * mscale * mscale
|
| 866 |
+
|
| 867 |
+
def _init_rope(self):
|
| 868 |
+
if self.config.rope_scaling is None:
|
| 869 |
+
self.rotary_emb = RotaryEmbedding(
|
| 870 |
+
self.qk_rope_head_dim,
|
| 871 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 872 |
+
base=self.rope_theta,
|
| 873 |
+
)
|
| 874 |
+
else:
|
| 875 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 876 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 877 |
+
if scaling_type == "linear":
|
| 878 |
+
self.rotary_emb = LinearScalingRotaryEmbedding(
|
| 879 |
+
self.qk_rope_head_dim,
|
| 880 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 881 |
+
scaling_factor=scaling_factor,
|
| 882 |
+
base=self.rope_theta,
|
| 883 |
+
)
|
| 884 |
+
elif scaling_type == "dynamic":
|
| 885 |
+
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
| 886 |
+
self.qk_rope_head_dim,
|
| 887 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 888 |
+
scaling_factor=scaling_factor,
|
| 889 |
+
base=self.rope_theta,
|
| 890 |
+
)
|
| 891 |
+
elif scaling_type == "yarn":
|
| 892 |
+
kwargs = {
|
| 893 |
+
key: self.config.rope_scaling[key]
|
| 894 |
+
for key in [
|
| 895 |
+
"original_max_position_embeddings",
|
| 896 |
+
"beta_fast",
|
| 897 |
+
"beta_slow",
|
| 898 |
+
"mscale",
|
| 899 |
+
"mscale_all_dim",
|
| 900 |
+
]
|
| 901 |
+
if key in self.config.rope_scaling
|
| 902 |
+
}
|
| 903 |
+
self.rotary_emb = YarnRotaryEmbedding(
|
| 904 |
+
self.qk_rope_head_dim,
|
| 905 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 906 |
+
scaling_factor=scaling_factor,
|
| 907 |
+
base=self.rope_theta,
|
| 908 |
+
**kwargs,
|
| 909 |
+
)
|
| 910 |
+
else:
|
| 911 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 912 |
+
|
| 913 |
+
def forward(
|
| 914 |
+
self,
|
| 915 |
+
hidden_states: torch.Tensor,
|
| 916 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 917 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 918 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 919 |
+
bsz, q_len, _ = hidden_states.size()
|
| 920 |
+
|
| 921 |
+
if self.q_lora_rank is None:
|
| 922 |
+
q = self.q_proj(hidden_states)
|
| 923 |
+
else:
|
| 924 |
+
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
| 925 |
+
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
| 926 |
+
q_nope, q_pe = torch.split(
|
| 927 |
+
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
| 931 |
+
compressed_kv, k_pe = torch.split(
|
| 932 |
+
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
| 933 |
+
)
|
| 934 |
+
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
| 935 |
+
kv = (
|
| 936 |
+
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
| 937 |
+
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 938 |
+
.transpose(1, 2)
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
k_nope, value_states = torch.split(
|
| 942 |
+
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
| 943 |
+
)
|
| 944 |
+
kv_seq_len = value_states.shape[-2]
|
| 945 |
+
|
| 946 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 947 |
+
|
| 948 |
+
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
| 949 |
+
|
| 950 |
+
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
| 951 |
+
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
| 952 |
+
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
| 953 |
+
|
| 954 |
+
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
| 955 |
+
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
| 956 |
+
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
| 957 |
+
|
| 958 |
+
if attention_mask is not None:
|
| 959 |
+
# Attention mask was made 4D because the `attn_weights` above is 4D.
|
| 960 |
+
# We probably can make this mask smarter if we want to pack sequences
|
| 961 |
+
# together, instead of using padding. This optimization can be used in
|
| 962 |
+
# inference. For training, if we want to pack sequences, data loader
|
| 963 |
+
# will pass in a mask containing such info.
|
| 964 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 965 |
+
attention_mask, # None, or user provided mask in 2D
|
| 966 |
+
(bsz, q_len),
|
| 967 |
+
hidden_states,
|
| 968 |
+
0, # past_key_values_length, 0 when training
|
| 969 |
+
)
|
| 970 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 971 |
+
raise ValueError(
|
| 972 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 976 |
+
query=query_states,
|
| 977 |
+
key=key_states,
|
| 978 |
+
value=value_states,
|
| 979 |
+
attn_mask=attention_mask,
|
| 980 |
+
dropout_p=self.attention_dropout,
|
| 981 |
+
is_causal=attention_mask is None,
|
| 982 |
+
scale=self.softmax_scale,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 986 |
+
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
| 987 |
+
attn_output = self.o_proj(attn_output)
|
| 988 |
+
|
| 989 |
+
return attn_output
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
class DecoderLayer(nn.Module):
|
| 993 |
+
def __init__(self, config: ModelArgs, layer_idx: int):
|
| 994 |
+
super().__init__()
|
| 995 |
+
self.hidden_size = config.hidden_size
|
| 996 |
+
|
| 997 |
+
self.self_attn = Attention(config=config, layer_idx=layer_idx)
|
| 998 |
+
|
| 999 |
+
self.mlp = (
|
| 1000 |
+
MoE(config)
|
| 1001 |
+
if (
|
| 1002 |
+
config.n_routed_experts is not None
|
| 1003 |
+
and layer_idx >= config.first_k_dense_replace
|
| 1004 |
+
and layer_idx % config.moe_layer_freq == 0
|
| 1005 |
+
)
|
| 1006 |
+
else MLP(config)
|
| 1007 |
+
)
|
| 1008 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1009 |
+
self.post_attention_layernorm = RMSNorm(
|
| 1010 |
+
config.hidden_size, eps=config.rms_norm_eps
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
def forward(
|
| 1014 |
+
self,
|
| 1015 |
+
hidden_states: torch.Tensor,
|
| 1016 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1017 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1018 |
+
) -> torch.Tensor:
|
| 1019 |
+
"""
|
| 1020 |
+
Args:
|
| 1021 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 1022 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
| 1023 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
| 1024 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
| 1025 |
+
"""
|
| 1026 |
+
residual = hidden_states
|
| 1027 |
+
|
| 1028 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 1029 |
+
|
| 1030 |
+
# Self Attention
|
| 1031 |
+
hidden_states = self.self_attn(
|
| 1032 |
+
hidden_states=hidden_states,
|
| 1033 |
+
attention_mask=attention_mask,
|
| 1034 |
+
position_ids=position_ids,
|
| 1035 |
+
)
|
| 1036 |
+
hidden_states = residual + hidden_states
|
| 1037 |
+
|
| 1038 |
+
# Fully Connected
|
| 1039 |
+
residual = hidden_states
|
| 1040 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 1041 |
+
hidden_states = self.mlp(hidden_states)
|
| 1042 |
+
hidden_states = residual + hidden_states
|
| 1043 |
+
|
| 1044 |
+
return hidden_states
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
Deepseek_INPUTS_DOCSTRING = r"""
|
| 1048 |
+
Args:
|
| 1049 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1050 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 1051 |
+
it.
|
| 1052 |
+
|
| 1053 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1054 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1055 |
+
|
| 1056 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1057 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1058 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1059 |
+
|
| 1060 |
+
- 1 for tokens that are **not masked**,
|
| 1061 |
+
- 0 for tokens that are **masked**.
|
| 1062 |
+
|
| 1063 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1064 |
+
|
| 1065 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1066 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1067 |
+
|
| 1068 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 1069 |
+
`past_key_values`).
|
| 1070 |
+
|
| 1071 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 1072 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 1073 |
+
information on the default strategy.
|
| 1074 |
+
|
| 1075 |
+
- 1 indicates the head is **not masked**,
|
| 1076 |
+
- 0 indicates the head is **masked**.
|
| 1077 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1078 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1079 |
+
config.n_positions - 1]`.
|
| 1080 |
+
|
| 1081 |
+
[What are position IDs?](../glossary#position-ids)
|
| 1082 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 1083 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 1084 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 1085 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 1086 |
+
|
| 1087 |
+
Two formats are allowed:
|
| 1088 |
+
- a [`~cache_utils.Cache`] instance;
|
| 1089 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 1090 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 1091 |
+
cache format.
|
| 1092 |
+
|
| 1093 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 1094 |
+
legacy cache format will be returned.
|
| 1095 |
+
|
| 1096 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 1097 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 1098 |
+
of shape `(batch_size, sequence_length)`.
|
| 1099 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1100 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 1101 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 1102 |
+
model's internal embedding lookup matrix.
|
| 1103 |
+
use_cache (`bool`, *optional*):
|
| 1104 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 1105 |
+
`past_key_values`).
|
| 1106 |
+
output_attentions (`bool`, *optional*):
|
| 1107 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1108 |
+
tensors for more detail.
|
| 1109 |
+
output_hidden_states (`bool`, *optional*):
|
| 1110 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1111 |
+
more detail.
|
| 1112 |
+
return_dict (`bool`, *optional*):
|
| 1113 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1114 |
+
"""
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
class DeepseekModel(torch.nn.Module):
|
| 1118 |
+
"""
|
| 1119 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`]
|
| 1120 |
+
|
| 1121 |
+
Args:
|
| 1122 |
+
config: ModelArgs
|
| 1123 |
+
"""
|
| 1124 |
+
|
| 1125 |
+
def __init__(self, config: ModelArgs):
|
| 1126 |
+
super().__init__()
|
| 1127 |
+
self.config = config
|
| 1128 |
+
self.padding_idx = config.pad_token_id
|
| 1129 |
+
self.vocab_size = config.vocab_size
|
| 1130 |
+
|
| 1131 |
+
# Creating model parts related to my stage
|
| 1132 |
+
assert (
|
| 1133 |
+
config.stage_idx < config.num_stages
|
| 1134 |
+
), f"Stage {config.stage_idx} is not in the model"
|
| 1135 |
+
print(f"Creating model stage {config.stage_idx} of {config.num_stages}")
|
| 1136 |
+
|
| 1137 |
+
self.embed_tokens = (
|
| 1138 |
+
nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 1139 |
+
if config.stage_idx == 0
|
| 1140 |
+
else None
|
| 1141 |
+
)
|
| 1142 |
+
|
| 1143 |
+
self.layers = torch.nn.ModuleDict()
|
| 1144 |
+
division = config.num_hidden_layers // config.num_stages
|
| 1145 |
+
residual = config.num_hidden_layers % config.num_stages
|
| 1146 |
+
# Some earlier stages may have 1 more layer than latter stages because
|
| 1147 |
+
# the division may have residual; this is more even than giving the
|
| 1148 |
+
# entire residual to the last stage.
|
| 1149 |
+
layers_per_stage = [
|
| 1150 |
+
division + 1 if stage < residual else division
|
| 1151 |
+
for stage in range(config.num_stages)
|
| 1152 |
+
]
|
| 1153 |
+
assert sum(layers_per_stage) == config.num_hidden_layers
|
| 1154 |
+
layer_id_start = sum(layers_per_stage[: config.stage_idx])
|
| 1155 |
+
layer_id_end = layer_id_start + layers_per_stage[config.stage_idx]
|
| 1156 |
+
for layer_id in range(layer_id_start, layer_id_end):
|
| 1157 |
+
self.layers[str(layer_id)] = DecoderLayer(config, layer_id)
|
| 1158 |
+
|
| 1159 |
+
self.norm = (
|
| 1160 |
+
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1161 |
+
if config.stage_idx == config.num_stages - 1
|
| 1162 |
+
else None
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
# Initialize weights and apply final processing
|
| 1166 |
+
self.apply(self._init_weights)
|
| 1167 |
+
|
| 1168 |
+
def _init_weights(self, module):
|
| 1169 |
+
std = self.config.initializer_range
|
| 1170 |
+
if isinstance(module, nn.Linear):
|
| 1171 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1172 |
+
if module.bias is not None:
|
| 1173 |
+
module.bias.data.zero_()
|
| 1174 |
+
elif isinstance(module, nn.Embedding):
|
| 1175 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1176 |
+
if module.padding_idx is not None:
|
| 1177 |
+
module.weight.data[module.padding_idx].zero_()
|
| 1178 |
+
|
| 1179 |
+
def forward(
|
| 1180 |
+
self,
|
| 1181 |
+
tokens: torch.Tensor,
|
| 1182 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1183 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1184 |
+
) -> torch.Tensor:
|
| 1185 |
+
# Embedding
|
| 1186 |
+
hidden_states = (
|
| 1187 |
+
self.embed_tokens(tokens) if self.embed_tokens is not None else tokens
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
# decoder layers
|
| 1191 |
+
for decoder_layer in self.layers.values():
|
| 1192 |
+
hidden_states = decoder_layer(
|
| 1193 |
+
hidden_states,
|
| 1194 |
+
attention_mask=attention_mask,
|
| 1195 |
+
position_ids=position_ids,
|
| 1196 |
+
)
|
| 1197 |
+
|
| 1198 |
+
hidden_states = (
|
| 1199 |
+
self.norm(hidden_states) if self.norm is not None else hidden_states
|
| 1200 |
+
)
|
| 1201 |
+
return hidden_states
|
| 1202 |
+
|
| 1203 |
+
|
| 1204 |
+
class DeepseekForCausalLM(torch.nn.Module):
|
| 1205 |
+
def __init__(self, config):
|
| 1206 |
+
super().__init__()
|
| 1207 |
+
self.model = DeepseekModel(config)
|
| 1208 |
+
self.lm_head = (
|
| 1209 |
+
nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1210 |
+
if config.stage_idx == config.num_stages - 1
|
| 1211 |
+
else None
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
# Initialize weights and apply final processing
|
| 1215 |
+
# self.post_init()
|
| 1216 |
+
|
| 1217 |
+
def forward(
|
| 1218 |
+
self,
|
| 1219 |
+
tokens: torch.Tensor,
|
| 1220 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1221 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1222 |
+
) -> Tuple:
|
| 1223 |
+
r"""
|
| 1224 |
+
Example:
|
| 1225 |
+
|
| 1226 |
+
```python
|
| 1227 |
+
>>> from transformers import AutoTokenizer, DeepseekForCausalLM
|
| 1228 |
+
|
| 1229 |
+
>>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
| 1230 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
| 1231 |
+
|
| 1232 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1233 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1234 |
+
|
| 1235 |
+
>>> # Generate
|
| 1236 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1237 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1238 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1239 |
+
```"""
|
| 1240 |
+
hidden_states = self.model(
|
| 1241 |
+
tokens,
|
| 1242 |
+
attention_mask=attention_mask,
|
| 1243 |
+
position_ids=position_ids,
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
logits = (
|
| 1247 |
+
self.lm_head(hidden_states) if self.lm_head is not None else hidden_states
|
| 1248 |
+
)
|
| 1249 |
+
return logits
|
| 1250 |
+
|
| 1251 |
+
def prepare_inputs_for_generation(
|
| 1252 |
+
self,
|
| 1253 |
+
input_ids,
|
| 1254 |
+
past_key_values=None,
|
| 1255 |
+
attention_mask=None,
|
| 1256 |
+
**kwargs,
|
| 1257 |
+
):
|
| 1258 |
+
if past_key_values is not None:
|
| 1259 |
+
# Assuming isinstance(past_key_values, Cache):
|
| 1260 |
+
cache_length = past_key_values.get_seq_length()
|
| 1261 |
+
past_length = past_key_values.seen_tokens
|
| 1262 |
+
max_cache_length = past_key_values.get_max_length()
|
| 1263 |
+
|
| 1264 |
+
# Keep only the unprocessed tokens:
|
| 1265 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 1266 |
+
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
| 1267 |
+
# input)
|
| 1268 |
+
if (
|
| 1269 |
+
attention_mask is not None
|
| 1270 |
+
and attention_mask.shape[1] > input_ids.shape[1]
|
| 1271 |
+
):
|
| 1272 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1273 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 1274 |
+
# input_ids based on the past_length.
|
| 1275 |
+
elif past_length < input_ids.shape[1]:
|
| 1276 |
+
input_ids = input_ids[:, past_length:]
|
| 1277 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 1278 |
+
|
| 1279 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
| 1280 |
+
if (
|
| 1281 |
+
max_cache_length is not None
|
| 1282 |
+
and attention_mask is not None
|
| 1283 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
| 1284 |
+
):
|
| 1285 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
| 1286 |
+
|
| 1287 |
+
position_ids = kwargs.get("position_ids", None)
|
| 1288 |
+
if attention_mask is not None and position_ids is None:
|
| 1289 |
+
# create position_ids on the fly for batch generation
|
| 1290 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1291 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1292 |
+
if past_key_values:
|
| 1293 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1294 |
+
|
| 1295 |
+
model_inputs = {"input_ids": input_ids}
|
| 1296 |
+
|
| 1297 |
+
model_inputs.update(
|
| 1298 |
+
{
|
| 1299 |
+
"position_ids": position_ids,
|
| 1300 |
+
"past_key_values": past_key_values,
|
| 1301 |
+
"use_cache": kwargs.get("use_cache"),
|
| 1302 |
+
"attention_mask": attention_mask,
|
| 1303 |
+
}
|
| 1304 |
+
)
|
| 1305 |
+
return model_inputs
|
| 1306 |
+
|
| 1307 |
+
@staticmethod
|
| 1308 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 1309 |
+
reordered_past = ()
|
| 1310 |
+
for layer_past in past_key_values:
|
| 1311 |
+
reordered_past += (
|
| 1312 |
+
tuple(
|
| 1313 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
| 1314 |
+
for past_state in layer_past
|
| 1315 |
+
),
|
| 1316 |
+
)
|
| 1317 |
+
return reordered_past
|
| 1318 |
+
|
| 1319 |
+
# Setup Symmetric Memory for MoE token shuffle.
|
| 1320 |
+
# Supports inference currently.
|
| 1321 |
+
def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
|
| 1322 |
+
for layer in self.model.layers.values():
|
| 1323 |
+
if not isinstance(layer.mlp, MoE):
|
| 1324 |
+
continue
|
| 1325 |
+
layer.mlp.setup_symm_mem(dtype, device)
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .triton_on_device_all_to_all_v import OnDeviceAllToAllV
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"OnDeviceAllToAllV",
|
| 11 |
+
]
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from .triton_utils import get_flat_bid, get_flat_tid
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@triton.jit
|
| 14 |
+
def send_signal(addrs, sem: tl.constexpr):
|
| 15 |
+
if sem == "relaxed":
|
| 16 |
+
tl.inline_asm_elementwise(
|
| 17 |
+
"""
|
| 18 |
+
{
|
| 19 |
+
.reg .u32 %tmp32_<1>;
|
| 20 |
+
.reg .pred %p<1>;
|
| 21 |
+
|
| 22 |
+
send_signal:
|
| 23 |
+
atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
| 24 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
| 25 |
+
@!%p0 bra send_signal;
|
| 26 |
+
}
|
| 27 |
+
""",
|
| 28 |
+
"=r, l",
|
| 29 |
+
[addrs],
|
| 30 |
+
dtype=tl.int32,
|
| 31 |
+
is_pure=False,
|
| 32 |
+
pack=1,
|
| 33 |
+
)
|
| 34 |
+
elif sem == "acq_rel":
|
| 35 |
+
tl.inline_asm_elementwise(
|
| 36 |
+
"""
|
| 37 |
+
{
|
| 38 |
+
.reg .u32 %tmp32_<1>;
|
| 39 |
+
.reg .pred %p<1>;
|
| 40 |
+
|
| 41 |
+
send_signal:
|
| 42 |
+
atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
| 43 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
| 44 |
+
@!%p0 bra send_signal;
|
| 45 |
+
}
|
| 46 |
+
""",
|
| 47 |
+
"=r, l",
|
| 48 |
+
[addrs],
|
| 49 |
+
dtype=tl.int32,
|
| 50 |
+
is_pure=False,
|
| 51 |
+
pack=1,
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@triton.jit
|
| 58 |
+
def wait_signal(addrs, sem: tl.constexpr):
|
| 59 |
+
if sem == "relaxed":
|
| 60 |
+
tl.inline_asm_elementwise(
|
| 61 |
+
"""
|
| 62 |
+
{
|
| 63 |
+
.reg .u32 %tmp32_<1>;
|
| 64 |
+
.reg .pred %p<1>;
|
| 65 |
+
|
| 66 |
+
wait_signal:
|
| 67 |
+
atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
|
| 68 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
| 69 |
+
@!%p0 bra wait_signal;
|
| 70 |
+
}
|
| 71 |
+
""",
|
| 72 |
+
"=r, l",
|
| 73 |
+
[addrs],
|
| 74 |
+
dtype=tl.int32,
|
| 75 |
+
is_pure=False,
|
| 76 |
+
pack=1,
|
| 77 |
+
)
|
| 78 |
+
elif sem == "acq_rel":
|
| 79 |
+
tl.inline_asm_elementwise(
|
| 80 |
+
"""
|
| 81 |
+
{
|
| 82 |
+
.reg .u32 %tmp32_<1>;
|
| 83 |
+
.reg .pred %p<1>;
|
| 84 |
+
|
| 85 |
+
wait_signal:
|
| 86 |
+
atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
|
| 87 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
| 88 |
+
@!%p0 bra wait_signal;
|
| 89 |
+
}
|
| 90 |
+
""",
|
| 91 |
+
"=r, l",
|
| 92 |
+
[addrs],
|
| 93 |
+
dtype=tl.int32,
|
| 94 |
+
is_pure=False,
|
| 95 |
+
pack=1,
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@triton.jit
|
| 102 |
+
def blockwise_barrier(
|
| 103 |
+
signal_pad_ptrs,
|
| 104 |
+
block_id,
|
| 105 |
+
rank: tl.constexpr,
|
| 106 |
+
world_size: tl.constexpr,
|
| 107 |
+
sem: tl.constexpr,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Synchronizes blocks with matching block_id across participating devices.
|
| 111 |
+
|
| 112 |
+
Note: the function itself is not a system level barrier/fence. It is a
|
| 113 |
+
building block for expressing different synchronization patterns.
|
| 114 |
+
|
| 115 |
+
Pattern 0: Ensures that all writes to symm_mem buffers from previous
|
| 116 |
+
kernels across all devices are visible to the current kernel:
|
| 117 |
+
|
| 118 |
+
blockwise_barrier(..., sem="relaxed")
|
| 119 |
+
sync_threads()
|
| 120 |
+
|
| 121 |
+
Pattern 1: Ensures that all writes to symm_mem buffers from the current
|
| 122 |
+
block are visible to all remote blocks with matching blockIdx:
|
| 123 |
+
|
| 124 |
+
sync_threads()
|
| 125 |
+
blockwise_barrier(..., sem="acq_rel")
|
| 126 |
+
sync_threads()
|
| 127 |
+
|
| 128 |
+
Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
|
| 129 |
+
for writing by subsequent kernels across all devices.
|
| 130 |
+
|
| 131 |
+
sync_threads()
|
| 132 |
+
blockwise_barrier(..., sem="relaxed")
|
| 133 |
+
|
| 134 |
+
CUDA graph friendliness:
|
| 135 |
+
|
| 136 |
+
This barrier operates through atomic operations on a zero-filled signal
|
| 137 |
+
pad, which resets to a zero-filled state after each successful
|
| 138 |
+
synchronization. This design eliminates the need for incrementing a
|
| 139 |
+
flag from host.
|
| 140 |
+
"""
|
| 141 |
+
if block_id is None:
|
| 142 |
+
block_id = get_flat_bid()
|
| 143 |
+
flat_tid = get_flat_tid()
|
| 144 |
+
|
| 145 |
+
remote_ranks = tl.arange(0, world_size)
|
| 146 |
+
signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
|
| 147 |
+
remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
|
| 148 |
+
tl.pointer_type(tl.uint32)
|
| 149 |
+
)
|
| 150 |
+
send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
|
| 151 |
+
|
| 152 |
+
local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
|
| 153 |
+
tl.pointer_type(tl.uint32)
|
| 154 |
+
)
|
| 155 |
+
wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
|
| 156 |
+
|
| 157 |
+
if flat_tid < world_size:
|
| 158 |
+
send_signal(send_addrs, sem)
|
| 159 |
+
wait_signal(wait_addrs, sem)
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
import torch.distributed._symmetric_memory as symm_mem
|
| 10 |
+
import triton
|
| 11 |
+
import triton.language as tl
|
| 12 |
+
|
| 13 |
+
from .triton_barrier import blockwise_barrier
|
| 14 |
+
from .triton_utils import sync_threads
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.jit
|
| 18 |
+
def _exchange_row_offsets(
|
| 19 |
+
split_sizes_ptrs,
|
| 20 |
+
rank: tl.constexpr,
|
| 21 |
+
world_size: tl.constexpr,
|
| 22 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
| 23 |
+
):
|
| 24 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
| 25 |
+
|
| 26 |
+
# split_sizes_ptr for all ranks
|
| 27 |
+
# All these vector stacks into split_sizes_matrix
|
| 28 |
+
split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
|
| 29 |
+
|
| 30 |
+
# split_sizes_matrix[remote_rank, :]
|
| 31 |
+
input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
|
| 32 |
+
tl.pointer_type(tl.int64)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
offsets_ = tl.arange(0, world_size)
|
| 36 |
+
input_split_sizes = tl.load(
|
| 37 |
+
input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
num_rows = tl.load(input_split_sizes_ptr + rank)
|
| 41 |
+
input_row_offset = tl.sum(input_split_sizes) - num_rows
|
| 42 |
+
|
| 43 |
+
# split_sizes_matrix[:, rank]
|
| 44 |
+
output_split_sizes_ptrs = (
|
| 45 |
+
tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
|
| 46 |
+
)
|
| 47 |
+
output_split_sizes = tl.load(
|
| 48 |
+
output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
|
| 49 |
+
)
|
| 50 |
+
output_row_offset = tl.sum(output_split_sizes) - num_rows
|
| 51 |
+
|
| 52 |
+
return input_row_offset, output_row_offset, num_rows
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@triton.jit
|
| 56 |
+
def on_device_all_to_all_v_kernel(
|
| 57 |
+
output_ptr,
|
| 58 |
+
output_splits_ptr,
|
| 59 |
+
input_ptrs,
|
| 60 |
+
input_splits_ptr,
|
| 61 |
+
signal_pad_ptrs,
|
| 62 |
+
dim: tl.constexpr, # Separate dim for easier vectorization
|
| 63 |
+
rank: tl.constexpr,
|
| 64 |
+
world_size: tl.constexpr,
|
| 65 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
| 66 |
+
UNROLL_FACTOR: tl.constexpr,
|
| 67 |
+
BLOCK_SIZE: tl.constexpr,
|
| 68 |
+
):
|
| 69 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
| 70 |
+
sync_threads()
|
| 71 |
+
|
| 72 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
| 73 |
+
block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
|
| 74 |
+
|
| 75 |
+
input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
|
| 76 |
+
input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
|
| 80 |
+
if block_offset == 0:
|
| 81 |
+
# Update output_splits
|
| 82 |
+
tl.store(output_splits_ptr + remote_rank, num_rows)
|
| 83 |
+
|
| 84 |
+
input_ptr = (
|
| 85 |
+
tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
|
| 86 |
+
tl.pointer_type(tl.bfloat16)
|
| 87 |
+
)
|
| 88 |
+
+ input_row_offset * dim
|
| 89 |
+
)
|
| 90 |
+
output_ptr = output_ptr + output_row_offset * dim
|
| 91 |
+
|
| 92 |
+
outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
|
| 93 |
+
outer_loop_iters_per_rank = tl.cdiv(
|
| 94 |
+
tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
|
| 95 |
+
)
|
| 96 |
+
numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
|
| 97 |
+
offset = numel_per_rank * block_offset
|
| 98 |
+
end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
|
| 99 |
+
|
| 100 |
+
unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
|
| 101 |
+
for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
|
| 102 |
+
datas = []
|
| 103 |
+
for j in tl.range(
|
| 104 |
+
i,
|
| 105 |
+
i + outer_loop_step,
|
| 106 |
+
BLOCK_SIZE,
|
| 107 |
+
loop_unroll_factor=UNROLL_FACTOR,
|
| 108 |
+
):
|
| 109 |
+
offsets = j + tl.arange(0, BLOCK_SIZE)
|
| 110 |
+
data = tl.load(input_ptr + offsets)
|
| 111 |
+
tl.store(output_ptr + offsets, data)
|
| 112 |
+
|
| 113 |
+
offset += unroll_region_size
|
| 114 |
+
while offset < end:
|
| 115 |
+
offsets = offset + tl.arange(0, BLOCK_SIZE)
|
| 116 |
+
mask = offsets < num_rows * dim
|
| 117 |
+
data = tl.load(input_ptr + offsets, mask=mask)
|
| 118 |
+
tl.store(output_ptr + offsets, data, mask=mask)
|
| 119 |
+
offset += BLOCK_SIZE
|
| 120 |
+
|
| 121 |
+
sync_threads()
|
| 122 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _on_device_all_to_all_v(
|
| 127 |
+
output: torch.Tensor,
|
| 128 |
+
output_splits: torch.Tensor,
|
| 129 |
+
input: torch.Tensor,
|
| 130 |
+
input_splits: torch.Tensor,
|
| 131 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
| 132 |
+
BLOCKS_PER_REMOTE_RANK=8,
|
| 133 |
+
UNROLL_FACTOR: int = 8,
|
| 134 |
+
BLOCK_SIZE: int = 16384,
|
| 135 |
+
):
|
| 136 |
+
assert output.dim() == 2, f"{output.shape}"
|
| 137 |
+
assert input.dim() == 2, f"{input.shape}"
|
| 138 |
+
assert output.shape[1] == input.shape[1]
|
| 139 |
+
|
| 140 |
+
dim = output.shape[1]
|
| 141 |
+
input_hdl = symm_mem.rendezvous(input, group=group)
|
| 142 |
+
input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
|
| 143 |
+
|
| 144 |
+
num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
|
| 145 |
+
kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
|
| 146 |
+
output,
|
| 147 |
+
output_splits,
|
| 148 |
+
input_hdl.buffer_ptrs_dev,
|
| 149 |
+
input_splits_hdl.buffer_ptrs_dev,
|
| 150 |
+
input_hdl.signal_pad_ptrs_dev,
|
| 151 |
+
dim=dim,
|
| 152 |
+
rank=input_hdl.rank,
|
| 153 |
+
world_size=input_hdl.world_size,
|
| 154 |
+
BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
|
| 155 |
+
UNROLL_FACTOR=UNROLL_FACTOR,
|
| 156 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 157 |
+
num_warps=16,
|
| 158 |
+
)
|
| 159 |
+
# log_triton_kernel(kernel)
|
| 160 |
+
return output
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class OnDeviceAllToAllV(torch.autograd.Function):
|
| 164 |
+
# A symmetric memory holding the grad_output during backward
|
| 165 |
+
grad_output_buf = None
|
| 166 |
+
# A symmetric memory for exchanges split sizes during both forward and backward
|
| 167 |
+
splits_buf = None
|
| 168 |
+
# Maximum output length (need to be set before use of OnDeviceAllToAllV)
|
| 169 |
+
max_output_len = None
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def forward(
|
| 173 |
+
ctx,
|
| 174 |
+
input: torch.Tensor,
|
| 175 |
+
input_splits: torch.Tensor,
|
| 176 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
| 177 |
+
):
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
input: input tensor with data for all ranks concatenated.
|
| 181 |
+
input_splits: input splits of shape (group.world_size,)
|
| 182 |
+
group: process group to scope the collective.
|
| 183 |
+
"""
|
| 184 |
+
# Initialize input splits buffer (one time only)
|
| 185 |
+
if OnDeviceAllToAllV.splits_buf is None:
|
| 186 |
+
OnDeviceAllToAllV.splits_buf = symm_mem.empty(
|
| 187 |
+
*input_splits.shape,
|
| 188 |
+
dtype=input_splits.dtype,
|
| 189 |
+
device=input_splits.device,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if OnDeviceAllToAllV.max_output_len is None:
|
| 193 |
+
raise RuntimeError(
|
| 194 |
+
"Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Allocate output buffer
|
| 198 |
+
output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
|
| 199 |
+
# Allocate output splits tensor
|
| 200 |
+
output_splits = torch.empty_like(input_splits)
|
| 201 |
+
# Copy input splits to the buffer
|
| 202 |
+
OnDeviceAllToAllV.splits_buf.copy_(input_splits)
|
| 203 |
+
|
| 204 |
+
# Shuffle input to output
|
| 205 |
+
_on_device_all_to_all_v(
|
| 206 |
+
output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Output splits in forward is the input splits in backward
|
| 210 |
+
ctx.save_for_backward(output_splits)
|
| 211 |
+
ctx.group = group
|
| 212 |
+
ctx.input_shape = input.shape
|
| 213 |
+
return output, output_splits
|
| 214 |
+
|
| 215 |
+
@staticmethod
|
| 216 |
+
def backward(ctx, grad_output, grad_splits):
|
| 217 |
+
"""
|
| 218 |
+
Backward is implemented as a shuffle of the output's gradients to the input.
|
| 219 |
+
Args:
|
| 220 |
+
`grad_output`: output's gradients passed from the downstream.
|
| 221 |
+
`grad_splits`: unused.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
# Initialize grad_output buffer (one time only)
|
| 225 |
+
if OnDeviceAllToAllV.grad_output_buf is None:
|
| 226 |
+
assert (
|
| 227 |
+
OnDeviceAllToAllV.max_output_len is not None
|
| 228 |
+
), "`max_output_len` not set"
|
| 229 |
+
OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
|
| 230 |
+
OnDeviceAllToAllV.max_output_len,
|
| 231 |
+
*grad_output.shape[1:],
|
| 232 |
+
dtype=grad_output.dtype,
|
| 233 |
+
device=grad_output.device,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# TODO: is there a way to tell autograd to feed grad_output directly to
|
| 237 |
+
# our symm_mem buffer?
|
| 238 |
+
OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
|
| 239 |
+
grad_output
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Size info
|
| 243 |
+
(grad_output_splits,) = ctx.saved_tensors
|
| 244 |
+
OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
|
| 245 |
+
grad_input_splits = torch.empty_like(grad_output_splits) # unused
|
| 246 |
+
grad_input = grad_output.new_empty(*ctx.input_shape)
|
| 247 |
+
|
| 248 |
+
# Shuffle gradients back to the input
|
| 249 |
+
_on_device_all_to_all_v(
|
| 250 |
+
grad_input,
|
| 251 |
+
grad_input_splits,
|
| 252 |
+
OnDeviceAllToAllV.grad_output_buf,
|
| 253 |
+
OnDeviceAllToAllV.splits_buf,
|
| 254 |
+
group=ctx.group,
|
| 255 |
+
)
|
| 256 |
+
return grad_input, None, None
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# Alias
|
| 260 |
+
on_device_all_to_all_v = OnDeviceAllToAllV.apply
|
torchtitan/experiments/deepseek_v3/train.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# torchrun --standalone --nproc-per-node 8 run.py
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from checkpoint import load_weights_from_hf
|
| 11 |
+
from model import DeepseekForCausalLM
|
| 12 |
+
from model_config import deepseek_config_registry
|
| 13 |
+
|
| 14 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 15 |
+
from torch.distributed.fsdp import fully_shard
|
| 16 |
+
from torch.distributed.pipelining import PipelineStage, Schedule1F1B
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Use DeepSeek-V2-Lite as a proxy
|
| 20 |
+
model_id = "deepseek-ai/DeepSeek-V2-Lite"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Run full model
|
| 24 |
+
def run_full_model(
|
| 25 |
+
mesh: DeviceMesh,
|
| 26 |
+
):
|
| 27 |
+
rank = dist.get_rank()
|
| 28 |
+
device_count = torch.cuda.device_count()
|
| 29 |
+
device = torch.device("cuda", rank % device_count)
|
| 30 |
+
|
| 31 |
+
pp_mesh = mesh["pp"]
|
| 32 |
+
ep_mesh = mesh["ep"]
|
| 33 |
+
pp_rank = pp_mesh.get_local_rank()
|
| 34 |
+
ep_rank = ep_mesh.get_local_rank()
|
| 35 |
+
pp_size = pp_mesh.size()
|
| 36 |
+
ep_size = ep_mesh.size()
|
| 37 |
+
|
| 38 |
+
# Get model configs
|
| 39 |
+
model_args = deepseek_config_registry[model_id]
|
| 40 |
+
# [Note]: I am making the model smaller for testing / avoiding OOM. If you
|
| 41 |
+
# have sufficient GPUs for model parallelism, you can remove this line.
|
| 42 |
+
model_args.num_hidden_layers = 16
|
| 43 |
+
|
| 44 |
+
# Apply model parallelism
|
| 45 |
+
model_args.ep_size = ep_size
|
| 46 |
+
model_args.num_stages = pp_size
|
| 47 |
+
model_args.stage_idx = pp_rank
|
| 48 |
+
print(model_args)
|
| 49 |
+
|
| 50 |
+
# Instantiate model
|
| 51 |
+
with device, mesh:
|
| 52 |
+
model = DeepseekForCausalLM(model_args)
|
| 53 |
+
|
| 54 |
+
# Load weights
|
| 55 |
+
load_weights_from_hf(model, model_id, device)
|
| 56 |
+
model.train()
|
| 57 |
+
|
| 58 |
+
# Apply data parallelism
|
| 59 |
+
fsdp_mesh = mesh["fsdp"]
|
| 60 |
+
hsdp_mesh = mesh["ep", "fsdp"]
|
| 61 |
+
# Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the
|
| 62 |
+
# optimizer (Zero-1) and gradients (Zero-2), but not the model weights.
|
| 63 |
+
# Reason: the MoE is "sparsely activated" compared to the dense model, thus
|
| 64 |
+
# it will be ineconomical re-gather the weights.
|
| 65 |
+
for layer in model.model.layers.values():
|
| 66 |
+
# Apply FSDP to experts
|
| 67 |
+
if hasattr(layer.mlp, "experts"):
|
| 68 |
+
for expert in layer.mlp.experts.values():
|
| 69 |
+
fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False)
|
| 70 |
+
# Apply HSDP to other parts such as attention, layernorm, because they
|
| 71 |
+
# are doing DDP on EP dimension
|
| 72 |
+
fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False)
|
| 73 |
+
|
| 74 |
+
# Apply HSDP on root model (lm_head, embeddings, etc)
|
| 75 |
+
fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)
|
| 76 |
+
|
| 77 |
+
# Synthetic setting
|
| 78 |
+
microbatches = pp_size * 2
|
| 79 |
+
|
| 80 |
+
# Use Symmetric Memory for MoE token shuffle.
|
| 81 |
+
# TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is
|
| 82 |
+
# currently supported for forward only. See `generate.py`.
|
| 83 |
+
# model.setup_symm_mem(torch.bfloat16, device)
|
| 84 |
+
|
| 85 |
+
# Example inputs
|
| 86 |
+
torch.manual_seed(ep_rank)
|
| 87 |
+
bs = 4
|
| 88 |
+
seqlen = 128
|
| 89 |
+
x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device)
|
| 90 |
+
label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device)
|
| 91 |
+
|
| 92 |
+
# Create loss function
|
| 93 |
+
loss_fn = torch.nn.functional.cross_entropy
|
| 94 |
+
|
| 95 |
+
# Run forward and backward
|
| 96 |
+
steps = 2
|
| 97 |
+
for _ in range(steps):
|
| 98 |
+
if pp_size > 1:
|
| 99 |
+
# Create pipeline stage
|
| 100 |
+
stage = PipelineStage(
|
| 101 |
+
model,
|
| 102 |
+
pp_rank,
|
| 103 |
+
pp_size,
|
| 104 |
+
device,
|
| 105 |
+
group=pp_mesh.get_group(),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Create pipeline schedule
|
| 109 |
+
losses = []
|
| 110 |
+
pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn)
|
| 111 |
+
|
| 112 |
+
if pp_rank == 0:
|
| 113 |
+
y = pp_schedule.step(x)
|
| 114 |
+
elif pp_rank == pp_size - 1:
|
| 115 |
+
y = pp_schedule.step(target=label, losses=losses)
|
| 116 |
+
loss = torch.mean(torch.stack(losses))
|
| 117 |
+
else:
|
| 118 |
+
pp_schedule.step()
|
| 119 |
+
else:
|
| 120 |
+
y = model(x)
|
| 121 |
+
loss = loss_fn(y, label)
|
| 122 |
+
loss.backward()
|
| 123 |
+
|
| 124 |
+
if pp_rank == pp_size - 1:
|
| 125 |
+
print(f"logits: {y.shape}")
|
| 126 |
+
print(f"{loss=}")
|
| 127 |
+
|
| 128 |
+
if pp_rank == 0:
|
| 129 |
+
param = model.get_parameter("model.layers.0.self_attn.q_proj.weight")
|
| 130 |
+
print(f"{torch.linalg.norm(param.grad)=}")
|
| 131 |
+
|
| 132 |
+
model.zero_grad()
|
| 133 |
+
|
| 134 |
+
print("Backward done")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp"))
|
| 139 |
+
|
| 140 |
+
run_full_model(mesh)
|
| 141 |
+
|
| 142 |
+
dist.destroy_process_group()
|
torchtitan/experiments/flux/README.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FLUX model in torchtitan
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
## Usage
|
| 6 |
+
First, download the autoencoder model from HuggingFace with your own access token:
|
| 7 |
+
```bash
|
| 8 |
+
python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
|
| 9 |
+
```
|
| 10 |
+
This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
|
| 11 |
+
|
| 12 |
+
Run the following command to train the model on a single GPU:
|
| 13 |
+
```bash
|
| 14 |
+
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## TODO
|
| 18 |
+
- [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
|
| 19 |
+
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
|
| 20 |
+
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
|
| 21 |
+
- [ ] Support for distributed checkpointing and loading
|
| 22 |
+
- [ ] Implement init_weights() function to initialize the model weights
|
| 23 |
+
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
|
torchtitan/experiments/flux/__init__.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
| 8 |
+
|
| 9 |
+
from torchtitan.components.lr_scheduler import build_lr_schedulers
|
| 10 |
+
from torchtitan.components.optimizer import build_optimizers
|
| 11 |
+
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
|
| 12 |
+
from torchtitan.experiments.flux.loss import build_mse_loss
|
| 13 |
+
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
|
| 14 |
+
from torchtitan.experiments.flux.parallelize_flux import parallelize_flux
|
| 15 |
+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
|
| 16 |
+
|
| 17 |
+
from .model.model import FluxModel, FluxModelArgs
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"FluxModelArgs",
|
| 21 |
+
"FluxModel",
|
| 22 |
+
"flux_configs",
|
| 23 |
+
"parallelize_flux",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
flux_configs = {
|
| 28 |
+
"flux-dev": FluxModelArgs(
|
| 29 |
+
in_channels=64,
|
| 30 |
+
out_channels=64,
|
| 31 |
+
vec_in_dim=768,
|
| 32 |
+
context_in_dim=512,
|
| 33 |
+
hidden_size=3072,
|
| 34 |
+
mlp_ratio=4.0,
|
| 35 |
+
num_heads=24,
|
| 36 |
+
depth=19,
|
| 37 |
+
depth_single_blocks=38,
|
| 38 |
+
axes_dim=(16, 56, 56),
|
| 39 |
+
theta=10_000,
|
| 40 |
+
qkv_bias=True,
|
| 41 |
+
guidance_embed=True,
|
| 42 |
+
autoencoder_params=AutoEncoderParams(
|
| 43 |
+
resolution=256,
|
| 44 |
+
in_channels=3,
|
| 45 |
+
ch=128,
|
| 46 |
+
out_ch=3,
|
| 47 |
+
ch_mult=(1, 2, 4, 4),
|
| 48 |
+
num_res_blocks=2,
|
| 49 |
+
z_channels=16,
|
| 50 |
+
scale_factor=0.3611,
|
| 51 |
+
shift_factor=0.1159,
|
| 52 |
+
),
|
| 53 |
+
),
|
| 54 |
+
"flux-schnell": FluxModelArgs(
|
| 55 |
+
in_channels=64,
|
| 56 |
+
out_channels=64,
|
| 57 |
+
vec_in_dim=768,
|
| 58 |
+
context_in_dim=4096,
|
| 59 |
+
hidden_size=3072,
|
| 60 |
+
mlp_ratio=4.0,
|
| 61 |
+
num_heads=24,
|
| 62 |
+
depth=19,
|
| 63 |
+
depth_single_blocks=38,
|
| 64 |
+
axes_dim=(16, 56, 56),
|
| 65 |
+
theta=10_000,
|
| 66 |
+
qkv_bias=True,
|
| 67 |
+
guidance_embed=False,
|
| 68 |
+
autoencoder_params=AutoEncoderParams(
|
| 69 |
+
resolution=256,
|
| 70 |
+
in_channels=3,
|
| 71 |
+
ch=128,
|
| 72 |
+
out_ch=3,
|
| 73 |
+
ch_mult=(1, 2, 4, 4),
|
| 74 |
+
num_res_blocks=2,
|
| 75 |
+
z_channels=16,
|
| 76 |
+
scale_factor=0.3611,
|
| 77 |
+
shift_factor=0.1159,
|
| 78 |
+
),
|
| 79 |
+
),
|
| 80 |
+
"flux-debug": FluxModelArgs(
|
| 81 |
+
in_channels=64,
|
| 82 |
+
out_channels=64,
|
| 83 |
+
vec_in_dim=768,
|
| 84 |
+
context_in_dim=512,
|
| 85 |
+
hidden_size=512,
|
| 86 |
+
mlp_ratio=4.0,
|
| 87 |
+
num_heads=4,
|
| 88 |
+
depth=2,
|
| 89 |
+
depth_single_blocks=2,
|
| 90 |
+
axes_dim=(16, 56, 56),
|
| 91 |
+
theta=10_000,
|
| 92 |
+
qkv_bias=True,
|
| 93 |
+
guidance_embed=True,
|
| 94 |
+
autoencoder_params=AutoEncoderParams(
|
| 95 |
+
resolution=256,
|
| 96 |
+
in_channels=3,
|
| 97 |
+
ch=128,
|
| 98 |
+
out_ch=3,
|
| 99 |
+
ch_mult=(1, 2, 4, 4),
|
| 100 |
+
num_res_blocks=2,
|
| 101 |
+
z_channels=16,
|
| 102 |
+
scale_factor=0.3611,
|
| 103 |
+
shift_factor=0.1159,
|
| 104 |
+
),
|
| 105 |
+
),
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
register_train_spec(
|
| 110 |
+
TrainSpec(
|
| 111 |
+
name="flux",
|
| 112 |
+
cls=FluxModel,
|
| 113 |
+
config=flux_configs,
|
| 114 |
+
parallelize_fn=parallelize_flux,
|
| 115 |
+
pipelining_fn=None,
|
| 116 |
+
build_optimizers_fn=build_optimizers,
|
| 117 |
+
build_lr_schedulers_fn=build_lr_schedulers,
|
| 118 |
+
build_dataloader_fn=build_flux_dataloader,
|
| 119 |
+
build_tokenizer_fn=None,
|
| 120 |
+
build_loss_fn=build_mse_loss,
|
| 121 |
+
)
|
| 122 |
+
)
|
torchtitan/experiments/flux/flux_argparser.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def extend_parser(parser: argparse.ArgumentParser) -> None:
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--training.guidance",
|
| 15 |
+
type=float,
|
| 16 |
+
default=3.5,
|
| 17 |
+
help="guidance value used for guidance distillation",
|
| 18 |
+
)
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--encoder.t5_encoder",
|
| 21 |
+
type=str,
|
| 22 |
+
default="google/t5-v1_1-small",
|
| 23 |
+
help="T5 encoder to use, HuggingFace model name.",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--encoder.clip_encoder",
|
| 27 |
+
type=str,
|
| 28 |
+
default="openai/clip-vit-large-patch14",
|
| 29 |
+
help="Clip encoder to use, HuggingFace model name.",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--encoder.encoder_dtype",
|
| 33 |
+
type=torch.dtype,
|
| 34 |
+
default=torch.bfloat16,
|
| 35 |
+
help="Which dtype to load for autoencoder. ",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--encoder.max_t5_encoding_len",
|
| 39 |
+
type=int,
|
| 40 |
+
default=512,
|
| 41 |
+
help="Maximum length of the T5 encoding.",
|
| 42 |
+
)
|
torchtitan/experiments/flux/loss.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable, TypeAlias
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torchtitan.config_manager import JobConfig
|
| 12 |
+
from torchtitan.tools.logging import logger
|
| 13 |
+
|
| 14 |
+
LossFunction: TypeAlias = Callable[..., torch.Tensor]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
"""Common MSE loss function for Transformer models training."""
|
| 19 |
+
return torch.nn.functional.mse_loss(pred.float(), labels.float().detach())
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_mse_loss(job_config: JobConfig):
|
| 23 |
+
loss_fn = mse_loss
|
| 24 |
+
if job_config.training.compile:
|
| 25 |
+
logger.info("Compiling the loss function with torch.compile")
|
| 26 |
+
loss_fn = torch.compile(loss_fn)
|
| 27 |
+
return loss_fn
|
torchtitan/experiments/flux/parallelize_flux.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
|
| 8 |
+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 14 |
+
|
| 15 |
+
from torchtitan.config_manager import JobConfig
|
| 16 |
+
from torchtitan.distributed import ParallelDims
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parallelize_flux(
|
| 20 |
+
model: nn.Module,
|
| 21 |
+
world_mesh: DeviceMesh,
|
| 22 |
+
parallel_dims: ParallelDims,
|
| 23 |
+
job_config: JobConfig,
|
| 24 |
+
):
|
| 25 |
+
# TODO: Add model parallel strategy here
|
| 26 |
+
return model
|
torchtitan/experiments/flux/requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers
|
| 2 |
+
einops
|
torchtitan/experiments/flux/scripts/download_autoencoder.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from requests.exceptions import HTTPError
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def hf_download(
|
| 13 |
+
repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None
|
| 14 |
+
) -> None:
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
hf_hub_download(
|
| 19 |
+
repo_id=repo_id,
|
| 20 |
+
filename=file_path,
|
| 21 |
+
local_dir=local_dir,
|
| 22 |
+
local_dir_use_symlinks=False,
|
| 23 |
+
token=hf_token,
|
| 24 |
+
)
|
| 25 |
+
except HTTPError as e:
|
| 26 |
+
if e.response.status_code == 401:
|
| 27 |
+
print(
|
| 28 |
+
"You need to pass a valid `--hf_token=...` to download private checkpoints."
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
raise e
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
import argparse
|
| 36 |
+
|
| 37 |
+
parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--repo_id",
|
| 40 |
+
type=str,
|
| 41 |
+
default="black-forest-labs/FLUX.1-dev",
|
| 42 |
+
help="Repository ID to download from. default to Flux-dev model",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--ae_path",
|
| 46 |
+
type=str,
|
| 47 |
+
default="ae.safetensors",
|
| 48 |
+
help="the autoencoder path relative to repo_id",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--hf_token", type=str, default=None, help="HuggingFace API token"
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--local_dir",
|
| 55 |
+
type=str,
|
| 56 |
+
default="torchtitan/experiments/flux/assets/autoencoder/",
|
| 57 |
+
help="local directory to save the autoencoder",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token)
|
torchtitan/experiments/flux/tests/test_flux_dataloader.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from torchtitan.config_manager import JobConfig
|
| 10 |
+
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
|
| 11 |
+
from torchtitan.tools.profiling import (
|
| 12 |
+
maybe_enable_memory_snapshot,
|
| 13 |
+
maybe_enable_profiling,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestFluxDataLoader:
|
| 18 |
+
def test_flux_dataloader(self):
|
| 19 |
+
dataset_name = "cc12m"
|
| 20 |
+
batch_size = 32
|
| 21 |
+
world_size = 4
|
| 22 |
+
rank = 0
|
| 23 |
+
|
| 24 |
+
num_steps = 10
|
| 25 |
+
|
| 26 |
+
path = "torchtitan.experiments.flux.flux_argparser"
|
| 27 |
+
sys.argv.append(f"--experimental.custom_args_module={path}")
|
| 28 |
+
config = JobConfig()
|
| 29 |
+
config.maybe_add_custom_args()
|
| 30 |
+
config.parse_args(
|
| 31 |
+
[
|
| 32 |
+
# Profiling options
|
| 33 |
+
# "--profiling.enable_profiling",
|
| 34 |
+
# "--profiling.profile_freq",
|
| 35 |
+
# "5",
|
| 36 |
+
# "--profiling.enable_memory_snapshot",
|
| 37 |
+
# "--profiling.save_memory_snapshot_folder",
|
| 38 |
+
# "memory_snapshot_flux",
|
| 39 |
+
"--training.dataset",
|
| 40 |
+
dataset_name,
|
| 41 |
+
"--training.batch_size",
|
| 42 |
+
str(batch_size),
|
| 43 |
+
"--encoder.t5_encoder",
|
| 44 |
+
"google/t5-v1_1-small",
|
| 45 |
+
"--encoder.clip_encoder",
|
| 46 |
+
"openai/clip-vit-large-patch14",
|
| 47 |
+
"--encoder.max_t5_encoding_len",
|
| 48 |
+
"512",
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
with maybe_enable_profiling(
|
| 53 |
+
config, global_step=0
|
| 54 |
+
) as torch_profiler, maybe_enable_memory_snapshot(
|
| 55 |
+
config, global_step=0
|
| 56 |
+
) as memory_profiler:
|
| 57 |
+
dl = self._build_dataloader(
|
| 58 |
+
config,
|
| 59 |
+
world_size,
|
| 60 |
+
rank,
|
| 61 |
+
)
|
| 62 |
+
dl = iter(dl)
|
| 63 |
+
|
| 64 |
+
for i in range(0, num_steps):
|
| 65 |
+
input_data, labels = next(dl)
|
| 66 |
+
print(f"Step {i} image size: {labels.shape}")
|
| 67 |
+
if torch_profiler:
|
| 68 |
+
torch_profiler.step()
|
| 69 |
+
if memory_profiler:
|
| 70 |
+
memory_profiler.step()
|
| 71 |
+
|
| 72 |
+
print(len(input_data["clip_tokens"]))
|
| 73 |
+
for k, v in input_data.items():
|
| 74 |
+
print(f"Step {i} {k} value: {type(v), v.shape}")
|
| 75 |
+
|
| 76 |
+
assert len(input_data) == 2 # (clip_encodings, t5_encodings)
|
| 77 |
+
assert labels.shape == (batch_size, 3, 256, 256)
|
| 78 |
+
# assert input_data["clip_tokens"].shape[0] == batch_size
|
| 79 |
+
# assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
|
| 80 |
+
|
| 81 |
+
if torch_profiler:
|
| 82 |
+
torch_profiler.step()
|
| 83 |
+
if memory_profiler:
|
| 84 |
+
memory_profiler.step(exit_ctx=True)
|
| 85 |
+
|
| 86 |
+
def test_preprocess(self):
|
| 87 |
+
# TODO
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
def _build_dataloader(
|
| 91 |
+
self,
|
| 92 |
+
job_config,
|
| 93 |
+
world_size,
|
| 94 |
+
rank,
|
| 95 |
+
):
|
| 96 |
+
|
| 97 |
+
return build_flux_dataloader(
|
| 98 |
+
dp_world_size=world_size,
|
| 99 |
+
dp_rank=rank,
|
| 100 |
+
job_config=job_config,
|
| 101 |
+
tokenizer=None,
|
| 102 |
+
infinite=False,
|
| 103 |
+
)
|
torchtitan/experiments/flux/train.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from torchtitan.config_manager import JobConfig
|
| 13 |
+
from torchtitan.distributed import utils as dist_utils
|
| 14 |
+
from torchtitan.experiments.flux.model.autoencoder import load_ae
|
| 15 |
+
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
|
| 16 |
+
from torchtitan.experiments.flux.model.model import FluxModel
|
| 17 |
+
from torchtitan.experiments.flux.utils import (
|
| 18 |
+
create_position_encoding_for_latents,
|
| 19 |
+
pack_latents,
|
| 20 |
+
preprocess_flux_data,
|
| 21 |
+
unpack_latents,
|
| 22 |
+
)
|
| 23 |
+
from torchtitan.tools.logging import init_logger, logger
|
| 24 |
+
from torchtitan.train import Trainer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class FluxTrainer(Trainer):
|
| 28 |
+
def __init__(self, job_config: JobConfig):
|
| 29 |
+
super().__init__(job_config)
|
| 30 |
+
|
| 31 |
+
self.preprocess_fn = preprocess_flux_data
|
| 32 |
+
# self.dtype = job_config.encoder.dtype
|
| 33 |
+
self._dtype = torch.bfloat16
|
| 34 |
+
self._seed = job_config.training.seed
|
| 35 |
+
self._guidance = job_config.training.guidance
|
| 36 |
+
|
| 37 |
+
# load components
|
| 38 |
+
model_config = self.train_spec.config[job_config.model.flavor]
|
| 39 |
+
self.autoencoder = load_ae(
|
| 40 |
+
job_config.encoder.auto_encoder_path,
|
| 41 |
+
model_config.autoencoder_params,
|
| 42 |
+
device="cpu",
|
| 43 |
+
dtype=self._dtype,
|
| 44 |
+
)
|
| 45 |
+
self.clip_encoder = FluxEmbedder(version=job_config.encoder.clip_encoder).to(
|
| 46 |
+
dtype=self._dtype
|
| 47 |
+
)
|
| 48 |
+
self.t5_encoder = FluxEmbedder(version=job_config.encoder.t5_encoder).to(
|
| 49 |
+
dtype=self._dtype
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def _predict_noise(
|
| 53 |
+
self,
|
| 54 |
+
model: FluxModel,
|
| 55 |
+
latents: torch.Tensor,
|
| 56 |
+
clip_encodings: torch.Tensor,
|
| 57 |
+
t5_encodings: torch.Tensor,
|
| 58 |
+
timesteps: torch.Tensor,
|
| 59 |
+
guidance: Optional[torch.Tensor] = None,
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Use Flux's flow-matching model to predict the noise in image latents.
|
| 63 |
+
Args:
|
| 64 |
+
model (FluxFlowModel): The Flux flow model.
|
| 65 |
+
latents (Tensor): Image encodings from the Flux autoencoder.
|
| 66 |
+
Shape: [bsz, 16, latent height, latent width]
|
| 67 |
+
clip_encodings (Tensor): CLIP text encodings.
|
| 68 |
+
Shape: [bsz, 768]
|
| 69 |
+
t5_encodings (Tensor): T5 text encodings.
|
| 70 |
+
Shape: [bsz, sequence length, 256 or 512]
|
| 71 |
+
timesteps (Tensor): The amount of noise (0 to 1).
|
| 72 |
+
Shape: [bsz]
|
| 73 |
+
guidance (Optional[Tensor]): The guidance value (1.5 to 4) if guidance-enabled model.
|
| 74 |
+
Shape: [bsz]
|
| 75 |
+
Default: None
|
| 76 |
+
model_ctx (ContextManager): Optional context to wrap the model call (e.g. for activation offloading)
|
| 77 |
+
Default: nullcontext
|
| 78 |
+
Returns:
|
| 79 |
+
Tensor: The noise prediction.
|
| 80 |
+
Shape: [bsz, 16, latent height, latent width]
|
| 81 |
+
"""
|
| 82 |
+
bsz, _, latent_height, latent_width = latents.shape
|
| 83 |
+
|
| 84 |
+
POSITION_DIM = 3 # constant for Flux flow model
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
# Create positional encodings
|
| 87 |
+
latent_pos_enc = create_position_encoding_for_latents(
|
| 88 |
+
bsz, latent_height, latent_width, POSITION_DIM
|
| 89 |
+
)
|
| 90 |
+
text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM)
|
| 91 |
+
|
| 92 |
+
# Convert latent into a sequence of patches
|
| 93 |
+
latents = pack_latents(latents)
|
| 94 |
+
|
| 95 |
+
# Predict noise
|
| 96 |
+
latent_noise_pred = model(
|
| 97 |
+
img=latents,
|
| 98 |
+
img_ids=latent_pos_enc.to(latents),
|
| 99 |
+
txt=t5_encodings.to(latents),
|
| 100 |
+
txt_ids=text_pos_enc.to(latents),
|
| 101 |
+
y=clip_encodings.to(latents),
|
| 102 |
+
timesteps=timesteps.to(latents),
|
| 103 |
+
guidance=guidance.to(latents) if guidance is not None else None,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Convert sequence of patches to latent shape
|
| 107 |
+
latent_noise_pred = unpack_latents(
|
| 108 |
+
latent_noise_pred, latent_height, latent_width
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return latent_noise_pred
|
| 112 |
+
|
| 113 |
+
def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
|
| 114 |
+
# generate t5 and clip
|
| 115 |
+
input_dict["image"] = labels
|
| 116 |
+
input_dict = self.preprocess_fn(
|
| 117 |
+
device=self.device,
|
| 118 |
+
dtype=self._dtype,
|
| 119 |
+
autoencoder=self.autoencoder,
|
| 120 |
+
clip_encoder=self.clip_encoder,
|
| 121 |
+
t5_encoder=self.t5_encoder,
|
| 122 |
+
batch=input_dict,
|
| 123 |
+
offload=True,
|
| 124 |
+
)
|
| 125 |
+
labels = input_dict["img_encodings"]
|
| 126 |
+
|
| 127 |
+
self.optimizers.zero_grad()
|
| 128 |
+
|
| 129 |
+
# Keep these variables local to shorten the code as these are
|
| 130 |
+
# the major variables that are used in the training loop.
|
| 131 |
+
model_parts = self.model_parts
|
| 132 |
+
world_mesh = self.world_mesh
|
| 133 |
+
parallel_dims = self.parallel_dims
|
| 134 |
+
|
| 135 |
+
# image in latent space transformed by self.auto_encoder
|
| 136 |
+
clip_encodings = input_dict["clip_encodings"]
|
| 137 |
+
t5_encodings = input_dict["t5_encodings"]
|
| 138 |
+
|
| 139 |
+
bsz = labels.shape[0]
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
noise = torch.randn_like(labels)
|
| 143 |
+
timesteps = torch.rand((bsz,)).to(labels)
|
| 144 |
+
sigmas = timesteps.view(-1, 1, 1, 1)
|
| 145 |
+
noisy_latents = (1 - sigmas) * labels + sigmas * noise
|
| 146 |
+
guidance = torch.full((bsz,), self._guidance).to(labels)
|
| 147 |
+
|
| 148 |
+
target = noise - labels
|
| 149 |
+
|
| 150 |
+
assert len(model_parts) == 1
|
| 151 |
+
# TODO(jianiw): model_parts will be wrapped by FSDP, which will cacluate
|
| 152 |
+
model_parts[0] = model_parts[0].to(dtype=self._dtype)
|
| 153 |
+
|
| 154 |
+
pred = self._predict_noise(
|
| 155 |
+
model_parts[0],
|
| 156 |
+
noisy_latents,
|
| 157 |
+
clip_encodings,
|
| 158 |
+
t5_encodings,
|
| 159 |
+
timesteps,
|
| 160 |
+
guidance,
|
| 161 |
+
)
|
| 162 |
+
loss = self.loss_fn(pred, target)
|
| 163 |
+
# pred.shape=(bs, seq_len, vocab_size)
|
| 164 |
+
# need to free to before bwd to avoid peaking memory
|
| 165 |
+
del (pred, noise, target)
|
| 166 |
+
loss.backward()
|
| 167 |
+
|
| 168 |
+
dist_utils.clip_grad_norm_(
|
| 169 |
+
[p for m in model_parts for p in m.parameters()],
|
| 170 |
+
self.job_config.training.max_norm,
|
| 171 |
+
foreach=True,
|
| 172 |
+
pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None,
|
| 173 |
+
)
|
| 174 |
+
self.checkpointer.maybe_wait_for_staging()
|
| 175 |
+
self.optimizers.step()
|
| 176 |
+
self.lr_schedulers.step()
|
| 177 |
+
|
| 178 |
+
# log metrics
|
| 179 |
+
if not self.metrics_processor.should_log(self.step):
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
if (
|
| 183 |
+
parallel_dims.dp_replicate_enabled
|
| 184 |
+
or parallel_dims.dp_shard_enabled
|
| 185 |
+
or parallel_dims.cp_enabled
|
| 186 |
+
):
|
| 187 |
+
loss = loss.detach()
|
| 188 |
+
global_avg_loss, global_max_loss = (
|
| 189 |
+
dist_utils.dist_mean(loss, world_mesh["dp_cp"]),
|
| 190 |
+
dist_utils.dist_max(loss, world_mesh["dp_cp"]),
|
| 191 |
+
)
|
| 192 |
+
else:
|
| 193 |
+
global_avg_loss = global_max_loss = loss.item()
|
| 194 |
+
|
| 195 |
+
self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
init_logger()
|
| 200 |
+
config = JobConfig()
|
| 201 |
+
config.maybe_add_custom_args()
|
| 202 |
+
config.parse_args()
|
| 203 |
+
trainer: Optional[FluxTrainer] = None
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
trainer = FluxTrainer(config)
|
| 207 |
+
if config.checkpoint.create_seed_checkpoint:
|
| 208 |
+
assert int(
|
| 209 |
+
os.environ["WORLD_SIZE"]
|
| 210 |
+
), "Must create seed checkpoint using a single device, to disable sharding."
|
| 211 |
+
assert (
|
| 212 |
+
config.checkpoint.enable_checkpoint
|
| 213 |
+
), "Must enable checkpointing when creating a seed checkpoint."
|
| 214 |
+
trainer.checkpointer.save(curr_step=0, force=True)
|
| 215 |
+
logger.info("Created seed checkpoint")
|
| 216 |
+
else:
|
| 217 |
+
trainer.train()
|
| 218 |
+
finally:
|
| 219 |
+
if trainer:
|
| 220 |
+
trainer.close()
|
| 221 |
+
|
| 222 |
+
if torch.distributed.is_initialized():
|
| 223 |
+
torch.distributed.destroy_process_group()
|
| 224 |
+
logger.info("Process group destroyed.")
|
torchtitan/experiments/flux/train_configs/debug_model.toml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
[job]
|
| 3 |
+
dump_folder = "./outputs"
|
| 4 |
+
description = "Flux debug model"
|
| 5 |
+
print_args = false
|
| 6 |
+
use_for_integration_test = true
|
| 7 |
+
|
| 8 |
+
[profiling]
|
| 9 |
+
enable_profiling = false
|
| 10 |
+
save_traces_folder = "profile_trace"
|
| 11 |
+
profile_freq = 10
|
| 12 |
+
enable_memory_snapshot = false
|
| 13 |
+
save_memory_snapshot_folder = "memory_snapshot"
|
| 14 |
+
|
| 15 |
+
[metrics]
|
| 16 |
+
log_freq = 1
|
| 17 |
+
disable_color_printing = false
|
| 18 |
+
enable_tensorboard = false
|
| 19 |
+
save_tb_folder = "tb"
|
| 20 |
+
enable_wandb = false
|
| 21 |
+
|
| 22 |
+
[model]
|
| 23 |
+
name = "flux"
|
| 24 |
+
flavor = "flux-debug"
|
| 25 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 26 |
+
# test tokenizer.model, for debug purpose only
|
| 27 |
+
# tokenizer_path = "./tests/assets/test_tiktoken.model"
|
| 28 |
+
# converters = "float8"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
[optimizer]
|
| 32 |
+
name = "AdamW"
|
| 33 |
+
lr = 8e-4
|
| 34 |
+
eps = 1e-8
|
| 35 |
+
|
| 36 |
+
[lr_scheduler]
|
| 37 |
+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
|
| 38 |
+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
|
| 39 |
+
decay_type = "linear"
|
| 40 |
+
lr_min = 0.0
|
| 41 |
+
|
| 42 |
+
[training]
|
| 43 |
+
batch_size = 32
|
| 44 |
+
seq_len = 512
|
| 45 |
+
max_norm = 1.0 # grad norm clipping
|
| 46 |
+
steps = 10
|
| 47 |
+
compile = false
|
| 48 |
+
dataset = "cc12m"
|
| 49 |
+
guidance = 3.5
|
| 50 |
+
seed = 0
|
| 51 |
+
|
| 52 |
+
[encoder]
|
| 53 |
+
t5_encoder="google/t5-v1_1-small"
|
| 54 |
+
clip_encoder="openai/clip-vit-large-patch14"
|
| 55 |
+
max_t5_encoding_len=512
|
| 56 |
+
auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
|
| 57 |
+
|
| 58 |
+
[parallelism]
|
| 59 |
+
data_parallel_replicate_degree = 1
|
| 60 |
+
data_parallel_shard_degree = 1
|
| 61 |
+
fsdp_reshard_after_forward = "default" # default / never / always
|
| 62 |
+
tensor_parallel_degree = 1
|
| 63 |
+
enable_async_tensor_parallel = false
|
| 64 |
+
pipeline_parallel_degree = 1
|
| 65 |
+
context_parallel_degree = 1
|
| 66 |
+
|
| 67 |
+
[experimental]
|
| 68 |
+
custom_args_module = "torchtitan.experiments.flux.flux_argparser"
|
torchtitan/experiments/flux/utils.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
|
| 13 |
+
from torchtitan.experiments.flux.model.autoencoder import AutoEncoder
|
| 14 |
+
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def preprocess_flux_data(
|
| 18 |
+
# arguments from the recipe
|
| 19 |
+
device: torch.device,
|
| 20 |
+
dtype: torch.dtype,
|
| 21 |
+
*,
|
| 22 |
+
# arguments from the config
|
| 23 |
+
autoencoder: Optional[AutoEncoder],
|
| 24 |
+
clip_encoder: FluxEmbedder,
|
| 25 |
+
t5_encoder: FluxEmbedder,
|
| 26 |
+
batch: dict[str, Tensor],
|
| 27 |
+
offload: bool = False,
|
| 28 |
+
) -> dict[str, Tensor]:
|
| 29 |
+
"""
|
| 30 |
+
Take a batch of inputs and encoder as input and return a batch of preprocessed data.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
device (torch.device): device to do preprocessing on
|
| 34 |
+
dtype (torch.dtype): data type to do preprocessing in
|
| 35 |
+
autoencoer(AutoEncoder): autoencoder to use for preprocessing
|
| 36 |
+
clip_encoder
|
| 37 |
+
t5_encoder
|
| 38 |
+
batch (dict[str, Tensor]): batch of data to preprocess
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
dict[str, Tensor]: batch of preprocessed data
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
# The input of encoder should be torch.int type
|
| 45 |
+
if offload:
|
| 46 |
+
clip_encoder.to(device)
|
| 47 |
+
t5_encoder.to(device)
|
| 48 |
+
if autoencoder is not None:
|
| 49 |
+
autoencoder.to(device)
|
| 50 |
+
|
| 51 |
+
clip_tokens = batch["clip_tokens"].squeeze().to(device=device, dtype=torch.int)
|
| 52 |
+
t5_tokens = batch["t5_tokens"].squeeze().to(device=device, dtype=torch.int)
|
| 53 |
+
|
| 54 |
+
clip_text_encodings = clip_encoder(clip_tokens)
|
| 55 |
+
t5_text_encodings = t5_encoder(t5_tokens)
|
| 56 |
+
|
| 57 |
+
if autoencoder is not None:
|
| 58 |
+
images = batch["image"].to(device=device, dtype=dtype)
|
| 59 |
+
img_encodings = autoencoder.encode(images)
|
| 60 |
+
batch["img_encodings"] = img_encodings.to(device=device, dtype=dtype)
|
| 61 |
+
|
| 62 |
+
batch["clip_encodings"] = clip_text_encodings.to(dtype)
|
| 63 |
+
batch["t5_encodings"] = t5_text_encodings.to(dtype)
|
| 64 |
+
|
| 65 |
+
# offload encoders to cpu after preprocessing
|
| 66 |
+
if offload:
|
| 67 |
+
clip_encoder.to("cpu")
|
| 68 |
+
t5_encoder.to("cpu")
|
| 69 |
+
if autoencoder is not None:
|
| 70 |
+
autoencoder.to("cpu")
|
| 71 |
+
|
| 72 |
+
return batch
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def generate_noise_latent(
|
| 76 |
+
bsz: int,
|
| 77 |
+
height: int,
|
| 78 |
+
width: int,
|
| 79 |
+
device: str | torch.device,
|
| 80 |
+
dtype: torch.dtype,
|
| 81 |
+
seed: int,
|
| 82 |
+
) -> Tensor:
|
| 83 |
+
"""Generate noise latents for the Flux flow model.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
bsz (int): batch_size.
|
| 87 |
+
height (int): The height of the image.
|
| 88 |
+
width (int): The width of the image.
|
| 89 |
+
device (str | torch.device): The device to use.
|
| 90 |
+
dtype (torch.dtype): The dtype to use.
|
| 91 |
+
seed (int): The seed to use for randomize.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Tensor: The noise latents.
|
| 95 |
+
Shape: [num_samples, LATENT_CHANNELS, height // IMG_LATENT_SIZE_RATIO, width // IMG_LATENT_SIZE_RATIO]
|
| 96 |
+
|
| 97 |
+
"""
|
| 98 |
+
LATENT_CHANNELS, IMAGE_LATENT_SIZE_RATIO = 16, 8
|
| 99 |
+
return torch.randn(
|
| 100 |
+
bsz,
|
| 101 |
+
LATENT_CHANNELS,
|
| 102 |
+
height // IMAGE_LATENT_SIZE_RATIO,
|
| 103 |
+
width // IMAGE_LATENT_SIZE_RATIO,
|
| 104 |
+
dtype=dtype,
|
| 105 |
+
generator=torch.Generator().manual_seed(seed),
|
| 106 |
+
).to(device)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def create_position_encoding_for_latents(
|
| 110 |
+
bsz: int, latent_height: int, latent_width: int, position_dim: int = 3
|
| 111 |
+
) -> Tensor:
|
| 112 |
+
"""
|
| 113 |
+
Create the packed latents' position encodings for the Flux flow model.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
bsz (int): The batch size.
|
| 117 |
+
latent_height (int): The height of the latent.
|
| 118 |
+
latent_width (int): The width of the latent.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Tensor: The position encodings.
|
| 122 |
+
Shape: [bsz, (latent_height // PATCH_HEIGHT) * (latent_width // PATCH_WIDTH), POSITION_DIM)
|
| 123 |
+
"""
|
| 124 |
+
PATCH_HEIGHT, PATCH_WIDTH = 2, 2
|
| 125 |
+
|
| 126 |
+
height = latent_height // PATCH_HEIGHT
|
| 127 |
+
width = latent_width // PATCH_WIDTH
|
| 128 |
+
|
| 129 |
+
position_encoding = torch.zeros(height, width, position_dim)
|
| 130 |
+
|
| 131 |
+
row_indices = torch.arange(height)
|
| 132 |
+
position_encoding[:, :, 1] = row_indices.unsqueeze(1)
|
| 133 |
+
|
| 134 |
+
col_indices = torch.arange(width)
|
| 135 |
+
position_encoding[:, :, 2] = col_indices.unsqueeze(0)
|
| 136 |
+
|
| 137 |
+
# Flatten and repeat for the full batch
|
| 138 |
+
# [height, width, 3] -> [bsz, height * width, 3]
|
| 139 |
+
position_encoding = position_encoding.view(1, height * width, position_dim)
|
| 140 |
+
position_encoding = position_encoding.repeat(bsz, 1, 1)
|
| 141 |
+
|
| 142 |
+
return position_encoding
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def pack_latents(x: Tensor) -> Tensor:
|
| 146 |
+
"""
|
| 147 |
+
Rearrange latents from an image-like format into a sequence of patches.
|
| 148 |
+
Equivalent to `einops.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)")`.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
x (Tensor): The unpacked latents.
|
| 152 |
+
Shape: [bsz, ch, latent height, latent width]
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Tensor: The packed latents.
|
| 156 |
+
Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
|
| 157 |
+
"""
|
| 158 |
+
PATCH_HEIGHT, PATCH_WIDTH = 2, 2
|
| 159 |
+
|
| 160 |
+
b, c, latent_height, latent_width = x.shape
|
| 161 |
+
h = latent_height // PATCH_HEIGHT
|
| 162 |
+
w = latent_width // PATCH_WIDTH
|
| 163 |
+
|
| 164 |
+
# [b, c, h*ph, w*ph] -> [b, c, h, w, ph, pw]
|
| 165 |
+
x = x.unfold(2, PATCH_HEIGHT, PATCH_HEIGHT).unfold(3, PATCH_WIDTH, PATCH_WIDTH)
|
| 166 |
+
|
| 167 |
+
# [b, c, h, w, ph, PW] -> [b, h, w, c, ph, PW]
|
| 168 |
+
x = x.permute(0, 2, 3, 1, 4, 5)
|
| 169 |
+
|
| 170 |
+
# [b, h, w, c, ph, PW] -> [b, h*w, c*ph*PW]
|
| 171 |
+
return x.reshape(b, h * w, c * PATCH_HEIGHT * PATCH_WIDTH)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def unpack_latents(x: Tensor, latent_height: int, latent_width: int) -> Tensor:
|
| 175 |
+
"""
|
| 176 |
+
Rearrange latents from a sequence of patches into an image-like format.
|
| 177 |
+
Equivalent to `einops.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)")`.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
x (Tensor): The packed latents.
|
| 181 |
+
Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
|
| 182 |
+
latent_height (int): The height of the unpacked latents.
|
| 183 |
+
latent_width (int): The width of the unpacked latents.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Tensor: The unpacked latents.
|
| 187 |
+
Shape: [bsz, ch, latent height, latent width]
|
| 188 |
+
"""
|
| 189 |
+
PATCH_HEIGHT, PATCH_WIDTH = 2, 2
|
| 190 |
+
|
| 191 |
+
b, _, c_ph_pw = x.shape
|
| 192 |
+
h = latent_height // PATCH_HEIGHT
|
| 193 |
+
w = latent_width // PATCH_WIDTH
|
| 194 |
+
c = c_ph_pw // (PATCH_HEIGHT * PATCH_WIDTH)
|
| 195 |
+
|
| 196 |
+
# [b, h*w, c*ph*pw] -> [b, h, w, c, ph, pw]
|
| 197 |
+
x = x.reshape(b, h, w, c, PATCH_HEIGHT, PATCH_WIDTH)
|
| 198 |
+
|
| 199 |
+
# [b, h, w, c, ph, pw] -> [b, c, h, ph, w, pw]
|
| 200 |
+
x = x.permute(0, 3, 1, 4, 2, 5)
|
| 201 |
+
|
| 202 |
+
# [b, c, h, ph, w, pw] -> [b, c, h*ph, w*pw]
|
| 203 |
+
return x.reshape(b, c, h * PATCH_HEIGHT, w * PATCH_WIDTH)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .mg_grouped_gemm import grouped_gemm_forward
|
| 8 |
+
from .tma_autotuning import ALIGN_SIZE_M
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"grouped_gemm_forward",
|
| 12 |
+
"ALIGN_SIZE_M",
|
| 13 |
+
]
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# pyre-unsafe
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from reference_utils import (
|
| 14 |
+
analyze_tensor_differences,
|
| 15 |
+
compute_reference_backward,
|
| 16 |
+
compute_reference_forward,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Configure logging
|
| 20 |
+
logging.basicConfig(
|
| 21 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Import grouped GEMM implementations
|
| 25 |
+
try:
|
| 26 |
+
from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward
|
| 27 |
+
|
| 28 |
+
except ImportError:
|
| 29 |
+
logging.error(
|
| 30 |
+
"Error importing grouped GEMM modules. Make sure the implementation files are in the correct path."
|
| 31 |
+
)
|
| 32 |
+
raise
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_forward_pass():
|
| 36 |
+
"""
|
| 37 |
+
A simple test for the M*G grouped GEMM forward pass with detailed error handling.
|
| 38 |
+
|
| 39 |
+
In M*G grouping:
|
| 40 |
+
- M dimension is partitioned into G groups (M_total = sum(M_sizes))
|
| 41 |
+
- N dimension is the same for all groups
|
| 42 |
+
"""
|
| 43 |
+
try:
|
| 44 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
+
|
| 46 |
+
# Test parameters for DeepSeek-like models
|
| 47 |
+
G = 1 # Number of groups
|
| 48 |
+
M_sizes = [
|
| 49 |
+
2048,
|
| 50 |
+
] # 2048, 2048, 2048] # Group sizes (will be adjusted)
|
| 51 |
+
M_total = sum(M_sizes) # Total M dimension
|
| 52 |
+
N = 4096 # Output dimension (same for all groups)
|
| 53 |
+
K = 7168 # Hidden dimension
|
| 54 |
+
|
| 55 |
+
# Create group sizes tensor
|
| 56 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 57 |
+
|
| 58 |
+
# Create input and weight tensors - using float16 for higher precision
|
| 59 |
+
x = torch.randn(M_total, K, dtype=torch.float16, device=device)
|
| 60 |
+
w = torch.randn(N, K, dtype=torch.float16, device=device)
|
| 61 |
+
|
| 62 |
+
# Log the setup
|
| 63 |
+
logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
|
| 64 |
+
logging.info(f"Group sizes: {m_sizes}")
|
| 65 |
+
logging.info(f"Input x shape: {x.shape}")
|
| 66 |
+
logging.info(f"Weight w shape: {w.shape}")
|
| 67 |
+
|
| 68 |
+
# Run forward pass
|
| 69 |
+
logging.info("Running forward pass with grouped GEMM")
|
| 70 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
| 71 |
+
logging.info(f"Forward result shape: {result.shape}")
|
| 72 |
+
|
| 73 |
+
# Compute reference result
|
| 74 |
+
logging.info("Computing reference result with PyTorch")
|
| 75 |
+
reference_result = compute_reference_forward(x, w, m_sizes)
|
| 76 |
+
|
| 77 |
+
# Compare results
|
| 78 |
+
logging.info("Comparing with PyTorch reference")
|
| 79 |
+
forward_close = analyze_tensor_differences(
|
| 80 |
+
result, reference_result, "Forward output"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return forward_close
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logging.error(f"Test failed with error: {e}")
|
| 87 |
+
import traceback
|
| 88 |
+
|
| 89 |
+
logging.error(traceback.format_exc())
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def test_backward_pass():
|
| 94 |
+
"""
|
| 95 |
+
A simple test for the M*G grouped GEMM backward pass with detailed error handling.
|
| 96 |
+
|
| 97 |
+
In M*G grouping:
|
| 98 |
+
- M dimension is partitioned into G groups (M_total = sum(M_sizes))
|
| 99 |
+
- N dimension is the same for all groups
|
| 100 |
+
"""
|
| 101 |
+
try:
|
| 102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 103 |
+
|
| 104 |
+
# Test parameters for DeepSeek-like models
|
| 105 |
+
G = 4 # Number of groups
|
| 106 |
+
M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted)
|
| 107 |
+
M_total = sum(M_sizes) # Total M dimension
|
| 108 |
+
N = 4096 # Output dimension (same for all groups)
|
| 109 |
+
K = 7168 # Hidden dimension
|
| 110 |
+
|
| 111 |
+
# Create group sizes tensor
|
| 112 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 113 |
+
|
| 114 |
+
# Create input and weight tensors - using float16 for higher precision
|
| 115 |
+
x = torch.randn(
|
| 116 |
+
M_total, K, dtype=torch.float16, device=device, requires_grad=True
|
| 117 |
+
)
|
| 118 |
+
w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True)
|
| 119 |
+
|
| 120 |
+
# Log the setup
|
| 121 |
+
logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
|
| 122 |
+
logging.info(f"Group sizes: {m_sizes}")
|
| 123 |
+
logging.info(f"Input x shape: {x.shape}")
|
| 124 |
+
logging.info(f"Weight w shape: {w.shape}")
|
| 125 |
+
|
| 126 |
+
# Step 1: Run forward pass
|
| 127 |
+
logging.info("Running forward pass")
|
| 128 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
| 129 |
+
logging.info(f"Forward result shape: {result.shape}")
|
| 130 |
+
|
| 131 |
+
# Create a gradient for backpropagation
|
| 132 |
+
grad_output = torch.randn_like(result)
|
| 133 |
+
logging.info(f"Created gradient with shape: {grad_output.shape}")
|
| 134 |
+
|
| 135 |
+
# Step 2: Run backward pass directly
|
| 136 |
+
logging.info("Running backward pass directly")
|
| 137 |
+
grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
|
| 138 |
+
|
| 139 |
+
# Verify gradient shapes
|
| 140 |
+
logging.info(
|
| 141 |
+
f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Step 3: Verify gradient computation using PyTorch's autograd
|
| 145 |
+
logging.info("Running PyTorch reference implementation")
|
| 146 |
+
|
| 147 |
+
# Compute reference gradients
|
| 148 |
+
x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output)
|
| 149 |
+
|
| 150 |
+
# Compare gradients
|
| 151 |
+
logging.info("Comparing gradients with PyTorch reference")
|
| 152 |
+
grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
|
| 153 |
+
grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
|
| 154 |
+
|
| 155 |
+
# Log overall result
|
| 156 |
+
if grad_x_close and grad_w_close:
|
| 157 |
+
logging.info("✓ SUCCESS: Gradients match the PyTorch reference")
|
| 158 |
+
else:
|
| 159 |
+
logging.error("✗ FAILURE: Gradient mismatch detected")
|
| 160 |
+
|
| 161 |
+
return grad_x_close and grad_w_close
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logging.error(f"Test failed with error: {e}")
|
| 165 |
+
import traceback
|
| 166 |
+
|
| 167 |
+
logging.error(traceback.format_exc())
|
| 168 |
+
return False
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def test_multiple_deepseek_configs():
|
| 172 |
+
"""
|
| 173 |
+
Test multiple DeepSeek model configurations with both forward and backward pass verification.
|
| 174 |
+
"""
|
| 175 |
+
# DeepSeek configurations: (G, M, K, N)
|
| 176 |
+
configs = [
|
| 177 |
+
(4, 8192, 7168, 4096), # Config 1
|
| 178 |
+
(4, 8192, 2048, 7168), # Config 2
|
| 179 |
+
(8, 4096, 7168, 4096), # Config 3
|
| 180 |
+
(8, 4096, 2048, 7168), # Config 4
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
results = []
|
| 184 |
+
|
| 185 |
+
for config_idx, (G, M, K, N) in enumerate(configs):
|
| 186 |
+
logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====")
|
| 187 |
+
logging.info(f"G={G}, M={M}, K={K}, N={N}")
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 191 |
+
|
| 192 |
+
# Create even group sizes
|
| 193 |
+
base_size = M // G
|
| 194 |
+
remainder = M % G
|
| 195 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
| 196 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 197 |
+
|
| 198 |
+
# Create input and weight tensors using float16 for higher precision
|
| 199 |
+
x = torch.randn(
|
| 200 |
+
M, K, dtype=torch.float16, device=device, requires_grad=True
|
| 201 |
+
)
|
| 202 |
+
w = torch.randn(
|
| 203 |
+
N, K, dtype=torch.float16, device=device, requires_grad=True
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}")
|
| 207 |
+
|
| 208 |
+
# Run forward pass
|
| 209 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
| 210 |
+
logging.info(f"Forward result shape: {result.shape}")
|
| 211 |
+
|
| 212 |
+
# ===== FORWARD PASS VERIFICATION =====
|
| 213 |
+
# Compute reference forward result
|
| 214 |
+
reference_result = compute_reference_forward(x, w, m_sizes)
|
| 215 |
+
|
| 216 |
+
# Compare forward results
|
| 217 |
+
forward_close = analyze_tensor_differences(
|
| 218 |
+
result, reference_result, "Forward output"
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# ===== BACKWARD PASS VERIFICATION =====
|
| 222 |
+
# Create gradient for backpropagation
|
| 223 |
+
grad_output = torch.randn_like(result)
|
| 224 |
+
|
| 225 |
+
# Run backward pass
|
| 226 |
+
grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
|
| 227 |
+
|
| 228 |
+
# Compute reference gradients
|
| 229 |
+
x_ref_grad, w_ref_grad = compute_reference_backward(
|
| 230 |
+
x, w, m_sizes, grad_output
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Compare backward results
|
| 234 |
+
grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
|
| 235 |
+
grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
|
| 236 |
+
|
| 237 |
+
# Overall config result
|
| 238 |
+
backward_close = grad_x_close and grad_w_close
|
| 239 |
+
config_success = forward_close and backward_close
|
| 240 |
+
results.append(
|
| 241 |
+
(config_idx + 1, config_success, forward_close, backward_close)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Log overall config result
|
| 245 |
+
if config_success:
|
| 246 |
+
logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!")
|
| 247 |
+
else:
|
| 248 |
+
logging.error(
|
| 249 |
+
f"✗ FAILURE: Config {config_idx+1} failed one or more tests"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logging.error(f"Config {config_idx+1} test failed with error: {e}")
|
| 254 |
+
import traceback
|
| 255 |
+
|
| 256 |
+
logging.error(traceback.format_exc())
|
| 257 |
+
results.append((config_idx + 1, False, False, False))
|
| 258 |
+
|
| 259 |
+
# Summary
|
| 260 |
+
logging.info("\n===== Test Results Summary =====")
|
| 261 |
+
for config_idx, overall_success, forward_success, backward_success in results:
|
| 262 |
+
overall_status = "✓ PASSED" if overall_success else "✗ FAILED"
|
| 263 |
+
forward_status = "✓ PASSED" if forward_success else "✗ FAILED"
|
| 264 |
+
backward_status = "✓ PASSED" if backward_success else "✗ FAILED"
|
| 265 |
+
|
| 266 |
+
logging.info(f"Config {config_idx}: {overall_status}")
|
| 267 |
+
logging.info(f" - Forward pass: {forward_status}")
|
| 268 |
+
logging.info(f" - Backward pass: {backward_status}")
|
| 269 |
+
|
| 270 |
+
return all(overall_success for _, overall_success, _, _ in results)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
logging.info(
|
| 275 |
+
"Running verification for both forward and backward pass of M*G grouped GEMM"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Run basic forward pass test
|
| 279 |
+
logging.info("\n===== Running basic forward pass test =====")
|
| 280 |
+
success_forward = test_forward_pass()
|
| 281 |
+
logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}")
|
| 282 |
+
|
| 283 |
+
# Run basic backward pass test
|
| 284 |
+
logging.info("\n===== Running basic backward pass test =====")
|
| 285 |
+
success_backward = test_backward_pass()
|
| 286 |
+
logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}")
|
| 287 |
+
|
| 288 |
+
# Run multiple DeepSeek configs with forward and backward verification
|
| 289 |
+
logging.info("\n===== Running tests for all DeepSeek configs =====")
|
| 290 |
+
success_configs = test_multiple_deepseek_configs()
|
| 291 |
+
logging.info(
|
| 292 |
+
f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Overall result
|
| 296 |
+
overall_success = success_forward and success_backward and success_configs
|
| 297 |
+
logging.info(
|
| 298 |
+
f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}"
|
| 299 |
+
)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
|
| 8 |
+
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
| 9 |
+
|
| 10 |
+
# pyre-unsafe
|
| 11 |
+
import functools
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from typing import Any, Dict, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
import triton
|
| 20 |
+
import triton.language as tl
|
| 21 |
+
from triton import Config as TConfig
|
| 22 |
+
|
| 23 |
+
from triton.runtime import driver # @manual
|
| 24 |
+
|
| 25 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ===== Supporting utils, CUDA and TMA =====
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CudaUtils:
|
| 32 |
+
@staticmethod
|
| 33 |
+
def is_cuda() -> bool:
|
| 34 |
+
"""Check if Triton is running on CUDA backend."""
|
| 35 |
+
return driver.active.get_current_target().backend == "cuda"
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def verify_tma() -> bool:
|
| 39 |
+
"""Check if TMA is supported on the current device."""
|
| 40 |
+
return (
|
| 41 |
+
CudaUtils.is_cuda()
|
| 42 |
+
and torch.cuda.is_available()
|
| 43 |
+
and torch.cuda.get_device_capability()[0] >= 9
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_num_sms() -> int:
|
| 48 |
+
"""Get the number of streaming multiprocessors on the current device."""
|
| 49 |
+
if not CudaUtils.is_cuda():
|
| 50 |
+
raise RuntimeError("Triton is not running on CUDA backend")
|
| 51 |
+
if not torch.cuda.is_available():
|
| 52 |
+
raise RuntimeError("CUDA is not available")
|
| 53 |
+
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TmaDescriptorHelper:
|
| 57 |
+
"""Helper class for managing TMA descriptors in Triton kernels."""
|
| 58 |
+
|
| 59 |
+
class KernelParamWrapper:
|
| 60 |
+
"""Wrapper to implement the TmaDescKernelParam interface."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, desc: torch.Tensor):
|
| 63 |
+
self.desc = desc
|
| 64 |
+
|
| 65 |
+
def tma_desc_cpu_ptr(self) -> int:
|
| 66 |
+
"""Return the CPU pointer to the TMA descriptor."""
|
| 67 |
+
return self.desc.data_ptr()
|
| 68 |
+
|
| 69 |
+
def __init__(self, tma_size: int = 128):
|
| 70 |
+
"""Initialize the TMA descriptor helper.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
tma_size: Size of the TMA descriptor in bytes
|
| 74 |
+
"""
|
| 75 |
+
if not CudaUtils.verify_tma():
|
| 76 |
+
raise RuntimeError(
|
| 77 |
+
"TMA not supported on this device (requires Hopper or newer)"
|
| 78 |
+
)
|
| 79 |
+
if "nv_tma_desc_type" not in dir(tl):
|
| 80 |
+
raise RuntimeError(
|
| 81 |
+
"TMA grid constant descriptors not supported in your Triton version"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.tma_size = tma_size
|
| 85 |
+
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
| 86 |
+
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
| 87 |
+
self.descriptors: Dict[str, torch.Tensor] = {}
|
| 88 |
+
|
| 89 |
+
def init_tma_descriptor(self, name: str) -> None:
|
| 90 |
+
"""Initialize a TMA descriptor with the given name.
|
| 91 |
+
|
| 92 |
+
Call this method outside of the lambda function for grid size.
|
| 93 |
+
"""
|
| 94 |
+
self.descriptors[name] = torch.empty(
|
| 95 |
+
self.tma_size, device="cpu", dtype=torch.int8
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def fill_1d_tma_descriptor(
|
| 99 |
+
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
| 100 |
+
) -> None:
|
| 101 |
+
"""Fill a 1D TMA descriptor.
|
| 102 |
+
|
| 103 |
+
Call this method inside the lambda function for grid size.
|
| 104 |
+
"""
|
| 105 |
+
if name not in self.descriptors:
|
| 106 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 107 |
+
|
| 108 |
+
desc_x = self.descriptors[name]
|
| 109 |
+
if desc_x.data_ptr() % 64 != 0:
|
| 110 |
+
raise ValueError("TMA descriptor must be 64-byte aligned")
|
| 111 |
+
self.fill_1d_tma_descriptor_inner(
|
| 112 |
+
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def fill_2d_tma_descriptor(
|
| 116 |
+
self,
|
| 117 |
+
name: str,
|
| 118 |
+
ptr: int,
|
| 119 |
+
dim1: int,
|
| 120 |
+
dim0: int,
|
| 121 |
+
block_dim1: int,
|
| 122 |
+
block_dim0: int,
|
| 123 |
+
element_size: int,
|
| 124 |
+
) -> None:
|
| 125 |
+
"""Fill a 2D TMA descriptor.
|
| 126 |
+
|
| 127 |
+
Call this method inside the lambda function for grid size.
|
| 128 |
+
"""
|
| 129 |
+
if name not in self.descriptors:
|
| 130 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 131 |
+
|
| 132 |
+
desc_x = self.descriptors[name]
|
| 133 |
+
if desc_x.data_ptr() % 64 != 0:
|
| 134 |
+
raise ValueError("TMA descriptor must be 64-byte aligned")
|
| 135 |
+
self.fill_2d_tma_descriptor_inner(
|
| 136 |
+
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
|
| 140 |
+
"""Get the TMA descriptor kernel parameter for the given name."""
|
| 141 |
+
if name not in self.descriptors or self.descriptors[name] is None:
|
| 142 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 143 |
+
return self.KernelParamWrapper(self.descriptors[name])
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ====== Autotuning utilities ======
|
| 147 |
+
ALIGN_SIZE_M = 128
|
| 148 |
+
|
| 149 |
+
_NV_CONFIGS = [
|
| 150 |
+
triton.Config(
|
| 151 |
+
{
|
| 152 |
+
"BLOCK_SIZE_M": block_size_m,
|
| 153 |
+
"BLOCK_SIZE_N": block_size_n,
|
| 154 |
+
"BLOCK_SIZE_K": block_size_k,
|
| 155 |
+
},
|
| 156 |
+
num_stages=num_stages,
|
| 157 |
+
num_warps=num_warps,
|
| 158 |
+
num_ctas=num_ctas,
|
| 159 |
+
)
|
| 160 |
+
for block_size_m in [ALIGN_SIZE_M, ]
|
| 161 |
+
for block_size_n in [64, 128, 256]
|
| 162 |
+
for block_size_k in [64, 128, 256]
|
| 163 |
+
for num_stages in [3, 4]
|
| 164 |
+
for num_warps in [4, 8]
|
| 165 |
+
for num_ctas in [1]
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
| 170 |
+
device = torch.cuda.current_device()
|
| 171 |
+
# Check for all possible pointer parameter names
|
| 172 |
+
if "grad_input_ptr" in named_args:
|
| 173 |
+
ptr_name = "grad_input_ptr"
|
| 174 |
+
elif "c_ptr" in named_args:
|
| 175 |
+
ptr_name = "c_ptr"
|
| 176 |
+
elif "grad_weight_ptr" in named_args:
|
| 177 |
+
ptr_name = "grad_weight_ptr"
|
| 178 |
+
else:
|
| 179 |
+
raise KeyError("No recognized pointer parameter found in kernel arguments")
|
| 180 |
+
|
| 181 |
+
if dtsize is None:
|
| 182 |
+
dtsize = named_args[ptr_name].element_size()
|
| 183 |
+
if dtype is None:
|
| 184 |
+
dtype = named_args[ptr_name].dtype
|
| 185 |
+
|
| 186 |
+
pruned_configs = []
|
| 187 |
+
for config in configs:
|
| 188 |
+
kw = config.kwargs
|
| 189 |
+
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
|
| 190 |
+
kw["BLOCK_SIZE_M"],
|
| 191 |
+
kw["BLOCK_SIZE_N"],
|
| 192 |
+
kw["BLOCK_SIZE_K"],
|
| 193 |
+
config.num_stages,
|
| 194 |
+
)
|
| 195 |
+
G, M, N, K = (
|
| 196 |
+
named_args["G"],
|
| 197 |
+
named_args["M_BUCKET"],
|
| 198 |
+
named_args["N"],
|
| 199 |
+
named_args["K"],
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# 1. make sure we have enough smem
|
| 203 |
+
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
| 204 |
+
"max_shared_mem"
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
| 208 |
+
if required_shared_memory > max_shared_memory:
|
| 209 |
+
continue
|
| 210 |
+
|
| 211 |
+
M_PER_GROUP = M // G
|
| 212 |
+
MIN_M_TILES = 64
|
| 213 |
+
# 2. make sure we don't load M tiles that are too big
|
| 214 |
+
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
| 215 |
+
continue
|
| 216 |
+
# 3. make sure we don't load N tiles that are too small
|
| 217 |
+
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
+
num_sm = driver.active.utils.get_device_properties(device)[
|
| 221 |
+
"multiprocessor_count"
|
| 222 |
+
]
|
| 223 |
+
N_TILES = N // BLOCK_N
|
| 224 |
+
MIN_N_TILES = 64
|
| 225 |
+
# 4. make sure we don't load N tiles that are too big
|
| 226 |
+
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
| 227 |
+
continue
|
| 228 |
+
# 5. make sure we don't load N tiles that are too small
|
| 229 |
+
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
| 230 |
+
continue
|
| 231 |
+
# 6. make sure K can be evenly divided
|
| 232 |
+
if K % BLOCK_K != 0:
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
pruned_configs.append(config)
|
| 236 |
+
|
| 237 |
+
return pruned_configs
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ======== End Autotuning utilities ========
|
torchtitan/experiments/llama4/README.md
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**The Llama 4 folder is still under development.**
|
| 2 |
+
|
| 3 |
+
#### Available features
|
| 4 |
+
- Llama 4 model definition (text-only), including the MoE architecture with token-choice routing
|
| 5 |
+
- Basic FSDP, TP, PP, CP support
|
| 6 |
+
- DCP checkpoint conversion scripts
|
| 7 |
+
|
| 8 |
+
#### Download Llama 4 tokenizer
|
| 9 |
+
```bash
|
| 10 |
+
# Llama 4 tokenizer.model
|
| 11 |
+
python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E --tokenizer_path "" --hf_token=...
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
#### To be added
|
| 15 |
+
- Modeling
|
| 16 |
+
- iRoPE implementation
|
| 17 |
+
- load balance loss for token-choice MoE
|
| 18 |
+
- alternative expert-choice MoE
|
| 19 |
+
- multimodal support
|
| 20 |
+
- Kernel integration
|
| 21 |
+
- efficient bfloat16 GroupedGEMM kernels (from PyTorch core)
|
| 22 |
+
- efficient float8 GroupedGEMM kernels (from torchao)
|
| 23 |
+
- Parallelism
|
| 24 |
+
- performant TP implementation and torch.compile support for MoE layers
|
| 25 |
+
- Context Parallel support for FlexAttention, iRoPE, and multimodal inputs
|
| 26 |
+
- Expert Parallel support
|
| 27 |
+
- Testing
|
| 28 |
+
- perfomance and loss converging tests
|
| 29 |
+
- CI integration
|
torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc
ADDED
|
Binary file (5.55 kB). View file
|
|
|
torchtitan/experiments/llama4/infra/expert_parallel.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.distributed.tensor import (
|
| 13 |
+
DeviceMesh,
|
| 14 |
+
distribute_module,
|
| 15 |
+
distribute_tensor,
|
| 16 |
+
DTensor,
|
| 17 |
+
Partial,
|
| 18 |
+
Replicate,
|
| 19 |
+
Shard,
|
| 20 |
+
)
|
| 21 |
+
from torch.distributed.tensor.parallel import ParallelStyle
|
| 22 |
+
from torch.distributed.tensor.placement_types import Placement
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# implementation of Tensor Parallel on the non-shared experts in MoE
|
| 26 |
+
class TensorParallel(ParallelStyle):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
*,
|
| 30 |
+
input_layouts: Optional[Tuple[Optional[Placement]]] = None,
|
| 31 |
+
output_layout: Optional[Placement] = None,
|
| 32 |
+
use_local_output: bool = True,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.input_layouts = input_layouts or (Replicate(), None)
|
| 36 |
+
self.output_layout = output_layout or Partial()
|
| 37 |
+
self.desired_input_layouts = (Replicate(), None)
|
| 38 |
+
self.use_local_output = use_local_output
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def _prepare_input_fn(
|
| 42 |
+
input_layouts, desired_input_layouts, mod, inputs, device_mesh
|
| 43 |
+
):
|
| 44 |
+
# TODO: figure out dynamo support for instance method and switch this to instance method
|
| 45 |
+
|
| 46 |
+
# annotate module input placements/sharding with input_layouts
|
| 47 |
+
input_tensor, input_layout, desired_input_layout = (
|
| 48 |
+
inputs[0],
|
| 49 |
+
input_layouts[0],
|
| 50 |
+
desired_input_layouts[0],
|
| 51 |
+
)
|
| 52 |
+
if not isinstance(input_tensor, DTensor):
|
| 53 |
+
input_tensor = DTensor.from_local(
|
| 54 |
+
input_tensor, device_mesh, (input_layout,), run_check=False
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if input_layouts != desired_input_layouts:
|
| 58 |
+
input_tensor = input_tensor.redistribute(
|
| 59 |
+
placements=(desired_input_layout,), async_op=True
|
| 60 |
+
)
|
| 61 |
+
return (input_tensor, *inputs[1:])
|
| 62 |
+
|
| 63 |
+
def _partition_fn(self, name, module, device_mesh):
|
| 64 |
+
module.register_parameter(
|
| 65 |
+
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)]))
|
| 66 |
+
) # Column-wise sharding
|
| 67 |
+
module.register_parameter(
|
| 68 |
+
"w2",
|
| 69 |
+
nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])),
|
| 70 |
+
) # Row-wise sharding
|
| 71 |
+
module.register_parameter(
|
| 72 |
+
"w3",
|
| 73 |
+
nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])),
|
| 74 |
+
) # Column-wise sharding
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
|
| 78 |
+
if outputs.placements != (output_layout,):
|
| 79 |
+
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
|
| 80 |
+
# back to local tensor
|
| 81 |
+
return outputs.to_local() if use_local_output else outputs
|
| 82 |
+
|
| 83 |
+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
| 84 |
+
return distribute_module(
|
| 85 |
+
module,
|
| 86 |
+
device_mesh,
|
| 87 |
+
self._partition_fn,
|
| 88 |
+
partial(
|
| 89 |
+
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
|
| 90 |
+
),
|
| 91 |
+
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# NOTE: This is to achieve replicate computation on the gate module in the MoE router.
|
| 96 |
+
# It does nothing other than (1) setting the module parameters as DTensors on the given mesh
|
| 97 |
+
# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
|
| 98 |
+
# TODO: The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
|
| 99 |
+
# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
|
| 100 |
+
class NoParallel(ParallelStyle):
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
*,
|
| 104 |
+
input_layout: Optional[Placement] = None,
|
| 105 |
+
output_layout: Optional[Placement] = None,
|
| 106 |
+
use_local_output: bool = True,
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.input_layout = input_layout or Replicate()
|
| 110 |
+
self.output_layout = output_layout or Replicate()
|
| 111 |
+
self.desired_input_layout = Replicate()
|
| 112 |
+
self.use_local_output = use_local_output
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
|
| 116 |
+
# annotate module input placements/sharding with input_layouts
|
| 117 |
+
input_tensor = inputs[0]
|
| 118 |
+
if not isinstance(input_tensor, DTensor):
|
| 119 |
+
input_tensor = DTensor.from_local(
|
| 120 |
+
input_tensor, device_mesh, (input_layout,), run_check=False
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if input_layout != desired_input_layout:
|
| 124 |
+
input_tensor = input_tensor.redistribute(
|
| 125 |
+
placements=(desired_input_layout,), async_op=True
|
| 126 |
+
)
|
| 127 |
+
return (input_tensor, *inputs[1:])
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
|
| 131 |
+
if outputs.placements != (output_layout,):
|
| 132 |
+
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
|
| 133 |
+
# back to local tensor
|
| 134 |
+
return outputs.to_local() if use_local_output else outputs
|
| 135 |
+
|
| 136 |
+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
| 137 |
+
return distribute_module(
|
| 138 |
+
module,
|
| 139 |
+
device_mesh,
|
| 140 |
+
None,
|
| 141 |
+
partial(
|
| 142 |
+
self._prepare_input_fn, self.input_layout, self.desired_input_layout
|
| 143 |
+
),
|
| 144 |
+
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
|
| 145 |
+
)
|
torchtitan/experiments/llama4/infra/parallelize_llama.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 11 |
+
|
| 12 |
+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
|
| 13 |
+
from torchtitan.distributed import ParallelDims
|
| 14 |
+
|
| 15 |
+
from torchtitan.models.llama3.parallelize_llama import (
|
| 16 |
+
apply_ac,
|
| 17 |
+
apply_compile,
|
| 18 |
+
apply_ddp,
|
| 19 |
+
apply_fsdp,
|
| 20 |
+
apply_tp,
|
| 21 |
+
)
|
| 22 |
+
from torchtitan.tools.logging import logger
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parallelize_llama(
|
| 26 |
+
model: nn.Module,
|
| 27 |
+
world_mesh: DeviceMesh,
|
| 28 |
+
parallel_dims: ParallelDims,
|
| 29 |
+
job_config: JobConfig,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
| 33 |
+
parallelism to the model.
|
| 34 |
+
|
| 35 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
| 36 |
+
the model must fit on GPU or CPU memory.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
if parallel_dims.tp_enabled:
|
| 40 |
+
if (
|
| 41 |
+
job_config.parallelism.enable_async_tensor_parallel
|
| 42 |
+
and not job_config.training.compile
|
| 43 |
+
):
|
| 44 |
+
raise RuntimeError("Async TP requires --training.compile")
|
| 45 |
+
|
| 46 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
| 47 |
+
float8_is_rowwise = job_config.float8.recipe_name in (
|
| 48 |
+
"rowwise",
|
| 49 |
+
"rowwise_with_gw_hp",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# For now, float8 all-gather with TP is only supported for tensorwise
|
| 53 |
+
# float8 scaling recipes. For rowwise recipes, we use regular TP and
|
| 54 |
+
# all-gather happens in high precision.
|
| 55 |
+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
|
| 56 |
+
|
| 57 |
+
apply_tp(
|
| 58 |
+
model,
|
| 59 |
+
world_mesh["tp"],
|
| 60 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
| 61 |
+
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
|
| 62 |
+
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
apply_moe_tp(model, world_mesh["tp"])
|
| 66 |
+
|
| 67 |
+
if job_config.activation_checkpoint.mode != "none":
|
| 68 |
+
if (
|
| 69 |
+
job_config.activation_checkpoint.mode == "selective"
|
| 70 |
+
and job_config.model.use_flex_attn
|
| 71 |
+
):
|
| 72 |
+
raise ValueError(
|
| 73 |
+
"FlexAttention is not compatible with selective AC yet. "
|
| 74 |
+
"See https://github.com/pytorch/pytorch/issues/147879"
|
| 75 |
+
)
|
| 76 |
+
apply_ac(model, job_config.activation_checkpoint)
|
| 77 |
+
|
| 78 |
+
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
|
| 79 |
+
if job_config.training.compile:
|
| 80 |
+
apply_compile(model)
|
| 81 |
+
|
| 82 |
+
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
|
| 83 |
+
torch._dynamo.config.capture_scalar_outputs = True
|
| 84 |
+
|
| 85 |
+
if (
|
| 86 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
| 87 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
| 88 |
+
if parallel_dims.dp_replicate_enabled:
|
| 89 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
| 90 |
+
else:
|
| 91 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
| 92 |
+
|
| 93 |
+
apply_fsdp(
|
| 94 |
+
model,
|
| 95 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
| 96 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
| 97 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
| 98 |
+
pp_enabled=parallel_dims.pp_enabled,
|
| 99 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
| 100 |
+
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if parallel_dims.dp_replicate_enabled:
|
| 104 |
+
logger.info("Applied HSDP to the model")
|
| 105 |
+
else:
|
| 106 |
+
logger.info("Applied FSDP to the model")
|
| 107 |
+
|
| 108 |
+
if parallel_dims.cp_enabled:
|
| 109 |
+
logger.info("Applied Context Parallel to the model")
|
| 110 |
+
|
| 111 |
+
if job_config.training.enable_cpu_offload:
|
| 112 |
+
logger.info("Applied CPU Offloading to the model")
|
| 113 |
+
elif parallel_dims.dp_replicate_enabled:
|
| 114 |
+
if world_mesh.ndim > 1:
|
| 115 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
| 116 |
+
apply_ddp(
|
| 117 |
+
model,
|
| 118 |
+
world_mesh,
|
| 119 |
+
enable_compile=job_config.training.compile,
|
| 120 |
+
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return model
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def apply_moe_tp(
|
| 127 |
+
model: nn.Module,
|
| 128 |
+
tp_mesh: DeviceMesh,
|
| 129 |
+
):
|
| 130 |
+
from torch.distributed.tensor import Partial, Replicate, Shard
|
| 131 |
+
from torch.distributed.tensor.parallel import (
|
| 132 |
+
parallelize_module,
|
| 133 |
+
PrepareModuleInputOutput,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
from .expert_parallel import NoParallel, TensorParallel
|
| 137 |
+
|
| 138 |
+
for _, transformer_block in model.layers.items():
|
| 139 |
+
moe_layer_plan = {
|
| 140 |
+
# input / output sharding on the seqlen dim
|
| 141 |
+
# all-gather for input, reduce-scatter for output
|
| 142 |
+
"moe": PrepareModuleInputOutput(
|
| 143 |
+
input_layouts=(Shard(1),),
|
| 144 |
+
desired_input_layouts=(Replicate(),),
|
| 145 |
+
use_local_input=True,
|
| 146 |
+
output_layouts=(Partial(),),
|
| 147 |
+
desired_output_layouts=(Shard(1),),
|
| 148 |
+
),
|
| 149 |
+
# replicate computation for the router
|
| 150 |
+
"moe.router.gate": NoParallel(),
|
| 151 |
+
# input Replicate, output Partial
|
| 152 |
+
"moe.experts": TensorParallel(),
|
| 153 |
+
"moe.shared_expert": TensorParallel(),
|
| 154 |
+
}
|
| 155 |
+
parallelize_module(
|
| 156 |
+
module=transformer_block,
|
| 157 |
+
device_mesh=tp_mesh,
|
| 158 |
+
parallelize_plan=moe_layer_plan,
|
| 159 |
+
)
|
torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc
ADDED
|
Binary file (4.1 kB). View file
|
|
|
torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (23.2 kB). View file
|
|
|
torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|