zaydzuhri commited on
Commit
22487a1
·
verified ·
1 Parent(s): 271d75e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/models/delta_net/__pycache__/__init__.cpython-312.pyc +0 -0
  2. fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc +0 -0
  3. fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc +0 -0
  4. fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc +0 -0
  5. fla/models/gla/configuration_gla.py +95 -0
  6. fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc +0 -0
  7. fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc +0 -0
  8. fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc +0 -0
  9. fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
  10. fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc +0 -0
  11. fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc +0 -0
  12. fla/models/transformer/__pycache__/__init__.cpython-312.pyc +0 -0
  13. fla/models/transformer_top/__pycache__/__init__.cpython-312.pyc +0 -0
  14. fla/models/transformer_top/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  15. fla/modules/__pycache__/activations.cpython-312.pyc +0 -0
  16. fla/modules/__pycache__/convolution.cpython-312.pyc +0 -0
  17. fla/modules/__pycache__/fused_bitlinear.cpython-312.pyc +0 -0
  18. fla/modules/__pycache__/mlp.cpython-312.pyc +0 -0
  19. fla/modules/__pycache__/rotary.cpython-312.pyc +0 -0
  20. tb/20260101-0922/wandb/run-20260101_092219--dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202601010919/files/requirements.txt +169 -0
  21. tb/20260101-0922/wandb/run-20260101_092219--dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202601010919/files/wandb-metadata.json +146 -0
  22. torchtitan/experiments/deepseek_v3/attn_mask_utils.py +397 -0
  23. torchtitan/experiments/deepseek_v3/inference.sh +15 -0
  24. torchtitan/experiments/deepseek_v3/model.py +1325 -0
  25. torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py +11 -0
  26. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
  27. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
  28. torchtitan/experiments/deepseek_v3/train.py +142 -0
  29. torchtitan/experiments/flux/README.md +23 -0
  30. torchtitan/experiments/flux/__init__.py +122 -0
  31. torchtitan/experiments/flux/flux_argparser.py +42 -0
  32. torchtitan/experiments/flux/loss.py +27 -0
  33. torchtitan/experiments/flux/parallelize_flux.py +26 -0
  34. torchtitan/experiments/flux/requirements.txt +2 -0
  35. torchtitan/experiments/flux/scripts/download_autoencoder.py +61 -0
  36. torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
  37. torchtitan/experiments/flux/train.py +224 -0
  38. torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
  39. torchtitan/experiments/flux/utils.py +203 -0
  40. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py +13 -0
  41. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py +299 -0
  42. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py +240 -0
  43. torchtitan/experiments/llama4/README.md +29 -0
  44. torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc +0 -0
  45. torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc +0 -0
  46. torchtitan/experiments/llama4/infra/expert_parallel.py +145 -0
  47. torchtitan/experiments/llama4/infra/parallelize_llama.py +159 -0
  48. torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc +0 -0
  49. torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc +0 -0
  50. torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc +0 -0
fla/models/delta_net/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (729 Bytes). View file
 
fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc ADDED
Binary file (3.62 kB). View file
 
fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc ADDED
Binary file (2.52 kB). View file
 
fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc ADDED
Binary file (18.6 kB). View file
 
fla/models/gla/configuration_gla.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GLAConfig(PretrainedConfig):
9
+
10
+ model_type = 'gla'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ expand_k: int = 0.5,
17
+ expand_v: int = 1,
18
+ hidden_ratio: Optional[int] = 4,
19
+ intermediate_size: Optional[int] = None,
20
+ num_hidden_layers: int = 24,
21
+ num_heads: int = 4,
22
+ num_kv_heads: Optional[int] = None,
23
+ feature_map: Optional[str] = None,
24
+ attn_mode: str = "chunk",
25
+ use_short_conv: bool = False,
26
+ conv_size: int = 4,
27
+ use_output_gate: bool = True,
28
+ clamp_min: Optional[float] = None,
29
+ hidden_act: str = "swish",
30
+ max_position_embeddings: int = 2048,
31
+ elementwise_affine: Optional[bool] = True,
32
+ norm_eps: float = 1e-6,
33
+ use_gk: bool = True,
34
+ use_gv: bool = False,
35
+ attn: Optional[Dict] = None,
36
+ use_cache: bool = True,
37
+ pad_token_id: int = None,
38
+ bos_token_id: int = 1,
39
+ eos_token_id: int = 2,
40
+ tie_word_embeddings: bool = False,
41
+ initializer_range: float = 0.006,
42
+ fuse_norm: bool = True,
43
+ fuse_swiglu: bool = True,
44
+ fuse_cross_entropy: bool = True,
45
+ vocab_size: int = 32000,
46
+ **kwargs
47
+ ):
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.num_hidden_layers = num_hidden_layers
54
+ self.num_heads = num_heads
55
+ self.num_kv_heads = num_kv_heads
56
+ self.feature_map = feature_map
57
+ self.attn_mode = attn_mode
58
+ self.use_short_conv = use_short_conv
59
+ self.conv_size = conv_size
60
+ self.use_output_gate = use_output_gate
61
+ self.clamp_min = clamp_min
62
+ self.hidden_act = hidden_act
63
+ self.max_position_embeddings = max_position_embeddings
64
+ self.elementwise_affine = elementwise_affine
65
+ self.norm_eps = norm_eps
66
+ self.use_gk = use_gk
67
+ self.use_gv = use_gv
68
+ self.attn = attn
69
+ self.use_cache = use_cache
70
+ self.initializer_range = initializer_range
71
+
72
+ self.fuse_norm = fuse_norm
73
+ self.fuse_swiglu = fuse_swiglu
74
+ self.fuse_cross_entropy = fuse_cross_entropy
75
+ self.vocab_size = vocab_size
76
+
77
+ if attn is not None:
78
+ if not isinstance(attn, Dict):
79
+ raise ValueError("attn must be a dictionary")
80
+ if 'layers' not in attn:
81
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
82
+ if 'num_heads' not in attn:
83
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
84
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
85
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
86
+ attn['window_size'] = attn.get('window_size', None)
87
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
88
+
89
+ super().__init__(
90
+ pad_token_id=pad_token_id,
91
+ bos_token_id=bos_token_id,
92
+ eos_token_id=eos_token_id,
93
+ tie_word_embeddings=tie_word_embeddings,
94
+ **kwargs,
95
+ )
fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc ADDED
Binary file (3.76 kB). View file
 
fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (715 Bytes). View file
 
fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (715 Bytes). View file
 
fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc ADDED
Binary file (3.42 kB). View file
 
fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc ADDED
Binary file (20.9 kB). View file
 
fla/models/transformer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (756 Bytes). View file
 
fla/models/transformer_top/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (777 Bytes). View file
 
fla/models/transformer_top/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.83 kB). View file
 
fla/modules/__pycache__/activations.cpython-312.pyc ADDED
Binary file (23 kB). View file
 
fla/modules/__pycache__/convolution.cpython-312.pyc ADDED
Binary file (21.1 kB). View file
 
fla/modules/__pycache__/fused_bitlinear.cpython-312.pyc ADDED
Binary file (23.7 kB). View file
 
fla/modules/__pycache__/mlp.cpython-312.pyc ADDED
Binary file (6.26 kB). View file
 
fla/modules/__pycache__/rotary.cpython-312.pyc ADDED
Binary file (23.2 kB). View file
 
tb/20260101-0922/wandb/run-20260101_092219--dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202601010919/files/requirements.txt ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flame==0.1.0
2
+ fsspec==2025.10.0
3
+ aiohappyeyeballs==2.6.1
4
+ ipykernel==7.1.0
5
+ smmap==5.0.2
6
+ pybind11==3.0.1
7
+ tabulate==0.9.0
8
+ parso==0.8.5
9
+ yarl==1.22.0
10
+ asttokens==3.0.1
11
+ pandas==2.3.3
12
+ xxhash==3.6.0
13
+ pathvalidate==3.3.1
14
+ Werkzeug==3.1.4
15
+ regex==2025.11.3
16
+ inquirerpy==0.3.4
17
+ click==8.3.1
18
+ idna==3.11
19
+ pydantic==2.12.5
20
+ pexpect==4.9.0
21
+ typepy==1.3.4
22
+ certifi==2025.11.12
23
+ wcwidth==0.2.14
24
+ triton==3.2.0
25
+ hf-xet==1.2.0
26
+ joblib==1.5.3
27
+ tqdm==4.67.1
28
+ nvidia-nvtx-cu12==12.4.127
29
+ setuptools==80.9.0
30
+ lxml==6.0.2
31
+ nvidia-cufft-cu12==11.2.1.3
32
+ evaluate==0.4.6
33
+ Markdown==3.10
34
+ chardet==5.2.0
35
+ multiprocess==0.70.18
36
+ tensorboard==2.20.0
37
+ nvidia-nvjitlink-cu12==12.4.127
38
+ flame==0.1.0
39
+ matplotlib-inline==0.2.1
40
+ Cython==3.2.3
41
+ tensorboard-data-server==0.7.2
42
+ nvidia-cusparse-cu12==12.3.1.170
43
+ lm_eval==0.4.9.1
44
+ pure_eval==0.2.3
45
+ protobuf==6.33.2
46
+ DataProperty==1.1.0
47
+ nvidia-cudnn-cu12==9.1.0.70
48
+ accelerate==1.12.0
49
+ psutil==7.1.3
50
+ Jinja2==3.1.6
51
+ scikit-learn==1.8.0
52
+ nvidia-nccl-cu12==2.21.5
53
+ typing_extensions==4.15.0
54
+ pyzmq==27.1.0
55
+ mpmath==1.3.0
56
+ annotated-types==0.7.0
57
+ propcache==0.4.1
58
+ wandb==0.23.1
59
+ requests==2.32.5
60
+ ipython==9.8.0
61
+ more-itertools==10.8.0
62
+ nvidia-cuda-runtime-cu12==12.4.127
63
+ sacrebleu==2.5.1
64
+ httpx==0.28.1
65
+ huggingface-hub==0.36.0
66
+ MarkupSafe==3.0.3
67
+ nvidia-cusolver-cu12==11.6.1.9
68
+ gitdb==4.0.12
69
+ torchdata==0.11.0
70
+ sentry-sdk==2.48.0
71
+ sympy==1.13.1
72
+ safetensors==0.7.0
73
+ httpcore==1.0.9
74
+ portalocker==3.2.0
75
+ attrs==25.4.0
76
+ typing-inspection==0.4.2
77
+ ptyprocess==0.7.0
78
+ nvidia-cublas-cu12==12.4.5.8
79
+ numexpr==2.14.1
80
+ executing==2.2.1
81
+ networkx==3.6.1
82
+ threadpoolctl==3.6.0
83
+ nvidia-cusparselt-cu12==0.6.2
84
+ einops==0.8.1
85
+ zstandard==0.25.0
86
+ comm==0.2.3
87
+ six==1.17.0
88
+ packaging==25.0
89
+ tqdm-multiprocess==0.0.11
90
+ numpy==2.3.5
91
+ colorama==0.4.6
92
+ nvidia-cuda-cupti-cu12==12.4.127
93
+ jupyter_client==8.7.0
94
+ scipy==1.16.3
95
+ tornado==6.5.4
96
+ nltk==3.9.2
97
+ antlr4-python3-runtime==4.11.0
98
+ jupyter_core==5.9.1
99
+ sqlitedict==2.1.0
100
+ tzdata==2025.3
101
+ pytz==2025.2
102
+ Pygments==2.19.2
103
+ python-dotenv==1.2.1
104
+ cmake==4.2.0
105
+ tiktoken==0.12.0
106
+ PyYAML==6.0.3
107
+ datasets==4.4.1
108
+ pillow==12.0.0
109
+ math-verify==0.8.0
110
+ dill==0.4.0
111
+ nvidia-cuda-nvrtc-cu12==12.4.127
112
+ anyio==4.12.0
113
+ prompt_toolkit==3.0.52
114
+ filelock==3.20.1
115
+ jedi==0.19.2
116
+ frozenlist==1.8.0
117
+ tokenizers==0.21.4
118
+ grpcio==1.76.0
119
+ ninja==1.13.0
120
+ mbstrdecoder==1.1.4
121
+ flash-attn==2.7.3
122
+ aiosignal==1.4.0
123
+ tabledata==1.3.4
124
+ h11==0.16.0
125
+ absl-py==2.3.1
126
+ latex2sympy2_extended==1.10.2
127
+ torch==2.6.0
128
+ nest_asyncio==1.6.0
129
+ pip==25.3
130
+ aiohttp==3.13.2
131
+ pfzy==0.3.4
132
+ platformdirs==4.5.1
133
+ wheel==0.45.1
134
+ peft==0.17.0
135
+ debugpy==1.8.19
136
+ ipython_pygments_lexers==1.1.1
137
+ rouge_score==0.1.2
138
+ multidict==6.7.0
139
+ tcolorpy==0.1.7
140
+ nvidia-curand-cu12==10.3.5.147
141
+ pydantic_core==2.41.5
142
+ pytablewriter==1.2.1
143
+ charset-normalizer==3.4.4
144
+ transformers==4.51.3
145
+ word2number==1.1
146
+ jsonlines==4.0.0
147
+ stack_data==0.6.3
148
+ urllib3==2.6.2
149
+ decorator==5.2.1
150
+ python-dateutil==2.9.0.post0
151
+ pyarrow==22.0.0
152
+ traitlets==5.14.3
153
+ GitPython==3.1.45
154
+ tomli==2.0.1
155
+ more-itertools==10.3.0
156
+ inflect==7.3.1
157
+ zipp==3.19.2
158
+ jaraco.functools==4.0.1
159
+ autocommand==2.2.2
160
+ jaraco.collections==5.1.0
161
+ platformdirs==4.2.2
162
+ backports.tarfile==1.2.0
163
+ importlib_metadata==8.0.0
164
+ jaraco.text==3.12.1
165
+ typing_extensions==4.12.2
166
+ jaraco.context==5.3.0
167
+ typeguard==4.3.0
168
+ packaging==24.2
169
+ wheel==0.45.1
tb/20260101-0922/wandb/run-20260101_092219--dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine-202601010919/files/wandb-metadata.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-6.8.0-85-generic-x86_64-with-glibc2.39",
3
+ "python": "CPython 3.12.12",
4
+ "startedAt": "2026-01-01T09:22:19.321743Z",
5
+ "args": [
6
+ "--job.config_file",
7
+ "flame/models/fla.toml",
8
+ "--job.dump_folder",
9
+ "exp/dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine",
10
+ "--model.config",
11
+ "configs/dsmtp_transformer_7B.json",
12
+ "--model.tokenizer_path",
13
+ "fla-hub/transformer-1.3B-100B",
14
+ "--optimizer.name",
15
+ "AdamW",
16
+ "--optimizer.eps",
17
+ "1e-15",
18
+ "--optimizer.lr",
19
+ "2e-5",
20
+ "--lr_scheduler.warmup_steps",
21
+ "400",
22
+ "--lr_scheduler.lr_min",
23
+ "0.1",
24
+ "--lr_scheduler.decay_type",
25
+ "cosine",
26
+ "--training.batch_size",
27
+ "8",
28
+ "--training.seq_len",
29
+ "4096",
30
+ "--training.context_len",
31
+ "4096",
32
+ "--training.gradient_accumulation_steps",
33
+ "2",
34
+ "--training.steps",
35
+ "40000",
36
+ "--training.max_norm",
37
+ "1.0",
38
+ "--training.skip_nan_inf",
39
+ "--training.dataset",
40
+ "/root/.cache/zaydzuhri___stack-edu-python/default",
41
+ "--training.dataset_split",
42
+ "train",
43
+ "--training.num_workers",
44
+ "32",
45
+ "--training.prefetch_factor",
46
+ "2",
47
+ "--training.seed",
48
+ "79",
49
+ "--training.compile",
50
+ "--checkpoint.interval",
51
+ "8000",
52
+ "--checkpoint.load_step",
53
+ "-1",
54
+ "--metrics.log_freq",
55
+ "5",
56
+ "--checkpoint.hf_upload_enabled",
57
+ "--checkpoint.hf_repo_base_name",
58
+ "zaydzuhri/dsmtp-code-7B-4096-batch8x2-steps40000",
59
+ "--comm.init_timeout_seconds",
60
+ "6000",
61
+ "--comm.train_timeout_seconds",
62
+ "6000"
63
+ ],
64
+ "program": "-m flame.train",
65
+ "git": {
66
+ "remote": "https://github.com/zaydzuhri/flame.git",
67
+ "commit": "5bcd6b6423606e07b92dd2644ecc24d908d2c7a4"
68
+ },
69
+ "email": "zaydzuhri@gmail.com",
70
+ "root": "exp/dsmtp.code.7B.batch8.seqlen4096.context4096.warmup400.update2.steps40000.lr2e-5.cosine/tb/20260101-0922",
71
+ "host": "rentals-6z3zwezo0sfapf3y-697b4fc787-gh86h",
72
+ "executable": "/root/miniconda3/envs/flame-env/bin/python3.12",
73
+ "cpu_count": 64,
74
+ "cpu_count_logical": 128,
75
+ "gpu": "NVIDIA H200",
76
+ "gpu_count": 8,
77
+ "disk": {
78
+ "/": {
79
+ "total": "3246163542016",
80
+ "used": "1652645769216"
81
+ }
82
+ },
83
+ "memory": {
84
+ "total": "1913835118592"
85
+ },
86
+ "gpu_nvidia": [
87
+ {
88
+ "name": "NVIDIA H200",
89
+ "memoryTotal": "150754820096",
90
+ "cudaCores": 16896,
91
+ "architecture": "Hopper",
92
+ "uuid": "GPU-bf7aa6f4-2ee0-0ff7-3851-1b40dafcde1f"
93
+ },
94
+ {
95
+ "name": "NVIDIA H200",
96
+ "memoryTotal": "150754820096",
97
+ "cudaCores": 16896,
98
+ "architecture": "Hopper",
99
+ "uuid": "GPU-24e3a14c-3196-7560-5e54-cd031aa25f76"
100
+ },
101
+ {
102
+ "name": "NVIDIA H200",
103
+ "memoryTotal": "150754820096",
104
+ "cudaCores": 16896,
105
+ "architecture": "Hopper",
106
+ "uuid": "GPU-3e484efe-97e7-7b5b-e6b7-d1dc17ed2765"
107
+ },
108
+ {
109
+ "name": "NVIDIA H200",
110
+ "memoryTotal": "150754820096",
111
+ "cudaCores": 16896,
112
+ "architecture": "Hopper",
113
+ "uuid": "GPU-7b9f4a41-11cd-03b7-0065-5e4dab09ddd4"
114
+ },
115
+ {
116
+ "name": "NVIDIA H200",
117
+ "memoryTotal": "150754820096",
118
+ "cudaCores": 16896,
119
+ "architecture": "Hopper",
120
+ "uuid": "GPU-3f34c938-ca85-e68c-f501-59e20f64e14c"
121
+ },
122
+ {
123
+ "name": "NVIDIA H200",
124
+ "memoryTotal": "150754820096",
125
+ "cudaCores": 16896,
126
+ "architecture": "Hopper",
127
+ "uuid": "GPU-9d38c94f-6fc0-6735-7f0d-3e359e2562cb"
128
+ },
129
+ {
130
+ "name": "NVIDIA H200",
131
+ "memoryTotal": "150754820096",
132
+ "cudaCores": 16896,
133
+ "architecture": "Hopper",
134
+ "uuid": "GPU-a7fc49a5-17e0-6e8f-feee-4ffa6e526637"
135
+ },
136
+ {
137
+ "name": "NVIDIA H200",
138
+ "memoryTotal": "150754820096",
139
+ "cudaCores": 16896,
140
+ "architecture": "Hopper",
141
+ "uuid": "GPU-731d225c-6930-54d4-db0f-f53e16eaeb2e"
142
+ }
143
+ ],
144
+ "cudaVersion": "12.8",
145
+ "writerId": "allj4tgslt7j35odul2j6cl35jpfxh7k"
146
+ }
torchtitan/experiments/deepseek_v3/attn_mask_utils.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This code is based on src/transformers/modeling_attn_mask_utils.py of
8
+ # huggingface/transformers. It has been modified from its original forms to
9
+ # contain only the necessary utilities.
10
+
11
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ from dataclasses import dataclass
25
+ from typing import List, Optional, Tuple, Union
26
+
27
+ import torch
28
+
29
+
30
+ @dataclass
31
+ class AttentionMaskConverter:
32
+ """
33
+ A utility attention mask class that allows one to:
34
+ - Create a causal 4d mask
35
+ - Create a causal 4d mask with slided window
36
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
37
+ key_value_length) that can be multiplied with attention scores
38
+
39
+ Examples:
40
+
41
+ ```python
42
+ >>> import torch
43
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
44
+
45
+ >>> converter = AttentionMaskConverter(True)
46
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
47
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
48
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
49
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
50
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
51
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
52
+ ```
53
+
54
+ Parameters:
55
+ is_causal (`bool`):
56
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
57
+
58
+ sliding_window (`int`, *optional*):
59
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
60
+ """
61
+
62
+ is_causal: bool
63
+ sliding_window: int
64
+
65
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
66
+ self.is_causal = is_causal
67
+ self.sliding_window = sliding_window
68
+
69
+ if self.sliding_window is not None and self.sliding_window <= 0:
70
+ raise ValueError(
71
+ "Make sure that when passing `sliding_window` that its value is a strictly positive integer, "
72
+ f"not `{self.sliding_window}`"
73
+ )
74
+
75
+ def to_causal_4d(
76
+ self,
77
+ batch_size: int,
78
+ query_length: int,
79
+ key_value_length: int,
80
+ dtype: torch.dtype,
81
+ device: Union[torch.device, "str"] = "cpu",
82
+ ) -> Optional[torch.Tensor]:
83
+ """
84
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
85
+ bias to upper right hand triangular matrix (causal mask).
86
+ """
87
+ if not self.is_causal:
88
+ raise ValueError(
89
+ f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True."
90
+ )
91
+
92
+ # If shape is not cached, create a new causal mask and cache it
93
+ input_shape = (batch_size, query_length)
94
+ past_key_values_length = key_value_length - query_length
95
+
96
+ # create causal mask
97
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
98
+ causal_4d_mask = None
99
+ if input_shape[-1] > 1 or self.sliding_window is not None:
100
+ causal_4d_mask = self._make_causal_mask(
101
+ input_shape,
102
+ dtype,
103
+ device=device,
104
+ past_key_values_length=past_key_values_length,
105
+ sliding_window=self.sliding_window,
106
+ )
107
+
108
+ return causal_4d_mask
109
+
110
+ def to_4d(
111
+ self,
112
+ attention_mask_2d: torch.Tensor,
113
+ query_length: int,
114
+ dtype: torch.dtype,
115
+ key_value_length: Optional[int] = None,
116
+ ) -> torch.Tensor:
117
+ """
118
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
119
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
120
+ causal, a causal mask will be added.
121
+ """
122
+ input_shape = (attention_mask_2d.shape[0], query_length)
123
+
124
+ # create causal mask
125
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
126
+ causal_4d_mask = None
127
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
128
+ if key_value_length is None:
129
+ raise ValueError(
130
+ "This attention mask converter is causal. Make sure to pass "
131
+ "`key_value_length` to correctly create a causal mask."
132
+ )
133
+
134
+ past_key_values_length = key_value_length - query_length
135
+ causal_4d_mask = self._make_causal_mask(
136
+ input_shape,
137
+ dtype,
138
+ device=attention_mask_2d.device,
139
+ past_key_values_length=past_key_values_length,
140
+ sliding_window=self.sliding_window,
141
+ )
142
+ elif self.sliding_window is not None:
143
+ raise NotImplementedError(
144
+ "Sliding window is currently only implemented for causal masking"
145
+ )
146
+
147
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
148
+ expanded_attn_mask = self._expand_mask(
149
+ attention_mask_2d, dtype, tgt_len=input_shape[-1]
150
+ ).to(attention_mask_2d.device)
151
+
152
+ if causal_4d_mask is not None:
153
+ expanded_attn_mask = causal_4d_mask.masked_fill(
154
+ expanded_attn_mask.bool(), torch.finfo(dtype).min
155
+ )
156
+
157
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
158
+ expanded_4d_mask = expanded_attn_mask
159
+
160
+ return expanded_4d_mask
161
+
162
+ @staticmethod
163
+ def _make_causal_mask(
164
+ input_ids_shape: torch.Size,
165
+ dtype: torch.dtype,
166
+ device: torch.device,
167
+ past_key_values_length: int = 0,
168
+ sliding_window: Optional[int] = None,
169
+ ):
170
+ """
171
+ Make causal mask used for bi-directional self-attention.
172
+ """
173
+ bsz, tgt_len = input_ids_shape
174
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
175
+ mask_cond = torch.arange(mask.size(-1), device=device)
176
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
177
+
178
+ mask = mask.to(dtype)
179
+
180
+ if past_key_values_length > 0:
181
+ mask = torch.cat(
182
+ [
183
+ torch.zeros(
184
+ tgt_len, past_key_values_length, dtype=dtype, device=device
185
+ ),
186
+ mask,
187
+ ],
188
+ dim=-1,
189
+ )
190
+
191
+ # add lower triangular sliding window mask if necessary
192
+ if sliding_window is not None:
193
+ diagonal = past_key_values_length - sliding_window - 1
194
+
195
+ context_mask = torch.tril(
196
+ torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
197
+ )
198
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
199
+
200
+ return mask[None, None, :, :].expand(
201
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
202
+ )
203
+
204
+ @staticmethod
205
+ def _expand_mask(
206
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
207
+ ):
208
+ """
209
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
210
+ """
211
+ bsz, src_len = mask.size()
212
+ tgt_len = tgt_len if tgt_len is not None else src_len
213
+
214
+ expanded_mask = (
215
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
216
+ )
217
+
218
+ inverted_mask = 1.0 - expanded_mask
219
+
220
+ return inverted_mask.masked_fill(
221
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
222
+ )
223
+
224
+ @staticmethod
225
+ def _unmask_unattended(
226
+ expanded_mask: torch.FloatTensor,
227
+ min_dtype: float,
228
+ ):
229
+ # fmt: off
230
+ """
231
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
232
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
233
+ Details: https://github.com/pytorch/pytorch/issues/110213
234
+
235
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
236
+ `attention_mask` is [bsz, src_seq_len].
237
+
238
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case
239
+ of alibi attention bias.
240
+
241
+ For example, if `expanded_mask` is (e.g. here left-padding case)
242
+ ```
243
+ [[[[0, 0, 0],
244
+ [0, 0, 0],
245
+ [0, 0, 1]]],
246
+ [[[1, 0, 0],
247
+ [1, 1, 0],
248
+ [1, 1, 1]]],
249
+ [[[0, 0, 0],
250
+ [0, 1, 0],
251
+ [0, 1, 1]]]]
252
+ ```
253
+ then the modified `expanded_mask` will be
254
+ ```
255
+ [[[[1, 1, 1], <-- modified
256
+ [1, 1, 1], <-- modified
257
+ [0, 0, 1]]],
258
+ [[[1, 0, 0],
259
+ [1, 1, 0],
260
+ [1, 1, 1]]],
261
+ [[[1, 1, 1], <-- modified
262
+ [0, 1, 0],
263
+ [0, 1, 1]]]]
264
+ ```
265
+ """
266
+ # fmt: on
267
+ if expanded_mask.dtype == torch.bool:
268
+ raise ValueError(
269
+ "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
270
+ )
271
+
272
+ return expanded_mask.mul(
273
+ ~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)
274
+ )
275
+
276
+ @staticmethod
277
+ def _ignore_causal_mask_sdpa(
278
+ attention_mask: Optional[torch.Tensor],
279
+ inputs_embeds: torch.Tensor,
280
+ past_key_values_length: int,
281
+ sliding_window: Optional[int] = None,
282
+ is_training: bool = False,
283
+ ) -> bool:
284
+ """
285
+ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
286
+ ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
287
+
288
+ In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
289
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
290
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
291
+ passed).
292
+ """
293
+
294
+ _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
295
+ key_value_length = query_length + past_key_values_length
296
+
297
+ is_tracing = (
298
+ torch.jit.is_tracing()
299
+ or isinstance(inputs_embeds, torch.fx.Proxy)
300
+ or is_torchdynamo_compiling()
301
+ )
302
+
303
+ ignore_causal_mask = False
304
+
305
+ if attention_mask is None:
306
+ # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
307
+ # shape, thus SDPA's `is_causal` argument is rightfully updated
308
+ # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
309
+ # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
310
+ # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
311
+ # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
312
+ # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
313
+ #
314
+ # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
315
+ # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
316
+ if (
317
+ (is_training or not is_tracing)
318
+ and (query_length == 1 or key_value_length == query_length)
319
+ and (sliding_window is None or key_value_length < sliding_window)
320
+ ):
321
+ ignore_causal_mask = True
322
+ elif sliding_window is None or key_value_length < sliding_window:
323
+ if len(attention_mask.shape) == 4:
324
+ return False
325
+ elif not is_tracing and torch.all(attention_mask == 1):
326
+ if query_length == 1 or key_value_length == query_length:
327
+ # For query_length == 1, causal attention and bi-directional attention are the same.
328
+ ignore_causal_mask = True
329
+
330
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
331
+ # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
332
+ # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
333
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
334
+ # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
335
+
336
+ return ignore_causal_mask
337
+
338
+
339
+ def _prepare_4d_causal_attention_mask(
340
+ attention_mask: Optional[torch.Tensor],
341
+ input_shape: Union[torch.Size, Tuple, List],
342
+ inputs_embeds: torch.Tensor,
343
+ past_key_values_length: int,
344
+ sliding_window: Optional[int] = None,
345
+ ):
346
+ """
347
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
348
+ `(batch_size, key_value_length)`
349
+
350
+ Args:
351
+ attention_mask (`torch.Tensor` or `None`):
352
+ A 2D attention mask of shape `(batch_size, key_value_length)`
353
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
354
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
355
+ inputs_embeds (`torch.Tensor`):
356
+ The embedded inputs as a torch Tensor.
357
+ past_key_values_length (`int`):
358
+ The length of the key value cache.
359
+ sliding_window (`int`, *optional*):
360
+ If the model uses windowed attention, a sliding window should be passed.
361
+ """
362
+ attn_mask_converter = AttentionMaskConverter(
363
+ is_causal=True, sliding_window=sliding_window
364
+ )
365
+
366
+ key_value_length = input_shape[-1] + past_key_values_length
367
+
368
+ # 4d mask is passed through the layers
369
+ if attention_mask is not None and len(attention_mask.shape) == 2:
370
+ attention_mask = attn_mask_converter.to_4d(
371
+ attention_mask,
372
+ input_shape[-1],
373
+ key_value_length=key_value_length,
374
+ dtype=inputs_embeds.dtype,
375
+ )
376
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
377
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
378
+ if tuple(attention_mask.shape) != expected_shape:
379
+ raise ValueError(
380
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
381
+ )
382
+ else:
383
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
384
+ inverted_mask = 1.0 - attention_mask
385
+ attention_mask = inverted_mask.masked_fill(
386
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
387
+ )
388
+ else:
389
+ attention_mask = attn_mask_converter.to_causal_4d(
390
+ input_shape[0],
391
+ input_shape[-1],
392
+ key_value_length,
393
+ dtype=inputs_embeds.dtype,
394
+ device=inputs_embeds.device,
395
+ )
396
+
397
+ return attention_mask
torchtitan/experiments/deepseek_v3/inference.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/bash
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ NGPU=${NGPU:-"4"}
10
+
11
+ # Get the prompt from command line argument or use a default
12
+ prompt="${1:-What is 2+2?}"
13
+
14
+ # Run the model with the prompt
15
+ torchrun --standalone --nproc-per-node ${NGPU} generate.py "$prompt"
torchtitan/experiments/deepseek_v3/model.py ADDED
@@ -0,0 +1,1325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This code is based on model definition of `deepseek-ai/DeepSeek-V3-Base` on
8
+ # Hugging Face Model Hub. Url:
9
+ # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py
10
+ # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/resolve/main/configuration_deepseek.py
11
+ #
12
+ # It has been modified from its original forms to accommodate naming convention
13
+ # and usage patterns of the TorchTitan project.
14
+
15
+ # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+ """ PyTorch DeepSeek model."""
29
+ import math
30
+ from typing import Optional, Tuple
31
+
32
+ import torch
33
+ import torch.distributed as dist
34
+
35
+ import torch.distributed._symmetric_memory as symm_mem
36
+ import torch.nn.functional as F
37
+ import torch.utils.checkpoint
38
+
39
+ from attn_mask_utils import _prepare_4d_causal_attention_mask
40
+ from indices import generate_permute_indices
41
+ from model_config import ModelArgs
42
+ from symm_mem_recipes import OnDeviceAllToAllV
43
+ from torch import nn
44
+ from torch.distributed._functional_collectives import all_to_all_single_autograd
45
+
46
+ from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import (
47
+ ALIGN_SIZE_M,
48
+ grouped_gemm_forward,
49
+ )
50
+
51
+ # Get model parallel subgroup by name:
52
+ # e.g. "pp", "ep", None
53
+ def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup:
54
+ glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh()
55
+ return glob.get_group(dim_name)
56
+
57
+
58
+ class RMSNorm(nn.Module):
59
+ def __init__(self, hidden_size, eps=1e-6):
60
+ super().__init__()
61
+ self.weight = nn.Parameter(torch.ones(hidden_size))
62
+ self.variance_epsilon = eps
63
+
64
+ def forward(self, hidden_states):
65
+ input_dtype = hidden_states.dtype
66
+ hidden_states = hidden_states.to(torch.float32)
67
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
68
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
69
+ return self.weight * hidden_states.to(input_dtype)
70
+
71
+
72
+ class RotaryEmbedding(nn.Module):
73
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
74
+ super().__init__()
75
+
76
+ self.dim = dim
77
+ self.max_position_embeddings = max_position_embeddings
78
+ self.base = base
79
+ inv_freq = 1.0 / (
80
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
81
+ )
82
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
83
+
84
+ # Build here to make `torch.jit.trace` work.
85
+ self._set_cos_sin_cache(
86
+ seq_len=max_position_embeddings,
87
+ device=self.inv_freq.device,
88
+ dtype=torch.get_default_dtype(),
89
+ )
90
+
91
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
92
+ self.max_seq_len_cached = seq_len
93
+ t = torch.arange(
94
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
95
+ )
96
+
97
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
98
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
99
+ emb = torch.cat((freqs, freqs), dim=-1)
100
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
101
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
102
+
103
+ def forward(self, x, seq_len=None):
104
+ # x: [bs, num_attention_heads, seq_len, head_size]
105
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
106
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
107
+
108
+ return (
109
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
110
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
111
+ )
112
+
113
+
114
+ class LinearScalingRotaryEmbedding(RotaryEmbedding):
115
+ """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
116
+
117
+ def __init__(
118
+ self,
119
+ dim,
120
+ max_position_embeddings=2048,
121
+ base=10000,
122
+ device=None,
123
+ scaling_factor=1.0,
124
+ ):
125
+ self.scaling_factor = scaling_factor
126
+ super().__init__(dim, max_position_embeddings, base, device)
127
+
128
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
129
+ self.max_seq_len_cached = seq_len
130
+ t = torch.arange(
131
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
132
+ )
133
+ t = t / self.scaling_factor
134
+
135
+ freqs = torch.outer(t, self.inv_freq)
136
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
137
+ emb = torch.cat((freqs, freqs), dim=-1)
138
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
139
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
140
+
141
+
142
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek
143
+ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
144
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
145
+
146
+ def __init__(
147
+ self,
148
+ dim,
149
+ max_position_embeddings=2048,
150
+ base=10000,
151
+ device=None,
152
+ scaling_factor=1.0,
153
+ ):
154
+ self.scaling_factor = scaling_factor
155
+ super().__init__(dim, max_position_embeddings, base, device)
156
+
157
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
158
+ self.max_seq_len_cached = seq_len
159
+
160
+ if seq_len > self.max_position_embeddings:
161
+ base = self.base * (
162
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
163
+ - (self.scaling_factor - 1)
164
+ ) ** (self.dim / (self.dim - 2))
165
+ inv_freq = 1.0 / (
166
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
167
+ )
168
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
169
+
170
+ t = torch.arange(
171
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
172
+ )
173
+
174
+ freqs = torch.outer(t, self.inv_freq)
175
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
176
+ emb = torch.cat((freqs, freqs), dim=-1)
177
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
178
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
179
+
180
+
181
+ # Inverse dim formula to find dim based on number of rotations
182
+ def yarn_find_correction_dim(
183
+ num_rotations, dim, base=10000, max_position_embeddings=2048
184
+ ):
185
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
186
+ 2 * math.log(base)
187
+ )
188
+
189
+
190
+ # Find dim range bounds based on rotations
191
+ def yarn_find_correction_range(
192
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
193
+ ):
194
+ low = math.floor(
195
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
196
+ )
197
+ high = math.ceil(
198
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
199
+ )
200
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
201
+
202
+
203
+ def yarn_get_mscale(scale=1, mscale=1):
204
+ if scale <= 1:
205
+ return 1.0
206
+ return 0.1 * mscale * math.log(scale) + 1.0
207
+
208
+
209
+ def yarn_linear_ramp_mask(min, max, dim):
210
+ if min == max:
211
+ max += 0.001 # Prevent singularity
212
+
213
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
214
+ ramp_func = torch.clamp(linear_func, 0, 1)
215
+ return ramp_func
216
+
217
+
218
+ class YarnRotaryEmbedding(RotaryEmbedding):
219
+ def __init__(
220
+ self,
221
+ dim,
222
+ max_position_embeddings=2048,
223
+ base=10000,
224
+ device=None,
225
+ scaling_factor=1.0,
226
+ original_max_position_embeddings=4096,
227
+ beta_fast=32,
228
+ beta_slow=1,
229
+ mscale=1,
230
+ mscale_all_dim=0,
231
+ ):
232
+ self.scaling_factor = scaling_factor
233
+ self.original_max_position_embeddings = original_max_position_embeddings
234
+ self.beta_fast = beta_fast
235
+ self.beta_slow = beta_slow
236
+ self.mscale = mscale
237
+ self.mscale_all_dim = mscale_all_dim
238
+ super().__init__(dim, max_position_embeddings, base, device)
239
+
240
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
241
+ self.max_seq_len_cached = seq_len
242
+ dim = self.dim
243
+
244
+ freq_extra = 1.0 / (
245
+ self.base
246
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
247
+ )
248
+ freq_inter = 1.0 / (
249
+ self.scaling_factor
250
+ * self.base
251
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
252
+ )
253
+
254
+ low, high = yarn_find_correction_range(
255
+ self.beta_fast,
256
+ self.beta_slow,
257
+ dim,
258
+ self.base,
259
+ self.original_max_position_embeddings,
260
+ )
261
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
262
+ device=device, dtype=torch.float32
263
+ )
264
+ inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
265
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
266
+
267
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
268
+
269
+ freqs = torch.outer(t, inv_freq)
270
+
271
+ _mscale = float(
272
+ yarn_get_mscale(self.scaling_factor, self.mscale)
273
+ / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
274
+ )
275
+
276
+ emb = torch.cat((freqs, freqs), dim=-1)
277
+ self.register_buffer(
278
+ "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
279
+ )
280
+ self.register_buffer(
281
+ "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
282
+ )
283
+
284
+
285
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
286
+ def rotate_half(x):
287
+ """Rotates half the hidden dims of the input."""
288
+ x1 = x[..., : x.shape[-1] // 2]
289
+ x2 = x[..., x.shape[-1] // 2 :]
290
+ return torch.cat((-x2, x1), dim=-1)
291
+
292
+
293
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
294
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
295
+ """Applies Rotary Position Embedding to the query and key tensors.
296
+
297
+ Args:
298
+ q (`torch.Tensor`): The query tensor.
299
+ k (`torch.Tensor`): The key tensor.
300
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
301
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
302
+ position_ids (`torch.Tensor`):
303
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
304
+ used to pass offsetted position ids when working with a KV-cache.
305
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
306
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
307
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
308
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
309
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
310
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
311
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
312
+ Returns:
313
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
314
+ """
315
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
316
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
317
+
318
+ b, h, s, d = q.shape
319
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
320
+
321
+ b, h, s, d = k.shape
322
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
323
+
324
+ q_embed = (q * cos) + (rotate_half(q) * sin)
325
+ k_embed = (k * cos) + (rotate_half(k) * sin)
326
+ return q_embed, k_embed
327
+
328
+
329
+ class MLP(nn.Module):
330
+ act_fn = nn.SiLU()
331
+
332
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
333
+ super().__init__()
334
+ self.config = config
335
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
336
+ self.intermediate_size = (
337
+ config.intermediate_size if intermediate_size is None else intermediate_size
338
+ )
339
+
340
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
341
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
342
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
343
+
344
+ def forward(self, x):
345
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
346
+ return down_proj
347
+
348
+
349
+ class MoEGate(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.config = config
353
+ self.top_k = config.num_experts_per_tok
354
+ self.n_routed_experts = config.n_routed_experts
355
+ self.routed_scaling_factor = config.routed_scaling_factor
356
+ self.scoring_func = config.scoring_func
357
+ self.seq_aux = config.seq_aux
358
+ self.topk_method = config.topk_method
359
+ self.n_group = config.n_group
360
+ self.topk_group = config.topk_group
361
+
362
+ # topk selection algorithm
363
+ self.norm_topk_prob = config.norm_topk_prob
364
+ self.gating_dim = config.hidden_size
365
+ self.weight = nn.Parameter(
366
+ torch.empty((self.n_routed_experts, self.gating_dim))
367
+ )
368
+ if self.topk_method == "noaux_tc":
369
+ self.e_score_correction_bias = nn.Parameter(
370
+ # Changed from torch.empty to torch.rand to avoid non-even
371
+ # distribution for runs without actual weigths
372
+ torch.rand((self.n_routed_experts))
373
+ )
374
+ self.reset_parameters()
375
+
376
+ def reset_parameters(self) -> None:
377
+ import torch.nn.init as init
378
+
379
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
380
+
381
+ def forward(self, hidden_states):
382
+ bsz, seq_len, h = hidden_states.shape
383
+ # compute gating score
384
+ hidden_states = hidden_states.view(-1, h)
385
+ logits = F.linear(
386
+ hidden_states.type(torch.float32), self.weight.type(torch.float32), None
387
+ )
388
+ if self.scoring_func == "sigmoid":
389
+ scores = logits.sigmoid()
390
+ elif self.scoring_func == "softmax":
391
+ scores = logits.softmax(dim=-1, dtype=torch.float32)
392
+ else:
393
+ raise NotImplementedError(
394
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
395
+ )
396
+
397
+ # select top-k experts
398
+ if self.topk_method == "noaux_tc":
399
+ scores_for_choice = scores.view(
400
+ bsz * seq_len, -1
401
+ ) + self.e_score_correction_bias.unsqueeze(0)
402
+ group_scores = (
403
+ scores_for_choice.view(bsz * seq_len, self.n_group, -1)
404
+ .topk(2, dim=-1)[0]
405
+ .sum(dim=-1)
406
+ ) # [n, n_group]
407
+ group_idx = torch.topk(
408
+ group_scores, k=self.topk_group, dim=-1, sorted=False
409
+ )[
410
+ 1
411
+ ] # [n, top_k_group]
412
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
413
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
414
+ score_mask = (
415
+ group_mask.unsqueeze(-1)
416
+ .expand(
417
+ bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
418
+ )
419
+ .reshape(bsz * seq_len, -1)
420
+ ) # [n, e]
421
+ tmp_scores = scores_for_choice.masked_fill(
422
+ ~score_mask.bool(), 0.0
423
+ ) # [n, e]
424
+ _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
425
+ topk_weight = scores.gather(1, topk_idx)
426
+ elif self.topk_method == "greedy":
427
+ topk_weight, topk_idx = torch.topk(
428
+ scores, k=self.top_k, dim=-1, sorted=False
429
+ )
430
+ else:
431
+ raise NotImplementedError(
432
+ f"insupportable TopK function for MoE gating: {self.topk_method}"
433
+ )
434
+
435
+ # norm gate to sum 1
436
+ if self.top_k > 1 and self.norm_topk_prob:
437
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
438
+ topk_weight = topk_weight / denominator
439
+ topk_weight = (
440
+ topk_weight * self.routed_scaling_factor
441
+ ) # must multiply the scaling factor
442
+
443
+ return topk_idx, topk_weight
444
+
445
+
446
+ class MoE(nn.Module):
447
+ """
448
+ A mixed expert module containing shared experts.
449
+ """
450
+
451
+ # Class attributes:
452
+ # Two shuffle method supported:
453
+ # 1. "torch_all_to_all"
454
+ # 2. "symm_mem" (see `setup_symm_mem` below)
455
+ shuffle_method = "torch_all_to_all"
456
+
457
+ # Symmetric memory buffers shared by all MoE instances across layers
458
+ token_send_buf: Optional[torch.Tensor] = None
459
+ token_gather_buf: Optional[torch.Tensor] = None
460
+
461
+ def __init__(self, config):
462
+ super().__init__()
463
+ self.config = config
464
+ self.num_experts_per_tok = config.num_experts_per_tok
465
+
466
+ # ep_size is the number of ranks in expert dimension
467
+ if config.ep_size <= 1:
468
+ raise ValueError(
469
+ "For code simplicity, this model only supports distributed experts, "
470
+ "thus EP size must be > 1, please modify your model config"
471
+ )
472
+ self.ep_group = get_group("ep")
473
+ assert config.ep_size == self.ep_group.size()
474
+ self.ep_size = config.ep_size
475
+ self.ep_rank = self.ep_group.rank()
476
+ self.experts_per_rank = config.n_routed_experts // config.ep_size
477
+ # Use ModuleDict instead of ModuleList to preserve absoulte expert
478
+ # IDs while avoiding `None` experts. The absolute expert IDs match
479
+ # with checkpoint FQNs.
480
+ self.experts = nn.ModuleDict()
481
+ for i in range(self.experts_per_rank):
482
+ abs_expert_id = self.ep_rank * self.experts_per_rank + i
483
+ self.experts[str(abs_expert_id)] = MLP(
484
+ config, intermediate_size=config.moe_intermediate_size
485
+ )
486
+ self.gate = MoEGate(config)
487
+ if config.n_shared_experts is not None:
488
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
489
+ self.shared_experts = MLP(
490
+ config=config, intermediate_size=intermediate_size
491
+ )
492
+
493
+ def combine_experts(self, submod_name):
494
+ all_weights = []
495
+ for expert in self.experts.values():
496
+ lin = expert.get_submodule(submod_name)
497
+ all_weights.append(lin.weight)
498
+ lin.weight = None
499
+
500
+ concat_weight = torch.cat(all_weights)
501
+ self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight))
502
+
503
+ # This function is used to create a symm mem buffer for MoE's. It is for
504
+ # shuffling tokens fully "on-device", as compared to traditional torch
505
+ # all_to_all APIs which requrie a GPU-to-CPU sync of the splits. If a user
506
+ # calls this function, the `shuffle_method` would switch from
507
+ # `torch_all_to_all` to `symm_mem`.
508
+ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
509
+ # Switch shuffle method
510
+ self.shuffle_method = "symm_mem"
511
+
512
+ # Combine expert weights
513
+ print("Combining expert weights for Group GEMM")
514
+ self.combine_experts("gate_proj")
515
+ self.combine_experts("up_proj")
516
+ self.combine_experts("down_proj")
517
+
518
+ # Assuming worst case, 2x tokens are routed to one EP rank
519
+ overflow = 2
520
+ OnDeviceAllToAllV.max_output_len = (
521
+ self.config.max_seq_len * self.num_experts_per_tok * overflow
522
+ )
523
+
524
+ # Symmetric memory buffers are shared by all MoE instances across
525
+ # layers, we only need to initialize them once
526
+ if MoE.token_send_buf is not None:
527
+ return
528
+
529
+ # Input buffer for DP-to-EP shuffle
530
+ MoE.token_send_buf = symm_mem.empty(
531
+ self.config.max_seq_len
532
+ * self.num_experts_per_tok, # seq len * top k (flattened)
533
+ self.config.hidden_size, # hidden dim
534
+ dtype=dtype,
535
+ device=device,
536
+ )
537
+ # Input buffer for EP-to-DP shuffle
538
+ MoE.token_gather_buf = symm_mem.empty(
539
+ self.config.max_seq_len
540
+ * self.num_experts_per_tok # seq len * top k (flattened)
541
+ * overflow,
542
+ self.config.hidden_size, # hidden dim
543
+ dtype=dtype,
544
+ device=device,
545
+ )
546
+ print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE")
547
+
548
+ def get_send_buf(self):
549
+ # [Why detach?] During a first forward-backward step, the buffer would
550
+ # be included in a computational graph. In a second step, autograd will
551
+ # return an error saying "Trying to backward through the graph a second
552
+ # time (or directly access saved tensors more than once)". This is
553
+ # because the buffer is still in the graph, and autograd is trying to
554
+ # backward through the graph a second time. To avoid this, we detach the
555
+ # buffer from the graph. `detach()` returns a new tensor, which shares
556
+ # the same storage with the original one.
557
+ self.token_send_buf.grad = None
558
+ return self.token_send_buf.detach()
559
+
560
+ def get_gather_buf(self):
561
+ # See [Why detach?] in `get_send_buf`
562
+ self.token_gather_buf.grad = None
563
+ return self.token_gather_buf.detach()
564
+
565
+ def forward(self, hidden_states):
566
+ identity = hidden_states
567
+ orig_shape = hidden_states.shape
568
+ # for each token, select top-k experts, and compute the weight for each expert
569
+ topk_idx, topk_weight = self.gate(hidden_states)
570
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
571
+ if self.shuffle_method == "symm_mem":
572
+ y = self.moe_on_device(hidden_states, topk_idx, topk_weight)
573
+ else: # "torch_all_to_all"
574
+ y = self.moe_forward(hidden_states, topk_idx, topk_weight)
575
+
576
+ y = y.view(*orig_shape)
577
+ if self.config.n_shared_experts is not None:
578
+ y = y + self.shared_experts(identity)
579
+ return y
580
+
581
+ def moe_forward(self, x, topk_ids, topk_weight):
582
+ # This part sorts the token indices so that tokens routed to the same expert reside consecutively.
583
+ # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
584
+ # Since this is an "aritificial" index creation (final outcome being
585
+ # `idxs`), we don't need gradients here.
586
+ with torch.no_grad():
587
+ # [seq_len, n_routed_experts]
588
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
589
+ # Fill 1 to the selected experts
590
+ cnts.scatter_(1, topk_ids, 1)
591
+ tokens_per_expert = cnts.sum(dim=0)
592
+ # Token indices for each expert
593
+ idxs = topk_ids.view(-1).argsort()
594
+ sorted_tokens_shape = idxs.shape + x.shape[1:]
595
+
596
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
597
+ assert sorted_tokens.shape == sorted_tokens_shape
598
+
599
+ # This part exchange the information about the number of tokens send and
600
+ # received by each expert. We can understand this information as "side
601
+ # band", which is not part of the actual data. Thus no gradient is
602
+ # needed.
603
+ with torch.no_grad():
604
+ # Sum the tokens over local experts, then we get tokens per EP rank,
605
+ # which is the input splits
606
+ tokens_per_expert_group = tokens_per_expert.new_empty(
607
+ tokens_per_expert.shape[0]
608
+ )
609
+ dist.all_to_all_single(
610
+ tokens_per_expert_group, tokens_per_expert, group=self.ep_group
611
+ )
612
+ input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
613
+
614
+ # DP to EP token shuffle. This part needs gradient.
615
+ if self.shuffle_method == "symm_mem":
616
+ # Move input to the `token_send_buf` symm mem
617
+ token_send_buf = self.get_send_buf()
618
+ token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
619
+ # Note: `out=` avoids copy, but it is not differentiable
620
+ # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
621
+ token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
622
+ token_send_buf,
623
+ input_splits,
624
+ self.ep_group,
625
+ )
626
+ with torch.no_grad():
627
+ # Received tokens from all other ranks. TODO: use mask instead
628
+ received = output_splits.sum()
629
+ # TODO: don't use `received`
630
+ gathered_tokens = token_gather_buf[:received]
631
+ else: # "torch_all_to_all"
632
+ # Prepare input ans output splits
633
+ with torch.no_grad():
634
+ output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
635
+ dim=1
636
+ )
637
+ gathered_tokens = all_to_all_single_autograd(
638
+ sorted_tokens,
639
+ output_splits.tolist(),
640
+ input_splits.tolist(),
641
+ self.ep_group,
642
+ )
643
+
644
+ # This part prepares a 1D tensor with the same length as
645
+ # `gathered_tokens`. The 1D tensor is filled with local expert IDs which
646
+ # the tokens in `gathered_tokens` are headed for. This part doesn't need
647
+ # gradient.
648
+ with torch.no_grad():
649
+ gatherd_idxs = (
650
+ torch.arange(
651
+ tokens_per_expert_group.numel(),
652
+ device=tokens_per_expert_group.device,
653
+ )
654
+ % self.experts_per_rank
655
+ )
656
+ gatherd_idxs = gatherd_idxs.repeat_interleave(tokens_per_expert_group)
657
+
658
+ # Prepare buffer for tokens processed by experts
659
+ if self.shuffle_method == "symm_mem":
660
+ # Take necessary space from `token_gather_buf` symm mem because we are
661
+ # going to send them out after expert processing
662
+ processed_tokens = self.get_gather_buf()[: gathered_tokens.shape[0]]
663
+ else: # "torch_all_to_all"
664
+ processed_tokens = torch.empty_like(gathered_tokens)
665
+
666
+ # This part processes the tokens routed to the local experts.
667
+ # TODO: can we use group GEMM here?
668
+ for i, expert in enumerate(self.experts.values()):
669
+ processed_tokens[gatherd_idxs == i] = expert(
670
+ gathered_tokens[gatherd_idxs == i]
671
+ )
672
+
673
+ # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
674
+ # The input/output splits are just a reverse of the previous shuffle.
675
+ if self.shuffle_method == "symm_mem":
676
+ token_return_buf, _ = OnDeviceAllToAllV.apply(
677
+ processed_tokens,
678
+ output_splits,
679
+ self.ep_group,
680
+ )
681
+ returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
682
+ else: # "torch_all_to_all"
683
+ returned_tokens = all_to_all_single_autograd(
684
+ processed_tokens,
685
+ input_splits.tolist(),
686
+ output_splits.tolist(),
687
+ self.ep_group,
688
+ )
689
+
690
+ output_tokens = torch.empty_like(returned_tokens)
691
+ output_tokens[idxs] = returned_tokens
692
+ final_out = (
693
+ output_tokens.view(*topk_ids.shape, -1)
694
+ .type(topk_weight.dtype)
695
+ .mul_(topk_weight.unsqueeze(dim=-1))
696
+ .sum(dim=1)
697
+ .type(returned_tokens.dtype)
698
+ )
699
+ return final_out
700
+
701
+ def moe_on_device(self, x, topk_ids, topk_weight):
702
+ # This part sorts the token indices so that tokens routed to the same expert reside consecutively.
703
+ # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
704
+ # Since this is an "aritificial" index creation (final outcome being
705
+ # `idxs`), we don't need gradients here.
706
+ with torch.no_grad():
707
+ # [seq_len, n_routed_experts]
708
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
709
+ # Fill 1 to the selected experts
710
+ cnts.scatter_(1, topk_ids, 1)
711
+ tokens_per_expert = cnts.sum(dim=0)
712
+ # Token indices for each expert
713
+ idxs = topk_ids.view(-1).argsort()
714
+ sorted_tokens_shape = idxs.shape + x.shape[1:]
715
+
716
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
717
+ assert sorted_tokens.shape == sorted_tokens_shape
718
+
719
+ # This part exchange the information about the number of tokens send and
720
+ # received by each expert. We can understand this information as "side
721
+ # band", which is not part of the actual data. Thus no gradient is
722
+ # needed.
723
+ with torch.no_grad():
724
+ # Sum the tokens over local experts, then we get tokens per EP rank,
725
+ # which is the input splits
726
+ tokens_per_expert_group = tokens_per_expert.new_empty(
727
+ tokens_per_expert.shape[0]
728
+ )
729
+ dist.all_to_all_single(
730
+ tokens_per_expert_group, tokens_per_expert, group=self.ep_group
731
+ )
732
+ input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
733
+
734
+ # Move input to the `token_send_buf` symm mem
735
+ token_send_buf = self.get_send_buf()
736
+ token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
737
+ # Note: `out=` avoids copy, but it is not differentiable
738
+ # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
739
+ token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
740
+ token_send_buf,
741
+ input_splits,
742
+ self.ep_group,
743
+ )
744
+
745
+ # We need to permute the received tokens so that tokens for the same expert are contiguous.
746
+ # This part prepares a 1D tensor `permuted_indices` for such permutation.
747
+ # This part doesn't need gradient.
748
+ with torch.no_grad():
749
+ permuted_indices, m_sizes = generate_permute_indices(
750
+ tokens_per_expert_group,
751
+ self.experts_per_rank,
752
+ self.ep_size,
753
+ token_gather_buf.shape[0],
754
+ ALIGN_SIZE_M,
755
+ )
756
+
757
+ # Permute the received tokens so that tokens for the same expert are contiguous.
758
+ contig_tokens = token_gather_buf[permuted_indices]
759
+
760
+ # Run the first grouped GEMM
761
+ w1 = self.get_parameter("gate_proj_weight")
762
+ gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes)
763
+
764
+ # Run the second grouped GEMM
765
+ w3 = self.get_parameter("up_proj_weight")
766
+ up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes)
767
+
768
+ # Apply activation
769
+ hidden_outputs = MLP.act_fn(gate_proj) * up_proj
770
+
771
+ # Run the third grouped GEMM
772
+ w2 = self.get_parameter("down_proj_weight")
773
+ hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes)
774
+
775
+ # Prepare buffer for tokens processed by experts
776
+ # Take necessary space from `token_gather_buf` symm mem because we are
777
+ # going to send them out after expert processing
778
+ processed_tokens = self.get_gather_buf()
779
+
780
+ # Move into Symmetric Memory for the return shuffle
781
+ processed_tokens[permuted_indices] = hidden_outputs
782
+
783
+ # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
784
+ # The input/output splits are just a reverse of the previous shuffle.
785
+ token_return_buf, _ = OnDeviceAllToAllV.apply(
786
+ processed_tokens,
787
+ output_splits,
788
+ self.ep_group,
789
+ )
790
+ returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
791
+
792
+ output_tokens = torch.empty_like(returned_tokens)
793
+ output_tokens[idxs] = returned_tokens
794
+ final_out = (
795
+ output_tokens.view(*topk_ids.shape, -1)
796
+ .type(topk_weight.dtype)
797
+ .mul_(topk_weight.unsqueeze(dim=-1))
798
+ .sum(dim=1)
799
+ .type(returned_tokens.dtype)
800
+ )
801
+ return final_out
802
+
803
+
804
+ class Attention(nn.Module):
805
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
806
+
807
+ def __init__(self, config: ModelArgs, layer_idx: Optional[int] = None):
808
+ super().__init__()
809
+ self.config = config
810
+ self.layer_idx = layer_idx
811
+ self.attention_dropout = config.attention_dropout
812
+ self.hidden_size = config.hidden_size
813
+ self.num_heads = config.num_attention_heads
814
+
815
+ self.max_position_embeddings = config.max_position_embeddings
816
+ self.rope_theta = config.rope_theta
817
+ self.q_lora_rank = config.q_lora_rank
818
+ self.qk_rope_head_dim = config.qk_rope_head_dim
819
+ self.kv_lora_rank = config.kv_lora_rank
820
+ self.v_head_dim = config.v_head_dim
821
+ self.qk_nope_head_dim = config.qk_nope_head_dim
822
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
823
+
824
+ self.is_causal = True
825
+
826
+ if self.q_lora_rank is None:
827
+ self.q_proj = nn.Linear(
828
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
829
+ )
830
+ else:
831
+ self.q_a_proj = nn.Linear(
832
+ self.hidden_size, config.q_lora_rank, bias=config.attention_bias
833
+ )
834
+ self.q_a_layernorm = RMSNorm(config.q_lora_rank)
835
+ self.q_b_proj = nn.Linear(
836
+ config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
837
+ )
838
+
839
+ self.kv_a_proj_with_mqa = nn.Linear(
840
+ self.hidden_size,
841
+ config.kv_lora_rank + config.qk_rope_head_dim,
842
+ bias=config.attention_bias,
843
+ )
844
+ self.kv_a_layernorm = RMSNorm(config.kv_lora_rank)
845
+ self.kv_b_proj = nn.Linear(
846
+ config.kv_lora_rank,
847
+ self.num_heads
848
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
849
+ bias=False,
850
+ )
851
+
852
+ self.o_proj = nn.Linear(
853
+ self.num_heads * self.v_head_dim,
854
+ self.hidden_size,
855
+ bias=config.attention_bias,
856
+ )
857
+ self._init_rope()
858
+
859
+ self.softmax_scale = self.q_head_dim ** (-0.5)
860
+ if self.config.rope_scaling is not None:
861
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
862
+ scaling_factor = self.config.rope_scaling["factor"]
863
+ if mscale_all_dim:
864
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
865
+ self.softmax_scale = self.softmax_scale * mscale * mscale
866
+
867
+ def _init_rope(self):
868
+ if self.config.rope_scaling is None:
869
+ self.rotary_emb = RotaryEmbedding(
870
+ self.qk_rope_head_dim,
871
+ max_position_embeddings=self.max_position_embeddings,
872
+ base=self.rope_theta,
873
+ )
874
+ else:
875
+ scaling_type = self.config.rope_scaling["type"]
876
+ scaling_factor = self.config.rope_scaling["factor"]
877
+ if scaling_type == "linear":
878
+ self.rotary_emb = LinearScalingRotaryEmbedding(
879
+ self.qk_rope_head_dim,
880
+ max_position_embeddings=self.max_position_embeddings,
881
+ scaling_factor=scaling_factor,
882
+ base=self.rope_theta,
883
+ )
884
+ elif scaling_type == "dynamic":
885
+ self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
886
+ self.qk_rope_head_dim,
887
+ max_position_embeddings=self.max_position_embeddings,
888
+ scaling_factor=scaling_factor,
889
+ base=self.rope_theta,
890
+ )
891
+ elif scaling_type == "yarn":
892
+ kwargs = {
893
+ key: self.config.rope_scaling[key]
894
+ for key in [
895
+ "original_max_position_embeddings",
896
+ "beta_fast",
897
+ "beta_slow",
898
+ "mscale",
899
+ "mscale_all_dim",
900
+ ]
901
+ if key in self.config.rope_scaling
902
+ }
903
+ self.rotary_emb = YarnRotaryEmbedding(
904
+ self.qk_rope_head_dim,
905
+ max_position_embeddings=self.max_position_embeddings,
906
+ scaling_factor=scaling_factor,
907
+ base=self.rope_theta,
908
+ **kwargs,
909
+ )
910
+ else:
911
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
912
+
913
+ def forward(
914
+ self,
915
+ hidden_states: torch.Tensor,
916
+ attention_mask: Optional[torch.Tensor] = None,
917
+ position_ids: Optional[torch.LongTensor] = None,
918
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
919
+ bsz, q_len, _ = hidden_states.size()
920
+
921
+ if self.q_lora_rank is None:
922
+ q = self.q_proj(hidden_states)
923
+ else:
924
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
925
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
926
+ q_nope, q_pe = torch.split(
927
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
928
+ )
929
+
930
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
931
+ compressed_kv, k_pe = torch.split(
932
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
933
+ )
934
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
935
+ kv = (
936
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
937
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
938
+ .transpose(1, 2)
939
+ )
940
+
941
+ k_nope, value_states = torch.split(
942
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
943
+ )
944
+ kv_seq_len = value_states.shape[-2]
945
+
946
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
947
+
948
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
949
+
950
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
951
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
952
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
953
+
954
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
955
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
956
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
957
+
958
+ if attention_mask is not None:
959
+ # Attention mask was made 4D because the `attn_weights` above is 4D.
960
+ # We probably can make this mask smarter if we want to pack sequences
961
+ # together, instead of using padding. This optimization can be used in
962
+ # inference. For training, if we want to pack sequences, data loader
963
+ # will pass in a mask containing such info.
964
+ attention_mask = _prepare_4d_causal_attention_mask(
965
+ attention_mask, # None, or user provided mask in 2D
966
+ (bsz, q_len),
967
+ hidden_states,
968
+ 0, # past_key_values_length, 0 when training
969
+ )
970
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
971
+ raise ValueError(
972
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
973
+ )
974
+
975
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
976
+ query=query_states,
977
+ key=key_states,
978
+ value=value_states,
979
+ attn_mask=attention_mask,
980
+ dropout_p=self.attention_dropout,
981
+ is_causal=attention_mask is None,
982
+ scale=self.softmax_scale,
983
+ )
984
+
985
+ attn_output = attn_output.transpose(1, 2).contiguous()
986
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
987
+ attn_output = self.o_proj(attn_output)
988
+
989
+ return attn_output
990
+
991
+
992
+ class DecoderLayer(nn.Module):
993
+ def __init__(self, config: ModelArgs, layer_idx: int):
994
+ super().__init__()
995
+ self.hidden_size = config.hidden_size
996
+
997
+ self.self_attn = Attention(config=config, layer_idx=layer_idx)
998
+
999
+ self.mlp = (
1000
+ MoE(config)
1001
+ if (
1002
+ config.n_routed_experts is not None
1003
+ and layer_idx >= config.first_k_dense_replace
1004
+ and layer_idx % config.moe_layer_freq == 0
1005
+ )
1006
+ else MLP(config)
1007
+ )
1008
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1009
+ self.post_attention_layernorm = RMSNorm(
1010
+ config.hidden_size, eps=config.rms_norm_eps
1011
+ )
1012
+
1013
+ def forward(
1014
+ self,
1015
+ hidden_states: torch.Tensor,
1016
+ attention_mask: Optional[torch.Tensor] = None,
1017
+ position_ids: Optional[torch.LongTensor] = None,
1018
+ ) -> torch.Tensor:
1019
+ """
1020
+ Args:
1021
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1022
+ attention_mask (`torch.FloatTensor`, *optional*):
1023
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1024
+ query_sequence_length, key_sequence_length)` if default attention is used.
1025
+ """
1026
+ residual = hidden_states
1027
+
1028
+ hidden_states = self.input_layernorm(hidden_states)
1029
+
1030
+ # Self Attention
1031
+ hidden_states = self.self_attn(
1032
+ hidden_states=hidden_states,
1033
+ attention_mask=attention_mask,
1034
+ position_ids=position_ids,
1035
+ )
1036
+ hidden_states = residual + hidden_states
1037
+
1038
+ # Fully Connected
1039
+ residual = hidden_states
1040
+ hidden_states = self.post_attention_layernorm(hidden_states)
1041
+ hidden_states = self.mlp(hidden_states)
1042
+ hidden_states = residual + hidden_states
1043
+
1044
+ return hidden_states
1045
+
1046
+
1047
+ Deepseek_INPUTS_DOCSTRING = r"""
1048
+ Args:
1049
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1050
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1051
+ it.
1052
+
1053
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1054
+ [`PreTrainedTokenizer.__call__`] for details.
1055
+
1056
+ [What are input IDs?](../glossary#input-ids)
1057
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1058
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1059
+
1060
+ - 1 for tokens that are **not masked**,
1061
+ - 0 for tokens that are **masked**.
1062
+
1063
+ [What are attention masks?](../glossary#attention-mask)
1064
+
1065
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1066
+ [`PreTrainedTokenizer.__call__`] for details.
1067
+
1068
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1069
+ `past_key_values`).
1070
+
1071
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1072
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1073
+ information on the default strategy.
1074
+
1075
+ - 1 indicates the head is **not masked**,
1076
+ - 0 indicates the head is **masked**.
1077
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1078
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1079
+ config.n_positions - 1]`.
1080
+
1081
+ [What are position IDs?](../glossary#position-ids)
1082
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1083
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1084
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1085
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1086
+
1087
+ Two formats are allowed:
1088
+ - a [`~cache_utils.Cache`] instance;
1089
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1090
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1091
+ cache format.
1092
+
1093
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1094
+ legacy cache format will be returned.
1095
+
1096
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1097
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1098
+ of shape `(batch_size, sequence_length)`.
1099
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1100
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1101
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1102
+ model's internal embedding lookup matrix.
1103
+ use_cache (`bool`, *optional*):
1104
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1105
+ `past_key_values`).
1106
+ output_attentions (`bool`, *optional*):
1107
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1108
+ tensors for more detail.
1109
+ output_hidden_states (`bool`, *optional*):
1110
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1111
+ more detail.
1112
+ return_dict (`bool`, *optional*):
1113
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1114
+ """
1115
+
1116
+
1117
+ class DeepseekModel(torch.nn.Module):
1118
+ """
1119
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`]
1120
+
1121
+ Args:
1122
+ config: ModelArgs
1123
+ """
1124
+
1125
+ def __init__(self, config: ModelArgs):
1126
+ super().__init__()
1127
+ self.config = config
1128
+ self.padding_idx = config.pad_token_id
1129
+ self.vocab_size = config.vocab_size
1130
+
1131
+ # Creating model parts related to my stage
1132
+ assert (
1133
+ config.stage_idx < config.num_stages
1134
+ ), f"Stage {config.stage_idx} is not in the model"
1135
+ print(f"Creating model stage {config.stage_idx} of {config.num_stages}")
1136
+
1137
+ self.embed_tokens = (
1138
+ nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1139
+ if config.stage_idx == 0
1140
+ else None
1141
+ )
1142
+
1143
+ self.layers = torch.nn.ModuleDict()
1144
+ division = config.num_hidden_layers // config.num_stages
1145
+ residual = config.num_hidden_layers % config.num_stages
1146
+ # Some earlier stages may have 1 more layer than latter stages because
1147
+ # the division may have residual; this is more even than giving the
1148
+ # entire residual to the last stage.
1149
+ layers_per_stage = [
1150
+ division + 1 if stage < residual else division
1151
+ for stage in range(config.num_stages)
1152
+ ]
1153
+ assert sum(layers_per_stage) == config.num_hidden_layers
1154
+ layer_id_start = sum(layers_per_stage[: config.stage_idx])
1155
+ layer_id_end = layer_id_start + layers_per_stage[config.stage_idx]
1156
+ for layer_id in range(layer_id_start, layer_id_end):
1157
+ self.layers[str(layer_id)] = DecoderLayer(config, layer_id)
1158
+
1159
+ self.norm = (
1160
+ RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1161
+ if config.stage_idx == config.num_stages - 1
1162
+ else None
1163
+ )
1164
+
1165
+ # Initialize weights and apply final processing
1166
+ self.apply(self._init_weights)
1167
+
1168
+ def _init_weights(self, module):
1169
+ std = self.config.initializer_range
1170
+ if isinstance(module, nn.Linear):
1171
+ module.weight.data.normal_(mean=0.0, std=std)
1172
+ if module.bias is not None:
1173
+ module.bias.data.zero_()
1174
+ elif isinstance(module, nn.Embedding):
1175
+ module.weight.data.normal_(mean=0.0, std=std)
1176
+ if module.padding_idx is not None:
1177
+ module.weight.data[module.padding_idx].zero_()
1178
+
1179
+ def forward(
1180
+ self,
1181
+ tokens: torch.Tensor,
1182
+ attention_mask: Optional[torch.Tensor] = None,
1183
+ position_ids: Optional[torch.LongTensor] = None,
1184
+ ) -> torch.Tensor:
1185
+ # Embedding
1186
+ hidden_states = (
1187
+ self.embed_tokens(tokens) if self.embed_tokens is not None else tokens
1188
+ )
1189
+
1190
+ # decoder layers
1191
+ for decoder_layer in self.layers.values():
1192
+ hidden_states = decoder_layer(
1193
+ hidden_states,
1194
+ attention_mask=attention_mask,
1195
+ position_ids=position_ids,
1196
+ )
1197
+
1198
+ hidden_states = (
1199
+ self.norm(hidden_states) if self.norm is not None else hidden_states
1200
+ )
1201
+ return hidden_states
1202
+
1203
+
1204
+ class DeepseekForCausalLM(torch.nn.Module):
1205
+ def __init__(self, config):
1206
+ super().__init__()
1207
+ self.model = DeepseekModel(config)
1208
+ self.lm_head = (
1209
+ nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1210
+ if config.stage_idx == config.num_stages - 1
1211
+ else None
1212
+ )
1213
+
1214
+ # Initialize weights and apply final processing
1215
+ # self.post_init()
1216
+
1217
+ def forward(
1218
+ self,
1219
+ tokens: torch.Tensor,
1220
+ attention_mask: Optional[torch.Tensor] = None,
1221
+ position_ids: Optional[torch.LongTensor] = None,
1222
+ ) -> Tuple:
1223
+ r"""
1224
+ Example:
1225
+
1226
+ ```python
1227
+ >>> from transformers import AutoTokenizer, DeepseekForCausalLM
1228
+
1229
+ >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1230
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1231
+
1232
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1233
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1234
+
1235
+ >>> # Generate
1236
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1237
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1238
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1239
+ ```"""
1240
+ hidden_states = self.model(
1241
+ tokens,
1242
+ attention_mask=attention_mask,
1243
+ position_ids=position_ids,
1244
+ )
1245
+
1246
+ logits = (
1247
+ self.lm_head(hidden_states) if self.lm_head is not None else hidden_states
1248
+ )
1249
+ return logits
1250
+
1251
+ def prepare_inputs_for_generation(
1252
+ self,
1253
+ input_ids,
1254
+ past_key_values=None,
1255
+ attention_mask=None,
1256
+ **kwargs,
1257
+ ):
1258
+ if past_key_values is not None:
1259
+ # Assuming isinstance(past_key_values, Cache):
1260
+ cache_length = past_key_values.get_seq_length()
1261
+ past_length = past_key_values.seen_tokens
1262
+ max_cache_length = past_key_values.get_max_length()
1263
+
1264
+ # Keep only the unprocessed tokens:
1265
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1266
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1267
+ # input)
1268
+ if (
1269
+ attention_mask is not None
1270
+ and attention_mask.shape[1] > input_ids.shape[1]
1271
+ ):
1272
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1273
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1274
+ # input_ids based on the past_length.
1275
+ elif past_length < input_ids.shape[1]:
1276
+ input_ids = input_ids[:, past_length:]
1277
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1278
+
1279
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1280
+ if (
1281
+ max_cache_length is not None
1282
+ and attention_mask is not None
1283
+ and cache_length + input_ids.shape[1] > max_cache_length
1284
+ ):
1285
+ attention_mask = attention_mask[:, -max_cache_length:]
1286
+
1287
+ position_ids = kwargs.get("position_ids", None)
1288
+ if attention_mask is not None and position_ids is None:
1289
+ # create position_ids on the fly for batch generation
1290
+ position_ids = attention_mask.long().cumsum(-1) - 1
1291
+ position_ids.masked_fill_(attention_mask == 0, 1)
1292
+ if past_key_values:
1293
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1294
+
1295
+ model_inputs = {"input_ids": input_ids}
1296
+
1297
+ model_inputs.update(
1298
+ {
1299
+ "position_ids": position_ids,
1300
+ "past_key_values": past_key_values,
1301
+ "use_cache": kwargs.get("use_cache"),
1302
+ "attention_mask": attention_mask,
1303
+ }
1304
+ )
1305
+ return model_inputs
1306
+
1307
+ @staticmethod
1308
+ def _reorder_cache(past_key_values, beam_idx):
1309
+ reordered_past = ()
1310
+ for layer_past in past_key_values:
1311
+ reordered_past += (
1312
+ tuple(
1313
+ past_state.index_select(0, beam_idx.to(past_state.device))
1314
+ for past_state in layer_past
1315
+ ),
1316
+ )
1317
+ return reordered_past
1318
+
1319
+ # Setup Symmetric Memory for MoE token shuffle.
1320
+ # Supports inference currently.
1321
+ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
1322
+ for layer in self.model.layers.values():
1323
+ if not isinstance(layer.mlp, MoE):
1324
+ continue
1325
+ layer.mlp.setup_symm_mem(dtype, device)
torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .triton_on_device_all_to_all_v import OnDeviceAllToAllV
8
+
9
+ __all__ = [
10
+ "OnDeviceAllToAllV",
11
+ ]
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from .triton_utils import get_flat_bid, get_flat_tid
11
+
12
+
13
+ @triton.jit
14
+ def send_signal(addrs, sem: tl.constexpr):
15
+ if sem == "relaxed":
16
+ tl.inline_asm_elementwise(
17
+ """
18
+ {
19
+ .reg .u32 %tmp32_<1>;
20
+ .reg .pred %p<1>;
21
+
22
+ send_signal:
23
+ atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
24
+ setp.eq.u32 %p0, %tmp32_0, 0;
25
+ @!%p0 bra send_signal;
26
+ }
27
+ """,
28
+ "=r, l",
29
+ [addrs],
30
+ dtype=tl.int32,
31
+ is_pure=False,
32
+ pack=1,
33
+ )
34
+ elif sem == "acq_rel":
35
+ tl.inline_asm_elementwise(
36
+ """
37
+ {
38
+ .reg .u32 %tmp32_<1>;
39
+ .reg .pred %p<1>;
40
+
41
+ send_signal:
42
+ atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
43
+ setp.eq.u32 %p0, %tmp32_0, 0;
44
+ @!%p0 bra send_signal;
45
+ }
46
+ """,
47
+ "=r, l",
48
+ [addrs],
49
+ dtype=tl.int32,
50
+ is_pure=False,
51
+ pack=1,
52
+ )
53
+ else:
54
+ raise RuntimeError(f"Unrecognized sem: {sem}")
55
+
56
+
57
+ @triton.jit
58
+ def wait_signal(addrs, sem: tl.constexpr):
59
+ if sem == "relaxed":
60
+ tl.inline_asm_elementwise(
61
+ """
62
+ {
63
+ .reg .u32 %tmp32_<1>;
64
+ .reg .pred %p<1>;
65
+
66
+ wait_signal:
67
+ atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
68
+ setp.eq.u32 %p0, %tmp32_0, 1;
69
+ @!%p0 bra wait_signal;
70
+ }
71
+ """,
72
+ "=r, l",
73
+ [addrs],
74
+ dtype=tl.int32,
75
+ is_pure=False,
76
+ pack=1,
77
+ )
78
+ elif sem == "acq_rel":
79
+ tl.inline_asm_elementwise(
80
+ """
81
+ {
82
+ .reg .u32 %tmp32_<1>;
83
+ .reg .pred %p<1>;
84
+
85
+ wait_signal:
86
+ atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
87
+ setp.eq.u32 %p0, %tmp32_0, 1;
88
+ @!%p0 bra wait_signal;
89
+ }
90
+ """,
91
+ "=r, l",
92
+ [addrs],
93
+ dtype=tl.int32,
94
+ is_pure=False,
95
+ pack=1,
96
+ )
97
+ else:
98
+ raise RuntimeError(f"Unrecognized sem: {sem}")
99
+
100
+
101
+ @triton.jit
102
+ def blockwise_barrier(
103
+ signal_pad_ptrs,
104
+ block_id,
105
+ rank: tl.constexpr,
106
+ world_size: tl.constexpr,
107
+ sem: tl.constexpr,
108
+ ):
109
+ """
110
+ Synchronizes blocks with matching block_id across participating devices.
111
+
112
+ Note: the function itself is not a system level barrier/fence. It is a
113
+ building block for expressing different synchronization patterns.
114
+
115
+ Pattern 0: Ensures that all writes to symm_mem buffers from previous
116
+ kernels across all devices are visible to the current kernel:
117
+
118
+ blockwise_barrier(..., sem="relaxed")
119
+ sync_threads()
120
+
121
+ Pattern 1: Ensures that all writes to symm_mem buffers from the current
122
+ block are visible to all remote blocks with matching blockIdx:
123
+
124
+ sync_threads()
125
+ blockwise_barrier(..., sem="acq_rel")
126
+ sync_threads()
127
+
128
+ Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
129
+ for writing by subsequent kernels across all devices.
130
+
131
+ sync_threads()
132
+ blockwise_barrier(..., sem="relaxed")
133
+
134
+ CUDA graph friendliness:
135
+
136
+ This barrier operates through atomic operations on a zero-filled signal
137
+ pad, which resets to a zero-filled state after each successful
138
+ synchronization. This design eliminates the need for incrementing a
139
+ flag from host.
140
+ """
141
+ if block_id is None:
142
+ block_id = get_flat_bid()
143
+ flat_tid = get_flat_tid()
144
+
145
+ remote_ranks = tl.arange(0, world_size)
146
+ signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
147
+ remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
148
+ tl.pointer_type(tl.uint32)
149
+ )
150
+ send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
151
+
152
+ local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
153
+ tl.pointer_type(tl.uint32)
154
+ )
155
+ wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
156
+
157
+ if flat_tid < world_size:
158
+ send_signal(send_addrs, sem)
159
+ wait_signal(wait_addrs, sem)
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.distributed._symmetric_memory as symm_mem
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from .triton_barrier import blockwise_barrier
14
+ from .triton_utils import sync_threads
15
+
16
+
17
+ @triton.jit
18
+ def _exchange_row_offsets(
19
+ split_sizes_ptrs,
20
+ rank: tl.constexpr,
21
+ world_size: tl.constexpr,
22
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
23
+ ):
24
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
25
+
26
+ # split_sizes_ptr for all ranks
27
+ # All these vector stacks into split_sizes_matrix
28
+ split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
29
+
30
+ # split_sizes_matrix[remote_rank, :]
31
+ input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
32
+ tl.pointer_type(tl.int64)
33
+ )
34
+
35
+ offsets_ = tl.arange(0, world_size)
36
+ input_split_sizes = tl.load(
37
+ input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
38
+ )
39
+
40
+ num_rows = tl.load(input_split_sizes_ptr + rank)
41
+ input_row_offset = tl.sum(input_split_sizes) - num_rows
42
+
43
+ # split_sizes_matrix[:, rank]
44
+ output_split_sizes_ptrs = (
45
+ tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
46
+ )
47
+ output_split_sizes = tl.load(
48
+ output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
49
+ )
50
+ output_row_offset = tl.sum(output_split_sizes) - num_rows
51
+
52
+ return input_row_offset, output_row_offset, num_rows
53
+
54
+
55
+ @triton.jit
56
+ def on_device_all_to_all_v_kernel(
57
+ output_ptr,
58
+ output_splits_ptr,
59
+ input_ptrs,
60
+ input_splits_ptr,
61
+ signal_pad_ptrs,
62
+ dim: tl.constexpr, # Separate dim for easier vectorization
63
+ rank: tl.constexpr,
64
+ world_size: tl.constexpr,
65
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
66
+ UNROLL_FACTOR: tl.constexpr,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ ):
69
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
70
+ sync_threads()
71
+
72
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
73
+ block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
74
+
75
+ input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
76
+ input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
77
+ )
78
+
79
+ output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
80
+ if block_offset == 0:
81
+ # Update output_splits
82
+ tl.store(output_splits_ptr + remote_rank, num_rows)
83
+
84
+ input_ptr = (
85
+ tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
86
+ tl.pointer_type(tl.bfloat16)
87
+ )
88
+ + input_row_offset * dim
89
+ )
90
+ output_ptr = output_ptr + output_row_offset * dim
91
+
92
+ outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
93
+ outer_loop_iters_per_rank = tl.cdiv(
94
+ tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
95
+ )
96
+ numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
97
+ offset = numel_per_rank * block_offset
98
+ end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
99
+
100
+ unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
101
+ for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
102
+ datas = []
103
+ for j in tl.range(
104
+ i,
105
+ i + outer_loop_step,
106
+ BLOCK_SIZE,
107
+ loop_unroll_factor=UNROLL_FACTOR,
108
+ ):
109
+ offsets = j + tl.arange(0, BLOCK_SIZE)
110
+ data = tl.load(input_ptr + offsets)
111
+ tl.store(output_ptr + offsets, data)
112
+
113
+ offset += unroll_region_size
114
+ while offset < end:
115
+ offsets = offset + tl.arange(0, BLOCK_SIZE)
116
+ mask = offsets < num_rows * dim
117
+ data = tl.load(input_ptr + offsets, mask=mask)
118
+ tl.store(output_ptr + offsets, data, mask=mask)
119
+ offset += BLOCK_SIZE
120
+
121
+ sync_threads()
122
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
123
+ return
124
+
125
+
126
+ def _on_device_all_to_all_v(
127
+ output: torch.Tensor,
128
+ output_splits: torch.Tensor,
129
+ input: torch.Tensor,
130
+ input_splits: torch.Tensor,
131
+ group: dist.ProcessGroup = dist.group.WORLD,
132
+ BLOCKS_PER_REMOTE_RANK=8,
133
+ UNROLL_FACTOR: int = 8,
134
+ BLOCK_SIZE: int = 16384,
135
+ ):
136
+ assert output.dim() == 2, f"{output.shape}"
137
+ assert input.dim() == 2, f"{input.shape}"
138
+ assert output.shape[1] == input.shape[1]
139
+
140
+ dim = output.shape[1]
141
+ input_hdl = symm_mem.rendezvous(input, group=group)
142
+ input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
143
+
144
+ num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
145
+ kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
146
+ output,
147
+ output_splits,
148
+ input_hdl.buffer_ptrs_dev,
149
+ input_splits_hdl.buffer_ptrs_dev,
150
+ input_hdl.signal_pad_ptrs_dev,
151
+ dim=dim,
152
+ rank=input_hdl.rank,
153
+ world_size=input_hdl.world_size,
154
+ BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
155
+ UNROLL_FACTOR=UNROLL_FACTOR,
156
+ BLOCK_SIZE=BLOCK_SIZE,
157
+ num_warps=16,
158
+ )
159
+ # log_triton_kernel(kernel)
160
+ return output
161
+
162
+
163
+ class OnDeviceAllToAllV(torch.autograd.Function):
164
+ # A symmetric memory holding the grad_output during backward
165
+ grad_output_buf = None
166
+ # A symmetric memory for exchanges split sizes during both forward and backward
167
+ splits_buf = None
168
+ # Maximum output length (need to be set before use of OnDeviceAllToAllV)
169
+ max_output_len = None
170
+
171
+ @staticmethod
172
+ def forward(
173
+ ctx,
174
+ input: torch.Tensor,
175
+ input_splits: torch.Tensor,
176
+ group: dist.ProcessGroup = dist.group.WORLD,
177
+ ):
178
+ """
179
+ Args:
180
+ input: input tensor with data for all ranks concatenated.
181
+ input_splits: input splits of shape (group.world_size,)
182
+ group: process group to scope the collective.
183
+ """
184
+ # Initialize input splits buffer (one time only)
185
+ if OnDeviceAllToAllV.splits_buf is None:
186
+ OnDeviceAllToAllV.splits_buf = symm_mem.empty(
187
+ *input_splits.shape,
188
+ dtype=input_splits.dtype,
189
+ device=input_splits.device,
190
+ )
191
+
192
+ if OnDeviceAllToAllV.max_output_len is None:
193
+ raise RuntimeError(
194
+ "Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
195
+ )
196
+
197
+ # Allocate output buffer
198
+ output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
199
+ # Allocate output splits tensor
200
+ output_splits = torch.empty_like(input_splits)
201
+ # Copy input splits to the buffer
202
+ OnDeviceAllToAllV.splits_buf.copy_(input_splits)
203
+
204
+ # Shuffle input to output
205
+ _on_device_all_to_all_v(
206
+ output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
207
+ )
208
+
209
+ # Output splits in forward is the input splits in backward
210
+ ctx.save_for_backward(output_splits)
211
+ ctx.group = group
212
+ ctx.input_shape = input.shape
213
+ return output, output_splits
214
+
215
+ @staticmethod
216
+ def backward(ctx, grad_output, grad_splits):
217
+ """
218
+ Backward is implemented as a shuffle of the output's gradients to the input.
219
+ Args:
220
+ `grad_output`: output's gradients passed from the downstream.
221
+ `grad_splits`: unused.
222
+ """
223
+
224
+ # Initialize grad_output buffer (one time only)
225
+ if OnDeviceAllToAllV.grad_output_buf is None:
226
+ assert (
227
+ OnDeviceAllToAllV.max_output_len is not None
228
+ ), "`max_output_len` not set"
229
+ OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
230
+ OnDeviceAllToAllV.max_output_len,
231
+ *grad_output.shape[1:],
232
+ dtype=grad_output.dtype,
233
+ device=grad_output.device,
234
+ )
235
+
236
+ # TODO: is there a way to tell autograd to feed grad_output directly to
237
+ # our symm_mem buffer?
238
+ OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
239
+ grad_output
240
+ )
241
+
242
+ # Size info
243
+ (grad_output_splits,) = ctx.saved_tensors
244
+ OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
245
+ grad_input_splits = torch.empty_like(grad_output_splits) # unused
246
+ grad_input = grad_output.new_empty(*ctx.input_shape)
247
+
248
+ # Shuffle gradients back to the input
249
+ _on_device_all_to_all_v(
250
+ grad_input,
251
+ grad_input_splits,
252
+ OnDeviceAllToAllV.grad_output_buf,
253
+ OnDeviceAllToAllV.splits_buf,
254
+ group=ctx.group,
255
+ )
256
+ return grad_input, None, None
257
+
258
+
259
+ # Alias
260
+ on_device_all_to_all_v = OnDeviceAllToAllV.apply
torchtitan/experiments/deepseek_v3/train.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # torchrun --standalone --nproc-per-node 8 run.py
8
+ import torch
9
+ import torch.distributed as dist
10
+ from checkpoint import load_weights_from_hf
11
+ from model import DeepseekForCausalLM
12
+ from model_config import deepseek_config_registry
13
+
14
+ from torch.distributed.device_mesh import DeviceMesh
15
+ from torch.distributed.fsdp import fully_shard
16
+ from torch.distributed.pipelining import PipelineStage, Schedule1F1B
17
+
18
+
19
+ # Use DeepSeek-V2-Lite as a proxy
20
+ model_id = "deepseek-ai/DeepSeek-V2-Lite"
21
+
22
+
23
+ # Run full model
24
+ def run_full_model(
25
+ mesh: DeviceMesh,
26
+ ):
27
+ rank = dist.get_rank()
28
+ device_count = torch.cuda.device_count()
29
+ device = torch.device("cuda", rank % device_count)
30
+
31
+ pp_mesh = mesh["pp"]
32
+ ep_mesh = mesh["ep"]
33
+ pp_rank = pp_mesh.get_local_rank()
34
+ ep_rank = ep_mesh.get_local_rank()
35
+ pp_size = pp_mesh.size()
36
+ ep_size = ep_mesh.size()
37
+
38
+ # Get model configs
39
+ model_args = deepseek_config_registry[model_id]
40
+ # [Note]: I am making the model smaller for testing / avoiding OOM. If you
41
+ # have sufficient GPUs for model parallelism, you can remove this line.
42
+ model_args.num_hidden_layers = 16
43
+
44
+ # Apply model parallelism
45
+ model_args.ep_size = ep_size
46
+ model_args.num_stages = pp_size
47
+ model_args.stage_idx = pp_rank
48
+ print(model_args)
49
+
50
+ # Instantiate model
51
+ with device, mesh:
52
+ model = DeepseekForCausalLM(model_args)
53
+
54
+ # Load weights
55
+ load_weights_from_hf(model, model_id, device)
56
+ model.train()
57
+
58
+ # Apply data parallelism
59
+ fsdp_mesh = mesh["fsdp"]
60
+ hsdp_mesh = mesh["ep", "fsdp"]
61
+ # Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the
62
+ # optimizer (Zero-1) and gradients (Zero-2), but not the model weights.
63
+ # Reason: the MoE is "sparsely activated" compared to the dense model, thus
64
+ # it will be ineconomical re-gather the weights.
65
+ for layer in model.model.layers.values():
66
+ # Apply FSDP to experts
67
+ if hasattr(layer.mlp, "experts"):
68
+ for expert in layer.mlp.experts.values():
69
+ fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False)
70
+ # Apply HSDP to other parts such as attention, layernorm, because they
71
+ # are doing DDP on EP dimension
72
+ fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False)
73
+
74
+ # Apply HSDP on root model (lm_head, embeddings, etc)
75
+ fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)
76
+
77
+ # Synthetic setting
78
+ microbatches = pp_size * 2
79
+
80
+ # Use Symmetric Memory for MoE token shuffle.
81
+ # TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is
82
+ # currently supported for forward only. See `generate.py`.
83
+ # model.setup_symm_mem(torch.bfloat16, device)
84
+
85
+ # Example inputs
86
+ torch.manual_seed(ep_rank)
87
+ bs = 4
88
+ seqlen = 128
89
+ x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device)
90
+ label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device)
91
+
92
+ # Create loss function
93
+ loss_fn = torch.nn.functional.cross_entropy
94
+
95
+ # Run forward and backward
96
+ steps = 2
97
+ for _ in range(steps):
98
+ if pp_size > 1:
99
+ # Create pipeline stage
100
+ stage = PipelineStage(
101
+ model,
102
+ pp_rank,
103
+ pp_size,
104
+ device,
105
+ group=pp_mesh.get_group(),
106
+ )
107
+
108
+ # Create pipeline schedule
109
+ losses = []
110
+ pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn)
111
+
112
+ if pp_rank == 0:
113
+ y = pp_schedule.step(x)
114
+ elif pp_rank == pp_size - 1:
115
+ y = pp_schedule.step(target=label, losses=losses)
116
+ loss = torch.mean(torch.stack(losses))
117
+ else:
118
+ pp_schedule.step()
119
+ else:
120
+ y = model(x)
121
+ loss = loss_fn(y, label)
122
+ loss.backward()
123
+
124
+ if pp_rank == pp_size - 1:
125
+ print(f"logits: {y.shape}")
126
+ print(f"{loss=}")
127
+
128
+ if pp_rank == 0:
129
+ param = model.get_parameter("model.layers.0.self_attn.q_proj.weight")
130
+ print(f"{torch.linalg.norm(param.grad)=}")
131
+
132
+ model.zero_grad()
133
+
134
+ print("Backward done")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp"))
139
+
140
+ run_full_model(mesh)
141
+
142
+ dist.destroy_process_group()
torchtitan/experiments/flux/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX model in torchtitan
2
+
3
+ ## Overview
4
+
5
+ ## Usage
6
+ First, download the autoencoder model from HuggingFace with your own access token:
7
+ ```bash
8
+ python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
9
+ ```
10
+ This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
11
+
12
+ Run the following command to train the model on a single GPU:
13
+ ```bash
14
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
15
+ ```
16
+
17
+ ## TODO
18
+ - [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
19
+ - [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
20
+ - [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
21
+ - [ ] Support for distributed checkpointing and loading
22
+ - [ ] Implement init_weights() function to initialize the model weights
23
+ - [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
torchtitan/experiments/flux/__init__.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8
+
9
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
10
+ from torchtitan.components.optimizer import build_optimizers
11
+ from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
12
+ from torchtitan.experiments.flux.loss import build_mse_loss
13
+ from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
14
+ from torchtitan.experiments.flux.parallelize_flux import parallelize_flux
15
+ from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
16
+
17
+ from .model.model import FluxModel, FluxModelArgs
18
+
19
+ __all__ = [
20
+ "FluxModelArgs",
21
+ "FluxModel",
22
+ "flux_configs",
23
+ "parallelize_flux",
24
+ ]
25
+
26
+
27
+ flux_configs = {
28
+ "flux-dev": FluxModelArgs(
29
+ in_channels=64,
30
+ out_channels=64,
31
+ vec_in_dim=768,
32
+ context_in_dim=512,
33
+ hidden_size=3072,
34
+ mlp_ratio=4.0,
35
+ num_heads=24,
36
+ depth=19,
37
+ depth_single_blocks=38,
38
+ axes_dim=(16, 56, 56),
39
+ theta=10_000,
40
+ qkv_bias=True,
41
+ guidance_embed=True,
42
+ autoencoder_params=AutoEncoderParams(
43
+ resolution=256,
44
+ in_channels=3,
45
+ ch=128,
46
+ out_ch=3,
47
+ ch_mult=(1, 2, 4, 4),
48
+ num_res_blocks=2,
49
+ z_channels=16,
50
+ scale_factor=0.3611,
51
+ shift_factor=0.1159,
52
+ ),
53
+ ),
54
+ "flux-schnell": FluxModelArgs(
55
+ in_channels=64,
56
+ out_channels=64,
57
+ vec_in_dim=768,
58
+ context_in_dim=4096,
59
+ hidden_size=3072,
60
+ mlp_ratio=4.0,
61
+ num_heads=24,
62
+ depth=19,
63
+ depth_single_blocks=38,
64
+ axes_dim=(16, 56, 56),
65
+ theta=10_000,
66
+ qkv_bias=True,
67
+ guidance_embed=False,
68
+ autoencoder_params=AutoEncoderParams(
69
+ resolution=256,
70
+ in_channels=3,
71
+ ch=128,
72
+ out_ch=3,
73
+ ch_mult=(1, 2, 4, 4),
74
+ num_res_blocks=2,
75
+ z_channels=16,
76
+ scale_factor=0.3611,
77
+ shift_factor=0.1159,
78
+ ),
79
+ ),
80
+ "flux-debug": FluxModelArgs(
81
+ in_channels=64,
82
+ out_channels=64,
83
+ vec_in_dim=768,
84
+ context_in_dim=512,
85
+ hidden_size=512,
86
+ mlp_ratio=4.0,
87
+ num_heads=4,
88
+ depth=2,
89
+ depth_single_blocks=2,
90
+ axes_dim=(16, 56, 56),
91
+ theta=10_000,
92
+ qkv_bias=True,
93
+ guidance_embed=True,
94
+ autoencoder_params=AutoEncoderParams(
95
+ resolution=256,
96
+ in_channels=3,
97
+ ch=128,
98
+ out_ch=3,
99
+ ch_mult=(1, 2, 4, 4),
100
+ num_res_blocks=2,
101
+ z_channels=16,
102
+ scale_factor=0.3611,
103
+ shift_factor=0.1159,
104
+ ),
105
+ ),
106
+ }
107
+
108
+
109
+ register_train_spec(
110
+ TrainSpec(
111
+ name="flux",
112
+ cls=FluxModel,
113
+ config=flux_configs,
114
+ parallelize_fn=parallelize_flux,
115
+ pipelining_fn=None,
116
+ build_optimizers_fn=build_optimizers,
117
+ build_lr_schedulers_fn=build_lr_schedulers,
118
+ build_dataloader_fn=build_flux_dataloader,
119
+ build_tokenizer_fn=None,
120
+ build_loss_fn=build_mse_loss,
121
+ )
122
+ )
torchtitan/experiments/flux/flux_argparser.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+
9
+ import torch
10
+
11
+
12
+ def extend_parser(parser: argparse.ArgumentParser) -> None:
13
+ parser.add_argument(
14
+ "--training.guidance",
15
+ type=float,
16
+ default=3.5,
17
+ help="guidance value used for guidance distillation",
18
+ )
19
+ parser.add_argument(
20
+ "--encoder.t5_encoder",
21
+ type=str,
22
+ default="google/t5-v1_1-small",
23
+ help="T5 encoder to use, HuggingFace model name.",
24
+ )
25
+ parser.add_argument(
26
+ "--encoder.clip_encoder",
27
+ type=str,
28
+ default="openai/clip-vit-large-patch14",
29
+ help="Clip encoder to use, HuggingFace model name.",
30
+ )
31
+ parser.add_argument(
32
+ "--encoder.encoder_dtype",
33
+ type=torch.dtype,
34
+ default=torch.bfloat16,
35
+ help="Which dtype to load for autoencoder. ",
36
+ )
37
+ parser.add_argument(
38
+ "--encoder.max_t5_encoding_len",
39
+ type=int,
40
+ default=512,
41
+ help="Maximum length of the T5 encoding.",
42
+ )
torchtitan/experiments/flux/loss.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, TypeAlias
8
+
9
+ import torch
10
+
11
+ from torchtitan.config_manager import JobConfig
12
+ from torchtitan.tools.logging import logger
13
+
14
+ LossFunction: TypeAlias = Callable[..., torch.Tensor]
15
+
16
+
17
+ def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
18
+ """Common MSE loss function for Transformer models training."""
19
+ return torch.nn.functional.mse_loss(pred.float(), labels.float().detach())
20
+
21
+
22
+ def build_mse_loss(job_config: JobConfig):
23
+ loss_fn = mse_loss
24
+ if job_config.training.compile:
25
+ logger.info("Compiling the loss function with torch.compile")
26
+ loss_fn = torch.compile(loss_fn)
27
+ return loss_fn
torchtitan/experiments/flux/parallelize_flux.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+
11
+ import torch.nn as nn
12
+
13
+ from torch.distributed.device_mesh import DeviceMesh
14
+
15
+ from torchtitan.config_manager import JobConfig
16
+ from torchtitan.distributed import ParallelDims
17
+
18
+
19
+ def parallelize_flux(
20
+ model: nn.Module,
21
+ world_mesh: DeviceMesh,
22
+ parallel_dims: ParallelDims,
23
+ job_config: JobConfig,
24
+ ):
25
+ # TODO: Add model parallel strategy here
26
+ return model
torchtitan/experiments/flux/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ einops
torchtitan/experiments/flux/scripts/download_autoencoder.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ from requests.exceptions import HTTPError
10
+
11
+
12
+ def hf_download(
13
+ repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None
14
+ ) -> None:
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ try:
18
+ hf_hub_download(
19
+ repo_id=repo_id,
20
+ filename=file_path,
21
+ local_dir=local_dir,
22
+ local_dir_use_symlinks=False,
23
+ token=hf_token,
24
+ )
25
+ except HTTPError as e:
26
+ if e.response.status_code == 401:
27
+ print(
28
+ "You need to pass a valid `--hf_token=...` to download private checkpoints."
29
+ )
30
+ else:
31
+ raise e
32
+
33
+
34
+ if __name__ == "__main__":
35
+ import argparse
36
+
37
+ parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
38
+ parser.add_argument(
39
+ "--repo_id",
40
+ type=str,
41
+ default="black-forest-labs/FLUX.1-dev",
42
+ help="Repository ID to download from. default to Flux-dev model",
43
+ )
44
+ parser.add_argument(
45
+ "--ae_path",
46
+ type=str,
47
+ default="ae.safetensors",
48
+ help="the autoencoder path relative to repo_id",
49
+ )
50
+ parser.add_argument(
51
+ "--hf_token", type=str, default=None, help="HuggingFace API token"
52
+ )
53
+ parser.add_argument(
54
+ "--local_dir",
55
+ type=str,
56
+ default="torchtitan/experiments/flux/assets/autoencoder/",
57
+ help="local directory to save the autoencoder",
58
+ )
59
+
60
+ args = parser.parse_args()
61
+ hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token)
torchtitan/experiments/flux/tests/test_flux_dataloader.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+
9
+ from torchtitan.config_manager import JobConfig
10
+ from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
11
+ from torchtitan.tools.profiling import (
12
+ maybe_enable_memory_snapshot,
13
+ maybe_enable_profiling,
14
+ )
15
+
16
+
17
+ class TestFluxDataLoader:
18
+ def test_flux_dataloader(self):
19
+ dataset_name = "cc12m"
20
+ batch_size = 32
21
+ world_size = 4
22
+ rank = 0
23
+
24
+ num_steps = 10
25
+
26
+ path = "torchtitan.experiments.flux.flux_argparser"
27
+ sys.argv.append(f"--experimental.custom_args_module={path}")
28
+ config = JobConfig()
29
+ config.maybe_add_custom_args()
30
+ config.parse_args(
31
+ [
32
+ # Profiling options
33
+ # "--profiling.enable_profiling",
34
+ # "--profiling.profile_freq",
35
+ # "5",
36
+ # "--profiling.enable_memory_snapshot",
37
+ # "--profiling.save_memory_snapshot_folder",
38
+ # "memory_snapshot_flux",
39
+ "--training.dataset",
40
+ dataset_name,
41
+ "--training.batch_size",
42
+ str(batch_size),
43
+ "--encoder.t5_encoder",
44
+ "google/t5-v1_1-small",
45
+ "--encoder.clip_encoder",
46
+ "openai/clip-vit-large-patch14",
47
+ "--encoder.max_t5_encoding_len",
48
+ "512",
49
+ ]
50
+ )
51
+
52
+ with maybe_enable_profiling(
53
+ config, global_step=0
54
+ ) as torch_profiler, maybe_enable_memory_snapshot(
55
+ config, global_step=0
56
+ ) as memory_profiler:
57
+ dl = self._build_dataloader(
58
+ config,
59
+ world_size,
60
+ rank,
61
+ )
62
+ dl = iter(dl)
63
+
64
+ for i in range(0, num_steps):
65
+ input_data, labels = next(dl)
66
+ print(f"Step {i} image size: {labels.shape}")
67
+ if torch_profiler:
68
+ torch_profiler.step()
69
+ if memory_profiler:
70
+ memory_profiler.step()
71
+
72
+ print(len(input_data["clip_tokens"]))
73
+ for k, v in input_data.items():
74
+ print(f"Step {i} {k} value: {type(v), v.shape}")
75
+
76
+ assert len(input_data) == 2 # (clip_encodings, t5_encodings)
77
+ assert labels.shape == (batch_size, 3, 256, 256)
78
+ # assert input_data["clip_tokens"].shape[0] == batch_size
79
+ # assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
80
+
81
+ if torch_profiler:
82
+ torch_profiler.step()
83
+ if memory_profiler:
84
+ memory_profiler.step(exit_ctx=True)
85
+
86
+ def test_preprocess(self):
87
+ # TODO
88
+ pass
89
+
90
+ def _build_dataloader(
91
+ self,
92
+ job_config,
93
+ world_size,
94
+ rank,
95
+ ):
96
+
97
+ return build_flux_dataloader(
98
+ dp_world_size=world_size,
99
+ dp_rank=rank,
100
+ job_config=job_config,
101
+ tokenizer=None,
102
+ infinite=False,
103
+ )
torchtitan/experiments/flux/train.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from typing import Optional
9
+
10
+ import torch
11
+
12
+ from torchtitan.config_manager import JobConfig
13
+ from torchtitan.distributed import utils as dist_utils
14
+ from torchtitan.experiments.flux.model.autoencoder import load_ae
15
+ from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
16
+ from torchtitan.experiments.flux.model.model import FluxModel
17
+ from torchtitan.experiments.flux.utils import (
18
+ create_position_encoding_for_latents,
19
+ pack_latents,
20
+ preprocess_flux_data,
21
+ unpack_latents,
22
+ )
23
+ from torchtitan.tools.logging import init_logger, logger
24
+ from torchtitan.train import Trainer
25
+
26
+
27
+ class FluxTrainer(Trainer):
28
+ def __init__(self, job_config: JobConfig):
29
+ super().__init__(job_config)
30
+
31
+ self.preprocess_fn = preprocess_flux_data
32
+ # self.dtype = job_config.encoder.dtype
33
+ self._dtype = torch.bfloat16
34
+ self._seed = job_config.training.seed
35
+ self._guidance = job_config.training.guidance
36
+
37
+ # load components
38
+ model_config = self.train_spec.config[job_config.model.flavor]
39
+ self.autoencoder = load_ae(
40
+ job_config.encoder.auto_encoder_path,
41
+ model_config.autoencoder_params,
42
+ device="cpu",
43
+ dtype=self._dtype,
44
+ )
45
+ self.clip_encoder = FluxEmbedder(version=job_config.encoder.clip_encoder).to(
46
+ dtype=self._dtype
47
+ )
48
+ self.t5_encoder = FluxEmbedder(version=job_config.encoder.t5_encoder).to(
49
+ dtype=self._dtype
50
+ )
51
+
52
+ def _predict_noise(
53
+ self,
54
+ model: FluxModel,
55
+ latents: torch.Tensor,
56
+ clip_encodings: torch.Tensor,
57
+ t5_encodings: torch.Tensor,
58
+ timesteps: torch.Tensor,
59
+ guidance: Optional[torch.Tensor] = None,
60
+ ) -> torch.Tensor:
61
+ """
62
+ Use Flux's flow-matching model to predict the noise in image latents.
63
+ Args:
64
+ model (FluxFlowModel): The Flux flow model.
65
+ latents (Tensor): Image encodings from the Flux autoencoder.
66
+ Shape: [bsz, 16, latent height, latent width]
67
+ clip_encodings (Tensor): CLIP text encodings.
68
+ Shape: [bsz, 768]
69
+ t5_encodings (Tensor): T5 text encodings.
70
+ Shape: [bsz, sequence length, 256 or 512]
71
+ timesteps (Tensor): The amount of noise (0 to 1).
72
+ Shape: [bsz]
73
+ guidance (Optional[Tensor]): The guidance value (1.5 to 4) if guidance-enabled model.
74
+ Shape: [bsz]
75
+ Default: None
76
+ model_ctx (ContextManager): Optional context to wrap the model call (e.g. for activation offloading)
77
+ Default: nullcontext
78
+ Returns:
79
+ Tensor: The noise prediction.
80
+ Shape: [bsz, 16, latent height, latent width]
81
+ """
82
+ bsz, _, latent_height, latent_width = latents.shape
83
+
84
+ POSITION_DIM = 3 # constant for Flux flow model
85
+ with torch.no_grad():
86
+ # Create positional encodings
87
+ latent_pos_enc = create_position_encoding_for_latents(
88
+ bsz, latent_height, latent_width, POSITION_DIM
89
+ )
90
+ text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM)
91
+
92
+ # Convert latent into a sequence of patches
93
+ latents = pack_latents(latents)
94
+
95
+ # Predict noise
96
+ latent_noise_pred = model(
97
+ img=latents,
98
+ img_ids=latent_pos_enc.to(latents),
99
+ txt=t5_encodings.to(latents),
100
+ txt_ids=text_pos_enc.to(latents),
101
+ y=clip_encodings.to(latents),
102
+ timesteps=timesteps.to(latents),
103
+ guidance=guidance.to(latents) if guidance is not None else None,
104
+ )
105
+
106
+ # Convert sequence of patches to latent shape
107
+ latent_noise_pred = unpack_latents(
108
+ latent_noise_pred, latent_height, latent_width
109
+ )
110
+
111
+ return latent_noise_pred
112
+
113
+ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
114
+ # generate t5 and clip
115
+ input_dict["image"] = labels
116
+ input_dict = self.preprocess_fn(
117
+ device=self.device,
118
+ dtype=self._dtype,
119
+ autoencoder=self.autoencoder,
120
+ clip_encoder=self.clip_encoder,
121
+ t5_encoder=self.t5_encoder,
122
+ batch=input_dict,
123
+ offload=True,
124
+ )
125
+ labels = input_dict["img_encodings"]
126
+
127
+ self.optimizers.zero_grad()
128
+
129
+ # Keep these variables local to shorten the code as these are
130
+ # the major variables that are used in the training loop.
131
+ model_parts = self.model_parts
132
+ world_mesh = self.world_mesh
133
+ parallel_dims = self.parallel_dims
134
+
135
+ # image in latent space transformed by self.auto_encoder
136
+ clip_encodings = input_dict["clip_encodings"]
137
+ t5_encodings = input_dict["t5_encodings"]
138
+
139
+ bsz = labels.shape[0]
140
+
141
+ with torch.no_grad():
142
+ noise = torch.randn_like(labels)
143
+ timesteps = torch.rand((bsz,)).to(labels)
144
+ sigmas = timesteps.view(-1, 1, 1, 1)
145
+ noisy_latents = (1 - sigmas) * labels + sigmas * noise
146
+ guidance = torch.full((bsz,), self._guidance).to(labels)
147
+
148
+ target = noise - labels
149
+
150
+ assert len(model_parts) == 1
151
+ # TODO(jianiw): model_parts will be wrapped by FSDP, which will cacluate
152
+ model_parts[0] = model_parts[0].to(dtype=self._dtype)
153
+
154
+ pred = self._predict_noise(
155
+ model_parts[0],
156
+ noisy_latents,
157
+ clip_encodings,
158
+ t5_encodings,
159
+ timesteps,
160
+ guidance,
161
+ )
162
+ loss = self.loss_fn(pred, target)
163
+ # pred.shape=(bs, seq_len, vocab_size)
164
+ # need to free to before bwd to avoid peaking memory
165
+ del (pred, noise, target)
166
+ loss.backward()
167
+
168
+ dist_utils.clip_grad_norm_(
169
+ [p for m in model_parts for p in m.parameters()],
170
+ self.job_config.training.max_norm,
171
+ foreach=True,
172
+ pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None,
173
+ )
174
+ self.checkpointer.maybe_wait_for_staging()
175
+ self.optimizers.step()
176
+ self.lr_schedulers.step()
177
+
178
+ # log metrics
179
+ if not self.metrics_processor.should_log(self.step):
180
+ return
181
+
182
+ if (
183
+ parallel_dims.dp_replicate_enabled
184
+ or parallel_dims.dp_shard_enabled
185
+ or parallel_dims.cp_enabled
186
+ ):
187
+ loss = loss.detach()
188
+ global_avg_loss, global_max_loss = (
189
+ dist_utils.dist_mean(loss, world_mesh["dp_cp"]),
190
+ dist_utils.dist_max(loss, world_mesh["dp_cp"]),
191
+ )
192
+ else:
193
+ global_avg_loss = global_max_loss = loss.item()
194
+
195
+ self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)
196
+
197
+
198
+ if __name__ == "__main__":
199
+ init_logger()
200
+ config = JobConfig()
201
+ config.maybe_add_custom_args()
202
+ config.parse_args()
203
+ trainer: Optional[FluxTrainer] = None
204
+
205
+ try:
206
+ trainer = FluxTrainer(config)
207
+ if config.checkpoint.create_seed_checkpoint:
208
+ assert int(
209
+ os.environ["WORLD_SIZE"]
210
+ ), "Must create seed checkpoint using a single device, to disable sharding."
211
+ assert (
212
+ config.checkpoint.enable_checkpoint
213
+ ), "Must enable checkpointing when creating a seed checkpoint."
214
+ trainer.checkpointer.save(curr_step=0, force=True)
215
+ logger.info("Created seed checkpoint")
216
+ else:
217
+ trainer.train()
218
+ finally:
219
+ if trainer:
220
+ trainer.close()
221
+
222
+ if torch.distributed.is_initialized():
223
+ torch.distributed.destroy_process_group()
224
+ logger.info("Process group destroyed.")
torchtitan/experiments/flux/train_configs/debug_model.toml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [job]
3
+ dump_folder = "./outputs"
4
+ description = "Flux debug model"
5
+ print_args = false
6
+ use_for_integration_test = true
7
+
8
+ [profiling]
9
+ enable_profiling = false
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 10
12
+ enable_memory_snapshot = false
13
+ save_memory_snapshot_folder = "memory_snapshot"
14
+
15
+ [metrics]
16
+ log_freq = 1
17
+ disable_color_printing = false
18
+ enable_tensorboard = false
19
+ save_tb_folder = "tb"
20
+ enable_wandb = false
21
+
22
+ [model]
23
+ name = "flux"
24
+ flavor = "flux-debug"
25
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
26
+ # test tokenizer.model, for debug purpose only
27
+ # tokenizer_path = "./tests/assets/test_tiktoken.model"
28
+ # converters = "float8"
29
+
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 8e-4
34
+ eps = 1e-8
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.0
41
+
42
+ [training]
43
+ batch_size = 32
44
+ seq_len = 512
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "cc12m"
49
+ guidance = 3.5
50
+ seed = 0
51
+
52
+ [encoder]
53
+ t5_encoder="google/t5-v1_1-small"
54
+ clip_encoder="openai/clip-vit-large-patch14"
55
+ max_t5_encoding_len=512
56
+ auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
57
+
58
+ [parallelism]
59
+ data_parallel_replicate_degree = 1
60
+ data_parallel_shard_degree = 1
61
+ fsdp_reshard_after_forward = "default" # default / never / always
62
+ tensor_parallel_degree = 1
63
+ enable_async_tensor_parallel = false
64
+ pipeline_parallel_degree = 1
65
+ context_parallel_degree = 1
66
+
67
+ [experimental]
68
+ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
torchtitan/experiments/flux/utils.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+
11
+ from torch import Tensor
12
+
13
+ from torchtitan.experiments.flux.model.autoencoder import AutoEncoder
14
+ from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
15
+
16
+
17
+ def preprocess_flux_data(
18
+ # arguments from the recipe
19
+ device: torch.device,
20
+ dtype: torch.dtype,
21
+ *,
22
+ # arguments from the config
23
+ autoencoder: Optional[AutoEncoder],
24
+ clip_encoder: FluxEmbedder,
25
+ t5_encoder: FluxEmbedder,
26
+ batch: dict[str, Tensor],
27
+ offload: bool = False,
28
+ ) -> dict[str, Tensor]:
29
+ """
30
+ Take a batch of inputs and encoder as input and return a batch of preprocessed data.
31
+
32
+ Args:
33
+ device (torch.device): device to do preprocessing on
34
+ dtype (torch.dtype): data type to do preprocessing in
35
+ autoencoer(AutoEncoder): autoencoder to use for preprocessing
36
+ clip_encoder
37
+ t5_encoder
38
+ batch (dict[str, Tensor]): batch of data to preprocess
39
+
40
+ Returns:
41
+ dict[str, Tensor]: batch of preprocessed data
42
+ """
43
+
44
+ # The input of encoder should be torch.int type
45
+ if offload:
46
+ clip_encoder.to(device)
47
+ t5_encoder.to(device)
48
+ if autoencoder is not None:
49
+ autoencoder.to(device)
50
+
51
+ clip_tokens = batch["clip_tokens"].squeeze().to(device=device, dtype=torch.int)
52
+ t5_tokens = batch["t5_tokens"].squeeze().to(device=device, dtype=torch.int)
53
+
54
+ clip_text_encodings = clip_encoder(clip_tokens)
55
+ t5_text_encodings = t5_encoder(t5_tokens)
56
+
57
+ if autoencoder is not None:
58
+ images = batch["image"].to(device=device, dtype=dtype)
59
+ img_encodings = autoencoder.encode(images)
60
+ batch["img_encodings"] = img_encodings.to(device=device, dtype=dtype)
61
+
62
+ batch["clip_encodings"] = clip_text_encodings.to(dtype)
63
+ batch["t5_encodings"] = t5_text_encodings.to(dtype)
64
+
65
+ # offload encoders to cpu after preprocessing
66
+ if offload:
67
+ clip_encoder.to("cpu")
68
+ t5_encoder.to("cpu")
69
+ if autoencoder is not None:
70
+ autoencoder.to("cpu")
71
+
72
+ return batch
73
+
74
+
75
+ def generate_noise_latent(
76
+ bsz: int,
77
+ height: int,
78
+ width: int,
79
+ device: str | torch.device,
80
+ dtype: torch.dtype,
81
+ seed: int,
82
+ ) -> Tensor:
83
+ """Generate noise latents for the Flux flow model.
84
+
85
+ Args:
86
+ bsz (int): batch_size.
87
+ height (int): The height of the image.
88
+ width (int): The width of the image.
89
+ device (str | torch.device): The device to use.
90
+ dtype (torch.dtype): The dtype to use.
91
+ seed (int): The seed to use for randomize.
92
+
93
+ Returns:
94
+ Tensor: The noise latents.
95
+ Shape: [num_samples, LATENT_CHANNELS, height // IMG_LATENT_SIZE_RATIO, width // IMG_LATENT_SIZE_RATIO]
96
+
97
+ """
98
+ LATENT_CHANNELS, IMAGE_LATENT_SIZE_RATIO = 16, 8
99
+ return torch.randn(
100
+ bsz,
101
+ LATENT_CHANNELS,
102
+ height // IMAGE_LATENT_SIZE_RATIO,
103
+ width // IMAGE_LATENT_SIZE_RATIO,
104
+ dtype=dtype,
105
+ generator=torch.Generator().manual_seed(seed),
106
+ ).to(device)
107
+
108
+
109
+ def create_position_encoding_for_latents(
110
+ bsz: int, latent_height: int, latent_width: int, position_dim: int = 3
111
+ ) -> Tensor:
112
+ """
113
+ Create the packed latents' position encodings for the Flux flow model.
114
+
115
+ Args:
116
+ bsz (int): The batch size.
117
+ latent_height (int): The height of the latent.
118
+ latent_width (int): The width of the latent.
119
+
120
+ Returns:
121
+ Tensor: The position encodings.
122
+ Shape: [bsz, (latent_height // PATCH_HEIGHT) * (latent_width // PATCH_WIDTH), POSITION_DIM)
123
+ """
124
+ PATCH_HEIGHT, PATCH_WIDTH = 2, 2
125
+
126
+ height = latent_height // PATCH_HEIGHT
127
+ width = latent_width // PATCH_WIDTH
128
+
129
+ position_encoding = torch.zeros(height, width, position_dim)
130
+
131
+ row_indices = torch.arange(height)
132
+ position_encoding[:, :, 1] = row_indices.unsqueeze(1)
133
+
134
+ col_indices = torch.arange(width)
135
+ position_encoding[:, :, 2] = col_indices.unsqueeze(0)
136
+
137
+ # Flatten and repeat for the full batch
138
+ # [height, width, 3] -> [bsz, height * width, 3]
139
+ position_encoding = position_encoding.view(1, height * width, position_dim)
140
+ position_encoding = position_encoding.repeat(bsz, 1, 1)
141
+
142
+ return position_encoding
143
+
144
+
145
+ def pack_latents(x: Tensor) -> Tensor:
146
+ """
147
+ Rearrange latents from an image-like format into a sequence of patches.
148
+ Equivalent to `einops.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)")`.
149
+
150
+ Args:
151
+ x (Tensor): The unpacked latents.
152
+ Shape: [bsz, ch, latent height, latent width]
153
+
154
+ Returns:
155
+ Tensor: The packed latents.
156
+ Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
157
+ """
158
+ PATCH_HEIGHT, PATCH_WIDTH = 2, 2
159
+
160
+ b, c, latent_height, latent_width = x.shape
161
+ h = latent_height // PATCH_HEIGHT
162
+ w = latent_width // PATCH_WIDTH
163
+
164
+ # [b, c, h*ph, w*ph] -> [b, c, h, w, ph, pw]
165
+ x = x.unfold(2, PATCH_HEIGHT, PATCH_HEIGHT).unfold(3, PATCH_WIDTH, PATCH_WIDTH)
166
+
167
+ # [b, c, h, w, ph, PW] -> [b, h, w, c, ph, PW]
168
+ x = x.permute(0, 2, 3, 1, 4, 5)
169
+
170
+ # [b, h, w, c, ph, PW] -> [b, h*w, c*ph*PW]
171
+ return x.reshape(b, h * w, c * PATCH_HEIGHT * PATCH_WIDTH)
172
+
173
+
174
+ def unpack_latents(x: Tensor, latent_height: int, latent_width: int) -> Tensor:
175
+ """
176
+ Rearrange latents from a sequence of patches into an image-like format.
177
+ Equivalent to `einops.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)")`.
178
+
179
+ Args:
180
+ x (Tensor): The packed latents.
181
+ Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
182
+ latent_height (int): The height of the unpacked latents.
183
+ latent_width (int): The width of the unpacked latents.
184
+
185
+ Returns:
186
+ Tensor: The unpacked latents.
187
+ Shape: [bsz, ch, latent height, latent width]
188
+ """
189
+ PATCH_HEIGHT, PATCH_WIDTH = 2, 2
190
+
191
+ b, _, c_ph_pw = x.shape
192
+ h = latent_height // PATCH_HEIGHT
193
+ w = latent_width // PATCH_WIDTH
194
+ c = c_ph_pw // (PATCH_HEIGHT * PATCH_WIDTH)
195
+
196
+ # [b, h*w, c*ph*pw] -> [b, h, w, c, ph, pw]
197
+ x = x.reshape(b, h, w, c, PATCH_HEIGHT, PATCH_WIDTH)
198
+
199
+ # [b, h, w, c, ph, pw] -> [b, c, h, ph, w, pw]
200
+ x = x.permute(0, 3, 1, 4, 2, 5)
201
+
202
+ # [b, c, h, ph, w, pw] -> [b, c, h*ph, w*pw]
203
+ return x.reshape(b, c, h * PATCH_HEIGHT, w * PATCH_WIDTH)
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mg_grouped_gemm import grouped_gemm_forward
8
+ from .tma_autotuning import ALIGN_SIZE_M
9
+
10
+ __all__ = [
11
+ "grouped_gemm_forward",
12
+ "ALIGN_SIZE_M",
13
+ ]
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from reference_utils import (
14
+ analyze_tensor_differences,
15
+ compute_reference_backward,
16
+ compute_reference_forward,
17
+ )
18
+
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
22
+ )
23
+
24
+ # Import grouped GEMM implementations
25
+ try:
26
+ from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward
27
+
28
+ except ImportError:
29
+ logging.error(
30
+ "Error importing grouped GEMM modules. Make sure the implementation files are in the correct path."
31
+ )
32
+ raise
33
+
34
+
35
+ def test_forward_pass():
36
+ """
37
+ A simple test for the M*G grouped GEMM forward pass with detailed error handling.
38
+
39
+ In M*G grouping:
40
+ - M dimension is partitioned into G groups (M_total = sum(M_sizes))
41
+ - N dimension is the same for all groups
42
+ """
43
+ try:
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+
46
+ # Test parameters for DeepSeek-like models
47
+ G = 1 # Number of groups
48
+ M_sizes = [
49
+ 2048,
50
+ ] # 2048, 2048, 2048] # Group sizes (will be adjusted)
51
+ M_total = sum(M_sizes) # Total M dimension
52
+ N = 4096 # Output dimension (same for all groups)
53
+ K = 7168 # Hidden dimension
54
+
55
+ # Create group sizes tensor
56
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
57
+
58
+ # Create input and weight tensors - using float16 for higher precision
59
+ x = torch.randn(M_total, K, dtype=torch.float16, device=device)
60
+ w = torch.randn(N, K, dtype=torch.float16, device=device)
61
+
62
+ # Log the setup
63
+ logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
64
+ logging.info(f"Group sizes: {m_sizes}")
65
+ logging.info(f"Input x shape: {x.shape}")
66
+ logging.info(f"Weight w shape: {w.shape}")
67
+
68
+ # Run forward pass
69
+ logging.info("Running forward pass with grouped GEMM")
70
+ result = grouped_gemm_forward(x, w, m_sizes)
71
+ logging.info(f"Forward result shape: {result.shape}")
72
+
73
+ # Compute reference result
74
+ logging.info("Computing reference result with PyTorch")
75
+ reference_result = compute_reference_forward(x, w, m_sizes)
76
+
77
+ # Compare results
78
+ logging.info("Comparing with PyTorch reference")
79
+ forward_close = analyze_tensor_differences(
80
+ result, reference_result, "Forward output"
81
+ )
82
+
83
+ return forward_close
84
+
85
+ except Exception as e:
86
+ logging.error(f"Test failed with error: {e}")
87
+ import traceback
88
+
89
+ logging.error(traceback.format_exc())
90
+ return False
91
+
92
+
93
+ def test_backward_pass():
94
+ """
95
+ A simple test for the M*G grouped GEMM backward pass with detailed error handling.
96
+
97
+ In M*G grouping:
98
+ - M dimension is partitioned into G groups (M_total = sum(M_sizes))
99
+ - N dimension is the same for all groups
100
+ """
101
+ try:
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+
104
+ # Test parameters for DeepSeek-like models
105
+ G = 4 # Number of groups
106
+ M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted)
107
+ M_total = sum(M_sizes) # Total M dimension
108
+ N = 4096 # Output dimension (same for all groups)
109
+ K = 7168 # Hidden dimension
110
+
111
+ # Create group sizes tensor
112
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
113
+
114
+ # Create input and weight tensors - using float16 for higher precision
115
+ x = torch.randn(
116
+ M_total, K, dtype=torch.float16, device=device, requires_grad=True
117
+ )
118
+ w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True)
119
+
120
+ # Log the setup
121
+ logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
122
+ logging.info(f"Group sizes: {m_sizes}")
123
+ logging.info(f"Input x shape: {x.shape}")
124
+ logging.info(f"Weight w shape: {w.shape}")
125
+
126
+ # Step 1: Run forward pass
127
+ logging.info("Running forward pass")
128
+ result = grouped_gemm_forward(x, w, m_sizes)
129
+ logging.info(f"Forward result shape: {result.shape}")
130
+
131
+ # Create a gradient for backpropagation
132
+ grad_output = torch.randn_like(result)
133
+ logging.info(f"Created gradient with shape: {grad_output.shape}")
134
+
135
+ # Step 2: Run backward pass directly
136
+ logging.info("Running backward pass directly")
137
+ grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
138
+
139
+ # Verify gradient shapes
140
+ logging.info(
141
+ f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}"
142
+ )
143
+
144
+ # Step 3: Verify gradient computation using PyTorch's autograd
145
+ logging.info("Running PyTorch reference implementation")
146
+
147
+ # Compute reference gradients
148
+ x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output)
149
+
150
+ # Compare gradients
151
+ logging.info("Comparing gradients with PyTorch reference")
152
+ grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
153
+ grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
154
+
155
+ # Log overall result
156
+ if grad_x_close and grad_w_close:
157
+ logging.info("✓ SUCCESS: Gradients match the PyTorch reference")
158
+ else:
159
+ logging.error("✗ FAILURE: Gradient mismatch detected")
160
+
161
+ return grad_x_close and grad_w_close
162
+
163
+ except Exception as e:
164
+ logging.error(f"Test failed with error: {e}")
165
+ import traceback
166
+
167
+ logging.error(traceback.format_exc())
168
+ return False
169
+
170
+
171
+ def test_multiple_deepseek_configs():
172
+ """
173
+ Test multiple DeepSeek model configurations with both forward and backward pass verification.
174
+ """
175
+ # DeepSeek configurations: (G, M, K, N)
176
+ configs = [
177
+ (4, 8192, 7168, 4096), # Config 1
178
+ (4, 8192, 2048, 7168), # Config 2
179
+ (8, 4096, 7168, 4096), # Config 3
180
+ (8, 4096, 2048, 7168), # Config 4
181
+ ]
182
+
183
+ results = []
184
+
185
+ for config_idx, (G, M, K, N) in enumerate(configs):
186
+ logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====")
187
+ logging.info(f"G={G}, M={M}, K={K}, N={N}")
188
+
189
+ try:
190
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
+
192
+ # Create even group sizes
193
+ base_size = M // G
194
+ remainder = M % G
195
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
196
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
197
+
198
+ # Create input and weight tensors using float16 for higher precision
199
+ x = torch.randn(
200
+ M, K, dtype=torch.float16, device=device, requires_grad=True
201
+ )
202
+ w = torch.randn(
203
+ N, K, dtype=torch.float16, device=device, requires_grad=True
204
+ )
205
+
206
+ logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}")
207
+
208
+ # Run forward pass
209
+ result = grouped_gemm_forward(x, w, m_sizes)
210
+ logging.info(f"Forward result shape: {result.shape}")
211
+
212
+ # ===== FORWARD PASS VERIFICATION =====
213
+ # Compute reference forward result
214
+ reference_result = compute_reference_forward(x, w, m_sizes)
215
+
216
+ # Compare forward results
217
+ forward_close = analyze_tensor_differences(
218
+ result, reference_result, "Forward output"
219
+ )
220
+
221
+ # ===== BACKWARD PASS VERIFICATION =====
222
+ # Create gradient for backpropagation
223
+ grad_output = torch.randn_like(result)
224
+
225
+ # Run backward pass
226
+ grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
227
+
228
+ # Compute reference gradients
229
+ x_ref_grad, w_ref_grad = compute_reference_backward(
230
+ x, w, m_sizes, grad_output
231
+ )
232
+
233
+ # Compare backward results
234
+ grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
235
+ grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
236
+
237
+ # Overall config result
238
+ backward_close = grad_x_close and grad_w_close
239
+ config_success = forward_close and backward_close
240
+ results.append(
241
+ (config_idx + 1, config_success, forward_close, backward_close)
242
+ )
243
+
244
+ # Log overall config result
245
+ if config_success:
246
+ logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!")
247
+ else:
248
+ logging.error(
249
+ f"✗ FAILURE: Config {config_idx+1} failed one or more tests"
250
+ )
251
+
252
+ except Exception as e:
253
+ logging.error(f"Config {config_idx+1} test failed with error: {e}")
254
+ import traceback
255
+
256
+ logging.error(traceback.format_exc())
257
+ results.append((config_idx + 1, False, False, False))
258
+
259
+ # Summary
260
+ logging.info("\n===== Test Results Summary =====")
261
+ for config_idx, overall_success, forward_success, backward_success in results:
262
+ overall_status = "✓ PASSED" if overall_success else "✗ FAILED"
263
+ forward_status = "✓ PASSED" if forward_success else "✗ FAILED"
264
+ backward_status = "✓ PASSED" if backward_success else "✗ FAILED"
265
+
266
+ logging.info(f"Config {config_idx}: {overall_status}")
267
+ logging.info(f" - Forward pass: {forward_status}")
268
+ logging.info(f" - Backward pass: {backward_status}")
269
+
270
+ return all(overall_success for _, overall_success, _, _ in results)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ logging.info(
275
+ "Running verification for both forward and backward pass of M*G grouped GEMM"
276
+ )
277
+
278
+ # Run basic forward pass test
279
+ logging.info("\n===== Running basic forward pass test =====")
280
+ success_forward = test_forward_pass()
281
+ logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}")
282
+
283
+ # Run basic backward pass test
284
+ logging.info("\n===== Running basic backward pass test =====")
285
+ success_backward = test_backward_pass()
286
+ logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}")
287
+
288
+ # Run multiple DeepSeek configs with forward and backward verification
289
+ logging.info("\n===== Running tests for all DeepSeek configs =====")
290
+ success_configs = test_multiple_deepseek_configs()
291
+ logging.info(
292
+ f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}"
293
+ )
294
+
295
+ # Overall result
296
+ overall_success = success_forward and success_backward and success_configs
297
+ logging.info(
298
+ f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}"
299
+ )
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # credit - TMAHelper class, AutoTuning are derived from FBGemm:
8
+ # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
9
+
10
+ # pyre-unsafe
11
+ import functools
12
+
13
+ import os
14
+ import sys
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ import torch
18
+
19
+ import triton
20
+ import triton.language as tl
21
+ from triton import Config as TConfig
22
+
23
+ from triton.runtime import driver # @manual
24
+
25
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
26
+
27
+
28
+ # ===== Supporting utils, CUDA and TMA =====
29
+
30
+
31
+ class CudaUtils:
32
+ @staticmethod
33
+ def is_cuda() -> bool:
34
+ """Check if Triton is running on CUDA backend."""
35
+ return driver.active.get_current_target().backend == "cuda"
36
+
37
+ @staticmethod
38
+ def verify_tma() -> bool:
39
+ """Check if TMA is supported on the current device."""
40
+ return (
41
+ CudaUtils.is_cuda()
42
+ and torch.cuda.is_available()
43
+ and torch.cuda.get_device_capability()[0] >= 9
44
+ )
45
+
46
+ @staticmethod
47
+ def get_num_sms() -> int:
48
+ """Get the number of streaming multiprocessors on the current device."""
49
+ if not CudaUtils.is_cuda():
50
+ raise RuntimeError("Triton is not running on CUDA backend")
51
+ if not torch.cuda.is_available():
52
+ raise RuntimeError("CUDA is not available")
53
+ return torch.cuda.get_device_properties("cuda").multi_processor_count
54
+
55
+
56
+ class TmaDescriptorHelper:
57
+ """Helper class for managing TMA descriptors in Triton kernels."""
58
+
59
+ class KernelParamWrapper:
60
+ """Wrapper to implement the TmaDescKernelParam interface."""
61
+
62
+ def __init__(self, desc: torch.Tensor):
63
+ self.desc = desc
64
+
65
+ def tma_desc_cpu_ptr(self) -> int:
66
+ """Return the CPU pointer to the TMA descriptor."""
67
+ return self.desc.data_ptr()
68
+
69
+ def __init__(self, tma_size: int = 128):
70
+ """Initialize the TMA descriptor helper.
71
+
72
+ Args:
73
+ tma_size: Size of the TMA descriptor in bytes
74
+ """
75
+ if not CudaUtils.verify_tma():
76
+ raise RuntimeError(
77
+ "TMA not supported on this device (requires Hopper or newer)"
78
+ )
79
+ if "nv_tma_desc_type" not in dir(tl):
80
+ raise RuntimeError(
81
+ "TMA grid constant descriptors not supported in your Triton version"
82
+ )
83
+
84
+ self.tma_size = tma_size
85
+ self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
86
+ self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
87
+ self.descriptors: Dict[str, torch.Tensor] = {}
88
+
89
+ def init_tma_descriptor(self, name: str) -> None:
90
+ """Initialize a TMA descriptor with the given name.
91
+
92
+ Call this method outside of the lambda function for grid size.
93
+ """
94
+ self.descriptors[name] = torch.empty(
95
+ self.tma_size, device="cpu", dtype=torch.int8
96
+ )
97
+
98
+ def fill_1d_tma_descriptor(
99
+ self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
100
+ ) -> None:
101
+ """Fill a 1D TMA descriptor.
102
+
103
+ Call this method inside the lambda function for grid size.
104
+ """
105
+ if name not in self.descriptors:
106
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
107
+
108
+ desc_x = self.descriptors[name]
109
+ if desc_x.data_ptr() % 64 != 0:
110
+ raise ValueError("TMA descriptor must be 64-byte aligned")
111
+ self.fill_1d_tma_descriptor_inner(
112
+ ptr, dim, block_dim, element_size, desc_x.data_ptr()
113
+ )
114
+
115
+ def fill_2d_tma_descriptor(
116
+ self,
117
+ name: str,
118
+ ptr: int,
119
+ dim1: int,
120
+ dim0: int,
121
+ block_dim1: int,
122
+ block_dim0: int,
123
+ element_size: int,
124
+ ) -> None:
125
+ """Fill a 2D TMA descriptor.
126
+
127
+ Call this method inside the lambda function for grid size.
128
+ """
129
+ if name not in self.descriptors:
130
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
131
+
132
+ desc_x = self.descriptors[name]
133
+ if desc_x.data_ptr() % 64 != 0:
134
+ raise ValueError("TMA descriptor must be 64-byte aligned")
135
+ self.fill_2d_tma_descriptor_inner(
136
+ ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
137
+ )
138
+
139
+ def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
140
+ """Get the TMA descriptor kernel parameter for the given name."""
141
+ if name not in self.descriptors or self.descriptors[name] is None:
142
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
143
+ return self.KernelParamWrapper(self.descriptors[name])
144
+
145
+
146
+ # ====== Autotuning utilities ======
147
+ ALIGN_SIZE_M = 128
148
+
149
+ _NV_CONFIGS = [
150
+ triton.Config(
151
+ {
152
+ "BLOCK_SIZE_M": block_size_m,
153
+ "BLOCK_SIZE_N": block_size_n,
154
+ "BLOCK_SIZE_K": block_size_k,
155
+ },
156
+ num_stages=num_stages,
157
+ num_warps=num_warps,
158
+ num_ctas=num_ctas,
159
+ )
160
+ for block_size_m in [ALIGN_SIZE_M, ]
161
+ for block_size_n in [64, 128, 256]
162
+ for block_size_k in [64, 128, 256]
163
+ for num_stages in [3, 4]
164
+ for num_warps in [4, 8]
165
+ for num_ctas in [1]
166
+ ]
167
+
168
+
169
+ def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
170
+ device = torch.cuda.current_device()
171
+ # Check for all possible pointer parameter names
172
+ if "grad_input_ptr" in named_args:
173
+ ptr_name = "grad_input_ptr"
174
+ elif "c_ptr" in named_args:
175
+ ptr_name = "c_ptr"
176
+ elif "grad_weight_ptr" in named_args:
177
+ ptr_name = "grad_weight_ptr"
178
+ else:
179
+ raise KeyError("No recognized pointer parameter found in kernel arguments")
180
+
181
+ if dtsize is None:
182
+ dtsize = named_args[ptr_name].element_size()
183
+ if dtype is None:
184
+ dtype = named_args[ptr_name].dtype
185
+
186
+ pruned_configs = []
187
+ for config in configs:
188
+ kw = config.kwargs
189
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
190
+ kw["BLOCK_SIZE_M"],
191
+ kw["BLOCK_SIZE_N"],
192
+ kw["BLOCK_SIZE_K"],
193
+ config.num_stages,
194
+ )
195
+ G, M, N, K = (
196
+ named_args["G"],
197
+ named_args["M_BUCKET"],
198
+ named_args["N"],
199
+ named_args["K"],
200
+ )
201
+
202
+ # 1. make sure we have enough smem
203
+ max_shared_memory = driver.active.utils.get_device_properties(device)[
204
+ "max_shared_mem"
205
+ ]
206
+
207
+ required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
208
+ if required_shared_memory > max_shared_memory:
209
+ continue
210
+
211
+ M_PER_GROUP = M // G
212
+ MIN_M_TILES = 64
213
+ # 2. make sure we don't load M tiles that are too big
214
+ if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
215
+ continue
216
+ # 3. make sure we don't load N tiles that are too small
217
+ if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
218
+ continue
219
+
220
+ num_sm = driver.active.utils.get_device_properties(device)[
221
+ "multiprocessor_count"
222
+ ]
223
+ N_TILES = N // BLOCK_N
224
+ MIN_N_TILES = 64
225
+ # 4. make sure we don't load N tiles that are too big
226
+ if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
227
+ continue
228
+ # 5. make sure we don't load N tiles that are too small
229
+ if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
230
+ continue
231
+ # 6. make sure K can be evenly divided
232
+ if K % BLOCK_K != 0:
233
+ continue
234
+
235
+ pruned_configs.append(config)
236
+
237
+ return pruned_configs
238
+
239
+
240
+ # ======== End Autotuning utilities ========
torchtitan/experiments/llama4/README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **The Llama 4 folder is still under development.**
2
+
3
+ #### Available features
4
+ - Llama 4 model definition (text-only), including the MoE architecture with token-choice routing
5
+ - Basic FSDP, TP, PP, CP support
6
+ - DCP checkpoint conversion scripts
7
+
8
+ #### Download Llama 4 tokenizer
9
+ ```bash
10
+ # Llama 4 tokenizer.model
11
+ python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E --tokenizer_path "" --hf_token=...
12
+ ```
13
+
14
+ #### To be added
15
+ - Modeling
16
+ - iRoPE implementation
17
+ - load balance loss for token-choice MoE
18
+ - alternative expert-choice MoE
19
+ - multimodal support
20
+ - Kernel integration
21
+ - efficient bfloat16 GroupedGEMM kernels (from PyTorch core)
22
+ - efficient float8 GroupedGEMM kernels (from torchao)
23
+ - Parallelism
24
+ - performant TP implementation and torch.compile support for MoE layers
25
+ - Context Parallel support for FlexAttention, iRoPE, and multimodal inputs
26
+ - Expert Parallel support
27
+ - Testing
28
+ - perfomance and loss converging tests
29
+ - CI integration
torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.66 kB). View file
 
torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc ADDED
Binary file (5.55 kB). View file
 
torchtitan/experiments/llama4/infra/expert_parallel.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ from functools import partial
9
+ from typing import Optional, Tuple
10
+
11
+ import torch.nn as nn
12
+ from torch.distributed.tensor import (
13
+ DeviceMesh,
14
+ distribute_module,
15
+ distribute_tensor,
16
+ DTensor,
17
+ Partial,
18
+ Replicate,
19
+ Shard,
20
+ )
21
+ from torch.distributed.tensor.parallel import ParallelStyle
22
+ from torch.distributed.tensor.placement_types import Placement
23
+
24
+
25
+ # implementation of Tensor Parallel on the non-shared experts in MoE
26
+ class TensorParallel(ParallelStyle):
27
+ def __init__(
28
+ self,
29
+ *,
30
+ input_layouts: Optional[Tuple[Optional[Placement]]] = None,
31
+ output_layout: Optional[Placement] = None,
32
+ use_local_output: bool = True,
33
+ ):
34
+ super().__init__()
35
+ self.input_layouts = input_layouts or (Replicate(), None)
36
+ self.output_layout = output_layout or Partial()
37
+ self.desired_input_layouts = (Replicate(), None)
38
+ self.use_local_output = use_local_output
39
+
40
+ @staticmethod
41
+ def _prepare_input_fn(
42
+ input_layouts, desired_input_layouts, mod, inputs, device_mesh
43
+ ):
44
+ # TODO: figure out dynamo support for instance method and switch this to instance method
45
+
46
+ # annotate module input placements/sharding with input_layouts
47
+ input_tensor, input_layout, desired_input_layout = (
48
+ inputs[0],
49
+ input_layouts[0],
50
+ desired_input_layouts[0],
51
+ )
52
+ if not isinstance(input_tensor, DTensor):
53
+ input_tensor = DTensor.from_local(
54
+ input_tensor, device_mesh, (input_layout,), run_check=False
55
+ )
56
+
57
+ if input_layouts != desired_input_layouts:
58
+ input_tensor = input_tensor.redistribute(
59
+ placements=(desired_input_layout,), async_op=True
60
+ )
61
+ return (input_tensor, *inputs[1:])
62
+
63
+ def _partition_fn(self, name, module, device_mesh):
64
+ module.register_parameter(
65
+ "w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)]))
66
+ ) # Column-wise sharding
67
+ module.register_parameter(
68
+ "w2",
69
+ nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])),
70
+ ) # Row-wise sharding
71
+ module.register_parameter(
72
+ "w3",
73
+ nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])),
74
+ ) # Column-wise sharding
75
+
76
+ @staticmethod
77
+ def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
78
+ if outputs.placements != (output_layout,):
79
+ outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
80
+ # back to local tensor
81
+ return outputs.to_local() if use_local_output else outputs
82
+
83
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
84
+ return distribute_module(
85
+ module,
86
+ device_mesh,
87
+ self._partition_fn,
88
+ partial(
89
+ self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
90
+ ),
91
+ partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
92
+ )
93
+
94
+
95
+ # NOTE: This is to achieve replicate computation on the gate module in the MoE router.
96
+ # It does nothing other than (1) setting the module parameters as DTensors on the given mesh
97
+ # and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
98
+ # TODO: The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
99
+ # which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
100
+ class NoParallel(ParallelStyle):
101
+ def __init__(
102
+ self,
103
+ *,
104
+ input_layout: Optional[Placement] = None,
105
+ output_layout: Optional[Placement] = None,
106
+ use_local_output: bool = True,
107
+ ):
108
+ super().__init__()
109
+ self.input_layout = input_layout or Replicate()
110
+ self.output_layout = output_layout or Replicate()
111
+ self.desired_input_layout = Replicate()
112
+ self.use_local_output = use_local_output
113
+
114
+ @staticmethod
115
+ def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
116
+ # annotate module input placements/sharding with input_layouts
117
+ input_tensor = inputs[0]
118
+ if not isinstance(input_tensor, DTensor):
119
+ input_tensor = DTensor.from_local(
120
+ input_tensor, device_mesh, (input_layout,), run_check=False
121
+ )
122
+
123
+ if input_layout != desired_input_layout:
124
+ input_tensor = input_tensor.redistribute(
125
+ placements=(desired_input_layout,), async_op=True
126
+ )
127
+ return (input_tensor, *inputs[1:])
128
+
129
+ @staticmethod
130
+ def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
131
+ if outputs.placements != (output_layout,):
132
+ outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
133
+ # back to local tensor
134
+ return outputs.to_local() if use_local_output else outputs
135
+
136
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
137
+ return distribute_module(
138
+ module,
139
+ device_mesh,
140
+ None,
141
+ partial(
142
+ self._prepare_input_fn, self.input_layout, self.desired_input_layout
143
+ ),
144
+ partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
145
+ )
torchtitan/experiments/llama4/infra/parallelize_llama.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.distributed.device_mesh import DeviceMesh
11
+
12
+ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
13
+ from torchtitan.distributed import ParallelDims
14
+
15
+ from torchtitan.models.llama3.parallelize_llama import (
16
+ apply_ac,
17
+ apply_compile,
18
+ apply_ddp,
19
+ apply_fsdp,
20
+ apply_tp,
21
+ )
22
+ from torchtitan.tools.logging import logger
23
+
24
+
25
+ def parallelize_llama(
26
+ model: nn.Module,
27
+ world_mesh: DeviceMesh,
28
+ parallel_dims: ParallelDims,
29
+ job_config: JobConfig,
30
+ ):
31
+ """
32
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
33
+ parallelism to the model.
34
+
35
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
36
+ the model must fit on GPU or CPU memory.
37
+ """
38
+
39
+ if parallel_dims.tp_enabled:
40
+ if (
41
+ job_config.parallelism.enable_async_tensor_parallel
42
+ and not job_config.training.compile
43
+ ):
44
+ raise RuntimeError("Async TP requires --training.compile")
45
+
46
+ enable_float8_linear = "float8" in job_config.model.converters
47
+ float8_is_rowwise = job_config.float8.recipe_name in (
48
+ "rowwise",
49
+ "rowwise_with_gw_hp",
50
+ )
51
+
52
+ # For now, float8 all-gather with TP is only supported for tensorwise
53
+ # float8 scaling recipes. For rowwise recipes, we use regular TP and
54
+ # all-gather happens in high precision.
55
+ enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
56
+
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
62
+ enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
63
+ )
64
+
65
+ apply_moe_tp(model, world_mesh["tp"])
66
+
67
+ if job_config.activation_checkpoint.mode != "none":
68
+ if (
69
+ job_config.activation_checkpoint.mode == "selective"
70
+ and job_config.model.use_flex_attn
71
+ ):
72
+ raise ValueError(
73
+ "FlexAttention is not compatible with selective AC yet. "
74
+ "See https://github.com/pytorch/pytorch/issues/147879"
75
+ )
76
+ apply_ac(model, job_config.activation_checkpoint)
77
+
78
+ # turn on per-TransformerBlock compile after AC wrapping and before FSDP
79
+ if job_config.training.compile:
80
+ apply_compile(model)
81
+
82
+ # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
83
+ torch._dynamo.config.capture_scalar_outputs = True
84
+
85
+ if (
86
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
87
+ ): # apply FSDP or HSDP, potentially with Context Parallel
88
+ if parallel_dims.dp_replicate_enabled:
89
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
90
+ else:
91
+ dp_mesh_dim_names = ("dp_shard_cp",)
92
+
93
+ apply_fsdp(
94
+ model,
95
+ world_mesh[tuple(dp_mesh_dim_names)],
96
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
97
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
98
+ pp_enabled=parallel_dims.pp_enabled,
99
+ cpu_offload=job_config.training.enable_cpu_offload,
100
+ reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
101
+ )
102
+
103
+ if parallel_dims.dp_replicate_enabled:
104
+ logger.info("Applied HSDP to the model")
105
+ else:
106
+ logger.info("Applied FSDP to the model")
107
+
108
+ if parallel_dims.cp_enabled:
109
+ logger.info("Applied Context Parallel to the model")
110
+
111
+ if job_config.training.enable_cpu_offload:
112
+ logger.info("Applied CPU Offloading to the model")
113
+ elif parallel_dims.dp_replicate_enabled:
114
+ if world_mesh.ndim > 1:
115
+ raise RuntimeError("DDP has not supported > 1D parallelism")
116
+ apply_ddp(
117
+ model,
118
+ world_mesh,
119
+ enable_compile=job_config.training.compile,
120
+ enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
121
+ )
122
+
123
+ return model
124
+
125
+
126
+ def apply_moe_tp(
127
+ model: nn.Module,
128
+ tp_mesh: DeviceMesh,
129
+ ):
130
+ from torch.distributed.tensor import Partial, Replicate, Shard
131
+ from torch.distributed.tensor.parallel import (
132
+ parallelize_module,
133
+ PrepareModuleInputOutput,
134
+ )
135
+
136
+ from .expert_parallel import NoParallel, TensorParallel
137
+
138
+ for _, transformer_block in model.layers.items():
139
+ moe_layer_plan = {
140
+ # input / output sharding on the seqlen dim
141
+ # all-gather for input, reduce-scatter for output
142
+ "moe": PrepareModuleInputOutput(
143
+ input_layouts=(Shard(1),),
144
+ desired_input_layouts=(Replicate(),),
145
+ use_local_input=True,
146
+ output_layouts=(Partial(),),
147
+ desired_output_layouts=(Shard(1),),
148
+ ),
149
+ # replicate computation for the router
150
+ "moe.router.gate": NoParallel(),
151
+ # input Replicate, output Partial
152
+ "moe.experts": TensorParallel(),
153
+ "moe.shared_expert": TensorParallel(),
154
+ }
155
+ parallelize_module(
156
+ module=transformer_block,
157
+ device_mesh=tp_mesh,
158
+ parallelize_plan=moe_layer_plan,
159
+ )
torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc ADDED
Binary file (4.1 kB). View file
 
torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc ADDED
Binary file (23.2 kB). View file
 
torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc ADDED
Binary file (10.5 kB). View file